From 9990cf647a2d03521eb4c09b0c0bbba8f7b599f7 Mon Sep 17 00:00:00 2001 From: virusdefender Date: Mon, 2 Oct 2017 03:54:34 +0800 Subject: [PATCH 1/8] =?UTF-8?q?=E4=BD=BF=E7=94=A8=20SysOptions?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- account/tasks.py | 29 ++++- account/views/oj.py | 22 ++-- conf/models.py | 32 ------ conf/serializers.py | 23 +--- conf/tests.py | 26 ++--- conf/views.py | 78 ++++--------- judge/dispatcher.py | 8 +- oj/settings.py | 1 + options/__init__.py | 0 options/migrations/0001_initial.py | 25 ++++ options/migrations/__init__.py | 0 options/models.py | 7 ++ options/options.py | 179 +++++++++++++++++++++++++++++ options/tests.py | 1 + options/views.py | 1 + utils/api/__init__.py | 2 +- utils/api/tests.py | 4 - utils/constants.py | 1 + utils/shortcuts.py | 30 +---- 19 files changed, 297 insertions(+), 172 deletions(-) create mode 100644 options/__init__.py create mode 100644 options/migrations/0001_initial.py create mode 100644 options/migrations/__init__.py create mode 100644 options/models.py create mode 100644 options/options.py create mode 100644 options/tests.py create mode 100644 options/views.py diff --git a/account/tasks.py b/account/tasks.py index 0aacec96..3e7c1d2f 100644 --- a/account/tasks.py +++ b/account/tasks.py @@ -1,6 +1,31 @@ -from celery import shared_task +import logging -from utils.shortcuts import send_email +from celery import shared_task +from envelopes import Envelope + +from options.options import SysOptions + +logger = logging.getLogger(__name__) + + +def send_email(from_name, to_email, to_name, subject, content): + smtp = SysOptions.smtp_config + if not smtp: + return + envlope = Envelope(from_addr=(smtp["email"], from_name), + to_addr=(to_email, to_name), + subject=subject, + html_body=content) + try: + envlope.send(smtp["server"], + login=smtp["email"], + password=smtp["password"], + port=smtp["port"], + tls=smtp["tls"]) + return True + except Exception as e: + logger.exception(e) + return False @shared_task diff --git a/account/views/oj.py b/account/views/oj.py index cfb94726..140c0e9c 100644 --- a/account/views/oj.py +++ b/account/views/oj.py @@ -1,24 +1,23 @@ import os -import qrcode import pickle from datetime import timedelta -from otpauth import OtpAuth +from importlib import import_module +import qrcode from django.conf import settings from django.contrib import auth -from importlib import import_module +from django.template.loader import render_to_string +from django.utils.decorators import method_decorator from django.utils.timezone import now from django.views.decorators.csrf import ensure_csrf_cookie -from django.utils.decorators import method_decorator -from django.template.loader import render_to_string +from otpauth import OtpAuth -from conf.models import WebsiteConfig +from options.options import SysOptions from utils.api import APIView, validate_serializer -from utils.captcha import Captcha -from utils.shortcuts import rand_str, img2base64, timestamp2utcstr from utils.cache import default_cache +from utils.captcha import Captcha from utils.constants import CacheKey - +from utils.shortcuts import rand_str, img2base64, timestamp2utcstr from ..decorators import login_required from ..models import User, UserProfile from ..serializers import (ApplyResetPasswordSerializer, ResetPasswordSerializer, @@ -137,9 +136,8 @@ class TwoFactorAuthAPI(APIView): user.tfa_token = token user.save() - config = WebsiteConfig.objects.first() - label = f"{config.name_shortcut}:{user.username}" - image = qrcode.make(OtpAuth(token).to_uri("totp", label, config.name)) + label = f"{SysOptions.website_name_shortcut}:{user.username}" + image = qrcode.make(OtpAuth(token).to_uri("totp", label, SysOptions.website_name)) return self.success(img2base64(image)) @login_required diff --git a/conf/models.py b/conf/models.py index 9fe3cc51..86248dbb 100644 --- a/conf/models.py +++ b/conf/models.py @@ -2,31 +2,6 @@ from django.db import models from django.utils import timezone -class SMTPConfig(models.Model): - server = models.CharField(max_length=128) - port = models.IntegerField(default=25) - email = models.CharField(max_length=128) - password = models.CharField(max_length=128) - tls = models.BooleanField() - - class Meta: - db_table = "smtp_config" - - -class WebsiteConfig(models.Model): - base_url = models.CharField(max_length=128, default="http://127.0.0.1") - name = models.CharField(max_length=32, default="Online Judge") - name_shortcut = models.CharField(max_length=32, default="oj") - footer = models.TextField(default="Online Judge Footer") - # allow register - allow_register = models.BooleanField(default=True) - # submission list show all user's submission - submission_list_show_all = models.BooleanField(default=True) - - class Meta: - db_table = "website_config" - - class JudgeServer(models.Model): hostname = models.CharField(max_length=64) ip = models.CharField(max_length=32, blank=True, null=True) @@ -48,10 +23,3 @@ class JudgeServer(models.Model): class Meta: db_table = "judge_server" - - -class JudgeServerToken(models.Model): - token = models.CharField(max_length=32) - - class Meta: - db_table = "judge_server_token" diff --git a/conf/serializers.py b/conf/serializers.py index 59b7203c..09f9940d 100644 --- a/conf/serializers.py +++ b/conf/serializers.py @@ -1,6 +1,6 @@ from utils.api import DateTimeTZField, serializers -from .models import JudgeServer, SMTPConfig, WebsiteConfig +from .models import JudgeServer class EditSMTPConfigSerializer(serializers.Serializer): @@ -15,31 +15,19 @@ class CreateSMTPConfigSerializer(EditSMTPConfigSerializer): password = serializers.CharField(max_length=128) -class SMTPConfigSerializer(serializers.ModelSerializer): - class Meta: - model = SMTPConfig - exclude = ["id", "password"] - - class TestSMTPConfigSerializer(serializers.Serializer): email = serializers.EmailField() class CreateEditWebsiteConfigSerializer(serializers.Serializer): - base_url = serializers.CharField(max_length=128) - name = serializers.CharField(max_length=32) - name_shortcut = serializers.CharField(max_length=32) - footer = serializers.CharField(max_length=1024) + website_base_url = serializers.CharField(max_length=128) + website_name = serializers.CharField(max_length=32) + website_name_shortcut = serializers.CharField(max_length=32) + website_footer = serializers.CharField(max_length=1024) allow_register = serializers.BooleanField() submission_list_show_all = serializers.BooleanField() -class WebsiteConfigSerializer(serializers.ModelSerializer): - class Meta: - model = WebsiteConfig - exclude = ["id"] - - class JudgeServerSerializer(serializers.ModelSerializer): create_time = DateTimeTZField() last_heartbeat = DateTimeTZField() @@ -47,6 +35,7 @@ class JudgeServerSerializer(serializers.ModelSerializer): class Meta: model = JudgeServer + fields = "__all__" class JudgeServerHeartbeatSerializer(serializers.Serializer): diff --git a/conf/tests.py b/conf/tests.py index 1694c218..eff8cfde 100644 --- a/conf/tests.py +++ b/conf/tests.py @@ -2,11 +2,11 @@ import hashlib from django.utils import timezone +from options.options import SysOptions from utils.api.tests import APITestCase from utils.cache import default_cache from utils.constants import CacheKey - -from .models import JudgeServer, JudgeServerToken, SMTPConfig +from .models import JudgeServer class SMTPConfigTest(APITestCase): @@ -29,10 +29,6 @@ class SMTPConfigTest(APITestCase): "tls": True} resp = self.client.put(self.url, data=data) self.assertSuccess(resp) - smtp = SMTPConfig.objects.first() - self.assertEqual(smtp.password, self.password) - self.assertEqual(smtp.server, "smtp1.test.com") - self.assertEqual(smtp.email, "test2@test.com") def test_edit_without_password1(self): self.test_create_smtp_config() @@ -40,7 +36,6 @@ class SMTPConfigTest(APITestCase): "tls": True, "password": ""} resp = self.client.put(self.url, data=data) self.assertSuccess(resp) - self.assertEqual(SMTPConfig.objects.first().password, self.password) def test_edit_with_password(self): self.test_create_smtp_config() @@ -48,18 +43,14 @@ class SMTPConfigTest(APITestCase): "tls": True, "password": "newpassword"} resp = self.client.put(self.url, data=data) self.assertSuccess(resp) - smtp = SMTPConfig.objects.first() - self.assertEqual(smtp.password, "newpassword") - self.assertEqual(smtp.server, "smtp1.test.com") - self.assertEqual(smtp.email, "test2@test.com") class WebsiteConfigAPITest(APITestCase): def test_create_website_config(self): self.create_super_admin() url = self.reverse("website_config_api") - data = {"base_url": "http://test.com", "name": "test name", - "name_shortcut": "test oj", "footer": "test", + data = {"website_base_url": "http://test.com", "website_name": "test name", + "website_name_shortcut": "test oj", "website_footer": "test", "allow_register": True, "submission_list_show_all": False} resp = self.client.post(url, data=data) self.assertSuccess(resp) @@ -67,8 +58,8 @@ class WebsiteConfigAPITest(APITestCase): def test_edit_website_config(self): self.create_super_admin() url = self.reverse("website_config_api") - data = {"base_url": "http://test.com", "name": "test name", - "name_shortcut": "test oj", "footer": "test", + data = {"website_base_url": "http://test.com", "website_name": "test name", + "website_name_shortcut": "test oj", "website_footer": "test", "allow_register": True, "submission_list_show_all": False} resp = self.client.post(url, data=data) self.assertSuccess(resp) @@ -78,7 +69,6 @@ class WebsiteConfigAPITest(APITestCase): url = self.reverse("website_info_api") resp = self.client.get(url) self.assertSuccess(resp) - self.assertEqual(resp.data["data"]["name_shortcut"], "oj") def tearDown(self): default_cache.delete(CacheKey.website_config) @@ -91,7 +81,7 @@ class JudgeServerHeartbeatTest(APITestCase): "cpu": 90.5, "memory": 80.3, "action": "heartbeat"} self.token = "test" self.hashed_token = hashlib.sha256(self.token.encode("utf-8")).hexdigest() - JudgeServerToken.objects.create(token=self.token) + SysOptions.judge_server_token = self.token def test_new_heartbeat(self): resp = self.client.post(self.url, data=self.data, **{"HTTP_X_JUDGE_SERVER_TOKEN": self.hashed_token}) @@ -127,11 +117,9 @@ class JudgeServerAPITest(APITestCase): self.create_super_admin() def test_get_judge_server(self): - self.assertFalse(JudgeServerToken.objects.exists()) resp = self.client.get(self.url) self.assertSuccess(resp) self.assertEqual(len(resp.data["data"]["servers"]), 1) - self.assertEqual(JudgeServerToken.objects.first().token, resp.data["data"]["token"]) def test_delete_judge_server(self): resp = self.client.delete(self.url + "?hostname=testhostname") diff --git a/conf/views.py b/conf/views.py index 814c0620..f09972cc 100644 --- a/conf/views.py +++ b/conf/views.py @@ -1,54 +1,45 @@ import hashlib -import pickle from django.utils import timezone from account.decorators import super_admin_required -from judge.languages import languages, spj_languages from judge.dispatcher import process_pending_task +from judge.languages import languages, spj_languages +from options.options import SysOptions from utils.api import APIView, CSRFExemptAPIView, validate_serializer -from utils.shortcuts import rand_str -from utils.cache import default_cache -from utils.constants import CacheKey - -from .models import JudgeServer, JudgeServerToken, SMTPConfig, WebsiteConfig +from .models import JudgeServer from .serializers import (CreateEditWebsiteConfigSerializer, CreateSMTPConfigSerializer, EditSMTPConfigSerializer, JudgeServerHeartbeatSerializer, - JudgeServerSerializer, SMTPConfigSerializer, - TestSMTPConfigSerializer, WebsiteConfigSerializer) + JudgeServerSerializer, TestSMTPConfigSerializer) class SMTPAPI(APIView): @super_admin_required def get(self, request): - smtp = SMTPConfig.objects.first() + smtp = SysOptions.smtp_config if not smtp: return self.success(None) - return self.success(SMTPConfigSerializer(smtp).data) + smtp.pop("password") + return self.success(smtp) @validate_serializer(CreateSMTPConfigSerializer) @super_admin_required def post(self, request): - SMTPConfig.objects.all().delete() - smtp = SMTPConfig.objects.create(**request.data) - return self.success(SMTPConfigSerializer(smtp).data) + SysOptions.smtp_config = request.data + return self.success() @validate_serializer(EditSMTPConfigSerializer) @super_admin_required def put(self, request): + smtp = SysOptions.smtp_config data = request.data - smtp = SMTPConfig.objects.first() - if not smtp: - return self.error("SMTP config is missing") - smtp.server = data["server"] - smtp.port = data["port"] - smtp.email = data["email"] - smtp.tls = data["tls"] - if data.get("password"): - smtp.password = data["password"] - smtp.save() - return self.success(SMTPConfigSerializer(smtp).data) + for item in ["server", "port", "email", "tls"]: + smtp[item] = data[item] + if "password" in data: + smtp["password"] = data["password"] + SysOptions.smtp_config = smtp + return self.success() class SMTPTestAPI(APIView): @@ -60,37 +51,24 @@ class SMTPTestAPI(APIView): class WebsiteConfigAPI(APIView): def get(self, request): - config = default_cache.get(CacheKey.website_config) - if config: - config = pickle.loads(config) - else: - config = WebsiteConfig.objects.first() - if not config: - config = WebsiteConfig.objects.create() - default_cache.set(CacheKey.website_config, pickle.dumps(config)) - return self.success(WebsiteConfigSerializer(config).data) + ret = {key: getattr(SysOptions, key) for key in + ["website_base_url", "website_name", "website_name_shortcut", + "website_footer", "allow_register", "submission_list_show_all"]} + return self.success(ret) @validate_serializer(CreateEditWebsiteConfigSerializer) @super_admin_required def post(self, request): - data = request.data - WebsiteConfig.objects.all().delete() - config = WebsiteConfig.objects.create(**data) - default_cache.set(CacheKey.website_config, pickle.dumps(config)) - return self.success(WebsiteConfigSerializer(config).data) + for k, v in request.data.items(): + setattr(SysOptions, k, v) + return self.success() class JudgeServerAPI(APIView): @super_admin_required def get(self, request): - judge_server_token = JudgeServerToken.objects.first() - if not judge_server_token: - token = rand_str(12) - JudgeServerToken.objects.create(token=token) - else: - token = judge_server_token.token servers = JudgeServer.objects.all().order_by("-last_heartbeat") - return self.success({"token": token, + return self.success({"token": SysOptions.judge_server_token, "servers": JudgeServerSerializer(servers, many=True).data}) @super_admin_required @@ -104,15 +82,9 @@ class JudgeServerAPI(APIView): class JudgeServerHeartbeatAPI(CSRFExemptAPIView): @validate_serializer(JudgeServerHeartbeatSerializer) def post(self, request): - judge_server_token = JudgeServerToken.objects.first() - if not judge_server_token: - token = rand_str(12) - JudgeServerToken.objects.create(token=token) - else: - token = judge_server_token.token data = request.data client_token = request.META.get("HTTP_X_JUDGE_SERVER_TOKEN") - if hashlib.sha256(token.encode("utf-8")).hexdigest() != client_token: + if hashlib.sha256(SysOptions.judge_server_token.encode("utf-8")).hexdigest() != client_token: return self.error("Invalid token") service_url = data.get("service_url") diff --git a/judge/dispatcher.py b/judge/dispatcher.py index 8ab66885..597ed7d1 100644 --- a/judge/dispatcher.py +++ b/judge/dispatcher.py @@ -8,9 +8,10 @@ from django.db import transaction from django.db.models import F from account.models import User -from conf.models import JudgeServer, JudgeServerToken +from conf.models import JudgeServer from contest.models import ContestRuleType, ACMContestRank, OIContestRank, ContestStatus from judge.languages import languages +from options.options import SysOptions from problem.models import Problem, ProblemRuleType from submission.models import JudgeStatus, Submission from utils.cache import judge_cache, default_cache @@ -30,8 +31,7 @@ def process_pending_task(): class JudgeDispatcher(object): def __init__(self, submission_id, problem_id): - token = JudgeServerToken.objects.first().token - self.token = hashlib.sha256(token.encode("utf-8")).hexdigest() + self.token = hashlib.sha256(SysOptions.judge_server_token.encode("utf-8")).hexdigest() self.redis_conn = judge_cache self.submission = Submission.objects.get(pk=submission_id) self.contest_id = self.submission.contest_id @@ -50,7 +50,7 @@ class JudgeDispatcher(object): try: return requests.post(url, **kwargs).json() except Exception as e: - logger.error(e.with_traceback()) + logger.exception(e) @staticmethod def choose_judge_server(): diff --git a/oj/settings.py b/oj/settings.py index eef91a66..4633cf0f 100644 --- a/oj/settings.py +++ b/oj/settings.py @@ -46,6 +46,7 @@ INSTALLED_APPS = ( 'contest', 'utils', 'submission', + 'options', ) MIDDLEWARE_CLASSES = ( diff --git a/options/__init__.py b/options/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/options/migrations/0001_initial.py b/options/migrations/0001_initial.py new file mode 100644 index 00000000..db40e1e8 --- /dev/null +++ b/options/migrations/0001_initial.py @@ -0,0 +1,25 @@ +# -*- coding: utf-8 -*- +# Generated by Django 1.11.3 on 2017-10-01 19:19 +from __future__ import unicode_literals + +import jsonfield.fields +from django.db import migrations, models + + +class Migration(migrations.Migration): + + initial = True + + dependencies = [ + ] + + operations = [ + migrations.CreateModel( + name='SysOptions', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('key', models.CharField(db_index=True, max_length=128, unique=True)), + ('value', jsonfield.fields.JSONField()), + ], + ), + ] diff --git a/options/migrations/__init__.py b/options/migrations/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/options/models.py b/options/models.py new file mode 100644 index 00000000..6d9ac75c --- /dev/null +++ b/options/models.py @@ -0,0 +1,7 @@ +from django.db import models +from jsonfield import JSONField + + +class SysOptions(models.Model): + key = models.CharField(max_length=128, unique=True, db_index=True) + value = JSONField() diff --git a/options/options.py b/options/options.py new file mode 100644 index 00000000..51beb539 --- /dev/null +++ b/options/options.py @@ -0,0 +1,179 @@ +from django.core.cache import cache +from django.db import transaction, IntegrityError + +from utils.constants import CacheKey +from utils.shortcuts import rand_str +from .models import SysOptions as SysOptionsModel + + +class OptionKeys: + website_base_url = "website_base_url" + website_name = "website_name" + website_name_shortcut = "website_name_shortcut" + website_footer = "website_footer" + allow_register = "allow_register" + submission_list_show_all = "submission_list_show_all" + smtp_config = "smtp_config" + judge_server_token = "judge_server_token" + + +class OptionDefaultValue: + website_base_url = "http://127.0.0.1" + website_name = "Online Judge" + website_name_shortcut = "oj" + website_footer = "Online Judge Footer" + allow_register = True + submission_list_show_all = True + smtp_config = {} + judge_server_token = rand_str + + +class _SysOptionsMeta(type): + @classmethod + def _set_cache(mcs, option_key, option_value): + cache.set(f"{CacheKey.option}:{option_key}", option_value, timeout=60) + + @classmethod + def _del_cache(mcs, option_key): + cache.delete(f"{CacheKey.option}:{option_key}") + + @classmethod + def _get_keys(cls): + return [key for key in OptionKeys.__dict__ if not key.startswith("__")] + + def rebuild_cache(cls): + for key in cls._get_keys(): + # get option 的时候会写 cache 的 + cls._get_option(key, use_cache=False) + + @classmethod + def _init_option(mcs): + for item in mcs._get_keys(): + if not SysOptionsModel.objects.filter(key=item).exists(): + default_value = getattr(OptionDefaultValue, item) + if callable(default_value): + default_value = default_value() + try: + SysOptionsModel.objects.create(key=item, value=default_value) + except IntegrityError: + pass + + @classmethod + def _get_option(mcs, option_key, use_cache=True): + try: + if use_cache: + option = cache.get(f"{CacheKey.option}:{option_key}") + if option: + return option + option = SysOptionsModel.objects.get(key=option_key) + value = option.value + mcs._set_cache(option_key, value) + return value + except SysOptionsModel.DoesNotExist: + mcs._init_option() + return mcs._get_option(option_key, use_cache=use_cache) + + @classmethod + def _set_option(mcs, option_key: str, option_value): + try: + with transaction.atomic(): + option = SysOptionsModel.objects.select_for_update().get(key=option_key) + option.value = option_value + option.save() + mcs._del_cache(option_key) + except SysOptionsModel.DoesNotExist: + mcs._init_option() + mcs._set_option(option_key, option_value) + + @classmethod + def _increment(mcs, option_key): + try: + with transaction.atomic(): + option = SysOptionsModel.objects.select_for_update().get(key=option_key) + value = option.value + 1 + option.value = value + option.save() + mcs._del_cache(option_key) + except SysOptionsModel.DoesNotExist: + mcs._init_option() + return mcs._increment(option_key) + + @classmethod + def set_options(mcs, options): + for key, value in options: + mcs._set_option(key, value) + + @classmethod + def get_options(mcs, keys): + result = {} + for key in keys: + result[key] = mcs._get_option(key) + return result + + @property + def website_base_url(cls): + return cls._get_option(OptionKeys.website_base_url) + + @website_base_url.setter + def website_base_url(cls, value): + cls._set_option(OptionKeys.website_base_url, value) + + @property + def website_name(cls): + return cls._get_option(OptionKeys.website_name) + + @website_name.setter + def website_name(cls, value): + cls._set_option(OptionKeys.website_name, value) + + @property + def website_name_shortcut(cls): + return cls._get_option(OptionKeys.website_name_shortcut) + + @website_name_shortcut.setter + def website_name_shortcut(cls, value): + cls._set_option(OptionKeys.website_name_shortcut, value) + + @property + def website_footer(cls): + return cls._get_option(OptionKeys.website_footer) + + @website_footer.setter + def website_footer(cls, value): + cls._set_option(OptionKeys.website_footer, value) + + @property + def allow_register(cls): + return cls._get_option(OptionKeys.allow_register) + + @allow_register.setter + def allow_register(cls, value): + cls._set_option(OptionKeys.allow_register, value) + + @property + def submission_list_show_all(cls): + return cls._get_option(OptionKeys.submission_list_show_all) + + @submission_list_show_all.setter + def submission_list_show_all(cls, value): + cls._set_option(OptionKeys.submission_list_show_all, value) + + @property + def smtp_config(cls): + return cls._get_option(OptionKeys.smtp_config) + + @smtp_config.setter + def smtp_config(cls, value): + cls._set_option(OptionKeys.smtp_config, value) + + @property + def judge_server_token(cls): + return cls._get_option(OptionKeys.judge_server_token) + + @judge_server_token.setter + def judge_server_token(cls, value): + cls._set_option(OptionKeys.judge_server_token, value) + + +class SysOptions(metaclass=_SysOptionsMeta): + pass diff --git a/options/tests.py b/options/tests.py new file mode 100644 index 00000000..a39b155a --- /dev/null +++ b/options/tests.py @@ -0,0 +1 @@ +# Create your tests here. diff --git a/options/views.py b/options/views.py new file mode 100644 index 00000000..60f00ef0 --- /dev/null +++ b/options/views.py @@ -0,0 +1 @@ +# Create your views here. diff --git a/utils/api/__init__.py b/utils/api/__init__.py index dedbe3a9..9384481c 100644 --- a/utils/api/__init__.py +++ b/utils/api/__init__.py @@ -1,2 +1,2 @@ -from .api import * # NOQA from ._serializers import * # NOQA +from .api import * # NOQA diff --git a/utils/api/tests.py b/utils/api/tests.py index 3d9cc306..4b485c9c 100644 --- a/utils/api/tests.py +++ b/utils/api/tests.py @@ -3,7 +3,6 @@ from django.test.testcases import TestCase from rest_framework.test import APIClient from account.models import AdminType, ProblemPermission, User, UserProfile -from conf.models import WebsiteConfig class APITestCase(TestCase): @@ -28,9 +27,6 @@ class APITestCase(TestCase): return self.create_user(username=username, password=password, admin_type=AdminType.SUPER_ADMIN, problem_permission=ProblemPermission.ALL, login=login) - def create_website_config(self): - return WebsiteConfig.objects.create() - def reverse(self, url_name): return reverse(url_name) diff --git a/utils/constants.py b/utils/constants.py index b11aa216..14f13afd 100644 --- a/utils/constants.py +++ b/utils/constants.py @@ -2,3 +2,4 @@ class CacheKey: waiting_queue = "waiting_queue" contest_rank_cache = "contest_rank_cache_" website_config = "website_config" + option = "option" diff --git a/utils/shortcuts.py b/utils/shortcuts.py index 525eef0e..8fc1ffac 100644 --- a/utils/shortcuts.py +++ b/utils/shortcuts.py @@ -1,35 +1,9 @@ -import logging -import random import datetime -from io import BytesIO +import random from base64 import b64encode +from io import BytesIO from django.utils.crypto import get_random_string -from envelopes import Envelope - -from conf.models import SMTPConfig - -logger = logging.getLogger(__name__) - - -def send_email(from_name, to_email, to_name, subject, content): - smtp = SMTPConfig.objects.first() - if not smtp: - return - envlope = Envelope(from_addr=(smtp.email, from_name), - to_addr=(to_email, to_name), - subject=subject, - html_body=content) - try: - envlope.send(smtp.server, - login=smtp.email, - password=smtp.password, - port=smtp.port, - tls=smtp.tls) - return True - except Exception as e: - logger.exception(e) - return False def rand_str(length=32, type="lower_hex"): From edb32eaf7bd589ed3d08aade16cf2fa839ff3e35 Mon Sep 17 00:00:00 2001 From: virusdefender Date: Mon, 2 Oct 2017 04:33:43 +0800 Subject: [PATCH 2/8] tiny work --- account/middleware.py | 21 ++------------------- account/views/oj.py | 1 - contest/models.py | 9 +++------ contest/views/oj.py | 37 +++++++++++++++---------------------- oj/settings.py | 1 - utils/api/api.py | 2 +- 6 files changed, 21 insertions(+), 50 deletions(-) diff --git a/account/middleware.py b/account/middleware.py index 48a6942e..9141d534 100644 --- a/account/middleware.py +++ b/account/middleware.py @@ -10,22 +10,11 @@ from django.utils.deprecation import MiddlewareMixin from utils.api import JSONResponse -class SessionSecurityMiddleware(MiddlewareMixin): - def process_request(self, request): - if request.user.is_authenticated(): - if "last_activity" in request.session and request.user.is_admin_role(): - # 24 hours passed since last visit, 86400 = 24 * 60 * 60 - if time.time() - request.session["last_activity"] >= 86400: - auth.logout(request) - return JSONResponse.response({"error": "login-required", "data": _("Please login in first")}) - request.session["last_activity"] = time.time() - - class SessionRecordMiddleware(MiddlewareMixin): def process_request(self, request): if request.user.is_authenticated(): session = request.session - ip = request.META.get("REMOTE_ADDR", "") + ip = request.META.get("HTTP_X_REAL_IP", "UNKNOWN IP") user_agent = request.META.get("HTTP_USER_AGENT", "") _ip = session.setdefault("ip", ip) _user_agent = session.setdefault("user_agent", user_agent) @@ -42,13 +31,7 @@ class AdminRoleRequiredMiddleware(MiddlewareMixin): path = request.path_info if path.startswith("/admin/") or path.startswith("/api/admin/"): if not (request.user.is_authenticated() and request.user.is_admin_role()): - return JSONResponse.response({"error": "login-required", "data": _("Please login in first")}) - - -class TimezoneMiddleware(MiddlewareMixin): - def process_request(self, request): - if request.user.is_authenticated(): - timezone.activate(pytz.timezone(request.user.userprofile.time_zone)) + return JSONResponse.response({"error": "login-required", "data": "Please login in first"}) class LogSqlMiddleware(MiddlewareMixin): diff --git a/account/views/oj.py b/account/views/oj.py index 140c0e9c..aad7c7d4 100644 --- a/account/views/oj.py +++ b/account/views/oj.py @@ -357,7 +357,6 @@ class SessionManagementAPI(APIView): def get(self, request): engine = import_module(settings.SESSION_ENGINE) SessionStore = engine.SessionStore - current_session = request.COOKIES.get(settings.SESSION_COOKIE_NAME) current_session = request.session.session_key session_keys = request.user.session_keys result = [] diff --git a/contest/models.py b/contest/models.py index 3383d17a..38b23568 100644 --- a/contest/models.py +++ b/contest/models.py @@ -64,7 +64,7 @@ class Contest(models.Model): ordering = ("-create_time",) -class ContestRank(models.Model): +class AbstractContestRank(models.Model): user = models.ForeignKey(User) contest = models.ForeignKey(Contest) submission_number = models.IntegerField(default=0) @@ -73,7 +73,7 @@ class ContestRank(models.Model): abstract = True -class ACMContestRank(ContestRank): +class ACMContestRank(AbstractContestRank): accepted_number = models.IntegerField(default=0) # total_time is only for ACM contest total_time = ac time + none-ac times * 20 * 60 total_time = models.IntegerField(default=0) @@ -85,7 +85,7 @@ class ACMContestRank(ContestRank): db_table = "acm_contest_rank" -class OIContestRank(ContestRank): +class OIContestRank(AbstractContestRank): total_score = models.IntegerField(default=0) # {23: 333}} # key is problem id, value is current score @@ -94,9 +94,6 @@ class OIContestRank(ContestRank): class Meta: db_table = "oi_contest_rank" - def update_rank(self, submission): - self.submission_number += 1 - class ContestAnnouncement(models.Model): contest = models.ForeignKey(Contest) diff --git a/contest/views/oj.py b/contest/views/oj.py index b25019c5..a3195137 100644 --- a/contest/views/oj.py +++ b/contest/views/oj.py @@ -1,6 +1,6 @@ import pickle from django.utils.timezone import now -from django.db.models import Q +from django.core.cache import cache from utils.api import APIView, validate_serializer from utils.cache import default_cache from utils.constants import CacheKey @@ -32,7 +32,7 @@ class ContestAPI(APIView): try: contest = Contest.objects.select_related("created_by").get(id=contest_id, visible=True) except Contest.DoesNotExist: - return self.error("Contest doesn't exist.") + return self.error("Contest does not exist") return self.success(ContestSerializer(contest).data) contests = Contest.objects.select_related("created_by").filter(visible=True) @@ -50,7 +50,7 @@ class ContestAPI(APIView): elif status == ContestStatus.CONTEST_ENDED: contests = contests.filter(end_time__lt=cur) else: - contests = contests.filter(Q(start_time__lte=cur) & Q(end_time__gte=cur)) + contests = contests.filter(start_time__lte=cur, end_time__gte=cur) return self.success(self.paginate_data(request, contests, ContestSerializer)) @@ -62,14 +62,14 @@ class ContestPasswordVerifyAPI(APIView): try: contest = Contest.objects.get(id=data["contest_id"], visible=True, password__isnull=False) except Contest.DoesNotExist: - return self.error("Contest %s doesn't exist." % data["contest_id"]) + return self.error("Contest does not exist") if contest.password != data["password"]: - return self.error("Password doesn't match.") + return self.error("Wrong password") # password verify OK. - if "contests" not in request.session: - request.session["contests"] = [] - request.session["contests"].append(int(data["contest_id"])) + if "accessible_contests" not in request.session: + request.session["accessible_contests"] = [] + request.session["contests"].append(contest.id) # https://docs.djangoproject.com/en/dev/topics/http/sessions/#when-sessions-are-saved request.session.modified = True return self.success(True) @@ -80,13 +80,8 @@ class ContestAccessAPI(APIView): def get(self, request): contest_id = request.GET.get("contest_id") if not contest_id: - return self.error("Parameter contest_id not exist.") - if "contests" not in request.session: - request.session["contests"] = [] - if int(contest_id) in request.session["contests"]: - return self.success({"Access": True}) - else: - return self.success({"Access": False}) + return self.error() + return self.success({"access": int(contest_id) in request.session.get("accessible_contests", [])}) class ContestRankAPI(APIView): @@ -105,12 +100,10 @@ class ContestRankAPI(APIView): else: serializer = OIContestRankSerializer - cache_key = CacheKey.contest_rank_cache + str(self.contest.id) - qs = default_cache.get(cache_key) + cache_key = f"{CacheKey.contest_rank_cache}:{self.contest.id}" + qs = cache.get(cache_key) if not qs: - ranks = self.get_rank() - default_cache.set(cache_key, pickle.dumps(ranks)) - else: - ranks = pickle.loads(qs) + qs = self.get_rank() + cache.set(cache_key, qs) - return self.success(self.paginate_data(request, ranks, serializer)) + return self.success(self.paginate_data(request, qs, serializer)) diff --git a/oj/settings.py b/oj/settings.py index 4633cf0f..672bc6f3 100644 --- a/oj/settings.py +++ b/oj/settings.py @@ -58,7 +58,6 @@ MIDDLEWARE_CLASSES = ( 'django.middleware.clickjacking.XFrameOptionsMiddleware', 'django.middleware.security.SecurityMiddleware', 'account.middleware.AdminRoleRequiredMiddleware', - 'account.middleware.SessionSecurityMiddleware', 'account.middleware.SessionRecordMiddleware', # 'account.middleware.LogSqlMiddleware', ) diff --git a/utils/api/api.py b/utils/api/api.py index 920b827a..f49b039a 100644 --- a/utils/api/api.py +++ b/utils/api/api.py @@ -79,7 +79,7 @@ class APIView(View): def success(self, data=None): return self.response({"error": None, "data": data}) - def error(self, msg, err="error"): + def error(self, msg="error", err="error"): return self.response({"error": err, "data": msg}) def _serializer_error_to_str(self, errors): From a324d55364ca9857d5cabf541b3470a3afbd3a07 Mon Sep 17 00:00:00 2001 From: virusdefender Date: Mon, 2 Oct 2017 05:16:14 +0800 Subject: [PATCH 3/8] tiny work --- account/models.py | 33 +++++------- account/serializers.py | 54 +++++++++---------- account/urls/oj.py | 1 - account/views/oj.py | 101 ++++++++++-------------------------- announcement/models.py | 2 +- announcement/serializers.py | 8 +-- conf/models.py | 6 +-- conf/serializers.py | 12 ++--- contest/models.py | 17 +----- contest/views/oj.py | 6 +-- utils/constants.py | 23 ++++++++ utils/xss_filter.py | 9 ++-- 12 files changed, 111 insertions(+), 161 deletions(-) diff --git a/account/models.py b/account/models.py index b909f932..2db0e55b 100644 --- a/account/models.py +++ b/account/models.py @@ -24,22 +24,22 @@ class UserManager(models.Manager): class User(AbstractBaseUser): - username = models.CharField(max_length=30, unique=True) - email = models.EmailField(max_length=254, null=True) + username = models.CharField(max_length=32, unique=True) + email = models.EmailField(max_length=64, null=True) create_time = models.DateTimeField(auto_now_add=True, null=True) # One of UserType - admin_type = models.CharField(max_length=24, default=AdminType.REGULAR_USER) - problem_permission = models.CharField(max_length=24, default=ProblemPermission.NONE) - reset_password_token = models.CharField(max_length=40, null=True) + admin_type = models.CharField(max_length=32, default=AdminType.REGULAR_USER) + problem_permission = models.CharField(max_length=32, default=ProblemPermission.NONE) + reset_password_token = models.CharField(max_length=32, null=True) reset_password_token_expire_time = models.DateTimeField(null=True) # SSO auth token - auth_token = models.CharField(max_length=40, null=True) + auth_token = models.CharField(max_length=32, null=True) two_factor_auth = models.BooleanField(default=False) - tfa_token = models.CharField(max_length=40, null=True) + tfa_token = models.CharField(max_length=32, null=True) session_keys = JSONField(default=[]) # open api key open_api = models.BooleanField(default=False) - open_api_appkey = models.CharField(max_length=35, null=True) + open_api_appkey = models.CharField(max_length=32, null=True) is_disabled = models.BooleanField(default=False) USERNAME_FIELD = "username" @@ -63,10 +63,6 @@ class User(AbstractBaseUser): db_table = "user" -def _default_avatar(): - return f"/{settings.IMAGE_UPLOAD_DIR}/default.png" - - class UserProfile(models.Model): user = models.OneToOneField(User) # Store user problem solution status with json string format @@ -75,14 +71,13 @@ class UserProfile(models.Model): # {problems: {1: 33}, contest_problems: {1: 44}, record problem_id and score oi_problems_status = JSONField(default={}) - real_name = models.CharField(max_length=30, blank=True, null=True) - avatar = models.CharField(max_length=50, default=_default_avatar()) + real_name = models.CharField(max_length=32, blank=True, null=True) + avatar = models.CharField(max_length=256, default=f"{settings.IMAGE_UPLOAD_DIR}/default.png") blog = models.URLField(blank=True, null=True) - mood = models.CharField(max_length=200, blank=True, null=True) - github = models.CharField(max_length=50, blank=True, null=True) - school = models.CharField(max_length=200, blank=True, null=True) - major = models.CharField(max_length=200, blank=True, null=True) - language = models.CharField(max_length=32, blank=True, null=True) + mood = models.CharField(max_length=256, blank=True, null=True) + github = models.CharField(max_length=64, blank=True, null=True) + school = models.CharField(max_length=64, blank=True, null=True) + major = models.CharField(max_length=64, blank=True, null=True) # for ACM accepted_number = models.IntegerField(default=0) # for OI diff --git a/account/serializers.py b/account/serializers.py index 8345fc69..1b29170c 100644 --- a/account/serializers.py +++ b/account/serializers.py @@ -6,27 +6,27 @@ from .models import AdminType, ProblemPermission, User, UserProfile class UserLoginSerializer(serializers.Serializer): - username = serializers.CharField(max_length=30) - password = serializers.CharField(max_length=30) - tfa_code = serializers.CharField(min_length=6, max_length=6, required=False, allow_null=True) + username = serializers.CharField() + password = serializers.CharField() + tfa_code = serializers.CharField(required=False, allow_null=True) class UsernameOrEmailCheckSerializer(serializers.Serializer): - username = serializers.CharField(max_length=30, required=False) - email = serializers.EmailField(max_length=30, required=False) + username = serializers.CharField(required=False) + email = serializers.EmailField(required=False) class UserRegisterSerializer(serializers.Serializer): - username = serializers.CharField(max_length=30) - password = serializers.CharField(max_length=30, min_length=6) - email = serializers.EmailField(max_length=30) - captcha = serializers.CharField(max_length=4, min_length=1) + username = serializers.CharField(max_length=32) + password = serializers.CharField(min_length=6) + email = serializers.EmailField(max_length=64) + captcha = serializers.CharField() class UserChangePasswordSerializer(serializers.Serializer): old_password = serializers.CharField() - new_password = serializers.CharField(max_length=30, min_length=6) - captcha = serializers.CharField(max_length=4, min_length=4) + new_password = serializers.CharField(min_length=6) + captcha = serializers.CharField() class UserSerializer(serializers.ModelSerializer): @@ -58,9 +58,9 @@ class UserInfoSerializer(serializers.ModelSerializer): class EditUserSerializer(serializers.Serializer): id = serializers.IntegerField() - username = serializers.CharField(max_length=30) - password = serializers.CharField(max_length=30, min_length=6, allow_blank=True, required=False, default=None) - email = serializers.EmailField(max_length=254) + username = serializers.CharField(max_length=32) + password = serializers.CharField(min_length=6, allow_blank=True, required=False, default=None) + email = serializers.EmailField(max_length=64) admin_type = serializers.ChoiceField(choices=(AdminType.REGULAR_USER, AdminType.ADMIN, AdminType.SUPER_ADMIN)) problem_permission = serializers.ChoiceField(choices=(ProblemPermission.NONE, ProblemPermission.OWN, ProblemPermission.ALL)) @@ -70,29 +70,29 @@ class EditUserSerializer(serializers.Serializer): class EditUserProfileSerializer(serializers.Serializer): - real_name = serializers.CharField(max_length=30, allow_blank=True) - avatar = serializers.CharField(max_length=100, allow_blank=True, required=False) - blog = serializers.URLField(allow_blank=True, required=False) - mood = serializers.CharField(max_length=200, allow_blank=True, required=False) - github = serializers.CharField(max_length=50, allow_blank=True, required=False) - school = serializers.CharField(max_length=200, allow_blank=True, required=False) - major = serializers.CharField(max_length=200, allow_blank=True, required=False) + real_name = serializers.CharField(max_length=32, allow_blank=True) + avatar = serializers.CharField(max_length=256, allow_blank=True, required=False) + blog = serializers.URLField(max_length=256, allow_blank=True, required=False) + mood = serializers.CharField(max_length=256, allow_blank=True, required=False) + github = serializers.CharField(max_length=64, allow_blank=True, required=False) + school = serializers.CharField(max_length=64, allow_blank=True, required=False) + major = serializers.CharField(max_length=64, allow_blank=True, required=False) class ApplyResetPasswordSerializer(serializers.Serializer): email = serializers.EmailField() - captcha = serializers.CharField(max_length=4, min_length=4) + captcha = serializers.CharField() class ResetPasswordSerializer(serializers.Serializer): - token = serializers.CharField(min_length=1, max_length=40) - password = serializers.CharField(min_length=6, max_length=30) - captcha = serializers.CharField(max_length=4, min_length=4) + token = serializers.CharField() + password = serializers.CharField(min_length=6) + captcha = serializers.CharField() class SSOSerializer(serializers.Serializer): - appkey = serializers.CharField(max_length=35) - token = serializers.CharField(max_length=40) + appkey = serializers.CharField() + token = serializers.CharField() class TwoFactorAuthCodeSerializer(serializers.Serializer): diff --git a/account/urls/oj.py b/account/urls/oj.py index b80b34e0..cc716830 100644 --- a/account/urls/oj.py +++ b/account/urls/oj.py @@ -19,7 +19,6 @@ urlpatterns = [ url(r"^check_username_or_email", UsernameOrEmailCheck.as_view(), name="check_username_or_email"), url(r"^profile/?$", UserProfileAPI.as_view(), name="user_profile_api"), url(r"^upload_avatar/?$", AvatarUploadAPI.as_view(), name="avatar_upload_api"), - url(r"^sso/?$", SSOAPI.as_view(), name="sso_api"), url(r"^tfa_required/?$", CheckTFARequiredAPI.as_view(), name="tfa_required_check"), url(r"^two_factor_auth/?$", TwoFactorAuthAPI.as_view(), name="two_factor_auth_api"), url(r"^user_rank/?$", UserRankAPI.as_view(), name="user_rank_api"), diff --git a/account/views/oj.py b/account/views/oj.py index aad7c7d4..1225556e 100644 --- a/account/views/oj.py +++ b/account/views/oj.py @@ -12,11 +12,10 @@ from django.utils.timezone import now from django.views.decorators.csrf import ensure_csrf_cookie from otpauth import OtpAuth +from utils.constants import ContestRuleType from options.options import SysOptions from utils.api import APIView, validate_serializer -from utils.cache import default_cache from utils.captcha import Captcha -from utils.constants import CacheKey from utils.shortcuts import rand_str, img2base64, timestamp2utcstr from ..decorators import login_required from ..models import User, UserProfile @@ -38,7 +37,7 @@ class UserProfileAPI(APIView): """ user = request.user if not user.is_authenticated(): - return self.success({}) + return self.success() username = request.GET.get("username") try: if username: @@ -47,8 +46,7 @@ class UserProfileAPI(APIView): user = request.user except User.DoesNotExist: return self.error("User does not exist") - profile = UserProfile.objects.select_related("user").get(user=user) - return self.success(UserProfileSerializer(profile).data) + return self.success(UserProfileSerializer(user.userprofile).data) @validate_serializer(EditUserProfileSerializer) @login_required @@ -71,8 +69,7 @@ class AvatarUploadAPI(APIView): avatar = form.cleaned_data["file"] else: return self.error("Invalid file content") - # 2097152 = 2 * 1024 * 1024 = 2MB - if avatar.size > 2097152: + if avatar.size > 2 * 1024 * 1024: return self.error("Picture is too large") suffix = os.path.splitext(avatar.name)[-1].lower() if suffix not in [".gif", ".jpg", ".jpeg", ".bmp", ".png"]: @@ -83,46 +80,12 @@ class AvatarUploadAPI(APIView): for chunk in avatar: img.write(chunk) user_profile = request.user.userprofile - _, old_avatar = os.path.split(user_profile.avatar) - if old_avatar != "default.png": - os.remove(os.path.join(settings.IMAGE_UPLOAD_DIR_ABS, old_avatar)) - user_profile.avatar = f"/{settings.IMAGE_UPLOAD_DIR}/{name}" + user_profile.avatar = f"{settings.IMAGE_UPLOAD_DIR}/{name}" user_profile.save() return self.success("Succeeded") -class SSOAPI(APIView): - @login_required - def get(self, request): - callback = request.GET.get("callback", None) - if not callback: - return self.error("Parameter Error") - token = rand_str() - request.user.auth_token = token - request.user.save() - return self.success({"redirect_url": callback + "?token=" + token, - "callback": callback}) - - @validate_serializer(SSOSerializer) - def post(self, request): - data = request.data - try: - User.objects.get(open_api_appkey=data["appkey"]) - except User.DoesNotExist: - return self.error("Invalid appkey") - try: - user = User.objects.get(auth_token=data["token"]) - user.auth_token = None - user.save() - return self.success({"username": user.username, - "id": user.id, - "admin_type": user.admin_type, - "avatar": user.userprofile.avatar}) - except User.DoesNotExist: - return self.error("User does not exist") - - class TwoFactorAuthAPI(APIView): @login_required def get(self, request): @@ -131,7 +94,7 @@ class TwoFactorAuthAPI(APIView): """ user = request.user if user.two_factor_auth: - return self.error("Already open 2FA") + return self.error("2FA is already turned on") token = rand_str() user.tfa_token = token user.save() @@ -161,7 +124,7 @@ class TwoFactorAuthAPI(APIView): code = request.data["code"] user = request.user if not user.two_factor_auth: - return self.error("Other session have disabled TFA") + return self.error("2FA is already turned off") if OtpAuth(user.tfa_token).valid_totp(code): user.two_factor_auth = False user.save() @@ -198,7 +161,7 @@ class UserLoginAPI(APIView): # None is returned if username or password is wrong if user: if user.is_disabled: - return self.error("Your account have been disabled") + return self.error("Your account has been disabled") if not user.two_factor_auth: auth.login(request, user) return self.success("Succeeded") @@ -218,13 +181,13 @@ class UserLoginAPI(APIView): # todo remove this, only for debug use def get(self, request): auth.login(request, auth.authenticate(username=request.GET["username"], password=request.GET["password"])) - return self.success({}) + return self.success() class UserLogoutAPI(APIView): def get(self, request): auth.logout(request) - return self.success({}) + return self.success() class UsernameOrEmailCheck(APIView): @@ -240,11 +203,9 @@ class UsernameOrEmailCheck(APIView): "email": False } if data.get("username"): - if User.objects.filter(username=data["username"]).exists(): - result["username"] = True + result["username"] = User.objects.filter(username=data["username"]).exists() if data.get("email"): - if User.objects.filter(email=data["email"]).exists(): - result["email"] = True + result["email"] = User.objects.filter(email=data["email"]).exists() return self.success(result) @@ -254,17 +215,9 @@ class UserRegisterAPI(APIView): """ User register api """ - config = default_cache.get(CacheKey.website_config) - if config: - config = pickle.loads(config) - else: - config = WebsiteConfig.objects.first() - if not config: - config = WebsiteConfig.objects.create() - default_cache.set(CacheKey.website_config, pickle.dumps(config)) - if not config.allow_register: - return self.error("Register have been disabled by admin") + if not SysOptions.allow_register: + return self.error("Register function has been disabled by admin") data = request.data captcha = Captcha(request) @@ -293,6 +246,7 @@ class UserChangePasswordAPI(APIView): username = request.user.username user = auth.authenticate(username=username, password=data["old_password"]) if user: + # TODO: check tfa? user.set_password(data["new_password"]) user.save() return self.success("Succeeded") @@ -305,7 +259,6 @@ class ApplyResetPasswordAPI(APIView): def post(self, request): data = request.data captcha = Captcha(request) - config = WebsiteConfig.objects.first() if not captcha.check(data["captcha"]): return self.error("Invalid captcha") try: @@ -320,14 +273,14 @@ class ApplyResetPasswordAPI(APIView): user.save() render_data = { "username": user.username, - "website_name": config.name, - "link": f"{config.base_url}/reset-password/{user.reset_password_token}" + "website_name": SysOptions.website_name, + "link": f"{SysOptions.website_base_url}/reset-password/{user.reset_password_token}" } email_html = render_to_string("reset_password_email.html", render_data) - send_email_async.delay(config.name, + send_email_async.delay(SysOptions.website_name, user.email, user.username, - config.name + " 登录信息找回邮件", + f"{SysOptions.website_name} 登录信息找回邮件", email_html) return self.success("Succeeded") @@ -342,9 +295,9 @@ class ResetPasswordAPI(APIView): try: user = User.objects.get(reset_password_token=data["token"]) except User.DoesNotExist: - return self.error("Token dose not exist") - if int((user.reset_password_token_expire_time - now()).total_seconds()) < 0: - return self.error("Token have expired") + return self.error("Token does not exist") + if user.reset_password_token_expire_time < now(): + return self.error("Token has expired") user.reset_password_token = None user.two_factor_auth = False user.set_password(data["password"]) @@ -356,13 +309,13 @@ class SessionManagementAPI(APIView): @login_required def get(self, request): engine = import_module(settings.SESSION_ENGINE) - SessionStore = engine.SessionStore + session_store = engine.SessionStore current_session = request.session.session_key session_keys = request.user.session_keys result = [] modified = False for key in session_keys[:]: - session = SessionStore(key) + session = session_store(key) # session does not exist or is expiry if not session._session: session_keys.remove(key) @@ -398,12 +351,12 @@ class SessionManagementAPI(APIView): class UserRankAPI(APIView): def get(self, request): rule_type = request.GET.get("rule") - if rule_type not in ["acm", "oi"]: - rule_type = "acm" + if rule_type not in ContestRuleType.choices(): + rule_type = ContestRuleType.ACM profiles = UserProfile.objects.select_related("user")\ .filter(submission_number__gt=0)\ .exclude(user__is_disabled=True) - if rule_type == "acm": + if rule_type == ContestRuleType.ACM: profiles = profiles.order_by("-accepted_number", "submission_number") else: profiles = profiles.order_by("-total_score") diff --git a/announcement/models.py b/announcement/models.py index 186d4ea6..49f57b82 100644 --- a/announcement/models.py +++ b/announcement/models.py @@ -5,7 +5,7 @@ from utils.models import RichTextField class Announcement(models.Model): - title = models.CharField(max_length=50) + title = models.CharField(max_length=64) # HTML content = RichTextField() create_time = models.DateTimeField(auto_now_add=True) diff --git a/announcement/serializers.py b/announcement/serializers.py index 0c0beccc..b660a615 100644 --- a/announcement/serializers.py +++ b/announcement/serializers.py @@ -5,8 +5,8 @@ from .models import Announcement class CreateAnnouncementSerializer(serializers.Serializer): - title = serializers.CharField(max_length=50) - content = serializers.CharField(max_length=10000) + title = serializers.CharField(max_length=64) + content = serializers.CharField(max_length=1024 * 1024 * 8) visible = serializers.BooleanField() @@ -21,6 +21,6 @@ class AnnouncementSerializer(serializers.ModelSerializer): class EditAnnouncementSerializer(serializers.Serializer): id = serializers.IntegerField() - title = serializers.CharField(max_length=50) - content = serializers.CharField(max_length=10000) + title = serializers.CharField(max_length=64) + content = serializers.CharField(max_length=1024 * 1024 * 8) visible = serializers.BooleanField() diff --git a/conf/models.py b/conf/models.py index 86248dbb..4c6348d5 100644 --- a/conf/models.py +++ b/conf/models.py @@ -3,16 +3,16 @@ from django.utils import timezone class JudgeServer(models.Model): - hostname = models.CharField(max_length=64) + hostname = models.CharField(max_length=128) ip = models.CharField(max_length=32, blank=True, null=True) - judger_version = models.CharField(max_length=24) + judger_version = models.CharField(max_length=32) cpu_core = models.IntegerField() memory_usage = models.FloatField() cpu_usage = models.FloatField() last_heartbeat = models.DateTimeField() create_time = models.DateTimeField(auto_now_add=True) task_number = models.IntegerField(default=0) - service_url = models.CharField(max_length=128, blank=True, null=True) + service_url = models.CharField(max_length=256, blank=True, null=True) @property def status(self): diff --git a/conf/serializers.py b/conf/serializers.py index 09f9940d..7f0cf575 100644 --- a/conf/serializers.py +++ b/conf/serializers.py @@ -21,9 +21,9 @@ class TestSMTPConfigSerializer(serializers.Serializer): class CreateEditWebsiteConfigSerializer(serializers.Serializer): website_base_url = serializers.CharField(max_length=128) - website_name = serializers.CharField(max_length=32) - website_name_shortcut = serializers.CharField(max_length=32) - website_footer = serializers.CharField(max_length=1024) + website_name = serializers.CharField(max_length=64) + website_name_shortcut = serializers.CharField(max_length=64) + website_footer = serializers.CharField(max_length=1024 * 1024) allow_register = serializers.BooleanField() submission_list_show_all = serializers.BooleanField() @@ -39,10 +39,10 @@ class JudgeServerSerializer(serializers.ModelSerializer): class JudgeServerHeartbeatSerializer(serializers.Serializer): - hostname = serializers.CharField(max_length=64) - judger_version = serializers.CharField(max_length=24) + hostname = serializers.CharField(max_length=128) + judger_version = serializers.CharField(max_length=32) cpu_core = serializers.IntegerField(min_value=1) memory = serializers.FloatField(min_value=0, max_value=100) cpu = serializers.FloatField(min_value=0, max_value=100) action = serializers.ChoiceField(choices=("heartbeat", )) - service_url = serializers.CharField(max_length=128, required=False) + service_url = serializers.CharField(max_length=256, required=False) diff --git a/contest/models.py b/contest/models.py index 38b23568..d66d4007 100644 --- a/contest/models.py +++ b/contest/models.py @@ -2,26 +2,11 @@ from django.db import models from django.utils.timezone import now from jsonfield import JSONField +from utils.constants import ContestStatus, ContestRuleType, ContestType from account.models import User, AdminType from utils.models import RichTextField -class ContestType(object): - PUBLIC_CONTEST = "Public" - PASSWORD_PROTECTED_CONTEST = "Password Protected" - - -class ContestStatus(object): - CONTEST_NOT_START = "1" - CONTEST_ENDED = "-1" - CONTEST_UNDERWAY = "0" - - -class ContestRuleType(object): - ACM = "ACM" - OI = "OI" - - class Contest(models.Model): title = models.CharField(max_length=40) description = RichTextField() diff --git a/contest/views/oj.py b/contest/views/oj.py index a3195137..c4fd4fc0 100644 --- a/contest/views/oj.py +++ b/contest/views/oj.py @@ -1,13 +1,11 @@ -import pickle from django.utils.timezone import now from django.core.cache import cache from utils.api import APIView, validate_serializer -from utils.cache import default_cache from utils.constants import CacheKey from account.decorators import login_required, check_contest_permission -from ..models import ContestAnnouncement, Contest, ContestStatus, ContestRuleType -from ..models import OIContestRank, ACMContestRank +from utils.constants import ContestRuleType, ContestType, ContestStatus +from ..models import ContestAnnouncement, Contest, OIContestRank, ACMContestRank from ..serializers import ContestAnnouncementSerializer from ..serializers import ContestSerializer, ContestPasswordVerifySerializer from ..serializers import OIContestRankSerializer, ACMContestRankSerializer diff --git a/utils/constants.py b/utils/constants.py index 14f13afd..be7057a6 100644 --- a/utils/constants.py +++ b/utils/constants.py @@ -1,3 +1,26 @@ +class Choices: + @classmethod + def choices(cls): + d = cls.__dict__ + return [d[item] for item in d.keys() if not item.startswith("__")] + + +class ContestType: + PUBLIC_CONTEST = "Public" + PASSWORD_PROTECTED_CONTEST = "Password Protected" + + +class ContestStatus: + CONTEST_NOT_START = "1" + CONTEST_ENDED = "-1" + CONTEST_UNDERWAY = "0" + + +class ContestRuleType(Choices): + ACM = "ACM" + OI = "OI" + + class CacheKey: waiting_queue = "waiting_queue" contest_rank_cache = "contest_rank_cache_" diff --git a/utils/xss_filter.py b/utils/xss_filter.py index d29495b6..34d65a8b 100644 --- a/utils/xss_filter.py +++ b/utils/xss_filter.py @@ -26,11 +26,8 @@ Cannot defense xss in browser which is belowed IE7 浏览器版本:IE7+ 或其他浏览器,无法防御IE6及以下版本浏览器中的XSS """ import re - -try: - from html.parser import HTMLParser -except: - from HTMLParser import HTMLParser +import copy +from html.parser import HTMLParser class XssHtml(HTMLParser): @@ -163,7 +160,7 @@ class XssHtml(HTMLParser): else: other = [] if attrs: - for (key, value) in attrs.items(): + for key, value in copy.deepcopy(attrs).items(): if key not in self.common_attrs + other: del attrs[key] return attrs From 93bd77d8d83249bbcc5b3908977de66fd03459d0 Mon Sep 17 00:00:00 2001 From: virusdefender Date: Fri, 6 Oct 2017 17:46:14 +0800 Subject: [PATCH 4/8] bug fixes --- account/middleware.py | 18 +++-------- account/migrations/0001_initial.py | 2 +- account/models.py | 2 +- account/serializers.py | 2 +- account/tests.py | 12 ++----- account/urls/oj.py | 2 +- account/views/oj.py | 10 +++--- contest/models.py | 3 +- contest/views/oj.py | 2 +- judge/dispatcher.py | 28 +++++++--------- oj/local_settings.py | 8 +++-- oj/settings.py | 51 ++++++++++++------------------ options/options.py | 8 ++--- problem/serializers.py | 2 ++ submission/views/oj.py | 4 +-- utils/cache.py | 31 +++++++++++++++--- 16 files changed, 91 insertions(+), 94 deletions(-) diff --git a/account/middleware.py b/account/middleware.py index 9141d534..b674346d 100644 --- a/account/middleware.py +++ b/account/middleware.py @@ -1,10 +1,5 @@ -import time -import pytz - -from django.contrib import auth -from django.utils import timezone -from django.utils.translation import ugettext as _ from django.db import connection +from django.utils.timezone import now from django.utils.deprecation import MiddlewareMixin from utils.api import JSONResponse @@ -14,14 +9,11 @@ class SessionRecordMiddleware(MiddlewareMixin): def process_request(self, request): if request.user.is_authenticated(): session = request.session - ip = request.META.get("HTTP_X_REAL_IP", "UNKNOWN IP") - user_agent = request.META.get("HTTP_USER_AGENT", "") - _ip = session.setdefault("ip", ip) - _user_agent = session.setdefault("user_agent", user_agent) - if ip != _ip or user_agent != _user_agent: - session.modified = True + session["user_agent"] = request.META.get("HTTP_USER_AGENT", "") + session["ip"] = request.META.get("HTTP_X_REAL_IP", "UNKNOWN IP") + session["last_activity"] = now() user_sessions = request.user.session_keys - if request.session.session_key not in user_sessions: + if session.session_key not in user_sessions: user_sessions.append(session.session_key) request.user.save() diff --git a/account/migrations/0001_initial.py b/account/migrations/0001_initial.py index a96776e0..e1e588ee 100644 --- a/account/migrations/0001_initial.py +++ b/account/migrations/0001_initial.py @@ -50,7 +50,7 @@ class Migration(migrations.Migration): fields=[ ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), ('problems_status', jsonfield.fields.JSONField(default={})), - ('avatar', models.CharField(default=account.models._default_avatar, max_length=50)), + ('avatar', models.CharField(default="default.png", max_length=50)), ('blog', models.URLField(blank=True, null=True)), ('mood', models.CharField(blank=True, max_length=200, null=True)), ('accepted_problem_number', models.IntegerField(default=0)), diff --git a/account/models.py b/account/models.py index 2db0e55b..3ba07d9a 100644 --- a/account/models.py +++ b/account/models.py @@ -72,7 +72,7 @@ class UserProfile(models.Model): oi_problems_status = JSONField(default={}) real_name = models.CharField(max_length=32, blank=True, null=True) - avatar = models.CharField(max_length=256, default=f"{settings.IMAGE_UPLOAD_DIR}/default.png") + avatar = models.CharField(max_length=256, default=f"/{settings.IMAGE_UPLOAD_DIR}/default.png") blog = models.URLField(blank=True, null=True) mood = models.CharField(max_length=256, blank=True, null=True) github = models.CharField(max_length=64, blank=True, null=True) diff --git a/account/serializers.py b/account/serializers.py index 1b29170c..aa9675a3 100644 --- a/account/serializers.py +++ b/account/serializers.py @@ -26,7 +26,6 @@ class UserRegisterSerializer(serializers.Serializer): class UserChangePasswordSerializer(serializers.Serializer): old_password = serializers.CharField() new_password = serializers.CharField(min_length=6) - captcha = serializers.CharField() class UserSerializer(serializers.ModelSerializer): @@ -46,6 +45,7 @@ class UserProfileSerializer(serializers.ModelSerializer): class Meta: model = UserProfile + fields = "__all__" class UserInfoSerializer(serializers.ModelSerializer): diff --git a/account/tests.py b/account/tests.py index 7331a0ee..906ef28b 100644 --- a/account/tests.py +++ b/account/tests.py @@ -8,11 +8,9 @@ from otpauth import OtpAuth from utils.api.tests import APIClient, APITestCase from utils.shortcuts import rand_str -from utils.cache import default_cache -from utils.constants import CacheKey +from options.options import SysOptions from .models import AdminType, ProblemPermission, User -from conf.models import WebsiteConfig class PermissionDecoratorTest(APITestCase): @@ -157,13 +155,9 @@ class UserRegisterAPITest(CaptchaTest): self.data = {"username": "test_user", "password": "testuserpassword", "real_name": "real_name", "email": "test@qduoj.com", "captcha": self._set_captcha(self.client.session)} - # clea cache in redis - default_cache.delete(CacheKey.website_config) def test_website_config_limit(self): - website = WebsiteConfig.objects.create() - website.allow_register = False - website.save() + SysOptions.allow_register = False resp = self.client.post(self.register_url, data=self.data) self.assertDictEqual(resp.data, {"error": "error", "data": "Register have been disabled by admin"}) @@ -247,7 +241,6 @@ class TwoFactorAuthAPITest(APITestCase): def setUp(self): self.url = self.reverse("two_factor_auth_api") self.create_user("test", "test123") - self.create_website_config() def _get_tfa_code(self): user = User.objects.first() @@ -295,7 +288,6 @@ class ApplyResetPasswordAPITest(CaptchaTest): user.email = "test@oj.com" user.save() self.url = self.reverse("apply_reset_password_api") - self.create_website_config() self.data = {"email": "test@oj.com", "captcha": self._set_captcha(self.client.session)} def _refresh_captcha(self): diff --git a/account/urls/oj.py b/account/urls/oj.py index cc716830..aacbb59d 100644 --- a/account/urls/oj.py +++ b/account/urls/oj.py @@ -3,7 +3,7 @@ from django.conf.urls import url from ..views.oj import (ApplyResetPasswordAPI, ResetPasswordAPI, UserChangePasswordAPI, UserRegisterAPI, UserLoginAPI, UserLogoutAPI, UsernameOrEmailCheck, - SSOAPI, AvatarUploadAPI, TwoFactorAuthAPI, UserProfileAPI, + AvatarUploadAPI, TwoFactorAuthAPI, UserProfileAPI, UserRankAPI, CheckTFARequiredAPI, SessionManagementAPI) from utils.captcha.views import CaptchaAPIView diff --git a/account/views/oj.py b/account/views/oj.py index 1225556e..5f72156e 100644 --- a/account/views/oj.py +++ b/account/views/oj.py @@ -1,5 +1,4 @@ import os -import pickle from datetime import timedelta from importlib import import_module @@ -16,15 +15,14 @@ from utils.constants import ContestRuleType from options.options import SysOptions from utils.api import APIView, validate_serializer from utils.captcha import Captcha -from utils.shortcuts import rand_str, img2base64, timestamp2utcstr +from utils.shortcuts import rand_str, img2base64, datetime2str from ..decorators import login_required from ..models import User, UserProfile from ..serializers import (ApplyResetPasswordSerializer, ResetPasswordSerializer, UserChangePasswordSerializer, UserLoginSerializer, UserRegisterSerializer, UsernameOrEmailCheckSerializer, RankInfoSerializer) -from ..serializers import (SSOSerializer, TwoFactorAuthCodeSerializer, - UserProfileSerializer, +from ..serializers import (TwoFactorAuthCodeSerializer, UserProfileSerializer, EditUserProfileSerializer, AvatarUploadForm) from ..tasks import send_email_async @@ -81,7 +79,7 @@ class AvatarUploadAPI(APIView): img.write(chunk) user_profile = request.user.userprofile - user_profile.avatar = f"{settings.IMAGE_UPLOAD_DIR}/{name}" + user_profile.avatar = f"/{settings.IMAGE_UPLOAD_DIR}/{name}" user_profile.save() return self.success("Succeeded") @@ -327,7 +325,7 @@ class SessionManagementAPI(APIView): s["current_session"] = True s["ip"] = session["ip"] s["user_agent"] = session["user_agent"] - s["last_activity"] = timestamp2utcstr(session["last_activity"]) + s["last_activity"] = datetime2str(session["last_activity"]) s["session_key"] = key result.append(s) if modified: diff --git a/contest/models.py b/contest/models.py index d66d4007..10581ccb 100644 --- a/contest/models.py +++ b/contest/models.py @@ -1,8 +1,9 @@ +from utils.constants import ContestRuleType # noqa from django.db import models from django.utils.timezone import now from jsonfield import JSONField -from utils.constants import ContestStatus, ContestRuleType, ContestType +from utils.constants import ContestStatus, ContestType from account.models import User, AdminType from utils.models import RichTextField diff --git a/contest/views/oj.py b/contest/views/oj.py index c4fd4fc0..6811ef61 100644 --- a/contest/views/oj.py +++ b/contest/views/oj.py @@ -4,7 +4,7 @@ from utils.api import APIView, validate_serializer from utils.constants import CacheKey from account.decorators import login_required, check_contest_permission -from utils.constants import ContestRuleType, ContestType, ContestStatus +from utils.constants import ContestRuleType, ContestStatus from ..models import ContestAnnouncement, Contest, OIContestRank, ACMContestRank from ..serializers import ContestAnnouncementSerializer from ..serializers import ContestSerializer, ContestPasswordVerifySerializer diff --git a/judge/dispatcher.py b/judge/dispatcher.py index 597ed7d1..29b4475f 100644 --- a/judge/dispatcher.py +++ b/judge/dispatcher.py @@ -14,7 +14,7 @@ from judge.languages import languages from options.options import SysOptions from problem.models import Problem, ProblemRuleType from submission.models import JudgeStatus, Submission -from utils.cache import judge_cache, default_cache +from utils.cache import cache from utils.constants import CacheKey logger = logging.getLogger(__name__) @@ -22,31 +22,28 @@ logger = logging.getLogger(__name__) # 继续处理在队列中的问题 def process_pending_task(): - if judge_cache.llen(CacheKey.waiting_queue): + if cache.llen(CacheKey.waiting_queue): # 防止循环引入 from judge.tasks import judge_task - data = json.loads(judge_cache.rpop(CacheKey.waiting_queue).decode("utf-8")) + data = json.loads(cache.rpop(CacheKey.waiting_queue).decode("utf-8")) judge_task.delay(**data) class JudgeDispatcher(object): def __init__(self, submission_id, problem_id): self.token = hashlib.sha256(SysOptions.judge_server_token.encode("utf-8")).hexdigest() - self.redis_conn = judge_cache - self.submission = Submission.objects.get(pk=submission_id) + self.submission = Submission.objects.get(id=submission_id) self.contest_id = self.submission.contest_id if self.contest_id: - self.problem = Problem.objects.select_related("contest") \ - .get(id=problem_id, contest_id=self.contest_id) + self.problem = Problem.objects.select_related("contest").get(id=problem_id, contest_id=self.contest_id) self.contest = self.problem.contest else: self.problem = Problem.objects.get(id=problem_id) def _request(self, url, data=None): - kwargs = {"headers": {"X-Judge-Server-Token": self.token, - "Content-Type": "application/json"}} + kwargs = {"headers": {"X-Judge-Server-Token": self.token}} if data: - kwargs["data"] = json.dumps(data) + kwargs["json"] = data try: return requests.post(url, **kwargs).json() except Exception as e: @@ -55,7 +52,6 @@ class JudgeDispatcher(object): @staticmethod def choose_judge_server(): with transaction.atomic(): - # TODO: use more reasonable way servers = JudgeServer.objects.select_for_update().all().order_by("task_number") servers = [s for s in servers if s.status == "normal"] if servers: @@ -65,10 +61,10 @@ class JudgeDispatcher(object): return server @staticmethod - def release_judge_res(judge_server_id): + def release_judge_server(judge_server_id): with transaction.atomic(): # 使用原子操作, 同时因为use和release中间间隔了判题过程,需要重新查询一下 - server = JudgeServer.objects.select_for_update().get(id=judge_server_id) + server = JudgeServer.objects.get(id=judge_server_id) server.used_instance_number = F("task_number") - 1 server.save() @@ -94,7 +90,7 @@ class JudgeDispatcher(object): server = self.choose_judge_server() if not server: data = {"submission_id": self.submission.id, "problem_id": self.problem.id} - self.redis_conn.lpush(CacheKey.waiting_queue, json.dumps(data)) + cache.lpush(CacheKey.waiting_queue, json.dumps(data)) return sub_config = list(filter(lambda item: self.submission.language == item["name"], languages))[0] @@ -138,7 +134,7 @@ class JudgeDispatcher(object): else: self.submission.result = JudgeStatus.PARTIALLY_ACCEPTED self.submission.save() - self.release_judge_res(server.id) + self.release_judge_server(server.id) self.update_problem_status() if self.contest_id: @@ -223,7 +219,7 @@ class JudgeDispatcher(object): if self.contest_id and self.contest.status != ContestStatus.CONTEST_UNDERWAY: return if self.contest.real_time_rank: - default_cache.delete(CacheKey.contest_rank_cache + str(self.contest_id)) + cache.delete(CacheKey.contest_rank_cache + str(self.contest_id)) with transaction.atomic(): if self.contest.rule_type == ContestRuleType.ACM: acm_rank, _ = ACMContestRank.objects.select_for_update(). \ diff --git a/oj/local_settings.py b/oj/local_settings.py index cee68f25..bbe2398f 100644 --- a/oj/local_settings.py +++ b/oj/local_settings.py @@ -5,8 +5,12 @@ BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) DATABASES = { 'default': { - 'ENGINE': 'django.db.backends.sqlite3', - 'NAME': os.path.join(BASE_DIR, 'db.sqlite3'), + 'ENGINE': 'django.db.backends.postgresql_psycopg2', + 'HOST': '127.0.0.1', + 'PORT': 5433, + 'NAME': "onlinejudge", + 'USER': "onlinejudge", + 'PASSWORD': 'onlinejudge' } } diff --git a/oj/settings.py b/oj/settings.py index 672bc6f3..ac972c79 100644 --- a/oj/settings.py +++ b/oj/settings.py @@ -61,7 +61,6 @@ MIDDLEWARE_CLASSES = ( 'account.middleware.SessionRecordMiddleware', # 'account.middleware.LogSqlMiddleware', ) -SESSION_ENGINE = 'django.contrib.sessions.backends.cache' ROOT_URLCONF = 'oj.urls' TEMPLATES = [ @@ -166,41 +165,33 @@ LOGGING = { } -REST_FRAMEWORK = { - 'TEST_REQUEST_DEFAULT_FORMAT': 'json', - 'DEFAULT_RENDERER_CLASSES': ( - 'rest_framework.renderers.JSONRenderer', - ) -} +REDIS_URL = "redis://127.0.0.1:6379" -CACHE_JUDGE_QUEUE = "judge_queue" -CACHE_THROTTLING = "throttling" +def redis_config(db): + def make_key(key, key_prefix, version): + return key + + return { + "BACKEND": "utils.cache.MyRedisCache", + "LOCATION": f"{REDIS_URL}/{db}", + "TIMEOUT": None, + "KEY_PREFIX": "", + "KEY_FUNCTION": make_key + } CACHES = { - "default": { - "BACKEND": "django_redis.cache.RedisCache", - "LOCATION": "redis://127.0.0.1:6379/1", - "OPTIONS": { - "CLIENT_CLASS": "django_redis.client.DefaultClient", - } - }, - CACHE_JUDGE_QUEUE: { - "BACKEND": "django_redis.cache.RedisCache", - "LOCATION": "redis://127.0.0.1:6379/2", - "OPTIONS": { - "CLIENT_CLASS": "django_redis.client.DefaultClient", - } - }, - CACHE_THROTTLING: { - "BACKEND": "django_redis.cache.RedisCache", - "LOCATION": "redis://127.0.0.1:6379/3", - "OPTIONS": { - "CLIENT_CLASS": "django_redis.client.DefaultClient", - } - } + "default": redis_config(db=1) } + +CELERY_RESULT_BACKEND = CELERY_BROKER_URL = f"{REDIS_URL}/2" +CELERY_TASK_SOFT_TIME_LIMIT = CELERY_TASK_TIME_LIMIT = 180 + +SESSION_ENGINE = "django.contrib.sessions.backends.cache" +SESSION_CACHE_ALIAS = "default" + + # For celery REDIS_QUEUE = { "host": "127.0.0.1", diff --git a/options/options.py b/options/options.py index 51beb539..b2d76f16 100644 --- a/options/options.py +++ b/options/options.py @@ -113,15 +113,15 @@ class _SysOptionsMeta(type): @property def website_base_url(cls): return cls._get_option(OptionKeys.website_base_url) - + @website_base_url.setter def website_base_url(cls, value): cls._set_option(OptionKeys.website_base_url, value) - + @property def website_name(cls): return cls._get_option(OptionKeys.website_name) - + @website_name.setter def website_name(cls, value): cls._set_option(OptionKeys.website_name, value) @@ -173,7 +173,7 @@ class _SysOptionsMeta(type): @judge_server_token.setter def judge_server_token(cls, value): cls._set_option(OptionKeys.judge_server_token, value) - + class SysOptions(metaclass=_SysOptionsMeta): pass diff --git a/problem/serializers.py b/problem/serializers.py index d9e49195..4856b6d6 100644 --- a/problem/serializers.py +++ b/problem/serializers.py @@ -71,6 +71,7 @@ class CreateContestProblemSerializer(CreateOrEditProblemSerializer): class TagSerializer(serializers.ModelSerializer): class Meta: model = ProblemTag + fields = "__all__" class BaseProblemSerializer(serializers.ModelSerializer): @@ -88,6 +89,7 @@ class BaseProblemSerializer(serializers.ModelSerializer): class ProblemAdminSerializer(BaseProblemSerializer): class Meta: model = Problem + fields = "__all__" class ContestProblemAdminSerializer(BaseProblemSerializer): diff --git a/submission/views/oj.py b/submission/views/oj.py index 274418a7..e9672a0a 100644 --- a/submission/views/oj.py +++ b/submission/views/oj.py @@ -5,16 +5,16 @@ from problem.models import Problem, ProblemRuleType from contest.models import Contest, ContestStatus, ContestRuleType from utils.api import APIView, validate_serializer from utils.throttling import TokenBucket, BucketController +from utils.cache import cache from ..models import Submission from ..serializers import CreateSubmissionSerializer, SubmissionModelSerializer from ..serializers import SubmissionSafeSerializer, SubmissionListSerializer -from utils.cache import throttling_cache def _submit(response, user, problem_id, language, code, contest_id): # TODO: 预设默认值,需修改 controller = BucketController(user_id=user.id, - redis_conn=throttling_cache, + redis_conn=cache, default_capacity=30) bucket = TokenBucket(fill_rate=10, capacity=20, last_capacity=controller.last_capacity, diff --git a/utils/cache.py b/utils/cache.py index c77131f6..ed9059b1 100644 --- a/utils/cache.py +++ b/utils/cache.py @@ -1,6 +1,27 @@ -from django.conf import settings -from django_redis import get_redis_connection +from django.core.cache import cache, caches # noqa +from django.conf import settings # noqa -judge_cache = get_redis_connection(settings.CACHE_JUDGE_QUEUE) -throttling_cache = get_redis_connection(settings.CACHE_THROTTLING) -default_cache = get_redis_connection("default") +from django_redis.cache import RedisCache +from django_redis.client.default import DefaultClient + + +class MyRedisClient(DefaultClient): + def __getattr__(self, item): + client = self.get_client(write=True) + return getattr(client, item) + + def redis_incr(self, key, count=1): + """ + django 默认的 incr 在 key 不存在时候会抛异常 + """ + client = self.get_client(write=True) + return client.incr(key, count) + + +class MyRedisCache(RedisCache): + def __init__(self, server, params): + super().__init__(server, params) + self._client_cls = MyRedisClient + + def __getattr__(self, item): + return getattr(self.client, item) From 080ecf1bcf81dc9f206ada138e130c0ebc595337 Mon Sep 17 00:00:00 2001 From: zema1 Date: Wed, 11 Oct 2017 21:43:29 +0800 Subject: [PATCH 5/8] migrate to postgres json field --- account/migrations/0008_auto_20171011_1214.py | 105 ++++++++++++++++++ account/models.py | 23 ++-- .../migrations/0002_auto_20171011_1214.py | 20 ++++ conf/migrations/0002_auto_20171011_1214.py | 39 +++++++ contest/migrations/0006_auto_20171011_1214.py | 26 +++++ contest/models.py | 6 +- judge/dispatcher.py | 18 +-- options/migrations/0002_auto_20171011_1214.py | 21 ++++ options/models.py | 2 +- problem/migrations/0009_auto_20171011_1214.py | 41 +++++++ problem/models.py | 4 +- problem/views/oj.py | 4 +- .../migrations/0008_auto_20171011_1214.py | 26 +++++ submission/models.py | 6 +- utils/models.py | 1 + 15 files changed, 315 insertions(+), 27 deletions(-) create mode 100644 account/migrations/0008_auto_20171011_1214.py create mode 100644 announcement/migrations/0002_auto_20171011_1214.py create mode 100644 conf/migrations/0002_auto_20171011_1214.py create mode 100644 contest/migrations/0006_auto_20171011_1214.py create mode 100644 options/migrations/0002_auto_20171011_1214.py create mode 100644 problem/migrations/0009_auto_20171011_1214.py create mode 100644 submission/migrations/0008_auto_20171011_1214.py diff --git a/account/migrations/0008_auto_20171011_1214.py b/account/migrations/0008_auto_20171011_1214.py new file mode 100644 index 00000000..7426a1f1 --- /dev/null +++ b/account/migrations/0008_auto_20171011_1214.py @@ -0,0 +1,105 @@ +# -*- coding: utf-8 -*- +# Generated by Django 1.11.4 on 2017-10-11 12:14 +from __future__ import unicode_literals + +import django.contrib.postgres.fields.jsonb +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('account', '0007_auto_20170920_0254'), + ] + + operations = [ + migrations.RemoveField( + model_name='userprofile', + name='language', + ), + migrations.AlterField( + model_name='user', + name='admin_type', + field=models.CharField(default='Regular User', max_length=32), + ), + migrations.AlterField( + model_name='user', + name='auth_token', + field=models.CharField(max_length=32, null=True), + ), + migrations.AlterField( + model_name='user', + name='email', + field=models.EmailField(max_length=64, null=True), + ), + migrations.AlterField( + model_name='user', + name='open_api_appkey', + field=models.CharField(max_length=32, null=True), + ), + migrations.AlterField( + model_name='user', + name='problem_permission', + field=models.CharField(default='None', max_length=32), + ), + migrations.AlterField( + model_name='user', + name='reset_password_token', + field=models.CharField(max_length=32, null=True), + ), + migrations.AlterField( + model_name='user', + name='session_keys', + field=django.contrib.postgres.fields.jsonb.JSONField(default=list), + ), + migrations.AlterField( + model_name='user', + name='tfa_token', + field=models.CharField(max_length=32, null=True), + ), + migrations.AlterField( + model_name='user', + name='username', + field=models.CharField(max_length=32, unique=True), + ), + migrations.AlterField( + model_name='userprofile', + name='acm_problems_status', + field=django.contrib.postgres.fields.jsonb.JSONField(default=dict), + ), + migrations.AlterField( + model_name='userprofile', + name='avatar', + field=models.CharField(default='/static/avatar/default.png', max_length=256), + ), + migrations.AlterField( + model_name='userprofile', + name='github', + field=models.CharField(blank=True, max_length=64, null=True), + ), + migrations.AlterField( + model_name='userprofile', + name='major', + field=models.CharField(blank=True, max_length=64, null=True), + ), + migrations.AlterField( + model_name='userprofile', + name='mood', + field=models.CharField(blank=True, max_length=256, null=True), + ), + migrations.AlterField( + model_name='userprofile', + name='oi_problems_status', + field=django.contrib.postgres.fields.jsonb.JSONField(default=dict), + ), + migrations.AlterField( + model_name='userprofile', + name='real_name', + field=models.CharField(blank=True, max_length=32, null=True), + ), + migrations.AlterField( + model_name='userprofile', + name='school', + field=models.CharField(blank=True, max_length=64, null=True), + ), + ] diff --git a/account/models.py b/account/models.py index 3ba07d9a..0b0e6ca5 100644 --- a/account/models.py +++ b/account/models.py @@ -1,7 +1,7 @@ from django.contrib.auth.models import AbstractBaseUser from django.conf import settings from django.db import models -from jsonfield import JSONField +from utils.models import JSONField class AdminType(object): @@ -36,7 +36,7 @@ class User(AbstractBaseUser): auth_token = models.CharField(max_length=32, null=True) two_factor_auth = models.BooleanField(default=False) tfa_token = models.CharField(max_length=32, null=True) - session_keys = JSONField(default=[]) + session_keys = JSONField(default=list) # open api key open_api = models.BooleanField(default=False) open_api_appkey = models.CharField(max_length=32, null=True) @@ -65,11 +65,20 @@ class User(AbstractBaseUser): class UserProfile(models.Model): user = models.OneToOneField(User) - # Store user problem solution status with json string format - # {problems: {1: JudgeStatus.ACCEPTED}, contest_problems: {1: JudgeStatus.ACCEPTED}}, record problem_id and status - acm_problems_status = JSONField(default={}) - # {problems: {1: 33}, contest_problems: {1: 44}, record problem_id and score - oi_problems_status = JSONField(default={}) + # acm_problems_status examples: + # { + # "problems": { + # "1": { + # "status": JudgeStatus.ACCEPTED, + # "_id": "1000" + # } + # }, + # "contest_problems": { + # } + # } + acm_problems_status = JSONField(default=dict) + # like acm_problems_status, merely add "score" field + oi_problems_status = JSONField(default=dict) real_name = models.CharField(max_length=32, blank=True, null=True) avatar = models.CharField(max_length=256, default=f"/{settings.IMAGE_UPLOAD_DIR}/default.png") diff --git a/announcement/migrations/0002_auto_20171011_1214.py b/announcement/migrations/0002_auto_20171011_1214.py new file mode 100644 index 00000000..ffe4c969 --- /dev/null +++ b/announcement/migrations/0002_auto_20171011_1214.py @@ -0,0 +1,20 @@ +# -*- coding: utf-8 -*- +# Generated by Django 1.11.4 on 2017-10-11 12:14 +from __future__ import unicode_literals + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('announcement', '0001_initial'), + ] + + operations = [ + migrations.AlterField( + model_name='announcement', + name='title', + field=models.CharField(max_length=64), + ), + ] diff --git a/conf/migrations/0002_auto_20171011_1214.py b/conf/migrations/0002_auto_20171011_1214.py new file mode 100644 index 00000000..ef355b50 --- /dev/null +++ b/conf/migrations/0002_auto_20171011_1214.py @@ -0,0 +1,39 @@ +# -*- coding: utf-8 -*- +# Generated by Django 1.11.4 on 2017-10-11 12:14 +from __future__ import unicode_literals + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('conf', '0001_initial'), + ] + + operations = [ + migrations.DeleteModel( + name='JudgeServerToken', + ), + migrations.DeleteModel( + name='SMTPConfig', + ), + migrations.DeleteModel( + name='WebsiteConfig', + ), + migrations.AlterField( + model_name='judgeserver', + name='hostname', + field=models.CharField(max_length=128), + ), + migrations.AlterField( + model_name='judgeserver', + name='judger_version', + field=models.CharField(max_length=32), + ), + migrations.AlterField( + model_name='judgeserver', + name='service_url', + field=models.CharField(blank=True, max_length=256, null=True), + ), + ] diff --git a/contest/migrations/0006_auto_20171011_1214.py b/contest/migrations/0006_auto_20171011_1214.py new file mode 100644 index 00000000..d429742f --- /dev/null +++ b/contest/migrations/0006_auto_20171011_1214.py @@ -0,0 +1,26 @@ +# -*- coding: utf-8 -*- +# Generated by Django 1.11.4 on 2017-10-11 12:14 +from __future__ import unicode_literals + +import django.contrib.postgres.fields.jsonb +from django.db import migrations + + +class Migration(migrations.Migration): + + dependencies = [ + ('contest', '0005_auto_20170823_0918'), + ] + + operations = [ + migrations.AlterField( + model_name='acmcontestrank', + name='submission_info', + field=django.contrib.postgres.fields.jsonb.JSONField(default=dict), + ), + migrations.AlterField( + model_name='oicontestrank', + name='submission_info', + field=django.contrib.postgres.fields.jsonb.JSONField(default=dict), + ), + ] diff --git a/contest/models.py b/contest/models.py index 10581ccb..42519988 100644 --- a/contest/models.py +++ b/contest/models.py @@ -1,7 +1,7 @@ from utils.constants import ContestRuleType # noqa from django.db import models from django.utils.timezone import now -from jsonfield import JSONField +from utils.models import JSONField from utils.constants import ContestStatus, ContestType from account.models import User, AdminType @@ -65,7 +65,7 @@ class ACMContestRank(AbstractContestRank): total_time = models.IntegerField(default=0) # {23: {"is_ac": True, "ac_time": 8999, "error_number": 2, "is_first_ac": True}} # key is problem id - submission_info = JSONField(default={}) + submission_info = JSONField(default=dict) class Meta: db_table = "acm_contest_rank" @@ -75,7 +75,7 @@ class OIContestRank(AbstractContestRank): total_score = models.IntegerField(default=0) # {23: 333}} # key is problem id, value is current score - submission_info = JSONField(default={}) + submission_info = JSONField(default=dict) class Meta: db_table = "oi_contest_rank" diff --git a/judge/dispatcher.py b/judge/dispatcher.py index 29b4475f..e457b29c 100644 --- a/judge/dispatcher.py +++ b/judge/dispatcher.py @@ -186,13 +186,10 @@ class JudgeDispatcher(object): # update user_profile if problem_id not in acm_problems_status: - acm_problems_status[problem_id] = self.submission.result + acm_problems_status[problem_id] = {"status": self.submission.result, "_id": self.problem._id} # skip if the problem has been accepted - elif acm_problems_status[problem_id] != JudgeStatus.ACCEPTED: - if self.submission.result == JudgeStatus.ACCEPTED: - acm_problems_status[problem_id] = JudgeStatus.ACCEPTED - else: - acm_problems_status[problem_id] = self.submission.result + elif acm_problems_status[problem_id]["status"] != JudgeStatus.ACCEPTED: + acm_problems_status[problem_id]["status"] = self.submission.result user_profile.acm_problems_status[key] = acm_problems_status else: @@ -204,11 +201,14 @@ class JudgeDispatcher(object): # update user_profile if problem_id not in oi_problems_status: user_profile.add_score(score) - oi_problems_status[problem_id] = score + oi_problems_status[problem_id] = {"status": self.submission.result, + "_id": self.problem._id, + "score": score} else: # minus last time score, add this time score - user_profile.add_score(this_time_score=score, last_time_score=oi_problems_status[problem_id]) - oi_problems_status[problem_id] = score + user_profile.add_score(this_time_score=score, last_time_score=oi_problems_status[problem_id]["score"]) + oi_problems_status[problem_id]["score"] = score + oi_problems_status[problem_id]["status"] = self.submission.result user_profile.oi_problems_status[key] = oi_problems_status problem.save(update_fields=["submission_number", "accepted_number", "statistic_info"]) diff --git a/options/migrations/0002_auto_20171011_1214.py b/options/migrations/0002_auto_20171011_1214.py new file mode 100644 index 00000000..ee52ffa4 --- /dev/null +++ b/options/migrations/0002_auto_20171011_1214.py @@ -0,0 +1,21 @@ +# -*- coding: utf-8 -*- +# Generated by Django 1.11.4 on 2017-10-11 12:14 +from __future__ import unicode_literals + +import django.contrib.postgres.fields.jsonb +from django.db import migrations + + +class Migration(migrations.Migration): + + dependencies = [ + ('options', '0001_initial'), + ] + + operations = [ + migrations.AlterField( + model_name='sysoptions', + name='value', + field=django.contrib.postgres.fields.jsonb.JSONField(), + ), + ] diff --git a/options/models.py b/options/models.py index 6d9ac75c..04dee5e2 100644 --- a/options/models.py +++ b/options/models.py @@ -1,5 +1,5 @@ from django.db import models -from jsonfield import JSONField +from utils.models import JSONField class SysOptions(models.Model): diff --git a/problem/migrations/0009_auto_20171011_1214.py b/problem/migrations/0009_auto_20171011_1214.py new file mode 100644 index 00000000..7073b8f9 --- /dev/null +++ b/problem/migrations/0009_auto_20171011_1214.py @@ -0,0 +1,41 @@ +# -*- coding: utf-8 -*- +# Generated by Django 1.11.4 on 2017-10-11 12:14 +from __future__ import unicode_literals + +import django.contrib.postgres.fields.jsonb +from django.db import migrations + + +class Migration(migrations.Migration): + + dependencies = [ + ('problem', '0008_auto_20170923_1318'), + ] + + operations = [ + migrations.AlterField( + model_name='problem', + name='languages', + field=django.contrib.postgres.fields.jsonb.JSONField(), + ), + migrations.AlterField( + model_name='problem', + name='samples', + field=django.contrib.postgres.fields.jsonb.JSONField(), + ), + migrations.AlterField( + model_name='problem', + name='statistic_info', + field=django.contrib.postgres.fields.jsonb.JSONField(default=dict), + ), + migrations.AlterField( + model_name='problem', + name='template', + field=django.contrib.postgres.fields.jsonb.JSONField(), + ), + migrations.AlterField( + model_name='problem', + name='test_case_score', + field=django.contrib.postgres.fields.jsonb.JSONField(), + ), + ] diff --git a/problem/models.py b/problem/models.py index 0e9c5e2f..72f3fc9d 100644 --- a/problem/models.py +++ b/problem/models.py @@ -1,5 +1,5 @@ from django.db import models -from jsonfield import JSONField +from utils.models import JSONField from account.models import User from contest.models import Contest @@ -66,7 +66,7 @@ class Problem(models.Model): submission_number = models.BigIntegerField(default=0) accepted_number = models.BigIntegerField(default=0) # ACM rule_type: {JudgeStatus.ACCEPTED: 3, JudgeStaus.WRONG_ANSWER: 11}, the number means count - statistic_info = JSONField(default={}) + statistic_info = JSONField(default=dict) class Meta: db_table = "problem" diff --git a/problem/views/oj.py b/problem/views/oj.py index 0a6b066b..6764249d 100644 --- a/problem/views/oj.py +++ b/problem/views/oj.py @@ -55,9 +55,9 @@ class ProblemAPI(APIView): oi_problems_status = profile.oi_problems_status.get("problems", {}) for problem in data["results"]: if problem["rule_type"] == ProblemRuleType.ACM: - problem["my_status"] = acm_problems_status.get(str(problem["id"]), None) + problem["my_status"] = acm_problems_status.get(str(problem["id"]), {}).get("status") else: - problem["my_status"] = oi_problems_status.get(str(problem["id"]), None) + problem["my_status"] = oi_problems_status.get(str(problem["id"]), {}).get("status") return self.success(data) diff --git a/submission/migrations/0008_auto_20171011_1214.py b/submission/migrations/0008_auto_20171011_1214.py new file mode 100644 index 00000000..1c585d8a --- /dev/null +++ b/submission/migrations/0008_auto_20171011_1214.py @@ -0,0 +1,26 @@ +# -*- coding: utf-8 -*- +# Generated by Django 1.11.4 on 2017-10-11 12:14 +from __future__ import unicode_literals + +import django.contrib.postgres.fields.jsonb +from django.db import migrations + + +class Migration(migrations.Migration): + + dependencies = [ + ('submission', '0007_auto_20170923_1318'), + ] + + operations = [ + migrations.AlterField( + model_name='submission', + name='info', + field=django.contrib.postgres.fields.jsonb.JSONField(default=dict), + ), + migrations.AlterField( + model_name='submission', + name='statistic_info', + field=django.contrib.postgres.fields.jsonb.JSONField(default=dict), + ), + ] diff --git a/submission/models.py b/submission/models.py index f4cb2f2e..ef408935 100644 --- a/submission/models.py +++ b/submission/models.py @@ -1,5 +1,5 @@ from django.db import models -from jsonfield import JSONField +from utils.models import JSONField from account.models import AdminType from problem.models import Problem from contest.models import Contest @@ -31,12 +31,12 @@ class Submission(models.Model): code = models.TextField() result = models.IntegerField(db_index=True, default=JudgeStatus.PENDING) # 判题结果的详细信息 - info = JSONField(default={}) + info = JSONField(default=dict) language = models.CharField(max_length=20) shared = models.BooleanField(default=False) # 存储该提交所用时间和内存值,方便提交列表显示 # {time_cost: "", memory_cost: "", err_info: "", score: 0} - statistic_info = JSONField(default={}) + statistic_info = JSONField(default=dict) def check_user_permission(self, user): return self.user_id == user.id or \ diff --git a/utils/models.py b/utils/models.py index b651aa2d..3c114522 100644 --- a/utils/models.py +++ b/utils/models.py @@ -1,3 +1,4 @@ +from django.contrib.postgres.fields import JSONField # NOQA from django.db import models from utils.xss_filter import XssHtml From 2c5a1e42bf5bc866067edccc9aa30b278c73a9f2 Mon Sep 17 00:00:00 2001 From: zema1 Date: Sun, 15 Oct 2017 18:36:55 +0800 Subject: [PATCH 6/8] support share submission --- account/models.py | 4 +++ problem/models.py | 2 +- problem/views/oj.py | 56 ++++++++++++++++++++++++++------------- submission/models.py | 10 ++++--- submission/serializers.py | 9 +++++-- submission/views/oj.py | 39 ++++++++++++++++++++------- 6 files changed, 85 insertions(+), 35 deletions(-) diff --git a/account/models.py b/account/models.py index 0b0e6ca5..6ef7f71f 100644 --- a/account/models.py +++ b/account/models.py @@ -74,6 +74,10 @@ class UserProfile(models.Model): # } # }, # "contest_problems": { + # "1": { + # "status": JudgeStatus.ACCEPTED, + # "_id": "1000" + # } # } # } acm_problems_status = JSONField(default=dict) diff --git a/problem/models.py b/problem/models.py index 72f3fc9d..90537932 100644 --- a/problem/models.py +++ b/problem/models.py @@ -65,7 +65,7 @@ class Problem(models.Model): total_score = models.IntegerField(default=0, blank=True) submission_number = models.BigIntegerField(default=0) accepted_number = models.BigIntegerField(default=0) - # ACM rule_type: {JudgeStatus.ACCEPTED: 3, JudgeStaus.WRONG_ANSWER: 11}, the number means count + # {JudgeStatus.ACCEPTED: 3, JudgeStaus.WRONG_ANSWER: 11}, the number means count statistic_info = JSONField(default=dict) class Meta: diff --git a/problem/views/oj.py b/problem/views/oj.py index 6764249d..321b1bb9 100644 --- a/problem/views/oj.py +++ b/problem/views/oj.py @@ -13,6 +13,25 @@ class ProblemTagAPI(APIView): class ProblemAPI(APIView): + @staticmethod + def _add_problem_status(request, queryset_values): + if request.user.is_authenticated(): + profile = request.user.userprofile + acm_problems_status = profile.acm_problems_status.get("problems", {}) + oi_problems_status = profile.oi_problems_status.get("problems", {}) + # paginate data + results = queryset_values.get("results") + if results: + problems = results + else: + problems = [queryset_values,] + + for problem in problems: + if problem["rule_type"] == ProblemRuleType.ACM: + problem["my_status"] = acm_problems_status.get(str(problem["id"]), {}).get("status") + else: + problem["my_status"] = oi_problems_status.get(str(problem["id"]), {}).get("status") + def get(self, request): # 问题详情页 problem_id = request.GET.get("problem_id") @@ -20,7 +39,9 @@ class ProblemAPI(APIView): try: problem = Problem.objects.select_related("created_by")\ .get(_id=problem_id, contest_id__isnull=True, visible=True) - return self.success(ProblemSerializer(problem).data) + problem_data = ProblemSerializer(problem).data + self._add_problem_status(request, problem_data) + return self.success(problem_data) except Problem.DoesNotExist: return self.error("Problem does not exist") @@ -49,19 +70,21 @@ class ProblemAPI(APIView): problems = problems.filter(difficulty=difficulty) # 根据profile 为做过的题目添加标记 data = self.paginate_data(request, problems, ProblemSerializer) - if request.user.id: - profile = request.user.userprofile - acm_problems_status = profile.acm_problems_status.get("problems", {}) - oi_problems_status = profile.oi_problems_status.get("problems", {}) - for problem in data["results"]: - if problem["rule_type"] == ProblemRuleType.ACM: - problem["my_status"] = acm_problems_status.get(str(problem["id"]), {}).get("status") - else: - problem["my_status"] = oi_problems_status.get(str(problem["id"]), {}).get("status") + self._add_problem_status(request, data) return self.success(data) class ContestProblemAPI(APIView): + def _add_problem_status(self, request, queryset_values): + if request.user.is_authenticated() and self.contest.rule_type != ContestRuleType.OI: + profile = request.user.userprofile + if self.contest.rule_type == ContestRuleType.ACM: + problems_status = profile.acm_problems_status.get("contest_problems", {}) + else: + problems_status = profile.oi_problems_status.get("contest_problems", {}) + for problem in queryset_values: + problem["my_status"] = problems_status.get(str(problem["id"]), {}).get("status") + @check_contest_permission def get(self, request): problem_id = request.GET.get("problem_id") @@ -72,17 +95,12 @@ class ContestProblemAPI(APIView): visible=True) except Problem.DoesNotExist: return self.error("Problem does not exist.") - return self.success(ContestProblemSerializer(problem).data) + problem_data = ContestProblemSerializer(problem).data + self._add_problem_status(request, problem_data) + return self.success(problem_data) contest_problems = Problem.objects.select_related("created_by").filter(contest=self.contest, visible=True) # 根据profile, 为做过的题目添加标记 data = ContestProblemSerializer(contest_problems, many=True).data - if request.user.is_authenticated() and self.contest.rule_type != ContestRuleType.OI: - profile = request.user.userprofile - if self.contest.rule_type == ContestRuleType.ACM: - problems_status = profile.acm_problems_status.get("contest_problems", {}) - else: - problems_status = profile.oi_problems_status.get("contest_problems", {}) - for problem in data: - problem["my_status"] = problems_status.get(str(problem["id"]), None) + self._add_problem_status(request, data) return self.success(data) diff --git a/submission/models.py b/submission/models.py index ef408935..7f046a04 100644 --- a/submission/models.py +++ b/submission/models.py @@ -30,7 +30,7 @@ class Submission(models.Model): username = models.CharField(max_length=30) code = models.TextField() result = models.IntegerField(db_index=True, default=JudgeStatus.PENDING) - # 判题结果的详细信息 + # 从JudgeServer返回的判题详情 info = JSONField(default=dict) language = models.CharField(max_length=20) shared = models.BooleanField(default=False) @@ -38,10 +38,12 @@ class Submission(models.Model): # {time_cost: "", memory_cost: "", err_info: "", score: 0} statistic_info = JSONField(default=dict) - def check_user_permission(self, user): + def check_user_permission(self, user, check_share=True): return self.user_id == user.id or \ - self.shared is True or \ - user.admin_type == AdminType.SUPER_ADMIN + (check_share and self.shared is True) or \ + user.is_super_admin() or \ + user.can_mgmt_all_problem() or \ + self.problem.created_by_id == user.id class Meta: db_table = "submission" diff --git a/submission/serializers.py b/submission/serializers.py index ae8c3a60..66a517bd 100644 --- a/submission/serializers.py +++ b/submission/serializers.py @@ -10,6 +10,11 @@ class CreateSubmissionSerializer(serializers.Serializer): contest_id = serializers.IntegerField(required=False) +class ShareSubmissionSerializer(serializers.Serializer): + id = serializers.CharField() + shared = serializers.BooleanField() + + class SubmissionModelSerializer(serializers.ModelSerializer): info = serializers.JSONField() statistic_info = serializers.JSONField() @@ -19,7 +24,7 @@ class SubmissionModelSerializer(serializers.ModelSerializer): # 不显示submission info的serializer, 用于ACM rule_type -class SubmissionSafeSerializer(serializers.ModelSerializer): +class SubmissionSafeModelSerializer(serializers.ModelSerializer): problem = serializers.SlugRelatedField(read_only=True, slug_field="_id") statistic_info = serializers.JSONField() @@ -43,6 +48,6 @@ class SubmissionListSerializer(serializers.ModelSerializer): def get_show_link(self, obj): # 没传user或为匿名user - if self.user is None or self.user.id is None: + if self.user is None or not self.user.is_authenticated(): return False return obj.check_user_permission(self.user) diff --git a/submission/views/oj.py b/submission/views/oj.py index e9672a0a..c4bfe21b 100644 --- a/submission/views/oj.py +++ b/submission/views/oj.py @@ -7,8 +7,9 @@ from utils.api import APIView, validate_serializer from utils.throttling import TokenBucket, BucketController from utils.cache import cache from ..models import Submission -from ..serializers import CreateSubmissionSerializer, SubmissionModelSerializer -from ..serializers import SubmissionSafeSerializer, SubmissionListSerializer +from ..serializers import (CreateSubmissionSerializer, SubmissionModelSerializer, + ShareSubmissionSerializer) +from ..serializers import SubmissionSafeModelSerializer, SubmissionListSerializer def _submit(response, user, problem_id, language, code, contest_id): @@ -63,17 +64,37 @@ class SubmissionAPI(APIView): def get(self, request): submission_id = request.GET.get("id") if not submission_id: - return self.error("Parameter id doesn't exist.") + return self.error("Parameter id doesn't exist") try: submission = Submission.objects.select_related("problem").get(id=submission_id) except Submission.DoesNotExist: - return self.error("Submission doesn't exist.") + return self.error("Submission doesn't exist") if not submission.check_user_permission(request.user): - return self.error("No permission for this submission.") + return self.error("No permission for this submission") if submission.problem.rule_type == ProblemRuleType.ACM: - return self.success(SubmissionSafeSerializer(submission).data) - return self.success(SubmissionModelSerializer(submission).data) + submission_data = SubmissionSafeModelSerializer(submission).data + else: + submission_data = SubmissionModelSerializer(submission).data + # 是否有权限取消共享 + submission_data["can_unshare"] = submission.check_user_permission(request.user, check_share=False) + return self.success(submission_data) + + @validate_serializer(ShareSubmissionSerializer) + @login_required + def put(self, request): + try: + submission = Submission.objects.select_related("problem")\ + .get(id=request.data["id"], contest__isnull=True) + except Submission.DoesNotExist: + return self.error("Submission doesn't exist") + if not submission.check_user_permission(request.user, check_share=False): + return self.error("No permission to share the submission") + if submission.contest and submission.contest.status == ContestStatus.CONTEST_UNDERWAY: + return self.error("Can not share submission during a contest going") + submission.shared = request.data["shared"] + submission.save(update_fields=["shared"]) + return self.success() class SubmissionListAPI(APIView): @@ -83,7 +104,7 @@ class SubmissionListAPI(APIView): if request.GET.get("contest_id"): return self.error("Parameter error") - submissions = Submission.objects.filter(contest_id__isnull=True) + submissions = Submission.objects.filter(contest_id__isnull=True).select_related("problem__created_by") problem_id = request.GET.get("problem_id") myself = request.GET.get("myself") result = request.GET.get("result") @@ -112,7 +133,7 @@ class ContestSubmissionListAPI(APIView): if contest.rule_type == ContestRuleType.OI and not contest.is_contest_admin(request.user): return self.error("No permission for OI contest submissions") - submissions = Submission.objects.filter(contest_id=contest.id) + submissions = Submission.objects.filter(contest_id=contest.id).select_related("problem__created_by") problem_id = request.GET.get("problem_id") myself = request.GET.get("myself") result = request.GET.get("result") From f5566148bce13b91d651e4089f2e1e970de05335 Mon Sep 17 00:00:00 2001 From: zema1 Date: Mon, 16 Oct 2017 09:45:29 +0800 Subject: [PATCH 7/8] =?UTF-8?q?=E5=AE=8C=E5=96=84OI=E7=BB=86=E5=88=99?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- account/decorators.py | 3 ++- contest/models.py | 8 ++++++++ contest/views/oj.py | 10 ++++++---- judge/dispatcher.py | 27 +++++++++++---------------- problem/serializers.py | 1 + problem/views/oj.py | 17 ++++++++--------- submission/views/oj.py | 13 +++++++------ utils/constants.py | 2 +- 8 files changed, 44 insertions(+), 37 deletions(-) diff --git a/account/decorators.py b/account/decorators.py index f47c4a03..b3523664 100644 --- a/account/decorators.py +++ b/account/decorators.py @@ -92,7 +92,8 @@ def check_contest_permission(func): if not user.is_authenticated(): return self.error("Please login in first.") # password error - if ("contests" not in request.session) or (self.contest.id not in request.session["contests"]): + if ("accessible_contests" not in request.session) or \ + (self.contest.id not in request.session["accessible_contests"]): return self.error("Password is required.") return func(*args, **kwargs) diff --git a/contest/models.py b/contest/models.py index 42519988..eb1c88a8 100644 --- a/contest/models.py +++ b/contest/models.py @@ -45,6 +45,14 @@ class Contest(models.Model): def is_contest_admin(self, user): return user.is_authenticated() and (self.created_by == user or user.admin_type == AdminType.SUPER_ADMIN) + def check_oi_permission(self, user): + if self.status != ContestStatus.CONTEST_ENDED and self.real_time_rank == False: + if self.is_contest_admin(user): + return True + else: + return False + return True + class Meta: db_table = "contest" ordering = ("-create_time",) diff --git a/contest/views/oj.py b/contest/views/oj.py index 6811ef61..1e787ce2 100644 --- a/contest/views/oj.py +++ b/contest/views/oj.py @@ -67,7 +67,7 @@ class ContestPasswordVerifyAPI(APIView): # password verify OK. if "accessible_contests" not in request.session: request.session["accessible_contests"] = [] - request.session["contests"].append(contest.id) + request.session["accessible_contests"].append(contest.id) # https://docs.djangoproject.com/en/dev/topics/http/sessions/#when-sessions-are-saved request.session.modified = True return self.success(True) @@ -93,10 +93,12 @@ class ContestRankAPI(APIView): @check_contest_permission def get(self, request): - if self.contest.rule_type == ContestRuleType.ACM: - serializer = ACMContestRankSerializer - else: + if self.contest.rule_type == ContestRuleType.OI: + if not self.contest.check_oi_permission(request.user): + return self.error("You have no permission for ranks now") serializer = OIContestRankSerializer + else: + serializer = ACMContestRankSerializer cache_key = f"{CacheKey.contest_rank_cache}:{self.contest.id}" qs = cache.get(cache_key) diff --git a/judge/dispatcher.py b/judge/dispatcher.py index e457b29c..829a43d7 100644 --- a/judge/dispatcher.py +++ b/judge/dispatcher.py @@ -156,7 +156,6 @@ class JudgeDispatcher(object): with transaction.atomic(): # prepare problem and user_profile problem = Problem.objects.select_for_update().get(contest_id=self.contest_id, id=self.problem.id) - problem_info = problem.statistic_info user = User.objects.select_for_update().select_for_update("userprofile").get(id=self.submission.user_id) user_profile = user.userprofile if self.contest_id: @@ -165,25 +164,25 @@ class JudgeDispatcher(object): key = "problems" acm_problems_status = user_profile.acm_problems_status.get(key, {}) oi_problems_status = user_profile.oi_problems_status.get(key, {}) + problem_id = str(self.problem.id) + problem_info = problem.statistic_info + + # update problem info + result = str(self.submission.result) + problem_info[result] = problem_info.get(result, 0) + 1 + problem.statistic_info = problem_info # update submission and accepted number counter problem.submission_number += 1 if self.submission.result == JudgeStatus.ACCEPTED: problem.accepted_number += 1 - # only when submission is not in contest, we update user profile, - # in other words, users' submission in a contest will not be counted in user profile + # submission in a contest will not be counted in user profile if not self.contest_id: user_profile.submission_number += 1 if self.submission.result == JudgeStatus.ACCEPTED: user_profile.accepted_number += 1 - problem_id = str(self.problem.id) if self.problem.rule_type == ProblemRuleType.ACM: - # update acm problem info - result = str(self.submission.result) - problem_info[result] = problem_info.get(result, 0) + 1 - problem.statistic_info = problem_info - # update user_profile if problem_id not in acm_problems_status: acm_problems_status[problem_id] = {"status": self.submission.result, "_id": self.problem._id} @@ -193,12 +192,8 @@ class JudgeDispatcher(object): user_profile.acm_problems_status[key] = acm_problems_status else: - # update oi problem info - score = self.submission.statistic_info["score"] - problem_info[score] = problem_info.get(score, 0) + 1 - problem.statistic_info = problem_info - # update user_profile + score = self.submission.statistic_info["score"] if problem_id not in oi_problems_status: user_profile.add_score(score) oi_problems_status[problem_id] = {"status": self.submission.result, @@ -218,8 +213,8 @@ class JudgeDispatcher(object): def update_contest_rank(self): if self.contest_id and self.contest.status != ContestStatus.CONTEST_UNDERWAY: return - if self.contest.real_time_rank: - cache.delete(CacheKey.contest_rank_cache + str(self.contest_id)) + if self.contest.rule_type == ContestRuleType.OI or self.contest.real_time_rank: + cache.delete(f"{CacheKey.contest_rank_cache}:{self.contest.id}") with transaction.atomic(): if self.contest.rule_type == ContestRuleType.ACM: acm_rank, _ = ACMContestRank.objects.select_for_update(). \ diff --git a/problem/serializers.py b/problem/serializers.py index 4856b6d6..0e44710b 100644 --- a/problem/serializers.py +++ b/problem/serializers.py @@ -95,6 +95,7 @@ class ProblemAdminSerializer(BaseProblemSerializer): class ContestProblemAdminSerializer(BaseProblemSerializer): class Meta: model = Problem + fields = "__all__" class ProblemSerializer(BaseProblemSerializer): diff --git a/problem/views/oj.py b/problem/views/oj.py index 321b1bb9..3410ae1c 100644 --- a/problem/views/oj.py +++ b/problem/views/oj.py @@ -21,11 +21,10 @@ class ProblemAPI(APIView): oi_problems_status = profile.oi_problems_status.get("problems", {}) # paginate data results = queryset_values.get("results") - if results: + if results is not None: problems = results else: problems = [queryset_values,] - for problem in problems: if problem["rule_type"] == ProblemRuleType.ACM: problem["my_status"] = acm_problems_status.get(str(problem["id"]), {}).get("status") @@ -53,11 +52,7 @@ class ProblemAPI(APIView): # 按照标签筛选 tag_text = request.GET.get("tag") if tag_text: - try: - tag = ProblemTag.objects.get(name=tag_text) - except ProblemTag.DoesNotExist: - return self.error("The Tag does not exist.") - problems = tag.problem_set.all().filter(visible=True) + problems = problems.filter(tags__name=tag_text) # 搜索的情况 keyword = request.GET.get("keyword", "").strip() @@ -76,7 +71,11 @@ class ProblemAPI(APIView): class ContestProblemAPI(APIView): def _add_problem_status(self, request, queryset_values): - if request.user.is_authenticated() and self.contest.rule_type != ContestRuleType.OI: + print("checking") + if self.contest.rule_type == ContestRuleType.OI and not self.contest.check_oi_permission(request.user): + return + print('here') + if request.user.is_authenticated(): profile = request.user.userprofile if self.contest.rule_type == ContestRuleType.ACM: problems_status = profile.acm_problems_status.get("contest_problems", {}) @@ -96,7 +95,7 @@ class ContestProblemAPI(APIView): except Problem.DoesNotExist: return self.error("Problem does not exist.") problem_data = ContestProblemSerializer(problem).data - self._add_problem_status(request, problem_data) + self._add_problem_status(request, [problem_data,]) return self.success(problem_data) contest_problems = Problem.objects.select_related("created_by").filter(contest=self.contest, visible=True) diff --git a/submission/views/oj.py b/submission/views/oj.py index c4bfe21b..4cf6a8c4 100644 --- a/submission/views/oj.py +++ b/submission/views/oj.py @@ -84,14 +84,13 @@ class SubmissionAPI(APIView): @login_required def put(self, request): try: - submission = Submission.objects.select_related("problem")\ - .get(id=request.data["id"], contest__isnull=True) + submission = Submission.objects.select_related("problem").get(id=request.data["id"]) except Submission.DoesNotExist: return self.error("Submission doesn't exist") if not submission.check_user_permission(request.user, check_share=False): return self.error("No permission to share the submission") if submission.contest and submission.contest.status == ContestStatus.CONTEST_UNDERWAY: - return self.error("Can not share submission during a contest going") + return self.error("Can not share submission now") submission.shared = request.data["shared"] submission.save(update_fields=["shared"]) return self.success() @@ -130,7 +129,7 @@ class ContestSubmissionListAPI(APIView): return self.error("Limit is needed") contest = self.contest - if contest.rule_type == ContestRuleType.OI and not contest.is_contest_admin(request.user): + if not contest.check_oi_permission(request.user): return self.error("No permission for OI contest submissions") submissions = Submission.objects.filter(contest_id=contest.id).select_related("problem__created_by") @@ -154,9 +153,11 @@ class ContestSubmissionListAPI(APIView): submissions = submissions.filter(create_time__gte=contest.start_time) # 封榜的时候只能看到自己的提交 - if not contest.real_time_rank and not contest.is_contest_admin(request.user): - submissions = submissions.filter(user_id=request.user.id) + if contest.rule_type == ContestRuleType.ACM: + if not contest.real_time_rank and not contest.is_contest_admin(request.user): + submissions = submissions.filter(user_id=request.user.id) data = self.paginate_data(request, submissions) data["results"] = SubmissionListSerializer(data["results"], many=True, user=request.user).data return self.success(data) + diff --git a/utils/constants.py b/utils/constants.py index be7057a6..390d5685 100644 --- a/utils/constants.py +++ b/utils/constants.py @@ -23,6 +23,6 @@ class ContestRuleType(Choices): class CacheKey: waiting_queue = "waiting_queue" - contest_rank_cache = "contest_rank_cache_" + contest_rank_cache = "contest_rank_cache" website_config = "website_config" option = "option" From d8bf33a12d68e688c2df8a066f72a71bc1bc96f0 Mon Sep 17 00:00:00 2001 From: zema1 Date: Sat, 21 Oct 2017 10:51:35 +0800 Subject: [PATCH 8/8] fix tests --- account/tests.py | 15 ++++---- conf/tests.py | 4 -- contest/tests.py | 4 +- oj/settings.py | 6 +++ problem/tests.py | 84 ++++++++++++++++++++++++++++-------------- problem/views/admin.py | 1 + problem/views/oj.py | 2 - submission/tests.py | 10 +++-- 8 files changed, 80 insertions(+), 46 deletions(-) diff --git a/account/tests.py b/account/tests.py index 906ef28b..75515ced 100644 --- a/account/tests.py +++ b/account/tests.py @@ -11,6 +11,7 @@ from utils.shortcuts import rand_str from options.options import SysOptions from .models import AdminType, ProblemPermission, User +from utils.constants import ContestRuleType class PermissionDecoratorTest(APITestCase): @@ -134,7 +135,7 @@ class UserLoginAPITest(APITestCase): self.user.save() resp = self.client.post(self.login_url, data={"username": self.username, "password": self.password}) - self.assertDictEqual(resp.data, {"error": "error", "data": "Your account have been disabled"}) + self.assertDictEqual(resp.data, {"error": "error", "data": "Your account has been disabled"}) class CaptchaTest(APITestCase): @@ -159,7 +160,7 @@ class UserRegisterAPITest(CaptchaTest): def test_website_config_limit(self): SysOptions.allow_register = False resp = self.client.post(self.register_url, data=self.data) - self.assertDictEqual(resp.data, {"error": "error", "data": "Register have been disabled by admin"}) + self.assertDictEqual(resp.data, {"error": "error", "data": "Register function has been disabled by admin"}) def test_invalid_captcha(self): self.data["captcha"] = "****" @@ -220,7 +221,7 @@ class UserProfileAPITest(APITestCase): def test_get_profile_without_login(self): resp = self.client.get(self.url) - self.assertDictEqual(resp.data, {"error": None, "data": {}}) + self.assertDictEqual(resp.data, {"error": None, "data": None}) def test_get_profile(self): self.create_user("test", "test123") @@ -335,14 +336,14 @@ class ResetPasswordAPITest(CaptchaTest): def test_reset_password_with_invalid_token(self): self.data["token"] = "aaaaaaaaaaa" resp = self.client.post(self.url, data=self.data) - self.assertDictEqual(resp.data, {"error": "error", "data": "Token dose not exist"}) + self.assertDictEqual(resp.data, {"error": "error", "data": "Token does not exist"}) def test_reset_password_with_expired_token(self): user = User.objects.first() user.reset_password_token_expire_time = now() - timedelta(seconds=30) user.save() resp = self.client.post(self.url, data=self.data) - self.assertDictEqual(resp.data, {"error": "error", "data": "Token have expired"}) + self.assertDictEqual(resp.data, {"error": "error", "data": "Token has expired"}) class UserChangePasswordAPITest(CaptchaTest): @@ -473,14 +474,14 @@ class UserRankAPITest(APITestCase): profile2.save() def test_get_acm_rank(self): - resp = self.client.get(self.url, data={"rule": "acm"}) + resp = self.client.get(self.url, data={"rule": ContestRuleType.ACM}) self.assertSuccess(resp) data = resp.data["data"] self.assertEqual(data[0]["user"]["username"], "test1") self.assertEqual(data[1]["user"]["username"], "test2") def test_get_oi_rank(self): - resp = self.client.get(self.url, data={"rule": "oi"}) + resp = self.client.get(self.url, data={"rule": ContestRuleType.OI}) self.assertSuccess(resp) data = resp.data["data"] self.assertEqual(data[0]["user"]["username"], "test2") diff --git a/conf/tests.py b/conf/tests.py index eff8cfde..5906b078 100644 --- a/conf/tests.py +++ b/conf/tests.py @@ -4,7 +4,6 @@ from django.utils import timezone from options.options import SysOptions from utils.api.tests import APITestCase -from utils.cache import default_cache from utils.constants import CacheKey from .models import JudgeServer @@ -70,9 +69,6 @@ class WebsiteConfigAPITest(APITestCase): resp = self.client.get(url) self.assertSuccess(resp) - def tearDown(self): - default_cache.delete(CacheKey.website_config) - class JudgeServerHeartbeatTest(APITestCase): def setUp(self): diff --git a/contest/tests.py b/contest/tests.py index c481becb..df05df87 100644 --- a/contest/tests.py +++ b/contest/tests.py @@ -79,7 +79,7 @@ class ContestAPITest(APITestCase): self.create_user("test", "test123") url = self.reverse("contest_password_api") resp = self.client.post(url, {"contest_id": contest_id, "password": "error_password"}) - self.assertDictEqual(resp.data, {"error": "error", "data": "Password doesn't match."}) + self.assertDictEqual(resp.data, {"error": "error", "data": "Wrong password"}) resp = self.client.post(url, {"contest_id": contest_id, "password": DEFAULT_CONTEST_DATA["password"]}) self.assertSuccess(resp) @@ -89,7 +89,7 @@ class ContestAPITest(APITestCase): self.create_user("test", "test123") url = self.reverse("contest_access_api") resp = self.client.get(url + "?contest_id=" + str(contest_id)) - self.assertFalse(resp.data["data"]["Access"]) + self.assertFalse(resp.data["data"]["access"]) password_url = self.reverse("contest_password_api") resp = self.client.post(password_url, {"contest_id": contest_id, "password": DEFAULT_CONTEST_DATA["password"]}) diff --git a/oj/settings.py b/oj/settings.py index ac972c79..950c426d 100644 --- a/oj/settings.py +++ b/oj/settings.py @@ -163,6 +163,12 @@ LOGGING = { } }, } +REST_FRAMEWORK = { + 'TEST_REQUEST_DEFAULT_FORMAT': 'json', + 'DEFAULT_RENDERER_CLASSES': ( + 'rest_framework.renderers.JSONRenderer', + ) +} REDIS_URL = "redis://127.0.0.1:6379" diff --git a/problem/tests.py b/problem/tests.py index 4e081ff7..7efa668b 100644 --- a/problem/tests.py +++ b/problem/tests.py @@ -1,4 +1,5 @@ import copy +import hashlib import os import shutil from datetime import timedelta @@ -9,6 +10,7 @@ from django.conf import settings from utils.api.tests import APITestCase from .models import ProblemTag +from .models import Problem, ProblemRuleType from .views.admin import TestCaseUploadAPI from contest.models import Contest from contest.tests import DEFAULT_CONTEST_DATA @@ -23,6 +25,40 @@ DEFAULT_PROBLEM_DATA = {"_id": "A-110", "title": "test", "description": "

test "input_size": 0, "score": 0}], "rule_type": "ACM", "hint": "

