From c466dfd3c6e52f66db6a24377b00bb7a009f246b Mon Sep 17 00:00:00 2001 From: yuetsh <517252939@qq.com> Date: Sat, 9 May 2026 02:30:47 -0600 Subject: [PATCH] change enum --- ...dmin_type_alter_user_problem_permission.py | 22 +++ account/models.py | 28 ++-- account/serializers.py | 18 +-- account/views/oj.py | 71 ++------- ...acmcontestrank_unique_together_and_more.py | 35 +++++ contest/models.py | 22 ++- contest/serializers.py | 2 +- flowchart/tasks.py | 4 +- judge/languages.py | 45 ++---- ..._alter_problem_unique_together_and_more.py | 33 ++++ problem/models.py | 32 ++-- problem/serializers.py | 25 +-- ...blemsetproblem_unique_together_and_more.py | 54 +++++++ problemset/models.py | 109 ++++++------- problemset/serializers.py | 68 ++++---- problemset/views/admin.py | 57 +++---- problemset/views/oj.py | 148 ++++-------------- .../0005_alter_submission_result.py | 33 ++++ submission/models.py | 44 ++---- tutorial/models.py | 30 ++-- tutorial/serializers.py | 6 +- utils/constants.py | 34 ++-- utils/migrate_data.py | 34 ++-- 23 files changed, 451 insertions(+), 503 deletions(-) create mode 100644 account/migrations/0004_alter_user_admin_type_alter_user_problem_permission.py create mode 100644 contest/migrations/0004_alter_acmcontestrank_unique_together_and_more.py create mode 100644 problem/migrations/0008_alter_problem_unique_together_and_more.py create mode 100644 problemset/migrations/0008_alter_problemsetproblem_unique_together_and_more.py create mode 100644 submission/migrations/0005_alter_submission_result.py diff --git a/account/migrations/0004_alter_user_admin_type_alter_user_problem_permission.py b/account/migrations/0004_alter_user_admin_type_alter_user_problem_permission.py new file mode 100644 index 0000000..9ef1ff9 --- /dev/null +++ b/account/migrations/0004_alter_user_admin_type_alter_user_problem_permission.py @@ -0,0 +1,22 @@ +# Generated by Django 6.0.4 on 2026-05-09 08:18 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("account", "0003_remove_userprofile_class_name_user_class_name"), + ] + + operations = [ + migrations.AlterField( + model_name="user", + name="admin_type", + field=models.TextField(choices=[("Regular User", "Regular User"), ("Admin", "Admin"), ("Super Admin", "Super Admin")], default="Regular User"), + ), + migrations.AlterField( + model_name="user", + name="problem_permission", + field=models.TextField(choices=[("None", "None"), ("Own", "Own"), ("All", "All")], default="None"), + ), + ] diff --git a/account/models.py b/account/models.py index 50ed357..e37cdc0 100644 --- a/account/models.py +++ b/account/models.py @@ -5,16 +5,16 @@ from django.db import models from utils.models import JSONField -class AdminType(object): - REGULAR_USER = "Regular User" - ADMIN = "Admin" - SUPER_ADMIN = "Super Admin" +class AdminType(models.TextChoices): + REGULAR_USER = "Regular User", "Regular User" + ADMIN = "Admin", "Admin" + SUPER_ADMIN = "Super Admin", "Super Admin" -class ProblemPermission(object): - NONE = "None" - OWN = "Own" - ALL = "All" +class ProblemPermission(models.TextChoices): + NONE = "None", "None" + OWN = "Own", "Own" + ALL = "All", "All" class UserManager(models.Manager): @@ -30,8 +30,8 @@ class User(AbstractBaseUser): email = models.TextField(null=True) create_time = models.DateTimeField(auto_now_add=True, null=True) # One of UserType - admin_type = models.TextField(default=AdminType.REGULAR_USER) - problem_permission = models.TextField(default=ProblemPermission.NONE) + admin_type = models.TextField(default=AdminType.REGULAR_USER, choices=AdminType.choices) + problem_permission = models.TextField(default=ProblemPermission.NONE, choices=ProblemPermission.choices) reset_password_token = models.TextField(null=True) reset_password_token_expire_time = models.DateTimeField(null=True) # SSO auth token @@ -43,9 +43,7 @@ class User(AbstractBaseUser): open_api = models.BooleanField(default=False) open_api_appkey = models.TextField(null=True) is_disabled = models.BooleanField(default=False) - raw_password = models.CharField( - max_length=20, null=True, blank=True, verbose_name="明文密码" - ) + raw_password = models.CharField(max_length=20, null=True, blank=True, verbose_name="明文密码") USERNAME_FIELD = "username" REQUIRED_FIELDS = [] @@ -68,9 +66,7 @@ class User(AbstractBaseUser): return self.problem_permission == ProblemPermission.ALL def is_contest_admin(self, contest): - return self.is_authenticated and ( - contest.created_by == self or self.admin_type == AdminType.SUPER_ADMIN - ) + return self.is_authenticated and (contest.created_by == self or self.admin_type == AdminType.SUPER_ADMIN) def set_password(self, raw_password): super().set_password(raw_password) diff --git a/account/serializers.py b/account/serializers.py index 2c8b513..05b7560 100644 --- a/account/serializers.py +++ b/account/serializers.py @@ -44,9 +44,7 @@ class GenerateUserSerializer(serializers.Serializer): class ImportUserSerializer(serializers.Serializer): - users = serializers.ListField( - child=serializers.ListField(child=serializers.CharField(max_length=64)) - ) + users = serializers.ListField(child=serializers.ListField(child=serializers.CharField(max_length=64))) class UserAdminSerializer(serializers.ModelSerializer): @@ -118,21 +116,16 @@ class EditUserSerializer(serializers.Serializer): id = serializers.IntegerField() username = serializers.CharField(max_length=32) real_name = serializers.CharField(max_length=32, allow_blank=True, allow_null=True) - password = serializers.CharField( - min_length=6, allow_blank=True, required=False, default=None - ) + password = serializers.CharField(min_length=6, allow_blank=True, required=False, default=None) email = serializers.EmailField(max_length=64) - admin_type = serializers.ChoiceField( - choices=(AdminType.REGULAR_USER, AdminType.ADMIN, AdminType.SUPER_ADMIN) - ) - problem_permission = serializers.ChoiceField( - choices=(ProblemPermission.NONE, ProblemPermission.OWN, ProblemPermission.ALL) - ) + admin_type = serializers.ChoiceField(choices=AdminType.choices) + problem_permission = serializers.ChoiceField(choices=ProblemPermission.choices) open_api = serializers.BooleanField() two_factor_auth = serializers.BooleanField() is_disabled = serializers.BooleanField() class_name = serializers.CharField(required=False, allow_null=True, allow_blank=True) + class EditUserProfileSerializer(serializers.Serializer): real_name = serializers.CharField(max_length=32, allow_null=True, required=False) avatar = serializers.CharField(max_length=256, allow_blank=True, required=False) @@ -143,6 +136,7 @@ class EditUserProfileSerializer(serializers.Serializer): major = serializers.CharField(max_length=64, allow_blank=True, required=False) language = serializers.CharField(max_length=32, allow_blank=True, required=False) + class ApplyResetPasswordSerializer(serializers.Serializer): email = serializers.EmailField() captcha = serializers.CharField() diff --git a/account/views/oj.py b/account/views/oj.py index 61337db..93547b6 100644 --- a/account/views/oj.py +++ b/account/views/oj.py @@ -78,9 +78,7 @@ class UserProfileAPI(APIView): show_real_name = True except User.DoesNotExist: return self.error("User does not exist") - return self.success( - UserProfileSerializer(user.userprofile, show_real_name=show_real_name).data - ) + return self.success(UserProfileSerializer(user.userprofile, show_real_name=show_real_name).data) @validate_serializer(EditUserProfileSerializer) @login_required @@ -90,9 +88,7 @@ class UserProfileAPI(APIView): for k, v in data.items(): setattr(user_profile, k, v) user_profile.save() - return self.success( - UserProfileSerializer(user_profile, show_real_name=True).data - ) + return self.success(UserProfileSerializer(user_profile, show_real_name=True).data) class Metrics(APIView): @@ -157,9 +153,7 @@ class TwoFactorAuthAPI(APIView): user.save() label = f"{SysOptions.website_name_shortcut}:{user.username}" - image = qrcode.make( - _totp_uri(token, label, SysOptions.website_name.replace(" ", "")) - ) + image = qrcode.make(_totp_uri(token, label, SysOptions.website_name.replace(" ", ""))) return self.success(img2base64(image)) @login_required @@ -224,9 +218,7 @@ class UserLoginAPI(APIView): if not user.two_factor_auth: prev_login = user.last_login auth.login(request, user) - request.session["prev_login"] = ( - datetime2str(prev_login) if prev_login else "" - ) + request.session["prev_login"] = datetime2str(prev_login) if prev_login else "" return self.success("Succeeded") # `tfa_code` not in post data @@ -236,9 +228,7 @@ class UserLoginAPI(APIView): if _valid_totp(user.tfa_token, data["tfa_code"]): prev_login = user.last_login auth.login(request, user) - request.session["prev_login"] = ( - datetime2str(prev_login) if prev_login else "" - ) + request.session["prev_login"] = datetime2str(prev_login) if prev_login else "" return self.success("Succeeded") else: return self.error("Invalid two factor verification code") @@ -262,9 +252,7 @@ class UsernameOrEmailCheck(APIView): # True means already exist. result = {"username": False, "email": False} if data.get("username"): - result["username"] = User.objects.filter( - username=data["username"].lower() - ).exists() + result["username"] = User.objects.filter(username=data["username"].lower()).exists() if data.get("email"): result["email"] = User.objects.filter(email=data["email"].lower()).exists() return self.success(result) @@ -301,9 +289,7 @@ class UserChangeEmailAPI(APIView): @login_required def post(self, request): data = request.data - user = auth.authenticate( - username=request.user.username, password=data["password"] - ) + user = auth.authenticate(username=request.user.username, password=data["password"]) if user: if user.two_factor_auth: if "tfa_code" not in data: @@ -356,12 +342,7 @@ class ApplyResetPasswordAPI(APIView): user = User.objects.get(email__iexact=data["email"]) except User.DoesNotExist: return self.error("User does not exist") - if ( - user.reset_password_token_expire_time - and 0 - < int((user.reset_password_token_expire_time - now()).total_seconds()) - < 20 * 60 - ): + if user.reset_password_token_expire_time and 0 < int((user.reset_password_token_expire_time - now()).total_seconds()) < 20 * 60: return self.error("You can only reset password once per 20 minutes") user.reset_password_token = rand_str() user.reset_password_token_expire_time = now() + timedelta(minutes=20) @@ -453,7 +434,7 @@ class UserRankAPI(APIView): n = int(request.GET.get("n", "0")) except ValueError: n = 0 - if rule_type not in ContestRuleType.choices(): + if rule_type not in ContestRuleType.values: rule_type = ContestRuleType.ACM profiles = UserProfile.objects.filter( @@ -462,9 +443,7 @@ class UserRankAPI(APIView): user__username__icontains=username, ).select_related("user") if rule_type == ContestRuleType.ACM: - profiles = profiles.filter(accepted_number__gte=0).order_by( - "-accepted_number", "submission_number" - ) + profiles = profiles.filter(accepted_number__gte=0).order_by("-accepted_number", "submission_number") else: profiles = profiles.filter(total_score__gt=0).order_by("-total_score") if n > 0: @@ -482,19 +461,13 @@ class UserActivityRankAPI(APIView): if cached is not None: return self.success(cached) - hidden_names = User.objects.filter( - Q(admin_type=AdminType.SUPER_ADMIN) | Q(is_disabled=True) - ).values_list("username", flat=True) + hidden_names = User.objects.filter(Q(admin_type=AdminType.SUPER_ADMIN) | Q(is_disabled=True)).values_list("username", flat=True) submissions = Submission.objects.filter( contest_id__isnull=True, create_time__gte=start, result=JudgeStatus.ACCEPTED, ).exclude(username__in=hidden_names) - data = list( - submissions.values("username") - .annotate(count=Count("problem_id", distinct=True)) - .order_by("-count")[:10] - ) + data = list(submissions.values("username").annotate(count=Count("problem_id", distinct=True)).order_by("-count")[:10]) cache.set(cache_key, data, 600) return self.success(data) @@ -506,12 +479,8 @@ class UserProblemRankAPI(APIView): if not user.is_authenticated: return self.error("User is not authenticated") - problem = Problem.objects.get( - _id=problem_id, contest_id__isnull=True, visible=True - ) - submissions = Submission.objects.filter( - problem=problem, result=JudgeStatus.ACCEPTED - ) + problem = Problem.objects.get(_id=problem_id, contest_id__isnull=True, visible=True) + submissions = Submission.objects.filter(problem=problem, result=JudgeStatus.ACCEPTED) all_ac_count = submissions.values("user_id").distinct().count() @@ -519,9 +488,7 @@ class UserProblemRankAPI(APIView): class_ac_count = 0 if class_name: - users = User.objects.filter( - class_name=user.class_name, is_disabled=False - ).values_list("id", flat=True) + users = User.objects.filter(class_name=user.class_name, is_disabled=False).values_list("id", flat=True) user_ids = list(users) submissions = submissions.filter(user_id__in=user_ids) class_ac_count = submissions.values("user_id").distinct().count() @@ -539,9 +506,7 @@ class UserProblemRankAPI(APIView): ) my_first_submission = my_submissions.order_by("create_time").first() - rank = submissions.filter( - create_time__lte=my_first_submission.create_time - ).count() + rank = submissions.filter(create_time__lte=my_first_submission.create_time).count() return self.success( { "class_name": class_name, @@ -561,9 +526,7 @@ class ProfileProblemDisplayIDRefreshAPI(APIView): ids = list(acm_problems.keys()) + list(oi_problems.keys()) if not ids: return self.success() - display_ids = Problem.objects.filter(id__in=ids, visible=True).values_list( - "_id", flat=True - ) + display_ids = Problem.objects.filter(id__in=ids, visible=True).values_list("_id", flat=True) id_map = dict(zip(ids, display_ids)) for k, v in acm_problems.items(): v["_id"] = id_map[k] diff --git a/contest/migrations/0004_alter_acmcontestrank_unique_together_and_more.py b/contest/migrations/0004_alter_acmcontestrank_unique_together_and_more.py new file mode 100644 index 0000000..f81fdf8 --- /dev/null +++ b/contest/migrations/0004_alter_acmcontestrank_unique_together_and_more.py @@ -0,0 +1,35 @@ +# Generated by Django 6.0.4 on 2026-05-09 08:18 + +from django.conf import settings +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("contest", "0003_acmcontestrank_acm_rank_contest_user_idx_and_more"), + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ] + + operations = [ + migrations.AlterUniqueTogether( + name="acmcontestrank", + unique_together=set(), + ), + migrations.AlterUniqueTogether( + name="oicontestrank", + unique_together=set(), + ), + migrations.AlterField( + model_name="contest", + name="rule_type", + field=models.TextField(choices=[("ACM", "ACM"), ("OI", "OI")]), + ), + migrations.AddConstraint( + model_name="acmcontestrank", + constraint=models.UniqueConstraint(fields=("user", "contest"), name="unique_acm_rank_user_contest"), + ), + migrations.AddConstraint( + model_name="oicontestrank", + constraint=models.UniqueConstraint(fields=("user", "contest"), name="unique_oi_rank_user_contest"), + ), + ] diff --git a/contest/models.py b/contest/models.py index 37b9c17..a8094d4 100644 --- a/contest/models.py +++ b/contest/models.py @@ -15,8 +15,7 @@ class Contest(models.Model): # show real time rank or cached rank real_time_rank = models.BooleanField() password = models.TextField(null=True) - # enum of ContestRuleType - rule_type = models.TextField() + rule_type = models.TextField(choices=ContestRuleType.choices) start_time = models.DateTimeField() end_time = models.DateTimeField() create_time = models.DateTimeField(auto_now_add=True) @@ -46,13 +45,7 @@ class Contest(models.Model): # 是否有权查看problem 的一些统计信息 诸如submission_number, accepted_number 等 def problem_details_permission(self, user): - return ( - self.rule_type == ContestRuleType.ACM - or self.status == ContestStatus.CONTEST_ENDED - or user.is_authenticated - and user.is_contest_admin(self) - or self.real_time_rank - ) + return self.rule_type == ContestRuleType.ACM or self.status == ContestStatus.CONTEST_ENDED or user.is_authenticated and user.is_contest_admin(self) or self.real_time_rank class Meta: db_table = "contest" @@ -78,10 +71,11 @@ class ACMContestRank(AbstractContestRank): class Meta: db_table = "acm_contest_rank" - unique_together = (("user", "contest"),) + constraints = [ + models.UniqueConstraint(fields=["user", "contest"], name="unique_acm_rank_user_contest"), + ] indexes = [ - models.Index(fields=["contest", "accepted_number", "total_time"], - name="acm_rank_order_idx"), + models.Index(fields=["contest", "accepted_number", "total_time"], name="acm_rank_order_idx"), models.Index(fields=["contest", "user"], name="acm_rank_contest_user_idx"), ] @@ -94,7 +88,9 @@ class OIContestRank(AbstractContestRank): class Meta: db_table = "oi_contest_rank" - unique_together = (("user", "contest"),) + constraints = [ + models.UniqueConstraint(fields=["user", "contest"], name="unique_oi_rank_user_contest"), + ] indexes = [ models.Index(fields=["contest", "total_score"], name="oi_rank_order_idx"), models.Index(fields=["contest", "user"], name="oi_rank_contest_user_idx"), diff --git a/contest/serializers.py b/contest/serializers.py index f96653e..ae2d3a2 100644 --- a/contest/serializers.py +++ b/contest/serializers.py @@ -9,7 +9,7 @@ class CreateConetestSeriaizer(serializers.Serializer): tag = serializers.CharField() start_time = serializers.DateTimeField() end_time = serializers.DateTimeField() - rule_type = serializers.ChoiceField(choices=[ContestRuleType.ACM, ContestRuleType.OI]) + rule_type = serializers.ChoiceField(choices=ContestRuleType.choices) password = serializers.CharField(allow_blank=True, max_length=32) visible = serializers.BooleanField() real_time_rank = serializers.BooleanField() diff --git a/flowchart/tasks.py b/flowchart/tasks.py index 071ea01..f6570be 100644 --- a/flowchart/tasks.py +++ b/flowchart/tasks.py @@ -3,8 +3,6 @@ import logging import time import dramatiq - -logger = logging.getLogger(__name__) from django.db import transaction from django.utils import timezone @@ -13,6 +11,8 @@ from utils.shortcuts import DRAMATIQ_WORKER_ARGS from .models import FlowchartSubmission, FlowchartSubmissionStatus +logger = logging.getLogger(__name__) + @dramatiq.actor(**DRAMATIQ_WORKER_ARGS(max_retries=3)) def evaluate_flowchart_task(submission_id): diff --git a/judge/languages.py b/judge/languages.py index 9627d0a..c08c1cc 100644 --- a/judge/languages.py +++ b/judge/languages.py @@ -27,11 +27,7 @@ int main() { "max_memory": 256 * 1024 * 1024, "compile_command": "/usr/bin/gcc -DONLINE_JUDGE -O2 -w -fmax-errors=3 -std=c17 {src_path} -lm -o {exe_path}", }, - "run": { - "command": "{exe_path}", - "seccomp_rule": {ProblemIOMode.standard: "c_cpp", ProblemIOMode.file: "c_cpp_file_io"}, - "env": default_env - } + "run": {"command": "{exe_path}", "seccomp_rule": {ProblemIOMode.STANDARD: "c_cpp", ProblemIOMode.FILE: "c_cpp_file_io"}, "env": default_env}, } @@ -60,11 +56,7 @@ int main() { "max_memory": 1024 * 1024 * 1024, "compile_command": "/usr/bin/g++ -DONLINE_JUDGE -O2 -w -fmax-errors=3 -std=c++20 {src_path} -lm -o {exe_path}", }, - "run": { - "command": "{exe_path}", - "seccomp_rule": {ProblemIOMode.standard: "c_cpp", ProblemIOMode.file: "c_cpp_file_io"}, - "env": default_env - } + "run": {"command": "{exe_path}", "seccomp_rule": {ProblemIOMode.STANDARD: "c_cpp", ProblemIOMode.FILE: "c_cpp_file_io"}, "env": default_env}, } @@ -91,14 +83,9 @@ class Main { "max_cpu_time": 5000, "max_real_time": 10000, "max_memory": -1, - "compile_command": "/usr/bin/javac {src_path} -d {exe_dir}" + "compile_command": "/usr/bin/javac {src_path} -d {exe_dir}", }, - "run": { - "command": "/usr/bin/java -cp {exe_dir} -XX:MaxRAM={max_memory}k Main", - "seccomp_rule": None, - "env": default_env, - "memory_limit_check_only": 1 - } + "run": {"command": "/usr/bin/java -cp {exe_dir} -XX:MaxRAM={max_memory}k Main", "seccomp_rule": None, "env": default_env, "memory_limit_check_only": 1}, } _py3_lang_config = { @@ -122,11 +109,7 @@ print(add(1, 2)) "max_memory": 128 * 1024 * 1024, "compile_command": "/usr/bin/python3 -m py_compile {src_path}", }, - "run": { - "command": "/usr/bin/python3 -BS {exe_path}", - "seccomp_rule": "general", - "env": default_env - } + "run": {"command": "/usr/bin/python3 -BS {exe_path}", "seccomp_rule": "general", "env": default_env}, } _go_lang_config = { @@ -154,14 +137,9 @@ func main() { "max_real_time": 5000, "max_memory": 1024 * 1024 * 1024, "compile_command": "/usr/bin/go build -o {exe_path} {src_path}", - "env": ["GOCACHE=/tmp", "GOPATH=/tmp", "GOMAXPROCS=1"] + default_env + "env": ["GOCACHE=/tmp", "GOPATH=/tmp", "GOMAXPROCS=1"] + default_env, }, - "run": { - "command": "{exe_path}", - "seccomp_rule": "golang", - "env": ["GOMAXPROCS=1"] + default_env, - "memory_limit_check_only": 1 - } + "run": {"command": "{exe_path}", "seccomp_rule": "golang", "env": ["GOMAXPROCS=1"] + default_env, "memory_limit_check_only": 1}, } _node_lang_config = { @@ -184,14 +162,9 @@ console.log(add(1, 2)) "max_real_time": 5000, "max_memory": 1024 * 1024 * 1024, "compile_command": "/usr/bin/node --check {src_path}", - "env": default_env - }, - "run": { - "command": "/usr/bin/node {exe_path}", - "seccomp_rule": "node", "env": default_env, - "memory_limit_check_only": 1 - } + }, + "run": {"command": "/usr/bin/node {exe_path}", "seccomp_rule": "node", "env": default_env, "memory_limit_check_only": 1}, } languages = [ diff --git a/problem/migrations/0008_alter_problem_unique_together_and_more.py b/problem/migrations/0008_alter_problem_unique_together_and_more.py new file mode 100644 index 0000000..a7c91f5 --- /dev/null +++ b/problem/migrations/0008_alter_problem_unique_together_and_more.py @@ -0,0 +1,33 @@ +# Generated by Django 6.0.4 on 2026-05-09 08:18 + +from django.conf import settings +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("contest", "0004_alter_acmcontestrank_unique_together_and_more"), + ("problem", "0007_problem_problem_visible_idx"), + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ] + + operations = [ + migrations.AlterUniqueTogether( + name="problem", + unique_together=set(), + ), + migrations.AlterField( + model_name="problem", + name="difficulty", + field=models.TextField(choices=[("Low", "Low"), ("Mid", "Mid"), ("High", "High")]), + ), + migrations.AlterField( + model_name="problem", + name="rule_type", + field=models.TextField(choices=[("ACM", "ACM"), ("OI", "OI")]), + ), + migrations.AddConstraint( + model_name="problem", + constraint=models.UniqueConstraint(fields=("_id", "contest"), name="unique_problem_id_contest"), + ), + ] diff --git a/problem/models.py b/problem/models.py index bcdeae7..0f8a205 100644 --- a/problem/models.py +++ b/problem/models.py @@ -2,7 +2,7 @@ from django.db import models from account.models import User from contest.models import Contest -from utils.constants import Choices +from utils.constants import Difficulty from utils.models import RichTextField @@ -13,25 +13,19 @@ class ProblemTag(models.Model): db_table = "problem_tag" -class ProblemRuleType(Choices): - ACM = "ACM" - OI = "OI" +class ProblemRuleType(models.TextChoices): + ACM = "ACM", "ACM" + OI = "OI", "OI" -class ProblemDifficulty(object): - High = "High" - Mid = "Mid" - Low = "Low" - - -class ProblemIOMode(Choices): - standard = "Standard IO" - file = "File IO" +class ProblemIOMode(models.TextChoices): + STANDARD = "Standard IO", "Standard IO" + FILE = "File IO", "File IO" def _default_io_mode(): return { - "io_mode": ProblemIOMode.standard, + "io_mode": ProblemIOMode.STANDARD, "input": "input.txt", "output": "output.txt", } @@ -66,9 +60,9 @@ class Problem(models.Model): memory_limit = models.IntegerField() # io mode io_mode = models.JSONField(default=_default_io_mode) - rule_type = models.TextField() + rule_type = models.TextField(choices=ProblemRuleType.choices) visible = models.BooleanField(default=True) - difficulty = models.TextField() + difficulty = models.TextField(choices=Difficulty.choices) tags = models.ManyToManyField(ProblemTag) source = models.TextField(null=True) prompt = models.TextField(null=True) @@ -81,7 +75,7 @@ class Problem(models.Model): # {JudgeStatus.ACCEPTED: 3, JudgeStatus.WRONG_ANSWER: 11}, the number means count statistic_info = models.JSONField(default=dict) share_submission = models.BooleanField(default=False) - + # 流程图相关字段 allow_flowchart = models.BooleanField(default=False) # 是否允许/需要提交流程图 mermaid_code = models.TextField(null=True, blank=True) # 流程图答案(Mermaid代码) @@ -91,7 +85,9 @@ class Problem(models.Model): class Meta: db_table = "problem" - unique_together = (("_id", "contest"),) + constraints = [ + models.UniqueConstraint(fields=["_id", "contest"], name="unique_problem_id_contest"), + ] ordering = ("create_time",) indexes = [ models.Index(fields=["contest", "visible"], name="problem_contest_visible_idx"), diff --git a/problem/serializers.py b/problem/serializers.py index 38073c8..b62593e 100644 --- a/problem/serializers.py +++ b/problem/serializers.py @@ -38,7 +38,7 @@ class CreateProblemCodeTemplateSerializer(serializers.Serializer): class ProblemIOModeSerializer(serializers.Serializer): - io_mode = serializers.ChoiceField(choices=ProblemIOMode.choices()) + io_mode = serializers.ChoiceField(choices=ProblemIOMode.choices) input = serializers.CharField() output = serializers.CharField() @@ -59,22 +59,16 @@ class CreateOrEditProblemSerializer(serializers.Serializer): output_description = serializers.CharField() samples = serializers.ListField(child=CreateSampleSerializer(), allow_empty=False) test_case_id = serializers.CharField(max_length=32) - test_case_score = serializers.ListField( - child=CreateTestCaseScoreSerializer(), allow_empty=True - ) + test_case_score = serializers.ListField(child=CreateTestCaseScoreSerializer(), allow_empty=True) time_limit = serializers.IntegerField(min_value=1, max_value=1000 * 60) memory_limit = serializers.IntegerField(min_value=1, max_value=1024) languages = LanguageNameMultiChoiceField() template = serializers.DictField(child=serializers.CharField(min_length=1)) - rule_type = serializers.ChoiceField( - choices=[ProblemRuleType.ACM, ProblemRuleType.OI] - ) + rule_type = serializers.ChoiceField(choices=ProblemRuleType.choices) io_mode = ProblemIOModeSerializer() visible = serializers.BooleanField() - difficulty = serializers.ChoiceField(choices=Difficulty.choices()) - tags = serializers.ListField( - child=serializers.CharField(max_length=32), allow_empty=False - ) + difficulty = serializers.ChoiceField(choices=Difficulty.choices) + tags = serializers.ListField(child=serializers.CharField(max_length=32), allow_empty=False) hint = serializers.CharField(allow_blank=True, allow_null=True) source = serializers.CharField(max_length=256, allow_blank=True, allow_null=True) prompt = serializers.CharField(allow_blank=True, allow_null=True) @@ -88,13 +82,9 @@ class CreateOrEditProblemSerializer(serializers.Serializer): # 流程图相关字段 allow_flowchart = serializers.BooleanField(required=False, default=False) show_flowchart = serializers.BooleanField(required=False, default=False) - mermaid_code = serializers.CharField( - allow_blank=True, allow_null=True, required=False - ) + mermaid_code = serializers.CharField(allow_blank=True, allow_null=True, required=False) - flowchart_hint = serializers.CharField( - allow_blank=True, allow_null=True, required=False - ) + flowchart_hint = serializers.CharField(allow_blank=True, allow_null=True, required=False) class CreateProblemSerializer(CreateOrEditProblemSerializer): @@ -220,6 +210,7 @@ class ProblemSafeSerializer(BaseProblemSerializer): return None return obj.flowchart_data + class ContestProblemMakePublicSerializer(serializers.Serializer): id = serializers.IntegerField() display_id = serializers.CharField(max_length=32) diff --git a/problemset/migrations/0008_alter_problemsetproblem_unique_together_and_more.py b/problemset/migrations/0008_alter_problemsetproblem_unique_together_and_more.py new file mode 100644 index 0000000..e6b77e8 --- /dev/null +++ b/problemset/migrations/0008_alter_problemsetproblem_unique_together_and_more.py @@ -0,0 +1,54 @@ +# Generated by Django 6.0.4 on 2026-05-09 08:18 + +from django.conf import settings +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("problem", "0008_alter_problem_unique_together_and_more"), + ("problemset", "0007_problemset_end_time"), + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ] + + operations = [ + migrations.AlterUniqueTogether( + name="problemsetproblem", + unique_together=set(), + ), + migrations.AlterUniqueTogether( + name="problemsetprogress", + unique_together=set(), + ), + migrations.AlterUniqueTogether( + name="userbadge", + unique_together=set(), + ), + migrations.AlterField( + model_name="problemset", + name="status", + field=models.TextField(choices=[("draft", "Draft"), ("active", "Active"), ("archived", "Archived")], default="draft", verbose_name="状态"), + ), + migrations.AlterField( + model_name="problemset", + name="difficulty", + field=models.TextField(choices=[("Easy", "Easy"), ("Medium", "Medium"), ("Hard", "Hard")], default="Easy", verbose_name="难度等级"), + ), + migrations.AlterField( + model_name="problemsetbadge", + name="condition_type", + field=models.TextField(choices=[("all_problems", "All Problems"), ("problem_count", "Problem Count"), ("score", "Score")], verbose_name="获得条件类型"), + ), + migrations.AddConstraint( + model_name="problemsetproblem", + constraint=models.UniqueConstraint(fields=("problemset", "problem"), name="unique_problemset_problem"), + ), + migrations.AddConstraint( + model_name="problemsetprogress", + constraint=models.UniqueConstraint(fields=("problemset", "user"), name="unique_problemset_progress_user"), + ), + migrations.AddConstraint( + model_name="userbadge", + constraint=models.UniqueConstraint(fields=("user", "badge"), name="unique_user_badge"), + ), + ] diff --git a/problemset/models.py b/problemset/models.py index 9f2392d..98b74e6 100644 --- a/problemset/models.py +++ b/problemset/models.py @@ -6,15 +6,31 @@ from problem.models import Problem from utils.models import JSONField, RichTextField +class ProblemSetStatus(models.TextChoices): + DRAFT = "draft", "Draft" + ACTIVE = "active", "Active" + ARCHIVED = "archived", "Archived" + + +class ProblemSetDifficulty(models.TextChoices): + EASY = "Easy", "Easy" + MEDIUM = "Medium", "Medium" + HARD = "Hard", "Hard" + + +class BadgeConditionType(models.TextChoices): + ALL_PROBLEMS = "all_problems", "All Problems" + PROBLEM_COUNT = "problem_count", "Problem Count" + SCORE = "score", "Score" + + class ProblemSet(models.Model): """题单模型""" title = models.TextField(verbose_name="题单标题") description = RichTextField(verbose_name="题单描述") # 创建者 - created_by = models.ForeignKey( - User, on_delete=models.CASCADE, verbose_name="创建者" - ) + created_by = models.ForeignKey(User, on_delete=models.CASCADE, verbose_name="创建者") # 创建时间 create_time = models.DateTimeField(auto_now_add=True, verbose_name="创建时间") # 更新时间 @@ -22,11 +38,13 @@ class ProblemSet(models.Model): # 是否可见 visible = models.BooleanField(default=True, verbose_name="是否可见") # 题单难度等级 - difficulty = models.TextField(default="Easy", verbose_name="难度等级") + difficulty = models.TextField( + default=ProblemSetDifficulty.EASY, + choices=ProblemSetDifficulty.choices, + verbose_name="难度等级", + ) # 题单状态 - status = models.TextField( - default="draft", verbose_name="状态" - ) # active, archived, draft + status = models.TextField(default=ProblemSetStatus.DRAFT, choices=ProblemSetStatus.choices, verbose_name="状态") # 截止时间(到期后自动解除防作弊隐藏) end_time = models.DateTimeField(null=True, blank=True, verbose_name="截止时间") @@ -43,9 +61,7 @@ class ProblemSet(models.Model): class ProblemSetProblem(models.Model): """题单题目关联模型""" - problemset = models.ForeignKey( - ProblemSet, on_delete=models.CASCADE, verbose_name="题单" - ) + problemset = models.ForeignKey(ProblemSet, on_delete=models.CASCADE, verbose_name="题单") problem = models.ForeignKey(Problem, on_delete=models.CASCADE, verbose_name="题目") # 在题单中的顺序 order = models.IntegerField(default=0, verbose_name="顺序") @@ -58,7 +74,9 @@ class ProblemSetProblem(models.Model): class Meta: db_table = "problemset_problem" - unique_together = (("problemset", "problem"),) + constraints = [ + models.UniqueConstraint(fields=["problemset", "problem"], name="unique_problemset_problem"), + ] ordering = ("order",) verbose_name = "题单题目" verbose_name_plural = "题单题目" @@ -70,17 +88,13 @@ class ProblemSetProblem(models.Model): class ProblemSetBadge(models.Model): """题单奖章模型""" - problemset = models.ForeignKey( - ProblemSet, on_delete=models.CASCADE, verbose_name="题单" - ) + problemset = models.ForeignKey(ProblemSet, on_delete=models.CASCADE, verbose_name="题单") name = models.TextField(verbose_name="奖章名称") description = models.TextField(verbose_name="奖章描述") # 奖章图标路径 icon = models.TextField(verbose_name="奖章图标") # 获得条件:完成所有题目、完成指定数量题目、达到指定分数等 - condition_type = models.TextField( - verbose_name="获得条件类型" - ) # all_problems, problem_count, score + condition_type = models.TextField(choices=BadgeConditionType.choices, verbose_name="获得条件类型") condition_value = models.IntegerField(default=0, verbose_name="条件值") class Meta: @@ -90,17 +104,13 @@ class ProblemSetBadge(models.Model): def __str__(self): return f"{self.problemset.title} - {self.name}" - + def recalculate_user_badges(self): """重新计算所有用户的徽章资格""" from django.db import transaction user_progresses = ProblemSetProgress.objects.filter(problemset=self.problemset) - new_badges = [ - UserBadge(user=progress.user, badge=self) - for progress in user_progresses - if self._is_eligible(progress) - ] + new_badges = [UserBadge(user=progress.user, badge=self) for progress in user_progresses if self._is_eligible(progress)] with transaction.atomic(): UserBadge.objects.filter(badge=self).delete() if new_badges: @@ -118,9 +128,7 @@ class ProblemSetBadge(models.Model): def _check_user_badge_eligibility(self, progress): """检查并授予单个用户的徽章(供外部单次调用)""" - if self._is_eligible(progress) and not UserBadge.objects.filter( - user=progress.user, badge=self - ).exists(): + if self._is_eligible(progress) and not UserBadge.objects.filter(user=progress.user, badge=self).exists(): UserBadge.objects.create(user=progress.user, badge=self) return True return False @@ -129,9 +137,7 @@ class ProblemSetBadge(models.Model): class ProblemSetProgress(models.Model): """题单进度模型""" - problemset = models.ForeignKey( - ProblemSet, on_delete=models.CASCADE, verbose_name="题单" - ) + problemset = models.ForeignKey(ProblemSet, on_delete=models.CASCADE, verbose_name="题单") user = models.ForeignKey(User, on_delete=models.CASCADE, verbose_name="用户") # 加入时间 join_time = models.DateTimeField(auto_now_add=True, verbose_name="加入时间") @@ -142,9 +148,7 @@ class ProblemSetProgress(models.Model): # 完成进度百分比 progress_percentage = models.FloatField(default=0.0, verbose_name="完成进度") # 已完成的题目数量 - completed_problems_count = models.IntegerField( - default=0, verbose_name="已完成题目数" - ) + completed_problems_count = models.IntegerField(default=0, verbose_name="已完成题目数") # 总题目数量 total_problems_count = models.IntegerField(default=0, verbose_name="总题目数") # 获得的总分 @@ -155,7 +159,9 @@ class ProblemSetProgress(models.Model): class Meta: db_table = "problemset_progress" - unique_together = (("problemset", "user"),) + constraints = [ + models.UniqueConstraint(fields=["problemset", "user"], name="unique_problemset_progress_user"), + ] verbose_name = "题单进度" verbose_name_plural = "题单进度" @@ -165,9 +171,7 @@ class ProblemSetProgress(models.Model): def update_progress(self): """更新进度信息""" # 获取题单中的所有题目 - problemset_problems = ProblemSetProblem.objects.filter( - problemset=self.problemset - ) + problemset_problems = ProblemSetProblem.objects.filter(problemset=self.problemset) self.total_problems_count = problemset_problems.count() # 获取当前题单中所有题目的ID集合(直接用 problem_id FK 字段,无需额外查询) @@ -199,9 +203,7 @@ class ProblemSetProgress(models.Model): # 计算完成百分比 if self.total_problems_count > 0: - self.progress_percentage = ( - completed_count / self.total_problems_count - ) * 100 + self.progress_percentage = (completed_count / self.total_problems_count) * 100 else: self.progress_percentage = 0 @@ -223,17 +225,11 @@ class ProblemSetProgress(models.Model): class ProblemSetSubmission(models.Model): """题单提交记录模型""" - - problemset = models.ForeignKey( - ProblemSet, on_delete=models.CASCADE, verbose_name="题单" - ) + + problemset = models.ForeignKey(ProblemSet, on_delete=models.CASCADE, verbose_name="题单") user = models.ForeignKey(User, on_delete=models.CASCADE, verbose_name="用户") - submission = models.ForeignKey( - "submission.Submission", on_delete=models.CASCADE, verbose_name="提交记录" - ) - problem = models.ForeignKey( - "problem.Problem", on_delete=models.CASCADE, verbose_name="题目" - ) + submission = models.ForeignKey("submission.Submission", on_delete=models.CASCADE, verbose_name="提交记录") + problem = models.ForeignKey("problem.Problem", on_delete=models.CASCADE, verbose_name="题目") class Meta: db_table = "problemset_submission" @@ -253,34 +249,33 @@ class ProblemSetSubmission(models.Model): def submit_time(self): """提交时间""" return self.submission.create_time - + @property def result(self): """提交结果""" return self.submission.result - + @property def language(self): """编程语言""" return self.submission.language + class UserBadge(models.Model): """用户奖章模型""" user = models.ForeignKey(User, on_delete=models.CASCADE, verbose_name="用户") - badge = models.ForeignKey( - ProblemSetBadge, on_delete=models.CASCADE, verbose_name="奖章" - ) + badge = models.ForeignKey(ProblemSetBadge, on_delete=models.CASCADE, verbose_name="奖章") # 获得时间 earned_time = models.DateTimeField(auto_now_add=True, verbose_name="获得时间") class Meta: db_table = "user_badge" - unique_together = (("user", "badge"),) + constraints = [ + models.UniqueConstraint(fields=["user", "badge"], name="unique_user_badge"), + ] verbose_name = "用户奖章" verbose_name_plural = "用户奖章" def __str__(self): return f"{self.user.username} - {self.badge.name}" - - diff --git a/problemset/serializers.py b/problemset/serializers.py index 3adab99..8aac2c0 100644 --- a/problemset/serializers.py +++ b/problemset/serializers.py @@ -1,10 +1,13 @@ from utils.api import UsernameSerializer, serializers from .models import ( + BadgeConditionType, ProblemSet, ProblemSetBadge, + ProblemSetDifficulty, ProblemSetProblem, ProblemSetProgress, + ProblemSetStatus, UserBadge, ) @@ -13,9 +16,7 @@ def get_user_progress_data(problemset, request): """获取当前用户在该题单中的进度 - 公共方法""" if request and request.user.is_authenticated: try: - progress = ProblemSetProgress.objects.get( - problemset=problemset, user=request.user - ) + progress = ProblemSetProgress.objects.get(problemset=problemset, user=request.user) return { "is_joined": True, "progress_percentage": progress.progress_percentage, @@ -61,9 +62,7 @@ class ProblemSetSerializer(serializers.ModelSerializer): request = self.context.get("request") if request and request.user.is_authenticated: try: - progress = ProblemSetProgress.objects.get( - problemset=obj, user=request.user - ) + progress = ProblemSetProgress.objects.get(problemset=obj, user=request.user) return progress.completed_problems_count except ProblemSetProgress.DoesNotExist: return 0 @@ -124,22 +123,22 @@ class ProblemSetListSerializer(serializers.ModelSerializer): def get_badges(self, obj): """获取题单的奖章列表,并标记用户已获得的徽章""" request = self.context.get("request") - + # 使用预加载的奖章数据 badges = getattr(obj, "badges", []) badge_data = ProblemSetBadgeSerializer(badges, many=True).data - + # 如果用户已登录,检查哪些徽章已被获得 if request and request.user.is_authenticated and hasattr(request, "_user_earned_badge_ids"): earned_badge_ids = request._user_earned_badge_ids # 为每个徽章添加是否已获得的标记 for badge in badge_data: - badge['is_earned'] = badge['id'] in earned_badge_ids + badge["is_earned"] = badge["id"] in earned_badge_ids else: # 未登录用户或未预加载,所有徽章都标记为未获得 for badge in badge_data: - badge['is_earned'] = False - + badge["is_earned"] = False + return badge_data @@ -148,8 +147,8 @@ class CreateProblemSetSerializer(serializers.Serializer): title = serializers.CharField(max_length=200) description = serializers.CharField() - difficulty = serializers.CharField(default="Easy") - status = serializers.CharField(default="active") + difficulty = serializers.ChoiceField(choices=ProblemSetDifficulty.choices, default=ProblemSetDifficulty.EASY) + status = serializers.ChoiceField(choices=ProblemSetStatus.choices, default=ProblemSetStatus.ACTIVE) end_time = serializers.DateTimeField(required=False) @@ -159,8 +158,8 @@ class EditProblemSetSerializer(serializers.Serializer): id = serializers.IntegerField() title = serializers.CharField(max_length=200, required=False) description = serializers.CharField(required=False) - difficulty = serializers.CharField(required=False) - status = serializers.CharField(required=False) + difficulty = serializers.ChoiceField(choices=ProblemSetDifficulty.choices, required=False) + status = serializers.ChoiceField(choices=ProblemSetStatus.choices, required=False) visible = serializers.BooleanField(required=False) end_time = serializers.DateTimeField(required=False, allow_null=True) @@ -190,9 +189,7 @@ class ProblemSetProblemSerializer(serializers.ModelSerializer): progress = self.context.get("user_progress") if progress is None: try: - progress = ProblemSetProgress.objects.get( - problemset=obj.problemset, user=request.user - ) + progress = ProblemSetProgress.objects.get(problemset=obj.problemset, user=request.user) except ProblemSetProgress.DoesNotExist: return False return str(obj.problem.id) in progress.progress_detail @@ -227,19 +224,21 @@ class ProblemSetBadgeSerializer(serializers.ModelSerializer): class CreateProblemSetBadgeSerializer(serializers.Serializer): """创建题单奖章序列化器""" + name = serializers.CharField(max_length=100) description = serializers.CharField() icon = serializers.CharField() - condition_type = serializers.CharField() # all_problems, problem_count, score + condition_type = serializers.ChoiceField(choices=BadgeConditionType.choices) condition_value = serializers.IntegerField(required=False) class EditProblemSetBadgeSerializer(serializers.Serializer): """编辑题单奖章序列化器""" + name = serializers.CharField(max_length=100, required=False) description = serializers.CharField(required=False) icon = serializers.CharField(required=False) - condition_type = serializers.CharField(required=False) # all_problems, problem_count, score + condition_type = serializers.ChoiceField(choices=BadgeConditionType.choices, required=False) condition_value = serializers.IntegerField(required=False) @@ -252,42 +251,35 @@ class ProblemSetProgressSerializer(serializers.ModelSerializer): class Meta: model = ProblemSetProgress fields = "__all__" - + def get_completed_problems(self, obj): """获取已完成的题目列表""" completed_problems = [] - + # 尝试从 request 中获取预加载的问题字典(用于批量查询优化) problems_dict = {} - request = self.context.get('request') - if request and hasattr(request, '_problems_dict_cache'): + request = self.context.get("request") + if request and hasattr(request, "_problems_dict_cache"): problems_dict = request._problems_dict_cache - + if obj.progress_detail: for problem_id in obj.progress_detail.keys(): # 优先使用预加载的问题字典 if problems_dict: problem = problems_dict.get(problem_id) if problem: - completed_problems.append({ - 'id': problem.id, - '_id': problem._id, - 'title': problem.title - }) + completed_problems.append({"id": problem.id, "_id": problem._id, "title": problem.title}) continue - + # 如果没有预加载字典,则回退到单独查询(向后兼容) from problem.models import Problem + try: problem = Problem.objects.get(id=problem_id) - completed_problems.append({ - 'id': problem.id, - '_id': problem._id, - 'title': problem.title - }) + completed_problems.append({"id": problem.id, "_id": problem._id, "title": problem.title}) except Problem.DoesNotExist: continue - + return completed_problems @@ -313,5 +305,3 @@ class UpdateProgressSerializer(serializers.Serializer): problemset_id = serializers.IntegerField() problem_id = serializers.IntegerField() submission_id = serializers.CharField(required=False) - - diff --git a/problemset/views/admin.py b/problemset/views/admin.py index d36e6ac..d7a2ef0 100644 --- a/problemset/views/admin.py +++ b/problemset/views/admin.py @@ -7,6 +7,7 @@ from problemset.models import ( ProblemSetBadge, ProblemSetProblem, ProblemSetProgress, + ProblemSetStatus, ) from problemset.serializers import ( AddProblemToSetSerializer, @@ -35,9 +36,7 @@ class ProblemSetAdminAPI(APIView): # 过滤条件 keyword = request.GET.get("keyword", "").strip() if keyword: - problem_sets = problem_sets.filter( - Q(title__icontains=keyword) | Q(description__icontains=keyword) - ) + problem_sets = problem_sets.filter(Q(title__icontains=keyword) | Q(description__icontains=keyword)) difficulty = request.GET.get("difficulty") if difficulty: @@ -129,12 +128,8 @@ class ProblemSetProblemAdminAPI(APIView): except ProblemSet.DoesNotExist: return self.error("题单不存在") - problems = ProblemSetProblem.objects.filter(problemset=problem_set).order_by( - "order" - ) - serializer = ProblemSetProblemSerializer( - problems, many=True, context={"request": request} - ) + problems = ProblemSetProblem.objects.filter(problemset=problem_set).order_by("order") + serializer = ProblemSetProblemSerializer(problems, many=True, context={"request": request}) return self.success(serializer.data) @super_admin_required @@ -158,9 +153,7 @@ class ProblemSetProblemAdminAPI(APIView): return self.error("题目不存在或不可见") # 检查题目是否已经在题单中 - if ProblemSetProblem.objects.filter( - problemset=problem_set, problem=problem - ).exists(): + if ProblemSetProblem.objects.filter(problemset=problem_set, problem=problem).exists(): return self.error("题目已在该题单中") ProblemSetProblem.objects.create( @@ -188,9 +181,7 @@ class ProblemSetProblemAdminAPI(APIView): return self.error("题单不存在") try: - problem_set_problem = ProblemSetProblem.objects.get( - id=problem_set_problem_id, problemset=problem_set - ) + problem_set_problem = ProblemSetProblem.objects.get(id=problem_set_problem_id, problemset=problem_set) except ProblemSetProblem.DoesNotExist: return self.error("题目不在该题单中") @@ -206,10 +197,10 @@ class ProblemSetProblemAdminAPI(APIView): problem_set_problem.hint = data["hint"] problem_set_problem.save() - + # 同步所有用户的进度 ProblemSetProgress.sync_all_progress_for_problemset(problem_set) - + return self.success("题目已更新") @super_admin_required @@ -222,14 +213,12 @@ class ProblemSetProblemAdminAPI(APIView): return self.error("题单不存在") try: - problem_set_problem = ProblemSetProblem.objects.get( - id=problem_set_problem_id, problemset=problem_set - ) + problem_set_problem = ProblemSetProblem.objects.get(id=problem_set_problem_id, problemset=problem_set) problem_set_problem.delete() - + # 同步所有用户的进度 ProblemSetProgress.sync_all_progress_for_problemset(problem_set) - + return self.success("题目已从题单中移除") except ProblemSetProblem.DoesNotExist: return self.error("题目不在该题单中") @@ -283,10 +272,10 @@ class ProblemSetBadgeAdminAPI(APIView): return self.error("奖章不存在") data = request.data - + # 记录是否修改了条件相关的字段 condition_changed = False - + # 更新奖章属性 if "name" in data: badge.name = data["name"] @@ -304,7 +293,7 @@ class ProblemSetBadgeAdminAPI(APIView): badge.level = data["level"] badge.save() - + # 如果修改了条件,重新计算所有用户的徽章资格 if condition_changed: try: @@ -312,7 +301,7 @@ class ProblemSetBadgeAdminAPI(APIView): return self.success("奖章已更新,并重新计算了所有用户的徽章资格") except Exception as e: return self.error(f"奖章已更新,但重新计算徽章资格时出错: {str(e)}") - + return self.success("奖章已更新") @super_admin_required @@ -344,9 +333,7 @@ class ProblemSetProgressAdminAPI(APIView): except ProblemSet.DoesNotExist: return self.error("题单不存在") - progress_list = ProblemSetProgress.objects.filter( - problemset=problem_set - ).order_by("-join_time") + progress_list = ProblemSetProgress.objects.filter(problemset=problem_set).order_by("-join_time") serializer = ProblemSetProgressSerializer(progress_list, many=True) return self.success(serializer.data) @@ -360,9 +347,7 @@ class ProblemSetProgressAdminAPI(APIView): return self.error("题单不存在") try: - progress = ProblemSetProgress.objects.get( - problemset=problem_set, user_id=user_id - ) + progress = ProblemSetProgress.objects.get(problemset=problem_set, user_id=user_id) progress.delete() return self.success("用户已从题单中移除") except ProblemSetProgress.DoesNotExist: @@ -371,7 +356,7 @@ class ProblemSetProgressAdminAPI(APIView): class ProblemSetSyncAPI(APIView): """题单同步管理API""" - + @super_admin_required def post(self, request, problem_set_id): """手动同步题单的所有用户进度(管理员)""" @@ -380,10 +365,10 @@ class ProblemSetSyncAPI(APIView): ensure_created_by(problem_set, request.user) except ProblemSet.DoesNotExist: return self.error("题单不存在") - + # 同步所有用户的进度 synced_count = ProblemSetProgress.sync_all_progress_for_problemset(problem_set) - + return self.success(f"已同步 {synced_count} 个用户的进度") @@ -419,7 +404,7 @@ class ProblemSetStatusAPI(APIView): return self.error("题单不存在") status = data.get("status") - if status not in ["active", "archived", "draft"]: + if status not in ProblemSetStatus.values: return self.error("无效的状态") problem_set.status = status diff --git a/problemset/views/oj.py b/problemset/views/oj.py index 8bfb7bc..7e15cd9 100644 --- a/problemset/views/oj.py +++ b/problemset/views/oj.py @@ -32,18 +32,14 @@ class ProblemSetAPI(APIView): """获取题单列表""" # 预加载创建者信息 problem_sets = ProblemSet.objects.filter(visible=True).exclude(status="draft").select_related("created_by") - + # 使用annotate在查询时计算题目数量,避免N+1查询 - problem_sets = problem_sets.annotate( - problems_count=Count("problemsetproblem", distinct=True) - ) + problem_sets = problem_sets.annotate(problems_count=Count("problemsetproblem", distinct=True)) # 过滤条件 keyword = request.GET.get("keyword", "").strip() if keyword: - problem_sets = problem_sets.filter( - Q(title__icontains=keyword) | Q(description__icontains=keyword) - ) + problem_sets = problem_sets.filter(Q(title__icontains=keyword) | Q(description__icontains=keyword)) difficulty = request.GET.get("difficulty") if difficulty: @@ -67,33 +63,19 @@ class ProblemSetAPI(APIView): if request.user.is_authenticated: # 先获取所有题单ID(不应用prefetch_related,只获取ID) problem_set_ids = list(problem_sets.values_list("id", flat=True)) - + if problem_set_ids: # 批量查询用户在这些题单中的进度 - user_progresses = ProblemSetProgress.objects.filter( - problemset_id__in=problem_set_ids, - user=request.user - ).select_related("problemset") + user_progresses = ProblemSetProgress.objects.filter(problemset_id__in=problem_set_ids, user=request.user).select_related("problemset") # 构建映射:题单ID -> 进度对象 user_progress_map = {progress.problemset_id: progress for progress in user_progresses} - + # 批量查询用户已获得的奖章ID(这些题单相关的) - user_earned_badge_ids = set( - UserBadge.objects.filter( - user=request.user, - badge__problemset_id__in=problem_set_ids - ).values_list('badge_id', flat=True) - ) - + user_earned_badge_ids = set(UserBadge.objects.filter(user=request.user, badge__problemset_id__in=problem_set_ids).values_list("badge_id", flat=True)) + # 预加载奖章信息(在获取ID之后应用,避免在获取ID时也预加载) - problem_sets = problem_sets.prefetch_related( - Prefetch( - "problemsetbadge_set", - queryset=ProblemSetBadge.objects.all(), - to_attr="badges" - ) - ) - + problem_sets = problem_sets.prefetch_related(Prefetch("problemsetbadge_set", queryset=ProblemSetBadge.objects.all(), to_attr="badges")) + # 将用户进度映射和已获得的奖章ID集合存储到request中,供序列化器使用 request._user_progress_map = user_progress_map request._user_earned_badge_ids = user_earned_badge_ids @@ -108,11 +90,7 @@ class ProblemSetDetailAPI(APIView): def get(self, request, problem_set_id): """获取题单详情""" try: - problem_set = ( - ProblemSet.objects.filter(id=problem_set_id, visible=True) - .exclude(status="draft") - .get() - ) + problem_set = ProblemSet.objects.filter(id=problem_set_id, visible=True).exclude(status="draft").get() except ProblemSet.DoesNotExist: return self.error("题单不存在") @@ -126,32 +104,19 @@ class ProblemSetProblemAPI(APIView): def get(self, request, problem_set_id): """获取题单中的题目列表""" try: - problem_set = ( - ProblemSet.objects.filter(id=problem_set_id, visible=True) - .exclude(status="draft") - .get() - ) + problem_set = ProblemSet.objects.filter(id=problem_set_id, visible=True).exclude(status="draft").get() except ProblemSet.DoesNotExist: return self.error("题单不存在") - problems = ( - ProblemSetProblem.objects.filter(problemset=problem_set) - .select_related("problem__created_by") - .prefetch_related("problem__tags") - .order_by("order") - ) + problems = ProblemSetProblem.objects.filter(problemset=problem_set).select_related("problem__created_by").prefetch_related("problem__tags").order_by("order") # 预取当前用户的题单进度,供 get_is_completed 使用,避免 N+1 user_progress = None if request.user.is_authenticated: try: - user_progress = ProblemSetProgress.objects.get( - problemset=problem_set, user=request.user - ) + user_progress = ProblemSetProgress.objects.get(problemset=problem_set, user=request.user) except ProblemSetProgress.DoesNotExist: pass - serializer = ProblemSetProblemSerializer( - problems, many=True, context={"request": request, "user_progress": user_progress} - ) + serializer = ProblemSetProblemSerializer(problems, many=True, context={"request": request, "user_progress": user_progress}) return self.success(serializer.data) @@ -163,23 +128,15 @@ class ProblemSetProgressAPI(APIView): """加入题单""" data = request.data try: - problem_set = ( - ProblemSet.objects.filter(id=data["problemset_id"], visible=True) - .exclude(status="draft") - .get() - ) + problem_set = ProblemSet.objects.filter(id=data["problemset_id"], visible=True).exclude(status="draft").get() except ProblemSet.DoesNotExist: return self.error("题单不存在") - if ProblemSetProgress.objects.filter( - problemset=problem_set, user=request.user - ).exists(): + if ProblemSetProgress.objects.filter(problemset=problem_set, user=request.user).exists(): return self.error("已经加入该题单") # 创建进度记录 - progress = ProblemSetProgress.objects.create( - problemset=problem_set, user=request.user - ) + progress = ProblemSetProgress.objects.create(problemset=problem_set, user=request.user) progress.update_progress() return self.success("成功加入题单") @@ -187,18 +144,12 @@ class ProblemSetProgressAPI(APIView): def get(self, request, problem_set_id): """获取题单进度""" try: - problem_set = ( - ProblemSet.objects.filter(id=problem_set_id, visible=True) - .exclude(status="draft") - .get() - ) + problem_set = ProblemSet.objects.filter(id=problem_set_id, visible=True).exclude(status="draft").get() except ProblemSet.DoesNotExist: return self.error("题单不存在") try: - progress = ProblemSetProgress.objects.get( - problemset=problem_set, user=request.user - ) + progress = ProblemSetProgress.objects.get(problemset=problem_set, user=request.user) except ProblemSetProgress.DoesNotExist: return self.error("未加入该题单") @@ -210,18 +161,12 @@ class ProblemSetProgressAPI(APIView): """更新进度""" data = request.data try: - problem_set = ( - ProblemSet.objects.filter(id=data["problemset_id"], visible=True) - .exclude(status="draft") - .get() - ) + problem_set = ProblemSet.objects.filter(id=data["problemset_id"], visible=True).exclude(status="draft").get() except ProblemSet.DoesNotExist: return self.error("题单不存在") try: - progress = ProblemSetProgress.objects.get( - problemset=problem_set, user=request.user - ) + progress = ProblemSetProgress.objects.get(problemset=problem_set, user=request.user) except ProblemSetProgress.DoesNotExist: return self.error("未加入该题单") @@ -230,9 +175,7 @@ class ProblemSetProgressAPI(APIView): # 获取该题目在题单中的分值 try: - problemset_problem = ProblemSetProblem.objects.get( - problemset=problem_set, problem_id=problem_id - ) + problemset_problem = ProblemSetProblem.objects.get(problemset=problem_set, problem_id=problem_id) problem_score = problemset_problem.score except ProblemSetProblem.DoesNotExist: problem_score = 0 @@ -296,9 +239,7 @@ class UserProgressAPI(APIView): def get(self, request): """获取用户的题单进度列表""" - progress_list = ProblemSetProgress.objects.filter(user=request.user).order_by( - "-join_time" - ) + progress_list = ProblemSetProgress.objects.filter(user=request.user).order_by("-join_time") serializer = ProblemSetProgressSerializer(progress_list, many=True) return self.success(serializer.data) @@ -315,16 +256,12 @@ class UserBadgeAPI(APIView): # 获取指定用户的徽章 try: target_user = User.objects.get(username=username, is_disabled=False) - badges = UserBadge.objects.filter(user=target_user).order_by( - "-earned_time" - ) + badges = UserBadge.objects.filter(user=target_user).order_by("-earned_time") except User.DoesNotExist: return self.error("用户不存在") else: # 获取当前用户的徽章 - badges = UserBadge.objects.filter(user=request.user).order_by( - "-earned_time" - ) + badges = UserBadge.objects.filter(user=request.user).order_by("-earned_time") serializer = UserBadgeSerializer(badges, many=True) return self.success(serializer.data) @@ -336,11 +273,7 @@ class ProblemSetBadgeAPI(APIView): def get(self, request, problem_set_id): """获取题单的奖章列表""" try: - problem_set = ( - ProblemSet.objects.filter(id=problem_set_id, visible=True) - .exclude(status="draft") - .get() - ) + problem_set = ProblemSet.objects.filter(id=problem_set_id, visible=True).exclude(status="draft").get() except ProblemSet.DoesNotExist: return self.error("题单不存在") @@ -355,18 +288,12 @@ class ProblemSetUserProgressAPI(APIView): def get(self, request, problem_set_id: int): """获取题单的用户进度列表""" try: - problem_set = ( - ProblemSet.objects.filter(id=problem_set_id, visible=True) - .exclude(status="draft") - .get() - ) + problem_set = ProblemSet.objects.filter(id=problem_set_id, visible=True).exclude(status="draft").get() except ProblemSet.DoesNotExist: return self.error("题单不存在") # 获取所有参与该题单的用户进度,使用 select_related 预加载用户信息 - progresses = ProblemSetProgress.objects.filter( - problemset=problem_set - ).select_related("user") + progresses = ProblemSetProgress.objects.filter(problemset=problem_set).select_related("user") # 班级过滤 class_name = request.GET.get("class_name", "").strip() @@ -386,9 +313,7 @@ class ProblemSetUserProgressAPI(APIView): progresses = progresses.filter(completed_problems_count=0) # 排序 - progresses = progresses.order_by( - "-is_completed", "-progress_percentage", "join_time" - ) + progresses = progresses.order_by("-is_completed", "-progress_percentage", "join_time") # 计算统计数据(基于所有数据,而非分页数据) # 使用一次查询获取所有统计数据 @@ -416,12 +341,9 @@ class ProblemSetUserProgressAPI(APIView): # 提前获取题单的所有题目(用于前端显示未完成题目和序列化器) # 使用 select_related 和 only 优化查询,只选择需要的字段 all_problemset_problems = ( - ProblemSetProblem.objects.filter(problemset=problem_set) - .select_related("problem") - .only("problem__id", "problem___id", "problem__title", "order") - .order_by("order") + ProblemSetProblem.objects.filter(problemset=problem_set).select_related("problem").only("problem__id", "problem___id", "problem__title", "order").order_by("order") ) - + # 构建题单所有题目的数据结构和映射 all_problems_list = [] all_problems_map = {} @@ -444,11 +366,7 @@ class ProblemSetUserProgressAPI(APIView): completed_problem_ids.update(progress.progress_detail.keys()) # 从已加载的题单题目中构建 problems_dict,避免重复查询 - problems_dict = { - pid: all_problems_map[pid] - for pid in completed_problem_ids - if pid in all_problems_map - } + problems_dict = {pid: all_problems_map[pid] for pid in completed_problem_ids if pid in all_problems_map} # 将预加载的问题字典存储到 request 中,供序列化器使用 request._problems_dict_cache = problems_dict diff --git a/submission/migrations/0005_alter_submission_result.py b/submission/migrations/0005_alter_submission_result.py new file mode 100644 index 0000000..48bc301 --- /dev/null +++ b/submission/migrations/0005_alter_submission_result.py @@ -0,0 +1,33 @@ +# Generated by Django 6.0.4 on 2026-05-09 08:18 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("submission", "0004_submission_problem_user_idx"), + ] + + operations = [ + migrations.AlterField( + model_name="submission", + name="result", + field=models.IntegerField( + choices=[ + (-2, "Compile Error"), + (-1, "Wrong Answer"), + (0, "Accepted"), + (1, "CPU Time Limit Exceeded"), + (2, "Real Time Limit Exceeded"), + (3, "Memory Limit Exceeded"), + (4, "Runtime Error"), + (5, "System Error"), + (6, "Pending"), + (7, "Judging"), + (8, "Partially Accepted"), + ], + db_index=True, + default=6, + ), + ), + ] diff --git a/submission/models.py b/submission/models.py index 82238b5..cfe7cdd 100644 --- a/submission/models.py +++ b/submission/models.py @@ -7,18 +7,18 @@ from utils.models import JSONField from utils.shortcuts import rand_str -class JudgeStatus: - COMPILE_ERROR = -2 - WRONG_ANSWER = -1 - ACCEPTED = 0 - CPU_TIME_LIMIT_EXCEEDED = 1 - REAL_TIME_LIMIT_EXCEEDED = 2 - MEMORY_LIMIT_EXCEEDED = 3 - RUNTIME_ERROR = 4 - SYSTEM_ERROR = 5 - PENDING = 6 - JUDGING = 7 - PARTIALLY_ACCEPTED = 8 +class JudgeStatus(models.IntegerChoices): + COMPILE_ERROR = -2, "Compile Error" + WRONG_ANSWER = -1, "Wrong Answer" + ACCEPTED = 0, "Accepted" + CPU_TIME_LIMIT_EXCEEDED = 1, "CPU Time Limit Exceeded" + REAL_TIME_LIMIT_EXCEEDED = 2, "Real Time Limit Exceeded" + MEMORY_LIMIT_EXCEEDED = 3, "Memory Limit Exceeded" + RUNTIME_ERROR = 4, "Runtime Error" + SYSTEM_ERROR = 5, "System Error" + PENDING = 6, "Pending" + JUDGING = 7, "Judging" + PARTIALLY_ACCEPTED = 8, "Partially Accepted" class Submission(models.Model): @@ -29,7 +29,7 @@ class Submission(models.Model): user_id = models.IntegerField(db_index=True) username = models.TextField() code = models.TextField() - result = models.IntegerField(db_index=True, default=JudgeStatus.PENDING) + result = models.IntegerField(choices=JudgeStatus.choices, db_index=True, default=JudgeStatus.PENDING) # 从JudgeServer返回的判题详情 info = JSONField(default=dict) language = models.TextField() @@ -40,11 +40,7 @@ class Submission(models.Model): ip = models.TextField(null=True) def check_user_permission(self, user, check_share=True): - if ( - self.user_id == user.id - or not user.is_regular_user() - or self.problem.created_by_id == user.id - ): + if self.user_id == user.id or not user.is_regular_user() or self.problem.created_by_id == user.id: return True if check_share: @@ -58,15 +54,9 @@ class Submission(models.Model): db_table = "submission" ordering = ("-create_time",) indexes = [ - models.Index( - fields=["user_id", "create_time"], name="user_create_time_idx" - ), - models.Index( - fields=["contest_id", "-create_time"], name="contest_create_time_idx" - ), - models.Index( - fields=["problem_id", "user_id"], name="problem_user_idx" - ), + models.Index(fields=["user_id", "create_time"], name="user_create_time_idx"), + models.Index(fields=["contest_id", "-create_time"], name="contest_create_time_idx"), + models.Index(fields=["problem_id", "user_id"], name="problem_user_idx"), ] def __str__(self): diff --git a/tutorial/models.py b/tutorial/models.py index 350dc73..6e19866 100644 --- a/tutorial/models.py +++ b/tutorial/models.py @@ -3,16 +3,22 @@ from django.db import models from account.models import User +class TutorialType(models.TextChoices): + PYTHON = "python", "Python" + C = "c", "C" + + +class ExerciseType(models.TextChoices): + MCQ = "mcq", "选择题" + SORT = "sort", "代码排序" + FILL = "fill", "代码填空" + + class Tutorial(models.Model): - TYPE_CHOICES = [ - ('python', 'Python'), - ('c', 'C'), - ] - title = models.CharField(max_length=128) content = models.TextField() code = models.TextField(null=True, blank=True) - type = models.CharField(max_length=10, choices=TYPE_CHOICES, default='python') + type = models.CharField(max_length=10, choices=TutorialType.choices, default=TutorialType.PYTHON) created_by = models.ForeignKey(User, on_delete=models.CASCADE) created_at = models.DateTimeField(auto_now_add=True) updated_at = models.DateTimeField(auto_now=True) @@ -21,21 +27,15 @@ class Tutorial(models.Model): class Meta: db_table = "tutorial" - ordering = ['order', '-created_at'] + ordering = ["order", "-created_at"] def __str__(self): return self.title class Exercise(models.Model): - TYPE_CHOICES = [ - ("mcq", "选择题"), - ("sort", "代码排序"), - ("fill", "代码填空"), - ] - tutorial = models.ForeignKey(Tutorial, on_delete=models.CASCADE, related_name="exercises") - type = models.CharField(max_length=16, choices=TYPE_CHOICES) + type = models.CharField(max_length=16, choices=ExerciseType.choices) data = models.JSONField() order = models.IntegerField(default=0) created_at = models.DateTimeField(auto_now_add=True) @@ -45,4 +45,4 @@ class Exercise(models.Model): ordering = ["order", "created_at"] def __str__(self): - return f"{self.get_type_display()} (Order {self.order})" \ No newline at end of file + return f"{self.get_type_display()} (Order {self.order})" diff --git a/tutorial/serializers.py b/tutorial/serializers.py index 8e0ef00..ba81b70 100644 --- a/tutorial/serializers.py +++ b/tutorial/serializers.py @@ -2,7 +2,7 @@ from rest_framework import serializers from account.serializers import UserSerializer -from .models import Exercise, Tutorial +from .models import Exercise, ExerciseType, Tutorial class TutorialListSerializer(serializers.ModelSerializer): @@ -65,13 +65,13 @@ class ExerciseSerializer(serializers.ModelSerializer): class CreateExerciseSerializer(serializers.Serializer): tutorial_id = serializers.IntegerField() - type = serializers.ChoiceField(choices=["mcq", "sort", "fill"]) + type = serializers.ChoiceField(choices=ExerciseType.choices) data = serializers.JSONField() order = serializers.IntegerField(default=0) class EditExerciseSerializer(serializers.Serializer): id = serializers.IntegerField() - type = serializers.ChoiceField(choices=["mcq", "sort", "fill"]) + type = serializers.ChoiceField(choices=ExerciseType.choices) data = serializers.JSONField() order = serializers.IntegerField(default=0) diff --git a/utils/constants.py b/utils/constants.py index e71d991..320ffb6 100644 --- a/utils/constants.py +++ b/utils/constants.py @@ -1,24 +1,20 @@ -class Choices: - @classmethod - def choices(cls): - d = cls.__dict__ - return [d[item] for item in d.keys() if not item.startswith("__")] +from django.db import models -class ContestType: - PUBLIC_CONTEST = "Public" - PASSWORD_PROTECTED_CONTEST = "Password Protected" +class ContestType(models.TextChoices): + PUBLIC_CONTEST = "Public", "Public" + PASSWORD_PROTECTED_CONTEST = "Password Protected", "Password Protected" -class ContestStatus: - CONTEST_NOT_START = "1" - CONTEST_ENDED = "-1" - CONTEST_UNDERWAY = "0" +class ContestStatus(models.TextChoices): + CONTEST_NOT_START = "1", "Not Started" + CONTEST_ENDED = "-1", "Ended" + CONTEST_UNDERWAY = "0", "Underway" -class ContestRuleType(Choices): - ACM = "ACM" - OI = "OI" +class ContestRuleType(models.TextChoices): + ACM = "ACM", "ACM" + OI = "OI", "OI" class CacheKey: @@ -31,10 +27,10 @@ class CacheKey: user_activity_rank = "user_activity_rank" -class Difficulty(Choices): - LOW = "Low" - MID = "Mid" - HIGH = "High" +class Difficulty(models.TextChoices): + LOW = "Low", "Low" + MID = "Mid", "Mid" + HIGH = "High", "High" CONTEST_PASSWORD_SESSION_KEY = "contest_password" diff --git a/utils/migrate_data.py b/utils/migrate_data.py index b9b9bd3..ec68824 100644 --- a/utils/migrate_data.py +++ b/utils/migrate_data.py @@ -12,18 +12,11 @@ os.environ.setdefault("DJANGO_SETTINGS_MODULE", "oj.settings") django.setup() from django.conf import settings from account.models import User, UserProfile, AdminType, ProblemPermission -from problem.models import Problem, ProblemTag, ProblemDifficulty, ProblemRuleType +from problem.models import Problem, ProblemTag, ProblemRuleType +from utils.constants import Difficulty -admin_type_map = { - 0: AdminType.REGULAR_USER, - 1: AdminType.ADMIN, - 2: AdminType.SUPER_ADMIN -} -languages_map = { - 1: "C", - 2: "C++", - 3: "Java" -} +admin_type_map = {0: AdminType.REGULAR_USER, 1: AdminType.ADMIN, 2: AdminType.SUPER_ADMIN} +languages_map = {1: "C", 2: "C++", 3: "Java"} email_regex = re.compile(r"(^[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+$)") # pk -> name @@ -46,9 +39,7 @@ def get_input_result(): def set_problem_display_id_prefix(): while True: print("Please input a prefix which will be used in all the imported problem's displayID") - print( - "For example, if your input is 'old'(no quote), the problems' display id will be old1, old2, old3..\ninput:", - end="") + print("For example, if your input is 'old'(no quote), the problems' display id will be old1, old2, old3..\ninput:", end="") resp = input() if resp.strip(): return resp.strip() @@ -60,8 +51,8 @@ def set_problem_display_id_prefix(): def get_stripped_output_md5(test_case_id, output_name): output_path = os.path.join(settings.TEST_CASE_DIR, test_case_id, output_name) - with open(output_path, 'r') as f: - return hashlib.md5(f.read().rstrip().encode('utf-8')).hexdigest() + with open(output_path, "r") as f: + return hashlib.md5(f.read().rstrip().encode("utf-8")).hexdigest() def get_test_case_score(test_case_id): @@ -79,9 +70,7 @@ def get_test_case_score(test_case_id): test_case["stripped_output_md5"] = test_case.pop("striped_output_md5") else: test_case["stripped_output_md5"] = get_stripped_output_md5(test_case_id, test_case["output_name"]) - test_case_score.append({"input_name": test_case["input_name"], - "output_name": test_case.get("output_name", "-"), - "score": 0}) + test_case_score.append({"input_name": test_case["input_name"], "output_name": test_case.get("output_name", "-"), "score": 0}) if need_rewrite: with open(info_path, "w") as f: f.write(json.dumps(info)) @@ -120,7 +109,7 @@ def import_users(): def import_tags(): i = 0 print("\nFind these tags in old data:") - print(", ".join(tags.values()), '\n') + print(", ".join(tags.values()), "\n") print("import tags now? (yes/no)") if get_input_result(): for tagname in tags.values(): @@ -149,14 +138,13 @@ def import_problems(): print("%s has the same display_id with the db problem" % data["title"]) continue try: - creator_id = \ - User.objects.filter(username=users[data["created_by"]]["username"]).values_list("id", flat=True)[0] + creator_id = User.objects.filter(username=users[data["created_by"]]["username"]).values_list("id", flat=True)[0] except (User.DoesNotExist, IndexError): print("The origin creator does not exist, set it to default_creator") creator_id = default_creator.id data["created_by_id"] = creator_id data.pop("created_by") - data["difficulty"] = ProblemDifficulty.Mid + data["difficulty"] = Difficulty.MID if data["spj_language"]: data["spj_language"] = languages_map[data["spj_language"]] data["samples"] = json.loads(data["samples"])