fix
This commit is contained in:
@@ -6,6 +6,7 @@ from ninja.errors import HttpError
|
||||
from django.shortcuts import get_object_or_404
|
||||
from django.contrib.auth.decorators import login_required
|
||||
|
||||
from django.db.models import Count
|
||||
from .models import Conversation, Message
|
||||
from .schemas import ConversationOut, MessageOut
|
||||
from account.models import RoleChoices
|
||||
@@ -16,7 +17,9 @@ router = Router()
|
||||
@router.get("/conversations/", response=List[ConversationOut])
|
||||
@login_required
|
||||
def list_conversations(request, task_id: int = None, user_id: int = None):
|
||||
convs = Conversation.objects.select_related("user", "task")
|
||||
convs = Conversation.objects.select_related("user", "task").annotate(
|
||||
msg_count=Count("messages")
|
||||
)
|
||||
# Normal users can only see their own
|
||||
if request.user.role == "normal":
|
||||
convs = convs.filter(user=request.user)
|
||||
@@ -24,6 +27,7 @@ def list_conversations(request, task_id: int = None, user_id: int = None):
|
||||
convs = convs.filter(user_id=user_id)
|
||||
if task_id:
|
||||
convs = convs.filter(task_id=task_id)
|
||||
convs = convs.order_by("-msg_count", "-created")
|
||||
return [ConversationOut.from_conv(c) for c in convs]
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import json
|
||||
from channels.generic.websocket import AsyncWebsocketConsumer
|
||||
from channels.db import database_sync_to_async
|
||||
from django.db.models import Count
|
||||
from .models import Conversation, Message
|
||||
from .llm import stream_chat, extract_code
|
||||
|
||||
@@ -35,12 +36,6 @@ class PromptConsumer(AsyncWebsocketConsumer):
|
||||
msg_type = data.get("type", "message")
|
||||
|
||||
if msg_type == "new_conversation":
|
||||
self.conversation = await self.create_conversation()
|
||||
await self.send(text_data=json.dumps({
|
||||
"type": "init",
|
||||
"conversation_id": str(self.conversation.id),
|
||||
"messages": [],
|
||||
}))
|
||||
return
|
||||
|
||||
prompt = data.get("content", "").strip()
|
||||
@@ -88,20 +83,16 @@ class PromptConsumer(AsyncWebsocketConsumer):
|
||||
|
||||
@database_sync_to_async
|
||||
def get_or_create_conversation(self):
|
||||
conv = Conversation.objects.filter(
|
||||
user=self.user, task_id=self.task_id, is_active=True
|
||||
).first()
|
||||
conv = (
|
||||
Conversation.objects.filter(user=self.user, task_id=self.task_id)
|
||||
.annotate(msg_count=Count("messages"))
|
||||
.order_by("-msg_count", "-created")
|
||||
.first()
|
||||
)
|
||||
if not conv:
|
||||
conv = Conversation.objects.create(user=self.user, task_id=self.task_id)
|
||||
return conv
|
||||
|
||||
@database_sync_to_async
|
||||
def create_conversation(self):
|
||||
Conversation.objects.filter(
|
||||
user=self.user, task_id=self.task_id, is_active=True
|
||||
).update(is_active=False)
|
||||
return Conversation.objects.create(user=self.user, task_id=self.task_id)
|
||||
|
||||
@database_sync_to_async
|
||||
def delete_message(self, message):
|
||||
message.delete()
|
||||
@@ -119,7 +110,7 @@ class PromptConsumer(AsyncWebsocketConsumer):
|
||||
|
||||
@database_sync_to_async
|
||||
def get_history(self):
|
||||
messages = self.conversation.messages.all()
|
||||
messages = self.conversation.messages.filter(source="conversation")
|
||||
return [
|
||||
{
|
||||
"role": m.role,
|
||||
@@ -136,5 +127,5 @@ class PromptConsumer(AsyncWebsocketConsumer):
|
||||
|
||||
@database_sync_to_async
|
||||
def get_history_for_llm(self):
|
||||
messages = self.conversation.messages.all()
|
||||
messages = self.conversation.messages.filter(source="conversation")
|
||||
return [{"role": m.role, "content": m.content} for m in messages]
|
||||
|
||||
18
prompt/migrations/0003_add_message_source.py
Normal file
18
prompt/migrations/0003_add_message_source.py
Normal file
@@ -0,0 +1,18 @@
|
||||
# Generated by Django 6.0.1 on 2026-03-31 07:10
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('prompt', '0002_message_prompt_level'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name='message',
|
||||
name='source',
|
||||
field=models.CharField(default='ai', max_length=10),
|
||||
),
|
||||
]
|
||||
18
prompt/migrations/0004_update_message_source_default.py
Normal file
18
prompt/migrations/0004_update_message_source_default.py
Normal file
@@ -0,0 +1,18 @@
|
||||
# Generated by Django 6.0.1 on 2026-03-31 07:24
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('prompt', '0003_add_message_source'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterField(
|
||||
model_name='message',
|
||||
name='source',
|
||||
field=models.CharField(default='conversation', max_length=12),
|
||||
),
|
||||
]
|
||||
@@ -23,6 +23,7 @@ class Message(models.Model):
|
||||
Conversation, on_delete=models.CASCADE, related_name="messages"
|
||||
)
|
||||
role = models.CharField(max_length=10) # "user" or "assistant"
|
||||
source = models.CharField(max_length=12, default="conversation") # "conversation" or "manual"
|
||||
content = models.TextField()
|
||||
code_html = models.TextField(null=True, blank=True)
|
||||
code_css = models.TextField(null=True, blank=True)
|
||||
|
||||
Reference in New Issue
Block a user