Merge branch 'opt'

This commit is contained in:
zema1 2017-10-21 10:51:59 +08:00
commit c1d099ed45
52 changed files with 1011 additions and 589 deletions

View File

@ -92,7 +92,8 @@ def check_contest_permission(func):
if not user.is_authenticated():
return self.error("Please login in first.")
# password error
if ("contests" not in request.session) or (self.contest.id not in request.session["contests"]):
if ("accessible_contests" not in request.session) or \
(self.contest.id not in request.session["accessible_contests"]):
return self.error("Password is required.")
return func(*args, **kwargs)

View File

@ -1,38 +1,19 @@
import time
import pytz
from django.contrib import auth
from django.utils import timezone
from django.utils.translation import ugettext as _
from django.db import connection
from django.utils.timezone import now
from django.utils.deprecation import MiddlewareMixin
from utils.api import JSONResponse
class SessionSecurityMiddleware(MiddlewareMixin):
def process_request(self, request):
if request.user.is_authenticated():
if "last_activity" in request.session and request.user.is_admin_role():
# 24 hours passed since last visit, 86400 = 24 * 60 * 60
if time.time() - request.session["last_activity"] >= 86400:
auth.logout(request)
return JSONResponse.response({"error": "login-required", "data": _("Please login in first")})
request.session["last_activity"] = time.time()
class SessionRecordMiddleware(MiddlewareMixin):
def process_request(self, request):
if request.user.is_authenticated():
session = request.session
ip = request.META.get("REMOTE_ADDR", "")
user_agent = request.META.get("HTTP_USER_AGENT", "")
_ip = session.setdefault("ip", ip)
_user_agent = session.setdefault("user_agent", user_agent)
if ip != _ip or user_agent != _user_agent:
session.modified = True
session["user_agent"] = request.META.get("HTTP_USER_AGENT", "")
session["ip"] = request.META.get("HTTP_X_REAL_IP", "UNKNOWN IP")
session["last_activity"] = now()
user_sessions = request.user.session_keys
if request.session.session_key not in user_sessions:
if session.session_key not in user_sessions:
user_sessions.append(session.session_key)
request.user.save()
@ -42,13 +23,7 @@ class AdminRoleRequiredMiddleware(MiddlewareMixin):
path = request.path_info
if path.startswith("/admin/") or path.startswith("/api/admin/"):
if not (request.user.is_authenticated() and request.user.is_admin_role()):
return JSONResponse.response({"error": "login-required", "data": _("Please login in first")})
class TimezoneMiddleware(MiddlewareMixin):
def process_request(self, request):
if request.user.is_authenticated():
timezone.activate(pytz.timezone(request.user.userprofile.time_zone))
return JSONResponse.response({"error": "login-required", "data": "Please login in first"})
class LogSqlMiddleware(MiddlewareMixin):

View File

@ -50,7 +50,7 @@ class Migration(migrations.Migration):
fields=[
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('problems_status', jsonfield.fields.JSONField(default={})),
('avatar', models.CharField(default=account.models._default_avatar, max_length=50)),
('avatar', models.CharField(default="default.png", max_length=50)),
('blog', models.URLField(blank=True, null=True)),
('mood', models.CharField(blank=True, max_length=200, null=True)),
('accepted_problem_number', models.IntegerField(default=0)),

View File

@ -0,0 +1,105 @@
# -*- coding: utf-8 -*-
# Generated by Django 1.11.4 on 2017-10-11 12:14
from __future__ import unicode_literals
import django.contrib.postgres.fields.jsonb
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('account', '0007_auto_20170920_0254'),
]
operations = [
migrations.RemoveField(
model_name='userprofile',
name='language',
),
migrations.AlterField(
model_name='user',
name='admin_type',
field=models.CharField(default='Regular User', max_length=32),
),
migrations.AlterField(
model_name='user',
name='auth_token',
field=models.CharField(max_length=32, null=True),
),
migrations.AlterField(
model_name='user',
name='email',
field=models.EmailField(max_length=64, null=True),
),
migrations.AlterField(
model_name='user',
name='open_api_appkey',
field=models.CharField(max_length=32, null=True),
),
migrations.AlterField(
model_name='user',
name='problem_permission',
field=models.CharField(default='None', max_length=32),
),
migrations.AlterField(
model_name='user',
name='reset_password_token',
field=models.CharField(max_length=32, null=True),
),
migrations.AlterField(
model_name='user',
name='session_keys',
field=django.contrib.postgres.fields.jsonb.JSONField(default=list),
),
migrations.AlterField(
model_name='user',
name='tfa_token',
field=models.CharField(max_length=32, null=True),
),
migrations.AlterField(
model_name='user',
name='username',
field=models.CharField(max_length=32, unique=True),
),
migrations.AlterField(
model_name='userprofile',
name='acm_problems_status',
field=django.contrib.postgres.fields.jsonb.JSONField(default=dict),
),
migrations.AlterField(
model_name='userprofile',
name='avatar',
field=models.CharField(default='/static/avatar/default.png', max_length=256),
),
migrations.AlterField(
model_name='userprofile',
name='github',
field=models.CharField(blank=True, max_length=64, null=True),
),
migrations.AlterField(
model_name='userprofile',
name='major',
field=models.CharField(blank=True, max_length=64, null=True),
),
migrations.AlterField(
model_name='userprofile',
name='mood',
field=models.CharField(blank=True, max_length=256, null=True),
),
migrations.AlterField(
model_name='userprofile',
name='oi_problems_status',
field=django.contrib.postgres.fields.jsonb.JSONField(default=dict),
),
migrations.AlterField(
model_name='userprofile',
name='real_name',
field=models.CharField(blank=True, max_length=32, null=True),
),
migrations.AlterField(
model_name='userprofile',
name='school',
field=models.CharField(blank=True, max_length=64, null=True),
),
]

View File

@ -1,7 +1,7 @@
from django.contrib.auth.models import AbstractBaseUser
from django.conf import settings
from django.db import models
from jsonfield import JSONField
from utils.models import JSONField
class AdminType(object):
@ -24,22 +24,22 @@ class UserManager(models.Manager):
class User(AbstractBaseUser):
username = models.CharField(max_length=30, unique=True)
email = models.EmailField(max_length=254, null=True)
username = models.CharField(max_length=32, unique=True)
email = models.EmailField(max_length=64, null=True)
create_time = models.DateTimeField(auto_now_add=True, null=True)
# One of UserType
admin_type = models.CharField(max_length=24, default=AdminType.REGULAR_USER)
problem_permission = models.CharField(max_length=24, default=ProblemPermission.NONE)
reset_password_token = models.CharField(max_length=40, null=True)
admin_type = models.CharField(max_length=32, default=AdminType.REGULAR_USER)
problem_permission = models.CharField(max_length=32, default=ProblemPermission.NONE)
reset_password_token = models.CharField(max_length=32, null=True)
reset_password_token_expire_time = models.DateTimeField(null=True)
# SSO auth token
auth_token = models.CharField(max_length=40, null=True)
auth_token = models.CharField(max_length=32, null=True)
two_factor_auth = models.BooleanField(default=False)
tfa_token = models.CharField(max_length=40, null=True)
session_keys = JSONField(default=[])
tfa_token = models.CharField(max_length=32, null=True)
session_keys = JSONField(default=list)
# open api key
open_api = models.BooleanField(default=False)
open_api_appkey = models.CharField(max_length=35, null=True)
open_api_appkey = models.CharField(max_length=32, null=True)
is_disabled = models.BooleanField(default=False)
USERNAME_FIELD = "username"
@ -63,26 +63,34 @@ class User(AbstractBaseUser):
db_table = "user"
def _default_avatar():
return f"/{settings.IMAGE_UPLOAD_DIR}/default.png"
class UserProfile(models.Model):
user = models.OneToOneField(User)
# Store user problem solution status with json string format
# {problems: {1: JudgeStatus.ACCEPTED}, contest_problems: {1: JudgeStatus.ACCEPTED}}, record problem_id and status
acm_problems_status = JSONField(default={})
# {problems: {1: 33}, contest_problems: {1: 44}, record problem_id and score
oi_problems_status = JSONField(default={})
# acm_problems_status examples:
# {
# "problems": {
# "1": {
# "status": JudgeStatus.ACCEPTED,
# "_id": "1000"
# }
# },
# "contest_problems": {
# "1": {
# "status": JudgeStatus.ACCEPTED,
# "_id": "1000"
# }
# }
# }
acm_problems_status = JSONField(default=dict)
# like acm_problems_status, merely add "score" field
oi_problems_status = JSONField(default=dict)
real_name = models.CharField(max_length=30, blank=True, null=True)
avatar = models.CharField(max_length=50, default=_default_avatar())
real_name = models.CharField(max_length=32, blank=True, null=True)
avatar = models.CharField(max_length=256, default=f"/{settings.IMAGE_UPLOAD_DIR}/default.png")
blog = models.URLField(blank=True, null=True)
mood = models.CharField(max_length=200, blank=True, null=True)
github = models.CharField(max_length=50, blank=True, null=True)
school = models.CharField(max_length=200, blank=True, null=True)
major = models.CharField(max_length=200, blank=True, null=True)
language = models.CharField(max_length=32, blank=True, null=True)
mood = models.CharField(max_length=256, blank=True, null=True)
github = models.CharField(max_length=64, blank=True, null=True)
school = models.CharField(max_length=64, blank=True, null=True)
major = models.CharField(max_length=64, blank=True, null=True)
# for ACM
accepted_number = models.IntegerField(default=0)
# for OI

View File

@ -6,27 +6,26 @@ from .models import AdminType, ProblemPermission, User, UserProfile
class UserLoginSerializer(serializers.Serializer):
username = serializers.CharField(max_length=30)
password = serializers.CharField(max_length=30)
tfa_code = serializers.CharField(min_length=6, max_length=6, required=False, allow_null=True)
username = serializers.CharField()
password = serializers.CharField()
tfa_code = serializers.CharField(required=False, allow_null=True)
class UsernameOrEmailCheckSerializer(serializers.Serializer):
username = serializers.CharField(max_length=30, required=False)
email = serializers.EmailField(max_length=30, required=False)
username = serializers.CharField(required=False)
email = serializers.EmailField(required=False)
class UserRegisterSerializer(serializers.Serializer):
username = serializers.CharField(max_length=30)
password = serializers.CharField(max_length=30, min_length=6)
email = serializers.EmailField(max_length=30)
captcha = serializers.CharField(max_length=4, min_length=1)
username = serializers.CharField(max_length=32)
password = serializers.CharField(min_length=6)
email = serializers.EmailField(max_length=64)
captcha = serializers.CharField()
class UserChangePasswordSerializer(serializers.Serializer):
old_password = serializers.CharField()
new_password = serializers.CharField(max_length=30, min_length=6)
captcha = serializers.CharField(max_length=4, min_length=4)
new_password = serializers.CharField(min_length=6)
class UserSerializer(serializers.ModelSerializer):
@ -46,6 +45,7 @@ class UserProfileSerializer(serializers.ModelSerializer):
class Meta:
model = UserProfile
fields = "__all__"
class UserInfoSerializer(serializers.ModelSerializer):
@ -58,9 +58,9 @@ class UserInfoSerializer(serializers.ModelSerializer):
class EditUserSerializer(serializers.Serializer):
id = serializers.IntegerField()
username = serializers.CharField(max_length=30)
password = serializers.CharField(max_length=30, min_length=6, allow_blank=True, required=False, default=None)
email = serializers.EmailField(max_length=254)
username = serializers.CharField(max_length=32)
password = serializers.CharField(min_length=6, allow_blank=True, required=False, default=None)
email = serializers.EmailField(max_length=64)
admin_type = serializers.ChoiceField(choices=(AdminType.REGULAR_USER, AdminType.ADMIN, AdminType.SUPER_ADMIN))
problem_permission = serializers.ChoiceField(choices=(ProblemPermission.NONE, ProblemPermission.OWN,
ProblemPermission.ALL))
@ -70,29 +70,29 @@ class EditUserSerializer(serializers.Serializer):
class EditUserProfileSerializer(serializers.Serializer):
real_name = serializers.CharField(max_length=30, allow_blank=True)
avatar = serializers.CharField(max_length=100, allow_blank=True, required=False)
blog = serializers.URLField(allow_blank=True, required=False)
mood = serializers.CharField(max_length=200, allow_blank=True, required=False)
github = serializers.CharField(max_length=50, allow_blank=True, required=False)
school = serializers.CharField(max_length=200, allow_blank=True, required=False)
major = serializers.CharField(max_length=200, allow_blank=True, required=False)
real_name = serializers.CharField(max_length=32, allow_blank=True)
avatar = serializers.CharField(max_length=256, allow_blank=True, required=False)
blog = serializers.URLField(max_length=256, allow_blank=True, required=False)
mood = serializers.CharField(max_length=256, allow_blank=True, required=False)
github = serializers.CharField(max_length=64, allow_blank=True, required=False)
school = serializers.CharField(max_length=64, allow_blank=True, required=False)
major = serializers.CharField(max_length=64, allow_blank=True, required=False)
class ApplyResetPasswordSerializer(serializers.Serializer):
email = serializers.EmailField()
captcha = serializers.CharField(max_length=4, min_length=4)
captcha = serializers.CharField()
class ResetPasswordSerializer(serializers.Serializer):
token = serializers.CharField(min_length=1, max_length=40)
password = serializers.CharField(min_length=6, max_length=30)
captcha = serializers.CharField(max_length=4, min_length=4)
token = serializers.CharField()
password = serializers.CharField(min_length=6)
captcha = serializers.CharField()
class SSOSerializer(serializers.Serializer):
appkey = serializers.CharField(max_length=35)
token = serializers.CharField(max_length=40)
appkey = serializers.CharField()
token = serializers.CharField()
class TwoFactorAuthCodeSerializer(serializers.Serializer):

