Compare commits

..

10 Commits

Author SHA1 Message Date
a1b51ebb9e update 2025-07-14 21:41:27 +08:00
a9a6b87fef test for asgi 2025-07-14 21:33:03 +08:00
2d3588c755 revert 2025-06-15 20:26:43 +08:00
a2bfc28ac7 test 2025-06-15 20:21:37 +08:00
6aac767641 test 2025-06-15 20:18:24 +08:00
73af9d96b2 test 2025-06-15 20:15:49 +08:00
8a2fa11afc test 2025-06-15 20:12:48 +08:00
3f1c7250bd test 2025-06-15 20:06:50 +08:00
bd0a7f30f8 test 2025-06-15 19:35:11 +08:00
8a043d2ffa test 2025-06-15 19:26:45 +08:00
94 changed files with 1354 additions and 5857 deletions

View File

@@ -1,18 +0,0 @@
# Generated by Django 5.2.3 on 2025-09-19 06:11
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('account', '0001_initial'),
]
operations = [
migrations.AddField(
model_name='userprofile',
name='class_name',
field=models.TextField(null=True),
),
]

View File

@@ -1,22 +0,0 @@
# Generated by Django 5.2.3 on 2025-09-19 06:14
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('account', '0002_userprofile_class_name'),
]
operations = [
migrations.RemoveField(
model_name='userprofile',
name='class_name',
),
migrations.AddField(
model_name='user',
name='class_name',
field=models.TextField(null=True),
),
]

View File

@@ -25,7 +25,6 @@ class UserManager(models.Manager):
class User(AbstractBaseUser): class User(AbstractBaseUser):
username = models.TextField(unique=True) username = models.TextField(unique=True)
class_name = models.TextField(null=True)
email = models.TextField(null=True) email = models.TextField(null=True)
create_time = models.DateTimeField(auto_now_add=True, null=True) create_time = models.DateTimeField(auto_now_add=True, null=True)
# One of UserType # One of UserType
@@ -51,9 +50,6 @@ class User(AbstractBaseUser):
objects = UserManager() objects = UserManager()
def is_regular_user(self):
return self.admin_type == AdminType.REGULAR_USER
def is_admin(self): def is_admin(self):
return self.admin_type == AdminType.ADMIN return self.admin_type == AdminType.ADMIN

View File

@@ -67,7 +67,6 @@ class UserAdminSerializer(serializers.ModelSerializer):
"open_api", "open_api",
"is_disabled", "is_disabled",
"raw_password", "raw_password",
"class_name",
] ]
def get_real_name(self, obj): def get_real_name(self, obj):
@@ -94,7 +93,6 @@ class UserSerializer(serializers.ModelSerializer):
"two_factor_auth", "two_factor_auth",
"open_api", "open_api",
"is_disabled", "is_disabled",
"class_name",
] ]
@@ -131,7 +129,7 @@ class EditUserSerializer(serializers.Serializer):
open_api = serializers.BooleanField() open_api = serializers.BooleanField()
two_factor_auth = serializers.BooleanField() two_factor_auth = serializers.BooleanField()
is_disabled = serializers.BooleanField() is_disabled = serializers.BooleanField()
class_name = serializers.CharField(required=False, allow_null=True, allow_blank=True)
class EditUserProfileSerializer(serializers.Serializer): class EditUserProfileSerializer(serializers.Serializer):
real_name = serializers.CharField(max_length=32, allow_null=True, required=False) real_name = serializers.CharField(max_length=32, allow_null=True, required=False)
@@ -143,6 +141,7 @@ class EditUserProfileSerializer(serializers.Serializer):
major = serializers.CharField(max_length=64, allow_blank=True, required=False) major = serializers.CharField(max_length=64, allow_blank=True, required=False)
language = serializers.CharField(max_length=32, allow_blank=True, required=False) language = serializers.CharField(max_length=32, allow_blank=True, required=False)
class ApplyResetPasswordSerializer(serializers.Serializer): class ApplyResetPasswordSerializer(serializers.Serializer):
email = serializers.EmailField() email = serializers.EmailField()
captcha = serializers.CharField() captcha = serializers.CharField()

View File

@@ -1,9 +1,8 @@
from django.urls import path from django.urls import path
from ..views.admin import UserAdminAPI, GenerateUserAPI, ResetUserPasswordAPI from ..views.admin import UserAdminAPI, GenerateUserAPI
urlpatterns = [ urlpatterns = [
path("user", UserAdminAPI.as_view()), path("user", UserAdminAPI.as_view()),
path("generate_user", GenerateUserAPI.as_view()), path("generate_user", GenerateUserAPI.as_view()),
path("reset_password", ResetUserPasswordAPI.as_view()),
] ]

View File

@@ -15,7 +15,6 @@ from ..views.oj import (
UserProfileAPI, UserProfileAPI,
UserRankAPI, UserRankAPI,
UserActivityRankAPI, UserActivityRankAPI,
UserProblemRankAPI,
CheckTFARequiredAPI, CheckTFARequiredAPI,
SessionManagementAPI, SessionManagementAPI,
ProfileProblemDisplayIDRefreshAPI, ProfileProblemDisplayIDRefreshAPI,
@@ -46,7 +45,6 @@ urlpatterns = [
), ),
path("user_rank", UserRankAPI.as_view()), path("user_rank", UserRankAPI.as_view()),
path("user_activity_rank", UserActivityRankAPI.as_view()), path("user_activity_rank", UserActivityRankAPI.as_view()),
path("user_problem_rank", UserProblemRankAPI.as_view()),
path("sessions", SessionManagementAPI.as_view()), path("sessions", SessionManagementAPI.as_view()),
path( path(
"open_api_appkey", "open_api_appkey",

View File

@@ -3,10 +3,9 @@ import re
import xlsxwriter import xlsxwriter
from django.db import transaction, IntegrityError from django.db import transaction, IntegrityError
from django.db.models import Q, F from django.db.models import Q
from django.http import HttpResponse from django.http import HttpResponse
from django.contrib.auth.hashers import make_password from django.contrib.auth.hashers import make_password
from django.utils.crypto import get_random_string
from submission.models import Submission from submission.models import Submission
from utils.api import APIView, validate_serializer from utils.api import APIView, validate_serializer
@@ -22,19 +21,6 @@ from ..serializers import (
from ..serializers import ImportUserSerializer from ..serializers import ImportUserSerializer
# ks251XXX 或者 ks2510XX 返回 251 或者 2510
# 其他返回 None
def get_class_name(username):
if username.startswith("ks"):
result = re.search(r"ks\d+", username)
if result:
return result.group(0)[2:]
else:
return None
else:
return None
class UserAdminAPI(APIView): class UserAdminAPI(APIView):
@validate_serializer(ImportUserSerializer) @validate_serializer(ImportUserSerializer)
@super_admin_required @super_admin_required
@@ -54,7 +40,6 @@ class UserAdminAPI(APIView):
password=make_password(user_data[1]), password=make_password(user_data[1]),
email=user_data[2], email=user_data[2],
raw_password=user_data[1], raw_password=user_data[1],
class_name=get_class_name(user_data[0]),
) )
) )
@@ -100,13 +85,12 @@ class UserAdminAPI(APIView):
pre_username = user.username pre_username = user.username
user.username = data["username"].lower() user.username = data["username"].lower()
user.class_name = get_class_name(data["username"])
user.email = data["email"].lower() user.email = data["email"].lower()
user.admin_type = data["admin_type"] user.admin_type = data["admin_type"]
user.is_disabled = data["is_disabled"] user.is_disabled = data["is_disabled"]
if data["admin_type"] == AdminType.ADMIN: if data["admin_type"] == AdminType.ADMIN:
user.problem_permission = data["problem_permission"] or ProblemPermission.OWN user.problem_permission = data["problem_permission"]
elif data["admin_type"] == AdminType.SUPER_ADMIN: elif data["admin_type"] == AdminType.SUPER_ADMIN:
user.problem_permission = ProblemPermission.ALL user.problem_permission = ProblemPermission.ALL
else: else:
@@ -154,24 +138,12 @@ class UserAdminAPI(APIView):
return self.error("User does not exist") return self.error("User does not exist")
return self.success(UserAdminSerializer(user).data) return self.success(UserAdminSerializer(user).data)
# 获取排序参数 user = User.objects.all().order_by("-create_time")
order_by = request.GET.get("order_by", "")
# 根据排序参数设置排序规则
if order_by == "-last_login":
# 最近登录,将 None 值放在最后
user = User.objects.all().order_by(F("last_login").desc(nulls_last=True))
elif order_by == "last_login":
# 最早登录,将 None 值放在最后
user = User.objects.all().order_by(F("last_login").asc(nulls_last=True))
else:
# 默认按创建时间倒序
user = User.objects.all().order_by("-create_time")
type = request.GET.get("type", "") is_admin = request.GET.get("admin", "0")
if type: if is_admin == "1":
user = user.filter(admin_type=type) user = user.exclude(admin_type=AdminType.REGULAR_USER)
keyword = request.GET.get("keyword", None) keyword = request.GET.get("keyword", None)
if keyword: if keyword:
@@ -267,27 +239,3 @@ class GenerateUserAPI(APIView):
# duplicate key value violates unique constraint "user_username_key" # duplicate key value violates unique constraint "user_username_key"
# DETAIL: Key (username)=(root11) already exists. # DETAIL: Key (username)=(root11) already exists.
return self.error(str(e).split("\n")[1]) return self.error(str(e).split("\n")[1])
class ResetUserPasswordAPI(APIView):
@super_admin_required
def post(self, request):
"""
重置用户密码为随机6位数字(不包括0)
"""
data = request.data
user_id = data["id"]
try:
user = User.objects.get(id=user_id)
except User.DoesNotExist:
return self.error("User does not exist")
# 生成6位随机数字密码(不包括0)
new_password = get_random_string(6, allowed_chars="123456789")
# 设置新密码
user.set_password(new_password)
user.save()
return self.success(new_password)

View File

@@ -12,7 +12,7 @@ from django.db.models import Count, Q
from django.utils import timezone from django.utils import timezone
import qrcode import qrcode
from otpauth import OtpAuth from otpauth import TOTP
from problem.models import Problem from problem.models import Problem
from submission.models import Submission, JudgeStatus from submission.models import Submission, JudgeStatus
@@ -143,7 +143,7 @@ class TwoFactorAuthAPI(APIView):
label = f"{SysOptions.website_name_shortcut}:{user.username}" label = f"{SysOptions.website_name_shortcut}:{user.username}"
image = qrcode.make( image = qrcode.make(
OtpAuth(token).to_uri( TOTP(token).to_uri(
"totp", label, SysOptions.website_name.replace(" ", "") "totp", label, SysOptions.website_name.replace(" ", "")
) )
) )
@@ -157,7 +157,7 @@ class TwoFactorAuthAPI(APIView):
""" """
code = request.data["code"] code = request.data["code"]
user = request.user user = request.user
if OtpAuth(user.tfa_token).valid_totp(code): if TOTP(user.tfa_token).verify(code):
user.two_factor_auth = True user.two_factor_auth = True
user.save() user.save()
return self.success("Succeeded") return self.success("Succeeded")
@@ -171,7 +171,7 @@ class TwoFactorAuthAPI(APIView):
user = request.user user = request.user
if not user.two_factor_auth: if not user.two_factor_auth:
return self.error("2FA is already turned off") return self.error("2FA is already turned off")
if OtpAuth(user.tfa_token).valid_totp(code): if TOTP(user.tfa_token).verify(code):
user.two_factor_auth = False user.two_factor_auth = False
user.save() user.save()
return self.success("Succeeded") return self.success("Succeeded")
@@ -216,7 +216,7 @@ class UserLoginAPI(APIView):
if user.two_factor_auth and "tfa_code" not in data: if user.two_factor_auth and "tfa_code" not in data:
return self.error("tfa_required") return self.error("tfa_required")
if OtpAuth(user.tfa_token).valid_totp(data["tfa_code"]): if TOTP(user.tfa_token).verify(data["tfa_code"]):
auth.login(request, user) auth.login(request, user)
return self.success("Succeeded") return self.success("Succeeded")
else: else:
@@ -287,7 +287,7 @@ class UserChangeEmailAPI(APIView):
if user.two_factor_auth: if user.two_factor_auth:
if "tfa_code" not in data: if "tfa_code" not in data:
return self.error("tfa_required") return self.error("tfa_required")
if not OtpAuth(user.tfa_token).valid_totp(data["tfa_code"]): if not TOTP(user.tfa_token).verify(data["tfa_code"]):
return self.error("Invalid two factor verification code") return self.error("Invalid two factor verification code")
data["new_email"] = data["new_email"].lower() data["new_email"] = data["new_email"].lower()
if User.objects.filter(email=data["new_email"]).exists(): if User.objects.filter(email=data["new_email"]).exists():
@@ -313,7 +313,7 @@ class UserChangePasswordAPI(APIView):
if user.two_factor_auth: if user.two_factor_auth:
if "tfa_code" not in data: if "tfa_code" not in data:
return self.error("tfa_required") return self.error("tfa_required")
if not OtpAuth(user.tfa_token).valid_totp(data["tfa_code"]): if not TOTP(user.tfa_token).verify(data["tfa_code"]):
return self.error("Invalid two factor verification code") return self.error("Invalid two factor verification code")
user.set_password(data["new_password"]) user.set_password(data["new_password"])
user.save() user.save()
@@ -434,9 +434,8 @@ class UserRankAPI(APIView):
n = 0 n = 0
if rule_type not in ContestRuleType.choices(): if rule_type not in ContestRuleType.choices():
rule_type = ContestRuleType.ACM rule_type = ContestRuleType.ACM
profiles = UserProfile.objects.filter( profiles = UserProfile.objects.filter(
user__admin_type__in=[AdminType.REGULAR_USER, AdminType.ADMIN], user__admin_type=AdminType.REGULAR_USER,
user__is_disabled=False, user__is_disabled=False,
user__username__icontains=username, user__username__icontains=username,
).select_related("user") ).select_related("user")
@@ -457,72 +456,23 @@ class UserActivityRankAPI(APIView):
if not start: if not start:
return self.error("start time is required") return self.error("start time is required")
hidden_names = User.objects.filter( hidden_names = User.objects.filter(
Q(admin_type=AdminType.SUPER_ADMIN) | Q(is_disabled=True) Q(admin_type=AdminType.SUPER_ADMIN)
| Q(admin_type=AdminType.ADMIN)
| Q(is_disabled=True)
).values_list("username", flat=True) ).values_list("username", flat=True)
submissions = Submission.objects.filter( submissions = Submission.objects.filter(
contest_id__isnull=True, contest_id__isnull=True, create_time__gte=start, result=JudgeStatus.ACCEPTED
create_time__gte=start, )
result=JudgeStatus.ACCEPTED, counts = (
).exclude(username__in=hidden_names)
data = list(
submissions.values("username") submissions.values("username")
.annotate(count=Count("problem_id", distinct=True)) .annotate(count=Count("problem_id", distinct=True))
.order_by("-count")[:10] .order_by("-count")[: 10 + len(hidden_names)]
)
return self.success(data)
class UserProblemRankAPI(APIView):
def get(self, request):
problem_id = request.GET.get("problem_id")
user = request.user
if not user.is_authenticated:
return self.error("User is not authenticated")
problem = Problem.objects.get(
_id=problem_id, contest_id__isnull=True, visible=True
)
submissions = Submission.objects.filter(
problem=problem, result=JudgeStatus.ACCEPTED
)
all_ac_count = submissions.values("user_id").distinct().count()
class_name = user.class_name or ""
class_ac_count = 0
if class_name:
users = User.objects.filter(
class_name=user.class_name, is_disabled=False
).values_list("id", flat=True)
user_ids = list(users)
submissions = submissions.filter(user_id__in=user_ids)
class_ac_count = submissions.values("user_id").distinct().count()
my_submissions = submissions.filter(user_id=user.id)
if len(my_submissions) == 0:
return self.success(
{
"class_name": class_name,
"rank": -1,
"class_ac_count": class_ac_count,
"all_ac_count": all_ac_count,
}
)
my_first_submission = my_submissions.order_by("create_time").first()
rank = submissions.filter(
create_time__lte=my_first_submission.create_time
).count()
return self.success(
{
"class_name": class_name,
"rank": rank,
"class_ac_count": class_ac_count,
"all_ac_count": all_ac_count,
}
) )
data = []
for count in counts:
if count["username"] not in hidden_names:
data.append(count)
return self.success(data[:10])
class ProfileProblemDisplayIDRefreshAPI(APIView): class ProfileProblemDisplayIDRefreshAPI(APIView):

View File

View File

@@ -1,6 +0,0 @@
from django.apps import AppConfig
class AiConfig(AppConfig):
default_auto_field = 'django.db.models.BigAutoField'
name = 'ai'

View File

@@ -1,34 +0,0 @@
# Generated by Django 5.2.3 on 2025-09-24 12:59
import django.db.models.deletion
from django.conf import settings
from django.db import migrations, models
class Migration(migrations.Migration):
initial = True
dependencies = [
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
]
operations = [
migrations.CreateModel(
name='AIAnalysis',
fields=[
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('provider', models.TextField(default='deepseek')),
('data', models.JSONField()),
('system_prompt', models.TextField()),
('user_prompt', models.TextField()),
('analysis', models.TextField()),
('create_time', models.DateTimeField(auto_now_add=True)),
('user', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL)),
],
options={
'db_table': 'ai_analysis',
'ordering': ['-create_time'],
},
),
]

View File

@@ -1,18 +0,0 @@
# Generated by Django 5.2.3 on 2025-09-24 13:02
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('ai', '0001_initial'),
]
operations = [
migrations.AddField(
model_name='aianalysis',
name='model',
field=models.TextField(default='deepseek-chat'),
),
]

View File

@@ -1,18 +0,0 @@
from django.db import models
from account.models import User
class AIAnalysis(models.Model):
user = models.ForeignKey(User, on_delete=models.CASCADE)
provider = models.TextField(default="deepseek")
model = models.TextField(default="deepseek-chat")
data = models.JSONField()
system_prompt = models.TextField()
user_prompt = models.TextField()
analysis = models.TextField()
create_time = models.DateTimeField(auto_now_add=True)
class Meta:
db_table = "ai_analysis"
ordering = ["-create_time"]

View File

View File

View File

@@ -1,15 +0,0 @@
from django.urls import path
from ..views.oj import (
AIAnalysisAPI,
AIDetailDataAPI,
AIDurationDataAPI,
AIHeatmapDataAPI,
)
urlpatterns = [
path("ai/detail", AIDetailDataAPI.as_view()),
path("ai/duration", AIDurationDataAPI.as_view()),
path("ai/analysis", AIAnalysisAPI.as_view()),
path("ai/heatmap", AIHeatmapDataAPI.as_view()),
]

View File

View File

@@ -1,655 +0,0 @@
from collections import defaultdict
from datetime import datetime, timedelta
import hashlib
import json
from dateutil.relativedelta import relativedelta
from django.core.cache import cache
from django.db.models import Min, Count
from django.db.models.functions import TruncDate
from django.http import StreamingHttpResponse
from django.utils import timezone
from openai import OpenAI
from utils.api import APIView
from utils.shortcuts import get_env
from account.models import User
from problem.models import Problem
from submission.models import Submission, JudgeStatus
from flowchart.models import FlowchartSubmission, FlowchartSubmissionStatus
from account.decorators import login_required
from ai.models import AIAnalysis
CACHE_TIMEOUT = 300
DIFFICULTY_MAP = {"Low": "简单", "Mid": "中等", "High": "困难"}
DEFAULT_CLASS_SIZE = 45
# 评级阈值配置:(百分位上限, 评级)
GRADE_THRESHOLDS = [
(10, "S"), # 前10%: S级 - 卓越
(35, "A"), # 前35%: A级 - 优秀
(75, "B"), # 前75%: B级 - 良好
(100, "C"), # 其余: C级 - 及格
]
# 小规模参与惩罚配置:(最小人数, 等级降级映射)
SMALL_SCALE_PENALTY = {
"threshold": 10,
"downgrade": {"S": "A", "A": "B"},
}
def get_cache_key(prefix, *args):
return hashlib.md5(f"{prefix}:{'_'.join(map(str, args))}".encode()).hexdigest()
def get_difficulty(difficulty):
return DIFFICULTY_MAP.get(difficulty, "中等")
def get_grade(rank, submission_count):
"""
计算题目完成评级
评级标准:
- S级前10%卓越水平10%的人)
- A级前35%优秀水平25%的人)
- B级前75%良好水平40%的人)
- C级75%之后及格水平25%的人)
特殊规则:
- 参与人数少于10人时S级降为A级A级降为B级避免因人少而评级虚高
Args:
rank: 用户排名1表示第一名
submission_count: 总AC人数
Returns:
评级字符串 ("S", "A", "B", "C")
"""
# 边界检查
if not rank or rank <= 0 or submission_count <= 0:
return "C"
# 计算百分位0-100使用 (rank-1) 使第一名的百分位为0
percentile = (rank - 1) / submission_count * 100
# 根据百分位确定基础评级
base_grade = "C"
for threshold, grade in GRADE_THRESHOLDS:
if percentile < threshold:
base_grade = grade
break
# 小规模参与惩罚:人数太少时降低评级
if submission_count < SMALL_SCALE_PENALTY["threshold"]:
base_grade = SMALL_SCALE_PENALTY["downgrade"].get(base_grade, base_grade)
return base_grade
def get_class_user_ids(user):
if not user.class_name:
return []
cache_key = get_cache_key("class_users", user.class_name)
user_ids = cache.get(cache_key)
if user_ids is None:
user_ids = list(
User.objects.filter(class_name=user.class_name).values_list("id", flat=True)
)
cache.set(cache_key, user_ids, CACHE_TIMEOUT)
return user_ids
def get_user_first_ac_submissions(
user_id, start, end, class_user_ids=None, use_class_scope=False
):
base_qs = Submission.objects.filter(
result=JudgeStatus.ACCEPTED, create_time__gte=start, create_time__lte=end
)
if use_class_scope and class_user_ids:
base_qs = base_qs.filter(user_id__in=class_user_ids)
user_first_ac = list(
base_qs.filter(user_id=user_id)
.values("problem_id")
.annotate(first_ac_time=Min("create_time"))
)
if not user_first_ac:
return [], {}, []
problem_ids = [item["problem_id"] for item in user_first_ac]
ranked_first_ac = list(
base_qs.filter(problem_id__in=problem_ids)
.values("user_id", "problem_id")
.annotate(first_ac_time=Min("create_time"))
)
by_problem = defaultdict(list)
for item in ranked_first_ac:
by_problem[item["problem_id"]].append(item)
for submissions in by_problem.values():
submissions.sort(key=lambda x: (x["first_ac_time"], x["user_id"]))
return user_first_ac, by_problem, problem_ids
class AIDetailDataAPI(APIView):
@login_required
def get(self, request):
start = request.GET.get("start")
end = request.GET.get("end")
user = request.user
cache_key = get_cache_key(
"ai_detail", user.id, user.class_name or "", start, end
)
cached_result = cache.get(cache_key)
if cached_result:
return self.success(cached_result)
class_user_ids = get_class_user_ids(user)
use_class_scope = bool(user.class_name) and len(class_user_ids) > 1
user_first_ac, by_problem, problem_ids = get_user_first_ac_submissions(
user.id, start, end, class_user_ids, use_class_scope
)
result = {
"user": user.username,
"class_name": user.class_name,
"start": start,
"end": end,
"solved": [],
"flowcharts": [],
"grade": "",
"tags": {},
"difficulty": {},
"contest_count": 0,
}
if user_first_ac:
problems = {
p.id: p
for p in Problem.objects.filter(id__in=problem_ids)
.select_related("contest")
.prefetch_related("tags")
}
solved, contest_ids = self._build_solved_records(
user_first_ac, by_problem, problems, user.id
)
# 查找 flowchart submissions
flowcharts_query = FlowchartSubmission.objects.filter(
user_id=user,
status=FlowchartSubmissionStatus.COMPLETED,
)
# 添加时间范围过滤
if start:
flowcharts_query = flowcharts_query.filter(create_time__gte=start)
if end:
flowcharts_query = flowcharts_query.filter(create_time__lte=end)
flowcharts = flowcharts_query.select_related("problem").only(
"id",
"create_time",
"ai_score",
"ai_grade",
"problem___id",
"problem__title",
)
# 按problem分组
problem_groups = defaultdict(list)
for flowchart in flowcharts:
problem_id = flowchart.problem._id
problem_groups[problem_id].append(flowchart)
flowcharts_data = []
for problem_id, submissions in problem_groups.items():
if not submissions:
continue
# 获取第一个提交的基本信息
first_submission = submissions[0]
# 计算统计数据
scores = [s.ai_score for s in submissions if s.ai_score is not None]
times = [s.create_time for s in submissions]
# 找到最高分和对应的等级
best_score = max(scores) if scores else 0
best_submission = next(
(s for s in submissions if s.ai_score == best_score), submissions[0]
)
best_grade = best_submission.ai_grade or ""
# 计算平均分
avg_score = sum(scores) / len(scores) if scores else 0
# 最新提交时间
latest_time = max(times) if times else first_submission.create_time
merged_item = {
"problem__id": problem_id,
"problem_title": first_submission.problem.title,
"submission_count": len(submissions),
"best_score": best_score,
"best_grade": best_grade,
"latest_submission_time": latest_time.isoformat() if latest_time else None,
"avg_score": round(avg_score, 0),
}
flowcharts_data.append(merged_item)
# 按最新提交时间排序
flowcharts_data.sort(
key=lambda x: x["latest_submission_time"] or "", reverse=True
)
result.update(
{
"solved": solved,
"flowcharts": flowcharts_data,
"grade": self._calculate_average_grade(solved),
"tags": self._calculate_top_tags(problems.values()),
"difficulty": self._calculate_difficulty_distribution(
problems.values()
),
"contest_count": len(set(contest_ids)),
}
)
cache.set(cache_key, result, CACHE_TIMEOUT)
return self.success(result)
def _build_solved_records(self, user_first_ac, by_problem, problems, user_id):
solved, contest_ids = [], []
for item in user_first_ac:
pid = item["problem_id"]
problem = problems.get(pid)
if not problem:
continue
ranking_list = by_problem.get(pid, [])
rank = next(
(
idx + 1
for idx, rec in enumerate(ranking_list)
if rec["user_id"] == user_id
),
None,
)
if problem.contest_id:
contest_ids.append(problem.contest_id)
solved.append(
{
"problem": {
"display_id": problem._id,
"title": problem.title,
"contest_id": problem.contest_id,
"contest_title": getattr(problem.contest, "title", ""),
},
"ac_time": timezone.localtime(item["first_ac_time"]).isoformat(),
"rank": rank,
"ac_count": len(ranking_list),
"grade": get_grade(rank, len(ranking_list)),
"difficulty": get_difficulty(problem.difficulty),
}
)
return sorted(solved, key=lambda x: x["ac_time"]), contest_ids
def _calculate_average_grade(self, solved):
"""
计算平均等级,使用加权平均方法
等级权重S=4, A=3, B=2, C=1
计算加权平均后,根据阈值确定最终等级
Args:
solved: 已解决的题目列表每个包含grade字段
Returns:
平均等级字符串 ("S", "A", "B", "C")
"""
if not solved:
return ""
# 等级权重映射
grade_weights = {"S": 4, "A": 3, "B": 2, "C": 1}
# 计算加权总分
total_weight = 0
total_score = 0
for s in solved:
grade = s["grade"]
if grade in grade_weights:
total_score += grade_weights[grade]
total_weight += 1
if total_weight == 0:
return ""
# 计算平均权重
average_weight = total_score / total_weight
# 根据平均权重确定等级
# S级: 3.5-4.0, A级: 2.5-3.5, B级: 1.5-2.5, C级: 1.0-1.5
if average_weight >= 3.5:
return "S"
elif average_weight >= 2.5:
return "A"
elif average_weight >= 1.5:
return "B"
else:
return "C"
def _calculate_top_tags(self, problems):
tags_counter = defaultdict(int)
for problem in problems:
for tag in problem.tags.all():
if tag.name:
tags_counter[tag.name] += 1
return dict(sorted(tags_counter.items(), key=lambda x: x[1], reverse=True)[:5])
def _calculate_difficulty_distribution(self, problems):
diff_counter = {"Low": 0, "Mid": 0, "High": 0}
for problem in problems:
diff_counter[
problem.difficulty if problem.difficulty in diff_counter else "Mid"
] += 1
return {
get_difficulty(k): v
for k, v in sorted(diff_counter.items(), key=lambda x: x[1], reverse=True)
}
class AIDurationDataAPI(APIView):
@login_required
def get(self, request):
end_iso = request.GET.get("end")
duration = request.GET.get("duration")
user = request.user
cache_key = get_cache_key(
"ai_duration", user.id, user.class_name or "", end_iso, duration
)
cached_result = cache.get(cache_key)
if cached_result:
return self.success(cached_result)
class_user_ids = get_class_user_ids(user)
use_class_scope = bool(user.class_name) and len(class_user_ids) > 1
time_config = self._parse_duration(duration)
start = datetime.fromisoformat(end_iso) - time_config["total_delta"]
duration_data = []
for i in range(time_config["show_count"]):
start = start + time_config["delta"]
period_end = start + time_config["delta"]
submission_count = Submission.objects.filter(
user_id=user.id, create_time__gte=start, create_time__lte=period_end
).count()
period_data = {
"unit": time_config["show_unit"],
"index": time_config["show_count"] - 1 - i,
"start": start.isoformat(),
"end": period_end.isoformat(),
"problem_count": 0,
"submission_count": submission_count,
"grade": "",
}
if submission_count > 0:
user_first_ac, by_problem, problem_ids = get_user_first_ac_submissions(
user.id,
start.isoformat(),
period_end.isoformat(),
class_user_ids,
use_class_scope,
)
if user_first_ac:
period_data["problem_count"] = len(problem_ids)
period_data["grade"] = self._calculate_period_grade(
user_first_ac, by_problem, user.id
)
duration_data.append(period_data)
cache.set(cache_key, duration_data, CACHE_TIMEOUT)
return self.success(duration_data)
def _parse_duration(self, duration):
unit, count = duration.split(":")
count = int(count)
configs = {
("months", 2): {
"show_count": 8,
"show_unit": "weeks",
"total_delta": timedelta(weeks=9),
"delta": timedelta(weeks=1),
},
("months", 6): {
"show_count": 6,
"show_unit": "months",
"total_delta": relativedelta(months=7),
"delta": relativedelta(months=1),
},
("years", 1): {
"show_count": 12,
"show_unit": "months",
"total_delta": relativedelta(months=13),
"delta": relativedelta(months=1),
},
}
return configs.get(
(unit, count),
{
"show_count": 4,
"show_unit": "weeks",
"total_delta": timedelta(weeks=5),
"delta": timedelta(weeks=1),
},
)
def _calculate_period_grade(self, user_first_ac, by_problem, user_id):
"""
计算时间段内的平均等级,使用加权平均方法
等级权重S=4, A=3, B=2, C=1
计算加权平均后,根据阈值确定最终等级
Args:
user_first_ac: 用户首次AC的提交记录
by_problem: 按题目分组的排名数据
user_id: 用户ID
Returns:
平均等级字符串 ("S", "A", "B", "C")
"""
if not user_first_ac:
return ""
# 等级权重映射
grade_weights = {"S": 4, "A": 3, "B": 2, "C": 1}
# 计算加权总分
total_weight = 0
total_score = 0
for item in user_first_ac:
ranking_list = by_problem.get(item["problem_id"], [])
rank = next(
(
idx + 1
for idx, rec in enumerate(ranking_list)
if rec["user_id"] == user_id
),
None,
)
if rank:
grade = get_grade(rank, len(ranking_list))
if grade in grade_weights:
total_score += grade_weights[grade]
total_weight += 1
if total_weight == 0:
return ""
# 计算平均权重
average_weight = total_score / total_weight
# 根据平均权重确定等级
# S级: 3.5-4.0, A级: 2.5-3.5, B级: 1.5-2.5, C级: 1.0-1.5
if average_weight >= 3.5:
return "S"
elif average_weight >= 2.5:
return "A"
elif average_weight >= 1.5:
return "B"
else:
return "C"
class AIAnalysisAPI(APIView):
@login_required
def post(self, request):
details = request.data.get("details")
duration = request.data.get("duration")
api_key = get_env("AI_KEY")
if not api_key:
return self.error("API_KEY is not set")
client = OpenAI(api_key=api_key, base_url="https://api.deepseek.com")
system_prompt = "你是一个风趣的编程老师,学生使用判题狗平台进行编程练习。请根据学生提供的详细数据和每周数据,给出用户的学习建议,最后写一句鼓励学生的话。请使用 markdown 格式输出,不要在代码块中输出。"
user_prompt = f"这段时间内的详细数据: {details}\n(其中部分字段含义是 flowcharts:流程图的提交,solved:代码的提交)\n每周或每月的数据: {duration}"
analysis_chunks = []
saved_instance = None
completed = False
def save_analysis():
nonlocal saved_instance
if analysis_chunks and not saved_instance:
saved_instance = AIAnalysis.objects.create(
user=request.user,
provider="deepseek",
model="deepseek-chat",
data={"details": details, "duration": duration},
system_prompt=system_prompt,
user_prompt="这段时间内的详细数据,每周或每月的数据。",
analysis="".join(analysis_chunks).strip(),
)
def stream_generator():
nonlocal completed
try:
stream = client.chat.completions.create(
model="deepseek-chat",
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
stream=True,
)
except Exception as exc:
yield f"data: {json.dumps({'type': 'error', 'message': str(exc)})}\n\n"
yield "event: end\n\n"
return
yield "event: start\n\n"
try:
for chunk in stream:
if not chunk.choices:
continue
choice = chunk.choices[0]
if choice.finish_reason:
completed = True
save_analysis()
yield f"data: {json.dumps({'type': 'done'})}\n\n"
break
content = choice.delta.content
if content:
analysis_chunks.append(content)
yield f"data: {json.dumps({'type': 'delta', 'content': content})}\n\n"
except Exception as exc:
yield f"data: {json.dumps({'type': 'error', 'message': str(exc)})}\n\n"
finally:
save_analysis()
if saved_instance and not completed:
try:
saved_instance.delete()
except Exception:
pass
yield "event: end\n\n"
response = StreamingHttpResponse(
streaming_content=stream_generator(),
content_type="text/event-stream",
)
response["Cache-Control"] = "no-cache"
return response
class AIHeatmapDataAPI(APIView):
@login_required
def get(self, request):
user = request.user
cache_key = get_cache_key("ai_heatmap", user.id, user.class_name or "")
cached_result = cache.get(cache_key)
if cached_result:
return self.success(cached_result)
end = datetime.now()
start = end - timedelta(days=365)
# 使用单次查询获取所有数据,按日期分组统计
submission_counts = (
Submission.objects.filter(
user_id=user.id, create_time__gte=start, create_time__lte=end
)
.annotate(date=TruncDate("create_time"))
.values("date")
.annotate(count=Count("id"))
.order_by("date")
)
# 将查询结果转换为字典,便于快速查找
submission_dict = {item["date"]: item["count"] for item in submission_counts}
# 生成365天的热力图数据
heatmap_data = []
current_date = start.date()
for i in range(365):
day_date = current_date + timedelta(days=i)
submission_count = submission_dict.get(day_date, 0)
heatmap_data.append(
{
"timestamp": int(
datetime.combine(day_date, datetime.min.time()).timestamp()
* 1000
),
"value": submission_count,
}
)
cache.set(cache_key, heatmap_data, CACHE_TIMEOUT)
return self.success(heatmap_data)

