Merge branch 'opt'

This commit is contained in:
zema1 2017-10-21 10:51:59 +08:00
commit c1d099ed45
52 changed files with 1011 additions and 589 deletions

View File

@ -92,7 +92,8 @@ def check_contest_permission(func):
if not user.is_authenticated(): if not user.is_authenticated():
return self.error("Please login in first.") return self.error("Please login in first.")
# password error # password error
if ("contests" not in request.session) or (self.contest.id not in request.session["contests"]): if ("accessible_contests" not in request.session) or \
(self.contest.id not in request.session["accessible_contests"]):
return self.error("Password is required.") return self.error("Password is required.")
return func(*args, **kwargs) return func(*args, **kwargs)

View File

@ -1,38 +1,19 @@
import time
import pytz
from django.contrib import auth
from django.utils import timezone
from django.utils.translation import ugettext as _
from django.db import connection from django.db import connection
from django.utils.timezone import now
from django.utils.deprecation import MiddlewareMixin from django.utils.deprecation import MiddlewareMixin
from utils.api import JSONResponse from utils.api import JSONResponse
class SessionSecurityMiddleware(MiddlewareMixin):
def process_request(self, request):
if request.user.is_authenticated():
if "last_activity" in request.session and request.user.is_admin_role():
# 24 hours passed since last visit, 86400 = 24 * 60 * 60
if time.time() - request.session["last_activity"] >= 86400:
auth.logout(request)
return JSONResponse.response({"error": "login-required", "data": _("Please login in first")})
request.session["last_activity"] = time.time()
class SessionRecordMiddleware(MiddlewareMixin): class SessionRecordMiddleware(MiddlewareMixin):
def process_request(self, request): def process_request(self, request):
if request.user.is_authenticated(): if request.user.is_authenticated():
session = request.session session = request.session
ip = request.META.get("REMOTE_ADDR", "") session["user_agent"] = request.META.get("HTTP_USER_AGENT", "")
user_agent = request.META.get("HTTP_USER_AGENT", "") session["ip"] = request.META.get("HTTP_X_REAL_IP", "UNKNOWN IP")
_ip = session.setdefault("ip", ip) session["last_activity"] = now()
_user_agent = session.setdefault("user_agent", user_agent)
if ip != _ip or user_agent != _user_agent:
session.modified = True
user_sessions = request.user.session_keys user_sessions = request.user.session_keys
if request.session.session_key not in user_sessions: if session.session_key not in user_sessions:
user_sessions.append(session.session_key) user_sessions.append(session.session_key)
request.user.save() request.user.save()
@ -42,13 +23,7 @@ class AdminRoleRequiredMiddleware(MiddlewareMixin):
path = request.path_info path = request.path_info
if path.startswith("/admin/") or path.startswith("/api/admin/"): if path.startswith("/admin/") or path.startswith("/api/admin/"):
if not (request.user.is_authenticated() and request.user.is_admin_role()): if not (request.user.is_authenticated() and request.user.is_admin_role()):
return JSONResponse.response({"error": "login-required", "data": _("Please login in first")}) return JSONResponse.response({"error": "login-required", "data": "Please login in first"})
class TimezoneMiddleware(MiddlewareMixin):
def process_request(self, request):
if request.user.is_authenticated():
timezone.activate(pytz.timezone(request.user.userprofile.time_zone))
class LogSqlMiddleware(MiddlewareMixin): class LogSqlMiddleware(MiddlewareMixin):

View File

