From 9c76c98da8bb3edfa32fbb90e4f2d683baa60598 Mon Sep 17 00:00:00 2001 From: yuetsh <517252939@qq.com> Date: Tue, 31 Mar 2026 02:21:44 -0600 Subject: [PATCH] fix --- prompt/api.py | 6 ++- prompt/consumers.py | 27 ++++------- prompt/migrations/0003_add_message_source.py | 18 ++++++++ .../0004_update_message_source_default.py | 18 ++++++++ prompt/models.py | 1 + submission/api.py | 45 ++++++++++++------- .../migrations/0010_remove_conversation_fk.py | 17 +++++++ submission/models.py | 5 --- submission/schemas.py | 4 +- 9 files changed, 98 insertions(+), 43 deletions(-) create mode 100644 prompt/migrations/0003_add_message_source.py create mode 100644 prompt/migrations/0004_update_message_source_default.py create mode 100644 submission/migrations/0010_remove_conversation_fk.py diff --git a/prompt/api.py b/prompt/api.py index 83891e6..579e213 100644 --- a/prompt/api.py +++ b/prompt/api.py @@ -6,6 +6,7 @@ from ninja.errors import HttpError from django.shortcuts import get_object_or_404 from django.contrib.auth.decorators import login_required +from django.db.models import Count from .models import Conversation, Message from .schemas import ConversationOut, MessageOut from account.models import RoleChoices @@ -16,7 +17,9 @@ router = Router() @router.get("/conversations/", response=List[ConversationOut]) @login_required def list_conversations(request, task_id: int = None, user_id: int = None): - convs = Conversation.objects.select_related("user", "task") + convs = Conversation.objects.select_related("user", "task").annotate( + msg_count=Count("messages") + ) # Normal users can only see their own if request.user.role == "normal": convs = convs.filter(user=request.user) @@ -24,6 +27,7 @@ def list_conversations(request, task_id: int = None, user_id: int = None): convs = convs.filter(user_id=user_id) if task_id: convs = convs.filter(task_id=task_id) + convs = convs.order_by("-msg_count", "-created") return [ConversationOut.from_conv(c) for c in convs] diff --git a/prompt/consumers.py b/prompt/consumers.py index 2c2e1f1..881b9fe 100644 --- a/prompt/consumers.py +++ b/prompt/consumers.py @@ -1,6 +1,7 @@ import json from channels.generic.websocket import AsyncWebsocketConsumer from channels.db import database_sync_to_async +from django.db.models import Count from .models import Conversation, Message from .llm import stream_chat, extract_code @@ -35,12 +36,6 @@ class PromptConsumer(AsyncWebsocketConsumer): msg_type = data.get("type", "message") if msg_type == "new_conversation": - self.conversation = await self.create_conversation() - await self.send(text_data=json.dumps({ - "type": "init", - "conversation_id": str(self.conversation.id), - "messages": [], - })) return prompt = data.get("content", "").strip() @@ -88,20 +83,16 @@ class PromptConsumer(AsyncWebsocketConsumer): @database_sync_to_async def get_or_create_conversation(self): - conv = Conversation.objects.filter( - user=self.user, task_id=self.task_id, is_active=True - ).first() + conv = ( + Conversation.objects.filter(user=self.user, task_id=self.task_id) + .annotate(msg_count=Count("messages")) + .order_by("-msg_count", "-created") + .first() + ) if not conv: conv = Conversation.objects.create(user=self.user, task_id=self.task_id) return conv - @database_sync_to_async - def create_conversation(self): - Conversation.objects.filter( - user=self.user, task_id=self.task_id, is_active=True - ).update(is_active=False) - return Conversation.objects.create(user=self.user, task_id=self.task_id) - @database_sync_to_async def delete_message(self, message): message.delete() @@ -119,7 +110,7 @@ class PromptConsumer(AsyncWebsocketConsumer): @database_sync_to_async def get_history(self): - messages = self.conversation.messages.all() + messages = self.conversation.messages.filter(source="conversation") return [ { "role": m.role, @@ -136,5 +127,5 @@ class PromptConsumer(AsyncWebsocketConsumer): @database_sync_to_async def get_history_for_llm(self): - messages = self.conversation.messages.all() + messages = self.conversation.messages.filter(source="conversation") return [{"role": m.role, "content": m.content} for m in messages] diff --git a/prompt/migrations/0003_add_message_source.py b/prompt/migrations/0003_add_message_source.py new file mode 100644 index 0000000..b8c547d --- /dev/null +++ b/prompt/migrations/0003_add_message_source.py @@ -0,0 +1,18 @@ +# Generated by Django 6.0.1 on 2026-03-31 07:10 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('prompt', '0002_message_prompt_level'), + ] + + operations = [ + migrations.AddField( + model_name='message', + name='source', + field=models.CharField(default='ai', max_length=10), + ), + ] diff --git a/prompt/migrations/0004_update_message_source_default.py b/prompt/migrations/0004_update_message_source_default.py new file mode 100644 index 0000000..e1f38ed --- /dev/null +++ b/prompt/migrations/0004_update_message_source_default.py @@ -0,0 +1,18 @@ +# Generated by Django 6.0.1 on 2026-03-31 07:24 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('prompt', '0003_add_message_source'), + ] + + operations = [ + migrations.AlterField( + model_name='message', + name='source', + field=models.CharField(default='conversation', max_length=12), + ), + ] diff --git a/prompt/models.py b/prompt/models.py index 61f1aa8..99e13a5 100644 --- a/prompt/models.py +++ b/prompt/models.py @@ -23,6 +23,7 @@ class Message(models.Model): Conversation, on_delete=models.CASCADE, related_name="messages" ) role = models.CharField(max_length=10) # "user" or "assistant" + source = models.CharField(max_length=12, default="conversation") # "conversation" or "manual" content = models.TextField() code_html = models.TextField(null=True, blank=True) code_css = models.TextField(null=True, blank=True) diff --git a/submission/api.py b/submission/api.py index 1822108..e48015d 100644 --- a/submission/api.py +++ b/submission/api.py @@ -36,28 +36,43 @@ def create_submission(request, payload: SubmissionIn): 创建一个新的提交 """ task = get_object_or_404(Task, id=payload.task_id) - conversation = None - if payload.conversation_id: - from prompt.models import Conversation - conversation = get_object_or_404( - Conversation, id=payload.conversation_id, user=request.user - ) - conversation.is_active = False - conversation.save(update_fields=["is_active"]) - sub = Submission.objects.create( + if payload.prompt: + from prompt.models import Conversation, Message + from django.db.models import Count as _Count + conversation = ( + Conversation.objects.filter(user=request.user, task=task) + .annotate(msg_count=_Count("messages")) + .order_by("-msg_count", "-created") + .first() + ) + if not conversation: + conversation = Conversation.objects.create( + user=request.user, task=task, is_active=False + ) + Message.objects.create( + conversation=conversation, role="user", content=payload.prompt, source="manual" + ) + Message.objects.create( + conversation=conversation, + role="assistant", + content="", + code_html=payload.html, + code_css=payload.css, + code_js=payload.js, + source="manual", + ) + from .classifier import classify_conversation_messages + threading.Thread(target=classify_conversation_messages, args=(conversation.id,), daemon=True).start() + + Submission.objects.create( user=request.user, task=task, html=payload.html, css=payload.css, js=payload.js, - conversation=conversation, ) - if conversation: - from .classifier import classify_conversation_messages - threading.Thread(target=classify_conversation_messages, args=(conversation.id,), daemon=True).start() - @router.get("/", response=List[SubmissionOut]) @paginate @@ -69,7 +84,6 @@ def list_submissions(request, filters: SubmissionFilter = Query(...)): submissions = ( Submission.objects.select_related("task", "user") .defer("html", "css", "js") - .exclude(conversation__isnull=False, html__isnull=True, css__isnull=True, js__isnull=True) ) if filters.task_id: @@ -143,7 +157,6 @@ def list_by_user_task(request, user_id: int, task_id: int): ) return ( Submission.objects.filter(user_id=user_id, task_id=task_id) - .exclude(conversation__isnull=False, html__isnull=True, css__isnull=True, js__isnull=True) .select_related("task", "user") .defer("html", "css", "js") .annotate(my_score=user_rating_subquery) diff --git a/submission/migrations/0010_remove_conversation_fk.py b/submission/migrations/0010_remove_conversation_fk.py new file mode 100644 index 0000000..976d72a --- /dev/null +++ b/submission/migrations/0010_remove_conversation_fk.py @@ -0,0 +1,17 @@ +# Generated by Django 6.0.1 on 2026-03-31 06:52 + +from django.db import migrations + + +class Migration(migrations.Migration): + + dependencies = [ + ('submission', '0009_add_view_count'), + ] + + operations = [ + migrations.RemoveField( + model_name='submission', + name='conversation', + ), + ] diff --git a/submission/models.py b/submission/models.py index 72041f5..ab328b5 100644 --- a/submission/models.py +++ b/submission/models.py @@ -9,7 +9,6 @@ from django.dispatch import receiver # 导入receiver from account.models import RoleChoices, User from task.models import Task -from prompt.models import Conversation class FlagChoices(models.TextChoices): @@ -41,10 +40,6 @@ class Submission(TimeStampedModel): html = models.TextField(null=True, blank=True, verbose_name="HTML代码") css = models.TextField(null=True, blank=True, verbose_name="CSS代码") js = models.TextField(null=True, blank=True, verbose_name="JS代码") - conversation = models.ForeignKey( - Conversation, on_delete=models.SET_NULL, null=True, blank=True, - related_name="submissions", verbose_name="对话" - ) flag = models.CharField( max_length=10, choices=FlagChoices.choices, diff --git a/submission/schemas.py b/submission/schemas.py index f4a86f9..b5d79ca 100644 --- a/submission/schemas.py +++ b/submission/schemas.py @@ -8,7 +8,7 @@ class SubmissionIn(Schema): html: Optional[str] = None css: Optional[str] = None js: Optional[str] = None - conversation_id: Optional[UUID] = None + prompt: Optional[str] = None class SubmissionOut(Schema): @@ -24,7 +24,6 @@ class SubmissionOut(Schema): html: Optional[str] = None css: Optional[str] = None js: Optional[str] = None - conversation_id: Optional[UUID] = None flag: Optional[str] = None zone: Optional[str] = None submit_count: int = 0 @@ -87,7 +86,6 @@ class SubmissionOut(Schema): "html": submission.html, "css": submission.css, "js": submission.js, - "conversation_id": submission.conversation_id, "flag": submission.flag, "created": submission.created.isoformat(), "modified": submission.modified.isoformat(),