This commit is contained in:
2026-03-31 02:21:44 -06:00
parent 986b48f1de
commit 9c76c98da8
9 changed files with 98 additions and 43 deletions

View File

@@ -6,6 +6,7 @@ from ninja.errors import HttpError
from django.shortcuts import get_object_or_404 from django.shortcuts import get_object_or_404
from django.contrib.auth.decorators import login_required from django.contrib.auth.decorators import login_required
from django.db.models import Count
from .models import Conversation, Message from .models import Conversation, Message
from .schemas import ConversationOut, MessageOut from .schemas import ConversationOut, MessageOut
from account.models import RoleChoices from account.models import RoleChoices
@@ -16,7 +17,9 @@ router = Router()
@router.get("/conversations/", response=List[ConversationOut]) @router.get("/conversations/", response=List[ConversationOut])
@login_required @login_required
def list_conversations(request, task_id: int = None, user_id: int = None): def list_conversations(request, task_id: int = None, user_id: int = None):
convs = Conversation.objects.select_related("user", "task") convs = Conversation.objects.select_related("user", "task").annotate(
msg_count=Count("messages")
)
# Normal users can only see their own # Normal users can only see their own
if request.user.role == "normal": if request.user.role == "normal":
convs = convs.filter(user=request.user) convs = convs.filter(user=request.user)
@@ -24,6 +27,7 @@ def list_conversations(request, task_id: int = None, user_id: int = None):
convs = convs.filter(user_id=user_id) convs = convs.filter(user_id=user_id)
if task_id: if task_id:
convs = convs.filter(task_id=task_id) convs = convs.filter(task_id=task_id)
convs = convs.order_by("-msg_count", "-created")
return [ConversationOut.from_conv(c) for c in convs] return [ConversationOut.from_conv(c) for c in convs]

View File

@@ -1,6 +1,7 @@
import json import json
from channels.generic.websocket import AsyncWebsocketConsumer from channels.generic.websocket import AsyncWebsocketConsumer
from channels.db import database_sync_to_async from channels.db import database_sync_to_async
from django.db.models import Count
from .models import Conversation, Message from .models import Conversation, Message
from .llm import stream_chat, extract_code from .llm import stream_chat, extract_code
@@ -35,12 +36,6 @@ class PromptConsumer(AsyncWebsocketConsumer):
msg_type = data.get("type", "message") msg_type = data.get("type", "message")
if msg_type == "new_conversation": if msg_type == "new_conversation":
self.conversation = await self.create_conversation()
await self.send(text_data=json.dumps({
"type": "init",
"conversation_id": str(self.conversation.id),
"messages": [],
}))
return return
prompt = data.get("content", "").strip() prompt = data.get("content", "").strip()
@@ -88,20 +83,16 @@ class PromptConsumer(AsyncWebsocketConsumer):
@database_sync_to_async @database_sync_to_async
def get_or_create_conversation(self): def get_or_create_conversation(self):
conv = Conversation.objects.filter( conv = (
user=self.user, task_id=self.task_id, is_active=True Conversation.objects.filter(user=self.user, task_id=self.task_id)
).first() .annotate(msg_count=Count("messages"))
.order_by("-msg_count", "-created")
.first()
)
if not conv: if not conv:
conv = Conversation.objects.create(user=self.user, task_id=self.task_id) conv = Conversation.objects.create(user=self.user, task_id=self.task_id)
return conv return conv
@database_sync_to_async
def create_conversation(self):
Conversation.objects.filter(
user=self.user, task_id=self.task_id, is_active=True
).update(is_active=False)
return Conversation.objects.create(user=self.user, task_id=self.task_id)
@database_sync_to_async @database_sync_to_async
def delete_message(self, message): def delete_message(self, message):
message.delete() message.delete()
@@ -119,7 +110,7 @@ class PromptConsumer(AsyncWebsocketConsumer):
@database_sync_to_async @database_sync_to_async
def get_history(self): def get_history(self):
messages = self.conversation.messages.all() messages = self.conversation.messages.filter(source="conversation")
return [ return [
{ {
"role": m.role, "role": m.role,
@@ -136,5 +127,5 @@ class PromptConsumer(AsyncWebsocketConsumer):
@database_sync_to_async @database_sync_to_async
def get_history_for_llm(self): def get_history_for_llm(self):
messages = self.conversation.messages.all() messages = self.conversation.messages.filter(source="conversation")
return [{"role": m.role, "content": m.content} for m in messages] return [{"role": m.role, "content": m.content} for m in messages]

View File

@@ -0,0 +1,18 @@
# Generated by Django 6.0.1 on 2026-03-31 07:10
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('prompt', '0002_message_prompt_level'),
]
operations = [
migrations.AddField(
model_name='message',
name='source',
field=models.CharField(default='ai', max_length=10),
),
]

View File

@@ -0,0 +1,18 @@
# Generated by Django 6.0.1 on 2026-03-31 07:24
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('prompt', '0003_add_message_source'),
]
operations = [
migrations.AlterField(
model_name='message',
name='source',
field=models.CharField(default='conversation', max_length=12),
),
]

View File

@@ -23,6 +23,7 @@ class Message(models.Model):
Conversation, on_delete=models.CASCADE, related_name="messages" Conversation, on_delete=models.CASCADE, related_name="messages"
) )
role = models.CharField(max_length=10) # "user" or "assistant" role = models.CharField(max_length=10) # "user" or "assistant"
source = models.CharField(max_length=12, default="conversation") # "conversation" or "manual"
content = models.TextField() content = models.TextField()
code_html = models.TextField(null=True, blank=True) code_html = models.TextField(null=True, blank=True)
code_css = models.TextField(null=True, blank=True) code_css = models.TextField(null=True, blank=True)

