删除无用代码并且新增流程图相关内容

This commit is contained in:
2025-10-11 23:29:56 +08:00
parent 0f3f2d256f
commit 4168d41a16
33 changed files with 776 additions and 722 deletions

View File

@@ -210,7 +210,6 @@ class LanguagesAPI(APIView):
return self.success(
{
"languages": SysOptions.languages,
"spj_languages": SysOptions.spj_languages,
}
)

0
flowchart/__init__.py Normal file
View File

0
flowchart/admin.py Normal file
View File

7
flowchart/apps.py Normal file
View File

@@ -0,0 +1,7 @@
from django.apps import AppConfig
class FlowchartConfig(AppConfig):
default_auto_field = 'django.db.models.BigAutoField'
name = 'flowchart'
verbose_name = '流程图管理'

83
flowchart/consumers.py Normal file
View File

@@ -0,0 +1,83 @@
"""
WebSocket consumers for flowchart evaluation updates
"""
import json
import logging
from channels.generic.websocket import AsyncWebsocketConsumer
logger = logging.getLogger(__name__)
class FlowchartConsumer(AsyncWebsocketConsumer):
"""
WebSocket consumer for real-time flowchart evaluation updates
当用户提交流程图后,通过 WebSocket 实时接收AI评分状态更新
"""
async def connect(self):
"""处理 WebSocket 连接"""
self.user = self.scope["user"]
# 只允许认证用户连接
if not self.user.is_authenticated:
await self.close()
return
# 使用用户 ID 作为组名,这样可以向特定用户推送消息
self.group_name = f"flowchart_user_{self.user.id}"
# 加入用户专属的组
await self.channel_layer.group_add(
self.group_name,
self.channel_name
)
await self.accept()
logger.info(f"Flowchart WebSocket connected: user_id={self.user.id}, channel={self.channel_name}")
async def disconnect(self, close_code):
"""处理 WebSocket 断开连接"""
if hasattr(self, 'group_name'):
await self.channel_layer.group_discard(
self.group_name,
self.channel_name
)
logger.info(f"Flowchart WebSocket disconnected: user_id={self.user.id}, close_code={close_code}")
async def receive(self, text_data):
"""
接收客户端消息
客户端可以发送心跳包或订阅特定流程图提交
"""
try:
data = json.loads(text_data)
message_type = data.get("type")
if message_type == "ping":
# 响应心跳包
await self.send(text_data=json.dumps({
"type": "pong",
"timestamp": data.get("timestamp")
}))
elif message_type == "subscribe":
# 订阅特定流程图提交的更新
submission_id = data.get("submission_id")
if submission_id:
logger.info(f"User {self.user.id} subscribed to flowchart submission {submission_id}")
# 可以在这里做额外的订阅逻辑
except json.JSONDecodeError:
logger.error(f"Invalid JSON received from user {self.user.id}")
except Exception as e:
logger.error(f"Error handling message from user {self.user.id}: {str(e)}")
async def flowchart_evaluation_update(self, event):
"""
接收来自 channel layer 的流程图评分更新消息并发送给客户端
这个方法名对应 push_flowchart_evaluation_update 中的 type 字段
"""
try:
# 从 event 中提取数据并发送给客户端
await self.send(text_data=json.dumps(event["data"]))
logger.debug(f"Sent flowchart evaluation update to user {self.user.id}: {event['data']}")
except Exception as e:
logger.error(f"Error sending flowchart evaluation update to user {self.user.id}: {str(e)}")

View File

@@ -0,0 +1,45 @@
# Generated by Django 5.2.3 on 2025-10-11 14:57
import django.db.models.deletion
import utils.shortcuts
from django.conf import settings
from django.db import migrations, models
class Migration(migrations.Migration):
initial = True
dependencies = [
('problem', '0004_problem_allow_flowchart_problem_flowchart_data_and_more'),
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
]
operations = [
migrations.CreateModel(
name='FlowchartSubmission',
fields=[
('id', models.TextField(db_index=True, default=utils.shortcuts.rand_str, primary_key=True, serialize=False)),
('mermaid_code', models.TextField()),
('flowchart_data', models.JSONField(default=dict)),
('status', models.IntegerField(default=0)),
('create_time', models.DateTimeField(auto_now_add=True)),
('ai_score', models.FloatField(blank=True, null=True)),
('ai_grade', models.CharField(blank=True, max_length=10, null=True)),
('ai_feedback', models.TextField(blank=True, null=True)),
('ai_suggestions', models.TextField(blank=True, null=True)),
('ai_criteria_details', models.JSONField(default=dict)),
('ai_provider', models.CharField(default='deepseek', max_length=50)),
('ai_model', models.CharField(default='deepseek-chat', max_length=50)),
('processing_time', models.FloatField(blank=True, null=True)),
('evaluation_time', models.DateTimeField(blank=True, null=True)),
('problem', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='flowchart_submissions', to='problem.problem')),
('user', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='flowchart_submissions', to=settings.AUTH_USER_MODEL)),
],
options={
'db_table': 'flowchart_submission',
'ordering': ['-create_time'],
'indexes': [models.Index(fields=['user', 'create_time'], name='flowchart_user_time_idx'), models.Index(fields=['problem', 'create_time'], name='flowchart_problem_time_idx'), models.Index(fields=['status'], name='flowchart_status_idx')],
},
),
]

View File

65
flowchart/models.py Normal file
View File

@@ -0,0 +1,65 @@
from django.db import models
from django.contrib.auth import get_user_model
from utils.shortcuts import rand_str
from problem.models import Problem
User = get_user_model()
class FlowchartSubmissionStatus:
PENDING = 0 # 等待AI评分
PROCESSING = 1 # AI评分中
COMPLETED = 2 # 评分完成
FAILED = 3 # 评分失败
class FlowchartSubmission(models.Model):
"""流程图提交模型"""
id = models.TextField(default=rand_str, primary_key=True, db_index=True)
# 基础信息
user = models.ForeignKey(User, on_delete=models.CASCADE, related_name='flowchart_submissions')
problem = models.ForeignKey(Problem, on_delete=models.CASCADE, related_name='flowchart_submissions')
# 提交内容
mermaid_code = models.TextField() # Mermaid代码
flowchart_data = models.JSONField(default=dict) # 流程图元数据
# 状态信息
status = models.IntegerField(default=FlowchartSubmissionStatus.PENDING)
create_time = models.DateTimeField(auto_now_add=True)
# AI评分结果
ai_score = models.FloatField(null=True, blank=True) # AI评分 (0-100)
ai_grade = models.CharField(max_length=10, null=True, blank=True) # 等级 (S/A/B/C)
ai_feedback = models.TextField(null=True, blank=True) # AI反馈
ai_suggestions = models.TextField(null=True, blank=True) # AI建议
ai_criteria_details = models.JSONField(default=dict) # 详细评分标准
# 处理信息
ai_provider = models.CharField(max_length=50, default='deepseek')
ai_model = models.CharField(max_length=50, default='deepseek-chat')
processing_time = models.FloatField(null=True, blank=True) # AI处理耗时(秒)
evaluation_time = models.DateTimeField(null=True, blank=True) # 评分完成时间
class Meta:
db_table = 'flowchart_submission'
ordering = ['-create_time']
indexes = [
models.Index(fields=['user', 'create_time'], name='flowchart_user_time_idx'),
models.Index(fields=['problem', 'create_time'], name='flowchart_problem_time_idx'),
models.Index(fields=['status'], name='flowchart_status_idx'),
]
def __str__(self):
return f"FlowchartSubmission {self.id}"
def check_user_permission(self, user, check_share=True):
"""检查用户权限"""
if (
self.user_id == user.id
or not user.is_regular_user()
or self.problem.created_by_id == user.id
):
return True
return False

60
flowchart/serializers.py Normal file
View File