View File

@ -1,6 +1,31 @@
from celery import shared_task
import logging
from utils.shortcuts import send_email
from celery import shared_task
from envelopes import Envelope
from options.options import SysOptions
logger = logging.getLogger(__name__)
def send_email(from_name, to_email, to_name, subject, content):
smtp = SysOptions.smtp_config
if not smtp:
return
envlope = Envelope(from_addr=(smtp["email"], from_name),
to_addr=(to_email, to_name),
subject=subject,
html_body=content)
try:
envlope.send(smtp["server"],
login=smtp["email"],
password=smtp["password"],
port=smtp["port"],
tls=smtp["tls"])
return True
except Exception as e:
logger.exception(e)
return False
@shared_task

View File

@ -8,11 +8,10 @@ from otpauth import OtpAuth
from utils.api.tests import APIClient, APITestCase
from utils.shortcuts import rand_str
from utils.cache import default_cache
from utils.constants import CacheKey
from options.options import SysOptions
from .models import AdminType, ProblemPermission, User
from conf.models import WebsiteConfig
from utils.constants import ContestRuleType
class PermissionDecoratorTest(APITestCase):
@ -136,7 +135,7 @@ class UserLoginAPITest(APITestCase):
self.user.save()
resp = self.client.post(self.login_url, data={"username": self.username,
"password": self.password})
self.assertDictEqual(resp.data, {"error": "error", "data": "Your account have been disabled"})
self.assertDictEqual(resp.data, {"error": "error", "data": "Your account has been disabled"})
class CaptchaTest(APITestCase):
@ -157,15 +156,11 @@ class UserRegisterAPITest(CaptchaTest):
self.data = {"username": "test_user", "password": "testuserpassword",
"real_name": "real_name", "email": "test@qduoj.com",
"captcha": self._set_captcha(self.client.session)}
# clea cache in redis
default_cache.delete(CacheKey.website_config)
def test_website_config_limit(self):
website = WebsiteConfig.objects.create()
website.allow_register = False
website.save()
SysOptions.allow_register = False
resp = self.client.post(self.register_url, data=self.data)
self.assertDictEqual(resp.data, {"error": "error", "data": "Register have been disabled by admin"})
self.assertDictEqual(resp.data, {"error": "error", "data": "Register function has been disabled by admin"})
def test_invalid_captcha(self):
self.data["captcha"] = "****"
@ -226,7 +221,7 @@ class UserProfileAPITest(APITestCase):
def test_get_profile_without_login(self):
resp = self.client.get(self.url)
self.assertDictEqual(resp.data, {"error": None, "data": {}})
self.assertDictEqual(resp.data, {"error": None, "data": None})
def test_get_profile(self):
self.create_user("test", "test123")
@ -247,7 +242,6 @@ class TwoFactorAuthAPITest(APITestCase):
def setUp(self):
self.url = self.reverse("two_factor_auth_api")
self.create_user("test", "test123")
self.create_website_config()
def _get_tfa_code(self):
user = User.objects.first()
@ -295,7 +289,6 @@ class ApplyResetPasswordAPITest(CaptchaTest):
user.email = "test@oj.com"
user.save()
self.url = self.reverse("apply_reset_password_api")
self.create_website_config()
self.data = {"email": "test@oj.com", "captcha": self._set_captcha(self.client.session)}
def _refresh_captcha(self):
@ -343,14 +336,14 @@ class ResetPasswordAPITest(CaptchaTest):
def test_reset_password_with_invalid_token(self):
self.data["token"] = "aaaaaaaaaaa"
resp = self.client.post(self.url, data=self.data)
self.assertDictEqual(resp.data, {"error": "error", "data": "Token dose not exist"})
self.assertDictEqual(resp.data, {"error": "error", "data": "Token does not exist"})
def test_reset_password_with_expired_token(self):
user = User.objects.first()
user.reset_password_token_expire_time = now() - timedelta(seconds=30)
user.save()
resp = self.client.post(self.url, data=self.data)
self.assertDictEqual(resp.data, {"error": "error", "data": "Token have expired"})
self.assertDictEqual(resp.data, {"error": "error", "data": "Token has expired"})
class UserChangePasswordAPITest(CaptchaTest):
@ -481,14 +474,14 @@ class UserRankAPITest(APITestCase):
profile2.save()
def test_get_acm_rank(self):
resp = self.client.get(self.url, data={"rule": "acm"})
resp = self.client.get(self.url, data={"rule": ContestRuleType.ACM})
self.assertSuccess(resp)
data = resp.data["data"]
self.assertEqual(data[0]["user"]["username"], "test1")
self.assertEqual(data[1]["user"]["username"], "test2")
def test_get_oi_rank(self):
resp = self.client.get(self.url, data={"rule": "oi"})
resp = self.client.get(self.url, data={"rule": ContestRuleType.OI})
self.assertSuccess(resp)
data = resp.data["data"]
self.assertEqual(data[0]["user"]["username"], "test2")

View File

