add ds thinking
This commit is contained in:
@@ -50,6 +50,12 @@ GUIDANCE_SYSTEM_PROMPT = """你是一个提示词写作教练,帮助学生写
|
||||
6. 不要生成任何代码"""
|
||||
|
||||
DEFAULT_MODEL = "deepseek-v4-flash"
|
||||
DEEPSEEK_THINKING_MODEL = "deepseek-v4-flash-thinking"
|
||||
MODEL_ALIASES = {
|
||||
DEEPSEEK_THINKING_MODEL: DEFAULT_MODEL,
|
||||
}
|
||||
NON_THINKING_MODELS = {"deepseek-v4-flash"}
|
||||
NON_THINKING_EXTRA_BODY = {"thinking": {"type": "disabled"}}
|
||||
|
||||
# Models served by the ARK (Volcengine) endpoint
|
||||
ARK_MODELS = {"doubao-seed-2-0-lite-260215"}
|
||||
@@ -64,14 +70,16 @@ def build_messages(history: list[dict]) -> list[dict]:
|
||||
|
||||
def _get_client(model: str) -> tuple[AsyncOpenAI, str]:
|
||||
"""Return (client, model_id) for the given model name."""
|
||||
if model in ARK_MODELS:
|
||||
requested_model = model or DEFAULT_MODEL
|
||||
resolved_model = MODEL_ALIASES.get(requested_model, requested_model)
|
||||
if resolved_model in ARK_MODELS:
|
||||
return (
|
||||
AsyncOpenAI(
|
||||
api_key=settings.ARK_API_KEY,
|
||||
base_url=settings.ARK_BASE_URL,
|
||||
timeout=120.0,
|
||||
),
|
||||
model,
|
||||
resolved_model,
|
||||
)
|
||||
return (
|
||||
AsyncOpenAI(
|
||||
@@ -79,19 +87,46 @@ def _get_client(model: str) -> tuple[AsyncOpenAI, str]:
|
||||
base_url=settings.LLM_BASE_URL,
|
||||
timeout=120.0,
|
||||
),
|
||||
model or DEFAULT_MODEL,
|
||||
resolved_model,
|
||||
)
|
||||
|
||||
|
||||
def _should_disable_thinking(requested_model: str, resolved_model: str) -> bool:
|
||||
return (
|
||||
resolved_model in NON_THINKING_MODELS
|
||||
and requested_model not in MODEL_ALIASES
|
||||
)
|
||||
|
||||
|
||||
def _chat_completion_kwargs(
|
||||
requested_model: str,
|
||||
resolved_model: str,
|
||||
messages: list[dict],
|
||||
stream: bool,
|
||||
) -> dict:
|
||||
kwargs = {
|
||||
"model": resolved_model,
|
||||
"messages": messages,
|
||||
"stream": stream,
|
||||
}
|
||||
if _should_disable_thinking(requested_model, resolved_model):
|
||||
kwargs["extra_body"] = NON_THINKING_EXTRA_BODY
|
||||
return kwargs
|
||||
|
||||
|
||||
async def stream_chat(history: list[dict], model: str = ""):
|
||||
"""Stream chat completion from the LLM. Yields content chunks."""
|
||||
messages = build_messages(history)
|
||||
client, resolved_model = _get_client(model)
|
||||
requested_model = model or DEFAULT_MODEL
|
||||
async with client as c:
|
||||
stream = await c.chat.completions.create(
|
||||
model=resolved_model,
|
||||
messages=messages,
|
||||
stream=True,
|
||||
**_chat_completion_kwargs(
|
||||
requested_model,
|
||||
resolved_model,
|
||||
messages,
|
||||
stream=True,
|
||||
),
|
||||
)
|
||||
async for chunk in stream:
|
||||
delta = chunk.choices[0].delta
|
||||
@@ -137,11 +172,15 @@ async def stream_guidance(history: list[dict]):
|
||||
messages = [{"role": "system", "content": GUIDANCE_SYSTEM_PROMPT}]
|
||||
messages.extend(history)
|
||||
client, model = _get_client("")
|
||||
requested_model = DEFAULT_MODEL
|
||||
async with client as c:
|
||||
stream = await c.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
stream=True,
|
||||
**_chat_completion_kwargs(
|
||||
requested_model,
|
||||
model,
|
||||
messages,
|
||||
stream=True,
|
||||
),
|
||||
)
|
||||
async for chunk in stream:
|
||||
delta = chunk.choices[0].delta
|
||||
|
||||
115
prompt/tests.py
115
prompt/tests.py
@@ -1,4 +1,7 @@
|
||||
from datetime import timedelta
|
||||
from unittest.mock import patch
|
||||
|
||||
from asgiref.sync import async_to_sync
|
||||
from django.test import TestCase
|
||||
from django.contrib.auth import get_user_model
|
||||
from django.utils import timezone
|
||||
@@ -297,7 +300,117 @@ class PromptHistoryTest(TestCase):
|
||||
self.assertIsNone(data[0]["code_js"])
|
||||
|
||||
|
||||
from prompt.llm import GUIDANCE_SYSTEM_PROMPT, parse_guidance_response
|
||||
from prompt.llm import (
|
||||
DEFAULT_MODEL,
|
||||
GUIDANCE_SYSTEM_PROMPT,
|
||||
parse_guidance_response,
|
||||
stream_chat,
|
||||
stream_guidance,
|
||||
)
|
||||
|
||||
|
||||
class _FakeDelta:
|
||||
def __init__(self, content):
|
||||
self.content = content
|
||||
|
||||
|
||||
class _FakeChoice:
|
||||
def __init__(self, content):
|
||||
self.delta = _FakeDelta(content)
|
||||
|
||||
|
||||
class _FakeChunk:
|
||||
def __init__(self, content):
|
||||
self.choices = [_FakeChoice(content)]
|
||||
|
||||
|
||||
class _FakeStream:
|
||||
def __init__(self, chunks):
|
||||
self._chunks = iter(chunks)
|
||||
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
try:
|
||||
return _FakeChunk(next(self._chunks))
|
||||
except StopIteration:
|
||||
raise StopAsyncIteration
|
||||
|
||||
|
||||
class _FakeCompletions:
|
||||
def __init__(self):
|
||||
self.kwargs = None
|
||||
|
||||
async def create(self, **kwargs):
|
||||
self.kwargs = kwargs
|
||||
return _FakeStream(["ok"])
|
||||
|
||||
|
||||
class _FakeChat:
|
||||
def __init__(self, completions):
|
||||
self.completions = completions
|
||||
|
||||
|
||||
class _FakeClient:
|
||||
def __init__(self):
|
||||
self.completions = _FakeCompletions()
|
||||
self.chat = _FakeChat(self.completions)
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return None
|
||||
|
||||
|
||||
async def _collect_stream(stream):
|
||||
return [chunk async for chunk in stream]
|
||||
|
||||
|
||||
class DeepSeekThinkingModeTest(TestCase):
|
||||
def test_stream_chat_disables_thinking_for_deepseek_flash(self):
|
||||
client = _FakeClient()
|
||||
|
||||
with patch("prompt.llm._get_client", return_value=(client, DEFAULT_MODEL)):
|
||||
chunks = async_to_sync(_collect_stream)(
|
||||
stream_chat([{"role": "user", "content": "做一个按钮"}])
|
||||
)
|
||||
|
||||
self.assertEqual(chunks, ["ok"])
|
||||
self.assertEqual(
|
||||
client.completions.kwargs["extra_body"],
|
||||
{"thinking": {"type": "disabled"}},
|
||||
)
|
||||
|
||||
def test_stream_guidance_disables_thinking_for_deepseek_flash(self):
|
||||
client = _FakeClient()
|
||||
|
||||
with patch("prompt.llm._get_client", return_value=(client, DEFAULT_MODEL)):
|
||||
chunks = async_to_sync(_collect_stream)(
|
||||
stream_guidance([{"role": "user", "content": "做一个页面"}])
|
||||
)
|
||||
|
||||
self.assertEqual(chunks, ["ok"])
|
||||
self.assertEqual(
|
||||
client.completions.kwargs["extra_body"],
|
||||
{"thinking": {"type": "disabled"}},
|
||||
)
|
||||
|
||||
def test_stream_chat_thinking_option_uses_deepseek_flash_without_disabling_thinking(self):
|
||||
client = _FakeClient()
|
||||
|
||||
with patch("prompt.llm.AsyncOpenAI", return_value=client):
|
||||
chunks = async_to_sync(_collect_stream)(
|
||||
stream_chat(
|
||||
[{"role": "user", "content": "做一个按钮"}],
|
||||
model="deepseek-v4-flash-thinking",
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEqual(chunks, ["ok"])
|
||||
self.assertEqual(client.completions.kwargs["model"], DEFAULT_MODEL)
|
||||
self.assertNotIn("extra_body", client.completions.kwargs)
|
||||
|
||||
|
||||
class ParseGuidanceResponseTest(TestCase):
|
||||
|
||||
Reference in New Issue
Block a user