add prompt assistant
This commit is contained in:
@@ -3,7 +3,7 @@ from channels.generic.websocket import AsyncWebsocketConsumer
|
|||||||
from channels.db import database_sync_to_async
|
from channels.db import database_sync_to_async
|
||||||
from django.db.models import Count
|
from django.db.models import Count
|
||||||
from .models import Conversation, Message
|
from .models import Conversation, Message
|
||||||
from .llm import stream_chat, extract_code
|
from .llm import stream_chat, extract_code, stream_guidance, parse_guidance_response
|
||||||
|
|
||||||
|
|
||||||
class PromptConsumer(AsyncWebsocketConsumer):
|
class PromptConsumer(AsyncWebsocketConsumer):
|
||||||
@@ -126,3 +126,93 @@ class PromptConsumer(AsyncWebsocketConsumer):
|
|||||||
def get_history_for_llm(self):
|
def get_history_for_llm(self):
|
||||||
messages = self.conversation.messages.filter(source="conversation")
|
messages = self.conversation.messages.filter(source="conversation")
|
||||||
return [{"role": m.role, "content": m.content} for m in messages]
|
return [{"role": m.role, "content": m.content} for m in messages]
|
||||||
|
|
||||||
|
|
||||||
|
class GuidanceConsumer(AsyncWebsocketConsumer):
|
||||||
|
async def connect(self):
|
||||||
|
self.user = self.scope["user"]
|
||||||
|
if self.user.is_anonymous:
|
||||||
|
await self.close()
|
||||||
|
return
|
||||||
|
|
||||||
|
self.task_id = int(self.scope["url_route"]["kwargs"]["task_id"])
|
||||||
|
self.current_user_message = None
|
||||||
|
self.session_messages = []
|
||||||
|
await self.accept()
|
||||||
|
|
||||||
|
self.conversation = await self.get_or_create_conversation()
|
||||||
|
await self.send(text_data=json.dumps({"type": "init", "messages": []}))
|
||||||
|
|
||||||
|
async def disconnect(self, close_code):
|
||||||
|
if self.current_user_message:
|
||||||
|
await self.delete_message(self.current_user_message)
|
||||||
|
self.current_user_message = None
|
||||||
|
|
||||||
|
async def receive(self, text_data):
|
||||||
|
data = json.loads(text_data)
|
||||||
|
prompt = data.get("content", "").strip()
|
||||||
|
if not prompt:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.current_user_message = await self.save_message("user", prompt)
|
||||||
|
self.session_messages.append({"role": "user", "content": prompt})
|
||||||
|
|
||||||
|
try:
|
||||||
|
full_response = ""
|
||||||
|
try:
|
||||||
|
async for chunk in stream_guidance(self.session_messages):
|
||||||
|
full_response += chunk
|
||||||
|
await self.send(text_data=json.dumps({
|
||||||
|
"type": "stream",
|
||||||
|
"content": chunk,
|
||||||
|
}))
|
||||||
|
except Exception as e:
|
||||||
|
await self.send(text_data=json.dumps({
|
||||||
|
"type": "error",
|
||||||
|
"content": f"AI 服务出错:{str(e)}",
|
||||||
|
}))
|
||||||
|
return
|
||||||
|
|
||||||
|
clean_response, is_ready = parse_guidance_response(full_response)
|
||||||
|
self.session_messages.append({
|
||||||
|
"role": "assistant",
|
||||||
|
"content": clean_response,
|
||||||
|
})
|
||||||
|
assistant_msg = await self.save_message("assistant", clean_response)
|
||||||
|
self.current_user_message = None
|
||||||
|
|
||||||
|
await self.send(text_data=json.dumps({
|
||||||
|
"type": "complete",
|
||||||
|
"message_id": assistant_msg.id,
|
||||||
|
"is_ready": is_ready,
|
||||||
|
}))
|
||||||
|
|
||||||
|
finally:
|
||||||
|
if self.current_user_message:
|
||||||
|
await self.delete_message(self.current_user_message)
|
||||||
|
self.current_user_message = None
|
||||||
|
|
||||||
|
@database_sync_to_async
|
||||||
|
def get_or_create_conversation(self):
|
||||||
|
conv = (
|
||||||
|
Conversation.objects.filter(user=self.user, task_id=self.task_id)
|
||||||
|
.annotate(msg_count=Count("messages"))
|
||||||
|
.order_by("-msg_count", "-created")
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
if not conv:
|
||||||
|
conv = Conversation.objects.create(user=self.user, task_id=self.task_id)
|
||||||
|
return conv
|
||||||
|
|
||||||
|
@database_sync_to_async
|
||||||
|
def delete_message(self, message):
|
||||||
|
message.delete()
|
||||||
|
|
||||||
|
@database_sync_to_async
|
||||||
|
def save_message(self, role, content):
|
||||||
|
return Message.objects.create(
|
||||||
|
conversation=self.conversation,
|
||||||
|
role=role,
|
||||||
|
source="guidance",
|
||||||
|
content=content,
|
||||||
|
)
|
||||||
|
|||||||
@@ -45,9 +45,11 @@ GUIDANCE_SYSTEM_PROMPT = """你是一个提示词写作教练,帮助学生写
|
|||||||
1. 如果提示词不够好,用 1-2 个启发性问题引导学生补充细节,不要直接给出答案
|
1. 如果提示词不够好,用 1-2 个启发性问题引导学生补充细节,不要直接给出答案
|
||||||
2. 如果提示词已经够好,以 [READY] 开头回复,简短夸奖学生并说明可以生成了
|
2. 如果提示词已经够好,以 [READY] 开头回复,简短夸奖学生并说明可以生成了
|
||||||
3. 用中文回复,语气鼓励,简洁明了
|
3. 用中文回复,语气鼓励,简洁明了
|
||||||
4. 不要生成任何代码"""
|
4. 使用 Markdown 语法高亮关键词,优先突出 **主题**、**视觉**、**交互**、**内容**、**可以生成** 等重点
|
||||||
|
5. 如果回复以 [READY] 开头,[READY] 不要加粗,必须保持原始文本
|
||||||
|
6. 不要生成任何代码"""
|
||||||
|
|
||||||
DEFAULT_MODEL = "deepseek-chat"
|
DEFAULT_MODEL = "deepseek-v4-flash"
|
||||||
|
|
||||||
# Models served by the ARK (Volcengine) endpoint
|
# Models served by the ARK (Volcengine) endpoint
|
||||||
ARK_MODELS = {"doubao-seed-2-0-lite-260215"}
|
ARK_MODELS = {"doubao-seed-2-0-lite-260215"}
|
||||||
|
|||||||
@@ -297,10 +297,16 @@ class PromptHistoryTest(TestCase):
|
|||||||
self.assertIsNone(data[0]["code_js"])
|
self.assertIsNone(data[0]["code_js"])
|
||||||
|
|
||||||
|
|
||||||
from prompt.llm import parse_guidance_response
|
from prompt.llm import GUIDANCE_SYSTEM_PROMPT, parse_guidance_response
|
||||||
|
|
||||||
|
|
||||||
class ParseGuidanceResponseTest(TestCase):
|
class ParseGuidanceResponseTest(TestCase):
|
||||||
|
def test_guidance_prompt_asks_for_bold_keywords(self):
|
||||||
|
self.assertIn("Markdown", GUIDANCE_SYSTEM_PROMPT)
|
||||||
|
self.assertIn("[READY] 不要加粗", GUIDANCE_SYSTEM_PROMPT)
|
||||||
|
for keyword in ("**主题**", "**视觉**", "**交互**", "**内容**", "**可以生成**"):
|
||||||
|
self.assertIn(keyword, GUIDANCE_SYSTEM_PROMPT)
|
||||||
|
|
||||||
def test_ready_prefix_with_newline_stripped(self):
|
def test_ready_prefix_with_newline_stripped(self):
|
||||||
content, is_ready = parse_guidance_response("[READY]\n很好,可以生成了!")
|
content, is_ready = parse_guidance_response("[READY]\n很好,可以生成了!")
|
||||||
self.assertEqual(content, "很好,可以生成了!")
|
self.assertEqual(content, "很好,可以生成了!")
|
||||||
|
|||||||
@@ -1,6 +1,15 @@
|
|||||||
from django.urls import path
|
from collections.abc import Callable
|
||||||
from .consumers import PromptConsumer
|
from typing import Any, cast
|
||||||
|
|
||||||
websocket_urlpatterns = [
|
from django.urls import path
|
||||||
path("ws/prompt/<int:task_id>/", PromptConsumer.as_asgi()),
|
from django.urls.resolvers import URLPattern, URLResolver
|
||||||
|
|
||||||
|
from .consumers import PromptConsumer, GuidanceConsumer
|
||||||
|
|
||||||
|
AsgiApplication = Callable[..., Any]
|
||||||
|
RoutePattern = URLPattern | URLResolver
|
||||||
|
|
||||||
|
websocket_urlpatterns: list[RoutePattern] = [
|
||||||
|
path("ws/prompt/<int:task_id>/", cast(AsgiApplication, PromptConsumer.as_asgi())),
|
||||||
|
path("ws/guidance/<int:task_id>/", cast(AsgiApplication, GuidanceConsumer.as_asgi())),
|
||||||
]
|
]
|
||||||
|
|||||||
Reference in New Issue
Block a user