fix
This commit is contained in:
@@ -13,6 +13,7 @@ class PromptConsumer(AsyncWebsocketConsumer):
|
|||||||
return
|
return
|
||||||
|
|
||||||
self.task_id = int(self.scope["url_route"]["kwargs"]["task_id"])
|
self.task_id = int(self.scope["url_route"]["kwargs"]["task_id"])
|
||||||
|
self.current_user_message = None
|
||||||
await self.accept()
|
await self.accept()
|
||||||
|
|
||||||
# Load or create conversation, send history
|
# Load or create conversation, send history
|
||||||
@@ -25,7 +26,9 @@ class PromptConsumer(AsyncWebsocketConsumer):
|
|||||||
}))
|
}))
|
||||||
|
|
||||||
async def disconnect(self, close_code):
|
async def disconnect(self, close_code):
|
||||||
pass
|
if self.current_user_message:
|
||||||
|
await self.delete_message(self.current_user_message)
|
||||||
|
self.current_user_message = None
|
||||||
|
|
||||||
async def receive(self, text_data):
|
async def receive(self, text_data):
|
||||||
data = json.loads(text_data)
|
data = json.loads(text_data)
|
||||||
@@ -45,8 +48,9 @@ class PromptConsumer(AsyncWebsocketConsumer):
|
|||||||
return
|
return
|
||||||
|
|
||||||
# Save user message
|
# Save user message
|
||||||
await self.save_message("user", prompt)
|
self.current_user_message = await self.save_message("user", prompt)
|
||||||
|
|
||||||
|
try:
|
||||||
# Build history for LLM
|
# Build history for LLM
|
||||||
history = await self.get_history_for_llm()
|
history = await self.get_history_for_llm()
|
||||||
task_content = await self.get_task_content()
|
task_content = await self.get_task_content()
|
||||||
@@ -70,6 +74,7 @@ class PromptConsumer(AsyncWebsocketConsumer):
|
|||||||
# Extract code and save assistant message
|
# Extract code and save assistant message
|
||||||
code = extract_code(full_response)
|
code = extract_code(full_response)
|
||||||
await self.save_message("assistant", full_response, code)
|
await self.save_message("assistant", full_response, code)
|
||||||
|
self.current_user_message = None
|
||||||
|
|
||||||
# Send completion with extracted code
|
# Send completion with extracted code
|
||||||
await self.send(text_data=json.dumps({
|
await self.send(text_data=json.dumps({
|
||||||
@@ -77,6 +82,11 @@ class PromptConsumer(AsyncWebsocketConsumer):
|
|||||||
"code": code,
|
"code": code,
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
finally:
|
||||||
|
if self.current_user_message:
|
||||||
|
await self.delete_message(self.current_user_message)
|
||||||
|
self.current_user_message = None
|
||||||
|
|
||||||
@database_sync_to_async
|
@database_sync_to_async
|
||||||
def get_or_create_conversation(self):
|
def get_or_create_conversation(self):
|
||||||
conv = Conversation.objects.filter(
|
conv = Conversation.objects.filter(
|
||||||
@@ -93,6 +103,10 @@ class PromptConsumer(AsyncWebsocketConsumer):
|
|||||||
).update(is_active=False)
|
).update(is_active=False)
|
||||||
return Conversation.objects.create(user=self.user, task_id=self.task_id)
|
return Conversation.objects.create(user=self.user, task_id=self.task_id)
|
||||||
|
|
||||||
|
@database_sync_to_async
|
||||||
|
def delete_message(self, message):
|
||||||
|
message.delete()
|
||||||
|
|
||||||
@database_sync_to_async
|
@database_sync_to_async
|
||||||
def save_message(self, role, content, code=None):
|
def save_message(self, role, content, code=None):
|
||||||
return Message.objects.create(
|
return Message.objects.create(
|
||||||
|
|||||||
Reference in New Issue
Block a user