@ -50,7 +50,7 @@ class Migration(migrations.Migration):
fields=[ fields=[
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('problems_status', jsonfield.fields.JSONField(default={})), ('problems_status', jsonfield.fields.JSONField(default={})),
('avatar', models.CharField(default=account.models._default_avatar, max_length=50)), ('avatar', models.CharField(default="default.png", max_length=50)),
('blog', models.URLField(blank=True, null=True)), ('blog', models.URLField(blank=True, null=True)),
('mood', models.CharField(blank=True, max_length=200, null=True)), ('mood', models.CharField(blank=True, max_length=200, null=True)),
('accepted_problem_number', models.IntegerField(default=0)), ('accepted_problem_number', models.IntegerField(default=0)),

View File

@ -0,0 +1,105 @@
# -*- coding: utf-8 -*-
# Generated by Django 1.11.4 on 2017-10-11 12:14
from __future__ import unicode_literals
import django.contrib.postgres.fields.jsonb
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('account', '0007_auto_20170920_0254'),
]
operations = [
migrations.RemoveField(
model_name='userprofile',
name='language',
),
migrations.AlterField(
model_name='user',
name='admin_type',
field=models.CharField(default='Regular User', max_length=32),
),
migrations.AlterField(
model_name='user',
name='auth_token',
field=models.CharField(max_length=32, null=True),
),
migrations.AlterField(
model_name='user',
name='email',
field=models.EmailField(max_length=64, null=True),
),
migrations.AlterField(
model_name='user',
name='open_api_appkey',
field=models.CharField(max_length=32, null=True),
),
migrations.AlterField(
model_name='user',
name='problem_permission',
field=models.CharField(default='None', max_length=32),
),
migrations.AlterField(
model_name='user',
name='reset_password_token',
field=models.CharField(max_length=32, null=True),
),
migrations.AlterField(
model_name='user',
name='session_keys',
field=django.contrib.postgres.fields.jsonb.JSONField(default=list),
),
migrations.AlterField(
model_name='user',
name='tfa_token',
field=models.CharField(max_length=32, null=True),
),
migrations.AlterField(
model_name='user',
name='username',
field=models.CharField(max_length=32, unique=True),
),
migrations.AlterField(
model_name='userprofile',
name='acm_problems_status',
field=django.contrib.postgres.fields.jsonb.JSONField(default=dict),
),
migrations.AlterField(
model_name='userprofile',
name='avatar',
field=models.CharField(default='/static/avatar/default.png', max_length=256),
),
migrations.AlterField(
model_name='userprofile',
name='github',
field=models.CharField(blank=True, max_length=64, null=True),
),
migrations.AlterField(
model_name='userprofile',
name='major',
field=models.CharField(blank=True, max_length=64, null=True),
),
migrations.AlterField(
model_name='userprofile',
name='mood',
field=models.CharField(blank=True, max_length=256, null=True),
),
migrations.AlterField(
model_name='userprofile',
name='oi_problems_status',
field=django.contrib.postgres.fields.jsonb.JSONField(default=dict),
),
migrations.AlterField(
model_name='userprofile',
name='real_name',
field=models.CharField(blank=True, max_length=32, null=True),
),
migrations.AlterField(
model_name='userprofile',
name='school',
field=models.CharField(blank=True, max_length=64, null=True),
),
]

View File

@ -1,7 +1,7 @@
from django.contrib.auth.models import AbstractBaseUser from django.contrib.auth.models import AbstractBaseUser
from django.conf import settings from django.conf import settings
from django.db import models from django.db import models
from jsonfield import JSONField from utils.models import JSONField
class AdminType(object): class AdminType(object):
@ -24,22 +24,22 @@ class UserManager(models.Manager):
class User(AbstractBaseUser): class User(AbstractBaseUser):
username = models.CharField(max_length=30, unique=True) username = models.CharField(max_length=32, unique=True)
email = models.EmailField(max_length=254, null=True) email = models.EmailField(max_length=64, null=True)
create_time = models.DateTimeField(auto_now_add=True, null=True) create_time = models.DateTimeField(auto_now_add=True, null=True)
# One of UserType # One of UserType
admin_type = models.CharField(max_length=24, default=AdminType.REGULAR_USER) admin_type = models.CharField(max_length=32, default=AdminType.REGULAR_USER)
problem_permission = models.CharField(max_length=24, default=ProblemPermission.NONE) problem_permission = models.CharField(max_length=32, default=ProblemPermission.NONE)
reset_password_token = models.CharField(max_length=40, null=True) reset_password_token = models.CharField(max_length=32, null=True)
reset_password_token_expire_time = models.DateTimeField(null=True) reset_password_token_expire_time = models.DateTimeField(null=True)
# SSO auth token # SSO auth token
auth_token = models.CharField(max_length=40, null=True) auth_token = models.CharField(max_length=32, null=True)
two_factor_auth = models.BooleanField(default=False) two_factor_auth = models.BooleanField(default=False)
tfa_token = models.CharField(max_length=40, null=True) tfa_token = models.CharField(max_length=32, null=True)
session_keys = JSONField(default=[]) session_keys = JSONField(default=list)
# open api key # open api key
open_api = models.BooleanField(default=False) open_api = models.BooleanField(default=False)
open_api_appkey = models.CharField(max_length=35, null=True) open_api_appkey = models.CharField(max_length=32, null=True)
is_disabled = models.BooleanField(default=False) is_disabled = models.BooleanField(default=False)
USERNAME_FIELD = "username" USERNAME_FIELD = "username"
@ -63,26 +63,34 @@ class User(AbstractBaseUser):
db_table = "user" db_table = "user"
def _default_avatar():
return f"/{settings.IMAGE_UPLOAD_DIR}/default.png"
class UserProfile(models.Model): class UserProfile(models.Model):
user = models.OneToOneField(User) user = models.OneToOneField(User)
# Store user problem solution status with json string format # acm_problems_status examples:
# {problems: {1: JudgeStatus.ACCEPTED}, contest_problems: {1: JudgeStatus.ACCEPTED}}, record problem_id and status # {
acm_problems_status = JSONField(default={}) # "problems": {
# {problems: {1: 33}, contest_problems: {1: 44}, record problem_id and score # "1": {
oi_problems_status = JSONField(default={}) # "status": JudgeStatus.ACCEPTED,
# "_id": "1000"
# }
# },
# "contest_problems": {
# "1": {
# "status": JudgeStatus.ACCEPTED,
# "_id": "1000"
# }
# }
# }
acm_problems_status = JSONField(default=dict)
# like acm_problems_status, merely add "score" field
oi_problems_status = JSONField(default=dict)
real_name = models.CharField(max_length=30, blank=True, null=True) real_name = models.CharField(max_length=32, blank=True, null=True)
avatar = models.CharField(max_length=50, default=_default_avatar()) avatar = models.CharField(max_length=256, default=f"/{settings.IMAGE_UPLOAD_DIR}/default.png")
blog = models.URLField(blank=True, null=True) blog = models.URLField(blank=True, null=True)
mood = models.CharField(max_length=200, blank=True, null=True) mood = models.CharField(max_length=256, blank=True, null=True)
github = models.CharField(max_length=50, blank=True, null=True) github = models.CharField(max_length=64, blank=True, null=True)
school = models.CharField(max_length=200, blank=True, null=True) school = models.CharField(max_length=64, blank=True, null=True)
major = models.CharField(max_length=200, blank=True, null=True) major = models.CharField(max_length=64, blank=True, null=True)
language = models.CharField(max_length=32, blank=True, null=True)
# for ACM # for ACM
accepted_number = models.IntegerField(default=0) accepted_number = models.IntegerField(default=0)
# for OI # for OI

View File

@ -6,27 +6,26 @@ from .models import AdminType, ProblemPermission, User, UserProfile
class UserLoginSerializer(serializers.Serializer): class UserLoginSerializer(serializers.Serializer):
username = serializers.CharField(max_length=30) username = serializers.CharField()
password = serializers.CharField(max_length=30) password = serializers.CharField()
tfa_code = serializers.CharField(min_length=6, max_length=6, required=False, allow_null=True) tfa_code = serializers.CharField(required=False, allow_null=True)
class UsernameOrEmailCheckSerializer(serializers.Serializer): class UsernameOrEmailCheckSerializer(serializers.Serializer):
username = serializers.CharField(max_length=30, required=False) username = serializers.CharField(required=False)
email = serializers.EmailField(max_length=30, required=False) email = serializers.EmailField(required=False)
class UserRegisterSerializer(serializers.Serializer): class UserRegisterSerializer(serializers.Serializer):
username = serializers.CharField(max_length=30) username = serializers.CharField(max_length=32)
password = serializers.CharField(max_length=30, min_length=6) password = serializers.CharField(min_length=6)
email = serializers.EmailField(max_length=30) email = serializers.EmailField(max_length=64)
captcha = serializers.CharField(max_length=4, min_length=1) captcha = serializers.CharField()
class UserChangePasswordSerializer(serializers.Serializer): class UserChangePasswordSerializer(serializers.Serializer):
old_password = serializers.CharField() old_password = serializers.CharField()
new_password = serializers.CharField(max_length=30, min_length=6) new_password = serializers.CharField(min_length=6)
captcha = serializers.CharField(max_length=4, min_length=4)
class UserSerializer(serializers.ModelSerializer): class UserSerializer(serializers.ModelSerializer):
@ -46,6 +45,7 @@ class UserProfileSerializer(serializers.ModelSerializer):
class Meta: class Meta:
model = UserProfile model = UserProfile
fields = "__all__"
class UserInfoSerializer(serializers.ModelSerializer): class UserInfoSerializer(serializers.ModelSerializer):
@ -58,9 +58,9 @@ class UserInfoSerializer(serializers.ModelSerializer):
class EditUserSerializer(serializers.Serializer): class EditUserSerializer(serializers.Serializer):
id = serializers.IntegerField() id = serializers.IntegerField()
username = serializers.CharField(max_length=30) username = serializers.CharField(max_length=32)
password = serializers.CharField(max_length=30, min_length=6, allow_blank=True, required=False, default=None) password = serializers.CharField(min_length=6, allow_blank=True, required=False, default=None)
email = serializers.EmailField(max_length=254) email = serializers.EmailField(max_length=64)
admin_type = serializers.ChoiceField(choices=(AdminType.REGULAR_USER, AdminType.ADMIN, AdminType.SUPER_ADMIN)) admin_type = serializers.ChoiceField(choices=(AdminType.REGULAR_USER, AdminType.ADMIN, AdminType.SUPER_ADMIN))
problem_permission = serializers.ChoiceField(choices=(ProblemPermission.NONE, ProblemPermission.OWN, problem_permission = serializers.ChoiceField(choices=(ProblemPermission.NONE, ProblemPermission.OWN,
ProblemPermission.ALL)) ProblemPermission.ALL))
@ -70,29 +70,29 @@ class EditUserSerializer(serializers.Serializer):
class EditUserProfileSerializer(serializers.Serializer): class EditUserProfileSerializer(serializers.Serializer):
real_name = serializers.CharField(max_length=30, allow_blank=True) real_name = serializers.CharField(max_length=32, allow_blank=True)
avatar = serializers.CharField(max_length=100, allow_blank=True, required=False) avatar = serializers.CharField(max_length=256, allow_blank=True, required=False)
blog = serializers.URLField(allow_blank=True, required=False) blog = serializers.URLField(max_length=256, allow_blank=True, required=False)
mood = serializers.CharField(max_length=200, allow_blank=True, required=False) mood = serializers.CharField(max_length=256, allow_blank=True, required=False)
github = serializers.CharField(max_length=50, allow_blank=True, required=False) github = serializers.CharField(max_length=64, allow_blank=True, required=False)
school = serializers.CharField(max_length=200, allow_blank=True, required=False) school = serializers.CharField(max_length=64, allow_blank=True, required=False)
major = serializers.CharField(max_length=200, allow_blank=True, required=False) major = serializers.CharField(max_length=64, allow_blank=True, required=False)
class ApplyResetPasswordSerializer(serializers.Serializer): class ApplyResetPasswordSerializer(serializers.Serializer):
email = serializers.EmailField() email = serializers.EmailField()
captcha = serializers.CharField(max_length=4, min_length=4) captcha = serializers.CharField()
class ResetPasswordSerializer(serializers.Serializer): class ResetPasswordSerializer(serializers.Serializer):
token = serializers.CharField(min_length=1, max_length=40) token = serializers.CharField()
password = serializers.CharField(min_length=6, max_length=30) password = serializers.CharField(min_length=6)
captcha = serializers.CharField(max_length=4, min_length=4) captcha = serializers.CharField()
class SSOSerializer(serializers.Serializer): class SSOSerializer(serializers.Serializer):
appkey = serializers.CharField(max_length=35) appkey = serializers.CharField()
token = serializers.CharField(max_length=40) token = serializers.CharField()
class TwoFactorAuthCodeSerializer(serializers.Serializer): class TwoFactorAuthCodeSerializer(serializers.Serializer):

View File

@ -1,6 +1,31 @@
from celery import shared_task import logging
from utils.shortcuts import send_email from celery import shared_task
from envelopes import Envelope
from options.options import SysOptions
logger = logging.getLogger(__name__)
def send_email(from_name, to_email, to_name, subject, content):
smtp = SysOptions.smtp_config
if not smtp:
return
envlope = Envelope(from_addr=(smtp["email"], from_name),
to_addr=(to_email, to_name),
subject=subject,
html_body=content)
try:
envlope.send(smtp["server"],
login=smtp["email"],
password=smtp["password"],
port=smtp["port"],
tls=smtp["tls"])
return True
except Exception as e:
logger.exception(e)
return False
@shared_task @shared_task

View File

@ -8,11 +8,10 @@ from otpauth import OtpAuth
from utils.api.tests import APIClient, APITestCase from utils.api.tests import APIClient, APITestCase
from utils.shortcuts import rand_str from utils.shortcuts import rand_str
from utils.cache import default_cache from options.options import SysOptions
from utils.constants import CacheKey
from .models import AdminType, ProblemPermission, User from .models import AdminType, ProblemPermission, User
from conf.models import WebsiteConfig from utils.constants import ContestRuleType
class PermissionDecoratorTest(APITestCase): class PermissionDecoratorTest(APITestCase):
@ -136,7 +135,7 @@ class UserLoginAPITest(APITestCase):
self.user.save() self.user.save()
resp = self.client.post(self.login_url, data={"username": self.username, resp = self.client.post(self.login_url, data={"username": self.username,
"password": self.password}) "password": self.password})
self.assertDictEqual(resp.data, {"error": "error", "data": "Your account have been disabled"}) self.assertDictEqual(resp.data, {"error": "error", "data": "Your account has been disabled"})
class CaptchaTest(APITestCase): class CaptchaTest(APITestCase):
@ -157,15 +156,11 @@ class UserRegisterAPITest(CaptchaTest):
self.data = {"username": "test_user", "password": "testuserpassword", self.data = {"username": "test_user", "password": "testuserpassword",
"real_name": "real_name", "email": "test@qduoj.com", "real_name": "real_name", "email": "test@qduoj.com",
"captcha": self._set_captcha(self.client.session)} "captcha": self._set_captcha(self.client.session)}
# clea cache in redis
default_cache.delete(CacheKey.website_config)
def test_website_config_limit(self): def test_website_config_limit(self):
website = WebsiteConfig.objects.create() SysOptions.allow_register = False
website.allow_register = False
website.save()
resp = self.client.post(self.register_url, data=self.data) resp = self.client.post(self.register_url, data=self.data)
self.assertDictEqual(resp.data, {"error": "error", "data": "Register have been disabled by admin"}) self.assertDictEqual(resp.data, {"error": "error", "data": "Register function has been disabled by admin"})
def test_invalid_captcha(self): def test_invalid_captcha(self):
self.data["captcha"] = "****" self.data["captcha"] = "****"
@ -226,7 +221,7 @@ class UserProfileAPITest(APITestCase):
def test_get_profile_without_login(self): def test_get_profile_without_login(self):
resp = self.client.get(self.url) resp = self.client.get(self.url)
self.assertDictEqual(resp.data, {"error": None, "data": {}}) self.assertDictEqual(resp.data, {"error": None, "data": None})
def test_get_profile(self): def test_get_profile(self):
self.create_user("test", "test123") self.create_user("test", "test123")
@ -247,7 +242,6 @@ class TwoFactorAuthAPITest(APITestCase):
def setUp(self): def setUp(self):
self.url = self.reverse("two_factor_auth_api") self.url = self.reverse("two_factor_auth_api")
self.create_user("test", "test123") self.create_user("test", "test123")
self.create_website_config()
def _get_tfa_code(self): def _get_tfa_code(self):
user = User.objects.first() user = User.objects.first()
@ -295,7 +289,6 @@ class ApplyResetPasswordAPITest(CaptchaTest):
user.email = "test@oj.com" user.email = "test@oj.com"
user.save() user.save()
self.url = self.reverse("apply_reset_password_api") self.url = self.reverse("apply_reset_password_api")
self.create_website_config()
self.data = {"email": "test@oj.com", "captcha": self._set_captcha(self.client.session)} self.data = {"email": "test@oj.com", "captcha": self._set_captcha(self.client.session)}
def _refresh_captcha(self): def _refresh_captcha(self):
@ -343,14 +336,14 @@ class ResetPasswordAPITest(CaptchaTest):
def test_reset_password_with_invalid_token(self): def test_reset_password_with_invalid_token(self):
self.data["token"] = "aaaaaaaaaaa" self.data["token"] = "aaaaaaaaaaa"
resp = self.client.post(self.url, data=self.data) resp = self.client.post(self.url, data=self.data)
self.assertDictEqual(resp.data, {"error": "error", "data": "Token dose not exist"}) self.assertDictEqual(resp.data, {"error": "error", "data": "Token does not exist"})
def test_reset_password_with_expired_token(self): def test_reset_password_with_expired_token(self):
user = User.objects.first() user = User.objects.first()
user.reset_password_token_expire_time = now() - timedelta(seconds=30) user.reset_password_token_expire_time = now() - timedelta(seconds=30)
user.save() user.save()
resp = self.client.post(self.url, data=self.data) resp = self.client.post(self.url, data=self.data)
self.assertDictEqual(resp.data, {"error": "error", "data": "Token have expired"}) self.assertDictEqual(resp.data, {"error": "error", "data": "Token has expired"})
class UserChangePasswordAPITest(CaptchaTest): class UserChangePasswordAPITest(CaptchaTest):
@ -481,14 +474,14 @@ class UserRankAPITest(APITestCase):
profile2.save() profile2.save()
def test_get_acm_rank(self): def test_get_acm_rank(self):
resp = self.client.get(self.url, data={"rule": "acm"}) resp = self.client.get(self.url, data={"rule": ContestRuleType.ACM})
self.assertSuccess(resp) self.assertSuccess(resp)
data = resp.data["data"] data = resp.data["data"]
self.assertEqual(data[0]["user"]["username"], "test1") self.assertEqual(data[0]["user"]["username"], "test1")
self.assertEqual(data[1]["user"]["username"], "test2") self.assertEqual(data[1]["user"]["username"], "test2")
def test_get_oi_rank(self): def test_get_oi_rank(self):
resp = self.client.get(self.url, data={"rule": "oi"}) resp = self.client.get(self.url, data={"rule": ContestRuleType.OI})
self.assertSuccess(resp) self.assertSuccess(resp)
data = resp.data["data"] data = resp.data["data"]
self.assertEqual(data[0]["user"]["username"], "test2") self.assertEqual(data[0]["user"]["username"], "test2")

View File

@ -3,7 +3,7 @@ from django.conf.urls import url
from ..views.oj import (ApplyResetPasswordAPI, ResetPasswordAPI, from ..views.oj import (ApplyResetPasswordAPI, ResetPasswordAPI,
UserChangePasswordAPI, UserRegisterAPI, UserChangePasswordAPI, UserRegisterAPI,
UserLoginAPI, UserLogoutAPI, UsernameOrEmailCheck, UserLoginAPI, UserLogoutAPI, UsernameOrEmailCheck,
SSOAPI, AvatarUploadAPI, TwoFactorAuthAPI, UserProfileAPI, AvatarUploadAPI, TwoFactorAuthAPI, UserProfileAPI,
UserRankAPI, CheckTFARequiredAPI, SessionManagementAPI) UserRankAPI, CheckTFARequiredAPI, SessionManagementAPI)
from utils.captcha.views import CaptchaAPIView from utils.captcha.views import CaptchaAPIView
@ -19,7 +19,6 @@ urlpatterns = [
url(r"^check_username_or_email", UsernameOrEmailCheck.as_view(), name="check_username_or_email"), url(r"^check_username_or_email", UsernameOrEmailCheck.as_view(), name="check_username_or_email"),
url(r"^profile/?$", UserProfileAPI.as_view(), name="user_profile_api"), url(r"^profile/?$", UserProfileAPI.as_view(), name="user_profile_api"),
url(r"^upload_avatar/?$", AvatarUploadAPI.as_view(), name="avatar_upload_api"), url(r"^upload_avatar/?$", AvatarUploadAPI.as_view(), name="avatar_upload_api"),
url(r"^sso/?$", SSOAPI.as_view(), name="sso_api"),
url(r"^tfa_required/?$", CheckTFARequiredAPI.as_view(), name="tfa_required_check"), url(r"^tfa_required/?$", CheckTFARequiredAPI.as_view(), name="tfa_required_check"),
url(r"^two_factor_auth/?$", TwoFactorAuthAPI.as_view(), name="two_factor_auth_api"), url(r"^two_factor_auth/?$", TwoFactorAuthAPI.as_view(), name="two_factor_auth_api"),
url(r"^user_rank/?$", UserRankAPI.as_view(), name="user_rank_api"), url(r"^user_rank/?$", UserRankAPI.as_view(), name="user_rank_api"),

View File

@ -1,32 +1,28 @@
import os import os
import qrcode
import pickle
from datetime import timedelta from datetime import timedelta
from otpauth import OtpAuth from importlib import import_module
import qrcode
from django.conf import settings from django.conf import settings
from django.contrib import auth from django.contrib import auth
from importlib import import_module from django.template.loader import render_to_string
from django.utils.decorators import method_decorator
from django.utils.timezone import now from django.utils.timezone import now
from django.views.decorators.csrf import ensure_csrf_cookie from django.views.decorators.csrf import ensure_csrf_cookie
from django.utils.decorators import method_decorator from otpauth import OtpAuth
from django.template.loader import render_to_string
from conf.models import WebsiteConfig from utils.constants import ContestRuleType
from options.options import SysOptions
from utils.api import APIView, validate_serializer from utils.api import APIView, validate_serializer
from utils.captcha import Captcha from utils.captcha import Captcha
from utils.shortcuts import rand_str, img2base64, timestamp2utcstr from utils.shortcuts import rand_str, img2base64, datetime2str
from utils.cache import default_cache
from utils.constants import CacheKey
from ..decorators import login_required from ..decorators import login_required
from ..models import User, UserProfile from ..models import User, UserProfile
from ..serializers import (ApplyResetPasswordSerializer, ResetPasswordSerializer, from ..serializers import (ApplyResetPasswordSerializer, ResetPasswordSerializer,
UserChangePasswordSerializer, UserLoginSerializer, UserChangePasswordSerializer, UserLoginSerializer,
UserRegisterSerializer, UsernameOrEmailCheckSerializer, UserRegisterSerializer, UsernameOrEmailCheckSerializer,
RankInfoSerializer) RankInfoSerializer)
from ..serializers import (SSOSerializer, TwoFactorAuthCodeSerializer, from ..serializers import (TwoFactorAuthCodeSerializer, UserProfileSerializer,
UserProfileSerializer,
EditUserProfileSerializer, AvatarUploadForm) EditUserProfileSerializer, AvatarUploadForm)
from ..tasks import send_email_async from ..tasks import send_email_async
@ -39,7 +35,7 @@ class UserProfileAPI(APIView):
""" """
user = request.user user = request.user
if not user.is_authenticated(): if not user.is_authenticated():
return self.success({}) return self.success()
username = request.GET.get("username") username = request.GET.get("username")
try: try:
if username: if username:
@ -48,8 +44,7 @@ class UserProfileAPI(APIView):
user = request.user user = request.user
except User.DoesNotExist: except User.DoesNotExist:
return self.error("User does not exist") return self.error("User does not exist")
profile = UserProfile.objects.select_related("user").get(user=user) return self.success(UserProfileSerializer(user.userprofile).data)
return self.success(UserProfileSerializer(profile).data)
@validate_serializer(EditUserProfileSerializer) @validate_serializer(EditUserProfileSerializer)
@login_required @login_required
@ -72,8 +67,7 @@ class AvatarUploadAPI(APIView):
avatar = form.cleaned_data["file"] avatar = form.cleaned_data["file"]
else: else:
return self.error("Invalid file content") return self.error("Invalid file content")
# 2097152 = 2 * 1024 * 1024 = 2MB if avatar.size > 2 * 1024 * 1024:
if avatar.size > 2097152:
return self.error("Picture is too large") return self.error("Picture is too large")
suffix = os.path.splitext(avatar.name)[-1].lower() suffix = os.path.splitext(avatar.name)[-1].lower()
if suffix not in [".gif", ".jpg", ".jpeg", ".bmp", ".png"]: if suffix not in [".gif", ".jpg", ".jpeg", ".bmp", ".png"]:
@ -84,46 +78,12 @@ class AvatarUploadAPI(APIView):
for chunk in avatar: for chunk in avatar:
img.write(chunk) img.write(chunk)
user_profile = request.user.userprofile user_profile = request.user.userprofile
_, old_avatar = os.path.split(user_profile.avatar)
if old_avatar != "default.png":
os.remove(os.path.join(settings.IMAGE_UPLOAD_DIR_ABS, old_avatar))
user_profile.avatar = f"/{settings.IMAGE_UPLOAD_DIR}/{name}" user_profile.avatar = f"/{settings.IMAGE_UPLOAD_DIR}/{name}"
user_profile.save() user_profile.save()
return self.success("Succeeded") return self.success("Succeeded")
class SSOAPI(APIView):
@login_required
def get(self, request):
callback = request.GET.get("callback", None)
if not callback:
return self.error("Parameter Error")
token = rand_str()
request.user.auth_token = token
request.user.save()
return self.success({"redirect_url": callback + "?token=" + token,
"callback": callback})
@validate_serializer(SSOSerializer)
def post(self, request):
data = request.data
try:
User.objects.get(open_api_appkey=data["appkey"])
except User.DoesNotExist:
return self.error("Invalid appkey")
try:
user = User.objects.get(auth_token=data["token"])
user.auth_token = None
user.save()
return self.success({"username": user.username,
"id": user.id,
"admin_type": user.admin_type,
"avatar": user.userprofile.avatar})
except User.DoesNotExist:
return self.error("User does not exist")
class TwoFactorAuthAPI(APIView): class TwoFactorAuthAPI(APIView):
@login_required @login_required
def get(self, request): def get(self, request):
@ -132,14 +92,13 @@ class TwoFactorAuthAPI(APIView):
""" """
user = request.user user = request.user
if user.two_factor_auth: if user.two_factor_auth:
return self.error("Already open 2FA") return self.error("2FA is already turned on")
token = rand_str() token = rand_str()
user.tfa_token = token user.tfa_token = token
user.save() user.save()
config = WebsiteConfig.objects.first() label = f"{SysOptions.website_name_shortcut}:{user.username}"
label = f"{config.name_shortcut}:{user.username}" image = qrcode.make(OtpAuth(token).to_uri("totp", label, SysOptions.website_name))
image = qrcode.make(OtpAuth(token).to_uri("totp", label, config.name))
return self.success(img2base64(image)) return self.success(img2base64(image))
@login_required @login_required
@ -163,7 +122,7 @@ class TwoFactorAuthAPI(APIView):
code = request.data["code"] code = request.data["code"]
user = request.user user = request.user
if not user.two_factor_auth: if not user.two_factor_auth:
return self.error("Other session have disabled TFA") return self.error("2FA is already turned off")
if OtpAuth(user.tfa_token).valid_totp(code): if OtpAuth(user.tfa_token).valid_totp(code):
user.two_factor_auth = False user.two_factor_auth = False
user.save() user.save()
@ -200,7 +159,7 @@ class UserLoginAPI(APIView):
# None is returned if username or password is wrong # None is returned if username or password is wrong
if user: if user:
if user.is_disabled: if user.is_disabled:
return self.error("Your account have been disabled") return self.error("Your account has been disabled")
if not user.two_factor_auth: if not user.two_factor_auth:
auth.login(request, user) auth.login(request, user)
return self.success("Succeeded") return self.success("Succeeded")
@ -220,13 +179,13 @@ class UserLoginAPI(APIView):
# todo remove this, only for debug use # todo remove this, only for debug use
def get(self, request): def get(self, request):
auth.login(request, auth.authenticate(username=request.GET["username"], password=request.GET["password"])) auth.login(request, auth.authenticate(username=request.GET["username"], password=request.GET["password"]))
return self.success({}) return self.success()
class UserLogoutAPI(APIView): class UserLogoutAPI(APIView):
def get(self, request): def get(self, request):
auth.logout(request) auth.logout(request)
return self.success({}) return self.success()
class UsernameOrEmailCheck(APIView): class UsernameOrEmailCheck(APIView):
@ -242,11 +201,9 @@ class UsernameOrEmailCheck(APIView):
"email": False "email": False
} }
if data.get("username"): if data.get("username"):
if User.objects.filter(username=data["username"]).exists(): result["username"] = User.objects.filter(username=data["username"]).exists()
result["username"] = True
if data.get("email"): if data.get("email"):
if User.objects.filter(email=data["email"]).exists(): result["email"] = User.objects.filter(email=data["email"]).exists()
result["email"] = True
return self.success(result) return self.success(result)
@ -256,17 +213,9 @@ class UserRegisterAPI(APIView):
""" """
User register api User register api
""" """
config = default_cache.get(CacheKey.website_config)
if config:
config = pickle.loads(config)
else:
config = WebsiteConfig.objects.first()
if not config:
config = WebsiteConfig.objects.create()
default_cache.set(CacheKey.website_config, pickle.dumps(config))
if not config.allow_register: if not SysOptions.allow_register:
return self.error("Register have been disabled by admin") return self.error("Register function has been disabled by admin")
data = request.data data = request.data
captcha = Captcha(request) captcha = Captcha(request)
@ -295,6 +244,7 @@ class UserChangePasswordAPI(APIView):
username = request.user.username username = request.user.username
user = auth.authenticate(username=username, password=data["old_password"]) user = auth.authenticate(username=username, password=data["old_password"])
if user: if user:
# TODO: check tfa?
user.set_password(data["new_password"]) user.set_password(data["new_password"])
user.save() user.save()
return self.success("Succeeded") return self.success("Succeeded")
@ -307,7 +257,6 @@ class ApplyResetPasswordAPI(APIView):
def post(self, request): def post(self, request):
data = request.data data = request.data
captcha = Captcha(request) captcha = Captcha(request)
config = WebsiteConfig.objects.first()
if not captcha.check(data["captcha"]): if not captcha.check(data["captcha"]):
return self.error("Invalid captcha") return self.error("Invalid captcha")
try: try:
@ -322,14 +271,14 @@ class ApplyResetPasswordAPI(APIView):
user.save() user.save()
render_data = { render_data = {
"username": user.username, "username": user.username,
"website_name": config.name, "website_name": SysOptions.website_name,
"link": f"{config.base_url}/reset-password/{user.reset_password_token}" "link": f"{SysOptions.website_base_url}/reset-password/{user.reset_password_token}"
} }
email_html = render_to_string("reset_password_email.html", render_data) email_html = render_to_string("reset_password_email.html", render_data)
send_email_async.delay(config.name, send_email_async.delay(SysOptions.website_name,
user.email, user.email,
user.username, user.username,
config.name + " 登录信息找回邮件", f"{SysOptions.website_name} 登录信息找回邮件",
email_html) email_html)
return self.success("Succeeded") return self.success("Succeeded")
@ -344,9 +293,9 @@ class ResetPasswordAPI(APIView):
try: try:
user = User.objects.get(reset_password_token=data["token"]) user = User.objects.get(reset_password_token=data["token"])
except User.DoesNotExist: except User.DoesNotExist:
return self.error("Token dose not exist") return self.error("Token does not exist")
if int((user.reset_password_token_expire_time - now()).total_seconds()) < 0: if user.reset_password_token_expire_time < now():
return self.error("Token have expired") return self.error("Token has expired")
user.reset_password_token = None user.reset_password_token = None
user.two_factor_auth = False user.two_factor_auth = False
user.set_password(data["password"]) user.set_password(data["password"])
@ -358,14 +307,13 @@ class SessionManagementAPI(APIView):
@login_required @login_required
def get(self, request): def get(self, request):
engine = import_module(settings.SESSION_ENGINE) engine = import_module(settings.SESSION_ENGINE)
SessionStore = engine.SessionStore session_store = engine.SessionStore
current_session = request.COOKIES.get(settings.SESSION_COOKIE_NAME)
current_session = request.session.session_key current_session = request.session.session_key
session_keys = request.user.session_keys session_keys = request.user.session_keys
result = [] result = []
modified = False modified = False
for key in session_keys[:]: for key in session_keys[:]:
session = SessionStore(key) session = session_store(key)
# session does not exist or is expiry # session does not exist or is expiry
if not session._session: if not session._session:
session_keys.remove(key) session_keys.remove(key)
@ -377,7 +325,7 @@ class SessionManagementAPI(APIView):
s["current_session"] = True s["current_session"] = True
s["ip"] = session["ip"] s["ip"] = session["ip"]
s["user_agent"] = session["user_agent"] s["user_agent"] = session["user_agent"]
s["last_activity"] = timestamp2utcstr(session["last_activity"]) s["last_activity"] = datetime2str(session["last_activity"])
s["session_key"] = key s["session_key"] = key
result.append(s) result.append(s)
if modified: if modified:
@ -401,12 +349,12 @@ class SessionManagementAPI(APIView):
class UserRankAPI(APIView): class UserRankAPI(APIView):
def get(self, request): def get(self, request):
rule_type = request.GET.get("rule") rule_type = request.GET.get("rule")
if rule_type not in ["acm", "oi"]: if rule_type not in ContestRuleType.choices():
rule_type = "acm" rule_type = ContestRuleType.ACM
profiles = UserProfile.objects.select_related("user")\ profiles = UserProfile.objects.select_related("user")\
.filter(submission_number__gt=0)\ .filter(submission_number__gt=0)\
.exclude(user__is_disabled=True) .exclude(user__is_disabled=True)
if rule_type == "acm": if rule_type == ContestRuleType.ACM:
profiles = profiles.order_by("-accepted_number", "submission_number") profiles = profiles.order_by("-accepted_number", "submission_number")
else: else:
profiles = profiles.order_by("-total_score") profiles = profiles.order_by("-total_score")