@ -3,7 +3,7 @@ from django.conf.urls import url
from ..views.oj import (ApplyResetPasswordAPI, ResetPasswordAPI,
UserChangePasswordAPI, UserRegisterAPI,
UserLoginAPI, UserLogoutAPI, UsernameOrEmailCheck,
SSOAPI, AvatarUploadAPI, TwoFactorAuthAPI, UserProfileAPI,
AvatarUploadAPI, TwoFactorAuthAPI, UserProfileAPI,
UserRankAPI, CheckTFARequiredAPI, SessionManagementAPI)
from utils.captcha.views import CaptchaAPIView
@ -19,7 +19,6 @@ urlpatterns = [
url(r"^check_username_or_email", UsernameOrEmailCheck.as_view(), name="check_username_or_email"),
url(r"^profile/?$", UserProfileAPI.as_view(), name="user_profile_api"),
url(r"^upload_avatar/?$", AvatarUploadAPI.as_view(), name="avatar_upload_api"),
url(r"^sso/?$", SSOAPI.as_view(), name="sso_api"),
url(r"^tfa_required/?$", CheckTFARequiredAPI.as_view(), name="tfa_required_check"),
url(r"^two_factor_auth/?$", TwoFactorAuthAPI.as_view(), name="two_factor_auth_api"),
url(r"^user_rank/?$", UserRankAPI.as_view(), name="user_rank_api"),

View File

@ -1,32 +1,28 @@
import os
import qrcode
import pickle
from datetime import timedelta
from otpauth import OtpAuth
from importlib import import_module
import qrcode
from django.conf import settings
from django.contrib import auth
from importlib import import_module
from django.template.loader import render_to_string
from django.utils.decorators import method_decorator
from django.utils.timezone import now
from django.views.decorators.csrf import ensure_csrf_cookie
from django.utils.decorators import method_decorator
from django.template.loader import render_to_string
from otpauth import OtpAuth
from conf.models import WebsiteConfig
from utils.constants import ContestRuleType
from options.options import SysOptions
from utils.api import APIView, validate_serializer
from utils.captcha import Captcha
from utils.shortcuts import rand_str, img2base64, timestamp2utcstr
from utils.cache import default_cache
from utils.constants import CacheKey
from utils.shortcuts import rand_str, img2base64, datetime2str
from ..decorators import login_required
from ..models import User, UserProfile
from ..serializers import (ApplyResetPasswordSerializer, ResetPasswordSerializer,
UserChangePasswordSerializer, UserLoginSerializer,
UserRegisterSerializer, UsernameOrEmailCheckSerializer,
RankInfoSerializer)
from ..serializers import (SSOSerializer, TwoFactorAuthCodeSerializer,
UserProfileSerializer,
from ..serializers import (TwoFactorAuthCodeSerializer, UserProfileSerializer,
EditUserProfileSerializer, AvatarUploadForm)
from ..tasks import send_email_async
@ -39,7 +35,7 @@ class UserProfileAPI(APIView):
"""
user = request.user
if not user.is_authenticated():
return self.success({})
return self.success()
username = request.GET.get("username")
try:
if username:
@ -48,8 +44,7 @@ class UserProfileAPI(APIView):
user = request.user
except User.DoesNotExist:
return self.error("User does not exist")
profile = UserProfile.objects.select_related("user").get(user=user)
return self.success(UserProfileSerializer(profile).data)
return self.success(UserProfileSerializer(user.userprofile).data)
@validate_serializer(EditUserProfileSerializer)
@login_required
@ -72,8 +67,7 @@ class AvatarUploadAPI(APIView):
avatar = form.cleaned_data["file"]
else:
return self.error("Invalid file content")
# 2097152 = 2 * 1024 * 1024 = 2MB
if avatar.size > 2097152:
if avatar.size > 2 * 1024 * 1024:
return self.error("Picture is too large")
suffix = os.path.splitext(avatar.name)[-1].lower()
if suffix not in [".gif", ".jpg", ".jpeg", ".bmp", ".png"]:
@ -84,46 +78,12 @@ class AvatarUploadAPI(APIView):
for chunk in avatar:
img.write(chunk)
user_profile = request.user.userprofile
_, old_avatar = os.path.split(user_profile.avatar)
if old_avatar != "default.png":
os.remove(os.path.join(settings.IMAGE_UPLOAD_DIR_ABS, old_avatar))
user_profile.avatar = f"/{settings.IMAGE_UPLOAD_DIR}/{name}"
user_profile.save()
return self.success("Succeeded")
class SSOAPI(APIView):
@login_required
def get(self, request):
callback = request.GET.get("callback", None)
if not callback:
return self.error("Parameter Error")
token = rand_str()
request.user.auth_token = token
request.user.save()
return self.success({"redirect_url": callback + "?token=" + token,
"callback": callback})
@validate_serializer(SSOSerializer)
def post(self, request):
data = request.data
try:
User.objects.get(open_api_appkey=data["appkey"])
except User.DoesNotExist:
return self.error("Invalid appkey")
try:
user = User.objects.get(auth_token=data["token"])
user.auth_token = None
user.save()
return self.success({"username": user.username,
"id": user.id,
"admin_type": user.admin_type,
"avatar": user.userprofile.avatar})
except User.DoesNotExist:
return self.error("User does not exist")
class TwoFactorAuthAPI(APIView):
@login_required
def get(self, request):
@ -132,14 +92,13 @@ class TwoFactorAuthAPI(APIView):
"""
user = request.user
if user.two_factor_auth:
return self.error("Already open 2FA")
return self.error("2FA is already turned on")
token = rand_str()
user.tfa_token = token
user.save()
config = WebsiteConfig.objects.first()
label = f"{config.name_shortcut}:{user.username}"
image = qrcode.make(OtpAuth(token).to_uri("totp", label, config.name))
label = f"{SysOptions.website_name_shortcut}:{user.username}"
image = qrcode.make(OtpAuth(token).to_uri("totp", label, SysOptions.website_name))
return self.success(img2base64(image))
@login_required
@ -163,7 +122,7 @@ class TwoFactorAuthAPI(APIView):
code = request.data["code"]
user = request.user
if not user.two_factor_auth:
return self.error("Other session have disabled TFA")
return self.error("2FA is already turned off")
if OtpAuth(user.tfa_token).valid_totp(code):
user.two_factor_auth = False
user.save()
@ -200,7 +159,7 @@ class UserLoginAPI(APIView):
# None is returned if username or password is wrong
if user:
if user.is_disabled:
return self.error("Your account have been disabled")
return self.error("Your account has been disabled")
if not user.two_factor_auth:
auth.login(request, user)
return self.success("Succeeded")
@ -220,13 +179,13 @@ class UserLoginAPI(APIView):
# todo remove this, only for debug use
def get(self, request):
auth.login(request, auth.authenticate(username=request.GET["username"], password=request.GET["password"]))
return self.success({})
return self.success()
class UserLogoutAPI(APIView):
def get(self, request):
auth.logout(request)
return self.success({})
return self.success()
class UsernameOrEmailCheck(APIView):
@ -242,11 +201,9 @@ class UsernameOrEmailCheck(APIView):
"email": False
}
if data.get("username"):
if User.objects.filter(username=data["username"]).exists():
result["username"] = True
result["username"] = User.objects.filter(username=data["username"]).exists()
if data.get("email"):
if User.objects.filter(email=data["email"]).exists():
result["email"] = True
result["email"] = User.objects.filter(email=data["email"]).exists()
return self.success(result)
@ -256,17 +213,9 @@ class UserRegisterAPI(APIView):
"""
User register api
"""
config = default_cache.get(CacheKey.website_config)
if config:
config = pickle.loads(config)
else:
config = WebsiteConfig.objects.first()
if not config:
config = WebsiteConfig.objects.create()
default_cache.set(CacheKey.website_config, pickle.dumps(config))
if not config.allow_register:
return self.error("Register have been disabled by admin")
if not SysOptions.allow_register:
return self.error("Register function has been disabled by admin")
data = request.data
captcha = Captcha(request)
@ -295,6 +244,7 @@ class UserChangePasswordAPI(APIView):
username = request.user.username
user = auth.authenticate(username=username, password=data["old_password"])
if user:
# TODO: check tfa?
user.set_password(data["new_password"])
user.save()
return self.success("Succeeded")
@ -307,7 +257,6 @@ class ApplyResetPasswordAPI(APIView):
def post(self, request):
data = request.data
captcha = Captcha(request)
config = WebsiteConfig.objects.first()
if not captcha.check(data["captcha"]):
return self.error("Invalid captcha")
try:
@ -322,14 +271,14 @@ class ApplyResetPasswordAPI(APIView):
user.save()
render_data = {
"username": user.username,
"website_name": config.name,
"link": f"{config.base_url}/reset-password/{user.reset_password_token}"
"website_name": SysOptions.website_name,
"link": f"{SysOptions.website_base_url}/reset-password/{user.reset_password_token}"
}
email_html = render_to_string("reset_password_email.html", render_data)
send_email_async.delay(config.name,
send_email_async.delay(SysOptions.website_name,
user.email,
user.username,
config.name + " 登录信息找回邮件",
f"{SysOptions.website_name} 登录信息找回邮件",
email_html)
return self.success("Succeeded")
@ -344,9 +293,9 @@ class ResetPasswordAPI(APIView):
try:
user = User.objects.get(reset_password_token=data["token"])
except User.DoesNotExist:
return self.error("Token dose not exist")
if int((user.reset_password_token_expire_time - now()).total_seconds()) < 0:
return self.error("Token have expired")
return self.error("Token does not exist")
if user.reset_password_token_expire_time < now():
return self.error("Token has expired")
user.reset_password_token = None
user.two_factor_auth = False
user.set_password(data["password"])
@ -358,14 +307,13 @@ class SessionManagementAPI(APIView):
@login_required
def get(self, request):
engine = import_module(settings.SESSION_ENGINE)
SessionStore = engine.SessionStore
current_session = request.COOKIES.get(settings.SESSION_COOKIE_NAME)
session_store = engine.SessionStore
current_session = request.session.session_key
session_keys = request.user.session_keys
result = []
modified = False
for key in session_keys[:]:
session = SessionStore(key)
session = session_store(key)
# session does not exist or is expiry
if not session._session:
session_keys.remove(key)
@ -377,7 +325,7 @@ class SessionManagementAPI(APIView):
s["current_session"] = True
s["ip"] = session["ip"]
s["user_agent"] = session["user_agent"]
s["last_activity"] = timestamp2utcstr(session["last_activity"])
s["last_activity"] = datetime2str(session["last_activity"])
s["session_key"] = key
result.append(s)
if modified:
@ -401,12 +349,12 @@ class SessionManagementAPI(APIView):
class UserRankAPI(APIView):
def get(self, request):
rule_type = request.GET.get("rule")
if rule_type not in ["acm", "oi"]:
rule_type = "acm"
if rule_type not in ContestRuleType.choices():
rule_type = ContestRuleType.ACM
profiles = UserProfile.objects.select_related("user")\
.filter(submission_number__gt=0)\
.exclude(user__is_disabled=True)
if rule_type == "acm":
if rule_type == ContestRuleType.ACM:
profiles = profiles.order_by("-accepted_number", "submission_number")
else:
profiles = profiles.order_by("-total_score")

View File

@ -0,0 +1,20 @@
# -*- coding: utf-8 -*-
# Generated by Django 1.11.4 on 2017-10-11 12:14
from __future__ import unicode_literals
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('announcement', '0001_initial'),
]
operations = [
migrations.AlterField(
model_name='announcement',
name='title',
field=models.CharField(max_length=64),
),
]

View File

@ -5,7 +5,7 @@ from utils.models import RichTextField
class Announcement(models.Model):
title = models.CharField(max_length=50)
title = models.CharField(max_length=64)
# HTML
content = RichTextField()
create_time = models.DateTimeField(auto_now_add=True)

View File

@ -5,8 +5,8 @@ from .models import Announcement
class CreateAnnouncementSerializer(serializers.Serializer):
title = serializers.CharField(max_length=50)
content = serializers.CharField(max_length=10000)
title = serializers.CharField(max_length=64)
content = serializers.CharField(max_length=1024 * 1024 * 8)
visible = serializers.BooleanField()
@ -21,6 +21,6 @@ class AnnouncementSerializer(serializers.ModelSerializer):
class EditAnnouncementSerializer(serializers.Serializer):
id = serializers.IntegerField()
title = serializers.CharField(max_length=50)
content = serializers.CharField(max_length=10000)
title = serializers.CharField(max_length=64)
content = serializers.CharField(max_length=1024 * 1024 * 8)
visible = serializers.BooleanField()

View File

@ -0,0 +1,39 @@
# -*- coding: utf-8 -*-
# Generated by Django 1.11.4 on 2017-10-11 12:14
from __future__ import unicode_literals
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('conf', '0001_initial'),
]
operations = [
migrations.DeleteModel(
name='JudgeServerToken',
),
migrations.DeleteModel(
name='SMTPConfig',
),
migrations.DeleteModel(
name='WebsiteConfig',
),
migrations.AlterField(
model_name='judgeserver',
name='hostname',
field=models.CharField(max_length=128),
),
migrations.AlterField(
model_name='judgeserver',
name='judger_version',
field=models.CharField(max_length=32),
),
migrations.AlterField(
model_name='judgeserver',
name='service_url',
field=models.CharField(blank=True, max_length=256, null=True),
),
]

View File

@ -2,42 +2,17 @@ from django.db import models
from django.utils import timezone
class SMTPConfig(models.Model):
server = models.CharField(max_length=128)
port = models.IntegerField(default=25)
email = models.CharField(max_length=128)
password = models.CharField(max_length=128)
tls = models.BooleanField()
class Meta:
db_table = "smtp_config"
class WebsiteConfig(models.Model):
base_url = models.CharField(max_length=128, default="http://127.0.0.1")
name = models.CharField(max_length=32, default="Online Judge")
name_shortcut = models.CharField(max_length=32, default="oj")
footer = models.TextField(default="Online Judge Footer")
# allow register
allow_register = models.BooleanField(default=True)
# submission list show all user's submission
submission_list_show_all = models.BooleanField(default=True)
class Meta:
db_table = "website_config"
class JudgeServer(models.Model):
hostname = models.CharField(max_length=64)
hostname = models.CharField(max_length=128)
ip = models.CharField(max_length=32, blank=True, null=True)
judger_version = models.CharField(max_length=24)
judger_version = models.CharField(max_length=32)
cpu_core = models.IntegerField()
memory_usage = models.FloatField()
cpu_usage = models.FloatField()
last_heartbeat = models.DateTimeField()
create_time = models.DateTimeField(auto_now_add=True)
task_number = models.IntegerField(default=0)
service_url = models.CharField(max_length=128, blank=True, null=True)
service_url = models.CharField(max_length=256, blank=True, null=True)
@property
def status(self):
@ -48,10 +23,3 @@ class JudgeServer(models.Model):
class Meta:
db_table = "judge_server"
class JudgeServerToken(models.Model):
token = models.CharField(max_length=32)
class Meta:
db_table = "judge_server_token"

View File

@ -1,6 +1,6 @@
from utils.api import DateTimeTZField, serializers
from .models import JudgeServer, SMTPConfig, WebsiteConfig
from .models import JudgeServer
class EditSMTPConfigSerializer(serializers.Serializer):
@ -15,31 +15,19 @@ class CreateSMTPConfigSerializer(EditSMTPConfigSerializer):
password = serializers.CharField(max_length=128)
class SMTPConfigSerializer(serializers.ModelSerializer):
class Meta:
model = SMTPConfig
exclude = ["id", "password"]
class TestSMTPConfigSerializer(serializers.Serializer):
email = serializers.EmailField()
class CreateEditWebsiteConfigSerializer(serializers.Serializer):
base_url = serializers.CharField(max_length=128)
name = serializers.CharField(max_length=32)
name_shortcut = serializers.CharField(max_length=32)
footer = serializers.CharField(max_length=1024)
website_base_url = serializers.CharField(max_length=128)
website_name = serializers.CharField(max_length=64)
website_name_shortcut = serializers.CharField(max_length=64)
website_footer = serializers.CharField(max_length=1024 * 1024)
allow_register = serializers.BooleanField()
submission_list_show_all = serializers.BooleanField()
class WebsiteConfigSerializer(serializers.ModelSerializer):
class Meta:
model = WebsiteConfig
exclude = ["id"]
class JudgeServerSerializer(serializers.ModelSerializer):
create_time = DateTimeTZField()
last_heartbeat = DateTimeTZField()
@ -47,13 +35,14 @@ class JudgeServerSerializer(serializers.ModelSerializer):
class Meta:
model = JudgeServer
fields = "__all__"
class JudgeServerHeartbeatSerializer(serializers.Serializer):
hostname = serializers.CharField(max_length=64)
judger_version = serializers.CharField(max_length=24)
hostname = serializers.CharField(max_length=128)
judger_version = serializers.CharField(max_length=32)
cpu_core = serializers.IntegerField(min_value=1)
memory = serializers.FloatField(min_value=0, max_value=100)
cpu = serializers.FloatField(min_value=0, max_value=100)
action = serializers.ChoiceField(choices=("heartbeat", ))
service_url = serializers.CharField(max_length=128, required=False)
service_url = serializers.CharField(max_length=256, required=False)

