删除无用代码并且新增流程图相关内容

This commit is contained in:
2025-10-11 23:29:56 +08:00
parent 0f3f2d256f
commit 4168d41a16
33 changed files with 776 additions and 722 deletions

View File

@@ -0,0 +1,38 @@
# Generated by Django 5.2.3 on 2025-10-11 14:57
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('problem', '0003_problem_answers'),
]
operations = [
migrations.AddField(
model_name='problem',
name='allow_flowchart',
field=models.BooleanField(default=False),
),
migrations.AddField(
model_name='problem',
name='flowchart_data',
field=models.JSONField(default=dict),
),
migrations.AddField(
model_name='problem',
name='flowchart_hint',
field=models.TextField(blank=True, null=True),
),
migrations.AddField(
model_name='problem',
name='mermaid_code',
field=models.TextField(blank=True, null=True),
),
migrations.AddField(
model_name='problem',
name='show_flowchart',
field=models.BooleanField(default=False),
),
]

View File

@@ -0,0 +1,33 @@
# Generated by Django 5.2.3 on 2025-10-11 15:22
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
('problem', '0004_problem_allow_flowchart_problem_flowchart_data_and_more'),
]
operations = [
migrations.RemoveField(
model_name='problem',
name='spj',
),
migrations.RemoveField(
model_name='problem',
name='spj_code',
),
migrations.RemoveField(
model_name='problem',
name='spj_compile_ok',
),
migrations.RemoveField(
model_name='problem',
name='spj_language',
),
migrations.RemoveField(
model_name='problem',
name='spj_version',
),
]

View File

@@ -1,4 +1,3 @@
from django.conf import settings
from django.db import models
from account.models import User
@@ -67,12 +66,6 @@ class Problem(models.Model):
memory_limit = models.IntegerField()
# io mode
io_mode = models.JSONField(default=_default_io_mode)
# special judge related
spj = models.BooleanField(default=False)
spj_language = models.TextField(null=True)
spj_code = models.TextField(null=True)
spj_version = models.TextField(null=True)
spj_compile_ok = models.BooleanField(default=False)
rule_type = models.TextField()
visible = models.BooleanField(default=True)
difficulty = models.TextField()
@@ -88,6 +81,13 @@ class Problem(models.Model):
# {JudgeStatus.ACCEPTED: 3, JudgeStatus.WRONG_ANSWER: 11}, the number means count
statistic_info = models.JSONField(default=dict)
share_submission = models.BooleanField(default=False)
# 流程图相关字段
allow_flowchart = models.BooleanField(default=False) # 是否允许/需要提交流程图
mermaid_code = models.TextField(null=True, blank=True) # 流程图答案(Mermaid代码)
flowchart_data = models.JSONField(default=dict) # 流程图答案元数据(JSON格式)
flowchart_hint = models.TextField(null=True, blank=True) # 流程图提示信息
show_flowchart = models.BooleanField(default=False) # 是否显示流程图答案数据如果True这样就不需要提交流程图了说明就是给学生看的
class Meta:
db_table = "problem"

View File

