Compare commits

..

10 Commits

Author SHA1 Message Date
a1b51ebb9e update 2025-07-14 21:41:27 +08:00
a9a6b87fef test for asgi 2025-07-14 21:33:03 +08:00
2d3588c755 revert 2025-06-15 20:26:43 +08:00
a2bfc28ac7 test 2025-06-15 20:21:37 +08:00
6aac767641 test 2025-06-15 20:18:24 +08:00
73af9d96b2 test 2025-06-15 20:15:49 +08:00
8a2fa11afc test 2025-06-15 20:12:48 +08:00
3f1c7250bd test 2025-06-15 20:06:50 +08:00
bd0a7f30f8 test 2025-06-15 19:35:11 +08:00
8a043d2ffa test 2025-06-15 19:26:45 +08:00
109 changed files with 1362 additions and 6888 deletions

139
CLAUDE.md
View File

@@ -1,139 +0,0 @@
# CLAUDE.md
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
## Project Overview
**OnlineJudge** is the backend for an Online Judge platform. Built with Django 5 + Django REST Framework, PostgreSQL, Redis, Django Channels (WebSocket), and Dramatiq (async task queue). Python 3.12+, managed with `uv`.
## Commands
```bash
# Development
python dev.py # Start dev server: Django on :8000 + Daphne WebSocket on :8001
python manage.py runserver # HTTP only (no WebSocket support)
python manage.py migrate # Apply database migrations
python manage.py makemigrations # Create new migrations
# Dependencies
uv sync # Install dependencies from uv.lock
uv add <package> # Add a dependency
# Testing
python manage.py test # Run all tests
python manage.py test account # Run tests for a single app
python manage.py test account.tests.TestClassName # Run a single test class
python run_test.py # Run flake8 lint + coverage in one step
python run_test.py -m account # Run flake8 + tests for a single module
python run_test.py -c # Run flake8 + tests + open HTML coverage report
# Initial setup
python manage.py inituser --username admin --password <pw> --action create_super_admin
python manage.py inituser --username admin --password <pw> --action reset
```
## Architecture
### App Modules
Each Django app follows the same structure:
```
<app>/
├── models.py # Django models
├── serializers.py # DRF serializers
├── views/
│ ├── oj.py # User-facing API views
│ └── admin.py # Admin API views
└── urls/
├── oj.py # User-facing URL patterns
└── admin.py # Admin URL patterns
```
Apps: `account`, `problem`, `submission`, `contest`, `ai`, `flowchart`, `problemset`, `class_pk`, `announcement`, `tutorial`, `message`, `comment`, `conf`, `options`, `judge`
`utils/` is itself a Django app (listed in `INSTALLED_APPS`) — not just a helpers package. It provides `RichTextField` (XSS-sanitized `TextField`), `APIError`, the base `APIView`, caching, WebSocket helpers, and the `inituser` management command. Import shared utilities from `utils.*`.
### URL Routing
All routes are registered in `oj/urls.py`:
- `api/` — user-facing endpoints
- `api/admin/` — admin-only endpoints
WebSocket routing is in `oj/routing.py`.
### Settings Structure
- `oj/settings.py` — base configuration (imports dev or production settings based on `OJ_ENV`)
- `oj/dev_settings.py` — development overrides (imported when `OJ_ENV != "production"`)
- `oj/production_settings.py` — production overrides
### Base APIView & View Patterns
`utils/api/api.py` provides the custom base classes and decorators used by **all** views:
- **`APIView`** — base class for all views (not DRF's `APIView`). Key methods:
- `self.success(data)` — returns `{"error": null, "data": data}`
- `self.error(msg)` — returns `{"error": "error", "data": msg}`
- `self.paginate_data(request, query_set, serializer)` — offset/limit pagination
- `self.invalid_serializer(serializer)` — standard validation error response
- **`CSRFExemptAPIView`** — same as `APIView` but CSRF-exempt
- **`@validate_serializer(SerializerClass)`** — decorator for view methods that validates `request.data` against a serializer before the method runs. On success, `request.data` is replaced with validated data.
Typical view method pattern:
```python
@validate_serializer(CreateProblemSerializer)
@super_admin_required
def post(self, request):
# request.data is already validated
return self.success(...)
```
### Authentication & Permissions
`account/decorators.py` provides decorators used on view methods:
- `@login_required` / `@admin_role_required` / `@super_admin_required`
- `@problem_permission_required`
- `@check_contest_permission(check_type)` — validates contest access, sets `self.contest`
- `ensure_created_by(obj, user)` — helper that raises `APIError` if user doesn't own the object
### Judge System
- `judge/dispatcher.py` — dispatches submissions to the judge sandbox (JudgeServer)
- `judge/tasks.py` — Dramatiq async tasks for judging
- `judge/languages.py` — language configurations (compile/run commands, limits)
Judge status codes are defined in `submission/models.py` (`JudgeStatus` class, codes -2 to 8) and must match the frontend's `utils/constants.ts`.
### Site Configuration (SysOptions)
`options/options.py` provides `SysOptions` — a metaclass-based system for site-wide configuration stored in the database with thread-local caching. Access settings like `SysOptions.smtp_config`, `SysOptions.languages`, etc.
### WebSocket (Channels)
`submission/consumers.py` — WebSocket consumer for real-time submission status updates. Uses `channels-redis` as the channel layer backend. Push updates via `utils/websocket.py:push_submission_update()`.
### Caching
Redis-backed via `django-redis`. Cache keys use MD5 hashing for consistency. See `utils/cache.py`.
### AI Integration
`utils/openai.py` — OpenAI client wrapper configured to work with OpenAI-compatible APIs (e.g., DeepSeek). Used by `ai/` app for submission analysis.
### Data Directory
Test cases and submission outputs are stored in a separate data directory (configured in settings, not in the repo). The `data/` directory in the repo contains configuration templates and `secret.key`.
## Key Domain Concepts
| Concept | Details |
|---|---|
| Problem types | ACM (binary accept/reject) vs OI (partial scoring) |
| Judge statuses | COMPILE_ERROR(-2), WRONG_ANSWER(-1), ACCEPTED(0), CPU_TLE(1), REAL_TLE(2), MLE(3), RE(4), SE(5), PENDING(6), JUDGING(7), PARTIALLY_ACCEPTED(8) |
| User roles | Regular / Admin / Super Admin |
| Contest types | Public vs Password Protected |
| Supported languages | C, C++, Python2, Python3, Java, JavaScript, Golang, Flowchart |
## Related Repository
The frontend is at `../ojnext` — a Vue 3 + Rsbuild project. See its CLAUDE.md for frontend details.

View File

@@ -1,18 +0,0 @@
# Generated by Django 5.2.3 on 2025-09-19 06:11
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('account', '0001_initial'),
]
operations = [
migrations.AddField(
model_name='userprofile',
name='class_name',
field=models.TextField(null=True),
),
]

View File

@@ -1,22 +0,0 @@
# Generated by Django 5.2.3 on 2025-09-19 06:14
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('account', '0002_userprofile_class_name'),
]
operations = [
migrations.RemoveField(
model_name='userprofile',
name='class_name',
),
migrations.AddField(
model_name='user',
name='class_name',
field=models.TextField(null=True),
),
]

View File

@@ -25,7 +25,6 @@ class UserManager(models.Manager):
class User(AbstractBaseUser):
username = models.TextField(unique=True)
class_name = models.TextField(null=True)
email = models.TextField(null=True)
create_time = models.DateTimeField(auto_now_add=True, null=True)
# One of UserType
@@ -51,9 +50,6 @@ class User(AbstractBaseUser):
objects = UserManager()
def is_regular_user(self):
return self.admin_type == AdminType.REGULAR_USER
def is_admin(self):
return self.admin_type == AdminType.ADMIN

View File

@@ -67,7 +67,6 @@ class UserAdminSerializer(serializers.ModelSerializer):
"open_api",
"is_disabled",
"raw_password",
"class_name",
]
def get_real_name(self, obj):
@@ -94,7 +93,6 @@ class UserSerializer(serializers.ModelSerializer):
"two_factor_auth",
"open_api",
"is_disabled",
"class_name",
]
@@ -131,7 +129,7 @@ class EditUserSerializer(serializers.Serializer):
open_api = serializers.BooleanField()
two_factor_auth = serializers.BooleanField()
is_disabled = serializers.BooleanField()
class_name = serializers.CharField(required=False, allow_null=True, allow_blank=True)
class EditUserProfileSerializer(serializers.Serializer):
real_name = serializers.CharField(max_length=32, allow_null=True, required=False)
@@ -143,6 +141,7 @@ class EditUserProfileSerializer(serializers.Serializer):
major = serializers.CharField(max_length=64, allow_blank=True, required=False)
language = serializers.CharField(max_length=32, allow_blank=True, required=False)
class ApplyResetPasswordSerializer(serializers.Serializer):
email = serializers.EmailField()
captcha = serializers.CharField()

View File

@@ -1,9 +1,8 @@
from django.urls import path
from ..views.admin import UserAdminAPI, GenerateUserAPI, ResetUserPasswordAPI
from ..views.admin import UserAdminAPI, GenerateUserAPI
urlpatterns = [
path("user", UserAdminAPI.as_view()),
path("generate_user", GenerateUserAPI.as_view()),
path("reset_password", ResetUserPasswordAPI.as_view()),
]

View File