View File

@@ -1,97 +0,0 @@
"""
WebSocket consumers for configuration updates
"""
import json
import logging
from channels.generic.websocket import AsyncWebsocketConsumer
logger = logging.getLogger(__name__)
class ConfigConsumer(AsyncWebsocketConsumer):
"""
WebSocket consumer for real-time configuration updates
当管理员修改配置后,通过 WebSocket 实时推送配置变化
"""
async def connect(self):
"""处理 WebSocket 连接"""
self.user = self.scope["user"]
# 只允许认证用户连接
if not self.user.is_authenticated:
await self.close()
return
# 使用全局配置组名,所有用户都能接收配置更新
self.group_name = "config_updates"
# 加入配置更新组
await self.channel_layer.group_add(
self.group_name,
self.channel_name
)
await self.accept()
logger.info(f"Config WebSocket connected: user_id={self.user.id}, channel={self.channel_name}")
async def disconnect(self, close_code):
"""处理 WebSocket 断开连接"""
if hasattr(self, 'group_name'):
await self.channel_layer.group_discard(
self.group_name,
self.channel_name
)
logger.info(f"Config WebSocket disconnected: user_id={self.user.id}, close_code={close_code}")
async def receive(self, text_data):
"""
接收客户端消息
客户端可以发送心跳包或配置更新请求
"""
try:
data = json.loads(text_data)
message_type = data.get("type")
if message_type == "ping":
# 响应心跳包
await self.send(text_data=json.dumps({
"type": "pong",
"timestamp": data.get("timestamp")
}))
elif message_type == "config_update":
# 处理配置更新请求
key = data.get("key")
value = data.get("value")
if key and value is not None:
logger.info(f"User {self.user.id} requested config update: {key}={value}")
# 这里可以添加权限检查,只有管理员才能发送配置更新
if self.user.is_superuser:
# 广播配置更新给所有连接的客户端
await self.channel_layer.group_send(
self.group_name,
{
"type": "config_update",
"data": {
"type": "config_update",
"key": key,
"value": value
}
}
)
except json.JSONDecodeError:
logger.error(f"Invalid JSON received from user {self.user.id}")
except Exception as e:
logger.error(f"Error handling message from user {self.user.id}: {str(e)}")
async def config_update(self, event):
"""
接收来自 channel layer 的配置更新消息并发送给客户端
这个方法名对应 group_send 中的 type 字段
"""
try:
# 从 event 中提取数据并发送给客户端
await self.send(text_data=json.dumps(event["data"]))
logger.debug(f"Sent config update to user {self.user.id}: {event['data']}")
except Exception as e:
logger.error(f"Error sending config update to user {self.user.id}: {str(e)}")

View File

@@ -27,7 +27,6 @@ class CreateEditWebsiteConfigSerializer(serializers.Serializer):
allow_register = serializers.BooleanField() allow_register = serializers.BooleanField()
submission_list_show_all = serializers.BooleanField() submission_list_show_all = serializers.BooleanField()
class_list = serializers.ListField(child=serializers.CharField(max_length=64)) class_list = serializers.ListField(child=serializers.CharField(max_length=64))
enable_maxkb = serializers.BooleanField()
class JudgeServerSerializer(serializers.ModelSerializer): class JudgeServerSerializer(serializers.ModelSerializer):

View File

@@ -24,7 +24,6 @@ from utils.api import APIView, CSRFExemptAPIView, validate_serializer
from utils.cache import JsonDataLoader from utils.cache import JsonDataLoader
from utils.shortcuts import send_email, get_env from utils.shortcuts import send_email, get_env
from utils.xss_filter import XSSHtml from utils.xss_filter import XSSHtml
from utils.websocket import push_config_update
from .models import JudgeServer from .models import JudgeServer
from .serializers import ( from .serializers import (
CreateEditWebsiteConfigSerializer, CreateEditWebsiteConfigSerializer,
@@ -108,7 +107,6 @@ class WebsiteConfigAPI(APIView):
"allow_register", "allow_register",
"submission_list_show_all", "submission_list_show_all",
"class_list", "class_list",
"enable_maxkb",
] ]
} }
return self.success(ret) return self.success(ret)
@@ -121,10 +119,6 @@ class WebsiteConfigAPI(APIView):
with XSSHtml() as parser: with XSSHtml() as parser:
v = parser.clean(v) v = parser.clean(v)
setattr(SysOptions, k, v) setattr(SysOptions, k, v)
# 推送配置更新到所有连接的客户端
push_config_update(k, v)
return self.success() return self.success()
@@ -210,6 +204,7 @@ class LanguagesAPI(APIView):
return self.success( return self.success(
{ {
"languages": SysOptions.languages, "languages": SysOptions.languages,
"spj_languages": SysOptions.spj_languages,
} }
) )
@@ -315,11 +310,8 @@ class RandomUsernameAPI(APIView):
class HitokotoAPI(APIView): class HitokotoAPI(APIView):
def get(self, request): def get(self, request):
try: categories = JsonDataLoader.load_data(settings.HITOKOTO_DIR, "categories.json")
categories = JsonDataLoader.load_data(settings.HITOKOTO_DIR, "categories.json") path = random.choice(categories).get("path")
path = random.choice(categories).get("path") sentences = JsonDataLoader.load_data(settings.HITOKOTO_DIR, path)
sentences = JsonDataLoader.load_data(settings.HITOKOTO_DIR, path) sentence = random.choice(sentences)
sentence = random.choice(sentences) return self.success(sentence)
return self.success(sentence)
except Exception:
return self.error("获取一言失败,请稍后再试")

View File

@@ -46,13 +46,10 @@ class Contest(models.Model):
# 是否有权查看problem 的一些统计信息 诸如submission_number, accepted_number 等 # 是否有权查看problem 的一些统计信息 诸如submission_number, accepted_number 等
def problem_details_permission(self, user): def problem_details_permission(self, user):
return ( return self.rule_type == ContestRuleType.ACM or \
self.rule_type == ContestRuleType.ACM self.status == ContestStatus.CONTEST_ENDED or \
or self.status == ContestStatus.CONTEST_ENDED user.is_authenticated and user.is_contest_admin(self) or \
or user.is_authenticated self.real_time_rank
and user.is_contest_admin(self)
or self.real_time_rank
)
class Meta: class Meta:
db_table = "contest" db_table = "contest"

View File

@@ -6,9 +6,7 @@ from ipaddress import ip_network
import dateutil.parser import dateutil.parser
from django.http import FileResponse from django.http import FileResponse
from problem.models import Problem from account.decorators import check_contest_permission, ensure_created_by
from account.decorators import super_admin_required
from account.models import User from account.models import User
from submission.models import Submission, JudgeStatus from submission.models import Submission, JudgeStatus
from utils.api import APIView, validate_serializer from utils.api import APIView, validate_serializer
@@ -17,20 +15,14 @@ from utils.constants import CacheKey
from utils.shortcuts import rand_str from utils.shortcuts import rand_str
from utils.tasks import delete_files from utils.tasks import delete_files
from ..models import Contest, ContestAnnouncement, ACMContestRank from ..models import Contest, ContestAnnouncement, ACMContestRank
from ..serializers import ( from ..serializers import (ContestAnnouncementSerializer, ContestAdminSerializer,
ContestAnnouncementSerializer, CreateConetestSeriaizer, CreateContestAnnouncementSerializer,
ContestAdminSerializer, EditConetestSeriaizer, EditContestAnnouncementSerializer,
CreateConetestSeriaizer, ACMContesHelperSerializer, )
CreateContestAnnouncementSerializer,
EditConetestSeriaizer,
EditContestAnnouncementSerializer,
ACMContesHelperSerializer,
)
class ContestAPI(APIView): class ContestAPI(APIView):
@validate_serializer(CreateConetestSeriaizer) @validate_serializer(CreateConetestSeriaizer)
@super_admin_required
def post(self, request): def post(self, request):
data = request.data data = request.data
data["start_time"] = dateutil.parser.parse(data["start_time"]) data["start_time"] = dateutil.parser.parse(data["start_time"])
@@ -49,11 +41,11 @@ class ContestAPI(APIView):
return self.success(ContestAdminSerializer(contest).data) return self.success(ContestAdminSerializer(contest).data)
@validate_serializer(EditConetestSeriaizer) @validate_serializer(EditConetestSeriaizer)
@super_admin_required
def put(self, request): def put(self, request):
data = request.data data = request.data
try: try:
contest = Contest.objects.get(id=data.pop("id")) contest = Contest.objects.get(id=data.pop("id"))
ensure_created_by(contest, request.user)
except Contest.DoesNotExist: except Contest.DoesNotExist:
return self.error("Contest does not exist") return self.error("Contest does not exist")
data["start_time"] = dateutil.parser.parse(data["start_time"]) data["start_time"] = dateutil.parser.parse(data["start_time"])
@@ -76,29 +68,28 @@ class ContestAPI(APIView):
contest.save() contest.save()
return self.success(ContestAdminSerializer(contest).data) return self.success(ContestAdminSerializer(contest).data)
@super_admin_required
def get(self, request): def get(self, request):
contest_id = request.GET.get("id") contest_id = request.GET.get("id")
if contest_id: if contest_id:
try: try:
contest = Contest.objects.get(id=contest_id) contest = Contest.objects.get(id=contest_id)
ensure_created_by(contest, request.user)
return self.success(ContestAdminSerializer(contest).data) return self.success(ContestAdminSerializer(contest).data)
except Contest.DoesNotExist: except Contest.DoesNotExist:
return self.error("Contest does not exist") return self.error("Contest does not exist")
contests = Contest.objects.all().order_by("-create_time") contests = Contest.objects.all().order_by("-create_time")
if request.user.is_admin():
contests = contests.filter(created_by=request.user)
keyword = request.GET.get("keyword") keyword = request.GET.get("keyword")
if keyword: if keyword:
contests = contests.filter(title__contains=keyword) contests = contests.filter(title__contains=keyword)
return self.success( return self.success(self.paginate_data(request, contests, ContestAdminSerializer))
self.paginate_data(request, contests, ContestAdminSerializer)
)
class ContestAnnouncementAPI(APIView): class ContestAnnouncementAPI(APIView):
@validate_serializer(CreateContestAnnouncementSerializer) @validate_serializer(CreateContestAnnouncementSerializer)
@super_admin_required
def post(self, request): def post(self, request):
""" """
Create one contest_announcement. Create one contest_announcement.
@@ -106,6 +97,7 @@ class ContestAnnouncementAPI(APIView):
data = request.data data = request.data
try: try:
contest = Contest.objects.get(id=data.pop("contest_id")) contest = Contest.objects.get(id=data.pop("contest_id"))
ensure_created_by(contest, request.user)
data["contest"] = contest data["contest"] = contest
data["created_by"] = request.user data["created_by"] = request.user
except Contest.DoesNotExist: except Contest.DoesNotExist:
@@ -114,7 +106,6 @@ class ContestAnnouncementAPI(APIView):
return self.success(ContestAnnouncementSerializer(announcement).data) return self.success(ContestAnnouncementSerializer(announcement).data)
@validate_serializer(EditContestAnnouncementSerializer) @validate_serializer(EditContestAnnouncementSerializer)
@super_admin_required
def put(self, request): def put(self, request):
""" """
update contest_announcement update contest_announcement
@@ -122,6 +113,7 @@ class ContestAnnouncementAPI(APIView):
data = request.data data = request.data
try: try:
contest_announcement = ContestAnnouncement.objects.get(id=data.pop("id")) contest_announcement = ContestAnnouncement.objects.get(id=data.pop("id"))
ensure_created_by(contest_announcement, request.user)
except ContestAnnouncement.DoesNotExist: except ContestAnnouncement.DoesNotExist:
return self.error("Contest announcement does not exist") return self.error("Contest announcement does not exist")
for k, v in data.items(): for k, v in data.items():
@@ -129,17 +121,19 @@ class ContestAnnouncementAPI(APIView):
contest_announcement.save() contest_announcement.save()
return self.success() return self.success()
@super_admin_required
def delete(self, request): def delete(self, request):
""" """
Delete one contest_announcement. Delete one contest_announcement.
""" """
contest_announcement_id = request.GET.get("id") contest_announcement_id = request.GET.get("id")
if contest_announcement_id: if contest_announcement_id:
ContestAnnouncement.objects.filter(id=contest_announcement_id).delete() if request.user.is_admin():
ContestAnnouncement.objects.filter(id=contest_announcement_id,
contest__created_by=request.user).delete()
else:
ContestAnnouncement.objects.filter(id=contest_announcement_id).delete()
return self.success() return self.success()
@super_admin_required
def get(self, request): def get(self, request):
""" """
Get one contest_announcement or contest_announcement list. Get one contest_announcement or contest_announcement list.
@@ -147,71 +141,45 @@ class ContestAnnouncementAPI(APIView):
contest_announcement_id = request.GET.get("id") contest_announcement_id = request.GET.get("id")
if contest_announcement_id: if contest_announcement_id:
try: try:
contest_announcement = ContestAnnouncement.objects.get( contest_announcement = ContestAnnouncement.objects.get(id=contest_announcement_id)
id=contest_announcement_id ensure_created_by(contest_announcement, request.user)
) return self.success(ContestAnnouncementSerializer(contest_announcement).data)
return self.success(
ContestAnnouncementSerializer(contest_announcement).data
)
except ContestAnnouncement.DoesNotExist: except ContestAnnouncement.DoesNotExist:
return self.error("Contest announcement does not exist") return self.error("Contest announcement does not exist")
contest_id = request.GET.get("contest_id") contest_id = request.GET.get("contest_id")
if not contest_id: if not contest_id:
return self.error("Parameter error") return self.error("Parameter error")
contest_announcements = ContestAnnouncement.objects.filter( contest_announcements = ContestAnnouncement.objects.filter(contest_id=contest_id)
contest_id=contest_id if request.user.is_admin():
) contest_announcements = contest_announcements.filter(created_by=request.user)
keyword = request.GET.get("keyword") keyword = request.GET.get("keyword")
if keyword: if keyword:
contest_announcements = contest_announcements.filter( contest_announcements = contest_announcements.filter(title__contains=keyword)
title__contains=keyword return self.success(ContestAnnouncementSerializer(contest_announcements, many=True).data)
)
return self.success(
ContestAnnouncementSerializer(contest_announcements, many=True).data
)
class ACMContestHelper(APIView): class ACMContestHelper(APIView):
@super_admin_required @check_contest_permission(check_type="ranks")
def get(self, request): def get(self, request):
contest_id = request.GET.get("contest_id") ranks = ACMContestRank.objects.filter(contest=self.contest, accepted_number__gt=0) \
if not contest_id: .values("id", "user__username", "user__userprofile__real_name", "submission_info")
return self.error("Parameter error, contest_id is required")
try:
contest = Contest.objects.get(id=contest_id, visible=True)
except Contest.DoesNotExist:
return self.error("Contest does not exist")
problems = Problem.objects.filter(contest=contest).values("id", "_id")
problem_id_map = {str(p["id"]): p["_id"] for p in problems}
ranks = ACMContestRank.objects.filter(
contest=contest, accepted_number__gt=0
).values(
"id", "user__username", "user__userprofile__real_name", "submission_info"
)
results = [] results = []
for rank in ranks: for rank in ranks:
for problem_id, info in rank["submission_info"].items(): for problem_id, info in rank["submission_info"].items():
if info["is_ac"]: if info["is_ac"]:
results.append( results.append({
{ "id": rank["id"],
"id": rank["id"], "username": rank["user__username"],
"username": rank["user__username"], "real_name": rank["user__userprofile__real_name"],
"real_name": rank["user__userprofile__real_name"], "problem_id": problem_id,
"problem_id": problem_id, "ac_info": info,
"problem_display_id": problem_id_map.get( "checked": info.get("checked", False)
problem_id, problem_id })
),
"ac_info": info,
"checked": info.get("checked", False),
}
)
results.sort(key=lambda x: -x["ac_info"]["ac_time"]) results.sort(key=lambda x: -x["ac_info"]["ac_time"])
return self.success(results) return self.success(results)
@super_admin_required @check_contest_permission(check_type="ranks")
@validate_serializer(ACMContesHelperSerializer) @validate_serializer(ACMContesHelperSerializer)
def put(self, request): def put(self, request):
data = request.data data = request.data
@@ -232,9 +200,7 @@ class DownloadContestSubmissions(APIView):
problem_ids = contest.problem_set.all().values_list("id", "_id") problem_ids = contest.problem_set.all().values_list("id", "_id")
id2display_id = {k[0]: k[1] for k in problem_ids} id2display_id = {k[0]: k[1] for k in problem_ids}
ac_map = {k[0]: False for k in problem_ids} ac_map = {k[0]: False for k in problem_ids}
submissions = Submission.objects.filter( submissions = Submission.objects.filter(contest=contest, result=JudgeStatus.ACCEPTED).order_by("-create_time")
contest=contest, result=JudgeStatus.ACCEPTED
).order_by("-create_time")
user_ids = submissions.values_list("user_id", flat=True) user_ids = submissions.values_list("user_id", flat=True)
users = User.objects.filter(id__in=user_ids) users = User.objects.filter(id__in=user_ids)
path = f"/tmp/{rand_str()}.zip" path = f"/tmp/{rand_str()}.zip"
@@ -248,25 +214,21 @@ class DownloadContestSubmissions(APIView):
problem_id = submission.problem_id problem_id = submission.problem_id
if user_ac_map[problem_id]: if user_ac_map[problem_id]:
continue continue
file_name = ( file_name = f"{user.username}_{id2display_id[submission.problem_id]}.txt"
f"{user.username}_{id2display_id[submission.problem_id]}.txt"
)
compression = zipfile.ZIP_DEFLATED compression = zipfile.ZIP_DEFLATED
zip_file.writestr( zip_file.writestr(zinfo_or_arcname=f"{file_name}",
zinfo_or_arcname=f"{file_name}", data=submission.code,
data=submission.code, compress_type=compression)
compress_type=compression,
)
user_ac_map[problem_id] = True user_ac_map[problem_id] = True
return path return path
@super_admin_required
def get(self, request): def get(self, request):
contest_id = request.GET.get("contest_id") contest_id = request.GET.get("contest_id")
if not contest_id: if not contest_id:
return self.error("Parameter error") return self.error("Parameter error")
try: try:
contest = Contest.objects.get(id=contest_id) contest = Contest.objects.get(id=contest_id)
ensure_created_by(contest, request.user)
except Contest.DoesNotExist: except Contest.DoesNotExist:
return self.error("Contest does not exist") return self.error("Contest does not exist")
@@ -275,7 +237,5 @@ class DownloadContestSubmissions(APIView):
delete_files.send_with_options(args=(zip_path,), delay=300_000) delete_files.send_with_options(args=(zip_path,), delay=300_000)
resp = FileResponse(open(zip_path, "rb")) resp = FileResponse(open(zip_path, "rb"))
resp["Content-Type"] = "application/zip" resp["Content-Type"] = "application/zip"
resp["Content-Disposition"] = ( resp["Content-Disposition"] = f"attachment;filename={os.path.basename(zip_path)}"
f"attachment;filename={os.path.basename(zip_path)}"
)
return resp return resp

View File

@@ -2,23 +2,6 @@ location /public {
root /data; root /data;
} }
# WebSocket 支持
location /ws/ {
proxy_pass http://websocket;
proxy_http_version 1.1;
proxy_set_header Upgrade $http_upgrade;
proxy_set_header Connection "upgrade";
proxy_set_header Host $http_host;
proxy_set_header X-Real-IP __IP_HEADER__;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
# WebSocket 超时设置
proxy_connect_timeout 7d;
proxy_send_timeout 7d;
proxy_read_timeout 7d;
}
location /api { location /api {
include api_proxy.conf; include api_proxy.conf;
} }

View File

@@ -38,11 +38,6 @@ http {
keepalive 32; keepalive 32;
} }
upstream websocket {
server 127.0.0.1:8001;
keepalive 32;
}
add_header X-XSS-Protection "1; mode=block" always; add_header X-XSS-Protection "1; mode=block" always;
add_header X-Frame-Options SAMEORIGIN always; add_header X-Frame-Options SAMEORIGIN always;
add_header X-Content-Type-Options nosniff always; add_header X-Content-Type-Options nosniff always;
@@ -51,7 +46,7 @@ http {
listen 8000 default_server; listen 8000 default_server;
server_name _; server_name _;
include locations.conf; include http_locations.conf;
} }
# server { # server {

View File

@@ -1,76 +1,31 @@
annotated-types==0.7.0 asgiref==3.8.1
anyio==4.12.0 certifi==2025.6.15
asgiref==3.11.0 charset-normalizer==3.4.2
attrs==25.4.0 click==8.2.1
autobahn==25.12.2 django==5.2.3
automat==25.4.16
cbor2==5.7.1
certifi==2025.11.12
cffi==2.0.0
channels==4.3.2
channels-redis==4.3.0
charset-normalizer==3.4.4
constantly==23.10.4
coverage==6.5.0
cryptography==46.0.3
daphne==4.2.1
distro==1.9.0
django==6.0
django-cas-ng==5.0.1
django-dbconn-retry==0.1.8 django-dbconn-retry==0.1.8
django-dramatiq==0.13.0 django-dramatiq==0.13.0
django-redis==5.4.0 django-redis==5.4.0
djangorestframework==3.16.0 djangorestframework==3.16.0
dramatiq==1.17.0 dramatiq==1.18.0
entrypoints==0.4
envelopes==0.4 envelopes==0.4
flake8==7.0.0 gunicorn==23.0.0
flake8-coding==1.3.2
flake8-quotes==3.3.2
gunicorn==22.0.0
h11==0.16.0 h11==0.16.0
httpcore==1.0.9 idna==3.10
httpx==0.28.1 otpauth==2.2.1
hyperlink==21.0.0
idna==3.11
incremental==24.11.0
jiter==0.12.0
jsonfield==3.1.0
lxml==6.0.2
mccabe==0.7.0
msgpack==1.1.2
openai==2.14.0
otpauth==1.0.1
packaging==25.0 packaging==25.0
pillow==10.2.0 pillow==11.2.1
prometheus-client==0.23.1 prometheus-client==0.22.1
psycopg==3.2.9 psycopg==3.2.9
psycopg-binary==3.2.9 psycopg-binary==3.2.9
py-ubjson==0.16.1 python-dateutil==2.9.0.post0
pyasn1==0.6.1
pyasn1-modules==0.4.2
pycodestyle==2.11.1
pycparser==2.23
pydantic==2.12.5
pydantic-core==2.41.5
pyflakes==3.2.0
pyopenssl==25.3.0
python-cas==1.7.1
python-dateutil==2.8.2
qrcode==8.2 qrcode==8.2
raven==6.10.0 raven==6.10.0
redis==7.1.0 redis==6.2.0
requests==2.32.5 requests==2.32.4
service-identity==24.2.0
six==1.17.0 six==1.17.0
sniffio==1.3.1 sqlparse==0.5.3
sqlparse==0.5.5 typing-extensions==4.14.0
tqdm==4.67.1 urllib3==2.4.0
twisted==25.5.0 uvicorn==0.35.0
txaio==25.12.2 xlsxwriter==3.2.5
typing-extensions==4.15.0
typing-inspection==0.4.2
ujson==5.11.0
urllib3==2.6.2
xlsxwriter==3.2.0
zope-interface==8.1.1

