重构用户权限

This commit is contained in:
2025-09-25 18:41:23 +08:00
parent 8436a4602f
commit a6d76a64c4
7 changed files with 72 additions and 58 deletions

View File

@@ -51,6 +51,9 @@ class User(AbstractBaseUser):
objects = UserManager() objects = UserManager()
def is_regular_user(self):
return self.admin_type == AdminType.REGULAR_USER
def is_admin(self): def is_admin(self):
return self.admin_type == AdminType.ADMIN return self.admin_type == AdminType.ADMIN

View File

@@ -1,4 +1,5 @@
from django import forms from django import forms
from httpx import request
from utils.api import serializers, UsernameSerializer from utils.api import serializers, UsernameSerializer
@@ -131,7 +132,7 @@ class EditUserSerializer(serializers.Serializer):
open_api = serializers.BooleanField() open_api = serializers.BooleanField()
two_factor_auth = serializers.BooleanField() two_factor_auth = serializers.BooleanField()
is_disabled = serializers.BooleanField() is_disabled = serializers.BooleanField()
class_name = serializers.CharField(max_length=32, allow_blank=True, required=False) class_name = serializers.CharField(required=False, allow_null=True)
class EditUserProfileSerializer(serializers.Serializer): class EditUserProfileSerializer(serializers.Serializer):
real_name = serializers.CharField(max_length=32, allow_null=True, required=False) real_name = serializers.CharField(max_length=32, allow_null=True, required=False)

View File

@@ -106,7 +106,7 @@ class UserAdminAPI(APIView):
user.is_disabled = data["is_disabled"] user.is_disabled = data["is_disabled"]
if data["admin_type"] == AdminType.ADMIN: if data["admin_type"] == AdminType.ADMIN:
user.problem_permission = data["problem_permission"] user.problem_permission = data["problem_permission"] or ProblemPermission.OWN
elif data["admin_type"] == AdminType.SUPER_ADMIN: elif data["admin_type"] == AdminType.SUPER_ADMIN:
user.problem_permission = ProblemPermission.ALL user.problem_permission = ProblemPermission.ALL
else: else:
@@ -156,10 +156,10 @@ class UserAdminAPI(APIView):
user = User.objects.all().order_by("-create_time") user = User.objects.all().order_by("-create_time")
is_admin = request.GET.get("admin", "0") type = request.GET.get("type", "")
if is_admin == "1": if type:
user = user.exclude(admin_type=AdminType.REGULAR_USER) user = user.filter(admin_type=type)
keyword = request.GET.get("keyword", None) keyword = request.GET.get("keyword", None)
if keyword: if keyword:

View File