View File

@ -0,0 +1,20 @@
# -*- coding: utf-8 -*-
# Generated by Django 1.11.4 on 2017-10-11 12:14
from __future__ import unicode_literals
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('announcement', '0001_initial'),
]
operations = [
migrations.AlterField(
model_name='announcement',
name='title',
field=models.CharField(max_length=64),
),
]

View File

@ -5,7 +5,7 @@ from utils.models import RichTextField
class Announcement(models.Model): class Announcement(models.Model):
title = models.CharField(max_length=50) title = models.CharField(max_length=64)
# HTML # HTML
content = RichTextField() content = RichTextField()
create_time = models.DateTimeField(auto_now_add=True) create_time = models.DateTimeField(auto_now_add=True)

View File

@ -5,8 +5,8 @@ from .models import Announcement
class CreateAnnouncementSerializer(serializers.Serializer): class CreateAnnouncementSerializer(serializers.Serializer):
title = serializers.CharField(max_length=50) title = serializers.CharField(max_length=64)
content = serializers.CharField(max_length=10000) content = serializers.CharField(max_length=1024 * 1024 * 8)
visible = serializers.BooleanField() visible = serializers.BooleanField()
@ -21,6 +21,6 @@ class AnnouncementSerializer(serializers.ModelSerializer):
class EditAnnouncementSerializer(serializers.Serializer): class EditAnnouncementSerializer(serializers.Serializer):
id = serializers.IntegerField() id = serializers.IntegerField()
title = serializers.CharField(max_length=50) title = serializers.CharField(max_length=64)
content = serializers.CharField(max_length=10000) content = serializers.CharField(max_length=1024 * 1024 * 8)
visible = serializers.BooleanField() visible = serializers.BooleanField()

View File

@ -0,0 +1,39 @@
# -*- coding: utf-8 -*-
# Generated by Django 1.11.4 on 2017-10-11 12:14
from __future__ import unicode_literals
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('conf', '0001_initial'),
]
operations = [
migrations.DeleteModel(
name='JudgeServerToken',
),
migrations.DeleteModel(
name='SMTPConfig',
),
migrations.DeleteModel(
name='WebsiteConfig',
),
migrations.AlterField(
model_name='judgeserver',
name='hostname',
field=models.CharField(max_length=128),
),
migrations.AlterField(
model_name='judgeserver',
name='judger_version',
field=models.CharField(max_length=32),
),
migrations.AlterField(
model_name='judgeserver',
name='service_url',
field=models.CharField(blank=True, max_length=256, null=True),
),
]

View File

@ -2,42 +2,17 @@ from django.db import models
from django.utils import timezone from django.utils import timezone
class SMTPConfig(models.Model):
server = models.CharField(max_length=128)
port = models.IntegerField(default=25)
email = models.CharField(max_length=128)
password = models.CharField(max_length=128)
tls = models.BooleanField()
class Meta:
db_table = "smtp_config"
class WebsiteConfig(models.Model):
base_url = models.CharField(max_length=128, default="http://127.0.0.1")
name = models.CharField(max_length=32, default="Online Judge")
name_shortcut = models.CharField(max_length=32, default="oj")
footer = models.TextField(default="Online Judge Footer")
# allow register
allow_register = models.BooleanField(default=True)
# submission list show all user's submission
submission_list_show_all = models.BooleanField(default=True)
class Meta:
db_table = "website_config"
class JudgeServer(models.Model): class JudgeServer(models.Model):
hostname = models.CharField(max_length=64) hostname = models.CharField(max_length=128)
ip = models.CharField(max_length=32, blank=True, null=True) ip = models.CharField(max_length=32, blank=True, null=True)
judger_version = models.CharField(max_length=24) judger_version = models.CharField(max_length=32)
cpu_core = models.IntegerField() cpu_core = models.IntegerField()
memory_usage = models.FloatField() memory_usage = models.FloatField()
cpu_usage = models.FloatField() cpu_usage = models.FloatField()
last_heartbeat = models.DateTimeField() last_heartbeat = models.DateTimeField()
create_time = models.DateTimeField(auto_now_add=True) create_time = models.DateTimeField(auto_now_add=True)
task_number = models.IntegerField(default=0) task_number = models.IntegerField(default=0)
service_url = models.CharField(max_length=128, blank=True, null=True) service_url = models.CharField(max_length=256, blank=True, null=True)
@property @property
def status(self): def status(self):
@ -48,10 +23,3 @@ class JudgeServer(models.Model):
class Meta: class Meta:
db_table = "judge_server" db_table = "judge_server"
class JudgeServerToken(models.Model):
token = models.CharField(max_length=32)
class Meta:
db_table = "judge_server_token"

View File

@ -1,6 +1,6 @@
from utils.api import DateTimeTZField, serializers from utils.api import DateTimeTZField, serializers
from .models import JudgeServer, SMTPConfig, WebsiteConfig from .models import JudgeServer
class EditSMTPConfigSerializer(serializers.Serializer): class EditSMTPConfigSerializer(serializers.Serializer):
@ -15,31 +15,19 @@ class CreateSMTPConfigSerializer(EditSMTPConfigSerializer):
password = serializers.CharField(max_length=128) password = serializers.CharField(max_length=128)
class SMTPConfigSerializer(serializers.ModelSerializer):
class Meta:
model = SMTPConfig
exclude = ["id", "password"]
class TestSMTPConfigSerializer(serializers.Serializer): class TestSMTPConfigSerializer(serializers.Serializer):
email = serializers.EmailField() email = serializers.EmailField()
class CreateEditWebsiteConfigSerializer(serializers.Serializer): class CreateEditWebsiteConfigSerializer(serializers.Serializer):
base_url = serializers.CharField(max_length=128) website_base_url = serializers.CharField(max_length=128)
name = serializers.CharField(max_length=32) website_name = serializers.CharField(max_length=64)
name_shortcut = serializers.CharField(max_length=32) website_name_shortcut = serializers.CharField(max_length=64)
footer = serializers.CharField(max_length=1024) website_footer = serializers.CharField(max_length=1024 * 1024)
allow_register = serializers.BooleanField() allow_register = serializers.BooleanField()
submission_list_show_all = serializers.BooleanField() submission_list_show_all = serializers.BooleanField()
class WebsiteConfigSerializer(serializers.ModelSerializer):
class Meta:
model = WebsiteConfig
exclude = ["id"]
class JudgeServerSerializer(serializers.ModelSerializer): class JudgeServerSerializer(serializers.ModelSerializer):
create_time = DateTimeTZField() create_time = DateTimeTZField()
last_heartbeat = DateTimeTZField() last_heartbeat = DateTimeTZField()
@ -47,13 +35,14 @@ class JudgeServerSerializer(serializers.ModelSerializer):
class Meta: class Meta:
model = JudgeServer model = JudgeServer
fields = "__all__"
class JudgeServerHeartbeatSerializer(serializers.Serializer): class JudgeServerHeartbeatSerializer(serializers.Serializer):
hostname = serializers.CharField(max_length=64) hostname = serializers.CharField(max_length=128)
judger_version = serializers.CharField(max_length=24) judger_version = serializers.CharField(max_length=32)
cpu_core = serializers.IntegerField(min_value=1) cpu_core = serializers.IntegerField(min_value=1)
memory = serializers.FloatField(min_value=0, max_value=100) memory = serializers.FloatField(min_value=0, max_value=100)
cpu = serializers.FloatField(min_value=0, max_value=100) cpu = serializers.FloatField(min_value=0, max_value=100)
action = serializers.ChoiceField(choices=("heartbeat", )) action = serializers.ChoiceField(choices=("heartbeat", ))
service_url = serializers.CharField(max_length=128, required=False) service_url = serializers.CharField(max_length=256, required=False)