test

", "source": "test"} +class ProblemCreateTestBase(APITestCase): + @staticmethod + def add_problem(problem_data, created_by): + data = copy.deepcopy(problem_data) + if data["spj"]: + if not data["spj_language"] or not data["spj_code"]: + raise ValueError("Invalid spj") + data["spj_version"] = hashlib.md5((data["spj_language"] + ":" + data["spj_code"]).encode("utf-8")).hexdigest() + else: + data["spj_language"] = None + data["spj_code"] = None + if data["rule_type"] == ProblemRuleType.OI: + total_score = 0 + for item in data["test_case_score"]: + if item["score"] <= 0: + raise ValueError("invalid score") + else: + total_score += item["score"] + data["total_score"] = total_score + data["created_by"] = created_by + tags = data.pop("tags") + + data["languages"] = list(data["languages"]) + + problem = Problem.objects.create(**data) + + for item in tags: + try: + tag = ProblemTag.objects.get(name=item) + except ProblemTag.DoesNotExist: + tag = ProblemTag.objects.create(name=item) + problem.tags.add(tag) + return problem + class ProblemTagListAPITest(APITestCase): def test_get_tag_list(self): @@ -96,7 +132,7 @@ class ProblemAdminAPITest(APITestCase): def setUp(self): self.url = self.reverse("problem_admin_api") self.create_super_admin() - self.data = DEFAULT_PROBLEM_DATA + self.data = copy.deepcopy(DEFAULT_PROBLEM_DATA) def test_create_problem(self): resp = self.client.post(self.url, data=self.data) @@ -138,23 +174,19 @@ class ProblemAdminAPITest(APITestCase): self.assertSuccess(resp) -class ProblemAPITest(APITestCase): +class ProblemAPITest(ProblemCreateTestBase): def setUp(self): self.url = self.reverse("problem_api") - self.create_admin() - - def create_problem(self): - url = self.reverse("problem_admin_api") - return self.client.post(url, data=DEFAULT_PROBLEM_DATA) + admin = self.create_admin(login=False) + self.problem = self.add_problem(DEFAULT_PROBLEM_DATA, admin) + self.create_user("test", "test123") def test_get_problem_list(self): - self.create_problem() resp = self.client.get(f"{self.url}?limit=10") self.assertSuccess(resp) def get_one_problem(self): - problem_id = self.create_problem().data["data"]["_id"] - resp = self.client.get(self.url + "?id=" + str(problem_id)) + resp = self.client.get(self.url + "?id=" + self.problem._id) self.assertSuccess(resp) @@ -169,51 +201,49 @@ class ContestProblemAdminTest(APITestCase): def test_create_contest_problem(self): contest = self.create_contest() - data = DEFAULT_PROBLEM_DATA + data = copy.deepcopy(DEFAULT_PROBLEM_DATA) data["contest_id"] = contest.data["data"]["id"] resp = self.client.post(self.url, data=data) self.assertSuccess(resp) - return resp + return contest, resp def test_get_contest_problem(self): - contest = self.test_create_contest_problem() + contest, contest_problem = self.test_create_contest_problem() contest_id = contest.data["data"]["id"] resp = self.client.get(self.url + "?contest_id=" + str(contest_id)) self.assertSuccess(resp) self.assertEqual(len(resp.data["data"]), 1) def test_get_one_contest_problem(self): - contest = self.test_create_contest_problem() + contest, contest_problem = self.test_create_contest_problem() contest_id = contest.data["data"]["id"] - resp = self.client.get(self.url + "?id=" + str(contest_id)) + problem_id = contest_problem.data["data"]["id"] + resp = self.client.get(f"{self.url}?contest_id={contest_id}&id={problem_id}") self.assertSuccess(resp) -class ContestProblemTest(APITestCase): +class ContestProblemTest(ProblemCreateTestBase): def setUp(self): - self.url = self.reverse("contest_problem_api") - self.create_admin() - + admin = self.create_admin() url = self.reverse("contest_admin_api") contest_data = copy.deepcopy(DEFAULT_CONTEST_DATA) contest_data["password"] = "" contest_data["start_time"] = contest_data["start_time"] + timedelta(hours=1) self.contest = self.client.post(url, data=contest_data).data["data"] + self.problem = self.add_problem(DEFAULT_PROBLEM_DATA, admin) + self.problem.contest_id = self.contest["id"] + self.problem.save() + self.url = self.reverse("contest_problem_api") - problem_data = copy.deepcopy(DEFAULT_PROBLEM_DATA) - problem_data["contest"] = self.contest["id"] - url = self.reverse("contest_problem_admin_api") - self.problem = self.client.post(url, problem_data).data["data"] - - def test_get_contest_problem_list(self): + def test_admin_get_contest_problem_list(self): contest_id = self.contest["id"] resp = self.client.get(self.url + "?contest_id=" + str(contest_id)) self.assertSuccess(resp) self.assertEqual(len(resp.data["data"]), 1) - def test_get_one_contest_problem(self): + def test_admin_get_one_contest_problem(self): contest_id = self.contest["id"] - problem_id = self.problem["_id"] + problem_id = self.problem._id resp = self.client.get("{}?contest_id={}&problem_id={}".format(self.url, contest_id, problem_id)) self.assertSuccess(resp) diff --git a/problem/views/admin.py b/problem/views/admin.py index 572e2fb4..00f46c4d 100644 --- a/problem/views/admin.py +++ b/problem/views/admin.py @@ -223,6 +223,7 @@ class ProblemAPI(APIView): data["total_score"] = total_score # todo check filename and score info tags = data.pop("tags") + data["languages"] = list(data["languages"]) for k, v in data.items(): setattr(problem, k, v) diff --git a/problem/views/oj.py b/problem/views/oj.py index 3410ae1c..330e9332 100644 --- a/problem/views/oj.py +++ b/problem/views/oj.py @@ -71,10 +71,8 @@ class ProblemAPI(APIView): class ContestProblemAPI(APIView): def _add_problem_status(self, request, queryset_values): - print("checking") if self.contest.rule_type == ContestRuleType.OI and not self.contest.check_oi_permission(request.user): return - print('here') if request.user.is_authenticated(): profile = request.user.userprofile if self.contest.rule_type == ContestRuleType.ACM: diff --git a/submission/tests.py b/submission/tests.py index c7348827..fdcd12ec 100644 --- a/submission/tests.py +++ b/submission/tests.py @@ -32,14 +32,16 @@ class SubmissionPrepare(APITestCase): def _create_problem_and_submission(self): user = self.create_admin("test", "test123", login=False) problem_data = deepcopy(DEFAULT_PROBLEM_DATA) - problem_data.pop("tags") + tags = problem_data.pop("tags") problem_data["created_by"] = user self.problem = Problem.objects.create(**problem_data) - for tag in DEFAULT_PROBLEM_DATA["tags"]: + for tag in tags: tag = ProblemTag.objects.create(name=tag) self.problem.tags.add(tag) self.problem.save() - self.submission = Submission.objects.create(**DEFAULT_SUBMISSION_DATA) + self.submission_data = deepcopy(DEFAULT_SUBMISSION_DATA) + self.submission_data["problem_id"] = self.problem.id + self.submission = Submission.objects.create(**self.submission_data) class SubmissionListTest(SubmissionPrepare): @@ -61,6 +63,6 @@ class SubmissionAPITest(SubmissionPrepare): self.url = self.reverse("submission_api") def test_create_submission(self, judge_task): - resp = self.client.post(self.url, DEFAULT_SUBMISSION_DATA) + resp = self.client.post(self.url, self.submission_data) self.assertSuccess(resp) judge_task.assert_called()