@@ -0,0 +1,60 @@
from rest_framework import serializers
from .models import FlowchartSubmission
class CreateFlowchartSubmissionSerializer(serializers.Serializer):
problem_id = serializers.IntegerField()
mermaid_code = serializers.CharField()
flowchart_data = serializers.JSONField(required=False, default=dict)
def validate_mermaid_code(self, value):
if not value.strip():
raise serializers.ValidationError("Mermaid代码不能为空")
return value
class FlowchartSubmissionSerializer(serializers.ModelSerializer):
class Meta:
model = FlowchartSubmission
fields = [
"id",
"user",
"problem",
"mermaid_code",
"flowchart_data",
"status",
"create_time",
"ai_score",
"ai_grade",
"ai_feedback",
"ai_suggestions",
"ai_criteria_details",
"ai_provider",
"ai_model",
"processing_time",
"evaluation_time",
]
read_only_fields = ["id", "create_time", "evaluation_time"]
class FlowchartSubmissionListSerializer(serializers.ModelSerializer):
"""用于列表显示的简化序列化器"""
username = serializers.CharField(source="user.username")
problem_title = serializers.CharField(source="problem.title")
class Meta:
model = FlowchartSubmission
fields = [
"id",
"username",
"problem_title",
"status",
"create_time",
"ai_score",
"ai_grade",
"ai_provider",
"ai_model",
"processing_time",
"evaluation_time",
]

186
flowchart/tasks.py Normal file
View File

@@ -0,0 +1,186 @@
import dramatiq
import json
import time
from openai import OpenAI
from django.db import transaction
from django.utils import timezone
from utils.shortcuts import get_env, DRAMATIQ_WORKER_ARGS
from .models import FlowchartSubmission, FlowchartSubmissionStatus
@dramatiq.actor(**DRAMATIQ_WORKER_ARGS(max_retries=3))
def evaluate_flowchart_task(submission_id):
"""异步AI评分任务"""
try:
submission = FlowchartSubmission.objects.get(id=submission_id)
# 更新状态为处理中
submission.status = FlowchartSubmissionStatus.PROCESSING
submission.save()
start_time = time.time()
# 使用固定评分标准
system_prompt = build_evaluation_prompt(submission.problem)
# 构建用户提示词,包含标准答案对比
user_prompt = f"""
请对以下Mermaid流程图进行评分
学生提交的流程图:
```mermaid
{submission.mermaid_code}
```
标准答案参考:
```mermaid
{submission.problem.mermaid_code}
```
"""
# 如果有流程图提示,添加到提示词中
if submission.problem.flowchart_hint:
user_prompt += f"""
设计提示:{submission.problem.flowchart_hint}
"""
user_prompt += """
请按照评分标准进行详细评估并给出0-100的分数。
"""
# 调用AI进行评分
api_key = get_env("AI_KEY")
if not api_key:
raise Exception("AI_KEY is not set")
client = OpenAI(api_key=api_key, base_url="https://api.deepseek.com")
response = client.chat.completions.create(
model="deepseek-chat",
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
],
temperature=0.3,
)
ai_response = response.choices[0].message.content
score_data = parse_ai_evaluation_response(ai_response)
processing_time = time.time() - start_time
# 保存评分结果
with transaction.atomic():
submission.ai_score = score_data['score']
submission.ai_grade = score_data['grade']
submission.ai_feedback = score_data['feedback']
submission.ai_suggestions = score_data.get('suggestions', '')
submission.ai_criteria_details = score_data.get('criteria_details', {})
submission.ai_provider = 'deepseek'
submission.ai_model = 'deepseek-chat'
submission.processing_time = processing_time
submission.status = FlowchartSubmissionStatus.COMPLETED
submission.evaluation_time = timezone.now()
submission.save()
# 推送评分完成通知
from utils.websocket import push_flowchart_evaluation_update
push_flowchart_evaluation_update(
submission_id=str(submission.id),
user_id=submission.user_id,
data={
"type": "flowchart_evaluation_completed",
"submission_id": str(submission.id),
"score": score_data['score'],
"grade": score_data['grade'],
"feedback": score_data['feedback']
}
)
except Exception as e:
# 处理失败
submission.status = FlowchartSubmissionStatus.FAILED
submission.save()
# 推送错误通知
from utils.websocket import push_flowchart_evaluation_update
push_flowchart_evaluation_update(
submission_id=str(submission.id),
user_id=submission.user_id,
data={
"type": "flowchart_evaluation_failed",
"submission_id": str(submission.id),
"error": str(e)
}
)
raise e
def build_evaluation_prompt(problem):
"""构建AI评分提示词 - 使用固定标准"""
# 使用固定的评分标准
criteria_text = """
- 逻辑正确性 (权重: 1.0, 最高分: 40): 检查流程图的逻辑是否正确,包括条件判断、循环结构等
- 完整性 (权重: 0.8, 最高分: 30): 检查流程图是否包含所有必要的步骤和分支
- 规范性 (权重: 0.6, 最高分: 20): 检查流程图符号使用是否规范,是否符合标准
- 清晰度 (权重: 0.4, 最高分: 10): 评估流程图的整体布局和可读性
"""
return f"""
你是一个专业的编程教学助手负责评估学生提交的Mermaid流程图。
评分标准:
{criteria_text}
评分要求:
1. 仔细分析流程图的逻辑正确性、完整性和清晰度
2. 检查是否涵盖了题目的所有要求
3. 评估流程图的规范性和可读性
4. 给出0-100的分数
5. 提供详细的反馈和改进建议
评分等级:
- S级 (90-100分): 优秀,逻辑清晰,完全符合要求
- A级 (80-89分): 良好,基本符合要求,有少量改进空间
- B级 (70-79分): 及格,基本正确但存在一些问题
- C级 (0-69分): 需要改进,存在明显问题
请以JSON格式返回评分结果
{{
"score": 85,
"grade": "A",
"feedback": "详细的反馈内容",
"suggestions": "改进建议",
"criteria_details": {{
"逻辑正确性": {{"score": 35, "max": 40, "comment": "逻辑基本正确"}},
"完整性": {{"score": 25, "max": 30, "comment": "缺少部分步骤"}},
"规范性": {{"score": 18, "max": 20, "comment": "符号使用规范"}},
"清晰度": {{"score": 8, "max": 10, "comment": "布局清晰"}}
}}
}}
"""
def parse_ai_evaluation_response(ai_response):
"""解析AI评分响应"""
try:
import re
json_match = re.search(r'\{.*\}', ai_response, re.DOTALL)
if json_match:
data = json.loads(json_match.group())
else:
data = {
"score": 60,
"grade": "C",
"feedback": "AI评分解析失败请重新提交",
"suggestions": "",
"criteria_details": {}
}
return data
except Exception:
return {
"score": 60,
"grade": "C",
"feedback": "AI评分解析失败请重新提交",
"suggestions": "",
"criteria_details": {}
}

View File

@@ -0,0 +1 @@
# URLs package

12
flowchart/urls/oj.py Normal file
View File

@@ -0,0 +1,12 @@
from django.urls import path
from ..views.oj import (
FlowchartSubmissionAPI,
FlowchartSubmissionListAPI,
FlowchartSubmissionRetryAPI
)
urlpatterns = [
path('flowchart/submission', FlowchartSubmissionAPI.as_view()),
path('flowchart/submissions', FlowchartSubmissionListAPI.as_view()),
path('flowchart/submission/retry', FlowchartSubmissionRetryAPI.as_view()),
]

3
flowchart/views.py Normal file
View File

@@ -0,0 +1,3 @@
from django.shortcuts import render
# Create your views here.

View File

@@ -0,0 +1 @@
# Views package

138
flowchart/views/oj.py Normal file
View File