@@ -15,7 +15,6 @@ from ..views.oj import (
UserProfileAPI,
UserRankAPI,
UserActivityRankAPI,
UserProblemRankAPI,
CheckTFARequiredAPI,
SessionManagementAPI,
ProfileProblemDisplayIDRefreshAPI,
@@ -46,7 +45,6 @@ urlpatterns = [
),
path("user_rank", UserRankAPI.as_view()),
path("user_activity_rank", UserActivityRankAPI.as_view()),
path("user_problem_rank", UserProblemRankAPI.as_view()),
path("sessions", SessionManagementAPI.as_view()),
path(
"open_api_appkey",

View File

@@ -3,10 +3,9 @@ import re
import xlsxwriter
from django.db import transaction, IntegrityError
from django.db.models import Q, F
from django.db.models import Q
from django.http import HttpResponse
from django.contrib.auth.hashers import make_password
from django.utils.crypto import get_random_string
from submission.models import Submission
from utils.api import APIView, validate_serializer
@@ -22,19 +21,6 @@ from ..serializers import (
from ..serializers import ImportUserSerializer
# ks251XXX 或者 ks2510XX 返回 251 或者 2510
# 其他返回 None
def get_class_name(username):
if username.startswith("ks"):
result = re.search(r"ks\d+", username)
if result:
return result.group(0)[2:]
else:
return None
else:
return None
class UserAdminAPI(APIView):
@validate_serializer(ImportUserSerializer)
@super_admin_required
@@ -54,7 +40,6 @@ class UserAdminAPI(APIView):
password=make_password(user_data[1]),
email=user_data[2],
raw_password=user_data[1],
class_name=get_class_name(user_data[0]),
)
)
@@ -100,13 +85,12 @@ class UserAdminAPI(APIView):
pre_username = user.username
user.username = data["username"].lower()
user.class_name = get_class_name(data["username"])
user.email = data["email"].lower()
user.admin_type = data["admin_type"]
user.is_disabled = data["is_disabled"]
if data["admin_type"] == AdminType.ADMIN:
user.problem_permission = data["problem_permission"] or ProblemPermission.OWN
user.problem_permission = data["problem_permission"]
elif data["admin_type"] == AdminType.SUPER_ADMIN:
user.problem_permission = ProblemPermission.ALL
else:
@@ -154,21 +138,12 @@ class UserAdminAPI(APIView):
return self.error("User does not exist")
return self.success(UserAdminSerializer(user).data)
# 获取排序参数
order_by = request.GET.get("order_by", "")
# 根据排序参数设置排序规则
if order_by == "-last_login":
# 最近登录,将 None 值放在最后
user = User.objects.all().order_by(F("last_login").desc(nulls_last=True))
else:
# 默认按创建时间倒序
user = User.objects.all().order_by("-create_time")
user = User.objects.all().order_by("-create_time")
type = request.GET.get("type", "")
is_admin = request.GET.get("admin", "0")
if type:
user = user.filter(admin_type=type)
if is_admin == "1":
user = user.exclude(admin_type=AdminType.REGULAR_USER)
keyword = request.GET.get("keyword", None)
if keyword:
@@ -264,27 +239,3 @@ class GenerateUserAPI(APIView):
# duplicate key value violates unique constraint "user_username_key"
# DETAIL: Key (username)=(root11) already exists.
return self.error(str(e).split("\n")[1])
class ResetUserPasswordAPI(APIView):
@super_admin_required
def post(self, request):
"""
重置用户密码为随机6位数字(不包括0)
"""
data = request.data
user_id = data["id"]
try:
user = User.objects.get(id=user_id)
except User.DoesNotExist:
return self.error("User does not exist")
# 生成6位随机数字密码(不包括0)
new_password = get_random_string(6, allowed_chars="123456789")
# 设置新密码
user.set_password(new_password)
user.save()
return self.success(new_password)

View File

@@ -12,7 +12,7 @@ from django.db.models import Count, Q
from django.utils import timezone
import qrcode
from otpauth import OtpAuth
from otpauth import TOTP
from problem.models import Problem
from submission.models import Submission, JudgeStatus
@@ -143,7 +143,7 @@ class TwoFactorAuthAPI(APIView):
label = f"{SysOptions.website_name_shortcut}:{user.username}"
image = qrcode.make(
OtpAuth(token).to_uri(
TOTP(token).to_uri(
"totp", label, SysOptions.website_name.replace(" ", "")
)
)
@@ -157,7 +157,7 @@ class TwoFactorAuthAPI(APIView):
"""
code = request.data["code"]
user = request.user
if OtpAuth(user.tfa_token).valid_totp(code):
if TOTP(user.tfa_token).verify(code):
user.two_factor_auth = True
user.save()
return self.success("Succeeded")
@@ -171,7 +171,7 @@ class TwoFactorAuthAPI(APIView):
user = request.user
if not user.two_factor_auth:
return self.error("2FA is already turned off")
if OtpAuth(user.tfa_token).valid_totp(code):
if TOTP(user.tfa_token).verify(code):
user.two_factor_auth = False
user.save()
return self.success("Succeeded")
@@ -209,23 +209,15 @@ class UserLoginAPI(APIView):
if user.is_disabled:
return self.error("Your account has been disabled")
if not user.two_factor_auth:
prev_login = user.last_login
auth.login(request, user)
request.session["prev_login"] = (
datetime2str(prev_login) if prev_login else ""
)
return self.success("Succeeded")
# `tfa_code` not in post data
if user.two_factor_auth and "tfa_code" not in data:
return self.error("tfa_required")
if OtpAuth(user.tfa_token).valid_totp(data["tfa_code"]):
prev_login = user.last_login
if TOTP(user.tfa_token).verify(data["tfa_code"]):
auth.login(request, user)
request.session["prev_login"] = (
datetime2str(prev_login) if prev_login else ""
)
return self.success("Succeeded")
else:
return self.error("Invalid two factor verification code")
@@ -295,7 +287,7 @@ class UserChangeEmailAPI(APIView):
if user.two_factor_auth:
if "tfa_code" not in data:
return self.error("tfa_required")
if not OtpAuth(user.tfa_token).valid_totp(data["tfa_code"]):
if not TOTP(user.tfa_token).verify(data["tfa_code"]):
return self.error("Invalid two factor verification code")
data["new_email"] = data["new_email"].lower()
if User.objects.filter(email=data["new_email"]).exists():
@@ -321,7 +313,7 @@ class UserChangePasswordAPI(APIView):
if user.two_factor_auth:
if "tfa_code" not in data:
return self.error("tfa_required")
if not OtpAuth(user.tfa_token).valid_totp(data["tfa_code"]):
if not TOTP(user.tfa_token).verify(data["tfa_code"]):
return self.error("Invalid two factor verification code")
user.set_password(data["new_password"])
user.save()
@@ -442,9 +434,8 @@ class UserRankAPI(APIView):
n = 0
if rule_type not in ContestRuleType.choices():
rule_type = ContestRuleType.ACM
profiles = UserProfile.objects.filter(
user__admin_type__in=[AdminType.REGULAR_USER, AdminType.ADMIN],
user__admin_type=AdminType.REGULAR_USER,
user__is_disabled=False,
user__username__icontains=username,
).select_related("user")
@@ -465,72 +456,23 @@ class UserActivityRankAPI(APIView):
if not start:
return self.error("start time is required")
hidden_names = User.objects.filter(
Q(admin_type=AdminType.SUPER_ADMIN) | Q(is_disabled=True)
Q(admin_type=AdminType.SUPER_ADMIN)
| Q(admin_type=AdminType.ADMIN)
| Q(is_disabled=True)
).values_list("username", flat=True)
submissions = Submission.objects.filter(
contest_id__isnull=True,
create_time__gte=start,
result=JudgeStatus.ACCEPTED,
).exclude(username__in=hidden_names)
data = list(
contest_id__isnull=True, create_time__gte=start, result=JudgeStatus.ACCEPTED
)
counts = (
submissions.values("username")
.annotate(count=Count("problem_id", distinct=True))
.order_by("-count")[:10]
)
return self.success(data)
class UserProblemRankAPI(APIView):
def get(self, request):
problem_id = request.GET.get("problem_id")
user = request.user
if not user.is_authenticated:
return self.error("User is not authenticated")
problem = Problem.objects.get(
_id=problem_id, contest_id__isnull=True, visible=True
)
submissions = Submission.objects.filter(
problem=problem, result=JudgeStatus.ACCEPTED
)
all_ac_count = submissions.values("user_id").distinct().count()
class_name = user.class_name or ""
class_ac_count = 0
if class_name:
users = User.objects.filter(
class_name=user.class_name, is_disabled=False
).values_list("id", flat=True)
user_ids = list(users)
submissions = submissions.filter(user_id__in=user_ids)
class_ac_count = submissions.values("user_id").distinct().count()
my_submissions = submissions.filter(user_id=user.id)
if len(my_submissions) == 0:
return self.success(
{
"class_name": class_name,
"rank": -1,
"class_ac_count": class_ac_count,
"all_ac_count": all_ac_count,
}
)
my_first_submission = my_submissions.order_by("create_time").first()
rank = submissions.filter(
create_time__lte=my_first_submission.create_time
).count()
return self.success(
{
"class_name": class_name,
"rank": rank,
"class_ac_count": class_ac_count,
"all_ac_count": all_ac_count,
}
.order_by("-count")[: 10 + len(hidden_names)]
)
data = []
for count in counts:
if count["username"] not in hidden_names:
data.append(count)
return self.success(data[:10])
class ProfileProblemDisplayIDRefreshAPI(APIView):

View File

View File

@@ -1,6 +0,0 @@
from django.apps import AppConfig
class AiConfig(AppConfig):
default_auto_field = 'django.db.models.BigAutoField'
name = 'ai'

View File

@@ -1,34 +0,0 @@
# Generated by Django 5.2.3 on 2025-09-24 12:59
import django.db.models.deletion
from django.conf import settings
from django.db import migrations, models
class Migration(migrations.Migration):
initial = True
dependencies = [
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
]
operations = [
migrations.CreateModel(
name='AIAnalysis',
fields=[
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('provider', models.TextField(default='deepseek')),
('data', models.JSONField()),
('system_prompt', models.TextField()),
('user_prompt', models.TextField()),
('analysis', models.TextField()),
('create_time', models.DateTimeField(auto_now_add=True)),
('user', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL)),
],
options={
'db_table': 'ai_analysis',
'ordering': ['-create_time'],
},
),
]

View File

@@ -1,18 +0,0 @@
# Generated by Django 5.2.3 on 2025-09-24 13:02
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('ai', '0001_initial'),
]
operations = [
migrations.AddField(
model_name='aianalysis',
name='model',
field=models.TextField(default='deepseek-chat'),
),
]

View File

@@ -1,18 +0,0 @@
from django.db import models
from account.models import User
class AIAnalysis(models.Model):
user = models.ForeignKey(User, on_delete=models.CASCADE)
provider = models.TextField(default="deepseek")
model = models.TextField(default="deepseek-chat")
data = models.JSONField()
system_prompt = models.TextField()
user_prompt = models.TextField()
analysis = models.TextField()
create_time = models.DateTimeField(auto_now_add=True)
class Meta:
db_table = "ai_analysis"
ordering = ["-create_time"]

View File

View File

View File

@@ -1,19 +0,0 @@
from django.urls import path
from ..views.oj import (
AIAnalysisAPI,
AIDetailDataAPI,
AIDurationDataAPI,
AIHeatmapDataAPI,
AIHintAPI,
AILoginSummaryAPI,
)
urlpatterns = [
path("ai/detail", AIDetailDataAPI.as_view()),
path("ai/duration", AIDurationDataAPI.as_view()),
path("ai/analysis", AIAnalysisAPI.as_view()),
path("ai/hint", AIHintAPI.as_view()),
path("ai/heatmap", AIHeatmapDataAPI.as_view()),
path("ai/login_summary", AILoginSummaryAPI.as_view()),
]

View File

View File

@@ -1,708 +0,0 @@
from collections import defaultdict
from datetime import datetime, timedelta
import hashlib
import json
from dateutil.relativedelta import relativedelta
from django.core.cache import cache
from django.db.models import Min, Count
from django.db.models.functions import TruncDate
from django.http import StreamingHttpResponse
from django.utils import timezone
from django.utils.dateparse import parse_datetime
from utils.api import APIView
from utils.openai import get_ai_client
from utils.shortcuts import datetime2str
from account.models import User
from problem.models import Problem
from submission.models import Submission, JudgeStatus
from flowchart.models import FlowchartSubmission, FlowchartSubmissionStatus
from account.decorators import login_required
from ai.models import AIAnalysis
CACHE_TIMEOUT = 300
DIFFICULTY_MAP = {"Low": "简单", "Mid": "中等", "High": "困难"}
DEFAULT_CLASS_SIZE = 45
# 评级阈值配置:(百分位上限, 评级)
GRADE_THRESHOLDS = [
(10, "S"), # 前10%: S级 - 卓越
(35, "A"), # 前35%: A级 - 优秀
(75, "B"), # 前75%: B级 - 良好
(100, "C"), # 其余: C级 - 及格
]
# 小规模参与惩罚配置:(最小人数, 等级降级映射)
SMALL_SCALE_PENALTY = {
"threshold": 10,
"downgrade": {"S": "A", "A": "B"},
}
# 等级权重映射(用于加权平均计算)
GRADE_WEIGHTS = {"S": 4, "A": 3, "B": 2, "C": 1}
# 平均等级阈值:(最小权重, 等级)
AVERAGE_GRADE_THRESHOLDS = [(3.5, "S"), (2.5, "A"), (1.5, "B")]
def get_cache_key(prefix, *args):
return hashlib.md5(f"{prefix}:{'_'.join(map(str, args))}".encode()).hexdigest()
def get_difficulty(difficulty):
return DIFFICULTY_MAP.get(difficulty, "中等")
def get_grade(rank, submission_count):
"""
计算题目完成评级
评级标准:
- S级前10%卓越水平10%的人)
- A级前35%优秀水平25%的人)
- B级前75%良好水平40%的人)
- C级75%之后及格水平25%的人)
特殊规则:
- 参与人数少于10人时S级降为A级A级降为B级避免因人少而评级虚高
"""
if not rank or rank <= 0 or submission_count <= 0:
return "C"
percentile = (rank - 1) / submission_count * 100
base_grade = "C"
for threshold, grade in GRADE_THRESHOLDS:
if percentile < threshold:
base_grade = grade
break
if submission_count < SMALL_SCALE_PENALTY["threshold"]:
base_grade = SMALL_SCALE_PENALTY["downgrade"].get(base_grade, base_grade)
return base_grade
def calculate_average_grade(grades):
"""根据等级列表计算加权平均等级"""
scores = [GRADE_WEIGHTS[g] for g in grades if g in GRADE_WEIGHTS]
if not scores:
return ""
avg = sum(scores) / len(scores)
for threshold, grade in AVERAGE_GRADE_THRESHOLDS:
if avg >= threshold:
return grade
return "C"
def find_user_rank(ranking_list, user_id):
"""在排名列表中找到用户的排名1-based未找到返回 None"""
return next(
(idx + 1 for idx, rec in enumerate(ranking_list) if rec["user_id"] == user_id),
None,
)
def get_class_user_ids(user):
if not user.class_name:
return []
cache_key = get_cache_key("class_users", user.class_name)
user_ids = cache.get(cache_key)
if user_ids is None:
user_ids = list(
User.objects.filter(class_name=user.class_name).values_list("id", flat=True)
)
cache.set(cache_key, user_ids, CACHE_TIMEOUT)
return user_ids
def get_user_first_ac_submissions(
user_id, start, end, class_user_ids=None, use_class_scope=False
):
base_qs = Submission.objects.filter(
result=JudgeStatus.ACCEPTED, create_time__gte=start, create_time__lte=end
)
if use_class_scope and class_user_ids:
base_qs = base_qs.filter(user_id__in=class_user_ids)
user_first_ac = list(
base_qs.filter(user_id=user_id)
.values("problem_id")
.annotate(first_ac_time=Min("create_time"))
)
if not user_first_ac:
return [], {}, []
problem_ids = [item["problem_id"] for item in user_first_ac]
ranked_first_ac = list(
base_qs.filter(problem_id__in=problem_ids)
.values("user_id", "problem_id")
.annotate(first_ac_time=Min("create_time"))
)
by_problem = defaultdict(list)
for item in ranked_first_ac:
by_problem[item["problem_id"]].append(item)
for submissions in by_problem.values():
submissions.sort(key=lambda x: (x["first_ac_time"], x["user_id"]))
return user_first_ac, by_problem, problem_ids
def stream_ai_response(client, system_prompt, user_prompt, on_complete=None):
"""SSE 流式响应生成器on_complete(full_text) 在流结束时调用"""
try:
stream = client.chat.completions.create(
model="deepseek-reasoner",
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
stream=True,
)
except Exception as exc:
yield f"data: {json.dumps({'type': 'error', 'message': str(exc)})}\n\n"
yield "event: end\n\n"
return
yield "event: start\n\n"
chunks = []
try:
for chunk in stream:
if not chunk.choices:
continue
choice = chunk.choices[0]
if choice.finish_reason:
if on_complete:
on_complete("".join(chunks).strip())
yield f"data: {json.dumps({'type': 'done'})}\n\n"
break
content = choice.delta.content
if content:
chunks.append(content)
yield f"data: {json.dumps({'type': 'delta', 'content': content})}\n\n"
except Exception as exc:
yield f"data: {json.dumps({'type': 'error', 'message': str(exc)})}\n\n"
finally:
yield "event: end\n\n"
def make_sse_response(generator):
"""创建 SSE StreamingHttpResponse"""
response = StreamingHttpResponse(
streaming_content=generator,
content_type="text/event-stream",
)
response["Cache-Control"] = "no-cache"
return response
class AIDetailDataAPI(APIView):
@login_required
def get(self, request):
start = request.GET.get("start")
end = request.GET.get("end")
user = request.user
cache_key = get_cache_key(
"ai_detail", user.id, user.class_name or "", start, end
)
cached_result = cache.get(cache_key)
if cached_result:
return self.success(cached_result)
class_user_ids = get_class_user_ids(user)
use_class_scope = bool(user.class_name) and len(class_user_ids) > 1
user_first_ac, by_problem, problem_ids = get_user_first_ac_submissions(
user.id, start, end, class_user_ids, use_class_scope
)
result = {
"user": user.username,
"class_name": user.class_name,
"start": start,
"end": end,
"solved": [],
"flowcharts": [],
"grade": "",
"tags": {},
"difficulty": {},
"contest_count": 0,
}
if user_first_ac:
problems = {
p.id: p
for p in Problem.objects.filter(id__in=problem_ids)
.select_related("contest")
.prefetch_related("tags")
}
solved, contest_ids = self._build_solved_records(
user_first_ac, by_problem, problems, user.id
)
# 查找 flowchart submissions
flowcharts_query = FlowchartSubmission.objects.filter(
user_id=user,
status=FlowchartSubmissionStatus.COMPLETED,
)
# 添加时间范围过滤
if start:
flowcharts_query = flowcharts_query.filter(create_time__gte=start)
if end:
flowcharts_query = flowcharts_query.filter(create_time__lte=end)
flowcharts = flowcharts_query.select_related("problem").only(
"id",
"create_time",
"ai_score",
"ai_grade",
"problem___id",
"problem__title",
)
# 按problem分组
problem_groups = defaultdict(list)
for flowchart in flowcharts:
problem_id = flowchart.problem._id
problem_groups[problem_id].append(flowchart)
flowcharts_data = []
for problem_id, submissions in problem_groups.items():
if not submissions:
continue
# 获取第一个提交的基本信息
first_submission = submissions[0]
# 计算统计数据
scores = [s.ai_score for s in submissions if s.ai_score is not None]
times = [s.create_time for s in submissions]
# 找到最高分和对应的等级
best_score = max(scores) if scores else 0
best_submission = next(
(s for s in submissions if s.ai_score == best_score), submissions[0]
)
best_grade = best_submission.ai_grade or ""
# 计算平均分
avg_score = sum(scores) / len(scores) if scores else 0
# 最新提交时间
latest_time = max(times) if times else first_submission.create_time
merged_item = {
"problem__id": problem_id,
"problem_title": first_submission.problem.title,
"submission_count": len(submissions),
"best_score": best_score,
"best_grade": best_grade,
"latest_submission_time": latest_time.isoformat() if latest_time else None,
"avg_score": round(avg_score, 0),
}
flowcharts_data.append(merged_item)
# 按最新提交时间排序
flowcharts_data.sort(
key=lambda x: x["latest_submission_time"] or "", reverse=True
)
result.update(
{
"solved": solved,
"flowcharts": flowcharts_data,
"grade": calculate_average_grade([s["grade"] for s in solved]),
"tags": self._calculate_top_tags(problems.values()),
"difficulty": self._calculate_difficulty_distribution(
problems.values()
),
"contest_count": len(set(contest_ids)),
}
)
cache.set(cache_key, result, CACHE_TIMEOUT)
return self.success(result)
def _build_solved_records(self, user_first_ac, by_problem, problems, user_id):
solved, contest_ids = [], []
for item in user_first_ac:
pid = item["problem_id"]
problem = problems.get(pid)
if not problem:
continue
ranking_list = by_problem.get(pid, [])
rank = find_user_rank(ranking_list, user_id)
if problem.contest_id:
contest_ids.append(problem.contest_id)
solved.append(
{
"problem": {
"display_id": problem._id,
"title": problem.title,
"contest_id": problem.contest_id,
"contest_title": getattr(problem.contest, "title", ""),
},
"ac_time": timezone.localtime(item["first_ac_time"]).isoformat(),
"rank": rank,
"ac_count": len(ranking_list),
"grade": get_grade(rank, len(ranking_list)),
"difficulty": get_difficulty(problem.difficulty),
}
)
return sorted(solved, key=lambda x: x["ac_time"]), contest_ids
def _calculate_top_tags(self, problems):
tags_counter = defaultdict(int)
for problem in problems:
for tag in problem.tags.all():
if tag.name:
tags_counter[tag.name] += 1
return dict(sorted(tags_counter.items(), key=lambda x: x[1], reverse=True)[:5])
def _calculate_difficulty_distribution(self, problems):
diff_counter = {"Low": 0, "Mid": 0, "High": 0}
for problem in problems:
diff_counter[
problem.difficulty if problem.difficulty in diff_counter else "Mid"
] += 1
return {
get_difficulty(k): v
for k, v in sorted(diff_counter.items(), key=lambda x: x[1], reverse=True)
}
class AIDurationDataAPI(APIView):
@login_required
def get(self, request):
end_iso = request.GET.get("end")
duration = request.GET.get("duration")
user = request.user
cache_key = get_cache_key(
"ai_duration", user.id, user.class_name or "", end_iso, duration
)
cached_result = cache.get(cache_key)
if cached_result:
return self.success(cached_result)
class_user_ids = get_class_user_ids(user)
use_class_scope = bool(user.class_name) and len(class_user_ids) > 1
time_config = self._parse_duration(duration)
start = datetime.fromisoformat(end_iso) - time_config["total_delta"]
duration_data = []
for i in range(time_config["show_count"]):
start = start + time_config["delta"]
period_end = start + time_config["delta"]
submission_count = Submission.objects.filter(
user_id=user.id, create_time__gte=start, create_time__lte=period_end
).count()
period_data = {
"unit": time_config["show_unit"],
"index": time_config["show_count"] - 1 - i,
"start": start.isoformat(),
"end": period_end.isoformat(),
"problem_count": 0,
"submission_count": submission_count,
"grade": "",
}
if submission_count > 0:
user_first_ac, by_problem, problem_ids = get_user_first_ac_submissions(
user.id,
start.isoformat(),
period_end.isoformat(),
class_user_ids,
use_class_scope,
)
if user_first_ac:
period_data["problem_count"] = len(problem_ids)
grades = [
get_grade(
find_user_rank(by_problem.get(item["problem_id"], []), user.id),
len(by_problem.get(item["problem_id"], [])),
)
for item in user_first_ac
]
period_data["grade"] = calculate_average_grade(grades)
duration_data.append(period_data)
cache.set(cache_key, duration_data, CACHE_TIMEOUT)
return self.success(duration_data)
def _parse_duration(self, duration):
unit, count = duration.split(":")
count = int(count)
configs = {
("months", 2): {
"show_count": 8,
"show_unit": "weeks",
"total_delta": timedelta(weeks=9),
"delta": timedelta(weeks=1),
},
("months", 6): {
"show_count": 6,
"show_unit": "months",
"total_delta": relativedelta(months=7),
"delta": relativedelta(months=1),
},
("years", 1): {
"show_count": 12,
"show_unit": "months",
"total_delta": relativedelta(months=13),
"delta": relativedelta(months=1),
},
}
return configs.get(
(unit, count),
{
"show_count": 4,
"show_unit": "weeks",
"total_delta": timedelta(weeks=5),
"delta": timedelta(weeks=1),
},
)
class AILoginSummaryAPI(APIView):
@login_required
def get(self, request):
user = request.user
end_time = timezone.now()
start_time = self._resolve_start_time(request, user, end_time)
problems_qs = Problem.objects.filter(
create_time__gte=start_time,
create_time__lte=end_time,
contest_id__isnull=True,
visible=True,
)
new_problem_count = problems_qs.count()
submissions_qs = Submission.objects.filter(
user_id=user.id, create_time__gte=start_time, create_time__lte=end_time
)
submission_count = submissions_qs.count()
accepted_count = submissions_qs.filter(result=JudgeStatus.ACCEPTED).count()
solved_count = (
submissions_qs.filter(result=JudgeStatus.ACCEPTED)
.values("problem_id")
.distinct()
.count()
)
flowchart_submission_count = FlowchartSubmission.objects.filter(
user_id=user.id, create_time__gte=start_time, create_time__lte=end_time
).count()
summary = {
"start": datetime2str(start_time),
"end": datetime2str(end_time),
"new_problem_count": new_problem_count,
"submission_count": submission_count,
"accepted_count": accepted_count,
"solved_count": solved_count,
"flowchart_submission_count": flowchart_submission_count,
}
analysis = ""
analysis_error = ""
if submission_count >= 3:
analysis, analysis_error = self._get_ai_analysis(summary)
data = {"summary": summary, "analysis": analysis}
if analysis_error:
data["analysis_error"] = analysis_error
return self.success(data)
def _resolve_start_time(self, request, user, end_time):
start_raw = request.session.get("prev_login") or request.GET.get("start")
start_time = parse_datetime(start_raw) if start_raw else None
if start_time and timezone.is_naive(start_time):
start_time = timezone.make_aware(
start_time, timezone.get_current_timezone()
)
if not start_time:
if user.last_login and user.last_login < end_time:
start_time = user.last_login
elif user.create_time:
start_time = user.create_time
else:
start_time = end_time - timedelta(days=7)
if start_time >= end_time:
start_time = end_time - timedelta(days=1)
return start_time
def _get_ai_analysis(self, summary):
try:
client = get_ai_client()
except Exception as exc:
return "", str(exc)
system_prompt = (
"你是 OnlineJudge 的学习助教。"
"请根据统计数据给出简短分析(1-2句),再给出一行结论,"
"结论用“结论:”开头。"
)
user_prompt = (
f"时间范围:{summary['start']}{summary['end']}\n"
f"新题目数:{summary['new_problem_count']}\n"
f"提交次数:{summary['submission_count']}\n"
f"AC 次数:{summary['accepted_count']}\n"
f"AC 题目数:{summary['solved_count']}\n"
f"流程图提交数:{summary['flowchart_submission_count']}\n"
)
try:
completion = client.chat.completions.create(
model="deepseek-reasoner",
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
)
except Exception as exc:
return "", str(exc)
if not completion.choices:
return "", ""
content = completion.choices[0].message.content or ""
return content.strip(), ""
class AIAnalysisAPI(APIView):
@login_required
def post(self, request):
details = request.data.get("details")
duration = request.data.get("duration")
client = get_ai_client()
system_prompt = "你是一个风趣的编程老师,学生使用判题狗平台进行编程练习。请根据学生提供的详细数据和每周数据,给出用户的学习建议,最后写一句鼓励学生的话。请使用 markdown 格式输出,不要在代码块中输出。"
user_prompt = f"这段时间内的详细数据: {details}\n(其中部分字段含义是 flowcharts:流程图的提交,solved:代码的提交)\n每周或每月的数据: {duration}"
def on_complete(full_text):
AIAnalysis.objects.create(
user=request.user,
provider="deepseek",
model="deepseek-reasoner",
data={"details": details, "duration": duration},
system_prompt=system_prompt,
user_prompt="这段时间内的详细数据,每周或每月的数据。",
analysis=full_text,
)
return make_sse_response(
stream_ai_response(client, system_prompt, user_prompt, on_complete)
)
class AIHintAPI(APIView):
@login_required
def post(self, request):
submission_id = request.data.get("submission_id")
if not submission_id:
return self.error("submission_id is required")
try:
submission = Submission.objects.get(id=submission_id, user_id=request.user.id)
except Submission.DoesNotExist:
return self.error("Submission not found")
problem = submission.problem
client = get_ai_client()
# 获取参考答案(同语言优先,否则取第一个)
answers = problem.answers or []
ref_answer = next(
(a["code"] for a in answers if a["language"] == submission.language),
answers[0]["code"] if answers else "",
)
system_prompt = (
"你是编程助教。你知道题目的参考答案,但【绝对禁止】把参考答案或其中任何代码"
"直接告诉学生,也不能以任何形式暗示完整解法。"
"你的任务是:对照参考答案,找出学生代码中的问题,"
"给出方向性提示(例如:指出哪类边界情况需要考虑、"
"哪个算法思路更合适、哪行代码逻辑可能有问题等)。"
"语气鼓励回复简洁3-5句话使用 Markdown 格式。"
)
user_prompt = (
f"题目:{problem.title}\n"
f"题目描述:{problem.description[:500]}\n"
f"参考答案(仅供你分析,不可透露给学生):\n```\n{ref_answer[:2000]}\n```\n"
f"学生提交语言:{submission.language}\n"
f"判题结果:{submission.result}\n"
f"错误信息:{submission.statistic_info.get('err_info', '')}\n"
f"学生代码:\n```\n{submission.code[:2000]}\n```"
)
return make_sse_response(
stream_ai_response(client, system_prompt, user_prompt)
)
class AIHeatmapDataAPI(APIView):
@login_required
def get(self, request):
user = request.user
cache_key = get_cache_key("ai_heatmap", user.id, user.class_name or "")
cached_result = cache.get(cache_key)
if cached_result:
return self.success(cached_result)
end = datetime.now()
start = end - timedelta(days=365)
# 使用单次查询获取所有数据,按日期分组统计
submission_counts = (
Submission.objects.filter(
user_id=user.id, create_time__gte=start, create_time__lte=end
)
.annotate(date=TruncDate("create_time"))
.values("date")
.annotate(count=Count("id"))
.order_by("date")
)
# 将查询结果转换为字典,便于快速查找
submission_dict = {item["date"]: item["count"] for item in submission_counts}
# 生成365天的热力图数据
heatmap_data = []
current_date = start.date()
for i in range(365):
day_date = current_date + timedelta(days=i)
submission_count = submission_dict.get(day_date, 0)
heatmap_data.append(
{
"timestamp": int(
datetime.combine(day_date, datetime.min.time()).timestamp()
* 1000
),
"value": submission_count,
}
)
cache.set(cache_key, heatmap_data, CACHE_TIMEOUT)
return self.success(heatmap_data)

View File

View File

@@ -1,3 +0,0 @@
from django.contrib import admin
# Register your models here.

View File

@@ -1,7 +0,0 @@
from django.apps import AppConfig
class ClassPkConfig(AppConfig):
default_auto_field = 'django.db.models.BigAutoField'
name = 'class_pk'
verbose_name = '班级PK'

View File

@@ -1,2 +0,0 @@
# 空文件

View File

@@ -1,4 +0,0 @@
from django.db import models
# 如果需要存储班级PK历史记录可以在这里定义模型
# 目前暂时不需要,因为都是实时计算

View File

@@ -1,3 +0,0 @@
# 如果需要序列化器,可以在这里定义
# 目前使用APIView的paginate_data方法暂时不需要

View File

@@ -1,2 +0,0 @@
# 空文件

View File

@@ -1,10 +0,0 @@
from django.urls import path
from ..views.oj import ClassRankAPI, UserClassRankAPI, ClassPKAPI
urlpatterns = [
path("class_rank", ClassRankAPI.as_view()),
path("user_class_rank", UserClassRankAPI.as_view()),
path("class_pk", ClassPKAPI.as_view()),
]

View File

@@ -1,344 +0,0 @@
import re
import statistics
from datetime import datetime
from django.db.models import Sum, Avg
from django.utils import timezone
from utils.api import APIView
from account.decorators import login_required
from account.models import User, UserProfile, AdminType
from submission.models import Submission, JudgeStatus
class ClassRankAPI(APIView):
"""获取班级排名列表"""
def get(self, request):
# 获取年级参数
grade = int(request.GET.get("grade"))
# 获取所有有用户的班级
classes = (
User.objects.filter(
class_name__isnull=False,
is_disabled=False,
admin_type__in=[AdminType.REGULAR_USER, AdminType.ADMIN],
class_name__startswith=str(grade),
)
.values("class_name")
.distinct()
)
class_stats = []
for class_info in classes:
class_name = class_info["class_name"]
users = User.objects.filter(
class_name=class_name,
is_disabled=False,
admin_type__in=[AdminType.REGULAR_USER, AdminType.ADMIN],
)
user_ids = list(users.values_list("id", flat=True))
profiles = UserProfile.objects.filter(user_id__in=user_ids)
total_ac = profiles.aggregate(total=Sum("accepted_number"))["total"] or 0
total_submission = (
profiles.aggregate(total=Sum("submission_number"))["total"] or 0
)
avg_ac = profiles.aggregate(avg=Avg("accepted_number"))["avg"] or 0
user_count = users.count()
class_stats.append(
{
"class_name": class_name,
"user_count": user_count,
"total_ac": int(total_ac),
"total_submission": int(total_submission),
"avg_ac": round(avg_ac, 2),
"ac_rate": round(total_ac / total_submission * 100, 2)
if total_submission > 0
else 0,
}
)
# 按总AC数排序
class_stats.sort(key=lambda x: (-x["total_ac"], x["total_submission"]))
# 添加排名
for i, stat in enumerate(class_stats):
stat["rank"] = i + 1
return self.success(class_stats)
class UserClassRankAPI(APIView):
"""获取用户在班级中的排名"""
@login_required
def get(self, request):
user = request.user
if not user.class_name:
return self.error("用户没有班级信息")
scope = request.GET.get("scope", "").lower()
show_all = scope == "all"
try:
limit = int(request.GET.get("limit", "10"))
except ValueError:
limit = 10
if limit <= 0 or limit > 250:
limit = 10
try:
offset = int(request.GET.get("offset", "0"))
except ValueError:
offset = 0
if offset < 0:
offset = 0
# 获取同班所有用户
class_users = User.objects.filter(
class_name=user.class_name,
is_disabled=False,
admin_type__in=[AdminType.REGULAR_USER, AdminType.ADMIN],
).select_related("userprofile")
user_ranks = []
for class_user in class_users:
profile = class_user.userprofile
user_ranks.append(
{
"user_id": class_user.id,
"username": class_user.username,
"accepted_number": profile.accepted_number,
"submission_number": profile.submission_number,
}
)
# 按AC数排序
user_ranks.sort(key=lambda x: (-x["accepted_number"], x["submission_number"]))
# 添加排名
my_rank = -1
for i, rank_info in enumerate(user_ranks):
rank_info["rank"] = i + 1
if rank_info["user_id"] == user.id:
my_rank = i + 1
trimmed_ranks = user_ranks
if not show_all and my_rank > 0 and len(user_ranks) > 10:
center_index = my_rank - 1
start = max(0, center_index - 5)
end = start + 10
if end > len(user_ranks):
end = len(user_ranks)
start = max(0, end - 10)
trimmed_ranks = user_ranks[start:end]
elif show_all:
trimmed_ranks = user_ranks[offset : offset + limit]
return self.success(
{
"class_name": user.class_name,
"my_rank": my_rank,
"total": len(user_ranks),
"ranks": trimmed_ranks,
}
)
class ClassPKAPI(APIView):
"""班级PK比较 - 多维度教育评价"""
def post(self, request):
class_names = request.data.get("class_name", [])
if not class_names or len(class_names) < 2:
return self.error("至少需要选择2个班级进行比较")
# 获取时间段参数
start_time = request.data.get("start_time")
end_time = request.data.get("end_time")
# 将时间字符串转换为datetime对象
# 处理空字符串、None 或 undefined 的情况
if start_time and isinstance(start_time, str) and start_time.strip():
try:
start_time = datetime.fromisoformat(start_time.replace("Z", "+00:00"))
if timezone.is_naive(start_time):
start_time = timezone.make_aware(start_time)
except (ValueError, AttributeError):
start_time = None
else:
start_time = None
if end_time and isinstance(end_time, str) and end_time.strip():
try:
end_time = datetime.fromisoformat(end_time.replace("Z", "+00:00"))
if timezone.is_naive(end_time):
end_time = timezone.make_aware(end_time)
except (ValueError, AttributeError):
end_time = None
else:
end_time = None
class_comparisons = []
for class_name in class_names:
users = User.objects.filter(
class_name=class_name,
is_disabled=False,
admin_type__in=[AdminType.REGULAR_USER, AdminType.ADMIN],
)
user_ids = list(users.values_list("id", flat=True))
# 获取所有学生的AC数列表用于统计计算
profiles = UserProfile.objects.filter(user_id__in=user_ids)
ac_list = sorted([p.accepted_number for p in profiles], reverse=True)
submission_list = sorted(
[p.submission_number for p in profiles], reverse=True
)
user_count = len(ac_list)
if user_count == 0:
continue
# 基础统计
total_ac = sum(ac_list)
total_submission = sum(submission_list)
avg_ac = statistics.mean(ac_list) if ac_list else 0
# 中位数和分位数
median_ac = statistics.median(ac_list) if ac_list else 0
q1_ac = statistics.quantiles(ac_list, n=4)[0] if len(ac_list) > 1 else 0
q3_ac = statistics.quantiles(ac_list, n=4)[2] if len(ac_list) > 1 else 0
iqr = q3_ac - q1_ac
# 标准差
std_dev = statistics.stdev(ac_list) if len(ac_list) > 1 else 0
# 前10名和后10名统计
top_10_count = min(10, user_count)
bottom_10_count = min(10, user_count)
top_10_avg = (
statistics.mean(ac_list[:top_10_count]) if top_10_count > 0 else 0
)
bottom_10_avg = (
statistics.mean(ac_list[-bottom_10_count:])
if bottom_10_count > 0
else 0
)
# 前25%和后25%统计
top_25_count = max(1, user_count // 4)
bottom_25_count = max(1, user_count // 4)
top_25_avg = (
statistics.mean(ac_list[:top_25_count]) if top_25_count > 0 else 0
)
bottom_25_avg = (
statistics.mean(ac_list[-bottom_25_count:])
if bottom_25_count > 0
else 0
)
# 优秀率AC数 >= 中位数 + 标准差)
# 使用中位数+标准差方法,既不受极端值影响,又能反映班级差异
excellent_threshold = (
median_ac + std_dev if std_dev > 0 else median_ac * 1.5
)
excellent_count = sum(1 for ac in ac_list if ac >= excellent_threshold)
excellent_rate = (
(excellent_count / user_count * 100) if user_count > 0 else 0
)
# 及格率AC数 >= 平均值的0.5倍)
pass_threshold = avg_ac * 0.5
pass_count = sum(1 for ac in ac_list if ac >= pass_threshold)
pass_rate = (pass_count / user_count * 100) if user_count > 0 else 0
# 参与度(有提交记录的学生比例)
active_count = sum(1 for sub in submission_list if sub > 0)
active_rate = (active_count / user_count * 100) if user_count > 0 else 0
# 时间段内的统计(如果提供了时间段)
recent_stats = {}
if start_time and end_time:
submissions = Submission.objects.filter(
user_id__in=user_ids,
create_time__gte=start_time,
create_time__lte=end_time,
)
recent_ac = (
submissions.filter(result=JudgeStatus.ACCEPTED)
.values("user_id", "problem_id")
.distinct()
.count()
)
recent_submission = submissions.count()
# 时间段内的用户AC数列表
recent_user_ac = {}
for user_id in user_ids:
user_recent_ac = (
submissions.filter(user_id=user_id, result=JudgeStatus.ACCEPTED)
.values("problem_id")
.distinct()
.count()
)
recent_user_ac[user_id] = user_recent_ac
recent_ac_list = sorted(recent_user_ac.values(), reverse=True)
if recent_ac_list:
recent_stats = {
"recent_total_ac": recent_ac,
"recent_total_submission": recent_submission,
"recent_avg_ac": statistics.mean(recent_ac_list),
"recent_median_ac": statistics.median(recent_ac_list),
"recent_top_10_avg": statistics.mean(
recent_ac_list[: min(10, len(recent_ac_list))]
)
if recent_ac_list
else 0,
"recent_active_count": sum(
1 for ac in recent_ac_list if ac > 0
),
}
class_comparisons.append(
{
"class_name": class_name,
"user_count": user_count,
# 基础统计
"total_ac": int(total_ac),
"total_submission": int(total_submission),
"avg_ac": round(avg_ac, 2),
# 中位数和分位数
"median_ac": round(median_ac, 2),
"q1_ac": round(q1_ac, 2),
"q3_ac": round(q3_ac, 2),
"iqr": round(iqr, 2),
# 标准差
"std_dev": round(std_dev, 2),
# 分层统计
"top_10_avg": round(top_10_avg, 2),
"bottom_10_avg": round(bottom_10_avg, 2),
"top_25_avg": round(top_25_avg, 2),
"bottom_25_avg": round(bottom_25_avg, 2),
# 比率统计
"excellent_rate": round(excellent_rate, 2),
"pass_rate": round(pass_rate, 2),
"active_rate": round(active_rate, 2),
# 正确率
"ac_rate": round(total_ac / total_submission * 100, 2)
if total_submission > 0
else 0,
# 时间段统计(如果有)
**recent_stats,
}
)
# 按总AC数排序
class_comparisons.sort(key=lambda x: (-x["total_ac"], x["total_submission"]))
return self.success(
{
"comparisons": class_comparisons,
"has_time_range": bool(start_time and end_time),
}
)

View File

@@ -1,97 +0,0 @@
"""
WebSocket consumers for configuration updates
"""
import json
import logging
from channels.generic.websocket import AsyncWebsocketConsumer
logger = logging.getLogger(__name__)
class ConfigConsumer(AsyncWebsocketConsumer):
"""
WebSocket consumer for real-time configuration updates
当管理员修改配置后,通过 WebSocket 实时推送配置变化
"""
async def connect(self):
"""处理 WebSocket 连接"""
self.user = self.scope["user"]
# 只允许认证用户连接
if not self.user.is_authenticated:
await self.close()
return
# 使用全局配置组名,所有用户都能接收配置更新
self.group_name = "config_updates"
# 加入配置更新组
await self.channel_layer.group_add(
self.group_name,
self.channel_name
)
await self.accept()
logger.info(f"Config WebSocket connected: user_id={self.user.id}, channel={self.channel_name}")
async def disconnect(self, close_code):
"""处理 WebSocket 断开连接"""
if hasattr(self, 'group_name'):
await self.channel_layer.group_discard(
self.group_name,
self.channel_name
)
logger.info(f"Config WebSocket disconnected: user_id={self.user.id}, close_code={close_code}")
async def receive(self, text_data):
"""
接收客户端消息
客户端可以发送心跳包或配置更新请求
"""
try:
data = json.loads(text_data)
message_type = data.get("type")
if message_type == "ping":
# 响应心跳包
await self.send(text_data=json.dumps({
"type": "pong",
"timestamp": data.get("timestamp")
}))
elif message_type == "config_update":
# 处理配置更新请求
key = data.get("key")
value = data.get("value")
if key and value is not None:
logger.info(f"User {self.user.id} requested config update: {key}={value}")
# 这里可以添加权限检查,只有管理员才能发送配置更新
if self.user.is_superuser:
# 广播配置更新给所有连接的客户端
await self.channel_layer.group_send(
self.group_name,
{
"type": "config_update",
"data": {
"type": "config_update",
"key": key,
"value": value
}
}
)
except json.JSONDecodeError:
logger.error(f"Invalid JSON received from user {self.user.id}")
except Exception as e:
logger.error(f"Error handling message from user {self.user.id}: {str(e)}")
async def config_update(self, event):
"""
接收来自 channel layer 的配置更新消息并发送给客户端
这个方法名对应 group_send 中的 type 字段
"""
try:
# 从 event 中提取数据并发送给客户端
await self.send(text_data=json.dumps(event["data"]))
logger.debug(f"Sent config update to user {self.user.id}: {event['data']}")
except Exception as e:
logger.error(f"Error sending config update to user {self.user.id}: {str(e)}")

View File

@@ -27,7 +27,6 @@ class CreateEditWebsiteConfigSerializer(serializers.Serializer):
allow_register = serializers.BooleanField()
submission_list_show_all = serializers.BooleanField()
class_list = serializers.ListField(child=serializers.CharField(max_length=64))
enable_maxkb = serializers.BooleanField()
class JudgeServerSerializer(serializers.ModelSerializer):

View File

@@ -1,12 +1,6 @@
from django.urls import path
from ..views import (
HitokotoAPI,
JudgeServerHeartbeatAPI,
LanguagesAPI,
WebsiteConfigAPI,
ClassUsernamesAPI,
)
from ..views import HitokotoAPI, JudgeServerHeartbeatAPI, LanguagesAPI, WebsiteConfigAPI
urlpatterns = [
path("website", WebsiteConfigAPI.as_view()),
@@ -14,5 +8,4 @@ urlpatterns = [
path("judge_server_heartbeat/", JudgeServerHeartbeatAPI.as_view()),
path("languages", LanguagesAPI.as_view()),
path("hitokoto", HitokotoAPI.as_view()),
path("class_usernames", ClassUsernamesAPI.as_view()),
]

View File

@@ -24,7 +24,6 @@ from utils.api import APIView, CSRFExemptAPIView, validate_serializer
from utils.cache import JsonDataLoader
from utils.shortcuts import send_email, get_env
from utils.xss_filter import XSSHtml
from utils.websocket import push_config_update
from .models import JudgeServer
from .serializers import (
CreateEditWebsiteConfigSerializer,
@@ -108,7 +107,6 @@ class WebsiteConfigAPI(APIView):
"allow_register",
"submission_list_show_all",
"class_list",
"enable_maxkb",
]
}
return self.success(ret)
@@ -121,10 +119,6 @@ class WebsiteConfigAPI(APIView):
with XSSHtml() as parser:
v = parser.clean(v)
setattr(SysOptions, k, v)
# 推送配置更新到所有连接的客户端
push_config_update(k, v)
return self.success()
@@ -210,6 +204,7 @@ class LanguagesAPI(APIView):
return self.success(
{
"languages": SysOptions.languages,
"spj_languages": SysOptions.spj_languages,
}
)
@@ -315,32 +310,8 @@ class RandomUsernameAPI(APIView):
class HitokotoAPI(APIView):
def get(self, request):
try:
categories = JsonDataLoader.load_data(
settings.HITOKOTO_DIR, "categories.json"
)
path = random.choice(categories).get("path")
sentences = JsonDataLoader.load_data(settings.HITOKOTO_DIR, path)
sentence = random.choice(sentences)
return self.success(sentence)
except Exception:
return self.error("获取一言失败,请稍后再试")
class ClassUsernamesAPI(APIView):
def get(self, request):
classroom = request.GET.get("classroom", "")
if not classroom:
return self.error("需要班级号")
users = User.objects.filter(class_name=classroom).order_by("-create_time")
names = []
for user in users:
prefix = f"ks{classroom}"
result = (
user.username[len(prefix) :]
if user.username.startswith(prefix)
else user.username
)
names.append(result)
return self.success(names)
categories = JsonDataLoader.load_data(settings.HITOKOTO_DIR, "categories.json")
path = random.choice(categories).get("path")
sentences = JsonDataLoader.load_data(settings.HITOKOTO_DIR, path)
sentence = random.choice(sentences)
return self.success(sentence)

View File

@@ -46,13 +46,10 @@ class Contest(models.Model):
# 是否有权查看problem 的一些统计信息 诸如submission_number, accepted_number 等
def problem_details_permission(self, user):
return (
self.rule_type == ContestRuleType.ACM
or self.status == ContestStatus.CONTEST_ENDED
or user.is_authenticated
and user.is_contest_admin(self)
or self.real_time_rank
)
return self.rule_type == ContestRuleType.ACM or \
self.status == ContestStatus.CONTEST_ENDED or \
user.is_authenticated and user.is_contest_admin(self) or \
self.real_time_rank
class Meta:
db_table = "contest"

View File

@@ -6,9 +6,7 @@ from ipaddress import ip_network
import dateutil.parser
from django.http import FileResponse
from problem.models import Problem
from account.decorators import super_admin_required
from account.decorators import check_contest_permission, ensure_created_by
from account.models import User
from submission.models import Submission, JudgeStatus
from utils.api import APIView, validate_serializer
@@ -17,20 +15,14 @@ from utils.constants import CacheKey
from utils.shortcuts import rand_str
from utils.tasks import delete_files
from ..models import Contest, ContestAnnouncement, ACMContestRank
from ..serializers import (
ContestAnnouncementSerializer,
ContestAdminSerializer,
CreateConetestSeriaizer,
CreateContestAnnouncementSerializer,
EditConetestSeriaizer,
EditContestAnnouncementSerializer,
ACMContesHelperSerializer,
)
from ..serializers import (ContestAnnouncementSerializer, ContestAdminSerializer,
CreateConetestSeriaizer, CreateContestAnnouncementSerializer,
EditConetestSeriaizer, EditContestAnnouncementSerializer,
ACMContesHelperSerializer, )
class ContestAPI(APIView):
@validate_serializer(CreateConetestSeriaizer)
@super_admin_required
def post(self, request):
data = request.data
data["start_time"] = dateutil.parser.parse(data["start_time"])
@@ -49,11 +41,11 @@ class ContestAPI(APIView):
return self.success(ContestAdminSerializer(contest).data)
@validate_serializer(EditConetestSeriaizer)
@super_admin_required
def put(self, request):
data = request.data
try:
contest = Contest.objects.get(id=data.pop("id"))
ensure_created_by(contest, request.user)
except Contest.DoesNotExist:
return self.error("Contest does not exist")
data["start_time"] = dateutil.parser.parse(data["start_time"])
@@ -76,29 +68,28 @@ class ContestAPI(APIView):
contest.save()
return self.success(ContestAdminSerializer(contest).data)
@super_admin_required
def get(self, request):
contest_id = request.GET.get("id")
if contest_id:
try:
contest = Contest.objects.get(id=contest_id)
ensure_created_by(contest, request.user)
return self.success(ContestAdminSerializer(contest).data)
except Contest.DoesNotExist:
return self.error("Contest does not exist")
contests = Contest.objects.all().order_by("-create_time")
if request.user.is_admin():
contests = contests.filter(created_by=request.user)
keyword = request.GET.get("keyword")
if keyword:
contests = contests.filter(title__contains=keyword)
return self.success(
self.paginate_data(request, contests, ContestAdminSerializer)
)
return self.success(self.paginate_data(request, contests, ContestAdminSerializer))
class ContestAnnouncementAPI(APIView):
@validate_serializer(CreateContestAnnouncementSerializer)
@super_admin_required
def post(self, request):
"""
Create one contest_announcement.
@@ -106,6 +97,7 @@ class ContestAnnouncementAPI(APIView):
data = request.data
try:
contest = Contest.objects.get(id=data.pop("contest_id"))
ensure_created_by(contest, request.user)
data["contest"] = contest
data["created_by"] = request.user
except Contest.DoesNotExist:
@@ -114,7 +106,6 @@ class ContestAnnouncementAPI(APIView):
return self.success(ContestAnnouncementSerializer(announcement).data)
@validate_serializer(EditContestAnnouncementSerializer)
@super_admin_required
def put(self, request):
"""
update contest_announcement
@@ -122,6 +113,7 @@ class ContestAnnouncementAPI(APIView):
data = request.data
try:
contest_announcement = ContestAnnouncement.objects.get(id=data.pop("id"))
ensure_created_by(contest_announcement, request.user)
except ContestAnnouncement.DoesNotExist:
return self.error("Contest announcement does not exist")
for k, v in data.items():
@@ -129,17 +121,19 @@ class ContestAnnouncementAPI(APIView):
contest_announcement.save()
return self.success()
@super_admin_required
def delete(self, request):
"""
Delete one contest_announcement.
"""
contest_announcement_id = request.GET.get("id")
if contest_announcement_id:
ContestAnnouncement.objects.filter(id=contest_announcement_id).delete()
if request.user.is_admin():
ContestAnnouncement.objects.filter(id=contest_announcement_id,
contest__created_by=request.user).delete()
else:
ContestAnnouncement.objects.filter(id=contest_announcement_id).delete()
return self.success()
@super_admin_required
def get(self, request):
"""
Get one contest_announcement or contest_announcement list.
@@ -147,71 +141,45 @@ class ContestAnnouncementAPI(APIView):
contest_announcement_id = request.GET.get("id")
if contest_announcement_id:
try:
contest_announcement = ContestAnnouncement.objects.get(
id=contest_announcement_id
)
return self.success(
ContestAnnouncementSerializer(contest_announcement).data
)
contest_announcement = ContestAnnouncement.objects.get(id=contest_announcement_id)
ensure_created_by(contest_announcement, request.user)
return self.success(ContestAnnouncementSerializer(contest_announcement).data)
except ContestAnnouncement.DoesNotExist:
return self.error("Contest announcement does not exist")
contest_id = request.GET.get("contest_id")
if not contest_id:
return self.error("Parameter error")
contest_announcements = ContestAnnouncement.objects.filter(
contest_id=contest_id
)
contest_announcements = ContestAnnouncement.objects.filter(contest_id=contest_id)
if request.user.is_admin():
contest_announcements = contest_announcements.filter(created_by=request.user)
keyword = request.GET.get("keyword")
if keyword:
contest_announcements = contest_announcements.filter(
title__contains=keyword
)
return self.success(
ContestAnnouncementSerializer(contest_announcements, many=True).data
)
contest_announcements = contest_announcements.filter(title__contains=keyword)
return self.success(ContestAnnouncementSerializer(contest_announcements, many=True).data)
class ACMContestHelper(APIView):
@super_admin_required
@check_contest_permission(check_type="ranks")
def get(self, request):
contest_id = request.GET.get("contest_id")
if not contest_id:
return self.error("Parameter error, contest_id is required")
try:
contest = Contest.objects.get(id=contest_id, visible=True)
except Contest.DoesNotExist:
return self.error("Contest does not exist")
problems = Problem.objects.filter(contest=contest).values("id", "_id")
problem_id_map = {str(p["id"]): p["_id"] for p in problems}
ranks = ACMContestRank.objects.filter(
contest=contest, accepted_number__gt=0
).values(
"id", "user__username", "user__userprofile__real_name", "submission_info"
)
ranks = ACMContestRank.objects.filter(contest=self.contest, accepted_number__gt=0) \
.values("id", "user__username", "user__userprofile__real_name", "submission_info")
results = []
for rank in ranks:
for problem_id, info in rank["submission_info"].items():
if info["is_ac"]:
results.append(
{
"id": rank["id"],
"username": rank["user__username"],
"real_name": rank["user__userprofile__real_name"],
"problem_id": problem_id,
"problem_display_id": problem_id_map.get(
problem_id, problem_id
),
"ac_info": info,
"checked": info.get("checked", False),
}
)
results.append({
"id": rank["id"],
"username": rank["user__username"],
"real_name": rank["user__userprofile__real_name"],
"problem_id": problem_id,
"ac_info": info,
"checked": info.get("checked", False)
})
results.sort(key=lambda x: -x["ac_info"]["ac_time"])
return self.success(results)
@super_admin_required
@check_contest_permission(check_type="ranks")
@validate_serializer(ACMContesHelperSerializer)
def put(self, request):
data = request.data
@@ -232,9 +200,7 @@ class DownloadContestSubmissions(APIView):
problem_ids = contest.problem_set.all().values_list("id", "_id")
id2display_id = {k[0]: k[1] for k in problem_ids}
ac_map = {k[0]: False for k in problem_ids}
submissions = Submission.objects.filter(
contest=contest, result=JudgeStatus.ACCEPTED
).order_by("-create_time")
submissions = Submission.objects.filter(contest=contest, result=JudgeStatus.ACCEPTED).order_by("-create_time")
user_ids = submissions.values_list("user_id", flat=True)
users = User.objects.filter(id__in=user_ids)
path = f"/tmp/{rand_str()}.zip"
@@ -248,25 +214,21 @@ class DownloadContestSubmissions(APIView):
problem_id = submission.problem_id
if user_ac_map[problem_id]:
continue
file_name = (
f"{user.username}_{id2display_id[submission.problem_id]}.txt"
)
file_name = f"{user.username}_{id2display_id[submission.problem_id]}.txt"
compression = zipfile.ZIP_DEFLATED
zip_file.writestr(
zinfo_or_arcname=f"{file_name}",
data=submission.code,
compress_type=compression,
)
zip_file.writestr(zinfo_or_arcname=f"{file_name}",
data=submission.code,
compress_type=compression)
user_ac_map[problem_id] = True
return path
@super_admin_required
def get(self, request):
contest_id = request.GET.get("contest_id")
if not contest_id:
return self.error("Parameter error")
try:
contest = Contest.objects.get(id=contest_id)
ensure_created_by(contest, request.user)
except Contest.DoesNotExist:
return self.error("Contest does not exist")
@@ -275,7 +237,5 @@ class DownloadContestSubmissions(APIView):
delete_files.send_with_options(args=(zip_path,), delay=300_000)
resp = FileResponse(open(zip_path, "rb"))
resp["Content-Type"] = "application/zip"
resp["Content-Disposition"] = (
f"attachment;filename={os.path.basename(zip_path)}"
)
resp["Content-Disposition"] = f"attachment;filename={os.path.basename(zip_path)}"
return resp

