import hashlib import json import os # import shutil import zipfile from wsgiref.util import FileWrapper from django.conf import settings from django.db.models import Q from django.http import StreamingHttpResponse from account.decorators import problem_permission_required, ensure_created_by from contest.models import Contest, ContestStatus from submission.models import Submission from utils.api import APIView, CSRFExemptAPIView, validate_serializer, APIError from utils.shortcuts import rand_str, natural_sort_key from ..models import Problem, ProblemRuleType, ProblemTag from ..serializers import ( CreateContestProblemSerializer, CreateProblemSerializer, EditProblemSerializer, EditContestProblemSerializer, ProblemAdminSerializer, ProblemAdminListSerializer, TestCaseUploadForm, ContestProblemMakePublicSerializer, AddContestProblemSerializer, ) class TestCaseZipProcessor(object): def process_zip(self, uploaded_zip_file, dir=""): try: zip_file = zipfile.ZipFile(uploaded_zip_file, "r") except zipfile.BadZipFile: raise APIError("Bad zip file") name_list = zip_file.namelist() test_case_list = self.filter_name_list(name_list, dir=dir) if not test_case_list: raise APIError("Empty file") test_case_id = rand_str() test_case_dir = os.path.join(settings.TEST_CASE_DIR, test_case_id) os.mkdir(test_case_dir) os.chmod(test_case_dir, 0o710) size_cache = {} md5_cache = {} for item in test_case_list: with open(os.path.join(test_case_dir, item), "wb") as f: content = zip_file.read(f"{dir}{item}").replace(b"\r\n", b"\n") size_cache[item] = len(content) if item.endswith(".out"): md5_cache[item] = hashlib.md5(content.rstrip()).hexdigest() f.write(content) test_case_info = {"test_cases": {}} info = [] # ["1.in", "1.out", "2.in", "2.out"] => [("1.in", "1.out"), ("2.in", "2.out")] test_case_list = zip(*[test_case_list[i::2] for i in range(2)]) for index, item in enumerate(test_case_list): data = { "stripped_output_md5": md5_cache[item[1]], "input_size": size_cache[item[0]], "output_size": size_cache[item[1]], "input_name": item[0], "output_name": item[1], } info.append(data) test_case_info["test_cases"][str(index + 1)] = data with open(os.path.join(test_case_dir, "info"), "w", encoding="utf-8") as f: f.write(json.dumps(test_case_info, indent=4)) for item in os.listdir(test_case_dir): os.chmod(os.path.join(test_case_dir, item), 0o640) return info, test_case_id def filter_name_list(self, name_list, dir=""): ret = [] prefix = 1 while True: in_name = f"{prefix}.in" out_name = f"{prefix}.out" if f"{dir}{in_name}" in name_list and f"{dir}{out_name}" in name_list: ret.append(in_name) ret.append(out_name) prefix += 1 continue else: return sorted(ret, key=natural_sort_key) class TestCaseAPI(CSRFExemptAPIView, TestCaseZipProcessor): request_parsers = () def get(self, request): problem_id = request.GET.get("problem_id") if not problem_id: return self.error("Parameter error, problem_id is required") try: problem = Problem.objects.get(id=problem_id) except Problem.DoesNotExist: return self.error("Problem does not exists") if problem.contest: ensure_created_by(problem.contest, request.user) else: ensure_created_by(problem, request.user) test_case_dir = os.path.join(settings.TEST_CASE_DIR, problem.test_case_id) if not os.path.isdir(test_case_dir): return self.error("Test case does not exists") name_list = self.filter_name_list(os.listdir(test_case_dir)) name_list.append("info") file_name = os.path.join(test_case_dir, problem.test_case_id + ".zip") with zipfile.ZipFile(file_name, "w") as file: for test_case in name_list: file.write(f"{test_case_dir}/{test_case}", test_case) response = StreamingHttpResponse( FileWrapper(open(file_name, "rb")), content_type="application/octet-stream" ) response["Content-Disposition"] = ( f"attachment; filename=problem_{problem.id}_test_cases.zip" ) response["Content-Length"] = os.path.getsize(file_name) return response def post(self, request): form = TestCaseUploadForm(request.POST, request.FILES) if form.is_valid(): file = form.cleaned_data["file"] else: return self.error("Upload failed") zip_file = f"/tmp/{rand_str()}.zip" with open(zip_file, "wb") as f: for chunk in file: f.write(chunk) info, test_case_id = self.process_zip(zip_file) os.remove(zip_file) return self.success({"id": test_case_id, "info": info}) class ProblemBase(APIView): def common_checks(self, request): data = request.data if data["rule_type"] == ProblemRuleType.OI: total_score = 0 for item in data["test_case_score"]: if item["score"] <= 0: return "Invalid score" else: total_score += item["score"] data["total_score"] = total_score data["languages"] = list(data["languages"]) class ProblemAPI(ProblemBase): @problem_permission_required @validate_serializer(CreateProblemSerializer) def post(self, request): data = request.data _id = data["_id"] if not _id: return self.error("Display ID is required") if Problem.objects.filter(_id=_id, contest_id__isnull=True).exists(): return self.error("Display ID already exists") error_info = self.common_checks(request) if error_info: return self.error(error_info) # todo check filename and score info tags = data.pop("tags") data["created_by"] = request.user problem = Problem.objects.create(**data) for item in tags: try: tag = ProblemTag.objects.get(name=item) except ProblemTag.DoesNotExist: tag = ProblemTag.objects.create(name=item) problem.tags.add(tag) return self.success(ProblemAdminSerializer(problem).data) @problem_permission_required def get(self, request): problem_id = request.GET.get("id") user = request.user if problem_id: try: problem = Problem.objects.get(id=problem_id) ensure_created_by(problem, request.user) return self.success(ProblemAdminSerializer(problem).data) except Problem.DoesNotExist: return self.error("Problem does not exist") problems = Problem.objects.filter(contest_id__isnull=True).order_by( "-create_time" ) author = request.GET.get("author", "") if author: problems = problems.filter(created_by__username=author) keyword = request.GET.get("keyword", "").strip() if keyword: problems = problems.filter( Q(title__icontains=keyword) | Q(_id__icontains=keyword) ) if not user.can_mgmt_all_problem(): problems = problems.filter(created_by=user) return self.success( self.paginate_data(request, problems, ProblemAdminListSerializer) ) @problem_permission_required @validate_serializer(EditProblemSerializer) def put(self, request): data = request.data problem_id = data.pop("id") try: problem = Problem.objects.get(id=problem_id) ensure_created_by(problem, request.user) except Problem.DoesNotExist: return self.error("Problem does not exist") _id = data["_id"] if not _id: return self.error("Display ID is required") if ( Problem.objects.exclude(id=problem_id) .filter(_id=_id, contest_id__isnull=True) .exists() ): return self.error("Display ID already exists") error_info = self.common_checks(request) if error_info: return self.error(error_info) # todo check filename and score info tags = data.pop("tags") data["languages"] = list(data["languages"]) for k, v in data.items(): setattr(problem, k, v) problem.save() problem.tags.remove(*problem.tags.all()) for tag in tags: try: tag = ProblemTag.objects.get(name=tag) except ProblemTag.DoesNotExist: tag = ProblemTag.objects.create(name=tag) problem.tags.add(tag) return self.success() @problem_permission_required def delete(self, request): id = request.GET.get("id") if not id: return self.error("Invalid parameter, id is required") try: problem = Problem.objects.get(id=id, contest_id__isnull=True) except Problem.DoesNotExist: return self.error("Problem does not exists") ensure_created_by(problem, request.user) # d = os.path.join(settings.TEST_CASE_DIR, problem.test_case_id) # if os.path.isdir(d): # shutil.rmtree(d, ignore_errors=True) problem.delete() return self.success() class ContestProblemAPI(ProblemBase): @validate_serializer(CreateContestProblemSerializer) def post(self, request): data = request.data try: contest = Contest.objects.get(id=data.pop("contest_id")) ensure_created_by(contest, request.user) except Contest.DoesNotExist: return self.error("Contest does not exist") if data["rule_type"] != contest.rule_type: return self.error("Invalid rule type") _id = data["_id"] if not _id: return self.error("Display ID is required") if Problem.objects.filter(_id=_id, contest=contest).exists(): return self.error("Duplicate Display id") error_info = self.common_checks(request) if error_info: return self.error(error_info) # todo check filename and score info data["contest"] = contest tags = data.pop("tags") data["created_by"] = request.user problem = Problem.objects.create(**data) for item in tags: try: tag = ProblemTag.objects.get(name=item) except ProblemTag.DoesNotExist: tag = ProblemTag.objects.create(name=item) problem.tags.add(tag) return self.success(ProblemAdminSerializer(problem).data) def get(self, request): problem_id = request.GET.get("id") contest_id = request.GET.get("contest_id") user = request.user if problem_id: try: problem = Problem.objects.get(id=problem_id) ensure_created_by(problem.contest, user) except Problem.DoesNotExist: return self.error("Problem does not exist") return self.success(ProblemAdminSerializer(problem).data) if not contest_id: return self.error("Contest id is required") try: contest = Contest.objects.get(id=contest_id) ensure_created_by(contest, user) except Contest.DoesNotExist: return self.error("Contest does not exist") problems = Problem.objects.filter(contest=contest).order_by("-create_time") if user.is_admin(): problems = problems.filter(contest__created_by=user) keyword = request.GET.get("keyword") if keyword: problems = problems.filter(title__contains=keyword) return self.success( self.paginate_data(request, problems, ProblemAdminListSerializer) ) @validate_serializer(EditContestProblemSerializer) def put(self, request): data = request.data user = request.user try: contest = Contest.objects.get(id=data.pop("contest_id")) ensure_created_by(contest, user) except Contest.DoesNotExist: return self.error("Contest does not exist") if data["rule_type"] != contest.rule_type: return self.error("Invalid rule type") problem_id = data.pop("id") try: problem = Problem.objects.get(id=problem_id, contest=contest) except Problem.DoesNotExist: return self.error("Problem does not exist") _id = data["_id"] if not _id: return self.error("Display ID is required") if ( Problem.objects.exclude(id=problem_id) .filter(_id=_id, contest=contest) .exists() ): return self.error("Display ID already exists") error_info = self.common_checks(request) if error_info: return self.error(error_info) # todo check filename and score info tags = data.pop("tags") data["languages"] = list(data["languages"]) for k, v in data.items(): setattr(problem, k, v) problem.save() problem.tags.remove(*problem.tags.all()) for tag in tags: try: tag = ProblemTag.objects.get(name=tag) except ProblemTag.DoesNotExist: tag = ProblemTag.objects.create(name=tag) problem.tags.add(tag) return self.success() def delete(self, request): id = request.GET.get("id") if not id: return self.error("Invalid parameter, id is required") try: problem = Problem.objects.get(id=id, contest_id__isnull=False) except Problem.DoesNotExist: return self.error("Problem does not exists") ensure_created_by(problem.contest, request.user) if Submission.objects.filter(problem=problem).exists(): return self.error("Can't delete the problem as it has submissions") # d = os.path.join(settings.TEST_CASE_DIR, problem.test_case_id) # if os.path.isdir(d): # shutil.rmtree(d, ignore_errors=True) problem.delete() return self.success() class MakeContestProblemPublicAPIView(APIView): @validate_serializer(ContestProblemMakePublicSerializer) @problem_permission_required def post(self, request): data = request.data display_id = data.get("display_id") if Problem.objects.filter(_id=display_id, contest_id__isnull=True).exists(): return self.error("Duplicate display ID") try: problem = Problem.objects.get(id=data["id"]) except Problem.DoesNotExist: return self.error("Problem does not exist") if not problem.contest or problem.is_public: return self.error("Already be a public problem") problem.is_public = True problem.save() tags = problem.tags.all() problem.pk = None problem.contest = None problem._id = display_id problem.visible = False problem.submission_number = problem.accepted_number = 0 problem.statistic_info = {} problem.save() problem.tags.set(tags) return self.success() class AddContestProblemAPI(APIView): @validate_serializer(AddContestProblemSerializer) def post(self, request): data = request.data try: contest = Contest.objects.get(id=data["contest_id"]) problem = Problem.objects.get(id=data["problem_id"]) except (Contest.DoesNotExist, Problem.DoesNotExist): return self.error("Contest or Problem does not exist") if contest.status == ContestStatus.CONTEST_ENDED: return self.error("Contest has ended") if Problem.objects.filter(contest=contest, _id=data["display_id"]).exists(): return self.error("Duplicate display id in this contest") tags = problem.tags.all() problem.pk = None problem.contest = contest problem.is_public = True problem.visible = True problem._id = request.data["display_id"] problem.submission_number = problem.accepted_number = 0 problem.statistic_info = {} problem.save() problem.tags.set(tags) return self.success() class ProblemVisibleAPI(APIView): @problem_permission_required def put(self, request): data = request.data try: problem = Problem.objects.get(id=data["id"]) except Problem.DoesNotExist: self.error("problem does not exists") problem.visible = not problem.visible problem.save() return self.success()