diff --git a/prompt/consumers.py b/prompt/consumers.py index be66dc0..760ae9d 100644 --- a/prompt/consumers.py +++ b/prompt/consumers.py @@ -13,6 +13,7 @@ class PromptConsumer(AsyncWebsocketConsumer): return self.task_id = int(self.scope["url_route"]["kwargs"]["task_id"]) + self.current_user_message = None await self.accept() # Load or create conversation, send history @@ -25,7 +26,9 @@ class PromptConsumer(AsyncWebsocketConsumer): })) async def disconnect(self, close_code): - pass + if self.current_user_message: + await self.delete_message(self.current_user_message) + self.current_user_message = None async def receive(self, text_data): data = json.loads(text_data) @@ -45,37 +48,44 @@ class PromptConsumer(AsyncWebsocketConsumer): return # Save user message - await self.save_message("user", prompt) + self.current_user_message = await self.save_message("user", prompt) - # Build history for LLM - history = await self.get_history_for_llm() - task_content = await self.get_task_content() - - # Stream AI response - full_response = "" try: - async for chunk in stream_chat(task_content, history): - full_response += chunk + # Build history for LLM + history = await self.get_history_for_llm() + task_content = await self.get_task_content() + + # Stream AI response + full_response = "" + try: + async for chunk in stream_chat(task_content, history): + full_response += chunk + await self.send(text_data=json.dumps({ + "type": "stream", + "content": chunk, + })) + except Exception as e: await self.send(text_data=json.dumps({ - "type": "stream", - "content": chunk, + "type": "error", + "content": f"AI 服务出错:{str(e)}", })) - except Exception as e: + return + + # Extract code and save assistant message + code = extract_code(full_response) + await self.save_message("assistant", full_response, code) + self.current_user_message = None + + # Send completion with extracted code await self.send(text_data=json.dumps({ - "type": "error", - "content": f"AI 服务出错:{str(e)}", + "type": "complete", + "code": code, })) - return - # Extract code and save assistant message - code = extract_code(full_response) - await self.save_message("assistant", full_response, code) - - # Send completion with extracted code - await self.send(text_data=json.dumps({ - "type": "complete", - "code": code, - })) + finally: + if self.current_user_message: + await self.delete_message(self.current_user_message) + self.current_user_message = None @database_sync_to_async def get_or_create_conversation(self): @@ -93,6 +103,10 @@ class PromptConsumer(AsyncWebsocketConsumer): ).update(is_active=False) return Conversation.objects.create(user=self.user, task_id=self.task_id) + @database_sync_to_async + def delete_message(self, message): + message.delete() + @database_sync_to_async def save_message(self, role, content, code=None): return Message.objects.create(