View File

@@ -2,23 +2,6 @@ location /public {
root /data;
}
# WebSocket 支持
location /ws/ {
proxy_pass http://websocket;
proxy_http_version 1.1;
proxy_set_header Upgrade $http_upgrade;
proxy_set_header Connection "upgrade";
proxy_set_header Host $http_host;
proxy_set_header X-Real-IP __IP_HEADER__;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
# WebSocket 超时设置
proxy_connect_timeout 7d;
proxy_send_timeout 7d;
proxy_read_timeout 7d;
}
location /api {
include api_proxy.conf;
}

View File

@@ -38,11 +38,6 @@ http {
keepalive 32;
}
upstream websocket {
server 127.0.0.1:8001;
keepalive 32;
}
add_header X-XSS-Protection "1; mode=block" always;
add_header X-Frame-Options SAMEORIGIN always;
add_header X-Content-Type-Options nosniff always;
@@ -51,7 +46,7 @@ http {
listen 8000 default_server;
server_name _;
include locations.conf;
include http_locations.conf;
}
# server {

View File

@@ -1,76 +1,31 @@
annotated-types==0.7.0
anyio==4.12.0
asgiref==3.11.0
attrs==25.4.0
autobahn==25.12.2
automat==25.4.16
cbor2==5.7.1
certifi==2025.11.12
cffi==2.0.0
channels==4.3.2
channels-redis==4.3.0
charset-normalizer==3.4.4
constantly==23.10.4
coverage==6.5.0
cryptography==46.0.3
daphne==4.2.1
distro==1.9.0
django==6.0
django-cas-ng==5.0.1
asgiref==3.8.1
certifi==2025.6.15
charset-normalizer==3.4.2
click==8.2.1
django==5.2.3
django-dbconn-retry==0.1.8
django-dramatiq==0.13.0
django-redis==5.4.0
djangorestframework==3.16.0
dramatiq==1.17.0
entrypoints==0.4
dramatiq==1.18.0
envelopes==0.4
flake8==7.0.0
flake8-coding==1.3.2
flake8-quotes==3.3.2
gunicorn==22.0.0
gunicorn==23.0.0
h11==0.16.0
httpcore==1.0.9
httpx==0.28.1
hyperlink==21.0.0
idna==3.11
incremental==24.11.0
jiter==0.12.0
jsonfield==3.1.0
lxml==6.0.2
mccabe==0.7.0
msgpack==1.1.2
openai==2.14.0
otpauth==1.0.1
idna==3.10
otpauth==2.2.1
packaging==25.0
pillow==10.2.0
prometheus-client==0.23.1
pillow==11.2.1
prometheus-client==0.22.1
psycopg==3.2.9
psycopg-binary==3.2.9
py-ubjson==0.16.1
pyasn1==0.6.1
pyasn1-modules==0.4.2
pycodestyle==2.11.1
pycparser==2.23
pydantic==2.12.5
pydantic-core==2.41.5
pyflakes==3.2.0
pyopenssl==25.3.0
python-cas==1.7.1
python-dateutil==2.8.2
python-dateutil==2.9.0.post0
qrcode==8.2
raven==6.10.0
redis==7.1.0
requests==2.32.5
service-identity==24.2.0
redis==6.2.0
requests==2.32.4
six==1.17.0
sniffio==1.3.1
sqlparse==0.5.5
tqdm==4.67.1
twisted==25.5.0
txaio==25.12.2
typing-extensions==4.15.0
typing-inspection==0.4.2
ujson==5.11.0
urllib3==2.6.2
xlsxwriter==3.2.0
zope-interface==8.1.1
sqlparse==0.5.3
typing-extensions==4.14.0
urllib3==2.4.0
uvicorn==0.35.0
xlsxwriter==3.2.5

View File

@@ -28,7 +28,7 @@ stopwaitsecs = 5
killasgroup=true
[program:gunicorn]
command=gunicorn oj.wsgi --user server --group spj --bind 127.0.0.1:8080 --workers %(ENV_MAX_WORKER_NUM)s --threads 4 --max-requests-jitter 10000 --max-requests 1000000 --keep-alive 32
command=gunicorn oj.asgi --user server --group spj --bind 127.0.0.1:8080 --workers %(ENV_MAX_WORKER_NUM)s --threads 4 --max-requests-jitter 10000 --max-requests 1000000 --keep-alive 32 --worker-class uvicorn.workers.UvicornWorker
directory=/app/
stdout_logfile=/data/log/gunicorn.log
stderr_logfile=/data/log/gunicorn.log
@@ -38,18 +38,6 @@ startsecs=5
stopwaitsecs = 5
killasgroup=true
[program:daphne]
command=daphne -b 127.0.0.1 -p 8001 --access-log /data/log/daphne_access.log oj.asgi:application
directory=/app/
user=server
stdout_logfile=/data/log/daphne.log
stderr_logfile=/data/log/daphne.log
autostart=true
autorestart=true
startsecs=5
stopwaitsecs = 5
killasgroup=true
[program:dramatiq]
command=python3 manage.py rundramatiq --processes %(ENV_MAX_WORKER_NUM)s --threads 4
directory=/app/

173
dev.py
View File

@@ -1,173 +0,0 @@
#!/usr/bin/env python
"""
WebSocket 开发服务器启动脚本
同时启动 Daphne (WebSocket) 和 Django runserver (开发服务器)
支持 Windows 和 Linux
"""
import os
import sys
import subprocess
import platform
import signal
from pathlib import Path
from threading import Thread
import time
def main():
# 获取项目根目录
base_dir = Path(__file__).resolve().parent
os.chdir(base_dir)
print("=" * 70)
print("启动 Django 开发服务器 + WebSocket 服务器")
print("=" * 70)
print()
# 检测操作系统
is_windows = platform.system() == "Windows"
# 检查虚拟环境(跨平台)
if is_windows:
# Windows: .venv/Scripts/python.exe
venv_python = base_dir / ".venv" / "Scripts" / "python.exe"
else:
# Linux/Mac: .venv/bin/python
venv_python = base_dir / ".venv" / "bin" / "python"
if venv_python.exists():
print("[✓] 使用虚拟环境: .venv")
python_exec = str(venv_python)
else:
print("[!] 未找到 .venv 虚拟环境,使用全局 Python")
print("[!] 建议创建虚拟环境: python -m venv .venv")
python_exec = sys.executable
# 检查 daphne 是否安装
try:
result = subprocess.run(
[python_exec, "-m", "daphne", "--version"], capture_output=True, text=True
)
if result.returncode != 0 and result.returncode != 2:
print("[✗] 错误: Daphne 未安装")
print("请运行: pip install daphne channels channels-redis")
sys.exit(1)
except FileNotFoundError:
print("[✗] 错误: 无法找到 Python 解释器")
sys.exit(1)
# 进程列表
processes = []
# 启动两个服务器
try:
# 启动 Django runserver (端口 8000)
print("[*] 启动 Django 开发服务器 (端口 8000)...")
runserver_cmd = ["uv", "run", "manage.py", "runserver", "0.0.0.0:8000"]
runserver_process = subprocess.Popen(
runserver_cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
bufsize=1,
universal_newlines=True,
)
processes.append(("Django Runserver", runserver_process))
# 等待一下,让 runserver 先启动
time.sleep(1)
# 启动 Daphne (端口 8001)
print("[*] 启动 Daphne WebSocket 服务器 (端口 8001)...")
daphne_cmd = [
python_exec,
"-m",
"daphne",
"-b",
"0.0.0.0",
"-p",
"8001",
"oj.asgi:application",
]
daphne_process = subprocess.Popen(
daphne_cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
bufsize=1,
universal_newlines=True,
)
processes.append(("Daphne", daphne_process))
print()
print("[✓] 所有服务器已启动")
print()
# 创建输出线程
def print_output(name, process):
"""打印进程输出"""
for line in process.stdout:
print(f"[{name}] {line}", end="")
# 启动输出线程
threads = []
for name, process in processes:
thread = Thread(target=print_output, args=(name, process), daemon=True)
thread.start()
threads.append(thread)
# 等待进程(任意一个退出就退出)
while True:
for name, process in processes:
if process.poll() is not None:
print(f"\n[!] {name} 已退出")
raise KeyboardInterrupt
time.sleep(0.5)
except KeyboardInterrupt:
print()
print()
print("[*] 正在停止所有服务器...")
# 终止所有进程
for name, process in processes:
try:
if process.poll() is None: # 如果进程还在运行
print(f"[*] 停止 {name}...")
if is_windows:
# Windows 使用 CTRL_C_EVENT
process.send_signal(signal.CTRL_C_EVENT)
else:
# Unix 使用 SIGTERM
process.terminate()
# 等待进程结束(最多 5 秒)
try:
process.wait(timeout=5)
except subprocess.TimeoutExpired:
print(f"[!] {name} 未响应,强制终止...")
process.kill()
process.wait()
except Exception as e:
print(f"[!] 停止 {name} 时出错: {e}")
print()
print("[✓] 所有服务器已停止")
except Exception as e:
print(f"[✗] 错误: {e}")
# 清理所有进程
for name, process in processes:
try:
if process.poll() is None:
process.kill()
process.wait()
except Exception:
pass
sys.exit(1)
if __name__ == "__main__":
main()

View File

View File

View File

@@ -1,7 +0,0 @@
from django.apps import AppConfig
class FlowchartConfig(AppConfig):
default_auto_field = 'django.db.models.BigAutoField'
name = 'flowchart'
verbose_name = '流程图管理'

View File

@@ -1,83 +0,0 @@
"""
WebSocket consumers for flowchart evaluation updates
"""
import json
import logging
from channels.generic.websocket import AsyncWebsocketConsumer
logger = logging.getLogger(__name__)
class FlowchartConsumer(AsyncWebsocketConsumer):
"""
WebSocket consumer for real-time flowchart evaluation updates
当用户提交流程图后,通过 WebSocket 实时接收AI评分状态更新
"""
async def connect(self):
"""处理 WebSocket 连接"""
self.user = self.scope["user"]
# 只允许认证用户连接
if not self.user.is_authenticated:
await self.close()
return
# 使用用户 ID 作为组名,这样可以向特定用户推送消息
self.group_name = f"flowchart_user_{self.user.id}"
# 加入用户专属的组
await self.channel_layer.group_add(
self.group_name,
self.channel_name
)
await self.accept()
logger.info(f"Flowchart WebSocket connected: user_id={self.user.id}, channel={self.channel_name}")
async def disconnect(self, close_code):
"""处理 WebSocket 断开连接"""
if hasattr(self, 'group_name'):
await self.channel_layer.group_discard(
self.group_name,
self.channel_name
)
logger.info(f"Flowchart WebSocket disconnected: user_id={self.user.id}, close_code={close_code}")
async def receive(self, text_data):
"""
接收客户端消息
客户端可以发送心跳包或订阅特定流程图提交
"""
try:
data = json.loads(text_data)
message_type = data.get("type")
if message_type == "ping":
# 响应心跳包
await self.send(text_data=json.dumps({
"type": "pong",
"timestamp": data.get("timestamp")
}))
elif message_type == "subscribe":
# 订阅特定流程图提交的更新
submission_id = data.get("submission_id")
if submission_id:
logger.info(f"User {self.user.id} subscribed to flowchart submission {submission_id}")
# 可以在这里做额外的订阅逻辑
except json.JSONDecodeError:
logger.error(f"Invalid JSON received from user {self.user.id}")
except Exception as e:
logger.error(f"Error handling message from user {self.user.id}: {str(e)}")
async def flowchart_evaluation_update(self, event):
"""
接收来自 channel layer 的流程图评分更新消息并发送给客户端
这个方法名对应 push_flowchart_evaluation_update 中的 type 字段
"""
try:
# 从 event 中提取数据并发送给客户端
await self.send(text_data=json.dumps(event["data"]))
logger.debug(f"Sent flowchart evaluation update to user {self.user.id}: {event['data']}")
except Exception as e:
logger.error(f"Error sending flowchart evaluation update to user {self.user.id}: {str(e)}")

View File

@@ -1,45 +0,0 @@
# Generated by Django 5.2.3 on 2025-10-11 14:57
import django.db.models.deletion
import utils.shortcuts
from django.conf import settings
from django.db import migrations, models
class Migration(migrations.Migration):
initial = True
dependencies = [
('problem', '0004_problem_allow_flowchart_problem_flowchart_data_and_more'),
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
]
operations = [
migrations.CreateModel(
name='FlowchartSubmission',
fields=[
('id', models.TextField(db_index=True, default=utils.shortcuts.rand_str, primary_key=True, serialize=False)),
('mermaid_code', models.TextField()),
('flowchart_data', models.JSONField(default=dict)),
('status', models.IntegerField(default=0)),
('create_time', models.DateTimeField(auto_now_add=True)),
('ai_score', models.FloatField(blank=True, null=True)),
('ai_grade', models.CharField(blank=True, max_length=10, null=True)),
('ai_feedback', models.TextField(blank=True, null=True)),
('ai_suggestions', models.TextField(blank=True, null=True)),
('ai_criteria_details', models.JSONField(default=dict)),
('ai_provider', models.CharField(default='deepseek', max_length=50)),
('ai_model', models.CharField(default='deepseek-chat', max_length=50)),
('processing_time', models.FloatField(blank=True, null=True)),
('evaluation_time', models.DateTimeField(blank=True, null=True)),
('problem', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='flowchart_submissions', to='problem.problem')),
('user', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='flowchart_submissions', to=settings.AUTH_USER_MODEL)),
],
options={
'db_table': 'flowchart_submission',
'ordering': ['-create_time'],
'indexes': [models.Index(fields=['user', 'create_time'], name='flowchart_user_time_idx'), models.Index(fields=['problem', 'create_time'], name='flowchart_problem_time_idx'), models.Index(fields=['status'], name='flowchart_status_idx')],
},
),
]

View File

@@ -1,65 +0,0 @@
from django.db import models
from django.contrib.auth import get_user_model
from utils.shortcuts import rand_str
from problem.models import Problem
User = get_user_model()
class FlowchartSubmissionStatus:
PENDING = 0 # 等待AI评分
PROCESSING = 1 # AI评分中
COMPLETED = 2 # 评分完成
FAILED = 3 # 评分失败
class FlowchartSubmission(models.Model):
"""流程图提交模型"""
id = models.TextField(default=rand_str, primary_key=True, db_index=True)
# 基础信息
user = models.ForeignKey(User, on_delete=models.CASCADE, related_name='flowchart_submissions')
problem = models.ForeignKey(Problem, on_delete=models.CASCADE, related_name='flowchart_submissions')
# 提交内容
mermaid_code = models.TextField() # Mermaid代码
flowchart_data = models.JSONField(default=dict) # 流程图元数据
# 状态信息
status = models.IntegerField(default=FlowchartSubmissionStatus.PENDING)
create_time = models.DateTimeField(auto_now_add=True)
# AI评分结果
ai_score = models.FloatField(null=True, blank=True) # AI评分 (0-100)
ai_grade = models.CharField(max_length=10, null=True, blank=True) # 等级 (S/A/B/C)
ai_feedback = models.TextField(null=True, blank=True) # AI反馈
ai_suggestions = models.TextField(null=True, blank=True) # AI建议
ai_criteria_details = models.JSONField(default=dict) # 详细评分标准
# 处理信息
ai_provider = models.CharField(max_length=50, default='deepseek')
ai_model = models.CharField(max_length=50, default='deepseek-chat')
processing_time = models.FloatField(null=True, blank=True) # AI处理耗时(秒)
evaluation_time = models.DateTimeField(null=True, blank=True) # 评分完成时间
class Meta:
db_table = 'flowchart_submission'
ordering = ['-create_time']
indexes = [
models.Index(fields=['user', 'create_time'], name='flowchart_user_time_idx'),
models.Index(fields=['problem', 'create_time'], name='flowchart_problem_time_idx'),
models.Index(fields=['status'], name='flowchart_status_idx'),
]
def __str__(self):
return f"FlowchartSubmission {self.id}"
def check_user_permission(self, user, check_share=True):
"""检查用户权限"""
if (
self.user_id == user.id
or not user.is_regular_user()
or self.problem.created_by_id == user.id
):
return True
return False

View File

@@ -1,91 +0,0 @@
from rest_framework import serializers
from .models import FlowchartSubmission
class CreateFlowchartSubmissionSerializer(serializers.Serializer):
problem_id = serializers.IntegerField()
mermaid_code = serializers.CharField()
flowchart_data = serializers.JSONField(required=False, default=dict)
def validate_mermaid_code(self, value):
if not value.strip():
raise serializers.ValidationError("Mermaid代码不能为空")
return value
class FlowchartSubmissionSerializer(serializers.ModelSerializer):
class Meta:
model = FlowchartSubmission
fields = [
"id",
"user",
"problem",
"mermaid_code",
"flowchart_data",
"status",
"create_time",
"ai_score",
"ai_grade",
"ai_feedback",
"ai_suggestions",
"ai_criteria_details",
"ai_provider",
"ai_model",
"processing_time",
"evaluation_time",
]
read_only_fields = ["id", "create_time", "evaluation_time"]
class FlowchartSubmissionListSerializer(serializers.ModelSerializer):
"""用于列表显示的简化序列化器"""
username = serializers.CharField(source="user.username")
problem = serializers.CharField(source="problem._id")
problem_title = serializers.CharField(source="problem.title")
class Meta:
model = FlowchartSubmission
fields = [
"id",
"username",
"problem_title",
"problem",
"status",
"create_time",
"ai_score",
"ai_grade",
"ai_provider",
"ai_model",
"processing_time",
"evaluation_time",
]
class FlowchartSubmissionSummarySerializer(serializers.ModelSerializer):
"""用于AI详情页面的极简序列化器只包含必要字段"""
problem_title = serializers.CharField(source="problem.title")
problem__id = serializers.CharField(source="problem._id")
class Meta:
model = FlowchartSubmission
fields = [
"id",
"problem__id",
"problem_title",
"ai_score",
"ai_grade",
"create_time",
]
class FlowchartSubmissionMergedSerializer(serializers.Serializer):
"""合并后的流程图提交序列化器"""
problem__id = serializers.CharField()
problem_title = serializers.CharField()
submission_count = serializers.IntegerField()
best_score = serializers.FloatField()
best_grade = serializers.CharField()
latest_submission_time = serializers.DateTimeField()
avg_score = serializers.FloatField()

View File

