From 22c3b65e289d5f2bdbd91a389d0a87b47f0bf572 Mon Sep 17 00:00:00 2001 From: yuetsh <517252939@qq.com> Date: Thu, 7 May 2026 09:51:47 -0600 Subject: [PATCH] add prompt assistant --- prompt/consumers.py | 92 ++++++++++++++++++++++++++++++++++++++++++++- prompt/llm.py | 6 ++- prompt/tests.py | 8 +++- prompt/url.py | 17 +++++++-- 4 files changed, 115 insertions(+), 8 deletions(-) diff --git a/prompt/consumers.py b/prompt/consumers.py index 35d9786..fa74be3 100644 --- a/prompt/consumers.py +++ b/prompt/consumers.py @@ -3,7 +3,7 @@ from channels.generic.websocket import AsyncWebsocketConsumer from channels.db import database_sync_to_async from django.db.models import Count from .models import Conversation, Message -from .llm import stream_chat, extract_code +from .llm import stream_chat, extract_code, stream_guidance, parse_guidance_response class PromptConsumer(AsyncWebsocketConsumer): @@ -126,3 +126,93 @@ class PromptConsumer(AsyncWebsocketConsumer): def get_history_for_llm(self): messages = self.conversation.messages.filter(source="conversation") return [{"role": m.role, "content": m.content} for m in messages] + + +class GuidanceConsumer(AsyncWebsocketConsumer): + async def connect(self): + self.user = self.scope["user"] + if self.user.is_anonymous: + await self.close() + return + + self.task_id = int(self.scope["url_route"]["kwargs"]["task_id"]) + self.current_user_message = None + self.session_messages = [] + await self.accept() + + self.conversation = await self.get_or_create_conversation() + await self.send(text_data=json.dumps({"type": "init", "messages": []})) + + async def disconnect(self, close_code): + if self.current_user_message: + await self.delete_message(self.current_user_message) + self.current_user_message = None + + async def receive(self, text_data): + data = json.loads(text_data) + prompt = data.get("content", "").strip() + if not prompt: + return + + self.current_user_message = await self.save_message("user", prompt) + self.session_messages.append({"role": "user", "content": prompt}) + + try: + full_response = "" + try: + async for chunk in stream_guidance(self.session_messages): + full_response += chunk + await self.send(text_data=json.dumps({ + "type": "stream", + "content": chunk, + })) + except Exception as e: + await self.send(text_data=json.dumps({ + "type": "error", + "content": f"AI 服务出错:{str(e)}", + })) + return + + clean_response, is_ready = parse_guidance_response(full_response) + self.session_messages.append({ + "role": "assistant", + "content": clean_response, + }) + assistant_msg = await self.save_message("assistant", clean_response) + self.current_user_message = None + + await self.send(text_data=json.dumps({ + "type": "complete", + "message_id": assistant_msg.id, + "is_ready": is_ready, + })) + + finally: + if self.current_user_message: + await self.delete_message(self.current_user_message) + self.current_user_message = None + + @database_sync_to_async + def get_or_create_conversation(self): + conv = ( + Conversation.objects.filter(user=self.user, task_id=self.task_id) + .annotate(msg_count=Count("messages")) + .order_by("-msg_count", "-created") + .first() + ) + if not conv: + conv = Conversation.objects.create(user=self.user, task_id=self.task_id) + return conv + + @database_sync_to_async + def delete_message(self, message): + message.delete() + + @database_sync_to_async + def save_message(self, role, content): + return Message.objects.create( + conversation=self.conversation, + role=role, + source="guidance", + content=content, + ) diff --git a/prompt/llm.py b/prompt/llm.py index 5f215ae..a09e83c 100644 --- a/prompt/llm.py +++ b/prompt/llm.py @@ -45,9 +45,11 @@ GUIDANCE_SYSTEM_PROMPT = """你是一个提示词写作教练,帮助学生写 1. 如果提示词不够好,用 1-2 个启发性问题引导学生补充细节,不要直接给出答案 2. 如果提示词已经够好,以 [READY] 开头回复,简短夸奖学生并说明可以生成了 3. 用中文回复,语气鼓励,简洁明了 -4. 不要生成任何代码""" +4. 使用 Markdown 语法高亮关键词,优先突出 **主题**、**视觉**、**交互**、**内容**、**可以生成** 等重点 +5. 如果回复以 [READY] 开头,[READY] 不要加粗,必须保持原始文本 +6. 不要生成任何代码""" -DEFAULT_MODEL = "deepseek-chat" +DEFAULT_MODEL = "deepseek-v4-flash" # Models served by the ARK (Volcengine) endpoint ARK_MODELS = {"doubao-seed-2-0-lite-260215"} diff --git a/prompt/tests.py b/prompt/tests.py index fd7786d..d72c3ce 100644 --- a/prompt/tests.py +++ b/prompt/tests.py @@ -297,10 +297,16 @@ class PromptHistoryTest(TestCase): self.assertIsNone(data[0]["code_js"]) -from prompt.llm import parse_guidance_response +from prompt.llm import GUIDANCE_SYSTEM_PROMPT, parse_guidance_response class ParseGuidanceResponseTest(TestCase): + def test_guidance_prompt_asks_for_bold_keywords(self): + self.assertIn("Markdown", GUIDANCE_SYSTEM_PROMPT) + self.assertIn("[READY] 不要加粗", GUIDANCE_SYSTEM_PROMPT) + for keyword in ("**主题**", "**视觉**", "**交互**", "**内容**", "**可以生成**"): + self.assertIn(keyword, GUIDANCE_SYSTEM_PROMPT) + def test_ready_prefix_with_newline_stripped(self): content, is_ready = parse_guidance_response("[READY]\n很好,可以生成了!") self.assertEqual(content, "很好,可以生成了!") diff --git a/prompt/url.py b/prompt/url.py index 73584f2..53828f9 100644 --- a/prompt/url.py +++ b/prompt/url.py @@ -1,6 +1,15 @@ -from django.urls import path -from .consumers import PromptConsumer +from collections.abc import Callable +from typing import Any, cast -websocket_urlpatterns = [ - path("ws/prompt//", PromptConsumer.as_asgi()), +from django.urls import path +from django.urls.resolvers import URLPattern, URLResolver + +from .consumers import PromptConsumer, GuidanceConsumer + +AsgiApplication = Callable[..., Any] +RoutePattern = URLPattern | URLResolver + +websocket_urlpatterns: list[RoutePattern] = [ + path("ws/prompt//", cast(AsgiApplication, PromptConsumer.as_asgi())), + path("ws/guidance//", cast(AsgiApplication, GuidanceConsumer.as_asgi())), ]