@@ -0,0 +1,138 @@
from utils.api import APIView
from account.decorators import login_required
from flowchart.models import FlowchartSubmission, FlowchartSubmissionStatus
from flowchart.serializers import (
CreateFlowchartSubmissionSerializer,
FlowchartSubmissionSerializer,
FlowchartSubmissionListSerializer
)
from flowchart.tasks import evaluate_flowchart_task
class FlowchartSubmissionAPI(APIView):
@login_required
def post(self, request):
"""创建流程图提交"""
serializer = CreateFlowchartSubmissionSerializer(data=request.data)
if not serializer.is_valid():
return self.error(serializer.errors)
data = serializer.validated_data
# 验证题目存在
try:
from problem.models import Problem
problem = Problem.objects.get(_id=data['problem_id'])
except Problem.DoesNotExist:
return self.error("Problem doesn't exist")
# 验证题目是否允许流程图提交
if not problem.allow_flowchart:
return self.error("This problem does not allow flowchart submission")
# 创建提交记录
submission = FlowchartSubmission.objects.create(
user=request.user,
problem=problem,
mermaid_code=data['mermaid_code'],
flowchart_data=data.get('flowchart_data', {})
)
# 启动AI评分任务
evaluate_flowchart_task.send(submission.id)
return self.success({
'submission_id': submission.id,
'status': 'pending'
})
@login_required
def get(self, request):
"""获取流程图提交详情"""
submission_id = request.GET.get('id')
if not submission_id:
return self.error("submission_id is required")
try:
submission = FlowchartSubmission.objects.get(id=submission_id)
except FlowchartSubmission.DoesNotExist:
return self.error("Submission doesn't exist")
if not submission.check_user_permission(request.user):
return self.error("No permission for this submission")
serializer = FlowchartSubmissionSerializer(submission)
return self.success(serializer.data)
class FlowchartSubmissionListAPI(APIView):
@login_required
def get(self, request):
"""获取流程图提交列表"""
user_id = request.GET.get('user_id')
problem_id = request.GET.get('problem_id')
offset = int(request.GET.get('offset', 0))
limit = int(request.GET.get('limit', 20))
queryset = FlowchartSubmission.objects.select_related('user', 'problem')
# 权限过滤
if not request.user.is_admin_role():
queryset = queryset.filter(user=request.user)
# 其他过滤条件
if user_id:
queryset = queryset.filter(user_id=user_id)
if problem_id:
queryset = queryset.filter(problem_id=problem_id)
total = queryset.count()
submissions = queryset[offset:offset + limit]
serializer = FlowchartSubmissionListSerializer(submissions, many=True)
return self.success({
'results': serializer.data,
'total': total
})
class FlowchartSubmissionRetryAPI(APIView):
@login_required
def post(self, request):
"""重新触发AI评分"""
submission_id = request.data.get('submission_id')
if not submission_id:
return self.error("submission_id is required")
try:
submission = FlowchartSubmission.objects.get(id=submission_id)
except FlowchartSubmission.DoesNotExist:
return self.error("Submission doesn't exist")
# 检查权限
if not submission.check_user_permission(request.user):
return self.error("No permission for this submission")
# 检查是否可以重新评分
if submission.status not in [FlowchartSubmissionStatus.FAILED, FlowchartSubmissionStatus.COMPLETED]:
return self.error("Submission is not in a state that allows retry")
# 重置状态并重新启动AI评分
submission.status = FlowchartSubmissionStatus.PENDING
submission.ai_score = None
submission.ai_grade = None
submission.ai_feedback = None
submission.ai_suggestions = None
submission.ai_criteria_details = {}
submission.processing_time = None
submission.evaluation_time = None
submission.save()
# 重新启动AI评分任务
evaluate_flowchart_task.send(submission.id)
return self.success({
'submission_id': submission.id,
'status': 'pending',
'message': 'AI evaluation restarted'
})

View File

@@ -67,26 +67,6 @@ class DispatcherBase(object):
logger.exception(e)
class SPJCompiler(DispatcherBase):
def __init__(self, spj_code, spj_version, spj_language):
super().__init__()
spj_compile_config = list(filter(lambda config: spj_language == config["name"], SysOptions.spj_languages))[0]["spj"][
"compile"]
self.data = {
"src": spj_code,
"spj_version": spj_version,
"spj_compile_config": spj_compile_config
}
def compile_spj(self):
with ChooseJudgeServer() as server:
if not server:
return "No available judge_server"
result = self._request(urljoin(server.service_url, "compile_spj"), data=self.data)
if not result:
return "Failed to call judge server"
if result["err"]:
return result["data"]
class JudgeDispatcher(DispatcherBase):
@@ -126,12 +106,6 @@ class JudgeDispatcher(DispatcherBase):
def judge(self):
language = self.submission.language
sub_config = list(filter(lambda item: language == item["name"], SysOptions.languages))[0]
spj_config = {}
if self.problem.spj_code:
for lang in SysOptions.spj_languages:
if lang["name"] == self.problem.spj_language:
spj_config = lang["spj"]
break
if language in self.problem.template:
template = parse_problem_template(self.problem.template[language])
@@ -146,10 +120,6 @@ class JudgeDispatcher(DispatcherBase):
"max_memory": 1024 * 1024 * self.problem.memory_limit,
"test_case_id": self.problem.test_case_id,
"output": False,
"spj_version": self.problem.spj_version,
"spj_config": spj_config.get("config"),
"spj_compile_config": spj_config.get("compile"),
"spj_src": self.problem.spj_code,
"io_mode": self.problem.io_mode
}

View File