@@ -1,180 +0,0 @@
import dramatiq
import json
import time
from django.db import transaction
from django.utils import timezone
from utils.openai import get_ai_client
from utils.shortcuts import DRAMATIQ_WORKER_ARGS
from .models import FlowchartSubmission, FlowchartSubmissionStatus
@dramatiq.actor(**DRAMATIQ_WORKER_ARGS(max_retries=3))
def evaluate_flowchart_task(submission_id):
"""异步AI评分任务"""
try:
submission = FlowchartSubmission.objects.get(id=submission_id)
# 更新状态为处理中
submission.status = FlowchartSubmissionStatus.PROCESSING
submission.save()
start_time = time.time()
# 使用固定评分标准
system_prompt = build_evaluation_prompt(submission.problem)
# 构建用户提示词,包含标准答案对比
user_prompt = f"""
请对以下Mermaid流程图进行评分
学生提交的流程图:
```mermaid
{submission.mermaid_code}
```
标准答案参考:
```mermaid
{submission.problem.mermaid_code}
```
"""
# 如果有流程图提示,添加到提示词中
if submission.problem.flowchart_hint:
user_prompt += f"""
设计提示:{submission.problem.flowchart_hint}
"""
user_prompt += """
请按照评分标准进行详细评估并给出0-100的分数。
"""
# 调用AI进行评分
client = get_ai_client()
response = client.chat.completions.create(
model="deepseek-reasoner",
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
],
temperature=0.3,
)
ai_response = response.choices[0].message.content
score_data = parse_ai_evaluation_response(ai_response)
processing_time = time.time() - start_time
# 保存评分结果
with transaction.atomic():
submission.ai_score = score_data['score']
submission.ai_grade = score_data['grade']
submission.ai_feedback = score_data['feedback']
submission.ai_suggestions = score_data.get('suggestions', '')
submission.ai_criteria_details = score_data.get('criteria_details', {})
submission.ai_provider = 'deepseek'
submission.ai_model = 'deepseek-reasoner'
submission.processing_time = processing_time
submission.status = FlowchartSubmissionStatus.COMPLETED
submission.evaluation_time = timezone.now()
submission.save()
# 推送评分完成通知
from utils.websocket import push_flowchart_evaluation_update
push_flowchart_evaluation_update(
submission_id=str(submission.id),
user_id=submission.user_id,
data={
"type": "flowchart_evaluation_completed",
"score": score_data['score'],
"grade": score_data['grade'],
}
)
except Exception as e:
# 处理失败
submission.status = FlowchartSubmissionStatus.FAILED
submission.save()
# 推送错误通知
from utils.websocket import push_flowchart_evaluation_update
push_flowchart_evaluation_update(
submission_id=str(submission.id),
user_id=submission.user_id,
data={
"type": "flowchart_evaluation_failed",
"submission_id": str(submission.id),
"error": str(e)
}
)
raise e
def build_evaluation_prompt(problem):
"""构建AI评分提示词 - 使用固定标准"""
# 使用固定的评分标准
criteria_text = """
- 逻辑正确性 (权重: 1.0, 最高分: 40): 检查流程图的逻辑是否正确,包括条件判断、循环结构等
- 完整性 (权重: 0.8, 最高分: 30): 检查流程图是否包含所有必要的步骤和分支
- 规范性 (权重: 0.6, 最高分: 20): 检查流程图符号使用是否规范,是否符合标准
- 清晰度 (权重: 0.4, 最高分: 10): 评估流程图的整体布局和连线情况(不用考虑节点ID是否复杂)
"""
return f"""
你是一个专业的编程教学助手负责评估学生提交的Mermaid流程图。
评分标准:
{criteria_text}
评分要求:
1. 仔细分析流程图的逻辑正确性、完整性和清晰度
2. 检查是否涵盖了题目的所有要求
3. 评估流程图的规范性和可读性(不用考虑节点ID是否复杂)
4. 给出0-100的分数
5. 提供详细的反馈和改进建议
评分等级:
- S级 (90-100分): 优秀,逻辑清晰,完全符合要求
- A级 (80-89分): 良好,基本符合要求,有少量改进空间
- B级 (70-79分): 及格,基本正确但存在一些问题
- C级 (0-69分): 需要改进,存在明显问题
请以JSON格式返回评分结果
{{
"score": 85,
"grade": "A",
"feedback": "详细的反馈内容",
"suggestions": "改进建议",
"criteria_details": {{
"逻辑正确性": {{"score": 35, "max": 40, "comment": "逻辑基本正确"}},
"完整性": {{"score": 25, "max": 30, "comment": "缺少部分步骤"}},
"规范性": {{"score": 18, "max": 20, "comment": "符号使用规范"}},
"清晰度": {{"score": 8, "max": 10, "comment": "布局清晰"}}
}}
}}
"""
def parse_ai_evaluation_response(ai_response):
"""解析AI评分响应"""
try:
import re
json_match = re.search(r'\{.*\}', ai_response, re.DOTALL)
if json_match:
data = json.loads(json_match.group())
else:
data = {
"score": 60,
"grade": "C",
"feedback": "AI评分解析失败请重新提交",
"suggestions": "",
"criteria_details": {}
}
return data
except Exception:
return {
"score": 60,
"grade": "C",
"feedback": "AI评分解析失败请重新提交",
"suggestions": "",
"criteria_details": {}
}

View File

@@ -1 +0,0 @@
# URLs package

View File

@@ -1,16 +0,0 @@
from django.urls import path
from ..views.oj import (
FlowchartSubmissionAPI,
FlowchartSubmissionListAPI,
FlowchartSubmissionRetryAPI,
FlowchartSubmissionCurrentAPI,
FlowchartSubmissionDetailAPI,
)
urlpatterns = [
path('flowchart/submission', FlowchartSubmissionAPI.as_view()),
path('flowchart/submissions', FlowchartSubmissionListAPI.as_view()),
path('flowchart/submission/retry', FlowchartSubmissionRetryAPI.as_view()),
path('flowchart/submission/detail', FlowchartSubmissionDetailAPI.as_view()),
path('flowchart/submission/current', FlowchartSubmissionCurrentAPI.as_view()),
]

View File

@@ -1,3 +0,0 @@
from django.shortcuts import render
# Create your views here.

View File

@@ -1 +0,0 @@
# Views package

View File

@@ -1,199 +0,0 @@
from utils.api import APIView
from account.decorators import login_required
from flowchart.models import FlowchartSubmission, FlowchartSubmissionStatus
from flowchart.serializers import (
CreateFlowchartSubmissionSerializer,
FlowchartSubmissionSerializer,
FlowchartSubmissionListSerializer,
)
from flowchart.tasks import evaluate_flowchart_task
from problem.models import Problem
class FlowchartSubmissionAPI(APIView):
@login_required
def post(self, request):
"""创建流程图提交"""
serializer = CreateFlowchartSubmissionSerializer(data=request.data)
if not serializer.is_valid():
return self.error(serializer.errors)
data = serializer.validated_data
# 验证题目存在
try:
from problem.models import Problem
problem = Problem.objects.get(id=data["problem_id"])
except Problem.DoesNotExist:
return self.error("Problem doesn't exist")
# 验证题目是否允许流程图提交
if not problem.allow_flowchart:
return self.error("This problem does not allow flowchart submission")
# 创建提交记录
submission = FlowchartSubmission.objects.create(
user=request.user,
problem=problem,
mermaid_code=data["mermaid_code"],
flowchart_data=data.get("flowchart_data", {}),
)
# 启动AI评分任务
evaluate_flowchart_task.send(submission.id)
return self.success({"submission_id": submission.id, "status": "pending"})
@login_required
def get(self, request):
"""获取流程图提交详情"""
submission_id = request.GET.get("id")
if not submission_id:
return self.error("submission_id is required")
try:
submission = FlowchartSubmission.objects.get(id=submission_id)
except FlowchartSubmission.DoesNotExist:
return self.error("Submission doesn't exist")
if not submission.check_user_permission(request.user):
return self.error("No permission for this submission")
serializer = FlowchartSubmissionSerializer(submission)
return self.success(serializer.data)
class FlowchartSubmissionListAPI(APIView):
def get(self, request):
"""获取流程图提交列表"""
username = request.GET.get("username")
problem_id = request.GET.get("problem_id")
myself = request.GET.get("myself")
queryset = FlowchartSubmission.objects.select_related("user", "problem")
if problem_id:
try:
problem = Problem.objects.get(
_id=problem_id, contest_id__isnull=True, visible=True
)
except Problem.DoesNotExist:
return self.error("Problem doesn't exist")
queryset = queryset.filter(problem=problem)
if myself and myself == "1":
queryset = queryset.filter(user=request.user)
if username:
queryset = queryset.filter(user__username__icontains=username)
data = self.paginate_data(request, queryset)
data["results"] = FlowchartSubmissionListSerializer(
data["results"], many=True
).data
return self.success(data)
class FlowchartSubmissionRetryAPI(APIView):
@login_required
def post(self, request):
"""重新触发AI评分"""
submission_id = request.data.get("submission_id")
if not submission_id:
return self.error("submission_id is required")
try:
submission = FlowchartSubmission.objects.get(id=submission_id)
except FlowchartSubmission.DoesNotExist:
return self.error("Submission doesn't exist")
# 检查权限
if not submission.check_user_permission(request.user):
return self.error("No permission for this submission")
# 检查是否可以重新评分
if submission.status not in [
FlowchartSubmissionStatus.FAILED,
FlowchartSubmissionStatus.COMPLETED,
]:
return self.error("Submission is not in a state that allows retry")
# 重置状态并重新启动AI评分
submission.status = FlowchartSubmissionStatus.PENDING
submission.ai_score = None
submission.ai_grade = None
submission.ai_feedback = None
submission.ai_suggestions = None
submission.ai_criteria_details = {}
submission.processing_time = None
submission.evaluation_time = None
submission.save()
# 重新启动AI评分任务
evaluate_flowchart_task.send(submission.id)
return self.success(
{
"submission_id": submission.id,
"status": "pending",
"message": "AI evaluation restarted",
}
)
class FlowchartSubmissionDetailAPI(APIView):
@login_required
def get(self, request):
"""获取当前用户对指定题目的流程图提交详情"""
problem_id = request.GET.get("problem_id")
if not problem_id:
return self.error("problem_id is required")
try:
problem = Problem.objects.get(id=problem_id)
except Problem.DoesNotExist:
return self.error("Problem doesn't exist")
page = int(request.GET.get("page", 0))
submissions = FlowchartSubmission.objects.filter(
user=request.user,
problem=problem,
status=FlowchartSubmissionStatus.COMPLETED,
).order_by("create_time")
if page == 0:
submission = submissions.last()
else:
submission = submissions[page - 1]
serializer = FlowchartSubmissionSerializer(submission)
return self.success({"submission": serializer.data})
class FlowchartSubmissionCurrentAPI(APIView):
@login_required
def get(self, request):
"""获取当前用户对指定题目的最新流程图提交,只返回次数和分数"""
problem_id = request.GET.get("problem_id")
if not problem_id:
return self.error("problem_id is required")
try:
problem = Problem.objects.get(id=problem_id)
except Problem.DoesNotExist:
return self.error("Problem doesn't exist")
submissions = (
FlowchartSubmission.objects.filter(
user=request.user,
problem=problem,
status=FlowchartSubmissionStatus.COMPLETED,
)
.values("ai_score", "ai_grade")
.order_by("-create_time")
)
count = submissions.count()
if count == 0:
return self.success({"count": 0, "score": 0, "grade": ""})
submission = submissions[0]
return self.success(
{
"count": count,
"score": submission["ai_score"],
"grade": submission["ai_grade"],
}
)

View File

@@ -16,7 +16,6 @@ from problem.utils import parse_problem_template
from submission.models import JudgeStatus, Submission
from utils.cache import cache
from utils.constants import CacheKey
from utils.websocket import push_submission_update
logger = logging.getLogger(__name__)
@@ -67,6 +66,26 @@ class DispatcherBase(object):
logger.exception(e)
class SPJCompiler(DispatcherBase):
def __init__(self, spj_code, spj_version, spj_language):
super().__init__()
spj_compile_config = list(filter(lambda config: spj_language == config["name"], SysOptions.spj_languages))[0]["spj"][
"compile"]
self.data = {
"src": spj_code,
"spj_version": spj_version,
"spj_compile_config": spj_compile_config
}
def compile_spj(self):
with ChooseJudgeServer() as server:
if not server:
return "No available judge_server"
result = self._request(urljoin(server.service_url, "compile_spj"), data=self.data)
if not result:
return "Failed to call judge server"
if result["err"]:
return result["data"]
class JudgeDispatcher(DispatcherBase):
@@ -106,6 +125,12 @@ class JudgeDispatcher(DispatcherBase):
def judge(self):
language = self.submission.language
sub_config = list(filter(lambda item: language == item["name"], SysOptions.languages))[0]
spj_config = {}
if self.problem.spj_code:
for lang in SysOptions.spj_languages:
if lang["name"] == self.problem.spj_language:
spj_config = lang["spj"]
break
if language in self.problem.template:
template = parse_problem_template(self.problem.template[language])
@@ -120,6 +145,10 @@ class JudgeDispatcher(DispatcherBase):
"max_memory": 1024 * 1024 * self.problem.memory_limit,
"test_case_id": self.problem.test_case_id,
"output": False,
"spj_version": self.problem.spj_version,
"spj_config": spj_config.get("config"),
"spj_compile_config": spj_config.get("compile"),
"spj_src": self.problem.spj_code,
"io_mode": self.problem.io_mode
}
@@ -127,56 +156,12 @@ class JudgeDispatcher(DispatcherBase):
if not server:
data = {"submission_id": self.submission.id, "problem_id": self.problem.id}
cache.lpush(CacheKey.waiting_queue, json.dumps(data))
# 推送排队状态
try:
push_submission_update(
submission_id=str(self.submission.id),
user_id=self.submission.user_id,
data={
"type": "submission_update",
"submission_id": str(self.submission.id),
"result": JudgeStatus.PENDING,
"status": "pending",
}
)
except Exception as e:
logger.error(f"Failed to push submission update: {str(e)}")
return
Submission.objects.filter(id=self.submission.id).update(result=JudgeStatus.JUDGING)
# 推送判题中状态
try:
push_submission_update(
submission_id=str(self.submission.id),
user_id=self.submission.user_id,
data={
"type": "submission_update",
"submission_id": str(self.submission.id),
"result": JudgeStatus.JUDGING,
"status": "judging",
}
)
except Exception as e:
logger.error(f"Failed to push submission update: {str(e)}")
resp = self._request(urljoin(server.service_url, "/judge"), data=data)
if not resp:
Submission.objects.filter(id=self.submission.id).update(result=JudgeStatus.SYSTEM_ERROR)
# 推送系统错误状态
try:
push_submission_update(
submission_id=str(self.submission.id),
user_id=self.submission.user_id,
data={
"type": "submission_update",
"submission_id": str(self.submission.id),
"result": JudgeStatus.SYSTEM_ERROR,
"status": "error",
}
)
except Exception as e:
logger.error(f"Failed to push submission update: {str(e)}")
return
if resp["err"]:
@@ -197,24 +182,6 @@ class JudgeDispatcher(DispatcherBase):
else:
self.submission.result = JudgeStatus.PARTIALLY_ACCEPTED
self.submission.save()
# 推送判题完成状态
try:
push_submission_update(
submission_id=str(self.submission.id),
user_id=self.submission.user_id,
data={
"type": "submission_update",
"submission_id": str(self.submission.id),
"result": self.submission.result,
"status": "finished",
"time_cost": self.submission.statistic_info.get("time_cost"),
"memory_cost": self.submission.statistic_info.get("memory_cost"),
"score": self.submission.statistic_info.get("score", 0),
}
)
except Exception as e:
logger.error(f"Failed to push submission update: {str(e)}")
if self.contest_id:
if self.contest.status != ContestStatus.CONTEST_UNDERWAY or \

View File

@@ -35,6 +35,20 @@ int main() {
}
}
_c_lang_spj_compile = {
"src_name": "spj-{spj_version}.c",
"exe_name": "spj-{spj_version}",
"max_cpu_time": 3000,
"max_real_time": 10000,
"max_memory": 1024 * 1024 * 1024,
"compile_command": "/usr/bin/gcc -DONLINE_JUDGE -O2 -w -fmax-errors=3 -std=c17 {src_path} -lm -o {exe_path}"
}
_c_lang_spj_config = {
"exe_name": "spj-{spj_version}",
"command": "{exe_path} {in_file_path} {user_out_file_path}",
"seccomp_rule": "c_cpp"
}
_cpp_lang_config = {
"template": """//PREPEND BEGIN
@@ -68,6 +82,20 @@ int main() {
}
}
_cpp_lang_spj_compile = {
"src_name": "spj-{spj_version}.cpp",
"exe_name": "spj-{spj_version}",
"max_cpu_time": 10000,
"max_real_time": 20000,
"max_memory": 1024 * 1024 * 1024,
"compile_command": "/usr/bin/g++ -DONLINE_JUDGE -O2 -w -fmax-errors=3 -std=c++20 {src_path} -lm -o {exe_path}"
}
_cpp_lang_spj_config = {
"exe_name": "spj-{spj_version}",
"command": "{exe_path} {in_file_path} {user_out_file_path}",
"seccomp_rule": "c_cpp"
}
_java_lang_config = {
"template": """//PREPEND BEGIN
@@ -196,8 +224,10 @@ console.log(add(1, 2))
}
languages = [
{"config": _c_lang_config, "name": "C", "description": "GCC 13", "content_type": "text/x-csrc"},
{"config": _cpp_lang_config, "name": "C++", "description": "GCC 13", "content_type": "text/x-c++src"},
{"config": _c_lang_config, "name": "C", "description": "GCC 13", "content_type": "text/x-csrc",
"spj": {"compile": _c_lang_spj_compile, "config": _c_lang_spj_config}},
{"config": _cpp_lang_config, "name": "C++", "description": "GCC 13", "content_type": "text/x-c++src",
"spj": {"compile": _cpp_lang_spj_compile, "config": _cpp_lang_spj_config}},
{"config": _java_lang_config, "name": "Java", "description": "Temurin 21", "content_type": "text/x-java"},
{"config": _py3_lang_config, "name": "Python3", "description": "Python 3.12", "content_type": "text/x-python"},
{"config": _go_lang_config, "name": "Golang", "description": "Golang 1.22", "content_type": "text/x-go"},

View File

@@ -1,30 +1,7 @@
"""
ASGI config for oj project.
It exposes the ASGI callable as a module-level variable named ``application``.
For more information on this file, see
https://docs.djangoproject.com/en/5.2/howto/deployment/asgi/
"""
import os
from django.core.asgi import get_asgi_application
from channels.routing import ProtocolTypeRouter, URLRouter
from channels.auth import AuthMiddlewareStack
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "oj.settings")
# Initialize Django ASGI application early to ensure the AppRegistry
# is populated before importing code that may import ORM models.
django_asgi_app = get_asgi_application()
# Import routing after Django setup
from oj.routing import websocket_urlpatterns
application = ProtocolTypeRouter(
{
"http": django_asgi_app,
"websocket": AuthMiddlewareStack(URLRouter(websocket_urlpatterns)),
}
)
import os
from django.core.asgi import get_asgi_application
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "oj.settings")
application = get_asgi_application()

View File

@@ -1,22 +1,19 @@
# coding=utf-8
import os
from utils.shortcuts import get_env
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
DATABASES = {
"default": {
"ENGINE": "django.db.backends.postgresql",
"HOST": "150.158.29.156",
"PORT": "5455",
"NAME": "onlinejudge",
"USER": "onlinejudge",
"PASSWORD": "onlinejudge",
"ENGINE": "django.db.backends.sqlite3",
"NAME": os.path.join(BASE_DIR, "db.sqlite3"),
}
}
REDIS_CONF = {
"host": "150.158.29.156",
"port": 5456,
"host": get_env("REDIS_HOST", "127.0.0.1"),
"port": get_env("REDIS_PORT", "6380"),
}

View File

@@ -1,15 +0,0 @@
"""
WebSocket URL Configuration for oj project.
"""
from django.urls import path
from submission.consumers import SubmissionConsumer
from conf.consumers import ConfigConsumer
from flowchart.consumers import FlowchartConsumer
websocket_urlpatterns = [
path("ws/submission/", SubmissionConsumer.as_asgi()),
path("ws/config/", ConfigConsumer.as_asgi()),
path("ws/flowchart/", FlowchartConsumer.as_asgi()),
]

View File

@@ -28,14 +28,12 @@ BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
# Applications
VENDOR_APPS = [
"daphne", # Channels ASGI server - must be first
"django.contrib.auth",
"django.contrib.sessions",
"django.contrib.contenttypes",
"django.contrib.messages",
"django.contrib.staticfiles",
"rest_framework",
"channels",
"django_dramatiq",
"django_dbconn_retry",
]
@@ -57,10 +55,6 @@ LOCAL_APPS = [
"message",
"comment",
"tutorial",
"ai",
"flowchart",
"problemset",
"class_pk",
]
INSTALLED_APPS = VENDOR_APPS + LOCAL_APPS
@@ -97,9 +91,6 @@ TEMPLATES = [
]
WSGI_APPLICATION = "oj.wsgi.application"
# ASGI Application for WebSocket support
ASGI_APPLICATION = "oj.asgi.application"
# Password validation
# https://docs.djangoproject.com/en/1.9/ref/settings/#auth-password-validators
@@ -121,9 +112,13 @@ AUTH_PASSWORD_VALIDATORS = [
# Internationalization
# https://docs.djangoproject.com/en/1.8/topics/i18n/
LANGUAGE_CODE = "zh-cn"
LANGUAGE_CODE = "en-us"
TIME_ZONE = "Asia/Shanghai"
TIME_ZONE = "UTC"
USE_I18N = True
USE_L10N = True
USE_TZ = True
@@ -215,23 +210,12 @@ def redis_config(db):
}
CACHES = {"default": redis_config(db=1)}
if production_env:
CACHES = {"default": redis_config(db=1)}
SESSION_ENGINE = "django.contrib.sessions.backends.cache"
SESSION_CACHE_ALIAS = "default"
# Channels Configuration
CHANNEL_LAYERS = {
"default": {
"BACKEND": "channels_redis.core.RedisChannelLayer",
"CONFIG": {
"hosts": [(REDIS_CONF["host"], REDIS_CONF["port"])],
"capacity": 1500, # 每个频道的最大消息数
"expiry": 10, # 消息过期时间(秒)
},
},
}
DRAMATIQ_BROKER = {
"BROKER": "dramatiq.brokers.redis.RedisBroker",
"OPTIONS": {

View File

@@ -19,9 +19,4 @@ urlpatterns = [
path("api/admin/", include("comment.urls.admin")),
path("api/", include("tutorial.urls.tutorial")),
path("api/admin/", include("tutorial.urls.admin")),
path("api/", include("ai.urls.oj")),
path("api/", include("flowchart.urls.oj")),
path("api/", include("problemset.urls.oj")),
path("api/admin/", include("problemset.urls.admin")),
path("api/", include("class_pk.urls.oj")),
]

View File

@@ -104,7 +104,6 @@ class OptionKeys:
judge_server_token = "judge_server_token"
throttling = "throttling"
languages = "languages"
enable_maxkb = "enable_maxkb"
class OptionDefaultValue:
@@ -120,7 +119,6 @@ class OptionDefaultValue:
throttling = {"ip": {"capacity": 100, "fill_rate": 0.1, "default_capacity": 50},
"user": {"capacity": 20, "fill_rate": 0.03, "default_capacity": 10}}
languages = languages
enable_maxkb = True
class _SysOptionsMeta(type):
@@ -273,18 +271,17 @@ class _SysOptionsMeta(type):
def languages(cls, value):
cls._set_option(OptionKeys.languages, value)
@my_property(ttl=DEFAULT_SHORT_TTL)
def spj_languages(cls):
return [item for item in cls.languages if "spj" in item]
@my_property(ttl=DEFAULT_SHORT_TTL)
def language_names(cls):
return [item["name"] for item in cls.languages]
@my_property(ttl=DEFAULT_SHORT_TTL)
def enable_maxkb(cls):
return cls._get_option(OptionKeys.enable_maxkb)
@enable_maxkb.setter
def enable_maxkb(cls, value):
cls._set_option(OptionKeys.enable_maxkb, value)
def spj_language_names(cls):
return [item["name"] for item in cls.languages if "spj" in item]
def reset_languages(cls):
cls.languages = languages

View File

@@ -1,18 +0,0 @@
# Generated by Django 5.2.3 on 2025-10-03 16:31
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('problem', '0001_initial'),
]
operations = [
migrations.AddField(
model_name='problem',
name='prompt',
field=models.TextField(null=True),
),
]

View File

@@ -1,18 +0,0 @@
# Generated by Django 5.2.3 on 2025-10-03 16:56
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('problem', '0002_problem_prompt'),
]
operations = [
migrations.AddField(
model_name='problem',
name='answers',
field=models.JSONField(null=True),
),
]

View File

@@ -1,38 +0,0 @@
# Generated by Django 5.2.3 on 2025-10-11 14:57
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('problem', '0003_problem_answers'),
]
operations = [
migrations.AddField(
model_name='problem',
name='allow_flowchart',
field=models.BooleanField(default=False),
),
migrations.AddField(
model_name='problem',
name='flowchart_data',
field=models.JSONField(default=dict),
),
migrations.AddField(
model_name='problem',
name='flowchart_hint',
field=models.TextField(blank=True, null=True),
),
migrations.AddField(
model_name='problem',
name='mermaid_code',
field=models.TextField(blank=True, null=True),
),
migrations.AddField(
model_name='problem',
name='show_flowchart',
field=models.BooleanField(default=False),
),
]

View File

@@ -1,33 +0,0 @@
# Generated by Django 5.2.3 on 2025-10-11 15:22
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
('problem', '0004_problem_allow_flowchart_problem_flowchart_data_and_more'),
]
operations = [
migrations.RemoveField(
model_name='problem',
name='spj',
),
migrations.RemoveField(
model_name='problem',
name='spj_code',
),
migrations.RemoveField(
model_name='problem',
name='spj_compile_ok',
),
migrations.RemoveField(
model_name='problem',
name='spj_language',
),
migrations.RemoveField(
model_name='problem',
name='spj_version',
),
]

View File

@@ -1,4 +1,6 @@
from django.conf import settings
from django.db import models
from utils.models import JSONField
from account.models import User
from contest.models import Contest
@@ -49,13 +51,13 @@ class Problem(models.Model):
input_description = RichTextField()
output_description = RichTextField()
# [{input: "test", output: "123"}, {input: "test123", output: "456"}]
samples = models.JSONField()
samples = JSONField()
test_case_id = models.TextField()
# [{"input_name": "1.in", "output_name": "1.out", "score": 0}]
test_case_score = models.JSONField()
test_case_score = JSONField()
hint = RichTextField(null=True)
languages = models.JSONField()
template = models.JSONField()
languages = JSONField()
template = JSONField()
create_time = models.DateTimeField(auto_now_add=True)
# we can not use auto_now here
last_update_time = models.DateTimeField(auto_now=True, null=True)
@@ -65,29 +67,25 @@ class Problem(models.Model):
# MB
memory_limit = models.IntegerField()
# io mode
io_mode = models.JSONField(default=_default_io_mode)
io_mode = JSONField(default=_default_io_mode)
# special judge related
spj = models.BooleanField(default=False)
spj_language = models.TextField(null=True)
spj_code = models.TextField(null=True)
spj_version = models.TextField(null=True)
spj_compile_ok = models.BooleanField(default=False)
rule_type = models.TextField()
visible = models.BooleanField(default=True)
difficulty = models.TextField()
tags = models.ManyToManyField(ProblemTag)
source = models.TextField(null=True)
prompt = models.TextField(null=True)
# [{language: "python", code: "..."}]
answers = models.JSONField(null=True)
# for OI mode
total_score = models.IntegerField(default=0)
submission_number = models.BigIntegerField(default=0)
accepted_number = models.BigIntegerField(default=0)
# {JudgeStatus.ACCEPTED: 3, JudgeStatus.WRONG_ANSWER: 11}, the number means count
statistic_info = models.JSONField(default=dict)
statistic_info = JSONField(default=dict)
share_submission = models.BooleanField(default=False)
# 流程图相关字段
allow_flowchart = models.BooleanField(default=False) # 是否允许/需要提交流程图
mermaid_code = models.TextField(null=True, blank=True) # 流程图答案(Mermaid代码)
flowchart_data = models.JSONField(default=dict) # 流程图答案元数据(JSON格式)
flowchart_hint = models.TextField(null=True, blank=True) # 流程图提示信息
show_flowchart = models.BooleanField(default=False) # 是否显示流程图答案数据如果True这样就不需要提交流程图了说明就是给学生看的
class Meta:
db_table = "problem"

View File

