diff --git a/account/decorators.py b/account/decorators.py index f47c4a03..b3523664 100644 --- a/account/decorators.py +++ b/account/decorators.py @@ -92,7 +92,8 @@ def check_contest_permission(func): if not user.is_authenticated(): return self.error("Please login in first.") # password error - if ("contests" not in request.session) or (self.contest.id not in request.session["contests"]): + if ("accessible_contests" not in request.session) or \ + (self.contest.id not in request.session["accessible_contests"]): return self.error("Password is required.") return func(*args, **kwargs) diff --git a/account/middleware.py b/account/middleware.py index 48a6942e..b674346d 100644 --- a/account/middleware.py +++ b/account/middleware.py @@ -1,38 +1,19 @@ -import time -import pytz - -from django.contrib import auth -from django.utils import timezone -from django.utils.translation import ugettext as _ from django.db import connection +from django.utils.timezone import now from django.utils.deprecation import MiddlewareMixin from utils.api import JSONResponse -class SessionSecurityMiddleware(MiddlewareMixin): - def process_request(self, request): - if request.user.is_authenticated(): - if "last_activity" in request.session and request.user.is_admin_role(): - # 24 hours passed since last visit, 86400 = 24 * 60 * 60 - if time.time() - request.session["last_activity"] >= 86400: - auth.logout(request) - return JSONResponse.response({"error": "login-required", "data": _("Please login in first")}) - request.session["last_activity"] = time.time() - - class SessionRecordMiddleware(MiddlewareMixin): def process_request(self, request): if request.user.is_authenticated(): session = request.session - ip = request.META.get("REMOTE_ADDR", "") - user_agent = request.META.get("HTTP_USER_AGENT", "") - _ip = session.setdefault("ip", ip) - _user_agent = session.setdefault("user_agent", user_agent) - if ip != _ip or user_agent != _user_agent: - session.modified = True + session["user_agent"] = request.META.get("HTTP_USER_AGENT", "") + session["ip"] = request.META.get("HTTP_X_REAL_IP", "UNKNOWN IP") + session["last_activity"] = now() user_sessions = request.user.session_keys - if request.session.session_key not in user_sessions: + if session.session_key not in user_sessions: user_sessions.append(session.session_key) request.user.save() @@ -42,13 +23,7 @@ class AdminRoleRequiredMiddleware(MiddlewareMixin): path = request.path_info if path.startswith("/admin/") or path.startswith("/api/admin/"): if not (request.user.is_authenticated() and request.user.is_admin_role()): - return JSONResponse.response({"error": "login-required", "data": _("Please login in first")}) - - -class TimezoneMiddleware(MiddlewareMixin): - def process_request(self, request): - if request.user.is_authenticated(): - timezone.activate(pytz.timezone(request.user.userprofile.time_zone)) + return JSONResponse.response({"error": "login-required", "data": "Please login in first"}) class LogSqlMiddleware(MiddlewareMixin): diff --git a/account/migrations/0001_initial.py b/account/migrations/0001_initial.py index a96776e0..e1e588ee 100644 --- a/account/migrations/0001_initial.py +++ b/account/migrations/0001_initial.py @@ -50,7 +50,7 @@ class Migration(migrations.Migration): fields=[ ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), ('problems_status', jsonfield.fields.JSONField(default={})), - ('avatar', models.CharField(default=account.models._default_avatar, max_length=50)), + ('avatar', models.CharField(default="default.png", max_length=50)), ('blog', models.URLField(blank=True, null=True)), ('mood', models.CharField(blank=True, max_length=200, null=True)), ('accepted_problem_number', models.IntegerField(default=0)), diff --git a/account/migrations/0008_auto_20171011_1214.py b/account/migrations/0008_auto_20171011_1214.py new file mode 100644 index 00000000..7426a1f1 --- /dev/null +++ b/account/migrations/0008_auto_20171011_1214.py @@ -0,0 +1,105 @@ +# -*- coding: utf-8 -*- +# Generated by Django 1.11.4 on 2017-10-11 12:14 +from __future__ import unicode_literals + +import django.contrib.postgres.fields.jsonb +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('account', '0007_auto_20170920_0254'), + ] + + operations = [ + migrations.RemoveField( + model_name='userprofile', + name='language', + ), + migrations.AlterField( + model_name='user', + name='admin_type', + field=models.CharField(default='Regular User', max_length=32), + ), + migrations.AlterField( + model_name='user', + name='auth_token', + field=models.CharField(max_length=32, null=True), + ), + migrations.AlterField( + model_name='user', + name='email', + field=models.EmailField(max_length=64, null=True), + ), + migrations.AlterField( + model_name='user', + name='open_api_appkey', + field=models.CharField(max_length=32, null=True), + ), + migrations.AlterField( + model_name='user', + name='problem_permission', + field=models.CharField(default='None', max_length=32), + ), + migrations.AlterField( + model_name='user', + name='reset_password_token', + field=models.CharField(max_length=32, null=True), + ), + migrations.AlterField( + model_name='user', + name='session_keys', + field=django.contrib.postgres.fields.jsonb.JSONField(default=list), + ), + migrations.AlterField( + model_name='user', + name='tfa_token', + field=models.CharField(max_length=32, null=True), + ), + migrations.AlterField( + model_name='user', + name='username', + field=models.CharField(max_length=32, unique=True), + ), + migrations.AlterField( + model_name='userprofile', + name='acm_problems_status', + field=django.contrib.postgres.fields.jsonb.JSONField(default=dict), + ), + migrations.AlterField( + model_name='userprofile', + name='avatar', + field=models.CharField(default='/static/avatar/default.png', max_length=256), + ), + migrations.AlterField( + model_name='userprofile', + name='github', + field=models.CharField(blank=True, max_length=64, null=True), + ), + migrations.AlterField( + model_name='userprofile', + name='major', + field=models.CharField(blank=True, max_length=64, null=True), + ), + migrations.AlterField( + model_name='userprofile', + name='mood', + field=models.CharField(blank=True, max_length=256, null=True), + ), + migrations.AlterField( + model_name='userprofile', + name='oi_problems_status', + field=django.contrib.postgres.fields.jsonb.JSONField(default=dict), + ), + migrations.AlterField( + model_name='userprofile', + name='real_name', + field=models.CharField(blank=True, max_length=32, null=True), + ), + migrations.AlterField( + model_name='userprofile', + name='school', + field=models.CharField(blank=True, max_length=64, null=True), + ), + ] diff --git a/account/models.py b/account/models.py index b909f932..6ef7f71f 100644 --- a/account/models.py +++ b/account/models.py @@ -1,7 +1,7 @@ from django.contrib.auth.models import AbstractBaseUser from django.conf import settings from django.db import models -from jsonfield import JSONField +from utils.models import JSONField class AdminType(object): @@ -24,22 +24,22 @@ class UserManager(models.Manager): class User(AbstractBaseUser): - username = models.CharField(max_length=30, unique=True) - email = models.EmailField(max_length=254, null=True) + username = models.CharField(max_length=32, unique=True) + email = models.EmailField(max_length=64, null=True) create_time = models.DateTimeField(auto_now_add=True, null=True) # One of UserType - admin_type = models.CharField(max_length=24, default=AdminType.REGULAR_USER) - problem_permission = models.CharField(max_length=24, default=ProblemPermission.NONE) - reset_password_token = models.CharField(max_length=40, null=True) + admin_type = models.CharField(max_length=32, default=AdminType.REGULAR_USER) + problem_permission = models.CharField(max_length=32, default=ProblemPermission.NONE) + reset_password_token = models.CharField(max_length=32, null=True) reset_password_token_expire_time = models.DateTimeField(null=True) # SSO auth token - auth_token = models.CharField(max_length=40, null=True) + auth_token = models.CharField(max_length=32, null=True) two_factor_auth = models.BooleanField(default=False) - tfa_token = models.CharField(max_length=40, null=True) - session_keys = JSONField(default=[]) + tfa_token = models.CharField(max_length=32, null=True) + session_keys = JSONField(default=list) # open api key open_api = models.BooleanField(default=False) - open_api_appkey = models.CharField(max_length=35, null=True) + open_api_appkey = models.CharField(max_length=32, null=True) is_disabled = models.BooleanField(default=False) USERNAME_FIELD = "username" @@ -63,26 +63,34 @@ class User(AbstractBaseUser): db_table = "user" -def _default_avatar(): - return f"/{settings.IMAGE_UPLOAD_DIR}/default.png" - - class UserProfile(models.Model): user = models.OneToOneField(User) - # Store user problem solution status with json string format - # {problems: {1: JudgeStatus.ACCEPTED}, contest_problems: {1: JudgeStatus.ACCEPTED}}, record problem_id and status - acm_problems_status = JSONField(default={}) - # {problems: {1: 33}, contest_problems: {1: 44}, record problem_id and score - oi_problems_status = JSONField(default={}) + # acm_problems_status examples: + # { + # "problems": { + # "1": { + # "status": JudgeStatus.ACCEPTED, + # "_id": "1000" + # } + # }, + # "contest_problems": { + # "1": { + # "status": JudgeStatus.ACCEPTED, + # "_id": "1000" + # } + # } + # } + acm_problems_status = JSONField(default=dict) + # like acm_problems_status, merely add "score" field + oi_problems_status = JSONField(default=dict) - real_name = models.CharField(max_length=30, blank=True, null=True) - avatar = models.CharField(max_length=50, default=_default_avatar()) + real_name = models.CharField(max_length=32, blank=True, null=True) + avatar = models.CharField(max_length=256, default=f"/{settings.IMAGE_UPLOAD_DIR}/default.png") blog = models.URLField(blank=True, null=True) - mood = models.CharField(max_length=200, blank=True, null=True) - github = models.CharField(max_length=50, blank=True, null=True) - school = models.CharField(max_length=200, blank=True, null=True) - major = models.CharField(max_length=200, blank=True, null=True) - language = models.CharField(max_length=32, blank=True, null=True) + mood = models.CharField(max_length=256, blank=True, null=True) + github = models.CharField(max_length=64, blank=True, null=True) + school = models.CharField(max_length=64, blank=True, null=True) + major = models.CharField(max_length=64, blank=True, null=True) # for ACM accepted_number = models.IntegerField(default=0) # for OI diff --git a/account/serializers.py b/account/serializers.py index 8345fc69..aa9675a3 100644 --- a/account/serializers.py +++ b/account/serializers.py @@ -6,27 +6,26 @@ from .models import AdminType, ProblemPermission, User, UserProfile class UserLoginSerializer(serializers.Serializer): - username = serializers.CharField(max_length=30) - password = serializers.CharField(max_length=30) - tfa_code = serializers.CharField(min_length=6, max_length=6, required=False, allow_null=True) + username = serializers.CharField() + password = serializers.CharField() + tfa_code = serializers.CharField(required=False, allow_null=True) class UsernameOrEmailCheckSerializer(serializers.Serializer): - username = serializers.CharField(max_length=30, required=False) - email = serializers.EmailField(max_length=30, required=False) + username = serializers.CharField(required=False) + email = serializers.EmailField(required=False) class UserRegisterSerializer(serializers.Serializer): - username = serializers.CharField(max_length=30) - password = serializers.CharField(max_length=30, min_length=6) - email = serializers.EmailField(max_length=30) - captcha = serializers.CharField(max_length=4, min_length=1) + username = serializers.CharField(max_length=32) + password = serializers.CharField(min_length=6) + email = serializers.EmailField(max_length=64) + captcha = serializers.CharField() class UserChangePasswordSerializer(serializers.Serializer): old_password = serializers.CharField() - new_password = serializers.CharField(max_length=30, min_length=6) - captcha = serializers.CharField(max_length=4, min_length=4) + new_password = serializers.CharField(min_length=6) class UserSerializer(serializers.ModelSerializer): @@ -46,6 +45,7 @@ class UserProfileSerializer(serializers.ModelSerializer): class Meta: model = UserProfile + fields = "__all__" class UserInfoSerializer(serializers.ModelSerializer): @@ -58,9 +58,9 @@ class UserInfoSerializer(serializers.ModelSerializer): class EditUserSerializer(serializers.Serializer): id = serializers.IntegerField() - username = serializers.CharField(max_length=30) - password = serializers.CharField(max_length=30, min_length=6, allow_blank=True, required=False, default=None) - email = serializers.EmailField(max_length=254) + username = serializers.CharField(max_length=32) + password = serializers.CharField(min_length=6, allow_blank=True, required=False, default=None) + email = serializers.EmailField(max_length=64) admin_type = serializers.ChoiceField(choices=(AdminType.REGULAR_USER, AdminType.ADMIN, AdminType.SUPER_ADMIN)) problem_permission = serializers.ChoiceField(choices=(ProblemPermission.NONE, ProblemPermission.OWN, ProblemPermission.ALL)) @@ -70,29 +70,29 @@ class EditUserSerializer(serializers.Serializer): class EditUserProfileSerializer(serializers.Serializer): - real_name = serializers.CharField(max_length=30, allow_blank=True) - avatar = serializers.CharField(max_length=100, allow_blank=True, required=False) - blog = serializers.URLField(allow_blank=True, required=False) - mood = serializers.CharField(max_length=200, allow_blank=True, required=False) - github = serializers.CharField(max_length=50, allow_blank=True, required=False) - school = serializers.CharField(max_length=200, allow_blank=True, required=False) - major = serializers.CharField(max_length=200, allow_blank=True, required=False) + real_name = serializers.CharField(max_length=32, allow_blank=True) + avatar = serializers.CharField(max_length=256, allow_blank=True, required=False) + blog = serializers.URLField(max_length=256, allow_blank=True, required=False) + mood = serializers.CharField(max_length=256, allow_blank=True, required=False) + github = serializers.CharField(max_length=64, allow_blank=True, required=False) + school = serializers.CharField(max_length=64, allow_blank=True, required=False) + major = serializers.CharField(max_length=64, allow_blank=True, required=False) class ApplyResetPasswordSerializer(serializers.Serializer): email = serializers.EmailField() - captcha = serializers.CharField(max_length=4, min_length=4) + captcha = serializers.CharField() class ResetPasswordSerializer(serializers.Serializer): - token = serializers.CharField(min_length=1, max_length=40) - password = serializers.CharField(min_length=6, max_length=30) - captcha = serializers.CharField(max_length=4, min_length=4) + token = serializers.CharField() + password = serializers.CharField(min_length=6) + captcha = serializers.CharField() class SSOSerializer(serializers.Serializer): - appkey = serializers.CharField(max_length=35) - token = serializers.CharField(max_length=40) + appkey = serializers.CharField() + token = serializers.CharField() class TwoFactorAuthCodeSerializer(serializers.Serializer): diff --git a/account/tasks.py b/account/tasks.py index 0aacec96..3e7c1d2f 100644 --- a/account/tasks.py +++ b/account/tasks.py @@ -1,6 +1,31 @@ -from celery import shared_task +import logging -from utils.shortcuts import send_email +from celery import shared_task +from envelopes import Envelope + +from options.options import SysOptions + +logger = logging.getLogger(__name__) + + +def send_email(from_name, to_email, to_name, subject, content): + smtp = SysOptions.smtp_config + if not smtp: + return + envlope = Envelope(from_addr=(smtp["email"], from_name), + to_addr=(to_email, to_name), + subject=subject, + html_body=content) + try: + envlope.send(smtp["server"], + login=smtp["email"], + password=smtp["password"], + port=smtp["port"], + tls=smtp["tls"]) + return True + except Exception as e: + logger.exception(e) + return False @shared_task diff --git a/account/tests.py b/account/tests.py index 7331a0ee..75515ced 100644 --- a/account/tests.py +++ b/account/tests.py @@ -8,11 +8,10 @@ from otpauth import OtpAuth from utils.api.tests import APIClient, APITestCase from utils.shortcuts import rand_str -from utils.cache import default_cache -from utils.constants import CacheKey +from options.options import SysOptions from .models import AdminType, ProblemPermission, User -from conf.models import WebsiteConfig +from utils.constants import ContestRuleType class PermissionDecoratorTest(APITestCase): @@ -136,7 +135,7 @@ class UserLoginAPITest(APITestCase): self.user.save() resp = self.client.post(self.login_url, data={"username": self.username, "password": self.password}) - self.assertDictEqual(resp.data, {"error": "error", "data": "Your account have been disabled"}) + self.assertDictEqual(resp.data, {"error": "error", "data": "Your account has been disabled"}) class CaptchaTest(APITestCase): @@ -157,15 +156,11 @@ class UserRegisterAPITest(CaptchaTest): self.data = {"username": "test_user", "password": "testuserpassword", "real_name": "real_name", "email": "test@qduoj.com", "captcha": self._set_captcha(self.client.session)} - # clea cache in redis - default_cache.delete(CacheKey.website_config) def test_website_config_limit(self): - website = WebsiteConfig.objects.create() - website.allow_register = False - website.save() + SysOptions.allow_register = False resp = self.client.post(self.register_url, data=self.data) - self.assertDictEqual(resp.data, {"error": "error", "data": "Register have been disabled by admin"}) + self.assertDictEqual(resp.data, {"error": "error", "data": "Register function has been disabled by admin"}) def test_invalid_captcha(self): self.data["captcha"] = "****" @@ -226,7 +221,7 @@ class UserProfileAPITest(APITestCase): def test_get_profile_without_login(self): resp = self.client.get(self.url) - self.assertDictEqual(resp.data, {"error": None, "data": {}}) + self.assertDictEqual(resp.data, {"error": None, "data": None}) def test_get_profile(self): self.create_user("test", "test123") @@ -247,7 +242,6 @@ class TwoFactorAuthAPITest(APITestCase): def setUp(self): self.url = self.reverse("two_factor_auth_api") self.create_user("test", "test123") - self.create_website_config() def _get_tfa_code(self): user = User.objects.first() @@ -295,7 +289,6 @@ class ApplyResetPasswordAPITest(CaptchaTest): user.email = "test@oj.com" user.save() self.url = self.reverse("apply_reset_password_api") - self.create_website_config() self.data = {"email": "test@oj.com", "captcha": self._set_captcha(self.client.session)} def _refresh_captcha(self): @@ -343,14 +336,14 @@ class ResetPasswordAPITest(CaptchaTest): def test_reset_password_with_invalid_token(self): self.data["token"] = "aaaaaaaaaaa" resp = self.client.post(self.url, data=self.data) - self.assertDictEqual(resp.data, {"error": "error", "data": "Token dose not exist"}) + self.assertDictEqual(resp.data, {"error": "error", "data": "Token does not exist"}) def test_reset_password_with_expired_token(self): user = User.objects.first() user.reset_password_token_expire_time = now() - timedelta(seconds=30) user.save() resp = self.client.post(self.url, data=self.data) - self.assertDictEqual(resp.data, {"error": "error", "data": "Token have expired"}) + self.assertDictEqual(resp.data, {"error": "error", "data": "Token has expired"}) class UserChangePasswordAPITest(CaptchaTest): @@ -481,14 +474,14 @@ class UserRankAPITest(APITestCase): profile2.save() def test_get_acm_rank(self): - resp = self.client.get(self.url, data={"rule": "acm"}) + resp = self.client.get(self.url, data={"rule": ContestRuleType.ACM}) self.assertSuccess(resp) data = resp.data["data"] self.assertEqual(data[0]["user"]["username"], "test1") self.assertEqual(data[1]["user"]["username"], "test2") def test_get_oi_rank(self): - resp = self.client.get(self.url, data={"rule": "oi"}) + resp = self.client.get(self.url, data={"rule": ContestRuleType.OI}) self.assertSuccess(resp) data = resp.data["data"] self.assertEqual(data[0]["user"]["username"], "test2") diff --git a/account/urls/oj.py b/account/urls/oj.py index b80b34e0..aacbb59d 100644 --- a/account/urls/oj.py +++ b/account/urls/oj.py @@ -3,7 +3,7 @@ from django.conf.urls import url from ..views.oj import (ApplyResetPasswordAPI, ResetPasswordAPI, UserChangePasswordAPI, UserRegisterAPI, UserLoginAPI, UserLogoutAPI, UsernameOrEmailCheck, - SSOAPI, AvatarUploadAPI, TwoFactorAuthAPI, UserProfileAPI, + AvatarUploadAPI, TwoFactorAuthAPI, UserProfileAPI, UserRankAPI, CheckTFARequiredAPI, SessionManagementAPI) from utils.captcha.views import CaptchaAPIView @@ -19,7 +19,6 @@ urlpatterns = [ url(r"^check_username_or_email", UsernameOrEmailCheck.as_view(), name="check_username_or_email"), url(r"^profile/?$", UserProfileAPI.as_view(), name="user_profile_api"), url(r"^upload_avatar/?$", AvatarUploadAPI.as_view(), name="avatar_upload_api"), - url(r"^sso/?$", SSOAPI.as_view(), name="sso_api"), url(r"^tfa_required/?$", CheckTFARequiredAPI.as_view(), name="tfa_required_check"), url(r"^two_factor_auth/?$", TwoFactorAuthAPI.as_view(), name="two_factor_auth_api"), url(r"^user_rank/?$", UserRankAPI.as_view(), name="user_rank_api"), diff --git a/account/views/oj.py b/account/views/oj.py index cfb94726..5f72156e 100644 --- a/account/views/oj.py +++ b/account/views/oj.py @@ -1,32 +1,28 @@ import os -import qrcode -import pickle from datetime import timedelta -from otpauth import OtpAuth +from importlib import import_module +import qrcode from django.conf import settings from django.contrib import auth -from importlib import import_module +from django.template.loader import render_to_string +from django.utils.decorators import method_decorator from django.utils.timezone import now from django.views.decorators.csrf import ensure_csrf_cookie -from django.utils.decorators import method_decorator -from django.template.loader import render_to_string +from otpauth import OtpAuth -from conf.models import WebsiteConfig +from utils.constants import ContestRuleType +from options.options import SysOptions from utils.api import APIView, validate_serializer from utils.captcha import Captcha -from utils.shortcuts import rand_str, img2base64, timestamp2utcstr -from utils.cache import default_cache -from utils.constants import CacheKey - +from utils.shortcuts import rand_str, img2base64, datetime2str from ..decorators import login_required from ..models import User, UserProfile from ..serializers import (ApplyResetPasswordSerializer, ResetPasswordSerializer, UserChangePasswordSerializer, UserLoginSerializer, UserRegisterSerializer, UsernameOrEmailCheckSerializer, RankInfoSerializer) -from ..serializers import (SSOSerializer, TwoFactorAuthCodeSerializer, - UserProfileSerializer, +from ..serializers import (TwoFactorAuthCodeSerializer, UserProfileSerializer, EditUserProfileSerializer, AvatarUploadForm) from ..tasks import send_email_async @@ -39,7 +35,7 @@ class UserProfileAPI(APIView): """ user = request.user if not user.is_authenticated(): - return self.success({}) + return self.success() username = request.GET.get("username") try: if username: @@ -48,8 +44,7 @@ class UserProfileAPI(APIView): user = request.user except User.DoesNotExist: return self.error("User does not exist") - profile = UserProfile.objects.select_related("user").get(user=user) - return self.success(UserProfileSerializer(profile).data) + return self.success(UserProfileSerializer(user.userprofile).data) @validate_serializer(EditUserProfileSerializer) @login_required @@ -72,8 +67,7 @@ class AvatarUploadAPI(APIView): avatar = form.cleaned_data["file"] else: return self.error("Invalid file content") - # 2097152 = 2 * 1024 * 1024 = 2MB - if avatar.size > 2097152: + if avatar.size > 2 * 1024 * 1024: return self.error("Picture is too large") suffix = os.path.splitext(avatar.name)[-1].lower() if suffix not in [".gif", ".jpg", ".jpeg", ".bmp", ".png"]: @@ -84,46 +78,12 @@ class AvatarUploadAPI(APIView): for chunk in avatar: img.write(chunk) user_profile = request.user.userprofile - _, old_avatar = os.path.split(user_profile.avatar) - if old_avatar != "default.png": - os.remove(os.path.join(settings.IMAGE_UPLOAD_DIR_ABS, old_avatar)) user_profile.avatar = f"/{settings.IMAGE_UPLOAD_DIR}/{name}" user_profile.save() return self.success("Succeeded") -class SSOAPI(APIView): - @login_required - def get(self, request): - callback = request.GET.get("callback", None) - if not callback: - return self.error("Parameter Error") - token = rand_str() - request.user.auth_token = token - request.user.save() - return self.success({"redirect_url": callback + "?token=" + token, - "callback": callback}) - - @validate_serializer(SSOSerializer) - def post(self, request): - data = request.data - try: - User.objects.get(open_api_appkey=data["appkey"]) - except User.DoesNotExist: - return self.error("Invalid appkey") - try: - user = User.objects.get(auth_token=data["token"]) - user.auth_token = None - user.save() - return self.success({"username": user.username, - "id": user.id, - "admin_type": user.admin_type, - "avatar": user.userprofile.avatar}) - except User.DoesNotExist: - return self.error("User does not exist") - - class TwoFactorAuthAPI(APIView): @login_required def get(self, request): @@ -132,14 +92,13 @@ class TwoFactorAuthAPI(APIView): """ user = request.user if user.two_factor_auth: - return self.error("Already open 2FA") + return self.error("2FA is already turned on") token = rand_str() user.tfa_token = token user.save() - config = WebsiteConfig.objects.first() - label = f"{config.name_shortcut}:{user.username}" - image = qrcode.make(OtpAuth(token).to_uri("totp", label, config.name)) + label = f"{SysOptions.website_name_shortcut}:{user.username}" + image = qrcode.make(OtpAuth(token).to_uri("totp", label, SysOptions.website_name)) return self.success(img2base64(image)) @login_required @@ -163,7 +122,7 @@ class TwoFactorAuthAPI(APIView): code = request.data["code"] user = request.user if not user.two_factor_auth: - return self.error("Other session have disabled TFA") + return self.error("2FA is already turned off") if OtpAuth(user.tfa_token).valid_totp(code): user.two_factor_auth = False user.save() @@ -200,7 +159,7 @@ class UserLoginAPI(APIView): # None is returned if username or password is wrong if user: if user.is_disabled: - return self.error("Your account have been disabled") + return self.error("Your account has been disabled") if not user.two_factor_auth: auth.login(request, user) return self.success("Succeeded") @@ -220,13 +179,13 @@ class UserLoginAPI(APIView): # todo remove this, only for debug use def get(self, request): auth.login(request, auth.authenticate(username=request.GET["username"], password=request.GET["password"])) - return self.success({}) + return self.success() class UserLogoutAPI(APIView): def get(self, request): auth.logout(request) - return self.success({}) + return self.success() class UsernameOrEmailCheck(APIView): @@ -242,11 +201,9 @@ class UsernameOrEmailCheck(APIView): "email": False } if data.get("username"): - if User.objects.filter(username=data["username"]).exists(): - result["username"] = True + result["username"] = User.objects.filter(username=data["username"]).exists() if data.get("email"): - if User.objects.filter(email=data["email"]).exists(): - result["email"] = True + result["email"] = User.objects.filter(email=data["email"]).exists() return self.success(result) @@ -256,17 +213,9 @@ class UserRegisterAPI(APIView): """ User register api """ - config = default_cache.get(CacheKey.website_config) - if config: - config = pickle.loads(config) - else: - config = WebsiteConfig.objects.first() - if not config: - config = WebsiteConfig.objects.create() - default_cache.set(CacheKey.website_config, pickle.dumps(config)) - if not config.allow_register: - return self.error("Register have been disabled by admin") + if not SysOptions.allow_register: + return self.error("Register function has been disabled by admin") data = request.data captcha = Captcha(request) @@ -295,6 +244,7 @@ class UserChangePasswordAPI(APIView): username = request.user.username user = auth.authenticate(username=username, password=data["old_password"]) if user: + # TODO: check tfa? user.set_password(data["new_password"]) user.save() return self.success("Succeeded") @@ -307,7 +257,6 @@ class ApplyResetPasswordAPI(APIView): def post(self, request): data = request.data captcha = Captcha(request) - config = WebsiteConfig.objects.first() if not captcha.check(data["captcha"]): return self.error("Invalid captcha") try: @@ -322,14 +271,14 @@ class ApplyResetPasswordAPI(APIView): user.save() render_data = { "username": user.username, - "website_name": config.name, - "link": f"{config.base_url}/reset-password/{user.reset_password_token}" + "website_name": SysOptions.website_name, + "link": f"{SysOptions.website_base_url}/reset-password/{user.reset_password_token}" } email_html = render_to_string("reset_password_email.html", render_data) - send_email_async.delay(config.name, + send_email_async.delay(SysOptions.website_name, user.email, user.username, - config.name + " 登录信息找回邮件", + f"{SysOptions.website_name} 登录信息找回邮件", email_html) return self.success("Succeeded") @@ -344,9 +293,9 @@ class ResetPasswordAPI(APIView): try: user = User.objects.get(reset_password_token=data["token"]) except User.DoesNotExist: - return self.error("Token dose not exist") - if int((user.reset_password_token_expire_time - now()).total_seconds()) < 0: - return self.error("Token have expired") + return self.error("Token does not exist") + if user.reset_password_token_expire_time < now(): + return self.error("Token has expired") user.reset_password_token = None user.two_factor_auth = False user.set_password(data["password"]) @@ -358,14 +307,13 @@ class SessionManagementAPI(APIView): @login_required def get(self, request): engine = import_module(settings.SESSION_ENGINE) - SessionStore = engine.SessionStore - current_session = request.COOKIES.get(settings.SESSION_COOKIE_NAME) + session_store = engine.SessionStore current_session = request.session.session_key session_keys = request.user.session_keys result = [] modified = False for key in session_keys[:]: - session = SessionStore(key) + session = session_store(key) # session does not exist or is expiry if not session._session: session_keys.remove(key) @@ -377,7 +325,7 @@ class SessionManagementAPI(APIView): s["current_session"] = True s["ip"] = session["ip"] s["user_agent"] = session["user_agent"] - s["last_activity"] = timestamp2utcstr(session["last_activity"]) + s["last_activity"] = datetime2str(session["last_activity"]) s["session_key"] = key result.append(s) if modified: @@ -401,12 +349,12 @@ class SessionManagementAPI(APIView): class UserRankAPI(APIView): def get(self, request): rule_type = request.GET.get("rule") - if rule_type not in ["acm", "oi"]: - rule_type = "acm" + if rule_type not in ContestRuleType.choices(): + rule_type = ContestRuleType.ACM profiles = UserProfile.objects.select_related("user")\ .filter(submission_number__gt=0)\ .exclude(user__is_disabled=True) - if rule_type == "acm": + if rule_type == ContestRuleType.ACM: profiles = profiles.order_by("-accepted_number", "submission_number") else: profiles = profiles.order_by("-total_score") diff --git a/announcement/migrations/0002_auto_20171011_1214.py b/announcement/migrations/0002_auto_20171011_1214.py new file mode 100644 index 00000000..ffe4c969 --- /dev/null +++ b/announcement/migrations/0002_auto_20171011_1214.py @@ -0,0 +1,20 @@ +# -*- coding: utf-8 -*- +# Generated by Django 1.11.4 on 2017-10-11 12:14 +from __future__ import unicode_literals + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('announcement', '0001_initial'), + ] + + operations = [ + migrations.AlterField( + model_name='announcement', + name='title', + field=models.CharField(max_length=64), + ), + ] diff --git a/announcement/models.py b/announcement/models.py index 186d4ea6..49f57b82 100644 --- a/announcement/models.py +++ b/announcement/models.py @@ -5,7 +5,7 @@ from utils.models import RichTextField class Announcement(models.Model): - title = models.CharField(max_length=50) + title = models.CharField(max_length=64) # HTML content = RichTextField() create_time = models.DateTimeField(auto_now_add=True) diff --git a/announcement/serializers.py b/announcement/serializers.py index 0c0beccc..b660a615 100644 --- a/announcement/serializers.py +++ b/announcement/serializers.py @@ -5,8 +5,8 @@ from .models import Announcement class CreateAnnouncementSerializer(serializers.Serializer): - title = serializers.CharField(max_length=50) - content = serializers.CharField(max_length=10000) + title = serializers.CharField(max_length=64) + content = serializers.CharField(max_length=1024 * 1024 * 8) visible = serializers.BooleanField() @@ -21,6 +21,6 @@ class AnnouncementSerializer(serializers.ModelSerializer): class EditAnnouncementSerializer(serializers.Serializer): id = serializers.IntegerField() - title = serializers.CharField(max_length=50) - content = serializers.CharField(max_length=10000) + title = serializers.CharField(max_length=64) + content = serializers.CharField(max_length=1024 * 1024 * 8) visible = serializers.BooleanField() diff --git a/conf/migrations/0002_auto_20171011_1214.py b/conf/migrations/0002_auto_20171011_1214.py new file mode 100644 index 00000000..ef355b50 --- /dev/null +++ b/conf/migrations/0002_auto_20171011_1214.py @@ -0,0 +1,39 @@ +# -*- coding: utf-8 -*- +# Generated by Django 1.11.4 on 2017-10-11 12:14 +from __future__ import unicode_literals + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('conf', '0001_initial'), + ] + + operations = [ + migrations.DeleteModel( + name='JudgeServerToken', + ), + migrations.DeleteModel( + name='SMTPConfig', + ), + migrations.DeleteModel( + name='WebsiteConfig', + ), + migrations.AlterField( + model_name='judgeserver', + name='hostname', + field=models.CharField(max_length=128), + ), + migrations.AlterField( + model_name='judgeserver', + name='judger_version', + field=models.CharField(max_length=32), + ), + migrations.AlterField( + model_name='judgeserver', + name='service_url', + field=models.CharField(blank=True, max_length=256, null=True), + ), + ] diff --git a/conf/models.py b/conf/models.py index 9fe3cc51..4c6348d5 100644 --- a/conf/models.py +++ b/conf/models.py @@ -2,42 +2,17 @@ from django.db import models from django.utils import timezone -class SMTPConfig(models.Model): - server = models.CharField(max_length=128) - port = models.IntegerField(default=25) - email = models.CharField(max_length=128) - password = models.CharField(max_length=128) - tls = models.BooleanField() - - class Meta: - db_table = "smtp_config" - - -class WebsiteConfig(models.Model): - base_url = models.CharField(max_length=128, default="http://127.0.0.1") - name = models.CharField(max_length=32, default="Online Judge") - name_shortcut = models.CharField(max_length=32, default="oj") - footer = models.TextField(default="Online Judge Footer") - # allow register - allow_register = models.BooleanField(default=True) - # submission list show all user's submission - submission_list_show_all = models.BooleanField(default=True) - - class Meta: - db_table = "website_config" - - class JudgeServer(models.Model): - hostname = models.CharField(max_length=64) + hostname = models.CharField(max_length=128) ip = models.CharField(max_length=32, blank=True, null=True) - judger_version = models.CharField(max_length=24) + judger_version = models.CharField(max_length=32) cpu_core = models.IntegerField() memory_usage = models.FloatField() cpu_usage = models.FloatField() last_heartbeat = models.DateTimeField() create_time = models.DateTimeField(auto_now_add=True) task_number = models.IntegerField(default=0) - service_url = models.CharField(max_length=128, blank=True, null=True) + service_url = models.CharField(max_length=256, blank=True, null=True) @property def status(self): @@ -48,10 +23,3 @@ class JudgeServer(models.Model): class Meta: db_table = "judge_server" - - -class JudgeServerToken(models.Model): - token = models.CharField(max_length=32) - - class Meta: - db_table = "judge_server_token" diff --git a/conf/serializers.py b/conf/serializers.py index 59b7203c..7f0cf575 100644 --- a/conf/serializers.py +++ b/conf/serializers.py @@ -1,6 +1,6 @@ from utils.api import DateTimeTZField, serializers -from .models import JudgeServer, SMTPConfig, WebsiteConfig +from .models import JudgeServer class EditSMTPConfigSerializer(serializers.Serializer): @@ -15,31 +15,19 @@ class CreateSMTPConfigSerializer(EditSMTPConfigSerializer): password = serializers.CharField(max_length=128) -class SMTPConfigSerializer(serializers.ModelSerializer): - class Meta: - model = SMTPConfig - exclude = ["id", "password"] - - class TestSMTPConfigSerializer(serializers.Serializer): email = serializers.EmailField() class CreateEditWebsiteConfigSerializer(serializers.Serializer): - base_url = serializers.CharField(max_length=128) - name = serializers.CharField(max_length=32) - name_shortcut = serializers.CharField(max_length=32) - footer = serializers.CharField(max_length=1024) + website_base_url = serializers.CharField(max_length=128) + website_name = serializers.CharField(max_length=64) + website_name_shortcut = serializers.CharField(max_length=64) + website_footer = serializers.CharField(max_length=1024 * 1024) allow_register = serializers.BooleanField() submission_list_show_all = serializers.BooleanField() -class WebsiteConfigSerializer(serializers.ModelSerializer): - class Meta: - model = WebsiteConfig - exclude = ["id"] - - class JudgeServerSerializer(serializers.ModelSerializer): create_time = DateTimeTZField() last_heartbeat = DateTimeTZField() @@ -47,13 +35,14 @@ class JudgeServerSerializer(serializers.ModelSerializer): class Meta: model = JudgeServer + fields = "__all__" class JudgeServerHeartbeatSerializer(serializers.Serializer): - hostname = serializers.CharField(max_length=64) - judger_version = serializers.CharField(max_length=24) + hostname = serializers.CharField(max_length=128) + judger_version = serializers.CharField(max_length=32) cpu_core = serializers.IntegerField(min_value=1) memory = serializers.FloatField(min_value=0, max_value=100) cpu = serializers.FloatField(min_value=0, max_value=100) action = serializers.ChoiceField(choices=("heartbeat", )) - service_url = serializers.CharField(max_length=128, required=False) + service_url = serializers.CharField(max_length=256, required=False) diff --git a/conf/tests.py b/conf/tests.py index 1694c218..5906b078 100644 --- a/conf/tests.py +++ b/conf/tests.py @@ -2,11 +2,10 @@ import hashlib from django.utils import timezone +from options.options import SysOptions from utils.api.tests import APITestCase -from utils.cache import default_cache from utils.constants import CacheKey - -from .models import JudgeServer, JudgeServerToken, SMTPConfig +from .models import JudgeServer class SMTPConfigTest(APITestCase): @@ -29,10 +28,6 @@ class SMTPConfigTest(APITestCase): "tls": True} resp = self.client.put(self.url, data=data) self.assertSuccess(resp) - smtp = SMTPConfig.objects.first() - self.assertEqual(smtp.password, self.password) - self.assertEqual(smtp.server, "smtp1.test.com") - self.assertEqual(smtp.email, "test2@test.com") def test_edit_without_password1(self): self.test_create_smtp_config() @@ -40,7 +35,6 @@ class SMTPConfigTest(APITestCase): "tls": True, "password": ""} resp = self.client.put(self.url, data=data) self.assertSuccess(resp) - self.assertEqual(SMTPConfig.objects.first().password, self.password) def test_edit_with_password(self): self.test_create_smtp_config() @@ -48,18 +42,14 @@ class SMTPConfigTest(APITestCase): "tls": True, "password": "newpassword"} resp = self.client.put(self.url, data=data) self.assertSuccess(resp) - smtp = SMTPConfig.objects.first() - self.assertEqual(smtp.password, "newpassword") - self.assertEqual(smtp.server, "smtp1.test.com") - self.assertEqual(smtp.email, "test2@test.com") class WebsiteConfigAPITest(APITestCase): def test_create_website_config(self): self.create_super_admin() url = self.reverse("website_config_api") - data = {"base_url": "http://test.com", "name": "test name", - "name_shortcut": "test oj", "footer": "test", + data = {"website_base_url": "http://test.com", "website_name": "test name", + "website_name_shortcut": "test oj", "website_footer": "test", "allow_register": True, "submission_list_show_all": False} resp = self.client.post(url, data=data) self.assertSuccess(resp) @@ -67,8 +57,8 @@ class WebsiteConfigAPITest(APITestCase): def test_edit_website_config(self): self.create_super_admin() url = self.reverse("website_config_api") - data = {"base_url": "http://test.com", "name": "test name", - "name_shortcut": "test oj", "footer": "test", + data = {"website_base_url": "http://test.com", "website_name": "test name", + "website_name_shortcut": "test oj", "website_footer": "test", "allow_register": True, "submission_list_show_all": False} resp = self.client.post(url, data=data) self.assertSuccess(resp) @@ -78,10 +68,6 @@ class WebsiteConfigAPITest(APITestCase): url = self.reverse("website_info_api") resp = self.client.get(url) self.assertSuccess(resp) - self.assertEqual(resp.data["data"]["name_shortcut"], "oj") - - def tearDown(self): - default_cache.delete(CacheKey.website_config) class JudgeServerHeartbeatTest(APITestCase): @@ -91,7 +77,7 @@ class JudgeServerHeartbeatTest(APITestCase): "cpu": 90.5, "memory": 80.3, "action": "heartbeat"} self.token = "test" self.hashed_token = hashlib.sha256(self.token.encode("utf-8")).hexdigest() - JudgeServerToken.objects.create(token=self.token) + SysOptions.judge_server_token = self.token def test_new_heartbeat(self): resp = self.client.post(self.url, data=self.data, **{"HTTP_X_JUDGE_SERVER_TOKEN": self.hashed_token}) @@ -127,11 +113,9 @@ class JudgeServerAPITest(APITestCase): self.create_super_admin() def test_get_judge_server(self): - self.assertFalse(JudgeServerToken.objects.exists()) resp = self.client.get(self.url) self.assertSuccess(resp) self.assertEqual(len(resp.data["data"]["servers"]), 1) - self.assertEqual(JudgeServerToken.objects.first().token, resp.data["data"]["token"]) def test_delete_judge_server(self): resp = self.client.delete(self.url + "?hostname=testhostname") diff --git a/conf/views.py b/conf/views.py index 814c0620..f09972cc 100644 --- a/conf/views.py +++ b/conf/views.py @@ -1,54 +1,45 @@ import hashlib -import pickle from django.utils import timezone from account.decorators import super_admin_required -from judge.languages import languages, spj_languages from judge.dispatcher import process_pending_task +from judge.languages import languages, spj_languages +from options.options import SysOptions from utils.api import APIView, CSRFExemptAPIView, validate_serializer -from utils.shortcuts import rand_str -from utils.cache import default_cache -from utils.constants import CacheKey - -from .models import JudgeServer, JudgeServerToken, SMTPConfig, WebsiteConfig +from .models import JudgeServer from .serializers import (CreateEditWebsiteConfigSerializer, CreateSMTPConfigSerializer, EditSMTPConfigSerializer, JudgeServerHeartbeatSerializer, - JudgeServerSerializer, SMTPConfigSerializer, - TestSMTPConfigSerializer, WebsiteConfigSerializer) + JudgeServerSerializer, TestSMTPConfigSerializer) class SMTPAPI(APIView): @super_admin_required def get(self, request): - smtp = SMTPConfig.objects.first() + smtp = SysOptions.smtp_config if not smtp: return self.success(None) - return self.success(SMTPConfigSerializer(smtp).data) + smtp.pop("password") + return self.success(smtp) @validate_serializer(CreateSMTPConfigSerializer) @super_admin_required def post(self, request): - SMTPConfig.objects.all().delete() - smtp = SMTPConfig.objects.create(**request.data) - return self.success(SMTPConfigSerializer(smtp).data) + SysOptions.smtp_config = request.data + return self.success() @validate_serializer(EditSMTPConfigSerializer) @super_admin_required def put(self, request): + smtp = SysOptions.smtp_config data = request.data - smtp = SMTPConfig.objects.first() - if not smtp: - return self.error("SMTP config is missing") - smtp.server = data["server"] - smtp.port = data["port"] - smtp.email = data["email"] - smtp.tls = data["tls"] - if data.get("password"): - smtp.password = data["password"] - smtp.save() - return self.success(SMTPConfigSerializer(smtp).data) + for item in ["server", "port", "email", "tls"]: + smtp[item] = data[item] + if "password" in data: + smtp["password"] = data["password"] + SysOptions.smtp_config = smtp + return self.success() class SMTPTestAPI(APIView): @@ -60,37 +51,24 @@ class SMTPTestAPI(APIView): class WebsiteConfigAPI(APIView): def get(self, request): - config = default_cache.get(CacheKey.website_config) - if config: - config = pickle.loads(config) - else: - config = WebsiteConfig.objects.first() - if not config: - config = WebsiteConfig.objects.create() - default_cache.set(CacheKey.website_config, pickle.dumps(config)) - return self.success(WebsiteConfigSerializer(config).data) + ret = {key: getattr(SysOptions, key) for key in + ["website_base_url", "website_name", "website_name_shortcut", + "website_footer", "allow_register", "submission_list_show_all"]} + return self.success(ret) @validate_serializer(CreateEditWebsiteConfigSerializer) @super_admin_required def post(self, request): - data = request.data - WebsiteConfig.objects.all().delete() - config = WebsiteConfig.objects.create(**data) - default_cache.set(CacheKey.website_config, pickle.dumps(config)) - return self.success(WebsiteConfigSerializer(config).data) + for k, v in request.data.items(): + setattr(SysOptions, k, v) + return self.success() class JudgeServerAPI(APIView): @super_admin_required def get(self, request): - judge_server_token = JudgeServerToken.objects.first() - if not judge_server_token: - token = rand_str(12) - JudgeServerToken.objects.create(token=token) - else: - token = judge_server_token.token servers = JudgeServer.objects.all().order_by("-last_heartbeat") - return self.success({"token": token, + return self.success({"token": SysOptions.judge_server_token, "servers": JudgeServerSerializer(servers, many=True).data}) @super_admin_required @@ -104,15 +82,9 @@ class JudgeServerAPI(APIView): class JudgeServerHeartbeatAPI(CSRFExemptAPIView): @validate_serializer(JudgeServerHeartbeatSerializer) def post(self, request): - judge_server_token = JudgeServerToken.objects.first() - if not judge_server_token: - token = rand_str(12) - JudgeServerToken.objects.create(token=token) - else: - token = judge_server_token.token data = request.data client_token = request.META.get("HTTP_X_JUDGE_SERVER_TOKEN") - if hashlib.sha256(token.encode("utf-8")).hexdigest() != client_token: + if hashlib.sha256(SysOptions.judge_server_token.encode("utf-8")).hexdigest() != client_token: return self.error("Invalid token") service_url = data.get("service_url") diff --git a/contest/migrations/0006_auto_20171011_1214.py b/contest/migrations/0006_auto_20171011_1214.py new file mode 100644 index 00000000..d429742f --- /dev/null +++ b/contest/migrations/0006_auto_20171011_1214.py @@ -0,0 +1,26 @@ +# -*- coding: utf-8 -*- +# Generated by Django 1.11.4 on 2017-10-11 12:14 +from __future__ import unicode_literals + +import django.contrib.postgres.fields.jsonb +from django.db import migrations + + +class Migration(migrations.Migration): + + dependencies = [ + ('contest', '0005_auto_20170823_0918'), + ] + + operations = [ + migrations.AlterField( + model_name='acmcontestrank', + name='submission_info', + field=django.contrib.postgres.fields.jsonb.JSONField(default=dict), + ), + migrations.AlterField( + model_name='oicontestrank', + name='submission_info', + field=django.contrib.postgres.fields.jsonb.JSONField(default=dict), + ), + ] diff --git a/contest/models.py b/contest/models.py index 3383d17a..eb1c88a8 100644 --- a/contest/models.py +++ b/contest/models.py @@ -1,27 +1,13 @@ +from utils.constants import ContestRuleType # noqa from django.db import models from django.utils.timezone import now -from jsonfield import JSONField +from utils.models import JSONField +from utils.constants import ContestStatus, ContestType from account.models import User, AdminType from utils.models import RichTextField -class ContestType(object): - PUBLIC_CONTEST = "Public" - PASSWORD_PROTECTED_CONTEST = "Password Protected" - - -class ContestStatus(object): - CONTEST_NOT_START = "1" - CONTEST_ENDED = "-1" - CONTEST_UNDERWAY = "0" - - -class ContestRuleType(object): - ACM = "ACM" - OI = "OI" - - class Contest(models.Model): title = models.CharField(max_length=40) description = RichTextField() @@ -59,12 +45,20 @@ class Contest(models.Model): def is_contest_admin(self, user): return user.is_authenticated() and (self.created_by == user or user.admin_type == AdminType.SUPER_ADMIN) + def check_oi_permission(self, user): + if self.status != ContestStatus.CONTEST_ENDED and self.real_time_rank == False: + if self.is_contest_admin(user): + return True + else: + return False + return True + class Meta: db_table = "contest" ordering = ("-create_time",) -class ContestRank(models.Model): +class AbstractContestRank(models.Model): user = models.ForeignKey(User) contest = models.ForeignKey(Contest) submission_number = models.IntegerField(default=0) @@ -73,30 +67,27 @@ class ContestRank(models.Model): abstract = True -class ACMContestRank(ContestRank): +class ACMContestRank(AbstractContestRank): accepted_number = models.IntegerField(default=0) # total_time is only for ACM contest total_time = ac time + none-ac times * 20 * 60 total_time = models.IntegerField(default=0) # {23: {"is_ac": True, "ac_time": 8999, "error_number": 2, "is_first_ac": True}} # key is problem id - submission_info = JSONField(default={}) + submission_info = JSONField(default=dict) class Meta: db_table = "acm_contest_rank" -class OIContestRank(ContestRank): +class OIContestRank(AbstractContestRank): total_score = models.IntegerField(default=0) # {23: 333}} # key is problem id, value is current score - submission_info = JSONField(default={}) + submission_info = JSONField(default=dict) class Meta: db_table = "oi_contest_rank" - def update_rank(self, submission): - self.submission_number += 1 - class ContestAnnouncement(models.Model): contest = models.ForeignKey(Contest) diff --git a/contest/tests.py b/contest/tests.py index c481becb..df05df87 100644 --- a/contest/tests.py +++ b/contest/tests.py @@ -79,7 +79,7 @@ class ContestAPITest(APITestCase): self.create_user("test", "test123") url = self.reverse("contest_password_api") resp = self.client.post(url, {"contest_id": contest_id, "password": "error_password"}) - self.assertDictEqual(resp.data, {"error": "error", "data": "Password doesn't match."}) + self.assertDictEqual(resp.data, {"error": "error", "data": "Wrong password"}) resp = self.client.post(url, {"contest_id": contest_id, "password": DEFAULT_CONTEST_DATA["password"]}) self.assertSuccess(resp) @@ -89,7 +89,7 @@ class ContestAPITest(APITestCase): self.create_user("test", "test123") url = self.reverse("contest_access_api") resp = self.client.get(url + "?contest_id=" + str(contest_id)) - self.assertFalse(resp.data["data"]["Access"]) + self.assertFalse(resp.data["data"]["access"]) password_url = self.reverse("contest_password_api") resp = self.client.post(password_url, {"contest_id": contest_id, "password": DEFAULT_CONTEST_DATA["password"]}) diff --git a/contest/views/oj.py b/contest/views/oj.py index b25019c5..1e787ce2 100644 --- a/contest/views/oj.py +++ b/contest/views/oj.py @@ -1,13 +1,11 @@ -import pickle from django.utils.timezone import now -from django.db.models import Q +from django.core.cache import cache from utils.api import APIView, validate_serializer -from utils.cache import default_cache from utils.constants import CacheKey from account.decorators import login_required, check_contest_permission -from ..models import ContestAnnouncement, Contest, ContestStatus, ContestRuleType -from ..models import OIContestRank, ACMContestRank +from utils.constants import ContestRuleType, ContestStatus +from ..models import ContestAnnouncement, Contest, OIContestRank, ACMContestRank from ..serializers import ContestAnnouncementSerializer from ..serializers import ContestSerializer, ContestPasswordVerifySerializer from ..serializers import OIContestRankSerializer, ACMContestRankSerializer @@ -32,7 +30,7 @@ class ContestAPI(APIView): try: contest = Contest.objects.select_related("created_by").get(id=contest_id, visible=True) except Contest.DoesNotExist: - return self.error("Contest doesn't exist.") + return self.error("Contest does not exist") return self.success(ContestSerializer(contest).data) contests = Contest.objects.select_related("created_by").filter(visible=True) @@ -50,7 +48,7 @@ class ContestAPI(APIView): elif status == ContestStatus.CONTEST_ENDED: contests = contests.filter(end_time__lt=cur) else: - contests = contests.filter(Q(start_time__lte=cur) & Q(end_time__gte=cur)) + contests = contests.filter(start_time__lte=cur, end_time__gte=cur) return self.success(self.paginate_data(request, contests, ContestSerializer)) @@ -62,14 +60,14 @@ class ContestPasswordVerifyAPI(APIView): try: contest = Contest.objects.get(id=data["contest_id"], visible=True, password__isnull=False) except Contest.DoesNotExist: - return self.error("Contest %s doesn't exist." % data["contest_id"]) + return self.error("Contest does not exist") if contest.password != data["password"]: - return self.error("Password doesn't match.") + return self.error("Wrong password") # password verify OK. - if "contests" not in request.session: - request.session["contests"] = [] - request.session["contests"].append(int(data["contest_id"])) + if "accessible_contests" not in request.session: + request.session["accessible_contests"] = [] + request.session["accessible_contests"].append(contest.id) # https://docs.djangoproject.com/en/dev/topics/http/sessions/#when-sessions-are-saved request.session.modified = True return self.success(True) @@ -80,13 +78,8 @@ class ContestAccessAPI(APIView): def get(self, request): contest_id = request.GET.get("contest_id") if not contest_id: - return self.error("Parameter contest_id not exist.") - if "contests" not in request.session: - request.session["contests"] = [] - if int(contest_id) in request.session["contests"]: - return self.success({"Access": True}) - else: - return self.success({"Access": False}) + return self.error() + return self.success({"access": int(contest_id) in request.session.get("accessible_contests", [])}) class ContestRankAPI(APIView): @@ -100,17 +93,17 @@ class ContestRankAPI(APIView): @check_contest_permission def get(self, request): - if self.contest.rule_type == ContestRuleType.ACM: - serializer = ACMContestRankSerializer - else: + if self.contest.rule_type == ContestRuleType.OI: + if not self.contest.check_oi_permission(request.user): + return self.error("You have no permission for ranks now") serializer = OIContestRankSerializer - - cache_key = CacheKey.contest_rank_cache + str(self.contest.id) - qs = default_cache.get(cache_key) - if not qs: - ranks = self.get_rank() - default_cache.set(cache_key, pickle.dumps(ranks)) else: - ranks = pickle.loads(qs) + serializer = ACMContestRankSerializer - return self.success(self.paginate_data(request, ranks, serializer)) + cache_key = f"{CacheKey.contest_rank_cache}:{self.contest.id}" + qs = cache.get(cache_key) + if not qs: + qs = self.get_rank() + cache.set(cache_key, qs) + + return self.success(self.paginate_data(request, qs, serializer)) diff --git a/judge/dispatcher.py b/judge/dispatcher.py index 8ab66885..829a43d7 100644 --- a/judge/dispatcher.py +++ b/judge/dispatcher.py @@ -8,12 +8,13 @@ from django.db import transaction from django.db.models import F from account.models import User -from conf.models import JudgeServer, JudgeServerToken +from conf.models import JudgeServer from contest.models import ContestRuleType, ACMContestRank, OIContestRank, ContestStatus from judge.languages import languages +from options.options import SysOptions from problem.models import Problem, ProblemRuleType from submission.models import JudgeStatus, Submission -from utils.cache import judge_cache, default_cache +from utils.cache import cache from utils.constants import CacheKey logger = logging.getLogger(__name__) @@ -21,41 +22,36 @@ logger = logging.getLogger(__name__) # 继续处理在队列中的问题 def process_pending_task(): - if judge_cache.llen(CacheKey.waiting_queue): + if cache.llen(CacheKey.waiting_queue): # 防止循环引入 from judge.tasks import judge_task - data = json.loads(judge_cache.rpop(CacheKey.waiting_queue).decode("utf-8")) + data = json.loads(cache.rpop(CacheKey.waiting_queue).decode("utf-8")) judge_task.delay(**data) class JudgeDispatcher(object): def __init__(self, submission_id, problem_id): - token = JudgeServerToken.objects.first().token - self.token = hashlib.sha256(token.encode("utf-8")).hexdigest() - self.redis_conn = judge_cache - self.submission = Submission.objects.get(pk=submission_id) + self.token = hashlib.sha256(SysOptions.judge_server_token.encode("utf-8")).hexdigest() + self.submission = Submission.objects.get(id=submission_id) self.contest_id = self.submission.contest_id if self.contest_id: - self.problem = Problem.objects.select_related("contest") \ - .get(id=problem_id, contest_id=self.contest_id) + self.problem = Problem.objects.select_related("contest").get(id=problem_id, contest_id=self.contest_id) self.contest = self.problem.contest else: self.problem = Problem.objects.get(id=problem_id) def _request(self, url, data=None): - kwargs = {"headers": {"X-Judge-Server-Token": self.token, - "Content-Type": "application/json"}} + kwargs = {"headers": {"X-Judge-Server-Token": self.token}} if data: - kwargs["data"] = json.dumps(data) + kwargs["json"] = data try: return requests.post(url, **kwargs).json() except Exception as e: - logger.error(e.with_traceback()) + logger.exception(e) @staticmethod def choose_judge_server(): with transaction.atomic(): - # TODO: use more reasonable way servers = JudgeServer.objects.select_for_update().all().order_by("task_number") servers = [s for s in servers if s.status == "normal"] if servers: @@ -65,10 +61,10 @@ class JudgeDispatcher(object): return server @staticmethod - def release_judge_res(judge_server_id): + def release_judge_server(judge_server_id): with transaction.atomic(): # 使用原子操作, 同时因为use和release中间间隔了判题过程,需要重新查询一下 - server = JudgeServer.objects.select_for_update().get(id=judge_server_id) + server = JudgeServer.objects.get(id=judge_server_id) server.used_instance_number = F("task_number") - 1 server.save() @@ -94,7 +90,7 @@ class JudgeDispatcher(object): server = self.choose_judge_server() if not server: data = {"submission_id": self.submission.id, "problem_id": self.problem.id} - self.redis_conn.lpush(CacheKey.waiting_queue, json.dumps(data)) + cache.lpush(CacheKey.waiting_queue, json.dumps(data)) return sub_config = list(filter(lambda item: self.submission.language == item["name"], languages))[0] @@ -138,7 +134,7 @@ class JudgeDispatcher(object): else: self.submission.result = JudgeStatus.PARTIALLY_ACCEPTED self.submission.save() - self.release_judge_res(server.id) + self.release_judge_server(server.id) self.update_problem_status() if self.contest_id: @@ -160,7 +156,6 @@ class JudgeDispatcher(object): with transaction.atomic(): # prepare problem and user_profile problem = Problem.objects.select_for_update().get(contest_id=self.contest_id, id=self.problem.id) - problem_info = problem.statistic_info user = User.objects.select_for_update().select_for_update("userprofile").get(id=self.submission.user_id) user_profile = user.userprofile if self.contest_id: @@ -169,50 +164,46 @@ class JudgeDispatcher(object): key = "problems" acm_problems_status = user_profile.acm_problems_status.get(key, {}) oi_problems_status = user_profile.oi_problems_status.get(key, {}) + problem_id = str(self.problem.id) + problem_info = problem.statistic_info + + # update problem info + result = str(self.submission.result) + problem_info[result] = problem_info.get(result, 0) + 1 + problem.statistic_info = problem_info # update submission and accepted number counter problem.submission_number += 1 if self.submission.result == JudgeStatus.ACCEPTED: problem.accepted_number += 1 - # only when submission is not in contest, we update user profile, - # in other words, users' submission in a contest will not be counted in user profile + # submission in a contest will not be counted in user profile if not self.contest_id: user_profile.submission_number += 1 if self.submission.result == JudgeStatus.ACCEPTED: user_profile.accepted_number += 1 - problem_id = str(self.problem.id) if self.problem.rule_type == ProblemRuleType.ACM: - # update acm problem info - result = str(self.submission.result) - problem_info[result] = problem_info.get(result, 0) + 1 - problem.statistic_info = problem_info - # update user_profile if problem_id not in acm_problems_status: - acm_problems_status[problem_id] = self.submission.result + acm_problems_status[problem_id] = {"status": self.submission.result, "_id": self.problem._id} # skip if the problem has been accepted - elif acm_problems_status[problem_id] != JudgeStatus.ACCEPTED: - if self.submission.result == JudgeStatus.ACCEPTED: - acm_problems_status[problem_id] = JudgeStatus.ACCEPTED - else: - acm_problems_status[problem_id] = self.submission.result + elif acm_problems_status[problem_id]["status"] != JudgeStatus.ACCEPTED: + acm_problems_status[problem_id]["status"] = self.submission.result user_profile.acm_problems_status[key] = acm_problems_status else: - # update oi problem info - score = self.submission.statistic_info["score"] - problem_info[score] = problem_info.get(score, 0) + 1 - problem.statistic_info = problem_info - # update user_profile + score = self.submission.statistic_info["score"] if problem_id not in oi_problems_status: user_profile.add_score(score) - oi_problems_status[problem_id] = score + oi_problems_status[problem_id] = {"status": self.submission.result, + "_id": self.problem._id, + "score": score} else: # minus last time score, add this time score - user_profile.add_score(this_time_score=score, last_time_score=oi_problems_status[problem_id]) - oi_problems_status[problem_id] = score + user_profile.add_score(this_time_score=score, last_time_score=oi_problems_status[problem_id]["score"]) + oi_problems_status[problem_id]["score"] = score + oi_problems_status[problem_id]["status"] = self.submission.result user_profile.oi_problems_status[key] = oi_problems_status problem.save(update_fields=["submission_number", "accepted_number", "statistic_info"]) @@ -222,8 +213,8 @@ class JudgeDispatcher(object): def update_contest_rank(self): if self.contest_id and self.contest.status != ContestStatus.CONTEST_UNDERWAY: return - if self.contest.real_time_rank: - default_cache.delete(CacheKey.contest_rank_cache + str(self.contest_id)) + if self.contest.rule_type == ContestRuleType.OI or self.contest.real_time_rank: + cache.delete(f"{CacheKey.contest_rank_cache}:{self.contest.id}") with transaction.atomic(): if self.contest.rule_type == ContestRuleType.ACM: acm_rank, _ = ACMContestRank.objects.select_for_update(). \ diff --git a/oj/local_settings.py b/oj/local_settings.py index cee68f25..bbe2398f 100644 --- a/oj/local_settings.py +++ b/oj/local_settings.py @@ -5,8 +5,12 @@ BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) DATABASES = { 'default': { - 'ENGINE': 'django.db.backends.sqlite3', - 'NAME': os.path.join(BASE_DIR, 'db.sqlite3'), + 'ENGINE': 'django.db.backends.postgresql_psycopg2', + 'HOST': '127.0.0.1', + 'PORT': 5433, + 'NAME': "onlinejudge", + 'USER': "onlinejudge", + 'PASSWORD': 'onlinejudge' } } diff --git a/oj/settings.py b/oj/settings.py index eef91a66..950c426d 100644 --- a/oj/settings.py +++ b/oj/settings.py @@ -46,6 +46,7 @@ INSTALLED_APPS = ( 'contest', 'utils', 'submission', + 'options', ) MIDDLEWARE_CLASSES = ( @@ -57,11 +58,9 @@ MIDDLEWARE_CLASSES = ( 'django.middleware.clickjacking.XFrameOptionsMiddleware', 'django.middleware.security.SecurityMiddleware', 'account.middleware.AdminRoleRequiredMiddleware', - 'account.middleware.SessionSecurityMiddleware', 'account.middleware.SessionRecordMiddleware', # 'account.middleware.LogSqlMiddleware', ) -SESSION_ENGINE = 'django.contrib.sessions.backends.cache' ROOT_URLCONF = 'oj.urls' TEMPLATES = [ @@ -164,8 +163,6 @@ LOGGING = { } }, } - - REST_FRAMEWORK = { 'TEST_REQUEST_DEFAULT_FORMAT': 'json', 'DEFAULT_RENDERER_CLASSES': ( @@ -173,34 +170,34 @@ REST_FRAMEWORK = { ) } -CACHE_JUDGE_QUEUE = "judge_queue" -CACHE_THROTTLING = "throttling" +REDIS_URL = "redis://127.0.0.1:6379" + + +def redis_config(db): + def make_key(key, key_prefix, version): + return key + + return { + "BACKEND": "utils.cache.MyRedisCache", + "LOCATION": f"{REDIS_URL}/{db}", + "TIMEOUT": None, + "KEY_PREFIX": "", + "KEY_FUNCTION": make_key + } CACHES = { - "default": { - "BACKEND": "django_redis.cache.RedisCache", - "LOCATION": "redis://127.0.0.1:6379/1", - "OPTIONS": { - "CLIENT_CLASS": "django_redis.client.DefaultClient", - } - }, - CACHE_JUDGE_QUEUE: { - "BACKEND": "django_redis.cache.RedisCache", - "LOCATION": "redis://127.0.0.1:6379/2", - "OPTIONS": { - "CLIENT_CLASS": "django_redis.client.DefaultClient", - } - }, - CACHE_THROTTLING: { - "BACKEND": "django_redis.cache.RedisCache", - "LOCATION": "redis://127.0.0.1:6379/3", - "OPTIONS": { - "CLIENT_CLASS": "django_redis.client.DefaultClient", - } - } + "default": redis_config(db=1) } + +CELERY_RESULT_BACKEND = CELERY_BROKER_URL = f"{REDIS_URL}/2" +CELERY_TASK_SOFT_TIME_LIMIT = CELERY_TASK_TIME_LIMIT = 180 + +SESSION_ENGINE = "django.contrib.sessions.backends.cache" +SESSION_CACHE_ALIAS = "default" + + # For celery REDIS_QUEUE = { "host": "127.0.0.1", diff --git a/options/__init__.py b/options/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/options/migrations/0001_initial.py b/options/migrations/0001_initial.py new file mode 100644 index 00000000..db40e1e8 --- /dev/null +++ b/options/migrations/0001_initial.py @@ -0,0 +1,25 @@ +# -*- coding: utf-8 -*- +# Generated by Django 1.11.3 on 2017-10-01 19:19 +from __future__ import unicode_literals + +import jsonfield.fields +from django.db import migrations, models + + +class Migration(migrations.Migration): + + initial = True + + dependencies = [ + ] + + operations = [ + migrations.CreateModel( + name='SysOptions', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('key', models.CharField(db_index=True, max_length=128, unique=True)), + ('value', jsonfield.fields.JSONField()), + ], + ), + ] diff --git a/options/migrations/0002_auto_20171011_1214.py b/options/migrations/0002_auto_20171011_1214.py new file mode 100644 index 00000000..ee52ffa4 --- /dev/null +++ b/options/migrations/0002_auto_20171011_1214.py @@ -0,0 +1,21 @@ +# -*- coding: utf-8 -*- +# Generated by Django 1.11.4 on 2017-10-11 12:14 +from __future__ import unicode_literals + +import django.contrib.postgres.fields.jsonb +from django.db import migrations + + +class Migration(migrations.Migration): + + dependencies = [ + ('options', '0001_initial'), + ] + + operations = [ + migrations.AlterField( + model_name='sysoptions', + name='value', + field=django.contrib.postgres.fields.jsonb.JSONField(), + ), + ] diff --git a/options/migrations/__init__.py b/options/migrations/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/options/models.py b/options/models.py new file mode 100644 index 00000000..04dee5e2 --- /dev/null +++ b/options/models.py @@ -0,0 +1,7 @@ +from django.db import models +from utils.models import JSONField + + +class SysOptions(models.Model): + key = models.CharField(max_length=128, unique=True, db_index=True) + value = JSONField() diff --git a/options/options.py b/options/options.py new file mode 100644 index 00000000..b2d76f16 --- /dev/null +++ b/options/options.py @@ -0,0 +1,179 @@ +from django.core.cache import cache +from django.db import transaction, IntegrityError + +from utils.constants import CacheKey +from utils.shortcuts import rand_str +from .models import SysOptions as SysOptionsModel + + +class OptionKeys: + website_base_url = "website_base_url" + website_name = "website_name" + website_name_shortcut = "website_name_shortcut" + website_footer = "website_footer" + allow_register = "allow_register" + submission_list_show_all = "submission_list_show_all" + smtp_config = "smtp_config" + judge_server_token = "judge_server_token" + + +class OptionDefaultValue: + website_base_url = "http://127.0.0.1" + website_name = "Online Judge" + website_name_shortcut = "oj" + website_footer = "Online Judge Footer" + allow_register = True + submission_list_show_all = True + smtp_config = {} + judge_server_token = rand_str + + +class _SysOptionsMeta(type): + @classmethod + def _set_cache(mcs, option_key, option_value): + cache.set(f"{CacheKey.option}:{option_key}", option_value, timeout=60) + + @classmethod + def _del_cache(mcs, option_key): + cache.delete(f"{CacheKey.option}:{option_key}") + + @classmethod + def _get_keys(cls): + return [key for key in OptionKeys.__dict__ if not key.startswith("__")] + + def rebuild_cache(cls): + for key in cls._get_keys(): + # get option 的时候会写 cache 的 + cls._get_option(key, use_cache=False) + + @classmethod + def _init_option(mcs): + for item in mcs._get_keys(): + if not SysOptionsModel.objects.filter(key=item).exists(): + default_value = getattr(OptionDefaultValue, item) + if callable(default_value): + default_value = default_value() + try: + SysOptionsModel.objects.create(key=item, value=default_value) + except IntegrityError: + pass + + @classmethod + def _get_option(mcs, option_key, use_cache=True): + try: + if use_cache: + option = cache.get(f"{CacheKey.option}:{option_key}") + if option: + return option + option = SysOptionsModel.objects.get(key=option_key) + value = option.value + mcs._set_cache(option_key, value) + return value + except SysOptionsModel.DoesNotExist: + mcs._init_option() + return mcs._get_option(option_key, use_cache=use_cache) + + @classmethod + def _set_option(mcs, option_key: str, option_value): + try: + with transaction.atomic(): + option = SysOptionsModel.objects.select_for_update().get(key=option_key) + option.value = option_value + option.save() + mcs._del_cache(option_key) + except SysOptionsModel.DoesNotExist: + mcs._init_option() + mcs._set_option(option_key, option_value) + + @classmethod + def _increment(mcs, option_key): + try: + with transaction.atomic(): + option = SysOptionsModel.objects.select_for_update().get(key=option_key) + value = option.value + 1 + option.value = value + option.save() + mcs._del_cache(option_key) + except SysOptionsModel.DoesNotExist: + mcs._init_option() + return mcs._increment(option_key) + + @classmethod + def set_options(mcs, options): + for key, value in options: + mcs._set_option(key, value) + + @classmethod + def get_options(mcs, keys): + result = {} + for key in keys: + result[key] = mcs._get_option(key) + return result + + @property + def website_base_url(cls): + return cls._get_option(OptionKeys.website_base_url) + + @website_base_url.setter + def website_base_url(cls, value): + cls._set_option(OptionKeys.website_base_url, value) + + @property + def website_name(cls): + return cls._get_option(OptionKeys.website_name) + + @website_name.setter + def website_name(cls, value): + cls._set_option(OptionKeys.website_name, value) + + @property + def website_name_shortcut(cls): + return cls._get_option(OptionKeys.website_name_shortcut) + + @website_name_shortcut.setter + def website_name_shortcut(cls, value): + cls._set_option(OptionKeys.website_name_shortcut, value) + + @property + def website_footer(cls): + return cls._get_option(OptionKeys.website_footer) + + @website_footer.setter + def website_footer(cls, value): + cls._set_option(OptionKeys.website_footer, value) + + @property + def allow_register(cls): + return cls._get_option(OptionKeys.allow_register) + + @allow_register.setter + def allow_register(cls, value): + cls._set_option(OptionKeys.allow_register, value) + + @property + def submission_list_show_all(cls): + return cls._get_option(OptionKeys.submission_list_show_all) + + @submission_list_show_all.setter + def submission_list_show_all(cls, value): + cls._set_option(OptionKeys.submission_list_show_all, value) + + @property + def smtp_config(cls): + return cls._get_option(OptionKeys.smtp_config) + + @smtp_config.setter + def smtp_config(cls, value): + cls._set_option(OptionKeys.smtp_config, value) + + @property + def judge_server_token(cls): + return cls._get_option(OptionKeys.judge_server_token) + + @judge_server_token.setter + def judge_server_token(cls, value): + cls._set_option(OptionKeys.judge_server_token, value) + + +class SysOptions(metaclass=_SysOptionsMeta): + pass diff --git a/options/tests.py b/options/tests.py new file mode 100644 index 00000000..a39b155a --- /dev/null +++ b/options/tests.py @@ -0,0 +1 @@ +# Create your tests here. diff --git a/options/views.py b/options/views.py new file mode 100644 index 00000000..60f00ef0 --- /dev/null +++ b/options/views.py @@ -0,0 +1 @@ +# Create your views here. diff --git a/problem/migrations/0009_auto_20171011_1214.py b/problem/migrations/0009_auto_20171011_1214.py new file mode 100644 index 00000000..7073b8f9 --- /dev/null +++ b/problem/migrations/0009_auto_20171011_1214.py @@ -0,0 +1,41 @@ +# -*- coding: utf-8 -*- +# Generated by Django 1.11.4 on 2017-10-11 12:14 +from __future__ import unicode_literals + +import django.contrib.postgres.fields.jsonb +from django.db import migrations + + +class Migration(migrations.Migration): + + dependencies = [ + ('problem', '0008_auto_20170923_1318'), + ] + + operations = [ + migrations.AlterField( + model_name='problem', + name='languages', + field=django.contrib.postgres.fields.jsonb.JSONField(), + ), + migrations.AlterField( + model_name='problem', + name='samples', + field=django.contrib.postgres.fields.jsonb.JSONField(), + ), + migrations.AlterField( + model_name='problem', + name='statistic_info', + field=django.contrib.postgres.fields.jsonb.JSONField(default=dict), + ), + migrations.AlterField( + model_name='problem', + name='template', + field=django.contrib.postgres.fields.jsonb.JSONField(), + ), + migrations.AlterField( + model_name='problem', + name='test_case_score', + field=django.contrib.postgres.fields.jsonb.JSONField(), + ), + ] diff --git a/problem/models.py b/problem/models.py index 0e9c5e2f..90537932 100644 --- a/problem/models.py +++ b/problem/models.py @@ -1,5 +1,5 @@ from django.db import models -from jsonfield import JSONField +from utils.models import JSONField from account.models import User from contest.models import Contest @@ -65,8 +65,8 @@ class Problem(models.Model): total_score = models.IntegerField(default=0, blank=True) submission_number = models.BigIntegerField(default=0) accepted_number = models.BigIntegerField(default=0) - # ACM rule_type: {JudgeStatus.ACCEPTED: 3, JudgeStaus.WRONG_ANSWER: 11}, the number means count - statistic_info = JSONField(default={}) + # {JudgeStatus.ACCEPTED: 3, JudgeStaus.WRONG_ANSWER: 11}, the number means count + statistic_info = JSONField(default=dict) class Meta: db_table = "problem" diff --git a/problem/serializers.py b/problem/serializers.py index d9e49195..0e44710b 100644 --- a/problem/serializers.py +++ b/problem/serializers.py @@ -71,6 +71,7 @@ class CreateContestProblemSerializer(CreateOrEditProblemSerializer): class TagSerializer(serializers.ModelSerializer): class Meta: model = ProblemTag + fields = "__all__" class BaseProblemSerializer(serializers.ModelSerializer): @@ -88,11 +89,13 @@ class BaseProblemSerializer(serializers.ModelSerializer): class ProblemAdminSerializer(BaseProblemSerializer): class Meta: model = Problem + fields = "__all__" class ContestProblemAdminSerializer(BaseProblemSerializer): class Meta: model = Problem + fields = "__all__" class ProblemSerializer(BaseProblemSerializer): diff --git a/problem/tests.py b/problem/tests.py index 4e081ff7..7efa668b 100644 --- a/problem/tests.py +++ b/problem/tests.py @@ -1,4 +1,5 @@ import copy +import hashlib import os import shutil from datetime import timedelta @@ -9,6 +10,7 @@ from django.conf import settings from utils.api.tests import APITestCase from .models import ProblemTag +from .models import Problem, ProblemRuleType from .views.admin import TestCaseUploadAPI from contest.models import Contest from contest.tests import DEFAULT_CONTEST_DATA @@ -23,6 +25,40 @@ DEFAULT_PROBLEM_DATA = {"_id": "A-110", "title": "test", "description": "