View File

@@ -28,7 +28,7 @@ stopwaitsecs = 5
killasgroup=true killasgroup=true
[program:gunicorn] [program:gunicorn]
command=gunicorn oj.wsgi --user server --group spj --bind 127.0.0.1:8080 --workers %(ENV_MAX_WORKER_NUM)s --threads 4 --max-requests-jitter 10000 --max-requests 1000000 --keep-alive 32 command=gunicorn oj.asgi --user server --group spj --bind 127.0.0.1:8080 --workers %(ENV_MAX_WORKER_NUM)s --threads 4 --max-requests-jitter 10000 --max-requests 1000000 --keep-alive 32 --worker-class uvicorn.workers.UvicornWorker
directory=/app/ directory=/app/
stdout_logfile=/data/log/gunicorn.log stdout_logfile=/data/log/gunicorn.log
stderr_logfile=/data/log/gunicorn.log stderr_logfile=/data/log/gunicorn.log
@@ -38,18 +38,6 @@ startsecs=5
stopwaitsecs = 5 stopwaitsecs = 5
killasgroup=true killasgroup=true
[program:daphne]
command=daphne -b 127.0.0.1 -p 8001 --access-log /data/log/daphne_access.log oj.asgi:application
directory=/app/
user=server
stdout_logfile=/data/log/daphne.log
stderr_logfile=/data/log/daphne.log
autostart=true
autorestart=true
startsecs=5
stopwaitsecs = 5
killasgroup=true
[program:dramatiq] [program:dramatiq]
command=python3 manage.py rundramatiq --processes %(ENV_MAX_WORKER_NUM)s --threads 4 command=python3 manage.py rundramatiq --processes %(ENV_MAX_WORKER_NUM)s --threads 4
directory=/app/ directory=/app/

173
dev.py
View File

@@ -1,173 +0,0 @@
#!/usr/bin/env python
"""
WebSocket 开发服务器启动脚本
同时启动 Daphne (WebSocket) 和 Django runserver (开发服务器)
支持 Windows 和 Linux
"""
import os
import sys
import subprocess
import platform
import signal
from pathlib import Path
from threading import Thread
import time
def main():
# 获取项目根目录
base_dir = Path(__file__).resolve().parent
os.chdir(base_dir)
print("=" * 70)
print("启动 Django 开发服务器 + WebSocket 服务器")
print("=" * 70)
print()
# 检测操作系统
is_windows = platform.system() == "Windows"
# 检查虚拟环境(跨平台)
if is_windows:
# Windows: .venv/Scripts/python.exe
venv_python = base_dir / ".venv" / "Scripts" / "python.exe"
else:
# Linux/Mac: .venv/bin/python
venv_python = base_dir / ".venv" / "bin" / "python"
if venv_python.exists():
print("[✓] 使用虚拟环境: .venv")
python_exec = str(venv_python)
else:
print("[!] 未找到 .venv 虚拟环境,使用全局 Python")
print("[!] 建议创建虚拟环境: python -m venv .venv")
python_exec = sys.executable
# 检查 daphne 是否安装
try:
result = subprocess.run(
[python_exec, "-m", "daphne", "--version"], capture_output=True, text=True
)
if result.returncode != 0 and result.returncode != 2:
print("[✗] 错误: Daphne 未安装")
print("请运行: pip install daphne channels channels-redis")
sys.exit(1)
except FileNotFoundError:
print("[✗] 错误: 无法找到 Python 解释器")
sys.exit(1)
# 进程列表
processes = []
# 启动两个服务器
try:
# 启动 Django runserver (端口 8000)
print("[*] 启动 Django 开发服务器 (端口 8000)...")
runserver_cmd = ["uv", "run", "manage.py", "runserver", "0.0.0.0:8000"]
runserver_process = subprocess.Popen(
runserver_cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
bufsize=1,
universal_newlines=True,
)
processes.append(("Django Runserver", runserver_process))
# 等待一下,让 runserver 先启动
time.sleep(1)
# 启动 Daphne (端口 8001)
print("[*] 启动 Daphne WebSocket 服务器 (端口 8001)...")
daphne_cmd = [
python_exec,
"-m",
"daphne",
"-b",
"0.0.0.0",
"-p",
"8001",
"oj.asgi:application",
]
daphne_process = subprocess.Popen(
daphne_cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
bufsize=1,
universal_newlines=True,
)
processes.append(("Daphne", daphne_process))
print()
print("[✓] 所有服务器已启动")
print()
# 创建输出线程
def print_output(name, process):
"""打印进程输出"""
for line in process.stdout:
print(f"[{name}] {line}", end="")
# 启动输出线程
threads = []
for name, process in processes:
thread = Thread(target=print_output, args=(name, process), daemon=True)
thread.start()
threads.append(thread)
# 等待进程(任意一个退出就退出)
while True:
for name, process in processes:
if process.poll() is not None:
print(f"\n[!] {name} 已退出")
raise KeyboardInterrupt
time.sleep(0.5)
except KeyboardInterrupt:
print()
print()
print("[*] 正在停止所有服务器...")
# 终止所有进程
for name, process in processes:
try:
if process.poll() is None: # 如果进程还在运行
print(f"[*] 停止 {name}...")
if is_windows:
# Windows 使用 CTRL_C_EVENT
process.send_signal(signal.CTRL_C_EVENT)
else:
# Unix 使用 SIGTERM
process.terminate()
# 等待进程结束(最多 5 秒)
try:
process.wait(timeout=5)
except subprocess.TimeoutExpired:
print(f"[!] {name} 未响应,强制终止...")
process.kill()
process.wait()
except Exception as e:
print(f"[!] 停止 {name} 时出错: {e}")
print()
print("[✓] 所有服务器已停止")
except Exception as e:
print(f"[✗] 错误: {e}")
# 清理所有进程
for name, process in processes:
try:
if process.poll() is None:
process.kill()
process.wait()
except Exception:
pass
sys.exit(1)
if __name__ == "__main__":
main()

View File

View File

View File

@@ -1,7 +0,0 @@
from django.apps import AppConfig
class FlowchartConfig(AppConfig):
default_auto_field = 'django.db.models.BigAutoField'
name = 'flowchart'
verbose_name = '流程图管理'

View File

@@ -1,83 +0,0 @@
"""
WebSocket consumers for flowchart evaluation updates
"""
import json
import logging
from channels.generic.websocket import AsyncWebsocketConsumer
logger = logging.getLogger(__name__)
class FlowchartConsumer(AsyncWebsocketConsumer):
"""
WebSocket consumer for real-time flowchart evaluation updates
当用户提交流程图后,通过 WebSocket 实时接收AI评分状态更新
"""
async def connect(self):
"""处理 WebSocket 连接"""
self.user = self.scope["user"]
# 只允许认证用户连接
if not self.user.is_authenticated:
await self.close()
return
# 使用用户 ID 作为组名,这样可以向特定用户推送消息
self.group_name = f"flowchart_user_{self.user.id}"
# 加入用户专属的组
await self.channel_layer.group_add(
self.group_name,
self.channel_name
)
await self.accept()
logger.info(f"Flowchart WebSocket connected: user_id={self.user.id}, channel={self.channel_name}")
async def disconnect(self, close_code):
"""处理 WebSocket 断开连接"""
if hasattr(self, 'group_name'):
await self.channel_layer.group_discard(
self.group_name,
self.channel_name
)
logger.info(f"Flowchart WebSocket disconnected: user_id={self.user.id}, close_code={close_code}")
async def receive(self, text_data):
"""
接收客户端消息
客户端可以发送心跳包或订阅特定流程图提交
"""
try:
data = json.loads(text_data)
message_type = data.get("type")
if message_type == "ping":
# 响应心跳包
await self.send(text_data=json.dumps({
"type": "pong",
"timestamp": data.get("timestamp")
}))
elif message_type == "subscribe":
# 订阅特定流程图提交的更新
submission_id = data.get("submission_id")
if submission_id:
logger.info(f"User {self.user.id} subscribed to flowchart submission {submission_id}")
# 可以在这里做额外的订阅逻辑
except json.JSONDecodeError:
logger.error(f"Invalid JSON received from user {self.user.id}")
except Exception as e:
logger.error(f"Error handling message from user {self.user.id}: {str(e)}")
async def flowchart_evaluation_update(self, event):
"""
接收来自 channel layer 的流程图评分更新消息并发送给客户端
这个方法名对应 push_flowchart_evaluation_update 中的 type 字段
"""
try:
# 从 event 中提取数据并发送给客户端
await self.send(text_data=json.dumps(event["data"]))
logger.debug(f"Sent flowchart evaluation update to user {self.user.id}: {event['data']}")
except Exception as e:
logger.error(f"Error sending flowchart evaluation update to user {self.user.id}: {str(e)}")

View File

@@ -1,45 +0,0 @@
# Generated by Django 5.2.3 on 2025-10-11 14:57
import django.db.models.deletion
import utils.shortcuts
from django.conf import settings
from django.db import migrations, models
class Migration(migrations.Migration):
initial = True
dependencies = [
('problem', '0004_problem_allow_flowchart_problem_flowchart_data_and_more'),
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
]
operations = [
migrations.CreateModel(
name='FlowchartSubmission',
fields=[
('id', models.TextField(db_index=True, default=utils.shortcuts.rand_str, primary_key=True, serialize=False)),
('mermaid_code', models.TextField()),
('flowchart_data', models.JSONField(default=dict)),
('status', models.IntegerField(default=0)),
('create_time', models.DateTimeField(auto_now_add=True)),
('ai_score', models.FloatField(blank=True, null=True)),
('ai_grade', models.CharField(blank=True, max_length=10, null=True)),
('ai_feedback', models.TextField(blank=True, null=True)),
('ai_suggestions', models.TextField(blank=True, null=True)),
('ai_criteria_details', models.JSONField(default=dict)),
('ai_provider', models.CharField(default='deepseek', max_length=50)),
('ai_model', models.CharField(default='deepseek-chat', max_length=50)),
('processing_time', models.FloatField(blank=True, null=True)),
('evaluation_time', models.DateTimeField(blank=True, null=True)),
('problem', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='flowchart_submissions', to='problem.problem')),
('user', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='flowchart_submissions', to=settings.AUTH_USER_MODEL)),
],
options={
'db_table': 'flowchart_submission',
'ordering': ['-create_time'],
'indexes': [models.Index(fields=['user', 'create_time'], name='flowchart_user_time_idx'), models.Index(fields=['problem', 'create_time'], name='flowchart_problem_time_idx'), models.Index(fields=['status'], name='flowchart_status_idx')],
},
),
]

View File

@@ -1,65 +0,0 @@
from django.db import models
from django.contrib.auth import get_user_model
from utils.shortcuts import rand_str
from problem.models import Problem
User = get_user_model()
class FlowchartSubmissionStatus:
PENDING = 0 # 等待AI评分
PROCESSING = 1 # AI评分中
COMPLETED = 2 # 评分完成
FAILED = 3 # 评分失败
class FlowchartSubmission(models.Model):
"""流程图提交模型"""
id = models.TextField(default=rand_str, primary_key=True, db_index=True)
# 基础信息
user = models.ForeignKey(User, on_delete=models.CASCADE, related_name='flowchart_submissions')
problem = models.ForeignKey(Problem, on_delete=models.CASCADE, related_name='flowchart_submissions')
# 提交内容
mermaid_code = models.TextField() # Mermaid代码
flowchart_data = models.JSONField(default=dict) # 流程图元数据
# 状态信息
status = models.IntegerField(default=FlowchartSubmissionStatus.PENDING)
create_time = models.DateTimeField(auto_now_add=True)
# AI评分结果
ai_score = models.FloatField(null=True, blank=True) # AI评分 (0-100)
ai_grade = models.CharField(max_length=10, null=True, blank=True) # 等级 (S/A/B/C)
ai_feedback = models.TextField(null=True, blank=True) # AI反馈
ai_suggestions = models.TextField(null=True, blank=True) # AI建议
ai_criteria_details = models.JSONField(default=dict) # 详细评分标准
# 处理信息
ai_provider = models.CharField(max_length=50, default='deepseek')
ai_model = models.CharField(max_length=50, default='deepseek-chat')
processing_time = models.FloatField(null=True, blank=True) # AI处理耗时(秒)
evaluation_time = models.DateTimeField(null=True, blank=True) # 评分完成时间
class Meta:
db_table = 'flowchart_submission'
ordering = ['-create_time']
indexes = [
models.Index(fields=['user', 'create_time'], name='flowchart_user_time_idx'),
models.Index(fields=['problem', 'create_time'], name='flowchart_problem_time_idx'),
models.Index(fields=['status'], name='flowchart_status_idx'),
]
def __str__(self):
return f"FlowchartSubmission {self.id}"
def check_user_permission(self, user, check_share=True):
"""检查用户权限"""
if (
self.user_id == user.id
or not user.is_regular_user()
or self.problem.created_by_id == user.id
):
return True
return False

View File

@@ -1,91 +0,0 @@
from rest_framework import serializers
from .models import FlowchartSubmission
class CreateFlowchartSubmissionSerializer(serializers.Serializer):
problem_id = serializers.IntegerField()
mermaid_code = serializers.CharField()
flowchart_data = serializers.JSONField(required=False, default=dict)
def validate_mermaid_code(self, value):
if not value.strip():
raise serializers.ValidationError("Mermaid代码不能为空")
return value
class FlowchartSubmissionSerializer(serializers.ModelSerializer):
class Meta:
model = FlowchartSubmission
fields = [
"id",
"user",
"problem",
"mermaid_code",
"flowchart_data",
"status",
"create_time",
"ai_score",
"ai_grade",
"ai_feedback",
"ai_suggestions",
"ai_criteria_details",
"ai_provider",
"ai_model",
"processing_time",
"evaluation_time",
]
read_only_fields = ["id", "create_time", "evaluation_time"]
class FlowchartSubmissionListSerializer(serializers.ModelSerializer):
"""用于列表显示的简化序列化器"""
username = serializers.CharField(source="user.username")
problem = serializers.CharField(source="problem._id")
problem_title = serializers.CharField(source="problem.title")
class Meta:
model = FlowchartSubmission
fields = [
"id",
"username",
"problem_title",
"problem",
"status",
"create_time",
"ai_score",
"ai_grade",
"ai_provider",
"ai_model",
"processing_time",
"evaluation_time",
]
class FlowchartSubmissionSummarySerializer(serializers.ModelSerializer):
"""用于AI详情页面的极简序列化器只包含必要字段"""
problem_title = serializers.CharField(source="problem.title")
problem__id = serializers.CharField(source="problem._id")
class Meta:
model = FlowchartSubmission
fields = [
"id",
"problem__id",
"problem_title",
"ai_score",
"ai_grade",
"create_time",
]
class FlowchartSubmissionMergedSerializer(serializers.Serializer):
"""合并后的流程图提交序列化器"""
problem__id = serializers.CharField()
problem_title = serializers.CharField()
submission_count = serializers.IntegerField()
best_score = serializers.FloatField()
best_grade = serializers.CharField()
latest_submission_time = serializers.DateTimeField()
avg_score = serializers.FloatField()

View File

@@ -1,184 +0,0 @@
import dramatiq
import json
import time
from openai import OpenAI
from django.db import transaction
from django.utils import timezone
from utils.shortcuts import get_env, DRAMATIQ_WORKER_ARGS
from .models import FlowchartSubmission, FlowchartSubmissionStatus
@dramatiq.actor(**DRAMATIQ_WORKER_ARGS(max_retries=3))
def evaluate_flowchart_task(submission_id):
"""异步AI评分任务"""
try:
submission = FlowchartSubmission.objects.get(id=submission_id)
# 更新状态为处理中
submission.status = FlowchartSubmissionStatus.PROCESSING
submission.save()
start_time = time.time()
# 使用固定评分标准
system_prompt = build_evaluation_prompt(submission.problem)
# 构建用户提示词,包含标准答案对比
user_prompt = f"""
请对以下Mermaid流程图进行评分
学生提交的流程图:
```mermaid
{submission.mermaid_code}
```
标准答案参考:
```mermaid
{submission.problem.mermaid_code}
```
"""
# 如果有流程图提示,添加到提示词中
if submission.problem.flowchart_hint:
user_prompt += f"""
设计提示:{submission.problem.flowchart_hint}
"""
user_prompt += """
请按照评分标准进行详细评估并给出0-100的分数。
"""
# 调用AI进行评分
api_key = get_env("AI_KEY")
if not api_key:
raise Exception("AI_KEY is not set")
client = OpenAI(api_key=api_key, base_url="https://api.deepseek.com")
response = client.chat.completions.create(
model="deepseek-chat",
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
],
temperature=0.3,
)
ai_response = response.choices[0].message.content
score_data = parse_ai_evaluation_response(ai_response)
processing_time = time.time() - start_time
# 保存评分结果
with transaction.atomic():
submission.ai_score = score_data['score']
submission.ai_grade = score_data['grade']
submission.ai_feedback = score_data['feedback']
submission.ai_suggestions = score_data.get('suggestions', '')
submission.ai_criteria_details = score_data.get('criteria_details', {})
submission.ai_provider = 'deepseek'
submission.ai_model = 'deepseek-chat'
submission.processing_time = processing_time
submission.status = FlowchartSubmissionStatus.COMPLETED
submission.evaluation_time = timezone.now()
submission.save()
# 推送评分完成通知
from utils.websocket import push_flowchart_evaluation_update
push_flowchart_evaluation_update(
submission_id=str(submission.id),
user_id=submission.user_id,
data={
"type": "flowchart_evaluation_completed",
"score": score_data['score'],
"grade": score_data['grade'],
}
)
except Exception as e:
# 处理失败
submission.status = FlowchartSubmissionStatus.FAILED
submission.save()
# 推送错误通知
from utils.websocket import push_flowchart_evaluation_update
push_flowchart_evaluation_update(
submission_id=str(submission.id),
user_id=submission.user_id,
data={
"type": "flowchart_evaluation_failed",
"submission_id": str(submission.id),
"error": str(e)
}
)
raise e
def build_evaluation_prompt(problem):
"""构建AI评分提示词 - 使用固定标准"""
# 使用固定的评分标准
criteria_text = """
- 逻辑正确性 (权重: 1.0, 最高分: 40): 检查流程图的逻辑是否正确,包括条件判断、循环结构等
- 完整性 (权重: 0.8, 最高分: 30): 检查流程图是否包含所有必要的步骤和分支
- 规范性 (权重: 0.6, 最高分: 20): 检查流程图符号使用是否规范,是否符合标准
- 清晰度 (权重: 0.4, 最高分: 10): 评估流程图的整体布局和连线情况(不用考虑节点ID是否复杂)
"""
return f"""
你是一个专业的编程教学助手负责评估学生提交的Mermaid流程图。
评分标准:
{criteria_text}
评分要求:
1. 仔细分析流程图的逻辑正确性、完整性和清晰度
2. 检查是否涵盖了题目的所有要求
3. 评估流程图的规范性和可读性(不用考虑节点ID是否复杂)
4. 给出0-100的分数
5. 提供详细的反馈和改进建议
评分等级:
- S级 (90-100分): 优秀,逻辑清晰,完全符合要求
- A级 (80-89分): 良好,基本符合要求,有少量改进空间
- B级 (70-79分): 及格,基本正确但存在一些问题
- C级 (0-69分): 需要改进,存在明显问题
请以JSON格式返回评分结果
{{
"score": 85,
"grade": "A",
"feedback": "详细的反馈内容",
"suggestions": "改进建议",
"criteria_details": {{
"逻辑正确性": {{"score": 35, "max": 40, "comment": "逻辑基本正确"}},
"完整性": {{"score": 25, "max": 30, "comment": "缺少部分步骤"}},
"规范性": {{"score": 18, "max": 20, "comment": "符号使用规范"}},
"清晰度": {{"score": 8, "max": 10, "comment": "布局清晰"}}
}}
}}
"""
def parse_ai_evaluation_response(ai_response):
"""解析AI评分响应"""
try:
import re
json_match = re.search(r'\{.*\}', ai_response, re.DOTALL)
if json_match:
data = json.loads(json_match.group())
else:
data = {
"score": 60,
"grade": "C",
"feedback": "AI评分解析失败请重新提交",
"suggestions": "",
"criteria_details": {}
}
return data
except Exception:
return {
"score": 60,
"grade": "C",
"feedback": "AI评分解析失败请重新提交",
"suggestions": "",
"criteria_details": {}
}

View File

@@ -1 +0,0 @@
# URLs package

View File

@@ -1,14 +0,0 @@
from django.urls import path
from ..views.oj import (
FlowchartSubmissionAPI,
FlowchartSubmissionListAPI,
FlowchartSubmissionRetryAPI,
FlowchartSubmissionCurrentAPI
)
urlpatterns = [
path('flowchart/submission', FlowchartSubmissionAPI.as_view()),
path('flowchart/submissions', FlowchartSubmissionListAPI.as_view()),
path('flowchart/submission/retry', FlowchartSubmissionRetryAPI.as_view()),
path('flowchart/submission/current', FlowchartSubmissionCurrentAPI.as_view()),
]

View File

@@ -1,3 +0,0 @@
from django.shortcuts import render
# Create your views here.

View File

@@ -1 +0,0 @@
# Views package

View File

@@ -1,163 +0,0 @@
from utils.api import APIView
from account.decorators import login_required
from flowchart.models import FlowchartSubmission, FlowchartSubmissionStatus
from flowchart.serializers import (
CreateFlowchartSubmissionSerializer,
FlowchartSubmissionSerializer,
FlowchartSubmissionListSerializer,
)
from flowchart.tasks import evaluate_flowchart_task
from problem.models import Problem
class FlowchartSubmissionAPI(APIView):
@login_required
def post(self, request):
"""创建流程图提交"""
serializer = CreateFlowchartSubmissionSerializer(data=request.data)
if not serializer.is_valid():
return self.error(serializer.errors)
data = serializer.validated_data
# 验证题目存在
try:
from problem.models import Problem
problem = Problem.objects.get(id=data["problem_id"])
except Problem.DoesNotExist:
return self.error("Problem doesn't exist")
# 验证题目是否允许流程图提交
if not problem.allow_flowchart:
return self.error("This problem does not allow flowchart submission")
# 创建提交记录
submission = FlowchartSubmission.objects.create(
user=request.user,
problem=problem,
mermaid_code=data["mermaid_code"],
flowchart_data=data.get("flowchart_data", {}),
)
# 启动AI评分任务
evaluate_flowchart_task.send(submission.id)
return self.success({"submission_id": submission.id, "status": "pending"})
@login_required
def get(self, request):
"""获取流程图提交详情"""
submission_id = request.GET.get("id")
if not submission_id:
return self.error("submission_id is required")
try:
submission = FlowchartSubmission.objects.get(id=submission_id)
except FlowchartSubmission.DoesNotExist:
return self.error("Submission doesn't exist")
if not submission.check_user_permission(request.user):
return self.error("No permission for this submission")
serializer = FlowchartSubmissionSerializer(submission)
return self.success(serializer.data)
class FlowchartSubmissionListAPI(APIView):
def get(self, request):
"""获取流程图提交列表"""
username = request.GET.get("username")
problem_id = request.GET.get("problem_id")
myself = request.GET.get("myself")
queryset = FlowchartSubmission.objects.select_related("user", "problem")
if problem_id:
try:
problem = Problem.objects.get(
_id=problem_id, contest_id__isnull=True, visible=True
)
except Problem.DoesNotExist:
return self.error("Problem doesn't exist")
queryset = queryset.filter(problem=problem)
if myself and myself == "1":
queryset = queryset.filter(user=request.user)
if username:
queryset = queryset.filter(user__username__icontains=username)
data = self.paginate_data(request, queryset)
data["results"] = FlowchartSubmissionListSerializer(
data["results"], many=True
).data
return self.success(data)
class FlowchartSubmissionRetryAPI(APIView):
@login_required
def post(self, request):
"""重新触发AI评分"""
submission_id = request.data.get("submission_id")
if not submission_id:
return self.error("submission_id is required")
try:
submission = FlowchartSubmission.objects.get(id=submission_id)
except FlowchartSubmission.DoesNotExist:
return self.error("Submission doesn't exist")
# 检查权限
if not submission.check_user_permission(request.user):
return self.error("No permission for this submission")
# 检查是否可以重新评分
if submission.status not in [
FlowchartSubmissionStatus.FAILED,
FlowchartSubmissionStatus.COMPLETED,
]:
return self.error("Submission is not in a state that allows retry")
# 重置状态并重新启动AI评分
submission.status = FlowchartSubmissionStatus.PENDING
submission.ai_score = None
submission.ai_grade = None
submission.ai_feedback = None
submission.ai_suggestions = None
submission.ai_criteria_details = {}
submission.processing_time = None
submission.evaluation_time = None
submission.save()
# 重新启动AI评分任务
evaluate_flowchart_task.send(submission.id)
return self.success(
{
"submission_id": submission.id,
"status": "pending",
"message": "AI evaluation restarted",
}
)
class FlowchartSubmissionCurrentAPI(APIView):
@login_required
def get(self, request):
"""获取当前用户对指定题目的最新流程图提交"""
problem_id = request.GET.get("problem_id")
if not problem_id:
return self.error("problem_id is required")
try:
problem = Problem.objects.get(id=problem_id)
except Problem.DoesNotExist:
return self.error("Problem doesn't exist")
submissions = FlowchartSubmission.objects.filter(
user=request.user, problem=problem
).order_by("-create_time")
count = submissions.count()
if count == 0:
return self.success({"submission": None, "count": 0})
first_submission = submissions[0]
serializer = FlowchartSubmissionSerializer(first_submission)
return self.success({"submission": serializer.data, "count": count})

View File

@@ -16,7 +16,6 @@ from problem.utils import parse_problem_template
from submission.models import JudgeStatus, Submission from submission.models import JudgeStatus, Submission
from utils.cache import cache from utils.cache import cache
from utils.constants import CacheKey from utils.constants import CacheKey
from utils.websocket import push_submission_update
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -67,6 +66,26 @@ class DispatcherBase(object):
logger.exception(e) logger.exception(e)
class SPJCompiler(DispatcherBase):
def __init__(self, spj_code, spj_version, spj_language):
super().__init__()
spj_compile_config = list(filter(lambda config: spj_language == config["name"], SysOptions.spj_languages))[0]["spj"][
"compile"]
self.data = {
"src": spj_code,
"spj_version": spj_version,
"spj_compile_config": spj_compile_config
}
def compile_spj(self):
with ChooseJudgeServer() as server:
if not server:
return "No available judge_server"
result = self._request(urljoin(server.service_url, "compile_spj"), data=self.data)
if not result:
return "Failed to call judge server"
if result["err"]:
return result["data"]
class JudgeDispatcher(DispatcherBase): class JudgeDispatcher(DispatcherBase):
@@ -106,6 +125,12 @@ class JudgeDispatcher(DispatcherBase):
def judge(self): def judge(self):
language = self.submission.language language = self.submission.language
sub_config = list(filter(lambda item: language == item["name"], SysOptions.languages))[0] sub_config = list(filter(lambda item: language == item["name"], SysOptions.languages))[0]
spj_config = {}
if self.problem.spj_code:
for lang in SysOptions.spj_languages:
if lang["name"] == self.problem.spj_language:
spj_config = lang["spj"]
break
if language in self.problem.template: if language in self.problem.template:
template = parse_problem_template(self.problem.template[language]) template = parse_problem_template(self.problem.template[language])
@@ -120,6 +145,10 @@ class JudgeDispatcher(DispatcherBase):
"max_memory": 1024 * 1024 * self.problem.memory_limit, "max_memory": 1024 * 1024 * self.problem.memory_limit,
"test_case_id": self.problem.test_case_id, "test_case_id": self.problem.test_case_id,
"output": False, "output": False,
"spj_version": self.problem.spj_version,
"spj_config": spj_config.get("config"),
"spj_compile_config": spj_config.get("compile"),
"spj_src": self.problem.spj_code,
"io_mode": self.problem.io_mode "io_mode": self.problem.io_mode
} }
@@ -127,56 +156,12 @@ class JudgeDispatcher(DispatcherBase):
if not server: if not server:
data = {"submission_id": self.submission.id, "problem_id": self.problem.id} data = {"submission_id": self.submission.id, "problem_id": self.problem.id}
cache.lpush(CacheKey.waiting_queue, json.dumps(data)) cache.lpush(CacheKey.waiting_queue, json.dumps(data))
# 推送排队状态
try:
push_submission_update(
submission_id=str(self.submission.id),
user_id=self.submission.user_id,
data={
"type": "submission_update",
"submission_id": str(self.submission.id),
"result": JudgeStatus.PENDING,
"status": "pending",
}
)
except Exception as e:
logger.error(f"Failed to push submission update: {str(e)}")
return return
Submission.objects.filter(id=self.submission.id).update(result=JudgeStatus.JUDGING) Submission.objects.filter(id=self.submission.id).update(result=JudgeStatus.JUDGING)
# 推送判题中状态
try:
push_submission_update(
submission_id=str(self.submission.id),
user_id=self.submission.user_id,
data={
"type": "submission_update",
"submission_id": str(self.submission.id),
"result": JudgeStatus.JUDGING,
"status": "judging",
}
)
except Exception as e:
logger.error(f"Failed to push submission update: {str(e)}")
resp = self._request(urljoin(server.service_url, "/judge"), data=data) resp = self._request(urljoin(server.service_url, "/judge"), data=data)
if not resp: if not resp:
Submission.objects.filter(id=self.submission.id).update(result=JudgeStatus.SYSTEM_ERROR) Submission.objects.filter(id=self.submission.id).update(result=JudgeStatus.SYSTEM_ERROR)
# 推送系统错误状态
try:
push_submission_update(
submission_id=str(self.submission.id),
user_id=self.submission.user_id,
data={
"type": "submission_update",
"submission_id": str(self.submission.id),
"result": JudgeStatus.SYSTEM_ERROR,
"status": "error",
}
)
except Exception as e:
logger.error(f"Failed to push submission update: {str(e)}")
return return
if resp["err"]: if resp["err"]:
@@ -197,24 +182,6 @@ class JudgeDispatcher(DispatcherBase):
else: else:
self.submission.result = JudgeStatus.PARTIALLY_ACCEPTED self.submission.result = JudgeStatus.PARTIALLY_ACCEPTED
self.submission.save() self.submission.save()
# 推送判题完成状态
try:
push_submission_update(
submission_id=str(self.submission.id),
user_id=self.submission.user_id,
data={
"type": "submission_update",
"submission_id": str(self.submission.id),
"result": self.submission.result,
"status": "finished",
"time_cost": self.submission.statistic_info.get("time_cost"),
"memory_cost": self.submission.statistic_info.get("memory_cost"),
"score": self.submission.statistic_info.get("score", 0),
}
)
except Exception as e:
logger.error(f"Failed to push submission update: {str(e)}")
if self.contest_id: if self.contest_id:
if self.contest.status != ContestStatus.CONTEST_UNDERWAY or \ if self.contest.status != ContestStatus.CONTEST_UNDERWAY or \