@@ -434,8 +434,9 @@ class UserRankAPI(APIView):
n = 0 n = 0
if rule_type not in ContestRuleType.choices(): if rule_type not in ContestRuleType.choices():
rule_type = ContestRuleType.ACM rule_type = ContestRuleType.ACM
profiles = UserProfile.objects.filter( profiles = UserProfile.objects.filter(
user__admin_type=AdminType.REGULAR_USER, user__admin_type__in=[AdminType.REGULAR_USER, AdminType.ADMIN],
user__is_disabled=False, user__is_disabled=False,
user__username__icontains=username, user__username__icontains=username,
).select_related("user") ).select_related("user")
@@ -456,23 +457,19 @@ class UserActivityRankAPI(APIView):
if not start: if not start:
return self.error("start time is required") return self.error("start time is required")
hidden_names = User.objects.filter( hidden_names = User.objects.filter(
Q(admin_type=AdminType.SUPER_ADMIN) Q(admin_type=AdminType.SUPER_ADMIN) | Q(is_disabled=True)
| Q(admin_type=AdminType.ADMIN)
| Q(is_disabled=True)
).values_list("username", flat=True) ).values_list("username", flat=True)
submissions = Submission.objects.filter( submissions = Submission.objects.filter(
contest_id__isnull=True, create_time__gte=start, result=JudgeStatus.ACCEPTED contest_id__isnull=True,
) create_time__gte=start,
counts = ( result=JudgeStatus.ACCEPTED,
).exclude(username__in=hidden_names)
data = list(
submissions.values("username") submissions.values("username")
.annotate(count=Count("problem_id", distinct=True)) .annotate(count=Count("problem_id", distinct=True))
.order_by("-count")[: 10 + len(hidden_names)] .order_by("-count")[:10]
) )
data = [] return self.success(data)
for count in counts:
if count["username"] not in hidden_names:
data.append(count)
return self.success(data[:10])
class UserProblemRankAPI(APIView): class UserProblemRankAPI(APIView):
@@ -482,8 +479,12 @@ class UserProblemRankAPI(APIView):
if not user.is_authenticated: if not user.is_authenticated:
return self.error("User is not authenticated") return self.error("User is not authenticated")
problem = Problem.objects.get(_id=problem_id, contest_id__isnull=True, visible=True) problem = Problem.objects.get(
submissions = Submission.objects.filter(problem=problem, result=JudgeStatus.ACCEPTED) _id=problem_id, contest_id__isnull=True, visible=True
)
submissions = Submission.objects.filter(
problem=problem, result=JudgeStatus.ACCEPTED
)
all_ac_count = submissions.values("user_id").distinct().count() all_ac_count = submissions.values("user_id").distinct().count()
@@ -491,7 +492,9 @@ class UserProblemRankAPI(APIView):
class_ac_count = 0 class_ac_count = 0
if class_name: if class_name:
users = User.objects.filter(class_name=user.class_name, is_disabled=False).values_list("id", flat=True) users = User.objects.filter(
class_name=user.class_name, is_disabled=False
).values_list("id", flat=True)
user_ids = list(users) user_ids = list(users)
submissions = submissions.filter(user_id__in=user_ids) submissions = submissions.filter(user_id__in=user_ids)
class_ac_count = submissions.values("user_id").distinct().count() class_ac_count = submissions.values("user_id").distinct().count()
@@ -499,21 +502,27 @@ class UserProblemRankAPI(APIView):
my_submissions = submissions.filter(user_id=user.id) my_submissions = submissions.filter(user_id=user.id)
if len(my_submissions) == 0: if len(my_submissions) == 0:
return self.success({ return self.success(
{
"class_name": class_name, "class_name": class_name,
"rank": -1, "rank": -1,
"class_ac_count": class_ac_count, "class_ac_count": class_ac_count,
"all_ac_count": all_ac_count "all_ac_count": all_ac_count,
}) }
)
my_first_submission = my_submissions.order_by("create_time").first() my_first_submission = my_submissions.order_by("create_time").first()
rank = submissions.filter(create_time__lte=my_first_submission.create_time).count() rank = submissions.filter(
return self.success({ create_time__lte=my_first_submission.create_time
).count()
return self.success(
{
"class_name": class_name, "class_name": class_name,
"rank": rank, "rank": rank,
"class_ac_count": class_ac_count, "class_ac_count": class_ac_count,
"all_ac_count": all_ac_count, "all_ac_count": all_ac_count,
}) }
)
class ProfileProblemDisplayIDRefreshAPI(APIView): class ProfileProblemDisplayIDRefreshAPI(APIView):

View File