@@ -35,20 +35,6 @@ int main() {
}
}
_c_lang_spj_compile = {
"src_name": "spj-{spj_version}.c",
"exe_name": "spj-{spj_version}",
"max_cpu_time": 3000,
"max_real_time": 10000,
"max_memory": 1024 * 1024 * 1024,
"compile_command": "/usr/bin/gcc -DONLINE_JUDGE -O2 -w -fmax-errors=3 -std=c17 {src_path} -lm -o {exe_path}"
}
_c_lang_spj_config = {
"exe_name": "spj-{spj_version}",
"command": "{exe_path} {in_file_path} {user_out_file_path}",
"seccomp_rule": "c_cpp"
}
_cpp_lang_config = {
"template": """//PREPEND BEGIN
@@ -82,20 +68,6 @@ int main() {
}
}
_cpp_lang_spj_compile = {
"src_name": "spj-{spj_version}.cpp",
"exe_name": "spj-{spj_version}",
"max_cpu_time": 10000,
"max_real_time": 20000,
"max_memory": 1024 * 1024 * 1024,
"compile_command": "/usr/bin/g++ -DONLINE_JUDGE -O2 -w -fmax-errors=3 -std=c++20 {src_path} -lm -o {exe_path}"
}
_cpp_lang_spj_config = {
"exe_name": "spj-{spj_version}",
"command": "{exe_path} {in_file_path} {user_out_file_path}",
"seccomp_rule": "c_cpp"
}
_java_lang_config = {
"template": """//PREPEND BEGIN
@@ -224,10 +196,8 @@ console.log(add(1, 2))
}
languages = [
{"config": _c_lang_config, "name": "C", "description": "GCC 13", "content_type": "text/x-csrc",
"spj": {"compile": _c_lang_spj_compile, "config": _c_lang_spj_config}},
{"config": _cpp_lang_config, "name": "C++", "description": "GCC 13", "content_type": "text/x-c++src",
"spj": {"compile": _cpp_lang_spj_compile, "config": _cpp_lang_spj_config}},
{"config": _c_lang_config, "name": "C", "description": "GCC 13", "content_type": "text/x-csrc"},
{"config": _cpp_lang_config, "name": "C++", "description": "GCC 13", "content_type": "text/x-c++src"},
{"config": _java_lang_config, "name": "Java", "description": "Temurin 21", "content_type": "text/x-java"},
{"config": _py3_lang_config, "name": "Python3", "description": "Python 3.12", "content_type": "text/x-python"},
{"config": _go_lang_config, "name": "Golang", "description": "Golang 1.22", "content_type": "text/x-go"},

View File

@@ -6,8 +6,8 @@ BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
DATABASES = {
"default": {
"ENGINE": "django.db.backends.postgresql",
"HOST": "10.13.114.114",
"PORT": "5433",
"HOST": "150.158.29.156",
"PORT": "5432",
"NAME": "onlinejudge",
"USER": "onlinejudge",
"PASSWORD": "onlinejudge",
@@ -15,7 +15,7 @@ DATABASES = {
}
REDIS_CONF = {
"host": "10.13.114.114",
"host": "150.158.29.156",
"port": 6379,
}

View File

@@ -5,9 +5,11 @@ WebSocket URL Configuration for oj project.
from django.urls import path
from submission.consumers import SubmissionConsumer
from conf.consumers import ConfigConsumer
from flowchart.consumers import FlowchartConsumer
websocket_urlpatterns = [
path("ws/submission/", SubmissionConsumer.as_asgi()),
path("ws/config/", ConfigConsumer.as_asgi()),
path("ws/flowchart/", FlowchartConsumer.as_asgi()),
]

View File

@@ -58,6 +58,7 @@ LOCAL_APPS = [
"comment",
"tutorial",
"ai",
"flowchart",
]
INSTALLED_APPS = VENDOR_APPS + LOCAL_APPS

View File

@@ -20,4 +20,5 @@ urlpatterns = [
path("api/", include("tutorial.urls.tutorial")),
path("api/admin/", include("tutorial.urls.admin")),
path("api/", include("ai.urls.oj")),
path("api/", include("flowchart.urls.oj")),
]

View File

@@ -273,18 +273,10 @@ class _SysOptionsMeta(type):
def languages(cls, value):
cls._set_option(OptionKeys.languages, value)
@my_property(ttl=DEFAULT_SHORT_TTL)
def spj_languages(cls):
return [item for item in cls.languages if "spj" in item]
@my_property(ttl=DEFAULT_SHORT_TTL)
def language_names(cls):
return [item["name"] for item in cls.languages]
@my_property(ttl=DEFAULT_SHORT_TTL)
def spj_language_names(cls):
return [item["name"] for item in cls.languages if "spj" in item]
@my_property(ttl=DEFAULT_SHORT_TTL)
def enable_maxkb(cls):
return cls._get_option(OptionKeys.enable_maxkb)

View File

@@ -0,0 +1,38 @@
# Generated by Django 5.2.3 on 2025-10-11 14:57
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('problem', '0003_problem_answers'),
]
operations = [
migrations.AddField(
model_name='problem',
name='allow_flowchart',
field=models.BooleanField(default=False),
),
migrations.AddField(
model_name='problem',
name='flowchart_data',
field=models.JSONField(default=dict),
),
migrations.AddField(
model_name='problem',
name='flowchart_hint',
field=models.TextField(blank=True, null=True),
),
migrations.AddField(
model_name='problem',
name='mermaid_code',
field=models.TextField(blank=True, null=True),
),
migrations.AddField(
model_name='problem',
name='show_flowchart',
field=models.BooleanField(default=False),
),
]

View File

@@ -0,0 +1,33 @@
# Generated by Django 5.2.3 on 2025-10-11 15:22
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
('problem', '0004_problem_allow_flowchart_problem_flowchart_data_and_more'),
]
operations = [
migrations.RemoveField(
model_name='problem',
name='spj',
),
migrations.RemoveField(
model_name='problem',
name='spj_code',
),
migrations.RemoveField(
model_name='problem',
name='spj_compile_ok',
),
migrations.RemoveField(
model_name='problem',
name='spj_language',
),
migrations.RemoveField(
model_name='problem',
name='spj_version',
),
]

View File

@@ -1,4 +1,3 @@
from django.conf import settings
from django.db import models
from account.models import User
@@ -67,12 +66,6 @@ class Problem(models.Model):
memory_limit = models.IntegerField()
# io mode
io_mode = models.JSONField(default=_default_io_mode)
# special judge related
spj = models.BooleanField(default=False)
spj_language = models.TextField(null=True)
spj_code = models.TextField(null=True)
spj_version = models.TextField(null=True)
spj_compile_ok = models.BooleanField(default=False)
rule_type = models.TextField()
visible = models.BooleanField(default=True)
difficulty = models.TextField()
@@ -88,6 +81,13 @@ class Problem(models.Model):
# {JudgeStatus.ACCEPTED: 3, JudgeStatus.WRONG_ANSWER: 11}, the number means count
statistic_info = models.JSONField(default=dict)
share_submission = models.BooleanField(default=False)
# 流程图相关字段
allow_flowchart = models.BooleanField(default=False) # 是否允许/需要提交流程图
mermaid_code = models.TextField(null=True, blank=True) # 流程图答案(Mermaid代码)
flowchart_data = models.JSONField(default=dict) # 流程图答案元数据(JSON格式)
flowchart_hint = models.TextField(null=True, blank=True) # 流程图提示信息
show_flowchart = models.BooleanField(default=False) # 是否显示流程图答案数据如果True这样就不需要提交流程图了说明就是给学生看的
class Meta:
db_table = "problem"

View File

@@ -2,12 +2,10 @@ import re
from django import forms
from options.options import SysOptions
from utils.api import UsernameSerializer, serializers
from utils.constants import Difficulty
from utils.serializers import (
LanguageNameMultiChoiceField,
SPJLanguageNameChoiceField,
LanguageNameChoiceField,
)
@@ -16,7 +14,6 @@ from .utils import parse_problem_template
class TestCaseUploadForm(forms.Form):
spj = forms.CharField(max_length=12)
file = forms.FileField()
@@ -73,10 +70,6 @@ class CreateOrEditProblemSerializer(serializers.Serializer):
choices=[ProblemRuleType.ACM, ProblemRuleType.OI]
)
io_mode = ProblemIOModeSerializer()
spj = serializers.BooleanField()
spj_language = SPJLanguageNameChoiceField(allow_blank=True, allow_null=True)
spj_code = serializers.CharField(allow_blank=True, allow_null=True)
spj_compile_ok = serializers.BooleanField(default=False)
visible = serializers.BooleanField()
difficulty = serializers.ChoiceField(choices=Difficulty.choices())
tags = serializers.ListField(
@@ -92,6 +85,17 @@ class CreateOrEditProblemSerializer(serializers.Serializer):
)
share_submission = serializers.BooleanField()
# 流程图相关字段
allow_flowchart = serializers.BooleanField(required=False, default=False)
show_flowchart = serializers.BooleanField(required=False, default=False)
mermaid_code = serializers.CharField(
allow_blank=True, allow_null=True, required=False
)
flowchart_hint = serializers.CharField(
allow_blank=True, allow_null=True, required=False
)
class CreateProblemSerializer(CreateOrEditProblemSerializer):
pass
@@ -116,11 +120,6 @@ class TagSerializer(serializers.ModelSerializer):
fields = "__all__"
class CompileSPJSerializer(serializers.Serializer):
spj_language = SPJLanguageNameChoiceField()
spj_code = serializers.CharField()
class BaseProblemSerializer(serializers.ModelSerializer):
tags = serializers.SlugRelatedField(many=True, slug_field="name", read_only=True)
created_by = UsernameSerializer()
@@ -154,9 +153,6 @@ class ProblemSerializer(BaseProblemSerializer):
"test_case_id",
"visible",
"is_public",
"spj_code",
"spj_version",
"spj_compile_ok",
"answers",
)
@@ -188,9 +184,6 @@ class ProblemSafeSerializer(BaseProblemSerializer):
"test_case_id",
"visible",
"is_public",
"spj_code",
"spj_version",
"spj_compile_ok",
"difficulty",
"submission_number",
"accepted_number",
@@ -204,101 +197,12 @@ class ContestProblemMakePublicSerializer(serializers.Serializer):
display_id = serializers.CharField(max_length=32)
class ExportProblemSerializer(serializers.ModelSerializer):
display_id = serializers.SerializerMethodField()
description = serializers.SerializerMethodField()
input_description = serializers.SerializerMethodField()
output_description = serializers.SerializerMethodField()
test_case_score = serializers.SerializerMethodField()
hint = serializers.SerializerMethodField()
spj = serializers.SerializerMethodField()
template = serializers.SerializerMethodField()
source = serializers.SerializerMethodField()
tags = serializers.SlugRelatedField(many=True, slug_field="name", read_only=True)
def get_display_id(self, obj):
return obj._id
def _html_format_value(self, value):
return {"format": "html", "value": value}
def get_description(self, obj):
return self._html_format_value(obj.description)
def get_input_description(self, obj):
return self._html_format_value(obj.input_description)
def get_output_description(self, obj):
return self._html_format_value(obj.output_description)
def get_hint(self, obj):
return self._html_format_value(obj.hint)
def get_test_case_score(self, obj):
return [
{
"score": item["score"] if obj.rule_type == ProblemRuleType.OI else 100,
"input_name": item["input_name"],
"output_name": item["output_name"],
}
for item in obj.test_case_score
]
def get_spj(self, obj):
return {"code": obj.spj_code, "language": obj.spj_language} if obj.spj else None
def get_template(self, obj):
ret = {}
for k, v in obj.template.items():
ret[k] = parse_problem_template(v)
return ret
def get_source(self, obj):
return obj.source or f"{SysOptions.website_name} {SysOptions.website_base_url}"
class Meta:
model = Problem
fields = (
"display_id",
"title",
"description",
"tags",
"input_description",
"output_description",
"test_case_score",
"hint",
"time_limit",
"memory_limit",
"samples",
"template",
"spj",
"rule_type",
"source",
"template",
)
class AddContestProblemSerializer(serializers.Serializer):
contest_id = serializers.IntegerField()
problem_id = serializers.IntegerField()
display_id = serializers.CharField()
class ExportProblemRequestSerializer(serializers.Serializer):
problem_id = serializers.ListField(
child=serializers.IntegerField(), allow_empty=False
)
class UploadProblemForm(forms.Form):
file = forms.FileField()
class FormatValueSerializer(serializers.Serializer):
format = serializers.ChoiceField(choices=["html", "markdown"])
value = serializers.CharField(allow_blank=True)
class TestCaseScoreSerializer(serializers.Serializer):
score = serializers.IntegerField(min_value=1)
input_name = serializers.CharField(max_length=32)
@@ -311,58 +215,6 @@ class TemplateSerializer(serializers.Serializer):
append = serializers.CharField()
class SPJSerializer(serializers.Serializer):
code = serializers.CharField()
language = SPJLanguageNameChoiceField()
class AnswerSerializer(serializers.Serializer):
code = serializers.CharField()
language = LanguageNameChoiceField()
class ImportProblemSerializer(serializers.Serializer):
display_id = serializers.CharField(max_length=128)
title = serializers.CharField(max_length=128)
description = FormatValueSerializer()
input_description = FormatValueSerializer()
output_description = FormatValueSerializer()
hint = FormatValueSerializer()
test_case_score = serializers.ListField(
child=TestCaseScoreSerializer(), allow_null=True
)
time_limit = serializers.IntegerField(min_value=1, max_value=60000)
memory_limit = serializers.IntegerField(min_value=1, max_value=10240)
samples = serializers.ListField(child=CreateSampleSerializer())
template = serializers.DictField(child=TemplateSerializer())
spj = SPJSerializer(allow_null=True)
rule_type = serializers.ChoiceField(choices=ProblemRuleType.choices())
source = serializers.CharField(max_length=200, allow_blank=True, allow_null=True)
answers = serializers.ListField(child=AnswerSerializer())
tags = serializers.ListField(child=serializers.CharField())
class FPSProblemSerializer(serializers.Serializer):
class UnitSerializer(serializers.Serializer):
unit = serializers.ChoiceField(choices=["MB", "s", "ms"])
value = serializers.IntegerField(min_value=1, max_value=60000)
title = serializers.CharField(max_length=128)
description = serializers.CharField()
input = serializers.CharField()
output = serializers.CharField()
hint = serializers.CharField(allow_blank=True, allow_null=True)
time_limit = UnitSerializer()
memory_limit = UnitSerializer()
samples = serializers.ListField(child=CreateSampleSerializer())
source = serializers.CharField(max_length=200, allow_blank=True, allow_null=True)
spj = SPJSerializer(allow_null=True)
template = serializers.ListField(
child=serializers.DictField(), allow_empty=True, allow_null=True
)
append = serializers.ListField(
child=serializers.DictField(), allow_empty=True, allow_null=True
)
prepend = serializers.ListField(
child=serializers.DictField(), allow_empty=True, allow_null=True
)