View File

@ -2,11 +2,10 @@ import hashlib
from django.utils import timezone
from options.options import SysOptions
from utils.api.tests import APITestCase
from utils.cache import default_cache
from utils.constants import CacheKey
from .models import JudgeServer, JudgeServerToken, SMTPConfig
from .models import JudgeServer
class SMTPConfigTest(APITestCase):
@ -29,10 +28,6 @@ class SMTPConfigTest(APITestCase):
"tls": True}
resp = self.client.put(self.url, data=data)
self.assertSuccess(resp)
smtp = SMTPConfig.objects.first()
self.assertEqual(smtp.password, self.password)
self.assertEqual(smtp.server, "smtp1.test.com")
self.assertEqual(smtp.email, "test2@test.com")
def test_edit_without_password1(self):
self.test_create_smtp_config()
@ -40,7 +35,6 @@ class SMTPConfigTest(APITestCase):
"tls": True, "password": ""}
resp = self.client.put(self.url, data=data)
self.assertSuccess(resp)
self.assertEqual(SMTPConfig.objects.first().password, self.password)
def test_edit_with_password(self):
self.test_create_smtp_config()
@ -48,18 +42,14 @@ class SMTPConfigTest(APITestCase):
"tls": True, "password": "newpassword"}
resp = self.client.put(self.url, data=data)
self.assertSuccess(resp)
smtp = SMTPConfig.objects.first()
self.assertEqual(smtp.password, "newpassword")
self.assertEqual(smtp.server, "smtp1.test.com")
self.assertEqual(smtp.email, "test2@test.com")
class WebsiteConfigAPITest(APITestCase):
def test_create_website_config(self):
self.create_super_admin()
url = self.reverse("website_config_api")
data = {"base_url": "http://test.com", "name": "test name",
"name_shortcut": "test oj", "footer": "<a>test</a>",
data = {"website_base_url": "http://test.com", "website_name": "test name",
"website_name_shortcut": "test oj", "website_footer": "<a>test</a>",
"allow_register": True, "submission_list_show_all": False}
resp = self.client.post(url, data=data)
self.assertSuccess(resp)
@ -67,8 +57,8 @@ class WebsiteConfigAPITest(APITestCase):
def test_edit_website_config(self):
self.create_super_admin()
url = self.reverse("website_config_api")
data = {"base_url": "http://test.com", "name": "test name",
"name_shortcut": "test oj", "footer": "<a>test</a>",
data = {"website_base_url": "http://test.com", "website_name": "test name",
"website_name_shortcut": "test oj", "website_footer": "<a>test</a>",
"allow_register": True, "submission_list_show_all": False}
resp = self.client.post(url, data=data)
self.assertSuccess(resp)
@ -78,10 +68,6 @@ class WebsiteConfigAPITest(APITestCase):
url = self.reverse("website_info_api")
resp = self.client.get(url)
self.assertSuccess(resp)
self.assertEqual(resp.data["data"]["name_shortcut"], "oj")
def tearDown(self):
default_cache.delete(CacheKey.website_config)
class JudgeServerHeartbeatTest(APITestCase):
@ -91,7 +77,7 @@ class JudgeServerHeartbeatTest(APITestCase):
"cpu": 90.5, "memory": 80.3, "action": "heartbeat"}
self.token = "test"
self.hashed_token = hashlib.sha256(self.token.encode("utf-8")).hexdigest()
JudgeServerToken.objects.create(token=self.token)
SysOptions.judge_server_token = self.token
def test_new_heartbeat(self):
resp = self.client.post(self.url, data=self.data, **{"HTTP_X_JUDGE_SERVER_TOKEN": self.hashed_token})
@ -127,11 +113,9 @@ class JudgeServerAPITest(APITestCase):
self.create_super_admin()
def test_get_judge_server(self):
self.assertFalse(JudgeServerToken.objects.exists())
resp = self.client.get(self.url)
self.assertSuccess(resp)
self.assertEqual(len(resp.data["data"]["servers"]), 1)
self.assertEqual(JudgeServerToken.objects.first().token, resp.data["data"]["token"])
def test_delete_judge_server(self):
resp = self.client.delete(self.url + "?hostname=testhostname")

View File

@ -1,54 +1,45 @@
import hashlib
import pickle
from django.utils import timezone
from account.decorators import super_admin_required
from judge.languages import languages, spj_languages
from judge.dispatcher import process_pending_task
from judge.languages import languages, spj_languages
from options.options import SysOptions
from utils.api import APIView, CSRFExemptAPIView, validate_serializer
from utils.shortcuts import rand_str
from utils.cache import default_cache
from utils.constants import CacheKey
from .models import JudgeServer, JudgeServerToken, SMTPConfig, WebsiteConfig
from .models import JudgeServer
from .serializers import (CreateEditWebsiteConfigSerializer,
CreateSMTPConfigSerializer, EditSMTPConfigSerializer,
JudgeServerHeartbeatSerializer,
JudgeServerSerializer, SMTPConfigSerializer,
TestSMTPConfigSerializer, WebsiteConfigSerializer)
JudgeServerSerializer, TestSMTPConfigSerializer)
class SMTPAPI(APIView):
@super_admin_required
def get(self, request):
smtp = SMTPConfig.objects.first()
smtp = SysOptions.smtp_config
if not smtp:
return self.success(None)
return self.success(SMTPConfigSerializer(smtp).data)
smtp.pop("password")
return self.success(smtp)
@validate_serializer(CreateSMTPConfigSerializer)
@super_admin_required
def post(self, request):
SMTPConfig.objects.all().delete()
smtp = SMTPConfig.objects.create(**request.data)
return self.success(SMTPConfigSerializer(smtp).data)
SysOptions.smtp_config = request.data
return self.success()
@validate_serializer(EditSMTPConfigSerializer)
@super_admin_required
def put(self, request):
smtp = SysOptions.smtp_config
data = request.data
smtp = SMTPConfig.objects.first()
if not smtp:
return self.error("SMTP config is missing")
smtp.server = data["server"]
smtp.port = data["port"]
smtp.email = data["email"]
smtp.tls = data["tls"]
if data.get("password"):
smtp.password = data["password"]
smtp.save()
return self.success(SMTPConfigSerializer(smtp).data)
for item in ["server", "port", "email", "tls"]:
smtp[item] = data[item]
if "password" in data:
smtp["password"] = data["password"]
SysOptions.smtp_config = smtp
return self.success()
class SMTPTestAPI(APIView):
@ -60,37 +51,24 @@ class SMTPTestAPI(APIView):
class WebsiteConfigAPI(APIView):
def get(self, request):
config = default_cache.get(CacheKey.website_config)
if config:
config = pickle.loads(config)
else:
config = WebsiteConfig.objects.first()
if not config:
config = WebsiteConfig.objects.create()
default_cache.set(CacheKey.website_config, pickle.dumps(config))
return self.success(WebsiteConfigSerializer(config).data)
ret = {key: getattr(SysOptions, key) for key in
["website_base_url", "website_name", "website_name_shortcut",
"website_footer", "allow_register", "submission_list_show_all"]}
return self.success(ret)
@validate_serializer(CreateEditWebsiteConfigSerializer)
@super_admin_required
def post(self, request):
data = request.data
WebsiteConfig.objects.all().delete()
config = WebsiteConfig.objects.create(**data)
default_cache.set(CacheKey.website_config, pickle.dumps(config))
return self.success(WebsiteConfigSerializer(config).data)
for k, v in request.data.items():
setattr(SysOptions, k, v)
return self.success()
class JudgeServerAPI(APIView):
@super_admin_required
def get(self, request):
judge_server_token = JudgeServerToken.objects.first()
if not judge_server_token:
token = rand_str(12)
JudgeServerToken.objects.create(token=token)
else:
token = judge_server_token.token
servers = JudgeServer.objects.all().order_by("-last_heartbeat")
return self.success({"token": token,
return self.success({"token": SysOptions.judge_server_token,
"servers": JudgeServerSerializer(servers, many=True).data})
@super_admin_required
@ -104,15 +82,9 @@ class JudgeServerAPI(APIView):
class JudgeServerHeartbeatAPI(CSRFExemptAPIView):
@validate_serializer(JudgeServerHeartbeatSerializer)
def post(self, request):
judge_server_token = JudgeServerToken.objects.first()
if not judge_server_token:
token = rand_str(12)
JudgeServerToken.objects.create(token=token)
else:
token = judge_server_token.token
data = request.data
client_token = request.META.get("HTTP_X_JUDGE_SERVER_TOKEN")
if hashlib.sha256(token.encode("utf-8")).hexdigest() != client_token:
if hashlib.sha256(SysOptions.judge_server_token.encode("utf-8")).hexdigest() != client_token:
return self.error("Invalid token")
service_url = data.get("service_url")

View File

@ -0,0 +1,26 @@
# -*- coding: utf-8 -*-
# Generated by Django 1.11.4 on 2017-10-11 12:14
from __future__ import unicode_literals
import django.contrib.postgres.fields.jsonb
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
('contest', '0005_auto_20170823_0918'),
]
operations = [
migrations.AlterField(
model_name='acmcontestrank',
name='submission_info',
field=django.contrib.postgres.fields.jsonb.JSONField(default=dict),
),
migrations.AlterField(
model_name='oicontestrank',
name='submission_info',
field=django.contrib.postgres.fields.jsonb.JSONField(default=dict),
),
]

View File

@ -1,27 +1,13 @@
from utils.constants import ContestRuleType # noqa
from django.db import models
from django.utils.timezone import now
from jsonfield import JSONField
from utils.models import JSONField
from utils.constants import ContestStatus, ContestType
from account.models import User, AdminType
from utils.models import RichTextField
class ContestType(object):
PUBLIC_CONTEST = "Public"
PASSWORD_PROTECTED_CONTEST = "Password Protected"
class ContestStatus(object):
CONTEST_NOT_START = "1"
CONTEST_ENDED = "-1"
CONTEST_UNDERWAY = "0"
class ContestRuleType(object):
ACM = "ACM"
OI = "OI"
class Contest(models.Model):
title = models.CharField(max_length=40)
description = RichTextField()
@ -59,12 +45,20 @@ class Contest(models.Model):
def is_contest_admin(self, user):
return user.is_authenticated() and (self.created_by == user or user.admin_type == AdminType.SUPER_ADMIN)
def check_oi_permission(self, user):
if self.status != ContestStatus.CONTEST_ENDED and self.real_time_rank == False:
if self.is_contest_admin(user):
return True
else:
return False
return True
class Meta:
db_table = "contest"
ordering = ("-create_time",)
class ContestRank(models.Model):
class AbstractContestRank(models.Model):
user = models.ForeignKey(User)
contest = models.ForeignKey(Contest)
submission_number = models.IntegerField(default=0)
@ -73,30 +67,27 @@ class ContestRank(models.Model):
abstract = True
class ACMContestRank(ContestRank):
class ACMContestRank(AbstractContestRank):
accepted_number = models.IntegerField(default=0)
# total_time is only for ACM contest total_time = ac time + none-ac times * 20 * 60
total_time = models.IntegerField(default=0)
# {23: {"is_ac": True, "ac_time": 8999, "error_number": 2, "is_first_ac": True}}
# key is problem id
submission_info = JSONField(default={})
submission_info = JSONField(default=dict)
class Meta:
db_table = "acm_contest_rank"
class OIContestRank(ContestRank):
class OIContestRank(AbstractContestRank):
total_score = models.IntegerField(default=0)
# {23: 333}}
# key is problem id, value is current score
submission_info = JSONField(default={})
submission_info = JSONField(default=dict)
class Meta:
db_table = "oi_contest_rank"
def update_rank(self, submission):
self.submission_number += 1
class ContestAnnouncement(models.Model):
contest = models.ForeignKey(Contest)

View File

@ -79,7 +79,7 @@ class ContestAPITest(APITestCase):
self.create_user("test", "test123")
url = self.reverse("contest_password_api")
resp = self.client.post(url, {"contest_id": contest_id, "password": "error_password"})
self.assertDictEqual(resp.data, {"error": "error", "data": "Password doesn't match."})
self.assertDictEqual(resp.data, {"error": "error", "data": "Wrong password"})
resp = self.client.post(url, {"contest_id": contest_id, "password": DEFAULT_CONTEST_DATA["password"]})
self.assertSuccess(resp)
@ -89,7 +89,7 @@ class ContestAPITest(APITestCase):
self.create_user("test", "test123")
url = self.reverse("contest_access_api")
resp = self.client.get(url + "?contest_id=" + str(contest_id))
self.assertFalse(resp.data["data"]["Access"])
self.assertFalse(resp.data["data"]["access"])
password_url = self.reverse("contest_password_api")
resp = self.client.post(password_url, {"contest_id": contest_id, "password": DEFAULT_CONTEST_DATA["password"]})

View File

@ -1,13 +1,11 @@
import pickle
from django.utils.timezone import now
from django.db.models import Q
from django.core.cache import cache
from utils.api import APIView, validate_serializer
from utils.cache import default_cache
from utils.constants import CacheKey
from account.decorators import login_required, check_contest_permission
from ..models import ContestAnnouncement, Contest, ContestStatus, ContestRuleType
from ..models import OIContestRank, ACMContestRank
from utils.constants import ContestRuleType, ContestStatus
from ..models import ContestAnnouncement, Contest, OIContestRank, ACMContestRank
from ..serializers import ContestAnnouncementSerializer
from ..serializers import ContestSerializer, ContestPasswordVerifySerializer
from ..serializers import OIContestRankSerializer, ACMContestRankSerializer
@ -32,7 +30,7 @@ class ContestAPI(APIView):
try:
contest = Contest.objects.select_related("created_by").get(id=contest_id, visible=True)
except Contest.DoesNotExist:
return self.error("Contest doesn't exist.")
return self.error("Contest does not exist")
return self.success(ContestSerializer(contest).data)
contests = Contest.objects.select_related("created_by").filter(visible=True)
@ -50,7 +48,7 @@ class ContestAPI(APIView):
elif status == ContestStatus.CONTEST_ENDED:
contests = contests.filter(end_time__lt=cur)
else:
contests = contests.filter(Q(start_time__lte=cur) & Q(end_time__gte=cur))
contests = contests.filter(start_time__lte=cur, end_time__gte=cur)
return self.success(self.paginate_data(request, contests, ContestSerializer))
@ -62,14 +60,14 @@ class ContestPasswordVerifyAPI(APIView):
try:
contest = Contest.objects.get(id=data["contest_id"], visible=True, password__isnull=False)
except Contest.DoesNotExist:
return self.error("Contest %s doesn't exist." % data["contest_id"])
return self.error("Contest does not exist")
if contest.password != data["password"]:
return self.error("Password doesn't match.")
return self.error("Wrong password")
# password verify OK.
if "contests" not in request.session:
request.session["contests"] = []
request.session["contests"].append(int(data["contest_id"]))
if "accessible_contests" not in request.session:
request.session["accessible_contests"] = []
request.session["accessible_contests"].append(contest.id)
# https://docs.djangoproject.com/en/dev/topics/http/sessions/#when-sessions-are-saved
request.session.modified = True
return self.success(True)
@ -80,13 +78,8 @@ class ContestAccessAPI(APIView):
def get(self, request):
contest_id = request.GET.get("contest_id")
if not contest_id:
return self.error("Parameter contest_id not exist.")
if "contests" not in request.session:
request.session["contests"] = []
if int(contest_id) in request.session["contests"]:
return self.success({"Access": True})
else:
return self.success({"Access": False})
return self.error()
return self.success({"access": int(contest_id) in request.session.get("accessible_contests", [])})
class ContestRankAPI(APIView):
@ -100,17 +93,17 @@ class ContestRankAPI(APIView):
@check_contest_permission
def get(self, request):
if self.contest.rule_type == ContestRuleType.ACM:
serializer = ACMContestRankSerializer
else:
if self.contest.rule_type == ContestRuleType.OI:
if not self.contest.check_oi_permission(request.user):
return self.error("You have no permission for ranks now")
serializer = OIContestRankSerializer
cache_key = CacheKey.contest_rank_cache + str(self.contest.id)
qs = default_cache.get(cache_key)
if not qs:
ranks = self.get_rank()
default_cache.set(cache_key, pickle.dumps(ranks))
else:
ranks = pickle.loads(qs)
serializer = ACMContestRankSerializer
return self.success(self.paginate_data(request, ranks, serializer))
cache_key = f"{CacheKey.contest_rank_cache}:{self.contest.id}"
qs = cache.get(cache_key)
if not qs:
qs = self.get_rank()
cache.set(cache_key, qs)
return self.success(self.paginate_data(request, qs, serializer))

View File

@ -8,12 +8,13 @@ from django.db import transaction
from django.db.models import F
from account.models import User
from conf.models import JudgeServer, JudgeServerToken
from conf.models import JudgeServer
from contest.models import ContestRuleType, ACMContestRank, OIContestRank, ContestStatus
from judge.languages import languages
from options.options import SysOptions
from problem.models import Problem, ProblemRuleType
from submission.models import JudgeStatus, Submission
from utils.cache import judge_cache, default_cache
from utils.cache import cache
from utils.constants import CacheKey
logger = logging.getLogger(__name__)
@ -21,41 +22,36 @@ logger = logging.getLogger(__name__)
# 继续处理在队列中的问题
def process_pending_task():
if judge_cache.llen(CacheKey.waiting_queue):
if cache.llen(CacheKey.waiting_queue):
# 防止循环引入
from judge.tasks import judge_task
data = json.loads(judge_cache.rpop(CacheKey.waiting_queue).decode("utf-8"))
data = json.loads(cache.rpop(CacheKey.waiting_queue).decode("utf-8"))
judge_task.delay(**data)
class JudgeDispatcher(object):
def __init__(self, submission_id, problem_id):
token = JudgeServerToken.objects.first().token
self.token = hashlib.sha256(token.encode("utf-8")).hexdigest()
self.redis_conn = judge_cache
self.submission = Submission.objects.get(pk=submission_id)
self.token = hashlib.sha256(SysOptions.judge_server_token.encode("utf-8")).hexdigest()
self.submission = Submission.objects.get(id=submission_id)
self.contest_id = self.submission.contest_id
if self.contest_id:
self.problem = Problem.objects.select_related("contest") \
.get(id=problem_id, contest_id=self.contest_id)
self.problem = Problem.objects.select_related("contest").get(id=problem_id, contest_id=self.contest_id)
self.contest = self.problem.contest
else:
self.problem = Problem.objects.get(id=problem_id)
def _request(self, url, data=None):
kwargs = {"headers": {"X-Judge-Server-Token": self.token,
"Content-Type": "application/json"}}
kwargs = {"headers": {"X-Judge-Server-Token": self.token}}
if data:
kwargs["data"] = json.dumps(data)
kwargs["json"] = data
try:
return requests.post(url, **kwargs).json()
except Exception as e:
logger.error(e.with_traceback())
logger.exception(e)
@staticmethod
def choose_judge_server():
with transaction.atomic():
# TODO: use more reasonable way
servers = JudgeServer.objects.select_for_update().all().order_by("task_number")
servers = [s for s in servers if s.status == "normal"]
if servers:
@ -65,10 +61,10 @@ class JudgeDispatcher(object):
return server
@staticmethod
def release_judge_res(judge_server_id):
def release_judge_server(judge_server_id):
with transaction.atomic():
# 使用原子操作, 同时因为use和release中间间隔了判题过程,需要重新查询一下
server = JudgeServer.objects.select_for_update().get(id=judge_server_id)
server = JudgeServer.objects.get(id=judge_server_id)
server.used_instance_number = F("task_number") - 1
server.save()
@ -94,7 +90,7 @@ class JudgeDispatcher(object):
server = self.choose_judge_server()
if not server:
data = {"submission_id": self.submission.id, "problem_id": self.problem.id}
self.redis_conn.lpush(CacheKey.waiting_queue, json.dumps(data))
cache.lpush(CacheKey.waiting_queue, json.dumps(data))
return
sub_config = list(filter(lambda item: self.submission.language == item["name"], languages))[0]
@ -138,7 +134,7 @@ class JudgeDispatcher(object):
else:
self.submission.result = JudgeStatus.PARTIALLY_ACCEPTED
self.submission.save()
self.release_judge_res(server.id)
self.release_judge_server(server.id)
self.update_problem_status()
if self.contest_id:
@ -160,7 +156,6 @@ class JudgeDispatcher(object):
with transaction.atomic():
# prepare problem and user_profile
problem = Problem.objects.select_for_update().get(contest_id=self.contest_id, id=self.problem.id)
problem_info = problem.statistic_info
user = User.objects.select_for_update().select_for_update("userprofile").get(id=self.submission.user_id)
user_profile = user.userprofile
if self.contest_id:
@ -169,50 +164,46 @@ class JudgeDispatcher(object):
key = "problems"
acm_problems_status = user_profile.acm_problems_status.get(key, {})
oi_problems_status = user_profile.oi_problems_status.get(key, {})
problem_id = str(self.problem.id)
problem_info = problem.statistic_info
# update problem info
result = str(self.submission.result)
problem_info[result] = problem_info.get(result, 0) + 1
problem.statistic_info = problem_info
# update submission and accepted number counter
problem.submission_number += 1
if self.submission.result == JudgeStatus.ACCEPTED:
problem.accepted_number += 1
# only when submission is not in contest, we update user profile,
# in other words, users' submission in a contest will not be counted in user profile
# submission in a contest will not be counted in user profile
if not self.contest_id:
user_profile.submission_number += 1
if self.submission.result == JudgeStatus.ACCEPTED:
user_profile.accepted_number += 1
problem_id = str(self.problem.id)
if self.problem.rule_type == ProblemRuleType.ACM:
# update acm problem info
result = str(self.submission.result)
problem_info[result] = problem_info.get(result, 0) + 1
problem.statistic_info = problem_info
# update user_profile
if problem_id not in acm_problems_status:
acm_problems_status[problem_id] = self.submission.result
acm_problems_status[problem_id] = {"status": self.submission.result, "_id": self.problem._id}
# skip if the problem has been accepted
elif acm_problems_status[problem_id] != JudgeStatus.ACCEPTED:
if self.submission.result == JudgeStatus.ACCEPTED:
acm_problems_status[problem_id] = JudgeStatus.ACCEPTED
else:
acm_problems_status[problem_id] = self.submission.result
elif acm_problems_status[problem_id]["status"] != JudgeStatus.ACCEPTED:
acm_problems_status[problem_id]["status"] = self.submission.result
user_profile.acm_problems_status[key] = acm_problems_status
else:
# update oi problem info
score = self.submission.statistic_info["score"]
problem_info[score] = problem_info.get(score, 0) + 1
problem.statistic_info = problem_info
# update user_profile
score = self.submission.statistic_info["score"]
if problem_id not in oi_problems_status:
user_profile.add_score(score)
oi_problems_status[problem_id] = score
oi_problems_status[problem_id] = {"status": self.submission.result,
"_id": self.problem._id,
"score": score}
else:
# minus last time score, add this time score
user_profile.add_score(this_time_score=score, last_time_score=oi_problems_status[problem_id])
oi_problems_status[problem_id] = score
user_profile.add_score(this_time_score=score, last_time_score=oi_problems_status[problem_id]["score"])
oi_problems_status[problem_id]["score"] = score
oi_problems_status[problem_id]["status"] = self.submission.result
user_profile.oi_problems_status[key] = oi_problems_status
problem.save(update_fields=["submission_number", "accepted_number", "statistic_info"])
@ -222,8 +213,8 @@ class JudgeDispatcher(object):
def update_contest_rank(self):
if self.contest_id and self.contest.status != ContestStatus.CONTEST_UNDERWAY:
return
if self.contest.real_time_rank:
default_cache.delete(CacheKey.contest_rank_cache + str(self.contest_id))
if self.contest.rule_type == ContestRuleType.OI or self.contest.real_time_rank:
cache.delete(f"{CacheKey.contest_rank_cache}:{self.contest.id}")
with transaction.atomic():
if self.contest.rule_type == ContestRuleType.ACM:
acm_rank, _ = ACMContestRank.objects.select_for_update(). \

View File

@ -5,8 +5,12 @@ BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
DATABASES = {
'default': {
'ENGINE': 'django.db.backends.sqlite3',
'NAME': os.path.join(BASE_DIR, 'db.sqlite3'),
'ENGINE': 'django.db.backends.postgresql_psycopg2',
'HOST': '127.0.0.1',
'PORT': 5433,
'NAME': "onlinejudge",
'USER': "onlinejudge",
'PASSWORD': 'onlinejudge'
}
}

View File

@ -46,6 +46,7 @@ INSTALLED_APPS = (
'contest',
'utils',
'submission',
'options',
)
MIDDLEWARE_CLASSES = (
@ -57,11 +58,9 @@ MIDDLEWARE_CLASSES = (
'django.middleware.clickjacking.XFrameOptionsMiddleware',
'django.middleware.security.SecurityMiddleware',
'account.middleware.AdminRoleRequiredMiddleware',
'account.middleware.SessionSecurityMiddleware',
'account.middleware.SessionRecordMiddleware',
# 'account.middleware.LogSqlMiddleware',
)
SESSION_ENGINE = 'django.contrib.sessions.backends.cache'
ROOT_URLCONF = 'oj.urls'
TEMPLATES = [
@ -164,8 +163,6 @@ LOGGING = {
}
},
}
REST_FRAMEWORK = {
'TEST_REQUEST_DEFAULT_FORMAT': 'json',
'DEFAULT_RENDERER_CLASSES': (
@ -173,34 +170,34 @@ REST_FRAMEWORK = {
)
}
CACHE_JUDGE_QUEUE = "judge_queue"
CACHE_THROTTLING = "throttling"
REDIS_URL = "redis://127.0.0.1:6379"
def redis_config(db):
def make_key(key, key_prefix, version):
return key
return {
"BACKEND": "utils.cache.MyRedisCache",
"LOCATION": f"{REDIS_URL}/{db}",
"TIMEOUT": None,
"KEY_PREFIX": "",
"KEY_FUNCTION": make_key
}
CACHES = {
"default": {
"BACKEND": "django_redis.cache.RedisCache",
"LOCATION": "redis://127.0.0.1:6379/1",
"OPTIONS": {
"CLIENT_CLASS": "django_redis.client.DefaultClient",
}
},
CACHE_JUDGE_QUEUE: {
"BACKEND": "django_redis.cache.RedisCache",
"LOCATION": "redis://127.0.0.1:6379/2",
"OPTIONS": {
"CLIENT_CLASS": "django_redis.client.DefaultClient",
}
},
CACHE_THROTTLING: {
"BACKEND": "django_redis.cache.RedisCache",
"LOCATION": "redis://127.0.0.1:6379/3",
"OPTIONS": {
"CLIENT_CLASS": "django_redis.client.DefaultClient",
}
}
"default": redis_config(db=1)
}
CELERY_RESULT_BACKEND = CELERY_BROKER_URL = f"{REDIS_URL}/2"
CELERY_TASK_SOFT_TIME_LIMIT = CELERY_TASK_TIME_LIMIT = 180
SESSION_ENGINE = "django.contrib.sessions.backends.cache"
SESSION_CACHE_ALIAS = "default"
# For celery
REDIS_QUEUE = {
"host": "127.0.0.1",

0
options/__init__.py Normal file
View File

View File

@ -0,0 +1,25 @@
# -*- coding: utf-8 -*-
# Generated by Django 1.11.3 on 2017-10-01 19:19
from __future__ import unicode_literals
import jsonfield.fields
from django.db import migrations, models
class Migration(migrations.Migration):
initial = True
dependencies = [
]
operations = [
migrations.CreateModel(
name='SysOptions',
fields=[
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('key', models.CharField(db_index=True, max_length=128, unique=True)),
('value', jsonfield.fields.JSONField()),
],
),
]

View File

@ -0,0 +1,21 @@
# -*- coding: utf-8 -*-
# Generated by Django 1.11.4 on 2017-10-11 12:14
from __future__ import unicode_literals
import django.contrib.postgres.fields.jsonb
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
('options', '0001_initial'),
]
operations = [
migrations.AlterField(
model_name='sysoptions',
name='value',
field=django.contrib.postgres.fields.jsonb.JSONField(),
),
]

View File

7
options/models.py Normal file
View File

@ -0,0 +1,7 @@
from django.db import models
from utils.models import JSONField
class SysOptions(models.Model):
key = models.CharField(max_length=128, unique=True, db_index=True)
value = JSONField()

179
options/options.py Normal file
View File

@ -0,0 +1,179 @@
from django.core.cache import cache
from django.db import transaction, IntegrityError
from utils.constants import CacheKey
from utils.shortcuts import rand_str
from .models import SysOptions as SysOptionsModel
class OptionKeys:
website_base_url = "website_base_url"
website_name = "website_name"
website_name_shortcut = "website_name_shortcut"
website_footer = "website_footer"
allow_register = "allow_register"
submission_list_show_all = "submission_list_show_all"
smtp_config = "smtp_config"
judge_server_token = "judge_server_token"
class OptionDefaultValue:
website_base_url = "http://127.0.0.1"
website_name = "Online Judge"
website_name_shortcut = "oj"
website_footer = "Online Judge Footer"
allow_register = True
submission_list_show_all = True
smtp_config = {}
judge_server_token = rand_str
class _SysOptionsMeta(type):
@classmethod
def _set_cache(mcs, option_key, option_value):
cache.set(f"{CacheKey.option}:{option_key}", option_value, timeout=60)
@classmethod
def _del_cache(mcs, option_key):
cache.delete(f"{CacheKey.option}:{option_key}")
@classmethod
def _get_keys(cls):
return [key for key in OptionKeys.__dict__ if not key.startswith("__")]
def rebuild_cache(cls):
for key in cls._get_keys():
# get option 的时候会写 cache 的
cls._get_option(key, use_cache=False)
@classmethod
def _init_option(mcs):
for item in mcs._get_keys():
if not SysOptionsModel.objects.filter(key=item).exists():
default_value = getattr(OptionDefaultValue, item)
if callable(default_value):
default_value = default_value()
try:
SysOptionsModel.objects.create(key=item, value=default_value)
except IntegrityError:
pass
@classmethod
def _get_option(mcs, option_key, use_cache=True):
try:
if use_cache:
option = cache.get(f"{CacheKey.option}:{option_key}")
if option:
return option
option = SysOptionsModel.objects.get(key=option_key)
value = option.value
mcs._set_cache(option_key, value)
return value
except SysOptionsModel.DoesNotExist:
mcs._init_option()
return mcs._get_option(option_key, use_cache=use_cache)
@classmethod
def _set_option(mcs, option_key: str, option_value):
try:
with transaction.atomic():
option = SysOptionsModel.objects.select_for_update().get(key=option_key)
option.value = option_value
option.save()
mcs._del_cache(option_key)
except SysOptionsModel.DoesNotExist:
mcs._init_option()
mcs._set_option(option_key, option_value)
@classmethod
def _increment(mcs, option_key):
try:
with transaction.atomic():
option = SysOptionsModel.objects.select_for_update().get(key=option_key)
value = option.value + 1
option.value = value
option.save()
mcs._del_cache(option_key)
except SysOptionsModel.DoesNotExist:
mcs._init_option()
return mcs._increment(option_key)
@classmethod
def set_options(mcs, options):
for key, value in options:
mcs._set_option(key, value)
@classmethod
def get_options(mcs, keys):
result = {}
for key in keys:
result[key] = mcs._get_option(key)
return result
@property
def website_base_url(cls):
return cls._get_option(OptionKeys.website_base_url)
@website_base_url.setter
def website_base_url(cls, value):
cls._set_option(OptionKeys.website_base_url, value)
@property
def website_name(cls):
return cls._get_option(OptionKeys.website_name)
@website_name.setter
def website_name(cls, value):
cls._set_option(OptionKeys.website_name, value)
@property
def website_name_shortcut(cls):
return cls._get_option(OptionKeys.website_name_shortcut)
@website_name_shortcut.setter
def website_name_shortcut(cls, value):
cls._set_option(OptionKeys.website_name_shortcut, value)
@property
def website_footer(cls):
return cls._get_option(OptionKeys.website_footer)
@website_footer.setter
def website_footer(cls, value):
cls._set_option(OptionKeys.website_footer, value)
@property
def allow_register(cls):
return cls._get_option(OptionKeys.allow_register)
@allow_register.setter
def allow_register(cls, value):
cls._set_option(OptionKeys.allow_register, value)
@property
def submission_list_show_all(cls):
return cls._get_option(OptionKeys.submission_list_show_all)
@submission_list_show_all.setter
def submission_list_show_all(cls, value):
cls._set_option(OptionKeys.submission_list_show_all, value)
@property
def smtp_config(cls):
return cls._get_option(OptionKeys.smtp_config)
@smtp_config.setter
def smtp_config(cls, value):
cls._set_option(OptionKeys.smtp_config, value)
@property
def judge_server_token(cls):
return cls._get_option(OptionKeys.judge_server_token)
@judge_server_token.setter
def judge_server_token(cls, value):
cls._set_option(OptionKeys.judge_server_token, value)
class SysOptions(metaclass=_SysOptionsMeta):
pass

1
options/tests.py Normal file
View File

@ -0,0 +1 @@
# Create your tests here.

1
options/views.py Normal file
View File

@ -0,0 +1 @@
# Create your views here.

View File

@ -0,0 +1,41 @@
# -*- coding: utf-8 -*-
# Generated by Django 1.11.4 on 2017-10-11 12:14
from __future__ import unicode_literals
import django.contrib.postgres.fields.jsonb
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
('problem', '0008_auto_20170923_1318'),
]
operations = [
migrations.AlterField(
model_name='problem',
name='languages',
field=django.contrib.postgres.fields.jsonb.JSONField(),
),
migrations.AlterField(
model_name='problem',
name='samples',
field=django.contrib.postgres.fields.jsonb.JSONField(),
),
migrations.AlterField(
model_name='problem',
name='statistic_info',
field=django.contrib.postgres.fields.jsonb.JSONField(default=dict),
),
migrations.AlterField(
model_name='problem',
name='template',
field=django.contrib.postgres.fields.jsonb.JSONField(),
),
migrations.AlterField(
model_name='problem',
name='test_case_score',
field=django.contrib.postgres.fields.jsonb.JSONField(),
),
]

View File

@ -1,5 +1,5 @@
from django.db import models
from jsonfield import JSONField
from utils.models import JSONField
from account.models import User
from contest.models import Contest
@ -65,8 +65,8 @@ class Problem(models.Model):
total_score = models.IntegerField(default=0, blank=True)
submission_number = models.BigIntegerField(default=0)
accepted_number = models.BigIntegerField(default=0)
# ACM rule_type: {JudgeStatus.ACCEPTED: 3, JudgeStaus.WRONG_ANSWER: 11}, the number means count
statistic_info = JSONField(default={})
# {JudgeStatus.ACCEPTED: 3, JudgeStaus.WRONG_ANSWER: 11}, the number means count
statistic_info = JSONField(default=dict)
class Meta:
db_table = "problem"

View File

@ -71,6 +71,7 @@ class CreateContestProblemSerializer(CreateOrEditProblemSerializer):
class TagSerializer(serializers.ModelSerializer):
class Meta:
model = ProblemTag
fields = "__all__"
class BaseProblemSerializer(serializers.ModelSerializer):
@ -88,11 +89,13 @@ class BaseProblemSerializer(serializers.ModelSerializer):
class ProblemAdminSerializer(BaseProblemSerializer):
class Meta:
model = Problem
fields = "__all__"
class ContestProblemAdminSerializer(BaseProblemSerializer):
class Meta:
model = Problem
fields = "__all__"
class ProblemSerializer(BaseProblemSerializer):

View File

@ -1,4 +1,5 @@
import copy
import hashlib
import os
import shutil
from datetime import timedelta
@ -9,6 +10,7 @@ from django.conf import settings
from utils.api.tests import APITestCase
from .models import ProblemTag
from .models import Problem, ProblemRuleType
from .views.admin import TestCaseUploadAPI
from contest.models import Contest
from contest.tests import DEFAULT_CONTEST_DATA
@ -23,6 +25,40 @@ DEFAULT_PROBLEM_DATA = {"_id": "A-110", "title": "test", "description": "<p>test
"input_size": 0, "score": 0}],
"rule_type": "ACM", "hint": "<p>test</p>", "source": "test"}
class ProblemCreateTestBase(APITestCase):
@staticmethod
def add_problem(problem_data, created_by):
data = copy.deepcopy(problem_data)
if data["spj"]:
if not data["spj_language"] or not data["spj_code"]:
raise ValueError("Invalid spj")
data["spj_version"] = hashlib.md5((data["spj_language"] + ":" + data["spj_code"]).encode("utf-8")).hexdigest()
else:
data["spj_language"] = None
data["spj_code"] = None
if data["rule_type"] == ProblemRuleType.OI:
total_score = 0
for item in data["test_case_score"]:
if item["score"] <= 0:
raise ValueError("invalid score")
else:
total_score += item["score"]
data["total_score"] = total_score
data["created_by"] = created_by
tags = data.pop("tags")
data["languages"] = list(data["languages"])
problem = Problem.objects.create(**data)
for item in tags:
try:
tag = ProblemTag.objects.get(name=item)
except ProblemTag.DoesNotExist:
tag = ProblemTag.objects.create(name=item)
problem.tags.add(tag)
return problem
class ProblemTagListAPITest(APITestCase):
def test_get_tag_list(self):
@ -96,7 +132,7 @@ class ProblemAdminAPITest(APITestCase):
def setUp(self):
self.url = self.reverse("problem_admin_api")
self.create_super_admin()
self.data = DEFAULT_PROBLEM_DATA
self.data = copy.deepcopy(DEFAULT_PROBLEM_DATA)
def test_create_problem(self):
resp = self.client.post(self.url, data=self.data)
@ -138,23 +174,19 @@ class ProblemAdminAPITest(APITestCase):
self.assertSuccess(resp)
class ProblemAPITest(APITestCase):
class ProblemAPITest(ProblemCreateTestBase):
def setUp(self):
self.url = self.reverse("problem_api")
self.create_admin()
def create_problem(self):
url = self.reverse("problem_admin_api")
return self.client.post(url, data=DEFAULT_PROBLEM_DATA)
admin = self.create_admin(login=False)
self.problem = self.add_problem(DEFAULT_PROBLEM_DATA, admin)
self.create_user("test", "test123")
def test_get_problem_list(self):
self.create_problem()
resp = self.client.get(f"{self.url}?limit=10")
self.assertSuccess(resp)
def get_one_problem(self):
problem_id = self.create_problem().data["data"]["_id"]
resp = self.client.get(self.url + "?id=" + str(problem_id))
resp = self.client.get(self.url + "?id=" + self.problem._id)
self.assertSuccess(resp)
@ -169,51 +201,49 @@ class ContestProblemAdminTest(APITestCase):
def test_create_contest_problem(self):
contest = self.create_contest()
data = DEFAULT_PROBLEM_DATA
data = copy.deepcopy(DEFAULT_PROBLEM_DATA)
data["contest_id"] = contest.data["data"]["id"]
resp = self.client.post(self.url, data=data)
self.assertSuccess(resp)
return resp
return contest, resp
def test_get_contest_problem(self):
contest = self.test_create_contest_problem()
contest, contest_problem = self.test_create_contest_problem()
contest_id = contest.data["data"]["id"]
resp = self.client.get(self.url + "?contest_id=" + str(contest_id))
self.assertSuccess(resp)
self.assertEqual(len(resp.data["data"]), 1)
def test_get_one_contest_problem(self):
contest = self.test_create_contest_problem()
contest, contest_problem = self.test_create_contest_problem()
contest_id = contest.data["data"]["id"]
resp = self.client.get(self.url + "?id=" + str(contest_id))
problem_id = contest_problem.data["data"]["id"]
resp = self.client.get(f"{self.url}?contest_id={contest_id}&id={problem_id}")
self.assertSuccess(resp)
class ContestProblemTest(APITestCase):
class ContestProblemTest(ProblemCreateTestBase):
def setUp(self):
self.url = self.reverse("contest_problem_api")
self.create_admin()
admin = self.create_admin()
url = self.reverse("contest_admin_api")
contest_data = copy.deepcopy(DEFAULT_CONTEST_DATA)
contest_data["password"] = ""
contest_data["start_time"] = contest_data["start_time"] + timedelta(hours=1)
self.contest = self.client.post(url, data=contest_data).data["data"]
self.problem = self.add_problem(DEFAULT_PROBLEM_DATA, admin)
self.problem.contest_id = self.contest["id"]
self.problem.save()
self.url = self.reverse("contest_problem_api")
problem_data = copy.deepcopy(DEFAULT_PROBLEM_DATA)
problem_data["contest"] = self.contest["id"]
url = self.reverse("contest_problem_admin_api")
self.problem = self.client.post(url, problem_data).data["data"]
def test_get_contest_problem_list(self):
def test_admin_get_contest_problem_list(self):
contest_id = self.contest["id"]
resp = self.client.get(self.url + "?contest_id=" + str(contest_id))
self.assertSuccess(resp)
self.assertEqual(len(resp.data["data"]), 1)
def test_get_one_contest_problem(self):
def test_admin_get_one_contest_problem(self):
contest_id = self.contest["id"]
problem_id = self.problem["_id"]
problem_id = self.problem._id
resp = self.client.get("{}?contest_id={}&problem_id={}".format(self.url, contest_id, problem_id))
self.assertSuccess(resp)