@@ -6,7 +6,7 @@ from ipaddress import ip_network
import dateutil.parser import dateutil.parser
from django.http import FileResponse from django.http import FileResponse
from account.decorators import check_contest_permission, ensure_created_by from account.decorators import super_admin_required
from account.models import User from account.models import User
from submission.models import Submission, JudgeStatus from submission.models import Submission, JudgeStatus
from utils.api import APIView, validate_serializer from utils.api import APIView, validate_serializer
@@ -23,6 +23,7 @@ from ..serializers import (ContestAnnouncementSerializer, ContestAdminSerializer
class ContestAPI(APIView): class ContestAPI(APIView):
@validate_serializer(CreateConetestSeriaizer) @validate_serializer(CreateConetestSeriaizer)
@super_admin_required
def post(self, request): def post(self, request):
data = request.data data = request.data
data["start_time"] = dateutil.parser.parse(data["start_time"]) data["start_time"] = dateutil.parser.parse(data["start_time"])
@@ -41,11 +42,11 @@ class ContestAPI(APIView):
return self.success(ContestAdminSerializer(contest).data) return self.success(ContestAdminSerializer(contest).data)
@validate_serializer(EditConetestSeriaizer) @validate_serializer(EditConetestSeriaizer)
@super_admin_required
def put(self, request): def put(self, request):
data = request.data data = request.data
try: try:
contest = Contest.objects.get(id=data.pop("id")) contest = Contest.objects.get(id=data.pop("id"))
ensure_created_by(contest, request.user)
except Contest.DoesNotExist: except Contest.DoesNotExist:
return self.error("Contest does not exist") return self.error("Contest does not exist")
data["start_time"] = dateutil.parser.parse(data["start_time"]) data["start_time"] = dateutil.parser.parse(data["start_time"])
@@ -68,19 +69,17 @@ class ContestAPI(APIView):
contest.save() contest.save()
return self.success(ContestAdminSerializer(contest).data) return self.success(ContestAdminSerializer(contest).data)
@super_admin_required
def get(self, request): def get(self, request):
contest_id = request.GET.get("id") contest_id = request.GET.get("id")
if contest_id: if contest_id:
try: try:
contest = Contest.objects.get(id=contest_id) contest = Contest.objects.get(id=contest_id)
ensure_created_by(contest, request.user)
return self.success(ContestAdminSerializer(contest).data) return self.success(ContestAdminSerializer(contest).data)
except Contest.DoesNotExist: except Contest.DoesNotExist:
return self.error("Contest does not exist") return self.error("Contest does not exist")
contests = Contest.objects.all().order_by("-create_time") contests = Contest.objects.all().order_by("-create_time")
if request.user.is_admin():
contests = contests.filter(created_by=request.user)
keyword = request.GET.get("keyword") keyword = request.GET.get("keyword")
if keyword: if keyword:
@@ -90,6 +89,7 @@ class ContestAPI(APIView):
class ContestAnnouncementAPI(APIView): class ContestAnnouncementAPI(APIView):
@validate_serializer(CreateContestAnnouncementSerializer) @validate_serializer(CreateContestAnnouncementSerializer)
@super_admin_required
def post(self, request): def post(self, request):
""" """
Create one contest_announcement. Create one contest_announcement.
@@ -97,7 +97,6 @@ class ContestAnnouncementAPI(APIView):
data = request.data data = request.data
try: try:
contest = Contest.objects.get(id=data.pop("contest_id")) contest = Contest.objects.get(id=data.pop("contest_id"))
ensure_created_by(contest, request.user)
data["contest"] = contest data["contest"] = contest
data["created_by"] = request.user data["created_by"] = request.user
except Contest.DoesNotExist: except Contest.DoesNotExist:
@@ -106,6 +105,7 @@ class ContestAnnouncementAPI(APIView):
return self.success(ContestAnnouncementSerializer(announcement).data) return self.success(ContestAnnouncementSerializer(announcement).data)
@validate_serializer(EditContestAnnouncementSerializer) @validate_serializer(EditContestAnnouncementSerializer)
@super_admin_required
def put(self, request): def put(self, request):
""" """
update contest_announcement update contest_announcement
@@ -113,7 +113,6 @@ class ContestAnnouncementAPI(APIView):
data = request.data data = request.data
try: try:
contest_announcement = ContestAnnouncement.objects.get(id=data.pop("id")) contest_announcement = ContestAnnouncement.objects.get(id=data.pop("id"))
ensure_created_by(contest_announcement, request.user)
except ContestAnnouncement.DoesNotExist: except ContestAnnouncement.DoesNotExist:
return self.error("Contest announcement does not exist") return self.error("Contest announcement does not exist")
for k, v in data.items(): for k, v in data.items():
@@ -121,19 +120,17 @@ class ContestAnnouncementAPI(APIView):
contest_announcement.save() contest_announcement.save()
return self.success() return self.success()
@super_admin_required
def delete(self, request): def delete(self, request):
""" """
Delete one contest_announcement. Delete one contest_announcement.
""" """
contest_announcement_id = request.GET.get("id") contest_announcement_id = request.GET.get("id")
if contest_announcement_id: if contest_announcement_id:
if request.user.is_admin():
ContestAnnouncement.objects.filter(id=contest_announcement_id,
contest__created_by=request.user).delete()
else:
ContestAnnouncement.objects.filter(id=contest_announcement_id).delete() ContestAnnouncement.objects.filter(id=contest_announcement_id).delete()
return self.success() return self.success()
@super_admin_required
def get(self, request): def get(self, request):
""" """
Get one contest_announcement or contest_announcement list. Get one contest_announcement or contest_announcement list.
@@ -142,7 +139,6 @@ class ContestAnnouncementAPI(APIView):
if contest_announcement_id: if contest_announcement_id:
try: try:
contest_announcement = ContestAnnouncement.objects.get(id=contest_announcement_id) contest_announcement = ContestAnnouncement.objects.get(id=contest_announcement_id)
ensure_created_by(contest_announcement, request.user)
return self.success(ContestAnnouncementSerializer(contest_announcement).data) return self.success(ContestAnnouncementSerializer(contest_announcement).data)
except ContestAnnouncement.DoesNotExist: except ContestAnnouncement.DoesNotExist:
return self.error("Contest announcement does not exist") return self.error("Contest announcement does not exist")
@@ -151,8 +147,6 @@ class ContestAnnouncementAPI(APIView):
if not contest_id: if not contest_id:
return self.error("Parameter error") return self.error("Parameter error")
contest_announcements = ContestAnnouncement.objects.filter(contest_id=contest_id) contest_announcements = ContestAnnouncement.objects.filter(contest_id=contest_id)
if request.user.is_admin():
contest_announcements = contest_announcements.filter(created_by=request.user)
keyword = request.GET.get("keyword") keyword = request.GET.get("keyword")
if keyword: if keyword:
contest_announcements = contest_announcements.filter(title__contains=keyword) contest_announcements = contest_announcements.filter(title__contains=keyword)
@@ -160,9 +154,17 @@ class ContestAnnouncementAPI(APIView):
class ACMContestHelper(APIView): class ACMContestHelper(APIView):
@check_contest_permission(check_type="ranks") @super_admin_required
def get(self, request): def get(self, request):
ranks = ACMContestRank.objects.filter(contest=self.contest, accepted_number__gt=0) \ contest_id = request.GET.get("contest_id")
if not contest_id:
return self.error("Parameter error, contest_id is required")
try:
contest = Contest.objects.get(id=contest_id, visible=True)
except Contest.DoesNotExist:
return self.error("Contest does not exist")
ranks = ACMContestRank.objects.filter(contest=contest, accepted_number__gt=0) \
.values("id", "user__username", "user__userprofile__real_name", "submission_info") .values("id", "user__username", "user__userprofile__real_name", "submission_info")
results = [] results = []
for rank in ranks: for rank in ranks:
@@ -179,7 +181,7 @@ class ACMContestHelper(APIView):
results.sort(key=lambda x: -x["ac_info"]["ac_time"]) results.sort(key=lambda x: -x["ac_info"]["ac_time"])
return self.success(results) return self.success(results)
@check_contest_permission(check_type="ranks") @super_admin_required
@validate_serializer(ACMContesHelperSerializer) @validate_serializer(ACMContesHelperSerializer)
def put(self, request): def put(self, request):
data = request.data data = request.data
@@ -222,13 +224,13 @@ class DownloadContestSubmissions(APIView):
user_ac_map[problem_id] = True user_ac_map[problem_id] = True
return path return path
@super_admin_required
def get(self, request): def get(self, request):
contest_id = request.GET.get("contest_id") contest_id = request.GET.get("contest_id")
if not contest_id: if not contest_id:
return self.error("Parameter error") return self.error("Parameter error")
try: try:
contest = Contest.objects.get(id=contest_id) contest = Contest.objects.get(id=contest_id)
ensure_created_by(contest, request.user)
except Contest.DoesNotExist: except Contest.DoesNotExist:
return self.error("Contest does not exist") return self.error("Contest does not exist")

View File

@@ -43,8 +43,7 @@ class Submission(models.Model):
def check_user_permission(self, user, check_share=True): def check_user_permission(self, user, check_share=True):
if ( if (
self.user_id == user.id self.user_id == user.id
or user.is_super_admin() or not user.is_regular_user()
or user.can_mgmt_all_problem()
or self.problem.created_by_id == user.id or self.problem.created_by_id == user.id
): ):
return True return True

View File

@@ -176,7 +176,7 @@ class SubmissionListAPI(APIView):
if ( if (
not SysOptions.submission_list_show_all not SysOptions.submission_list_show_all
and not request.user.is_super_admin() and request.user.is_regular_user()
): ):
return self.success({"results": [], "total": 0}) return self.success({"results": [], "total": 0})