@@ -2,12 +2,10 @@ import re
from django import forms
from options.options import SysOptions
from utils.api import UsernameSerializer, serializers
from utils.constants import Difficulty
from utils.serializers import (
LanguageNameMultiChoiceField,
SPJLanguageNameChoiceField,
LanguageNameChoiceField,
)
@@ -16,7 +14,6 @@ from .utils import parse_problem_template
class TestCaseUploadForm(forms.Form):
spj = forms.CharField(max_length=12)
file = forms.FileField()
@@ -73,10 +70,6 @@ class CreateOrEditProblemSerializer(serializers.Serializer):
choices=[ProblemRuleType.ACM, ProblemRuleType.OI]
)
io_mode = ProblemIOModeSerializer()
spj = serializers.BooleanField()
spj_language = SPJLanguageNameChoiceField(allow_blank=True, allow_null=True)
spj_code = serializers.CharField(allow_blank=True, allow_null=True)
spj_compile_ok = serializers.BooleanField(default=False)
visible = serializers.BooleanField()
difficulty = serializers.ChoiceField(choices=Difficulty.choices())
tags = serializers.ListField(
@@ -92,6 +85,17 @@ class CreateOrEditProblemSerializer(serializers.Serializer):
)
share_submission = serializers.BooleanField()
# 流程图相关字段
allow_flowchart = serializers.BooleanField(required=False, default=False)
show_flowchart = serializers.BooleanField(required=False, default=False)
mermaid_code = serializers.CharField(
allow_blank=True, allow_null=True, required=False
)
flowchart_hint = serializers.CharField(
allow_blank=True, allow_null=True, required=False
)
class CreateProblemSerializer(CreateOrEditProblemSerializer):
pass
@@ -116,11 +120,6 @@ class TagSerializer(serializers.ModelSerializer):
fields = "__all__"
class CompileSPJSerializer(serializers.Serializer):
spj_language = SPJLanguageNameChoiceField()
spj_code = serializers.CharField()
class BaseProblemSerializer(serializers.ModelSerializer):
tags = serializers.SlugRelatedField(many=True, slug_field="name", read_only=True)
created_by = UsernameSerializer()
@@ -154,9 +153,6 @@ class ProblemSerializer(BaseProblemSerializer):
"test_case_id",
"visible",
"is_public",
"spj_code",
"spj_version",
"spj_compile_ok",
"answers",
)
@@ -188,9 +184,6 @@ class ProblemSafeSerializer(BaseProblemSerializer):
"test_case_id",
"visible",
"is_public",
"spj_code",
"spj_version",
"spj_compile_ok",
"difficulty",
"submission_number",
"accepted_number",
@@ -204,101 +197,12 @@ class ContestProblemMakePublicSerializer(serializers.Serializer):
display_id = serializers.CharField(max_length=32)
class ExportProblemSerializer(serializers.ModelSerializer):
display_id = serializers.SerializerMethodField()
description = serializers.SerializerMethodField()
input_description = serializers.SerializerMethodField()
output_description = serializers.SerializerMethodField()
test_case_score = serializers.SerializerMethodField()
hint = serializers.SerializerMethodField()
spj = serializers.SerializerMethodField()
template = serializers.SerializerMethodField()
source = serializers.SerializerMethodField()
tags = serializers.SlugRelatedField(many=True, slug_field="name", read_only=True)
def get_display_id(self, obj):
return obj._id
def _html_format_value(self, value):
return {"format": "html", "value": value}
def get_description(self, obj):
return self._html_format_value(obj.description)
def get_input_description(self, obj):
return self._html_format_value(obj.input_description)
def get_output_description(self, obj):
return self._html_format_value(obj.output_description)
def get_hint(self, obj):
return self._html_format_value(obj.hint)
def get_test_case_score(self, obj):
return [
{
"score": item["score"] if obj.rule_type == ProblemRuleType.OI else 100,
"input_name": item["input_name"],
"output_name": item["output_name"],
}
for item in obj.test_case_score
]
def get_spj(self, obj):
return {"code": obj.spj_code, "language": obj.spj_language} if obj.spj else None
def get_template(self, obj):
ret = {}
for k, v in obj.template.items():
ret[k] = parse_problem_template(v)
return ret
def get_source(self, obj):
return obj.source or f"{SysOptions.website_name} {SysOptions.website_base_url}"
class Meta:
model = Problem
fields = (
"display_id",
"title",
"description",
"tags",
"input_description",
"output_description",
"test_case_score",
"hint",
"time_limit",
"memory_limit",
"samples",
"template",
"spj",
"rule_type",
"source",
"template",
)
class AddContestProblemSerializer(serializers.Serializer):
contest_id = serializers.IntegerField()
problem_id = serializers.IntegerField()
display_id = serializers.CharField()
class ExportProblemRequestSerializer(serializers.Serializer):
problem_id = serializers.ListField(
child=serializers.IntegerField(), allow_empty=False
)
class UploadProblemForm(forms.Form):
file = forms.FileField()
class FormatValueSerializer(serializers.Serializer):
format = serializers.ChoiceField(choices=["html", "markdown"])
value = serializers.CharField(allow_blank=True)
class TestCaseScoreSerializer(serializers.Serializer):
score = serializers.IntegerField(min_value=1)
input_name = serializers.CharField(max_length=32)
@@ -311,58 +215,6 @@ class TemplateSerializer(serializers.Serializer):
append = serializers.CharField()
class SPJSerializer(serializers.Serializer):
code = serializers.CharField()
language = SPJLanguageNameChoiceField()
class AnswerSerializer(serializers.Serializer):
code = serializers.CharField()
language = LanguageNameChoiceField()
class ImportProblemSerializer(serializers.Serializer):
display_id = serializers.CharField(max_length=128)
title = serializers.CharField(max_length=128)
description = FormatValueSerializer()
input_description = FormatValueSerializer()
output_description = FormatValueSerializer()
hint = FormatValueSerializer()
test_case_score = serializers.ListField(
child=TestCaseScoreSerializer(), allow_null=True
)
time_limit = serializers.IntegerField(min_value=1, max_value=60000)
memory_limit = serializers.IntegerField(min_value=1, max_value=10240)
samples = serializers.ListField(child=CreateSampleSerializer())
template = serializers.DictField(child=TemplateSerializer())
spj = SPJSerializer(allow_null=True)
rule_type = serializers.ChoiceField(choices=ProblemRuleType.choices())
source = serializers.CharField(max_length=200, allow_blank=True, allow_null=True)
answers = serializers.ListField(child=AnswerSerializer())
tags = serializers.ListField(child=serializers.CharField())
class FPSProblemSerializer(serializers.Serializer):
class UnitSerializer(serializers.Serializer):
unit = serializers.ChoiceField(choices=["MB", "s", "ms"])
value = serializers.IntegerField(min_value=1, max_value=60000)
title = serializers.CharField(max_length=128)
description = serializers.CharField()
input = serializers.CharField()
output = serializers.CharField()
hint = serializers.CharField(allow_blank=True, allow_null=True)
time_limit = UnitSerializer()
memory_limit = UnitSerializer()
samples = serializers.ListField(child=CreateSampleSerializer())
source = serializers.CharField(max_length=200, allow_blank=True, allow_null=True)
spj = SPJSerializer(allow_null=True)
template = serializers.ListField(
child=serializers.DictField(), allow_empty=True, allow_null=True
)
append = serializers.ListField(
child=serializers.DictField(), allow_empty=True, allow_null=True
)
prepend = serializers.ListField(
child=serializers.DictField(), allow_empty=True, allow_null=True
)