@@ -2,10 +2,12 @@ import re
from django import forms
from options.options import SysOptions
from utils.api import UsernameSerializer, serializers
from utils.constants import Difficulty
from utils.serializers import (
LanguageNameMultiChoiceField,
SPJLanguageNameChoiceField,
LanguageNameChoiceField,
)
@@ -14,6 +16,7 @@ from .utils import parse_problem_template
class TestCaseUploadForm(forms.Form):
spj = forms.CharField(max_length=12)
file = forms.FileField()
@@ -22,11 +25,6 @@ class CreateSampleSerializer(serializers.Serializer):
output = serializers.CharField(trim_whitespace=False)
class CreateAnswerSerializer(serializers.Serializer):
language = serializers.CharField()
code = serializers.CharField()
class CreateTestCaseScoreSerializer(serializers.Serializer):
input_name = serializers.CharField(max_length=32)
output_name = serializers.CharField(max_length=32)
@@ -70,6 +68,10 @@ class CreateOrEditProblemSerializer(serializers.Serializer):
choices=[ProblemRuleType.ACM, ProblemRuleType.OI]
)
io_mode = ProblemIOModeSerializer()
spj = serializers.BooleanField()
spj_language = SPJLanguageNameChoiceField(allow_blank=True, allow_null=True)
spj_code = serializers.CharField(allow_blank=True, allow_null=True)
spj_compile_ok = serializers.BooleanField(default=False)
visible = serializers.BooleanField()
difficulty = serializers.ChoiceField(choices=Difficulty.choices())
tags = serializers.ListField(
@@ -77,25 +79,8 @@ class CreateOrEditProblemSerializer(serializers.Serializer):
)
hint = serializers.CharField(allow_blank=True, allow_null=True)
source = serializers.CharField(max_length=256, allow_blank=True, allow_null=True)
prompt = serializers.CharField(allow_blank=True, allow_null=True)
answers = serializers.ListField(
child=CreateAnswerSerializer(),
allow_empty=True,
allow_null=True,
)
share_submission = serializers.BooleanField()
# 流程图相关字段
allow_flowchart = serializers.BooleanField(required=False, default=False)
show_flowchart = serializers.BooleanField(required=False, default=False)
mermaid_code = serializers.CharField(
allow_blank=True, allow_null=True, required=False
)
flowchart_hint = serializers.CharField(
allow_blank=True, allow_null=True, required=False
)
class CreateProblemSerializer(CreateOrEditProblemSerializer):
pass
@@ -120,6 +105,11 @@ class TagSerializer(serializers.ModelSerializer):
fields = "__all__"
class CompileSPJSerializer(serializers.Serializer):
spj_language = SPJLanguageNameChoiceField()
spj_code = serializers.CharField()
class BaseProblemSerializer(serializers.ModelSerializer):
tags = serializers.SlugRelatedField(many=True, slug_field="name", read_only=True)
created_by = UsernameSerializer()
@@ -145,8 +135,6 @@ class ProblemAdminListSerializer(BaseProblemSerializer):
class ProblemSerializer(BaseProblemSerializer):
template = serializers.SerializerMethodField("get_public_template")
mermaid_code = serializers.SerializerMethodField()
flowchart_data = serializers.SerializerMethodField()
class Meta:
model = Problem
@@ -155,21 +143,11 @@ class ProblemSerializer(BaseProblemSerializer):
"test_case_id",
"visible",
"is_public",
"answers",
"spj_code",
"spj_version",
"spj_compile_ok",
)
def get_mermaid_code(self, obj):
# 当 allow_flowchart 为 True 时,不返回 mermaid_code
if obj.allow_flowchart:
return None
return obj.mermaid_code
def get_flowchart_data(self, obj):
# 当 allow_flowchart 为 True 时,不返回 flowchart_data
if obj.allow_flowchart:
return None
return obj.flowchart_data
class ProblemListSerializer(BaseProblemSerializer):
class Meta:
@@ -184,14 +162,12 @@ class ProblemListSerializer(BaseProblemSerializer):
"created_by",
"tags",
"contest",
"allow_flowchart",
"rule_type",
]
class ProblemSafeSerializer(BaseProblemSerializer):
template = serializers.SerializerMethodField("get_public_template")
mermaid_code = serializers.SerializerMethodField()
flowchart_data = serializers.SerializerMethodField()
class Meta:
model = Problem
@@ -200,36 +176,116 @@ class ProblemSafeSerializer(BaseProblemSerializer):
"test_case_id",
"visible",
"is_public",
"spj_code",
"spj_version",
"spj_compile_ok",
"difficulty",
"submission_number",
"accepted_number",
"statistic_info",
"answers",
)
def get_mermaid_code(self, obj):
# 当 allow_flowchart 为 True 时,不返回 mermaid_code
if obj.allow_flowchart:
return None
return obj.mermaid_code
def get_flowchart_data(self, obj):
# 当 allow_flowchart 为 True 时,不返回 flowchart_data
if obj.allow_flowchart:
return None
return obj.flowchart_data
class ContestProblemMakePublicSerializer(serializers.Serializer):
id = serializers.IntegerField()
display_id = serializers.CharField(max_length=32)
class ExportProblemSerializer(serializers.ModelSerializer):
display_id = serializers.SerializerMethodField()
description = serializers.SerializerMethodField()
input_description = serializers.SerializerMethodField()
output_description = serializers.SerializerMethodField()
test_case_score = serializers.SerializerMethodField()
hint = serializers.SerializerMethodField()
spj = serializers.SerializerMethodField()
template = serializers.SerializerMethodField()
source = serializers.SerializerMethodField()
tags = serializers.SlugRelatedField(many=True, slug_field="name", read_only=True)
def get_display_id(self, obj):
return obj._id
def _html_format_value(self, value):
return {"format": "html", "value": value}
def get_description(self, obj):
return self._html_format_value(obj.description)
def get_input_description(self, obj):
return self._html_format_value(obj.input_description)
def get_output_description(self, obj):
return self._html_format_value(obj.output_description)
def get_hint(self, obj):
return self._html_format_value(obj.hint)
def get_test_case_score(self, obj):
return [
{
"score": item["score"] if obj.rule_type == ProblemRuleType.OI else 100,
"input_name": item["input_name"],
"output_name": item["output_name"],
}
for item in obj.test_case_score
]
def get_spj(self, obj):
return {"code": obj.spj_code, "language": obj.spj_language} if obj.spj else None
def get_template(self, obj):
ret = {}
for k, v in obj.template.items():
ret[k] = parse_problem_template(v)
return ret
def get_source(self, obj):
return obj.source or f"{SysOptions.website_name} {SysOptions.website_base_url}"
class Meta:
model = Problem
fields = (
"display_id",
"title",
"description",
"tags",
"input_description",
"output_description",
"test_case_score",
"hint",
"time_limit",
"memory_limit",
"samples",
"template",
"spj",
"rule_type",
"source",
"template",
)
class AddContestProblemSerializer(serializers.Serializer):
contest_id = serializers.IntegerField()
problem_id = serializers.IntegerField()
display_id = serializers.CharField()
class ExportProblemRequestSerializer(serializers.Serializer):
problem_id = serializers.ListField(
child=serializers.IntegerField(), allow_empty=False
)
class UploadProblemForm(forms.Form):
file = forms.FileField()
class FormatValueSerializer(serializers.Serializer):
format = serializers.ChoiceField(choices=["html", "markdown"])
value = serializers.CharField(allow_blank=True)
class TestCaseScoreSerializer(serializers.Serializer):
score = serializers.IntegerField(min_value=1)
input_name = serializers.CharField(max_length=32)
@@ -242,6 +298,58 @@ class TemplateSerializer(serializers.Serializer):
append = serializers.CharField()
class SPJSerializer(serializers.Serializer):
code = serializers.CharField()
language = SPJLanguageNameChoiceField()
class AnswerSerializer(serializers.Serializer):
code = serializers.CharField()
language = LanguageNameChoiceField()
class ImportProblemSerializer(serializers.Serializer):
display_id = serializers.CharField(max_length=128)
title = serializers.CharField(max_length=128)
description = FormatValueSerializer()
input_description = FormatValueSerializer()
output_description = FormatValueSerializer()
hint = FormatValueSerializer()
test_case_score = serializers.ListField(
child=TestCaseScoreSerializer(), allow_null=True
)
time_limit = serializers.IntegerField(min_value=1, max_value=60000)
memory_limit = serializers.IntegerField(min_value=1, max_value=10240)
samples = serializers.ListField(child=CreateSampleSerializer())
template = serializers.DictField(child=TemplateSerializer())
spj = SPJSerializer(allow_null=True)
rule_type = serializers.ChoiceField(choices=ProblemRuleType.choices())
source = serializers.CharField(max_length=200, allow_blank=True, allow_null=True)
answers = serializers.ListField(child=AnswerSerializer())
tags = serializers.ListField(child=serializers.CharField())
class FPSProblemSerializer(serializers.Serializer):
class UnitSerializer(serializers.Serializer):
unit = serializers.ChoiceField(choices=["MB", "s", "ms"])
value = serializers.IntegerField(min_value=1, max_value=60000)
title = serializers.CharField(max_length=128)
description = serializers.CharField()
input = serializers.CharField()
output = serializers.CharField()
hint = serializers.CharField(allow_blank=True, allow_null=True)
time_limit = UnitSerializer()
memory_limit = UnitSerializer()
samples = serializers.ListField(child=CreateSampleSerializer())
source = serializers.CharField(max_length=200, allow_blank=True, allow_null=True)
spj = SPJSerializer(allow_null=True)
template = serializers.ListField(
child=serializers.DictField(), allow_empty=True, allow_null=True
)
append = serializers.ListField(
child=serializers.DictField(), allow_empty=True, allow_null=True
)
prepend = serializers.ListField(
child=serializers.DictField(), allow_empty=True, allow_null=True
)

324
problem/tests.py Normal file
View File

@@ -0,0 +1,324 @@
import copy
import hashlib
import os
import shutil
from datetime import timedelta
from zipfile import ZipFile
from django.conf import settings
from utils.api.tests import APITestCase
from .models import ProblemTag, ProblemIOMode
from .models import Problem, ProblemRuleType
from contest.models import Contest
from contest.tests import DEFAULT_CONTEST_DATA
from .views.admin import TestCaseAPI
from .utils import parse_problem_template
DEFAULT_PROBLEM_DATA = {"_id": "A-110", "title": "test", "description": "<p>test</p>", "input_description": "test",
"output_description": "test", "time_limit": 1000, "memory_limit": 256, "difficulty": "Low",
"visible": True, "tags": ["test"], "languages": ["C", "C++", "Java", "Python2"], "template": {},
"samples": [{"input": "test", "output": "test"}], "spj": False, "spj_language": "C",
"spj_code": "", "spj_compile_ok": True, "test_case_id": "499b26290cc7994e0b497212e842ea85",
"test_case_score": [{"output_name": "1.out", "input_name": "1.in", "output_size": 0,
"stripped_output_md5": "d41d8cd98f00b204e9800998ecf8427e",
"input_size": 0, "score": 0}],
"io_mode": {"io_mode": ProblemIOMode.standard, "input": "input.txt", "output": "output.txt"},
"share_submission": False,
"rule_type": "ACM", "hint": "<p>test</p>", "source": "test"}
class ProblemCreateTestBase(APITestCase):
@staticmethod
def add_problem(problem_data, created_by):
data = copy.deepcopy(problem_data)
if data["spj"]:
if not data["spj_language"] or not data["spj_code"]:
raise ValueError("Invalid spj")
data["spj_version"] = hashlib.md5(
(data["spj_language"] + ":" + data["spj_code"]).encode("utf-8")).hexdigest()
else:
data["spj_language"] = None
data["spj_code"] = None
if data["rule_type"] == ProblemRuleType.OI:
total_score = 0
for item in data["test_case_score"]:
if item["score"] <= 0:
raise ValueError("invalid score")
else:
total_score += item["score"]
data["total_score"] = total_score
data["created_by"] = created_by
tags = data.pop("tags")
data["languages"] = list(data["languages"])
problem = Problem.objects.create(**data)
for item in tags:
try:
tag = ProblemTag.objects.get(name=item)
except ProblemTag.DoesNotExist:
tag = ProblemTag.objects.create(name=item)
problem.tags.add(tag)
return problem
class ProblemTagListAPITest(APITestCase):
def test_get_tag_list(self):
ProblemTag.objects.create(name="name1")
ProblemTag.objects.create(name="name2")
resp = self.client.get(self.reverse("problem_tag_list_api"))
self.assertSuccess(resp)
class TestCaseUploadAPITest(APITestCase):
def setUp(self):
self.api = TestCaseAPI()
self.url = self.reverse("test_case_api")
self.create_super_admin()
def test_filter_file_name(self):
self.assertEqual(self.api.filter_name_list(["1.in", "1.out", "2.in", ".DS_Store"], spj=False),
["1.in", "1.out"])
self.assertEqual(self.api.filter_name_list(["2.in", "2.out"], spj=False), [])
self.assertEqual(self.api.filter_name_list(["1.in", "1.out", "2.in"], spj=True), ["1.in", "2.in"])
self.assertEqual(self.api.filter_name_list(["2.in", "3.in"], spj=True), [])
def make_test_case_zip(self):
base_dir = os.path.join("/tmp", "test_case")
shutil.rmtree(base_dir, ignore_errors=True)
os.mkdir(base_dir)
file_names = ["1.in", "1.out", "2.in", ".DS_Store"]
for item in file_names:
with open(os.path.join(base_dir, item), "w", encoding="utf-8") as f:
f.write(item + "\n" + item + "\r\n" + "end")
zip_file = os.path.join(base_dir, "test_case.zip")
with ZipFile(os.path.join(base_dir, "test_case.zip"), "w") as f:
for item in file_names:
f.write(os.path.join(base_dir, item), item)
return zip_file
def test_upload_spj_test_case_zip(self):
with open(self.make_test_case_zip(), "rb") as f:
resp = self.client.post(self.url,
data={"spj": "true", "file": f}, format="multipart")
self.assertSuccess(resp)
data = resp.data["data"]
self.assertEqual(data["spj"], True)
test_case_dir = os.path.join(settings.TEST_CASE_DIR, data["id"])
self.assertTrue(os.path.exists(test_case_dir))
for item in data["info"]:
name = item["input_name"]
with open(os.path.join(test_case_dir, name), "r", encoding="utf-8") as f:
self.assertEqual(f.read(), name + "\n" + name + "\n" + "end")
def test_upload_test_case_zip(self):
with open(self.make_test_case_zip(), "rb") as f:
resp = self.client.post(self.url,
data={"spj": "false", "file": f}, format="multipart")
self.assertSuccess(resp)
data = resp.data["data"]
self.assertEqual(data["spj"], False)
test_case_dir = os.path.join(settings.TEST_CASE_DIR, data["id"])
self.assertTrue(os.path.exists(test_case_dir))
for item in data["info"]:
name = item["input_name"]
with open(os.path.join(test_case_dir, name), "r", encoding="utf-8") as f:
self.assertEqual(f.read(), name + "\n" + name + "\n" + "end")
class ProblemAdminAPITest(APITestCase):
def setUp(self):
self.url = self.reverse("problem_admin_api")
self.create_super_admin()
self.data = copy.deepcopy(DEFAULT_PROBLEM_DATA)
def test_create_problem(self):
resp = self.client.post(self.url, data=self.data)
self.assertSuccess(resp)
return resp
def test_duplicate_display_id(self):
self.test_create_problem()
resp = self.client.post(self.url, data=self.data)
self.assertFailed(resp, "Display ID already exists")
def test_spj(self):
data = copy.deepcopy(self.data)
data["spj"] = True
resp = self.client.post(self.url, data)
self.assertFailed(resp, "Invalid spj")
data["spj_code"] = "test"
resp = self.client.post(self.url, data=data)
self.assertSuccess(resp)
def test_get_problem(self):
self.test_create_problem()
resp = self.client.get(self.url)
self.assertSuccess(resp)
def test_get_one_problem(self):
problem_id = self.test_create_problem().data["data"]["id"]
resp = self.client.get(self.url + "?id=" + str(problem_id))
self.assertSuccess(resp)
def test_edit_problem(self):
problem_id = self.test_create_problem().data["data"]["id"]
data = copy.deepcopy(self.data)
data["id"] = problem_id
resp = self.client.put(self.url, data=data)
self.assertSuccess(resp)
class ProblemAPITest(ProblemCreateTestBase):
def setUp(self):
self.url = self.reverse("problem_api")
admin = self.create_admin(login=False)
self.problem = self.add_problem(DEFAULT_PROBLEM_DATA, admin)
self.create_user("test", "test123")
def test_get_problem_list(self):
resp = self.client.get(f"{self.url}?limit=10")
self.assertSuccess(resp)
def get_one_problem(self):
resp = self.client.get(self.url + "?id=" + self.problem._id)
self.assertSuccess(resp)
class ContestProblemAdminTest(APITestCase):
def setUp(self):
self.url = self.reverse("contest_problem_admin_api")
self.create_admin()
self.contest = self.client.post(self.reverse("contest_admin_api"), data=DEFAULT_CONTEST_DATA).data["data"]
def test_create_contest_problem(self):
data = copy.deepcopy(DEFAULT_PROBLEM_DATA)
data["contest_id"] = self.contest["id"]
resp = self.client.post(self.url, data=data)
self.assertSuccess(resp)
return resp.data["data"]
def test_get_contest_problem(self):
self.test_create_contest_problem()
contest_id = self.contest["id"]
resp = self.client.get(self.url + "?contest_id=" + str(contest_id))
self.assertSuccess(resp)
self.assertEqual(len(resp.data["data"]["results"]), 1)
def test_get_one_contest_problem(self):
contest_problem = self.test_create_contest_problem()
contest_id = self.contest["id"]
problem_id = contest_problem["id"]
resp = self.client.get(f"{self.url}?contest_id={contest_id}&id={problem_id}")
self.assertSuccess(resp)
class ContestProblemTest(ProblemCreateTestBase):
def setUp(self):
admin = self.create_admin()
url = self.reverse("contest_admin_api")
contest_data = copy.deepcopy(DEFAULT_CONTEST_DATA)
contest_data["password"] = ""
contest_data["start_time"] = contest_data["start_time"] + timedelta(hours=1)
self.contest = self.client.post(url, data=contest_data).data["data"]
self.problem = self.add_problem(DEFAULT_PROBLEM_DATA, admin)
self.problem.contest_id = self.contest["id"]
self.problem.save()
self.url = self.reverse("contest_problem_api")
def test_admin_get_contest_problem_list(self):
contest_id = self.contest["id"]
resp = self.client.get(self.url + "?contest_id=" + str(contest_id))
self.assertSuccess(resp)
self.assertEqual(len(resp.data["data"]), 1)
def test_admin_get_one_contest_problem(self):
contest_id = self.contest["id"]
problem_id = self.problem._id
resp = self.client.get("{}?contest_id={}&problem_id={}".format(self.url, contest_id, problem_id))
self.assertSuccess(resp)
def test_regular_user_get_not_started_contest_problem(self):
self.create_user("test", "test123")
resp = self.client.get(self.url + "?contest_id=" + str(self.contest["id"]))
self.assertDictEqual(resp.data, {"error": "error", "data": "Contest has not started yet."})
def test_reguar_user_get_started_contest_problem(self):
self.create_user("test", "test123")
contest = Contest.objects.first()
contest.start_time = contest.start_time - timedelta(hours=1)
contest.save()
resp = self.client.get(self.url + "?contest_id=" + str(self.contest["id"]))
self.assertSuccess(resp)
class AddProblemFromPublicProblemAPITest(ProblemCreateTestBase):
def setUp(self):
admin = self.create_admin()
url = self.reverse("contest_admin_api")
contest_data = copy.deepcopy(DEFAULT_CONTEST_DATA)
contest_data["password"] = ""
contest_data["start_time"] = contest_data["start_time"] + timedelta(hours=1)
self.contest = self.client.post(url, data=contest_data).data["data"]
self.problem = self.add_problem(DEFAULT_PROBLEM_DATA, admin)
self.url = self.reverse("add_contest_problem_from_public_api")
self.data = {
"display_id": "1000",
"contest_id": self.contest["id"],
"problem_id": self.problem.id
}
def test_add_contest_problem(self):
resp = self.client.post(self.url, data=self.data)
self.assertSuccess(resp)
self.assertTrue(Problem.objects.all().exists())
self.assertTrue(Problem.objects.filter(contest_id=self.contest["id"]).exists())
class ParseProblemTemplateTest(APITestCase):
def test_parse(self):
template_str = """
//PREPEND BEGIN
aaa
//PREPEND END
//TEMPLATE BEGIN
bbb
//TEMPLATE END
//APPEND BEGIN
ccc
//APPEND END
"""
ret = parse_problem_template(template_str)
self.assertEqual(ret["prepend"], "aaa\n")
self.assertEqual(ret["template"], "bbb\n")
self.assertEqual(ret["append"], "ccc\n")
def test_parse1(self):
template_str = """
//PREPEND BEGIN
aaa
//PREPEND END
//APPEND BEGIN
ccc
//APPEND END
//APPEND BEGIN
ddd
//APPEND END
"""
ret = parse_problem_template(template_str)
self.assertEqual(ret["prepend"], "aaa\n")
self.assertEqual(ret["template"], "")
self.assertEqual(ret["append"], "ccc\n")

View File

@@ -1,23 +1,18 @@
from django.urls import path
from ..views.admin import (
ContestProblemAPI,
ProblemAPI,
ProblemFlowchartAIGen,
StuckProblemsAPI,
TestCaseAPI,
MakeContestProblemPublicAPIView,
AddContestProblemAPI,
ProblemVisibleAPI,
)
from ..views.admin import (ContestProblemAPI, ProblemAPI, TestCaseAPI, MakeContestProblemPublicAPIView,
CompileSPJAPI, AddContestProblemAPI, ExportProblemAPI, ImportProblemAPI,
FPSProblemImport, ProblemVisibleAPI)
urlpatterns = [
path("test_case", TestCaseAPI.as_view()),
path("compile_spj", CompileSPJAPI.as_view()),
path("problem", ProblemAPI.as_view()),
path("problem/visible", ProblemVisibleAPI.as_view()),
path("problem/stuck", StuckProblemsAPI.as_view()),
path("problem/flowchart", ProblemFlowchartAIGen.as_view()),
path("contest/problem", ContestProblemAPI.as_view()),
path("contest_problem/make_public", MakeContestProblemPublicAPIView.as_view()),
path("contest/add_problem_from_public", AddContestProblemAPI.as_view()),
path("export_problem", ExportProblemAPI.as_view()),
path("import_problem", ImportProblemAPI.as_view()),
path("import_fps", FPSProblemImport.as_view()),
]

View File

@@ -6,16 +6,12 @@ from ..views.oj import (
ProblemAPI,
ContestProblemAPI,
PickOneAPI,
ProblemAuthorAPI,
SimilarProblemAPI,
)
urlpatterns = [
path("problem/tags", ProblemTagAPI.as_view()),
path("problem", ProblemAPI.as_view()),
path("problem/beat_count", ProblemSolvedPeopleCount.as_view()),
path("problem/similar", SimilarProblemAPI.as_view()),
path("problem/author", ProblemAuthorAPI.as_view()),
path("pickone", PickOneAPI.as_view()),
path("contest/problem", ContestProblemAPI.as_view()),
]

View File