View File

@ -223,6 +223,7 @@ class ProblemAPI(APIView):
data["total_score"] = total_score
# todo check filename and score info
tags = data.pop("tags")
data["languages"] = list(data["languages"])
for k, v in data.items():
setattr(problem, k, v)

View File

@ -13,6 +13,24 @@ class ProblemTagAPI(APIView):
class ProblemAPI(APIView):
@staticmethod
def _add_problem_status(request, queryset_values):
if request.user.is_authenticated():
profile = request.user.userprofile
acm_problems_status = profile.acm_problems_status.get("problems", {})
oi_problems_status = profile.oi_problems_status.get("problems", {})
# paginate data
results = queryset_values.get("results")
if results is not None:
problems = results
else:
problems = [queryset_values,]
for problem in problems:
if problem["rule_type"] == ProblemRuleType.ACM:
problem["my_status"] = acm_problems_status.get(str(problem["id"]), {}).get("status")
else:
problem["my_status"] = oi_problems_status.get(str(problem["id"]), {}).get("status")
def get(self, request):
# 问题详情页
problem_id = request.GET.get("problem_id")
@ -20,7 +38,9 @@ class ProblemAPI(APIView):
try:
problem = Problem.objects.select_related("created_by")\
.get(_id=problem_id, contest_id__isnull=True, visible=True)
return self.success(ProblemSerializer(problem).data)
problem_data = ProblemSerializer(problem).data
self._add_problem_status(request, problem_data)
return self.success(problem_data)
except Problem.DoesNotExist:
return self.error("Problem does not exist")
@ -32,11 +52,7 @@ class ProblemAPI(APIView):
# 按照标签筛选
tag_text = request.GET.get("tag")
if tag_text:
try:
tag = ProblemTag.objects.get(name=tag_text)
except ProblemTag.DoesNotExist:
return self.error("The Tag does not exist.")
problems = tag.problem_set.all().filter(visible=True)
problems = problems.filter(tags__name=tag_text)
# 搜索的情况
keyword = request.GET.get("keyword", "").strip()
@ -49,19 +65,23 @@ class ProblemAPI(APIView):
problems = problems.filter(difficulty=difficulty)
# 根据profile 为做过的题目添加标记
data = self.paginate_data(request, problems, ProblemSerializer)
if request.user.id:
profile = request.user.userprofile
acm_problems_status = profile.acm_problems_status.get("problems", {})
oi_problems_status = profile.oi_problems_status.get("problems", {})
for problem in data["results"]:
if problem["rule_type"] == ProblemRuleType.ACM:
problem["my_status"] = acm_problems_status.get(str(problem["id"]), None)
else:
problem["my_status"] = oi_problems_status.get(str(problem["id"]), None)
self._add_problem_status(request, data)
return self.success(data)
class ContestProblemAPI(APIView):
def _add_problem_status(self, request, queryset_values):
if self.contest.rule_type == ContestRuleType.OI and not self.contest.check_oi_permission(request.user):
return
if request.user.is_authenticated():
profile = request.user.userprofile
if self.contest.rule_type == ContestRuleType.ACM:
problems_status = profile.acm_problems_status.get("contest_problems", {})
else:
problems_status = profile.oi_problems_status.get("contest_problems", {})
for problem in queryset_values:
problem["my_status"] = problems_status.get(str(problem["id"]), {}).get("status")
@check_contest_permission
def get(self, request):
problem_id = request.GET.get("problem_id")
@ -72,17 +92,12 @@ class ContestProblemAPI(APIView):
visible=True)
except Problem.DoesNotExist:
return self.error("Problem does not exist.")
return self.success(ContestProblemSerializer(problem).data)
problem_data = ContestProblemSerializer(problem).data
self._add_problem_status(request, [problem_data,])
return self.success(problem_data)
contest_problems = Problem.objects.select_related("created_by").filter(contest=self.contest, visible=True)
# 根据profile 为做过的题目添加标记
data = ContestProblemSerializer(contest_problems, many=True).data
if request.user.is_authenticated() and self.contest.rule_type != ContestRuleType.OI:
profile = request.user.userprofile
if self.contest.rule_type == ContestRuleType.ACM:
problems_status = profile.acm_problems_status.get("contest_problems", {})
else:
problems_status = profile.oi_problems_status.get("contest_problems", {})
for problem in data:
problem["my_status"] = problems_status.get(str(problem["id"]), None)
self._add_problem_status(request, data)
return self.success(data)