View File

@@ -20,8 +20,7 @@ from .utils import parse_problem_template
DEFAULT_PROBLEM_DATA = {"_id": "A-110", "title": "test", "description": "<p>test</p>", "input_description": "test",
"output_description": "test", "time_limit": 1000, "memory_limit": 256, "difficulty": "Low",
"visible": True, "tags": ["test"], "languages": ["C", "C++", "Java", "Python2"], "template": {},
"samples": [{"input": "test", "output": "test"}], "spj": False, "spj_language": "C",
"spj_code": "", "spj_compile_ok": True, "test_case_id": "499b26290cc7994e0b497212e842ea85",
"samples": [{"input": "test", "output": "test"}], "test_case_id": "499b26290cc7994e0b497212e842ea85",
"test_case_score": [{"output_name": "1.out", "input_name": "1.in", "output_size": 0,
"stripped_output_md5": "d41d8cd98f00b204e9800998ecf8427e",
"input_size": 0, "score": 0}],
@@ -34,14 +33,6 @@ class ProblemCreateTestBase(APITestCase):
@staticmethod
def add_problem(problem_data, created_by):
data = copy.deepcopy(problem_data)
if data["spj"]:
if not data["spj_language"] or not data["spj_code"]:
raise ValueError("Invalid spj")
data["spj_version"] = hashlib.md5(
(data["spj_language"] + ":" + data["spj_code"]).encode("utf-8")).hexdigest()
else:
data["spj_language"] = None
data["spj_code"] = None
if data["rule_type"] == ProblemRuleType.OI:
total_score = 0
for item in data["test_case_score"]:
@@ -81,12 +72,9 @@ class TestCaseUploadAPITest(APITestCase):
self.create_super_admin()
def test_filter_file_name(self):
self.assertEqual(self.api.filter_name_list(["1.in", "1.out", "2.in", ".DS_Store"], spj=False),
self.assertEqual(self.api.filter_name_list(["1.in", "1.out", "2.in", ".DS_Store"]),
["1.in", "1.out"])
self.assertEqual(self.api.filter_name_list(["2.in", "2.out"], spj=False), [])
self.assertEqual(self.api.filter_name_list(["1.in", "1.out", "2.in"], spj=True), ["1.in", "2.in"])
self.assertEqual(self.api.filter_name_list(["2.in", "3.in"], spj=True), [])
self.assertEqual(self.api.filter_name_list(["2.in", "2.out"]), [])
def make_test_case_zip(self):
base_dir = os.path.join("/tmp", "test_case")
@@ -102,27 +90,13 @@ class TestCaseUploadAPITest(APITestCase):
f.write(os.path.join(base_dir, item), item)
return zip_file
def test_upload_spj_test_case_zip(self):
with open(self.make_test_case_zip(), "rb") as f:
resp = self.client.post(self.url,
data={"spj": "true", "file": f}, format="multipart")
self.assertSuccess(resp)
data = resp.data["data"]
self.assertEqual(data["spj"], True)
test_case_dir = os.path.join(settings.TEST_CASE_DIR, data["id"])
self.assertTrue(os.path.exists(test_case_dir))
for item in data["info"]:
name = item["input_name"]
with open(os.path.join(test_case_dir, name), "r", encoding="utf-8") as f:
self.assertEqual(f.read(), name + "\n" + name + "\n" + "end")
def test_upload_test_case_zip(self):
with open(self.make_test_case_zip(), "rb") as f:
resp = self.client.post(self.url,
data={"spj": "false", "file": f}, format="multipart")
data={"file": f}, format="multipart")
self.assertSuccess(resp)
data = resp.data["data"]
self.assertEqual(data["spj"], False)
test_case_dir = os.path.join(settings.TEST_CASE_DIR, data["id"])
self.assertTrue(os.path.exists(test_case_dir))
for item in data["info"]:
@@ -148,16 +122,6 @@ class ProblemAdminAPITest(APITestCase):
resp = self.client.post(self.url, data=self.data)
self.assertFailed(resp, "Display ID already exists")
def test_spj(self):
data = copy.deepcopy(self.data)
data["spj"] = True
resp = self.client.post(self.url, data)
self.assertFailed(resp, "Invalid spj")
data["spj_code"] = "test"
resp = self.client.post(self.url, data=data)
self.assertSuccess(resp)
def test_get_problem(self):
self.test_create_problem()

View File

@@ -1,18 +1,19 @@
from django.urls import path
from ..views.admin import (ContestProblemAPI, ProblemAPI, TestCaseAPI, MakeContestProblemPublicAPIView,
CompileSPJAPI, AddContestProblemAPI, ExportProblemAPI, ImportProblemAPI,
FPSProblemImport, ProblemVisibleAPI)
from ..views.admin import (
ContestProblemAPI,
ProblemAPI,
TestCaseAPI,
MakeContestProblemPublicAPIView,
AddContestProblemAPI,
ProblemVisibleAPI,
)
urlpatterns = [
path("test_case", TestCaseAPI.as_view()),
path("compile_spj", CompileSPJAPI.as_view()),
path("problem", ProblemAPI.as_view()),
path("problem/visible", ProblemVisibleAPI.as_view()),
path("contest/problem", ContestProblemAPI.as_view()),
path("contest_problem/make_public", MakeContestProblemPublicAPIView.as_view()),
path("contest/add_problem_from_public", AddContestProblemAPI.as_view()),
path("export_problem", ExportProblemAPI.as_view()),
path("import_problem", ImportProblemAPI.as_view()),
path("import_fps", FPSProblemImport.as_view()),
]

View File

@@ -3,29 +3,21 @@ import json
import os
# import shutil
import tempfile
import zipfile
from wsgiref.util import FileWrapper
from django.conf import settings
from django.db import transaction
from django.db.models import Q
from django.http import StreamingHttpResponse, FileResponse
from django.http import StreamingHttpResponse
from account.decorators import problem_permission_required, ensure_created_by
from contest.models import Contest, ContestStatus
from fps.parser import FPSHelper, FPSParser
from judge.dispatcher import SPJCompiler
from options.options import SysOptions
from submission.models import Submission, JudgeStatus
from submission.models import Submission
from utils.api import APIView, CSRFExemptAPIView, validate_serializer, APIError
from utils.constants import Difficulty
from utils.shortcuts import rand_str, natural_sort_key
from utils.tasks import delete_files
from ..models import Problem, ProblemRuleType, ProblemTag
from ..serializers import (
CreateContestProblemSerializer,
CompileSPJSerializer,
CreateProblemSerializer,
EditProblemSerializer,
EditContestProblemSerializer,
@@ -34,23 +26,17 @@ from ..serializers import (
TestCaseUploadForm,
ContestProblemMakePublicSerializer,
AddContestProblemSerializer,
ExportProblemSerializer,
ExportProblemRequestSerializer,
UploadProblemForm,
ImportProblemSerializer,
FPSProblemSerializer,
)
from ..utils import TEMPLATE_BASE, build_problem_template
class TestCaseZipProcessor(object):
def process_zip(self, uploaded_zip_file, spj, dir=""):
def process_zip(self, uploaded_zip_file, dir=""):
try:
zip_file = zipfile.ZipFile(uploaded_zip_file, "r")
except zipfile.BadZipFile:
raise APIError("Bad zip file")
name_list = zip_file.namelist()
test_case_list = self.filter_name_list(name_list, spj=spj, dir=dir)
test_case_list = self.filter_name_list(name_list, dir=dir)
if not test_case_list:
raise APIError("Empty file")
@@ -69,28 +55,22 @@ class TestCaseZipProcessor(object):
if item.endswith(".out"):
md5_cache[item] = hashlib.md5(content.rstrip()).hexdigest()
f.write(content)
test_case_info = {"spj": spj, "test_cases": {}}
test_case_info = {"test_cases": {}}
info = []
if spj:
for index, item in enumerate(test_case_list):
data = {"input_name": item, "input_size": size_cache[item]}
info.append(data)
test_case_info["test_cases"][str(index + 1)] = data
else:
# ["1.in", "1.out", "2.in", "2.out"] => [("1.in", "1.out"), ("2.in", "2.out")]
test_case_list = zip(*[test_case_list[i::2] for i in range(2)])
for index, item in enumerate(test_case_list):
data = {
"stripped_output_md5": md5_cache[item[1]],
"input_size": size_cache[item[0]],
"output_size": size_cache[item[1]],
"input_name": item[0],
"output_name": item[1],
}
info.append(data)
test_case_info["test_cases"][str(index + 1)] = data
# ["1.in", "1.out", "2.in", "2.out"] => [("1.in", "1.out"), ("2.in", "2.out")]
test_case_list = zip(*[test_case_list[i::2] for i in range(2)])
for index, item in enumerate(test_case_list):
data = {
"stripped_output_md5": md5_cache[item[1]],
"input_size": size_cache[item[0]],
"output_size": size_cache[item[1]],
"input_name": item[0],
"output_name": item[1],
}
info.append(data)
test_case_info["test_cases"][str(index + 1)] = data
with open(os.path.join(test_case_dir, "info"), "w", encoding="utf-8") as f:
f.write(json.dumps(test_case_info, indent=4))
@@ -100,29 +80,19 @@ class TestCaseZipProcessor(object):
return info, test_case_id
def filter_name_list(self, name_list, spj, dir=""):
def filter_name_list(self, name_list, dir=""):
ret = []
prefix = 1
if spj:
while True:
in_name = f"{prefix}.in"
if f"{dir}{in_name}" in name_list:
ret.append(in_name)
prefix += 1
continue
else:
return sorted(ret, key=natural_sort_key)
else:
while True:
in_name = f"{prefix}.in"
out_name = f"{prefix}.out"
if f"{dir}{in_name}" in name_list and f"{dir}{out_name}" in name_list:
ret.append(in_name)
ret.append(out_name)
prefix += 1
continue
else:
return sorted(ret, key=natural_sort_key)
while True:
in_name = f"{prefix}.in"
out_name = f"{prefix}.out"
if f"{dir}{in_name}" in name_list and f"{dir}{out_name}" in name_list:
ret.append(in_name)
ret.append(out_name)
prefix += 1
continue
else:
return sorted(ret, key=natural_sort_key)
class TestCaseAPI(CSRFExemptAPIView, TestCaseZipProcessor):
@@ -145,7 +115,7 @@ class TestCaseAPI(CSRFExemptAPIView, TestCaseZipProcessor):
test_case_dir = os.path.join(settings.TEST_CASE_DIR, problem.test_case_id)
if not os.path.isdir(test_case_dir):
return self.error("Test case does not exists")
name_list = self.filter_name_list(os.listdir(test_case_dir), problem.spj)
name_list = self.filter_name_list(os.listdir(test_case_dir))
name_list.append("info")
file_name = os.path.join(test_case_dir, problem.test_case_id + ".zip")
with zipfile.ZipFile(file_name, "w") as file:
@@ -164,7 +134,6 @@ class TestCaseAPI(CSRFExemptAPIView, TestCaseZipProcessor):
def post(self, request):
form = TestCaseUploadForm(request.POST, request.FILES)
if form.is_valid():
spj = form.cleaned_data["spj"] == "true"
file = form.cleaned_data["file"]
else:
return self.error("Upload failed")
@@ -172,39 +141,14 @@ class TestCaseAPI(CSRFExemptAPIView, TestCaseZipProcessor):
with open(zip_file, "wb") as f:
for chunk in file:
f.write(chunk)
info, test_case_id = self.process_zip(zip_file, spj=spj)
info, test_case_id = self.process_zip(zip_file)
os.remove(zip_file)
return self.success({"id": test_case_id, "info": info, "spj": spj})
class CompileSPJAPI(APIView):
@validate_serializer(CompileSPJSerializer)
def post(self, request):
data = request.data
spj_version = rand_str(8)
error = SPJCompiler(
data["spj_code"], spj_version, data["spj_language"]
).compile_spj()
if error:
return self.error(error)
else:
return self.success()
return self.success({"id": test_case_id, "info": info})
class ProblemBase(APIView):
def common_checks(self, request):
data = request.data
if data["spj"]:
if not data["spj_language"] or not data["spj_code"]:
return "Invalid spj"
if not data["spj_compile_ok"]:
return "SPJ code must be compiled successfully"
data["spj_version"] = hashlib.md5(
(data["spj_language"] + ":" + data["spj_code"]).encode("utf-8")
).hexdigest()
else:
data["spj_language"] = None
data["spj_code"] = None
if data["rule_type"] == ProblemRuleType.OI:
total_score = 0
for item in data["test_case_score"]:
@@ -529,257 +473,6 @@ class AddContestProblemAPI(APIView):
return self.success()
class ExportProblemAPI(APIView):
def choose_answers(self, user, problem):
ret = []
for item in problem.languages:
submission = (
Submission.objects.filter(
problem=problem,
user_id=user.id,
language=item,
result=JudgeStatus.ACCEPTED,
)
.order_by("-create_time")
.first()
)
if submission:
ret.append({"language": submission.language, "code": submission.code})
return ret
def process_one_problem(self, zip_file, user, problem, index):
info = ExportProblemSerializer(problem).data
info["answers"] = self.choose_answers(user, problem=problem)
compression = zipfile.ZIP_DEFLATED
zip_file.writestr(
zinfo_or_arcname=f"{index}/problem.json",
data=json.dumps(info, indent=4),
compress_type=compression,
)
problem_test_case_dir = os.path.join(
settings.TEST_CASE_DIR, problem.test_case_id
)
with open(os.path.join(problem_test_case_dir, "info")) as f:
info = json.load(f)
for k, v in info["test_cases"].items():
zip_file.write(
filename=os.path.join(problem_test_case_dir, v["input_name"]),
arcname=f"{index}/testcase/{v['input_name']}",
compress_type=compression,
)
if not info["spj"]:
zip_file.write(
filename=os.path.join(problem_test_case_dir, v["output_name"]),
arcname=f"{index}/testcase/{v['output_name']}",
compress_type=compression,
)
@validate_serializer(ExportProblemRequestSerializer)
def get(self, request):
problems = Problem.objects.filter(id__in=request.data["problem_id"])
for problem in problems:
if problem.contest:
ensure_created_by(problem.contest, request.user)
else:
ensure_created_by(problem, request.user)
path = f"/tmp/{rand_str()}.zip"
with zipfile.ZipFile(path, "w") as zip_file:
for index, problem in enumerate(problems):
self.process_one_problem(
zip_file=zip_file,
user=request.user,
problem=problem,
index=index + 1,
)
delete_files.send_with_options(args=(path,), delay=300_000)
resp = FileResponse(open(path, "rb"))
resp["Content-Type"] = "application/zip"
resp["Content-Disposition"] = "attachment;filename=problem-export.zip"
return resp
class ImportProblemAPI(CSRFExemptAPIView, TestCaseZipProcessor):
request_parsers = ()
def post(self, request):
form = UploadProblemForm(request.POST, request.FILES)
if form.is_valid():
file = form.cleaned_data["file"]
tmp_file = f"/tmp/{rand_str()}.zip"
with open(tmp_file, "wb") as f:
for chunk in file:
f.write(chunk)
else:
return self.error("Upload failed")
count = 0
with zipfile.ZipFile(tmp_file, "r") as zip_file:
name_list = zip_file.namelist()
for item in name_list:
if "/problem.json" in item:
count += 1
with transaction.atomic():
for i in range(1, count + 1):
with zip_file.open(f"{i}/problem.json") as f:
problem_info = json.load(f)
serializer = ImportProblemSerializer(data=problem_info)
if not serializer.is_valid():
return self.error(
f"Invalid problem format, error is {serializer.errors}"
)
else:
problem_info = serializer.data
for item in problem_info["template"].keys():
if item not in SysOptions.language_names:
return self.error(f"Unsupported language {item}")
problem_info["display_id"] = problem_info["display_id"][:24]
for k, v in problem_info["template"].items():
problem_info["template"][k] = build_problem_template(
v["prepend"], v["template"], v["append"]
)
spj = problem_info["spj"] is not None
rule_type = problem_info["rule_type"]
test_case_score = problem_info["test_case_score"]
# process test case
_, test_case_id = self.process_zip(
tmp_file, spj=spj, dir=f"{i}/testcase/"
)
problem_obj = Problem.objects.create(
_id=problem_info["display_id"],
title=problem_info["title"],
description=problem_info["description"]["value"],
input_description=problem_info["input_description"][
"value"
],
output_description=problem_info["output_description"][
"value"
],
hint=problem_info["hint"]["value"],
test_case_score=test_case_score if test_case_score else [],
time_limit=problem_info["time_limit"],
memory_limit=problem_info["memory_limit"],
samples=problem_info["samples"],
template=problem_info["template"],
rule_type=problem_info["rule_type"],
source=problem_info["source"],
spj=spj,
spj_code=problem_info["spj"]["code"] if spj else None,
spj_language=problem_info["spj"]["language"]
if spj
else None,
spj_version=rand_str(8) if spj else "",
languages=SysOptions.language_names,
created_by=request.user,
visible=False,
difficulty=Difficulty.MID,
total_score=sum(item["score"] for item in test_case_score)
if rule_type == ProblemRuleType.OI
else 0,
test_case_id=test_case_id,
)
for tag_name in problem_info["tags"]:
tag_obj, _ = ProblemTag.objects.get_or_create(name=tag_name)
problem_obj.tags.add(tag_obj)
return self.success({"import_count": count})
class FPSProblemImport(CSRFExemptAPIView):
request_parsers = ()
def _create_problem(self, problem_data, creator):
if problem_data["time_limit"]["unit"] == "ms":
time_limit = problem_data["time_limit"]["value"]
else:
time_limit = problem_data["time_limit"]["value"] * 1000
template = {}
prepend = {}
append = {}
for t in problem_data["prepend"]:
prepend[t["language"]] = t["code"]
for t in problem_data["append"]:
append[t["language"]] = t["code"]
for t in problem_data["template"]:
our_lang = lang = t["language"]
if lang == "Python":
our_lang = "Python3"
template[our_lang] = TEMPLATE_BASE.format(
prepend.get(lang, ""), t["code"], append.get(lang, "")
)
spj = problem_data["spj"] is not None
Problem.objects.create(
_id=f"fps-{rand_str(4)}",
title=problem_data["title"],
description=problem_data["description"],
input_description=problem_data["input"],
output_description=problem_data["output"],
hint=problem_data["hint"],
test_case_score=problem_data["test_case_score"],
time_limit=time_limit,
memory_limit=problem_data["memory_limit"]["value"],
samples=problem_data["samples"],
template=template,
rule_type=ProblemRuleType.ACM,
source=problem_data.get("source", ""),
spj=spj,
spj_code=problem_data["spj"]["code"] if spj else None,
spj_language=problem_data["spj"]["language"] if spj else None,
spj_version=rand_str(8) if spj else "",
visible=False,
languages=SysOptions.language_names,
created_by=creator,
difficulty=Difficulty.MID,
test_case_id=problem_data["test_case_id"],
)
def post(self, request):
form = UploadProblemForm(request.POST, request.FILES)
if form.is_valid():
file = form.cleaned_data["file"]
with tempfile.NamedTemporaryFile("wb") as tf:
for chunk in file.chunks(4096):
tf.file.write(chunk)
tf.file.flush()
os.fsync(tf.file)
problems = FPSParser(tf.name).parse()
else:
return self.error("Parse upload file error")
helper = FPSHelper()
with transaction.atomic():
for _problem in problems:
test_case_id = rand_str()
test_case_dir = os.path.join(settings.TEST_CASE_DIR, test_case_id)
os.mkdir(test_case_dir)
score = []
for item in helper.save_test_case(_problem, test_case_dir)[
"test_cases"
].values():
score.append(
{
"score": 0,
"input_name": item["input_name"],
"output_name": item.get("output_name"),
}
)
problem_data = helper.save_image(
_problem, settings.UPLOAD_DIR, settings.UPLOAD_PREFIX
)
s = FPSProblemSerializer(data=problem_data)
if not s.is_valid():
return self.error(f"Parse FPS file error: {s.errors}")
problem_data = s.data
problem_data["test_case_id"] = test_case_id
problem_data["test_case_score"] = score
self._create_problem(problem_data, request.user)
return self.success({"import_count": len(problems)})
class ProblemVisibleAPI(APIView):
@problem_permission_required
def put(self, request):