@@ -1,45 +1,44 @@
import hashlib
import json
import os
# import shutil
import tempfile
import zipfile
from wsgiref.util import FileWrapper
from django.conf import settings
from django.db import transaction
from django.db.models import Q
from django.http import StreamingHttpResponse
from django.db.models import Count
from django.http import StreamingHttpResponse, FileResponse
from account.decorators import problem_permission_required, ensure_created_by, super_admin_required
from contest.models import Contest, ContestStatus
from submission.models import Submission
from fps.parser import FPSHelper, FPSParser
from judge.dispatcher import SPJCompiler
from options.options import SysOptions
from submission.models import Submission, JudgeStatus
from utils.api import APIView, CSRFExemptAPIView, validate_serializer, APIError
from utils.constants import Difficulty
from utils.shortcuts import rand_str, natural_sort_key
from utils.openai import get_ai_client
from utils.tasks import delete_files
from ..models import Problem, ProblemRuleType, ProblemTag
from ..serializers import (
CreateContestProblemSerializer,
CreateProblemSerializer,
EditProblemSerializer,
EditContestProblemSerializer,
ProblemAdminSerializer,
ProblemAdminListSerializer,
TestCaseUploadForm,
ContestProblemMakePublicSerializer,
AddContestProblemSerializer,
)
from ..serializers import (CreateContestProblemSerializer, CompileSPJSerializer,
CreateProblemSerializer, EditProblemSerializer, EditContestProblemSerializer,
ProblemAdminSerializer, ProblemAdminListSerializer, TestCaseUploadForm,
ContestProblemMakePublicSerializer, AddContestProblemSerializer, ExportProblemSerializer,
ExportProblemRequestSerializer, UploadProblemForm, ImportProblemSerializer,
FPSProblemSerializer)
from ..utils import TEMPLATE_BASE, build_problem_template
class TestCaseZipProcessor(object):
def process_zip(self, uploaded_zip_file, dir=""):
def process_zip(self, uploaded_zip_file, spj, dir=""):
try:
zip_file = zipfile.ZipFile(uploaded_zip_file, "r")
except zipfile.BadZipFile:
raise APIError("Bad zip file")
name_list = zip_file.namelist()
test_case_list = self.filter_name_list(name_list, dir=dir)
test_case_list = self.filter_name_list(name_list, spj=spj, dir=dir)
if not test_case_list:
raise APIError("Empty file")
@@ -58,22 +57,26 @@ class TestCaseZipProcessor(object):
if item.endswith(".out"):
md5_cache[item] = hashlib.md5(content.rstrip()).hexdigest()
f.write(content)
test_case_info = {"test_cases": {}}
test_case_info = {"spj": spj, "test_cases": {}}
info = []
# ["1.in", "1.out", "2.in", "2.out"] => [("1.in", "1.out"), ("2.in", "2.out")]
test_case_list = zip(*[test_case_list[i::2] for i in range(2)])
for index, item in enumerate(test_case_list):
data = {
"stripped_output_md5": md5_cache[item[1]],
"input_size": size_cache[item[0]],
"output_size": size_cache[item[1]],
"input_name": item[0],
"output_name": item[1],
}
info.append(data)
test_case_info["test_cases"][str(index + 1)] = data
if spj:
for index, item in enumerate(test_case_list):
data = {"input_name": item, "input_size": size_cache[item]}
info.append(data)
test_case_info["test_cases"][str(index + 1)] = data
else:
# ["1.in", "1.out", "2.in", "2.out"] => [("1.in", "1.out"), ("2.in", "2.out")]
test_case_list = zip(*[test_case_list[i::2] for i in range(2)])
for index, item in enumerate(test_case_list):
data = {"stripped_output_md5": md5_cache[item[1]],
"input_size": size_cache[item[0]],
"output_size": size_cache[item[1]],
"input_name": item[0],
"output_name": item[1]}
info.append(data)
test_case_info["test_cases"][str(index + 1)] = data
with open(os.path.join(test_case_dir, "info"), "w", encoding="utf-8") as f:
f.write(json.dumps(test_case_info, indent=4))
@@ -83,19 +86,29 @@ class TestCaseZipProcessor(object):
return info, test_case_id
def filter_name_list(self, name_list, dir=""):
def filter_name_list(self, name_list, spj, dir=""):
ret = []
prefix = 1
while True:
in_name = f"{prefix}.in"
out_name = f"{prefix}.out"
if f"{dir}{in_name}" in name_list and f"{dir}{out_name}" in name_list:
ret.append(in_name)
ret.append(out_name)
prefix += 1
continue
else:
return sorted(ret, key=natural_sort_key)
if spj:
while True:
in_name = f"{prefix}.in"
if f"{dir}{in_name}" in name_list:
ret.append(in_name)
prefix += 1
continue
else:
return sorted(ret, key=natural_sort_key)
else:
while True:
in_name = f"{prefix}.in"
out_name = f"{prefix}.out"
if f"{dir}{in_name}" in name_list and f"{dir}{out_name}" in name_list:
ret.append(in_name)
ret.append(out_name)
prefix += 1
continue
else:
return sorted(ret, key=natural_sort_key)
class TestCaseAPI(CSRFExemptAPIView, TestCaseZipProcessor):
@@ -118,25 +131,23 @@ class TestCaseAPI(CSRFExemptAPIView, TestCaseZipProcessor):
test_case_dir = os.path.join(settings.TEST_CASE_DIR, problem.test_case_id)
if not os.path.isdir(test_case_dir):
return self.error("Test case does not exists")
name_list = self.filter_name_list(os.listdir(test_case_dir))
name_list = self.filter_name_list(os.listdir(test_case_dir), problem.spj)
name_list.append("info")
file_name = os.path.join(test_case_dir, problem.test_case_id + ".zip")
with zipfile.ZipFile(file_name, "w") as file:
for test_case in name_list:
file.write(f"{test_case_dir}/{test_case}", test_case)
response = StreamingHttpResponse(
FileWrapper(open(file_name, "rb")), content_type="application/octet-stream"
)
response = StreamingHttpResponse(FileWrapper(open(file_name, "rb")),
content_type="application/octet-stream")
response["Content-Disposition"] = (
f"attachment; filename=problem_{problem.id}_test_cases.zip"
)
response["Content-Disposition"] = f"attachment; filename=problem_{problem.id}_test_cases.zip"
response["Content-Length"] = os.path.getsize(file_name)
return response
def post(self, request):
form = TestCaseUploadForm(request.POST, request.FILES)
if form.is_valid():
spj = form.cleaned_data["spj"] == "true"
file = form.cleaned_data["file"]
else:
return self.error("Upload failed")
@@ -144,14 +155,36 @@ class TestCaseAPI(CSRFExemptAPIView, TestCaseZipProcessor):
with open(zip_file, "wb") as f:
for chunk in file:
f.write(chunk)
info, test_case_id = self.process_zip(zip_file)
info, test_case_id = self.process_zip(zip_file, spj=spj)
os.remove(zip_file)
return self.success({"id": test_case_id, "info": info})
return self.success({"id": test_case_id, "info": info, "spj": spj})
class CompileSPJAPI(APIView):
@validate_serializer(CompileSPJSerializer)
def post(self, request):
data = request.data
spj_version = rand_str(8)
error = SPJCompiler(data["spj_code"], spj_version, data["spj_language"]).compile_spj()
if error:
return self.error(error)
else:
return self.success()
class ProblemBase(APIView):
def common_checks(self, request):
data = request.data
if data["spj"]:
if not data["spj_language"] or not data["spj_code"]:
return "Invalid spj"
if not data["spj_compile_ok"]:
return "SPJ code must be compiled successfully"
data["spj_version"] = hashlib.md5(
(data["spj_language"] + ":" + data["spj_code"]).encode("utf-8")).hexdigest()
else:
data["spj_language"] = None
data["spj_code"] = None
if data["rule_type"] == ProblemRuleType.OI:
total_score = 0
for item in data["test_case_score"]:
@@ -194,6 +227,7 @@ class ProblemAPI(ProblemBase):
@problem_permission_required
def get(self, request):
problem_id = request.GET.get("id")
rule_type = request.GET.get("rule_type")
user = request.user
if problem_id:
try:
@@ -203,24 +237,19 @@ class ProblemAPI(ProblemBase):
except Problem.DoesNotExist:
return self.error("Problem does not exist")
problems = Problem.objects.filter(contest_id__isnull=True).order_by(
"-create_time"
)
author = request.GET.get("author", "")
if author:
problems = problems.filter(created_by__username=author)
problems = Problem.objects.filter(contest_id__isnull=True).order_by("-create_time")
if rule_type:
if rule_type not in ProblemRuleType.choices():
return self.error("Invalid rule_type")
else:
problems = problems.filter(rule_type=rule_type)
keyword = request.GET.get("keyword", "").strip()
if keyword:
problems = problems.filter(
Q(title__icontains=keyword) | Q(_id__icontains=keyword)
)
problems = problems.filter(Q(title__icontains=keyword) | Q(_id__icontains=keyword))
if not user.can_mgmt_all_problem():
problems = problems.filter(created_by=user)
return self.success(
self.paginate_data(request, problems, ProblemAdminListSerializer)
)
return self.success(self.paginate_data(request, problems, ProblemAdminListSerializer))
@problem_permission_required
@validate_serializer(EditProblemSerializer)
@@ -237,11 +266,7 @@ class ProblemAPI(ProblemBase):
_id = data["_id"]
if not _id:
return self.error("Display ID is required")
if (
Problem.objects.exclude(id=problem_id)
.filter(_id=_id, contest_id__isnull=True)
.exists()
):
if Problem.objects.exclude(id=problem_id).filter(_id=_id, contest_id__isnull=True).exists():
return self.error("Display ID already exists")
error_info = self.common_checks(request)
@@ -345,9 +370,7 @@ class ContestProblemAPI(ProblemBase):
keyword = request.GET.get("keyword")
if keyword:
problems = problems.filter(title__contains=keyword)
return self.success(
self.paginate_data(request, problems, ProblemAdminListSerializer)
)
return self.success(self.paginate_data(request, problems, ProblemAdminListSerializer))
@validate_serializer(EditContestProblemSerializer)
def put(self, request):
@@ -373,11 +396,7 @@ class ContestProblemAPI(ProblemBase):
_id = data["_id"]
if not _id:
return self.error("Display ID is required")
if (
Problem.objects.exclude(id=problem_id)
.filter(_id=_id, contest=contest)
.exists()
):
if Problem.objects.exclude(id=problem_id).filter(_id=_id, contest=contest).exists():
return self.error("Display ID already exists")
error_info = self.common_checks(request)
@@ -436,6 +455,7 @@ class MakeContestProblemPublicAPIView(APIView):
return self.error("Already be a public problem")
problem.is_public = True
problem.save()
# https://docs.djangoproject.com/en/1.11/topics/db/queries/#copying-model-instances
tags = problem.tags.all()
problem.pk = None
problem.contest = None
@@ -476,6 +496,215 @@ class AddContestProblemAPI(APIView):
return self.success()
class ExportProblemAPI(APIView):
def choose_answers(self, user, problem):
ret = []
for item in problem.languages:
submission = Submission.objects.filter(problem=problem,
user_id=user.id,
language=item,
result=JudgeStatus.ACCEPTED).order_by("-create_time").first()
if submission:
ret.append({"language": submission.language, "code": submission.code})
return ret
def process_one_problem(self, zip_file, user, problem, index):
info = ExportProblemSerializer(problem).data
info["answers"] = self.choose_answers(user, problem=problem)
compression = zipfile.ZIP_DEFLATED
zip_file.writestr(zinfo_or_arcname=f"{index}/problem.json",
data=json.dumps(info, indent=4),
compress_type=compression)
problem_test_case_dir = os.path.join(settings.TEST_CASE_DIR, problem.test_case_id)
with open(os.path.join(problem_test_case_dir, "info")) as f:
info = json.load(f)
for k, v in info["test_cases"].items():
zip_file.write(filename=os.path.join(problem_test_case_dir, v["input_name"]),
arcname=f"{index}/testcase/{v['input_name']}",
compress_type=compression)
if not info["spj"]:
zip_file.write(filename=os.path.join(problem_test_case_dir, v["output_name"]),
arcname=f"{index}/testcase/{v['output_name']}",
compress_type=compression)
@validate_serializer(ExportProblemRequestSerializer)
def get(self, request):
problems = Problem.objects.filter(id__in=request.data["problem_id"])
for problem in problems:
if problem.contest:
ensure_created_by(problem.contest, request.user)
else:
ensure_created_by(problem, request.user)
path = f"/tmp/{rand_str()}.zip"
with zipfile.ZipFile(path, "w") as zip_file:
for index, problem in enumerate(problems):
self.process_one_problem(zip_file=zip_file, user=request.user, problem=problem, index=index + 1)
delete_files.send_with_options(args=(path,), delay=300_000)
resp = FileResponse(open(path, "rb"))
resp["Content-Type"] = "application/zip"
resp["Content-Disposition"] = "attachment;filename=problem-export.zip"
return resp
class ImportProblemAPI(CSRFExemptAPIView, TestCaseZipProcessor):
request_parsers = ()
def post(self, request):
form = UploadProblemForm(request.POST, request.FILES)
if form.is_valid():
file = form.cleaned_data["file"]
tmp_file = f"/tmp/{rand_str()}.zip"
with open(tmp_file, "wb") as f:
for chunk in file:
f.write(chunk)
else:
return self.error("Upload failed")
count = 0
with zipfile.ZipFile(tmp_file, "r") as zip_file:
name_list = zip_file.namelist()
for item in name_list:
if "/problem.json" in item:
count += 1
with transaction.atomic():
for i in range(1, count + 1):
with zip_file.open(f"{i}/problem.json") as f:
problem_info = json.load(f)
serializer = ImportProblemSerializer(data=problem_info)
if not serializer.is_valid():
return self.error(f"Invalid problem format, error is {serializer.errors}")
else:
problem_info = serializer.data
for item in problem_info["template"].keys():
if item not in SysOptions.language_names:
return self.error(f"Unsupported language {item}")
problem_info["display_id"] = problem_info["display_id"][:24]
for k, v in problem_info["template"].items():
problem_info["template"][k] = build_problem_template(v["prepend"], v["template"],
v["append"])
spj = problem_info["spj"] is not None
rule_type = problem_info["rule_type"]
test_case_score = problem_info["test_case_score"]
# process test case
_, test_case_id = self.process_zip(tmp_file, spj=spj, dir=f"{i}/testcase/")
problem_obj = Problem.objects.create(_id=problem_info["display_id"],
title=problem_info["title"],
description=problem_info["description"]["value"],
input_description=problem_info["input_description"][
"value"],
output_description=problem_info["output_description"][
"value"],
hint=problem_info["hint"]["value"],
test_case_score=test_case_score if test_case_score else [],
time_limit=problem_info["time_limit"],
memory_limit=problem_info["memory_limit"],
samples=problem_info["samples"],
template=problem_info["template"],
rule_type=problem_info["rule_type"],
source=problem_info["source"],
spj=spj,
spj_code=problem_info["spj"]["code"] if spj else None,
spj_language=problem_info["spj"][
"language"] if spj else None,
spj_version=rand_str(8) if spj else "",
languages=SysOptions.language_names,
created_by=request.user,
visible=False,
difficulty=Difficulty.MID,
total_score=sum(item["score"] for item in test_case_score)
if rule_type == ProblemRuleType.OI else 0,
test_case_id=test_case_id
)
for tag_name in problem_info["tags"]:
tag_obj, _ = ProblemTag.objects.get_or_create(name=tag_name)
problem_obj.tags.add(tag_obj)
return self.success({"import_count": count})
class FPSProblemImport(CSRFExemptAPIView):
request_parsers = ()
def _create_problem(self, problem_data, creator):
if problem_data["time_limit"]["unit"] == "ms":
time_limit = problem_data["time_limit"]["value"]
else:
time_limit = problem_data["time_limit"]["value"] * 1000
template = {}
prepend = {}
append = {}
for t in problem_data["prepend"]:
prepend[t["language"]] = t["code"]
for t in problem_data["append"]:
append[t["language"]] = t["code"]
for t in problem_data["template"]:
our_lang = lang = t["language"]
if lang == "Python":
our_lang = "Python3"
template[our_lang] = TEMPLATE_BASE.format(prepend.get(lang, ""), t["code"], append.get(lang, ""))
spj = problem_data["spj"] is not None
Problem.objects.create(_id=f"fps-{rand_str(4)}",
title=problem_data["title"],
description=problem_data["description"],
input_description=problem_data["input"],
output_description=problem_data["output"],
hint=problem_data["hint"],
test_case_score=problem_data["test_case_score"],
time_limit=time_limit,
memory_limit=problem_data["memory_limit"]["value"],
samples=problem_data["samples"],
template=template,
rule_type=ProblemRuleType.ACM,
source=problem_data.get("source", ""),
spj=spj,
spj_code=problem_data["spj"]["code"] if spj else None,
spj_language=problem_data["spj"]["language"] if spj else None,
spj_version=rand_str(8) if spj else "",
visible=False,
languages=SysOptions.language_names,
created_by=creator,
difficulty=Difficulty.MID,
test_case_id=problem_data["test_case_id"])
def post(self, request):
form = UploadProblemForm(request.POST, request.FILES)
if form.is_valid():
file = form.cleaned_data["file"]
with tempfile.NamedTemporaryFile("wb") as tf:
for chunk in file.chunks(4096):
tf.file.write(chunk)
tf.file.flush()
os.fsync(tf.file)
problems = FPSParser(tf.name).parse()
else:
return self.error("Parse upload file error")
helper = FPSHelper()
with transaction.atomic():
for _problem in problems:
test_case_id = rand_str()
test_case_dir = os.path.join(settings.TEST_CASE_DIR, test_case_id)
os.mkdir(test_case_dir)
score = []
for item in helper.save_test_case(_problem, test_case_dir)["test_cases"].values():
score.append({"score": 0, "input_name": item["input_name"],
"output_name": item.get("output_name")})
problem_data = helper.save_image(_problem, settings.UPLOAD_DIR, settings.UPLOAD_PREFIX)
s = FPSProblemSerializer(data=problem_data)
if not s.is_valid():
return self.error(f"Parse FPS file error: {s.errors}")
problem_data = s.data
problem_data["test_case_id"] = test_case_id
problem_data["test_case_score"] = score
self._create_problem(problem_data, request.user)
return self.success({"import_count": len(problems)})
class ProblemVisibleAPI(APIView):
@problem_permission_required
def put(self, request):
@@ -486,65 +715,4 @@ class ProblemVisibleAPI(APIView):
self.error("problem does not exists")
problem.visible = not problem.visible
problem.save()
return self.success()
class ProblemFlowchartAIGen(APIView):
@problem_permission_required
def post(self, request):
python_code = request.data.get("python", "")
client = get_ai_client()
response = client.chat.completions.create(
model="deepseek-reasoner",
messages=[
{
"role": "system",
"content": """你是一个可以将Python代码转换为mermaid的助手。
请将用户提供的Python代码转换为 Mermaid 纯文本。
注意括号内的内容用引号包裹,如果本身就有引号,请注意双引号和单引号的问题。
请只返回 mermaid 代码,连 ``` 都不需要。""",
},
{"role": "user", "content": python_code},
],
temperature=1.0,
)
mermaid_code = response.choices[0].message.content
return self.success({"flowchart": mermaid_code})
class StuckProblemsAPI(APIView):
@super_admin_required
def get(self, request):
from submission.models import JudgeStatus
failed_q = Q(result__in=[
JudgeStatus.WRONG_ANSWER,
JudgeStatus.COMPILE_ERROR,
JudgeStatus.RUNTIME_ERROR,
])
rows = (
Submission.objects.values("problem_id", "problem___id", "problem__title")
.annotate(
total=Count("id"),
accepted=Count("id", filter=Q(result=JudgeStatus.ACCEPTED)),
failed=Count("id", filter=failed_q),
failed_users=Count("user_id", filter=failed_q, distinct=True),
)
.filter(failed_users__gt=0)
.order_by("-failed_users")[:40]
)
result = [
{
"problem_id": r["problem___id"],
"problem_title": r["problem__title"],
"total": r["total"],
"failed": r["failed"],
"failed_users": r["failed_users"],
"ac_rate": round(r["accepted"] / r["total"] * 100, 1)
if r["total"]
else 0,
}
for r in rows
]
return self.success(result)
return self.success()

View File

@@ -1,12 +1,10 @@
from datetime import datetime
import random
from django.db.models import Q, Count
from django.core.cache import cache
from account.models import User
from submission.models import Submission, JudgeStatus
from utils.api import APIView
from account.decorators import check_contest_permission
from utils.constants import CacheKey
from ..models import ProblemTag, Problem, ProblemRuleType
from ..serializers import (
ProblemSerializer,
@@ -42,16 +40,24 @@ class ProblemAPI(APIView):
if request.user.is_authenticated:
profile = request.user.userprofile
acm_problems_status = profile.acm_problems_status.get("problems", {})
oi_problems_status = profile.oi_problems_status.get("problems", {})
# paginate data
results = queryset_values.get("results")
if results is not None:
problems = results
else:
problems = [queryset_values]
problems = [
queryset_values,
]
for problem in problems:
problem["my_status"] = acm_problems_status.get(
str(problem["id"]), {}
).get("status")
if problem["rule_type"] == ProblemRuleType.ACM:
problem["my_status"] = acm_problems_status.get(
str(problem["id"]), {}
).get("status")
else:
problem["my_status"] = oi_problems_status.get(
str(problem["id"]), {}
).get("status")
def get(self, request):
# 问题详情页
@@ -63,22 +69,6 @@ class ProblemAPI(APIView):
)
problem_data = ProblemSerializer(problem).data
self._add_problem_status(request, problem_data)
if request.user.is_authenticated:
failed_statuses = [
JudgeStatus.WRONG_ANSWER,
JudgeStatus.CPU_TIME_LIMIT_EXCEEDED,
JudgeStatus.REAL_TIME_LIMIT_EXCEEDED,
JudgeStatus.MEMORY_LIMIT_EXCEEDED,
JudgeStatus.RUNTIME_ERROR,
JudgeStatus.COMPILE_ERROR,
]
problem_data["my_failed_count"] = Submission.objects.filter(
user_id=request.user.id,
problem_id=problem.id,
result__in=failed_statuses,
).count()
else:
problem_data["my_failed_count"] = 0
return self.success(problem_data)
except Problem.DoesNotExist:
return self.error("Problem does not exist")
@@ -92,11 +82,6 @@ class ProblemAPI(APIView):
.filter(contest_id__isnull=True, visible=True)
.order_by("-create_time")
)
author = request.GET.get("author")
if author:
problems = problems.filter(created_by__username=author)
# 按照标签筛选
tag_text = request.GET.get("tag")
if tag_text:
@@ -113,12 +98,6 @@ class ProblemAPI(APIView):
difficulty = request.GET.get("difficulty")
if difficulty:
problems = problems.filter(difficulty=difficulty)
# 排序
sort = request.GET.get("sort")
if sort:
problems = problems.order_by(sort)
# 根据profile 为做过的题目添加标记
data = self.paginate_data(request, problems, ProblemListSerializer)
self._add_problem_status(request, data)
@@ -187,82 +166,17 @@ class ProblemSolvedPeopleCount(APIView):
if submission_count == 0:
return self.success(rate)
today = datetime.today()
years_ago = datetime(today.year - 2, today.month, today.day, 0, 0)
twoYearAge = datetime(today.year - 2, today.month, today.day, 0, 0)
total_count = User.objects.filter(
is_disabled=False, last_login__gte=years_ago
is_disabled=False, last_login__gte=twoYearAge
).count()
accepted_count = Submission.objects.filter(
problem_id=problem_id,
result=JudgeStatus.ACCEPTED,
create_time__gte=years_ago,
create_time__gte=twoYearAge,
).aggregate(user_count=Count("user_id", distinct=True))["user_count"]
if accepted_count < total_count:
rate = "%.2f" % ((total_count - accepted_count) / total_count * 100)
else:
rate = "0"
return self.success(rate)
class SimilarProblemAPI(APIView):
def get(self, request):
problem_display_id = request.GET.get("problem_id")
if not problem_display_id:
return self.error("problem_id is required")
try:
problem = Problem.objects.get(_id=problem_display_id, contest__isnull=True)
except Problem.DoesNotExist:
return self.error("Problem not found")
tag_ids = list(problem.tags.values_list("id", flat=True))
if not tag_ids:
return self.success([])
exclude_ids = [problem_display_id]
if request.user.is_authenticated:
profile = request.user.userprofile
ac_display_ids = [
v["_id"]
for v in profile.acm_problems_status.get("problems", {}).values()
if v.get("status") == JudgeStatus.ACCEPTED
]
exclude_ids.extend(ac_display_ids)
similar = (
Problem.objects.filter(tags__in=tag_ids, visible=True, contest__isnull=True)
.exclude(_id__in=exclude_ids)
.distinct()
.order_by("difficulty")[:5]
)
return self.success(ProblemListSerializer(similar, many=True).data)
class ProblemAuthorAPI(APIView):
def get(self, request):
show_all = request.GET.get("all", "0") == "1"
cached_data = cache.get(
f"{CacheKey.problem_authors}{'_all' if show_all else '_only_visible'}"
)
if cached_data:
return self.success(cached_data)
problem_filter = {"contest_id__isnull": True, "created_by__is_disabled": False}
if not show_all:
problem_filter["visible"] = True
authors = (
Problem.objects.filter(**problem_filter)
.values("created_by__username")
.annotate(problem_count=Count("id"))
.order_by("-problem_count")
)
result = [
{
"username": author["created_by__username"],
"problem_count": author["problem_count"],
}
for author in authors
]
cache.set(CacheKey.problem_authors, result, 7200)
return self.success(result)

View File

View File

@@ -1,9 +0,0 @@
from django.apps import AppConfig
class ProblemsetConfig(AppConfig):
default_auto_field = 'django.db.models.BigAutoField'
name = 'problemset'
def ready(self):
import problemset.signals

View File

@@ -1,115 +0,0 @@
# Generated by Django 5.2.3 on 2025-10-22 10:27
import django.db.models.deletion
import utils.models
from django.conf import settings
from django.db import migrations, models
class Migration(migrations.Migration):
initial = True
dependencies = [
('problem', '0005_remove_spj_fields'),
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
]
operations = [
migrations.CreateModel(
name='ProblemSet',
fields=[
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('title', models.TextField(verbose_name='题单标题')),
('description', utils.models.RichTextField(verbose_name='题单描述')),
('create_time', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')),
('last_update_time', models.DateTimeField(auto_now=True, verbose_name='更新时间')),
('visible', models.BooleanField(default=True, verbose_name='是否可见')),
('is_public', models.BooleanField(default=True, verbose_name='是否公开')),
('difficulty', models.TextField(default='Easy', verbose_name='难度等级')),
('status', models.TextField(default='active', verbose_name='状态')),
('created_by', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL, verbose_name='创建者')),
],
options={
'verbose_name': '题单',
'verbose_name_plural': '题单',
'db_table': 'problemset',
'ordering': ('-create_time',),
},
),
migrations.CreateModel(
name='ProblemSetBadge',
fields=[
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('name', models.TextField(verbose_name='奖章名称')),
('description', models.TextField(verbose_name='奖章描述')),
('icon', models.TextField(verbose_name='奖章图标')),
('condition_type', models.TextField(verbose_name='获得条件类型')),
('condition_value', models.IntegerField(default=0, verbose_name='条件值')),
('level', models.IntegerField(default=1, verbose_name='奖章等级')),
('problemset', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='problemset.problemset', verbose_name='题单')),
],
options={
'verbose_name': '题单奖章',
'verbose_name_plural': '题单奖章',
'db_table': 'problemset_badge',
},
),
migrations.CreateModel(
name='ProblemSetProblem',
fields=[
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('order', models.IntegerField(default=0, verbose_name='顺序')),
('is_required', models.BooleanField(default=True, verbose_name='是否必做')),
('score', models.IntegerField(default=0, verbose_name='分值')),
('hint', models.TextField(blank=True, null=True, verbose_name='提示')),
('problem', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='problem.problem', verbose_name='题目')),
('problemset', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='problemset.problemset', verbose_name='题单')),
],
options={
'verbose_name': '题单题目',
'verbose_name_plural': '题单题目',
'db_table': 'problemset_problem',
'ordering': ('order',),
'unique_together': {('problemset', 'problem')},
},
),
migrations.CreateModel(
name='ProblemSetProgress',
fields=[
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('join_time', models.DateTimeField(auto_now_add=True, verbose_name='加入时间')),
('complete_time', models.DateTimeField(blank=True, null=True, verbose_name='完成时间')),
('is_completed', models.BooleanField(default=False, verbose_name='是否完成')),
('progress_percentage', models.FloatField(default=0.0, verbose_name='完成进度')),
('completed_problems_count', models.IntegerField(default=0, verbose_name='已完成题目数')),
('total_problems_count', models.IntegerField(default=0, verbose_name='总题目数')),
('total_score', models.IntegerField(default=0, verbose_name='总分')),
('progress_detail', models.JSONField(default=dict, verbose_name='详细进度')),
('problemset', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='problemset.problemset', verbose_name='题单')),
('user', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL, verbose_name='用户')),
],
options={
'verbose_name': '题单进度',
'verbose_name_plural': '题单进度',
'db_table': 'problemset_progress',
'unique_together': {('problemset', 'user')},
},
),
migrations.CreateModel(
name='UserBadge',
fields=[
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('earned_time', models.DateTimeField(auto_now_add=True, verbose_name='获得时间')),
('is_displayed', models.BooleanField(default=False, verbose_name='是否已展示')),
('badge', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='problemset.problemsetbadge', verbose_name='奖章')),
('user', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL, verbose_name='用户')),
],
options={
'verbose_name': '用户奖章',
'verbose_name_plural': '用户奖章',
'db_table': 'user_badge',
'unique_together': {('user', 'badge')},
},
),
]

View File

@@ -1,17 +0,0 @@
# Generated by Django 5.2.3 on 2025-10-22 11:04
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
('problemset', '0001_initial'),
]
operations = [
migrations.RemoveField(
model_name='problemset',
name='is_public',
),
]

View File

@@ -1,17 +0,0 @@
# Generated by Django 5.2.3 on 2025-10-22 12:04
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
('problemset', '0002_remove_is_public_field'),
]
operations = [
migrations.RemoveField(
model_name='problemsetbadge',
name='level',
),
]

View File

@@ -1,47 +0,0 @@
# Generated by Django 5.2.3 on 2025-10-22 16:49
import django.db.models.deletion
from django.conf import settings
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('problem', '0005_remove_spj_fields'),
('problemset', '0003_remove_badge_level'),
('submission', '0002_submission_user_create_time_idx'),
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
]
operations = [
migrations.AlterField(
model_name='problemset',
name='status',
field=models.TextField(default='draft', verbose_name='状态'),
),
migrations.CreateModel(
name='ProblemSetSubmission',
fields=[
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('submit_time', models.DateTimeField(auto_now_add=True, verbose_name='提交时间')),
('result', models.IntegerField(verbose_name='提交结果')),
('score', models.IntegerField(default=0, verbose_name='得分')),
('language', models.CharField(max_length=20, verbose_name='编程语言')),
('code_length', models.IntegerField(default=0, verbose_name='代码长度')),
('execution_time', models.IntegerField(default=0, verbose_name='执行时间')),
('memory_usage', models.IntegerField(default=0, verbose_name='内存使用')),
('problem', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='problem.problem', verbose_name='题目')),
('problemset', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='problemset.problemset', verbose_name='题单')),
('submission', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='submission.submission', verbose_name='提交记录')),
('user', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL, verbose_name='用户')),
],
options={
'verbose_name': '题单提交记录',
'verbose_name_plural': '题单提交记录',
'db_table': 'problemset_submission',
'ordering': ('-submit_time',),
'indexes': [models.Index(fields=['problemset', 'user'], name='problemset__problem_1f39fa_idx'), models.Index(fields=['problemset', 'problem'], name='problemset__problem_22f053_idx'), models.Index(fields=['user', 'submit_time'], name='problemset__user_id_63c1d0_idx')],
},
),
]

View File

@@ -1,57 +0,0 @@
# Generated by Django 5.2.3 on 2025-10-23 01:34
from django.conf import settings
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('problem', '0005_remove_spj_fields'),
('problemset', '0004_alter_problemset_status_problemsetsubmission'),
('submission', '0002_submission_user_create_time_idx'),
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
]
operations = [
migrations.AlterModelOptions(
name='problemsetsubmission',
options={'ordering': ('-submission__create_time',), 'verbose_name': '题单提交记录', 'verbose_name_plural': '题单提交记录'},
),
migrations.RemoveIndex(
model_name='problemsetsubmission',
name='problemset__user_id_63c1d0_idx',
),
migrations.RemoveField(
model_name='problemsetsubmission',
name='code_length',
),
migrations.RemoveField(
model_name='problemsetsubmission',
name='execution_time',
),
migrations.RemoveField(
model_name='problemsetsubmission',
name='language',
),
migrations.RemoveField(
model_name='problemsetsubmission',
name='memory_usage',
),
migrations.RemoveField(
model_name='problemsetsubmission',
name='result',
),
migrations.RemoveField(
model_name='problemsetsubmission',
name='score',
),
migrations.RemoveField(
model_name='problemsetsubmission',
name='submit_time',
),
migrations.AddIndex(
model_name='problemsetsubmission',
index=models.Index(fields=['user'], name='problemset__user_id_2f1501_idx'),
),
]

View File

@@ -1,17 +0,0 @@
# Generated by Django 5.2.3 on 2025-10-23 03:41
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
('problemset', '0005_alter_problemsetsubmission_options_and_more'),
]
operations = [
migrations.RemoveField(
model_name='userbadge',
name='is_displayed',
),
]

View File

@@ -1,18 +0,0 @@
# Generated by Django 6.0 on 2026-03-16 15:32
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('problemset', '0006_remove_is_displayed_field'),
]
operations = [
migrations.AddField(
model_name='problemset',
name='end_time',
field=models.DateTimeField(blank=True, null=True, verbose_name='截止时间'),
),
]

View File