View File

@ -0,0 +1,26 @@
# -*- coding: utf-8 -*-
# Generated by Django 1.11.4 on 2017-10-11 12:14
from __future__ import unicode_literals
import django.contrib.postgres.fields.jsonb
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
('submission', '0007_auto_20170923_1318'),
]
operations = [
migrations.AlterField(
model_name='submission',
name='info',
field=django.contrib.postgres.fields.jsonb.JSONField(default=dict),
),
migrations.AlterField(
model_name='submission',
name='statistic_info',
field=django.contrib.postgres.fields.jsonb.JSONField(default=dict),
),
]

View File

@ -1,5 +1,5 @@
from django.db import models
from jsonfield import JSONField
from utils.models import JSONField
from account.models import AdminType
from problem.models import Problem
from contest.models import Contest
@ -30,18 +30,20 @@ class Submission(models.Model):
username = models.CharField(max_length=30)
code = models.TextField()
result = models.IntegerField(db_index=True, default=JudgeStatus.PENDING)
# 判题结果的详细信息
info = JSONField(default={})
# 从JudgeServer返回的判题详情
info = JSONField(default=dict)
language = models.CharField(max_length=20)
shared = models.BooleanField(default=False)
# 存储该提交所用时间和内存值,方便提交列表显示
# {time_cost: "", memory_cost: "", err_info: "", score: 0}
statistic_info = JSONField(default={})
statistic_info = JSONField(default=dict)
def check_user_permission(self, user):
def check_user_permission(self, user, check_share=True):
return self.user_id == user.id or \
self.shared is True or \
user.admin_type == AdminType.SUPER_ADMIN
(check_share and self.shared is True) or \
user.is_super_admin() or \
user.can_mgmt_all_problem() or \
self.problem.created_by_id == user.id
class Meta:
db_table = "submission"

