diff --git a/prompt/api.py b/prompt/api.py index fbe4fb5..980d62f 100644 --- a/prompt/api.py +++ b/prompt/api.py @@ -6,7 +6,7 @@ from ninja.errors import HttpError from django.shortcuts import get_object_or_404 from django.contrib.auth.decorators import login_required -from django.db.models import Count +from django.db.models import Count, Prefetch from .models import Conversation, Message from .schemas import ConversationOut, MessageOut, PromptHistoryItemOut from account.models import RoleChoices @@ -66,11 +66,13 @@ def list_prompt_history(request, task_id: int): conversations = Conversation.objects.filter( user=request.user, task_id=task_id, - ).prefetch_related("messages") + ).prefetch_related( + Prefetch("messages", queryset=Message.objects.order_by("created", "id")) + ) items = [] for conv in conversations: - messages = list(conv.messages.all().order_by("created", "id")) + messages = list(conv.messages.all()) for idx, user_msg in enumerate(messages): if user_msg.role != "user": continue diff --git a/prompt/llm.py b/prompt/llm.py index 5e7486a..2f0909a 100644 --- a/prompt/llm.py +++ b/prompt/llm.py @@ -3,32 +3,32 @@ from django.conf import settings from openai import AsyncOpenAI -SYSTEM_PROMPT = """你是一个网页生成助手。根据用户的需求描述,生成 HTML、CSS 和 JavaScript 代码。 +SYSTEM_PROMPT = """你是一个网页生成助手。根据用户的需求描述,生成网页代码。 规则: -1. 始终使用三个独立的代码块返回代码,分别用 ```html、```css、```js 标记 +1. 使用一个 ```html 代码块返回所有代码 2. HTML 代码只需要 body 内的内容,不需要完整的 HTML 文档结构 -3. CSS 和 JS 可以为空,但仍然需要返回空的代码块 +3. CSS 样式写在 -```js + + + ```""" GUIDANCE_SYSTEM_PROMPT = """你是一个提示词写作教练,帮助学生写出清晰、具体的网页需求描述。 @@ -61,13 +61,6 @@ NON_THINKING_EXTRA_BODY = {"thinking": {"type": "disabled"}} ARK_MODELS = {"doubao-seed-2-0-lite-260215"} -def build_messages(history: list[dict]) -> list[dict]: - """Build the message list for the LLM API call.""" - messages = [{"role": "system", "content": SYSTEM_PROMPT}] - messages.extend(history) - return messages - - def _get_client(model: str) -> tuple[AsyncOpenAI, str]: """Return (client, model_id) for the given model name.""" requested_model = model or DEFAULT_MODEL @@ -114,19 +107,12 @@ def _chat_completion_kwargs( return kwargs -async def stream_chat(history: list[dict], model: str = ""): - """Stream chat completion from the LLM. Yields content chunks.""" - messages = build_messages(history) +async def _stream_completion(messages: list[dict], model: str = ""): client, resolved_model = _get_client(model) requested_model = model or DEFAULT_MODEL async with client as c: stream = await c.chat.completions.create( - **_chat_completion_kwargs( - requested_model, - resolved_model, - messages, - stream=True, - ), + **_chat_completion_kwargs(requested_model, resolved_model, messages, stream=True), ) async for chunk in stream: delta = chunk.choices[0].delta @@ -134,8 +120,15 @@ async def stream_chat(history: list[dict], model: str = ""): yield delta.content +async def stream_chat(history: list[dict], model: str = ""): + """Stream chat completion from the LLM. Yields content chunks.""" + messages = [{"role": "system", "content": SYSTEM_PROMPT}, *history] + async for chunk in _stream_completion(messages, model): + yield chunk + + def extract_code(text: str) -> dict: - """Extract HTML, CSS, JS code blocks from AI response text.""" + """Extract code from AI response. Supports single HTML block (new) or separate html/css/js blocks (legacy).""" result = {"html": None, "css": None, "js": None} pattern = r"```(html|css|js|javascript|typescript|ts|jsx|tsx)\s*\n(.*?)```" matches = re.findall(pattern, text, re.DOTALL | re.IGNORECASE) @@ -146,17 +139,21 @@ def extract_code(text: str) -> dict: if lang in result and result[lang] is None: result[lang] = code.strip() - # Fallback: extract ", result["html"], re.DOTALL | re.IGNORECASE) - if style_match: - result["css"] = style_match.group(1).strip() + # Single HTML block: extract ", html, re.DOTALL | re.IGNORECASE) + if style_match: + result["css"] = style_match.group(1).strip() + html = re.sub(r"]*>.*?", "", html, flags=re.DOTALL | re.IGNORECASE) + + script_match = re.search(r"]*>(.*?)", html, re.DOTALL | re.IGNORECASE) + if script_match: + result["js"] = script_match.group(1).strip() + html = re.sub(r"]*>.*?", "", html, flags=re.DOTALL | re.IGNORECASE) + + result["html"] = html.strip() return result @@ -169,20 +166,6 @@ def parse_guidance_response(full_response: str) -> tuple[str, bool]: async def stream_guidance(history: list[dict]): """Stream guidance coaching response. Yields content chunks.""" - messages = [{"role": "system", "content": GUIDANCE_SYSTEM_PROMPT}] - messages.extend(history) - client, model = _get_client("") - requested_model = DEFAULT_MODEL - async with client as c: - stream = await c.chat.completions.create( - **_chat_completion_kwargs( - requested_model, - model, - messages, - stream=True, - ), - ) - async for chunk in stream: - delta = chunk.choices[0].delta - if delta.content: - yield delta.content + messages = [{"role": "system", "content": GUIDANCE_SYSTEM_PROMPT}, *history] + async for chunk in _stream_completion(messages): + yield chunk diff --git a/prompt/schemas.py b/prompt/schemas.py index 9cc9c5f..e7aea49 100644 --- a/prompt/schemas.py +++ b/prompt/schemas.py @@ -47,6 +47,6 @@ class ConversationOut(Schema): "task_id": conv.task_id, "task_title": conv.task.title, "is_active": conv.is_active, - "message_count": conv.messages.count(), + "message_count": conv.msg_count if hasattr(conv, "msg_count") else conv.messages.count(), "created": conv.created.isoformat(), } diff --git a/submission/api.py b/submission/api.py index 89b4f3a..f945969 100644 --- a/submission/api.py +++ b/submission/api.py @@ -22,6 +22,7 @@ from django.db.models import ( ) from account.decorators import admin_required from prompt.models import Conversation, Message +from .classifier import classify_conversation_messages from .schemas import ( @@ -150,8 +151,6 @@ def create_submission(request, payload: SubmissionIn): code_js=payload.js, source="manual", ) - from .classifier import classify_conversation_messages - threading.Thread(target=classify_conversation_messages, args=(conversation.id,), daemon=True).start() else: conversation = ( Conversation.objects.filter(user=request.user, task=task) @@ -159,9 +158,9 @@ def create_submission(request, payload: SubmissionIn): .order_by("-msg_count", "-created") .first() ) - if conversation: - from .classifier import classify_conversation_messages - threading.Thread(target=classify_conversation_messages, args=(conversation.id,), daemon=True).start() + + if conversation: + threading.Thread(target=classify_conversation_messages, args=(conversation.id,), daemon=True).start() submission = Submission.objects.create( user=request.user,