test "input_size": 0, "score": 0}], "rule_type": "ACM", "hint": "

test

", "source": "test"} +class ProblemCreateTestBase(APITestCase): + @staticmethod + def add_problem(problem_data, created_by): + data = copy.deepcopy(problem_data) + if data["spj"]: + if not data["spj_language"] or not data["spj_code"]: + raise ValueError("Invalid spj") + data["spj_version"] = hashlib.md5((data["spj_language"] + ":" + data["spj_code"]).encode("utf-8")).hexdigest() + else: + data["spj_language"] = None + data["spj_code"] = None + if data["rule_type"] == ProblemRuleType.OI: + total_score = 0 + for item in data["test_case_score"]: + if item["score"] <= 0: + raise ValueError("invalid score") + else: + total_score += item["score"] + data["total_score"] = total_score + data["created_by"] = created_by + tags = data.pop("tags") + + data["languages"] = list(data["languages"]) + + problem = Problem.objects.create(**data) + + for item in tags: + try: + tag = ProblemTag.objects.get(name=item) + except ProblemTag.DoesNotExist: + tag = ProblemTag.objects.create(name=item) + problem.tags.add(tag) + return problem + class ProblemTagListAPITest(APITestCase): def test_get_tag_list(self): @@ -96,7 +132,7 @@ class ProblemAdminAPITest(APITestCase): def setUp(self): self.url = self.reverse("problem_admin_api") self.create_super_admin() - self.data = DEFAULT_PROBLEM_DATA + self.data = copy.deepcopy(DEFAULT_PROBLEM_DATA) def test_create_problem(self): resp = self.client.post(self.url, data=self.data) @@ -138,23 +174,19 @@ class ProblemAdminAPITest(APITestCase): self.assertSuccess(resp) -class ProblemAPITest(APITestCase): +class ProblemAPITest(ProblemCreateTestBase): def setUp(self): self.url = self.reverse("problem_api") - self.create_admin() - - def create_problem(self): - url = self.reverse("problem_admin_api") - return self.client.post(url, data=DEFAULT_PROBLEM_DATA) + admin = self.create_admin(login=False) + self.problem = self.add_problem(DEFAULT_PROBLEM_DATA, admin) + self.create_user("test", "test123") def test_get_problem_list(self): - self.create_problem() resp = self.client.get(f"{self.url}?limit=10") self.assertSuccess(resp) def get_one_problem(self): - problem_id = self.create_problem().data["data"]["_id"] - resp = self.client.get(self.url + "?id=" + str(problem_id)) + resp = self.client.get(self.url + "?id=" + self.problem._id) self.assertSuccess(resp) @@ -169,51 +201,49 @@ class ContestProblemAdminTest(APITestCase): def test_create_contest_problem(self): contest = self.create_contest() - data = DEFAULT_PROBLEM_DATA + data = copy.deepcopy(DEFAULT_PROBLEM_DATA) data["contest_id"] = contest.data["data"]["id"] resp = self.client.post(self.url, data=data) self.assertSuccess(resp) - return resp + return contest, resp def test_get_contest_problem(self): - contest = self.test_create_contest_problem() + contest, contest_problem = self.test_create_contest_problem() contest_id = contest.data["data"]["id"] resp = self.client.get(self.url + "?contest_id=" + str(contest_id)) self.assertSuccess(resp) self.assertEqual(len(resp.data["data"]), 1) def test_get_one_contest_problem(self): - contest = self.test_create_contest_problem() + contest, contest_problem = self.test_create_contest_problem() contest_id = contest.data["data"]["id"] - resp = self.client.get(self.url + "?id=" + str(contest_id)) + problem_id = contest_problem.data["data"]["id"] + resp = self.client.get(f"{self.url}?contest_id={contest_id}&id={problem_id}") self.assertSuccess(resp) -class ContestProblemTest(APITestCase): +class ContestProblemTest(ProblemCreateTestBase): def setUp(self): - self.url = self.reverse("contest_problem_api") - self.create_admin() - + admin = self.create_admin() url = self.reverse("contest_admin_api") contest_data = copy.deepcopy(DEFAULT_CONTEST_DATA) contest_data["password"] = "" contest_data["start_time"] = contest_data["start_time"] + timedelta(hours=1) self.contest = self.client.post(url, data=contest_data).data["data"] + self.problem = self.add_problem(DEFAULT_PROBLEM_DATA, admin) + self.problem.contest_id = self.contest["id"] + self.problem.save() + self.url = self.reverse("contest_problem_api") - problem_data = copy.deepcopy(DEFAULT_PROBLEM_DATA) - problem_data["contest"] = self.contest["id"] - url = self.reverse("contest_problem_admin_api") - self.problem = self.client.post(url, problem_data).data["data"] - - def test_get_contest_problem_list(self): + def test_admin_get_contest_problem_list(self): contest_id = self.contest["id"] resp = self.client.get(self.url + "?contest_id=" + str(contest_id)) self.assertSuccess(resp) self.assertEqual(len(resp.data["data"]), 1) - def test_get_one_contest_problem(self): + def test_admin_get_one_contest_problem(self): contest_id = self.contest["id"] - problem_id = self.problem["_id"] + problem_id = self.problem._id resp = self.client.get("{}?contest_id={}&problem_id={}".format(self.url, contest_id, problem_id)) self.assertSuccess(resp) diff --git a/problem/views/admin.py b/problem/views/admin.py index 572e2fb4..00f46c4d 100644 --- a/problem/views/admin.py +++ b/problem/views/admin.py @@ -223,6 +223,7 @@ class ProblemAPI(APIView): data["total_score"] = total_score # todo check filename and score info tags = data.pop("tags") + data["languages"] = list(data["languages"]) for k, v in data.items(): setattr(problem, k, v) diff --git a/problem/views/oj.py b/problem/views/oj.py index 0a6b066b..330e9332 100644 --- a/problem/views/oj.py +++ b/problem/views/oj.py @@ -13,6 +13,24 @@ class ProblemTagAPI(APIView): class ProblemAPI(APIView): + @staticmethod + def _add_problem_status(request, queryset_values): + if request.user.is_authenticated(): + profile = request.user.userprofile + acm_problems_status = profile.acm_problems_status.get("problems", {}) + oi_problems_status = profile.oi_problems_status.get("problems", {}) + # paginate data + results = queryset_values.get("results") + if results is not None: + problems = results + else: + problems = [queryset_values,] + for problem in problems: + if problem["rule_type"] == ProblemRuleType.ACM: + problem["my_status"] = acm_problems_status.get(str(problem["id"]), {}).get("status") + else: + problem["my_status"] = oi_problems_status.get(str(problem["id"]), {}).get("status") + def get(self, request): # 问题详情页 problem_id = request.GET.get("problem_id") @@ -20,7 +38,9 @@ class ProblemAPI(APIView): try: problem = Problem.objects.select_related("created_by")\ .get(_id=problem_id, contest_id__isnull=True, visible=True) - return self.success(ProblemSerializer(problem).data) + problem_data = ProblemSerializer(problem).data + self._add_problem_status(request, problem_data) + return self.success(problem_data) except Problem.DoesNotExist: return self.error("Problem does not exist") @@ -32,11 +52,7 @@ class ProblemAPI(APIView): # 按照标签筛选 tag_text = request.GET.get("tag") if tag_text: - try: - tag = ProblemTag.objects.get(name=tag_text) - except ProblemTag.DoesNotExist: - return self.error("The Tag does not exist.") - problems = tag.problem_set.all().filter(visible=True) + problems = problems.filter(tags__name=tag_text) # 搜索的情况 keyword = request.GET.get("keyword", "").strip() @@ -49,19 +65,23 @@ class ProblemAPI(APIView): problems = problems.filter(difficulty=difficulty) # 根据profile 为做过的题目添加标记 data = self.paginate_data(request, problems, ProblemSerializer) - if request.user.id: - profile = request.user.userprofile - acm_problems_status = profile.acm_problems_status.get("problems", {}) - oi_problems_status = profile.oi_problems_status.get("problems", {}) - for problem in data["results"]: - if problem["rule_type"] == ProblemRuleType.ACM: - problem["my_status"] = acm_problems_status.get(str(problem["id"]), None) - else: - problem["my_status"] = oi_problems_status.get(str(problem["id"]), None) + self._add_problem_status(request, data) return self.success(data) class ContestProblemAPI(APIView): + def _add_problem_status(self, request, queryset_values): + if self.contest.rule_type == ContestRuleType.OI and not self.contest.check_oi_permission(request.user): + return + if request.user.is_authenticated(): + profile = request.user.userprofile + if self.contest.rule_type == ContestRuleType.ACM: + problems_status = profile.acm_problems_status.get("contest_problems", {}) + else: + problems_status = profile.oi_problems_status.get("contest_problems", {}) + for problem in queryset_values: + problem["my_status"] = problems_status.get(str(problem["id"]), {}).get("status") + @check_contest_permission def get(self, request): problem_id = request.GET.get("problem_id") @@ -72,17 +92,12 @@ class ContestProblemAPI(APIView): visible=True) except Problem.DoesNotExist: return self.error("Problem does not exist.") - return self.success(ContestProblemSerializer(problem).data) + problem_data = ContestProblemSerializer(problem).data + self._add_problem_status(request, [problem_data,]) + return self.success(problem_data) contest_problems = Problem.objects.select_related("created_by").filter(contest=self.contest, visible=True) # 根据profile, 为做过的题目添加标记 data = ContestProblemSerializer(contest_problems, many=True).data - if request.user.is_authenticated() and self.contest.rule_type != ContestRuleType.OI: - profile = request.user.userprofile - if self.contest.rule_type == ContestRuleType.ACM: - problems_status = profile.acm_problems_status.get("contest_problems", {}) - else: - problems_status = profile.oi_problems_status.get("contest_problems", {}) - for problem in data: - problem["my_status"] = problems_status.get(str(problem["id"]), None) + self._add_problem_status(request, data) return self.success(data) diff --git a/submission/migrations/0008_auto_20171011_1214.py b/submission/migrations/0008_auto_20171011_1214.py new file mode 100644 index 00000000..1c585d8a --- /dev/null +++ b/submission/migrations/0008_auto_20171011_1214.py @@ -0,0 +1,26 @@ +# -*- coding: utf-8 -*- +# Generated by Django 1.11.4 on 2017-10-11 12:14 +from __future__ import unicode_literals + +import django.contrib.postgres.fields.jsonb +from django.db import migrations + + +class Migration(migrations.Migration): + + dependencies = [ + ('submission', '0007_auto_20170923_1318'), + ] + + operations = [ + migrations.AlterField( + model_name='submission', + name='info', + field=django.contrib.postgres.fields.jsonb.JSONField(default=dict), + ), + migrations.AlterField( + model_name='submission', + name='statistic_info', + field=django.contrib.postgres.fields.jsonb.JSONField(default=dict), + ), + ] diff --git a/submission/models.py b/submission/models.py index f4cb2f2e..7f046a04 100644 --- a/submission/models.py +++ b/submission/models.py @@ -1,5 +1,5 @@ from django.db import models -from jsonfield import JSONField +from utils.models import JSONField from account.models import AdminType from problem.models import Problem from contest.models import Contest @@ -30,18 +30,20 @@ class Submission(models.Model): username = models.CharField(max_length=30) code = models.TextField() result = models.IntegerField(db_index=True, default=JudgeStatus.PENDING) - # 判题结果的详细信息 - info = JSONField(default={}) + # 从JudgeServer返回的判题详情 + info = JSONField(default=dict) language = models.CharField(max_length=20) shared = models.BooleanField(default=False) # 存储该提交所用时间和内存值,方便提交列表显示 # {time_cost: "", memory_cost: "", err_info: "", score: 0} - statistic_info = JSONField(default={}) + statistic_info = JSONField(default=dict) - def check_user_permission(self, user): + def check_user_permission(self, user, check_share=True): return self.user_id == user.id or \ - self.shared is True or \ - user.admin_type == AdminType.SUPER_ADMIN + (check_share and self.shared is True) or \ + user.is_super_admin() or \ + user.can_mgmt_all_problem() or \ + self.problem.created_by_id == user.id class Meta: db_table = "submission" diff --git a/submission/serializers.py b/submission/serializers.py index ae8c3a60..66a517bd 100644 --- a/submission/serializers.py +++ b/submission/serializers.py @@ -10,6 +10,11 @@ class CreateSubmissionSerializer(serializers.Serializer): contest_id = serializers.IntegerField(required=False) +class ShareSubmissionSerializer(serializers.Serializer): + id = serializers.CharField() + shared = serializers.BooleanField() + + class SubmissionModelSerializer(serializers.ModelSerializer): info = serializers.JSONField() statistic_info = serializers.JSONField() @@ -19,7 +24,7 @@ class SubmissionModelSerializer(serializers.ModelSerializer): # 不显示submission info的serializer, 用于ACM rule_type -class SubmissionSafeSerializer(serializers.ModelSerializer): +class SubmissionSafeModelSerializer(serializers.ModelSerializer): problem = serializers.SlugRelatedField(read_only=True, slug_field="_id") statistic_info = serializers.JSONField() @@ -43,6 +48,6 @@ class SubmissionListSerializer(serializers.ModelSerializer): def get_show_link(self, obj): # 没传user或为匿名user - if self.user is None or self.user.id is None: + if self.user is None or not self.user.is_authenticated(): return False return obj.check_user_permission(self.user) diff --git a/submission/tests.py b/submission/tests.py index c7348827..fdcd12ec 100644 --- a/submission/tests.py +++ b/submission/tests.py @@ -32,14 +32,16 @@ class SubmissionPrepare(APITestCase): def _create_problem_and_submission(self): user = self.create_admin("test", "test123", login=False) problem_data = deepcopy(DEFAULT_PROBLEM_DATA) - problem_data.pop("tags") + tags = problem_data.pop("tags") problem_data["created_by"] = user self.problem = Problem.objects.create(**problem_data) - for tag in DEFAULT_PROBLEM_DATA["tags"]: + for tag in tags: tag = ProblemTag.objects.create(name=tag) self.problem.tags.add(tag) self.problem.save() - self.submission = Submission.objects.create(**DEFAULT_SUBMISSION_DATA) + self.submission_data = deepcopy(DEFAULT_SUBMISSION_DATA) + self.submission_data["problem_id"] = self.problem.id + self.submission = Submission.objects.create(**self.submission_data) class SubmissionListTest(SubmissionPrepare): @@ -61,6 +63,6 @@ class SubmissionAPITest(SubmissionPrepare): self.url = self.reverse("submission_api") def test_create_submission(self, judge_task): - resp = self.client.post(self.url, DEFAULT_SUBMISSION_DATA) + resp = self.client.post(self.url, self.submission_data) self.assertSuccess(resp) judge_task.assert_called() diff --git a/submission/views/oj.py b/submission/views/oj.py index 274418a7..4cf6a8c4 100644 --- a/submission/views/oj.py +++ b/submission/views/oj.py @@ -5,16 +5,17 @@ from problem.models import Problem, ProblemRuleType from contest.models import Contest, ContestStatus, ContestRuleType from utils.api import APIView, validate_serializer from utils.throttling import TokenBucket, BucketController +from utils.cache import cache from ..models import Submission -from ..serializers import CreateSubmissionSerializer, SubmissionModelSerializer -from ..serializers import SubmissionSafeSerializer, SubmissionListSerializer -from utils.cache import throttling_cache +from ..serializers import (CreateSubmissionSerializer, SubmissionModelSerializer, + ShareSubmissionSerializer) +from ..serializers import SubmissionSafeModelSerializer, SubmissionListSerializer def _submit(response, user, problem_id, language, code, contest_id): # TODO: 预设默认值,需修改 controller = BucketController(user_id=user.id, - redis_conn=throttling_cache, + redis_conn=cache, default_capacity=30) bucket = TokenBucket(fill_rate=10, capacity=20, last_capacity=controller.last_capacity, @@ -63,17 +64,36 @@ class SubmissionAPI(APIView): def get(self, request): submission_id = request.GET.get("id") if not submission_id: - return self.error("Parameter id doesn't exist.") + return self.error("Parameter id doesn't exist") try: submission = Submission.objects.select_related("problem").get(id=submission_id) except Submission.DoesNotExist: - return self.error("Submission doesn't exist.") + return self.error("Submission doesn't exist") if not submission.check_user_permission(request.user): - return self.error("No permission for this submission.") + return self.error("No permission for this submission") if submission.problem.rule_type == ProblemRuleType.ACM: - return self.success(SubmissionSafeSerializer(submission).data) - return self.success(SubmissionModelSerializer(submission).data) + submission_data = SubmissionSafeModelSerializer(submission).data + else: + submission_data = SubmissionModelSerializer(submission).data + # 是否有权限取消共享 + submission_data["can_unshare"] = submission.check_user_permission(request.user, check_share=False) + return self.success(submission_data) + + @validate_serializer(ShareSubmissionSerializer) + @login_required + def put(self, request): + try: + submission = Submission.objects.select_related("problem").get(id=request.data["id"]) + except Submission.DoesNotExist: + return self.error("Submission doesn't exist") + if not submission.check_user_permission(request.user, check_share=False): + return self.error("No permission to share the submission") + if submission.contest and submission.contest.status == ContestStatus.CONTEST_UNDERWAY: + return self.error("Can not share submission now") + submission.shared = request.data["shared"] + submission.save(update_fields=["shared"]) + return self.success() class SubmissionListAPI(APIView): @@ -83,7 +103,7 @@ class SubmissionListAPI(APIView): if request.GET.get("contest_id"): return self.error("Parameter error") - submissions = Submission.objects.filter(contest_id__isnull=True) + submissions = Submission.objects.filter(contest_id__isnull=True).select_related("problem__created_by") problem_id = request.GET.get("problem_id") myself = request.GET.get("myself") result = request.GET.get("result") @@ -109,10 +129,10 @@ class ContestSubmissionListAPI(APIView): return self.error("Limit is needed") contest = self.contest - if contest.rule_type == ContestRuleType.OI and not contest.is_contest_admin(request.user): + if not contest.check_oi_permission(request.user): return self.error("No permission for OI contest submissions") - submissions = Submission.objects.filter(contest_id=contest.id) + submissions = Submission.objects.filter(contest_id=contest.id).select_related("problem__created_by") problem_id = request.GET.get("problem_id") myself = request.GET.get("myself") result = request.GET.get("result") @@ -133,9 +153,11 @@ class ContestSubmissionListAPI(APIView): submissions = submissions.filter(create_time__gte=contest.start_time) # 封榜的时候只能看到自己的提交 - if not contest.real_time_rank and not contest.is_contest_admin(request.user): - submissions = submissions.filter(user_id=request.user.id) + if contest.rule_type == ContestRuleType.ACM: + if not contest.real_time_rank and not contest.is_contest_admin(request.user): + submissions = submissions.filter(user_id=request.user.id) data = self.paginate_data(request, submissions) data["results"] = SubmissionListSerializer(data["results"], many=True, user=request.user).data return self.success(data) + diff --git a/utils/api/__init__.py b/utils/api/__init__.py index dedbe3a9..9384481c 100644 --- a/utils/api/__init__.py +++ b/utils/api/__init__.py @@ -1,2 +1,2 @@ -from .api import * # NOQA from ._serializers import * # NOQA +from .api import * # NOQA diff --git a/utils/api/api.py b/utils/api/api.py index 920b827a..f49b039a 100644 --- a/utils/api/api.py +++ b/utils/api/api.py @@ -79,7 +79,7 @@ class APIView(View): def success(self, data=None): return self.response({"error": None, "data": data}) - def error(self, msg, err="error"): + def error(self, msg="error", err="error"): return self.response({"error": err, "data": msg}) def _serializer_error_to_str(self, errors): diff --git a/utils/api/tests.py b/utils/api/tests.py index 3d9cc306..4b485c9c 100644 --- a/utils/api/tests.py +++ b/utils/api/tests.py @@ -3,7 +3,6 @@ from django.test.testcases import TestCase from rest_framework.test import APIClient from account.models import AdminType, ProblemPermission, User, UserProfile -from conf.models import WebsiteConfig class APITestCase(TestCase): @@ -28,9 +27,6 @@ class APITestCase(TestCase): return self.create_user(username=username, password=password, admin_type=AdminType.SUPER_ADMIN, problem_permission=ProblemPermission.ALL, login=login) - def create_website_config(self): - return WebsiteConfig.objects.create() - def reverse(self, url_name): return reverse(url_name) diff --git a/utils/cache.py b/utils/cache.py index c77131f6..ed9059b1 100644 --- a/utils/cache.py +++ b/utils/cache.py @@ -1,6 +1,27 @@ -from django.conf import settings -from django_redis import get_redis_connection +from django.core.cache import cache, caches # noqa +from django.conf import settings # noqa -judge_cache = get_redis_connection(settings.CACHE_JUDGE_QUEUE) -throttling_cache = get_redis_connection(settings.CACHE_THROTTLING) -default_cache = get_redis_connection("default") +from django_redis.cache import RedisCache +from django_redis.client.default import DefaultClient + + +class MyRedisClient(DefaultClient): + def __getattr__(self, item): + client = self.get_client(write=True) + return getattr(client, item) + + def redis_incr(self, key, count=1): + """ + django 默认的 incr 在 key 不存在时候会抛异常 + """ + client = self.get_client(write=True) + return client.incr(key, count) + + +class MyRedisCache(RedisCache): + def __init__(self, server, params): + super().__init__(server, params) + self._client_cls = MyRedisClient + + def __getattr__(self, item): + return getattr(self.client, item) diff --git a/utils/constants.py b/utils/constants.py index b11aa216..390d5685 100644 --- a/utils/constants.py +++ b/utils/constants.py @@ -1,4 +1,28 @@ +class Choices: + @classmethod + def choices(cls): + d = cls.__dict__ + return [d[item] for item in d.keys() if not item.startswith("__")] + + +class ContestType: + PUBLIC_CONTEST = "Public" + PASSWORD_PROTECTED_CONTEST = "Password Protected" + + +class ContestStatus: + CONTEST_NOT_START = "1" + CONTEST_ENDED = "-1" + CONTEST_UNDERWAY = "0" + + +class ContestRuleType(Choices): + ACM = "ACM" + OI = "OI" + + class CacheKey: waiting_queue = "waiting_queue" - contest_rank_cache = "contest_rank_cache_" + contest_rank_cache = "contest_rank_cache" website_config = "website_config" + option = "option" diff --git a/utils/models.py b/utils/models.py index b651aa2d..3c114522 100644 --- a/utils/models.py +++ b/utils/models.py @@ -1,3 +1,4 @@ +from django.contrib.postgres.fields import JSONField # NOQA from django.db import models from utils.xss_filter import XssHtml diff --git a/utils/shortcuts.py b/utils/shortcuts.py index 525eef0e..8fc1ffac 100644 --- a/utils/shortcuts.py +++ b/utils/shortcuts.py @@ -1,35 +1,9 @@ -import logging -import random import datetime -from io import BytesIO +import random from base64 import b64encode +from io import BytesIO from django.utils.crypto import get_random_string -from envelopes import Envelope - -from conf.models import SMTPConfig - -logger = logging.getLogger(__name__) - - -def send_email(from_name, to_email, to_name, subject, content): - smtp = SMTPConfig.objects.first() - if not smtp: - return - envlope = Envelope(from_addr=(smtp.email, from_name), - to_addr=(to_email, to_name), - subject=subject, - html_body=content) - try: - envlope.send(smtp.server, - login=smtp.email, - password=smtp.password, - port=smtp.port, - tls=smtp.tls) - return True - except Exception as e: - logger.exception(e) - return False def rand_str(length=32, type="lower_hex"): diff --git a/utils/xss_filter.py b/utils/xss_filter.py index d29495b6..34d65a8b 100644 --- a/utils/xss_filter.py +++ b/utils/xss_filter.py @@ -26,11 +26,8 @@ Cannot defense xss in browser which is belowed IE7 浏览器版本:IE7+ 或其他浏览器,无法防御IE6及以下版本浏览器中的XSS """ import re - -try: - from html.parser import HTMLParser -except: - from HTMLParser import HTMLParser +import copy +from html.parser import HTMLParser class XssHtml(HTMLParser): @@ -163,7 +160,7 @@ class XssHtml(HTMLParser): else: other = [] if attrs: - for (key, value) in attrs.items(): + for key, value in copy.deepcopy(attrs).items(): if key not in self.common_attrs + other: del attrs[key] return attrs