View File

@@ -35,6 +35,20 @@ int main() {
} }
} }
_c_lang_spj_compile = {
"src_name": "spj-{spj_version}.c",
"exe_name": "spj-{spj_version}",
"max_cpu_time": 3000,
"max_real_time": 10000,
"max_memory": 1024 * 1024 * 1024,
"compile_command": "/usr/bin/gcc -DONLINE_JUDGE -O2 -w -fmax-errors=3 -std=c17 {src_path} -lm -o {exe_path}"
}
_c_lang_spj_config = {
"exe_name": "spj-{spj_version}",
"command": "{exe_path} {in_file_path} {user_out_file_path}",
"seccomp_rule": "c_cpp"
}
_cpp_lang_config = { _cpp_lang_config = {
"template": """//PREPEND BEGIN "template": """//PREPEND BEGIN
@@ -68,6 +82,20 @@ int main() {
} }
} }
_cpp_lang_spj_compile = {
"src_name": "spj-{spj_version}.cpp",
"exe_name": "spj-{spj_version}",
"max_cpu_time": 10000,
"max_real_time": 20000,
"max_memory": 1024 * 1024 * 1024,
"compile_command": "/usr/bin/g++ -DONLINE_JUDGE -O2 -w -fmax-errors=3 -std=c++20 {src_path} -lm -o {exe_path}"
}
_cpp_lang_spj_config = {
"exe_name": "spj-{spj_version}",
"command": "{exe_path} {in_file_path} {user_out_file_path}",
"seccomp_rule": "c_cpp"
}
_java_lang_config = { _java_lang_config = {
"template": """//PREPEND BEGIN "template": """//PREPEND BEGIN
@@ -196,8 +224,10 @@ console.log(add(1, 2))
} }
languages = [ languages = [
{"config": _c_lang_config, "name": "C", "description": "GCC 13", "content_type": "text/x-csrc"}, {"config": _c_lang_config, "name": "C", "description": "GCC 13", "content_type": "text/x-csrc",
{"config": _cpp_lang_config, "name": "C++", "description": "GCC 13", "content_type": "text/x-c++src"}, "spj": {"compile": _c_lang_spj_compile, "config": _c_lang_spj_config}},
{"config": _cpp_lang_config, "name": "C++", "description": "GCC 13", "content_type": "text/x-c++src",
"spj": {"compile": _cpp_lang_spj_compile, "config": _cpp_lang_spj_config}},
{"config": _java_lang_config, "name": "Java", "description": "Temurin 21", "content_type": "text/x-java"}, {"config": _java_lang_config, "name": "Java", "description": "Temurin 21", "content_type": "text/x-java"},
{"config": _py3_lang_config, "name": "Python3", "description": "Python 3.12", "content_type": "text/x-python"}, {"config": _py3_lang_config, "name": "Python3", "description": "Python 3.12", "content_type": "text/x-python"},
{"config": _go_lang_config, "name": "Golang", "description": "Golang 1.22", "content_type": "text/x-go"}, {"config": _go_lang_config, "name": "Golang", "description": "Golang 1.22", "content_type": "text/x-go"},

View File

@@ -1,30 +1,7 @@
""" import os
ASGI config for oj project.
from django.core.asgi import get_asgi_application
It exposes the ASGI callable as a module-level variable named ``application``.
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "oj.settings")
For more information on this file, see
https://docs.djangoproject.com/en/5.2/howto/deployment/asgi/ application = get_asgi_application()
"""
import os
from django.core.asgi import get_asgi_application
from channels.routing import ProtocolTypeRouter, URLRouter
from channels.auth import AuthMiddlewareStack
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "oj.settings")
# Initialize Django ASGI application early to ensure the AppRegistry
# is populated before importing code that may import ORM models.
django_asgi_app = get_asgi_application()
# Import routing after Django setup
from oj.routing import websocket_urlpatterns
application = ProtocolTypeRouter(
{
"http": django_asgi_app,
"websocket": AuthMiddlewareStack(URLRouter(websocket_urlpatterns)),
}
)

View File

@@ -1,22 +1,19 @@
# coding=utf-8 # coding=utf-8
import os import os
from utils.shortcuts import get_env
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
DATABASES = { DATABASES = {
"default": { "default": {
"ENGINE": "django.db.backends.postgresql", "ENGINE": "django.db.backends.sqlite3",
"HOST": "150.158.29.156", "NAME": os.path.join(BASE_DIR, "db.sqlite3"),
"PORT": "5455",
"NAME": "onlinejudge",
"USER": "onlinejudge",
"PASSWORD": "onlinejudge",
} }
} }
REDIS_CONF = { REDIS_CONF = {
"host": "150.158.29.156", "host": get_env("REDIS_HOST", "127.0.0.1"),
"port": 5456, "port": get_env("REDIS_PORT", "6380"),
} }

View File

@@ -1,15 +0,0 @@
"""
WebSocket URL Configuration for oj project.
"""
from django.urls import path
from submission.consumers import SubmissionConsumer
from conf.consumers import ConfigConsumer
from flowchart.consumers import FlowchartConsumer
websocket_urlpatterns = [
path("ws/submission/", SubmissionConsumer.as_asgi()),
path("ws/config/", ConfigConsumer.as_asgi()),
path("ws/flowchart/", FlowchartConsumer.as_asgi()),
]

View File

@@ -28,14 +28,12 @@ BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
# Applications # Applications
VENDOR_APPS = [ VENDOR_APPS = [
"daphne", # Channels ASGI server - must be first
"django.contrib.auth", "django.contrib.auth",
"django.contrib.sessions", "django.contrib.sessions",
"django.contrib.contenttypes", "django.contrib.contenttypes",
"django.contrib.messages", "django.contrib.messages",
"django.contrib.staticfiles", "django.contrib.staticfiles",
"rest_framework", "rest_framework",
"channels",
"django_dramatiq", "django_dramatiq",
"django_dbconn_retry", "django_dbconn_retry",
] ]
@@ -57,9 +55,6 @@ LOCAL_APPS = [
"message", "message",
"comment", "comment",
"tutorial", "tutorial",
"ai",
"flowchart",
"problemset",
] ]
INSTALLED_APPS = VENDOR_APPS + LOCAL_APPS INSTALLED_APPS = VENDOR_APPS + LOCAL_APPS
@@ -96,9 +91,6 @@ TEMPLATES = [
] ]
WSGI_APPLICATION = "oj.wsgi.application" WSGI_APPLICATION = "oj.wsgi.application"
# ASGI Application for WebSocket support
ASGI_APPLICATION = "oj.asgi.application"
# Password validation # Password validation
# https://docs.djangoproject.com/en/1.9/ref/settings/#auth-password-validators # https://docs.djangoproject.com/en/1.9/ref/settings/#auth-password-validators
@@ -120,9 +112,13 @@ AUTH_PASSWORD_VALIDATORS = [
# Internationalization # Internationalization
# https://docs.djangoproject.com/en/1.8/topics/i18n/ # https://docs.djangoproject.com/en/1.8/topics/i18n/
LANGUAGE_CODE = "zh-cn" LANGUAGE_CODE = "en-us"
TIME_ZONE = "Asia/Shanghai" TIME_ZONE = "UTC"
USE_I18N = True
USE_L10N = True
USE_TZ = True USE_TZ = True
@@ -214,23 +210,12 @@ def redis_config(db):
} }
CACHES = {"default": redis_config(db=1)} if production_env:
CACHES = {"default": redis_config(db=1)}
SESSION_ENGINE = "django.contrib.sessions.backends.cache" SESSION_ENGINE = "django.contrib.sessions.backends.cache"
SESSION_CACHE_ALIAS = "default" SESSION_CACHE_ALIAS = "default"
# Channels Configuration
CHANNEL_LAYERS = {
"default": {
"BACKEND": "channels_redis.core.RedisChannelLayer",
"CONFIG": {
"hosts": [(REDIS_CONF["host"], REDIS_CONF["port"])],
"capacity": 1500, # 每个频道的最大消息数
"expiry": 10, # 消息过期时间(秒)
},
},
}
DRAMATIQ_BROKER = { DRAMATIQ_BROKER = {
"BROKER": "dramatiq.brokers.redis.RedisBroker", "BROKER": "dramatiq.brokers.redis.RedisBroker",
"OPTIONS": { "OPTIONS": {

View File

@@ -19,8 +19,4 @@ urlpatterns = [
path("api/admin/", include("comment.urls.admin")), path("api/admin/", include("comment.urls.admin")),
path("api/", include("tutorial.urls.tutorial")), path("api/", include("tutorial.urls.tutorial")),
path("api/admin/", include("tutorial.urls.admin")), path("api/admin/", include("tutorial.urls.admin")),
path("api/", include("ai.urls.oj")),
path("api/", include("flowchart.urls.oj")),
path("api/", include("problemset.urls.oj")),
path("api/admin/", include("problemset.urls.admin")),
] ]

View File

@@ -104,7 +104,6 @@ class OptionKeys:
judge_server_token = "judge_server_token" judge_server_token = "judge_server_token"
throttling = "throttling" throttling = "throttling"
languages = "languages" languages = "languages"
enable_maxkb = "enable_maxkb"
class OptionDefaultValue: class OptionDefaultValue:
@@ -120,7 +119,6 @@ class OptionDefaultValue:
throttling = {"ip": {"capacity": 100, "fill_rate": 0.1, "default_capacity": 50}, throttling = {"ip": {"capacity": 100, "fill_rate": 0.1, "default_capacity": 50},
"user": {"capacity": 20, "fill_rate": 0.03, "default_capacity": 10}} "user": {"capacity": 20, "fill_rate": 0.03, "default_capacity": 10}}
languages = languages languages = languages
enable_maxkb = True
class _SysOptionsMeta(type): class _SysOptionsMeta(type):
@@ -273,18 +271,17 @@ class _SysOptionsMeta(type):
def languages(cls, value): def languages(cls, value):
cls._set_option(OptionKeys.languages, value) cls._set_option(OptionKeys.languages, value)
@my_property(ttl=DEFAULT_SHORT_TTL)
def spj_languages(cls):
return [item for item in cls.languages if "spj" in item]
@my_property(ttl=DEFAULT_SHORT_TTL) @my_property(ttl=DEFAULT_SHORT_TTL)
def language_names(cls): def language_names(cls):
return [item["name"] for item in cls.languages] return [item["name"] for item in cls.languages]
@my_property(ttl=DEFAULT_SHORT_TTL) @my_property(ttl=DEFAULT_SHORT_TTL)
def enable_maxkb(cls): def spj_language_names(cls):
return cls._get_option(OptionKeys.enable_maxkb) return [item["name"] for item in cls.languages if "spj" in item]
@enable_maxkb.setter
def enable_maxkb(cls, value):
cls._set_option(OptionKeys.enable_maxkb, value)
def reset_languages(cls): def reset_languages(cls):
cls.languages = languages cls.languages = languages

View File

@@ -1,18 +0,0 @@
# Generated by Django 5.2.3 on 2025-10-03 16:31
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('problem', '0001_initial'),
]
operations = [
migrations.AddField(
model_name='problem',
name='prompt',
field=models.TextField(null=True),
),
]

View File

@@ -1,18 +0,0 @@
# Generated by Django 5.2.3 on 2025-10-03 16:56
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('problem', '0002_problem_prompt'),
]
operations = [
migrations.AddField(
model_name='problem',
name='answers',
field=models.JSONField(null=True),
),
]

View File

@@ -1,38 +0,0 @@
# Generated by Django 5.2.3 on 2025-10-11 14:57
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('problem', '0003_problem_answers'),
]
operations = [
migrations.AddField(
model_name='problem',
name='allow_flowchart',
field=models.BooleanField(default=False),
),
migrations.AddField(
model_name='problem',
name='flowchart_data',
field=models.JSONField(default=dict),
),
migrations.AddField(
model_name='problem',
name='flowchart_hint',
field=models.TextField(blank=True, null=True),
),
migrations.AddField(
model_name='problem',
name='mermaid_code',
field=models.TextField(blank=True, null=True),
),
migrations.AddField(
model_name='problem',
name='show_flowchart',
field=models.BooleanField(default=False),
),
]

View File

@@ -1,33 +0,0 @@
# Generated by Django 5.2.3 on 2025-10-11 15:22
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
('problem', '0004_problem_allow_flowchart_problem_flowchart_data_and_more'),
]
operations = [
migrations.RemoveField(
model_name='problem',
name='spj',
),
migrations.RemoveField(
model_name='problem',
name='spj_code',
),
migrations.RemoveField(
model_name='problem',
name='spj_compile_ok',
),
migrations.RemoveField(
model_name='problem',
name='spj_language',
),
migrations.RemoveField(
model_name='problem',
name='spj_version',
),
]

View File

@@ -1,4 +1,6 @@
from django.conf import settings
from django.db import models from django.db import models
from utils.models import JSONField
from account.models import User from account.models import User
from contest.models import Contest from contest.models import Contest
@@ -49,13 +51,13 @@ class Problem(models.Model):
input_description = RichTextField() input_description = RichTextField()
output_description = RichTextField() output_description = RichTextField()
# [{input: "test", output: "123"}, {input: "test123", output: "456"}] # [{input: "test", output: "123"}, {input: "test123", output: "456"}]
samples = models.JSONField() samples = JSONField()
test_case_id = models.TextField() test_case_id = models.TextField()
# [{"input_name": "1.in", "output_name": "1.out", "score": 0}] # [{"input_name": "1.in", "output_name": "1.out", "score": 0}]
test_case_score = models.JSONField() test_case_score = JSONField()
hint = RichTextField(null=True) hint = RichTextField(null=True)
languages = models.JSONField() languages = JSONField()
template = models.JSONField() template = JSONField()
create_time = models.DateTimeField(auto_now_add=True) create_time = models.DateTimeField(auto_now_add=True)
# we can not use auto_now here # we can not use auto_now here
last_update_time = models.DateTimeField(auto_now=True, null=True) last_update_time = models.DateTimeField(auto_now=True, null=True)
@@ -65,29 +67,25 @@ class Problem(models.Model):
# MB # MB
memory_limit = models.IntegerField() memory_limit = models.IntegerField()
# io mode # io mode
io_mode = models.JSONField(default=_default_io_mode) io_mode = JSONField(default=_default_io_mode)
# special judge related
spj = models.BooleanField(default=False)
spj_language = models.TextField(null=True)
spj_code = models.TextField(null=True)
spj_version = models.TextField(null=True)
spj_compile_ok = models.BooleanField(default=False)
rule_type = models.TextField() rule_type = models.TextField()
visible = models.BooleanField(default=True) visible = models.BooleanField(default=True)
difficulty = models.TextField() difficulty = models.TextField()
tags = models.ManyToManyField(ProblemTag) tags = models.ManyToManyField(ProblemTag)
source = models.TextField(null=True) source = models.TextField(null=True)
prompt = models.TextField(null=True)
# [{language: "python", code: "..."}]
answers = models.JSONField(null=True)
# for OI mode # for OI mode
total_score = models.IntegerField(default=0) total_score = models.IntegerField(default=0)
submission_number = models.BigIntegerField(default=0) submission_number = models.BigIntegerField(default=0)
accepted_number = models.BigIntegerField(default=0) accepted_number = models.BigIntegerField(default=0)
# {JudgeStatus.ACCEPTED: 3, JudgeStatus.WRONG_ANSWER: 11}, the number means count # {JudgeStatus.ACCEPTED: 3, JudgeStatus.WRONG_ANSWER: 11}, the number means count
statistic_info = models.JSONField(default=dict) statistic_info = JSONField(default=dict)
share_submission = models.BooleanField(default=False) share_submission = models.BooleanField(default=False)
# 流程图相关字段
allow_flowchart = models.BooleanField(default=False) # 是否允许/需要提交流程图
mermaid_code = models.TextField(null=True, blank=True) # 流程图答案(Mermaid代码)
flowchart_data = models.JSONField(default=dict) # 流程图答案元数据(JSON格式)
flowchart_hint = models.TextField(null=True, blank=True) # 流程图提示信息
show_flowchart = models.BooleanField(default=False) # 是否显示流程图答案数据如果True这样就不需要提交流程图了说明就是给学生看的
class Meta: class Meta:
db_table = "problem" db_table = "problem"

View File

@@ -2,10 +2,12 @@ import re
from django import forms from django import forms
from options.options import SysOptions
from utils.api import UsernameSerializer, serializers from utils.api import UsernameSerializer, serializers
from utils.constants import Difficulty from utils.constants import Difficulty
from utils.serializers import ( from utils.serializers import (
LanguageNameMultiChoiceField, LanguageNameMultiChoiceField,
SPJLanguageNameChoiceField,
LanguageNameChoiceField, LanguageNameChoiceField,
) )
@@ -14,6 +16,7 @@ from .utils import parse_problem_template
class TestCaseUploadForm(forms.Form): class TestCaseUploadForm(forms.Form):
spj = forms.CharField(max_length=12)
file = forms.FileField() file = forms.FileField()
@@ -22,11 +25,6 @@ class CreateSampleSerializer(serializers.Serializer):
output = serializers.CharField(trim_whitespace=False) output = serializers.CharField(trim_whitespace=False)
class CreateAnswerSerializer(serializers.Serializer):
language = serializers.CharField()
code = serializers.CharField()
class CreateTestCaseScoreSerializer(serializers.Serializer): class CreateTestCaseScoreSerializer(serializers.Serializer):
input_name = serializers.CharField(max_length=32) input_name = serializers.CharField(max_length=32)
output_name = serializers.CharField(max_length=32) output_name = serializers.CharField(max_length=32)
@@ -70,6 +68,10 @@ class CreateOrEditProblemSerializer(serializers.Serializer):
choices=[ProblemRuleType.ACM, ProblemRuleType.OI] choices=[ProblemRuleType.ACM, ProblemRuleType.OI]
) )
io_mode = ProblemIOModeSerializer() io_mode = ProblemIOModeSerializer()
spj = serializers.BooleanField()
spj_language = SPJLanguageNameChoiceField(allow_blank=True, allow_null=True)
spj_code = serializers.CharField(allow_blank=True, allow_null=True)
spj_compile_ok = serializers.BooleanField(default=False)
visible = serializers.BooleanField() visible = serializers.BooleanField()
difficulty = serializers.ChoiceField(choices=Difficulty.choices()) difficulty = serializers.ChoiceField(choices=Difficulty.choices())
tags = serializers.ListField( tags = serializers.ListField(
@@ -77,25 +79,8 @@ class CreateOrEditProblemSerializer(serializers.Serializer):
) )
hint = serializers.CharField(allow_blank=True, allow_null=True) hint = serializers.CharField(allow_blank=True, allow_null=True)
source = serializers.CharField(max_length=256, allow_blank=True, allow_null=True) source = serializers.CharField(max_length=256, allow_blank=True, allow_null=True)
prompt = serializers.CharField(allow_blank=True, allow_null=True)
answers = serializers.ListField(
child=CreateAnswerSerializer(),
allow_empty=True,
allow_null=True,
)
share_submission = serializers.BooleanField() share_submission = serializers.BooleanField()
# 流程图相关字段
allow_flowchart = serializers.BooleanField(required=False, default=False)
show_flowchart = serializers.BooleanField(required=False, default=False)
mermaid_code = serializers.CharField(
allow_blank=True, allow_null=True, required=False
)
flowchart_hint = serializers.CharField(
allow_blank=True, allow_null=True, required=False
)
class CreateProblemSerializer(CreateOrEditProblemSerializer): class CreateProblemSerializer(CreateOrEditProblemSerializer):
pass pass
@@ -120,6 +105,11 @@ class TagSerializer(serializers.ModelSerializer):
fields = "__all__" fields = "__all__"
class CompileSPJSerializer(serializers.Serializer):
spj_language = SPJLanguageNameChoiceField()
spj_code = serializers.CharField()
class BaseProblemSerializer(serializers.ModelSerializer): class BaseProblemSerializer(serializers.ModelSerializer):
tags = serializers.SlugRelatedField(many=True, slug_field="name", read_only=True) tags = serializers.SlugRelatedField(many=True, slug_field="name", read_only=True)
created_by = UsernameSerializer() created_by = UsernameSerializer()
@@ -145,8 +135,6 @@ class ProblemAdminListSerializer(BaseProblemSerializer):
class ProblemSerializer(BaseProblemSerializer): class ProblemSerializer(BaseProblemSerializer):
template = serializers.SerializerMethodField("get_public_template") template = serializers.SerializerMethodField("get_public_template")
mermaid_code = serializers.SerializerMethodField()
flowchart_data = serializers.SerializerMethodField()
class Meta: class Meta:
model = Problem model = Problem
@@ -155,21 +143,11 @@ class ProblemSerializer(BaseProblemSerializer):
"test_case_id", "test_case_id",
"visible", "visible",
"is_public", "is_public",
"answers", "spj_code",
"spj_version",
"spj_compile_ok",
) )
def get_mermaid_code(self, obj):
# 当 allow_flowchart 为 True 时,不返回 mermaid_code
if obj.allow_flowchart:
return None
return obj.mermaid_code
def get_flowchart_data(self, obj):
# 当 allow_flowchart 为 True 时,不返回 flowchart_data
if obj.allow_flowchart:
return None
return obj.flowchart_data
class ProblemListSerializer(BaseProblemSerializer): class ProblemListSerializer(BaseProblemSerializer):
class Meta: class Meta:
@@ -184,14 +162,12 @@ class ProblemListSerializer(BaseProblemSerializer):
"created_by", "created_by",
"tags", "tags",
"contest", "contest",
"allow_flowchart", "rule_type",
] ]
class ProblemSafeSerializer(BaseProblemSerializer): class ProblemSafeSerializer(BaseProblemSerializer):
template = serializers.SerializerMethodField("get_public_template") template = serializers.SerializerMethodField("get_public_template")
mermaid_code = serializers.SerializerMethodField()
flowchart_data = serializers.SerializerMethodField()
class Meta: class Meta:
model = Problem model = Problem
@@ -200,36 +176,116 @@ class ProblemSafeSerializer(BaseProblemSerializer):
"test_case_id", "test_case_id",
"visible", "visible",
"is_public", "is_public",
"spj_code",
"spj_version",
"spj_compile_ok",
"difficulty", "difficulty",
"submission_number", "submission_number",
"accepted_number", "accepted_number",
"statistic_info", "statistic_info",
"answers",
) )
def get_mermaid_code(self, obj):
# 当 allow_flowchart 为 True 时,不返回 mermaid_code
if obj.allow_flowchart:
return None
return obj.mermaid_code
def get_flowchart_data(self, obj):
# 当 allow_flowchart 为 True 时,不返回 flowchart_data
if obj.allow_flowchart:
return None
return obj.flowchart_data
class ContestProblemMakePublicSerializer(serializers.Serializer): class ContestProblemMakePublicSerializer(serializers.Serializer):
id = serializers.IntegerField() id = serializers.IntegerField()
display_id = serializers.CharField(max_length=32) display_id = serializers.CharField(max_length=32)
class ExportProblemSerializer(serializers.ModelSerializer):
display_id = serializers.SerializerMethodField()
description = serializers.SerializerMethodField()
input_description = serializers.SerializerMethodField()
output_description = serializers.SerializerMethodField()
test_case_score = serializers.SerializerMethodField()
hint = serializers.SerializerMethodField()
spj = serializers.SerializerMethodField()
template = serializers.SerializerMethodField()
source = serializers.SerializerMethodField()
tags = serializers.SlugRelatedField(many=True, slug_field="name", read_only=True)
def get_display_id(self, obj):
return obj._id
def _html_format_value(self, value):
return {"format": "html", "value": value}
def get_description(self, obj):
return self._html_format_value(obj.description)
def get_input_description(self, obj):
return self._html_format_value(obj.input_description)
def get_output_description(self, obj):
return self._html_format_value(obj.output_description)
def get_hint(self, obj):
return self._html_format_value(obj.hint)
def get_test_case_score(self, obj):
return [
{
"score": item["score"] if obj.rule_type == ProblemRuleType.OI else 100,
"input_name": item["input_name"],
"output_name": item["output_name"],
}
for item in obj.test_case_score
]
def get_spj(self, obj):
return {"code": obj.spj_code, "language": obj.spj_language} if obj.spj else None
def get_template(self, obj):
ret = {}
for k, v in obj.template.items():
ret[k] = parse_problem_template(v)
return ret
def get_source(self, obj):
return obj.source or f"{SysOptions.website_name} {SysOptions.website_base_url}"
class Meta:
model = Problem
fields = (
"display_id",
"title",
"description",
"tags",
"input_description",
"output_description",
"test_case_score",
"hint",
"time_limit",
"memory_limit",
"samples",
"template",
"spj",
"rule_type",
"source",
"template",
)
class AddContestProblemSerializer(serializers.Serializer): class AddContestProblemSerializer(serializers.Serializer):
contest_id = serializers.IntegerField() contest_id = serializers.IntegerField()
problem_id = serializers.IntegerField() problem_id = serializers.IntegerField()
display_id = serializers.CharField() display_id = serializers.CharField()
class ExportProblemRequestSerializer(serializers.Serializer):
problem_id = serializers.ListField(
child=serializers.IntegerField(), allow_empty=False
)
class UploadProblemForm(forms.Form):
file = forms.FileField()
class FormatValueSerializer(serializers.Serializer):
format = serializers.ChoiceField(choices=["html", "markdown"])
value = serializers.CharField(allow_blank=True)
class TestCaseScoreSerializer(serializers.Serializer): class TestCaseScoreSerializer(serializers.Serializer):
score = serializers.IntegerField(min_value=1) score = serializers.IntegerField(min_value=1)
input_name = serializers.CharField(max_length=32) input_name = serializers.CharField(max_length=32)
@@ -242,6 +298,58 @@ class TemplateSerializer(serializers.Serializer):
append = serializers.CharField() append = serializers.CharField()
class SPJSerializer(serializers.Serializer):
code = serializers.CharField()
language = SPJLanguageNameChoiceField()
class AnswerSerializer(serializers.Serializer): class AnswerSerializer(serializers.Serializer):
code = serializers.CharField() code = serializers.CharField()
language = LanguageNameChoiceField() language = LanguageNameChoiceField()
class ImportProblemSerializer(serializers.Serializer):
display_id = serializers.CharField(max_length=128)
title = serializers.CharField(max_length=128)
description = FormatValueSerializer()
input_description = FormatValueSerializer()
output_description = FormatValueSerializer()
hint = FormatValueSerializer()
test_case_score = serializers.ListField(
child=TestCaseScoreSerializer(), allow_null=True
)
time_limit = serializers.IntegerField(min_value=1, max_value=60000)
memory_limit = serializers.IntegerField(min_value=1, max_value=10240)
samples = serializers.ListField(child=CreateSampleSerializer())
template = serializers.DictField(child=TemplateSerializer())
spj = SPJSerializer(allow_null=True)
rule_type = serializers.ChoiceField(choices=ProblemRuleType.choices())
source = serializers.CharField(max_length=200, allow_blank=True, allow_null=True)
answers = serializers.ListField(child=AnswerSerializer())
tags = serializers.ListField(child=serializers.CharField())
class FPSProblemSerializer(serializers.Serializer):
class UnitSerializer(serializers.Serializer):
unit = serializers.ChoiceField(choices=["MB", "s", "ms"])
value = serializers.IntegerField(min_value=1, max_value=60000)
title = serializers.CharField(max_length=128)
description = serializers.CharField()
input = serializers.CharField()
output = serializers.CharField()
hint = serializers.CharField(allow_blank=True, allow_null=True)
time_limit = UnitSerializer()
memory_limit = UnitSerializer()
samples = serializers.ListField(child=CreateSampleSerializer())
source = serializers.CharField(max_length=200, allow_blank=True, allow_null=True)
spj = SPJSerializer(allow_null=True)
template = serializers.ListField(
child=serializers.DictField(), allow_empty=True, allow_null=True
)
append = serializers.ListField(
child=serializers.DictField(), allow_empty=True, allow_null=True
)
prepend = serializers.ListField(
child=serializers.DictField(), allow_empty=True, allow_null=True
)