View File

@ -2,11 +2,10 @@ import hashlib
from django.utils import timezone from django.utils import timezone
from options.options import SysOptions
from utils.api.tests import APITestCase from utils.api.tests import APITestCase
from utils.cache import default_cache
from utils.constants import CacheKey from utils.constants import CacheKey
from .models import JudgeServer
from .models import JudgeServer, JudgeServerToken, SMTPConfig
class SMTPConfigTest(APITestCase): class SMTPConfigTest(APITestCase):
@ -29,10 +28,6 @@ class SMTPConfigTest(APITestCase):
"tls": True} "tls": True}
resp = self.client.put(self.url, data=data) resp = self.client.put(self.url, data=data)
self.assertSuccess(resp) self.assertSuccess(resp)
smtp = SMTPConfig.objects.first()
self.assertEqual(smtp.password, self.password)
self.assertEqual(smtp.server, "smtp1.test.com")
self.assertEqual(smtp.email, "test2@test.com")
def test_edit_without_password1(self): def test_edit_without_password1(self):
self.test_create_smtp_config() self.test_create_smtp_config()
@ -40,7 +35,6 @@ class SMTPConfigTest(APITestCase):
"tls": True, "password": ""} "tls": True, "password": ""}
resp = self.client.put(self.url, data=data) resp = self.client.put(self.url, data=data)
self.assertSuccess(resp) self.assertSuccess(resp)
self.assertEqual(SMTPConfig.objects.first().password, self.password)
def test_edit_with_password(self): def test_edit_with_password(self):
self.test_create_smtp_config() self.test_create_smtp_config()
@ -48,18 +42,14 @@ class SMTPConfigTest(APITestCase):
"tls": True, "password": "newpassword"} "tls": True, "password": "newpassword"}
resp = self.client.put(self.url, data=data) resp = self.client.put(self.url, data=data)
self.assertSuccess(resp) self.assertSuccess(resp)
smtp = SMTPConfig.objects.first()
self.assertEqual(smtp.password, "newpassword")
self.assertEqual(smtp.server, "smtp1.test.com")
self.assertEqual(smtp.email, "test2@test.com")
class WebsiteConfigAPITest(APITestCase): class WebsiteConfigAPITest(APITestCase):
def test_create_website_config(self): def test_create_website_config(self):
self.create_super_admin() self.create_super_admin()
url = self.reverse("website_config_api") url = self.reverse("website_config_api")
data = {"base_url": "http://test.com", "name": "test name", data = {"website_base_url": "http://test.com", "website_name": "test name",
"name_shortcut": "test oj", "footer": "<a>test</a>", "website_name_shortcut": "test oj", "website_footer": "<a>test</a>",
"allow_register": True, "submission_list_show_all": False} "allow_register": True, "submission_list_show_all": False}
resp = self.client.post(url, data=data) resp = self.client.post(url, data=data)
self.assertSuccess(resp) self.assertSuccess(resp)
@ -67,8 +57,8 @@ class WebsiteConfigAPITest(APITestCase):
def test_edit_website_config(self): def test_edit_website_config(self):
self.create_super_admin() self.create_super_admin()
url = self.reverse("website_config_api") url = self.reverse("website_config_api")
data = {"base_url": "http://test.com", "name": "test name", data = {"website_base_url": "http://test.com", "website_name": "test name",
"name_shortcut": "test oj", "footer": "<a>test</a>", "website_name_shortcut": "test oj", "website_footer": "<a>test</a>",
"allow_register": True, "submission_list_show_all": False} "allow_register": True, "submission_list_show_all": False}
resp = self.client.post(url, data=data) resp = self.client.post(url, data=data)
self.assertSuccess(resp) self.assertSuccess(resp)
@ -78,10 +68,6 @@ class WebsiteConfigAPITest(APITestCase):
url = self.reverse("website_info_api") url = self.reverse("website_info_api")
resp = self.client.get(url) resp = self.client.get(url)
self.assertSuccess(resp) self.assertSuccess(resp)
self.assertEqual(resp.data["data"]["name_shortcut"], "oj")
def tearDown(self):
default_cache.delete(CacheKey.website_config)
class JudgeServerHeartbeatTest(APITestCase): class JudgeServerHeartbeatTest(APITestCase):
@ -91,7 +77,7 @@ class JudgeServerHeartbeatTest(APITestCase):
"cpu": 90.5, "memory": 80.3, "action": "heartbeat"} "cpu": 90.5, "memory": 80.3, "action": "heartbeat"}
self.token = "test" self.token = "test"
self.hashed_token = hashlib.sha256(self.token.encode("utf-8")).hexdigest() self.hashed_token = hashlib.sha256(self.token.encode("utf-8")).hexdigest()
JudgeServerToken.objects.create(token=self.token) SysOptions.judge_server_token = self.token
def test_new_heartbeat(self): def test_new_heartbeat(self):
resp = self.client.post(self.url, data=self.data, **{"HTTP_X_JUDGE_SERVER_TOKEN": self.hashed_token}) resp = self.client.post(self.url, data=self.data, **{"HTTP_X_JUDGE_SERVER_TOKEN": self.hashed_token})
@ -127,11 +113,9 @@ class JudgeServerAPITest(APITestCase):
self.create_super_admin() self.create_super_admin()
def test_get_judge_server(self): def test_get_judge_server(self):
self.assertFalse(JudgeServerToken.objects.exists())
resp = self.client.get(self.url) resp = self.client.get(self.url)
self.assertSuccess(resp) self.assertSuccess(resp)
self.assertEqual(len(resp.data["data"]["servers"]), 1) self.assertEqual(len(resp.data["data"]["servers"]), 1)
self.assertEqual(JudgeServerToken.objects.first().token, resp.data["data"]["token"])
def test_delete_judge_server(self): def test_delete_judge_server(self):
resp = self.client.delete(self.url + "?hostname=testhostname") resp = self.client.delete(self.url + "?hostname=testhostname")

View File

@ -1,54 +1,45 @@
import hashlib import hashlib
import pickle
from django.utils import timezone from django.utils import timezone
from account.decorators import super_admin_required from account.decorators import super_admin_required
from judge.languages import languages, spj_languages
from judge.dispatcher import process_pending_task from judge.dispatcher import process_pending_task
from judge.languages import languages, spj_languages
from options.options import SysOptions
from utils.api import APIView, CSRFExemptAPIView, validate_serializer from utils.api import APIView, CSRFExemptAPIView, validate_serializer
from utils.shortcuts import rand_str from .models import JudgeServer
from utils.cache import default_cache
from utils.constants import CacheKey
from .models import JudgeServer, JudgeServerToken, SMTPConfig, WebsiteConfig
from .serializers import (CreateEditWebsiteConfigSerializer, from .serializers import (CreateEditWebsiteConfigSerializer,
CreateSMTPConfigSerializer, EditSMTPConfigSerializer, CreateSMTPConfigSerializer, EditSMTPConfigSerializer,
JudgeServerHeartbeatSerializer, JudgeServerHeartbeatSerializer,
JudgeServerSerializer, SMTPConfigSerializer, JudgeServerSerializer, TestSMTPConfigSerializer)
TestSMTPConfigSerializer, WebsiteConfigSerializer)
class SMTPAPI(APIView): class SMTPAPI(APIView):
@super_admin_required @super_admin_required
def get(self, request): def get(self, request):
smtp = SMTPConfig.objects.first() smtp = SysOptions.smtp_config
if not smtp: if not smtp:
return self.success(None) return self.success(None)
return self.success(SMTPConfigSerializer(smtp).data) smtp.pop("password")
return self.success(smtp)
@validate_serializer(CreateSMTPConfigSerializer) @validate_serializer(CreateSMTPConfigSerializer)
@super_admin_required @super_admin_required
def post(self, request): def post(self, request):
SMTPConfig.objects.all().delete() SysOptions.smtp_config = request.data
smtp = SMTPConfig.objects.create(**request.data) return self.success()
return self.success(SMTPConfigSerializer(smtp).data)
@validate_serializer(EditSMTPConfigSerializer) @validate_serializer(EditSMTPConfigSerializer)
@super_admin_required @super_admin_required
def put(self, request): def put(self, request):
smtp = SysOptions.smtp_config
data = request.data data = request.data
smtp = SMTPConfig.objects.first() for item in ["server", "port", "email", "tls"]:
if not smtp: smtp[item] = data[item]
return self.error("SMTP config is missing") if "password" in data:
smtp.server = data["server"] smtp["password"] = data["password"]
smtp.port = data["port"] SysOptions.smtp_config = smtp
smtp.email = data["email"] return self.success()
smtp.tls = data["tls"]
if data.get("password"):
smtp.password = data["password"]
smtp.save()
return self.success(SMTPConfigSerializer(smtp).data)
class SMTPTestAPI(APIView): class SMTPTestAPI(APIView):
@ -60,37 +51,24 @@ class SMTPTestAPI(APIView):
class WebsiteConfigAPI(APIView): class WebsiteConfigAPI(APIView):
def get(self, request): def get(self, request):
config = default_cache.get(CacheKey.website_config) ret = {key: getattr(SysOptions, key) for key in
if config: ["website_base_url", "website_name", "website_name_shortcut",
config = pickle.loads(config) "website_footer", "allow_register", "submission_list_show_all"]}
else: return self.success(ret)
config = WebsiteConfig.objects.first()
if not config:
config = WebsiteConfig.objects.create()
default_cache.set(CacheKey.website_config, pickle.dumps(config))
return self.success(WebsiteConfigSerializer(config).data)
@validate_serializer(CreateEditWebsiteConfigSerializer) @validate_serializer(CreateEditWebsiteConfigSerializer)
@super_admin_required @super_admin_required
def post(self, request): def post(self, request):
data = request.data for k, v in request.data.items():
WebsiteConfig.objects.all().delete() setattr(SysOptions, k, v)
config = WebsiteConfig.objects.create(**data) return self.success()
default_cache.set(CacheKey.website_config, pickle.dumps(config))
return self.success(WebsiteConfigSerializer(config).data)
class JudgeServerAPI(APIView): class JudgeServerAPI(APIView):
@super_admin_required @super_admin_required
def get(self, request): def get(self, request):
judge_server_token = JudgeServerToken.objects.first()
if not judge_server_token:
token = rand_str(12)
JudgeServerToken.objects.create(token=token)
else:
token = judge_server_token.token
servers = JudgeServer.objects.all().order_by("-last_heartbeat") servers = JudgeServer.objects.all().order_by("-last_heartbeat")
return self.success({"token": token, return self.success({"token": SysOptions.judge_server_token,
"servers": JudgeServerSerializer(servers, many=True).data}) "servers": JudgeServerSerializer(servers, many=True).data})
@super_admin_required @super_admin_required
@ -104,15 +82,9 @@ class JudgeServerAPI(APIView):
class JudgeServerHeartbeatAPI(CSRFExemptAPIView): class JudgeServerHeartbeatAPI(CSRFExemptAPIView):
@validate_serializer(JudgeServerHeartbeatSerializer) @validate_serializer(JudgeServerHeartbeatSerializer)
def post(self, request): def post(self, request):
judge_server_token = JudgeServerToken.objects.first()
if not judge_server_token:
token = rand_str(12)
JudgeServerToken.objects.create(token=token)
else:
token = judge_server_token.token
data = request.data data = request.data
client_token = request.META.get("HTTP_X_JUDGE_SERVER_TOKEN") client_token = request.META.get("HTTP_X_JUDGE_SERVER_TOKEN")
if hashlib.sha256(token.encode("utf-8")).hexdigest() != client_token: if hashlib.sha256(SysOptions.judge_server_token.encode("utf-8")).hexdigest() != client_token:
return self.error("Invalid token") return self.error("Invalid token")
service_url = data.get("service_url") service_url = data.get("service_url")

View File

@ -0,0 +1,26 @@
# -*- coding: utf-8 -*-
# Generated by Django 1.11.4 on 2017-10-11 12:14
from __future__ import unicode_literals
import django.contrib.postgres.fields.jsonb
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
('contest', '0005_auto_20170823_0918'),
]
operations = [
migrations.AlterField(
model_name='acmcontestrank',
name='submission_info',
field=django.contrib.postgres.fields.jsonb.JSONField(default=dict),
),
migrations.AlterField(
model_name='oicontestrank',
name='submission_info',
field=django.contrib.postgres.fields.jsonb.JSONField(default=dict),
),
]

View File

@ -1,27 +1,13 @@
from utils.constants import ContestRuleType # noqa
from django.db import models from django.db import models
from django.utils.timezone import now from django.utils.timezone import now
from jsonfield import JSONField from utils.models import JSONField
from utils.constants import ContestStatus, ContestType
from account.models import User, AdminType from account.models import User, AdminType
from utils.models import RichTextField from utils.models import RichTextField
class ContestType(object):
PUBLIC_CONTEST = "Public"
PASSWORD_PROTECTED_CONTEST = "Password Protected"
class ContestStatus(object):
CONTEST_NOT_START = "1"
CONTEST_ENDED = "-1"
CONTEST_UNDERWAY = "0"
class ContestRuleType(object):
ACM = "ACM"
OI = "OI"
class Contest(models.Model): class Contest(models.Model):
title = models.CharField(max_length=40) title = models.CharField(max_length=40)
description = RichTextField() description = RichTextField()
@ -59,12 +45,20 @@ class Contest(models.Model):
def is_contest_admin(self, user): def is_contest_admin(self, user):
return user.is_authenticated() and (self.created_by == user or user.admin_type == AdminType.SUPER_ADMIN) return user.is_authenticated() and (self.created_by == user or user.admin_type == AdminType.SUPER_ADMIN)
def check_oi_permission(self, user):
if self.status != ContestStatus.CONTEST_ENDED and self.real_time_rank == False:
if self.is_contest_admin(user):
return True
else:
return False
return True
class Meta: class Meta:
db_table = "contest" db_table = "contest"
ordering = ("-create_time",) ordering = ("-create_time",)
class ContestRank(models.Model): class AbstractContestRank(models.Model):
user = models.ForeignKey(User) user = models.ForeignKey(User)
contest = models.ForeignKey(Contest) contest = models.ForeignKey(Contest)
submission_number = models.IntegerField(default=0) submission_number = models.IntegerField(default=0)
@ -73,30 +67,27 @@ class ContestRank(models.Model):
abstract = True abstract = True
class ACMContestRank(ContestRank): class ACMContestRank(AbstractContestRank):
accepted_number = models.IntegerField(default=0) accepted_number = models.IntegerField(default=0)
# total_time is only for ACM contest total_time = ac time + none-ac times * 20 * 60 # total_time is only for ACM contest total_time = ac time + none-ac times * 20 * 60
total_time = models.IntegerField(default=0) total_time = models.IntegerField(default=0)
# {23: {"is_ac": True, "ac_time": 8999, "error_number": 2, "is_first_ac": True}} # {23: {"is_ac": True, "ac_time": 8999, "error_number": 2, "is_first_ac": True}}
# key is problem id # key is problem id
submission_info = JSONField(default={}) submission_info = JSONField(default=dict)
class Meta: class Meta:
db_table = "acm_contest_rank" db_table = "acm_contest_rank"
class OIContestRank(ContestRank): class OIContestRank(AbstractContestRank):
total_score = models.IntegerField(default=0) total_score = models.IntegerField(default=0)
# {23: 333}} # {23: 333}}
# key is problem id, value is current score # key is problem id, value is current score
submission_info = JSONField(default={}) submission_info = JSONField(default=dict)
class Meta: class Meta:
db_table = "oi_contest_rank" db_table = "oi_contest_rank"
def update_rank(self, submission):
self.submission_number += 1
class ContestAnnouncement(models.Model): class ContestAnnouncement(models.Model):
contest = models.ForeignKey(Contest) contest = models.ForeignKey(Contest)

