add bloom
This commit is contained in:
@@ -1,11 +1,14 @@
|
||||
from typing import List
|
||||
import threading
|
||||
from typing import List, Optional
|
||||
from uuid import UUID
|
||||
from ninja import Router
|
||||
from ninja.errors import HttpError
|
||||
from django.shortcuts import get_object_or_404
|
||||
from django.contrib.auth.decorators import login_required
|
||||
|
||||
from .models import Conversation, Message
|
||||
from .schemas import ConversationOut, MessageOut
|
||||
from account.models import RoleChoices
|
||||
|
||||
router = Router()
|
||||
|
||||
@@ -40,7 +43,53 @@ def list_messages(request, conversation_id: UUID):
|
||||
"code_html": m.code_html,
|
||||
"code_css": m.code_css,
|
||||
"code_js": m.code_js,
|
||||
"prompt_level": m.prompt_level,
|
||||
"created": m.created.isoformat(),
|
||||
}
|
||||
for m in messages
|
||||
]
|
||||
|
||||
|
||||
@router.post("/conversations/{conversation_id}/classify")
|
||||
@login_required
|
||||
def classify_conversation(request, conversation_id: UUID, force: bool = False):
|
||||
"""
|
||||
对对话中所有用户消息进行层级分类(仅管理员和超级管理员可操作,异步执行)
|
||||
"""
|
||||
if request.user.role not in (RoleChoices.SUPER, RoleChoices.ADMIN):
|
||||
raise HttpError(403, "没有权限")
|
||||
|
||||
get_object_or_404(Conversation, id=conversation_id)
|
||||
|
||||
from submission.classifier import classify_conversation_messages
|
||||
threading.Thread(
|
||||
target=classify_conversation_messages,
|
||||
args=(conversation_id,),
|
||||
kwargs={"force": force},
|
||||
daemon=True,
|
||||
).start()
|
||||
|
||||
return {"message": "开始分类"}
|
||||
|
||||
|
||||
@router.post("/classify-batch")
|
||||
@login_required
|
||||
def classify_batch(request, task_id: Optional[int] = None, force: bool = False):
|
||||
"""
|
||||
批量分类所有(或指定任务)对话的用户消息层级(仅管理员和超级管理员,异步执行)
|
||||
"""
|
||||
if request.user.role not in (RoleChoices.SUPER, RoleChoices.ADMIN):
|
||||
raise HttpError(403, "没有权限")
|
||||
|
||||
qs = Message.objects.filter(role="user")
|
||||
if task_id:
|
||||
qs = qs.filter(conversation__task_id=task_id)
|
||||
if not force:
|
||||
qs = qs.filter(prompt_level__isnull=True)
|
||||
|
||||
ids = list(qs.values_list("id", flat=True))
|
||||
|
||||
from submission.classifier import classify_messages_batch
|
||||
threading.Thread(target=classify_messages_batch, args=(ids,), daemon=True).start()
|
||||
|
||||
return {"message": f"开始分类 {len(ids)} 条消息", "count": len(ids)}
|
||||
|
||||
18
prompt/migrations/0002_message_prompt_level.py
Normal file
18
prompt/migrations/0002_message_prompt_level.py
Normal file
@@ -0,0 +1,18 @@
|
||||
# Generated by Django 6.0.1 on 2026-03-30 11:17
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('prompt', '0001_initial'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name='message',
|
||||
name='prompt_level',
|
||||
field=models.IntegerField(blank=True, db_index=True, default=None, null=True, verbose_name='提示词层级'),
|
||||
),
|
||||
]
|
||||
@@ -28,6 +28,9 @@ class Message(models.Model):
|
||||
code_css = models.TextField(null=True, blank=True)
|
||||
code_js = models.TextField(null=True, blank=True)
|
||||
created = models.DateTimeField(auto_now_add=True)
|
||||
prompt_level = models.IntegerField(
|
||||
null=True, blank=True, default=None, db_index=True, verbose_name="提示词层级"
|
||||
)
|
||||
|
||||
class Meta:
|
||||
ordering = ("created",)
|
||||
|
||||
@@ -10,6 +10,7 @@ class MessageOut(Schema):
|
||||
code_html: Optional[str] = None
|
||||
code_css: Optional[str] = None
|
||||
code_js: Optional[str] = None
|
||||
prompt_level: Optional[int] = None
|
||||
created: str
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user