From d05c05757d60f04bb4ee873d71600d8093f1fbe9 Mon Sep 17 00:00:00 2001 From: yuetsh <517252939@qq.com> Date: Mon, 30 Mar 2026 05:33:14 -0600 Subject: [PATCH] add bloom --- prompt/api.py | 51 ++++++++++- .../migrations/0002_message_prompt_level.py | 18 ++++ prompt/models.py | 3 + prompt/schemas.py | 1 + submission/api.py | 9 +- submission/classifier.py | 86 +++++++++++++++++++ submission/management/commands/__init__.py | 0 .../management/commands/classify_prompts.py | 35 ++++++++ 8 files changed, 201 insertions(+), 2 deletions(-) create mode 100644 prompt/migrations/0002_message_prompt_level.py create mode 100644 submission/classifier.py create mode 100644 submission/management/commands/__init__.py create mode 100644 submission/management/commands/classify_prompts.py diff --git a/prompt/api.py b/prompt/api.py index 54bff34..83891e6 100644 --- a/prompt/api.py +++ b/prompt/api.py @@ -1,11 +1,14 @@ -from typing import List +import threading +from typing import List, Optional from uuid import UUID from ninja import Router +from ninja.errors import HttpError from django.shortcuts import get_object_or_404 from django.contrib.auth.decorators import login_required from .models import Conversation, Message from .schemas import ConversationOut, MessageOut +from account.models import RoleChoices router = Router() @@ -40,7 +43,53 @@ def list_messages(request, conversation_id: UUID): "code_html": m.code_html, "code_css": m.code_css, "code_js": m.code_js, + "prompt_level": m.prompt_level, "created": m.created.isoformat(), } for m in messages ] + + +@router.post("/conversations/{conversation_id}/classify") +@login_required +def classify_conversation(request, conversation_id: UUID, force: bool = False): + """ + 对对话中所有用户消息进行层级分类(仅管理员和超级管理员可操作,异步执行) + """ + if request.user.role not in (RoleChoices.SUPER, RoleChoices.ADMIN): + raise HttpError(403, "没有权限") + + get_object_or_404(Conversation, id=conversation_id) + + from submission.classifier import classify_conversation_messages + threading.Thread( + target=classify_conversation_messages, + args=(conversation_id,), + kwargs={"force": force}, + daemon=True, + ).start() + + return {"message": "开始分类"} + + +@router.post("/classify-batch") +@login_required +def classify_batch(request, task_id: Optional[int] = None, force: bool = False): + """ + 批量分类所有(或指定任务)对话的用户消息层级(仅管理员和超级管理员,异步执行) + """ + if request.user.role not in (RoleChoices.SUPER, RoleChoices.ADMIN): + raise HttpError(403, "没有权限") + + qs = Message.objects.filter(role="user") + if task_id: + qs = qs.filter(conversation__task_id=task_id) + if not force: + qs = qs.filter(prompt_level__isnull=True) + + ids = list(qs.values_list("id", flat=True)) + + from submission.classifier import classify_messages_batch + threading.Thread(target=classify_messages_batch, args=(ids,), daemon=True).start() + + return {"message": f"开始分类 {len(ids)} 条消息", "count": len(ids)} diff --git a/prompt/migrations/0002_message_prompt_level.py b/prompt/migrations/0002_message_prompt_level.py new file mode 100644 index 0000000..c4de1ed --- /dev/null +++ b/prompt/migrations/0002_message_prompt_level.py @@ -0,0 +1,18 @@ +# Generated by Django 6.0.1 on 2026-03-30 11:17 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('prompt', '0001_initial'), + ] + + operations = [ + migrations.AddField( + model_name='message', + name='prompt_level', + field=models.IntegerField(blank=True, db_index=True, default=None, null=True, verbose_name='提示词层级'), + ), + ] diff --git a/prompt/models.py b/prompt/models.py index c57ab1d..61f1aa8 100644 --- a/prompt/models.py +++ b/prompt/models.py @@ -28,6 +28,9 @@ class Message(models.Model): code_css = models.TextField(null=True, blank=True) code_js = models.TextField(null=True, blank=True) created = models.DateTimeField(auto_now_add=True) + prompt_level = models.IntegerField( + null=True, blank=True, default=None, db_index=True, verbose_name="提示词层级" + ) class Meta: ordering = ("created",) diff --git a/prompt/schemas.py b/prompt/schemas.py index 6cb77d1..6b67c4a 100644 --- a/prompt/schemas.py +++ b/prompt/schemas.py @@ -10,6 +10,7 @@ class MessageOut(Schema): code_html: Optional[str] = None code_css: Optional[str] = None code_js: Optional[str] = None + prompt_level: Optional[int] = None created: str diff --git a/submission/api.py b/submission/api.py index cac2238..4680a45 100644 --- a/submission/api.py +++ b/submission/api.py @@ -1,3 +1,4 @@ +import threading from typing import List, Optional from uuid import UUID from ninja import Router, Query @@ -44,7 +45,7 @@ def create_submission(request, payload: SubmissionIn): conversation.is_active = False conversation.save(update_fields=["is_active"]) - Submission.objects.create( + sub = Submission.objects.create( user=request.user, task=task, html=payload.html, @@ -53,6 +54,10 @@ def create_submission(request, payload: SubmissionIn): conversation=conversation, ) + if conversation: + from .classifier import classify_conversation_messages + threading.Thread(target=classify_conversation_messages, args=(conversation.id,), daemon=True).start() + @router.get("/", response=List[SubmissionOut]) @paginate @@ -360,3 +365,5 @@ def update_flag(request, submission_id: UUID, payload: FlagIn): return {"flag": submission.flag} + + diff --git a/submission/classifier.py b/submission/classifier.py new file mode 100644 index 0000000..601d79b --- /dev/null +++ b/submission/classifier.py @@ -0,0 +1,86 @@ +import re +import time +import logging +from uuid import UUID + +from django.conf import settings +from openai import OpenAI + +logger = logging.getLogger(__name__) + +CLASSIFY_SYSTEM_PROMPT = """你是一个教育评估专家。根据布鲁姆认知分类学,分析以下学生在前端学习中发送给AI助手的一条提示词,判断该提示词所体现的认知层级。 + +层级定义: +- L1 记忆:能背诵HTML标签语法(例:"帮我写一个按钮") +- L2 理解:能解释flex布局原理(例:"为什么这里不居中?") +- L3 应用:能独立搭建页面结构(例:"用flex做导航栏,间距16px") +- L4 分析:能定位跨浏览器兼容性bug(例:"Safari中margin失效,原因?") +- L5 评价:能对比并选择方案(例:"对比Grid与Flex方案优劣") +- L6 创造:能设计并实现原创交互作品(例:"设计夜间/日间切换效果") + +只返回一个数字(1-6),不要解释。""" + + +def _call_llm(content: str) -> int | None: + """Call LLM to classify a single message content. Returns level 1-6 or None.""" + try: + client = OpenAI( + api_key=settings.LLM_API_KEY, + base_url=settings.LLM_BASE_URL, + timeout=30.0, + ) + response = client.chat.completions.create( + model=settings.LLM_MODEL, + messages=[ + {"role": "system", "content": CLASSIFY_SYSTEM_PROMPT}, + {"role": "user", "content": content}, + ], + max_tokens=10, + stream=False, + ) + text = response.choices[0].message.content or "" + match = re.search(r"[1-6]", text) + if not match: + logger.warning("classify: unexpected LLM response '%s'", text) + return None + return int(match.group()) + except Exception as e: + logger.error("classify LLM call failed: %s", e) + return None + + +def classify_message(message_id: int) -> int | None: + """Classify a single user Message by ID. Returns level or None.""" + from prompt.models import Message + + try: + msg = Message.objects.get(id=message_id, role="user") + except Message.DoesNotExist: + return None + + level = _call_llm(msg.content) + if level is not None: + Message.objects.filter(id=message_id).update(prompt_level=level) + return level + + +def classify_conversation_messages(conversation_id: UUID, force: bool = False) -> None: + """Classify all user messages in a conversation.""" + from prompt.models import Message + + qs = Message.objects.filter(conversation_id=conversation_id, role="user") + if not force: + qs = qs.filter(prompt_level__isnull=True) + + for msg in qs.order_by("created"): + level = _call_llm(msg.content) + if level is not None: + Message.objects.filter(id=msg.id).update(prompt_level=level) + time.sleep(0.3) + + +def classify_messages_batch(message_ids: list) -> None: + """Classify a list of messages by ID.""" + for mid in message_ids: + classify_message(mid) + time.sleep(0.5) diff --git a/submission/management/commands/__init__.py b/submission/management/commands/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/submission/management/commands/classify_prompts.py b/submission/management/commands/classify_prompts.py new file mode 100644 index 0000000..6c18d58 --- /dev/null +++ b/submission/management/commands/classify_prompts.py @@ -0,0 +1,35 @@ +from django.core.management.base import BaseCommand + +from prompt.models import Message +from submission.classifier import classify_message + + +class Command(BaseCommand): + help = "Classify prompt levels (L1-L6) for user messages using LLM" + + def add_arguments(self, parser): + parser.add_argument("--task-id", type=int, help="Only classify messages for this task ID") + parser.add_argument("--force", action="store_true", help="Re-classify already classified messages") + parser.add_argument("--dry-run", action="store_true", help="Show count without classifying") + + def handle(self, *args, **options): + qs = Message.objects.filter(role="user") + if options["task_id"]: + qs = qs.filter(conversation__task_id=options["task_id"]) + if not options["force"]: + qs = qs.filter(prompt_level__isnull=True) + + ids = list(qs.values_list("id", flat=True)) + self.stdout.write(f"Found {len(ids)} message(s) to classify.") + + if options["dry_run"]: + self.stdout.write("Dry run — no changes made.") + return + + for i, mid in enumerate(ids, 1): + level = classify_message(mid) + self.stdout.write( + f"[{i}/{len(ids)}] msg#{mid} → L{level}" if level else f"[{i}/{len(ids)}] msg#{mid} → (skipped)" + ) + + self.stdout.write(self.style.SUCCESS("Done."))