324
problem/tests.py Normal file
View File

@@ -0,0 +1,324 @@
import copy
import hashlib
import os
import shutil
from datetime import timedelta
from zipfile import ZipFile
from django.conf import settings
from utils.api.tests import APITestCase
from .models import ProblemTag, ProblemIOMode
from .models import Problem, ProblemRuleType
from contest.models import Contest
from contest.tests import DEFAULT_CONTEST_DATA
from .views.admin import TestCaseAPI
from .utils import parse_problem_template
DEFAULT_PROBLEM_DATA = {"_id": "A-110", "title": "test", "description": "<p>test</p>", "input_description": "test",
"output_description": "test", "time_limit": 1000, "memory_limit": 256, "difficulty": "Low",
"visible": True, "tags": ["test"], "languages": ["C", "C++", "Java", "Python2"], "template": {},
"samples": [{"input": "test", "output": "test"}], "spj": False, "spj_language": "C",
"spj_code": "", "spj_compile_ok": True, "test_case_id": "499b26290cc7994e0b497212e842ea85",
"test_case_score": [{"output_name": "1.out", "input_name": "1.in", "output_size": 0,
"stripped_output_md5": "d41d8cd98f00b204e9800998ecf8427e",
"input_size": 0, "score": 0}],
"io_mode": {"io_mode": ProblemIOMode.standard, "input": "input.txt", "output": "output.txt"},
"share_submission": False,
"rule_type": "ACM", "hint": "<p>test</p>", "source": "test"}
class ProblemCreateTestBase(APITestCase):
@staticmethod
def add_problem(problem_data, created_by):
data = copy.deepcopy(problem_data)
if data["spj"]:
if not data["spj_language"] or not data["spj_code"]:
raise ValueError("Invalid spj")
data["spj_version"] = hashlib.md5(
(data["spj_language"] + ":" + data["spj_code"]).encode("utf-8")).hexdigest()
else:
data["spj_language"] = None
data["spj_code"] = None
if data["rule_type"] == ProblemRuleType.OI:
total_score = 0
for item in data["test_case_score"]:
if item["score"] <= 0:
raise ValueError("invalid score")
else:
total_score += item["score"]
data["total_score"] = total_score
data["created_by"] = created_by
tags = data.pop("tags")
data["languages"] = list(data["languages"])
problem = Problem.objects.create(**data)
for item in tags:
try:
tag = ProblemTag.objects.get(name=item)
except ProblemTag.DoesNotExist:
tag = ProblemTag.objects.create(name=item)
problem.tags.add(tag)
return problem
class ProblemTagListAPITest(APITestCase):
def test_get_tag_list(self):
ProblemTag.objects.create(name="name1")
ProblemTag.objects.create(name="name2")
resp = self.client.get(self.reverse("problem_tag_list_api"))
self.assertSuccess(resp)
class TestCaseUploadAPITest(APITestCase):
def setUp(self):
self.api = TestCaseAPI()
self.url = self.reverse("test_case_api")
self.create_super_admin()
def test_filter_file_name(self):
self.assertEqual(self.api.filter_name_list(["1.in", "1.out", "2.in", ".DS_Store"], spj=False),
["1.in", "1.out"])
self.assertEqual(self.api.filter_name_list(["2.in", "2.out"], spj=False), [])
self.assertEqual(self.api.filter_name_list(["1.in", "1.out", "2.in"], spj=True), ["1.in", "2.in"])
self.assertEqual(self.api.filter_name_list(["2.in", "3.in"], spj=True), [])
def make_test_case_zip(self):
base_dir = os.path.join("/tmp", "test_case")
shutil.rmtree(base_dir, ignore_errors=True)
os.mkdir(base_dir)
file_names = ["1.in", "1.out", "2.in", ".DS_Store"]
for item in file_names:
with open(os.path.join(base_dir, item), "w", encoding="utf-8") as f:
f.write(item + "\n" + item + "\r\n" + "end")
zip_file = os.path.join(base_dir, "test_case.zip")
with ZipFile(os.path.join(base_dir, "test_case.zip"), "w") as f:
for item in file_names:
f.write(os.path.join(base_dir, item), item)
return zip_file
def test_upload_spj_test_case_zip(self):
with open(self.make_test_case_zip(), "rb") as f:
resp = self.client.post(self.url,
data={"spj": "true", "file": f}, format="multipart")
self.assertSuccess(resp)
data = resp.data["data"]
self.assertEqual(data["spj"], True)
test_case_dir = os.path.join(settings.TEST_CASE_DIR, data["id"])
self.assertTrue(os.path.exists(test_case_dir))
for item in data["info"]:
name = item["input_name"]
with open(os.path.join(test_case_dir, name), "r", encoding="utf-8") as f:
self.assertEqual(f.read(), name + "\n" + name + "\n" + "end")
def test_upload_test_case_zip(self):
with open(self.make_test_case_zip(), "rb") as f:
resp = self.client.post(self.url,
data={"spj": "false", "file": f}, format="multipart")
self.assertSuccess(resp)
data = resp.data["data"]
self.assertEqual(data["spj"], False)
test_case_dir = os.path.join(settings.TEST_CASE_DIR, data["id"])
self.assertTrue(os.path.exists(test_case_dir))
for item in data["info"]:
name = item["input_name"]
with open(os.path.join(test_case_dir, name), "r", encoding="utf-8") as f:
self.assertEqual(f.read(), name + "\n" + name + "\n" + "end")
class ProblemAdminAPITest(APITestCase):
def setUp(self):
self.url = self.reverse("problem_admin_api")
self.create_super_admin()
self.data = copy.deepcopy(DEFAULT_PROBLEM_DATA)
def test_create_problem(self):
resp = self.client.post(self.url, data=self.data)
self.assertSuccess(resp)
return resp
def test_duplicate_display_id(self):
self.test_create_problem()
resp = self.client.post(self.url, data=self.data)
self.assertFailed(resp, "Display ID already exists")
def test_spj(self):
data = copy.deepcopy(self.data)
data["spj"] = True
resp = self.client.post(self.url, data)
self.assertFailed(resp, "Invalid spj")
data["spj_code"] = "test"
resp = self.client.post(self.url, data=data)
self.assertSuccess(resp)
def test_get_problem(self):
self.test_create_problem()
resp = self.client.get(self.url)
self.assertSuccess(resp)
def test_get_one_problem(self):
problem_id = self.test_create_problem().data["data"]["id"]
resp = self.client.get(self.url + "?id=" + str(problem_id))
self.assertSuccess(resp)
def test_edit_problem(self):
problem_id = self.test_create_problem().data["data"]["id"]
data = copy.deepcopy(self.data)
data["id"] = problem_id
resp = self.client.put(self.url, data=data)
self.assertSuccess(resp)
class ProblemAPITest(ProblemCreateTestBase):
def setUp(self):
self.url = self.reverse("problem_api")
admin = self.create_admin(login=False)
self.problem = self.add_problem(DEFAULT_PROBLEM_DATA, admin)
self.create_user("test", "test123")
def test_get_problem_list(self):
resp = self.client.get(f"{self.url}?limit=10")
self.assertSuccess(resp)
def get_one_problem(self):
resp = self.client.get(self.url + "?id=" + self.problem._id)
self.assertSuccess(resp)
class ContestProblemAdminTest(APITestCase):
def setUp(self):
self.url = self.reverse("contest_problem_admin_api")
self.create_admin()
self.contest = self.client.post(self.reverse("contest_admin_api"), data=DEFAULT_CONTEST_DATA).data["data"]
def test_create_contest_problem(self):
data = copy.deepcopy(DEFAULT_PROBLEM_DATA)
data["contest_id"] = self.contest["id"]
resp = self.client.post(self.url, data=data)
self.assertSuccess(resp)
return resp.data["data"]
def test_get_contest_problem(self):
self.test_create_contest_problem()
contest_id = self.contest["id"]
resp = self.client.get(self.url + "?contest_id=" + str(contest_id))
self.assertSuccess(resp)
self.assertEqual(len(resp.data["data"]["results"]), 1)
def test_get_one_contest_problem(self):
contest_problem = self.test_create_contest_problem()
contest_id = self.contest["id"]
problem_id = contest_problem["id"]
resp = self.client.get(f"{self.url}?contest_id={contest_id}&id={problem_id}")
self.assertSuccess(resp)
class ContestProblemTest(ProblemCreateTestBase):
def setUp(self):
admin = self.create_admin()
url = self.reverse("contest_admin_api")
contest_data = copy.deepcopy(DEFAULT_CONTEST_DATA)
contest_data["password"] = ""
contest_data["start_time"] = contest_data["start_time"] + timedelta(hours=1)
self.contest = self.client.post(url, data=contest_data).data["data"]
self.problem = self.add_problem(DEFAULT_PROBLEM_DATA, admin)
self.problem.contest_id = self.contest["id"]
self.problem.save()
self.url = self.reverse("contest_problem_api")
def test_admin_get_contest_problem_list(self):
contest_id = self.contest["id"]
resp = self.client.get(self.url + "?contest_id=" + str(contest_id))
self.assertSuccess(resp)
self.assertEqual(len(resp.data["data"]), 1)
def test_admin_get_one_contest_problem(self):
contest_id = self.contest["id"]
problem_id = self.problem._id
resp = self.client.get("{}?contest_id={}&problem_id={}".format(self.url, contest_id, problem_id))
self.assertSuccess(resp)
def test_regular_user_get_not_started_contest_problem(self):
self.create_user("test", "test123")
resp = self.client.get(self.url + "?contest_id=" + str(self.contest["id"]))
self.assertDictEqual(resp.data, {"error": "error", "data": "Contest has not started yet."})
def test_reguar_user_get_started_contest_problem(self):
self.create_user("test", "test123")
contest = Contest.objects.first()
contest.start_time = contest.start_time - timedelta(hours=1)
contest.save()
resp = self.client.get(self.url + "?contest_id=" + str(self.contest["id"]))
self.assertSuccess(resp)
class AddProblemFromPublicProblemAPITest(ProblemCreateTestBase):
def setUp(self):
admin = self.create_admin()
url = self.reverse("contest_admin_api")
contest_data = copy.deepcopy(DEFAULT_CONTEST_DATA)
contest_data["password"] = ""
contest_data["start_time"] = contest_data["start_time"] + timedelta(hours=1)
self.contest = self.client.post(url, data=contest_data).data["data"]
self.problem = self.add_problem(DEFAULT_PROBLEM_DATA, admin)
self.url = self.reverse("add_contest_problem_from_public_api")
self.data = {
"display_id": "1000",
"contest_id": self.contest["id"],
"problem_id": self.problem.id
}
def test_add_contest_problem(self):
resp = self.client.post(self.url, data=self.data)
self.assertSuccess(resp)
self.assertTrue(Problem.objects.all().exists())
self.assertTrue(Problem.objects.filter(contest_id=self.contest["id"]).exists())
class ParseProblemTemplateTest(APITestCase):
def test_parse(self):
template_str = """
//PREPEND BEGIN
aaa
//PREPEND END
//TEMPLATE BEGIN
bbb
//TEMPLATE END
//APPEND BEGIN
ccc
//APPEND END
"""
ret = parse_problem_template(template_str)
self.assertEqual(ret["prepend"], "aaa\n")
self.assertEqual(ret["template"], "bbb\n")
self.assertEqual(ret["append"], "ccc\n")
def test_parse1(self):
template_str = """
//PREPEND BEGIN
aaa
//PREPEND END
//APPEND BEGIN
ccc
//APPEND END
//APPEND BEGIN
ddd
//APPEND END
"""
ret = parse_problem_template(template_str)
self.assertEqual(ret["prepend"], "aaa\n")
self.assertEqual(ret["template"], "")
self.assertEqual(ret["append"], "ccc\n")

View File

@@ -1,19 +1,18 @@
from django.urls import path from django.urls import path
from ..views.admin import ( from ..views.admin import (ContestProblemAPI, ProblemAPI, TestCaseAPI, MakeContestProblemPublicAPIView,
ContestProblemAPI, CompileSPJAPI, AddContestProblemAPI, ExportProblemAPI, ImportProblemAPI,
ProblemAPI, FPSProblemImport, ProblemVisibleAPI)
TestCaseAPI,
MakeContestProblemPublicAPIView,
AddContestProblemAPI,
ProblemVisibleAPI,
)
urlpatterns = [ urlpatterns = [
path("test_case", TestCaseAPI.as_view()), path("test_case", TestCaseAPI.as_view()),
path("compile_spj", CompileSPJAPI.as_view()),
path("problem", ProblemAPI.as_view()), path("problem", ProblemAPI.as_view()),
path("problem/visible", ProblemVisibleAPI.as_view()), path("problem/visible", ProblemVisibleAPI.as_view()),
path("contest/problem", ContestProblemAPI.as_view()), path("contest/problem", ContestProblemAPI.as_view()),
path("contest_problem/make_public", MakeContestProblemPublicAPIView.as_view()), path("contest_problem/make_public", MakeContestProblemPublicAPIView.as_view()),
path("contest/add_problem_from_public", AddContestProblemAPI.as_view()), path("contest/add_problem_from_public", AddContestProblemAPI.as_view()),
path("export_problem", ExportProblemAPI.as_view()),
path("import_problem", ImportProblemAPI.as_view()),
path("import_fps", FPSProblemImport.as_view()),
] ]

View File

@@ -6,14 +6,12 @@ from ..views.oj import (
ProblemAPI, ProblemAPI,
ContestProblemAPI, ContestProblemAPI,
PickOneAPI, PickOneAPI,
ProblemAuthorAPI,
) )
urlpatterns = [ urlpatterns = [
path("problem/tags", ProblemTagAPI.as_view()), path("problem/tags", ProblemTagAPI.as_view()),
path("problem", ProblemAPI.as_view()), path("problem", ProblemAPI.as_view()),
path("problem/beat_count", ProblemSolvedPeopleCount.as_view()), path("problem/beat_count", ProblemSolvedPeopleCount.as_view()),
path("problem/author", ProblemAuthorAPI.as_view()),
path("pickone", PickOneAPI.as_view()), path("pickone", PickOneAPI.as_view()),
path("contest/problem", ContestProblemAPI.as_view()), path("contest/problem", ContestProblemAPI.as_view()),
] ]

View File

@@ -1,42 +1,44 @@
import hashlib import hashlib
import json import json
import os import os
# import shutil # import shutil
import tempfile
import zipfile import zipfile
from wsgiref.util import FileWrapper from wsgiref.util import FileWrapper
from django.conf import settings from django.conf import settings
from django.db import transaction
from django.db.models import Q from django.db.models import Q
from django.http import StreamingHttpResponse from django.http import StreamingHttpResponse, FileResponse
from account.decorators import problem_permission_required, ensure_created_by from account.decorators import problem_permission_required, ensure_created_by, super_admin_required
from contest.models import Contest, ContestStatus from contest.models import Contest, ContestStatus
from submission.models import Submission from fps.parser import FPSHelper, FPSParser
from judge.dispatcher import SPJCompiler
from options.options import SysOptions
from submission.models import Submission, JudgeStatus
from utils.api import APIView, CSRFExemptAPIView, validate_serializer, APIError from utils.api import APIView, CSRFExemptAPIView, validate_serializer, APIError
from utils.constants import Difficulty
from utils.shortcuts import rand_str, natural_sort_key from utils.shortcuts import rand_str, natural_sort_key
from utils.tasks import delete_files
from ..models import Problem, ProblemRuleType, ProblemTag from ..models import Problem, ProblemRuleType, ProblemTag
from ..serializers import ( from ..serializers import (CreateContestProblemSerializer, CompileSPJSerializer,
CreateContestProblemSerializer, CreateProblemSerializer, EditProblemSerializer, EditContestProblemSerializer,
CreateProblemSerializer, ProblemAdminSerializer, ProblemAdminListSerializer, TestCaseUploadForm,
EditProblemSerializer, ContestProblemMakePublicSerializer, AddContestProblemSerializer, ExportProblemSerializer,
EditContestProblemSerializer, ExportProblemRequestSerializer, UploadProblemForm, ImportProblemSerializer,
ProblemAdminSerializer, FPSProblemSerializer)
ProblemAdminListSerializer, from ..utils import TEMPLATE_BASE, build_problem_template
TestCaseUploadForm,
ContestProblemMakePublicSerializer,
AddContestProblemSerializer,
)
class TestCaseZipProcessor(object): class TestCaseZipProcessor(object):
def process_zip(self, uploaded_zip_file, dir=""): def process_zip(self, uploaded_zip_file, spj, dir=""):
try: try:
zip_file = zipfile.ZipFile(uploaded_zip_file, "r") zip_file = zipfile.ZipFile(uploaded_zip_file, "r")
except zipfile.BadZipFile: except zipfile.BadZipFile:
raise APIError("Bad zip file") raise APIError("Bad zip file")
name_list = zip_file.namelist() name_list = zip_file.namelist()
test_case_list = self.filter_name_list(name_list, dir=dir) test_case_list = self.filter_name_list(name_list, spj=spj, dir=dir)
if not test_case_list: if not test_case_list:
raise APIError("Empty file") raise APIError("Empty file")
@@ -55,22 +57,26 @@ class TestCaseZipProcessor(object):
if item.endswith(".out"): if item.endswith(".out"):
md5_cache[item] = hashlib.md5(content.rstrip()).hexdigest() md5_cache[item] = hashlib.md5(content.rstrip()).hexdigest()
f.write(content) f.write(content)
test_case_info = {"test_cases": {}} test_case_info = {"spj": spj, "test_cases": {}}
info = [] info = []
# ["1.in", "1.out", "2.in", "2.out"] => [("1.in", "1.out"), ("2.in", "2.out")] if spj:
test_case_list = zip(*[test_case_list[i::2] for i in range(2)]) for index, item in enumerate(test_case_list):
for index, item in enumerate(test_case_list): data = {"input_name": item, "input_size": size_cache[item]}
data = { info.append(data)
"stripped_output_md5": md5_cache[item[1]], test_case_info["test_cases"][str(index + 1)] = data
"input_size": size_cache[item[0]], else:
"output_size": size_cache[item[1]], # ["1.in", "1.out", "2.in", "2.out"] => [("1.in", "1.out"), ("2.in", "2.out")]
"input_name": item[0], test_case_list = zip(*[test_case_list[i::2] for i in range(2)])
"output_name": item[1], for index, item in enumerate(test_case_list):
} data = {"stripped_output_md5": md5_cache[item[1]],
info.append(data) "input_size": size_cache[item[0]],
test_case_info["test_cases"][str(index + 1)] = data "output_size": size_cache[item[1]],
"input_name": item[0],
"output_name": item[1]}
info.append(data)
test_case_info["test_cases"][str(index + 1)] = data
with open(os.path.join(test_case_dir, "info"), "w", encoding="utf-8") as f: with open(os.path.join(test_case_dir, "info"), "w", encoding="utf-8") as f:
f.write(json.dumps(test_case_info, indent=4)) f.write(json.dumps(test_case_info, indent=4))
@@ -80,19 +86,29 @@ class TestCaseZipProcessor(object):
return info, test_case_id return info, test_case_id
def filter_name_list(self, name_list, dir=""): def filter_name_list(self, name_list, spj, dir=""):
ret = [] ret = []
prefix = 1 prefix = 1
while True: if spj:
in_name = f"{prefix}.in" while True:
out_name = f"{prefix}.out" in_name = f"{prefix}.in"
if f"{dir}{in_name}" in name_list and f"{dir}{out_name}" in name_list: if f"{dir}{in_name}" in name_list:
ret.append(in_name) ret.append(in_name)
ret.append(out_name) prefix += 1
prefix += 1 continue
continue else:
else: return sorted(ret, key=natural_sort_key)
return sorted(ret, key=natural_sort_key) else:
while True:
in_name = f"{prefix}.in"
out_name = f"{prefix}.out"
if f"{dir}{in_name}" in name_list and f"{dir}{out_name}" in name_list:
ret.append(in_name)
ret.append(out_name)
prefix += 1
continue
else:
return sorted(ret, key=natural_sort_key)
class TestCaseAPI(CSRFExemptAPIView, TestCaseZipProcessor): class TestCaseAPI(CSRFExemptAPIView, TestCaseZipProcessor):
@@ -115,25 +131,23 @@ class TestCaseAPI(CSRFExemptAPIView, TestCaseZipProcessor):
test_case_dir = os.path.join(settings.TEST_CASE_DIR, problem.test_case_id) test_case_dir = os.path.join(settings.TEST_CASE_DIR, problem.test_case_id)
if not os.path.isdir(test_case_dir): if not os.path.isdir(test_case_dir):
return self.error("Test case does not exists") return self.error("Test case does not exists")
name_list = self.filter_name_list(os.listdir(test_case_dir)) name_list = self.filter_name_list(os.listdir(test_case_dir), problem.spj)
name_list.append("info") name_list.append("info")
file_name = os.path.join(test_case_dir, problem.test_case_id + ".zip") file_name = os.path.join(test_case_dir, problem.test_case_id + ".zip")
with zipfile.ZipFile(file_name, "w") as file: with zipfile.ZipFile(file_name, "w") as file:
for test_case in name_list: for test_case in name_list:
file.write(f"{test_case_dir}/{test_case}", test_case) file.write(f"{test_case_dir}/{test_case}", test_case)
response = StreamingHttpResponse( response = StreamingHttpResponse(FileWrapper(open(file_name, "rb")),
FileWrapper(open(file_name, "rb")), content_type="application/octet-stream" content_type="application/octet-stream")
)
response["Content-Disposition"] = ( response["Content-Disposition"] = f"attachment; filename=problem_{problem.id}_test_cases.zip"
f"attachment; filename=problem_{problem.id}_test_cases.zip"
)
response["Content-Length"] = os.path.getsize(file_name) response["Content-Length"] = os.path.getsize(file_name)
return response return response
def post(self, request): def post(self, request):
form = TestCaseUploadForm(request.POST, request.FILES) form = TestCaseUploadForm(request.POST, request.FILES)
if form.is_valid(): if form.is_valid():
spj = form.cleaned_data["spj"] == "true"
file = form.cleaned_data["file"] file = form.cleaned_data["file"]
else: else:
return self.error("Upload failed") return self.error("Upload failed")
@@ -141,14 +155,36 @@ class TestCaseAPI(CSRFExemptAPIView, TestCaseZipProcessor):
with open(zip_file, "wb") as f: with open(zip_file, "wb") as f:
for chunk in file: for chunk in file:
f.write(chunk) f.write(chunk)
info, test_case_id = self.process_zip(zip_file) info, test_case_id = self.process_zip(zip_file, spj=spj)
os.remove(zip_file) os.remove(zip_file)
return self.success({"id": test_case_id, "info": info}) return self.success({"id": test_case_id, "info": info, "spj": spj})
class CompileSPJAPI(APIView):
@validate_serializer(CompileSPJSerializer)
def post(self, request):
data = request.data
spj_version = rand_str(8)
error = SPJCompiler(data["spj_code"], spj_version, data["spj_language"]).compile_spj()
if error:
return self.error(error)
else:
return self.success()
class ProblemBase(APIView): class ProblemBase(APIView):
def common_checks(self, request): def common_checks(self, request):
data = request.data data = request.data
if data["spj"]:
if not data["spj_language"] or not data["spj_code"]:
return "Invalid spj"
if not data["spj_compile_ok"]:
return "SPJ code must be compiled successfully"
data["spj_version"] = hashlib.md5(
(data["spj_language"] + ":" + data["spj_code"]).encode("utf-8")).hexdigest()
else:
data["spj_language"] = None
data["spj_code"] = None
if data["rule_type"] == ProblemRuleType.OI: if data["rule_type"] == ProblemRuleType.OI:
total_score = 0 total_score = 0
for item in data["test_case_score"]: for item in data["test_case_score"]:
@@ -191,6 +227,7 @@ class ProblemAPI(ProblemBase):
@problem_permission_required @problem_permission_required
def get(self, request): def get(self, request):
problem_id = request.GET.get("id") problem_id = request.GET.get("id")
rule_type = request.GET.get("rule_type")
user = request.user user = request.user
if problem_id: if problem_id:
try: try:
@@ -200,24 +237,19 @@ class ProblemAPI(ProblemBase):
except Problem.DoesNotExist: except Problem.DoesNotExist:
return self.error("Problem does not exist") return self.error("Problem does not exist")
problems = Problem.objects.filter(contest_id__isnull=True).order_by( problems = Problem.objects.filter(contest_id__isnull=True).order_by("-create_time")
"-create_time" if rule_type:
) if rule_type not in ProblemRuleType.choices():
return self.error("Invalid rule_type")
author = request.GET.get("author", "") else:
if author: problems = problems.filter(rule_type=rule_type)
problems = problems.filter(created_by__username=author)
keyword = request.GET.get("keyword", "").strip() keyword = request.GET.get("keyword", "").strip()
if keyword: if keyword:
problems = problems.filter( problems = problems.filter(Q(title__icontains=keyword) | Q(_id__icontains=keyword))
Q(title__icontains=keyword) | Q(_id__icontains=keyword)
)
if not user.can_mgmt_all_problem(): if not user.can_mgmt_all_problem():
problems = problems.filter(created_by=user) problems = problems.filter(created_by=user)
return self.success( return self.success(self.paginate_data(request, problems, ProblemAdminListSerializer))
self.paginate_data(request, problems, ProblemAdminListSerializer)
)
@problem_permission_required @problem_permission_required
@validate_serializer(EditProblemSerializer) @validate_serializer(EditProblemSerializer)
@@ -234,11 +266,7 @@ class ProblemAPI(ProblemBase):
_id = data["_id"] _id = data["_id"]
if not _id: if not _id:
return self.error("Display ID is required") return self.error("Display ID is required")
if ( if Problem.objects.exclude(id=problem_id).filter(_id=_id, contest_id__isnull=True).exists():
Problem.objects.exclude(id=problem_id)
.filter(_id=_id, contest_id__isnull=True)
.exists()
):
return self.error("Display ID already exists") return self.error("Display ID already exists")
error_info = self.common_checks(request) error_info = self.common_checks(request)
@@ -342,9 +370,7 @@ class ContestProblemAPI(ProblemBase):
keyword = request.GET.get("keyword") keyword = request.GET.get("keyword")
if keyword: if keyword:
problems = problems.filter(title__contains=keyword) problems = problems.filter(title__contains=keyword)
return self.success( return self.success(self.paginate_data(request, problems, ProblemAdminListSerializer))
self.paginate_data(request, problems, ProblemAdminListSerializer)
)
@validate_serializer(EditContestProblemSerializer) @validate_serializer(EditContestProblemSerializer)
def put(self, request): def put(self, request):
@@ -370,11 +396,7 @@ class ContestProblemAPI(ProblemBase):
_id = data["_id"] _id = data["_id"]
if not _id: if not _id:
return self.error("Display ID is required") return self.error("Display ID is required")
if ( if Problem.objects.exclude(id=problem_id).filter(_id=_id, contest=contest).exists():
Problem.objects.exclude(id=problem_id)
.filter(_id=_id, contest=contest)
.exists()
):
return self.error("Display ID already exists") return self.error("Display ID already exists")
error_info = self.common_checks(request) error_info = self.common_checks(request)
@@ -433,6 +455,7 @@ class MakeContestProblemPublicAPIView(APIView):
return self.error("Already be a public problem") return self.error("Already be a public problem")
problem.is_public = True problem.is_public = True
problem.save() problem.save()
# https://docs.djangoproject.com/en/1.11/topics/db/queries/#copying-model-instances
tags = problem.tags.all() tags = problem.tags.all()
problem.pk = None problem.pk = None
problem.contest = None problem.contest = None
@@ -473,6 +496,215 @@ class AddContestProblemAPI(APIView):
return self.success() return self.success()
class ExportProblemAPI(APIView):
def choose_answers(self, user, problem):
ret = []
for item in problem.languages:
submission = Submission.objects.filter(problem=problem,
user_id=user.id,
language=item,
result=JudgeStatus.ACCEPTED).order_by("-create_time").first()
if submission:
ret.append({"language": submission.language, "code": submission.code})
return ret
def process_one_problem(self, zip_file, user, problem, index):
info = ExportProblemSerializer(problem).data
info["answers"] = self.choose_answers(user, problem=problem)
compression = zipfile.ZIP_DEFLATED
zip_file.writestr(zinfo_or_arcname=f"{index}/problem.json",
data=json.dumps(info, indent=4),
compress_type=compression)
problem_test_case_dir = os.path.join(settings.TEST_CASE_DIR, problem.test_case_id)
with open(os.path.join(problem_test_case_dir, "info")) as f:
info = json.load(f)
for k, v in info["test_cases"].items():
zip_file.write(filename=os.path.join(problem_test_case_dir, v["input_name"]),
arcname=f"{index}/testcase/{v['input_name']}",
compress_type=compression)
if not info["spj"]:
zip_file.write(filename=os.path.join(problem_test_case_dir, v["output_name"]),
arcname=f"{index}/testcase/{v['output_name']}",
compress_type=compression)
@validate_serializer(ExportProblemRequestSerializer)
def get(self, request):
problems = Problem.objects.filter(id__in=request.data["problem_id"])
for problem in problems:
if problem.contest:
ensure_created_by(problem.contest, request.user)
else:
ensure_created_by(problem, request.user)
path = f"/tmp/{rand_str()}.zip"
with zipfile.ZipFile(path, "w") as zip_file:
for index, problem in enumerate(problems):
self.process_one_problem(zip_file=zip_file, user=request.user, problem=problem, index=index + 1)
delete_files.send_with_options(args=(path,), delay=300_000)
resp = FileResponse(open(path, "rb"))
resp["Content-Type"] = "application/zip"
resp["Content-Disposition"] = "attachment;filename=problem-export.zip"
return resp
class ImportProblemAPI(CSRFExemptAPIView, TestCaseZipProcessor):
request_parsers = ()
def post(self, request):
form = UploadProblemForm(request.POST, request.FILES)
if form.is_valid():
file = form.cleaned_data["file"]
tmp_file = f"/tmp/{rand_str()}.zip"
with open(tmp_file, "wb") as f:
for chunk in file:
f.write(chunk)
else:
return self.error("Upload failed")
count = 0
with zipfile.ZipFile(tmp_file, "r") as zip_file:
name_list = zip_file.namelist()
for item in name_list:
if "/problem.json" in item:
count += 1
with transaction.atomic():
for i in range(1, count + 1):
with zip_file.open(f"{i}/problem.json") as f:
problem_info = json.load(f)
serializer = ImportProblemSerializer(data=problem_info)
if not serializer.is_valid():
return self.error(f"Invalid problem format, error is {serializer.errors}")
else:
problem_info = serializer.data
for item in problem_info["template"].keys():
if item not in SysOptions.language_names:
return self.error(f"Unsupported language {item}")
problem_info["display_id"] = problem_info["display_id"][:24]
for k, v in problem_info["template"].items():
problem_info["template"][k] = build_problem_template(v["prepend"], v["template"],
v["append"])
spj = problem_info["spj"] is not None
rule_type = problem_info["rule_type"]
test_case_score = problem_info["test_case_score"]
# process test case
_, test_case_id = self.process_zip(tmp_file, spj=spj, dir=f"{i}/testcase/")
problem_obj = Problem.objects.create(_id=problem_info["display_id"],
title=problem_info["title"],
description=problem_info["description"]["value"],
input_description=problem_info["input_description"][
"value"],
output_description=problem_info["output_description"][
"value"],
hint=problem_info["hint"]["value"],
test_case_score=test_case_score if test_case_score else [],
time_limit=problem_info["time_limit"],
memory_limit=problem_info["memory_limit"],
samples=problem_info["samples"],
template=problem_info["template"],
rule_type=problem_info["rule_type"],
source=problem_info["source"],
spj=spj,
spj_code=problem_info["spj"]["code"] if spj else None,
spj_language=problem_info["spj"][
"language"] if spj else None,
spj_version=rand_str(8) if spj else "",
languages=SysOptions.language_names,
created_by=request.user,
visible=False,
difficulty=Difficulty.MID,
total_score=sum(item["score"] for item in test_case_score)
if rule_type == ProblemRuleType.OI else 0,
test_case_id=test_case_id
)
for tag_name in problem_info["tags"]:
tag_obj, _ = ProblemTag.objects.get_or_create(name=tag_name)
problem_obj.tags.add(tag_obj)
return self.success({"import_count": count})
class FPSProblemImport(CSRFExemptAPIView):
request_parsers = ()
def _create_problem(self, problem_data, creator):
if problem_data["time_limit"]["unit"] == "ms":
time_limit = problem_data["time_limit"]["value"]
else:
time_limit = problem_data["time_limit"]["value"] * 1000
template = {}
prepend = {}
append = {}
for t in problem_data["prepend"]:
prepend[t["language"]] = t["code"]
for t in problem_data["append"]:
append[t["language"]] = t["code"]
for t in problem_data["template"]:
our_lang = lang = t["language"]
if lang == "Python":
our_lang = "Python3"
template[our_lang] = TEMPLATE_BASE.format(prepend.get(lang, ""), t["code"], append.get(lang, ""))
spj = problem_data["spj"] is not None
Problem.objects.create(_id=f"fps-{rand_str(4)}",
title=problem_data["title"],
description=problem_data["description"],
input_description=problem_data["input"],
output_description=problem_data["output"],
hint=problem_data["hint"],
test_case_score=problem_data["test_case_score"],
time_limit=time_limit,
memory_limit=problem_data["memory_limit"]["value"],
samples=problem_data["samples"],
template=template,
rule_type=ProblemRuleType.ACM,
source=problem_data.get("source", ""),
spj=spj,
spj_code=problem_data["spj"]["code"] if spj else None,
spj_language=problem_data["spj"]["language"] if spj else None,
spj_version=rand_str(8) if spj else "",
visible=False,
languages=SysOptions.language_names,
created_by=creator,
difficulty=Difficulty.MID,
test_case_id=problem_data["test_case_id"])
def post(self, request):
form = UploadProblemForm(request.POST, request.FILES)
if form.is_valid():
file = form.cleaned_data["file"]
with tempfile.NamedTemporaryFile("wb") as tf:
for chunk in file.chunks(4096):
tf.file.write(chunk)
tf.file.flush()
os.fsync(tf.file)
problems = FPSParser(tf.name).parse()
else:
return self.error("Parse upload file error")
helper = FPSHelper()
with transaction.atomic():
for _problem in problems:
test_case_id = rand_str()
test_case_dir = os.path.join(settings.TEST_CASE_DIR, test_case_id)
os.mkdir(test_case_dir)
score = []
for item in helper.save_test_case(_problem, test_case_dir)["test_cases"].values():
score.append({"score": 0, "input_name": item["input_name"],
"output_name": item.get("output_name")})
problem_data = helper.save_image(_problem, settings.UPLOAD_DIR, settings.UPLOAD_PREFIX)
s = FPSProblemSerializer(data=problem_data)
if not s.is_valid():
return self.error(f"Parse FPS file error: {s.errors}")
problem_data = s.data
problem_data["test_case_id"] = test_case_id
problem_data["test_case_score"] = score
self._create_problem(problem_data, request.user)
return self.success({"import_count": len(problems)})
class ProblemVisibleAPI(APIView): class ProblemVisibleAPI(APIView):
@problem_permission_required @problem_permission_required
def put(self, request): def put(self, request):
@@ -483,4 +715,4 @@ class ProblemVisibleAPI(APIView):
self.error("problem does not exists") self.error("problem does not exists")
problem.visible = not problem.visible problem.visible = not problem.visible
problem.save() problem.save()
return self.success() return self.success()

