change model
This commit is contained in:
@@ -206,3 +206,7 @@ MEDIA_ROOT = BASE_DIR / "media"
|
|||||||
LLM_API_KEY = os.environ.get("LLM_API_KEY", "")
|
LLM_API_KEY = os.environ.get("LLM_API_KEY", "")
|
||||||
LLM_BASE_URL = os.environ.get("LLM_BASE_URL", "https://api.deepseek.com")
|
LLM_BASE_URL = os.environ.get("LLM_BASE_URL", "https://api.deepseek.com")
|
||||||
LLM_MODEL = os.environ.get("LLM_MODEL", "deepseek-chat")
|
LLM_MODEL = os.environ.get("LLM_MODEL", "deepseek-chat")
|
||||||
|
|
||||||
|
# ARK (Volcengine) LLM Configuration
|
||||||
|
ARK_API_KEY = os.environ.get("ARK_API_KEY", "")
|
||||||
|
ARK_BASE_URL = os.environ.get("ARK_BASE_URL", "https://ark.cn-beijing.volces.com/api/v3")
|
||||||
|
|||||||
@@ -39,6 +39,7 @@ class PromptConsumer(AsyncWebsocketConsumer):
|
|||||||
return
|
return
|
||||||
|
|
||||||
prompt = data.get("content", "").strip()
|
prompt = data.get("content", "").strip()
|
||||||
|
model = data.get("model", "")
|
||||||
if not prompt:
|
if not prompt:
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -52,7 +53,7 @@ class PromptConsumer(AsyncWebsocketConsumer):
|
|||||||
# Stream AI response
|
# Stream AI response
|
||||||
full_response = ""
|
full_response = ""
|
||||||
try:
|
try:
|
||||||
async for chunk in stream_chat(history):
|
async for chunk in stream_chat(history, model=model):
|
||||||
full_response += chunk
|
full_response += chunk
|
||||||
await self.send(text_data=json.dumps({
|
await self.send(text_data=json.dumps({
|
||||||
"type": "stream",
|
"type": "stream",
|
||||||
|
|||||||
@@ -13,6 +13,9 @@ SYSTEM_PROMPT = """你是一个网页生成助手。根据用户的需求描述
|
|||||||
5. 在已有代码基础上修改时,返回完整的修改后代码,不要只返回片段
|
5. 在已有代码基础上修改时,返回完整的修改后代码,不要只返回片段
|
||||||
6. 由于任何外部链接都被屏蔽,使用纯 HTML、CSS 和 JS 实现功能,不要依赖外部库"""
|
6. 由于任何外部链接都被屏蔽,使用纯 HTML、CSS 和 JS 实现功能,不要依赖外部库"""
|
||||||
|
|
||||||
|
# Models served by the ARK (Volcengine) endpoint
|
||||||
|
ARK_MODELS = {"doubao-seed-2-0-mini-260215"}
|
||||||
|
|
||||||
|
|
||||||
def build_messages(history: list[dict]) -> list[dict]:
|
def build_messages(history: list[dict]) -> list[dict]:
|
||||||
"""Build the message list for the LLM API call."""
|
"""Build the message list for the LLM API call."""
|
||||||
@@ -21,16 +24,35 @@ def build_messages(history: list[dict]) -> list[dict]:
|
|||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
|
||||||
async def stream_chat(history: list[dict]):
|
def _get_client(model: str) -> tuple[AsyncOpenAI, str]:
|
||||||
|
"""Return (client, model_id) for the given model name."""
|
||||||
|
if model in ARK_MODELS:
|
||||||
|
return (
|
||||||
|
AsyncOpenAI(
|
||||||
|
api_key=settings.ARK_API_KEY,
|
||||||
|
base_url=settings.ARK_BASE_URL,
|
||||||
|
timeout=120.0,
|
||||||
|
),
|
||||||
|
model,
|
||||||
|
)
|
||||||
|
resolved_model = model or settings.LLM_MODEL
|
||||||
|
return (
|
||||||
|
AsyncOpenAI(
|
||||||
|
api_key=settings.LLM_API_KEY,
|
||||||
|
base_url=settings.LLM_BASE_URL,
|
||||||
|
timeout=120.0,
|
||||||
|
),
|
||||||
|
resolved_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def stream_chat(history: list[dict], model: str = ""):
|
||||||
"""Stream chat completion from the LLM. Yields content chunks."""
|
"""Stream chat completion from the LLM. Yields content chunks."""
|
||||||
messages = build_messages(history)
|
messages = build_messages(history)
|
||||||
async with AsyncOpenAI(
|
client, resolved_model = _get_client(model or settings.LLM_MODEL)
|
||||||
api_key=settings.LLM_API_KEY,
|
async with client as c:
|
||||||
base_url=settings.LLM_BASE_URL,
|
stream = await c.chat.completions.create(
|
||||||
timeout=120.0,
|
model=resolved_model,
|
||||||
) as client:
|
|
||||||
stream = await client.chat.completions.create(
|
|
||||||
model=settings.LLM_MODEL,
|
|
||||||
messages=messages,
|
messages=messages,
|
||||||
stream=True,
|
stream=True,
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user