@@ -1,286 +0,0 @@
from django.db import models
from django.utils.timezone import now
from account.models import User
from problem.models import Problem
from utils.models import RichTextField, JSONField
class ProblemSet(models.Model):
"""题单模型"""
title = models.TextField(verbose_name="题单标题")
description = RichTextField(verbose_name="题单描述")
# 创建者
created_by = models.ForeignKey(
User, on_delete=models.CASCADE, verbose_name="创建者"
)
# 创建时间
create_time = models.DateTimeField(auto_now_add=True, verbose_name="创建时间")
# 更新时间
last_update_time = models.DateTimeField(auto_now=True, verbose_name="更新时间")
# 是否可见
visible = models.BooleanField(default=True, verbose_name="是否可见")
# 题单难度等级
difficulty = models.TextField(default="Easy", verbose_name="难度等级")
# 题单状态
status = models.TextField(
default="draft", verbose_name="状态"
) # active, archived, draft
# 截止时间(到期后自动解除防作弊隐藏)
end_time = models.DateTimeField(null=True, blank=True, verbose_name="截止时间")
class Meta:
db_table = "problemset"
ordering = ("-create_time",)
verbose_name = "题单"
verbose_name_plural = "题单"
def __str__(self):
return self.title
class ProblemSetProblem(models.Model):
"""题单题目关联模型"""
problemset = models.ForeignKey(
ProblemSet, on_delete=models.CASCADE, verbose_name="题单"
)
problem = models.ForeignKey(Problem, on_delete=models.CASCADE, verbose_name="题目")
# 在题单中的顺序
order = models.IntegerField(default=0, verbose_name="顺序")
# 是否为必做题
is_required = models.BooleanField(default=True, verbose_name="是否必做")
# 题目在题单中的分值
score = models.IntegerField(default=0, verbose_name="分值")
# 题目提示信息
hint = models.TextField(null=True, blank=True, verbose_name="提示")
class Meta:
db_table = "problemset_problem"
unique_together = (("problemset", "problem"),)
ordering = ("order",)
verbose_name = "题单题目"
verbose_name_plural = "题单题目"
def __str__(self):
return f"{self.problemset.title} - {self.problem.title}"
class ProblemSetBadge(models.Model):
"""题单奖章模型"""
problemset = models.ForeignKey(
ProblemSet, on_delete=models.CASCADE, verbose_name="题单"
)
name = models.TextField(verbose_name="奖章名称")
description = models.TextField(verbose_name="奖章描述")
# 奖章图标路径
icon = models.TextField(verbose_name="奖章图标")
# 获得条件:完成所有题目、完成指定数量题目、达到指定分数等
condition_type = models.TextField(
verbose_name="获得条件类型"
) # all_problems, problem_count, score
condition_value = models.IntegerField(default=0, verbose_name="条件值")
class Meta:
db_table = "problemset_badge"
verbose_name = "题单奖章"
verbose_name_plural = "题单奖章"
def __str__(self):
return f"{self.problemset.title} - {self.name}"
def recalculate_user_badges(self):
"""重新计算所有用户的徽章资格"""
# 获取所有已加入该题单的用户进度
user_progresses = ProblemSetProgress.objects.filter(problemset=self.problemset)
# 删除该徽章的所有现有用户徽章记录
UserBadge.objects.filter(badge=self).delete()
# 重新评估每个用户的徽章资格
for progress in user_progresses:
self._check_user_badge_eligibility(progress)
def _check_user_badge_eligibility(self, progress):
"""检查用户是否符合该徽章的条件"""
# 检查是否已经拥有该徽章
if UserBadge.objects.filter(user=progress.user, badge=self).exists():
return False
# 根据条件类型检查用户是否符合条件
if self.condition_type == "all_problems":
if progress.completed_problems_count == progress.total_problems_count:
UserBadge.objects.create(user=progress.user, badge=self)
return True
elif self.condition_type == "problem_count":
if progress.completed_problems_count >= self.condition_value:
UserBadge.objects.create(user=progress.user, badge=self)
return True
elif self.condition_type == "score":
if progress.total_score >= self.condition_value:
UserBadge.objects.create(user=progress.user, badge=self)
return True
return False
class ProblemSetProgress(models.Model):
"""题单进度模型"""
problemset = models.ForeignKey(
ProblemSet, on_delete=models.CASCADE, verbose_name="题单"
)
user = models.ForeignKey(User, on_delete=models.CASCADE, verbose_name="用户")
# 加入时间
join_time = models.DateTimeField(auto_now_add=True, verbose_name="加入时间")
# 完成时间
complete_time = models.DateTimeField(null=True, blank=True, verbose_name="完成时间")
# 是否完成
is_completed = models.BooleanField(default=False, verbose_name="是否完成")
# 完成进度百分比
progress_percentage = models.FloatField(default=0.0, verbose_name="完成进度")
# 已完成的题目数量
completed_problems_count = models.IntegerField(
default=0, verbose_name="已完成题目数"
)
# 总题目数量
total_problems_count = models.IntegerField(default=0, verbose_name="总题目数")
# 获得的总分
total_score = models.IntegerField(default=0, verbose_name="总分")
# 用户在该题单中的详细进度信息
# {"problem_id": {"score": 20, "submit_time": "2024-01-01T00:00:00Z"}}
progress_detail = JSONField(default=dict, verbose_name="详细进度")
class Meta:
db_table = "problemset_progress"
unique_together = (("problemset", "user"),)
verbose_name = "题单进度"
verbose_name_plural = "题单进度"
def __str__(self):
return f"{self.user.username} - {self.problemset.title}"
def update_progress(self):
"""更新进度信息"""
# 获取题单中的所有题目
problemset_problems = ProblemSetProblem.objects.filter(
problemset=self.problemset
)
self.total_problems_count = problemset_problems.count()
# 获取当前题单中所有题目的ID集合
current_problem_ids = {str(psp.problem.id) for psp in problemset_problems}
# 清理已删除题目的进度记录
progress_detail_to_remove = []
for problem_id in self.progress_detail.keys():
if problem_id not in current_problem_ids:
progress_detail_to_remove.append(problem_id)
for problem_id in progress_detail_to_remove:
del self.progress_detail[problem_id]
# 计算已完成题目数
completed_count = 0
total_score = 0
for psp in problemset_problems:
problem_id = str(psp.problem.id)
if problem_id in self.progress_detail:
problem_progress = self.progress_detail[problem_id]
completed_count += 1
total_score += psp.score
problem_progress["score"] = psp.score
self.completed_problems_count = completed_count
self.total_score = total_score
# 计算完成百分比
if self.total_problems_count > 0:
self.progress_percentage = (
completed_count / self.total_problems_count
) * 100
else:
self.progress_percentage = 0
# 检查是否完成
self.is_completed = completed_count == self.total_problems_count
if self.is_completed and not self.complete_time:
self.complete_time = now()
self.save()
@classmethod
def sync_all_progress_for_problemset(cls, problemset):
"""同步指定题单的所有用户进度"""
progresses = cls.objects.filter(problemset=problemset)
for progress in progresses:
progress.update_progress()
return progresses.count()
class ProblemSetSubmission(models.Model):
"""题单提交记录模型"""
problemset = models.ForeignKey(
ProblemSet, on_delete=models.CASCADE, verbose_name="题单"
)
user = models.ForeignKey(User, on_delete=models.CASCADE, verbose_name="用户")
submission = models.ForeignKey(
"submission.Submission", on_delete=models.CASCADE, verbose_name="提交记录"
)
problem = models.ForeignKey(
"problem.Problem", on_delete=models.CASCADE, verbose_name="题目"
)
class Meta:
db_table = "problemset_submission"
ordering = ("-submission__create_time",)
verbose_name = "题单提交记录"
verbose_name_plural = "题单提交记录"
indexes = [
models.Index(fields=["problemset", "user"]),
models.Index(fields=["problemset", "problem"]),
models.Index(fields=["user"]),
]
def __str__(self):
return f"{self.user.username} - {self.problemset.title} - {self.problem.title}"
@property
def submit_time(self):
"""提交时间"""
return self.submission.create_time
@property
def result(self):
"""提交结果"""
return self.submission.result
@property
def language(self):
"""编程语言"""
return self.submission.language
class UserBadge(models.Model):
"""用户奖章模型"""
user = models.ForeignKey(User, on_delete=models.CASCADE, verbose_name="用户")
badge = models.ForeignKey(
ProblemSetBadge, on_delete=models.CASCADE, verbose_name="奖章"
)
# 获得时间
earned_time = models.DateTimeField(auto_now_add=True, verbose_name="获得时间")
class Meta:
db_table = "user_badge"
unique_together = (("user", "badge"),)
verbose_name = "用户奖章"
verbose_name_plural = "用户奖章"
def __str__(self):
return f"{self.user.username} - {self.badge.name}"

View File

@@ -1,314 +0,0 @@
from utils.api import UsernameSerializer, serializers
from .models import (
ProblemSet,
ProblemSetProblem,
ProblemSetBadge,
ProblemSetProgress,
UserBadge,
)
def get_user_progress_data(problemset, request):
"""获取当前用户在该题单中的进度 - 公共方法"""
if request and request.user.is_authenticated:
try:
progress = ProblemSetProgress.objects.get(
problemset=problemset, user=request.user
)
return {
"is_joined": True,
"progress_percentage": progress.progress_percentage,
"completed_count": progress.completed_problems_count,
"total_count": progress.total_problems_count,
"is_completed": progress.is_completed,
}
except ProblemSetProgress.DoesNotExist:
return {
"is_joined": False,
"progress_percentage": 0,
"completed_count": 0,
"total_count": 0,
"is_completed": False,
}
return {
"is_joined": False,
"progress_percentage": 0,
"completed_count": 0,
"total_count": 0,
"is_completed": False,
}
class ProblemSetSerializer(serializers.ModelSerializer):
"""题单序列化器"""
created_by = UsernameSerializer()
problems_count = serializers.SerializerMethodField()
completed_count = serializers.SerializerMethodField()
user_progress = serializers.SerializerMethodField()
class Meta:
model = ProblemSet
fields = "__all__"
def get_problems_count(self, obj):
"""获取题单中的题目数量"""
return ProblemSetProblem.objects.filter(problemset=obj).count()
def get_completed_count(self, obj):
"""获取当前用户在该题单中完成的题目数量"""
request = self.context.get("request")
if request and request.user.is_authenticated:
try:
progress = ProblemSetProgress.objects.get(
problemset=obj, user=request.user
)
return progress.completed_problems_count
except ProblemSetProgress.DoesNotExist:
return 0
return 0
def get_user_progress(self, obj):
"""获取当前用户在该题单中的进度"""
request = self.context.get("request")
return get_user_progress_data(obj, request)
class ProblemSetListSerializer(serializers.ModelSerializer):
"""题单列表序列化器"""
created_by = UsernameSerializer()
problems_count = serializers.IntegerField(read_only=True)
user_progress = serializers.SerializerMethodField()
badges = serializers.SerializerMethodField()
class Meta:
model = ProblemSet
fields = [
"id",
"title",
"description",
"created_by",
"create_time",
"difficulty",
"status",
"end_time",
"problems_count",
"user_progress",
"badges",
"visible",
]
def get_user_progress(self, obj):
"""获取当前用户在该题单中的进度"""
request = self.context.get("request")
if request and hasattr(request, "_user_progress_map"):
progress = request._user_progress_map.get(obj.id)
if progress:
return {
"is_joined": True,
"progress_percentage": progress.progress_percentage,
"completed_count": progress.completed_problems_count,
"total_count": progress.total_problems_count,
"is_completed": progress.is_completed,
}
return {
"is_joined": False,
"progress_percentage": 0,
"completed_count": 0,
"total_count": 0,
"is_completed": False,
}
def get_badges(self, obj):
"""获取题单的奖章列表,并标记用户已获得的徽章"""
request = self.context.get("request")
# 使用预加载的奖章数据
badges = getattr(obj, "badges", [])
badge_data = ProblemSetBadgeSerializer(badges, many=True).data
# 如果用户已登录,检查哪些徽章已被获得
if request and request.user.is_authenticated and hasattr(request, "_user_earned_badge_ids"):
earned_badge_ids = request._user_earned_badge_ids
# 为每个徽章添加是否已获得的标记
for badge in badge_data:
badge['is_earned'] = badge['id'] in earned_badge_ids
else:
# 未登录用户或未预加载,所有徽章都标记为未获得
for badge in badge_data:
badge['is_earned'] = False
return badge_data
class CreateProblemSetSerializer(serializers.Serializer):
"""创建题单序列化器"""
title = serializers.CharField(max_length=200)
description = serializers.CharField()
difficulty = serializers.CharField(default="Easy")
status = serializers.CharField(default="active")
end_time = serializers.DateTimeField(required=False)
class EditProblemSetSerializer(serializers.Serializer):
"""编辑题单序列化器"""
id = serializers.IntegerField()
title = serializers.CharField(max_length=200, required=False)
description = serializers.CharField(required=False)
difficulty = serializers.CharField(required=False)
status = serializers.CharField(required=False)
visible = serializers.BooleanField(required=False)
end_time = serializers.DateTimeField(required=False, allow_null=True)
class ProblemSetProblemSerializer(serializers.ModelSerializer):
"""题单题目序列化器"""
problem = serializers.SerializerMethodField()
is_completed = serializers.SerializerMethodField()
class Meta:
model = ProblemSetProblem
fields = "__all__"
def get_problem(self, obj):
"""获取题目详细信息"""
from problem.serializers import ProblemListSerializer
return ProblemListSerializer(obj.problem, context=self.context).data
def get_is_completed(self, obj):
"""获取当前用户是否已完成该题目"""
request = self.context.get("request")
if request and request.user.is_authenticated:
try:
progress = ProblemSetProgress.objects.get(
problemset=obj.problemset, user=request.user
)
problem_id = str(obj.problem.id)
return problem_id in progress.progress_detail
except ProblemSetProgress.DoesNotExist:
return False
return False
class AddProblemToSetSerializer(serializers.Serializer):
"""添加题目到题单序列化器"""
problem_id = serializers.CharField()
order = serializers.IntegerField(default=0)
is_required = serializers.BooleanField(default=True)
score = serializers.IntegerField(default=0)
hint = serializers.CharField(required=False, allow_blank=True)
class EditProblemInSetSerializer(serializers.Serializer):
"""编辑题单中的题目序列化器"""
order = serializers.IntegerField(required=False)
is_required = serializers.BooleanField(required=False)
score = serializers.IntegerField(required=False)
hint = serializers.CharField(required=False, allow_blank=True)
class ProblemSetBadgeSerializer(serializers.ModelSerializer):
"""题单奖章序列化器"""
class Meta:
model = ProblemSetBadge
fields = "__all__"
class CreateProblemSetBadgeSerializer(serializers.Serializer):
"""创建题单奖章序列化器"""
name = serializers.CharField(max_length=100)
description = serializers.CharField()
icon = serializers.CharField()
condition_type = serializers.CharField() # all_problems, problem_count, score
condition_value = serializers.IntegerField(required=False)
class EditProblemSetBadgeSerializer(serializers.Serializer):
"""编辑题单奖章序列化器"""
name = serializers.CharField(max_length=100, required=False)
description = serializers.CharField(required=False)
icon = serializers.CharField(required=False)
condition_type = serializers.CharField(required=False) # all_problems, problem_count, score
condition_value = serializers.IntegerField(required=False)
class ProblemSetProgressSerializer(serializers.ModelSerializer):
"""题单进度序列化器"""
user = UsernameSerializer()
completed_problems = serializers.SerializerMethodField()
class Meta:
model = ProblemSetProgress
fields = "__all__"
def get_completed_problems(self, obj):
"""获取已完成的题目列表"""
completed_problems = []
# 尝试从 request 中获取预加载的问题字典(用于批量查询优化)
problems_dict = {}
request = self.context.get('request')
if request and hasattr(request, '_problems_dict_cache'):
problems_dict = request._problems_dict_cache
if obj.progress_detail:
for problem_id in obj.progress_detail.keys():
# 优先使用预加载的问题字典
if problems_dict:
problem = problems_dict.get(problem_id)
if problem:
completed_problems.append({
'id': problem.id,
'_id': problem._id,
'title': problem.title
})
continue
# 如果没有预加载字典,则回退到单独查询(向后兼容)
from problem.models import Problem
try:
problem = Problem.objects.get(id=problem_id)
completed_problems.append({
'id': problem.id,
'_id': problem._id,
'title': problem.title
})
except Problem.DoesNotExist:
continue
return completed_problems
class UserBadgeSerializer(serializers.ModelSerializer):
"""用户奖章序列化器"""
badge = ProblemSetBadgeSerializer()
class Meta:
model = UserBadge
fields = "__all__"
class JoinProblemSetSerializer(serializers.Serializer):
"""加入题单序列化器"""
problemset_id = serializers.IntegerField()
class UpdateProgressSerializer(serializers.Serializer):
"""更新进度序列化器"""
problemset_id = serializers.IntegerField()
problem_id = serializers.IntegerField()
submission_id = serializers.CharField(required=False)

View File

@@ -1,91 +0,0 @@
# 题单应用信号处理
from django.db.models.signals import post_save, post_delete
from django.dispatch import receiver
from .models import ProblemSetProblem, ProblemSetProgress, ProblemSetBadge, UserBadge
from django.db import transaction
import logging
logger = logging.getLogger(__name__)
@receiver(post_save, sender=ProblemSetProblem)
def sync_progress_on_problem_change(sender, instance, created, **kwargs):
"""当题单题目发生变化时,同步所有用户的进度"""
try:
with transaction.atomic():
# 获取该题单的所有用户进度
progresses = ProblemSetProgress.objects.filter(
problemset=instance.problemset
)
# 批量更新所有用户的进度
for progress in progresses:
progress.update_progress()
# 重新计算该题单的所有徽章资格
badges = ProblemSetBadge.objects.filter(problemset=instance.problemset)
for badge in badges:
badge.recalculate_user_badges()
logger.info(f"已同步题单 {instance.problemset.id} 的所有用户进度和徽章资格")
except Exception as e:
logger.error(f"同步题单进度时出错: {e}")
@receiver(post_delete, sender=ProblemSetProblem)
def sync_progress_on_problem_delete(sender, instance, **kwargs):
"""当题单题目被删除时,同步所有用户的进度并清理相关提交记录"""
try:
with transaction.atomic():
# 清理该题目在题单中的所有提交记录
from .models import ProblemSetSubmission
ProblemSetSubmission.objects.filter(
problemset=instance.problemset,
problem=instance.problem
).delete()
# 获取该题单的所有用户进度
progresses = ProblemSetProgress.objects.filter(
problemset=instance.problemset
)
# 批量更新所有用户的进度
for progress in progresses:
progress.update_progress()
# 重新计算该题单的所有徽章资格
badges = ProblemSetBadge.objects.filter(problemset=instance.problemset)
for badge in badges:
badge.recalculate_user_badges()
logger.info(f"已同步题单 {instance.problemset.id} 的所有用户进度和徽章资格(删除题目后)")
except Exception as e:
logger.error(f"同步题单进度时出错: {e}")
@receiver(post_save, sender=ProblemSetBadge)
def sync_badges_on_badge_change(sender, instance, created, **kwargs):
"""当题单奖章发生变化时,重新计算所有用户的奖章资格"""
try:
with transaction.atomic():
# 重新计算该奖章的所有用户资格
instance.recalculate_user_badges()
logger.info(f"已重新计算题单 {instance.problemset.id} 的奖章 {instance.id} 的用户资格")
except Exception as e:
logger.error(f"重新计算奖章资格时出错: {e}")
@receiver(post_delete, sender=ProblemSetBadge)
def cleanup_badges_on_badge_delete(sender, instance, **kwargs):
"""当题单奖章被删除时,清理相关的用户奖章记录"""
try:
with transaction.atomic():
# 删除该奖章的所有用户奖章记录
UserBadge.objects.filter(badge=instance).delete()
logger.info(f"已清理奖章 {instance.id} 的所有用户奖章记录")
except Exception as e:
logger.error(f"清理用户奖章记录时出错: {e}")

View File

@@ -1,71 +0,0 @@
from django.urls import path
from problemset.views.admin import (
ProblemSetAdminAPI,
ProblemSetBadgeAdminAPI,
ProblemSetDetailAdminAPI,
ProblemSetProblemAdminAPI,
ProblemSetProgressAdminAPI,
ProblemSetStatusAPI,
ProblemSetSyncAPI,
ProblemSetVisibleAPI,
)
urlpatterns = [
# 管理员题单管理API
path("problemset", ProblemSetAdminAPI.as_view(), name="admin_problemset_api"),
path(
"problemset/<int:problem_set_id>",
ProblemSetDetailAdminAPI.as_view(),
name="admin_problemset_detail_api",
),
path(
"problemset/<int:problem_set_id>/problems",
ProblemSetProblemAdminAPI.as_view(),
name="admin_problemset_problems_api",
),
path(
"problemset/<int:problem_set_id>/problems/<int:problem_set_problem_id>",
ProblemSetProblemAdminAPI.as_view(),
name="admin_problemset_problem_detail_api",
),
# 管理员奖章管理API
path(
"problemset/<int:problem_set_id>/badges",
ProblemSetBadgeAdminAPI.as_view(),
name="admin_problemset_badges_api",
),
path(
"problemset/<int:problem_set_id>/badges/<int:badge_id>",
ProblemSetBadgeAdminAPI.as_view(),
name="admin_problemset_badge_detail_api",
),
# 管理员进度管理API
path(
"problemset/<int:problem_set_id>/progress",
ProblemSetProgressAdminAPI.as_view(),
name="admin_problemset_progress_api",
),
path(
"problemset/<int:problem_set_id>/progress/<int:user_id>",
ProblemSetProgressAdminAPI.as_view(),
name="admin_problemset_progress_detail_api",
),
# 题单同步管理API
path(
"problemset/<int:problem_set_id>/sync",
ProblemSetSyncAPI.as_view(),
name="admin_problemset_sync_api",
),
# 题单状态管理API
path(
"problemset/visible",
ProblemSetVisibleAPI.as_view(),
name="admin_problemset_visible_api",
),
path(
"problemset/status",
ProblemSetStatusAPI.as_view(),
name="admin_problemset_status_api",
),
]

View File

@@ -1,55 +0,0 @@
from django.urls import path
from problemset.views.oj import (
ProblemSetAPI,
ProblemSetDetailAPI,
ProblemSetProblemAPI,
ProblemSetProgressAPI,
UserBadgeAPI,
UserProgressAPI,
ProblemSetBadgeAPI,
ProblemSetUserProgressAPI,
)
urlpatterns = [
# 题单相关API
path("problemset", ProblemSetAPI.as_view(), name="problemset_api"),
path(
"problemset/<int:problem_set_id>",
ProblemSetDetailAPI.as_view(),
name="problemset_detail_api",
),
path(
"problemset/<int:problem_set_id>/problems",
ProblemSetProblemAPI.as_view(),
name="problemset_problems_api",
),
path(
"problemset/<int:problem_set_id>/problems/<int:problem_id>",
ProblemSetProblemAPI.as_view(),
name="problemset_problem_detail_api",
),
# 进度相关API
path(
"problemset/progress",
ProblemSetProgressAPI.as_view(),
name="problemset_progress_api",
),
path(
"problemset/<int:problem_set_id>/progress",
ProblemSetProgressAPI.as_view(),
name="problemset_progress_detail_api",
),
path("user/progress", UserProgressAPI.as_view(), name="user_progress_api"),
# 奖章相关API
path("user/badges", UserBadgeAPI.as_view(), name="user_badges_api"),
path(
"problemset/<int:problem_set_id>/badges",
ProblemSetBadgeAPI.as_view(),
name="problemset_badges_api",
),
path(
"problemset/<int:problem_set_id>/users_progress",
ProblemSetUserProgressAPI.as_view(),
name="problemset_user_progress_api",
),
]

View File

@@ -1,428 +0,0 @@
from django.db.models import Q
from utils.api import APIView, validate_serializer
from account.decorators import super_admin_required, ensure_created_by
from problemset.models import (
ProblemSet,
ProblemSetProblem,
ProblemSetBadge,
ProblemSetProgress,
)
from problemset.serializers import (
ProblemSetSerializer,
ProblemSetListSerializer,
CreateProblemSetSerializer,
EditProblemSetSerializer,
ProblemSetProblemSerializer,
AddProblemToSetSerializer,
EditProblemInSetSerializer,
ProblemSetBadgeSerializer,
CreateProblemSetBadgeSerializer,
EditProblemSetBadgeSerializer,
ProblemSetProgressSerializer,
)
from problem.models import Problem
class ProblemSetAdminAPI(APIView):
"""题单管理API"""
@super_admin_required
def get(self, request):
"""获取题单列表(管理员)"""
problem_sets = ProblemSet.objects.all().order_by("-create_time")
# 过滤条件
keyword = request.GET.get("keyword", "").strip()
if keyword:
problem_sets = problem_sets.filter(
Q(title__icontains=keyword) | Q(description__icontains=keyword)
)
difficulty = request.GET.get("difficulty")
if difficulty:
problem_sets = problem_sets.filter(difficulty=difficulty)
status = request.GET.get("status")
if status:
problem_sets = problem_sets.filter(status=status)
# 使用统一的分页方法
data = self.paginate_data(request, problem_sets, ProblemSetListSerializer)
return self.success(data)
@super_admin_required
@validate_serializer(CreateProblemSetSerializer)
def post(self, request):
"""创建题单"""
data = request.data
data["created_by"] = request.user
problem_set = ProblemSet.objects.create(**data)
return self.success(ProblemSetSerializer(problem_set).data)
@super_admin_required
@validate_serializer(EditProblemSetSerializer)
def put(self, request):
"""编辑题单"""
data = request.data
problem_set_id = data.pop("id")
try:
problem_set = ProblemSet.objects.get(id=problem_set_id)
ensure_created_by(problem_set, request.user)
except ProblemSet.DoesNotExist:
return self.error("题单不存在")
# 更新题单信息
for key, value in data.items():
if key != "id":
setattr(problem_set, key, value)
problem_set.save()
return self.success(ProblemSetSerializer(problem_set).data)
@super_admin_required
def delete(self, request):
"""删除题单"""
problem_set_id = request.GET.get("id")
if not problem_set_id:
return self.error("题单ID是必需的")
try:
problem_set = ProblemSet.objects.get(id=problem_set_id)
ensure_created_by(problem_set, request.user)
except ProblemSet.DoesNotExist:
return self.error("题单不存在")
# 软删除:设置为不可见
problem_set.visible = False
problem_set.save()
return self.success("题单已删除")
class ProblemSetDetailAdminAPI(APIView):
"""题单详情管理API"""
@super_admin_required
def get(self, request, problem_set_id):
"""获取题单详情(管理员)"""
try:
problem_set = ProblemSet.objects.get(id=problem_set_id)
ensure_created_by(problem_set, request.user)
except ProblemSet.DoesNotExist:
return self.error("题单不存在")
serializer = ProblemSetSerializer(problem_set, context={"request": request})
return self.success(serializer.data)
class ProblemSetProblemAdminAPI(APIView):
"""题单题目管理API管理员"""
@super_admin_required
def get(self, request, problem_set_id):
"""获取题单中的题目列表(管理员)"""
try:
problem_set = ProblemSet.objects.get(id=problem_set_id)
ensure_created_by(problem_set, request.user)
except ProblemSet.DoesNotExist:
return self.error("题单不存在")
problems = ProblemSetProblem.objects.filter(problemset=problem_set).order_by(
"order"
)
serializer = ProblemSetProblemSerializer(
problems, many=True, context={"request": request}
)
return self.success(serializer.data)
@super_admin_required
@validate_serializer(AddProblemToSetSerializer)
def post(self, request, problem_set_id):
"""添加题目到题单(管理员)"""
try:
problem_set = ProblemSet.objects.get(id=problem_set_id)
ensure_created_by(problem_set, request.user)
except ProblemSet.DoesNotExist:
return self.error("题单不存在")
data = request.data
try:
problem = Problem.objects.filter(
_id=data["problem_id"],
visible=True,
contest_id__isnull=True,
).get()
except Problem.DoesNotExist:
return self.error("题目不存在或不可见")
# 检查题目是否已经在题单中
if ProblemSetProblem.objects.filter(
problemset=problem_set, problem=problem
).exists():
return self.error("题目已在该题单中")
ProblemSetProblem.objects.create(
problemset=problem_set,
problem=problem,
order=data.get("order", 0),
is_required=data.get("is_required", True),
score=data.get("score", 0),
hint=data.get("hint", ""),
)
# 同步所有用户的进度
ProblemSetProgress.sync_all_progress_for_problemset(problem_set)
return self.success("题目已添加到题单")
@super_admin_required
@validate_serializer(EditProblemInSetSerializer)
def put(self, request, problem_set_id, problem_set_problem_id):
"""编辑题单中的题目(管理员)"""
try:
problem_set = ProblemSet.objects.get(id=problem_set_id)
ensure_created_by(problem_set, request.user)
except ProblemSet.DoesNotExist:
return self.error("题单不存在")
try:
problem_set_problem = ProblemSetProblem.objects.get(
id=problem_set_problem_id, problemset=problem_set
)
except ProblemSetProblem.DoesNotExist:
return self.error("题目不在该题单中")
data = request.data
# 更新题目属性
if "order" in data:
problem_set_problem.order = data["order"]
if "is_required" in data:
problem_set_problem.is_required = data["is_required"]
if "score" in data:
problem_set_problem.score = data["score"]
if "hint" in data:
problem_set_problem.hint = data["hint"]
problem_set_problem.save()
# 同步所有用户的进度
ProblemSetProgress.sync_all_progress_for_problemset(problem_set)
return self.success("题目已更新")
@super_admin_required
def delete(self, request, problem_set_id, problem_set_problem_id):
"""从题单中移除题目(管理员)"""
try:
problem_set = ProblemSet.objects.get(id=problem_set_id)
ensure_created_by(problem_set, request.user)
except ProblemSet.DoesNotExist:
return self.error("题单不存在")
try:
problem_set_problem = ProblemSetProblem.objects.get(
id=problem_set_problem_id, problemset=problem_set
)
problem_set_problem.delete()
# 同步所有用户的进度
ProblemSetProgress.sync_all_progress_for_problemset(problem_set)
return self.success("题目已从题单中移除")
except ProblemSetProblem.DoesNotExist:
return self.error("题目不在该题单中")
class ProblemSetBadgeAdminAPI(APIView):
"""题单奖章管理API管理员"""
@super_admin_required
def get(self, request, problem_set_id):
"""获取题单的奖章列表(管理员)"""
try:
problem_set = ProblemSet.objects.get(id=problem_set_id)
ensure_created_by(problem_set, request.user)
except ProblemSet.DoesNotExist:
return self.error("题单不存在")
badges = ProblemSetBadge.objects.filter(problemset=problem_set)
serializer = ProblemSetBadgeSerializer(badges, many=True)
return self.success(serializer.data)
@super_admin_required
@validate_serializer(CreateProblemSetBadgeSerializer)
def post(self, request, problem_set_id):
"""创建题单奖章(管理员)"""
try:
problem_set = ProblemSet.objects.get(id=problem_set_id)
ensure_created_by(problem_set, request.user)
except ProblemSet.DoesNotExist:
return self.error("题单不存在")
data = request.data
data["problemset"] = problem_set
badge = ProblemSetBadge.objects.create(**data)
return self.success(ProblemSetBadgeSerializer(badge).data)
@super_admin_required
@validate_serializer(EditProblemSetBadgeSerializer)
def put(self, request, problem_set_id, badge_id):
"""编辑题单奖章(管理员)"""
try:
problem_set = ProblemSet.objects.get(id=problem_set_id)
ensure_created_by(problem_set, request.user)
except ProblemSet.DoesNotExist:
return self.error("题单不存在")
try:
badge = ProblemSetBadge.objects.get(id=badge_id, problemset=problem_set)
except ProblemSetBadge.DoesNotExist:
return self.error("奖章不存在")
data = request.data
# 记录是否修改了条件相关的字段
condition_changed = False
# 更新奖章属性
if "name" in data:
badge.name = data["name"]
if "description" in data:
badge.description = data["description"]
if "icon" in data:
badge.icon = data["icon"]
if "condition_type" in data:
badge.condition_type = data["condition_type"]
condition_changed = True
if "condition_value" in data:
badge.condition_value = data["condition_value"]
condition_changed = True
if "level" in data:
badge.level = data["level"]
badge.save()
# 如果修改了条件,重新计算所有用户的徽章资格
if condition_changed:
try:
badge.recalculate_user_badges()
return self.success("奖章已更新,并重新计算了所有用户的徽章资格")
except Exception as e:
return self.error(f"奖章已更新,但重新计算徽章资格时出错: {str(e)}")
return self.success("奖章已更新")
@super_admin_required
def delete(self, request, problem_set_id, badge_id):
"""删除题单奖章(管理员)"""
try:
problem_set = ProblemSet.objects.get(id=problem_set_id)
ensure_created_by(problem_set, request.user)
except ProblemSet.DoesNotExist:
return self.error("题单不存在")
try:
badge = ProblemSetBadge.objects.get(id=badge_id, problemset=problem_set)
badge.delete()
return self.success("奖章已删除")
except ProblemSetBadge.DoesNotExist:
return self.error("奖章不存在")
class ProblemSetProgressAdminAPI(APIView):
"""题单进度管理API管理员"""
@super_admin_required
def get(self, request, problem_set_id):
"""获取题单的所有用户进度(管理员)"""
try:
problem_set = ProblemSet.objects.get(id=problem_set_id)
ensure_created_by(problem_set, request.user)
except ProblemSet.DoesNotExist:
return self.error("题单不存在")
progress_list = ProblemSetProgress.objects.filter(
problemset=problem_set
).order_by("-join_time")
serializer = ProblemSetProgressSerializer(progress_list, many=True)
return self.success(serializer.data)
@super_admin_required
def delete(self, request, problem_set_id, user_id):
"""移除用户从题单(管理员)"""
try:
problem_set = ProblemSet.objects.get(id=problem_set_id)
ensure_created_by(problem_set, request.user)
except ProblemSet.DoesNotExist:
return self.error("题单不存在")
try:
progress = ProblemSetProgress.objects.get(
problemset=problem_set, user_id=user_id
)
progress.delete()
return self.success("用户已从题单中移除")
except ProblemSetProgress.DoesNotExist:
return self.error("用户未加入该题单")
class ProblemSetSyncAPI(APIView):
"""题单同步管理API"""
@super_admin_required
def post(self, request, problem_set_id):
"""手动同步题单的所有用户进度(管理员)"""
try:
problem_set = ProblemSet.objects.get(id=problem_set_id)
ensure_created_by(problem_set, request.user)
except ProblemSet.DoesNotExist:
return self.error("题单不存在")
# 同步所有用户的进度
synced_count = ProblemSetProgress.sync_all_progress_for_problemset(problem_set)
return self.success(f"已同步 {synced_count} 个用户的进度")
class ProblemSetVisibleAPI(APIView):
"""题单可见性管理API"""
@super_admin_required
def put(self, request):
"""切换题单可见性"""
data = request.data
try:
problem_set = ProblemSet.objects.get(id=data["id"])
ensure_created_by(problem_set, request.user)
except ProblemSet.DoesNotExist:
return self.error("题单不存在")
problem_set.visible = not problem_set.visible
problem_set.save()
return self.success()
class ProblemSetStatusAPI(APIView):
"""题单状态管理API"""
@super_admin_required
def put(self, request):
"""更新题单状态"""
data = request.data
try:
problem_set = ProblemSet.objects.get(id=data["id"])
ensure_created_by(problem_set, request.user)
except ProblemSet.DoesNotExist:
return self.error("题单不存在")
status = data.get("status")
if status not in ["active", "archived", "draft"]:
return self.error("无效的状态")
problem_set.status = status
problem_set.save()
return self.success()