View File

@@ -1,12 +1,10 @@
from datetime import datetime from datetime import datetime
import random import random
from django.db.models import Q, Count from django.db.models import Q, Count
from django.core.cache import cache
from account.models import User from account.models import User
from submission.models import Submission, JudgeStatus from submission.models import Submission, JudgeStatus
from utils.api import APIView from utils.api import APIView
from account.decorators import check_contest_permission from account.decorators import check_contest_permission
from utils.constants import CacheKey
from ..models import ProblemTag, Problem, ProblemRuleType from ..models import ProblemTag, Problem, ProblemRuleType
from ..serializers import ( from ..serializers import (
ProblemSerializer, ProblemSerializer,
@@ -42,16 +40,24 @@ class ProblemAPI(APIView):
if request.user.is_authenticated: if request.user.is_authenticated:
profile = request.user.userprofile profile = request.user.userprofile
acm_problems_status = profile.acm_problems_status.get("problems", {}) acm_problems_status = profile.acm_problems_status.get("problems", {})
oi_problems_status = profile.oi_problems_status.get("problems", {})
# paginate data # paginate data
results = queryset_values.get("results") results = queryset_values.get("results")
if results is not None: if results is not None:
problems = results problems = results
else: else:
problems = [queryset_values] problems = [
queryset_values,
]
for problem in problems: for problem in problems:
problem["my_status"] = acm_problems_status.get( if problem["rule_type"] == ProblemRuleType.ACM:
str(problem["id"]), {} problem["my_status"] = acm_problems_status.get(
).get("status") str(problem["id"]), {}
).get("status")
else:
problem["my_status"] = oi_problems_status.get(
str(problem["id"]), {}
).get("status")
def get(self, request): def get(self, request):
# 问题详情页 # 问题详情页
@@ -76,11 +82,6 @@ class ProblemAPI(APIView):
.filter(contest_id__isnull=True, visible=True) .filter(contest_id__isnull=True, visible=True)
.order_by("-create_time") .order_by("-create_time")
) )
author = request.GET.get("author")
if author:
problems = problems.filter(created_by__username=author)
# 按照标签筛选 # 按照标签筛选
tag_text = request.GET.get("tag") tag_text = request.GET.get("tag")
if tag_text: if tag_text:
@@ -97,12 +98,6 @@ class ProblemAPI(APIView):
difficulty = request.GET.get("difficulty") difficulty = request.GET.get("difficulty")
if difficulty: if difficulty:
problems = problems.filter(difficulty=difficulty) problems = problems.filter(difficulty=difficulty)
# 排序
sort = request.GET.get("sort")
if sort:
problems = problems.order_by(sort)
# 根据profile 为做过的题目添加标记 # 根据profile 为做过的题目添加标记
data = self.paginate_data(request, problems, ProblemListSerializer) data = self.paginate_data(request, problems, ProblemListSerializer)
self._add_problem_status(request, data) self._add_problem_status(request, data)
@@ -171,48 +166,17 @@ class ProblemSolvedPeopleCount(APIView):
if submission_count == 0: if submission_count == 0:
return self.success(rate) return self.success(rate)
today = datetime.today() today = datetime.today()
years_ago = datetime(today.year - 2, today.month, today.day, 0, 0) twoYearAge = datetime(today.year - 2, today.month, today.day, 0, 0)
total_count = User.objects.filter( total_count = User.objects.filter(
is_disabled=False, last_login__gte=years_ago is_disabled=False, last_login__gte=twoYearAge
).count() ).count()
accepted_count = Submission.objects.filter( accepted_count = Submission.objects.filter(
problem_id=problem_id, problem_id=problem_id,
result=JudgeStatus.ACCEPTED, result=JudgeStatus.ACCEPTED,
create_time__gte=years_ago, create_time__gte=twoYearAge,
).aggregate(user_count=Count("user_id", distinct=True))["user_count"] ).aggregate(user_count=Count("user_id", distinct=True))["user_count"]
if accepted_count < total_count: if accepted_count < total_count:
rate = "%.2f" % ((total_count - accepted_count) / total_count * 100) rate = "%.2f" % ((total_count - accepted_count) / total_count * 100)
else: else:
rate = "0" rate = "0"
return self.success(rate) return self.success(rate)
class ProblemAuthorAPI(APIView):
def get(self, request):
show_all = request.GET.get("all", "0") == "1"
cached_data = cache.get(
f"{CacheKey.problem_authors}{'_all' if show_all else '_only_visible'}"
)
if cached_data:
return self.success(cached_data)
problem_filter = {"contest_id__isnull": True, "created_by__is_disabled": False}
if not show_all:
problem_filter["visible"] = True
authors = (
Problem.objects.filter(**problem_filter)
.values("created_by__username")
.annotate(problem_count=Count("id"))
.order_by("-problem_count")
)
result = [
{
"username": author["created_by__username"],
"problem_count": author["problem_count"],
}
for author in authors
]
cache.set(CacheKey.problem_authors, result, 7200)
return self.success(result)

View File

View File

@@ -1,9 +0,0 @@
from django.apps import AppConfig
class ProblemsetConfig(AppConfig):
default_auto_field = 'django.db.models.BigAutoField'
name = 'problemset'
def ready(self):
import problemset.signals

View File

@@ -1,115 +0,0 @@
# Generated by Django 5.2.3 on 2025-10-22 10:27
import django.db.models.deletion
import utils.models
from django.conf import settings
from django.db import migrations, models
class Migration(migrations.Migration):
initial = True
dependencies = [
('problem', '0005_remove_spj_fields'),
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
]
operations = [
migrations.CreateModel(
name='ProblemSet',
fields=[
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('title', models.TextField(verbose_name='题单标题')),
('description', utils.models.RichTextField(verbose_name='题单描述')),
('create_time', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')),
('last_update_time', models.DateTimeField(auto_now=True, verbose_name='更新时间')),
('visible', models.BooleanField(default=True, verbose_name='是否可见')),
('is_public', models.BooleanField(default=True, verbose_name='是否公开')),
('difficulty', models.TextField(default='Easy', verbose_name='难度等级')),
('status', models.TextField(default='active', verbose_name='状态')),
('created_by', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL, verbose_name='创建者')),
],
options={
'verbose_name': '题单',
'verbose_name_plural': '题单',
'db_table': 'problemset',
'ordering': ('-create_time',),
},
),
migrations.CreateModel(
name='ProblemSetBadge',
fields=[
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('name', models.TextField(verbose_name='奖章名称')),
('description', models.TextField(verbose_name='奖章描述')),
('icon', models.TextField(verbose_name='奖章图标')),
('condition_type', models.TextField(verbose_name='获得条件类型')),
('condition_value', models.IntegerField(default=0, verbose_name='条件值')),
('level', models.IntegerField(default=1, verbose_name='奖章等级')),
('problemset', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='problemset.problemset', verbose_name='题单')),
],
options={
'verbose_name': '题单奖章',
'verbose_name_plural': '题单奖章',
'db_table': 'problemset_badge',
},
),
migrations.CreateModel(
name='ProblemSetProblem',
fields=[
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('order', models.IntegerField(default=0, verbose_name='顺序')),
('is_required', models.BooleanField(default=True, verbose_name='是否必做')),
('score', models.IntegerField(default=0, verbose_name='分值')),
('hint', models.TextField(blank=True, null=True, verbose_name='提示')),
('problem', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='problem.problem', verbose_name='题目')),
('problemset', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='problemset.problemset', verbose_name='题单')),
],
options={
'verbose_name': '题单题目',
'verbose_name_plural': '题单题目',
'db_table': 'problemset_problem',
'ordering': ('order',),
'unique_together': {('problemset', 'problem')},
},
),
migrations.CreateModel(
name='ProblemSetProgress',
fields=[
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('join_time', models.DateTimeField(auto_now_add=True, verbose_name='加入时间')),
('complete_time', models.DateTimeField(blank=True, null=True, verbose_name='完成时间')),
('is_completed', models.BooleanField(default=False, verbose_name='是否完成')),
('progress_percentage', models.FloatField(default=0.0, verbose_name='完成进度')),
('completed_problems_count', models.IntegerField(default=0, verbose_name='已完成题目数')),
('total_problems_count', models.IntegerField(default=0, verbose_name='总题目数')),
('total_score', models.IntegerField(default=0, verbose_name='总分')),
('progress_detail', models.JSONField(default=dict, verbose_name='详细进度')),
('problemset', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='problemset.problemset', verbose_name='题单')),
('user', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL, verbose_name='用户')),
],
options={
'verbose_name': '题单进度',
'verbose_name_plural': '题单进度',
'db_table': 'problemset_progress',
'unique_together': {('problemset', 'user')},
},
),
migrations.CreateModel(
name='UserBadge',
fields=[
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('earned_time', models.DateTimeField(auto_now_add=True, verbose_name='获得时间')),
('is_displayed', models.BooleanField(default=False, verbose_name='是否已展示')),
('badge', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='problemset.problemsetbadge', verbose_name='奖章')),
('user', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL, verbose_name='用户')),
],
options={
'verbose_name': '用户奖章',
'verbose_name_plural': '用户奖章',
'db_table': 'user_badge',
'unique_together': {('user', 'badge')},
},
),
]

View File

@@ -1,17 +0,0 @@
# Generated by Django 5.2.3 on 2025-10-22 11:04
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
('problemset', '0001_initial'),
]
operations = [
migrations.RemoveField(
model_name='problemset',
name='is_public',
),
]

View File

@@ -1,17 +0,0 @@
# Generated by Django 5.2.3 on 2025-10-22 12:04
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
('problemset', '0002_remove_is_public_field'),
]
operations = [
migrations.RemoveField(
model_name='problemsetbadge',
name='level',
),
]

View File

@@ -1,47 +0,0 @@
# Generated by Django 5.2.3 on 2025-10-22 16:49
import django.db.models.deletion
from django.conf import settings
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('problem', '0005_remove_spj_fields'),
('problemset', '0003_remove_badge_level'),
('submission', '0002_submission_user_create_time_idx'),
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
]
operations = [
migrations.AlterField(
model_name='problemset',
name='status',
field=models.TextField(default='draft', verbose_name='状态'),
),
migrations.CreateModel(
name='ProblemSetSubmission',
fields=[
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('submit_time', models.DateTimeField(auto_now_add=True, verbose_name='提交时间')),
('result', models.IntegerField(verbose_name='提交结果')),
('score', models.IntegerField(default=0, verbose_name='得分')),
('language', models.CharField(max_length=20, verbose_name='编程语言')),
('code_length', models.IntegerField(default=0, verbose_name='代码长度')),
('execution_time', models.IntegerField(default=0, verbose_name='执行时间')),
('memory_usage', models.IntegerField(default=0, verbose_name='内存使用')),
('problem', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='problem.problem', verbose_name='题目')),
('problemset', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='problemset.problemset', verbose_name='题单')),
('submission', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='submission.submission', verbose_name='提交记录')),
('user', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL, verbose_name='用户')),
],
options={
'verbose_name': '题单提交记录',
'verbose_name_plural': '题单提交记录',
'db_table': 'problemset_submission',
'ordering': ('-submit_time',),
'indexes': [models.Index(fields=['problemset', 'user'], name='problemset__problem_1f39fa_idx'), models.Index(fields=['problemset', 'problem'], name='problemset__problem_22f053_idx'), models.Index(fields=['user', 'submit_time'], name='problemset__user_id_63c1d0_idx')],
},
),
]

View File

@@ -1,57 +0,0 @@
# Generated by Django 5.2.3 on 2025-10-23 01:34
from django.conf import settings
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('problem', '0005_remove_spj_fields'),
('problemset', '0004_alter_problemset_status_problemsetsubmission'),
('submission', '0002_submission_user_create_time_idx'),
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
]
operations = [
migrations.AlterModelOptions(
name='problemsetsubmission',
options={'ordering': ('-submission__create_time',), 'verbose_name': '题单提交记录', 'verbose_name_plural': '题单提交记录'},
),
migrations.RemoveIndex(
model_name='problemsetsubmission',
name='problemset__user_id_63c1d0_idx',
),
migrations.RemoveField(
model_name='problemsetsubmission',
name='code_length',
),
migrations.RemoveField(
model_name='problemsetsubmission',
name='execution_time',
),
migrations.RemoveField(
model_name='problemsetsubmission',
name='language',
),
migrations.RemoveField(
model_name='problemsetsubmission',
name='memory_usage',
),
migrations.RemoveField(
model_name='problemsetsubmission',
name='result',
),
migrations.RemoveField(
model_name='problemsetsubmission',
name='score',
),
migrations.RemoveField(
model_name='problemsetsubmission',
name='submit_time',
),
migrations.AddIndex(
model_name='problemsetsubmission',
index=models.Index(fields=['user'], name='problemset__user_id_2f1501_idx'),
),
]

View File

@@ -1,17 +0,0 @@
# Generated by Django 5.2.3 on 2025-10-23 03:41
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
('problemset', '0005_alter_problemsetsubmission_options_and_more'),
]
operations = [
migrations.RemoveField(
model_name='userbadge',
name='is_displayed',
),
]

View File

@@ -1,284 +0,0 @@
from django.db import models
from django.utils.timezone import now
from account.models import User
from problem.models import Problem
from utils.models import RichTextField, JSONField
class ProblemSet(models.Model):
"""题单模型"""
title = models.TextField(verbose_name="题单标题")
description = RichTextField(verbose_name="题单描述")
# 创建者
created_by = models.ForeignKey(
User, on_delete=models.CASCADE, verbose_name="创建者"
)
# 创建时间
create_time = models.DateTimeField(auto_now_add=True, verbose_name="创建时间")
# 更新时间
last_update_time = models.DateTimeField(auto_now=True, verbose_name="更新时间")
# 是否可见
visible = models.BooleanField(default=True, verbose_name="是否可见")
# 题单难度等级
difficulty = models.TextField(default="Easy", verbose_name="难度等级")
# 题单状态
status = models.TextField(
default="draft", verbose_name="状态"
) # active, archived, draft
class Meta:
db_table = "problemset"
ordering = ("-create_time",)
verbose_name = "题单"
verbose_name_plural = "题单"
def __str__(self):
return self.title
class ProblemSetProblem(models.Model):
"""题单题目关联模型"""
problemset = models.ForeignKey(
ProblemSet, on_delete=models.CASCADE, verbose_name="题单"
)
problem = models.ForeignKey(Problem, on_delete=models.CASCADE, verbose_name="题目")
# 在题单中的顺序
order = models.IntegerField(default=0, verbose_name="顺序")
# 是否为必做题
is_required = models.BooleanField(default=True, verbose_name="是否必做")
# 题目在题单中的分值
score = models.IntegerField(default=0, verbose_name="分值")
# 题目提示信息
hint = models.TextField(null=True, blank=True, verbose_name="提示")
class Meta:
db_table = "problemset_problem"
unique_together = (("problemset", "problem"),)
ordering = ("order",)
verbose_name = "题单题目"
verbose_name_plural = "题单题目"
def __str__(self):
return f"{self.problemset.title} - {self.problem.title}"
class ProblemSetBadge(models.Model):
"""题单奖章模型"""
problemset = models.ForeignKey(
ProblemSet, on_delete=models.CASCADE, verbose_name="题单"
)
name = models.TextField(verbose_name="奖章名称")
description = models.TextField(verbose_name="奖章描述")
# 奖章图标路径
icon = models.TextField(verbose_name="奖章图标")
# 获得条件:完成所有题目、完成指定数量题目、达到指定分数等
condition_type = models.TextField(
verbose_name="获得条件类型"
) # all_problems, problem_count, score
condition_value = models.IntegerField(default=0, verbose_name="条件值")
class Meta:
db_table = "problemset_badge"
verbose_name = "题单奖章"
verbose_name_plural = "题单奖章"
def __str__(self):
return f"{self.problemset.title} - {self.name}"
def recalculate_user_badges(self):
"""重新计算所有用户的徽章资格"""
# 获取所有已加入该题单的用户进度
user_progresses = ProblemSetProgress.objects.filter(problemset=self.problemset)
# 删除该徽章的所有现有用户徽章记录
UserBadge.objects.filter(badge=self).delete()
# 重新评估每个用户的徽章资格
for progress in user_progresses:
self._check_user_badge_eligibility(progress)
def _check_user_badge_eligibility(self, progress):
"""检查用户是否符合该徽章的条件"""
# 检查是否已经拥有该徽章
if UserBadge.objects.filter(user=progress.user, badge=self).exists():
return False
# 根据条件类型检查用户是否符合条件
if self.condition_type == "all_problems":
if progress.completed_problems_count == progress.total_problems_count:
UserBadge.objects.create(user=progress.user, badge=self)
return True
elif self.condition_type == "problem_count":
if progress.completed_problems_count >= self.condition_value:
UserBadge.objects.create(user=progress.user, badge=self)
return True
elif self.condition_type == "score":
if progress.total_score >= self.condition_value:
UserBadge.objects.create(user=progress.user, badge=self)
return True
return False
class ProblemSetProgress(models.Model):
"""题单进度模型"""
problemset = models.ForeignKey(
ProblemSet, on_delete=models.CASCADE, verbose_name="题单"
)
user = models.ForeignKey(User, on_delete=models.CASCADE, verbose_name="用户")
# 加入时间
join_time = models.DateTimeField(auto_now_add=True, verbose_name="加入时间")
# 完成时间
complete_time = models.DateTimeField(null=True, blank=True, verbose_name="完成时间")
# 是否完成
is_completed = models.BooleanField(default=False, verbose_name="是否完成")
# 完成进度百分比
progress_percentage = models.FloatField(default=0.0, verbose_name="完成进度")
# 已完成的题目数量
completed_problems_count = models.IntegerField(
default=0, verbose_name="已完成题目数"
)
# 总题目数量
total_problems_count = models.IntegerField(default=0, verbose_name="总题目数")
# 获得的总分
total_score = models.IntegerField(default=0, verbose_name="总分")
# 用户在该题单中的详细进度信息
# {"problem_id": {"score": 20, "submit_time": "2024-01-01T00:00:00Z"}}
progress_detail = JSONField(default=dict, verbose_name="详细进度")
class Meta:
db_table = "problemset_progress"
unique_together = (("problemset", "user"),)
verbose_name = "题单进度"
verbose_name_plural = "题单进度"
def __str__(self):
return f"{self.user.username} - {self.problemset.title}"
def update_progress(self):
"""更新进度信息"""
# 获取题单中的所有题目
problemset_problems = ProblemSetProblem.objects.filter(
problemset=self.problemset
)
self.total_problems_count = problemset_problems.count()
# 获取当前题单中所有题目的ID集合
current_problem_ids = {str(psp.problem.id) for psp in problemset_problems}
# 清理已删除题目的进度记录
progress_detail_to_remove = []
for problem_id in self.progress_detail.keys():
if problem_id not in current_problem_ids:
progress_detail_to_remove.append(problem_id)
for problem_id in progress_detail_to_remove:
del self.progress_detail[problem_id]
# 计算已完成题目数
completed_count = 0
total_score = 0
for psp in problemset_problems:
problem_id = str(psp.problem.id)
if problem_id in self.progress_detail:
problem_progress = self.progress_detail[problem_id]
completed_count += 1
total_score += psp.score
problem_progress["score"] = psp.score
self.completed_problems_count = completed_count
self.total_score = total_score
# 计算完成百分比
if self.total_problems_count > 0:
self.progress_percentage = (
completed_count / self.total_problems_count
) * 100
else:
self.progress_percentage = 0
# 检查是否完成
self.is_completed = completed_count == self.total_problems_count
if self.is_completed and not self.complete_time:
self.complete_time = now()
self.save()
@classmethod
def sync_all_progress_for_problemset(cls, problemset):
"""同步指定题单的所有用户进度"""
progresses = cls.objects.filter(problemset=problemset)
for progress in progresses:
progress.update_progress()
return progresses.count()
class ProblemSetSubmission(models.Model):
"""题单提交记录模型"""
problemset = models.ForeignKey(
ProblemSet, on_delete=models.CASCADE, verbose_name="题单"
)
user = models.ForeignKey(User, on_delete=models.CASCADE, verbose_name="用户")
submission = models.ForeignKey(
"submission.Submission", on_delete=models.CASCADE, verbose_name="提交记录"
)
problem = models.ForeignKey(
"problem.Problem", on_delete=models.CASCADE, verbose_name="题目"
)
class Meta:
db_table = "problemset_submission"
ordering = ("-submission__create_time",)
verbose_name = "题单提交记录"
verbose_name_plural = "题单提交记录"
indexes = [
models.Index(fields=["problemset", "user"]),
models.Index(fields=["problemset", "problem"]),
models.Index(fields=["user"]),
]
def __str__(self):
return f"{self.user.username} - {self.problemset.title} - {self.problem.title}"
@property
def submit_time(self):
"""提交时间"""
return self.submission.create_time
@property
def result(self):
"""提交结果"""
return self.submission.result
@property
def language(self):
"""编程语言"""
return self.submission.language
class UserBadge(models.Model):
"""用户奖章模型"""
user = models.ForeignKey(User, on_delete=models.CASCADE, verbose_name="用户")
badge = models.ForeignKey(
ProblemSetBadge, on_delete=models.CASCADE, verbose_name="奖章"
)
# 获得时间
earned_time = models.DateTimeField(auto_now_add=True, verbose_name="获得时间")
class Meta:
db_table = "user_badge"
unique_together = (("user", "badge"),)
verbose_name = "用户奖章"
verbose_name_plural = "用户奖章"
def __str__(self):
return f"{self.user.username} - {self.badge.name}"

View File

