add bloom
This commit is contained in:
@@ -1,11 +1,14 @@
|
|||||||
from typing import List
|
import threading
|
||||||
|
from typing import List, Optional
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
from ninja import Router
|
from ninja import Router
|
||||||
|
from ninja.errors import HttpError
|
||||||
from django.shortcuts import get_object_or_404
|
from django.shortcuts import get_object_or_404
|
||||||
from django.contrib.auth.decorators import login_required
|
from django.contrib.auth.decorators import login_required
|
||||||
|
|
||||||
from .models import Conversation, Message
|
from .models import Conversation, Message
|
||||||
from .schemas import ConversationOut, MessageOut
|
from .schemas import ConversationOut, MessageOut
|
||||||
|
from account.models import RoleChoices
|
||||||
|
|
||||||
router = Router()
|
router = Router()
|
||||||
|
|
||||||
@@ -40,7 +43,53 @@ def list_messages(request, conversation_id: UUID):
|
|||||||
"code_html": m.code_html,
|
"code_html": m.code_html,
|
||||||
"code_css": m.code_css,
|
"code_css": m.code_css,
|
||||||
"code_js": m.code_js,
|
"code_js": m.code_js,
|
||||||
|
"prompt_level": m.prompt_level,
|
||||||
"created": m.created.isoformat(),
|
"created": m.created.isoformat(),
|
||||||
}
|
}
|
||||||
for m in messages
|
for m in messages
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/conversations/{conversation_id}/classify")
|
||||||
|
@login_required
|
||||||
|
def classify_conversation(request, conversation_id: UUID, force: bool = False):
|
||||||
|
"""
|
||||||
|
对对话中所有用户消息进行层级分类(仅管理员和超级管理员可操作,异步执行)
|
||||||
|
"""
|
||||||
|
if request.user.role not in (RoleChoices.SUPER, RoleChoices.ADMIN):
|
||||||
|
raise HttpError(403, "没有权限")
|
||||||
|
|
||||||
|
get_object_or_404(Conversation, id=conversation_id)
|
||||||
|
|
||||||
|
from submission.classifier import classify_conversation_messages
|
||||||
|
threading.Thread(
|
||||||
|
target=classify_conversation_messages,
|
||||||
|
args=(conversation_id,),
|
||||||
|
kwargs={"force": force},
|
||||||
|
daemon=True,
|
||||||
|
).start()
|
||||||
|
|
||||||
|
return {"message": "开始分类"}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/classify-batch")
|
||||||
|
@login_required
|
||||||
|
def classify_batch(request, task_id: Optional[int] = None, force: bool = False):
|
||||||
|
"""
|
||||||
|
批量分类所有(或指定任务)对话的用户消息层级(仅管理员和超级管理员,异步执行)
|
||||||
|
"""
|
||||||
|
if request.user.role not in (RoleChoices.SUPER, RoleChoices.ADMIN):
|
||||||
|
raise HttpError(403, "没有权限")
|
||||||
|
|
||||||
|
qs = Message.objects.filter(role="user")
|
||||||
|
if task_id:
|
||||||
|
qs = qs.filter(conversation__task_id=task_id)
|
||||||
|
if not force:
|
||||||
|
qs = qs.filter(prompt_level__isnull=True)
|
||||||
|
|
||||||
|
ids = list(qs.values_list("id", flat=True))
|
||||||
|
|
||||||
|
from submission.classifier import classify_messages_batch
|
||||||
|
threading.Thread(target=classify_messages_batch, args=(ids,), daemon=True).start()
|
||||||
|
|
||||||
|
return {"message": f"开始分类 {len(ids)} 条消息", "count": len(ids)}
|
||||||
|
|||||||
18
prompt/migrations/0002_message_prompt_level.py
Normal file
18
prompt/migrations/0002_message_prompt_level.py
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
# Generated by Django 6.0.1 on 2026-03-30 11:17
|
||||||
|
|
||||||
|
from django.db import migrations, models
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
|
||||||
|
dependencies = [
|
||||||
|
('prompt', '0001_initial'),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.AddField(
|
||||||
|
model_name='message',
|
||||||
|
name='prompt_level',
|
||||||
|
field=models.IntegerField(blank=True, db_index=True, default=None, null=True, verbose_name='提示词层级'),
|
||||||
|
),
|
||||||
|
]
|
||||||
@@ -28,6 +28,9 @@ class Message(models.Model):
|
|||||||
code_css = models.TextField(null=True, blank=True)
|
code_css = models.TextField(null=True, blank=True)
|
||||||
code_js = models.TextField(null=True, blank=True)
|
code_js = models.TextField(null=True, blank=True)
|
||||||
created = models.DateTimeField(auto_now_add=True)
|
created = models.DateTimeField(auto_now_add=True)
|
||||||
|
prompt_level = models.IntegerField(
|
||||||
|
null=True, blank=True, default=None, db_index=True, verbose_name="提示词层级"
|
||||||
|
)
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
ordering = ("created",)
|
ordering = ("created",)
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ class MessageOut(Schema):
|
|||||||
code_html: Optional[str] = None
|
code_html: Optional[str] = None
|
||||||
code_css: Optional[str] = None
|
code_css: Optional[str] = None
|
||||||
code_js: Optional[str] = None
|
code_js: Optional[str] = None
|
||||||
|
prompt_level: Optional[int] = None
|
||||||
created: str
|
created: str
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import threading
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
from ninja import Router, Query
|
from ninja import Router, Query
|
||||||
@@ -44,7 +45,7 @@ def create_submission(request, payload: SubmissionIn):
|
|||||||
conversation.is_active = False
|
conversation.is_active = False
|
||||||
conversation.save(update_fields=["is_active"])
|
conversation.save(update_fields=["is_active"])
|
||||||
|
|
||||||
Submission.objects.create(
|
sub = Submission.objects.create(
|
||||||
user=request.user,
|
user=request.user,
|
||||||
task=task,
|
task=task,
|
||||||
html=payload.html,
|
html=payload.html,
|
||||||
@@ -53,6 +54,10 @@ def create_submission(request, payload: SubmissionIn):
|
|||||||
conversation=conversation,
|
conversation=conversation,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if conversation:
|
||||||
|
from .classifier import classify_conversation_messages
|
||||||
|
threading.Thread(target=classify_conversation_messages, args=(conversation.id,), daemon=True).start()
|
||||||
|
|
||||||
|
|
||||||
@router.get("/", response=List[SubmissionOut])
|
@router.get("/", response=List[SubmissionOut])
|
||||||
@paginate
|
@paginate
|
||||||
@@ -360,3 +365,5 @@ def update_flag(request, submission_id: UUID, payload: FlagIn):
|
|||||||
return {"flag": submission.flag}
|
return {"flag": submission.flag}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
86
submission/classifier.py
Normal file
86
submission/classifier.py
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
import re
|
||||||
|
import time
|
||||||
|
import logging
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from django.conf import settings
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
CLASSIFY_SYSTEM_PROMPT = """你是一个教育评估专家。根据布鲁姆认知分类学,分析以下学生在前端学习中发送给AI助手的一条提示词,判断该提示词所体现的认知层级。
|
||||||
|
|
||||||
|
层级定义:
|
||||||
|
- L1 记忆:能背诵HTML标签语法(例:"帮我写一个按钮")
|
||||||
|
- L2 理解:能解释flex布局原理(例:"为什么这里不居中?")
|
||||||
|
- L3 应用:能独立搭建页面结构(例:"用flex做导航栏,间距16px")
|
||||||
|
- L4 分析:能定位跨浏览器兼容性bug(例:"Safari中margin失效,原因?")
|
||||||
|
- L5 评价:能对比并选择方案(例:"对比Grid与Flex方案优劣")
|
||||||
|
- L6 创造:能设计并实现原创交互作品(例:"设计夜间/日间切换效果")
|
||||||
|
|
||||||
|
只返回一个数字(1-6),不要解释。"""
|
||||||
|
|
||||||
|
|
||||||
|
def _call_llm(content: str) -> int | None:
|
||||||
|
"""Call LLM to classify a single message content. Returns level 1-6 or None."""
|
||||||
|
try:
|
||||||
|
client = OpenAI(
|
||||||
|
api_key=settings.LLM_API_KEY,
|
||||||
|
base_url=settings.LLM_BASE_URL,
|
||||||
|
timeout=30.0,
|
||||||
|
)
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model=settings.LLM_MODEL,
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": CLASSIFY_SYSTEM_PROMPT},
|
||||||
|
{"role": "user", "content": content},
|
||||||
|
],
|
||||||
|
max_tokens=10,
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
text = response.choices[0].message.content or ""
|
||||||
|
match = re.search(r"[1-6]", text)
|
||||||
|
if not match:
|
||||||
|
logger.warning("classify: unexpected LLM response '%s'", text)
|
||||||
|
return None
|
||||||
|
return int(match.group())
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("classify LLM call failed: %s", e)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def classify_message(message_id: int) -> int | None:
|
||||||
|
"""Classify a single user Message by ID. Returns level or None."""
|
||||||
|
from prompt.models import Message
|
||||||
|
|
||||||
|
try:
|
||||||
|
msg = Message.objects.get(id=message_id, role="user")
|
||||||
|
except Message.DoesNotExist:
|
||||||
|
return None
|
||||||
|
|
||||||
|
level = _call_llm(msg.content)
|
||||||
|
if level is not None:
|
||||||
|
Message.objects.filter(id=message_id).update(prompt_level=level)
|
||||||
|
return level
|
||||||
|
|
||||||
|
|
||||||
|
def classify_conversation_messages(conversation_id: UUID, force: bool = False) -> None:
|
||||||
|
"""Classify all user messages in a conversation."""
|
||||||
|
from prompt.models import Message
|
||||||
|
|
||||||
|
qs = Message.objects.filter(conversation_id=conversation_id, role="user")
|
||||||
|
if not force:
|
||||||
|
qs = qs.filter(prompt_level__isnull=True)
|
||||||
|
|
||||||
|
for msg in qs.order_by("created"):
|
||||||
|
level = _call_llm(msg.content)
|
||||||
|
if level is not None:
|
||||||
|
Message.objects.filter(id=msg.id).update(prompt_level=level)
|
||||||
|
time.sleep(0.3)
|
||||||
|
|
||||||
|
|
||||||
|
def classify_messages_batch(message_ids: list) -> None:
|
||||||
|
"""Classify a list of messages by ID."""
|
||||||
|
for mid in message_ids:
|
||||||
|
classify_message(mid)
|
||||||
|
time.sleep(0.5)
|
||||||
0
submission/management/commands/__init__.py
Normal file
0
submission/management/commands/__init__.py
Normal file
35
submission/management/commands/classify_prompts.py
Normal file
35
submission/management/commands/classify_prompts.py
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
from django.core.management.base import BaseCommand
|
||||||
|
|
||||||
|
from prompt.models import Message
|
||||||
|
from submission.classifier import classify_message
|
||||||
|
|
||||||
|
|
||||||
|
class Command(BaseCommand):
|
||||||
|
help = "Classify prompt levels (L1-L6) for user messages using LLM"
|
||||||
|
|
||||||
|
def add_arguments(self, parser):
|
||||||
|
parser.add_argument("--task-id", type=int, help="Only classify messages for this task ID")
|
||||||
|
parser.add_argument("--force", action="store_true", help="Re-classify already classified messages")
|
||||||
|
parser.add_argument("--dry-run", action="store_true", help="Show count without classifying")
|
||||||
|
|
||||||
|
def handle(self, *args, **options):
|
||||||
|
qs = Message.objects.filter(role="user")
|
||||||
|
if options["task_id"]:
|
||||||
|
qs = qs.filter(conversation__task_id=options["task_id"])
|
||||||
|
if not options["force"]:
|
||||||
|
qs = qs.filter(prompt_level__isnull=True)
|
||||||
|
|
||||||
|
ids = list(qs.values_list("id", flat=True))
|
||||||
|
self.stdout.write(f"Found {len(ids)} message(s) to classify.")
|
||||||
|
|
||||||
|
if options["dry_run"]:
|
||||||
|
self.stdout.write("Dry run — no changes made.")
|
||||||
|
return
|
||||||
|
|
||||||
|
for i, mid in enumerate(ids, 1):
|
||||||
|
level = classify_message(mid)
|
||||||
|
self.stdout.write(
|
||||||
|
f"[{i}/{len(ids)}] msg#{mid} → L{level}" if level else f"[{i}/{len(ids)}] msg#{mid} → (skipped)"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.stdout.write(self.style.SUCCESS("Done."))
|
||||||
Reference in New Issue
Block a user