diff --git a/prompt/llm.py b/prompt/llm.py index a09e83c..5e7486a 100644 --- a/prompt/llm.py +++ b/prompt/llm.py @@ -50,6 +50,12 @@ GUIDANCE_SYSTEM_PROMPT = """你是一个提示词写作教练,帮助学生写 6. 不要生成任何代码""" DEFAULT_MODEL = "deepseek-v4-flash" +DEEPSEEK_THINKING_MODEL = "deepseek-v4-flash-thinking" +MODEL_ALIASES = { + DEEPSEEK_THINKING_MODEL: DEFAULT_MODEL, +} +NON_THINKING_MODELS = {"deepseek-v4-flash"} +NON_THINKING_EXTRA_BODY = {"thinking": {"type": "disabled"}} # Models served by the ARK (Volcengine) endpoint ARK_MODELS = {"doubao-seed-2-0-lite-260215"} @@ -64,14 +70,16 @@ def build_messages(history: list[dict]) -> list[dict]: def _get_client(model: str) -> tuple[AsyncOpenAI, str]: """Return (client, model_id) for the given model name.""" - if model in ARK_MODELS: + requested_model = model or DEFAULT_MODEL + resolved_model = MODEL_ALIASES.get(requested_model, requested_model) + if resolved_model in ARK_MODELS: return ( AsyncOpenAI( api_key=settings.ARK_API_KEY, base_url=settings.ARK_BASE_URL, timeout=120.0, ), - model, + resolved_model, ) return ( AsyncOpenAI( @@ -79,19 +87,46 @@ def _get_client(model: str) -> tuple[AsyncOpenAI, str]: base_url=settings.LLM_BASE_URL, timeout=120.0, ), - model or DEFAULT_MODEL, + resolved_model, ) +def _should_disable_thinking(requested_model: str, resolved_model: str) -> bool: + return ( + resolved_model in NON_THINKING_MODELS + and requested_model not in MODEL_ALIASES + ) + + +def _chat_completion_kwargs( + requested_model: str, + resolved_model: str, + messages: list[dict], + stream: bool, +) -> dict: + kwargs = { + "model": resolved_model, + "messages": messages, + "stream": stream, + } + if _should_disable_thinking(requested_model, resolved_model): + kwargs["extra_body"] = NON_THINKING_EXTRA_BODY + return kwargs + + async def stream_chat(history: list[dict], model: str = ""): """Stream chat completion from the LLM. Yields content chunks.""" messages = build_messages(history) client, resolved_model = _get_client(model) + requested_model = model or DEFAULT_MODEL async with client as c: stream = await c.chat.completions.create( - model=resolved_model, - messages=messages, - stream=True, + **_chat_completion_kwargs( + requested_model, + resolved_model, + messages, + stream=True, + ), ) async for chunk in stream: delta = chunk.choices[0].delta @@ -137,11 +172,15 @@ async def stream_guidance(history: list[dict]): messages = [{"role": "system", "content": GUIDANCE_SYSTEM_PROMPT}] messages.extend(history) client, model = _get_client("") + requested_model = DEFAULT_MODEL async with client as c: stream = await c.chat.completions.create( - model=model, - messages=messages, - stream=True, + **_chat_completion_kwargs( + requested_model, + model, + messages, + stream=True, + ), ) async for chunk in stream: delta = chunk.choices[0].delta diff --git a/prompt/tests.py b/prompt/tests.py index d72c3ce..593b1a3 100644 --- a/prompt/tests.py +++ b/prompt/tests.py @@ -1,4 +1,7 @@ from datetime import timedelta +from unittest.mock import patch + +from asgiref.sync import async_to_sync from django.test import TestCase from django.contrib.auth import get_user_model from django.utils import timezone @@ -297,7 +300,117 @@ class PromptHistoryTest(TestCase): self.assertIsNone(data[0]["code_js"]) -from prompt.llm import GUIDANCE_SYSTEM_PROMPT, parse_guidance_response +from prompt.llm import ( + DEFAULT_MODEL, + GUIDANCE_SYSTEM_PROMPT, + parse_guidance_response, + stream_chat, + stream_guidance, +) + + +class _FakeDelta: + def __init__(self, content): + self.content = content + + +class _FakeChoice: + def __init__(self, content): + self.delta = _FakeDelta(content) + + +class _FakeChunk: + def __init__(self, content): + self.choices = [_FakeChoice(content)] + + +class _FakeStream: + def __init__(self, chunks): + self._chunks = iter(chunks) + + def __aiter__(self): + return self + + async def __anext__(self): + try: + return _FakeChunk(next(self._chunks)) + except StopIteration: + raise StopAsyncIteration + + +class _FakeCompletions: + def __init__(self): + self.kwargs = None + + async def create(self, **kwargs): + self.kwargs = kwargs + return _FakeStream(["ok"]) + + +class _FakeChat: + def __init__(self, completions): + self.completions = completions + + +class _FakeClient: + def __init__(self): + self.completions = _FakeCompletions() + self.chat = _FakeChat(self.completions) + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return None + + +async def _collect_stream(stream): + return [chunk async for chunk in stream] + + +class DeepSeekThinkingModeTest(TestCase): + def test_stream_chat_disables_thinking_for_deepseek_flash(self): + client = _FakeClient() + + with patch("prompt.llm._get_client", return_value=(client, DEFAULT_MODEL)): + chunks = async_to_sync(_collect_stream)( + stream_chat([{"role": "user", "content": "做一个按钮"}]) + ) + + self.assertEqual(chunks, ["ok"]) + self.assertEqual( + client.completions.kwargs["extra_body"], + {"thinking": {"type": "disabled"}}, + ) + + def test_stream_guidance_disables_thinking_for_deepseek_flash(self): + client = _FakeClient() + + with patch("prompt.llm._get_client", return_value=(client, DEFAULT_MODEL)): + chunks = async_to_sync(_collect_stream)( + stream_guidance([{"role": "user", "content": "做一个页面"}]) + ) + + self.assertEqual(chunks, ["ok"]) + self.assertEqual( + client.completions.kwargs["extra_body"], + {"thinking": {"type": "disabled"}}, + ) + + def test_stream_chat_thinking_option_uses_deepseek_flash_without_disabling_thinking(self): + client = _FakeClient() + + with patch("prompt.llm.AsyncOpenAI", return_value=client): + chunks = async_to_sync(_collect_stream)( + stream_chat( + [{"role": "user", "content": "做一个按钮"}], + model="deepseek-v4-flash-thinking", + ) + ) + + self.assertEqual(chunks, ["ok"]) + self.assertEqual(client.completions.kwargs["model"], DEFAULT_MODEL) + self.assertNotIn("extra_body", client.completions.kwargs) class ParseGuidanceResponseTest(TestCase):