@@ -1,285 +0,0 @@
from utils.api import UsernameSerializer, serializers
from .models import (
ProblemSet,
ProblemSetProblem,
ProblemSetBadge,
ProblemSetProgress,
UserBadge,
)
def get_user_progress_data(problemset, request):
"""获取当前用户在该题单中的进度 - 公共方法"""
if request and request.user.is_authenticated:
try:
progress = ProblemSetProgress.objects.get(
problemset=problemset, user=request.user
)
return {
"is_joined": True,
"progress_percentage": progress.progress_percentage,
"completed_count": progress.completed_problems_count,
"total_count": progress.total_problems_count,
"is_completed": progress.is_completed,
}
except ProblemSetProgress.DoesNotExist:
return {
"is_joined": False,
"progress_percentage": 0,
"completed_count": 0,
"total_count": 0,
"is_completed": False,
}
return {
"is_joined": False,
"progress_percentage": 0,
"completed_count": 0,
"total_count": 0,
"is_completed": False,
}
class ProblemSetSerializer(serializers.ModelSerializer):
"""题单序列化器"""
created_by = UsernameSerializer()
problems_count = serializers.SerializerMethodField()
completed_count = serializers.SerializerMethodField()
user_progress = serializers.SerializerMethodField()
class Meta:
model = ProblemSet
fields = "__all__"
def get_problems_count(self, obj):
"""获取题单中的题目数量"""
return ProblemSetProblem.objects.filter(problemset=obj).count()
def get_completed_count(self, obj):
"""获取当前用户在该题单中完成的题目数量"""
request = self.context.get("request")
if request and request.user.is_authenticated:
try:
progress = ProblemSetProgress.objects.get(
problemset=obj, user=request.user
)
return progress.completed_problems_count
except ProblemSetProgress.DoesNotExist:
return 0
return 0
def get_user_progress(self, obj):
"""获取当前用户在该题单中的进度"""
request = self.context.get("request")
return get_user_progress_data(obj, request)
class ProblemSetListSerializer(serializers.ModelSerializer):
"""题单列表序列化器"""
created_by = UsernameSerializer()
problems_count = serializers.SerializerMethodField()
user_progress = serializers.SerializerMethodField()
badges = serializers.SerializerMethodField()
class Meta:
model = ProblemSet
fields = [
"id",
"title",
"description",
"created_by",
"create_time",
"difficulty",
"status",
"problems_count",
"user_progress",
"badges",
"visible",
]
def get_problems_count(self, obj):
"""获取题单中的题目数量"""
return ProblemSetProblem.objects.filter(problemset=obj).count()
def get_user_progress(self, obj):
"""获取当前用户在该题单中的进度"""
request = self.context.get("request")
return get_user_progress_data(obj, request)
def get_badges(self, obj):
"""获取题单的奖章列表,并标记用户已获得的徽章"""
request = self.context.get("request")
badges = ProblemSetBadge.objects.filter(problemset=obj)
badge_data = ProblemSetBadgeSerializer(badges, many=True).data
# 如果用户已登录,检查哪些徽章已被获得
if request and request.user.is_authenticated:
earned_badge_ids = set(
UserBadge.objects.filter(
user=request.user,
badge__problemset=obj
).values_list('badge_id', flat=True)
)
# 为每个徽章添加是否已获得的标记
for badge in badge_data:
badge['is_earned'] = badge['id'] in earned_badge_ids
else:
# 未登录用户,所有徽章都标记为未获得
for badge in badge_data:
badge['is_earned'] = False
return badge_data
class CreateProblemSetSerializer(serializers.Serializer):
"""创建题单序列化器"""
title = serializers.CharField(max_length=200)
description = serializers.CharField()
difficulty = serializers.CharField(default="Easy")
status = serializers.CharField(default="active")
class EditProblemSetSerializer(serializers.Serializer):
"""编辑题单序列化器"""
id = serializers.IntegerField()
title = serializers.CharField(max_length=200, required=False)
description = serializers.CharField(required=False)
difficulty = serializers.CharField(required=False)
status = serializers.CharField(required=False)
visible = serializers.BooleanField(required=False)
class ProblemSetProblemSerializer(serializers.ModelSerializer):
"""题单题目序列化器"""
problem = serializers.SerializerMethodField()
is_completed = serializers.SerializerMethodField()
class Meta:
model = ProblemSetProblem
fields = "__all__"
def get_problem(self, obj):
"""获取题目详细信息"""
from problem.serializers import ProblemListSerializer
return ProblemListSerializer(obj.problem, context=self.context).data
def get_is_completed(self, obj):
"""获取当前用户是否已完成该题目"""
request = self.context.get("request")
if request and request.user.is_authenticated:
try:
progress = ProblemSetProgress.objects.get(
problemset=obj.problemset, user=request.user
)
problem_id = str(obj.problem.id)
return problem_id in progress.progress_detail
except ProblemSetProgress.DoesNotExist:
return False
return False
class AddProblemToSetSerializer(serializers.Serializer):
"""添加题目到题单序列化器"""
problem_id = serializers.CharField()
order = serializers.IntegerField(default=0)
is_required = serializers.BooleanField(default=True)
score = serializers.IntegerField(default=0)
hint = serializers.CharField(required=False, allow_blank=True)
class EditProblemInSetSerializer(serializers.Serializer):
"""编辑题单中的题目序列化器"""
order = serializers.IntegerField(required=False)
is_required = serializers.BooleanField(required=False)
score = serializers.IntegerField(required=False)
hint = serializers.CharField(required=False, allow_blank=True)
class ProblemSetBadgeSerializer(serializers.ModelSerializer):
"""题单奖章序列化器"""
class Meta:
model = ProblemSetBadge
fields = "__all__"
class CreateProblemSetBadgeSerializer(serializers.Serializer):
"""创建题单奖章序列化器"""
name = serializers.CharField(max_length=100)
description = serializers.CharField()
icon = serializers.CharField()
condition_type = serializers.CharField() # all_problems, problem_count, score
condition_value = serializers.IntegerField(required=False)
class EditProblemSetBadgeSerializer(serializers.Serializer):
"""编辑题单奖章序列化器"""
name = serializers.CharField(max_length=100, required=False)
description = serializers.CharField(required=False)
icon = serializers.CharField(required=False)
condition_type = serializers.CharField(required=False) # all_problems, problem_count, score
condition_value = serializers.IntegerField(required=False)
class ProblemSetProgressSerializer(serializers.ModelSerializer):
"""题单进度序列化器"""
user = UsernameSerializer()
completed_problems = serializers.SerializerMethodField()
class Meta:
model = ProblemSetProgress
fields = "__all__"
def get_completed_problems(self, obj):
"""获取已完成的题目列表"""
from problem.models import Problem
completed_problems = []
if obj.progress_detail:
for problem_id in obj.progress_detail.keys():
try:
problem = Problem.objects.get(id=problem_id)
completed_problems.append({
'id': problem.id,
'_id': problem._id,
'title': problem.title
})
except Problem.DoesNotExist:
continue
return completed_problems
class UserBadgeSerializer(serializers.ModelSerializer):
"""用户奖章序列化器"""
badge = ProblemSetBadgeSerializer()
class Meta:
model = UserBadge
fields = "__all__"
class JoinProblemSetSerializer(serializers.Serializer):
"""加入题单序列化器"""
problemset_id = serializers.IntegerField()
class UpdateProgressSerializer(serializers.Serializer):
"""更新进度序列化器"""
problemset_id = serializers.IntegerField()
problem_id = serializers.IntegerField()
submission_id = serializers.CharField(required=False)

View File

@@ -1,91 +0,0 @@
# 题单应用信号处理
from django.db.models.signals import post_save, post_delete
from django.dispatch import receiver
from .models import ProblemSetProblem, ProblemSetProgress, ProblemSetBadge, UserBadge
from django.db import transaction
import logging
logger = logging.getLogger(__name__)
@receiver(post_save, sender=ProblemSetProblem)
def sync_progress_on_problem_change(sender, instance, created, **kwargs):
"""当题单题目发生变化时,同步所有用户的进度"""
try:
with transaction.atomic():
# 获取该题单的所有用户进度
progresses = ProblemSetProgress.objects.filter(
problemset=instance.problemset
)
# 批量更新所有用户的进度
for progress in progresses:
progress.update_progress()
# 重新计算该题单的所有徽章资格
badges = ProblemSetBadge.objects.filter(problemset=instance.problemset)
for badge in badges:
badge.recalculate_user_badges()
logger.info(f"已同步题单 {instance.problemset.id} 的所有用户进度和徽章资格")
except Exception as e:
logger.error(f"同步题单进度时出错: {e}")
@receiver(post_delete, sender=ProblemSetProblem)
def sync_progress_on_problem_delete(sender, instance, **kwargs):
"""当题单题目被删除时,同步所有用户的进度并清理相关提交记录"""
try:
with transaction.atomic():
# 清理该题目在题单中的所有提交记录
from .models import ProblemSetSubmission
ProblemSetSubmission.objects.filter(
problemset=instance.problemset,
problem=instance.problem
).delete()
# 获取该题单的所有用户进度
progresses = ProblemSetProgress.objects.filter(
problemset=instance.problemset
)
# 批量更新所有用户的进度
for progress in progresses:
progress.update_progress()
# 重新计算该题单的所有徽章资格
badges = ProblemSetBadge.objects.filter(problemset=instance.problemset)
for badge in badges:
badge.recalculate_user_badges()
logger.info(f"已同步题单 {instance.problemset.id} 的所有用户进度和徽章资格(删除题目后)")
except Exception as e:
logger.error(f"同步题单进度时出错: {e}")
@receiver(post_save, sender=ProblemSetBadge)
def sync_badges_on_badge_change(sender, instance, created, **kwargs):
"""当题单奖章发生变化时,重新计算所有用户的奖章资格"""
try:
with transaction.atomic():
# 重新计算该奖章的所有用户资格
instance.recalculate_user_badges()
logger.info(f"已重新计算题单 {instance.problemset.id} 的奖章 {instance.id} 的用户资格")
except Exception as e:
logger.error(f"重新计算奖章资格时出错: {e}")
@receiver(post_delete, sender=ProblemSetBadge)
def cleanup_badges_on_badge_delete(sender, instance, **kwargs):
"""当题单奖章被删除时,清理相关的用户奖章记录"""
try:
with transaction.atomic():
# 删除该奖章的所有用户奖章记录
UserBadge.objects.filter(badge=instance).delete()
logger.info(f"已清理奖章 {instance.id} 的所有用户奖章记录")
except Exception as e:
logger.error(f"清理用户奖章记录时出错: {e}")

View File

@@ -1,71 +0,0 @@
from django.urls import path
from problemset.views.admin import (
ProblemSetAdminAPI,
ProblemSetBadgeAdminAPI,
ProblemSetDetailAdminAPI,
ProblemSetProblemAdminAPI,
ProblemSetProgressAdminAPI,
ProblemSetStatusAPI,
ProblemSetSyncAPI,
ProblemSetVisibleAPI,
)
urlpatterns = [
# 管理员题单管理API
path("problemset", ProblemSetAdminAPI.as_view(), name="admin_problemset_api"),
path(
"problemset/<int:problem_set_id>",
ProblemSetDetailAdminAPI.as_view(),
name="admin_problemset_detail_api",
),
path(
"problemset/<int:problem_set_id>/problems",
ProblemSetProblemAdminAPI.as_view(),
name="admin_problemset_problems_api",
),
path(
"problemset/<int:problem_set_id>/problems/<int:problem_set_problem_id>",
ProblemSetProblemAdminAPI.as_view(),
name="admin_problemset_problem_detail_api",
),
# 管理员奖章管理API
path(
"problemset/<int:problem_set_id>/badges",
ProblemSetBadgeAdminAPI.as_view(),
name="admin_problemset_badges_api",
),
path(
"problemset/<int:problem_set_id>/badges/<int:badge_id>",
ProblemSetBadgeAdminAPI.as_view(),
name="admin_problemset_badge_detail_api",
),
# 管理员进度管理API
path(
"problemset/<int:problem_set_id>/progress",
ProblemSetProgressAdminAPI.as_view(),
name="admin_problemset_progress_api",
),
path(
"problemset/<int:problem_set_id>/progress/<int:user_id>",
ProblemSetProgressAdminAPI.as_view(),
name="admin_problemset_progress_detail_api",
),
# 题单同步管理API
path(
"problemset/<int:problem_set_id>/sync",
ProblemSetSyncAPI.as_view(),
name="admin_problemset_sync_api",
),
# 题单状态管理API
path(
"problemset/visible",
ProblemSetVisibleAPI.as_view(),
name="admin_problemset_visible_api",
),
path(
"problemset/status",
ProblemSetStatusAPI.as_view(),
name="admin_problemset_status_api",
),
]

View File

@@ -1,55 +0,0 @@
from django.urls import path
from problemset.views.oj import (
ProblemSetAPI,
ProblemSetDetailAPI,
ProblemSetProblemAPI,
ProblemSetProgressAPI,
UserBadgeAPI,
UserProgressAPI,
ProblemSetBadgeAPI,
ProblemSetUserProgressAPI,
)
urlpatterns = [
# 题单相关API
path("problemset", ProblemSetAPI.as_view(), name="problemset_api"),
path(
"problemset/<int:problem_set_id>",
ProblemSetDetailAPI.as_view(),
name="problemset_detail_api",
),
path(
"problemset/<int:problem_set_id>/problems",
ProblemSetProblemAPI.as_view(),
name="problemset_problems_api",
),
path(
"problemset/<int:problem_set_id>/problems/<int:problem_id>",
ProblemSetProblemAPI.as_view(),
name="problemset_problem_detail_api",
),
# 进度相关API
path(
"problemset/progress",
ProblemSetProgressAPI.as_view(),
name="problemset_progress_api",
),
path(
"problemset/<int:problem_set_id>/progress",
ProblemSetProgressAPI.as_view(),
name="problemset_progress_detail_api",
),
path("user/progress", UserProgressAPI.as_view(), name="user_progress_api"),
# 奖章相关API
path("user/badges", UserBadgeAPI.as_view(), name="user_badges_api"),
path(
"problemset/<int:problem_set_id>/badges",
ProblemSetBadgeAPI.as_view(),
name="problemset_badges_api",
),
path(
"problemset/<int:problem_set_id>/users_progress",
ProblemSetUserProgressAPI.as_view(),
name="problemset_user_progress_api",
),
]

View File

@@ -1,428 +0,0 @@
from django.db.models import Q
from utils.api import APIView, validate_serializer
from account.decorators import super_admin_required, ensure_created_by
from problemset.models import (
ProblemSet,
ProblemSetProblem,
ProblemSetBadge,
ProblemSetProgress,
)
from problemset.serializers import (
ProblemSetSerializer,
ProblemSetListSerializer,
CreateProblemSetSerializer,
EditProblemSetSerializer,
ProblemSetProblemSerializer,
AddProblemToSetSerializer,
EditProblemInSetSerializer,
ProblemSetBadgeSerializer,
CreateProblemSetBadgeSerializer,
EditProblemSetBadgeSerializer,
ProblemSetProgressSerializer,
)
from problem.models import Problem
class ProblemSetAdminAPI(APIView):
"""题单管理API"""
@super_admin_required
def get(self, request):
"""获取题单列表(管理员)"""
problem_sets = ProblemSet.objects.all().order_by("-create_time")
# 过滤条件
keyword = request.GET.get("keyword", "").strip()
if keyword:
problem_sets = problem_sets.filter(
Q(title__icontains=keyword) | Q(description__icontains=keyword)
)
difficulty = request.GET.get("difficulty")
if difficulty:
problem_sets = problem_sets.filter(difficulty=difficulty)
status = request.GET.get("status")
if status:
problem_sets = problem_sets.filter(status=status)
# 使用统一的分页方法
data = self.paginate_data(request, problem_sets, ProblemSetListSerializer)
return self.success(data)
@super_admin_required
@validate_serializer(CreateProblemSetSerializer)
def post(self, request):
"""创建题单"""
data = request.data
data["created_by"] = request.user
problem_set = ProblemSet.objects.create(**data)
return self.success(ProblemSetSerializer(problem_set).data)
@super_admin_required
@validate_serializer(EditProblemSetSerializer)
def put(self, request):
"""编辑题单"""
data = request.data
problem_set_id = data.pop("id")
try:
problem_set = ProblemSet.objects.get(id=problem_set_id)
ensure_created_by(problem_set, request.user)
except ProblemSet.DoesNotExist:
return self.error("题单不存在")
# 更新题单信息
for key, value in data.items():
if key != "id":
setattr(problem_set, key, value)
problem_set.save()
return self.success(ProblemSetSerializer(problem_set).data)
@super_admin_required
def delete(self, request):
"""删除题单"""
problem_set_id = request.GET.get("id")
if not problem_set_id:
return self.error("题单ID是必需的")
try:
problem_set = ProblemSet.objects.get(id=problem_set_id)
ensure_created_by(problem_set, request.user)
except ProblemSet.DoesNotExist:
return self.error("题单不存在")
# 软删除:设置为不可见
problem_set.visible = False
problem_set.save()
return self.success("题单已删除")
class ProblemSetDetailAdminAPI(APIView):
"""题单详情管理API"""
@super_admin_required
def get(self, request, problem_set_id):
"""获取题单详情(管理员)"""
try:
problem_set = ProblemSet.objects.get(id=problem_set_id)
ensure_created_by(problem_set, request.user)
except ProblemSet.DoesNotExist:
return self.error("题单不存在")
serializer = ProblemSetSerializer(problem_set, context={"request": request})
return self.success(serializer.data)
class ProblemSetProblemAdminAPI(APIView):
"""题单题目管理API管理员"""
@super_admin_required
def get(self, request, problem_set_id):
"""获取题单中的题目列表(管理员)"""
try:
problem_set = ProblemSet.objects.get(id=problem_set_id)
ensure_created_by(problem_set, request.user)
except ProblemSet.DoesNotExist:
return self.error("题单不存在")
problems = ProblemSetProblem.objects.filter(problemset=problem_set).order_by(
"order"
)
serializer = ProblemSetProblemSerializer(
problems, many=True, context={"request": request}
)
return self.success(serializer.data)
@super_admin_required
@validate_serializer(AddProblemToSetSerializer)
def post(self, request, problem_set_id):
"""添加题目到题单(管理员)"""
try:
problem_set = ProblemSet.objects.get(id=problem_set_id)
ensure_created_by(problem_set, request.user)
except ProblemSet.DoesNotExist:
return self.error("题单不存在")
data = request.data
try:
problem = Problem.objects.filter(
_id=data["problem_id"],
visible=True,
contest_id__isnull=True,
).get()
except Problem.DoesNotExist:
return self.error("题目不存在或不可见")
# 检查题目是否已经在题单中
if ProblemSetProblem.objects.filter(
problemset=problem_set, problem=problem
).exists():
return self.error("题目已在该题单中")
ProblemSetProblem.objects.create(
problemset=problem_set,
problem=problem,
order=data.get("order", 0),
is_required=data.get("is_required", True),
score=data.get("score", 0),
hint=data.get("hint", ""),
)
# 同步所有用户的进度
ProblemSetProgress.sync_all_progress_for_problemset(problem_set)
return self.success("题目已添加到题单")
@super_admin_required
@validate_serializer(EditProblemInSetSerializer)
def put(self, request, problem_set_id, problem_set_problem_id):
"""编辑题单中的题目(管理员)"""
try:
problem_set = ProblemSet.objects.get(id=problem_set_id)
ensure_created_by(problem_set, request.user)
except ProblemSet.DoesNotExist:
return self.error("题单不存在")
try:
problem_set_problem = ProblemSetProblem.objects.get(
id=problem_set_problem_id, problemset=problem_set
)
except ProblemSetProblem.DoesNotExist:
return self.error("题目不在该题单中")
data = request.data
# 更新题目属性
if "order" in data:
problem_set_problem.order = data["order"]
if "is_required" in data:
problem_set_problem.is_required = data["is_required"]
if "score" in data:
problem_set_problem.score = data["score"]
if "hint" in data:
problem_set_problem.hint = data["hint"]
problem_set_problem.save()
# 同步所有用户的进度
ProblemSetProgress.sync_all_progress_for_problemset(problem_set)
return self.success("题目已更新")
@super_admin_required
def delete(self, request, problem_set_id, problem_set_problem_id):
"""从题单中移除题目(管理员)"""
try:
problem_set = ProblemSet.objects.get(id=problem_set_id)
ensure_created_by(problem_set, request.user)
except ProblemSet.DoesNotExist:
return self.error("题单不存在")
try:
problem_set_problem = ProblemSetProblem.objects.get(
id=problem_set_problem_id, problemset=problem_set
)
problem_set_problem.delete()
# 同步所有用户的进度
ProblemSetProgress.sync_all_progress_for_problemset(problem_set)
return self.success("题目已从题单中移除")
except ProblemSetProblem.DoesNotExist:
return self.error("题目不在该题单中")
class ProblemSetBadgeAdminAPI(APIView):
"""题单奖章管理API管理员"""
@super_admin_required
def get(self, request, problem_set_id):
"""获取题单的奖章列表(管理员)"""
try:
problem_set = ProblemSet.objects.get(id=problem_set_id)
ensure_created_by(problem_set, request.user)
except ProblemSet.DoesNotExist:
return self.error("题单不存在")
badges = ProblemSetBadge.objects.filter(problemset=problem_set)
serializer = ProblemSetBadgeSerializer(badges, many=True)
return self.success(serializer.data)
@super_admin_required
@validate_serializer(CreateProblemSetBadgeSerializer)
def post(self, request, problem_set_id):
"""创建题单奖章(管理员)"""
try:
problem_set = ProblemSet.objects.get(id=problem_set_id)
ensure_created_by(problem_set, request.user)
except ProblemSet.DoesNotExist:
return self.error("题单不存在")
data = request.data
data["problemset"] = problem_set
badge = ProblemSetBadge.objects.create(**data)
return self.success(ProblemSetBadgeSerializer(badge).data)
@super_admin_required
@validate_serializer(EditProblemSetBadgeSerializer)
def put(self, request, problem_set_id, badge_id):
"""编辑题单奖章(管理员)"""
try:
problem_set = ProblemSet.objects.get(id=problem_set_id)
ensure_created_by(problem_set, request.user)
except ProblemSet.DoesNotExist:
return self.error("题单不存在")
try:
badge = ProblemSetBadge.objects.get(id=badge_id, problemset=problem_set)
except ProblemSetBadge.DoesNotExist:
return self.error("奖章不存在")
data = request.data
# 记录是否修改了条件相关的字段
condition_changed = False
# 更新奖章属性
if "name" in data:
badge.name = data["name"]
if "description" in data:
badge.description = data["description"]
if "icon" in data:
badge.icon = data["icon"]
if "condition_type" in data:
badge.condition_type = data["condition_type"]
condition_changed = True
if "condition_value" in data:
badge.condition_value = data["condition_value"]
condition_changed = True
if "level" in data:
badge.level = data["level"]
badge.save()
# 如果修改了条件,重新计算所有用户的徽章资格
if condition_changed:
try:
badge.recalculate_user_badges()
return self.success("奖章已更新,并重新计算了所有用户的徽章资格")
except Exception as e:
return self.error(f"奖章已更新,但重新计算徽章资格时出错: {str(e)}")
return self.success("奖章已更新")
@super_admin_required
def delete(self, request, problem_set_id, badge_id):
"""删除题单奖章(管理员)"""
try:
problem_set = ProblemSet.objects.get(id=problem_set_id)
ensure_created_by(problem_set, request.user)
except ProblemSet.DoesNotExist:
return self.error("题单不存在")
try:
badge = ProblemSetBadge.objects.get(id=badge_id, problemset=problem_set)
badge.delete()
return self.success("奖章已删除")
except ProblemSetBadge.DoesNotExist:
return self.error("奖章不存在")
class ProblemSetProgressAdminAPI(APIView):
"""题单进度管理API管理员"""
@super_admin_required
def get(self, request, problem_set_id):
"""获取题单的所有用户进度(管理员)"""
try:
problem_set = ProblemSet.objects.get(id=problem_set_id)
ensure_created_by(problem_set, request.user)
except ProblemSet.DoesNotExist:
return self.error("题单不存在")
progress_list = ProblemSetProgress.objects.filter(
problemset=problem_set
).order_by("-join_time")
serializer = ProblemSetProgressSerializer(progress_list, many=True)
return self.success(serializer.data)
@super_admin_required
def delete(self, request, problem_set_id, user_id):
"""移除用户从题单(管理员)"""
try:
problem_set = ProblemSet.objects.get(id=problem_set_id)
ensure_created_by(problem_set, request.user)
except ProblemSet.DoesNotExist:
return self.error("题单不存在")
try:
progress = ProblemSetProgress.objects.get(
problemset=problem_set, user_id=user_id
)
progress.delete()
return self.success("用户已从题单中移除")
except ProblemSetProgress.DoesNotExist:
return self.error("用户未加入该题单")
class ProblemSetSyncAPI(APIView):
"""题单同步管理API"""
@super_admin_required
def post(self, request, problem_set_id):
"""手动同步题单的所有用户进度(管理员)"""
try:
problem_set = ProblemSet.objects.get(id=problem_set_id)
ensure_created_by(problem_set, request.user)
except ProblemSet.DoesNotExist:
return self.error("题单不存在")
# 同步所有用户的进度
synced_count = ProblemSetProgress.sync_all_progress_for_problemset(problem_set)
return self.success(f"已同步 {synced_count} 个用户的进度")
class ProblemSetVisibleAPI(APIView):
"""题单可见性管理API"""
@super_admin_required
def put(self, request):
"""切换题单可见性"""
data = request.data
try:
problem_set = ProblemSet.objects.get(id=data["id"])
ensure_created_by(problem_set, request.user)
except ProblemSet.DoesNotExist:
return self.error("题单不存在")
problem_set.visible = not problem_set.visible
problem_set.save()
return self.success()
class ProblemSetStatusAPI(APIView):
"""题单状态管理API"""
@super_admin_required
def put(self, request):
"""更新题单状态"""
data = request.data
try:
problem_set = ProblemSet.objects.get(id=data["id"])
ensure_created_by(problem_set, request.user)
except ProblemSet.DoesNotExist:
return self.error("题单不存在")
status = data.get("status")
if status not in ["active", "archived", "draft"]:
return self.error("无效的状态")
problem_set.status = status
problem_set.save()
return self.success()

View File