View File

@ -10,6 +10,11 @@ class CreateSubmissionSerializer(serializers.Serializer):
contest_id = serializers.IntegerField(required=False)
class ShareSubmissionSerializer(serializers.Serializer):
id = serializers.CharField()
shared = serializers.BooleanField()
class SubmissionModelSerializer(serializers.ModelSerializer):
info = serializers.JSONField()
statistic_info = serializers.JSONField()
@ -19,7 +24,7 @@ class SubmissionModelSerializer(serializers.ModelSerializer):
# 不显示submission info的serializer, 用于ACM rule_type
class SubmissionSafeSerializer(serializers.ModelSerializer):
class SubmissionSafeModelSerializer(serializers.ModelSerializer):
problem = serializers.SlugRelatedField(read_only=True, slug_field="_id")
statistic_info = serializers.JSONField()
@ -43,6 +48,6 @@ class SubmissionListSerializer(serializers.ModelSerializer):
def get_show_link(self, obj):
# 没传user或为匿名user
if self.user is None or self.user.id is None:
if self.user is None or not self.user.is_authenticated():
return False
return obj.check_user_permission(self.user)

View File

@ -32,14 +32,16 @@ class SubmissionPrepare(APITestCase):
def _create_problem_and_submission(self):
user = self.create_admin("test", "test123", login=False)
problem_data = deepcopy(DEFAULT_PROBLEM_DATA)
problem_data.pop("tags")
tags = problem_data.pop("tags")
problem_data["created_by"] = user
self.problem = Problem.objects.create(**problem_data)
for tag in DEFAULT_PROBLEM_DATA["tags"]:
for tag in tags:
tag = ProblemTag.objects.create(name=tag)
self.problem.tags.add(tag)
self.problem.save()
self.submission = Submission.objects.create(**DEFAULT_SUBMISSION_DATA)
self.submission_data = deepcopy(DEFAULT_SUBMISSION_DATA)
self.submission_data["problem_id"] = self.problem.id
self.submission = Submission.objects.create(**self.submission_data)
class SubmissionListTest(SubmissionPrepare):
@ -61,6 +63,6 @@ class SubmissionAPITest(SubmissionPrepare):
self.url = self.reverse("submission_api")
def test_create_submission(self, judge_task):
resp = self.client.post(self.url, DEFAULT_SUBMISSION_DATA)
resp = self.client.post(self.url, self.submission_data)
self.assertSuccess(resp)
judge_task.assert_called()

