From e2449869aa4cc311a25406f47c1acc87b7266f75 Mon Sep 17 00:00:00 2001 From: yuetsh <517252939@qq.com> Date: Tue, 31 Mar 2026 05:50:27 -0600 Subject: [PATCH] change model --- api/settings.py | 4 ++++ prompt/consumers.py | 3 ++- prompt/llm.py | 38 ++++++++++++++++++++++++++++++-------- 3 files changed, 36 insertions(+), 9 deletions(-) diff --git a/api/settings.py b/api/settings.py index 8966b0f..4691518 100644 --- a/api/settings.py +++ b/api/settings.py @@ -206,3 +206,7 @@ MEDIA_ROOT = BASE_DIR / "media" LLM_API_KEY = os.environ.get("LLM_API_KEY", "") LLM_BASE_URL = os.environ.get("LLM_BASE_URL", "https://api.deepseek.com") LLM_MODEL = os.environ.get("LLM_MODEL", "deepseek-chat") + +# ARK (Volcengine) LLM Configuration +ARK_API_KEY = os.environ.get("ARK_API_KEY", "") +ARK_BASE_URL = os.environ.get("ARK_BASE_URL", "https://ark.cn-beijing.volces.com/api/v3") diff --git a/prompt/consumers.py b/prompt/consumers.py index 881b9fe..a71ca46 100644 --- a/prompt/consumers.py +++ b/prompt/consumers.py @@ -39,6 +39,7 @@ class PromptConsumer(AsyncWebsocketConsumer): return prompt = data.get("content", "").strip() + model = data.get("model", "") if not prompt: return @@ -52,7 +53,7 @@ class PromptConsumer(AsyncWebsocketConsumer): # Stream AI response full_response = "" try: - async for chunk in stream_chat(history): + async for chunk in stream_chat(history, model=model): full_response += chunk await self.send(text_data=json.dumps({ "type": "stream", diff --git a/prompt/llm.py b/prompt/llm.py index 551fab6..99b5433 100644 --- a/prompt/llm.py +++ b/prompt/llm.py @@ -13,6 +13,9 @@ SYSTEM_PROMPT = """你是一个网页生成助手。根据用户的需求描述 5. 在已有代码基础上修改时,返回完整的修改后代码,不要只返回片段 6. 由于任何外部链接都被屏蔽,使用纯 HTML、CSS 和 JS 实现功能,不要依赖外部库""" +# Models served by the ARK (Volcengine) endpoint +ARK_MODELS = {"doubao-seed-2-0-mini-260215"} + def build_messages(history: list[dict]) -> list[dict]: """Build the message list for the LLM API call.""" @@ -21,16 +24,35 @@ def build_messages(history: list[dict]) -> list[dict]: return messages -async def stream_chat(history: list[dict]): +def _get_client(model: str) -> tuple[AsyncOpenAI, str]: + """Return (client, model_id) for the given model name.""" + if model in ARK_MODELS: + return ( + AsyncOpenAI( + api_key=settings.ARK_API_KEY, + base_url=settings.ARK_BASE_URL, + timeout=120.0, + ), + model, + ) + resolved_model = model or settings.LLM_MODEL + return ( + AsyncOpenAI( + api_key=settings.LLM_API_KEY, + base_url=settings.LLM_BASE_URL, + timeout=120.0, + ), + resolved_model, + ) + + +async def stream_chat(history: list[dict], model: str = ""): """Stream chat completion from the LLM. Yields content chunks.""" messages = build_messages(history) - async with AsyncOpenAI( - api_key=settings.LLM_API_KEY, - base_url=settings.LLM_BASE_URL, - timeout=120.0, - ) as client: - stream = await client.chat.completions.create( - model=settings.LLM_MODEL, + client, resolved_model = _get_client(model or settings.LLM_MODEL) + async with client as c: + stream = await c.chat.completions.create( + model=resolved_model, messages=messages, stream=True, )