@@ -1,337 +0,0 @@
from django.db.models import Q, Avg
from django.utils import timezone
from utils.api import APIView, validate_serializer
from account.models import User
from problemset.models import (
ProblemSet,
ProblemSetProblem,
ProblemSetBadge,
ProblemSetProgress,
ProblemSetSubmission,
UserBadge,
)
from problemset.serializers import (
ProblemSetSerializer,
ProblemSetListSerializer,
ProblemSetProblemSerializer,
ProblemSetBadgeSerializer,
ProblemSetProgressSerializer,
UserBadgeSerializer,
JoinProblemSetSerializer,
UpdateProgressSerializer,
)
from submission.models import Submission
from problem.models import Problem
class ProblemSetAPI(APIView):
"""题单API - 用户端"""
def get(self, request):
"""获取题单列表"""
problem_sets = ProblemSet.objects.filter(visible=True).exclude(status="draft")
# 过滤条件
keyword = request.GET.get("keyword", "").strip()
if keyword:
problem_sets = problem_sets.filter(
Q(title__icontains=keyword) | Q(description__icontains=keyword)
)
difficulty = request.GET.get("difficulty")
if difficulty:
problem_sets = problem_sets.filter(difficulty=difficulty)
status_filter = request.GET.get("status")
if status_filter:
problem_sets = problem_sets.filter(status=status_filter)
# 排序
sort = request.GET.get("sort")
if sort:
problem_sets = problem_sets.order_by(sort)
else:
problem_sets = problem_sets.order_by("-create_time")
data = self.paginate_data(request, problem_sets, ProblemSetListSerializer)
return self.success(data)
class ProblemSetDetailAPI(APIView):
"""题单详情API - 用户端"""
def get(self, request, problem_set_id):
"""获取题单详情"""
try:
problem_set = (
ProblemSet.objects.filter(id=problem_set_id, visible=True)
.exclude(status="draft")
.get()
)
except ProblemSet.DoesNotExist:
return self.error("题单不存在")
serializer = ProblemSetSerializer(problem_set, context={"request": request})
return self.success(serializer.data)
class ProblemSetProblemAPI(APIView):
"""题单题目API - 用户端"""
def get(self, request, problem_set_id):
"""获取题单中的题目列表"""
try:
problem_set = (
ProblemSet.objects.filter(id=problem_set_id, visible=True)
.exclude(status="draft")
.get()
)
except ProblemSet.DoesNotExist:
return self.error("题单不存在")
problems = ProblemSetProblem.objects.filter(problemset=problem_set).order_by(
"order"
)
serializer = ProblemSetProblemSerializer(
problems, many=True, context={"request": request}
)
return self.success(serializer.data)
class ProblemSetProgressAPI(APIView):
"""题单进度API"""
@validate_serializer(JoinProblemSetSerializer)
def post(self, request):
"""加入题单"""
data = request.data
try:
problem_set = (
ProblemSet.objects.filter(id=data["problemset_id"], visible=True)
.exclude(status="draft")
.get()
)
except ProblemSet.DoesNotExist:
return self.error("题单不存在")
if ProblemSetProgress.objects.filter(
problemset=problem_set, user=request.user
).exists():
return self.error("已经加入该题单")
# 创建进度记录
progress = ProblemSetProgress.objects.create(
problemset=problem_set, user=request.user
)
progress.update_progress()
return self.success("成功加入题单")
def get(self, request, problem_set_id):
"""获取题单进度"""
try:
problem_set = (
ProblemSet.objects.filter(id=problem_set_id, visible=True)
.exclude(status="draft")
.get()
)
except ProblemSet.DoesNotExist:
return self.error("题单不存在")
try:
progress = ProblemSetProgress.objects.get(
problemset=problem_set, user=request.user
)
except ProblemSetProgress.DoesNotExist:
return self.error("未加入该题单")
serializer = ProblemSetProgressSerializer(progress)
return self.success(serializer.data)
@validate_serializer(UpdateProgressSerializer)
def put(self, request):
"""更新进度"""
data = request.data
try:
problem_set = (
ProblemSet.objects.filter(id=data["problemset_id"], visible=True)
.exclude(status="draft")
.get()
)
except ProblemSet.DoesNotExist:
return self.error("题单不存在")
try:
progress = ProblemSetProgress.objects.get(
problemset=problem_set, user=request.user
)
except ProblemSetProgress.DoesNotExist:
return self.error("未加入该题单")
# 更新详细进度
problem_id = str(data["problem_id"])
# 获取该题目在题单中的分值
try:
problemset_problem = ProblemSetProblem.objects.get(
problemset=problem_set, problem_id=problem_id
)
problem_score = problemset_problem.score
except ProblemSetProblem.DoesNotExist:
problem_score = 0
progress.progress_detail[problem_id] = {
"score": problem_score, # 题单中设置的分值
"submit_time": data.get("submit_time", timezone.now().isoformat()),
}
# 更新进度
progress.update_progress()
# 只有当提供了submission_id时才创建ProblemSetSubmission记录
if "submission_id" in data and data["submission_id"]:
try:
submission = Submission.objects.get(id=data["submission_id"])
problem = Problem.objects.get(id=problem_id)
has_accepted = ProblemSetSubmission.objects.filter(
problemset=problem_set,
user=request.user,
problem=problem,
).exists()
if not has_accepted:
ProblemSetSubmission.objects.create(
problemset=problem_set,
user=request.user,
submission=submission,
problem=problem,
)
except Submission.DoesNotExist:
# 如果提交记录不存在,记录错误但不中断流程
pass
# 检查是否获得奖章
self._check_badges(progress)
return self.success("进度已更新")
def _check_badges(self, progress):
"""检查是否获得奖章"""
badges = ProblemSetBadge.objects.filter(problemset=progress.problemset)
for badge in badges:
if UserBadge.objects.filter(user=progress.user, badge=badge).exists():
continue
if badge.condition_type == "all_problems":
if progress.completed_problems_count == progress.total_problems_count:
UserBadge.objects.create(user=progress.user, badge=badge)
elif badge.condition_type == "problem_count":
if progress.completed_problems_count >= badge.condition_value:
UserBadge.objects.create(user=progress.user, badge=badge)
elif badge.condition_type == "score":
if progress.total_score >= badge.condition_value:
UserBadge.objects.create(user=progress.user, badge=badge)
class UserProgressAPI(APIView):
"""用户进度API"""
def get(self, request):
"""获取用户的题单进度列表"""
progress_list = ProblemSetProgress.objects.filter(user=request.user).order_by(
"-join_time"
)
serializer = ProblemSetProgressSerializer(progress_list, many=True)
return self.success(serializer.data)
class UserBadgeAPI(APIView):
"""用户奖章API"""
def get(self, request):
"""获取用户的奖章列表"""
# 支持通过username参数获取指定用户的徽章
username = request.GET.get("username")
if username:
# 获取指定用户的徽章
try:
target_user = User.objects.get(username=username, is_disabled=False)
badges = UserBadge.objects.filter(user=target_user).order_by("-earned_time")
except User.DoesNotExist:
return self.error("用户不存在")
else:
# 获取当前用户的徽章
badges = UserBadge.objects.filter(user=request.user).order_by("-earned_time")
serializer = UserBadgeSerializer(badges, many=True)
return self.success(serializer.data)
class ProblemSetBadgeAPI(APIView):
"""题单奖章API - 用户端"""
def get(self, request, problem_set_id):
"""获取题单的奖章列表"""
try:
problem_set = (
ProblemSet.objects.filter(id=problem_set_id, visible=True)
.exclude(status="draft")
.get()
)
except ProblemSet.DoesNotExist:
return self.error("题单不存在")
badges = ProblemSetBadge.objects.filter(problemset=problem_set)
serializer = ProblemSetBadgeSerializer(badges, many=True)
return self.success(serializer.data)
class ProblemSetUserProgressAPI(APIView):
"""题单用户进度列表API"""
def get(self, request, problem_set_id: int):
"""获取题单的用户进度列表"""
try:
problem_set = (
ProblemSet.objects.filter(id=problem_set_id, visible=True)
.exclude(status="draft")
.get()
)
except ProblemSet.DoesNotExist:
return self.error("题单不存在")
# 获取所有参与该题单的用户进度
progresses = ProblemSetProgress.objects.filter(problemset=problem_set)
# 班级过滤
class_name = request.GET.get("class_name", "").strip()
if class_name:
progresses = progresses.filter(user_username__icontains=class_name)
# 排序
progresses = progresses.order_by(
"-is_completed", "-progress_percentage", "join_time"
)
# 计算统计数据(基于所有数据,而非分页数据)
total_count = progresses.count()
completed_count = progresses.filter(is_completed=True).count()
avg_progress = progresses.aggregate(avg=Avg("progress_percentage"))["avg"] or 0
# 使用分页
data = self.paginate_data(request, progresses, ProblemSetProgressSerializer)
# 添加统计数据
data["statistics"] = {
"total": total_count,
"completed": completed_count,
"avg_progress": round(avg_progress, 2)
}
return self.success(data)

View File

@@ -5,31 +5,21 @@ description = "Add your description here"
readme = "README.md" readme = "README.md"
requires-python = ">=3.12" requires-python = ">=3.12"
dependencies = [ dependencies = [
"channels>=4.2.0",
"channels-redis>=4.2.0",
"coverage==6.5.0",
"daphne>=4.1.2",
"django>=5.2.3", "django>=5.2.3",
"django-cas-ng==5.0.1", "django-dbconn-retry>=0.1.8",
"django-dbconn-retry==0.1.8", "django-dramatiq>=0.13.0",
"django-dramatiq==0.13.0", "django-redis>=5.4.0",
"django-redis==5.4.0", "djangorestframework>=3.16.0",
"djangorestframework==3.16.0", "envelopes>=0.4",
"dramatiq==1.17.0", "gunicorn>=23.0.0",
"entrypoints==0.4", "otpauth>=2.2.1",
"envelopes==0.4", "pillow>=11.2.1",
"flake8==7.0.0", "psycopg>=3.2.9",
"flake8-coding==1.3.2", "psycopg-binary>=3.2.9",
"flake8-quotes==3.3.2", "python-dateutil>=2.9.0.post0",
"gunicorn==22.0.0", "qrcode>=8.2",
"jsonfield==3.1.0", "raven>=6.10.0",
"openai>=1.108.1", "requests>=2.32.4",
"otpauth==1.0.1", "uvicorn>=0.35.0",
"pillow==10.2.0", "xlsxwriter>=3.2.5",
"psycopg==3.2.9",
"psycopg-binary==3.2.9",
"python-dateutil==2.8.2",
"qrcode==8.2",
"raven==6.10.0",
"xlsxwriter==3.2.0",
] ]

View File

@@ -1,84 +0,0 @@
"""
WebSocket consumers for submission updates
"""
import json
import logging
from channels.generic.websocket import AsyncWebsocketConsumer
logger = logging.getLogger(__name__)
class SubmissionConsumer(AsyncWebsocketConsumer):
"""
WebSocket consumer for real-time submission updates
当用户提交代码后,通过 WebSocket 实时接收判题状态更新
"""
async def connect(self):
"""处理 WebSocket 连接"""
self.user = self.scope["user"]
# 只允许认证用户连接
if not self.user.is_authenticated:
await self.close()
return
# 使用用户 ID 作为组名,这样可以向特定用户推送消息
self.group_name = f"submission_user_{self.user.id}"
# 加入用户专属的组
await self.channel_layer.group_add(
self.group_name,
self.channel_name
)
await self.accept()
logger.info(f"WebSocket connected: user_id={self.user.id}, channel={self.channel_name}")
async def disconnect(self, close_code):
"""处理 WebSocket 断开连接"""
if hasattr(self, 'group_name'):
await self.channel_layer.group_discard(
self.group_name,
self.channel_name
)
logger.info(f"WebSocket disconnected: user_id={self.user.id}, close_code={close_code}")
async def receive(self, text_data):
"""
接收客户端消息
客户端可以发送心跳包或订阅特定提交
"""
try:
data = json.loads(text_data)
message_type = data.get("type")
if message_type == "ping":
# 响应心跳包
await self.send(text_data=json.dumps({
"type": "pong",
"timestamp": data.get("timestamp")
}))
elif message_type == "subscribe":
# 订阅特定提交的更新
submission_id = data.get("submission_id")
if submission_id:
logger.info(f"User {self.user.id} subscribed to submission {submission_id}")
# 可以在这里做额外的订阅逻辑
except json.JSONDecodeError:
logger.error(f"Invalid JSON received from user {self.user.id}")
except Exception as e:
logger.error(f"Error handling message from user {self.user.id}: {str(e)}")
async def submission_update(self, event):
"""
接收来自 channel layer 的代码提交更新消息并发送给客户端
这个方法名对应 push_submission_update 中的 type 字段
"""
try:
# 从 event 中提取数据并发送给客户端
await self.send(text_data=json.dumps(event["data"]))
logger.debug(f"Sent submission update to user {self.user.id}: {event['data']}")
except Exception as e:
logger.error(f"Error sending submission update to user {self.user.id}: {str(e)}")

View File

@@ -1,19 +0,0 @@
# Generated by Django 5.2.3 on 2025-09-25 07:03
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('contest', '0001_initial'),
('problem', '0001_initial'),
('submission', '0001_initial'),
]
operations = [
migrations.AddIndex(
model_name='submission',
index=models.Index(fields=['user_id', 'create_time'], name='user_create_time_idx'),
),
]

View File

@@ -41,11 +41,7 @@ class Submission(models.Model):
ip = models.TextField(null=True) ip = models.TextField(null=True)
def check_user_permission(self, user, check_share=True): def check_user_permission(self, user, check_share=True):
if ( if self.user_id == user.id or user.is_super_admin() or user.can_mgmt_all_problem() or self.problem.created_by_id == user.id:
self.user_id == user.id
or not user.is_regular_user()
or self.problem.created_by_id == user.id
):
return True return True
if check_share: if check_share:
@@ -58,11 +54,6 @@ class Submission(models.Model):
class Meta: class Meta:
db_table = "submission" db_table = "submission"
ordering = ("-create_time",) ordering = ("-create_time",)
indexes = [
models.Index(
fields=["user_id", "create_time"], name="user_create_time_idx"
),
]
def __str__(self): def __str__(self):
return self.id return self.id

View File

@@ -8,7 +8,6 @@ class CreateSubmissionSerializer(serializers.Serializer):
language = LanguageNameChoiceField() language = LanguageNameChoiceField()
code = serializers.CharField(max_length=1024 * 1024) code = serializers.CharField(max_length=1024 * 1024)
contest_id = serializers.IntegerField(required=False) contest_id = serializers.IntegerField(required=False)
problemset_id = serializers.IntegerField(required=False)
captcha = serializers.CharField(required=False) captcha = serializers.CharField(required=False)
@@ -35,7 +34,6 @@ class SubmissionSafeModelSerializer(serializers.ModelSerializer):
class SubmissionListSerializer(serializers.ModelSerializer): class SubmissionListSerializer(serializers.ModelSerializer):
problem = serializers.SlugRelatedField(read_only=True, slug_field="_id") problem = serializers.SlugRelatedField(read_only=True, slug_field="_id")
problem_title = serializers.CharField(source="problem.title")
show_link = serializers.SerializerMethodField() show_link = serializers.SerializerMethodField()
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):

78
submission/tests.py Normal file
View File

@@ -0,0 +1,78 @@
from copy import deepcopy
from unittest import mock
from problem.models import Problem, ProblemTag
from utils.api.tests import APITestCase
from .models import Submission
DEFAULT_PROBLEM_DATA = {"_id": "A-110", "title": "test", "description": "<p>test</p>", "input_description": "test",
"output_description": "test", "time_limit": 1000, "memory_limit": 256, "difficulty": "Low",
"visible": True, "tags": ["test"], "languages": ["C", "C++", "Java", "Python2"], "template": {},
"samples": [{"input": "test", "output": "test"}], "spj": False, "spj_language": "C",
"spj_code": "", "test_case_id": "499b26290cc7994e0b497212e842ea85",
"test_case_score": [{"output_name": "1.out", "input_name": "1.in", "output_size": 0,
"stripped_output_md5": "d41d8cd98f00b204e9800998ecf8427e",
"input_size": 0, "score": 0}],
"rule_type": "ACM", "hint": "<p>test</p>", "source": "test"}
DEFAULT_SUBMISSION_DATA = {
"problem_id": "1",
"user_id": 1,
"username": "test",
"code": "xxxxxxxxxxxxxx",
"result": -2,
"info": {},
"language": "C",
"statistic_info": {}
}
# todo contest submission
class SubmissionPrepare(APITestCase):
def _create_problem_and_submission(self):
user = self.create_admin("test", "test123", login=False)
problem_data = deepcopy(DEFAULT_PROBLEM_DATA)
tags = problem_data.pop("tags")
problem_data["created_by"] = user
self.problem = Problem.objects.create(**problem_data)
for tag in tags:
tag = ProblemTag.objects.create(name=tag)
self.problem.tags.add(tag)
self.problem.save()
self.submission_data = deepcopy(DEFAULT_SUBMISSION_DATA)
self.submission_data["problem_id"] = self.problem.id
self.submission = Submission.objects.create(**self.submission_data)
class SubmissionListTest(SubmissionPrepare):
def setUp(self):
self._create_problem_and_submission()
self.create_user("123", "345")
self.url = self.reverse("submission_list_api")
def test_get_submission_list(self):
resp = self.client.get(self.url, data={"limit": "10"})
self.assertSuccess(resp)
@mock.patch("submission.views.oj.judge_task.send")
class SubmissionAPITest(SubmissionPrepare):
def setUp(self):
self._create_problem_and_submission()
self.user = self.create_user("123", "test123")
self.url = self.reverse("submission_api")
def test_create_submission(self, judge_task):
resp = self.client.post(self.url, self.submission_data)
self.assertSuccess(resp)
judge_task.assert_called()
def test_create_submission_with_wrong_language(self, judge_task):
self.submission_data.update({"language": "Python3"})
resp = self.client.post(self.url, self.submission_data)
self.assertFailed(resp)
self.assertDictEqual(resp.data, {"error": "error",
"data": "Python3 is now allowed in the problem"})
judge_task.assert_not_called()

View File

@@ -1,6 +1,6 @@
from account.decorators import super_admin_required from account.decorators import super_admin_required
from judge.tasks import judge_task from judge.tasks import judge_task
# from judge.dispatcher import JudgeDispatcher
from utils.api import APIView from utils.api import APIView
from ..models import Submission, JudgeStatus from ..models import Submission, JudgeStatus
from account.models import User, AdminType from account.models import User, AdminType
@@ -8,12 +8,6 @@ from problem.models import Problem
from django.db.models import Count, Q from django.db.models import Count, Q
def get_real_name(username, class_name):
if class_name and username.startswith("ks"):
return username[len(f"ks{class_name}"):]
return username
class SubmissionRejudgeAPI(APIView): class SubmissionRejudgeAPI(APIView):
@super_admin_required @super_admin_required
def get(self, request): def get(self, request):
@@ -21,9 +15,7 @@ class SubmissionRejudgeAPI(APIView):
if not id: if not id:
return self.error("Parameter error, id is required") return self.error("Parameter error, id is required")
try: try:
submission = Submission.objects.select_related("problem").get( submission = Submission.objects.select_related("problem").get(id=id, contest_id__isnull=True)
id=id, contest_id__isnull=True
)
except Submission.DoesNotExist: except Submission.DoesNotExist:
return self.error("Submission does not exists") return self.error("Submission does not exists")
submission.statistic_info = {} submission.statistic_info = {}
@@ -31,123 +23,82 @@ class SubmissionRejudgeAPI(APIView):
judge_task.send(submission.id, submission.problem.id) judge_task.send(submission.id, submission.problem.id)
return self.success() return self.success()
class SubmissionStatisticsAPI(APIView): class SubmissionStatisticsAPI(APIView):
@super_admin_required @super_admin_required
def get(self, request): def get(self, request):
start = request.GET.get("start") start = request.GET.get("start")
end = request.GET.get("end") end = request.GET.get("end")
if not start or not end: if not start or not end:
return self.error("start and end is required") return self.error("start and end is required")
submissions = Submission.objects.filter( submissions = Submission.objects.filter(contest_id__isnull=True,
contest_id__isnull=True, create_time__gte=start, create_time__lte=end create_time__gte=start,
).select_related("problem__created_by") create_time__lte=end).select_related("problem__created_by")
problem_id = request.GET.get("problem_id") problem_id = request.GET.get("problem_id")
if problem_id: if problem_id:
try: try:
problem = Problem.objects.get( problem = Problem.objects.get(_id=problem_id, contest_id__isnull=True, visible=True)
_id=problem_id, contest_id__isnull=True, visible=True
)
except Problem.DoesNotExist: except Problem.DoesNotExist:
return self.error("Problem doesn't exist") return self.error("Problem doesn't exist")
submissions = submissions.filter(problem=problem) submissions = submissions.filter(problem=problem)
username = request.GET.get("username") username = request.GET.get("username")
all_users_dict = {} # 统计人数
person_count = 0
all_persons = []
if username: if username:
submissions = submissions.filter(username__icontains=username) submissions = submissions.filter(username__icontains=username)
all_users_dict = { all_persons = User.objects.filter(username__icontains=username,
user["username"]: user["class_name"] is_disabled=False,
for user in User.objects.filter( admin_type=AdminType.REGULAR_USER).values_list("username", flat=True)
username__icontains=username, person_count = all_persons.count()
is_disabled=False,
admin_type=AdminType.REGULAR_USER,
).values("username", "class_name")
}
# 优化:一次性获取所有统计数据 submission_count = submissions.count()
submission_stats = submissions.aggregate( accepted_count = submissions.filter(result=JudgeStatus.ACCEPTED).count()
total_count=Count("id"),
accepted_count=Count("id", filter=Q(result=JudgeStatus.ACCEPTED)), try:
) correct_rate = round(accepted_count/submission_count*100, 2)
submission_count = submission_stats["total_count"] except ZeroDivisionError:
accepted_count = submission_stats["accepted_count"] correct_rate = 0
correct_rate = (
round(accepted_count / submission_count * 100, 2) if submission_count else 0
)
# 优化:获取用户提交统计 counts = submissions.values("username").annotate(submission_count=Count("id", distinct=True),
user_submissions = ( accepted_count=Count("id", filter=Q(result=JudgeStatus.ACCEPTED), distinct=True),
submissions.values("username") ).order_by("-submission_count")
.annotate(
submission_count=Count("id"),
accepted_count=Count("id", filter=Q(result=JudgeStatus.ACCEPTED)),
)
.order_by("-submission_count")
)
# 获取所有有提交记录的用户的class_name信息
submitted_usernames = {item["username"] for item in user_submissions}
if submitted_usernames:
submitted_users_dict = {
user["username"]: user["class_name"]
for user in User.objects.filter(
username__in=submitted_usernames
).values("username", "class_name")
}
else:
submitted_users_dict = {}
# 处理有提交记录的用户
accepted = [] accepted = []
for item in counts:
for item in user_submissions:
username_key = item["username"]
if item["accepted_count"] > 0: if item["accepted_count"] > 0:
rate = round(item["accepted_count"] / item["submission_count"] * 100, 2) rate = round(item["accepted_count"]/item["submission_count"]*100, 2)
accepted.append( item["correct_rate"] = f"{rate}%"
{ accepted.append(item)
"username": username_key,
"class_name": submitted_users_dict.get(username_key),
"submission_count": item["submission_count"],
"accepted_count": item["accepted_count"],
"correct_rate": f"{rate}%",
}
)
# 处理无提交记录的用户,只返回姓名列表
unaccepted = [] unaccepted = []
if all_users_dict: if len(accepted) > 0:
unaccepted_usernames = set(all_users_dict.keys()) - submitted_usernames unaccepted = list(set(all_persons) - set([item['username'] for item in accepted]))
for username in unaccepted_usernames:
class_name = all_users_dict[username] # 统计人数完成率
real_name = get_real_name(username, class_name)
unaccepted.append(real_name)
# 计算人数完成率
person_count = len(all_users_dict) if all_users_dict else 0
person_rate = 0 person_rate = 0
if person_count: if person_count:
person_rate = min(100, round(len(accepted) / person_count * 100, 2)) person_rate = round(len(accepted)/person_count*100, 2)
# 处理已删除用户但提交记录仍存在的情况 # 下面是做一些超出 100% 的操作,比如有人已经删号了,但是提交记录还在
if person_rate >= 100:
person_rate = 100
# 搜出来的人数比提交人数还多的情况
if person_count < len(accepted): if person_count < len(accepted):
person_count = len(accepted) person_count = len(accepted)
return self.success( return self.success({
{ "submission_count": submission_count,
"submission_count": submission_count, "accepted_count": accepted_count,
"accepted_count": accepted_count, "correct_rate": f"{correct_rate}%",
"correct_rate": f"{correct_rate}%", "person_count": person_count,
"person_count": person_count, "person_rate": f"{person_rate}%",
"person_rate": f"{person_rate}%", "data": accepted,
"data": accepted, "data_unaccepted": unaccepted
"data_unaccepted": unaccepted, })
}
)

View File

@@ -19,7 +19,6 @@ from ..serializers import (
ShareSubmissionSerializer, ShareSubmissionSerializer,
) )
from ..serializers import SubmissionSafeModelSerializer, SubmissionListSerializer from ..serializers import SubmissionSafeModelSerializer, SubmissionListSerializer
from problemset.models import ProblemSetSubmission
class SubmissionAPI(APIView): class SubmissionAPI(APIView):
@@ -92,7 +91,6 @@ class SubmissionAPI(APIView):
ip=request.session["ip"], ip=request.session["ip"],
contest_id=data.get("contest_id"), contest_id=data.get("contest_id"),
) )
# use this for debug # use this for debug
# JudgeDispatcher(submission.id, problem.id).judge() # JudgeDispatcher(submission.id, problem.id).judge()
judge_task.send(submission.id, problem.id) judge_task.send(submission.id, problem.id)
@@ -166,7 +164,6 @@ class SubmissionListAPI(APIView):
myself = request.GET.get("myself") myself = request.GET.get("myself")
result = request.GET.get("result") result = request.GET.get("result")
username = request.GET.get("username") username = request.GET.get("username")
language = request.GET.get("language")
if problem_id: if problem_id:
try: try:
problem = Problem.objects.get( problem = Problem.objects.get(
@@ -176,7 +173,10 @@ class SubmissionListAPI(APIView):
return self.error("Problem doesn't exist") return self.error("Problem doesn't exist")
submissions = submissions.filter(problem=problem) submissions = submissions.filter(problem=problem)
if not SysOptions.submission_list_show_all and request.user.is_regular_user(): if (
not SysOptions.submission_list_show_all
and not request.user.is_super_admin()
):
return self.success({"results": [], "total": 0}) return self.success({"results": [], "total": 0})
if myself and myself == "1": if myself and myself == "1":
@@ -185,8 +185,6 @@ class SubmissionListAPI(APIView):
submissions = submissions.filter(username__icontains=username) submissions = submissions.filter(username__icontains=username)
if result: if result:
submissions = submissions.filter(result=result) submissions = submissions.filter(result=result)
if language:
submissions = submissions.filter(language=language)
data = self.paginate_data(request, submissions) data = self.paginate_data(request, submissions)
data["results"] = SubmissionListSerializer( data["results"] = SubmissionListSerializer(

View File

@@ -132,7 +132,7 @@ class APIView(View):
results = query_set[offset:offset + limit] results = query_set[offset:offset + limit]
if object_serializer: if object_serializer:
count = query_set.count() count = query_set.count()
results = object_serializer(results, many=True, context={"request": request}).data results = object_serializer(results, many=True).data
else: else:
count = query_set.count() count = query_set.count()
data = {"results": results, data = {"results": results,

View File

@@ -25,7 +25,6 @@ class CacheKey:
waiting_queue = "waiting_queue" waiting_queue = "waiting_queue"
contest_rank_cache = "contest_rank_cache" contest_rank_cache = "contest_rank_cache"
website_config = "website_config" website_config = "website_config"
problem_authors = "problem_authors"
class Difficulty(Choices): class Difficulty(Choices):

View File

@@ -16,6 +16,14 @@ class LanguageNameChoiceField(serializers.CharField):
return data return data
class SPJLanguageNameChoiceField(serializers.CharField):
def to_internal_value(self, data):
data = super().to_internal_value(data)
if data and data not in SysOptions.spj_language_names:
raise InvalidLanguage(data)
return data
class LanguageNameMultiChoiceField(serializers.ListField): class LanguageNameMultiChoiceField(serializers.ListField):
def to_internal_value(self, data): def to_internal_value(self, data):
data = super().to_internal_value(data) data = super().to_internal_value(data)
@@ -23,3 +31,12 @@ class LanguageNameMultiChoiceField(serializers.ListField):
if item not in SysOptions.language_names: if item not in SysOptions.language_names:
raise InvalidLanguage(item) raise InvalidLanguage(item)
return data return data
class SPJLanguageNameMultiChoiceField(serializers.ListField):
def to_internal_value(self, data):
data = super().to_internal_value(data)
for item in data:
if item not in SysOptions.spj_language_names:
raise InvalidLanguage(item)
return data

View File

@@ -57,6 +57,11 @@ def datetime2str(value, format="iso-8601"):
return value return value
return value.strftime(format) return value.strftime(format)
def timestamp2utcstr(value):
return datetime.datetime.utcfromtimestamp(value).isoformat()
def natural_sort_key(s, _nsre=re.compile(r"(\d+)")): def natural_sort_key(s, _nsre=re.compile(r"(\d+)")):
return [int(text) if text.isdigit() else text.lower() return [int(text) if text.isdigit() else text.lower()
for text in re.split(_nsre, s)] for text in re.split(_nsre, s)]

View File

@@ -1,142 +0,0 @@
"""
WebSocket utility functions for pushing real-time updates
"""
import logging
from channels.layers import get_channel_layer
from asgiref.sync import async_to_sync
logger = logging.getLogger(__name__)
def push_submission_update(submission_id: str, user_id: int, data: dict):
"""
推送提交状态更新到指定用户的 WebSocket 连接
Args:
submission_id: 提交 ID
user_id: 用户 ID
data: 要发送的数据,应该包含 type, submission_id, result 等字段
"""
channel_layer = get_channel_layer()
if channel_layer is None:
logger.warning("Channel layer is not configured, cannot push submission update")
return
# 构建组名,与 SubmissionConsumer 中的组名一致
group_name = f"submission_user_{user_id}"
try:
# 向指定用户组发送消息
# type 字段对应 consumer 中的方法名submission_update
async_to_sync(channel_layer.group_send)(
group_name,
{
"type": "submission_update", # 对应 SubmissionConsumer.submission_update 方法
"data": data,
}
)
logger.info(f"Pushed submission update: submission_id={submission_id}, user_id={user_id}, status={data.get('status')}")
except Exception as e:
logger.error(f"Failed to push submission update: submission_id={submission_id}, user_id={user_id}, error={str(e)}")
def push_to_user(user_id: int, message_type: str, data: dict):
"""
向指定用户推送自定义消息
Args:
user_id: 用户 ID
message_type: 消息类型
data: 消息数据
"""
channel_layer = get_channel_layer()
if channel_layer is None:
logger.warning("Channel layer is not configured, cannot push message")
return
group_name = f"submission_user_{user_id}"
try:
async_to_sync(channel_layer.group_send)(
group_name,
{
"type": "submission_update",
"data": {
"type": message_type,
**data
},
}
)
logger.info(f"Pushed message to user {user_id}: type={message_type}")
except Exception as e:
logger.error(f"Failed to push message to user {user_id}: error={str(e)}")
def push_flowchart_evaluation_update(submission_id: str, user_id: int, data: dict):
"""
推送流程图评分状态更新到指定用户的 WebSocket 连接
Args:
submission_id: 流程图提交 ID
user_id: 用户 ID
data: 要发送的数据,应该包含 type, submission_id, score, grade, feedback 等字段
"""
channel_layer = get_channel_layer()
if channel_layer is None:
logger.warning("Channel layer is not configured, cannot push flowchart evaluation update")
return
# 构建组名,与 FlowchartConsumer 中的组名一致
group_name = f"flowchart_user_{user_id}"
try:
# 向指定用户组发送消息
# type 字段对应 consumer 中的方法名flowchart_evaluation_update
async_to_sync(channel_layer.group_send)(
group_name,
{
"type": "flowchart_evaluation_update", # 对应 FlowchartConsumer.flowchart_evaluation_update 方法
"data": data,
}
)
logger.info(f"Pushed flowchart evaluation update: submission_id={submission_id}, user_id={user_id}, type={data.get('type')}")
except Exception as e:
logger.error(f"Failed to push flowchart evaluation update: submission_id={submission_id}, user_id={user_id}, error={str(e)}")
def push_config_update(key: str, value):
"""
推送配置更新到所有连接的客户端
Args:
key: 配置键名
value: 配置值
"""
channel_layer = get_channel_layer()
if channel_layer is None:
logger.warning("Channel layer is not configured, cannot push config update")
return
# 使用全局配置组名
group_name = "config_updates"
try:
# 向所有连接的客户端发送配置更新
async_to_sync(channel_layer.group_send)(
group_name,
{
"type": "config_update",
"data": {
"type": "config_update",
"key": key,
"value": value
}
}
)
logger.info(f"Pushed config update: {key}={value}")
except Exception as e:
logger.error(f"Failed to push config update: {key}={value}, error={str(e)}")

1263
uv.lock generated

File diff suppressed because it is too large Load Diff