View File

@ -79,7 +79,7 @@ class ContestAPITest(APITestCase):
self.create_user("test", "test123") self.create_user("test", "test123")
url = self.reverse("contest_password_api") url = self.reverse("contest_password_api")
resp = self.client.post(url, {"contest_id": contest_id, "password": "error_password"}) resp = self.client.post(url, {"contest_id": contest_id, "password": "error_password"})
self.assertDictEqual(resp.data, {"error": "error", "data": "Password doesn't match."}) self.assertDictEqual(resp.data, {"error": "error", "data": "Wrong password"})
resp = self.client.post(url, {"contest_id": contest_id, "password": DEFAULT_CONTEST_DATA["password"]}) resp = self.client.post(url, {"contest_id": contest_id, "password": DEFAULT_CONTEST_DATA["password"]})
self.assertSuccess(resp) self.assertSuccess(resp)
@ -89,7 +89,7 @@ class ContestAPITest(APITestCase):
self.create_user("test", "test123") self.create_user("test", "test123")
url = self.reverse("contest_access_api") url = self.reverse("contest_access_api")
resp = self.client.get(url + "?contest_id=" + str(contest_id)) resp = self.client.get(url + "?contest_id=" + str(contest_id))
self.assertFalse(resp.data["data"]["Access"]) self.assertFalse(resp.data["data"]["access"])
password_url = self.reverse("contest_password_api") password_url = self.reverse("contest_password_api")
resp = self.client.post(password_url, {"contest_id": contest_id, "password": DEFAULT_CONTEST_DATA["password"]}) resp = self.client.post(password_url, {"contest_id": contest_id, "password": DEFAULT_CONTEST_DATA["password"]})

View File

@ -1,13 +1,11 @@
import pickle
from django.utils.timezone import now from django.utils.timezone import now
from django.db.models import Q from django.core.cache import cache
from utils.api import APIView, validate_serializer from utils.api import APIView, validate_serializer
from utils.cache import default_cache
from utils.constants import CacheKey from utils.constants import CacheKey
from account.decorators import login_required, check_contest_permission from account.decorators import login_required, check_contest_permission
from ..models import ContestAnnouncement, Contest, ContestStatus, ContestRuleType from utils.constants import ContestRuleType, ContestStatus
from ..models import OIContestRank, ACMContestRank from ..models import ContestAnnouncement, Contest, OIContestRank, ACMContestRank
from ..serializers import ContestAnnouncementSerializer from ..serializers import ContestAnnouncementSerializer
from ..serializers import ContestSerializer, ContestPasswordVerifySerializer from ..serializers import ContestSerializer, ContestPasswordVerifySerializer
from ..serializers import OIContestRankSerializer, ACMContestRankSerializer from ..serializers import OIContestRankSerializer, ACMContestRankSerializer
@ -32,7 +30,7 @@ class ContestAPI(APIView):
try: try:
contest = Contest.objects.select_related("created_by").get(id=contest_id, visible=True) contest = Contest.objects.select_related("created_by").get(id=contest_id, visible=True)
except Contest.DoesNotExist: except Contest.DoesNotExist:
return self.error("Contest doesn't exist.") return self.error("Contest does not exist")
return self.success(ContestSerializer(contest).data) return self.success(ContestSerializer(contest).data)
contests = Contest.objects.select_related("created_by").filter(visible=True) contests = Contest.objects.select_related("created_by").filter(visible=True)
@ -50,7 +48,7 @@ class ContestAPI(APIView):
elif status == ContestStatus.CONTEST_ENDED: elif status == ContestStatus.CONTEST_ENDED:
contests = contests.filter(end_time__lt=cur) contests = contests.filter(end_time__lt=cur)
else: else:
contests = contests.filter(Q(start_time__lte=cur) & Q(end_time__gte=cur)) contests = contests.filter(start_time__lte=cur, end_time__gte=cur)
return self.success(self.paginate_data(request, contests, ContestSerializer)) return self.success(self.paginate_data(request, contests, ContestSerializer))
@ -62,14 +60,14 @@ class ContestPasswordVerifyAPI(APIView):
try: try:
contest = Contest.objects.get(id=data["contest_id"], visible=True, password__isnull=False) contest = Contest.objects.get(id=data["contest_id"], visible=True, password__isnull=False)
except Contest.DoesNotExist: except Contest.DoesNotExist:
return self.error("Contest %s doesn't exist." % data["contest_id"]) return self.error("Contest does not exist")
if contest.password != data["password"]: if contest.password != data["password"]:
return self.error("Password doesn't match.") return self.error("Wrong password")
# password verify OK. # password verify OK.
if "contests" not in request.session: if "accessible_contests" not in request.session:
request.session["contests"] = [] request.session["accessible_contests"] = []
request.session["contests"].append(int(data["contest_id"])) request.session["accessible_contests"].append(contest.id)
# https://docs.djangoproject.com/en/dev/topics/http/sessions/#when-sessions-are-saved # https://docs.djangoproject.com/en/dev/topics/http/sessions/#when-sessions-are-saved
request.session.modified = True request.session.modified = True
return self.success(True) return self.success(True)
@ -80,13 +78,8 @@ class ContestAccessAPI(APIView):
def get(self, request): def get(self, request):
contest_id = request.GET.get("contest_id") contest_id = request.GET.get("contest_id")
if not contest_id: if not contest_id:
return self.error("Parameter contest_id not exist.") return self.error()
if "contests" not in request.session: return self.success({"access": int(contest_id) in request.session.get("accessible_contests", [])})
request.session["contests"] = []
if int(contest_id) in request.session["contests"]:
return self.success({"Access": True})
else:
return self.success({"Access": False})
class ContestRankAPI(APIView): class ContestRankAPI(APIView):
@ -100,17 +93,17 @@ class ContestRankAPI(APIView):
@check_contest_permission @check_contest_permission
def get(self, request): def get(self, request):
if self.contest.rule_type == ContestRuleType.ACM: if self.contest.rule_type == ContestRuleType.OI:
serializer = ACMContestRankSerializer if not self.contest.check_oi_permission(request.user):
else: return self.error("You have no permission for ranks now")
serializer = OIContestRankSerializer serializer = OIContestRankSerializer
cache_key = CacheKey.contest_rank_cache + str(self.contest.id)
qs = default_cache.get(cache_key)
if not qs:
ranks = self.get_rank()
default_cache.set(cache_key, pickle.dumps(ranks))
else: else:
ranks = pickle.loads(qs) serializer = ACMContestRankSerializer
return self.success(self.paginate_data(request, ranks, serializer)) cache_key = f"{CacheKey.contest_rank_cache}:{self.contest.id}"
qs = cache.get(cache_key)
if not qs:
qs = self.get_rank()
cache.set(cache_key, qs)
return self.success(self.paginate_data(request, qs, serializer))

View File

@ -8,12 +8,13 @@ from django.db import transaction
from django.db.models import F from django.db.models import F
from account.models import User from account.models import User
from conf.models import JudgeServer, JudgeServerToken from conf.models import JudgeServer
from contest.models import ContestRuleType, ACMContestRank, OIContestRank, ContestStatus from contest.models import ContestRuleType, ACMContestRank, OIContestRank, ContestStatus
from judge.languages import languages from judge.languages import languages
from options.options import SysOptions
from problem.models import Problem, ProblemRuleType from problem.models import Problem, ProblemRuleType
from submission.models import JudgeStatus, Submission from submission.models import JudgeStatus, Submission
from utils.cache import judge_cache, default_cache from utils.cache import cache
from utils.constants import CacheKey from utils.constants import CacheKey
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -21,41 +22,36 @@ logger = logging.getLogger(__name__)
# 继续处理在队列中的问题 # 继续处理在队列中的问题
def process_pending_task(): def process_pending_task():
if judge_cache.llen(CacheKey.waiting_queue): if cache.llen(CacheKey.waiting_queue):
# 防止循环引入 # 防止循环引入
from judge.tasks import judge_task from judge.tasks import judge_task
data = json.loads(judge_cache.rpop(CacheKey.waiting_queue).decode("utf-8")) data = json.loads(cache.rpop(CacheKey.waiting_queue).decode("utf-8"))
judge_task.delay(**data) judge_task.delay(**data)
class JudgeDispatcher(object): class JudgeDispatcher(object):
def __init__(self, submission_id, problem_id): def __init__(self, submission_id, problem_id):
token = JudgeServerToken.objects.first().token self.token = hashlib.sha256(SysOptions.judge_server_token.encode("utf-8")).hexdigest()
self.token = hashlib.sha256(token.encode("utf-8")).hexdigest() self.submission = Submission.objects.get(id=submission_id)
self.redis_conn = judge_cache
self.submission = Submission.objects.get(pk=submission_id)
self.contest_id = self.submission.contest_id self.contest_id = self.submission.contest_id
if self.contest_id: if self.contest_id:
self.problem = Problem.objects.select_related("contest") \ self.problem = Problem.objects.select_related("contest").get(id=problem_id, contest_id=self.contest_id)
.get(id=problem_id, contest_id=self.contest_id)
self.contest = self.problem.contest self.contest = self.problem.contest
else: else:
self.problem = Problem.objects.get(id=problem_id) self.problem = Problem.objects.get(id=problem_id)
def _request(self, url, data=None): def _request(self, url, data=None):
kwargs = {"headers": {"X-Judge-Server-Token": self.token, kwargs = {"headers": {"X-Judge-Server-Token": self.token}}
"Content-Type": "application/json"}}
if data: if data:
kwargs["data"] = json.dumps(data) kwargs["json"] = data
try: try:
return requests.post(url, **kwargs).json() return requests.post(url, **kwargs).json()
except Exception as e: except Exception as e:
logger.error(e.with_traceback()) logger.exception(e)
@staticmethod @staticmethod
def choose_judge_server(): def choose_judge_server():
with transaction.atomic(): with transaction.atomic():
# TODO: use more reasonable way
servers = JudgeServer.objects.select_for_update().all().order_by("task_number") servers = JudgeServer.objects.select_for_update().all().order_by("task_number")
servers = [s for s in servers if s.status == "normal"] servers = [s for s in servers if s.status == "normal"]
if servers: if servers:
@ -65,10 +61,10 @@ class JudgeDispatcher(object):
return server return server
@staticmethod @staticmethod
def release_judge_res(judge_server_id): def release_judge_server(judge_server_id):
with transaction.atomic(): with transaction.atomic():
# 使用原子操作, 同时因为use和release中间间隔了判题过程,需要重新查询一下 # 使用原子操作, 同时因为use和release中间间隔了判题过程,需要重新查询一下
server = JudgeServer.objects.select_for_update().get(id=judge_server_id) server = JudgeServer.objects.get(id=judge_server_id)
server.used_instance_number = F("task_number") - 1 server.used_instance_number = F("task_number") - 1
server.save() server.save()
@ -94,7 +90,7 @@ class JudgeDispatcher(object):
server = self.choose_judge_server() server = self.choose_judge_server()
if not server: if not server:
data = {"submission_id": self.submission.id, "problem_id": self.problem.id} data = {"submission_id": self.submission.id, "problem_id": self.problem.id}
self.redis_conn.lpush(CacheKey.waiting_queue, json.dumps(data)) cache.lpush(CacheKey.waiting_queue, json.dumps(data))
return return
sub_config = list(filter(lambda item: self.submission.language == item["name"], languages))[0] sub_config = list(filter(lambda item: self.submission.language == item["name"], languages))[0]
@ -138,7 +134,7 @@ class JudgeDispatcher(object):
else: else:
self.submission.result = JudgeStatus.PARTIALLY_ACCEPTED self.submission.result = JudgeStatus.PARTIALLY_ACCEPTED
self.submission.save() self.submission.save()
self.release_judge_res(server.id) self.release_judge_server(server.id)
self.update_problem_status() self.update_problem_status()
if self.contest_id: if self.contest_id:
@ -160,7 +156,6 @@ class JudgeDispatcher(object):
with transaction.atomic(): with transaction.atomic():
# prepare problem and user_profile # prepare problem and user_profile
problem = Problem.objects.select_for_update().get(contest_id=self.contest_id, id=self.problem.id) problem = Problem.objects.select_for_update().get(contest_id=self.contest_id, id=self.problem.id)
problem_info = problem.statistic_info
user = User.objects.select_for_update().select_for_update("userprofile").get(id=self.submission.user_id) user = User.objects.select_for_update().select_for_update("userprofile").get(id=self.submission.user_id)
user_profile = user.userprofile user_profile = user.userprofile
if self.contest_id: if self.contest_id:
@ -169,50 +164,46 @@ class JudgeDispatcher(object):
key = "problems" key = "problems"
acm_problems_status = user_profile.acm_problems_status.get(key, {}) acm_problems_status = user_profile.acm_problems_status.get(key, {})
oi_problems_status = user_profile.oi_problems_status.get(key, {}) oi_problems_status = user_profile.oi_problems_status.get(key, {})
problem_id = str(self.problem.id)
problem_info = problem.statistic_info
# update problem info
result = str(self.submission.result)
problem_info[result] = problem_info.get(result, 0) + 1
problem.statistic_info = problem_info
# update submission and accepted number counter # update submission and accepted number counter
problem.submission_number += 1 problem.submission_number += 1
if self.submission.result == JudgeStatus.ACCEPTED: if self.submission.result == JudgeStatus.ACCEPTED:
problem.accepted_number += 1 problem.accepted_number += 1
# only when submission is not in contest, we update user profile, # submission in a contest will not be counted in user profile
# in other words, users' submission in a contest will not be counted in user profile
if not self.contest_id: if not self.contest_id:
user_profile.submission_number += 1 user_profile.submission_number += 1
if self.submission.result == JudgeStatus.ACCEPTED: if self.submission.result == JudgeStatus.ACCEPTED:
user_profile.accepted_number += 1 user_profile.accepted_number += 1
problem_id = str(self.problem.id)
if self.problem.rule_type == ProblemRuleType.ACM: if self.problem.rule_type == ProblemRuleType.ACM:
# update acm problem info
result = str(self.submission.result)
problem_info[result] = problem_info.get(result, 0) + 1
problem.statistic_info = problem_info
# update user_profile # update user_profile
if problem_id not in acm_problems_status: if problem_id not in acm_problems_status:
acm_problems_status[problem_id] = self.submission.result acm_problems_status[problem_id] = {"status": self.submission.result, "_id": self.problem._id}
# skip if the problem has been accepted # skip if the problem has been accepted
elif acm_problems_status[problem_id] != JudgeStatus.ACCEPTED: elif acm_problems_status[problem_id]["status"] != JudgeStatus.ACCEPTED:
if self.submission.result == JudgeStatus.ACCEPTED: acm_problems_status[problem_id]["status"] = self.submission.result
acm_problems_status[problem_id] = JudgeStatus.ACCEPTED
else:
acm_problems_status[problem_id] = self.submission.result
user_profile.acm_problems_status[key] = acm_problems_status user_profile.acm_problems_status[key] = acm_problems_status
else: else:
# update oi problem info
score = self.submission.statistic_info["score"]
problem_info[score] = problem_info.get(score, 0) + 1
problem.statistic_info = problem_info
# update user_profile # update user_profile
score = self.submission.statistic_info["score"]
if problem_id not in oi_problems_status: if problem_id not in oi_problems_status:
user_profile.add_score(score) user_profile.add_score(score)
oi_problems_status[problem_id] = score oi_problems_status[problem_id] = {"status": self.submission.result,
"_id": self.problem._id,
"score": score}
else: else:
# minus last time score, add this time score # minus last time score, add this time score
user_profile.add_score(this_time_score=score, last_time_score=oi_problems_status[problem_id]) user_profile.add_score(this_time_score=score, last_time_score=oi_problems_status[problem_id]["score"])
oi_problems_status[problem_id] = score oi_problems_status[problem_id]["score"] = score
oi_problems_status[problem_id]["status"] = self.submission.result
user_profile.oi_problems_status[key] = oi_problems_status user_profile.oi_problems_status[key] = oi_problems_status
problem.save(update_fields=["submission_number", "accepted_number", "statistic_info"]) problem.save(update_fields=["submission_number", "accepted_number", "statistic_info"])
@ -222,8 +213,8 @@ class JudgeDispatcher(object):
def update_contest_rank(self): def update_contest_rank(self):
if self.contest_id and self.contest.status != ContestStatus.CONTEST_UNDERWAY: if self.contest_id and self.contest.status != ContestStatus.CONTEST_UNDERWAY:
return return
if self.contest.real_time_rank: if self.contest.rule_type == ContestRuleType.OI or self.contest.real_time_rank:
default_cache.delete(CacheKey.contest_rank_cache + str(self.contest_id)) cache.delete(f"{CacheKey.contest_rank_cache}:{self.contest.id}")
with transaction.atomic(): with transaction.atomic():
if self.contest.rule_type == ContestRuleType.ACM: if self.contest.rule_type == ContestRuleType.ACM:
acm_rank, _ = ACMContestRank.objects.select_for_update(). \ acm_rank, _ = ACMContestRank.objects.select_for_update(). \