View File

@ -5,16 +5,17 @@ from problem.models import Problem, ProblemRuleType
from contest.models import Contest, ContestStatus, ContestRuleType
from utils.api import APIView, validate_serializer
from utils.throttling import TokenBucket, BucketController
from utils.cache import cache
from ..models import Submission
from ..serializers import CreateSubmissionSerializer, SubmissionModelSerializer
from ..serializers import SubmissionSafeSerializer, SubmissionListSerializer
from utils.cache import throttling_cache
from ..serializers import (CreateSubmissionSerializer, SubmissionModelSerializer,
ShareSubmissionSerializer)
from ..serializers import SubmissionSafeModelSerializer, SubmissionListSerializer
def _submit(response, user, problem_id, language, code, contest_id):
# TODO: 预设默认值,需修改
controller = BucketController(user_id=user.id,
redis_conn=throttling_cache,
redis_conn=cache,
default_capacity=30)
bucket = TokenBucket(fill_rate=10, capacity=20,
last_capacity=controller.last_capacity,
@ -63,17 +64,36 @@ class SubmissionAPI(APIView):
def get(self, request):
submission_id = request.GET.get("id")
if not submission_id:
return self.error("Parameter id doesn't exist.")
return self.error("Parameter id doesn't exist")
try:
submission = Submission.objects.select_related("problem").get(id=submission_id)
except Submission.DoesNotExist:
return self.error("Submission doesn't exist.")
return self.error("Submission doesn't exist")
if not submission.check_user_permission(request.user):
return self.error("No permission for this submission.")
return self.error("No permission for this submission")
if submission.problem.rule_type == ProblemRuleType.ACM:
return self.success(SubmissionSafeSerializer(submission).data)
return self.success(SubmissionModelSerializer(submission).data)
submission_data = SubmissionSafeModelSerializer(submission).data
else:
submission_data = SubmissionModelSerializer(submission).data
# 是否有权限取消共享
submission_data["can_unshare"] = submission.check_user_permission(request.user, check_share=False)
return self.success(submission_data)
@validate_serializer(ShareSubmissionSerializer)
@login_required
def put(self, request):
try:
submission = Submission.objects.select_related("problem").get(id=request.data["id"])
except Submission.DoesNotExist:
return self.error("Submission doesn't exist")
if not submission.check_user_permission(request.user, check_share=False):
return self.error("No permission to share the submission")
if submission.contest and submission.contest.status == ContestStatus.CONTEST_UNDERWAY:
return self.error("Can not share submission now")
submission.shared = request.data["shared"]
submission.save(update_fields=["shared"])
return self.success()
class SubmissionListAPI(APIView):
@ -83,7 +103,7 @@ class SubmissionListAPI(APIView):
if request.GET.get("contest_id"):
return self.error("Parameter error")
submissions = Submission.objects.filter(contest_id__isnull=True)
submissions = Submission.objects.filter(contest_id__isnull=True).select_related("problem__created_by")
problem_id = request.GET.get("problem_id")
myself = request.GET.get("myself")
result = request.GET.get("result")
@ -109,10 +129,10 @@ class ContestSubmissionListAPI(APIView):
return self.error("Limit is needed")
contest = self.contest
if contest.rule_type == ContestRuleType.OI and not contest.is_contest_admin(request.user):
if not contest.check_oi_permission(request.user):
return self.error("No permission for OI contest submissions")
submissions = Submission.objects.filter(contest_id=contest.id)
submissions = Submission.objects.filter(contest_id=contest.id).select_related("problem__created_by")
problem_id = request.GET.get("problem_id")
myself = request.GET.get("myself")
result = request.GET.get("result")
@ -133,9 +153,11 @@ class ContestSubmissionListAPI(APIView):
submissions = submissions.filter(create_time__gte=contest.start_time)
# 封榜的时候只能看到自己的提交
if contest.rule_type == ContestRuleType.ACM:
if not contest.real_time_rank and not contest.is_contest_admin(request.user):
submissions = submissions.filter(user_id=request.user.id)
data = self.paginate_data(request, submissions)
data["results"] = SubmissionListSerializer(data["results"], many=True, user=request.user).data
return self.success(data)

View File

@ -1,2 +1,2 @@
from .api import * # NOQA
from ._serializers import * # NOQA
from .api import * # NOQA

View File

@ -79,7 +79,7 @@ class APIView(View):
def success(self, data=None):
return self.response({"error": None, "data": data})
def error(self, msg, err="error"):
def error(self, msg="error", err="error"):
return self.response({"error": err, "data": msg})
def _serializer_error_to_str(self, errors):

View File

@ -3,7 +3,6 @@ from django.test.testcases import TestCase
from rest_framework.test import APIClient
from account.models import AdminType, ProblemPermission, User, UserProfile
from conf.models import WebsiteConfig
class APITestCase(TestCase):
@ -28,9 +27,6 @@ class APITestCase(TestCase):
return self.create_user(username=username, password=password, admin_type=AdminType.SUPER_ADMIN,
problem_permission=ProblemPermission.ALL, login=login)
def create_website_config(self):
return WebsiteConfig.objects.create()
def reverse(self, url_name):
return reverse(url_name)

View File

@ -1,6 +1,27 @@
from django.conf import settings
from django_redis import get_redis_connection
from django.core.cache import cache, caches # noqa
from django.conf import settings # noqa
judge_cache = get_redis_connection(settings.CACHE_JUDGE_QUEUE)
throttling_cache = get_redis_connection(settings.CACHE_THROTTLING)
default_cache = get_redis_connection("default")
from django_redis.cache import RedisCache
from django_redis.client.default import DefaultClient
class MyRedisClient(DefaultClient):
def __getattr__(self, item):
client = self.get_client(write=True)
return getattr(client, item)
def redis_incr(self, key, count=1):
"""
django 默认的 incr key 不存在时候会抛异常
"""
client = self.get_client(write=True)
return client.incr(key, count)
class MyRedisCache(RedisCache):
def __init__(self, server, params):
super().__init__(server, params)
self._client_cls = MyRedisClient
def __getattr__(self, item):
return getattr(self.client, item)

View File

@ -1,4 +1,28 @@
class Choices:
@classmethod
def choices(cls):
d = cls.__dict__
return [d[item] for item in d.keys() if not item.startswith("__")]
class ContestType:
PUBLIC_CONTEST = "Public"
PASSWORD_PROTECTED_CONTEST = "Password Protected"
class ContestStatus:
CONTEST_NOT_START = "1"
CONTEST_ENDED = "-1"
CONTEST_UNDERWAY = "0"
class ContestRuleType(Choices):
ACM = "ACM"
OI = "OI"
class CacheKey:
waiting_queue = "waiting_queue"
contest_rank_cache = "contest_rank_cache_"
contest_rank_cache = "contest_rank_cache"
website_config = "website_config"
option = "option"

View File

@ -1,3 +1,4 @@
from django.contrib.postgres.fields import JSONField # NOQA
from django.db import models
from utils.xss_filter import XssHtml

View File

@ -1,35 +1,9 @@
import logging
import random
import datetime
from io import BytesIO
import random
from base64 import b64encode
from io import BytesIO
from django.utils.crypto import get_random_string
from envelopes import Envelope
from conf.models import SMTPConfig
logger = logging.getLogger(__name__)
def send_email(from_name, to_email, to_name, subject, content):
smtp = SMTPConfig.objects.first()
if not smtp:
return
envlope = Envelope(from_addr=(smtp.email, from_name),
to_addr=(to_email, to_name),
subject=subject,
html_body=content)
try:
envlope.send(smtp.server,
login=smtp.email,
password=smtp.password,
port=smtp.port,
tls=smtp.tls)
return True
except Exception as e:
logger.exception(e)
return False
def rand_str(length=32, type="lower_hex"):

View File

@ -26,11 +26,8 @@ Cannot defense xss in browser which is belowed IE7
浏览器版本IE7+ 或其他浏览器无法防御IE6及以下版本浏览器中的XSS
"""
import re
try:
import copy
from html.parser import HTMLParser
except:
from HTMLParser import HTMLParser
class XssHtml(HTMLParser):
@ -163,7 +160,7 @@ class XssHtml(HTMLParser):
else:
other = []
if attrs:
for (key, value) in attrs.items():
for key, value in copy.deepcopy(attrs).items():
if key not in self.common_attrs + other:
del attrs[key]
return attrs