fix chain modal
This commit is contained in:
@@ -489,6 +489,43 @@ def get_showcase_detail(request, submission_id: UUID):
|
||||
}
|
||||
|
||||
|
||||
def _build_prompt_rounds(source_msg: Message):
|
||||
messages = list(source_msg.conversation.messages.all().order_by("created", "id"))
|
||||
try:
|
||||
source_index = messages.index(source_msg)
|
||||
except ValueError:
|
||||
source_index = len(messages) - 1
|
||||
messages = messages[: source_index + 1]
|
||||
|
||||
rounds = []
|
||||
for i, msg in enumerate(messages):
|
||||
if msg.role != "user":
|
||||
continue
|
||||
html = css = js = None
|
||||
assistant_msg_id = None
|
||||
for reply in messages[i + 1:]:
|
||||
if reply.role == "user":
|
||||
break
|
||||
if reply.role == "assistant":
|
||||
assistant_msg_id = reply.id
|
||||
html = reply.code_html
|
||||
css = reply.code_css
|
||||
js = reply.code_js
|
||||
break
|
||||
rounds.append(
|
||||
{
|
||||
"question": msg.content,
|
||||
"source": msg.source,
|
||||
"prompt_level": msg.prompt_level,
|
||||
"assistant_msg_id": assistant_msg_id,
|
||||
"html": html,
|
||||
"css": css,
|
||||
"js": js,
|
||||
}
|
||||
)
|
||||
return rounds
|
||||
|
||||
|
||||
@router.get("/showcase/{submission_id}/prompt-chain/", response=List[PromptRoundOut])
|
||||
@login_required
|
||||
def get_showcase_prompt_chain(request, submission_id: UUID):
|
||||
@@ -501,32 +538,19 @@ def get_showcase_prompt_chain(request, submission_id: UUID):
|
||||
except Message.DoesNotExist:
|
||||
raise HttpError(404, "该作品没有关联提示词链")
|
||||
|
||||
messages = list(source_msg.conversation.messages.all().order_by("created"))
|
||||
rounds = []
|
||||
for i, msg in enumerate(messages):
|
||||
if msg.role != "user":
|
||||
continue
|
||||
html = css = js = None
|
||||
for reply in messages[i + 1:]:
|
||||
if reply.role == "user":
|
||||
break
|
||||
if reply.role == "assistant":
|
||||
html = reply.code_html
|
||||
css = reply.code_css
|
||||
js = reply.code_js
|
||||
break
|
||||
rounds.append(
|
||||
{
|
||||
"question": msg.content,
|
||||
"source": msg.source,
|
||||
"prompt_level": msg.prompt_level,
|
||||
"html": html,
|
||||
"css": css,
|
||||
"js": js,
|
||||
}
|
||||
)
|
||||
return _build_prompt_rounds(source_msg)
|
||||
|
||||
return rounds
|
||||
|
||||
@router.get("/{submission_id}/prompt-chain", response=List[PromptRoundOut])
|
||||
@login_required
|
||||
def get_submission_prompt_chain(request, submission_id: UUID):
|
||||
sub = get_object_or_404(Submission, id=submission_id)
|
||||
try:
|
||||
source_msg = Message.objects.select_related("conversation").get(submission=sub)
|
||||
except Message.DoesNotExist:
|
||||
raise HttpError(404, "该提交没有关联提示词链")
|
||||
|
||||
return _build_prompt_rounds(source_msg)
|
||||
|
||||
|
||||
@router.get("/{submission_id}", response=SubmissionOut)
|
||||
@@ -601,4 +625,3 @@ def update_flag(request, submission_id: UUID, payload: FlagIn):
|
||||
submission.save(update_fields=["flag"])
|
||||
return {"flag": submission.flag}
|
||||
|
||||
|
||||
|
||||
@@ -198,6 +198,7 @@ class PromptRoundOut(Schema):
|
||||
question: str
|
||||
source: str
|
||||
prompt_level: Optional[int] = None
|
||||
assistant_msg_id: Optional[int] = None
|
||||
html: Optional[str] = None
|
||||
css: Optional[str] = None
|
||||
js: Optional[str] = None
|
||||
|
||||
76
submission/tests.py
Normal file
76
submission/tests.py
Normal file
@@ -0,0 +1,76 @@
|
||||
from django.contrib.auth import get_user_model
|
||||
from django.test import TestCase
|
||||
|
||||
from prompt.models import Conversation, Message
|
||||
from task.models import Task
|
||||
|
||||
from .models import Submission
|
||||
|
||||
User = get_user_model()
|
||||
|
||||
|
||||
def _make_user(username):
|
||||
return User.objects.create_user(username=username, password="pw")
|
||||
|
||||
|
||||
def _make_task():
|
||||
return Task.objects.create(
|
||||
title="Test Challenge",
|
||||
task_type="challenge",
|
||||
display=1,
|
||||
content="",
|
||||
)
|
||||
|
||||
|
||||
class SubmissionPromptChainTest(TestCase):
|
||||
def setUp(self):
|
||||
self.viewer = _make_user("viewer")
|
||||
self.author = _make_user("author")
|
||||
self.task = _make_task()
|
||||
|
||||
viewer_conv = Conversation.objects.create(user=self.viewer, task=self.task)
|
||||
Message.objects.create(
|
||||
conversation=viewer_conv,
|
||||
role="user",
|
||||
content="viewer prompt",
|
||||
)
|
||||
Message.objects.create(
|
||||
conversation=viewer_conv,
|
||||
role="assistant",
|
||||
content="viewer answer",
|
||||
code_html="<p>viewer</p>",
|
||||
)
|
||||
|
||||
author_conv = Conversation.objects.create(user=self.author, task=self.task)
|
||||
Message.objects.create(
|
||||
conversation=author_conv,
|
||||
role="user",
|
||||
content="author prompt",
|
||||
)
|
||||
self.submission = Submission.objects.create(
|
||||
user=self.author,
|
||||
task=self.task,
|
||||
html="<button>author</button>",
|
||||
css="button { color: red; }",
|
||||
js="",
|
||||
)
|
||||
Message.objects.create(
|
||||
conversation=author_conv,
|
||||
role="assistant",
|
||||
content="author answer",
|
||||
code_html="<button>author</button>",
|
||||
code_css="button { color: red; }",
|
||||
code_js="",
|
||||
submission=self.submission,
|
||||
)
|
||||
|
||||
def test_normal_user_can_view_prompt_chain_for_another_users_submission(self):
|
||||
self.client.force_login(self.viewer)
|
||||
|
||||
resp = self.client.get(f"/api/submission/{self.submission.id}/prompt-chain")
|
||||
|
||||
self.assertEqual(resp.status_code, 200)
|
||||
data = resp.json()
|
||||
self.assertEqual(len(data), 1)
|
||||
self.assertEqual(data[0]["question"], "author prompt")
|
||||
self.assertEqual(data[0]["html"], "<button>author</button>")
|
||||
Reference in New Issue
Block a user