View File

@@ -36,28 +36,43 @@ def create_submission(request, payload: SubmissionIn):
创建一个新的提交 创建一个新的提交
""" """
task = get_object_or_404(Task, id=payload.task_id) task = get_object_or_404(Task, id=payload.task_id)
conversation = None
if payload.conversation_id:
from prompt.models import Conversation
conversation = get_object_or_404(
Conversation, id=payload.conversation_id, user=request.user
)
conversation.is_active = False
conversation.save(update_fields=["is_active"])
sub = Submission.objects.create( if payload.prompt:
from prompt.models import Conversation, Message
from django.db.models import Count as _Count
conversation = (
Conversation.objects.filter(user=request.user, task=task)
.annotate(msg_count=_Count("messages"))
.order_by("-msg_count", "-created")
.first()
)
if not conversation:
conversation = Conversation.objects.create(
user=request.user, task=task, is_active=False
)
Message.objects.create(
conversation=conversation, role="user", content=payload.prompt, source="manual"
)
Message.objects.create(
conversation=conversation,
role="assistant",
content="",
code_html=payload.html,
code_css=payload.css,
code_js=payload.js,
source="manual",
)
from .classifier import classify_conversation_messages
threading.Thread(target=classify_conversation_messages, args=(conversation.id,), daemon=True).start()
Submission.objects.create(
user=request.user, user=request.user,
task=task, task=task,
html=payload.html, html=payload.html,
css=payload.css, css=payload.css,
js=payload.js, js=payload.js,
conversation=conversation,
) )
if conversation:
from .classifier import classify_conversation_messages
threading.Thread(target=classify_conversation_messages, args=(conversation.id,), daemon=True).start()
@router.get("/", response=List[SubmissionOut]) @router.get("/", response=List[SubmissionOut])
@paginate @paginate
@@ -69,7 +84,6 @@ def list_submissions(request, filters: SubmissionFilter = Query(...)):
submissions = ( submissions = (
Submission.objects.select_related("task", "user") Submission.objects.select_related("task", "user")
.defer("html", "css", "js") .defer("html", "css", "js")
.exclude(conversation__isnull=False, html__isnull=True, css__isnull=True, js__isnull=True)
) )
if filters.task_id: if filters.task_id:
@@ -143,7 +157,6 @@ def list_by_user_task(request, user_id: int, task_id: int):
) )
return ( return (
Submission.objects.filter(user_id=user_id, task_id=task_id) Submission.objects.filter(user_id=user_id, task_id=task_id)
.exclude(conversation__isnull=False, html__isnull=True, css__isnull=True, js__isnull=True)
.select_related("task", "user") .select_related("task", "user")
.defer("html", "css", "js") .defer("html", "css", "js")
.annotate(my_score=user_rating_subquery) .annotate(my_score=user_rating_subquery)

View File

@@ -0,0 +1,17 @@
# Generated by Django 6.0.1 on 2026-03-31 06:52
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
('submission', '0009_add_view_count'),
]
operations = [
migrations.RemoveField(
model_name='submission',
name='conversation',
),
]

View File

@@ -9,7 +9,6 @@ from django.dispatch import receiver # 导入receiver
from account.models import RoleChoices, User from account.models import RoleChoices, User
from task.models import Task from task.models import Task
from prompt.models import Conversation
class FlagChoices(models.TextChoices): class FlagChoices(models.TextChoices):
@@ -41,10 +40,6 @@ class Submission(TimeStampedModel):
html = models.TextField(null=True, blank=True, verbose_name="HTML代码") html = models.TextField(null=True, blank=True, verbose_name="HTML代码")
css = models.TextField(null=True, blank=True, verbose_name="CSS代码") css = models.TextField(null=True, blank=True, verbose_name="CSS代码")
js = models.TextField(null=True, blank=True, verbose_name="JS代码") js = models.TextField(null=True, blank=True, verbose_name="JS代码")
conversation = models.ForeignKey(
Conversation, on_delete=models.SET_NULL, null=True, blank=True,
related_name="submissions", verbose_name="对话"
)
flag = models.CharField( flag = models.CharField(
max_length=10, max_length=10,
choices=FlagChoices.choices, choices=FlagChoices.choices,

View File

@@ -8,7 +8,7 @@ class SubmissionIn(Schema):
html: Optional[str] = None html: Optional[str] = None
css: Optional[str] = None css: Optional[str] = None
js: Optional[str] = None js: Optional[str] = None
conversation_id: Optional[UUID] = None prompt: Optional[str] = None
class SubmissionOut(Schema): class SubmissionOut(Schema):
@@ -24,7 +24,6 @@ class SubmissionOut(Schema):
html: Optional[str] = None html: Optional[str] = None
css: Optional[str] = None css: Optional[str] = None
js: Optional[str] = None js: Optional[str] = None
conversation_id: Optional[UUID] = None
flag: Optional[str] = None flag: Optional[str] = None
zone: Optional[str] = None zone: Optional[str] = None
submit_count: int = 0 submit_count: int = 0
@@ -87,7 +86,6 @@ class SubmissionOut(Schema):
"html": submission.html, "html": submission.html,
"css": submission.css, "css": submission.css,
"js": submission.js, "js": submission.js,
"conversation_id": submission.conversation_id,
"flag": submission.flag, "flag": submission.flag,
"created": submission.created.isoformat(), "created": submission.created.isoformat(),
"modified": submission.modified.isoformat(), "modified": submission.modified.isoformat(),