fix
This commit is contained in:
@@ -6,6 +6,7 @@ from ninja.errors import HttpError
|
|||||||
from django.shortcuts import get_object_or_404
|
from django.shortcuts import get_object_or_404
|
||||||
from django.contrib.auth.decorators import login_required
|
from django.contrib.auth.decorators import login_required
|
||||||
|
|
||||||
|
from django.db.models import Count
|
||||||
from .models import Conversation, Message
|
from .models import Conversation, Message
|
||||||
from .schemas import ConversationOut, MessageOut
|
from .schemas import ConversationOut, MessageOut
|
||||||
from account.models import RoleChoices
|
from account.models import RoleChoices
|
||||||
@@ -16,7 +17,9 @@ router = Router()
|
|||||||
@router.get("/conversations/", response=List[ConversationOut])
|
@router.get("/conversations/", response=List[ConversationOut])
|
||||||
@login_required
|
@login_required
|
||||||
def list_conversations(request, task_id: int = None, user_id: int = None):
|
def list_conversations(request, task_id: int = None, user_id: int = None):
|
||||||
convs = Conversation.objects.select_related("user", "task")
|
convs = Conversation.objects.select_related("user", "task").annotate(
|
||||||
|
msg_count=Count("messages")
|
||||||
|
)
|
||||||
# Normal users can only see their own
|
# Normal users can only see their own
|
||||||
if request.user.role == "normal":
|
if request.user.role == "normal":
|
||||||
convs = convs.filter(user=request.user)
|
convs = convs.filter(user=request.user)
|
||||||
@@ -24,6 +27,7 @@ def list_conversations(request, task_id: int = None, user_id: int = None):
|
|||||||
convs = convs.filter(user_id=user_id)
|
convs = convs.filter(user_id=user_id)
|
||||||
if task_id:
|
if task_id:
|
||||||
convs = convs.filter(task_id=task_id)
|
convs = convs.filter(task_id=task_id)
|
||||||
|
convs = convs.order_by("-msg_count", "-created")
|
||||||
return [ConversationOut.from_conv(c) for c in convs]
|
return [ConversationOut.from_conv(c) for c in convs]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
from channels.generic.websocket import AsyncWebsocketConsumer
|
from channels.generic.websocket import AsyncWebsocketConsumer
|
||||||
from channels.db import database_sync_to_async
|
from channels.db import database_sync_to_async
|
||||||
|
from django.db.models import Count
|
||||||
from .models import Conversation, Message
|
from .models import Conversation, Message
|
||||||
from .llm import stream_chat, extract_code
|
from .llm import stream_chat, extract_code
|
||||||
|
|
||||||
@@ -35,12 +36,6 @@ class PromptConsumer(AsyncWebsocketConsumer):
|
|||||||
msg_type = data.get("type", "message")
|
msg_type = data.get("type", "message")
|
||||||
|
|
||||||
if msg_type == "new_conversation":
|
if msg_type == "new_conversation":
|
||||||
self.conversation = await self.create_conversation()
|
|
||||||
await self.send(text_data=json.dumps({
|
|
||||||
"type": "init",
|
|
||||||
"conversation_id": str(self.conversation.id),
|
|
||||||
"messages": [],
|
|
||||||
}))
|
|
||||||
return
|
return
|
||||||
|
|
||||||
prompt = data.get("content", "").strip()
|
prompt = data.get("content", "").strip()
|
||||||
@@ -88,20 +83,16 @@ class PromptConsumer(AsyncWebsocketConsumer):
|
|||||||
|
|
||||||
@database_sync_to_async
|
@database_sync_to_async
|
||||||
def get_or_create_conversation(self):
|
def get_or_create_conversation(self):
|
||||||
conv = Conversation.objects.filter(
|
conv = (
|
||||||
user=self.user, task_id=self.task_id, is_active=True
|
Conversation.objects.filter(user=self.user, task_id=self.task_id)
|
||||||
).first()
|
.annotate(msg_count=Count("messages"))
|
||||||
|
.order_by("-msg_count", "-created")
|
||||||
|
.first()
|
||||||
|
)
|
||||||
if not conv:
|
if not conv:
|
||||||
conv = Conversation.objects.create(user=self.user, task_id=self.task_id)
|
conv = Conversation.objects.create(user=self.user, task_id=self.task_id)
|
||||||
return conv
|
return conv
|
||||||
|
|
||||||
@database_sync_to_async
|
|
||||||
def create_conversation(self):
|
|
||||||
Conversation.objects.filter(
|
|
||||||
user=self.user, task_id=self.task_id, is_active=True
|
|
||||||
).update(is_active=False)
|
|
||||||
return Conversation.objects.create(user=self.user, task_id=self.task_id)
|
|
||||||
|
|
||||||
@database_sync_to_async
|
@database_sync_to_async
|
||||||
def delete_message(self, message):
|
def delete_message(self, message):
|
||||||
message.delete()
|
message.delete()
|
||||||
@@ -119,7 +110,7 @@ class PromptConsumer(AsyncWebsocketConsumer):
|
|||||||
|
|
||||||
@database_sync_to_async
|
@database_sync_to_async
|
||||||
def get_history(self):
|
def get_history(self):
|
||||||
messages = self.conversation.messages.all()
|
messages = self.conversation.messages.filter(source="conversation")
|
||||||
return [
|
return [
|
||||||
{
|
{
|
||||||
"role": m.role,
|
"role": m.role,
|
||||||
@@ -136,5 +127,5 @@ class PromptConsumer(AsyncWebsocketConsumer):
|
|||||||
|
|
||||||
@database_sync_to_async
|
@database_sync_to_async
|
||||||
def get_history_for_llm(self):
|
def get_history_for_llm(self):
|
||||||
messages = self.conversation.messages.all()
|
messages = self.conversation.messages.filter(source="conversation")
|
||||||
return [{"role": m.role, "content": m.content} for m in messages]
|
return [{"role": m.role, "content": m.content} for m in messages]
|
||||||
|
|||||||
18
prompt/migrations/0003_add_message_source.py
Normal file
18
prompt/migrations/0003_add_message_source.py
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
# Generated by Django 6.0.1 on 2026-03-31 07:10
|
||||||
|
|
||||||
|
from django.db import migrations, models
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
|
||||||
|
dependencies = [
|
||||||
|
('prompt', '0002_message_prompt_level'),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.AddField(
|
||||||
|
model_name='message',
|
||||||
|
name='source',
|
||||||
|
field=models.CharField(default='ai', max_length=10),
|
||||||
|
),
|
||||||
|
]
|
||||||
18
prompt/migrations/0004_update_message_source_default.py
Normal file
18
prompt/migrations/0004_update_message_source_default.py
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
# Generated by Django 6.0.1 on 2026-03-31 07:24
|
||||||
|
|
||||||
|
from django.db import migrations, models
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
|
||||||
|
dependencies = [
|
||||||
|
('prompt', '0003_add_message_source'),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.AlterField(
|
||||||
|
model_name='message',
|
||||||
|
name='source',
|
||||||
|
field=models.CharField(default='conversation', max_length=12),
|
||||||
|
),
|
||||||
|
]
|
||||||
@@ -23,6 +23,7 @@ class Message(models.Model):
|
|||||||
Conversation, on_delete=models.CASCADE, related_name="messages"
|
Conversation, on_delete=models.CASCADE, related_name="messages"
|
||||||
)
|
)
|
||||||
role = models.CharField(max_length=10) # "user" or "assistant"
|
role = models.CharField(max_length=10) # "user" or "assistant"
|
||||||
|
source = models.CharField(max_length=12, default="conversation") # "conversation" or "manual"
|
||||||
content = models.TextField()
|
content = models.TextField()
|
||||||
code_html = models.TextField(null=True, blank=True)
|
code_html = models.TextField(null=True, blank=True)
|
||||||
code_css = models.TextField(null=True, blank=True)
|
code_css = models.TextField(null=True, blank=True)
|
||||||
|
|||||||
@@ -36,28 +36,43 @@ def create_submission(request, payload: SubmissionIn):
|
|||||||
创建一个新的提交
|
创建一个新的提交
|
||||||
"""
|
"""
|
||||||
task = get_object_or_404(Task, id=payload.task_id)
|
task = get_object_or_404(Task, id=payload.task_id)
|
||||||
conversation = None
|
|
||||||
if payload.conversation_id:
|
|
||||||
from prompt.models import Conversation
|
|
||||||
conversation = get_object_or_404(
|
|
||||||
Conversation, id=payload.conversation_id, user=request.user
|
|
||||||
)
|
|
||||||
conversation.is_active = False
|
|
||||||
conversation.save(update_fields=["is_active"])
|
|
||||||
|
|
||||||
sub = Submission.objects.create(
|
if payload.prompt:
|
||||||
|
from prompt.models import Conversation, Message
|
||||||
|
from django.db.models import Count as _Count
|
||||||
|
conversation = (
|
||||||
|
Conversation.objects.filter(user=request.user, task=task)
|
||||||
|
.annotate(msg_count=_Count("messages"))
|
||||||
|
.order_by("-msg_count", "-created")
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
if not conversation:
|
||||||
|
conversation = Conversation.objects.create(
|
||||||
|
user=request.user, task=task, is_active=False
|
||||||
|
)
|
||||||
|
Message.objects.create(
|
||||||
|
conversation=conversation, role="user", content=payload.prompt, source="manual"
|
||||||
|
)
|
||||||
|
Message.objects.create(
|
||||||
|
conversation=conversation,
|
||||||
|
role="assistant",
|
||||||
|
content="",
|
||||||
|
code_html=payload.html,
|
||||||
|
code_css=payload.css,
|
||||||
|
code_js=payload.js,
|
||||||
|
source="manual",
|
||||||
|
)
|
||||||
|
from .classifier import classify_conversation_messages
|
||||||
|
threading.Thread(target=classify_conversation_messages, args=(conversation.id,), daemon=True).start()
|
||||||
|
|
||||||
|
Submission.objects.create(
|
||||||
user=request.user,
|
user=request.user,
|
||||||
task=task,
|
task=task,
|
||||||
html=payload.html,
|
html=payload.html,
|
||||||
css=payload.css,
|
css=payload.css,
|
||||||
js=payload.js,
|
js=payload.js,
|
||||||
conversation=conversation,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if conversation:
|
|
||||||
from .classifier import classify_conversation_messages
|
|
||||||
threading.Thread(target=classify_conversation_messages, args=(conversation.id,), daemon=True).start()
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/", response=List[SubmissionOut])
|
@router.get("/", response=List[SubmissionOut])
|
||||||
@paginate
|
@paginate
|
||||||
@@ -69,7 +84,6 @@ def list_submissions(request, filters: SubmissionFilter = Query(...)):
|
|||||||
submissions = (
|
submissions = (
|
||||||
Submission.objects.select_related("task", "user")
|
Submission.objects.select_related("task", "user")
|
||||||
.defer("html", "css", "js")
|
.defer("html", "css", "js")
|
||||||
.exclude(conversation__isnull=False, html__isnull=True, css__isnull=True, js__isnull=True)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if filters.task_id:
|
if filters.task_id:
|
||||||
@@ -143,7 +157,6 @@ def list_by_user_task(request, user_id: int, task_id: int):
|
|||||||
)
|
)
|
||||||
return (
|
return (
|
||||||
Submission.objects.filter(user_id=user_id, task_id=task_id)
|
Submission.objects.filter(user_id=user_id, task_id=task_id)
|
||||||
.exclude(conversation__isnull=False, html__isnull=True, css__isnull=True, js__isnull=True)
|
|
||||||
.select_related("task", "user")
|
.select_related("task", "user")
|
||||||
.defer("html", "css", "js")
|
.defer("html", "css", "js")
|
||||||
.annotate(my_score=user_rating_subquery)
|
.annotate(my_score=user_rating_subquery)
|
||||||
|
|||||||
17
submission/migrations/0010_remove_conversation_fk.py
Normal file
17
submission/migrations/0010_remove_conversation_fk.py
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
# Generated by Django 6.0.1 on 2026-03-31 06:52
|
||||||
|
|
||||||
|
from django.db import migrations
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
|
||||||
|
dependencies = [
|
||||||
|
('submission', '0009_add_view_count'),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.RemoveField(
|
||||||
|
model_name='submission',
|
||||||
|
name='conversation',
|
||||||
|
),
|
||||||
|
]
|
||||||
@@ -9,7 +9,6 @@ from django.dispatch import receiver # 导入receiver
|
|||||||
|
|
||||||
from account.models import RoleChoices, User
|
from account.models import RoleChoices, User
|
||||||
from task.models import Task
|
from task.models import Task
|
||||||
from prompt.models import Conversation
|
|
||||||
|
|
||||||
|
|
||||||
class FlagChoices(models.TextChoices):
|
class FlagChoices(models.TextChoices):
|
||||||
@@ -41,10 +40,6 @@ class Submission(TimeStampedModel):
|
|||||||
html = models.TextField(null=True, blank=True, verbose_name="HTML代码")
|
html = models.TextField(null=True, blank=True, verbose_name="HTML代码")
|
||||||
css = models.TextField(null=True, blank=True, verbose_name="CSS代码")
|
css = models.TextField(null=True, blank=True, verbose_name="CSS代码")
|
||||||
js = models.TextField(null=True, blank=True, verbose_name="JS代码")
|
js = models.TextField(null=True, blank=True, verbose_name="JS代码")
|
||||||
conversation = models.ForeignKey(
|
|
||||||
Conversation, on_delete=models.SET_NULL, null=True, blank=True,
|
|
||||||
related_name="submissions", verbose_name="对话"
|
|
||||||
)
|
|
||||||
flag = models.CharField(
|
flag = models.CharField(
|
||||||
max_length=10,
|
max_length=10,
|
||||||
choices=FlagChoices.choices,
|
choices=FlagChoices.choices,
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ class SubmissionIn(Schema):
|
|||||||
html: Optional[str] = None
|
html: Optional[str] = None
|
||||||
css: Optional[str] = None
|
css: Optional[str] = None
|
||||||
js: Optional[str] = None
|
js: Optional[str] = None
|
||||||
conversation_id: Optional[UUID] = None
|
prompt: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class SubmissionOut(Schema):
|
class SubmissionOut(Schema):
|
||||||
@@ -24,7 +24,6 @@ class SubmissionOut(Schema):
|
|||||||
html: Optional[str] = None
|
html: Optional[str] = None
|
||||||
css: Optional[str] = None
|
css: Optional[str] = None
|
||||||
js: Optional[str] = None
|
js: Optional[str] = None
|
||||||
conversation_id: Optional[UUID] = None
|
|
||||||
flag: Optional[str] = None
|
flag: Optional[str] = None
|
||||||
zone: Optional[str] = None
|
zone: Optional[str] = None
|
||||||
submit_count: int = 0
|
submit_count: int = 0
|
||||||
@@ -87,7 +86,6 @@ class SubmissionOut(Schema):
|
|||||||
"html": submission.html,
|
"html": submission.html,
|
||||||
"css": submission.css,
|
"css": submission.css,
|
||||||
"js": submission.js,
|
"js": submission.js,
|
||||||
"conversation_id": submission.conversation_id,
|
|
||||||
"flag": submission.flag,
|
"flag": submission.flag,
|
||||||
"created": submission.created.isoformat(),
|
"created": submission.created.isoformat(),
|
||||||
"modified": submission.modified.isoformat(),
|
"modified": submission.modified.isoformat(),
|
||||||
|
|||||||
Reference in New Issue
Block a user