View File

@@ -1,460 +0,0 @@
from django.db.models import Q, Avg, Count, Prefetch
from django.utils import timezone
from utils.api import APIView, validate_serializer
from account.models import User
from problemset.models import (
ProblemSet,
ProblemSetProblem,
ProblemSetBadge,
ProblemSetProgress,
ProblemSetSubmission,
UserBadge,
)
from problemset.serializers import (
ProblemSetSerializer,
ProblemSetListSerializer,
ProblemSetProblemSerializer,
ProblemSetBadgeSerializer,
ProblemSetProgressSerializer,
UserBadgeSerializer,
JoinProblemSetSerializer,
UpdateProgressSerializer,
)
from submission.models import Submission
from problem.models import Problem
class ProblemSetAPI(APIView):
"""题单API - 用户端"""
def get(self, request):
"""获取题单列表"""
# 预加载创建者信息
problem_sets = ProblemSet.objects.filter(visible=True).exclude(status="draft").select_related("created_by")
# 使用annotate在查询时计算题目数量避免N+1查询
problem_sets = problem_sets.annotate(
problems_count=Count("problemsetproblem", distinct=True)
)
# 过滤条件
keyword = request.GET.get("keyword", "").strip()
if keyword:
problem_sets = problem_sets.filter(
Q(title__icontains=keyword) | Q(description__icontains=keyword)
)
difficulty = request.GET.get("difficulty")
if difficulty:
problem_sets = problem_sets.filter(difficulty=difficulty)
status_filter = request.GET.get("status")
if status_filter:
problem_sets = problem_sets.filter(status=status_filter)
# 排序
sort = request.GET.get("sort")
if sort:
problem_sets = problem_sets.order_by(sort)
else:
problem_sets = problem_sets.order_by("-create_time")
# 批量查询用户进度和已获得的奖章(如果用户已登录)
# 注意需要在应用prefetch_related之前获取ID列表避免不必要的预加载
user_progress_map = {}
user_earned_badge_ids = set()
if request.user.is_authenticated:
# 先获取所有题单ID不应用prefetch_related只获取ID
problem_set_ids = list(problem_sets.values_list("id", flat=True))
if problem_set_ids:
# 批量查询用户在这些题单中的进度
user_progresses = ProblemSetProgress.objects.filter(
problemset_id__in=problem_set_ids,
user=request.user
).select_related("problemset")
# 构建映射题单ID -> 进度对象
user_progress_map = {progress.problemset_id: progress for progress in user_progresses}
# 批量查询用户已获得的奖章ID这些题单相关的
user_earned_badge_ids = set(
UserBadge.objects.filter(
user=request.user,
badge__problemset_id__in=problem_set_ids
).values_list('badge_id', flat=True)
)
# 预加载奖章信息在获取ID之后应用避免在获取ID时也预加载
problem_sets = problem_sets.prefetch_related(
Prefetch(
"problemsetbadge_set",
queryset=ProblemSetBadge.objects.all(),
to_attr="badges"
)
)
# 将用户进度映射和已获得的奖章ID集合存储到request中供序列化器使用
request._user_progress_map = user_progress_map
request._user_earned_badge_ids = user_earned_badge_ids
data = self.paginate_data(request, problem_sets, ProblemSetListSerializer)
return self.success(data)
class ProblemSetDetailAPI(APIView):
"""题单详情API - 用户端"""
def get(self, request, problem_set_id):
"""获取题单详情"""
try:
problem_set = (
ProblemSet.objects.filter(id=problem_set_id, visible=True)
.exclude(status="draft")
.get()
)
except ProblemSet.DoesNotExist:
return self.error("题单不存在")
serializer = ProblemSetSerializer(problem_set, context={"request": request})
return self.success(serializer.data)
class ProblemSetProblemAPI(APIView):
"""题单题目API - 用户端"""
def get(self, request, problem_set_id):
"""获取题单中的题目列表"""
try:
problem_set = (
ProblemSet.objects.filter(id=problem_set_id, visible=True)
.exclude(status="draft")
.get()
)
except ProblemSet.DoesNotExist:
return self.error("题单不存在")
problems = ProblemSetProblem.objects.filter(problemset=problem_set).order_by(
"order"
)
serializer = ProblemSetProblemSerializer(
problems, many=True, context={"request": request}
)
return self.success(serializer.data)
class ProblemSetProgressAPI(APIView):
"""题单进度API"""
@validate_serializer(JoinProblemSetSerializer)
def post(self, request):
"""加入题单"""
data = request.data
try:
problem_set = (
ProblemSet.objects.filter(id=data["problemset_id"], visible=True)
.exclude(status="draft")
.get()
)
except ProblemSet.DoesNotExist:
return self.error("题单不存在")
if ProblemSetProgress.objects.filter(
problemset=problem_set, user=request.user
).exists():
return self.error("已经加入该题单")
# 创建进度记录
progress = ProblemSetProgress.objects.create(
problemset=problem_set, user=request.user
)
progress.update_progress()
return self.success("成功加入题单")
def get(self, request, problem_set_id):
"""获取题单进度"""
try:
problem_set = (
ProblemSet.objects.filter(id=problem_set_id, visible=True)
.exclude(status="draft")
.get()
)
except ProblemSet.DoesNotExist:
return self.error("题单不存在")
try:
progress = ProblemSetProgress.objects.get(
problemset=problem_set, user=request.user
)
except ProblemSetProgress.DoesNotExist:
return self.error("未加入该题单")
serializer = ProblemSetProgressSerializer(progress)
return self.success(serializer.data)
@validate_serializer(UpdateProgressSerializer)
def put(self, request):
"""更新进度"""
data = request.data
try:
problem_set = (
ProblemSet.objects.filter(id=data["problemset_id"], visible=True)
.exclude(status="draft")
.get()
)
except ProblemSet.DoesNotExist:
return self.error("题单不存在")
try:
progress = ProblemSetProgress.objects.get(
problemset=problem_set, user=request.user
)
except ProblemSetProgress.DoesNotExist:
return self.error("未加入该题单")
# 更新详细进度
problem_id = str(data["problem_id"])
# 获取该题目在题单中的分值
try:
problemset_problem = ProblemSetProblem.objects.get(
problemset=problem_set, problem_id=problem_id
)
problem_score = problemset_problem.score
except ProblemSetProblem.DoesNotExist:
problem_score = 0
progress.progress_detail[problem_id] = {
"score": problem_score, # 题单中设置的分值
"submit_time": data.get("submit_time", timezone.now().isoformat()),
}
# 更新进度
progress.update_progress()
# 只有当提供了submission_id时才创建ProblemSetSubmission记录
if "submission_id" in data and data["submission_id"]:
try:
submission = Submission.objects.get(id=data["submission_id"])
problem = Problem.objects.get(id=problem_id)
has_accepted = ProblemSetSubmission.objects.filter(
problemset=problem_set,
user=request.user,
problem=problem,
).exists()
if not has_accepted:
ProblemSetSubmission.objects.create(
problemset=problem_set,
user=request.user,
submission=submission,
problem=problem,
)
except Submission.DoesNotExist:
# 如果提交记录不存在,记录错误但不中断流程
pass
# 检查是否获得奖章
self._check_badges(progress)
return self.success("进度已更新")
def _check_badges(self, progress):
"""检查是否获得奖章"""
badges = ProblemSetBadge.objects.filter(problemset=progress.problemset)
for badge in badges:
if UserBadge.objects.filter(user=progress.user, badge=badge).exists():
continue
if badge.condition_type == "all_problems":
if progress.completed_problems_count == progress.total_problems_count:
UserBadge.objects.create(user=progress.user, badge=badge)
elif badge.condition_type == "problem_count":
if progress.completed_problems_count >= badge.condition_value:
UserBadge.objects.create(user=progress.user, badge=badge)
elif badge.condition_type == "score":
if progress.total_score >= badge.condition_value:
UserBadge.objects.create(user=progress.user, badge=badge)
class UserProgressAPI(APIView):
"""用户进度API"""
def get(self, request):
"""获取用户的题单进度列表"""
progress_list = ProblemSetProgress.objects.filter(user=request.user).order_by(
"-join_time"
)
serializer = ProblemSetProgressSerializer(progress_list, many=True)
return self.success(serializer.data)
class UserBadgeAPI(APIView):
"""用户奖章API"""
def get(self, request):
"""获取用户的奖章列表"""
# 支持通过username参数获取指定用户的徽章
username = request.GET.get("username")
if username:
# 获取指定用户的徽章
try:
target_user = User.objects.get(username=username, is_disabled=False)
badges = UserBadge.objects.filter(user=target_user).order_by(
"-earned_time"
)
except User.DoesNotExist:
return self.error("用户不存在")
else:
# 获取当前用户的徽章
badges = UserBadge.objects.filter(user=request.user).order_by(
"-earned_time"
)
serializer = UserBadgeSerializer(badges, many=True)
return self.success(serializer.data)
class ProblemSetBadgeAPI(APIView):
"""题单奖章API - 用户端"""
def get(self, request, problem_set_id):
"""获取题单的奖章列表"""
try:
problem_set = (
ProblemSet.objects.filter(id=problem_set_id, visible=True)
.exclude(status="draft")
.get()
)
except ProblemSet.DoesNotExist:
return self.error("题单不存在")
badges = ProblemSetBadge.objects.filter(problemset=problem_set)
serializer = ProblemSetBadgeSerializer(badges, many=True)
return self.success(serializer.data)
class ProblemSetUserProgressAPI(APIView):
"""题单用户进度列表API"""
def get(self, request, problem_set_id: int):
"""获取题单的用户进度列表"""
try:
problem_set = (
ProblemSet.objects.filter(id=problem_set_id, visible=True)
.exclude(status="draft")
.get()
)
except ProblemSet.DoesNotExist:
return self.error("题单不存在")
# 获取所有参与该题单的用户进度,使用 select_related 预加载用户信息
progresses = ProblemSetProgress.objects.filter(
problemset=problem_set
).select_related("user")
# 班级过滤
class_name = request.GET.get("class_name", "").strip()
if class_name:
progresses = progresses.filter(user__username__icontains=class_name)
# 完成度筛选
completion_status = request.GET.get("completion_status", "").strip()
if completion_status == "completed":
# 已完成:所有题目都已完成
progresses = progresses.filter(is_completed=True)
elif completion_status == "in_progress":
# 进行中:未完成且已开始(至少完成了一道题,排除未开始的用户)
progresses = progresses.filter(is_completed=False, completed_problems_count__gt=0)
elif completion_status == "not_started":
# 未开始:还没有完成任何题目
progresses = progresses.filter(completed_problems_count=0)
# 排序
progresses = progresses.order_by(
"-is_completed", "-progress_percentage", "join_time"
)
# 计算统计数据(基于所有数据,而非分页数据)
# 使用一次查询获取所有统计数据
stats = progresses.aggregate(
total=Count("id"),
completed=Count("id", filter=Q(is_completed=True)),
avg_progress=Avg("progress_percentage"),
)
total_count = stats["total"]
completed_count = stats["completed"]
avg_progress = stats["avg_progress"] or 0
# 获取分页参数
try:
limit = int(request.GET.get("limit", "10"))
except ValueError:
limit = 10
try:
offset = int(request.GET.get("offset", "0"))
except ValueError:
offset = 0
if offset < 0:
offset = 0
# 提前获取题单的所有题目(用于前端显示未完成题目和序列化器)
# 使用 select_related 和 only 优化查询,只选择需要的字段
all_problemset_problems = (
ProblemSetProblem.objects.filter(problemset=problem_set)
.select_related("problem")
.only("problem__id", "problem___id", "problem__title", "order")
.order_by("order")
)
# 构建题单所有题目的数据结构和映射
all_problems_list = []
all_problems_map = {}
for psp in all_problemset_problems:
problem_data = {
"id": psp.problem.id,
"_id": psp.problem._id,
"title": psp.problem.title,
}
all_problems_list.append(problem_data)
# 用于序列化器查找key 使用字符串格式(与 progress_detail 的 key 格式一致)
all_problems_map[str(psp.problem.id)] = psp.problem
# 从当前页的数据中收集已完成的问题ID用于序列化器
paginated_progresses = list(progresses[offset : offset + limit])
completed_problem_ids = set()
for progress in paginated_progresses:
if progress.progress_detail:
# progress_detail 的 key 是字符串格式的 problem_id
completed_problem_ids.update(progress.progress_detail.keys())
# 从已加载的题单题目中构建 problems_dict避免重复查询
problems_dict = {
pid: all_problems_map[pid]
for pid in completed_problem_ids
if pid in all_problems_map
}
# 将预加载的问题字典存储到 request 中,供序列化器使用
request._problems_dict_cache = problems_dict
# 使用分页
data = self.paginate_data(request, progresses, ProblemSetProgressSerializer)
# 添加统计数据
data["statistics"] = {
"total": total_count,
"completed": completed_count,
"avg_progress": round(avg_progress, 2),
}
# 返回题单的所有题目
data["problems"] = all_problems_list
return self.success(data)

View File

@@ -5,31 +5,21 @@ description = "Add your description here"
readme = "README.md"
requires-python = ">=3.12"
dependencies = [
"channels>=4.2.0",
"channels-redis>=4.2.0",
"coverage==6.5.0",
"daphne>=4.1.2",
"django>=5.2.3",
"django-cas-ng==5.0.1",
"django-dbconn-retry==0.1.8",
"django-dramatiq==0.13.0",
"django-redis==5.4.0",
"djangorestframework==3.16.0",
"dramatiq==1.17.0",
"entrypoints==0.4",
"envelopes==0.4",
"flake8==7.0.0",
"flake8-coding==1.3.2",
"flake8-quotes==3.3.2",
"gunicorn==22.0.0",
"jsonfield==3.1.0",
"openai>=1.108.1",
"otpauth==1.0.1",
"pillow==10.2.0",
"psycopg==3.2.9",
"psycopg-binary==3.2.9",
"python-dateutil==2.8.2",
"qrcode==8.2",
"raven==6.10.0",
"xlsxwriter==3.2.0",
"django-dbconn-retry>=0.1.8",
"django-dramatiq>=0.13.0",
"django-redis>=5.4.0",
"djangorestframework>=3.16.0",
"envelopes>=0.4",
"gunicorn>=23.0.0",
"otpauth>=2.2.1",
"pillow>=11.2.1",
"psycopg>=3.2.9",
"psycopg-binary>=3.2.9",
"python-dateutil>=2.9.0.post0",
"qrcode>=8.2",
"raven>=6.10.0",
"requests>=2.32.4",
"uvicorn>=0.35.0",
"xlsxwriter>=3.2.5",
]

View File

@@ -1,84 +0,0 @@
"""
WebSocket consumers for submission updates
"""
import json
import logging
from channels.generic.websocket import AsyncWebsocketConsumer
logger = logging.getLogger(__name__)
class SubmissionConsumer(AsyncWebsocketConsumer):
"""
WebSocket consumer for real-time submission updates
当用户提交代码后,通过 WebSocket 实时接收判题状态更新
"""
async def connect(self):
"""处理 WebSocket 连接"""
self.user = self.scope["user"]
# 只允许认证用户连接
if not self.user.is_authenticated:
await self.close()
return
# 使用用户 ID 作为组名,这样可以向特定用户推送消息
self.group_name = f"submission_user_{self.user.id}"
# 加入用户专属的组
await self.channel_layer.group_add(
self.group_name,
self.channel_name
)
await self.accept()
logger.info(f"WebSocket connected: user_id={self.user.id}, channel={self.channel_name}")
async def disconnect(self, close_code):
"""处理 WebSocket 断开连接"""
if hasattr(self, 'group_name'):
await self.channel_layer.group_discard(
self.group_name,
self.channel_name
)
logger.info(f"WebSocket disconnected: user_id={self.user.id}, close_code={close_code}")
async def receive(self, text_data):
"""
接收客户端消息
客户端可以发送心跳包或订阅特定提交
"""
try:
data = json.loads(text_data)
message_type = data.get("type")
if message_type == "ping":
# 响应心跳包
await self.send(text_data=json.dumps({
"type": "pong",
"timestamp": data.get("timestamp")
}))
elif message_type == "subscribe":
# 订阅特定提交的更新
submission_id = data.get("submission_id")
if submission_id:
logger.info(f"User {self.user.id} subscribed to submission {submission_id}")
# 可以在这里做额外的订阅逻辑
except json.JSONDecodeError:
logger.error(f"Invalid JSON received from user {self.user.id}")
except Exception as e:
logger.error(f"Error handling message from user {self.user.id}: {str(e)}")
async def submission_update(self, event):
"""
接收来自 channel layer 的代码提交更新消息并发送给客户端
这个方法名对应 push_submission_update 中的 type 字段
"""
try:
# 从 event 中提取数据并发送给客户端
await self.send(text_data=json.dumps(event["data"]))
logger.debug(f"Sent submission update to user {self.user.id}: {event['data']}")
except Exception as e:
logger.error(f"Error sending submission update to user {self.user.id}: {str(e)}")

View File

@@ -1,19 +0,0 @@
# Generated by Django 5.2.3 on 2025-09-25 07:03
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('contest', '0001_initial'),
('problem', '0001_initial'),
('submission', '0001_initial'),
]
operations = [
migrations.AddIndex(
model_name='submission',
index=models.Index(fields=['user_id', 'create_time'], name='user_create_time_idx'),
),
]

View File

@@ -1,19 +0,0 @@
# Generated by Django 6.0 on 2026-03-09 13:06
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('contest', '0001_initial'),
('problem', '0005_remove_spj_fields'),
('submission', '0002_submission_user_create_time_idx'),
]
operations = [
migrations.AddIndex(
model_name='submission',
index=models.Index(fields=['contest_id', '-create_time'], name='contest_create_time_idx'),
),
]

View File

@@ -41,11 +41,7 @@ class Submission(models.Model):
ip = models.TextField(null=True)
def check_user_permission(self, user, check_share=True):
if (
self.user_id == user.id
or not user.is_regular_user()
or self.problem.created_by_id == user.id
):
if self.user_id == user.id or user.is_super_admin() or user.can_mgmt_all_problem() or self.problem.created_by_id == user.id:
return True
if check_share:
@@ -58,14 +54,6 @@ class Submission(models.Model):
class Meta:
db_table = "submission"
ordering = ("-create_time",)
indexes = [
models.Index(
fields=["user_id", "create_time"], name="user_create_time_idx"
),
models.Index(
fields=["contest_id", "-create_time"], name="contest_create_time_idx"
),
]
def __str__(self):
return self.id

View File

@@ -1,10 +1,6 @@
from django.db import models
from django.utils import timezone
from .models import Submission
from utils.api import serializers
from utils.serializers import LanguageNameChoiceField
from problemset.models import ProblemSetProgress
class CreateSubmissionSerializer(serializers.Serializer):
@@ -12,7 +8,6 @@ class CreateSubmissionSerializer(serializers.Serializer):
language = LanguageNameChoiceField()
code = serializers.CharField(max_length=1024 * 1024)
contest_id = serializers.IntegerField(required=False)
problemset_id = serializers.IntegerField(required=False)
captcha = serializers.CharField(required=False)
@@ -39,7 +34,6 @@ class SubmissionSafeModelSerializer(serializers.ModelSerializer):
class SubmissionListSerializer(serializers.ModelSerializer):
problem = serializers.SlugRelatedField(read_only=True, slug_field="_id")
problem_title = serializers.CharField(source="problem.title")
show_link = serializers.SerializerMethodField()
def __init__(self, *args, **kwargs):
@@ -54,36 +48,4 @@ class SubmissionListSerializer(serializers.ModelSerializer):
# 没传user或为匿名user
if self.user is None or not self.user.is_authenticated:
return False
if not obj.check_user_permission(self.user):
return False
# 题单防作弊:用户加入了包含该题目的 active 题单时,隐藏加入前的提交链接
# 如果该题目已在题单中做出来了,则恢复显示
if obj.user_id == self.user.id and self.user.is_regular_user():
progress = self._get_problemset_progress(obj.problem_id)
if (
progress
and obj.create_time < progress.join_time
and str(obj.problem_id) not in progress.progress_detail
):
return False
return True
def _get_problemset_progress(self, problem_id):
"""查询用户是否加入了包含该题目的 active 题单,带缓存避免 N+1"""
if not hasattr(self, "_problemset_progress_cache"):
self._problemset_progress_cache = {}
if problem_id not in self._problemset_progress_cache:
self._problemset_progress_cache[problem_id] = (
ProblemSetProgress.objects.filter(
user=self.user,
problemset__status="active",
problemset__problemsetproblem__problem_id=problem_id,
)
.filter(
models.Q(problemset__end_time__isnull=True)
| models.Q(problemset__end_time__gt=timezone.now())
)
.only("join_time", "progress_detail")
.first()
)
return self._problemset_progress_cache[problem_id]
return obj.check_user_permission(self.user)

78
submission/tests.py Normal file
View File

@@ -0,0 +1,78 @@
from copy import deepcopy
from unittest import mock
from problem.models import Problem, ProblemTag
from utils.api.tests import APITestCase
from .models import Submission
DEFAULT_PROBLEM_DATA = {"_id": "A-110", "title": "test", "description": "<p>test</p>", "input_description": "test",
"output_description": "test", "time_limit": 1000, "memory_limit": 256, "difficulty": "Low",
"visible": True, "tags": ["test"], "languages": ["C", "C++", "Java", "Python2"], "template": {},
"samples": [{"input": "test", "output": "test"}], "spj": False, "spj_language": "C",
"spj_code": "", "test_case_id": "499b26290cc7994e0b497212e842ea85",
"test_case_score": [{"output_name": "1.out", "input_name": "1.in", "output_size": 0,
"stripped_output_md5": "d41d8cd98f00b204e9800998ecf8427e",
"input_size": 0, "score": 0}],
"rule_type": "ACM", "hint": "<p>test</p>", "source": "test"}
DEFAULT_SUBMISSION_DATA = {
"problem_id": "1",
"user_id": 1,
"username": "test",
"code": "xxxxxxxxxxxxxx",
"result": -2,
"info": {},
"language": "C",
"statistic_info": {}
}
# todo contest submission
class SubmissionPrepare(APITestCase):
def _create_problem_and_submission(self):
user = self.create_admin("test", "test123", login=False)
problem_data = deepcopy(DEFAULT_PROBLEM_DATA)
tags = problem_data.pop("tags")
problem_data["created_by"] = user
self.problem = Problem.objects.create(**problem_data)
for tag in tags:
tag = ProblemTag.objects.create(name=tag)
self.problem.tags.add(tag)
self.problem.save()
self.submission_data = deepcopy(DEFAULT_SUBMISSION_DATA)
self.submission_data["problem_id"] = self.problem.id
self.submission = Submission.objects.create(**self.submission_data)
class SubmissionListTest(SubmissionPrepare):
def setUp(self):
self._create_problem_and_submission()
self.create_user("123", "345")
self.url = self.reverse("submission_list_api")
def test_get_submission_list(self):
resp = self.client.get(self.url, data={"limit": "10"})
self.assertSuccess(resp)
@mock.patch("submission.views.oj.judge_task.send")
class SubmissionAPITest(SubmissionPrepare):
def setUp(self):
self._create_problem_and_submission()
self.user = self.create_user("123", "test123")
self.url = self.reverse("submission_api")
def test_create_submission(self, judge_task):
resp = self.client.post(self.url, self.submission_data)
self.assertSuccess(resp)
judge_task.assert_called()
def test_create_submission_with_wrong_language(self, judge_task):
self.submission_data.update({"language": "Python3"})
resp = self.client.post(self.url, self.submission_data)
self.assertFailed(resp)
self.assertDictEqual(resp.data, {"error": "error",
"data": "Python3 is now allowed in the problem"})
judge_task.assert_not_called()

Some files were not shown because too many files have changed in this diff Show More