View File

@ -5,8 +5,12 @@ BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
DATABASES = { DATABASES = {
'default': { 'default': {
'ENGINE': 'django.db.backends.sqlite3', 'ENGINE': 'django.db.backends.postgresql_psycopg2',
'NAME': os.path.join(BASE_DIR, 'db.sqlite3'), 'HOST': '127.0.0.1',
'PORT': 5433,
'NAME': "onlinejudge",
'USER': "onlinejudge",
'PASSWORD': 'onlinejudge'
} }
} }

View File

@ -46,6 +46,7 @@ INSTALLED_APPS = (
'contest', 'contest',
'utils', 'utils',
'submission', 'submission',
'options',
) )
MIDDLEWARE_CLASSES = ( MIDDLEWARE_CLASSES = (
@ -57,11 +58,9 @@ MIDDLEWARE_CLASSES = (
'django.middleware.clickjacking.XFrameOptionsMiddleware', 'django.middleware.clickjacking.XFrameOptionsMiddleware',
'django.middleware.security.SecurityMiddleware', 'django.middleware.security.SecurityMiddleware',
'account.middleware.AdminRoleRequiredMiddleware', 'account.middleware.AdminRoleRequiredMiddleware',
'account.middleware.SessionSecurityMiddleware',
'account.middleware.SessionRecordMiddleware', 'account.middleware.SessionRecordMiddleware',
# 'account.middleware.LogSqlMiddleware', # 'account.middleware.LogSqlMiddleware',
) )
SESSION_ENGINE = 'django.contrib.sessions.backends.cache'
ROOT_URLCONF = 'oj.urls' ROOT_URLCONF = 'oj.urls'
TEMPLATES = [ TEMPLATES = [
@ -164,8 +163,6 @@ LOGGING = {
} }
}, },
} }
REST_FRAMEWORK = { REST_FRAMEWORK = {
'TEST_REQUEST_DEFAULT_FORMAT': 'json', 'TEST_REQUEST_DEFAULT_FORMAT': 'json',
'DEFAULT_RENDERER_CLASSES': ( 'DEFAULT_RENDERER_CLASSES': (
@ -173,34 +170,34 @@ REST_FRAMEWORK = {
) )
} }
CACHE_JUDGE_QUEUE = "judge_queue"
CACHE_THROTTLING = "throttling"
REDIS_URL = "redis://127.0.0.1:6379"
def redis_config(db):
def make_key(key, key_prefix, version):
return key
return {
"BACKEND": "utils.cache.MyRedisCache",
"LOCATION": f"{REDIS_URL}/{db}",
"TIMEOUT": None,
"KEY_PREFIX": "",
"KEY_FUNCTION": make_key
}
CACHES = { CACHES = {
"default": { "default": redis_config(db=1)
"BACKEND": "django_redis.cache.RedisCache",
"LOCATION": "redis://127.0.0.1:6379/1",
"OPTIONS": {
"CLIENT_CLASS": "django_redis.client.DefaultClient",
}
},
CACHE_JUDGE_QUEUE: {
"BACKEND": "django_redis.cache.RedisCache",
"LOCATION": "redis://127.0.0.1:6379/2",
"OPTIONS": {
"CLIENT_CLASS": "django_redis.client.DefaultClient",
}
},
CACHE_THROTTLING: {
"BACKEND": "django_redis.cache.RedisCache",
"LOCATION": "redis://127.0.0.1:6379/3",
"OPTIONS": {
"CLIENT_CLASS": "django_redis.client.DefaultClient",
}
}
} }
CELERY_RESULT_BACKEND = CELERY_BROKER_URL = f"{REDIS_URL}/2"
CELERY_TASK_SOFT_TIME_LIMIT = CELERY_TASK_TIME_LIMIT = 180
SESSION_ENGINE = "django.contrib.sessions.backends.cache"
SESSION_CACHE_ALIAS = "default"
# For celery # For celery
REDIS_QUEUE = { REDIS_QUEUE = {
"host": "127.0.0.1", "host": "127.0.0.1",

0
options/__init__.py Normal file
View File

View File

@ -0,0 +1,25 @@
# -*- coding: utf-8 -*-
# Generated by Django 1.11.3 on 2017-10-01 19:19
from __future__ import unicode_literals
import jsonfield.fields
from django.db import migrations, models
class Migration(migrations.Migration):
initial = True
dependencies = [
]
operations = [
migrations.CreateModel(
name='SysOptions',
fields=[
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('key', models.CharField(db_index=True, max_length=128, unique=True)),
('value', jsonfield.fields.JSONField()),
],
),
]

View File

@ -0,0 +1,21 @@
# -*- coding: utf-8 -*-
# Generated by Django 1.11.4 on 2017-10-11 12:14
from __future__ import unicode_literals
import django.contrib.postgres.fields.jsonb
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
('options', '0001_initial'),
]
operations = [
migrations.AlterField(
model_name='sysoptions',
name='value',
field=django.contrib.postgres.fields.jsonb.JSONField(),
),
]

View File

7
options/models.py Normal file
View File

@ -0,0 +1,7 @@
from django.db import models
from utils.models import JSONField
class SysOptions(models.Model):
key = models.CharField(max_length=128, unique=True, db_index=True)
value = JSONField()

179
options/options.py Normal file
View File

@ -0,0 +1,179 @@
from django.core.cache import cache
from django.db import transaction, IntegrityError
from utils.constants import CacheKey
from utils.shortcuts import rand_str
from .models import SysOptions as SysOptionsModel
class OptionKeys:
website_base_url = "website_base_url"
website_name = "website_name"
website_name_shortcut = "website_name_shortcut"
website_footer = "website_footer"
allow_register = "allow_register"
submission_list_show_all = "submission_list_show_all"
smtp_config = "smtp_config"
judge_server_token = "judge_server_token"
class OptionDefaultValue:
website_base_url = "http://127.0.0.1"
website_name = "Online Judge"
website_name_shortcut = "oj"
website_footer = "Online Judge Footer"
allow_register = True
submission_list_show_all = True
smtp_config = {}
judge_server_token = rand_str
class _SysOptionsMeta(type):
@classmethod
def _set_cache(mcs, option_key, option_value):
cache.set(f"{CacheKey.option}:{option_key}", option_value, timeout=60)
@classmethod
def _del_cache(mcs, option_key):
cache.delete(f"{CacheKey.option}:{option_key}")
@classmethod
def _get_keys(cls):
return [key for key in OptionKeys.__dict__ if not key.startswith("__")]
def rebuild_cache(cls):
for key in cls._get_keys():
# get option 的时候会写 cache 的
cls._get_option(key, use_cache=False)
@classmethod
def _init_option(mcs):
for item in mcs._get_keys():
if not SysOptionsModel.objects.filter(key=item).exists():
default_value = getattr(OptionDefaultValue, item)
if callable(default_value):
default_value = default_value()
try:
SysOptionsModel.objects.create(key=item, value=default_value)
except IntegrityError:
pass
@classmethod
def _get_option(mcs, option_key, use_cache=True):
try:
if use_cache:
option = cache.get(f"{CacheKey.option}:{option_key}")
if option:
return option
option = SysOptionsModel.objects.get(key=option_key)
value = option.value
mcs._set_cache(option_key, value)
return value
except SysOptionsModel.DoesNotExist:
mcs._init_option()
return mcs._get_option(option_key, use_cache=use_cache)
@classmethod
def _set_option(mcs, option_key: str, option_value):
try:
with transaction.atomic():
option = SysOptionsModel.objects.select_for_update().get(key=option_key)
option.value = option_value
option.save()
mcs._del_cache(option_key)
except SysOptionsModel.DoesNotExist:
mcs._init_option()
mcs._set_option(option_key, option_value)
@classmethod
def _increment(mcs, option_key):
try:
with transaction.atomic():
option = SysOptionsModel.objects.select_for_update().get(key=option_key)
value = option.value + 1
option.value = value
option.save()
mcs._del_cache(option_key)
except SysOptionsModel.DoesNotExist:
mcs._init_option()
return mcs._increment(option_key)
@classmethod
def set_options(mcs, options):
for key, value in options:
mcs._set_option(key, value)
@classmethod
def get_options(mcs, keys):
result = {}
for key in keys:
result[key] = mcs._get_option(key)
return result
@property
def website_base_url(cls):
return cls._get_option(OptionKeys.website_base_url)
@website_base_url.setter
def website_base_url(cls, value):
cls._set_option(OptionKeys.website_base_url, value)
@property
def website_name(cls):
return cls._get_option(OptionKeys.website_name)
@website_name.setter
def website_name(cls, value):
cls._set_option(OptionKeys.website_name, value)
@property
def website_name_shortcut(cls):
return cls._get_option(OptionKeys.website_name_shortcut)
@website_name_shortcut.setter
def website_name_shortcut(cls, value):
cls._set_option(OptionKeys.website_name_shortcut, value)
@property
def website_footer(cls):
return cls._get_option(OptionKeys.website_footer)
@website_footer.setter
def website_footer(cls, value):
cls._set_option(OptionKeys.website_footer, value)
@property
def allow_register(cls):
return cls._get_option(OptionKeys.allow_register)
@allow_register.setter
def allow_register(cls, value):
cls._set_option(OptionKeys.allow_register, value)
@property
def submission_list_show_all(cls):
return cls._get_option(OptionKeys.submission_list_show_all)
@submission_list_show_all.setter
def submission_list_show_all(cls, value):
cls._set_option(OptionKeys.submission_list_show_all, value)
@property
def smtp_config(cls):
return cls._get_option(OptionKeys.smtp_config)
@smtp_config.setter
def smtp_config(cls, value):
cls._set_option(OptionKeys.smtp_config, value)
@property
def judge_server_token(cls):
return cls._get_option(OptionKeys.judge_server_token)
@judge_server_token.setter
def judge_server_token(cls, value):
cls._set_option(OptionKeys.judge_server_token, value)
class SysOptions(metaclass=_SysOptionsMeta):
pass

1
options/tests.py Normal file
View File

@ -0,0 +1 @@
# Create your tests here.

1
options/views.py Normal file
View File

@ -0,0 +1 @@
# Create your views here.

View File

@ -0,0 +1,41 @@
# -*- coding: utf-8 -*-
# Generated by Django 1.11.4 on 2017-10-11 12:14
from __future__ import unicode_literals
import django.contrib.postgres.fields.jsonb
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
('problem', '0008_auto_20170923_1318'),
]
operations = [
migrations.AlterField(
model_name='problem',
name='languages',
field=django.contrib.postgres.fields.jsonb.JSONField(),
),
migrations.AlterField(
model_name='problem',
name='samples',
field=django.contrib.postgres.fields.jsonb.JSONField(),
),
migrations.AlterField(
model_name='problem',
name='statistic_info',
field=django.contrib.postgres.fields.jsonb.JSONField(default=dict),
),
migrations.AlterField(
model_name='problem',
name='template',
field=django.contrib.postgres.fields.jsonb.JSONField(),
),
migrations.AlterField(
model_name='problem',
name='test_case_score',
field=django.contrib.postgres.fields.jsonb.JSONField(),
),
]

View File

@ -1,5 +1,5 @@
from django.db import models from django.db import models
from jsonfield import JSONField from utils.models import JSONField
from account.models import User from account.models import User
from contest.models import Contest from contest.models import Contest
@ -65,8 +65,8 @@ class Problem(models.Model):
total_score = models.IntegerField(default=0, blank=True) total_score = models.IntegerField(default=0, blank=True)
submission_number = models.BigIntegerField(default=0) submission_number = models.BigIntegerField(default=0)
accepted_number = models.BigIntegerField(default=0) accepted_number = models.BigIntegerField(default=0)
# ACM rule_type: {JudgeStatus.ACCEPTED: 3, JudgeStaus.WRONG_ANSWER: 11}, the number means count # {JudgeStatus.ACCEPTED: 3, JudgeStaus.WRONG_ANSWER: 11}, the number means count
statistic_info = JSONField(default={}) statistic_info = JSONField(default=dict)
class Meta: class Meta:
db_table = "problem" db_table = "problem"

View File

@ -71,6 +71,7 @@ class CreateContestProblemSerializer(CreateOrEditProblemSerializer):
class TagSerializer(serializers.ModelSerializer): class TagSerializer(serializers.ModelSerializer):
class Meta: class Meta:
model = ProblemTag model = ProblemTag
fields = "__all__"
class BaseProblemSerializer(serializers.ModelSerializer): class BaseProblemSerializer(serializers.ModelSerializer):
@ -88,11 +89,13 @@ class BaseProblemSerializer(serializers.ModelSerializer):
class ProblemAdminSerializer(BaseProblemSerializer): class ProblemAdminSerializer(BaseProblemSerializer):
class Meta: class Meta:
model = Problem model = Problem
fields = "__all__"
class ContestProblemAdminSerializer(BaseProblemSerializer): class ContestProblemAdminSerializer(BaseProblemSerializer):
class Meta: class Meta:
model = Problem model = Problem
fields = "__all__"
class ProblemSerializer(BaseProblemSerializer): class ProblemSerializer(BaseProblemSerializer):

View File

@ -1,4 +1,5 @@
import copy import copy
import hashlib
import os import os
import shutil import shutil
from datetime import timedelta from datetime import timedelta
@ -9,6 +10,7 @@ from django.conf import settings
from utils.api.tests import APITestCase from utils.api.tests import APITestCase
from .models import ProblemTag from .models import ProblemTag
from .models import Problem, ProblemRuleType
from .views.admin import TestCaseUploadAPI from .views.admin import TestCaseUploadAPI
from contest.models import Contest from contest.models import Contest
from contest.tests import DEFAULT_CONTEST_DATA from contest.tests import DEFAULT_CONTEST_DATA
@ -23,6 +25,40 @@ DEFAULT_PROBLEM_DATA = {"_id": "A-110", "title": "test", "description": "<p>test
"input_size": 0, "score": 0}], "input_size": 0, "score": 0}],
"rule_type": "ACM", "hint": "<p>test</p>", "source": "test"} "rule_type": "ACM", "hint": "<p>test</p>", "source": "test"}
class ProblemCreateTestBase(APITestCase):
@staticmethod
def add_problem(problem_data, created_by):
data = copy.deepcopy(problem_data)
if data["spj"]:
if not data["spj_language"] or not data["spj_code"]:
raise ValueError("Invalid spj")
data["spj_version"] = hashlib.md5((data["spj_language"] + ":" + data["spj_code"]).encode("utf-8")).hexdigest()
else:
data["spj_language"] = None
data["spj_code"] = None
if data["rule_type"] == ProblemRuleType.OI:
total_score = 0
for item in data["test_case_score"]:
if item["score"] <= 0:
raise ValueError("invalid score")
else:
total_score += item["score"]
data["total_score"] = total_score
data["created_by"] = created_by
tags = data.pop("tags")
data["languages"] = list(data["languages"])
problem = Problem.objects.create(**data)
for item in tags:
try:
tag = ProblemTag.objects.get(name=item)
except ProblemTag.DoesNotExist:
tag = ProblemTag.objects.create(name=item)
problem.tags.add(tag)
return problem
class ProblemTagListAPITest(APITestCase): class ProblemTagListAPITest(APITestCase):
def test_get_tag_list(self): def test_get_tag_list(self):
@ -96,7 +132,7 @@ class ProblemAdminAPITest(APITestCase):
def setUp(self): def setUp(self):
self.url = self.reverse("problem_admin_api") self.url = self.reverse("problem_admin_api")
self.create_super_admin() self.create_super_admin()
self.data = DEFAULT_PROBLEM_DATA self.data = copy.deepcopy(DEFAULT_PROBLEM_DATA)
def test_create_problem(self): def test_create_problem(self):
resp = self.client.post(self.url, data=self.data) resp = self.client.post(self.url, data=self.data)
@ -138,23 +174,19 @@ class ProblemAdminAPITest(APITestCase):
self.assertSuccess(resp) self.assertSuccess(resp)
class ProblemAPITest(APITestCase): class ProblemAPITest(ProblemCreateTestBase):
def setUp(self): def setUp(self):
self.url = self.reverse("problem_api") self.url = self.reverse("problem_api")
self.create_admin() admin = self.create_admin(login=False)
self.problem = self.add_problem(DEFAULT_PROBLEM_DATA, admin)
def create_problem(self): self.create_user("test", "test123")
url = self.reverse("problem_admin_api")
return self.client.post(url, data=DEFAULT_PROBLEM_DATA)
def test_get_problem_list(self): def test_get_problem_list(self):
self.create_problem()
resp = self.client.get(f"{self.url}?limit=10") resp = self.client.get(f"{self.url}?limit=10")
self.assertSuccess(resp) self.assertSuccess(resp)
def get_one_problem(self): def get_one_problem(self):
problem_id = self.create_problem().data["data"]["_id"] resp = self.client.get(self.url + "?id=" + self.problem._id)
resp = self.client.get(self.url + "?id=" + str(problem_id))
self.assertSuccess(resp) self.assertSuccess(resp)
@ -169,51 +201,49 @@ class ContestProblemAdminTest(APITestCase):
def test_create_contest_problem(self): def test_create_contest_problem(self):
contest = self.create_contest() contest = self.create_contest()
data = DEFAULT_PROBLEM_DATA data = copy.deepcopy(DEFAULT_PROBLEM_DATA)
data["contest_id"] = contest.data["data"]["id"] data["contest_id"] = contest.data["data"]["id"]
resp = self.client.post(self.url, data=data) resp = self.client.post(self.url, data=data)
self.assertSuccess(resp) self.assertSuccess(resp)
return resp return contest, resp
def test_get_contest_problem(self): def test_get_contest_problem(self):
contest = self.test_create_contest_problem() contest, contest_problem = self.test_create_contest_problem()
contest_id = contest.data["data"]["id"] contest_id = contest.data["data"]["id"]
resp = self.client.get(self.url + "?contest_id=" + str(contest_id)) resp = self.client.get(self.url + "?contest_id=" + str(contest_id))
self.assertSuccess(resp) self.assertSuccess(resp)
self.assertEqual(len(resp.data["data"]), 1) self.assertEqual(len(resp.data["data"]), 1)
def test_get_one_contest_problem(self): def test_get_one_contest_problem(self):
contest = self.test_create_contest_problem() contest, contest_problem = self.test_create_contest_problem()
contest_id = contest.data["data"]["id"] contest_id = contest.data["data"]["id"]
resp = self.client.get(self.url + "?id=" + str(contest_id)) problem_id = contest_problem.data["data"]["id"]
resp = self.client.get(f"{self.url}?contest_id={contest_id}&id={problem_id}")
self.assertSuccess(resp) self.assertSuccess(resp)
class ContestProblemTest(APITestCase): class ContestProblemTest(ProblemCreateTestBase):
def setUp(self): def setUp(self):
self.url = self.reverse("contest_problem_api") admin = self.create_admin()
self.create_admin()
url = self.reverse("contest_admin_api") url = self.reverse("contest_admin_api")
contest_data = copy.deepcopy(DEFAULT_CONTEST_DATA) contest_data = copy.deepcopy(DEFAULT_CONTEST_DATA)
contest_data["password"] = "" contest_data["password"] = ""
contest_data["start_time"] = contest_data["start_time"] + timedelta(hours=1) contest_data["start_time"] = contest_data["start_time"] + timedelta(hours=1)
self.contest = self.client.post(url, data=contest_data).data["data"] self.contest = self.client.post(url, data=contest_data).data["data"]
self.problem = self.add_problem(DEFAULT_PROBLEM_DATA, admin)
self.problem.contest_id = self.contest["id"]
self.problem.save()
self.url = self.reverse("contest_problem_api")
problem_data = copy.deepcopy(DEFAULT_PROBLEM_DATA) def test_admin_get_contest_problem_list(self):
problem_data["contest"] = self.contest["id"]
url = self.reverse("contest_problem_admin_api")
self.problem = self.client.post(url, problem_data).data["data"]
def test_get_contest_problem_list(self):
contest_id = self.contest["id"] contest_id = self.contest["id"]
resp = self.client.get(self.url + "?contest_id=" + str(contest_id)) resp = self.client.get(self.url + "?contest_id=" + str(contest_id))
self.assertSuccess(resp) self.assertSuccess(resp)
self.assertEqual(len(resp.data["data"]), 1) self.assertEqual(len(resp.data["data"]), 1)
def test_get_one_contest_problem(self): def test_admin_get_one_contest_problem(self):
contest_id = self.contest["id"] contest_id = self.contest["id"]
problem_id = self.problem["_id"] problem_id = self.problem._id
resp = self.client.get("{}?contest_id={}&problem_id={}".format(self.url, contest_id, problem_id)) resp = self.client.get("{}?contest_id={}&problem_id={}".format(self.url, contest_id, problem_id))
self.assertSuccess(resp) self.assertSuccess(resp)