View File

@@ -20,8 +20,7 @@ from .utils import parse_problem_template
DEFAULT_PROBLEM_DATA = {"_id": "A-110", "title": "test", "description": "<p>test</p>", "input_description": "test",
"output_description": "test", "time_limit": 1000, "memory_limit": 256, "difficulty": "Low",
"visible": True, "tags": ["test"], "languages": ["C", "C++", "Java", "Python2"], "template": {},
"samples": [{"input": "test", "output": "test"}], "spj": False, "spj_language": "C",
"spj_code": "", "spj_compile_ok": True, "test_case_id": "499b26290cc7994e0b497212e842ea85",
"samples": [{"input": "test", "output": "test"}], "test_case_id": "499b26290cc7994e0b497212e842ea85",
"test_case_score": [{"output_name": "1.out", "input_name": "1.in", "output_size": 0,
"stripped_output_md5": "d41d8cd98f00b204e9800998ecf8427e",
"input_size": 0, "score": 0}],
@@ -34,14 +33,6 @@ class ProblemCreateTestBase(APITestCase):
@staticmethod
def add_problem(problem_data, created_by):
data = copy.deepcopy(problem_data)
if data["spj"]:
if not data["spj_language"] or not data["spj_code"]:
raise ValueError("Invalid spj")
data["spj_version"] = hashlib.md5(
(data["spj_language"] + ":" + data["spj_code"]).encode("utf-8")).hexdigest()
else:
data["spj_language"] = None
data["spj_code"] = None
if data["rule_type"] == ProblemRuleType.OI:
total_score = 0
for item in data["test_case_score"]:
@@ -81,12 +72,9 @@ class TestCaseUploadAPITest(APITestCase):
self.create_super_admin()
def test_filter_file_name(self):
self.assertEqual(self.api.filter_name_list(["1.in", "1.out", "2.in", ".DS_Store"], spj=False),
self.assertEqual(self.api.filter_name_list(["1.in", "1.out", "2.in", ".DS_Store"]),
["1.in", "1.out"])
self.assertEqual(self.api.filter_name_list(["2.in", "2.out"], spj=False), [])
self.assertEqual(self.api.filter_name_list(["1.in", "1.out", "2.in"], spj=True), ["1.in", "2.in"])
self.assertEqual(self.api.filter_name_list(["2.in", "3.in"], spj=True), [])
self.assertEqual(self.api.filter_name_list(["2.in", "2.out"]), [])
def make_test_case_zip(self):
base_dir = os.path.join("/tmp", "test_case")
@@ -102,27 +90,13 @@ class TestCaseUploadAPITest(APITestCase):
f.write(os.path.join(base_dir, item), item)
return zip_file
def test_upload_spj_test_case_zip(self):
with open(self.make_test_case_zip(), "rb") as f:
resp = self.client.post(self.url,
data={"spj": "true", "file": f}, format="multipart")
self.assertSuccess(resp)
data = resp.data["data"]
self.assertEqual(data["spj"], True)
test_case_dir = os.path.join(settings.TEST_CASE_DIR, data["id"])
self.assertTrue(os.path.exists(test_case_dir))
for item in data["info"]:
name = item["input_name"]
with open(os.path.join(test_case_dir, name), "r", encoding="utf-8") as f:
self.assertEqual(f.read(), name + "\n" + name + "\n" + "end")
def test_upload_test_case_zip(self):
with open(self.make_test_case_zip(), "rb") as f:
resp = self.client.post(self.url,
data={"spj": "false", "file": f}, format="multipart")
data={"file": f}, format="multipart")
self.assertSuccess(resp)
data = resp.data["data"]
self.assertEqual(data["spj"], False)
test_case_dir = os.path.join(settings.TEST_CASE_DIR, data["id"])
self.assertTrue(os.path.exists(test_case_dir))
for item in data["info"]:
@@ -148,16 +122,6 @@ class ProblemAdminAPITest(APITestCase):
resp = self.client.post(self.url, data=self.data)
self.assertFailed(resp, "Display ID already exists")
def test_spj(self):
data = copy.deepcopy(self.data)
data["spj"] = True
resp = self.client.post(self.url, data)
self.assertFailed(resp, "Invalid spj")
data["spj_code"] = "test"
resp = self.client.post(self.url, data=data)
self.assertSuccess(resp)
def test_get_problem(self):
self.test_create_problem()

View File

@@ -1,18 +1,19 @@
from django.urls import path
from ..views.admin import (ContestProblemAPI, ProblemAPI, TestCaseAPI, MakeContestProblemPublicAPIView,
CompileSPJAPI, AddContestProblemAPI, ExportProblemAPI, ImportProblemAPI,
FPSProblemImport, ProblemVisibleAPI)
from ..views.admin import (
ContestProblemAPI,
ProblemAPI,
TestCaseAPI,
MakeContestProblemPublicAPIView,
AddContestProblemAPI,
ProblemVisibleAPI,
)
urlpatterns = [
path("test_case", TestCaseAPI.as_view()),
path("compile_spj", CompileSPJAPI.as_view()),
path("problem", ProblemAPI.as_view()),
path("problem/visible", ProblemVisibleAPI.as_view()),
path("contest/problem", ContestProblemAPI.as_view()),
path("contest_problem/make_public", MakeContestProblemPublicAPIView.as_view()),
path("contest/add_problem_from_public", AddContestProblemAPI.as_view()),
path("export_problem", ExportProblemAPI.as_view()),
path("import_problem", ImportProblemAPI.as_view()),
path("import_fps", FPSProblemImport.as_view()),
]

View File

@@ -3,29 +3,21 @@ import json
import os
# import shutil
import tempfile
import zipfile
from wsgiref.util import FileWrapper
from django.conf import settings
from django.db import transaction
from django.db.models import Q
from django.http import StreamingHttpResponse, FileResponse
from django.http import StreamingHttpResponse
from account.decorators import problem_permission_required, ensure_created_by
from contest.models import Contest, ContestStatus
from fps.parser import FPSHelper, FPSParser
from judge.dispatcher import SPJCompiler
from options.options import SysOptions
from submission.models import Submission, JudgeStatus
from submission.models import Submission
from utils.api import APIView, CSRFExemptAPIView, validate_serializer, APIError
from utils.constants import Difficulty
from utils.shortcuts import rand_str, natural_sort_key
from utils.tasks import delete_files
from ..models import Problem, ProblemRuleType, ProblemTag
from ..serializers import (
CreateContestProblemSerializer,
CompileSPJSerializer,
CreateProblemSerializer,
EditProblemSerializer,
EditContestProblemSerializer,
@@ -34,23 +26,17 @@ from ..serializers import (
TestCaseUploadForm,
ContestProblemMakePublicSerializer,
AddContestProblemSerializer,
ExportProblemSerializer,
ExportProblemRequestSerializer,
UploadProblemForm,
ImportProblemSerializer,
FPSProblemSerializer,
)
from ..utils import TEMPLATE_BASE, build_problem_template
class TestCaseZipProcessor(object):
def process_zip(self, uploaded_zip_file, spj, dir=""):
def process_zip(self, uploaded_zip_file, dir=""):
try:
zip_file = zipfile.ZipFile(uploaded_zip_file, "r")
except zipfile.BadZipFile:
raise APIError("Bad zip file")
name_list = zip_file.namelist()
test_case_list = self.filter_name_list(name_list, spj=spj, dir=dir)
test_case_list = self.filter_name_list(name_list, dir=dir)
if not test_case_list:
raise APIError("Empty file")
@@ -69,28 +55,22 @@ class TestCaseZipProcessor(object):
if item.endswith(".out"):
md5_cache[item] = hashlib.md5(content.rstrip()).hexdigest()
f.write(content)
test_case_info = {"spj": spj, "test_cases": {}}
test_case_info = {"test_cases": {}}
info = []
if spj:
for index, item in enumerate(test_case_list):
data = {"input_name": item, "input_size": size_cache[item]}
info.append(data)
test_case_info["test_cases"][str(index + 1)] = data
else:
# ["1.in", "1.out", "2.in", "2.out"] => [("1.in", "1.out"), ("2.in", "2.out")]
test_case_list = zip(*[test_case_list[i::2] for i in range(2)])
for index, item in enumerate(test_case_list):
data = {
"stripped_output_md5": md5_cache[item[1]],
"input_size": size_cache[item[0]],
"output_size": size_cache[item[1]],
"input_name": item[0],
"output_name": item[1],
}
info.append(data)
test_case_info["test_cases"][str(index + 1)] = data
# ["1.in", "1.out", "2.in", "2.out"] => [("1.in", "1.out"), ("2.in", "2.out")]
test_case_list = zip(*[test_case_list[i::2] for i in range(2)])
for index, item in enumerate(test_case_list):
data = {
"stripped_output_md5": md5_cache[item[1]],
"input_size": size_cache[item[0]],
"output_size": size_cache[item[1]],
"input_name": item[0],
"output_name": item[1],
}
info.append(data)
test_case_info["test_cases"][str(index + 1)] = data
with open(os.path.join(test_case_dir, "info"), "w", encoding="utf-8") as f:
f.write(json.dumps(test_case_info, indent=4))
@@ -100,29 +80,19 @@ class TestCaseZipProcessor(object):
return info, test_case_id
def filter_name_list(self, name_list, spj, dir=""):
def filter_name_list(self, name_list, dir=""):
ret = []
prefix = 1
if spj:
while True:
in_name = f"{prefix}.in"
if f"{dir}{in_name}" in name_list:
ret.append(in_name)
prefix += 1
continue
else:
return sorted(ret, key=natural_sort_key)
else:
while True:
in_name = f"{prefix}.in"
out_name = f"{prefix}.out"
if f"{dir}{in_name}" in name_list and f"{dir}{out_name}" in name_list:
ret.append(in_name)
ret.append(out_name)
prefix += 1
continue
else:
return sorted(ret, key=natural_sort_key)
while True:
in_name = f"{prefix}.in"
out_name = f"{prefix}.out"
if f"{dir}{in_name}" in name_list and f"{dir}{out_name}" in name_list:
ret.append(in_name)
ret.append(out_name)
prefix += 1
continue
else:
return sorted(ret, key=natural_sort_key)
class TestCaseAPI(CSRFExemptAPIView, TestCaseZipProcessor):
@@ -145,7 +115,7 @@ class TestCaseAPI(CSRFExemptAPIView, TestCaseZipProcessor):
test_case_dir = os.path.join(settings.TEST_CASE_DIR, problem.test_case_id)
if not os.path.isdir(test_case_dir):
return self.error("Test case does not exists")
name_list = self.filter_name_list(os.listdir(test_case_dir), problem.spj)
name_list = self.filter_name_list(os.listdir(test_case_dir))
name_list.append("info")
file_name = os.path.join(test_case_dir, problem.test_case_id + ".zip")
with zipfile.ZipFile(file_name, "w") as file:
@@ -164,7 +134,6 @@ class TestCaseAPI(CSRFExemptAPIView, TestCaseZipProcessor):
def post(self, request):
form = TestCaseUploadForm(request.POST, request.FILES)
if form.is_valid():
spj = form.cleaned_data["spj"] == "true"
file = form.cleaned_data["file"]
else:
return self.error("Upload failed")
@@ -172,39 +141,14 @@ class TestCaseAPI(CSRFExemptAPIView, TestCaseZipProcessor):
with open(zip_file, "wb") as f:
for chunk in file:
f.write(chunk)
info, test_case_id = self.process_zip(zip_file, spj=spj)
info, test_case_id = self.process_zip(zip_file)
os.remove(zip_file)
return self.success({"id": test_case_id, "info": info, "spj": spj})
class CompileSPJAPI(APIView):
@validate_serializer(CompileSPJSerializer)
def post(self, request):
data = request.data
spj_version = rand_str(8)
error = SPJCompiler(
data["spj_code"], spj_version, data["spj_language"]
).compile_spj()
if error:
return self.error(error)
else:
return self.success()
return self.success({"id": test_case_id, "info": info})
class ProblemBase(APIView):
def common_checks(self, request):
data = request.data
if data["spj"]:
if not data["spj_language"] or not data["spj_code"]:
return "Invalid spj"
if not data["spj_compile_ok"]:
return "SPJ code must be compiled successfully"
data["spj_version"] = hashlib.md5(
(data["spj_language"] + ":" + data["spj_code"]).encode("utf-8")
).hexdigest()
else:
data["spj_language"] = None
data["spj_code"] = None
if data["rule_type"] == ProblemRuleType.OI:
total_score = 0
for item in data["test_case_score"]:
@@ -529,257 +473,6 @@ class AddContestProblemAPI(APIView):
return self.success()
class ExportProblemAPI(APIView):
def choose_answers(self, user, problem):
ret = []
for item in problem.languages:
submission = (
Submission.objects.filter(
problem=problem,
user_id=user.id,
language=item,
result=JudgeStatus.ACCEPTED,
)
.order_by("-create_time")
.first()
)
if submission:
ret.append({"language": submission.language, "code": submission.code})
return ret
def process_one_problem(self, zip_file, user, problem, index):
info = ExportProblemSerializer(problem).data
info["answers"] = self.choose_answers(user, problem=problem)
compression = zipfile.ZIP_DEFLATED
zip_file.writestr(
zinfo_or_arcname=f"{index}/problem.json",
data=json.dumps(info, indent=4),
compress_type=compression,
)
problem_test_case_dir = os.path.join(
settings.TEST_CASE_DIR, problem.test_case_id
)
with open(os.path.join(problem_test_case_dir, "info")) as f:
info = json.load(f)
for k, v in info["test_cases"].items():
zip_file.write(
filename=os.path.join(problem_test_case_dir, v["input_name"]),
arcname=f"{index}/testcase/{v['input_name']}",
compress_type=compression,
)
if not info["spj"]:
zip_file.write(
filename=os.path.join(problem_test_case_dir, v["output_name"]),
arcname=f"{index}/testcase/{v['output_name']}",
compress_type=compression,
)
@validate_serializer(ExportProblemRequestSerializer)
def get(self, request):
problems = Problem.objects.filter(id__in=request.data["problem_id"])
for problem in problems:
if problem.contest:
ensure_created_by(problem.contest, request.user)
else:
ensure_created_by(problem, request.user)
path = f"/tmp/{rand_str()}.zip"
with zipfile.ZipFile(path, "w") as zip_file:
for index, problem in enumerate(problems):
self.process_one_problem(
zip_file=zip_file,
user=request.user,
problem=problem,
index=index + 1,
)
delete_files.send_with_options(args=(path,), delay=300_000)
resp = FileResponse(open(path, "rb"))
resp["Content-Type"] = "application/zip"
resp["Content-Disposition"] = "attachment;filename=problem-export.zip"
return resp
class ImportProblemAPI(CSRFExemptAPIView, TestCaseZipProcessor):
request_parsers = ()
def post(self, request):
form = UploadProblemForm(request.POST, request.FILES)
if form.is_valid():
file = form.cleaned_data["file"]
tmp_file = f"/tmp/{rand_str()}.zip"
with open(tmp_file, "wb") as f:
for chunk in file:
f.write(chunk)
else:
return self.error("Upload failed")
count = 0
with zipfile.ZipFile(tmp_file, "r") as zip_file:
name_list = zip_file.namelist()
for item in name_list:
if "/problem.json" in item:
count += 1
with transaction.atomic():
for i in range(1, count + 1):
with zip_file.open(f"{i}/problem.json") as f:
problem_info = json.load(f)
serializer = ImportProblemSerializer(data=problem_info)
if not serializer.is_valid():
return self.error(
f"Invalid problem format, error is {serializer.errors}"
)
else:
problem_info = serializer.data
for item in problem_info["template"].keys():
if item not in SysOptions.language_names:
return self.error(f"Unsupported language {item}")
problem_info["display_id"] = problem_info["display_id"][:24]
for k, v in problem_info["template"].items():
problem_info["template"][k] = build_problem_template(
v["prepend"], v["template"], v["append"]
)
spj = problem_info["spj"] is not None
rule_type = problem_info["rule_type"]
test_case_score = problem_info["test_case_score"]
# process test case
_, test_case_id = self.process_zip(
tmp_file, spj=spj, dir=f"{i}/testcase/"
)
problem_obj = Problem.objects.create(
_id=problem_info["display_id"],
title=problem_info["title"],
description=problem_info["description"]["value"],
input_description=problem_info["input_description"][
"value"
],
output_description=problem_info["output_description"][
"value"
],
hint=problem_info["hint"]["value"],
test_case_score=test_case_score if test_case_score else [],
time_limit=problem_info["time_limit"],
memory_limit=problem_info["memory_limit"],
samples=problem_info["samples"],
template=problem_info["template"],
rule_type=problem_info["rule_type"],
source=problem_info["source"],
spj=spj,
spj_code=problem_info["spj"]["code"] if spj else None,
spj_language=problem_info["spj"]["language"]
if spj
else None,
spj_version=rand_str(8) if spj else "",
languages=SysOptions.language_names,
created_by=request.user,
visible=False,
difficulty=Difficulty.MID,
total_score=sum(item["score"] for item in test_case_score)
if rule_type == ProblemRuleType.OI
else 0,
test_case_id=test_case_id,
)
for tag_name in problem_info["tags"]:
tag_obj, _ = ProblemTag.objects.get_or_create(name=tag_name)
problem_obj.tags.add(tag_obj)
return self.success({"import_count": count})
class FPSProblemImport(CSRFExemptAPIView):
request_parsers = ()
def _create_problem(self, problem_data, creator):
if problem_data["time_limit"]["unit"] == "ms":
time_limit = problem_data["time_limit"]["value"]
else:
time_limit = problem_data["time_limit"]["value"] * 1000
template = {}
prepend = {}
append = {}
for t in problem_data["prepend"]:
prepend[t["language"]] = t["code"]
for t in problem_data["append"]:
append[t["language"]] = t["code"]
for t in problem_data["template"]:
our_lang = lang = t["language"]
if lang == "Python":
our_lang = "Python3"
template[our_lang] = TEMPLATE_BASE.format(
prepend.get(lang, ""), t["code"], append.get(lang, "")
)
spj = problem_data["spj"] is not None
Problem.objects.create(
_id=f"fps-{rand_str(4)}",
title=problem_data["title"],
description=problem_data["description"],
input_description=problem_data["input"],
output_description=problem_data["output"],
hint=problem_data["hint"],
test_case_score=problem_data["test_case_score"],
time_limit=time_limit,
memory_limit=problem_data["memory_limit"]["value"],
samples=problem_data["samples"],
template=template,
rule_type=ProblemRuleType.ACM,
source=problem_data.get("source", ""),
spj=spj,
spj_code=problem_data["spj"]["code"] if spj else None,
spj_language=problem_data["spj"]["language"] if spj else None,
spj_version=rand_str(8) if spj else "",
visible=False,
languages=SysOptions.language_names,
created_by=creator,
difficulty=Difficulty.MID,
test_case_id=problem_data["test_case_id"],
)
def post(self, request):
form = UploadProblemForm(request.POST, request.FILES)
if form.is_valid():
file = form.cleaned_data["file"]
with tempfile.NamedTemporaryFile("wb") as tf:
for chunk in file.chunks(4096):
tf.file.write(chunk)
tf.file.flush()
os.fsync(tf.file)
problems = FPSParser(tf.name).parse()
else:
return self.error("Parse upload file error")
helper = FPSHelper()
with transaction.atomic():
for _problem in problems:
test_case_id = rand_str()
test_case_dir = os.path.join(settings.TEST_CASE_DIR, test_case_id)
os.mkdir(test_case_dir)
score = []
for item in helper.save_test_case(_problem, test_case_dir)[
"test_cases"
].values():
score.append(
{
"score": 0,
"input_name": item["input_name"],
"output_name": item.get("output_name"),
}
)
problem_data = helper.save_image(
_problem, settings.UPLOAD_DIR, settings.UPLOAD_PREFIX
)
s = FPSProblemSerializer(data=problem_data)
if not s.is_valid():
return self.error(f"Parse FPS file error: {s.errors}")
problem_data = s.data
problem_data["test_case_id"] = test_case_id
problem_data["test_case_score"] = score
self._create_problem(problem_data, request.user)
return self.success({"import_count": len(problems)})
class ProblemVisibleAPI(APIView):
@problem_permission_required
def put(self, request):

View File

@@ -4,7 +4,6 @@ WebSocket consumers for submission updates
import json
import logging
from channels.generic.websocket import AsyncWebsocketConsumer
from channels.db import database_sync_to_async
logger = logging.getLogger(__name__)
@@ -73,7 +72,7 @@ class SubmissionConsumer(AsyncWebsocketConsumer):
async def submission_update(self, event):
"""
接收来自 channel layer 的提交更新消息并发送给客户端
接收来自 channel layer 的代码提交更新消息并发送给客户端
这个方法名对应 push_submission_update 中的 type 字段
"""
try:

View File

@@ -1,78 +0,0 @@
from copy import deepcopy
from unittest import mock
from problem.models import Problem, ProblemTag
from utils.api.tests import APITestCase
from .models import Submission
DEFAULT_PROBLEM_DATA = {"_id": "A-110", "title": "test", "description": "<p>test</p>", "input_description": "test",
"output_description": "test", "time_limit": 1000, "memory_limit": 256, "difficulty": "Low",
"visible": True, "tags": ["test"], "languages": ["C", "C++", "Java", "Python2"], "template": {},
"samples": [{"input": "test", "output": "test"}], "spj": False, "spj_language": "C",
"spj_code": "", "test_case_id": "499b26290cc7994e0b497212e842ea85",
"test_case_score": [{"output_name": "1.out", "input_name": "1.in", "output_size": 0,
"stripped_output_md5": "d41d8cd98f00b204e9800998ecf8427e",
"input_size": 0, "score": 0}],
"rule_type": "ACM", "hint": "<p>test</p>", "source": "test"}
DEFAULT_SUBMISSION_DATA = {
"problem_id": "1",
"user_id": 1,
"username": "test",
"code": "xxxxxxxxxxxxxx",
"result": -2,
"info": {},
"language": "C",
"statistic_info": {}
}
# todo contest submission
class SubmissionPrepare(APITestCase):
def _create_problem_and_submission(self):
user = self.create_admin("test", "test123", login=False)
problem_data = deepcopy(DEFAULT_PROBLEM_DATA)
tags = problem_data.pop("tags")
problem_data["created_by"] = user
self.problem = Problem.objects.create(**problem_data)
for tag in tags:
tag = ProblemTag.objects.create(name=tag)
self.problem.tags.add(tag)
self.problem.save()
self.submission_data = deepcopy(DEFAULT_SUBMISSION_DATA)
self.submission_data["problem_id"] = self.problem.id
self.submission = Submission.objects.create(**self.submission_data)
class SubmissionListTest(SubmissionPrepare):
def setUp(self):
self._create_problem_and_submission()
self.create_user("123", "345")
self.url = self.reverse("submission_list_api")
def test_get_submission_list(self):
resp = self.client.get(self.url, data={"limit": "10"})
self.assertSuccess(resp)
@mock.patch("submission.views.oj.judge_task.send")
class SubmissionAPITest(SubmissionPrepare):
def setUp(self):
self._create_problem_and_submission()
self.user = self.create_user("123", "test123")
self.url = self.reverse("submission_api")
def test_create_submission(self, judge_task):
resp = self.client.post(self.url, self.submission_data)
self.assertSuccess(resp)
judge_task.assert_called()
def test_create_submission_with_wrong_language(self, judge_task):
self.submission_data.update({"language": "Python3"})
resp = self.client.post(self.url, self.submission_data)
self.assertFailed(resp)
self.assertDictEqual(resp.data, {"error": "error",
"data": "Python3 is now allowed in the problem"})
judge_task.assert_not_called()

View File

@@ -16,14 +16,6 @@ class LanguageNameChoiceField(serializers.CharField):
return data
class SPJLanguageNameChoiceField(serializers.CharField):
def to_internal_value(self, data):
data = super().to_internal_value(data)
if data and data not in SysOptions.spj_language_names:
raise InvalidLanguage(data)
return data
class LanguageNameMultiChoiceField(serializers.ListField):
def to_internal_value(self, data):
data = super().to_internal_value(data)
@@ -31,12 +23,3 @@ class LanguageNameMultiChoiceField(serializers.ListField):
if item not in SysOptions.language_names:
raise InvalidLanguage(item)
return data
class SPJLanguageNameMultiChoiceField(serializers.ListField):
def to_internal_value(self, data):
data = super().to_internal_value(data)
for item in data:
if item not in SysOptions.spj_language_names:
raise InvalidLanguage(item)
return data

View File

@@ -74,6 +74,39 @@ def push_to_user(user_id: int, message_type: str, data: dict):
logger.error(f"Failed to push message to user {user_id}: error={str(e)}")
def push_flowchart_evaluation_update(submission_id: str, user_id: int, data: dict):
"""
推送流程图评分状态更新到指定用户的 WebSocket 连接
Args:
submission_id: 流程图提交 ID
user_id: 用户 ID
data: 要发送的数据,应该包含 type, submission_id, score, grade, feedback 等字段
"""
channel_layer = get_channel_layer()
if channel_layer is None:
logger.warning("Channel layer is not configured, cannot push flowchart evaluation update")
return
# 构建组名,与 SubmissionConsumer 中的组名一致
group_name = f"submission_user_{user_id}"
try:
# 向指定用户组发送消息
# type 字段对应 consumer 中的方法名flowchart_evaluation_update
async_to_sync(channel_layer.group_send)(
group_name,
{
"type": "flowchart_evaluation_update", # 对应 SubmissionConsumer.flowchart_evaluation_update 方法
"data": data,
}
)
logger.info(f"Pushed flowchart evaluation update: submission_id={submission_id}, user_id={user_id}, type={data.get('type')}")
except Exception as e:
logger.error(f"Failed to push flowchart evaluation update: submission_id={submission_id}, user_id={user_id}, error={str(e)}")
def push_config_update(key: str, value):
"""
推送配置更新到所有连接的客户端