mirror of
https://github.com/QingdaoU/OnlineJudge.git
synced 2024-09-21 00:13:18 +00:00
Merge branch 'opt'
This commit is contained in:
commit
c1d099ed45
@ -92,7 +92,8 @@ def check_contest_permission(func):
|
||||
if not user.is_authenticated():
|
||||
return self.error("Please login in first.")
|
||||
# password error
|
||||
if ("contests" not in request.session) or (self.contest.id not in request.session["contests"]):
|
||||
if ("accessible_contests" not in request.session) or \
|
||||
(self.contest.id not in request.session["accessible_contests"]):
|
||||
return self.error("Password is required.")
|
||||
|
||||
return func(*args, **kwargs)
|
||||
|
@ -1,38 +1,19 @@
|
||||
import time
|
||||
import pytz
|
||||
|
||||
from django.contrib import auth
|
||||
from django.utils import timezone
|
||||
from django.utils.translation import ugettext as _
|
||||
from django.db import connection
|
||||
from django.utils.timezone import now
|
||||
from django.utils.deprecation import MiddlewareMixin
|
||||
|
||||
from utils.api import JSONResponse
|
||||
|
||||
|
||||
class SessionSecurityMiddleware(MiddlewareMixin):
|
||||
def process_request(self, request):
|
||||
if request.user.is_authenticated():
|
||||
if "last_activity" in request.session and request.user.is_admin_role():
|
||||
# 24 hours passed since last visit, 86400 = 24 * 60 * 60
|
||||
if time.time() - request.session["last_activity"] >= 86400:
|
||||
auth.logout(request)
|
||||
return JSONResponse.response({"error": "login-required", "data": _("Please login in first")})
|
||||
request.session["last_activity"] = time.time()
|
||||
|
||||
|
||||
class SessionRecordMiddleware(MiddlewareMixin):
|
||||
def process_request(self, request):
|
||||
if request.user.is_authenticated():
|
||||
session = request.session
|
||||
ip = request.META.get("REMOTE_ADDR", "")
|
||||
user_agent = request.META.get("HTTP_USER_AGENT", "")
|
||||
_ip = session.setdefault("ip", ip)
|
||||
_user_agent = session.setdefault("user_agent", user_agent)
|
||||
if ip != _ip or user_agent != _user_agent:
|
||||
session.modified = True
|
||||
session["user_agent"] = request.META.get("HTTP_USER_AGENT", "")
|
||||
session["ip"] = request.META.get("HTTP_X_REAL_IP", "UNKNOWN IP")
|
||||
session["last_activity"] = now()
|
||||
user_sessions = request.user.session_keys
|
||||
if request.session.session_key not in user_sessions:
|
||||
if session.session_key not in user_sessions:
|
||||
user_sessions.append(session.session_key)
|
||||
request.user.save()
|
||||
|
||||
@ -42,13 +23,7 @@ class AdminRoleRequiredMiddleware(MiddlewareMixin):
|
||||
path = request.path_info
|
||||
if path.startswith("/admin/") or path.startswith("/api/admin/"):
|
||||
if not (request.user.is_authenticated() and request.user.is_admin_role()):
|
||||
return JSONResponse.response({"error": "login-required", "data": _("Please login in first")})
|
||||
|
||||
|
||||
class TimezoneMiddleware(MiddlewareMixin):
|
||||
def process_request(self, request):
|
||||
if request.user.is_authenticated():
|
||||
timezone.activate(pytz.timezone(request.user.userprofile.time_zone))
|
||||
return JSONResponse.response({"error": "login-required", "data": "Please login in first"})
|
||||
|
||||
|
||||
class LogSqlMiddleware(MiddlewareMixin):
|
||||
|
@ -50,7 +50,7 @@ class Migration(migrations.Migration):
|
||||
fields=[
|
||||
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
|
||||
('problems_status', jsonfield.fields.JSONField(default={})),
|
||||
('avatar', models.CharField(default=account.models._default_avatar, max_length=50)),
|
||||
('avatar', models.CharField(default="default.png", max_length=50)),
|
||||
('blog', models.URLField(blank=True, null=True)),
|
||||
('mood', models.CharField(blank=True, max_length=200, null=True)),
|
||||
('accepted_problem_number', models.IntegerField(default=0)),
|
||||
|
105
account/migrations/0008_auto_20171011_1214.py
Normal file
105
account/migrations/0008_auto_20171011_1214.py
Normal file
@ -0,0 +1,105 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Generated by Django 1.11.4 on 2017-10-11 12:14
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import django.contrib.postgres.fields.jsonb
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('account', '0007_auto_20170920_0254'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.RemoveField(
|
||||
model_name='userprofile',
|
||||
name='language',
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name='user',
|
||||
name='admin_type',
|
||||
field=models.CharField(default='Regular User', max_length=32),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name='user',
|
||||
name='auth_token',
|
||||
field=models.CharField(max_length=32, null=True),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name='user',
|
||||
name='email',
|
||||
field=models.EmailField(max_length=64, null=True),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name='user',
|
||||
name='open_api_appkey',
|
||||
field=models.CharField(max_length=32, null=True),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name='user',
|
||||
name='problem_permission',
|
||||
field=models.CharField(default='None', max_length=32),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name='user',
|
||||
name='reset_password_token',
|
||||
field=models.CharField(max_length=32, null=True),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name='user',
|
||||
name='session_keys',
|
||||
field=django.contrib.postgres.fields.jsonb.JSONField(default=list),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name='user',
|
||||
name='tfa_token',
|
||||
field=models.CharField(max_length=32, null=True),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name='user',
|
||||
name='username',
|
||||
field=models.CharField(max_length=32, unique=True),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name='userprofile',
|
||||
name='acm_problems_status',
|
||||
field=django.contrib.postgres.fields.jsonb.JSONField(default=dict),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name='userprofile',
|
||||
name='avatar',
|
||||
field=models.CharField(default='/static/avatar/default.png', max_length=256),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name='userprofile',
|
||||
name='github',
|
||||
field=models.CharField(blank=True, max_length=64, null=True),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name='userprofile',
|
||||
name='major',
|
||||
field=models.CharField(blank=True, max_length=64, null=True),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name='userprofile',
|
||||
name='mood',
|
||||
field=models.CharField(blank=True, max_length=256, null=True),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name='userprofile',
|
||||
name='oi_problems_status',
|
||||
field=django.contrib.postgres.fields.jsonb.JSONField(default=dict),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name='userprofile',
|
||||
name='real_name',
|
||||
field=models.CharField(blank=True, max_length=32, null=True),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name='userprofile',
|
||||
name='school',
|
||||
field=models.CharField(blank=True, max_length=64, null=True),
|
||||
),
|
||||
]
|
@ -1,7 +1,7 @@
|
||||
from django.contrib.auth.models import AbstractBaseUser
|
||||
from django.conf import settings
|
||||
from django.db import models
|
||||
from jsonfield import JSONField
|
||||
from utils.models import JSONField
|
||||
|
||||
|
||||
class AdminType(object):
|
||||
@ -24,22 +24,22 @@ class UserManager(models.Manager):
|
||||
|
||||
|
||||
class User(AbstractBaseUser):
|
||||
username = models.CharField(max_length=30, unique=True)
|
||||
email = models.EmailField(max_length=254, null=True)
|
||||
username = models.CharField(max_length=32, unique=True)
|
||||
email = models.EmailField(max_length=64, null=True)
|
||||
create_time = models.DateTimeField(auto_now_add=True, null=True)
|
||||
# One of UserType
|
||||
admin_type = models.CharField(max_length=24, default=AdminType.REGULAR_USER)
|
||||
problem_permission = models.CharField(max_length=24, default=ProblemPermission.NONE)
|
||||
reset_password_token = models.CharField(max_length=40, null=True)
|
||||
admin_type = models.CharField(max_length=32, default=AdminType.REGULAR_USER)
|
||||
problem_permission = models.CharField(max_length=32, default=ProblemPermission.NONE)
|
||||
reset_password_token = models.CharField(max_length=32, null=True)
|
||||
reset_password_token_expire_time = models.DateTimeField(null=True)
|
||||
# SSO auth token
|
||||
auth_token = models.CharField(max_length=40, null=True)
|
||||
auth_token = models.CharField(max_length=32, null=True)
|
||||
two_factor_auth = models.BooleanField(default=False)
|
||||
tfa_token = models.CharField(max_length=40, null=True)
|
||||
session_keys = JSONField(default=[])
|
||||
tfa_token = models.CharField(max_length=32, null=True)
|
||||
session_keys = JSONField(default=list)
|
||||
# open api key
|
||||
open_api = models.BooleanField(default=False)
|
||||
open_api_appkey = models.CharField(max_length=35, null=True)
|
||||
open_api_appkey = models.CharField(max_length=32, null=True)
|
||||
is_disabled = models.BooleanField(default=False)
|
||||
|
||||
USERNAME_FIELD = "username"
|
||||
@ -63,26 +63,34 @@ class User(AbstractBaseUser):
|
||||
db_table = "user"
|
||||
|
||||
|
||||
def _default_avatar():
|
||||
return f"/{settings.IMAGE_UPLOAD_DIR}/default.png"
|
||||
|
||||
|
||||
class UserProfile(models.Model):
|
||||
user = models.OneToOneField(User)
|
||||
# Store user problem solution status with json string format
|
||||
# {problems: {1: JudgeStatus.ACCEPTED}, contest_problems: {1: JudgeStatus.ACCEPTED}}, record problem_id and status
|
||||
acm_problems_status = JSONField(default={})
|
||||
# {problems: {1: 33}, contest_problems: {1: 44}, record problem_id and score
|
||||
oi_problems_status = JSONField(default={})
|
||||
# acm_problems_status examples:
|
||||
# {
|
||||
# "problems": {
|
||||
# "1": {
|
||||
# "status": JudgeStatus.ACCEPTED,
|
||||
# "_id": "1000"
|
||||
# }
|
||||
# },
|
||||
# "contest_problems": {
|
||||
# "1": {
|
||||
# "status": JudgeStatus.ACCEPTED,
|
||||
# "_id": "1000"
|
||||
# }
|
||||
# }
|
||||
# }
|
||||
acm_problems_status = JSONField(default=dict)
|
||||
# like acm_problems_status, merely add "score" field
|
||||
oi_problems_status = JSONField(default=dict)
|
||||
|
||||
real_name = models.CharField(max_length=30, blank=True, null=True)
|
||||
avatar = models.CharField(max_length=50, default=_default_avatar())
|
||||
real_name = models.CharField(max_length=32, blank=True, null=True)
|
||||
avatar = models.CharField(max_length=256, default=f"/{settings.IMAGE_UPLOAD_DIR}/default.png")
|
||||
blog = models.URLField(blank=True, null=True)
|
||||
mood = models.CharField(max_length=200, blank=True, null=True)
|
||||
github = models.CharField(max_length=50, blank=True, null=True)
|
||||
school = models.CharField(max_length=200, blank=True, null=True)
|
||||
major = models.CharField(max_length=200, blank=True, null=True)
|
||||
language = models.CharField(max_length=32, blank=True, null=True)
|
||||
mood = models.CharField(max_length=256, blank=True, null=True)
|
||||
github = models.CharField(max_length=64, blank=True, null=True)
|
||||
school = models.CharField(max_length=64, blank=True, null=True)
|
||||
major = models.CharField(max_length=64, blank=True, null=True)
|
||||
# for ACM
|
||||
accepted_number = models.IntegerField(default=0)
|
||||
# for OI
|
||||
|
@ -6,27 +6,26 @@ from .models import AdminType, ProblemPermission, User, UserProfile
|
||||
|
||||
|
||||
class UserLoginSerializer(serializers.Serializer):
|
||||
username = serializers.CharField(max_length=30)
|
||||
password = serializers.CharField(max_length=30)
|
||||
tfa_code = serializers.CharField(min_length=6, max_length=6, required=False, allow_null=True)
|
||||
username = serializers.CharField()
|
||||
password = serializers.CharField()
|
||||
tfa_code = serializers.CharField(required=False, allow_null=True)
|
||||
|
||||
|
||||
class UsernameOrEmailCheckSerializer(serializers.Serializer):
|
||||
username = serializers.CharField(max_length=30, required=False)
|
||||
email = serializers.EmailField(max_length=30, required=False)
|
||||
username = serializers.CharField(required=False)
|
||||
email = serializers.EmailField(required=False)
|
||||
|
||||
|
||||
class UserRegisterSerializer(serializers.Serializer):
|
||||
username = serializers.CharField(max_length=30)
|
||||
password = serializers.CharField(max_length=30, min_length=6)
|
||||
email = serializers.EmailField(max_length=30)
|
||||
captcha = serializers.CharField(max_length=4, min_length=1)
|
||||
username = serializers.CharField(max_length=32)
|
||||
password = serializers.CharField(min_length=6)
|
||||
email = serializers.EmailField(max_length=64)
|
||||
captcha = serializers.CharField()
|
||||
|
||||
|
||||
class UserChangePasswordSerializer(serializers.Serializer):
|
||||
old_password = serializers.CharField()
|
||||
new_password = serializers.CharField(max_length=30, min_length=6)
|
||||
captcha = serializers.CharField(max_length=4, min_length=4)
|
||||
new_password = serializers.CharField(min_length=6)
|
||||
|
||||
|
||||
class UserSerializer(serializers.ModelSerializer):
|
||||
@ -46,6 +45,7 @@ class UserProfileSerializer(serializers.ModelSerializer):
|
||||
|
||||
class Meta:
|
||||
model = UserProfile
|
||||
fields = "__all__"
|
||||
|
||||
|
||||
class UserInfoSerializer(serializers.ModelSerializer):
|
||||
@ -58,9 +58,9 @@ class UserInfoSerializer(serializers.ModelSerializer):
|
||||
|
||||
class EditUserSerializer(serializers.Serializer):
|
||||
id = serializers.IntegerField()
|
||||
username = serializers.CharField(max_length=30)
|
||||
password = serializers.CharField(max_length=30, min_length=6, allow_blank=True, required=False, default=None)
|
||||
email = serializers.EmailField(max_length=254)
|
||||
username = serializers.CharField(max_length=32)
|
||||
password = serializers.CharField(min_length=6, allow_blank=True, required=False, default=None)
|
||||
email = serializers.EmailField(max_length=64)
|
||||
admin_type = serializers.ChoiceField(choices=(AdminType.REGULAR_USER, AdminType.ADMIN, AdminType.SUPER_ADMIN))
|
||||
problem_permission = serializers.ChoiceField(choices=(ProblemPermission.NONE, ProblemPermission.OWN,
|
||||
ProblemPermission.ALL))
|
||||
@ -70,29 +70,29 @@ class EditUserSerializer(serializers.Serializer):
|
||||
|
||||
|
||||
class EditUserProfileSerializer(serializers.Serializer):
|
||||
real_name = serializers.CharField(max_length=30, allow_blank=True)
|
||||
avatar = serializers.CharField(max_length=100, allow_blank=True, required=False)
|
||||
blog = serializers.URLField(allow_blank=True, required=False)
|
||||
mood = serializers.CharField(max_length=200, allow_blank=True, required=False)
|
||||
github = serializers.CharField(max_length=50, allow_blank=True, required=False)
|
||||
school = serializers.CharField(max_length=200, allow_blank=True, required=False)
|
||||
major = serializers.CharField(max_length=200, allow_blank=True, required=False)
|
||||
real_name = serializers.CharField(max_length=32, allow_blank=True)
|
||||
avatar = serializers.CharField(max_length=256, allow_blank=True, required=False)
|
||||
blog = serializers.URLField(max_length=256, allow_blank=True, required=False)
|
||||
mood = serializers.CharField(max_length=256, allow_blank=True, required=False)
|
||||
github = serializers.CharField(max_length=64, allow_blank=True, required=False)
|
||||
school = serializers.CharField(max_length=64, allow_blank=True, required=False)
|
||||
major = serializers.CharField(max_length=64, allow_blank=True, required=False)
|
||||
|
||||
|
||||
class ApplyResetPasswordSerializer(serializers.Serializer):
|
||||
email = serializers.EmailField()
|
||||
captcha = serializers.CharField(max_length=4, min_length=4)
|
||||
captcha = serializers.CharField()
|
||||
|
||||
|
||||
class ResetPasswordSerializer(serializers.Serializer):
|
||||
token = serializers.CharField(min_length=1, max_length=40)
|
||||
password = serializers.CharField(min_length=6, max_length=30)
|
||||
captcha = serializers.CharField(max_length=4, min_length=4)
|
||||
token = serializers.CharField()
|
||||
password = serializers.CharField(min_length=6)
|
||||
captcha = serializers.CharField()
|
||||
|
||||
|
||||
class SSOSerializer(serializers.Serializer):
|
||||
appkey = serializers.CharField(max_length=35)
|
||||
token = serializers.CharField(max_length=40)
|
||||
appkey = serializers.CharField()
|
||||
token = serializers.CharField()
|
||||
|
||||
|
||||
class TwoFactorAuthCodeSerializer(serializers.Serializer):
|
||||
|
@ -1,6 +1,31 @@
|
||||
from celery import shared_task
|
||||
import logging
|
||||
|
||||
from utils.shortcuts import send_email
|
||||
from celery import shared_task
|
||||
from envelopes import Envelope
|
||||
|
||||
from options.options import SysOptions
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def send_email(from_name, to_email, to_name, subject, content):
|
||||
smtp = SysOptions.smtp_config
|
||||
if not smtp:
|
||||
return
|
||||
envlope = Envelope(from_addr=(smtp["email"], from_name),
|
||||
to_addr=(to_email, to_name),
|
||||
subject=subject,
|
||||
html_body=content)
|
||||
try:
|
||||
envlope.send(smtp["server"],
|
||||
login=smtp["email"],
|
||||
password=smtp["password"],
|
||||
port=smtp["port"],
|
||||
tls=smtp["tls"])
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
return False
|
||||
|
||||
|
||||
@shared_task
|
||||
|
@ -8,11 +8,10 @@ from otpauth import OtpAuth
|
||||
|
||||
from utils.api.tests import APIClient, APITestCase
|
||||
from utils.shortcuts import rand_str
|
||||
from utils.cache import default_cache
|
||||
from utils.constants import CacheKey
|
||||
from options.options import SysOptions
|
||||
|
||||
from .models import AdminType, ProblemPermission, User
|
||||
from conf.models import WebsiteConfig
|
||||
from utils.constants import ContestRuleType
|
||||
|
||||
|
||||
class PermissionDecoratorTest(APITestCase):
|
||||
@ -136,7 +135,7 @@ class UserLoginAPITest(APITestCase):
|
||||
self.user.save()
|
||||
resp = self.client.post(self.login_url, data={"username": self.username,
|
||||
"password": self.password})
|
||||
self.assertDictEqual(resp.data, {"error": "error", "data": "Your account have been disabled"})
|
||||
self.assertDictEqual(resp.data, {"error": "error", "data": "Your account has been disabled"})
|
||||
|
||||
|
||||
class CaptchaTest(APITestCase):
|
||||
@ -157,15 +156,11 @@ class UserRegisterAPITest(CaptchaTest):
|
||||
self.data = {"username": "test_user", "password": "testuserpassword",
|
||||
"real_name": "real_name", "email": "test@qduoj.com",
|
||||
"captcha": self._set_captcha(self.client.session)}
|
||||
# clea cache in redis
|
||||
default_cache.delete(CacheKey.website_config)
|
||||
|
||||
def test_website_config_limit(self):
|
||||
website = WebsiteConfig.objects.create()
|
||||
website.allow_register = False
|
||||
website.save()
|
||||
SysOptions.allow_register = False
|
||||
resp = self.client.post(self.register_url, data=self.data)
|
||||
self.assertDictEqual(resp.data, {"error": "error", "data": "Register have been disabled by admin"})
|
||||
self.assertDictEqual(resp.data, {"error": "error", "data": "Register function has been disabled by admin"})
|
||||
|
||||
def test_invalid_captcha(self):
|
||||
self.data["captcha"] = "****"
|
||||
@ -226,7 +221,7 @@ class UserProfileAPITest(APITestCase):
|
||||
|
||||
def test_get_profile_without_login(self):
|
||||
resp = self.client.get(self.url)
|
||||
self.assertDictEqual(resp.data, {"error": None, "data": {}})
|
||||
self.assertDictEqual(resp.data, {"error": None, "data": None})
|
||||
|
||||
def test_get_profile(self):
|
||||
self.create_user("test", "test123")
|
||||
@ -247,7 +242,6 @@ class TwoFactorAuthAPITest(APITestCase):
|
||||
def setUp(self):
|
||||
self.url = self.reverse("two_factor_auth_api")
|
||||
self.create_user("test", "test123")
|
||||
self.create_website_config()
|
||||
|
||||
def _get_tfa_code(self):
|
||||
user = User.objects.first()
|
||||
@ -295,7 +289,6 @@ class ApplyResetPasswordAPITest(CaptchaTest):
|
||||
user.email = "test@oj.com"
|
||||
user.save()
|
||||
self.url = self.reverse("apply_reset_password_api")
|
||||
self.create_website_config()
|
||||
self.data = {"email": "test@oj.com", "captcha": self._set_captcha(self.client.session)}
|
||||
|
||||
def _refresh_captcha(self):
|
||||
@ -343,14 +336,14 @@ class ResetPasswordAPITest(CaptchaTest):
|
||||
def test_reset_password_with_invalid_token(self):
|
||||
self.data["token"] = "aaaaaaaaaaa"
|
||||
resp = self.client.post(self.url, data=self.data)
|
||||
self.assertDictEqual(resp.data, {"error": "error", "data": "Token dose not exist"})
|
||||
self.assertDictEqual(resp.data, {"error": "error", "data": "Token does not exist"})
|
||||
|
||||
def test_reset_password_with_expired_token(self):
|
||||
user = User.objects.first()
|
||||
user.reset_password_token_expire_time = now() - timedelta(seconds=30)
|
||||
user.save()
|
||||
resp = self.client.post(self.url, data=self.data)
|
||||
self.assertDictEqual(resp.data, {"error": "error", "data": "Token have expired"})
|
||||
self.assertDictEqual(resp.data, {"error": "error", "data": "Token has expired"})
|
||||
|
||||
|
||||
class UserChangePasswordAPITest(CaptchaTest):
|
||||
@ -481,14 +474,14 @@ class UserRankAPITest(APITestCase):
|
||||
profile2.save()
|
||||
|
||||
def test_get_acm_rank(self):
|
||||
resp = self.client.get(self.url, data={"rule": "acm"})
|
||||
resp = self.client.get(self.url, data={"rule": ContestRuleType.ACM})
|
||||
self.assertSuccess(resp)
|
||||
data = resp.data["data"]
|
||||
self.assertEqual(data[0]["user"]["username"], "test1")
|
||||
self.assertEqual(data[1]["user"]["username"], "test2")
|
||||
|
||||
def test_get_oi_rank(self):
|
||||
resp = self.client.get(self.url, data={"rule": "oi"})
|
||||
resp = self.client.get(self.url, data={"rule": ContestRuleType.OI})
|
||||
self.assertSuccess(resp)
|
||||
data = resp.data["data"]
|
||||
self.assertEqual(data[0]["user"]["username"], "test2")
|
||||
|
@ -3,7 +3,7 @@ from django.conf.urls import url
|
||||
from ..views.oj import (ApplyResetPasswordAPI, ResetPasswordAPI,
|
||||
UserChangePasswordAPI, UserRegisterAPI,
|
||||
UserLoginAPI, UserLogoutAPI, UsernameOrEmailCheck,
|
||||
SSOAPI, AvatarUploadAPI, TwoFactorAuthAPI, UserProfileAPI,
|
||||
AvatarUploadAPI, TwoFactorAuthAPI, UserProfileAPI,
|
||||
UserRankAPI, CheckTFARequiredAPI, SessionManagementAPI)
|
||||
|
||||
from utils.captcha.views import CaptchaAPIView
|
||||
@ -19,7 +19,6 @@ urlpatterns = [
|
||||
url(r"^check_username_or_email", UsernameOrEmailCheck.as_view(), name="check_username_or_email"),
|
||||
url(r"^profile/?$", UserProfileAPI.as_view(), name="user_profile_api"),
|
||||
url(r"^upload_avatar/?$", AvatarUploadAPI.as_view(), name="avatar_upload_api"),
|
||||
url(r"^sso/?$", SSOAPI.as_view(), name="sso_api"),
|
||||
url(r"^tfa_required/?$", CheckTFARequiredAPI.as_view(), name="tfa_required_check"),
|
||||
url(r"^two_factor_auth/?$", TwoFactorAuthAPI.as_view(), name="two_factor_auth_api"),
|
||||
url(r"^user_rank/?$", UserRankAPI.as_view(), name="user_rank_api"),
|
||||
|
@ -1,32 +1,28 @@
|
||||
import os
|
||||
import qrcode
|
||||
import pickle
|
||||
from datetime import timedelta
|
||||
from otpauth import OtpAuth
|
||||
from importlib import import_module
|
||||
|
||||
import qrcode
|
||||
from django.conf import settings
|
||||
from django.contrib import auth
|
||||
from importlib import import_module
|
||||
from django.template.loader import render_to_string
|
||||
from django.utils.decorators import method_decorator
|
||||
from django.utils.timezone import now
|
||||
from django.views.decorators.csrf import ensure_csrf_cookie
|
||||
from django.utils.decorators import method_decorator
|
||||
from django.template.loader import render_to_string
|
||||
from otpauth import OtpAuth
|
||||
|
||||
from conf.models import WebsiteConfig
|
||||
from utils.constants import ContestRuleType
|
||||
from options.options import SysOptions
|
||||
from utils.api import APIView, validate_serializer
|
||||
from utils.captcha import Captcha
|
||||
from utils.shortcuts import rand_str, img2base64, timestamp2utcstr
|
||||
from utils.cache import default_cache
|
||||
from utils.constants import CacheKey
|
||||
|
||||
from utils.shortcuts import rand_str, img2base64, datetime2str
|
||||
from ..decorators import login_required
|
||||
from ..models import User, UserProfile
|
||||
from ..serializers import (ApplyResetPasswordSerializer, ResetPasswordSerializer,
|
||||
UserChangePasswordSerializer, UserLoginSerializer,
|
||||
UserRegisterSerializer, UsernameOrEmailCheckSerializer,
|
||||
RankInfoSerializer)
|
||||
from ..serializers import (SSOSerializer, TwoFactorAuthCodeSerializer,
|
||||
UserProfileSerializer,
|
||||
from ..serializers import (TwoFactorAuthCodeSerializer, UserProfileSerializer,
|
||||
EditUserProfileSerializer, AvatarUploadForm)
|
||||
from ..tasks import send_email_async
|
||||
|
||||
@ -39,7 +35,7 @@ class UserProfileAPI(APIView):
|
||||
"""
|
||||
user = request.user
|
||||
if not user.is_authenticated():
|
||||
return self.success({})
|
||||
return self.success()
|
||||
username = request.GET.get("username")
|
||||
try:
|
||||
if username:
|
||||
@ -48,8 +44,7 @@ class UserProfileAPI(APIView):
|
||||
user = request.user
|
||||
except User.DoesNotExist:
|
||||
return self.error("User does not exist")
|
||||
profile = UserProfile.objects.select_related("user").get(user=user)
|
||||
return self.success(UserProfileSerializer(profile).data)
|
||||
return self.success(UserProfileSerializer(user.userprofile).data)
|
||||
|
||||
@validate_serializer(EditUserProfileSerializer)
|
||||
@login_required
|
||||
@ -72,8 +67,7 @@ class AvatarUploadAPI(APIView):
|
||||
avatar = form.cleaned_data["file"]
|
||||
else:
|
||||
return self.error("Invalid file content")
|
||||
# 2097152 = 2 * 1024 * 1024 = 2MB
|
||||
if avatar.size > 2097152:
|
||||
if avatar.size > 2 * 1024 * 1024:
|
||||
return self.error("Picture is too large")
|
||||
suffix = os.path.splitext(avatar.name)[-1].lower()
|
||||
if suffix not in [".gif", ".jpg", ".jpeg", ".bmp", ".png"]:
|
||||
@ -84,46 +78,12 @@ class AvatarUploadAPI(APIView):
|
||||
for chunk in avatar:
|
||||
img.write(chunk)
|
||||
user_profile = request.user.userprofile
|
||||
_, old_avatar = os.path.split(user_profile.avatar)
|
||||
if old_avatar != "default.png":
|
||||
os.remove(os.path.join(settings.IMAGE_UPLOAD_DIR_ABS, old_avatar))
|
||||
|
||||
user_profile.avatar = f"/{settings.IMAGE_UPLOAD_DIR}/{name}"
|
||||
user_profile.save()
|
||||
return self.success("Succeeded")
|
||||
|
||||
|
||||
class SSOAPI(APIView):
|
||||
@login_required
|
||||
def get(self, request):
|
||||
callback = request.GET.get("callback", None)
|
||||
if not callback:
|
||||
return self.error("Parameter Error")
|
||||
token = rand_str()
|
||||
request.user.auth_token = token
|
||||
request.user.save()
|
||||
return self.success({"redirect_url": callback + "?token=" + token,
|
||||
"callback": callback})
|
||||
|
||||
@validate_serializer(SSOSerializer)
|
||||
def post(self, request):
|
||||
data = request.data
|
||||
try:
|
||||
User.objects.get(open_api_appkey=data["appkey"])
|
||||
except User.DoesNotExist:
|
||||
return self.error("Invalid appkey")
|
||||
try:
|
||||
user = User.objects.get(auth_token=data["token"])
|
||||
user.auth_token = None
|
||||
user.save()
|
||||
return self.success({"username": user.username,
|
||||
"id": user.id,
|
||||
"admin_type": user.admin_type,
|
||||
"avatar": user.userprofile.avatar})
|
||||
except User.DoesNotExist:
|
||||
return self.error("User does not exist")
|
||||
|
||||
|
||||
class TwoFactorAuthAPI(APIView):
|
||||
@login_required
|
||||
def get(self, request):
|
||||
@ -132,14 +92,13 @@ class TwoFactorAuthAPI(APIView):
|
||||
"""
|
||||
user = request.user
|
||||
if user.two_factor_auth:
|
||||
return self.error("Already open 2FA")
|
||||
return self.error("2FA is already turned on")
|
||||
token = rand_str()
|
||||
user.tfa_token = token
|
||||
user.save()
|
||||
|
||||
config = WebsiteConfig.objects.first()
|
||||
label = f"{config.name_shortcut}:{user.username}"
|
||||
image = qrcode.make(OtpAuth(token).to_uri("totp", label, config.name))
|
||||
label = f"{SysOptions.website_name_shortcut}:{user.username}"
|
||||
image = qrcode.make(OtpAuth(token).to_uri("totp", label, SysOptions.website_name))
|
||||
return self.success(img2base64(image))
|
||||
|
||||
@login_required
|
||||
@ -163,7 +122,7 @@ class TwoFactorAuthAPI(APIView):
|
||||
code = request.data["code"]
|
||||
user = request.user
|
||||
if not user.two_factor_auth:
|
||||
return self.error("Other session have disabled TFA")
|
||||
return self.error("2FA is already turned off")
|
||||
if OtpAuth(user.tfa_token).valid_totp(code):
|
||||
user.two_factor_auth = False
|
||||
user.save()
|
||||
@ -200,7 +159,7 @@ class UserLoginAPI(APIView):
|
||||
# None is returned if username or password is wrong
|
||||
if user:
|
||||
if user.is_disabled:
|
||||
return self.error("Your account have been disabled")
|
||||
return self.error("Your account has been disabled")
|
||||
if not user.two_factor_auth:
|
||||
auth.login(request, user)
|
||||
return self.success("Succeeded")
|
||||
@ -220,13 +179,13 @@ class UserLoginAPI(APIView):
|
||||
# todo remove this, only for debug use
|
||||
def get(self, request):
|
||||
auth.login(request, auth.authenticate(username=request.GET["username"], password=request.GET["password"]))
|
||||
return self.success({})
|
||||
return self.success()
|
||||
|
||||
|
||||
class UserLogoutAPI(APIView):
|
||||
def get(self, request):
|
||||
auth.logout(request)
|
||||
return self.success({})
|
||||
return self.success()
|
||||
|
||||
|
||||
class UsernameOrEmailCheck(APIView):
|
||||
@ -242,11 +201,9 @@ class UsernameOrEmailCheck(APIView):
|
||||
"email": False
|
||||
}
|
||||
if data.get("username"):
|
||||
if User.objects.filter(username=data["username"]).exists():
|
||||
result["username"] = True
|
||||
result["username"] = User.objects.filter(username=data["username"]).exists()
|
||||
if data.get("email"):
|
||||
if User.objects.filter(email=data["email"]).exists():
|
||||
result["email"] = True
|
||||
result["email"] = User.objects.filter(email=data["email"]).exists()
|
||||
return self.success(result)
|
||||
|
||||
|
||||
@ -256,17 +213,9 @@ class UserRegisterAPI(APIView):
|
||||
"""
|
||||
User register api
|
||||
"""
|
||||
config = default_cache.get(CacheKey.website_config)
|
||||
if config:
|
||||
config = pickle.loads(config)
|
||||
else:
|
||||
config = WebsiteConfig.objects.first()
|
||||
if not config:
|
||||
config = WebsiteConfig.objects.create()
|
||||
default_cache.set(CacheKey.website_config, pickle.dumps(config))
|
||||
|
||||
if not config.allow_register:
|
||||
return self.error("Register have been disabled by admin")
|
||||
if not SysOptions.allow_register:
|
||||
return self.error("Register function has been disabled by admin")
|
||||
|
||||
data = request.data
|
||||
captcha = Captcha(request)
|
||||
@ -295,6 +244,7 @@ class UserChangePasswordAPI(APIView):
|
||||
username = request.user.username
|
||||
user = auth.authenticate(username=username, password=data["old_password"])
|
||||
if user:
|
||||
# TODO: check tfa?
|
||||
user.set_password(data["new_password"])
|
||||
user.save()
|
||||
return self.success("Succeeded")
|
||||
@ -307,7 +257,6 @@ class ApplyResetPasswordAPI(APIView):
|
||||
def post(self, request):
|
||||
data = request.data
|
||||
captcha = Captcha(request)
|
||||
config = WebsiteConfig.objects.first()
|
||||
if not captcha.check(data["captcha"]):
|
||||
return self.error("Invalid captcha")
|
||||
try:
|
||||
@ -322,14 +271,14 @@ class ApplyResetPasswordAPI(APIView):
|
||||
user.save()
|
||||
render_data = {
|
||||
"username": user.username,
|
||||
"website_name": config.name,
|
||||
"link": f"{config.base_url}/reset-password/{user.reset_password_token}"
|
||||
"website_name": SysOptions.website_name,
|
||||
"link": f"{SysOptions.website_base_url}/reset-password/{user.reset_password_token}"
|
||||
}
|
||||
email_html = render_to_string("reset_password_email.html", render_data)
|
||||
send_email_async.delay(config.name,
|
||||
send_email_async.delay(SysOptions.website_name,
|
||||
user.email,
|
||||
user.username,
|
||||
config.name + " 登录信息找回邮件",
|
||||
f"{SysOptions.website_name} 登录信息找回邮件",
|
||||
email_html)
|
||||
return self.success("Succeeded")
|
||||
|
||||
@ -344,9 +293,9 @@ class ResetPasswordAPI(APIView):
|
||||
try:
|
||||
user = User.objects.get(reset_password_token=data["token"])
|
||||
except User.DoesNotExist:
|
||||
return self.error("Token dose not exist")
|
||||
if int((user.reset_password_token_expire_time - now()).total_seconds()) < 0:
|
||||
return self.error("Token have expired")
|
||||
return self.error("Token does not exist")
|
||||
if user.reset_password_token_expire_time < now():
|
||||
return self.error("Token has expired")
|
||||
user.reset_password_token = None
|
||||
user.two_factor_auth = False
|
||||
user.set_password(data["password"])
|
||||
@ -358,14 +307,13 @@ class SessionManagementAPI(APIView):
|
||||
@login_required
|
||||
def get(self, request):
|
||||
engine = import_module(settings.SESSION_ENGINE)
|
||||
SessionStore = engine.SessionStore
|
||||
current_session = request.COOKIES.get(settings.SESSION_COOKIE_NAME)
|
||||
session_store = engine.SessionStore
|
||||
current_session = request.session.session_key
|
||||
session_keys = request.user.session_keys
|
||||
result = []
|
||||
modified = False
|
||||
for key in session_keys[:]:
|
||||
session = SessionStore(key)
|
||||
session = session_store(key)
|
||||
# session does not exist or is expiry
|
||||
if not session._session:
|
||||
session_keys.remove(key)
|
||||
@ -377,7 +325,7 @@ class SessionManagementAPI(APIView):
|
||||
s["current_session"] = True
|
||||
s["ip"] = session["ip"]
|
||||
s["user_agent"] = session["user_agent"]
|
||||
s["last_activity"] = timestamp2utcstr(session["last_activity"])
|
||||
s["last_activity"] = datetime2str(session["last_activity"])
|
||||
s["session_key"] = key
|
||||
result.append(s)
|
||||
if modified:
|
||||
@ -401,12 +349,12 @@ class SessionManagementAPI(APIView):
|
||||
class UserRankAPI(APIView):
|
||||
def get(self, request):
|
||||
rule_type = request.GET.get("rule")
|
||||
if rule_type not in ["acm", "oi"]:
|
||||
rule_type = "acm"
|
||||
if rule_type not in ContestRuleType.choices():
|
||||
rule_type = ContestRuleType.ACM
|
||||
profiles = UserProfile.objects.select_related("user")\
|
||||
.filter(submission_number__gt=0)\
|
||||
.exclude(user__is_disabled=True)
|
||||
if rule_type == "acm":
|
||||
if rule_type == ContestRuleType.ACM:
|
||||
profiles = profiles.order_by("-accepted_number", "submission_number")
|
||||
else:
|
||||
profiles = profiles.order_by("-total_score")
|
||||
|
20
announcement/migrations/0002_auto_20171011_1214.py
Normal file
20
announcement/migrations/0002_auto_20171011_1214.py
Normal file
@ -0,0 +1,20 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Generated by Django 1.11.4 on 2017-10-11 12:14
|
||||
from __future__ import unicode_literals
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('announcement', '0001_initial'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterField(
|
||||
model_name='announcement',
|
||||
name='title',
|
||||
field=models.CharField(max_length=64),
|
||||
),
|
||||
]
|
@ -5,7 +5,7 @@ from utils.models import RichTextField
|
||||
|
||||
|
||||
class Announcement(models.Model):
|
||||
title = models.CharField(max_length=50)
|
||||
title = models.CharField(max_length=64)
|
||||
# HTML
|
||||
content = RichTextField()
|
||||
create_time = models.DateTimeField(auto_now_add=True)
|
||||
|
@ -5,8 +5,8 @@ from .models import Announcement
|
||||
|
||||
|
||||
class CreateAnnouncementSerializer(serializers.Serializer):
|
||||
title = serializers.CharField(max_length=50)
|
||||
content = serializers.CharField(max_length=10000)
|
||||
title = serializers.CharField(max_length=64)
|
||||
content = serializers.CharField(max_length=1024 * 1024 * 8)
|
||||
visible = serializers.BooleanField()
|
||||
|
||||
|
||||
@ -21,6 +21,6 @@ class AnnouncementSerializer(serializers.ModelSerializer):
|
||||
|
||||
class EditAnnouncementSerializer(serializers.Serializer):
|
||||
id = serializers.IntegerField()
|
||||
title = serializers.CharField(max_length=50)
|
||||
content = serializers.CharField(max_length=10000)
|
||||
title = serializers.CharField(max_length=64)
|
||||
content = serializers.CharField(max_length=1024 * 1024 * 8)
|
||||
visible = serializers.BooleanField()
|
||||
|
39
conf/migrations/0002_auto_20171011_1214.py
Normal file
39
conf/migrations/0002_auto_20171011_1214.py
Normal file
@ -0,0 +1,39 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Generated by Django 1.11.4 on 2017-10-11 12:14
|
||||
from __future__ import unicode_literals
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('conf', '0001_initial'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.DeleteModel(
|
||||
name='JudgeServerToken',
|
||||
),
|
||||
migrations.DeleteModel(
|
||||
name='SMTPConfig',
|
||||
),
|
||||
migrations.DeleteModel(
|
||||
name='WebsiteConfig',
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name='judgeserver',
|
||||
name='hostname',
|
||||
field=models.CharField(max_length=128),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name='judgeserver',
|
||||
name='judger_version',
|
||||
field=models.CharField(max_length=32),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name='judgeserver',
|
||||
name='service_url',
|
||||
field=models.CharField(blank=True, max_length=256, null=True),
|
||||
),
|
||||
]
|
@ -2,42 +2,17 @@ from django.db import models
|
||||
from django.utils import timezone
|
||||
|
||||
|
||||
class SMTPConfig(models.Model):
|
||||
server = models.CharField(max_length=128)
|
||||
port = models.IntegerField(default=25)
|
||||
email = models.CharField(max_length=128)
|
||||
password = models.CharField(max_length=128)
|
||||
tls = models.BooleanField()
|
||||
|
||||
class Meta:
|
||||
db_table = "smtp_config"
|
||||
|
||||
|
||||
class WebsiteConfig(models.Model):
|
||||
base_url = models.CharField(max_length=128, default="http://127.0.0.1")
|
||||
name = models.CharField(max_length=32, default="Online Judge")
|
||||
name_shortcut = models.CharField(max_length=32, default="oj")
|
||||
footer = models.TextField(default="Online Judge Footer")
|
||||
# allow register
|
||||
allow_register = models.BooleanField(default=True)
|
||||
# submission list show all user's submission
|
||||
submission_list_show_all = models.BooleanField(default=True)
|
||||
|
||||
class Meta:
|
||||
db_table = "website_config"
|
||||
|
||||
|
||||
class JudgeServer(models.Model):
|
||||
hostname = models.CharField(max_length=64)
|
||||
hostname = models.CharField(max_length=128)
|
||||
ip = models.CharField(max_length=32, blank=True, null=True)
|
||||
judger_version = models.CharField(max_length=24)
|
||||
judger_version = models.CharField(max_length=32)
|
||||
cpu_core = models.IntegerField()
|
||||
memory_usage = models.FloatField()
|
||||
cpu_usage = models.FloatField()
|
||||
last_heartbeat = models.DateTimeField()
|
||||
create_time = models.DateTimeField(auto_now_add=True)
|
||||
task_number = models.IntegerField(default=0)
|
||||
service_url = models.CharField(max_length=128, blank=True, null=True)
|
||||
service_url = models.CharField(max_length=256, blank=True, null=True)
|
||||
|
||||
@property
|
||||
def status(self):
|
||||
@ -48,10 +23,3 @@ class JudgeServer(models.Model):
|
||||
|
||||
class Meta:
|
||||
db_table = "judge_server"
|
||||
|
||||
|
||||
class JudgeServerToken(models.Model):
|
||||
token = models.CharField(max_length=32)
|
||||
|
||||
class Meta:
|
||||
db_table = "judge_server_token"
|
||||
|
@ -1,6 +1,6 @@
|
||||
from utils.api import DateTimeTZField, serializers
|
||||
|
||||
from .models import JudgeServer, SMTPConfig, WebsiteConfig
|
||||
from .models import JudgeServer
|
||||
|
||||
|
||||
class EditSMTPConfigSerializer(serializers.Serializer):
|
||||
@ -15,31 +15,19 @@ class CreateSMTPConfigSerializer(EditSMTPConfigSerializer):
|
||||
password = serializers.CharField(max_length=128)
|
||||
|
||||
|
||||
class SMTPConfigSerializer(serializers.ModelSerializer):
|
||||
class Meta:
|
||||
model = SMTPConfig
|
||||
exclude = ["id", "password"]
|
||||
|
||||
|
||||
class TestSMTPConfigSerializer(serializers.Serializer):
|
||||
email = serializers.EmailField()
|
||||
|
||||
|
||||
class CreateEditWebsiteConfigSerializer(serializers.Serializer):
|
||||
base_url = serializers.CharField(max_length=128)
|
||||
name = serializers.CharField(max_length=32)
|
||||
name_shortcut = serializers.CharField(max_length=32)
|
||||
footer = serializers.CharField(max_length=1024)
|
||||
website_base_url = serializers.CharField(max_length=128)
|
||||
website_name = serializers.CharField(max_length=64)
|
||||
website_name_shortcut = serializers.CharField(max_length=64)
|
||||
website_footer = serializers.CharField(max_length=1024 * 1024)
|
||||
allow_register = serializers.BooleanField()
|
||||
submission_list_show_all = serializers.BooleanField()
|
||||
|
||||
|
||||
class WebsiteConfigSerializer(serializers.ModelSerializer):
|
||||
class Meta:
|
||||
model = WebsiteConfig
|
||||
exclude = ["id"]
|
||||
|
||||
|
||||
class JudgeServerSerializer(serializers.ModelSerializer):
|
||||
create_time = DateTimeTZField()
|
||||
last_heartbeat = DateTimeTZField()
|
||||
@ -47,13 +35,14 @@ class JudgeServerSerializer(serializers.ModelSerializer):
|
||||
|
||||
class Meta:
|
||||
model = JudgeServer
|
||||
fields = "__all__"
|
||||
|
||||
|
||||
class JudgeServerHeartbeatSerializer(serializers.Serializer):
|
||||
hostname = serializers.CharField(max_length=64)
|
||||
judger_version = serializers.CharField(max_length=24)
|
||||
hostname = serializers.CharField(max_length=128)
|
||||
judger_version = serializers.CharField(max_length=32)
|
||||
cpu_core = serializers.IntegerField(min_value=1)
|
||||
memory = serializers.FloatField(min_value=0, max_value=100)
|
||||
cpu = serializers.FloatField(min_value=0, max_value=100)
|
||||
action = serializers.ChoiceField(choices=("heartbeat", ))
|
||||
service_url = serializers.CharField(max_length=128, required=False)
|
||||
service_url = serializers.CharField(max_length=256, required=False)
|
||||
|
@ -2,11 +2,10 @@ import hashlib
|
||||
|
||||
from django.utils import timezone
|
||||
|
||||
from options.options import SysOptions
|
||||
from utils.api.tests import APITestCase
|
||||
from utils.cache import default_cache
|
||||
from utils.constants import CacheKey
|
||||
|
||||
from .models import JudgeServer, JudgeServerToken, SMTPConfig
|
||||
from .models import JudgeServer
|
||||
|
||||
|
||||
class SMTPConfigTest(APITestCase):
|
||||
@ -29,10 +28,6 @@ class SMTPConfigTest(APITestCase):
|
||||
"tls": True}
|
||||
resp = self.client.put(self.url, data=data)
|
||||
self.assertSuccess(resp)
|
||||
smtp = SMTPConfig.objects.first()
|
||||
self.assertEqual(smtp.password, self.password)
|
||||
self.assertEqual(smtp.server, "smtp1.test.com")
|
||||
self.assertEqual(smtp.email, "test2@test.com")
|
||||
|
||||
def test_edit_without_password1(self):
|
||||
self.test_create_smtp_config()
|
||||
@ -40,7 +35,6 @@ class SMTPConfigTest(APITestCase):
|
||||
"tls": True, "password": ""}
|
||||
resp = self.client.put(self.url, data=data)
|
||||
self.assertSuccess(resp)
|
||||
self.assertEqual(SMTPConfig.objects.first().password, self.password)
|
||||
|
||||
def test_edit_with_password(self):
|
||||
self.test_create_smtp_config()
|
||||
@ -48,18 +42,14 @@ class SMTPConfigTest(APITestCase):
|
||||
"tls": True, "password": "newpassword"}
|
||||
resp = self.client.put(self.url, data=data)
|
||||
self.assertSuccess(resp)
|
||||
smtp = SMTPConfig.objects.first()
|
||||
self.assertEqual(smtp.password, "newpassword")
|
||||
self.assertEqual(smtp.server, "smtp1.test.com")
|
||||
self.assertEqual(smtp.email, "test2@test.com")
|
||||
|
||||
|
||||
class WebsiteConfigAPITest(APITestCase):
|
||||
def test_create_website_config(self):
|
||||
self.create_super_admin()
|
||||
url = self.reverse("website_config_api")
|
||||
data = {"base_url": "http://test.com", "name": "test name",
|
||||
"name_shortcut": "test oj", "footer": "<a>test</a>",
|
||||
data = {"website_base_url": "http://test.com", "website_name": "test name",
|
||||
"website_name_shortcut": "test oj", "website_footer": "<a>test</a>",
|
||||
"allow_register": True, "submission_list_show_all": False}
|
||||
resp = self.client.post(url, data=data)
|
||||
self.assertSuccess(resp)
|
||||
@ -67,8 +57,8 @@ class WebsiteConfigAPITest(APITestCase):
|
||||
def test_edit_website_config(self):
|
||||
self.create_super_admin()
|
||||
url = self.reverse("website_config_api")
|
||||
data = {"base_url": "http://test.com", "name": "test name",
|
||||
"name_shortcut": "test oj", "footer": "<a>test</a>",
|
||||
data = {"website_base_url": "http://test.com", "website_name": "test name",
|
||||
"website_name_shortcut": "test oj", "website_footer": "<a>test</a>",
|
||||
"allow_register": True, "submission_list_show_all": False}
|
||||
resp = self.client.post(url, data=data)
|
||||
self.assertSuccess(resp)
|
||||
@ -78,10 +68,6 @@ class WebsiteConfigAPITest(APITestCase):
|
||||
url = self.reverse("website_info_api")
|
||||
resp = self.client.get(url)
|
||||
self.assertSuccess(resp)
|
||||
self.assertEqual(resp.data["data"]["name_shortcut"], "oj")
|
||||
|
||||
def tearDown(self):
|
||||
default_cache.delete(CacheKey.website_config)
|
||||
|
||||
|
||||
class JudgeServerHeartbeatTest(APITestCase):
|
||||
@ -91,7 +77,7 @@ class JudgeServerHeartbeatTest(APITestCase):
|
||||
"cpu": 90.5, "memory": 80.3, "action": "heartbeat"}
|
||||
self.token = "test"
|
||||
self.hashed_token = hashlib.sha256(self.token.encode("utf-8")).hexdigest()
|
||||
JudgeServerToken.objects.create(token=self.token)
|
||||
SysOptions.judge_server_token = self.token
|
||||
|
||||
def test_new_heartbeat(self):
|
||||
resp = self.client.post(self.url, data=self.data, **{"HTTP_X_JUDGE_SERVER_TOKEN": self.hashed_token})
|
||||
@ -127,11 +113,9 @@ class JudgeServerAPITest(APITestCase):
|
||||
self.create_super_admin()
|
||||
|
||||
def test_get_judge_server(self):
|
||||
self.assertFalse(JudgeServerToken.objects.exists())
|
||||
resp = self.client.get(self.url)
|
||||
self.assertSuccess(resp)
|
||||
self.assertEqual(len(resp.data["data"]["servers"]), 1)
|
||||
self.assertEqual(JudgeServerToken.objects.first().token, resp.data["data"]["token"])
|
||||
|
||||
def test_delete_judge_server(self):
|
||||
resp = self.client.delete(self.url + "?hostname=testhostname")
|
||||
|
@ -1,54 +1,45 @@
|
||||
import hashlib
|
||||
import pickle
|
||||
|
||||
from django.utils import timezone
|
||||
|
||||
from account.decorators import super_admin_required
|
||||
from judge.languages import languages, spj_languages
|
||||
from judge.dispatcher import process_pending_task
|
||||
from judge.languages import languages, spj_languages
|
||||
from options.options import SysOptions
|
||||
from utils.api import APIView, CSRFExemptAPIView, validate_serializer
|
||||
from utils.shortcuts import rand_str
|
||||
from utils.cache import default_cache
|
||||
from utils.constants import CacheKey
|
||||
|
||||
from .models import JudgeServer, JudgeServerToken, SMTPConfig, WebsiteConfig
|
||||
from .models import JudgeServer
|
||||
from .serializers import (CreateEditWebsiteConfigSerializer,
|
||||
CreateSMTPConfigSerializer, EditSMTPConfigSerializer,
|
||||
JudgeServerHeartbeatSerializer,
|
||||
JudgeServerSerializer, SMTPConfigSerializer,
|
||||
TestSMTPConfigSerializer, WebsiteConfigSerializer)
|
||||
JudgeServerSerializer, TestSMTPConfigSerializer)
|
||||
|
||||
|
||||
class SMTPAPI(APIView):
|
||||
@super_admin_required
|
||||
def get(self, request):
|
||||
smtp = SMTPConfig.objects.first()
|
||||
smtp = SysOptions.smtp_config
|
||||
if not smtp:
|
||||
return self.success(None)
|
||||
return self.success(SMTPConfigSerializer(smtp).data)
|
||||
smtp.pop("password")
|
||||
return self.success(smtp)
|
||||
|
||||
@validate_serializer(CreateSMTPConfigSerializer)
|
||||
@super_admin_required
|
||||
def post(self, request):
|
||||
SMTPConfig.objects.all().delete()
|
||||
smtp = SMTPConfig.objects.create(**request.data)
|
||||
return self.success(SMTPConfigSerializer(smtp).data)
|
||||
SysOptions.smtp_config = request.data
|
||||
return self.success()
|
||||
|
||||
@validate_serializer(EditSMTPConfigSerializer)
|
||||
@super_admin_required
|
||||
def put(self, request):
|
||||
smtp = SysOptions.smtp_config
|
||||
data = request.data
|
||||
smtp = SMTPConfig.objects.first()
|
||||
if not smtp:
|
||||
return self.error("SMTP config is missing")
|
||||
smtp.server = data["server"]
|
||||
smtp.port = data["port"]
|
||||
smtp.email = data["email"]
|
||||
smtp.tls = data["tls"]
|
||||
if data.get("password"):
|
||||
smtp.password = data["password"]
|
||||
smtp.save()
|
||||
return self.success(SMTPConfigSerializer(smtp).data)
|
||||
for item in ["server", "port", "email", "tls"]:
|
||||
smtp[item] = data[item]
|
||||
if "password" in data:
|
||||
smtp["password"] = data["password"]
|
||||
SysOptions.smtp_config = smtp
|
||||
return self.success()
|
||||
|
||||
|
||||
class SMTPTestAPI(APIView):
|
||||
@ -60,37 +51,24 @@ class SMTPTestAPI(APIView):
|
||||
|
||||
class WebsiteConfigAPI(APIView):
|
||||
def get(self, request):
|
||||
config = default_cache.get(CacheKey.website_config)
|
||||
if config:
|
||||
config = pickle.loads(config)
|
||||
else:
|
||||
config = WebsiteConfig.objects.first()
|
||||
if not config:
|
||||
config = WebsiteConfig.objects.create()
|
||||
default_cache.set(CacheKey.website_config, pickle.dumps(config))
|
||||
return self.success(WebsiteConfigSerializer(config).data)
|
||||
ret = {key: getattr(SysOptions, key) for key in
|
||||
["website_base_url", "website_name", "website_name_shortcut",
|
||||
"website_footer", "allow_register", "submission_list_show_all"]}
|
||||
return self.success(ret)
|
||||
|
||||
@validate_serializer(CreateEditWebsiteConfigSerializer)
|
||||
@super_admin_required
|
||||
def post(self, request):
|
||||
data = request.data
|
||||
WebsiteConfig.objects.all().delete()
|
||||
config = WebsiteConfig.objects.create(**data)
|
||||
default_cache.set(CacheKey.website_config, pickle.dumps(config))
|
||||
return self.success(WebsiteConfigSerializer(config).data)
|
||||
for k, v in request.data.items():
|
||||
setattr(SysOptions, k, v)
|
||||
return self.success()
|
||||
|
||||
|
||||
class JudgeServerAPI(APIView):
|
||||
@super_admin_required
|
||||
def get(self, request):
|
||||
judge_server_token = JudgeServerToken.objects.first()
|
||||
if not judge_server_token:
|
||||
token = rand_str(12)
|
||||
JudgeServerToken.objects.create(token=token)
|
||||
else:
|
||||
token = judge_server_token.token
|
||||
servers = JudgeServer.objects.all().order_by("-last_heartbeat")
|
||||
return self.success({"token": token,
|
||||
return self.success({"token": SysOptions.judge_server_token,
|
||||
"servers": JudgeServerSerializer(servers, many=True).data})
|
||||
|
||||
@super_admin_required
|
||||
@ -104,15 +82,9 @@ class JudgeServerAPI(APIView):
|
||||
class JudgeServerHeartbeatAPI(CSRFExemptAPIView):
|
||||
@validate_serializer(JudgeServerHeartbeatSerializer)
|
||||
def post(self, request):
|
||||
judge_server_token = JudgeServerToken.objects.first()
|
||||
if not judge_server_token:
|
||||
token = rand_str(12)
|
||||
JudgeServerToken.objects.create(token=token)
|
||||
else:
|
||||
token = judge_server_token.token
|
||||
data = request.data
|
||||
client_token = request.META.get("HTTP_X_JUDGE_SERVER_TOKEN")
|
||||
if hashlib.sha256(token.encode("utf-8")).hexdigest() != client_token:
|
||||
if hashlib.sha256(SysOptions.judge_server_token.encode("utf-8")).hexdigest() != client_token:
|
||||
return self.error("Invalid token")
|
||||
service_url = data.get("service_url")
|
||||
|
||||
|
26
contest/migrations/0006_auto_20171011_1214.py
Normal file
26
contest/migrations/0006_auto_20171011_1214.py
Normal file
@ -0,0 +1,26 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Generated by Django 1.11.4 on 2017-10-11 12:14
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import django.contrib.postgres.fields.jsonb
|
||||
from django.db import migrations
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('contest', '0005_auto_20170823_0918'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterField(
|
||||
model_name='acmcontestrank',
|
||||
name='submission_info',
|
||||
field=django.contrib.postgres.fields.jsonb.JSONField(default=dict),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name='oicontestrank',
|
||||
name='submission_info',
|
||||
field=django.contrib.postgres.fields.jsonb.JSONField(default=dict),
|
||||
),
|
||||
]
|
@ -1,27 +1,13 @@
|
||||
from utils.constants import ContestRuleType # noqa
|
||||
from django.db import models
|
||||
from django.utils.timezone import now
|
||||
from jsonfield import JSONField
|
||||
from utils.models import JSONField
|
||||
|
||||
from utils.constants import ContestStatus, ContestType
|
||||
from account.models import User, AdminType
|
||||
from utils.models import RichTextField
|
||||
|
||||
|
||||
class ContestType(object):
|
||||
PUBLIC_CONTEST = "Public"
|
||||
PASSWORD_PROTECTED_CONTEST = "Password Protected"
|
||||
|
||||
|
||||
class ContestStatus(object):
|
||||
CONTEST_NOT_START = "1"
|
||||
CONTEST_ENDED = "-1"
|
||||
CONTEST_UNDERWAY = "0"
|
||||
|
||||
|
||||
class ContestRuleType(object):
|
||||
ACM = "ACM"
|
||||
OI = "OI"
|
||||
|
||||
|
||||
class Contest(models.Model):
|
||||
title = models.CharField(max_length=40)
|
||||
description = RichTextField()
|
||||
@ -59,12 +45,20 @@ class Contest(models.Model):
|
||||
def is_contest_admin(self, user):
|
||||
return user.is_authenticated() and (self.created_by == user or user.admin_type == AdminType.SUPER_ADMIN)
|
||||
|
||||
def check_oi_permission(self, user):
|
||||
if self.status != ContestStatus.CONTEST_ENDED and self.real_time_rank == False:
|
||||
if self.is_contest_admin(user):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
class Meta:
|
||||
db_table = "contest"
|
||||
ordering = ("-create_time",)
|
||||
|
||||
|
||||
class ContestRank(models.Model):
|
||||
class AbstractContestRank(models.Model):
|
||||
user = models.ForeignKey(User)
|
||||
contest = models.ForeignKey(Contest)
|
||||
submission_number = models.IntegerField(default=0)
|
||||
@ -73,30 +67,27 @@ class ContestRank(models.Model):
|
||||
abstract = True
|
||||
|
||||
|
||||
class ACMContestRank(ContestRank):
|
||||
class ACMContestRank(AbstractContestRank):
|
||||
accepted_number = models.IntegerField(default=0)
|
||||
# total_time is only for ACM contest total_time = ac time + none-ac times * 20 * 60
|
||||
total_time = models.IntegerField(default=0)
|
||||
# {23: {"is_ac": True, "ac_time": 8999, "error_number": 2, "is_first_ac": True}}
|
||||
# key is problem id
|
||||
submission_info = JSONField(default={})
|
||||
submission_info = JSONField(default=dict)
|
||||
|
||||
class Meta:
|
||||
db_table = "acm_contest_rank"
|
||||
|
||||
|
||||
class OIContestRank(ContestRank):
|
||||
class OIContestRank(AbstractContestRank):
|
||||
total_score = models.IntegerField(default=0)
|
||||
# {23: 333}}
|
||||
# key is problem id, value is current score
|
||||
submission_info = JSONField(default={})
|
||||
submission_info = JSONField(default=dict)
|
||||
|
||||
class Meta:
|
||||
db_table = "oi_contest_rank"
|
||||
|
||||
def update_rank(self, submission):
|
||||
self.submission_number += 1
|
||||
|
||||
|
||||
class ContestAnnouncement(models.Model):
|
||||
contest = models.ForeignKey(Contest)
|
||||
|
@ -79,7 +79,7 @@ class ContestAPITest(APITestCase):
|
||||
self.create_user("test", "test123")
|
||||
url = self.reverse("contest_password_api")
|
||||
resp = self.client.post(url, {"contest_id": contest_id, "password": "error_password"})
|
||||
self.assertDictEqual(resp.data, {"error": "error", "data": "Password doesn't match."})
|
||||
self.assertDictEqual(resp.data, {"error": "error", "data": "Wrong password"})
|
||||
|
||||
resp = self.client.post(url, {"contest_id": contest_id, "password": DEFAULT_CONTEST_DATA["password"]})
|
||||
self.assertSuccess(resp)
|
||||
@ -89,7 +89,7 @@ class ContestAPITest(APITestCase):
|
||||
self.create_user("test", "test123")
|
||||
url = self.reverse("contest_access_api")
|
||||
resp = self.client.get(url + "?contest_id=" + str(contest_id))
|
||||
self.assertFalse(resp.data["data"]["Access"])
|
||||
self.assertFalse(resp.data["data"]["access"])
|
||||
|
||||
password_url = self.reverse("contest_password_api")
|
||||
resp = self.client.post(password_url, {"contest_id": contest_id, "password": DEFAULT_CONTEST_DATA["password"]})
|
||||
|
@ -1,13 +1,11 @@
|
||||
import pickle
|
||||
from django.utils.timezone import now
|
||||
from django.db.models import Q
|
||||
from django.core.cache import cache
|
||||
from utils.api import APIView, validate_serializer
|
||||
from utils.cache import default_cache
|
||||
from utils.constants import CacheKey
|
||||
from account.decorators import login_required, check_contest_permission
|
||||
|
||||
from ..models import ContestAnnouncement, Contest, ContestStatus, ContestRuleType
|
||||
from ..models import OIContestRank, ACMContestRank
|
||||
from utils.constants import ContestRuleType, ContestStatus
|
||||
from ..models import ContestAnnouncement, Contest, OIContestRank, ACMContestRank
|
||||
from ..serializers import ContestAnnouncementSerializer
|
||||
from ..serializers import ContestSerializer, ContestPasswordVerifySerializer
|
||||
from ..serializers import OIContestRankSerializer, ACMContestRankSerializer
|
||||
@ -32,7 +30,7 @@ class ContestAPI(APIView):
|
||||
try:
|
||||
contest = Contest.objects.select_related("created_by").get(id=contest_id, visible=True)
|
||||
except Contest.DoesNotExist:
|
||||
return self.error("Contest doesn't exist.")
|
||||
return self.error("Contest does not exist")
|
||||
return self.success(ContestSerializer(contest).data)
|
||||
|
||||
contests = Contest.objects.select_related("created_by").filter(visible=True)
|
||||
@ -50,7 +48,7 @@ class ContestAPI(APIView):
|
||||
elif status == ContestStatus.CONTEST_ENDED:
|
||||
contests = contests.filter(end_time__lt=cur)
|
||||
else:
|
||||
contests = contests.filter(Q(start_time__lte=cur) & Q(end_time__gte=cur))
|
||||
contests = contests.filter(start_time__lte=cur, end_time__gte=cur)
|
||||
return self.success(self.paginate_data(request, contests, ContestSerializer))
|
||||
|
||||
|
||||
@ -62,14 +60,14 @@ class ContestPasswordVerifyAPI(APIView):
|
||||
try:
|
||||
contest = Contest.objects.get(id=data["contest_id"], visible=True, password__isnull=False)
|
||||
except Contest.DoesNotExist:
|
||||
return self.error("Contest %s doesn't exist." % data["contest_id"])
|
||||
return self.error("Contest does not exist")
|
||||
if contest.password != data["password"]:
|
||||
return self.error("Password doesn't match.")
|
||||
return self.error("Wrong password")
|
||||
|
||||
# password verify OK.
|
||||
if "contests" not in request.session:
|
||||
request.session["contests"] = []
|
||||
request.session["contests"].append(int(data["contest_id"]))
|
||||
if "accessible_contests" not in request.session:
|
||||
request.session["accessible_contests"] = []
|
||||
request.session["accessible_contests"].append(contest.id)
|
||||
# https://docs.djangoproject.com/en/dev/topics/http/sessions/#when-sessions-are-saved
|
||||
request.session.modified = True
|
||||
return self.success(True)
|
||||
@ -80,13 +78,8 @@ class ContestAccessAPI(APIView):
|
||||
def get(self, request):
|
||||
contest_id = request.GET.get("contest_id")
|
||||
if not contest_id:
|
||||
return self.error("Parameter contest_id not exist.")
|
||||
if "contests" not in request.session:
|
||||
request.session["contests"] = []
|
||||
if int(contest_id) in request.session["contests"]:
|
||||
return self.success({"Access": True})
|
||||
else:
|
||||
return self.success({"Access": False})
|
||||
return self.error()
|
||||
return self.success({"access": int(contest_id) in request.session.get("accessible_contests", [])})
|
||||
|
||||
|
||||
class ContestRankAPI(APIView):
|
||||
@ -100,17 +93,17 @@ class ContestRankAPI(APIView):
|
||||
|
||||
@check_contest_permission
|
||||
def get(self, request):
|
||||
if self.contest.rule_type == ContestRuleType.ACM:
|
||||
serializer = ACMContestRankSerializer
|
||||
else:
|
||||
if self.contest.rule_type == ContestRuleType.OI:
|
||||
if not self.contest.check_oi_permission(request.user):
|
||||
return self.error("You have no permission for ranks now")
|
||||
serializer = OIContestRankSerializer
|
||||
|
||||
cache_key = CacheKey.contest_rank_cache + str(self.contest.id)
|
||||
qs = default_cache.get(cache_key)
|
||||
if not qs:
|
||||
ranks = self.get_rank()
|
||||
default_cache.set(cache_key, pickle.dumps(ranks))
|
||||
else:
|
||||
ranks = pickle.loads(qs)
|
||||
serializer = ACMContestRankSerializer
|
||||
|
||||
return self.success(self.paginate_data(request, ranks, serializer))
|
||||
cache_key = f"{CacheKey.contest_rank_cache}:{self.contest.id}"
|
||||
qs = cache.get(cache_key)
|
||||
if not qs:
|
||||
qs = self.get_rank()
|
||||
cache.set(cache_key, qs)
|
||||
|
||||
return self.success(self.paginate_data(request, qs, serializer))
|
||||
|
@ -8,12 +8,13 @@ from django.db import transaction
|
||||
from django.db.models import F
|
||||
|
||||
from account.models import User
|
||||
from conf.models import JudgeServer, JudgeServerToken
|
||||
from conf.models import JudgeServer
|
||||
from contest.models import ContestRuleType, ACMContestRank, OIContestRank, ContestStatus
|
||||
from judge.languages import languages
|
||||
from options.options import SysOptions
|
||||
from problem.models import Problem, ProblemRuleType
|
||||
from submission.models import JudgeStatus, Submission
|
||||
from utils.cache import judge_cache, default_cache
|
||||
from utils.cache import cache
|
||||
from utils.constants import CacheKey
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -21,41 +22,36 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
# 继续处理在队列中的问题
|
||||
def process_pending_task():
|
||||
if judge_cache.llen(CacheKey.waiting_queue):
|
||||
if cache.llen(CacheKey.waiting_queue):
|
||||
# 防止循环引入
|
||||
from judge.tasks import judge_task
|
||||
data = json.loads(judge_cache.rpop(CacheKey.waiting_queue).decode("utf-8"))
|
||||
data = json.loads(cache.rpop(CacheKey.waiting_queue).decode("utf-8"))
|
||||
judge_task.delay(**data)
|
||||
|
||||
|
||||
class JudgeDispatcher(object):
|
||||
def __init__(self, submission_id, problem_id):
|
||||
token = JudgeServerToken.objects.first().token
|
||||
self.token = hashlib.sha256(token.encode("utf-8")).hexdigest()
|
||||
self.redis_conn = judge_cache
|
||||
self.submission = Submission.objects.get(pk=submission_id)
|
||||
self.token = hashlib.sha256(SysOptions.judge_server_token.encode("utf-8")).hexdigest()
|
||||
self.submission = Submission.objects.get(id=submission_id)
|
||||
self.contest_id = self.submission.contest_id
|
||||
if self.contest_id:
|
||||
self.problem = Problem.objects.select_related("contest") \
|
||||
.get(id=problem_id, contest_id=self.contest_id)
|
||||
self.problem = Problem.objects.select_related("contest").get(id=problem_id, contest_id=self.contest_id)
|
||||
self.contest = self.problem.contest
|
||||
else:
|
||||
self.problem = Problem.objects.get(id=problem_id)
|
||||
|
||||
def _request(self, url, data=None):
|
||||
kwargs = {"headers": {"X-Judge-Server-Token": self.token,
|
||||
"Content-Type": "application/json"}}
|
||||
kwargs = {"headers": {"X-Judge-Server-Token": self.token}}
|
||||
if data:
|
||||
kwargs["data"] = json.dumps(data)
|
||||
kwargs["json"] = data
|
||||
try:
|
||||
return requests.post(url, **kwargs).json()
|
||||
except Exception as e:
|
||||
logger.error(e.with_traceback())
|
||||
logger.exception(e)
|
||||
|
||||
@staticmethod
|
||||
def choose_judge_server():
|
||||
with transaction.atomic():
|
||||
# TODO: use more reasonable way
|
||||
servers = JudgeServer.objects.select_for_update().all().order_by("task_number")
|
||||
servers = [s for s in servers if s.status == "normal"]
|
||||
if servers:
|
||||
@ -65,10 +61,10 @@ class JudgeDispatcher(object):
|
||||
return server
|
||||
|
||||
@staticmethod
|
||||
def release_judge_res(judge_server_id):
|
||||
def release_judge_server(judge_server_id):
|
||||
with transaction.atomic():
|
||||
# 使用原子操作, 同时因为use和release中间间隔了判题过程,需要重新查询一下
|
||||
server = JudgeServer.objects.select_for_update().get(id=judge_server_id)
|
||||
server = JudgeServer.objects.get(id=judge_server_id)
|
||||
server.used_instance_number = F("task_number") - 1
|
||||
server.save()
|
||||
|
||||
@ -94,7 +90,7 @@ class JudgeDispatcher(object):
|
||||
server = self.choose_judge_server()
|
||||
if not server:
|
||||
data = {"submission_id": self.submission.id, "problem_id": self.problem.id}
|
||||
self.redis_conn.lpush(CacheKey.waiting_queue, json.dumps(data))
|
||||
cache.lpush(CacheKey.waiting_queue, json.dumps(data))
|
||||
return
|
||||
|
||||
sub_config = list(filter(lambda item: self.submission.language == item["name"], languages))[0]
|
||||
@ -138,7 +134,7 @@ class JudgeDispatcher(object):
|
||||
else:
|
||||
self.submission.result = JudgeStatus.PARTIALLY_ACCEPTED
|
||||
self.submission.save()
|
||||
self.release_judge_res(server.id)
|
||||
self.release_judge_server(server.id)
|
||||
|
||||
self.update_problem_status()
|
||||
if self.contest_id:
|
||||
@ -160,7 +156,6 @@ class JudgeDispatcher(object):
|
||||
with transaction.atomic():
|
||||
# prepare problem and user_profile
|
||||
problem = Problem.objects.select_for_update().get(contest_id=self.contest_id, id=self.problem.id)
|
||||
problem_info = problem.statistic_info
|
||||
user = User.objects.select_for_update().select_for_update("userprofile").get(id=self.submission.user_id)
|
||||
user_profile = user.userprofile
|
||||
if self.contest_id:
|
||||
@ -169,50 +164,46 @@ class JudgeDispatcher(object):
|
||||
key = "problems"
|
||||
acm_problems_status = user_profile.acm_problems_status.get(key, {})
|
||||
oi_problems_status = user_profile.oi_problems_status.get(key, {})
|
||||
problem_id = str(self.problem.id)
|
||||
problem_info = problem.statistic_info
|
||||
|
||||
# update problem info
|
||||
result = str(self.submission.result)
|
||||
problem_info[result] = problem_info.get(result, 0) + 1
|
||||
problem.statistic_info = problem_info
|
||||
|
||||
# update submission and accepted number counter
|
||||
problem.submission_number += 1
|
||||
if self.submission.result == JudgeStatus.ACCEPTED:
|
||||
problem.accepted_number += 1
|
||||
# only when submission is not in contest, we update user profile,
|
||||
# in other words, users' submission in a contest will not be counted in user profile
|
||||
# submission in a contest will not be counted in user profile
|
||||
if not self.contest_id:
|
||||
user_profile.submission_number += 1
|
||||
if self.submission.result == JudgeStatus.ACCEPTED:
|
||||
user_profile.accepted_number += 1
|
||||
|
||||
problem_id = str(self.problem.id)
|
||||
if self.problem.rule_type == ProblemRuleType.ACM:
|
||||
# update acm problem info
|
||||
result = str(self.submission.result)
|
||||
problem_info[result] = problem_info.get(result, 0) + 1
|
||||
problem.statistic_info = problem_info
|
||||
|
||||
# update user_profile
|
||||
if problem_id not in acm_problems_status:
|
||||
acm_problems_status[problem_id] = self.submission.result
|
||||
acm_problems_status[problem_id] = {"status": self.submission.result, "_id": self.problem._id}
|
||||
# skip if the problem has been accepted
|
||||
elif acm_problems_status[problem_id] != JudgeStatus.ACCEPTED:
|
||||
if self.submission.result == JudgeStatus.ACCEPTED:
|
||||
acm_problems_status[problem_id] = JudgeStatus.ACCEPTED
|
||||
else:
|
||||
acm_problems_status[problem_id] = self.submission.result
|
||||
elif acm_problems_status[problem_id]["status"] != JudgeStatus.ACCEPTED:
|
||||
acm_problems_status[problem_id]["status"] = self.submission.result
|
||||
user_profile.acm_problems_status[key] = acm_problems_status
|
||||
|
||||
else:
|
||||
# update oi problem info
|
||||
score = self.submission.statistic_info["score"]
|
||||
problem_info[score] = problem_info.get(score, 0) + 1
|
||||
problem.statistic_info = problem_info
|
||||
|
||||
# update user_profile
|
||||
score = self.submission.statistic_info["score"]
|
||||
if problem_id not in oi_problems_status:
|
||||
user_profile.add_score(score)
|
||||
oi_problems_status[problem_id] = score
|
||||
oi_problems_status[problem_id] = {"status": self.submission.result,
|
||||
"_id": self.problem._id,
|
||||
"score": score}
|
||||
else:
|
||||
# minus last time score, add this time score
|
||||
user_profile.add_score(this_time_score=score, last_time_score=oi_problems_status[problem_id])
|
||||
oi_problems_status[problem_id] = score
|
||||
user_profile.add_score(this_time_score=score, last_time_score=oi_problems_status[problem_id]["score"])
|
||||
oi_problems_status[problem_id]["score"] = score
|
||||
oi_problems_status[problem_id]["status"] = self.submission.result
|
||||
user_profile.oi_problems_status[key] = oi_problems_status
|
||||
|
||||
problem.save(update_fields=["submission_number", "accepted_number", "statistic_info"])
|
||||
@ -222,8 +213,8 @@ class JudgeDispatcher(object):
|
||||
def update_contest_rank(self):
|
||||
if self.contest_id and self.contest.status != ContestStatus.CONTEST_UNDERWAY:
|
||||
return
|
||||
if self.contest.real_time_rank:
|
||||
default_cache.delete(CacheKey.contest_rank_cache + str(self.contest_id))
|
||||
if self.contest.rule_type == ContestRuleType.OI or self.contest.real_time_rank:
|
||||
cache.delete(f"{CacheKey.contest_rank_cache}:{self.contest.id}")
|
||||
with transaction.atomic():
|
||||
if self.contest.rule_type == ContestRuleType.ACM:
|
||||
acm_rank, _ = ACMContestRank.objects.select_for_update(). \
|
||||
|
@ -5,8 +5,12 @@ BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
DATABASES = {
|
||||
'default': {
|
||||
'ENGINE': 'django.db.backends.sqlite3',
|
||||
'NAME': os.path.join(BASE_DIR, 'db.sqlite3'),
|
||||
'ENGINE': 'django.db.backends.postgresql_psycopg2',
|
||||
'HOST': '127.0.0.1',
|
||||
'PORT': 5433,
|
||||
'NAME': "onlinejudge",
|
||||
'USER': "onlinejudge",
|
||||
'PASSWORD': 'onlinejudge'
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -46,6 +46,7 @@ INSTALLED_APPS = (
|
||||
'contest',
|
||||
'utils',
|
||||
'submission',
|
||||
'options',
|
||||
)
|
||||
|
||||
MIDDLEWARE_CLASSES = (
|
||||
@ -57,11 +58,9 @@ MIDDLEWARE_CLASSES = (
|
||||
'django.middleware.clickjacking.XFrameOptionsMiddleware',
|
||||
'django.middleware.security.SecurityMiddleware',
|
||||
'account.middleware.AdminRoleRequiredMiddleware',
|
||||
'account.middleware.SessionSecurityMiddleware',
|
||||
'account.middleware.SessionRecordMiddleware',
|
||||
# 'account.middleware.LogSqlMiddleware',
|
||||
)
|
||||
SESSION_ENGINE = 'django.contrib.sessions.backends.cache'
|
||||
ROOT_URLCONF = 'oj.urls'
|
||||
|
||||
TEMPLATES = [
|
||||
@ -164,8 +163,6 @@ LOGGING = {
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
REST_FRAMEWORK = {
|
||||
'TEST_REQUEST_DEFAULT_FORMAT': 'json',
|
||||
'DEFAULT_RENDERER_CLASSES': (
|
||||
@ -173,34 +170,34 @@ REST_FRAMEWORK = {
|
||||
)
|
||||
}
|
||||
|
||||
CACHE_JUDGE_QUEUE = "judge_queue"
|
||||
CACHE_THROTTLING = "throttling"
|
||||
|
||||
REDIS_URL = "redis://127.0.0.1:6379"
|
||||
|
||||
|
||||
def redis_config(db):
|
||||
def make_key(key, key_prefix, version):
|
||||
return key
|
||||
|
||||
return {
|
||||
"BACKEND": "utils.cache.MyRedisCache",
|
||||
"LOCATION": f"{REDIS_URL}/{db}",
|
||||
"TIMEOUT": None,
|
||||
"KEY_PREFIX": "",
|
||||
"KEY_FUNCTION": make_key
|
||||
}
|
||||
|
||||
CACHES = {
|
||||
"default": {
|
||||
"BACKEND": "django_redis.cache.RedisCache",
|
||||
"LOCATION": "redis://127.0.0.1:6379/1",
|
||||
"OPTIONS": {
|
||||
"CLIENT_CLASS": "django_redis.client.DefaultClient",
|
||||
}
|
||||
},
|
||||
CACHE_JUDGE_QUEUE: {
|
||||
"BACKEND": "django_redis.cache.RedisCache",
|
||||
"LOCATION": "redis://127.0.0.1:6379/2",
|
||||
"OPTIONS": {
|
||||
"CLIENT_CLASS": "django_redis.client.DefaultClient",
|
||||
}
|
||||
},
|
||||
CACHE_THROTTLING: {
|
||||
"BACKEND": "django_redis.cache.RedisCache",
|
||||
"LOCATION": "redis://127.0.0.1:6379/3",
|
||||
"OPTIONS": {
|
||||
"CLIENT_CLASS": "django_redis.client.DefaultClient",
|
||||
}
|
||||
}
|
||||
"default": redis_config(db=1)
|
||||
}
|
||||
|
||||
|
||||
CELERY_RESULT_BACKEND = CELERY_BROKER_URL = f"{REDIS_URL}/2"
|
||||
CELERY_TASK_SOFT_TIME_LIMIT = CELERY_TASK_TIME_LIMIT = 180
|
||||
|
||||
SESSION_ENGINE = "django.contrib.sessions.backends.cache"
|
||||
SESSION_CACHE_ALIAS = "default"
|
||||
|
||||
|
||||
# For celery
|
||||
REDIS_QUEUE = {
|
||||
"host": "127.0.0.1",
|
||||
|
0
options/__init__.py
Normal file
0
options/__init__.py
Normal file
25
options/migrations/0001_initial.py
Normal file
25
options/migrations/0001_initial.py
Normal file
@ -0,0 +1,25 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Generated by Django 1.11.3 on 2017-10-01 19:19
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import jsonfield.fields
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
initial = True
|
||||
|
||||
dependencies = [
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.CreateModel(
|
||||
name='SysOptions',
|
||||
fields=[
|
||||
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
|
||||
('key', models.CharField(db_index=True, max_length=128, unique=True)),
|
||||
('value', jsonfield.fields.JSONField()),
|
||||
],
|
||||
),
|
||||
]
|
21
options/migrations/0002_auto_20171011_1214.py
Normal file
21
options/migrations/0002_auto_20171011_1214.py
Normal file
@ -0,0 +1,21 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Generated by Django 1.11.4 on 2017-10-11 12:14
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import django.contrib.postgres.fields.jsonb
|
||||
from django.db import migrations
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('options', '0001_initial'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterField(
|
||||
model_name='sysoptions',
|
||||
name='value',
|
||||
field=django.contrib.postgres.fields.jsonb.JSONField(),
|
||||
),
|
||||
]
|
0
options/migrations/__init__.py
Normal file
0
options/migrations/__init__.py
Normal file
7
options/models.py
Normal file
7
options/models.py
Normal file
@ -0,0 +1,7 @@
|
||||
from django.db import models
|
||||
from utils.models import JSONField
|
||||
|
||||
|
||||
class SysOptions(models.Model):
|
||||
key = models.CharField(max_length=128, unique=True, db_index=True)
|
||||
value = JSONField()
|
179
options/options.py
Normal file
179
options/options.py
Normal file
@ -0,0 +1,179 @@
|
||||
from django.core.cache import cache
|
||||
from django.db import transaction, IntegrityError
|
||||
|
||||
from utils.constants import CacheKey
|
||||
from utils.shortcuts import rand_str
|
||||
from .models import SysOptions as SysOptionsModel
|
||||
|
||||
|
||||
class OptionKeys:
|
||||
website_base_url = "website_base_url"
|
||||
website_name = "website_name"
|
||||
website_name_shortcut = "website_name_shortcut"
|
||||
website_footer = "website_footer"
|
||||
allow_register = "allow_register"
|
||||
submission_list_show_all = "submission_list_show_all"
|
||||
smtp_config = "smtp_config"
|
||||
judge_server_token = "judge_server_token"
|
||||
|
||||
|
||||
class OptionDefaultValue:
|
||||
website_base_url = "http://127.0.0.1"
|
||||
website_name = "Online Judge"
|
||||
website_name_shortcut = "oj"
|
||||
website_footer = "Online Judge Footer"
|
||||
allow_register = True
|
||||
submission_list_show_all = True
|
||||
smtp_config = {}
|
||||
judge_server_token = rand_str
|
||||
|
||||
|
||||
class _SysOptionsMeta(type):
|
||||
@classmethod
|
||||
def _set_cache(mcs, option_key, option_value):
|
||||
cache.set(f"{CacheKey.option}:{option_key}", option_value, timeout=60)
|
||||
|
||||
@classmethod
|
||||
def _del_cache(mcs, option_key):
|
||||
cache.delete(f"{CacheKey.option}:{option_key}")
|
||||
|
||||
@classmethod
|
||||
def _get_keys(cls):
|
||||
return [key for key in OptionKeys.__dict__ if not key.startswith("__")]
|
||||
|
||||
def rebuild_cache(cls):
|
||||
for key in cls._get_keys():
|
||||
# get option 的时候会写 cache 的
|
||||
cls._get_option(key, use_cache=False)
|
||||
|
||||
@classmethod
|
||||
def _init_option(mcs):
|
||||
for item in mcs._get_keys():
|
||||
if not SysOptionsModel.objects.filter(key=item).exists():
|
||||
default_value = getattr(OptionDefaultValue, item)
|
||||
if callable(default_value):
|
||||
default_value = default_value()
|
||||
try:
|
||||
SysOptionsModel.objects.create(key=item, value=default_value)
|
||||
except IntegrityError:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def _get_option(mcs, option_key, use_cache=True):
|
||||
try:
|
||||
if use_cache:
|
||||
option = cache.get(f"{CacheKey.option}:{option_key}")
|
||||
if option:
|
||||
return option
|
||||
option = SysOptionsModel.objects.get(key=option_key)
|
||||
value = option.value
|
||||
mcs._set_cache(option_key, value)
|
||||
return value
|
||||
except SysOptionsModel.DoesNotExist:
|
||||
mcs._init_option()
|
||||
return mcs._get_option(option_key, use_cache=use_cache)
|
||||
|
||||
@classmethod
|
||||
def _set_option(mcs, option_key: str, option_value):
|
||||
try:
|
||||
with transaction.atomic():
|
||||
option = SysOptionsModel.objects.select_for_update().get(key=option_key)
|
||||
option.value = option_value
|
||||
option.save()
|
||||
mcs._del_cache(option_key)
|
||||
except SysOptionsModel.DoesNotExist:
|
||||
mcs._init_option()
|
||||
mcs._set_option(option_key, option_value)
|
||||
|
||||
@classmethod
|
||||
def _increment(mcs, option_key):
|
||||
try:
|
||||
with transaction.atomic():
|
||||
option = SysOptionsModel.objects.select_for_update().get(key=option_key)
|
||||
value = option.value + 1
|
||||
option.value = value
|
||||
option.save()
|
||||
mcs._del_cache(option_key)
|
||||
except SysOptionsModel.DoesNotExist:
|
||||
mcs._init_option()
|
||||
return mcs._increment(option_key)
|
||||
|
||||
@classmethod
|
||||
def set_options(mcs, options):
|
||||
for key, value in options:
|
||||
mcs._set_option(key, value)
|
||||
|
||||
@classmethod
|
||||
def get_options(mcs, keys):
|
||||
result = {}
|
||||
for key in keys:
|
||||
result[key] = mcs._get_option(key)
|
||||
return result
|
||||
|
||||
@property
|
||||
def website_base_url(cls):
|
||||
return cls._get_option(OptionKeys.website_base_url)
|
||||
|
||||
@website_base_url.setter
|
||||
def website_base_url(cls, value):
|
||||
cls._set_option(OptionKeys.website_base_url, value)
|
||||
|
||||
@property
|
||||
def website_name(cls):
|
||||
return cls._get_option(OptionKeys.website_name)
|
||||
|
||||
@website_name.setter
|
||||
def website_name(cls, value):
|
||||
cls._set_option(OptionKeys.website_name, value)
|
||||
|
||||
@property
|
||||
def website_name_shortcut(cls):
|
||||
return cls._get_option(OptionKeys.website_name_shortcut)
|
||||
|
||||
@website_name_shortcut.setter
|
||||
def website_name_shortcut(cls, value):
|
||||
cls._set_option(OptionKeys.website_name_shortcut, value)
|
||||
|
||||
@property
|
||||
def website_footer(cls):
|
||||
return cls._get_option(OptionKeys.website_footer)
|
||||
|
||||
@website_footer.setter
|
||||
def website_footer(cls, value):
|
||||
cls._set_option(OptionKeys.website_footer, value)
|
||||
|
||||
@property
|
||||
def allow_register(cls):
|
||||
return cls._get_option(OptionKeys.allow_register)
|
||||
|
||||
@allow_register.setter
|
||||
def allow_register(cls, value):
|
||||
cls._set_option(OptionKeys.allow_register, value)
|
||||
|
||||
@property
|
||||
def submission_list_show_all(cls):
|
||||
return cls._get_option(OptionKeys.submission_list_show_all)
|
||||
|
||||
@submission_list_show_all.setter
|
||||
def submission_list_show_all(cls, value):
|
||||
cls._set_option(OptionKeys.submission_list_show_all, value)
|
||||
|
||||
@property
|
||||
def smtp_config(cls):
|
||||
return cls._get_option(OptionKeys.smtp_config)
|
||||
|
||||
@smtp_config.setter
|
||||
def smtp_config(cls, value):
|
||||
cls._set_option(OptionKeys.smtp_config, value)
|
||||
|
||||
@property
|
||||
def judge_server_token(cls):
|
||||
return cls._get_option(OptionKeys.judge_server_token)
|
||||
|
||||
@judge_server_token.setter
|
||||
def judge_server_token(cls, value):
|
||||
cls._set_option(OptionKeys.judge_server_token, value)
|
||||
|
||||
|
||||
class SysOptions(metaclass=_SysOptionsMeta):
|
||||
pass
|
1
options/tests.py
Normal file
1
options/tests.py
Normal file
@ -0,0 +1 @@
|
||||
# Create your tests here.
|
1
options/views.py
Normal file
1
options/views.py
Normal file
@ -0,0 +1 @@
|
||||
# Create your views here.
|
41
problem/migrations/0009_auto_20171011_1214.py
Normal file
41
problem/migrations/0009_auto_20171011_1214.py
Normal file
@ -0,0 +1,41 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Generated by Django 1.11.4 on 2017-10-11 12:14
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import django.contrib.postgres.fields.jsonb
|
||||
from django.db import migrations
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('problem', '0008_auto_20170923_1318'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterField(
|
||||
model_name='problem',
|
||||
name='languages',
|
||||
field=django.contrib.postgres.fields.jsonb.JSONField(),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name='problem',
|
||||
name='samples',
|
||||
field=django.contrib.postgres.fields.jsonb.JSONField(),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name='problem',
|
||||
name='statistic_info',
|
||||
field=django.contrib.postgres.fields.jsonb.JSONField(default=dict),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name='problem',
|
||||
name='template',
|
||||
field=django.contrib.postgres.fields.jsonb.JSONField(),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name='problem',
|
||||
name='test_case_score',
|
||||
field=django.contrib.postgres.fields.jsonb.JSONField(),
|
||||
),
|
||||
]
|
@ -1,5 +1,5 @@
|
||||
from django.db import models
|
||||
from jsonfield import JSONField
|
||||
from utils.models import JSONField
|
||||
|
||||
from account.models import User
|
||||
from contest.models import Contest
|
||||
@ -65,8 +65,8 @@ class Problem(models.Model):
|
||||
total_score = models.IntegerField(default=0, blank=True)
|
||||
submission_number = models.BigIntegerField(default=0)
|
||||
accepted_number = models.BigIntegerField(default=0)
|
||||
# ACM rule_type: {JudgeStatus.ACCEPTED: 3, JudgeStaus.WRONG_ANSWER: 11}, the number means count
|
||||
statistic_info = JSONField(default={})
|
||||
# {JudgeStatus.ACCEPTED: 3, JudgeStaus.WRONG_ANSWER: 11}, the number means count
|
||||
statistic_info = JSONField(default=dict)
|
||||
|
||||
class Meta:
|
||||
db_table = "problem"
|
||||
|
@ -71,6 +71,7 @@ class CreateContestProblemSerializer(CreateOrEditProblemSerializer):
|
||||
class TagSerializer(serializers.ModelSerializer):
|
||||
class Meta:
|
||||
model = ProblemTag
|
||||
fields = "__all__"
|
||||
|
||||
|
||||
class BaseProblemSerializer(serializers.ModelSerializer):
|
||||
@ -88,11 +89,13 @@ class BaseProblemSerializer(serializers.ModelSerializer):
|
||||
class ProblemAdminSerializer(BaseProblemSerializer):
|
||||
class Meta:
|
||||
model = Problem
|
||||
fields = "__all__"
|
||||
|
||||
|
||||
class ContestProblemAdminSerializer(BaseProblemSerializer):
|
||||
class Meta:
|
||||
model = Problem
|
||||
fields = "__all__"
|
||||
|
||||
|
||||
class ProblemSerializer(BaseProblemSerializer):
|
||||
|
@ -1,4 +1,5 @@
|
||||
import copy
|
||||
import hashlib
|
||||
import os
|
||||
import shutil
|
||||
from datetime import timedelta
|
||||
@ -9,6 +10,7 @@ from django.conf import settings
|
||||
from utils.api.tests import APITestCase
|
||||
|
||||
from .models import ProblemTag
|
||||
from .models import Problem, ProblemRuleType
|
||||
from .views.admin import TestCaseUploadAPI
|
||||
from contest.models import Contest
|
||||
from contest.tests import DEFAULT_CONTEST_DATA
|
||||
@ -23,6 +25,40 @@ DEFAULT_PROBLEM_DATA = {"_id": "A-110", "title": "test", "description": "<p>test
|
||||
"input_size": 0, "score": 0}],
|
||||
"rule_type": "ACM", "hint": "<p>test</p>", "source": "test"}
|
||||
|
||||
class ProblemCreateTestBase(APITestCase):
|
||||
@staticmethod
|
||||
def add_problem(problem_data, created_by):
|
||||
data = copy.deepcopy(problem_data)
|
||||
if data["spj"]:
|
||||
if not data["spj_language"] or not data["spj_code"]:
|
||||
raise ValueError("Invalid spj")
|
||||
data["spj_version"] = hashlib.md5((data["spj_language"] + ":" + data["spj_code"]).encode("utf-8")).hexdigest()
|
||||
else:
|
||||
data["spj_language"] = None
|
||||
data["spj_code"] = None
|
||||
if data["rule_type"] == ProblemRuleType.OI:
|
||||
total_score = 0
|
||||
for item in data["test_case_score"]:
|
||||
if item["score"] <= 0:
|
||||
raise ValueError("invalid score")
|
||||
else:
|
||||
total_score += item["score"]
|
||||
data["total_score"] = total_score
|
||||
data["created_by"] = created_by
|
||||
tags = data.pop("tags")
|
||||
|
||||
data["languages"] = list(data["languages"])
|
||||
|
||||
problem = Problem.objects.create(**data)
|
||||
|
||||
for item in tags:
|
||||
try:
|
||||
tag = ProblemTag.objects.get(name=item)
|
||||
except ProblemTag.DoesNotExist:
|
||||
tag = ProblemTag.objects.create(name=item)
|
||||
problem.tags.add(tag)
|
||||
return problem
|
||||
|
||||
|
||||
class ProblemTagListAPITest(APITestCase):
|
||||
def test_get_tag_list(self):
|
||||
@ -96,7 +132,7 @@ class ProblemAdminAPITest(APITestCase):
|
||||
def setUp(self):
|
||||
self.url = self.reverse("problem_admin_api")
|
||||
self.create_super_admin()
|
||||
self.data = DEFAULT_PROBLEM_DATA
|
||||
self.data = copy.deepcopy(DEFAULT_PROBLEM_DATA)
|
||||
|
||||
def test_create_problem(self):
|
||||
resp = self.client.post(self.url, data=self.data)
|
||||
@ -138,23 +174,19 @@ class ProblemAdminAPITest(APITestCase):
|
||||
self.assertSuccess(resp)
|
||||
|
||||
|
||||
class ProblemAPITest(APITestCase):
|
||||
class ProblemAPITest(ProblemCreateTestBase):
|
||||
def setUp(self):
|
||||
self.url = self.reverse("problem_api")
|
||||
self.create_admin()
|
||||
|
||||
def create_problem(self):
|
||||
url = self.reverse("problem_admin_api")
|
||||
return self.client.post(url, data=DEFAULT_PROBLEM_DATA)
|
||||
admin = self.create_admin(login=False)
|
||||
self.problem = self.add_problem(DEFAULT_PROBLEM_DATA, admin)
|
||||
self.create_user("test", "test123")
|
||||
|
||||
def test_get_problem_list(self):
|
||||
self.create_problem()
|
||||
resp = self.client.get(f"{self.url}?limit=10")
|
||||
self.assertSuccess(resp)
|
||||
|
||||
def get_one_problem(self):
|
||||
problem_id = self.create_problem().data["data"]["_id"]
|
||||
resp = self.client.get(self.url + "?id=" + str(problem_id))
|
||||
resp = self.client.get(self.url + "?id=" + self.problem._id)
|
||||
self.assertSuccess(resp)
|
||||
|
||||
|
||||
@ -169,51 +201,49 @@ class ContestProblemAdminTest(APITestCase):
|
||||
|
||||
def test_create_contest_problem(self):
|
||||
contest = self.create_contest()
|
||||
data = DEFAULT_PROBLEM_DATA
|
||||
data = copy.deepcopy(DEFAULT_PROBLEM_DATA)
|
||||
data["contest_id"] = contest.data["data"]["id"]
|
||||
resp = self.client.post(self.url, data=data)
|
||||
self.assertSuccess(resp)
|
||||
return resp
|
||||
return contest, resp
|
||||
|
||||
def test_get_contest_problem(self):
|
||||
contest = self.test_create_contest_problem()
|
||||
contest, contest_problem = self.test_create_contest_problem()
|
||||
contest_id = contest.data["data"]["id"]
|
||||
resp = self.client.get(self.url + "?contest_id=" + str(contest_id))
|
||||
self.assertSuccess(resp)
|
||||
self.assertEqual(len(resp.data["data"]), 1)
|
||||
|
||||
def test_get_one_contest_problem(self):
|
||||
contest = self.test_create_contest_problem()
|
||||
contest, contest_problem = self.test_create_contest_problem()
|
||||
contest_id = contest.data["data"]["id"]
|
||||
resp = self.client.get(self.url + "?id=" + str(contest_id))
|
||||
problem_id = contest_problem.data["data"]["id"]
|
||||
resp = self.client.get(f"{self.url}?contest_id={contest_id}&id={problem_id}")
|
||||
self.assertSuccess(resp)
|
||||
|
||||
|
||||
class ContestProblemTest(APITestCase):
|
||||
class ContestProblemTest(ProblemCreateTestBase):
|
||||
def setUp(self):
|
||||
self.url = self.reverse("contest_problem_api")
|
||||
self.create_admin()
|
||||
|
||||
admin = self.create_admin()
|
||||
url = self.reverse("contest_admin_api")
|
||||
contest_data = copy.deepcopy(DEFAULT_CONTEST_DATA)
|
||||
contest_data["password"] = ""
|
||||
contest_data["start_time"] = contest_data["start_time"] + timedelta(hours=1)
|
||||
self.contest = self.client.post(url, data=contest_data).data["data"]
|
||||
self.problem = self.add_problem(DEFAULT_PROBLEM_DATA, admin)
|
||||
self.problem.contest_id = self.contest["id"]
|
||||
self.problem.save()
|
||||
self.url = self.reverse("contest_problem_api")
|
||||
|
||||
problem_data = copy.deepcopy(DEFAULT_PROBLEM_DATA)
|
||||
problem_data["contest"] = self.contest["id"]
|
||||
url = self.reverse("contest_problem_admin_api")
|
||||
self.problem = self.client.post(url, problem_data).data["data"]
|
||||
|
||||
def test_get_contest_problem_list(self):
|
||||
def test_admin_get_contest_problem_list(self):
|
||||
contest_id = self.contest["id"]
|
||||
resp = self.client.get(self.url + "?contest_id=" + str(contest_id))
|
||||
self.assertSuccess(resp)
|
||||
self.assertEqual(len(resp.data["data"]), 1)
|
||||
|
||||
def test_get_one_contest_problem(self):
|
||||
def test_admin_get_one_contest_problem(self):
|
||||
contest_id = self.contest["id"]
|
||||
problem_id = self.problem["_id"]
|
||||
problem_id = self.problem._id
|
||||
resp = self.client.get("{}?contest_id={}&problem_id={}".format(self.url, contest_id, problem_id))
|
||||
self.assertSuccess(resp)
|
||||
|
||||
|
@ -223,6 +223,7 @@ class ProblemAPI(APIView):
|
||||
data["total_score"] = total_score
|
||||
# todo check filename and score info
|
||||
tags = data.pop("tags")
|
||||
data["languages"] = list(data["languages"])
|
||||
|
||||
for k, v in data.items():
|
||||
setattr(problem, k, v)
|
||||
|
@ -13,6 +13,24 @@ class ProblemTagAPI(APIView):
|
||||
|
||||
|
||||
class ProblemAPI(APIView):
|
||||
@staticmethod
|
||||
def _add_problem_status(request, queryset_values):
|
||||
if request.user.is_authenticated():
|
||||
profile = request.user.userprofile
|
||||
acm_problems_status = profile.acm_problems_status.get("problems", {})
|
||||
oi_problems_status = profile.oi_problems_status.get("problems", {})
|
||||
# paginate data
|
||||
results = queryset_values.get("results")
|
||||
if results is not None:
|
||||
problems = results
|
||||
else:
|
||||
problems = [queryset_values,]
|
||||
for problem in problems:
|
||||
if problem["rule_type"] == ProblemRuleType.ACM:
|
||||
problem["my_status"] = acm_problems_status.get(str(problem["id"]), {}).get("status")
|
||||
else:
|
||||
problem["my_status"] = oi_problems_status.get(str(problem["id"]), {}).get("status")
|
||||
|
||||
def get(self, request):
|
||||
# 问题详情页
|
||||
problem_id = request.GET.get("problem_id")
|
||||
@ -20,7 +38,9 @@ class ProblemAPI(APIView):
|
||||
try:
|
||||
problem = Problem.objects.select_related("created_by")\
|
||||
.get(_id=problem_id, contest_id__isnull=True, visible=True)
|
||||
return self.success(ProblemSerializer(problem).data)
|
||||
problem_data = ProblemSerializer(problem).data
|
||||
self._add_problem_status(request, problem_data)
|
||||
return self.success(problem_data)
|
||||
except Problem.DoesNotExist:
|
||||
return self.error("Problem does not exist")
|
||||
|
||||
@ -32,11 +52,7 @@ class ProblemAPI(APIView):
|
||||
# 按照标签筛选
|
||||
tag_text = request.GET.get("tag")
|
||||
if tag_text:
|
||||
try:
|
||||
tag = ProblemTag.objects.get(name=tag_text)
|
||||
except ProblemTag.DoesNotExist:
|
||||
return self.error("The Tag does not exist.")
|
||||
problems = tag.problem_set.all().filter(visible=True)
|
||||
problems = problems.filter(tags__name=tag_text)
|
||||
|
||||
# 搜索的情况
|
||||
keyword = request.GET.get("keyword", "").strip()
|
||||
@ -49,19 +65,23 @@ class ProblemAPI(APIView):
|
||||
problems = problems.filter(difficulty=difficulty)
|
||||
# 根据profile 为做过的题目添加标记
|
||||
data = self.paginate_data(request, problems, ProblemSerializer)
|
||||
if request.user.id:
|
||||
profile = request.user.userprofile
|
||||
acm_problems_status = profile.acm_problems_status.get("problems", {})
|
||||
oi_problems_status = profile.oi_problems_status.get("problems", {})
|
||||
for problem in data["results"]:
|
||||
if problem["rule_type"] == ProblemRuleType.ACM:
|
||||
problem["my_status"] = acm_problems_status.get(str(problem["id"]), None)
|
||||
else:
|
||||
problem["my_status"] = oi_problems_status.get(str(problem["id"]), None)
|
||||
self._add_problem_status(request, data)
|
||||
return self.success(data)
|
||||
|
||||
|
||||
class ContestProblemAPI(APIView):
|
||||
def _add_problem_status(self, request, queryset_values):
|
||||
if self.contest.rule_type == ContestRuleType.OI and not self.contest.check_oi_permission(request.user):
|
||||
return
|
||||
if request.user.is_authenticated():
|
||||
profile = request.user.userprofile
|
||||
if self.contest.rule_type == ContestRuleType.ACM:
|
||||
problems_status = profile.acm_problems_status.get("contest_problems", {})
|
||||
else:
|
||||
problems_status = profile.oi_problems_status.get("contest_problems", {})
|
||||
for problem in queryset_values:
|
||||
problem["my_status"] = problems_status.get(str(problem["id"]), {}).get("status")
|
||||
|
||||
@check_contest_permission
|
||||
def get(self, request):
|
||||
problem_id = request.GET.get("problem_id")
|
||||
@ -72,17 +92,12 @@ class ContestProblemAPI(APIView):
|
||||
visible=True)
|
||||
except Problem.DoesNotExist:
|
||||
return self.error("Problem does not exist.")
|
||||
return self.success(ContestProblemSerializer(problem).data)
|
||||
problem_data = ContestProblemSerializer(problem).data
|
||||
self._add_problem_status(request, [problem_data,])
|
||||
return self.success(problem_data)
|
||||
|
||||
contest_problems = Problem.objects.select_related("created_by").filter(contest=self.contest, visible=True)
|
||||
# 根据profile, 为做过的题目添加标记
|
||||
data = ContestProblemSerializer(contest_problems, many=True).data
|
||||
if request.user.is_authenticated() and self.contest.rule_type != ContestRuleType.OI:
|
||||
profile = request.user.userprofile
|
||||
if self.contest.rule_type == ContestRuleType.ACM:
|
||||
problems_status = profile.acm_problems_status.get("contest_problems", {})
|
||||
else:
|
||||
problems_status = profile.oi_problems_status.get("contest_problems", {})
|
||||
for problem in data:
|
||||
problem["my_status"] = problems_status.get(str(problem["id"]), None)
|
||||
self._add_problem_status(request, data)
|
||||
return self.success(data)
|
||||
|
26
submission/migrations/0008_auto_20171011_1214.py
Normal file
26
submission/migrations/0008_auto_20171011_1214.py
Normal file
@ -0,0 +1,26 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Generated by Django 1.11.4 on 2017-10-11 12:14
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import django.contrib.postgres.fields.jsonb
|
||||
from django.db import migrations
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('submission', '0007_auto_20170923_1318'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterField(
|
||||
model_name='submission',
|
||||
name='info',
|
||||
field=django.contrib.postgres.fields.jsonb.JSONField(default=dict),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name='submission',
|
||||
name='statistic_info',
|
||||
field=django.contrib.postgres.fields.jsonb.JSONField(default=dict),
|
||||
),
|
||||
]
|
@ -1,5 +1,5 @@
|
||||
from django.db import models
|
||||
from jsonfield import JSONField
|
||||
from utils.models import JSONField
|
||||
from account.models import AdminType
|
||||
from problem.models import Problem
|
||||
from contest.models import Contest
|
||||
@ -30,18 +30,20 @@ class Submission(models.Model):
|
||||
username = models.CharField(max_length=30)
|
||||
code = models.TextField()
|
||||
result = models.IntegerField(db_index=True, default=JudgeStatus.PENDING)
|
||||
# 判题结果的详细信息
|
||||
info = JSONField(default={})
|
||||
# 从JudgeServer返回的判题详情
|
||||
info = JSONField(default=dict)
|
||||
language = models.CharField(max_length=20)
|
||||
shared = models.BooleanField(default=False)
|
||||
# 存储该提交所用时间和内存值,方便提交列表显示
|
||||
# {time_cost: "", memory_cost: "", err_info: "", score: 0}
|
||||
statistic_info = JSONField(default={})
|
||||
statistic_info = JSONField(default=dict)
|
||||
|
||||
def check_user_permission(self, user):
|
||||
def check_user_permission(self, user, check_share=True):
|
||||
return self.user_id == user.id or \
|
||||
self.shared is True or \
|
||||
user.admin_type == AdminType.SUPER_ADMIN
|
||||
(check_share and self.shared is True) or \
|
||||
user.is_super_admin() or \
|
||||
user.can_mgmt_all_problem() or \
|
||||
self.problem.created_by_id == user.id
|
||||
|
||||
class Meta:
|
||||
db_table = "submission"
|
||||
|
@ -10,6 +10,11 @@ class CreateSubmissionSerializer(serializers.Serializer):
|
||||
contest_id = serializers.IntegerField(required=False)
|
||||
|
||||
|
||||
class ShareSubmissionSerializer(serializers.Serializer):
|
||||
id = serializers.CharField()
|
||||
shared = serializers.BooleanField()
|
||||
|
||||
|
||||
class SubmissionModelSerializer(serializers.ModelSerializer):
|
||||
info = serializers.JSONField()
|
||||
statistic_info = serializers.JSONField()
|
||||
@ -19,7 +24,7 @@ class SubmissionModelSerializer(serializers.ModelSerializer):
|
||||
|
||||
|
||||
# 不显示submission info的serializer, 用于ACM rule_type
|
||||
class SubmissionSafeSerializer(serializers.ModelSerializer):
|
||||
class SubmissionSafeModelSerializer(serializers.ModelSerializer):
|
||||
problem = serializers.SlugRelatedField(read_only=True, slug_field="_id")
|
||||
statistic_info = serializers.JSONField()
|
||||
|
||||
@ -43,6 +48,6 @@ class SubmissionListSerializer(serializers.ModelSerializer):
|
||||
|
||||
def get_show_link(self, obj):
|
||||
# 没传user或为匿名user
|
||||
if self.user is None or self.user.id is None:
|
||||
if self.user is None or not self.user.is_authenticated():
|
||||
return False
|
||||
return obj.check_user_permission(self.user)
|
||||
|
@ -32,14 +32,16 @@ class SubmissionPrepare(APITestCase):
|
||||
def _create_problem_and_submission(self):
|
||||
user = self.create_admin("test", "test123", login=False)
|
||||
problem_data = deepcopy(DEFAULT_PROBLEM_DATA)
|
||||
problem_data.pop("tags")
|
||||
tags = problem_data.pop("tags")
|
||||
problem_data["created_by"] = user
|
||||
self.problem = Problem.objects.create(**problem_data)
|
||||
for tag in DEFAULT_PROBLEM_DATA["tags"]:
|
||||
for tag in tags:
|
||||
tag = ProblemTag.objects.create(name=tag)
|
||||
self.problem.tags.add(tag)
|
||||
self.problem.save()
|
||||
self.submission = Submission.objects.create(**DEFAULT_SUBMISSION_DATA)
|
||||
self.submission_data = deepcopy(DEFAULT_SUBMISSION_DATA)
|
||||
self.submission_data["problem_id"] = self.problem.id
|
||||
self.submission = Submission.objects.create(**self.submission_data)
|
||||
|
||||
|
||||
class SubmissionListTest(SubmissionPrepare):
|
||||
@ -61,6 +63,6 @@ class SubmissionAPITest(SubmissionPrepare):
|
||||
self.url = self.reverse("submission_api")
|
||||
|
||||
def test_create_submission(self, judge_task):
|
||||
resp = self.client.post(self.url, DEFAULT_SUBMISSION_DATA)
|
||||
resp = self.client.post(self.url, self.submission_data)
|
||||
self.assertSuccess(resp)
|
||||
judge_task.assert_called()
|
||||
|
@ -5,16 +5,17 @@ from problem.models import Problem, ProblemRuleType
|
||||
from contest.models import Contest, ContestStatus, ContestRuleType
|
||||
from utils.api import APIView, validate_serializer
|
||||
from utils.throttling import TokenBucket, BucketController
|
||||
from utils.cache import cache
|
||||
from ..models import Submission
|
||||
from ..serializers import CreateSubmissionSerializer, SubmissionModelSerializer
|
||||
from ..serializers import SubmissionSafeSerializer, SubmissionListSerializer
|
||||
from utils.cache import throttling_cache
|
||||
from ..serializers import (CreateSubmissionSerializer, SubmissionModelSerializer,
|
||||
ShareSubmissionSerializer)
|
||||
from ..serializers import SubmissionSafeModelSerializer, SubmissionListSerializer
|
||||
|
||||
|
||||
def _submit(response, user, problem_id, language, code, contest_id):
|
||||
# TODO: 预设默认值,需修改
|
||||
controller = BucketController(user_id=user.id,
|
||||
redis_conn=throttling_cache,
|
||||
redis_conn=cache,
|
||||
default_capacity=30)
|
||||
bucket = TokenBucket(fill_rate=10, capacity=20,
|
||||
last_capacity=controller.last_capacity,
|
||||
@ -63,17 +64,36 @@ class SubmissionAPI(APIView):
|
||||
def get(self, request):
|
||||
submission_id = request.GET.get("id")
|
||||
if not submission_id:
|
||||
return self.error("Parameter id doesn't exist.")
|
||||
return self.error("Parameter id doesn't exist")
|
||||
try:
|
||||
submission = Submission.objects.select_related("problem").get(id=submission_id)
|
||||
except Submission.DoesNotExist:
|
||||
return self.error("Submission doesn't exist.")
|
||||
return self.error("Submission doesn't exist")
|
||||
if not submission.check_user_permission(request.user):
|
||||
return self.error("No permission for this submission.")
|
||||
return self.error("No permission for this submission")
|
||||
|
||||
if submission.problem.rule_type == ProblemRuleType.ACM:
|
||||
return self.success(SubmissionSafeSerializer(submission).data)
|
||||
return self.success(SubmissionModelSerializer(submission).data)
|
||||
submission_data = SubmissionSafeModelSerializer(submission).data
|
||||
else:
|
||||
submission_data = SubmissionModelSerializer(submission).data
|
||||
# 是否有权限取消共享
|
||||
submission_data["can_unshare"] = submission.check_user_permission(request.user, check_share=False)
|
||||
return self.success(submission_data)
|
||||
|
||||
@validate_serializer(ShareSubmissionSerializer)
|
||||
@login_required
|
||||
def put(self, request):
|
||||
try:
|
||||
submission = Submission.objects.select_related("problem").get(id=request.data["id"])
|
||||
except Submission.DoesNotExist:
|
||||
return self.error("Submission doesn't exist")
|
||||
if not submission.check_user_permission(request.user, check_share=False):
|
||||
return self.error("No permission to share the submission")
|
||||
if submission.contest and submission.contest.status == ContestStatus.CONTEST_UNDERWAY:
|
||||
return self.error("Can not share submission now")
|
||||
submission.shared = request.data["shared"]
|
||||
submission.save(update_fields=["shared"])
|
||||
return self.success()
|
||||
|
||||
|
||||
class SubmissionListAPI(APIView):
|
||||
@ -83,7 +103,7 @@ class SubmissionListAPI(APIView):
|
||||
if request.GET.get("contest_id"):
|
||||
return self.error("Parameter error")
|
||||
|
||||
submissions = Submission.objects.filter(contest_id__isnull=True)
|
||||
submissions = Submission.objects.filter(contest_id__isnull=True).select_related("problem__created_by")
|
||||
problem_id = request.GET.get("problem_id")
|
||||
myself = request.GET.get("myself")
|
||||
result = request.GET.get("result")
|
||||
@ -109,10 +129,10 @@ class ContestSubmissionListAPI(APIView):
|
||||
return self.error("Limit is needed")
|
||||
|
||||
contest = self.contest
|
||||
if contest.rule_type == ContestRuleType.OI and not contest.is_contest_admin(request.user):
|
||||
if not contest.check_oi_permission(request.user):
|
||||
return self.error("No permission for OI contest submissions")
|
||||
|
||||
submissions = Submission.objects.filter(contest_id=contest.id)
|
||||
submissions = Submission.objects.filter(contest_id=contest.id).select_related("problem__created_by")
|
||||
problem_id = request.GET.get("problem_id")
|
||||
myself = request.GET.get("myself")
|
||||
result = request.GET.get("result")
|
||||
@ -133,9 +153,11 @@ class ContestSubmissionListAPI(APIView):
|
||||
submissions = submissions.filter(create_time__gte=contest.start_time)
|
||||
|
||||
# 封榜的时候只能看到自己的提交
|
||||
if not contest.real_time_rank and not contest.is_contest_admin(request.user):
|
||||
submissions = submissions.filter(user_id=request.user.id)
|
||||
if contest.rule_type == ContestRuleType.ACM:
|
||||
if not contest.real_time_rank and not contest.is_contest_admin(request.user):
|
||||
submissions = submissions.filter(user_id=request.user.id)
|
||||
|
||||
data = self.paginate_data(request, submissions)
|
||||
data["results"] = SubmissionListSerializer(data["results"], many=True, user=request.user).data
|
||||
return self.success(data)
|
||||
|
||||
|
@ -1,2 +1,2 @@
|
||||
from .api import * # NOQA
|
||||
from ._serializers import * # NOQA
|
||||
from .api import * # NOQA
|
||||
|
@ -79,7 +79,7 @@ class APIView(View):
|
||||
def success(self, data=None):
|
||||
return self.response({"error": None, "data": data})
|
||||
|
||||
def error(self, msg, err="error"):
|
||||
def error(self, msg="error", err="error"):
|
||||
return self.response({"error": err, "data": msg})
|
||||
|
||||
def _serializer_error_to_str(self, errors):
|
||||
|
@ -3,7 +3,6 @@ from django.test.testcases import TestCase
|
||||
from rest_framework.test import APIClient
|
||||
|
||||
from account.models import AdminType, ProblemPermission, User, UserProfile
|
||||
from conf.models import WebsiteConfig
|
||||
|
||||
|
||||
class APITestCase(TestCase):
|
||||
@ -28,9 +27,6 @@ class APITestCase(TestCase):
|
||||
return self.create_user(username=username, password=password, admin_type=AdminType.SUPER_ADMIN,
|
||||
problem_permission=ProblemPermission.ALL, login=login)
|
||||
|
||||
def create_website_config(self):
|
||||
return WebsiteConfig.objects.create()
|
||||
|
||||
def reverse(self, url_name):
|
||||
return reverse(url_name)
|
||||
|
||||
|
@ -1,6 +1,27 @@
|
||||
from django.conf import settings
|
||||
from django_redis import get_redis_connection
|
||||
from django.core.cache import cache, caches # noqa
|
||||
from django.conf import settings # noqa
|
||||
|
||||
judge_cache = get_redis_connection(settings.CACHE_JUDGE_QUEUE)
|
||||
throttling_cache = get_redis_connection(settings.CACHE_THROTTLING)
|
||||
default_cache = get_redis_connection("default")
|
||||
from django_redis.cache import RedisCache
|
||||
from django_redis.client.default import DefaultClient
|
||||
|
||||
|
||||
class MyRedisClient(DefaultClient):
|
||||
def __getattr__(self, item):
|
||||
client = self.get_client(write=True)
|
||||
return getattr(client, item)
|
||||
|
||||
def redis_incr(self, key, count=1):
|
||||
"""
|
||||
django 默认的 incr 在 key 不存在时候会抛异常
|
||||
"""
|
||||
client = self.get_client(write=True)
|
||||
return client.incr(key, count)
|
||||
|
||||
|
||||
class MyRedisCache(RedisCache):
|
||||
def __init__(self, server, params):
|
||||
super().__init__(server, params)
|
||||
self._client_cls = MyRedisClient
|
||||
|
||||
def __getattr__(self, item):
|
||||
return getattr(self.client, item)
|
||||
|
@ -1,4 +1,28 @@
|
||||
class Choices:
|
||||
@classmethod
|
||||
def choices(cls):
|
||||
d = cls.__dict__
|
||||
return [d[item] for item in d.keys() if not item.startswith("__")]
|
||||
|
||||
|
||||
class ContestType:
|
||||
PUBLIC_CONTEST = "Public"
|
||||
PASSWORD_PROTECTED_CONTEST = "Password Protected"
|
||||
|
||||
|
||||
class ContestStatus:
|
||||
CONTEST_NOT_START = "1"
|
||||
CONTEST_ENDED = "-1"
|
||||
CONTEST_UNDERWAY = "0"
|
||||
|
||||
|
||||
class ContestRuleType(Choices):
|
||||
ACM = "ACM"
|
||||
OI = "OI"
|
||||
|
||||
|
||||
class CacheKey:
|
||||
waiting_queue = "waiting_queue"
|
||||
contest_rank_cache = "contest_rank_cache_"
|
||||
contest_rank_cache = "contest_rank_cache"
|
||||
website_config = "website_config"
|
||||
option = "option"
|
||||
|
@ -1,3 +1,4 @@
|
||||
from django.contrib.postgres.fields import JSONField # NOQA
|
||||
from django.db import models
|
||||
|
||||
from utils.xss_filter import XssHtml
|
||||
|
@ -1,35 +1,9 @@
|
||||
import logging
|
||||
import random
|
||||
import datetime
|
||||
from io import BytesIO
|
||||
import random
|
||||
from base64 import b64encode
|
||||
from io import BytesIO
|
||||
|
||||
from django.utils.crypto import get_random_string
|
||||
from envelopes import Envelope
|
||||
|
||||
from conf.models import SMTPConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def send_email(from_name, to_email, to_name, subject, content):
|
||||
smtp = SMTPConfig.objects.first()
|
||||
if not smtp:
|
||||
return
|
||||
envlope = Envelope(from_addr=(smtp.email, from_name),
|
||||
to_addr=(to_email, to_name),
|
||||
subject=subject,
|
||||
html_body=content)
|
||||
try:
|
||||
envlope.send(smtp.server,
|
||||
login=smtp.email,
|
||||
password=smtp.password,
|
||||
port=smtp.port,
|
||||
tls=smtp.tls)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
return False
|
||||
|
||||
|
||||
def rand_str(length=32, type="lower_hex"):
|
||||
|
@ -26,11 +26,8 @@ Cannot defense xss in browser which is belowed IE7
|
||||
浏览器版本:IE7+ 或其他浏览器,无法防御IE6及以下版本浏览器中的XSS
|
||||
"""
|
||||
import re
|
||||
|
||||
try:
|
||||
from html.parser import HTMLParser
|
||||
except:
|
||||
from HTMLParser import HTMLParser
|
||||
import copy
|
||||
from html.parser import HTMLParser
|
||||
|
||||
|
||||
class XssHtml(HTMLParser):
|
||||
@ -163,7 +160,7 @@ class XssHtml(HTMLParser):
|
||||
else:
|
||||
other = []
|
||||
if attrs:
|
||||
for (key, value) in attrs.items():
|
||||
for key, value in copy.deepcopy(attrs).items():
|
||||
if key not in self.common_attrs + other:
|
||||
del attrs[key]
|
||||
return attrs
|
||||
|
Loading…
Reference in New Issue
Block a user