View File

@ -223,6 +223,7 @@ class ProblemAPI(APIView):
data["total_score"] = total_score data["total_score"] = total_score
# todo check filename and score info # todo check filename and score info
tags = data.pop("tags") tags = data.pop("tags")
data["languages"] = list(data["languages"])
for k, v in data.items(): for k, v in data.items():
setattr(problem, k, v) setattr(problem, k, v)

View File

@ -13,6 +13,24 @@ class ProblemTagAPI(APIView):
class ProblemAPI(APIView): class ProblemAPI(APIView):
@staticmethod
def _add_problem_status(request, queryset_values):
if request.user.is_authenticated():
profile = request.user.userprofile
acm_problems_status = profile.acm_problems_status.get("problems", {})
oi_problems_status = profile.oi_problems_status.get("problems", {})
# paginate data
results = queryset_values.get("results")
if results is not None:
problems = results
else:
problems = [queryset_values,]
for problem in problems:
if problem["rule_type"] == ProblemRuleType.ACM:
problem["my_status"] = acm_problems_status.get(str(problem["id"]), {}).get("status")
else:
problem["my_status"] = oi_problems_status.get(str(problem["id"]), {}).get("status")
def get(self, request): def get(self, request):
# 问题详情页 # 问题详情页
problem_id = request.GET.get("problem_id") problem_id = request.GET.get("problem_id")
@ -20,7 +38,9 @@ class ProblemAPI(APIView):
try: try:
problem = Problem.objects.select_related("created_by")\ problem = Problem.objects.select_related("created_by")\
.get(_id=problem_id, contest_id__isnull=True, visible=True) .get(_id=problem_id, contest_id__isnull=True, visible=True)
return self.success(ProblemSerializer(problem).data) problem_data = ProblemSerializer(problem).data
self._add_problem_status(request, problem_data)
return self.success(problem_data)
except Problem.DoesNotExist: except Problem.DoesNotExist:
return self.error("Problem does not exist") return self.error("Problem does not exist")
@ -32,11 +52,7 @@ class ProblemAPI(APIView):
# 按照标签筛选 # 按照标签筛选
tag_text = request.GET.get("tag") tag_text = request.GET.get("tag")
if tag_text: if tag_text:
try: problems = problems.filter(tags__name=tag_text)
tag = ProblemTag.objects.get(name=tag_text)
except ProblemTag.DoesNotExist:
return self.error("The Tag does not exist.")
problems = tag.problem_set.all().filter(visible=True)
# 搜索的情况 # 搜索的情况
keyword = request.GET.get("keyword", "").strip() keyword = request.GET.get("keyword", "").strip()
@ -49,19 +65,23 @@ class ProblemAPI(APIView):
problems = problems.filter(difficulty=difficulty) problems = problems.filter(difficulty=difficulty)
# 根据profile 为做过的题目添加标记 # 根据profile 为做过的题目添加标记
data = self.paginate_data(request, problems, ProblemSerializer) data = self.paginate_data(request, problems, ProblemSerializer)
if request.user.id: self._add_problem_status(request, data)
profile = request.user.userprofile
acm_problems_status = profile.acm_problems_status.get("problems", {})
oi_problems_status = profile.oi_problems_status.get("problems", {})
for problem in data["results"]:
if problem["rule_type"] == ProblemRuleType.ACM:
problem["my_status"] = acm_problems_status.get(str(problem["id"]), None)
else:
problem["my_status"] = oi_problems_status.get(str(problem["id"]), None)
return self.success(data) return self.success(data)
class ContestProblemAPI(APIView): class ContestProblemAPI(APIView):
def _add_problem_status(self, request, queryset_values):
if self.contest.rule_type == ContestRuleType.OI and not self.contest.check_oi_permission(request.user):
return
if request.user.is_authenticated():
profile = request.user.userprofile
if self.contest.rule_type == ContestRuleType.ACM:
problems_status = profile.acm_problems_status.get("contest_problems", {})
else:
problems_status = profile.oi_problems_status.get("contest_problems", {})
for problem in queryset_values:
problem["my_status"] = problems_status.get(str(problem["id"]), {}).get("status")
@check_contest_permission @check_contest_permission
def get(self, request): def get(self, request):
problem_id = request.GET.get("problem_id") problem_id = request.GET.get("problem_id")
@ -72,17 +92,12 @@ class ContestProblemAPI(APIView):
visible=True) visible=True)
except Problem.DoesNotExist: except Problem.DoesNotExist:
return self.error("Problem does not exist.") return self.error("Problem does not exist.")
return self.success(ContestProblemSerializer(problem).data) problem_data = ContestProblemSerializer(problem).data
self._add_problem_status(request, [problem_data,])
return self.success(problem_data)
contest_problems = Problem.objects.select_related("created_by").filter(contest=self.contest, visible=True) contest_problems = Problem.objects.select_related("created_by").filter(contest=self.contest, visible=True)
# 根据profile 为做过的题目添加标记 # 根据profile 为做过的题目添加标记
data = ContestProblemSerializer(contest_problems, many=True).data data = ContestProblemSerializer(contest_problems, many=True).data
if request.user.is_authenticated() and self.contest.rule_type != ContestRuleType.OI: self._add_problem_status(request, data)
profile = request.user.userprofile
if self.contest.rule_type == ContestRuleType.ACM:
problems_status = profile.acm_problems_status.get("contest_problems", {})
else:
problems_status = profile.oi_problems_status.get("contest_problems", {})
for problem in data:
problem["my_status"] = problems_status.get(str(problem["id"]), None)
return self.success(data) return self.success(data)

View File

@ -0,0 +1,26 @@
# -*- coding: utf-8 -*-
# Generated by Django 1.11.4 on 2017-10-11 12:14
from __future__ import unicode_literals
import django.contrib.postgres.fields.jsonb
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
('submission', '0007_auto_20170923_1318'),
]
operations = [
migrations.AlterField(
model_name='submission',
name='info',
field=django.contrib.postgres.fields.jsonb.JSONField(default=dict),
),
migrations.AlterField(
model_name='submission',
name='statistic_info',
field=django.contrib.postgres.fields.jsonb.JSONField(default=dict),
),
]

View File

