support fps problems import; fix qduoj problems import

This commit is contained in:
virusdefender
2018-01-07 14:37:14 +08:00
parent 79724b0463
commit 82890a92b4
8 changed files with 483 additions and 119 deletions

View File

@@ -3,35 +3,91 @@ import json
import os
import shutil
import zipfile
import tempfile
from wsgiref.util import FileWrapper
from django.conf import settings
from django.http import StreamingHttpResponse, HttpResponse
from django.http import StreamingHttpResponse, HttpResponse, FileResponse
from django.db import transaction
from account.decorators import problem_permission_required, ensure_created_by
from judge.dispatcher import SPJCompiler
from judge.languages import language_names
from contest.models import Contest, ContestStatus
from submission.models import Submission
from utils.api import APIView, CSRFExemptAPIView, validate_serializer
from submission.models import Submission, JudgeStatus
from fps.parser import FPSHelper, FPSParser
from utils.api import APIView, CSRFExemptAPIView, validate_serializer, APIError
from utils.shortcuts import rand_str, natural_sort_key
from utils.tasks import delete_files
from utils.constants import Difficulty
from ..utils import TEMPLATE_BASE, build_problem_template
from ..models import Problem, ProblemRuleType, ProblemTag
from ..serializers import (CreateContestProblemSerializer, CompileSPJSerializer,
CreateProblemSerializer, EditProblemSerializer, EditContestProblemSerializer,
ProblemAdminSerializer, TestCaseUploadForm, ContestProblemMakePublicSerializer,
AddContestProblemSerializer)
AddContestProblemSerializer, ExportProblemSerializer,
ExportProblemRequestSerialzier, UploadProblemForm, ImportProblemSerializer,
FPSProblemSerializer)
class TestCaseAPI(CSRFExemptAPIView):
request_parsers = ()
class TestCaseZipProcessor(object):
def process_zip(self, uploaded_zip_file, spj, dir=""):
try:
zip_file = zipfile.ZipFile(uploaded_zip_file, "r")
except zipfile.BadZipFile:
raise APIError("Bad zip file")
name_list = zip_file.namelist()
test_case_list = self.filter_name_list(name_list, spj=spj, dir=dir)
if not test_case_list:
raise APIError("Empty file")
def filter_name_list(self, name_list, spj):
test_case_id = rand_str()
test_case_dir = os.path.join(settings.TEST_CASE_DIR, test_case_id)
os.mkdir(test_case_dir)
size_cache = {}
md5_cache = {}
for item in test_case_list:
with open(os.path.join(test_case_dir, item), "wb") as f:
content = zip_file.read(f"{dir}{item}").replace(b"\r\n", b"\n")
size_cache[item] = len(content)
if item.endswith(".out"):
md5_cache[item] = hashlib.md5(content.rstrip()).hexdigest()
f.write(content)
test_case_info = {"spj": spj, "test_cases": {}}
info = []
if spj:
for index, item in enumerate(test_case_list):
data = {"input_name": item, "input_size": size_cache[item]}
info.append(data)
test_case_info["test_cases"][str(index + 1)] = data
else:
# ["1.in", "1.out", "2.in", "2.out"] => [("1.in", "1.out"), ("2.in", "2.out")]
test_case_list = zip(*[test_case_list[i::2] for i in range(2)])
for index, item in enumerate(test_case_list):
data = {"stripped_output_md5": md5_cache[item[1]],
"input_size": size_cache[item[0]],
"output_size": size_cache[item[1]],
"input_name": item[0],
"output_name": item[1]}
info.append(data)
test_case_info["test_cases"][str(index + 1)] = data
with open(os.path.join(test_case_dir, "info"), "w", encoding="utf-8") as f:
f.write(json.dumps(test_case_info, indent=4))
return info, test_case_id
def filter_name_list(self, name_list, spj, dir=""):
ret = []
prefix = 1
if spj:
while True:
in_name = str(prefix) + ".in"
if in_name in name_list:
in_name = f"{prefix}.in"
if f"{dir}{in_name}" in name_list:
ret.append(in_name)
prefix += 1
continue
@@ -39,9 +95,9 @@ class TestCaseAPI(CSRFExemptAPIView):
return sorted(ret, key=natural_sort_key)
else:
while True:
in_name = str(prefix) + ".in"
out_name = str(prefix) + ".out"
if in_name in name_list and out_name in name_list:
in_name = f"{prefix}.in"
out_name = f"{prefix}.out"
if f"{dir}{in_name}" in name_list and f"{dir}{out_name}" in name_list:
ret.append(in_name)
ret.append(out_name)
prefix += 1
@@ -49,6 +105,10 @@ class TestCaseAPI(CSRFExemptAPIView):
else:
return sorted(ret, key=natural_sort_key)
class TestCaseAPI(CSRFExemptAPIView, TestCaseZipProcessor):
request_parsers = ()
def get(self, request):
problem_id = request.GET.get("problem_id")
if not problem_id:
@@ -90,62 +150,13 @@ class TestCaseAPI(CSRFExemptAPIView):
file = form.cleaned_data["file"]
else:
return self.error("Upload failed")
tmp_file = os.path.join("/tmp", rand_str() + ".zip")
with open(tmp_file, "wb") as f:
zip_file = f"/tmp/{rand_str()}.zip"
with open(zip_file, "wb") as f:
for chunk in file:
f.write(chunk)
try:
zip_file = zipfile.ZipFile(tmp_file)
except zipfile.BadZipFile:
return self.error("Bad zip file")
name_list = zip_file.namelist()
test_case_list = self.filter_name_list(name_list, spj=spj)
if not test_case_list:
return self.error("Empty file")
test_case_id = rand_str()
test_case_dir = os.path.join(settings.TEST_CASE_DIR, test_case_id)
os.mkdir(test_case_dir)
size_cache = {}
md5_cache = {}
for item in test_case_list:
with open(os.path.join(test_case_dir, item), "wb") as f:
content = zip_file.read(item).replace(b"\r\n", b"\n")
size_cache[item] = len(content)
if item.endswith(".out"):
md5_cache[item] = hashlib.md5(content.rstrip()).hexdigest()
f.write(content)
test_case_info = {"spj": spj, "test_cases": {}}
hint = None
diff = set(name_list).difference(set(test_case_list))
if diff:
hint = ", ".join(diff) + " are ignored"
ret = []
if spj:
for index, item in enumerate(test_case_list):
data = {"input_name": item, "input_size": size_cache[item]}
ret.append(data)
test_case_info["test_cases"][str(index + 1)] = data
else:
# ["1.in", "1.out", "2.in", "2.out"] => [("1.in", "1.out"), ("2.in", "2.out")]
test_case_list = zip(*[test_case_list[i::2] for i in range(2)])
for index, item in enumerate(test_case_list):
data = {"stripped_output_md5": md5_cache[item[1]],
"input_size": size_cache[item[0]],
"output_size": size_cache[item[1]],
"input_name": item[0],
"output_name": item[1]}
ret.append(data)
test_case_info["test_cases"][str(index + 1)] = data
with open(os.path.join(test_case_dir, "info"), "w", encoding="utf-8") as f:
f.write(json.dumps(test_case_info, indent=4))
return self.success({"id": test_case_id, "info": ret, "hint": hint, "spj": spj})
info, test_case_id = self.process_zip(zip_file, spj=spj)
os.remove(zip_file)
return self.success({"id": test_case_id, "info": info, "spj": spj})
class CompileSPJAPI(APIView):
@@ -466,3 +477,204 @@ class AddContestProblemAPI(APIView):
problem.save()
problem.tags.set(tags)
return self.success()
class ExportProblemAPI(APIView):
def choose_answers(self, user, problem):
ret = []
for item in problem.languages:
submission = Submission.objects.filter(problem=problem,
user_id=user.id,
language=item,
result=JudgeStatus.ACCEPTED).order_by("-create_time").first()
if submission:
ret.append({"language": submission.language, "code": submission.code})
return ret
def process_one_problem(self, zip_file, user, problem, index):
info = ExportProblemSerializer(problem).data
info["answers"] = self.choose_answers(user, problem=problem)
compression = zipfile.ZIP_DEFLATED
zip_file.writestr(zinfo_or_arcname=f"{index}/problem.json",
data=json.dumps(info, indent=4),
compress_type=compression)
problem_test_case_dir = os.path.join(settings.TEST_CASE_DIR, problem.test_case_id)
with open(os.path.join(problem_test_case_dir, "info")) as f:
info = json.load(f)
for k, v in info["test_cases"].items():
zip_file.write(filename=os.path.join(problem_test_case_dir, v["input_name"]),
arcname=f"{index}/testcase/{v['input_name']}",
compress_type=compression)
if not info["spj"]:
zip_file.write(filename=os.path.join(problem_test_case_dir, v["output_name"]),
arcname=f"{index}/testcase/{v['output_name']}",
compress_type=compression)
@validate_serializer(ExportProblemRequestSerialzier)
def get(self, request):
problems = Problem.objects.filter(id__in=request.data["problem_id"])
for problem in problems:
if problem.contest:
ensure_created_by(problem.contest, request.user)
else:
ensure_created_by(problem, request.user)
path = f"/tmp/{rand_str()}.zip"
with zipfile.ZipFile(path, "w") as zip_file:
for index, problem in enumerate(problems):
self.process_one_problem(zip_file=zip_file, user=request.user, problem=problem, index=index + 1)
delete_files.apply_async((path,), countdown=300)
resp = FileResponse(open(path, "rb"))
resp["Content-Type"] = "application/zip"
resp["Content-Disposition"] = f"attachment;filename=problem-export.zip"
return resp
class ImportProblemAPI(CSRFExemptAPIView, TestCaseZipProcessor):
request_parsers = ()
def post(self, request):
form = UploadProblemForm(request.POST, request.FILES)
if form.is_valid():
file = form.cleaned_data["file"]
tmp_file = f"/tmp/{rand_str()}.zip"
with open(tmp_file, "wb") as f:
for chunk in file:
f.write(chunk)
else:
return self.error("Upload failed")
count = 0
with zipfile.ZipFile(tmp_file, "r") as zip_file:
name_list = zip_file.namelist()
for item in name_list:
if "/problem.json" in item:
count += 1
with transaction.atomic():
for i in range(1, count + 1):
with zip_file.open(f"{i}/problem.json") as f:
problem_info = json.load(f)
serializer = ImportProblemSerializer(data=problem_info)
if not serializer.is_valid():
return self.error(f"Invalid problem format, error is {serializer.errors}")
else:
problem_info = serializer.data
for item in problem_info["template"].keys():
if item not in language_names:
return self.error(f"Unsupported language {item}")
problem_info["display_id"] = problem_info["display_id"][:24]
for k, v in problem_info["template"].items():
problem_info["template"][k] = build_problem_template(v["prepend"], v["template"],
v["append"])
spj = problem_info["spj"] is not None
rule_type = problem_info["rule_type"]
test_case_score = problem_info["test_case_score"]
# process test case
_, test_case_id = self.process_zip(tmp_file, spj=spj, dir=f"{i}/testcase/")
problem_obj = Problem.objects.create(_id=problem_info["display_id"],
title=problem_info["title"],
description=problem_info["description"]["value"],
input_description=problem_info["input_description"][
"value"],
output_description=problem_info["output_description"][
"value"],
hint=problem_info["hint"]["value"],
test_case_score=test_case_score if test_case_score else [],
time_limit=problem_info["time_limit"],
memory_limit=problem_info["memory_limit"],
samples=problem_info["samples"],
template=problem_info["template"],
rule_type=problem_info["rule_type"],
source=problem_info["source"],
spj=spj,
spj_code=problem_info["spj"]["code"] if spj else None,
spj_language=problem_info["spj"][
"language"] if spj else None,
spj_version=rand_str(8) if spj else "",
languages=language_names,
created_by=request.user,
visible=False,
difficulty=Difficulty.MID,
total_score=sum(item["score"] for item in test_case_score)
if rule_type == ProblemRuleType.OI else 0,
test_case_id=test_case_id
)
for tag_name in problem_info["tags"]:
tag_obj, _ = ProblemTag.objects.get_or_create(name=tag_name)
problem_obj.tags.add(tag_obj)
return self.success({"import_count": count})
class FPSProblemImport(CSRFExemptAPIView):
request_parsers = ()
def _create_problem(self, problem_data, creator):
if problem_data["time_limit"]["unit"] == "ms":
time_limit = problem_data["time_limit"]["value"]
else:
time_limit = problem_data["time_limit"]["value"] * 1000
template = {}
prepend = {}
append = {}
for t in problem_data["prepend"]:
prepend[t["language"]] = t["code"]
for t in problem_data["append"]:
append[t["language"]] = t["code"]
for t in problem_data["template"]:
our_lang = lang = t["language"]
if lang == "Python":
our_lang = "Python3"
template[our_lang] = TEMPLATE_BASE.format(prepend.get(lang, ""), t["code"], append.get(lang, ""))
spj = problem_data["spj"] is not None
Problem.objects.create(_id=f"fps-{rand_str(4)}",
title=problem_data["title"],
description=problem_data["description"],
input_description=problem_data["input"],
output_description=problem_data["output"],
hint=problem_data["hint"],
test_case_score=[],
time_limit=time_limit,
memory_limit=problem_data["memory_limit"]["value"],
samples=problem_data["samples"],
template=template,
rule_type=ProblemRuleType.ACM,
source=problem_data.get("source", ""),
spj=spj,
spj_code=problem_data["spj"]["code"] if spj else None,
spj_language=problem_data["spj"]["language"] if spj else None,
spj_version=rand_str(8) if spj else "",
visible=False,
languages=language_names,
created_by=creator,
difficulty=Difficulty.MID,
test_case_id=problem_data["test_case_id"])
def post(self, request):
form = UploadProblemForm(request.POST, request.FILES)
if form.is_valid():
file = form.cleaned_data["file"]
with tempfile.NamedTemporaryFile("wb") as tf:
for chunk in file.chunks(4096):
tf.file.write(chunk)
problems = FPSParser(tf.name).parse()
else:
return self.error("Parse upload file error")
helper = FPSHelper()
with transaction.atomic():
for _problem in problems:
test_case_id = rand_str()
test_case_dir = os.path.join(settings.TEST_CASE_DIR, test_case_id)
os.mkdir(test_case_dir)
helper.save_test_case(_problem, test_case_dir)
problem_data = helper.save_image(_problem, settings.UPLOAD_DIR, settings.UPLOAD_PREFIX)
s = FPSProblemSerializer(data=problem_data)
if not s.is_valid():
return self.error(f"Parse FPS file error: {s.errors}")
problem_data = s.data
problem_data["test_case_id"] = test_case_id
self._create_problem(problem_data, request.user)
return self.success({"import_count": len(problems)})