@ -1,5 +1,5 @@
from django.db import models from django.db import models
from jsonfield import JSONField from utils.models import JSONField
from account.models import AdminType from account.models import AdminType
from problem.models import Problem from problem.models import Problem
from contest.models import Contest from contest.models import Contest
@ -30,18 +30,20 @@ class Submission(models.Model):
username = models.CharField(max_length=30) username = models.CharField(max_length=30)
code = models.TextField() code = models.TextField()
result = models.IntegerField(db_index=True, default=JudgeStatus.PENDING) result = models.IntegerField(db_index=True, default=JudgeStatus.PENDING)
# 判题结果的详细信息 # 从JudgeServer返回的判题详情
info = JSONField(default={}) info = JSONField(default=dict)
language = models.CharField(max_length=20) language = models.CharField(max_length=20)
shared = models.BooleanField(default=False) shared = models.BooleanField(default=False)
# 存储该提交所用时间和内存值,方便提交列表显示 # 存储该提交所用时间和内存值,方便提交列表显示
# {time_cost: "", memory_cost: "", err_info: "", score: 0} # {time_cost: "", memory_cost: "", err_info: "", score: 0}
statistic_info = JSONField(default={}) statistic_info = JSONField(default=dict)
def check_user_permission(self, user): def check_user_permission(self, user, check_share=True):
return self.user_id == user.id or \ return self.user_id == user.id or \
self.shared is True or \ (check_share and self.shared is True) or \
user.admin_type == AdminType.SUPER_ADMIN user.is_super_admin() or \
user.can_mgmt_all_problem() or \
self.problem.created_by_id == user.id
class Meta: class Meta:
db_table = "submission" db_table = "submission"

View File

@ -10,6 +10,11 @@ class CreateSubmissionSerializer(serializers.Serializer):
contest_id = serializers.IntegerField(required=False) contest_id = serializers.IntegerField(required=False)
class ShareSubmissionSerializer(serializers.Serializer):
id = serializers.CharField()
shared = serializers.BooleanField()
class SubmissionModelSerializer(serializers.ModelSerializer): class SubmissionModelSerializer(serializers.ModelSerializer):
info = serializers.JSONField() info = serializers.JSONField()
statistic_info = serializers.JSONField() statistic_info = serializers.JSONField()
@ -19,7 +24,7 @@ class SubmissionModelSerializer(serializers.ModelSerializer):
# 不显示submission info的serializer, 用于ACM rule_type # 不显示submission info的serializer, 用于ACM rule_type
class SubmissionSafeSerializer(serializers.ModelSerializer): class SubmissionSafeModelSerializer(serializers.ModelSerializer):
problem = serializers.SlugRelatedField(read_only=True, slug_field="_id") problem = serializers.SlugRelatedField(read_only=True, slug_field="_id")
statistic_info = serializers.JSONField() statistic_info = serializers.JSONField()
@ -43,6 +48,6 @@ class SubmissionListSerializer(serializers.ModelSerializer):
def get_show_link(self, obj): def get_show_link(self, obj):
# 没传user或为匿名user # 没传user或为匿名user
if self.user is None or self.user.id is None: if self.user is None or not self.user.is_authenticated():
return False return False
return obj.check_user_permission(self.user) return obj.check_user_permission(self.user)

View File

@ -32,14 +32,16 @@ class SubmissionPrepare(APITestCase):
def _create_problem_and_submission(self): def _create_problem_and_submission(self):
user = self.create_admin("test", "test123", login=False) user = self.create_admin("test", "test123", login=False)
problem_data = deepcopy(DEFAULT_PROBLEM_DATA) problem_data = deepcopy(DEFAULT_PROBLEM_DATA)
problem_data.pop("tags") tags = problem_data.pop("tags")
problem_data["created_by"] = user problem_data["created_by"] = user
self.problem = Problem.objects.create(**problem_data) self.problem = Problem.objects.create(**problem_data)
for tag in DEFAULT_PROBLEM_DATA["tags"]: for tag in tags:
tag = ProblemTag.objects.create(name=tag) tag = ProblemTag.objects.create(name=tag)
self.problem.tags.add(tag) self.problem.tags.add(tag)
self.problem.save() self.problem.save()
self.submission = Submission.objects.create(**DEFAULT_SUBMISSION_DATA) self.submission_data = deepcopy(DEFAULT_SUBMISSION_DATA)
self.submission_data["problem_id"] = self.problem.id
self.submission = Submission.objects.create(**self.submission_data)
class SubmissionListTest(SubmissionPrepare): class SubmissionListTest(SubmissionPrepare):
@ -61,6 +63,6 @@ class SubmissionAPITest(SubmissionPrepare):
self.url = self.reverse("submission_api") self.url = self.reverse("submission_api")
def test_create_submission(self, judge_task): def test_create_submission(self, judge_task):
resp = self.client.post(self.url, DEFAULT_SUBMISSION_DATA) resp = self.client.post(self.url, self.submission_data)
self.assertSuccess(resp) self.assertSuccess(resp)
judge_task.assert_called() judge_task.assert_called()

View File

@ -5,16 +5,17 @@ from problem.models import Problem, ProblemRuleType
from contest.models import Contest, ContestStatus, ContestRuleType from contest.models import Contest, ContestStatus, ContestRuleType
from utils.api import APIView, validate_serializer from utils.api import APIView, validate_serializer
from utils.throttling import TokenBucket, BucketController from utils.throttling import TokenBucket, BucketController
from utils.cache import cache
from ..models import Submission from ..models import Submission
from ..serializers import CreateSubmissionSerializer, SubmissionModelSerializer from ..serializers import (CreateSubmissionSerializer, SubmissionModelSerializer,
from ..serializers import SubmissionSafeSerializer, SubmissionListSerializer ShareSubmissionSerializer)
from utils.cache import throttling_cache from ..serializers import SubmissionSafeModelSerializer, SubmissionListSerializer
def _submit(response, user, problem_id, language, code, contest_id): def _submit(response, user, problem_id, language, code, contest_id):
# TODO: 预设默认值,需修改 # TODO: 预设默认值,需修改
controller = BucketController(user_id=user.id, controller = BucketController(user_id=user.id,
redis_conn=throttling_cache, redis_conn=cache,
default_capacity=30) default_capacity=30)
bucket = TokenBucket(fill_rate=10, capacity=20, bucket = TokenBucket(fill_rate=10, capacity=20,
last_capacity=controller.last_capacity, last_capacity=controller.last_capacity,
@ -63,17 +64,36 @@ class SubmissionAPI(APIView):
def get(self, request): def get(self, request):
submission_id = request.GET.get("id") submission_id = request.GET.get("id")
if not submission_id: if not submission_id:
return self.error("Parameter id doesn't exist.") return self.error("Parameter id doesn't exist")
try: try:
submission = Submission.objects.select_related("problem").get(id=submission_id) submission = Submission.objects.select_related("problem").get(id=submission_id)
except Submission.DoesNotExist: except Submission.DoesNotExist:
return self.error("Submission doesn't exist.") return self.error("Submission doesn't exist")
if not submission.check_user_permission(request.user): if not submission.check_user_permission(request.user):
return self.error("No permission for this submission.") return self.error("No permission for this submission")
if submission.problem.rule_type == ProblemRuleType.ACM: if submission.problem.rule_type == ProblemRuleType.ACM:
return self.success(SubmissionSafeSerializer(submission).data) submission_data = SubmissionSafeModelSerializer(submission).data
return self.success(SubmissionModelSerializer(submission).data) else:
submission_data = SubmissionModelSerializer(submission).data
# 是否有权限取消共享
submission_data["can_unshare"] = submission.check_user_permission(request.user, check_share=False)
return self.success(submission_data)
@validate_serializer(ShareSubmissionSerializer)
@login_required
def put(self, request):
try:
submission = Submission.objects.select_related("problem").get(id=request.data["id"])
except Submission.DoesNotExist:
return self.error("Submission doesn't exist")
if not submission.check_user_permission(request.user, check_share=False):
return self.error("No permission to share the submission")
if submission.contest and submission.contest.status == ContestStatus.CONTEST_UNDERWAY:
return self.error("Can not share submission now")
submission.shared = request.data["shared"]
submission.save(update_fields=["shared"])
return self.success()
class SubmissionListAPI(APIView): class SubmissionListAPI(APIView):
@ -83,7 +103,7 @@ class SubmissionListAPI(APIView):
if request.GET.get("contest_id"): if request.GET.get("contest_id"):
return self.error("Parameter error") return self.error("Parameter error")
submissions = Submission.objects.filter(contest_id__isnull=True) submissions = Submission.objects.filter(contest_id__isnull=True).select_related("problem__created_by")
problem_id = request.GET.get("problem_id") problem_id = request.GET.get("problem_id")
myself = request.GET.get("myself") myself = request.GET.get("myself")
result = request.GET.get("result") result = request.GET.get("result")
@ -109,10 +129,10 @@ class ContestSubmissionListAPI(APIView):
return self.error("Limit is needed") return self.error("Limit is needed")
contest = self.contest contest = self.contest
if contest.rule_type == ContestRuleType.OI and not contest.is_contest_admin(request.user): if not contest.check_oi_permission(request.user):
return self.error("No permission for OI contest submissions") return self.error("No permission for OI contest submissions")
submissions = Submission.objects.filter(contest_id=contest.id) submissions = Submission.objects.filter(contest_id=contest.id).select_related("problem__created_by")
problem_id = request.GET.get("problem_id") problem_id = request.GET.get("problem_id")
myself = request.GET.get("myself") myself = request.GET.get("myself")
result = request.GET.get("result") result = request.GET.get("result")
@ -133,9 +153,11 @@ class ContestSubmissionListAPI(APIView):
submissions = submissions.filter(create_time__gte=contest.start_time) submissions = submissions.filter(create_time__gte=contest.start_time)
# 封榜的时候只能看到自己的提交 # 封榜的时候只能看到自己的提交
if contest.rule_type == ContestRuleType.ACM:
if not contest.real_time_rank and not contest.is_contest_admin(request.user): if not contest.real_time_rank and not contest.is_contest_admin(request.user):
submissions = submissions.filter(user_id=request.user.id) submissions = submissions.filter(user_id=request.user.id)
data = self.paginate_data(request, submissions) data = self.paginate_data(request, submissions)
data["results"] = SubmissionListSerializer(data["results"], many=True, user=request.user).data data["results"] = SubmissionListSerializer(data["results"], many=True, user=request.user).data
return self.success(data) return self.success(data)

View File

@ -1,2 +1,2 @@
from .api import * # NOQA
from ._serializers import * # NOQA from ._serializers import * # NOQA
from .api import * # NOQA

View File

@ -79,7 +79,7 @@ class APIView(View):
def success(self, data=None): def success(self, data=None):
return self.response({"error": None, "data": data}) return self.response({"error": None, "data": data})
def error(self, msg, err="error"): def error(self, msg="error", err="error"):
return self.response({"error": err, "data": msg}) return self.response({"error": err, "data": msg})
def _serializer_error_to_str(self, errors): def _serializer_error_to_str(self, errors):

View File

@ -3,7 +3,6 @@ from django.test.testcases import TestCase
from rest_framework.test import APIClient from rest_framework.test import APIClient
from account.models import AdminType, ProblemPermission, User, UserProfile from account.models import AdminType, ProblemPermission, User, UserProfile
from conf.models import WebsiteConfig
class APITestCase(TestCase): class APITestCase(TestCase):
@ -28,9 +27,6 @@ class APITestCase(TestCase):
return self.create_user(username=username, password=password, admin_type=AdminType.SUPER_ADMIN, return self.create_user(username=username, password=password, admin_type=AdminType.SUPER_ADMIN,
problem_permission=ProblemPermission.ALL, login=login) problem_permission=ProblemPermission.ALL, login=login)
def create_website_config(self):
return WebsiteConfig.objects.create()
def reverse(self, url_name): def reverse(self, url_name):
return reverse(url_name) return reverse(url_name)

View File

@ -1,6 +1,27 @@
from django.conf import settings from django.core.cache import cache, caches # noqa
from django_redis import get_redis_connection from django.conf import settings # noqa
judge_cache = get_redis_connection(settings.CACHE_JUDGE_QUEUE) from django_redis.cache import RedisCache
throttling_cache = get_redis_connection(settings.CACHE_THROTTLING) from django_redis.client.default import DefaultClient
default_cache = get_redis_connection("default")
class MyRedisClient(DefaultClient):
def __getattr__(self, item):
client = self.get_client(write=True)
return getattr(client, item)
def redis_incr(self, key, count=1):
"""
django 默认的 incr key 不存在时候会抛异常
"""
client = self.get_client(write=True)
return client.incr(key, count)
class MyRedisCache(RedisCache):
def __init__(self, server, params):
super().__init__(server, params)
self._client_cls = MyRedisClient
def __getattr__(self, item):
return getattr(self.client, item)

View File

@ -1,4 +1,28 @@
class Choices:
@classmethod
def choices(cls):
d = cls.__dict__
return [d[item] for item in d.keys() if not item.startswith("__")]
class ContestType:
PUBLIC_CONTEST = "Public"
PASSWORD_PROTECTED_CONTEST = "Password Protected"
class ContestStatus:
CONTEST_NOT_START = "1"
CONTEST_ENDED = "-1"
CONTEST_UNDERWAY = "0"
class ContestRuleType(Choices):
ACM = "ACM"
OI = "OI"
class CacheKey: class CacheKey:
waiting_queue = "waiting_queue" waiting_queue = "waiting_queue"
contest_rank_cache = "contest_rank_cache_" contest_rank_cache = "contest_rank_cache"
website_config = "website_config" website_config = "website_config"
option = "option"

View File

@ -1,3 +1,4 @@
from django.contrib.postgres.fields import JSONField # NOQA
from django.db import models from django.db import models
from utils.xss_filter import XssHtml from utils.xss_filter import XssHtml

View File

@ -1,35 +1,9 @@
import logging
import random
import datetime import datetime
from io import BytesIO import random
from base64 import b64encode from base64 import b64encode
from io import BytesIO
from django.utils.crypto import get_random_string from django.utils.crypto import get_random_string
from envelopes import Envelope
from conf.models import SMTPConfig
logger = logging.getLogger(__name__)
def send_email(from_name, to_email, to_name, subject, content):
smtp = SMTPConfig.objects.first()
if not smtp:
return
envlope = Envelope(from_addr=(smtp.email, from_name),
to_addr=(to_email, to_name),
subject=subject,
html_body=content)
try:
envlope.send(smtp.server,
login=smtp.email,
password=smtp.password,
port=smtp.port,
tls=smtp.tls)
return True
except Exception as e:
logger.exception(e)
return False
def rand_str(length=32, type="lower_hex"): def rand_str(length=32, type="lower_hex"):

View File

@ -26,11 +26,8 @@ Cannot defense xss in browser which is belowed IE7
浏览器版本IE7+ 或其他浏览器无法防御IE6及以下版本浏览器中的XSS 浏览器版本IE7+ 或其他浏览器无法防御IE6及以下版本浏览器中的XSS
""" """
import re import re
import copy
try:
from html.parser import HTMLParser from html.parser import HTMLParser
except:
from HTMLParser import HTMLParser
class XssHtml(HTMLParser): class XssHtml(HTMLParser):
@ -163,7 +160,7 @@ class XssHtml(HTMLParser):
else: else:
other = [] other = []
if attrs: if attrs:
for (key, value) in attrs.items(): for key, value in copy.deepcopy(attrs).items():
if key not in self.common_attrs + other: if key not in self.common_attrs + other:
del attrs[key] del attrs[key]
return attrs return attrs