Merge branch 'zemal_dev'

This commit is contained in:
zema1 2017-11-28 16:22:42 +08:00
commit 7ce13911a7
129 changed files with 4740 additions and 999 deletions

View File

@ -1,8 +1,9 @@
[flake8] [flake8]
exclude = exclude =
xss_filter.py, xss_filter.py,
migrations/, */migrations/,
*settings.py *settings.py
*/apps.py
max-line-length = 180 max-line-length = 180
inline-quotes = " inline-quotes = "
no-accept-encodings = True no-accept-encodings = True

23
.gitignore vendored
View File

@ -54,21 +54,18 @@ db.db
#*.out #*.out
*.sqlite3 *.sqlite3
.DS_Store .DS_Store
log/
static/release/css
static/release/js
static/release/img
static/src/upload_image/*
build.txt build.txt
tmp/ tmp/
test_case/
release/
upload/
custom_settings.py custom_settings.py
docker-compose.yml
*.zip *.zip
rsyncd.passwd
node_modules/ data/log/*
update.sh !data/log/.gitkeep
ssh.sh data/test_case/*
!data/test_case/.gitkeep
data/ssl/*
!data/ssl/.gitkeep
data/public/upload/*
!data/public/upload/.gitkeep
data/public/avatar/*
!data/public/avatar/default.png

View File

@ -1 +1 @@
3.5.0 3.6.2

View File

@ -1,14 +1,19 @@
language: python language: python
python: python:
- "3.5" - "3.6"
services:
- redis-server
- docker
before_install:
- docker pull postgres:10
- docker run -it -d -e POSTGRES_DB=onlinejudge -e POSTGRES_USER=onlinejudge -e POSTGRES_PASSWORD=onlinejudge -p 127.0.0.1:5433:5432 postgres:10
install: install:
- pip install -r deploy/requirements.txt - pip install -r deploy/requirements.txt
- mkdir log test_case upload
- cp oj/custom_settings.example.py oj/custom_settings.py - cp oj/custom_settings.example.py oj/custom_settings.py
- echo "SECRET_KEY=\"`cat /dev/urandom | head -1 | md5sum | head -c 32`\"" >> oj/custom_settings.py - echo "SECRET_KEY=\"`cat /dev/urandom | head -1 | md5sum | head -c 32`\"" >> oj/custom_settings.py
- python manage.py migrate - python manage.py migrate
- python manage.py initadmin
script: script:
- docker ps -a
- flake8 . - flake8 .
- coverage run --include="$PWD/*" manage.py test - coverage run --include="$PWD/*" manage.py test
- coverage report - coverage report

15
Dockerfile Normal file
View File

@ -0,0 +1,15 @@
FROM python:3.6-alpine3.6
ENV OJ_ENV production
ADD . /app
WORKDIR /app
RUN printf "https://mirrors.tuna.tsinghua.edu.cn/alpine/v3.6/community/\nhttps://mirrors.tuna.tsinghua.edu.cn/alpine/v3.6/main/" > /etc/apk/repositories && \
apk add --update --no-cache build-base nginx openssl curl unzip supervisor jpeg-dev zlib-dev postgresql-dev freetype-dev && \
pip install --no-cache-dir -r /app/deploy/requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple && \
apk del build-base --purge
RUN curl -L $(curl -s https://api.github.com/repos/QingdaoU/OnlineJudgeFE/releases/latest | grep /dist.zip | cut -d '"' -f 4) -o dist.zip && \
unzip dist.zip && \
rm dist.zip
CMD sh /app/deploy/run.sh

View File

@ -4,6 +4,8 @@ from utils.api import JSONResponse
from .models import ProblemPermission from .models import ProblemPermission
from contest.models import Contest, ContestType, ContestStatus, ContestRuleType
class BasePermissionDecorator(object): class BasePermissionDecorator(object):
def __init__(self, func): def __init__(self, func):
@ -23,7 +25,7 @@ class BasePermissionDecorator(object):
return self.error("Your account is disabled") return self.error("Your account is disabled")
return self.func(*args, **kwargs) return self.func(*args, **kwargs)
else: else:
return self.error("Please login in first") return self.error("Please login first")
def check_permission(self): def check_permission(self):
raise NotImplementedError() raise NotImplementedError()
@ -53,3 +55,56 @@ class problem_permission_required(admin_role_required):
if self.request.user.problem_permission == ProblemPermission.NONE: if self.request.user.problem_permission == ProblemPermission.NONE:
return False return False
return True return True
def check_contest_permission(check_type="details"):
"""
只供Class based view 使用检查用户是否有权进入该contest, check_type 可选 details, problems, ranks, submissions
若通过验证在view中可通过self.contest获得该contest
"""
def decorator(func):
def _check_permission(*args, **kwargs):
self = args[0]
request = args[1]
user = request.user
if kwargs.get("contest_id"):
contest_id = kwargs.pop("contest_id")
else:
contest_id = request.GET.get("contest_id")
if not contest_id:
return self.error("Parameter contest_id doesn't exist.")
try:
# use self.contest to avoid query contest again in view.
self.contest = Contest.objects.select_related("created_by").get(id=contest_id, visible=True)
except Contest.DoesNotExist:
return self.error("Contest %s doesn't exist" % contest_id)
# creator or owner
if user.is_authenticated() and user.is_contest_admin(self.contest):
return func(*args, **kwargs)
if self.contest.contest_type == ContestType.PASSWORD_PROTECTED_CONTEST:
# Anonymous
if not user.is_authenticated():
return self.error("Please login first.")
# password error
if ("accessible_contests" not in request.session) or \
(self.contest.id not in request.session["accessible_contests"]):
return self.error("Password is required.")
# regular user get contest problems, ranks etc. before contest started
if self.contest.status == ContestStatus.CONTEST_NOT_START and check_type != "details":
return self.error("Contest has not started yet.")
# check does user have permission to get ranks, submissions in OI Contest
if self.contest.status == ContestStatus.CONTEST_UNDERWAY and self.contest.rule_type == ContestRuleType.OI:
if not self.contest.real_time_rank and (check_type == "ranks" or check_type == "submissions"):
return self.error(f"No permission to get {check_type}")
return func(*args, **kwargs)
return _check_permission
return decorator

View File

@ -1,34 +1,50 @@
import time from django.db import connection
from django.utils.timezone import now
import pytz from django.utils.deprecation import MiddlewareMixin
from django.contrib import auth
from django.utils import timezone
from django.utils.translation import ugettext as _
from utils.api import JSONResponse from utils.api import JSONResponse
from account.models import User
class SessionSecurityMiddleware(object): class APITokenAuthMiddleware(MiddlewareMixin):
def process_request(self, request): def process_request(self, request):
if request.user.is_authenticated() and request.user.is_admin_role(): appkey = request.META.get("HTTP_APPKEY")
if "last_activity" in request.session: if appkey:
# 24 hours passed since last visit try:
if time.time() - request.session["last_activity"] >= 24 * 60 * 60: request.user = User.objects.get(open_api_appkey=appkey, open_api=True, is_disabled=False)
auth.logout(request) request.csrf_processing_done = True
return JSONResponse.response({"error": "login-required", "data": _("Please login in first")}) except User.DoesNotExist:
# update last active time pass
request.session["last_activity"] = time.time()
class AdminRoleRequiredMiddleware(object): class SessionRecordMiddleware(MiddlewareMixin):
def process_request(self, request):
if request.user.is_authenticated():
session = request.session
session["user_agent"] = request.META.get("HTTP_USER_AGENT", "")
session["ip"] = request.META.get("HTTP_X_REAL_IP", request.META.get("REMOTE_ADDR"))
session["last_activity"] = now()
user_sessions = request.user.session_keys
if session.session_key not in user_sessions:
user_sessions.append(session.session_key)
request.user.save()
class AdminRoleRequiredMiddleware(MiddlewareMixin):
def process_request(self, request): def process_request(self, request):
path = request.path_info path = request.path_info
if path.startswith("/admin/") or path.startswith("/api/admin/"): if path.startswith("/admin/") or path.startswith("/api/admin/"):
if not(request.user.is_authenticated() and request.user.is_admin_role()): if not (request.user.is_authenticated() and request.user.is_admin_role()):
return JSONResponse.response({"error": "login-required", "data": _("Please login in first")}) return JSONResponse.response({"error": "login-required", "data": "Please login in first"})
class TimezoneMiddleware(object): class LogSqlMiddleware(MiddlewareMixin):
def process_request(self, request): def process_response(self, request, response):
if request.user.is_authenticated(): print("\033[94m", "#" * 30, "\033[0m")
timezone.activate(pytz.timezone(request.user.userprofile.time_zone)) time_threshold = 0.03
for query in connection.queries:
if float(query["time"]) > time_threshold:
print("\033[93m", query, "\n", "-" * 30, "\033[0m")
else:
print(query, "\n", "-" * 30)
return response

View File

@ -50,7 +50,7 @@ class Migration(migrations.Migration):
fields=[ fields=[
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('problems_status', jsonfield.fields.JSONField(default={})), ('problems_status', jsonfield.fields.JSONField(default={})),
('avatar', models.CharField(default=account.models._random_avatar, max_length=50)), ('avatar', models.CharField(default="default.png", max_length=50)),
('blog', models.URLField(blank=True, null=True)), ('blog', models.URLField(blank=True, null=True)),
('mood', models.CharField(blank=True, max_length=200, null=True)), ('mood', models.CharField(blank=True, max_length=200, null=True)),
('accepted_problem_number', models.IntegerField(default=0)), ('accepted_problem_number', models.IntegerField(default=0)),

View File

@ -0,0 +1,29 @@
# -*- coding: utf-8 -*-
# Generated by Django 1.9.6 on 2017-08-20 02:03
from __future__ import unicode_literals
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('account', '0002_auto_20170209_1028'),
]
operations = [
migrations.AddField(
model_name='userprofile',
name='total_score',
field=models.BigIntegerField(default=0),
),
migrations.RenameField(
model_name='userprofile',
old_name='accepted_problem_number',
new_name='accepted_number',
),
migrations.RemoveField(
model_name='userprofile',
name='time_zone',
)
]

View File

@ -0,0 +1,39 @@
# -*- coding: utf-8 -*-
# Generated by Django 1.11.4 on 2017-08-30 11:54
from __future__ import unicode_literals
from django.db import migrations, models
import jsonfield.fields
class Migration(migrations.Migration):
dependencies = [
('account', '0003_userprofile_total_score'),
]
operations = [
migrations.RenameField(
model_name='userprofile',
old_name='problems_status',
new_name='acm_problems_status',
),
migrations.AddField(
model_name='userprofile',
name='oi_problems_status',
field=jsonfield.fields.JSONField(default={}),
),
migrations.RemoveField(
model_name='user',
name='real_name',
),
migrations.RemoveField(
model_name='userprofile',
name='student_id',
),
migrations.AddField(
model_name='userprofile',
name='real_name',
field=models.CharField(max_length=30, blank=True, null=True),
),
]

View File

@ -0,0 +1,36 @@
# -*- coding: utf-8 -*-
# Generated by Django 1.11.4 on 2017-09-16 06:22
from __future__ import unicode_literals
from django.db import migrations, models
import jsonfield.fields
class Migration(migrations.Migration):
dependencies = [
('account', '0005_auto_20170830_1154'),
]
operations = [
migrations.AddField(
model_name='user',
name='session_keys',
field=jsonfield.fields.JSONField(default=[]),
),
migrations.RenameField(
model_name='userprofile',
old_name='phone_number',
new_name='github',
),
migrations.AlterField(
model_name='userprofile',
name='avatar',
field=models.CharField(default='/static/avatar/default.png', max_length=50),
),
migrations.AlterField(
model_name='userprofile',
name='github',
field=models.CharField(blank=True, max_length=50, null=True),
),
]

View File

@ -0,0 +1,105 @@
# -*- coding: utf-8 -*-
# Generated by Django 1.11.4 on 2017-10-11 12:14
from __future__ import unicode_literals
import django.contrib.postgres.fields.jsonb
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('account', '0006_user_session_keys'),
]
operations = [
migrations.RemoveField(
model_name='userprofile',
name='language',
),
migrations.AlterField(
model_name='user',
name='admin_type',
field=models.CharField(default='Regular User', max_length=32),
),
migrations.AlterField(
model_name='user',
name='auth_token',
field=models.CharField(max_length=32, null=True),
),
migrations.AlterField(
model_name='user',
name='email',
field=models.EmailField(max_length=64, null=True),
),
migrations.AlterField(
model_name='user',
name='open_api_appkey',
field=models.CharField(max_length=32, null=True),
),
migrations.AlterField(
model_name='user',
name='problem_permission',
field=models.CharField(default='None', max_length=32),
),
migrations.AlterField(
model_name='user',
name='reset_password_token',
field=models.CharField(max_length=32, null=True),
),
migrations.AlterField(
model_name='user',
name='session_keys',
field=django.contrib.postgres.fields.jsonb.JSONField(default=list),
),
migrations.AlterField(
model_name='user',
name='tfa_token',
field=models.CharField(max_length=32, null=True),
),
migrations.AlterField(
model_name='user',
name='username',
field=models.CharField(max_length=32, unique=True),
),
migrations.AlterField(
model_name='userprofile',
name='acm_problems_status',
field=django.contrib.postgres.fields.jsonb.JSONField(default=dict),
),
migrations.AlterField(
model_name='userprofile',
name='avatar',
field=models.CharField(default='/static/avatar/default.png', max_length=256),
),
migrations.AlterField(
model_name='userprofile',
name='github',
field=models.CharField(blank=True, max_length=64, null=True),
),
migrations.AlterField(
model_name='userprofile',
name='major',
field=models.CharField(blank=True, max_length=64, null=True),
),
migrations.AlterField(
model_name='userprofile',
name='mood',
field=models.CharField(blank=True, max_length=256, null=True),
),
migrations.AlterField(
model_name='userprofile',
name='oi_problems_status',
field=django.contrib.postgres.fields.jsonb.JSONField(default=dict),
),
migrations.AlterField(
model_name='userprofile',
name='real_name',
field=models.CharField(blank=True, max_length=32, null=True),
),
migrations.AlterField(
model_name='userprofile',
name='school',
field=models.CharField(blank=True, max_length=64, null=True),
),
]

View File

@ -0,0 +1,20 @@
# -*- coding: utf-8 -*-
# Generated by Django 1.11.4 on 2017-11-25 15:14
from __future__ import unicode_literals
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('account', '0008_auto_20171011_1214'),
]
operations = [
migrations.AlterField(
model_name='userprofile',
name='avatar',
field=models.CharField(default='/public/avatar/default.png', max_length=256),
),
]

View File

@ -1,6 +1,7 @@
from django.contrib.auth.models import AbstractBaseUser from django.contrib.auth.models import AbstractBaseUser
from django.conf import settings
from django.db import models from django.db import models
from jsonfield import JSONField from utils.models import JSONField
class AdminType(object): class AdminType(object):
@ -9,11 +10,6 @@ class AdminType(object):
SUPER_ADMIN = "Super Admin" SUPER_ADMIN = "Super Admin"
class ProblemSolutionStatus(object):
ACCEPTED = 1
PENDING = 2
class ProblemPermission(object): class ProblemPermission(object):
NONE = "None" NONE = "None"
OWN = "Own" OWN = "Own"
@ -24,26 +20,26 @@ class UserManager(models.Manager):
use_in_migrations = True use_in_migrations = True
def get_by_natural_key(self, username): def get_by_natural_key(self, username):
return self.get(**{self.model.USERNAME_FIELD: username}) return self.get(**{f"{self.model.USERNAME_FIELD}__iexact": username})
class User(AbstractBaseUser): class User(AbstractBaseUser):
username = models.CharField(max_length=30, unique=True) username = models.CharField(max_length=32, unique=True)
real_name = models.CharField(max_length=30, null=True) email = models.EmailField(max_length=64, null=True)
email = models.EmailField(max_length=254, null=True)
create_time = models.DateTimeField(auto_now_add=True, null=True) create_time = models.DateTimeField(auto_now_add=True, null=True)
# One of UserType # One of UserType
admin_type = models.CharField(max_length=24, default=AdminType.REGULAR_USER) admin_type = models.CharField(max_length=32, default=AdminType.REGULAR_USER)
problem_permission = models.CharField(max_length=24, default=ProblemPermission.NONE) problem_permission = models.CharField(max_length=32, default=ProblemPermission.NONE)
reset_password_token = models.CharField(max_length=40, null=True) reset_password_token = models.CharField(max_length=32, null=True)
reset_password_token_expire_time = models.DateTimeField(null=True) reset_password_token_expire_time = models.DateTimeField(null=True)
# SSO auth token # SSO auth token
auth_token = models.CharField(max_length=40, null=True) auth_token = models.CharField(max_length=32, null=True)
two_factor_auth = models.BooleanField(default=False) two_factor_auth = models.BooleanField(default=False)
tfa_token = models.CharField(max_length=40, null=True) tfa_token = models.CharField(max_length=32, null=True)
session_keys = JSONField(default=list)
# open api key # open api key
open_api = models.BooleanField(default=False) open_api = models.BooleanField(default=False)
open_api_appkey = models.CharField(max_length=35, null=True) open_api_appkey = models.CharField(max_length=32, null=True)
is_disabled = models.BooleanField(default=False) is_disabled = models.BooleanField(default=False)
USERNAME_FIELD = "username" USERNAME_FIELD = "username"
@ -63,42 +59,59 @@ class User(AbstractBaseUser):
def can_mgmt_all_problem(self): def can_mgmt_all_problem(self):
return self.problem_permission == ProblemPermission.ALL return self.problem_permission == ProblemPermission.ALL
def is_contest_admin(self, contest):
return self.is_authenticated() and (contest.created_by == self or self.admin_type == AdminType.SUPER_ADMIN)
class Meta: class Meta:
db_table = "user" db_table = "user"
def _random_avatar():
import random
return "/static/img/avatar/avatar-" + str(random.randint(1, 20)) + ".png"
class UserProfile(models.Model): class UserProfile(models.Model):
user = models.OneToOneField(User) user = models.OneToOneField(User, on_delete=models.CASCADE)
# Store user problem solution status with json string format # acm_problems_status examples:
# {"problems": {1: ProblemSolutionStatus.ACCEPTED}, "contest_problems": {20: ProblemSolutionStatus.PENDING)} # {
problems_status = JSONField(default={}) # "problems": {
avatar = models.CharField(max_length=50, default=_random_avatar) # "1": {
# "status": JudgeStatus.ACCEPTED,
# "_id": "1000"
# }
# },
# "contest_problems": {
# "1": {
# "status": JudgeStatus.ACCEPTED,
# "_id": "1000"
# }
# }
# }
acm_problems_status = JSONField(default=dict)
# like acm_problems_status, merely add "score" field
oi_problems_status = JSONField(default=dict)
real_name = models.CharField(max_length=32, blank=True, null=True)
avatar = models.CharField(max_length=256, default=f"{settings.AVATAR_URI_PREFIX}/default.png")
blog = models.URLField(blank=True, null=True) blog = models.URLField(blank=True, null=True)
mood = models.CharField(max_length=200, blank=True, null=True) mood = models.CharField(max_length=256, blank=True, null=True)
accepted_problem_number = models.IntegerField(default=0) github = models.CharField(max_length=64, blank=True, null=True)
school = models.CharField(max_length=64, blank=True, null=True)
major = models.CharField(max_length=64, blank=True, null=True)
# for ACM
accepted_number = models.IntegerField(default=0)
# for OI
total_score = models.BigIntegerField(default=0)
submission_number = models.IntegerField(default=0) submission_number = models.IntegerField(default=0)
phone_number = models.CharField(max_length=15, blank=True, null=True)
school = models.CharField(max_length=200, blank=True, null=True)
major = models.CharField(max_length=200, blank=True, null=True)
student_id = models.CharField(max_length=15, blank=True, null=True)
time_zone = models.CharField(max_length=32, blank=True, null=True)
language = models.CharField(max_length=32, blank=True, null=True)
def add_accepted_problem_number(self): def add_accepted_problem_number(self):
self.accepted_problem_number = models.F("accepted_problem_number") + 1 self.accepted_number = models.F("accepted_number") + 1
self.save() self.save()
def add_submission_number(self): def add_submission_number(self):
self.submission_number = models.F("submission_number") + 1 self.submission_number = models.F("submission_number") + 1
self.save() self.save()
def minus_accepted_problem_number(self): # 计算总分时, 应先减掉上次该题所得分数, 然后再加上本次所得分数
self.accepted_problem_number = models.F("accepted_problem_number") - 1 def add_score(self, this_time_score, last_time_score=None):
last_time_score = last_time_score or 0
self.total_score = models.F("total_score") - last_time_score + this_time_score
self.save() self.save()
class Meta: class Meta:

View File

@ -1,25 +1,51 @@
from utils.api import DateTimeTZField, serializers from django import forms
from .models import AdminType, ProblemPermission, User from utils.api import DateTimeTZField, serializers, UsernameSerializer
from .models import AdminType, ProblemPermission, User, UserProfile
class UserLoginSerializer(serializers.Serializer): class UserLoginSerializer(serializers.Serializer):
username = serializers.CharField(max_length=30) username = serializers.CharField()
password = serializers.CharField(max_length=30) password = serializers.CharField()
tfa_code = serializers.CharField(min_length=6, max_length=6, required=False, allow_null=True) tfa_code = serializers.CharField(required=False, allow_blank=True)
class UsernameOrEmailCheckSerializer(serializers.Serializer):
username = serializers.CharField(required=False)
email = serializers.EmailField(required=False)
class UserRegisterSerializer(serializers.Serializer): class UserRegisterSerializer(serializers.Serializer):
username = serializers.CharField(max_length=30) username = serializers.CharField(max_length=32)
password = serializers.CharField(max_length=30, min_length=6) password = serializers.CharField(min_length=6)
email = serializers.EmailField(max_length=254) email = serializers.EmailField(max_length=64)
captcha = serializers.CharField(max_length=4, min_length=4) captcha = serializers.CharField()
class UserChangePasswordSerializer(serializers.Serializer): class UserChangePasswordSerializer(serializers.Serializer):
old_password = serializers.CharField() old_password = serializers.CharField()
new_password = serializers.CharField(max_length=30, min_length=6) new_password = serializers.CharField(min_length=6)
captcha = serializers.CharField(max_length=4, min_length=4) tfa_code = serializers.CharField(required=False, allow_blank=True)
class UserChangeEmailSerializer(serializers.Serializer):
password = serializers.CharField()
new_email = serializers.EmailField(max_length=64)
tfa_code = serializers.CharField(required=False, allow_blank=True)
class GenerateUserSerializer(serializers.Serializer):
prefix = serializers.CharField(max_length=16, allow_blank=True)
suffix = serializers.CharField(max_length=16, allow_blank=True)
number_from = serializers.IntegerField()
number_to = serializers.IntegerField()
password_length = serializers.IntegerField(max_value=16, default=8)
class ImportUserSeralizer(serializers.Serializer):
users = serializers.ListField(
child=serializers.ListField(child=serializers.CharField(max_length=64)))
class UserSerializer(serializers.ModelSerializer): class UserSerializer(serializers.ModelSerializer):
@ -28,16 +54,33 @@ class UserSerializer(serializers.ModelSerializer):
class Meta: class Meta:
model = User model = User
fields = ["id", "username", "real_name", "email", "admin_type", "problem_permission", fields = ["id", "username", "email", "admin_type", "problem_permission",
"create_time", "last_login", "two_factor_auth", "open_api", "is_disabled"] "create_time", "last_login", "two_factor_auth", "open_api", "is_disabled"]
class UserProfileSerializer(serializers.ModelSerializer):
user = UserSerializer()
acm_problems_status = serializers.JSONField()
oi_problems_status = serializers.JSONField()
class Meta:
model = UserProfile
fields = "__all__"
class UserInfoSerializer(serializers.ModelSerializer):
acm_problems_status = serializers.JSONField()
oi_problems_status = serializers.JSONField()
class Meta:
model = UserProfile
class EditUserSerializer(serializers.Serializer): class EditUserSerializer(serializers.Serializer):
id = serializers.IntegerField() id = serializers.IntegerField()
username = serializers.CharField(max_length=30) username = serializers.CharField(max_length=32)
real_name = serializers.CharField(max_length=30) password = serializers.CharField(min_length=6, allow_blank=True, required=False, default=None)
password = serializers.CharField(max_length=30, min_length=6, allow_blank=True, required=False, default=None) email = serializers.EmailField(max_length=64)
email = serializers.EmailField(max_length=254)
admin_type = serializers.ChoiceField(choices=(AdminType.REGULAR_USER, AdminType.ADMIN, AdminType.SUPER_ADMIN)) admin_type = serializers.ChoiceField(choices=(AdminType.REGULAR_USER, AdminType.ADMIN, AdminType.SUPER_ADMIN))
problem_permission = serializers.ChoiceField(choices=(ProblemPermission.NONE, ProblemPermission.OWN, problem_permission = serializers.ChoiceField(choices=(ProblemPermission.NONE, ProblemPermission.OWN,
ProblemPermission.ALL)) ProblemPermission.ALL))
@ -46,21 +89,42 @@ class EditUserSerializer(serializers.Serializer):
is_disabled = serializers.BooleanField() is_disabled = serializers.BooleanField()
class EditUserProfileSerializer(serializers.Serializer):
real_name = serializers.CharField(max_length=32, allow_null=True, required=False)
avatar = serializers.CharField(max_length=256, allow_null=True, allow_blank=True, required=False)
blog = serializers.URLField(max_length=256, allow_null=True, allow_blank=True, required=False)
mood = serializers.CharField(max_length=256, allow_null=True, allow_blank=True, required=False)
github = serializers.CharField(max_length=64, allow_null=True, allow_blank=True, required=False)
school = serializers.CharField(max_length=64, allow_null=True, allow_blank=True, required=False)
major = serializers.CharField(max_length=64, allow_null=True, allow_blank=True, required=False)
class ApplyResetPasswordSerializer(serializers.Serializer): class ApplyResetPasswordSerializer(serializers.Serializer):
email = serializers.EmailField() email = serializers.EmailField()
captcha = serializers.CharField(max_length=4, min_length=4) captcha = serializers.CharField()
class ResetPasswordSerializer(serializers.Serializer): class ResetPasswordSerializer(serializers.Serializer):
token = serializers.CharField(min_length=1, max_length=40) token = serializers.CharField()
password = serializers.CharField(min_length=6, max_length=30) password = serializers.CharField(min_length=6)
captcha = serializers.CharField(max_length=4, min_length=4) captcha = serializers.CharField()
class SSOSerializer(serializers.Serializer): class SSOSerializer(serializers.Serializer):
appkey = serializers.CharField(max_length=35) appkey = serializers.CharField()
token = serializers.CharField(max_length=40) token = serializers.CharField()
class TwoFactorAuthCodeSerializer(serializers.Serializer): class TwoFactorAuthCodeSerializer(serializers.Serializer):
code = serializers.IntegerField() code = serializers.IntegerField()
class ImageUploadForm(forms.Form):
image = forms.FileField()
class RankInfoSerializer(serializers.ModelSerializer):
user = UsernameSerializer()
class Meta:
model = UserProfile

View File

@ -1,6 +1,31 @@
from celery import shared_task import logging
from utils.shortcuts import send_email from celery import shared_task
from envelopes import Envelope
from options.options import SysOptions
logger = logging.getLogger(__name__)
def send_email(from_name, to_email, to_name, subject, content):
smtp = SysOptions.smtp_config
if not smtp:
return
envlope = Envelope(from_addr=(smtp["email"], from_name),
to_addr=(to_email, to_name),
subject=subject,
html_body=content)
try:
envlope.send(smtp["server"],
login=smtp["email"],
password=smtp["password"],
port=smtp["port"],
tls=smtp["tls"])
return True
except Exception as e:
logger.exception(e)
return False
@shared_task @shared_task

View File

@ -8,7 +8,7 @@
<tbody> <tbody>
<tr height="39" style="background-color:#50a5e6;"> <tr height="39" style="background-color:#50a5e6;">
<td style="padding-left:15px;font-family:'微软雅黑','黑体',arial;"> <td style="padding-left:15px;font-family:'微软雅黑','黑体',arial;">
{{ website_name }} 登录信息找回 {{ website_name }}
</td> </td>
</tr> </tr>
</tbody> </tbody>
@ -32,18 +32,18 @@
</tr> </tr>
<tr height="30"> <tr height="30">
<td style="padding-left:55px;padding-right:55px;font-family:'微软雅黑','黑体',arial;font-size:14px;"> <td style="padding-left:55px;padding-right:55px;font-family:'微软雅黑','黑体',arial;font-size:14px;">
您刚刚在 {{ website_name }} 申请了找回登录信息服务。 We received a request to reset your password for {{ website_name }}.
</td> </td>
</tr> </tr>
<tr height="30"> <tr height="30">
<td style="padding-left:55px;padding-right:55px;font-family:'微软雅黑','黑体',arial;font-size:14px;"> <td style="padding-left:55px;padding-right:55px;font-family:'微软雅黑','黑体',arial;font-size:14px;">
请在<span style="color:rgb(255,0,0)">30分钟</span>内点击下面链接设置您的新密码: You can use the following link to reset your password in <span style="color:rgb(255,0,0)">20 minutes.</span>
</td> </td>
</tr> </tr>
<tr height="60"> <tr height="60">
<td style="padding-left:55px;padding-right:55px;font-family:'微软雅黑','黑体',arial;font-size:14px;"> <td style="padding-left:55px;padding-right:55px;font-family:'微软雅黑','黑体',arial;font-size:14px;">
<a href="{{ link }}" target="_blank" <a href="{{ link }}" target="_blank"
style="color: rgb(255,255,255);text-decoration: none;display: block;min-height: 39px;width: 158px;line-height: 39px;background-color:rgb(80,165,230);font-size:20px;text-align:center;">重置密码</a> style="color: rgb(255,255,255);text-decoration: none;display: block;min-height: 39px;width: 158px;line-height: 39px;background-color:rgb(80,165,230);font-size:20px;text-align:center;">Reset Password</a>
</td> </td>
</tr> </tr>
<tr height="10"> <tr height="10">
@ -51,7 +51,7 @@
</tr> </tr>
<tr height="20"> <tr height="20">
<td style="padding-left:55px;padding-right:55px;font-family:'微软雅黑','黑体',arial;font-size:12px;"> <td style="padding-left:55px;padding-right:55px;font-family:'微软雅黑','黑体',arial;font-size:12px;">
如果上面的链接点击无效,请复制以下链接至浏览器的地址栏直接打开。 If the button above doesn't work, please copy the following link to your browser and press enter.
</td> </td>
</tr> </tr>
<tr height="30"> <tr height="30">
@ -63,8 +63,7 @@
</tr> </tr>
<tr height="20"> <tr height="20">
<td style="padding-left:55px;padding-right:55px;font-family:'微软雅黑','黑体',arial;font-size:12px;"> <td style="padding-left:55px;padding-right:55px;font-family:'微软雅黑','黑体',arial;font-size:12px;">
如果您没有提出过该申请,请忽略此邮件。有可能是其他用户误填了您的邮件地址,我们不会对你的帐户进行任何修改。 If you did not ask that, please ignore this email. It will expire and become useless in 20 minutes.
请不要向他人透露本邮件的内容,否则可能会导致您的账号被盗。
</td> </td>
</tr> </tr>
<tr height="20"> <tr height="20">

View File

@ -1,13 +1,19 @@
import time import time
from unittest import mock from unittest import mock
from datetime import timedelta
from copy import deepcopy
from django.contrib import auth from django.contrib import auth
from django.utils.timezone import now
from otpauth import OtpAuth from otpauth import OtpAuth
from utils.api.tests import APIClient, APITestCase from utils.api.tests import APIClient, APITestCase
from utils.shortcuts import rand_str from utils.shortcuts import rand_str
from options.options import SysOptions
from .models import AdminType, ProblemPermission, User from .models import AdminType, ProblemPermission, User
from utils.constants import ContestRuleType
class PermissionDecoratorTest(APITestCase): class PermissionDecoratorTest(APITestCase):
@ -28,6 +34,54 @@ class PermissionDecoratorTest(APITestCase):
pass pass
class DuplicateUserCheckAPITest(APITestCase):
def setUp(self):
user = self.create_user("test", "test123", login=False)
user.email = "test@test.com"
user.save()
self.url = self.reverse("check_username_or_email")
def test_duplicate_username(self):
resp = self.client.post(self.url, data={"username": "test"})
data = resp.data["data"]
self.assertEqual(data["username"], True)
resp = self.client.post(self.url, data={"username": "Test"})
self.assertEqual(resp.data["data"]["username"], True)
def test_ok_username(self):
resp = self.client.post(self.url, data={"username": "test1"})
data = resp.data["data"]
self.assertFalse(data["username"])
def test_duplicate_email(self):
resp = self.client.post(self.url, data={"email": "test@test.com"})
self.assertEqual(resp.data["data"]["email"], True)
resp = self.client.post(self.url, data={"email": "Test@Test.com"})
self.assertTrue(resp.data["data"]["email"])
def test_ok_email(self):
resp = self.client.post(self.url, data={"email": "aa@test.com"})
self.assertFalse(resp.data["data"]["email"])
class TFARequiredCheckAPITest(APITestCase):
def setUp(self):
self.url = self.reverse("tfa_required_check")
self.create_user("test", "test123", login=False)
def test_not_required_tfa(self):
resp = self.client.post(self.url, data={"username": "test"})
self.assertSuccess(resp)
self.assertEqual(resp.data["data"]["result"], False)
def test_required_tfa(self):
user = User.objects.first()
user.two_factor_auth = True
user.save()
resp = self.client.post(self.url, data={"username": "test"})
self.assertEqual(resp.data["data"]["result"], True)
class UserLoginAPITest(APITestCase): class UserLoginAPITest(APITestCase):
def setUp(self): def setUp(self):
self.username = self.password = "test" self.username = self.password = "test"
@ -49,6 +103,12 @@ class UserLoginAPITest(APITestCase):
user = auth.get_user(self.client) user = auth.get_user(self.client)
self.assertTrue(user.is_authenticated()) self.assertTrue(user.is_authenticated())
def test_login_with_correct_info_upper_username(self):
resp = self.client.post(self.login_url, data={"username": self.username.upper(), "password": self.password})
self.assertDictEqual(resp.data, {"error": None, "data": "Succeeded"})
user = auth.get_user(self.client)
self.assertTrue(user.is_authenticated())
def test_login_with_wrong_info(self): def test_login_with_wrong_info(self):
response = self.client.post(self.login_url, response = self.client.post(self.login_url,
data={"username": self.username, "password": "invalid_password"}) data={"username": self.username, "password": "invalid_password"})
@ -87,11 +147,18 @@ class UserLoginAPITest(APITestCase):
response = self.client.post(self.login_url, response = self.client.post(self.login_url,
data={"username": self.username, data={"username": self.username,
"password": self.password}) "password": self.password})
self.assertDictEqual(response.data, {"error": None, "data": "tfa_required"}) self.assertDictEqual(response.data, {"error": "error", "data": "tfa_required"})
user = auth.get_user(self.client) user = auth.get_user(self.client)
self.assertFalse(user.is_authenticated()) self.assertFalse(user.is_authenticated())
def test_user_disabled(self):
self.user.is_disabled = True
self.user.save()
resp = self.client.post(self.login_url, data={"username": self.username,
"password": self.password})
self.assertDictEqual(resp.data, {"error": "error", "data": "Your account has been disabled"})
class CaptchaTest(APITestCase): class CaptchaTest(APITestCase):
def _set_captcha(self, session): def _set_captcha(self, session):
@ -112,6 +179,11 @@ class UserRegisterAPITest(CaptchaTest):
"real_name": "real_name", "email": "test@qduoj.com", "real_name": "real_name", "email": "test@qduoj.com",
"captcha": self._set_captcha(self.client.session)} "captcha": self._set_captcha(self.client.session)}
def test_website_config_limit(self):
SysOptions.allow_register = False
resp = self.client.post(self.register_url, data=self.data)
self.assertDictEqual(resp.data, {"error": "error", "data": "Register function has been disabled by admin"})
def test_invalid_captcha(self): def test_invalid_captcha(self):
self.data["captcha"] = "****" self.data["captcha"] = "****"
response = self.client.post(self.register_url, data=self.data) response = self.client.post(self.register_url, data=self.data)
@ -142,23 +214,206 @@ class UserRegisterAPITest(CaptchaTest):
self.assertDictEqual(response.data, {"error": "error", "data": "Email already exists"}) self.assertDictEqual(response.data, {"error": "error", "data": "Email already exists"})
class UserChangePasswordAPITest(CaptchaTest): class SessionManagementAPITest(APITestCase):
def setUp(self):
self.create_user("test", "test123")
self.url = self.reverse("session_management_api")
# launch a request to provide session data
login_url = self.reverse("user_login_api")
self.client.post(login_url, data={"username": "test", "password": "test123"})
def test_get_sessions(self):
resp = self.client.get(self.url)
self.assertSuccess(resp)
data = resp.data["data"]
self.assertEqual(len(data), 1)
# def test_delete_session_key(self):
# resp = self.client.delete(self.url + "?session_key=" + self.session_key)
# self.assertSuccess(resp)
def test_delete_session_with_invalid_key(self):
resp = self.client.delete(self.url + "?session_key=aaaaaaaaaa")
self.assertDictEqual(resp.data, {"error": "error", "data": "Invalid session_key"})
class UserProfileAPITest(APITestCase):
def setUp(self):
self.url = self.reverse("user_profile_api")
def test_get_profile_without_login(self):
resp = self.client.get(self.url)
self.assertDictEqual(resp.data, {"error": None, "data": None})
def test_get_profile(self):
self.create_user("test", "test123")
resp = self.client.get(self.url)
self.assertSuccess(resp)
def test_update_profile(self):
self.create_user("test", "test123")
update_data = {"real_name": "zemal", "submission_number": 233}
resp = self.client.put(self.url, data=update_data)
self.assertSuccess(resp)
data = resp.data["data"]
self.assertEqual(data["real_name"], "zemal")
self.assertEqual(data["submission_number"], 0)
class TwoFactorAuthAPITest(APITestCase):
def setUp(self):
self.url = self.reverse("two_factor_auth_api")
self.create_user("test", "test123")
def _get_tfa_code(self):
user = User.objects.first()
code = OtpAuth(user.tfa_token).totp()
if len(str(code)) < 6:
code = (6 - len(str(code))) * "0" + str(code)
return code
def test_get_image(self):
resp = self.client.get(self.url)
self.assertSuccess(resp)
def test_open_tfa_with_invalid_code(self):
self.test_get_image()
resp = self.client.post(self.url, data={"code": "000000"})
self.assertDictEqual(resp.data, {"error": "error", "data": "Invalid code"})
def test_open_tfa_with_correct_code(self):
self.test_get_image()
code = self._get_tfa_code()
resp = self.client.post(self.url, data={"code": code})
self.assertSuccess(resp)
user = User.objects.first()
self.assertEqual(user.two_factor_auth, True)
def test_close_tfa_with_invalid_code(self):
self.test_open_tfa_with_correct_code()
resp = self.client.post(self.url, data={"code": "000000"})
self.assertDictEqual(resp.data, {"error": "error", "data": "Invalid code"})
def test_close_tfa_with_correct_code(self):
self.test_open_tfa_with_correct_code()
code = self._get_tfa_code()
resp = self.client.put(self.url, data={"code": code})
self.assertSuccess(resp)
user = User.objects.first()
self.assertEqual(user.two_factor_auth, False)
@mock.patch("account.views.oj.send_email_async.delay")
class ApplyResetPasswordAPITest(CaptchaTest):
def setUp(self):
self.create_user("test", "test123", login=False)
user = User.objects.first()
user.email = "test@oj.com"
user.save()
self.url = self.reverse("apply_reset_password_api")
self.data = {"email": "test@oj.com", "captcha": self._set_captcha(self.client.session)}
def _refresh_captcha(self):
self.data["captcha"] = self._set_captcha(self.client.session)
def test_apply_reset_password(self, send_email_delay):
resp = self.client.post(self.url, data=self.data)
self.assertSuccess(resp)
send_email_delay.assert_called()
def test_apply_reset_password_twice_in_20_mins(self, send_email_delay):
self.test_apply_reset_password()
send_email_delay.reset_mock()
self._refresh_captcha()
resp = self.client.post(self.url, data=self.data)
self.assertDictEqual(resp.data, {"error": "error", "data": "You can only reset password once per 20 minutes"})
send_email_delay.assert_not_called()
def test_apply_reset_password_again_after_20_mins(self, send_email_delay):
self.test_apply_reset_password()
user = User.objects.first()
user.reset_password_token_expire_time = now() - timedelta(minutes=21)
user.save()
self._refresh_captcha()
self.test_apply_reset_password()
class ResetPasswordAPITest(CaptchaTest):
def setUp(self):
self.create_user("test", "test123", login=False)
self.url = self.reverse("reset_password_api")
user = User.objects.first()
user.reset_password_token = "online_judge?"
user.reset_password_token_expire_time = now() + timedelta(minutes=20)
user.save()
self.data = {"token": user.reset_password_token,
"captcha": self._set_captcha(self.client.session),
"password": "test456"}
def test_reset_password_with_correct_token(self):
resp = self.client.post(self.url, data=self.data)
self.assertSuccess(resp)
self.assertTrue(self.client.login(username="test", password="test456"))
def test_reset_password_with_invalid_token(self):
self.data["token"] = "aaaaaaaaaaa"
resp = self.client.post(self.url, data=self.data)
self.assertDictEqual(resp.data, {"error": "error", "data": "Token does not exist"})
def test_reset_password_with_expired_token(self):
user = User.objects.first()
user.reset_password_token_expire_time = now() - timedelta(seconds=30)
user.save()
resp = self.client.post(self.url, data=self.data)
self.assertDictEqual(resp.data, {"error": "error", "data": "Token has expired"})
class UserChangeEmailAPITest(APITestCase):
def setUp(self):
self.url = self.reverse("user_change_email_api")
self.user = self.create_user("test", "test123")
self.new_mail = "test@oj.com"
self.data = {"password": "test123", "new_email": self.new_mail}
def test_change_email_success(self):
resp = self.client.post(self.url, data=self.data)
self.assertSuccess(resp)
def test_wrong_password(self):
self.data["password"] = "aaaa"
resp = self.client.post(self.url, data=self.data)
self.assertDictEqual(resp.data, {"error": "error", "data": "Wrong password"})
def test_duplicate_email(self):
u = self.create_user("aa", "bb", login=False)
u.email = self.new_mail
u.save()
resp = self.client.post(self.url, data=self.data)
self.assertDictEqual(resp.data, {"error": "error", "data": "The email is owned by other account"})
class UserChangePasswordAPITest(APITestCase):
def setUp(self): def setUp(self):
self.client = APIClient()
self.url = self.reverse("user_change_password_api") self.url = self.reverse("user_change_password_api")
# Create user at first # Create user at first
self.username = "test_user" self.username = "test_user"
self.old_password = "testuserpassword" self.old_password = "testuserpassword"
self.new_password = "new_password" self.new_password = "new_password"
self.create_user(username=self.username, password=self.old_password, login=False) self.user = self.create_user(username=self.username, password=self.old_password, login=False)
self.data = {"old_password": self.old_password, "new_password": self.new_password, self.data = {"old_password": self.old_password, "new_password": self.new_password}
"captcha": self._set_captcha(self.client.session)}
def _get_tfa_code(self):
user = User.objects.first()
code = OtpAuth(user.tfa_token).totp()
if len(str(code)) < 6:
code = (6 - len(str(code))) * "0" + str(code)
return code
def test_login_required(self): def test_login_required(self):
response = self.client.post(self.url, data=self.data) response = self.client.post(self.url, data=self.data)
self.assertEqual(response.data, {"error": "permission-denied", "data": "Please login in first"}) self.assertEqual(response.data, {"error": "permission-denied", "data": "Please login first"})
def test_valid_ola_password(self): def test_valid_ola_password(self):
self.assertTrue(self.client.login(username=self.username, password=self.old_password)) self.assertTrue(self.client.login(username=self.username, password=self.old_password))
@ -172,6 +427,58 @@ class UserChangePasswordAPITest(CaptchaTest):
response = self.client.post(self.url, data=self.data) response = self.client.post(self.url, data=self.data)
self.assertEqual(response.data, {"error": "error", "data": "Invalid old password"}) self.assertEqual(response.data, {"error": "error", "data": "Invalid old password"})
def test_tfa_code_required(self):
self.user.two_factor_auth = True
self.user.tfa_token = "tfa_token"
self.user.save()
self.assertTrue(self.client.login(username=self.username, password=self.old_password))
self.data["tfa_code"] = rand_str(6)
resp = self.client.post(self.url, data=self.data)
self.assertEqual(resp.data, {"error": "error", "data": "Invalid two factor verification code"})
self.data["tfa_code"] = self._get_tfa_code()
resp = self.client.post(self.url, data=self.data)
self.assertSuccess(resp)
class UserRankAPITest(APITestCase):
def setUp(self):
self.url = self.reverse("user_rank_api")
self.create_user("test1", "test123", login=False)
self.create_user("test2", "test123", login=False)
test1 = User.objects.get(username="test1")
profile1 = test1.userprofile
profile1.submission_number = 10
profile1.accepted_number = 10
profile1.total_score = 240
profile1.save()
test2 = User.objects.get(username="test2")
profile2 = test2.userprofile
profile2.submission_number = 15
profile2.accepted_number = 10
profile2.total_score = 700
profile2.save()
def test_get_acm_rank(self):
resp = self.client.get(self.url, data={"rule": ContestRuleType.ACM})
self.assertSuccess(resp)
data = resp.data["data"]["results"]
self.assertEqual(data[0]["user"]["username"], "test1")
self.assertEqual(data[1]["user"]["username"], "test2")
def test_get_oi_rank(self):
resp = self.client.get(self.url, data={"rule": ContestRuleType.OI})
self.assertSuccess(resp)
data = resp.data["data"]["results"]
self.assertEqual(data[0]["user"]["username"], "test2")
self.assertEqual(data[1]["user"]["username"], "test1")
class ProfileProblemDisplayIDRefreshAPITest(APITestCase):
def setUp(self):
pass
class AdminUserTest(APITestCase): class AdminUserTest(APITestCase):
def setUp(self): def setUp(self):
@ -194,7 +501,6 @@ class AdminUserTest(APITestCase):
resp_data = response.data["data"] resp_data = response.data["data"]
self.assertEqual(resp_data["username"], self.username) self.assertEqual(resp_data["username"], self.username)
self.assertEqual(resp_data["email"], "test@qq.com") self.assertEqual(resp_data["email"], "test@qq.com")
self.assertEqual(resp_data["real_name"], "test_name")
self.assertEqual(resp_data["open_api"], True) self.assertEqual(resp_data["open_api"], True)
self.assertEqual(resp_data["two_factor_auth"], False) self.assertEqual(resp_data["two_factor_auth"], False)
self.assertEqual(resp_data["is_disabled"], False) self.assertEqual(resp_data["is_disabled"], False)
@ -249,3 +555,75 @@ class AdminUserTest(APITestCase):
# if `openapi_app_key` is not None, the value is not changed # if `openapi_app_key` is not None, the value is not changed
self.assertTrue(resp_data["open_api"]) self.assertTrue(resp_data["open_api"])
self.assertEqual(User.objects.get(id=self.regular_user.id).open_api_appkey, key) self.assertEqual(User.objects.get(id=self.regular_user.id).open_api_appkey, key)
def test_import_users(self):
data = {"users": [["user1", "pass1", "eami1@e.com"],
["user2", "pass3", "eamil3@e.com"]]
}
resp = self.client.post(self.url, data)
self.assertSuccess(resp)
# successfully created 2 users
self.assertEqual(User.objects.all().count(), 4)
def test_import_duplicate_user(self):
data = {"users": [["user1", "pass1", "eami1@e.com"],
["user1", "pass1", "eami1@e.com"]]
}
resp = self.client.post(self.url, data)
self.assertFailed(resp, "DETAIL: Key (username)=(user1) already exists.")
# no user is created
self.assertEqual(User.objects.all().count(), 2)
def test_delete_users(self):
self.test_import_users()
user_ids = User.objects.filter(username__in=["user1", "user2"]).values_list("id", flat=True)
user_ids = ",".join([str(id) for id in user_ids])
resp = self.client.delete(self.url + "?id=" + user_ids)
self.assertSuccess(resp)
self.assertEqual(User.objects.all().count(), 2)
class GenerateUserAPITest(APITestCase):
def setUp(self):
self.create_super_admin()
self.url = self.reverse("generate_user_api")
self.data = {
"number_from": 100, "number_to": 105,
"prefix": "pre", "suffix": "suf",
"default_email": "test@test.com",
"password_length": 8
}
def test_error_case(self):
data = deepcopy(self.data)
data["prefix"] = "t" * 16
data["suffix"] = "s" * 14
resp = self.client.post(self.url, data=data)
self.assertEqual(resp.data["data"], "Username should not more than 32 characters")
data2 = deepcopy(self.data)
data2["number_from"] = 106
resp = self.client.post(self.url, data=data2)
self.assertEqual(resp.data["data"], "Start number must be lower than end number")
@mock.patch("account.views.admin.xlsxwriter.Workbook")
def test_generate_user_success(self, mock_workbook):
resp = self.client.post(self.url, data=self.data)
self.assertSuccess(resp)
mock_workbook.assert_called()
class OpenAPIAppkeyAPITest(APITestCase):
def setUp(self):
self.user = self.create_super_admin()
self.url = self.reverse("open_api_appkey_api")
def test_reset_appkey(self):
resp = self.client.post(self.url, data={})
self.assertFailed(resp)
self.user.open_api = True
self.user.save()
resp = self.client.post(self.url, data={})
self.assertSuccess(resp)
self.assertEqual(resp.data["data"]["appkey"], User.objects.get(username=self.user.username).open_api_appkey)

View File

@ -1,7 +1,8 @@
from django.conf.urls import url from django.conf.urls import url
from ..views.admin import UserAdminAPI from ..views.admin import UserAdminAPI, GenerateUserAPI
urlpatterns = [ urlpatterns = [
url(r"^user/?$", UserAdminAPI.as_view(), name="user_admin_api"), url(r"^user/?$", UserAdminAPI.as_view(), name="user_admin_api"),
url(r"^generate_user/?$", GenerateUserAPI.as_view(), name="generate_user_api"),
] ]

View File

@ -1,12 +1,30 @@
from django.conf.urls import url from django.conf.urls import url
from ..views.oj import (ApplyResetPasswordAPI, ResetPasswordAPI, from ..views.oj import (ApplyResetPasswordAPI, ResetPasswordAPI,
UserChangePasswordAPI, UserLoginAPI, UserRegisterAPI) UserChangePasswordAPI, UserRegisterAPI, UserChangeEmailAPI,
UserLoginAPI, UserLogoutAPI, UsernameOrEmailCheck,
AvatarUploadAPI, TwoFactorAuthAPI, UserProfileAPI,
UserRankAPI, CheckTFARequiredAPI, SessionManagementAPI,
ProfileProblemDisplayIDRefreshAPI, OpenAPIAppkeyAPI)
from utils.captcha.views import CaptchaAPIView
urlpatterns = [ urlpatterns = [
url(r"^login/?$", UserLoginAPI.as_view(), name="user_login_api"), url(r"^login/?$", UserLoginAPI.as_view(), name="user_login_api"),
url(r"^logout/?$", UserLogoutAPI.as_view(), name="user_logout_api"),
url(r"^register/?$", UserRegisterAPI.as_view(), name="user_register_api"), url(r"^register/?$", UserRegisterAPI.as_view(), name="user_register_api"),
url(r"^change_password/?$", UserChangePasswordAPI.as_view(), name="user_change_password_api"), url(r"^change_password/?$", UserChangePasswordAPI.as_view(), name="user_change_password_api"),
url(r"^change_email/?$", UserChangeEmailAPI.as_view(), name="user_change_email_api"),
url(r"^apply_reset_password/?$", ApplyResetPasswordAPI.as_view(), name="apply_reset_password_api"), url(r"^apply_reset_password/?$", ApplyResetPasswordAPI.as_view(), name="apply_reset_password_api"),
url(r"^reset_password/?$", ResetPasswordAPI.as_view(), name="apply_reset_password_api") url(r"^reset_password/?$", ResetPasswordAPI.as_view(), name="reset_password_api"),
url(r"^captcha/?$", CaptchaAPIView.as_view(), name="show_captcha"),
url(r"^check_username_or_email", UsernameOrEmailCheck.as_view(), name="check_username_or_email"),
url(r"^profile/?$", UserProfileAPI.as_view(), name="user_profile_api"),
url(r"^profile/fresh_display_id", ProfileProblemDisplayIDRefreshAPI.as_view(), name="display_id_fresh"),
url(r"^upload_avatar/?$", AvatarUploadAPI.as_view(), name="avatar_upload_api"),
url(r"^tfa_required/?$", CheckTFARequiredAPI.as_view(), name="tfa_required_check"),
url(r"^two_factor_auth/?$", TwoFactorAuthAPI.as_view(), name="two_factor_auth_api"),
url(r"^user_rank/?$", UserRankAPI.as_view(), name="user_rank_api"),
url(r"^sessions/?$", SessionManagementAPI.as_view(), name="session_management_api"),
url(r"^open_api_appkey/?$", OpenAPIAppkeyAPI.as_view(), name="open_api_appkey_api"),
] ]

View File

@ -1,12 +0,0 @@
from django.conf.urls import url
from ..views.user import (SSOAPI, AvatarUploadAPI, TwoFactorAuthAPI,
UserInfoAPI, UserProfileAPI)
urlpatterns = [
url(r"^user/?$", UserInfoAPI.as_view(), name="user_info_api"),
url(r"^profile/?$", UserProfileAPI.as_view(), name="user_profile_api"),
url(r"^avatar/upload/?$", AvatarUploadAPI.as_view(), name="avatar_upload_api"),
url(r"^sso/?$", SSOAPI.as_view(), name="sso_api"),
url(r"^two_factor_auth/?$", TwoFactorAuthAPI.as_view(), name="two_factor_auth_api")
]

View File

@ -1,15 +1,48 @@
from django.core.exceptions import MultipleObjectsReturned import os
from django.db.models import Q import re
import xlsxwriter
from django.db import transaction, IntegrityError
from django.db.models import Q
from django.http import HttpResponse
from django.contrib.auth.hashers import make_password
from submission.models import Submission
from utils.api import APIView, validate_serializer from utils.api import APIView, validate_serializer
from utils.shortcuts import rand_str from utils.shortcuts import rand_str
from ..decorators import super_admin_required from ..decorators import super_admin_required
from ..models import AdminType, ProblemPermission, User from ..models import AdminType, ProblemPermission, User, UserProfile
from ..serializers import EditUserSerializer, UserSerializer from ..serializers import EditUserSerializer, UserSerializer, GenerateUserSerializer
from ..serializers import ImportUserSeralizer
class UserAdminAPI(APIView): class UserAdminAPI(APIView):
@validate_serializer(ImportUserSeralizer)
@super_admin_required
def post(self, request):
"""
Import User
"""
data = request.data["users"]
user_list = []
for user_data in data:
if len(user_data) != 3 or len(user_data[0]) > 32:
return self.error(f"Error occurred while processing data '{user_data}'")
user_list.append(User(username=user_data[0], password=make_password(user_data[1]), email=user_data[2]))
try:
with transaction.atomic():
ret = User.objects.bulk_create(user_list)
UserProfile.objects.bulk_create([UserProfile(user=user) for user in ret])
return self.success()
except IntegrityError as e:
# Extract detail from exception message
# duplicate key value violates unique constraint "user_username_key"
# DETAIL: Key (username)=(root11) already exists.
return self.error(str(e).split("\n")[1])
@validate_serializer(EditUserSerializer) @validate_serializer(EditUserSerializer)
@super_admin_required @super_admin_required
def put(self, request): def put(self, request):
@ -21,25 +54,13 @@ class UserAdminAPI(APIView):
user = User.objects.get(id=data["id"]) user = User.objects.get(id=data["id"])
except User.DoesNotExist: except User.DoesNotExist:
return self.error("User does not exist") return self.error("User does not exist")
try: if User.objects.filter(username=data["username"]).exclude(id=user.id).exists():
user = User.objects.get(username=data["username"])
if user.id != data["id"]:
return self.error("Username already exists") return self.error("Username already exists")
except User.DoesNotExist: if User.objects.filter(email=data["email"].lower()).exclude(id=user.id).exists():
pass
try:
user = User.objects.get(email=data["email"])
if user.id != data["id"]:
return self.error("Email already exists") return self.error("Email already exists")
# Some old data has duplicate email
except MultipleObjectsReturned:
return self.error("Email already exists")
except User.DoesNotExist:
pass
pre_username = user.username
user.username = data["username"] user.username = data["username"]
user.real_name = data["real_name"]
user.email = data["email"] user.email = data["email"]
user.admin_type = data["admin_type"] user.admin_type = data["admin_type"]
user.is_disabled = data["is_disabled"] user.is_disabled = data["is_disabled"]
@ -72,6 +93,8 @@ class UserAdminAPI(APIView):
user.two_factor_auth = data["two_factor_auth"] user.two_factor_auth = data["two_factor_auth"]
user.save() user.save()
if pre_username != user.username:
Submission.objects.filter(username=pre_username).update(username=user.username)
return self.success(UserSerializer(user).data) return self.success(UserSerializer(user).data)
@super_admin_required @super_admin_required
@ -91,7 +114,97 @@ class UserAdminAPI(APIView):
keyword = request.GET.get("keyword", None) keyword = request.GET.get("keyword", None)
if keyword: if keyword:
user = user.filter(Q(username__contains=keyword) | user = user.filter(Q(username__icontains=keyword) |
Q(real_name__contains=keyword) | Q(userprofile__real_name__icontains=keyword) |
Q(email__contains=keyword)) Q(email__icontains=keyword))
return self.success(self.paginate_data(request, user, UserSerializer)) return self.success(self.paginate_data(request, user, UserSerializer))
def delete_one(self, user_id):
try:
user = User.objects.get(id=user_id)
except User.DoesNotExist:
return f"User {user_id} does not exist"
if Submission.objects.filter(user_id=user_id).exists():
return f"Can't delete the user {user_id} as he/she has submissions"
user.delete()
@super_admin_required
def delete(self, request):
id = request.GET.get("id")
if not id:
return self.error("Invalid Parameter, id is required")
for user_id in id.split(","):
if user_id:
error = self.delete_one(user_id)
if error:
return self.error(error)
return self.success()
class GenerateUserAPI(APIView):
@super_admin_required
def get(self, request):
"""
download users excel
"""
file_id = request.GET.get("file_id")
if not file_id:
return self.error("Invalid Parameter, file_id is required")
if not re.match(r"^[a-zA-Z0-9]+$", file_id):
return self.error("Illegal file_id")
file_path = f"/tmp/{file_id}.xlsx"
if not os.path.isfile(file_path):
return self.error("File does not exist")
with open(file_path, "rb") as f:
raw_data = f.read()
os.remove(file_path)
response = HttpResponse(raw_data)
response["Content-Disposition"] = f"attachment; filename=users.xlsx"
response["Content-Type"] = "application/xlsx"
return response
@validate_serializer(GenerateUserSerializer)
@super_admin_required
def post(self, request):
"""
Generate User
"""
data = request.data
number_max_length = max(len(str(data["number_from"])), len(str(data["number_to"])))
if number_max_length + len(data["prefix"]) + len(data["suffix"]) > 32:
return self.error("Username should not more than 32 characters")
if data["number_from"] > data["number_to"]:
return self.error("Start number must be lower than end number")
file_id = rand_str(8)
filename = f"/tmp/{file_id}.xlsx"
workbook = xlsxwriter.Workbook(filename)
worksheet = workbook.add_worksheet()
worksheet.set_column("A:B", 20)
worksheet.write("A1", "Username")
worksheet.write("B1", "Password")
i = 1
user_list = []
for number in range(data["number_from"], data["number_to"] + 1):
raw_password = rand_str(data["password_length"])
user = User(username=f"{data['prefix']}{number}{data['suffix']}", password=make_password(raw_password))
user.raw_password = raw_password
user_list.append(user)
try:
with transaction.atomic():
ret = User.objects.bulk_create(user_list)
UserProfile.objects.bulk_create([UserProfile(user=user) for user in ret])
for item in user_list:
worksheet.write_string(i, 0, item.username)
worksheet.write_string(i, 1, item.raw_password)
i += 1
workbook.close()
return self.success({"file_id": file_id})
except IntegrityError as e:
# Extract detail from exception message
# duplicate key value violates unique constraint "user_username_key"
# DETAIL: Key (username)=(root11) already exists.
return self.error(str(e).split("\n")[1])

View File

@ -1,25 +1,154 @@
import os
from datetime import timedelta from datetime import timedelta
from importlib import import_module
import qrcode
from django.conf import settings from django.conf import settings
from django.contrib import auth from django.contrib import auth
from django.core.exceptions import MultipleObjectsReturned from django.template.loader import render_to_string
from django.utils.decorators import method_decorator
from django.utils.timezone import now from django.utils.timezone import now
from django.views.decorators.csrf import ensure_csrf_cookie
from otpauth import OtpAuth from otpauth import OtpAuth
from conf.models import WebsiteConfig from problem.models import Problem
from utils.constants import ContestRuleType
from options.options import SysOptions
from utils.api import APIView, validate_serializer from utils.api import APIView, validate_serializer
from utils.captcha import Captcha from utils.captcha import Captcha
from utils.shortcuts import rand_str from utils.shortcuts import rand_str, img2base64, datetime2str
from ..decorators import login_required from ..decorators import login_required
from ..models import User, UserProfile from ..models import User, UserProfile
from ..serializers import (ApplyResetPasswordSerializer, from ..serializers import (ApplyResetPasswordSerializer, ResetPasswordSerializer,
ResetPasswordSerializer,
UserChangePasswordSerializer, UserLoginSerializer, UserChangePasswordSerializer, UserLoginSerializer,
UserRegisterSerializer) UserRegisterSerializer, UsernameOrEmailCheckSerializer,
RankInfoSerializer, UserChangeEmailSerializer)
from ..serializers import (TwoFactorAuthCodeSerializer, UserProfileSerializer,
EditUserProfileSerializer, ImageUploadForm)
from ..tasks import send_email_async from ..tasks import send_email_async
class UserProfileAPI(APIView):
@method_decorator(ensure_csrf_cookie)
def get(self, request, **kwargs):
"""
判断是否登录 若登录返回用户信息
"""
user = request.user
if not user.is_authenticated():
return self.success()
username = request.GET.get("username")
try:
if username:
user = User.objects.get(username=username, is_disabled=False)
else:
user = request.user
except User.DoesNotExist:
return self.error("User does not exist")
return self.success(UserProfileSerializer(user.userprofile).data)
@validate_serializer(EditUserProfileSerializer)
@login_required
def put(self, request):
data = request.data
user_profile = request.user.userprofile
for k, v in data.items():
setattr(user_profile, k, v)
user_profile.save()
return self.success(UserProfileSerializer(user_profile).data)
class AvatarUploadAPI(APIView):
request_parsers = ()
@login_required
def post(self, request):
form = ImageUploadForm(request.POST, request.FILES)
if form.is_valid():
avatar = form.cleaned_data["image"]
else:
return self.error("Invalid file content")
if avatar.size > 2 * 1024 * 1024:
return self.error("Picture is too large")
suffix = os.path.splitext(avatar.name)[-1].lower()
if suffix not in [".gif", ".jpg", ".jpeg", ".bmp", ".png"]:
return self.error("Unsupported file format")
name = rand_str(10) + suffix
with open(os.path.join(settings.AVATAR_UPLOAD_DIR, name), "wb") as img:
for chunk in avatar:
img.write(chunk)
user_profile = request.user.userprofile
user_profile.avatar = f"{settings.AVATAR_URI_PREFIX}/{name}"
user_profile.save()
return self.success("Succeeded")
class TwoFactorAuthAPI(APIView):
@login_required
def get(self, request):
"""
Get QR code
"""
user = request.user
if user.two_factor_auth:
return self.error("2FA is already turned on")
token = rand_str()
user.tfa_token = token
user.save()
label = f"{SysOptions.website_name_shortcut}:{user.username}"
image = qrcode.make(OtpAuth(token).to_uri("totp", label, SysOptions.website_name))
return self.success(img2base64(image))
@login_required
@validate_serializer(TwoFactorAuthCodeSerializer)
def post(self, request):
"""
Open 2FA
"""
code = request.data["code"]
user = request.user
if OtpAuth(user.tfa_token).valid_totp(code):
user.two_factor_auth = True
user.save()
return self.success("Succeeded")
else:
return self.error("Invalid code")
@login_required
@validate_serializer(TwoFactorAuthCodeSerializer)
def put(self, request):
code = request.data["code"]
user = request.user
if not user.two_factor_auth:
return self.error("2FA is already turned off")
if OtpAuth(user.tfa_token).valid_totp(code):
user.two_factor_auth = False
user.save()
return self.success("Succeeded")
else:
return self.error("Invalid code")
class CheckTFARequiredAPI(APIView):
@validate_serializer(UsernameOrEmailCheckSerializer)
def post(self, request):
"""
Check TFA is required
"""
data = request.data
result = False
if data.get("username"):
try:
user = User.objects.get(username=data["username"])
result = user.two_factor_auth
except User.DoesNotExist:
pass
return self.success({"result": result})
class UserLoginAPI(APIView): class UserLoginAPI(APIView):
@validate_serializer(UserLoginSerializer) @validate_serializer(UserLoginSerializer)
def post(self, request): def post(self, request):
@ -30,13 +159,15 @@ class UserLoginAPI(APIView):
user = auth.authenticate(username=data["username"], password=data["password"]) user = auth.authenticate(username=data["username"], password=data["password"])
# None is returned if username or password is wrong # None is returned if username or password is wrong
if user: if user:
if user.is_disabled:
return self.error("Your account has been disabled")
if not user.two_factor_auth: if not user.two_factor_auth:
auth.login(request, user) auth.login(request, user)
return self.success("Succeeded") return self.success("Succeeded")
# `tfa_code` not in post data # `tfa_code` not in post data
if user.two_factor_auth and "tfa_code" not in data: if user.two_factor_auth and "tfa_code" not in data:
return self.success("tfa_required") return self.error("tfa_required")
if OtpAuth(user.tfa_token).valid_totp(data["tfa_code"]): if OtpAuth(user.tfa_token).valid_totp(data["tfa_code"]):
auth.login(request, user) auth.login(request, user)
@ -46,10 +177,30 @@ class UserLoginAPI(APIView):
else: else:
return self.error("Invalid username or password") return self.error("Invalid username or password")
# todo remove this, only for debug use
class UserLogoutAPI(APIView):
def get(self, request): def get(self, request):
auth.login(request, auth.authenticate(username=request.GET["username"], password=request.GET["password"])) auth.logout(request)
return self.success({}) return self.success()
class UsernameOrEmailCheck(APIView):
@validate_serializer(UsernameOrEmailCheckSerializer)
def post(self, request):
"""
check username or email is duplicate
"""
data = request.data
# True means already exist.
result = {
"username": False,
"email": False
}
if data.get("username"):
result["username"] = User.objects.filter(username=data["username"].lower()).exists()
if data.get("email"):
result["email"] = User.objects.filter(email=data["email"].lower()).exists()
return self.success(result)
class UserRegisterAPI(APIView): class UserRegisterAPI(APIView):
@ -58,22 +209,19 @@ class UserRegisterAPI(APIView):
""" """
User register api User register api
""" """
if not SysOptions.allow_register:
return self.error("Register function has been disabled by admin")
data = request.data data = request.data
captcha = Captcha(request) captcha = Captcha(request)
if not captcha.check(data["captcha"]): if not captcha.check(data["captcha"]):
return self.error("Invalid captcha") return self.error("Invalid captcha")
try: if User.objects.filter(username=data["username"]).exists():
User.objects.get(username=data["username"])
return self.error("Username already exists") return self.error("Username already exists")
except User.DoesNotExist: data["email"] = data["email"].lower()
pass if User.objects.filter(email=data["email"]).exists():
try:
User.objects.get(email=data["email"])
return self.error("Email already exists") return self.error("Email already exists")
# Some old data has duplicate email
except MultipleObjectsReturned:
return self.error("Email already exists")
except User.DoesNotExist:
user = User.objects.create(username=data["username"], email=data["email"]) user = User.objects.create(username=data["username"], email=data["email"])
user.set_password(data["password"]) user.set_password(data["password"])
user.save() user.save()
@ -81,6 +229,28 @@ class UserRegisterAPI(APIView):
return self.success("Succeeded") return self.success("Succeeded")
class UserChangeEmailAPI(APIView):
@validate_serializer(UserChangeEmailSerializer)
@login_required
def post(self, request):
data = request.data
user = auth.authenticate(username=request.user.username, password=data["password"])
if user:
if user.two_factor_auth:
if "tfa_code" not in data:
return self.error("tfa_required")
if not OtpAuth(user.tfa_token).valid_totp(data["tfa_code"]):
return self.error("Invalid two factor verification code")
data["new_email"] = data["new_email"].lower()
if User.objects.filter(email=data["new_email"]).exists():
return self.error("The email is owned by other account")
user.email = data["new_email"]
user.save()
return self.success("Succeeded")
else:
return self.error("Wrong password")
class UserChangePasswordAPI(APIView): class UserChangePasswordAPI(APIView):
@validate_serializer(UserChangePasswordSerializer) @validate_serializer(UserChangePasswordSerializer)
@login_required @login_required
@ -89,12 +259,14 @@ class UserChangePasswordAPI(APIView):
User change password api User change password api
""" """
data = request.data data = request.data
captcha = Captcha(request)
if not captcha.check(data["captcha"]):
return self.error("Invalid captcha")
username = request.user.username username = request.user.username
user = auth.authenticate(username=username, password=data["old_password"]) user = auth.authenticate(username=username, password=data["old_password"])
if user: if user:
if user.two_factor_auth:
if "tfa_code" not in data:
return self.error("tfa_required")
if not OtpAuth(user.tfa_token).valid_totp(data["tfa_code"]):
return self.error("Invalid two factor verification code")
user.set_password(data["new_password"]) user.set_password(data["new_password"])
user.save() user.save()
return self.success("Succeeded") return self.success("Succeeded")
@ -105,33 +277,33 @@ class UserChangePasswordAPI(APIView):
class ApplyResetPasswordAPI(APIView): class ApplyResetPasswordAPI(APIView):
@validate_serializer(ApplyResetPasswordSerializer) @validate_serializer(ApplyResetPasswordSerializer)
def post(self, request): def post(self, request):
if request.user.is_authenticated():
return self.error("You have already logged in, are you kidding me? ")
data = request.data data = request.data
captcha = Captcha(request) captcha = Captcha(request)
config = WebsiteConfig.objects.first()
if not captcha.check(data["captcha"]): if not captcha.check(data["captcha"]):
return self.error("Invalid captcha") return self.error("Invalid captcha")
try: try:
user = User.objects.get(email=data["email"]) user = User.objects.get(email__iexact=data["email"])
except User.DoesNotExist: except User.DoesNotExist:
return self.error("User does not exist") return self.error("User does not exist")
if user.reset_password_token_expire_time and 0 < ( if user.reset_password_token_expire_time and 0 < int(
user.reset_password_token_expire_time - now()).total_seconds() < 20 * 60: (user.reset_password_token_expire_time - now()).total_seconds()) < 20 * 60:
return self.error("You can only reset password once per 20 minutes") return self.error("You can only reset password once per 20 minutes")
user.reset_password_token = rand_str() user.reset_password_token = rand_str()
user.reset_password_token_expire_time = now() + timedelta(minutes=20) user.reset_password_token_expire_time = now() + timedelta(minutes=20)
user.save() user.save()
email_template = open("reset_password_email.html", "w", render_data = {
encoding="utf-8").read() "username": user.username,
email_template = email_template.replace("{{ username }}", user.username). \ "website_name": SysOptions.website_name,
replace("{{ website_name }}", settings.WEBSITE_INFO["website_name"]). \ "link": f"{SysOptions.website_base_url}/reset-password/{user.reset_password_token}"
replace("{{ link }}", settings.WEBSITE_INFO["url"] + "/reset_password/t/" + }
user.reset_password_token) email_html = render_to_string("reset_password_email.html", render_data)
send_email_async.delay(config.name, send_email_async.delay(SysOptions.website_name,
user.email, user.email,
user.username, user.username,
config.name + " 登录信息找回邮件", f"{SysOptions.website_name} 登录信息找回邮件",
email_template) email_html)
return self.success("Succeeded") return self.success("Succeeded")
@ -145,10 +317,99 @@ class ResetPasswordAPI(APIView):
try: try:
user = User.objects.get(reset_password_token=data["token"]) user = User.objects.get(reset_password_token=data["token"])
except User.DoesNotExist: except User.DoesNotExist:
return self.error("Token dose not exist") return self.error("Token does not exist")
if 0 < (user.reset_password_token_expire_time - now()).total_seconds() < 30 * 60: if user.reset_password_token_expire_time < now():
return self.error("Token expired") return self.error("Token has expired")
user.reset_password_token = None user.reset_password_token = None
user.two_factor_auth = False
user.set_password(data["password"]) user.set_password(data["password"])
user.save() user.save()
return self.success("Succeeded") return self.success("Succeeded")
class SessionManagementAPI(APIView):
@login_required
def get(self, request):
engine = import_module(settings.SESSION_ENGINE)
session_store = engine.SessionStore
current_session = request.session.session_key
session_keys = request.user.session_keys
result = []
modified = False
for key in session_keys[:]:
session = session_store(key)
# session does not exist or is expiry
if not session._session:
session_keys.remove(key)
modified = True
continue
s = {}
if current_session == key:
s["current_session"] = True
s["ip"] = session["ip"]
s["user_agent"] = session["user_agent"]
s["last_activity"] = datetime2str(session["last_activity"])
s["session_key"] = key
result.append(s)
if modified:
request.user.save()
return self.success(result)
@login_required
def delete(self, request):
session_key = request.GET.get("session_key")
if not session_key:
return self.error("Parameter Error")
request.session.delete(session_key)
if session_key in request.user.session_keys:
request.user.session_keys.remove(session_key)
request.user.save()
return self.success("Succeeded")
else:
return self.error("Invalid session_key")
class UserRankAPI(APIView):
def get(self, request):
rule_type = request.GET.get("rule")
if rule_type not in ContestRuleType.choices():
rule_type = ContestRuleType.ACM
profiles = UserProfile.objects.select_related("user")\
.exclude(user__is_disabled=True)
if rule_type == ContestRuleType.ACM:
profiles = profiles.filter(submission_number__gt=0).order_by("-accepted_number", "submission_number")
else:
profiles = profiles.filter(total_score__gt=0).order_by("-total_score")
return self.success(self.paginate_data(request, profiles, RankInfoSerializer))
class ProfileProblemDisplayIDRefreshAPI(APIView):
@login_required
def get(self, request):
profile = request.user.userprofile
acm_problems = profile.acm_problems_status.get("problems", {})
oi_problems = profile.oi_problems_status.get("problems", {})
ids = list(acm_problems.keys()) + list(oi_problems.keys())
if not ids:
return self.success()
display_ids = Problem.objects.filter(id__in=ids).values_list("_id", flat=True)
id_map = dict(zip(ids, display_ids))
for k, v in acm_problems.items():
v["_id"] = id_map[k]
for k, v in oi_problems.items():
v["_id"] = id_map[k]
profile.save(update_fields=["acm_problems_status", "oi_problems_status"])
return self.success()
class OpenAPIAppkeyAPI(APIView):
@login_required
def post(self, request):
user = request.user
if not user.open_api:
return self.error("Permission denied")
api_appkey = rand_str()
user.open_api_appkey = api_appkey
user.save()
return self.success({"appkey": api_appkey})

View File

@ -1,148 +0,0 @@
import os
from io import StringIO
import qrcode
from django.conf import settings
from django.http import HttpResponse
from otpauth import OtpAuth
from conf.models import WebsiteConfig
from utils.api import APIView, validate_serializer
from utils.shortcuts import rand_str
from ..decorators import login_required
from ..models import User
from ..serializers import (EditUserSerializer, SSOSerializer,
TwoFactorAuthCodeSerializer, UserSerializer)
class UserInfoAPI(APIView):
@login_required
def get(self, request):
"""
Return user info api
"""
return self.success(UserSerializer(request.user).data)
class UserProfileAPI(APIView):
@login_required
def get(self, request):
"""
Return user info api
"""
return self.success(UserSerializer(request.user).data)
@validate_serializer(EditUserSerializer)
@login_required
def put(self, request):
data = request.data
user_profile = request.user.userprofile
if data["avatar"]:
user_profile.avatar = data["avatar"]
else:
user_profile.mood = data["mood"]
user_profile.blog = data["blog"]
user_profile.school = data["school"]
user_profile.student_id = data["student_id"]
user_profile.phone_number = data["phone_number"]
user_profile.major = data["major"]
# Timezone & language 暂时不加
user_profile.save()
return self.success("Succeeded")
class AvatarUploadAPI(APIView):
def post(self, request):
if "file" not in request.FILES:
return self.error("Upload failed")
f = request.FILES["file"]
if f.size > 1024 * 1024:
return self.error("Picture too large")
if os.path.splitext(f.name)[-1].lower() not in [".gif", ".jpg", ".jpeg", ".bmp", ".png"]:
return self.error("Unsupported file format")
name = "avatar_" + rand_str(5) + os.path.splitext(f.name)[-1]
with open(os.path.join(settings.IMAGE_UPLOAD_DIR, name), "wb") as img:
for chunk in request.FILES["file"]:
img.write(chunk)
return self.success({"path": "/static/upload/" + name})
class SSOAPI(APIView):
@login_required
def get(self, request):
callback = request.GET.get("callback", None)
if not callback:
return self.error("Parameter Error")
token = rand_str()
request.user.auth_token = token
request.user.save()
return self.success({"redirect_url": callback + "?token=" + token,
"callback": callback})
@validate_serializer(SSOSerializer)
def post(self, request):
data = request.data
try:
User.objects.get(open_api_appkey=data["appkey"])
except User.DoesNotExist:
return self.error("Invalid appkey")
try:
user = User.objects.get(auth_token=data["token"])
user.auth_token = None
user.save()
return self.success({"username": user.username,
"id": user.id,
"admin_type": user.admin_type,
"avatar": user.userprofile.avatar})
except User.DoesNotExist:
return self.error("User does not exist")
class TwoFactorAuthAPI(APIView):
@login_required
def get(self, request):
"""
Get QR code
"""
user = request.user
if user.two_factor_auth:
return self.error("Already open 2FA")
token = rand_str()
user.tfa_token = token
user.save()
config = WebsiteConfig.objects.first()
image = qrcode.make(OtpAuth(token).to_uri("totp", config.base_url, config.name))
buf = StringIO()
image.save(buf, "gif")
return HttpResponse(buf.getvalue(), "image/gif")
@login_required
@validate_serializer(TwoFactorAuthCodeSerializer)
def post(self, request):
"""
Open 2FA
"""
code = request.data["code"]
user = request.user
if OtpAuth(user.tfa_token).valid_totp(code):
user.two_factor_auth = True
user.save()
return self.success("Succeeded")
else:
return self.error("Invalid captcha")
@login_required
@validate_serializer(TwoFactorAuthCodeSerializer)
def put(self, request):
code = request.data["code"]
user = request.user
if OtpAuth(user.tfa_token).valid_totp(code):
user.two_factor_auth = False
user.save()
else:
return self.error("Invalid captcha")

View File

@ -0,0 +1,24 @@
# -*- coding: utf-8 -*-
# Generated by Django 1.11.4 on 2017-10-11 12:14
from __future__ import unicode_literals
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('announcement', '0001_initial'),
]
operations = [
migrations.AlterField(
model_name='announcement',
name='title',
field=models.CharField(max_length=64),
),
migrations.AlterModelOptions(
name='announcement',
options={'ordering': ('-create_time',)},
),
]

View File

@ -5,7 +5,7 @@ from utils.models import RichTextField
class Announcement(models.Model): class Announcement(models.Model):
title = models.CharField(max_length=50) title = models.CharField(max_length=64)
# HTML # HTML
content = RichTextField() content = RichTextField()
create_time = models.DateTimeField(auto_now_add=True) create_time = models.DateTimeField(auto_now_add=True)
@ -15,3 +15,4 @@ class Announcement(models.Model):
class Meta: class Meta:
db_table = "announcement" db_table = "announcement"
ordering = ("-create_time",)

View File

@ -5,8 +5,8 @@ from .models import Announcement
class CreateAnnouncementSerializer(serializers.Serializer): class CreateAnnouncementSerializer(serializers.Serializer):
title = serializers.CharField(max_length=50) title = serializers.CharField(max_length=64)
content = serializers.CharField(max_length=10000) content = serializers.CharField(max_length=1024 * 1024 * 8)
visible = serializers.BooleanField() visible = serializers.BooleanField()
@ -21,6 +21,6 @@ class AnnouncementSerializer(serializers.ModelSerializer):
class EditAnnouncementSerializer(serializers.Serializer): class EditAnnouncementSerializer(serializers.Serializer):
id = serializers.IntegerField() id = serializers.IntegerField()
title = serializers.CharField(max_length=50) title = serializers.CharField(max_length=64)
content = serializers.CharField(max_length=10000) content = serializers.CharField(max_length=1024 * 1024 * 8)
visible = serializers.BooleanField() visible = serializers.BooleanField()

View File

@ -35,3 +35,14 @@ class AnnouncementAdminTest(APITestCase):
resp = self.client.delete(self.url + "?id=" + str(id)) resp = self.client.delete(self.url + "?id=" + str(id))
self.assertSuccess(resp) self.assertSuccess(resp)
self.assertFalse(Announcement.objects.filter(id=id).exists()) self.assertFalse(Announcement.objects.filter(id=id).exists())
class AnnouncementAPITest(APITestCase):
def setUp(self):
self.user = self.create_super_admin()
Announcement.objects.create(title="title", content="content", visible=True, created_by=self.user)
self.url = self.reverse("announcement_api")
def test_get_announcement_list(self):
resp = self.client.get(self.url)
self.assertSuccess(resp)

View File

@ -1,6 +1,6 @@
from django.conf.urls import url from django.conf.urls import url
from ..views import AnnouncementAdminAPI from ..views.admin import AnnouncementAdminAPI
urlpatterns = [ urlpatterns = [
url(r"^announcement/?$", AnnouncementAdminAPI.as_view(), name="announcement_admin_api"), url(r"^announcement/?$", AnnouncementAdminAPI.as_view(), name="announcement_admin_api"),

7
announcement/urls/oj.py Normal file
View File

@ -0,0 +1,7 @@
from django.conf.urls import url
from ..views.oj import AnnouncementAPI
urlpatterns = [
url(r"^announcement/?$", AnnouncementAPI.as_view(), name="announcement_api"),
]

View File

View File

@ -1,8 +1,8 @@
from account.decorators import super_admin_required from account.decorators import super_admin_required
from utils.api import APIView, validate_serializer from utils.api import APIView, validate_serializer
from .models import Announcement from announcement.models import Announcement
from .serializers import (AnnouncementSerializer, CreateAnnouncementSerializer, from announcement.serializers import (AnnouncementSerializer, CreateAnnouncementSerializer,
EditAnnouncementSerializer) EditAnnouncementSerializer)
@ -28,13 +28,12 @@ class AnnouncementAdminAPI(APIView):
""" """
data = request.data data = request.data
try: try:
announcement = Announcement.objects.get(id=data["id"]) announcement = Announcement.objects.get(id=data.pop("id"))
except Announcement.DoesNotExist: except Announcement.DoesNotExist:
return self.error("Announcement does not exist") return self.error("Announcement does not exist")
announcement.title = data["title"] for k, v in data.items():
announcement.content = data["content"] setattr(announcement, k, v)
announcement.visible = data["visible"]
announcement.save() announcement.save()
return self.success(AnnouncementSerializer(announcement).data) return self.success(AnnouncementSerializer(announcement).data)

10
announcement/views/oj.py Normal file
View File

@ -0,0 +1,10 @@
from utils.api import APIView
from announcement.models import Announcement
from announcement.serializers import AnnouncementSerializer
class AnnouncementAPI(APIView):
def get(self, request):
announcements = Announcement.objects.filter(visible=True)
return self.success(self.paginate_data(request, announcements, AnnouncementSerializer))

View File

@ -0,0 +1,39 @@
# -*- coding: utf-8 -*-
# Generated by Django 1.11.4 on 2017-10-11 12:14
from __future__ import unicode_literals
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('conf', '0001_initial'),
]
operations = [
migrations.DeleteModel(
name='JudgeServerToken',
),
migrations.DeleteModel(
name='SMTPConfig',
),
migrations.DeleteModel(
name='WebsiteConfig',
),
migrations.AlterField(
model_name='judgeserver',
name='hostname',
field=models.CharField(max_length=128),
),
migrations.AlterField(
model_name='judgeserver',
name='judger_version',
field=models.CharField(max_length=32),
),
migrations.AlterField(
model_name='judgeserver',
name='service_url',
field=models.CharField(blank=True, max_length=256, null=True),
),
]

View File

@ -2,55 +2,24 @@ from django.db import models
from django.utils import timezone from django.utils import timezone
class SMTPConfig(models.Model):
server = models.CharField(max_length=128)
port = models.IntegerField(default=25)
email = models.CharField(max_length=128)
password = models.CharField(max_length=128)
tls = models.BooleanField()
class Meta:
db_table = "smtp_config"
class WebsiteConfig(models.Model):
base_url = models.CharField(max_length=128, default="http://127.0.0.1")
name = models.CharField(max_length=32, default="Online Judge")
name_shortcut = models.CharField(max_length=32, default="oj")
footer = models.TextField(default="Online Judge Footer")
# allow register
allow_register = models.BooleanField(default=True)
# submission list show all user's submission
submission_list_show_all = models.BooleanField(default=True)
class Meta:
db_table = "website_config"
class JudgeServer(models.Model): class JudgeServer(models.Model):
hostname = models.CharField(max_length=64) hostname = models.CharField(max_length=128)
ip = models.CharField(max_length=32, blank=True, null=True) ip = models.CharField(max_length=32, blank=True, null=True)
judger_version = models.CharField(max_length=24) judger_version = models.CharField(max_length=32)
cpu_core = models.IntegerField() cpu_core = models.IntegerField()
memory_usage = models.FloatField() memory_usage = models.FloatField()
cpu_usage = models.FloatField() cpu_usage = models.FloatField()
last_heartbeat = models.DateTimeField() last_heartbeat = models.DateTimeField()
create_time = models.DateTimeField(auto_now_add=True) create_time = models.DateTimeField(auto_now_add=True)
task_number = models.IntegerField(default=0) task_number = models.IntegerField(default=0)
service_url = models.CharField(max_length=128, blank=True, null=True) service_url = models.CharField(max_length=256, blank=True, null=True)
@property @property
def status(self): def status(self):
if (timezone.now() - self.last_heartbeat).total_seconds() > 5: # 增加一秒延时,提高对网络环境的适应性
if (timezone.now() - self.last_heartbeat).total_seconds() > 6:
return "abnormal" return "abnormal"
return "normal" return "normal"
class Meta: class Meta:
db_table = "judge_server" db_table = "judge_server"
class JudgeServerToken(models.Model):
token = models.CharField(max_length=32)
class Meta:
db_table = "judge_server_token"

View File

@ -1,6 +1,6 @@
from utils.api import DateTimeTZField, serializers from utils.api import DateTimeTZField, serializers
from .models import JudgeServer, SMTPConfig, WebsiteConfig from .models import JudgeServer
class EditSMTPConfigSerializer(serializers.Serializer): class EditSMTPConfigSerializer(serializers.Serializer):
@ -15,31 +15,19 @@ class CreateSMTPConfigSerializer(EditSMTPConfigSerializer):
password = serializers.CharField(max_length=128) password = serializers.CharField(max_length=128)
class SMTPConfigSerializer(serializers.ModelSerializer):
class Meta:
model = SMTPConfig
exclude = ["id", "password"]
class TestSMTPConfigSerializer(serializers.Serializer): class TestSMTPConfigSerializer(serializers.Serializer):
email = serializers.EmailField() email = serializers.EmailField()
class CreateEditWebsiteConfigSerializer(serializers.Serializer): class CreateEditWebsiteConfigSerializer(serializers.Serializer):
base_url = serializers.CharField(max_length=128) website_base_url = serializers.CharField(max_length=128)
name = serializers.CharField(max_length=32) website_name = serializers.CharField(max_length=64)
name_shortcut = serializers.CharField(max_length=32) website_name_shortcut = serializers.CharField(max_length=64)
footer = serializers.CharField(max_length=1024) website_footer = serializers.CharField(max_length=1024 * 1024)
allow_register = serializers.BooleanField() allow_register = serializers.BooleanField()
submission_list_show_all = serializers.BooleanField() submission_list_show_all = serializers.BooleanField()
class WebsiteConfigSerializer(serializers.ModelSerializer):
class Meta:
model = WebsiteConfig
exclude = ["id"]
class JudgeServerSerializer(serializers.ModelSerializer): class JudgeServerSerializer(serializers.ModelSerializer):
create_time = DateTimeTZField() create_time = DateTimeTZField()
last_heartbeat = DateTimeTZField() last_heartbeat = DateTimeTZField()
@ -47,13 +35,14 @@ class JudgeServerSerializer(serializers.ModelSerializer):
class Meta: class Meta:
model = JudgeServer model = JudgeServer
fields = "__all__"
class JudgeServerHeartbeatSerializer(serializers.Serializer): class JudgeServerHeartbeatSerializer(serializers.Serializer):
hostname = serializers.CharField(max_length=64) hostname = serializers.CharField(max_length=128)
judger_version = serializers.CharField(max_length=24) judger_version = serializers.CharField(max_length=32)
cpu_core = serializers.IntegerField(min_value=1) cpu_core = serializers.IntegerField(min_value=1)
memory = serializers.FloatField(min_value=0, max_value=100) memory = serializers.FloatField(min_value=0, max_value=100)
cpu = serializers.FloatField(min_value=0, max_value=100) cpu = serializers.FloatField(min_value=0, max_value=100)
action = serializers.ChoiceField(choices=("heartbeat", )) action = serializers.ChoiceField(choices=("heartbeat", ))
service_url = serializers.CharField(max_length=128, required=False) service_url = serializers.CharField(max_length=256, required=False)

View File

@ -2,9 +2,9 @@ import hashlib
from django.utils import timezone from django.utils import timezone
from options.options import SysOptions
from utils.api.tests import APITestCase from utils.api.tests import APITestCase
from .models import JudgeServer
from .models import JudgeServer, JudgeServerToken, SMTPConfig
class SMTPConfigTest(APITestCase): class SMTPConfigTest(APITestCase):
@ -27,10 +27,6 @@ class SMTPConfigTest(APITestCase):
"tls": True} "tls": True}
resp = self.client.put(self.url, data=data) resp = self.client.put(self.url, data=data)
self.assertSuccess(resp) self.assertSuccess(resp)
smtp = SMTPConfig.objects.first()
self.assertEqual(smtp.password, self.password)
self.assertEqual(smtp.server, "smtp1.test.com")
self.assertEqual(smtp.email, "test2@test.com")
def test_edit_without_password1(self): def test_edit_without_password1(self):
self.test_create_smtp_config() self.test_create_smtp_config()
@ -38,7 +34,6 @@ class SMTPConfigTest(APITestCase):
"tls": True, "password": ""} "tls": True, "password": ""}
resp = self.client.put(self.url, data=data) resp = self.client.put(self.url, data=data)
self.assertSuccess(resp) self.assertSuccess(resp)
self.assertEqual(SMTPConfig.objects.first().password, self.password)
def test_edit_with_password(self): def test_edit_with_password(self):
self.test_create_smtp_config() self.test_create_smtp_config()
@ -46,18 +41,14 @@ class SMTPConfigTest(APITestCase):
"tls": True, "password": "newpassword"} "tls": True, "password": "newpassword"}
resp = self.client.put(self.url, data=data) resp = self.client.put(self.url, data=data)
self.assertSuccess(resp) self.assertSuccess(resp)
smtp = SMTPConfig.objects.first()
self.assertEqual(smtp.password, "newpassword")
self.assertEqual(smtp.server, "smtp1.test.com")
self.assertEqual(smtp.email, "test2@test.com")
class WebsiteConfigAPITest(APITestCase): class WebsiteConfigAPITest(APITestCase):
def test_create_website_config(self): def test_create_website_config(self):
self.create_super_admin() self.create_super_admin()
url = self.reverse("website_config_api") url = self.reverse("website_config_api")
data = {"base_url": "http://test.com", "name": "test name", data = {"website_base_url": "http://test.com", "website_name": "test name",
"name_shortcut": "test oj", "footer": "<a>test</a>", "website_name_shortcut": "test oj", "website_footer": "<a>test</a>",
"allow_register": True, "submission_list_show_all": False} "allow_register": True, "submission_list_show_all": False}
resp = self.client.post(url, data=data) resp = self.client.post(url, data=data)
self.assertSuccess(resp) self.assertSuccess(resp)
@ -65,8 +56,8 @@ class WebsiteConfigAPITest(APITestCase):
def test_edit_website_config(self): def test_edit_website_config(self):
self.create_super_admin() self.create_super_admin()
url = self.reverse("website_config_api") url = self.reverse("website_config_api")
data = {"base_url": "http://test.com", "name": "test name", data = {"website_base_url": "http://test.com", "website_name": "test name",
"name_shortcut": "test oj", "footer": "<a>test</a>", "website_name_shortcut": "test oj", "website_footer": "<a>test</a>",
"allow_register": True, "submission_list_show_all": False} "allow_register": True, "submission_list_show_all": False}
resp = self.client.post(url, data=data) resp = self.client.post(url, data=data)
self.assertSuccess(resp) self.assertSuccess(resp)
@ -76,7 +67,6 @@ class WebsiteConfigAPITest(APITestCase):
url = self.reverse("website_info_api") url = self.reverse("website_info_api")
resp = self.client.get(url) resp = self.client.get(url)
self.assertSuccess(resp) self.assertSuccess(resp)
self.assertEqual(resp.data["data"]["name_shortcut"], "oj")
class JudgeServerHeartbeatTest(APITestCase): class JudgeServerHeartbeatTest(APITestCase):
@ -86,7 +76,7 @@ class JudgeServerHeartbeatTest(APITestCase):
"cpu": 90.5, "memory": 80.3, "action": "heartbeat"} "cpu": 90.5, "memory": 80.3, "action": "heartbeat"}
self.token = "test" self.token = "test"
self.hashed_token = hashlib.sha256(self.token.encode("utf-8")).hexdigest() self.hashed_token = hashlib.sha256(self.token.encode("utf-8")).hexdigest()
JudgeServerToken.objects.create(token=self.token) SysOptions.judge_server_token = self.token
def test_new_heartbeat(self): def test_new_heartbeat(self):
resp = self.client.post(self.url, data=self.data, **{"HTTP_X_JUDGE_SERVER_TOKEN": self.hashed_token}) resp = self.client.post(self.url, data=self.data, **{"HTTP_X_JUDGE_SERVER_TOKEN": self.hashed_token})
@ -122,11 +112,9 @@ class JudgeServerAPITest(APITestCase):
self.create_super_admin() self.create_super_admin()
def test_get_judge_server(self): def test_get_judge_server(self):
self.assertFalse(JudgeServerToken.objects.exists())
resp = self.client.get(self.url) resp = self.client.get(self.url)
self.assertSuccess(resp) self.assertSuccess(resp)
self.assertEqual(len(resp.data["data"]["servers"]), 1) self.assertEqual(len(resp.data["data"]["servers"]), 1)
self.assertEqual(JudgeServerToken.objects.first().token, resp.data["data"]["token"])
def test_delete_judge_server(self): def test_delete_judge_server(self):
resp = self.client.delete(self.url + "?hostname=testhostname") resp = self.client.delete(self.url + "?hostname=testhostname")

View File

@ -3,48 +3,43 @@ import hashlib
from django.utils import timezone from django.utils import timezone
from account.decorators import super_admin_required from account.decorators import super_admin_required
from judge.dispatcher import process_pending_task
from judge.languages import languages, spj_languages from judge.languages import languages, spj_languages
from options.options import SysOptions
from utils.api import APIView, CSRFExemptAPIView, validate_serializer from utils.api import APIView, CSRFExemptAPIView, validate_serializer
from utils.shortcuts import rand_str from .models import JudgeServer
from .models import JudgeServer, JudgeServerToken, SMTPConfig, WebsiteConfig
from .serializers import (CreateEditWebsiteConfigSerializer, from .serializers import (CreateEditWebsiteConfigSerializer,
CreateSMTPConfigSerializer, EditSMTPConfigSerializer, CreateSMTPConfigSerializer, EditSMTPConfigSerializer,
JudgeServerHeartbeatSerializer, JudgeServerHeartbeatSerializer,
JudgeServerSerializer, SMTPConfigSerializer, JudgeServerSerializer, TestSMTPConfigSerializer)
TestSMTPConfigSerializer, WebsiteConfigSerializer)
class SMTPAPI(APIView): class SMTPAPI(APIView):
@super_admin_required @super_admin_required
def get(self, request): def get(self, request):
smtp = SMTPConfig.objects.first() smtp = SysOptions.smtp_config
if not smtp: if not smtp:
return self.success(None) return self.success(None)
return self.success(SMTPConfigSerializer(smtp).data) smtp.pop("password")
return self.success(smtp)
@validate_serializer(CreateSMTPConfigSerializer) @validate_serializer(CreateSMTPConfigSerializer)
@super_admin_required @super_admin_required
def post(self, request): def post(self, request):
SMTPConfig.objects.all().delete() SysOptions.smtp_config = request.data
smtp = SMTPConfig.objects.create(**request.data) return self.success()
return self.success(SMTPConfigSerializer(smtp).data)
@validate_serializer(EditSMTPConfigSerializer) @validate_serializer(EditSMTPConfigSerializer)
@super_admin_required @super_admin_required
def put(self, request): def put(self, request):
smtp = SysOptions.smtp_config
data = request.data data = request.data
smtp = SMTPConfig.objects.first() for item in ["server", "port", "email", "tls"]:
if not smtp: smtp[item] = data[item]
return self.error("SMTP config is missing") if "password" in data:
smtp.server = data["server"] smtp["password"] = data["password"]
smtp.port = data["port"] SysOptions.smtp_config = smtp
smtp.email = data["email"] return self.success()
smtp.tls = data["tls"]
if data.get("password"):
smtp.password = data["password"]
smtp.save()
return self.success(SMTPConfigSerializer(smtp).data)
class SMTPTestAPI(APIView): class SMTPTestAPI(APIView):
@ -56,31 +51,24 @@ class SMTPTestAPI(APIView):
class WebsiteConfigAPI(APIView): class WebsiteConfigAPI(APIView):
def get(self, request): def get(self, request):
config = WebsiteConfig.objects.first() ret = {key: getattr(SysOptions, key) for key in
if not config: ["website_base_url", "website_name", "website_name_shortcut",
config = WebsiteConfig.objects.create() "website_footer", "allow_register", "submission_list_show_all"]}
return self.success(WebsiteConfigSerializer(config).data) return self.success(ret)
@validate_serializer(CreateEditWebsiteConfigSerializer) @validate_serializer(CreateEditWebsiteConfigSerializer)
@super_admin_required @super_admin_required
def post(self, request): def post(self, request):
data = request.data for k, v in request.data.items():
WebsiteConfig.objects.all().delete() setattr(SysOptions, k, v)
config = WebsiteConfig.objects.create(**data) return self.success()
return self.success(WebsiteConfigSerializer(config).data)
class JudgeServerAPI(APIView): class JudgeServerAPI(APIView):
@super_admin_required @super_admin_required
def get(self, request): def get(self, request):
judge_server_token = JudgeServerToken.objects.first()
if not judge_server_token:
token = rand_str(12)
JudgeServerToken.objects.create(token=token)
else:
token = judge_server_token.token
servers = JudgeServer.objects.all().order_by("-last_heartbeat") servers = JudgeServer.objects.all().order_by("-last_heartbeat")
return self.success({"token": token, return self.success({"token": SysOptions.judge_server_token,
"servers": JudgeServerSerializer(servers, many=True).data}) "servers": JudgeServerSerializer(servers, many=True).data})
@super_admin_required @super_admin_required
@ -94,15 +82,9 @@ class JudgeServerAPI(APIView):
class JudgeServerHeartbeatAPI(CSRFExemptAPIView): class JudgeServerHeartbeatAPI(CSRFExemptAPIView):
@validate_serializer(JudgeServerHeartbeatSerializer) @validate_serializer(JudgeServerHeartbeatSerializer)
def post(self, request): def post(self, request):
judge_server_token = JudgeServerToken.objects.first()
if not judge_server_token:
token = rand_str(12)
JudgeServerToken.objects.create(token=token)
else:
token = judge_server_token.token
data = request.data data = request.data
client_token = request.META.get("HTTP_X_JUDGE_SERVER_TOKEN") client_token = request.META.get("HTTP_X_JUDGE_SERVER_TOKEN")
if hashlib.sha256(token.encode("utf-8")).hexdigest() != client_token: if hashlib.sha256(SysOptions.judge_server_token.encode("utf-8")).hexdigest() != client_token:
return self.error("Invalid token") return self.error("Invalid token")
service_url = data.get("service_url") service_url = data.get("service_url")
@ -126,6 +108,9 @@ class JudgeServerHeartbeatAPI(CSRFExemptAPIView):
service_url=service_url, service_url=service_url,
last_heartbeat=timezone.now(), last_heartbeat=timezone.now(),
) )
# 新server上线 处理队列中的防止没有新的提交而导致一直waiting
process_pending_task()
return self.success() return self.success()

View File

@ -0,0 +1,23 @@
# -*- coding: utf-8 -*-
# Generated by Django 1.9.6 on 2017-07-17 13:24
from __future__ import unicode_literals
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
('contest', '0003_auto_20170217_0820'),
]
operations = [
migrations.AlterModelOptions(
name='contest',
options={'ordering': ('-create_time',)},
),
migrations.AlterModelOptions(
name='contestannouncement',
options={'ordering': ('-create_time',)},
),
]

View File

@ -0,0 +1,30 @@
# -*- coding: utf-8 -*-
# Generated by Django 1.9.6 on 2017-08-23 09:18
from __future__ import unicode_literals
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
('contest', '0004_auto_20170717_1324'),
]
operations = [
migrations.RenameField(
model_name='acmcontestrank',
old_name='total_ac_number',
new_name='accepted_number',
),
migrations.RenameField(
model_name='acmcontestrank',
old_name='total_submission_number',
new_name='submission_number',
),
migrations.RenameField(
model_name='oicontestrank',
old_name='total_submission_number',
new_name='submission_number',
),
]

View File

@ -0,0 +1,30 @@
# -*- coding: utf-8 -*-
# Generated by Django 1.11.4 on 2017-10-11 12:14
from __future__ import unicode_literals
import django.contrib.postgres.fields.jsonb
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
('contest', '0005_auto_20170823_0918'),
]
operations = [
migrations.AlterField(
model_name='acmcontestrank',
name='submission_info',
field=django.contrib.postgres.fields.jsonb.JSONField(default=dict),
),
migrations.AlterField(
model_name='oicontestrank',
name='submission_info',
field=django.contrib.postgres.fields.jsonb.JSONField(default=dict),
),
migrations.AlterModelOptions(
name='contest',
options={'ordering': ('-start_time',)},
),
]

View File

@ -0,0 +1,20 @@
# -*- coding: utf-8 -*-
# Generated by Django 1.11.4 on 2017-11-06 09:02
from __future__ import unicode_literals
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('contest', '0006_auto_20171011_1214'),
]
operations = [
migrations.AddField(
model_name='contestannouncement',
name='visible',
field=models.BooleanField(default=True),
),
]

View File

@ -0,0 +1,21 @@
# -*- coding: utf-8 -*-
# Generated by Django 1.11.4 on 2017-11-10 06:57
from __future__ import unicode_literals
import django.contrib.postgres.fields.jsonb
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
('contest', '0007_contestannouncement_visible'),
]
operations = [
migrations.AddField(
model_name='contest',
name='allowed_ip_ranges',
field=django.contrib.postgres.fields.jsonb.JSONField(default=list),
),
]

View File

@ -1,27 +1,13 @@
from utils.constants import ContestRuleType # noqa
from django.db import models from django.db import models
from django.utils.timezone import now from django.utils.timezone import now
from jsonfield import JSONField from utils.models import JSONField
from utils.constants import ContestStatus, ContestType
from account.models import User from account.models import User
from utils.models import RichTextField from utils.models import RichTextField
class ContestType(object):
PUBLIC_CONTEST = "Public"
PASSWORD_PROTECTED_CONTEST = "Password Protected"
class ContestStatus(object):
CONTEST_NOT_START = "Not Started"
CONTEST_ENDED = "Ended"
CONTEST_UNDERWAY = "Underway"
class ContestRuleType(object):
ACM = "ACM"
OI = "OI"
class Contest(models.Model): class Contest(models.Model):
title = models.CharField(max_length=40) title = models.CharField(max_length=40)
description = RichTextField() description = RichTextField()
@ -37,6 +23,7 @@ class Contest(models.Model):
created_by = models.ForeignKey(User) created_by = models.ForeignKey(User)
# 是否可见 false的话相当于删除 # 是否可见 false的话相当于删除
visible = models.BooleanField(default=True) visible = models.BooleanField(default=True)
allowed_ip_ranges = JSONField(default=list)
@property @property
def status(self): def status(self):
@ -56,36 +43,44 @@ class Contest(models.Model):
return ContestType.PASSWORD_PROTECTED_CONTEST return ContestType.PASSWORD_PROTECTED_CONTEST
return ContestType.PUBLIC_CONTEST return ContestType.PUBLIC_CONTEST
# 是否有权查看problem 的一些统计信息 诸如submission_number, accepted_number 等
def problem_details_permission(self, user):
return self.rule_type == ContestRuleType.ACM or \
self.status == ContestStatus.CONTEST_ENDED or \
user.is_authenticated() and user.is_contest_admin(self) or \
self.real_time_rank
class Meta: class Meta:
db_table = "contest" db_table = "contest"
ordering = ("-start_time",)
class ContestRank(models.Model): class AbstractContestRank(models.Model):
user = models.ForeignKey(User) user = models.ForeignKey(User)
contest = models.ForeignKey(Contest) contest = models.ForeignKey(Contest)
total_submission_number = models.IntegerField(default=0) submission_number = models.IntegerField(default=0)
class Meta: class Meta:
abstract = True abstract = True
class ACMContestRank(ContestRank): class ACMContestRank(AbstractContestRank):
total_ac_number = models.IntegerField(default=0) accepted_number = models.IntegerField(default=0)
# total_time is only for ACM contest total_time = ac time + none-ac times * 20 * 60 # total_time is only for ACM contest total_time = ac time + none-ac times * 20 * 60
total_time = models.IntegerField(default=0) total_time = models.IntegerField(default=0)
# {23: {"is_ac": True, "ac_time": 8999, "error_number": 2, "is_first_ac": True}} # {23: {"is_ac": True, "ac_time": 8999, "error_number": 2, "is_first_ac": True}}
# key is problem id # key is problem id
submission_info = JSONField(default={}) submission_info = JSONField(default=dict)
class Meta: class Meta:
db_table = "acm_contest_rank" db_table = "acm_contest_rank"
class OIContestRank(ContestRank): class OIContestRank(AbstractContestRank):
total_score = models.IntegerField(default=0) total_score = models.IntegerField(default=0)
# {23: {"score": 80, "total_score": 100}} # {23: 333}}
# key is problem id # key is problem id, value is current score
submission_info = JSONField(default={}) submission_info = JSONField(default=dict)
class Meta: class Meta:
db_table = "oi_contest_rank" db_table = "oi_contest_rank"
@ -96,7 +91,9 @@ class ContestAnnouncement(models.Model):
title = models.CharField(max_length=128) title = models.CharField(max_length=128)
content = RichTextField() content = RichTextField()
created_by = models.ForeignKey(User) created_by = models.ForeignKey(User)
visible = models.BooleanField(default=True)
create_time = models.DateTimeField(auto_now_add=True) create_time = models.DateTimeField(auto_now_add=True)
class Meta: class Meta:
db_table = "contest_announcement" db_table = "contest_announcement"
ordering = ("-create_time",)

View File

@ -1,6 +1,7 @@
from utils.api import DateTimeTZField, UsernameSerializer, serializers from utils.api import DateTimeTZField, UsernameSerializer, serializers
from .models import Contest, ContestAnnouncement, ContestRuleType from .models import Contest, ContestAnnouncement, ContestRuleType
from .models import ACMContestRank, OIContestRank
class CreateConetestSeriaizer(serializers.Serializer): class CreateConetestSeriaizer(serializers.Serializer):
@ -12,9 +13,22 @@ class CreateConetestSeriaizer(serializers.Serializer):
password = serializers.CharField(allow_blank=True, max_length=32) password = serializers.CharField(allow_blank=True, max_length=32)
visible = serializers.BooleanField() visible = serializers.BooleanField()
real_time_rank = serializers.BooleanField() real_time_rank = serializers.BooleanField()
allowed_ip_ranges = serializers.ListField(child=serializers.CharField(max_length=32), allow_empty=True)
class ContestSerializer(serializers.ModelSerializer): class EditConetestSeriaizer(serializers.Serializer):
id = serializers.IntegerField()
title = serializers.CharField(max_length=128)
description = serializers.CharField()
start_time = serializers.DateTimeField()
end_time = serializers.DateTimeField()
password = serializers.CharField(allow_blank=True, allow_null=True, max_length=32)
visible = serializers.BooleanField()
real_time_rank = serializers.BooleanField()
allowed_ip_ranges = serializers.ListField(child=serializers.CharField(max_length=32))
class ContestAdminSerializer(serializers.ModelSerializer):
start_time = DateTimeTZField() start_time = DateTimeTZField()
end_time = DateTimeTZField() end_time = DateTimeTZField()
create_time = DateTimeTZField() create_time = DateTimeTZField()
@ -27,15 +41,10 @@ class ContestSerializer(serializers.ModelSerializer):
model = Contest model = Contest
class EditConetestSeriaizer(serializers.Serializer): class ContestSerializer(ContestAdminSerializer):
id = serializers.IntegerField() class Meta:
title = serializers.CharField(max_length=128) model = Contest
description = serializers.CharField() exclude = ("password", "visible", "allowed_ip_ranges")
start_time = serializers.DateTimeField()
end_time = serializers.DateTimeField()
password = serializers.CharField(allow_blank=True, allow_null=True, max_length=32)
visible = serializers.BooleanField()
real_time_rank = serializers.BooleanField()
class ContestAnnouncementSerializer(serializers.ModelSerializer): class ContestAnnouncementSerializer(serializers.ModelSerializer):
@ -47,6 +56,35 @@ class ContestAnnouncementSerializer(serializers.ModelSerializer):
class CreateContestAnnouncementSerializer(serializers.Serializer): class CreateContestAnnouncementSerializer(serializers.Serializer):
contest_id = serializers.IntegerField()
title = serializers.CharField(max_length=128) title = serializers.CharField(max_length=128)
content = serializers.CharField() content = serializers.CharField()
visible = serializers.BooleanField()
class EditContestAnnouncementSerializer(serializers.Serializer):
id = serializers.IntegerField()
title = serializers.CharField(max_length=128, required=False)
content = serializers.CharField(required=False, allow_blank=True)
visible = serializers.BooleanField(required=False)
class ContestPasswordVerifySerializer(serializers.Serializer):
contest_id = serializers.IntegerField() contest_id = serializers.IntegerField()
password = serializers.CharField(max_length=30, required=True)
class ACMContestRankSerializer(serializers.ModelSerializer):
user = UsernameSerializer()
submission_info = serializers.JSONField()
class Meta:
model = ACMContestRank
class OIContestRankSerializer(serializers.ModelSerializer):
user = UsernameSerializer()
submission_info = serializers.JSONField()
class Meta:
model = OIContestRank

View File

@ -6,27 +6,33 @@ from django.utils import timezone
from utils.api._serializers import DateTimeTZField from utils.api._serializers import DateTimeTZField
from utils.api.tests import APITestCase from utils.api.tests import APITestCase
from .models import ContestAnnouncement, ContestRuleType from .models import ContestAnnouncement, ContestRuleType, Contest
DEFAULT_CONTEST_DATA = {"title": "test title", "description": "test description", DEFAULT_CONTEST_DATA = {"title": "test title", "description": "test description",
"start_time": timezone.localtime(timezone.now()), "start_time": timezone.localtime(timezone.now()),
"end_time": timezone.localtime(timezone.now()) + timedelta(days=1), "end_time": timezone.localtime(timezone.now()) + timedelta(days=1),
"rule_type": ContestRuleType.ACM, "rule_type": ContestRuleType.ACM,
"password": "123", "password": "123",
"allowed_ip_ranges": [],
"visible": True, "real_time_rank": True} "visible": True, "real_time_rank": True}
class ContestAPITest(APITestCase): class ContestAdminAPITest(APITestCase):
def setUp(self): def setUp(self):
self.create_super_admin() self.create_super_admin()
self.url = self.reverse("contest_api") self.url = self.reverse("contest_admin_api")
self.data = DEFAULT_CONTEST_DATA self.data = copy.deepcopy(DEFAULT_CONTEST_DATA)
def test_create_contest(self): def test_create_contest(self):
response = self.client.post(self.url, data=self.data) response = self.client.post(self.url, data=self.data)
self.assertSuccess(response) self.assertSuccess(response)
return response return response
def test_create_contest_with_invalid_cidr(self):
self.data["allowed_ip_ranges"] = ["127.0.0"]
resp = self.client.post(self.url, data=self.data)
self.assertTrue(resp.data["data"].endswith("is not a valid cidr network"))
def test_update_contest(self): def test_update_contest(self):
id = self.test_create_contest().data["data"]["id"] id = self.test_create_contest().data["data"]["id"]
update_data = {"id": id, "title": "update title", update_data = {"id": id, "title": "update title",
@ -55,15 +61,54 @@ class ContestAPITest(APITestCase):
self.assertSuccess(response) self.assertSuccess(response)
class ContestAnnouncementAPITest(APITestCase): class ContestAPITest(APITestCase):
def setUp(self):
user = self.create_admin()
self.contest = Contest.objects.create(created_by=user, **DEFAULT_CONTEST_DATA)
self.url = self.reverse("contest_api") + "?id=" + str(self.contest.id)
def test_get_contest_list(self):
url = self.reverse("contest_list_api")
response = self.client.get(url + "?limit=10")
self.assertSuccess(response)
self.assertEqual(len(response.data["data"]["results"]), 1)
def test_get_one_contest(self):
resp = self.client.get(self.url)
self.assertSuccess(resp)
def test_regular_user_validate_contest_password(self):
self.create_user("test", "test123")
url = self.reverse("contest_password_api")
resp = self.client.post(url, {"contest_id": self.contest.id, "password": "error_password"})
self.assertDictEqual(resp.data, {"error": "error", "data": "Wrong password"})
resp = self.client.post(url, {"contest_id": self.contest.id, "password": DEFAULT_CONTEST_DATA["password"]})
self.assertSuccess(resp)
def test_regular_user_access_contest(self):
self.create_user("test", "test123")
url = self.reverse("contest_access_api")
resp = self.client.get(url + "?contest_id=" + str(self.contest.id))
self.assertFalse(resp.data["data"]["access"])
password_url = self.reverse("contest_password_api")
resp = self.client.post(password_url,
{"contest_id": self.contest.id, "password": DEFAULT_CONTEST_DATA["password"]})
self.assertSuccess(resp)
resp = self.client.get(self.url)
self.assertSuccess(resp)
class ContestAnnouncementAdminAPITest(APITestCase):
def setUp(self): def setUp(self):
self.create_super_admin() self.create_super_admin()
self.url = self.reverse("contest_announcement_admin_api") self.url = self.reverse("contest_announcement_admin_api")
contest_id = self.create_contest().data["data"]["id"] contest_id = self.create_contest().data["data"]["id"]
self.data = {"title": "test title", "content": "test content", "contest_id": contest_id} self.data = {"title": "test title", "content": "test content", "contest_id": contest_id, "visible": True}
def create_contest(self): def create_contest(self):
url = self.reverse("contest_api") url = self.reverse("contest_admin_api")
data = DEFAULT_CONTEST_DATA data = DEFAULT_CONTEST_DATA
return self.client.post(url, data=data) return self.client.post(url, data=data)
@ -80,7 +125,7 @@ class ContestAnnouncementAPITest(APITestCase):
def test_get_contest_announcements(self): def test_get_contest_announcements(self):
self.test_create_contest_announcement() self.test_create_contest_announcement()
response = self.client.get(self.url) response = self.client.get(self.url + "?contest_id=" + str(self.data["contest_id"]))
self.assertSuccess(response) self.assertSuccess(response)
def test_get_one_contest_announcement(self): def test_get_one_contest_announcement(self):
@ -92,10 +137,10 @@ class ContestAnnouncementAPITest(APITestCase):
class ContestAnnouncementListAPITest(APITestCase): class ContestAnnouncementListAPITest(APITestCase):
def setUp(self): def setUp(self):
self.create_super_admin() self.create_super_admin()
self.url = self.reverse("contest_list_api") self.url = self.reverse("contest_announcement_api")
def create_contest_announcements(self): def create_contest_announcements(self):
contest_id = self.client.post(self.reverse("contest_api"), data=DEFAULT_CONTEST_DATA).data["data"]["id"] contest_id = self.client.post(self.reverse("contest_admin_api"), data=DEFAULT_CONTEST_DATA).data["data"]["id"]
url = self.reverse("contest_announcement_admin_api") url = self.reverse("contest_announcement_admin_api")
self.client.post(url, data={"title": "test title1", "content": "test content1", "contest_id": contest_id}) self.client.post(url, data={"title": "test title1", "content": "test content1", "contest_id": contest_id})
self.client.post(url, data={"title": "test title2", "content": "test content2", "contest_id": contest_id}) self.client.post(url, data={"title": "test title2", "content": "test content2", "contest_id": contest_id})
@ -105,3 +150,15 @@ class ContestAnnouncementListAPITest(APITestCase):
contest_id = self.create_contest_announcements() contest_id = self.create_contest_announcements()
response = self.client.get(self.url, data={"contest_id": contest_id}) response = self.client.get(self.url, data={"contest_id": contest_id})
self.assertSuccess(response) self.assertSuccess(response)
class ContestRankAPITest(APITestCase):
def setUp(self):
user = self.create_admin()
self.acm_contest = Contest.objects.create(created_by=user, **DEFAULT_CONTEST_DATA)
self.create_user("test", "test123")
self.url = self.reverse("contest_rank_api")
def get_contest_rank(self):
resp = self.client.get(self.url + "?contest_id=" + self.acm_contest.id)
self.assertSuccess(resp)

View File

@ -3,6 +3,6 @@ from django.conf.urls import url
from ..views.admin import ContestAnnouncementAPI, ContestAPI from ..views.admin import ContestAnnouncementAPI, ContestAPI
urlpatterns = [ urlpatterns = [
url(r"^contest/?$", ContestAPI.as_view(), name="contest_api"), url(r"^contest/?$", ContestAPI.as_view(), name="contest_admin_api"),
url(r"^contest/announcement/?$", ContestAnnouncementAPI.as_view(), name="contest_announcement_admin_api") url(r"^contest/announcement/?$", ContestAnnouncementAPI.as_view(), name="contest_announcement_admin_api")
] ]

View File

@ -1,7 +1,15 @@
from django.conf.urls import url from django.conf.urls import url
from ..views.oj import ContestAnnouncementListAPI from ..views.oj import ContestAnnouncementListAPI
from ..views.oj import ContestPasswordVerifyAPI, ContestAccessAPI
from ..views.oj import ContestListAPI, ContestAPI
from ..views.oj import ContestRankAPI
urlpatterns = [ urlpatterns = [
url(r"^contest/?$", ContestAnnouncementListAPI.as_view(), name="contest_list_api"), url(r"^contests/?$", ContestListAPI.as_view(), name="contest_list_api"),
url(r"^contest/?$", ContestAPI.as_view(), name="contest_api"),
url(r"^contest/password/?$", ContestPasswordVerifyAPI.as_view(), name="contest_password_api"),
url(r"^contest/announcement/?$", ContestAnnouncementListAPI.as_view(), name="contest_announcement_api"),
url(r"^contest/access/?$", ContestAccessAPI.as_view(), name="contest_access_api"),
url(r"^contest_rank/?$", ContestRankAPI.as_view(), name="contest_rank_api"),
] ]

View File

@ -1,12 +1,14 @@
from ipaddress import ip_network
import dateutil.parser import dateutil.parser
from utils.api import APIView, validate_serializer from utils.api import APIView, validate_serializer
from ..models import Contest, ContestAnnouncement from ..models import Contest, ContestAnnouncement
from ..serializers import (ContestAnnouncementSerializer, ContestSerializer, from ..serializers import (ContestAnnouncementSerializer, ContestAdminSerializer,
CreateConetestSeriaizer, CreateConetestSeriaizer,
CreateContestAnnouncementSerializer, CreateContestAnnouncementSerializer,
EditConetestSeriaizer) EditConetestSeriaizer,
EditContestAnnouncementSerializer)
class ContestAPI(APIView): class ContestAPI(APIView):
@ -18,10 +20,15 @@ class ContestAPI(APIView):
data["created_by"] = request.user data["created_by"] = request.user
if data["end_time"] <= data["start_time"]: if data["end_time"] <= data["start_time"]:
return self.error("Start time must occur earlier than end time") return self.error("Start time must occur earlier than end time")
if not data["password"]: if data.get("password") and data["password"] == "":
data["password"] = None data["password"] = None
for ip_range in data["allowed_ip_ranges"]:
try:
ip_network(ip_range, strict=False)
except ValueError:
return self.error(f"{ip_range} is not a valid cidr network")
contest = Contest.objects.create(**data) contest = Contest.objects.create(**data)
return self.success(ContestSerializer(contest).data) return self.success(ContestAdminSerializer(contest).data)
@validate_serializer(EditConetestSeriaizer) @validate_serializer(EditConetestSeriaizer)
def put(self, request): def put(self, request):
@ -38,10 +45,16 @@ class ContestAPI(APIView):
return self.error("Start time must occur earlier than end time") return self.error("Start time must occur earlier than end time")
if not data["password"]: if not data["password"]:
data["password"] = None data["password"] = None
for ip_range in data["allowed_ip_ranges"]:
try:
ip_network(ip_range, strict=False)
except ValueError as e:
return self.error(f"{ip_range} is not a valid cidr network")
for k, v in data.items(): for k, v in data.items():
setattr(contest, k, v) setattr(contest, k, v)
contest.save() contest.save()
return self.success(ContestSerializer(contest).data) return self.success(ContestAdminSerializer(contest).data)
def get(self, request): def get(self, request):
contest_id = request.GET.get("id") contest_id = request.GET.get("id")
@ -50,7 +63,7 @@ class ContestAPI(APIView):
contest = Contest.objects.get(id=contest_id) contest = Contest.objects.get(id=contest_id)
if request.user.is_admin() and contest.created_by != request.user: if request.user.is_admin() and contest.created_by != request.user:
return self.error("Contest does not exist") return self.error("Contest does not exist")
return self.success(ContestSerializer(contest).data) return self.success(ContestAdminSerializer(contest).data)
except Contest.DoesNotExist: except Contest.DoesNotExist:
return self.error("Contest does not exist") return self.error("Contest does not exist")
@ -62,7 +75,7 @@ class ContestAPI(APIView):
if request.user.is_admin(): if request.user.is_admin():
contests = contests.filter(created_by=request.user) contests = contests.filter(created_by=request.user)
return self.success(self.paginate_data(request, contests, ContestSerializer)) return self.success(self.paginate_data(request, contests, ContestAdminSerializer))
class ContestAnnouncementAPI(APIView): class ContestAnnouncementAPI(APIView):
@ -83,6 +96,23 @@ class ContestAnnouncementAPI(APIView):
announcement = ContestAnnouncement.objects.create(**data) announcement = ContestAnnouncement.objects.create(**data)
return self.success(ContestAnnouncementSerializer(announcement).data) return self.success(ContestAnnouncementSerializer(announcement).data)
@validate_serializer(EditContestAnnouncementSerializer)
def put(self, request):
"""
update contest_announcement
"""
data = request.data
try:
contest_announcement = ContestAnnouncement.objects.get(id=data.pop("id"))
if request.user.is_admin() and contest_announcement.created_by != request.user:
return self.error("Contest announcement does not exist")
except ContestAnnouncement.DoesNotExist:
return self.error("Contest announcement does not exist")
for k, v in data.items():
setattr(contest_announcement, k, v)
contest_announcement.save()
return self.success()
def delete(self, request): def delete(self, request):
""" """
Delete one contest_announcement. Delete one contest_announcement.
@ -110,10 +140,13 @@ class ContestAnnouncementAPI(APIView):
except ContestAnnouncement.DoesNotExist: except ContestAnnouncement.DoesNotExist:
return self.error("Contest announcement does not exist") return self.error("Contest announcement does not exist")
contest_announcements = ContestAnnouncement.objects.all().order_by("-create_time") contest_id = request.GET.get("contest_id")
if not contest_id:
return self.error("Paramater error")
contest_announcements = ContestAnnouncement.objects.filter(contest_id=contest_id)
if request.user.is_admin(): if request.user.is_admin():
contest_announcements = contest_announcements.filter(created_by=request.user) contest_announcements = contest_announcements.filter(created_by=request.user)
keyword = request.GET.get("keyword") keyword = request.GET.get("keyword")
if keyword: if keyword:
contest_announcements = contest_announcements.filter(title__contains=keyword) contest_announcements = contest_announcements.filter(title__contains=keyword)
return self.success(self.paginate_data(request, contest_announcements, ContestAnnouncementSerializer)) return self.success(ContestAnnouncementSerializer(contest_announcements, many=True).data)

View File

@ -1,16 +1,115 @@
from utils.api import APIView from django.utils.timezone import now
from django.core.cache import cache
from utils.api import APIView, validate_serializer
from utils.constants import CacheKey
from utils.shortcuts import datetime2str
from account.decorators import login_required, check_contest_permission
from ..models import ContestAnnouncement from utils.constants import ContestRuleType, ContestStatus
from ..models import ContestAnnouncement, Contest, OIContestRank, ACMContestRank
from ..serializers import ContestAnnouncementSerializer from ..serializers import ContestAnnouncementSerializer
from ..serializers import ContestSerializer, ContestPasswordVerifySerializer
from ..serializers import OIContestRankSerializer, ACMContestRankSerializer
class ContestAnnouncementListAPI(APIView): class ContestAnnouncementListAPI(APIView):
@check_contest_permission(check_type="announcements")
def get(self, request): def get(self, request):
contest_id = request.GET.get("contest_id") contest_id = request.GET.get("contest_id")
if not contest_id: if not contest_id:
return self.error("Invalid parameter") return self.error("Invalid parameter, contest_id is required")
data = ContestAnnouncement.objects.filter(contest_id=contest_id).order_by("-create_time") data = ContestAnnouncement.objects.select_related("created_by").filter(contest_id=contest_id, visible=True)
max_id = request.GET.get("max_id") max_id = request.GET.get("max_id")
if max_id: if max_id:
data = data.filter(id__gt=max_id) data = data.filter(id__gt=max_id)
return self.success(ContestAnnouncementSerializer(data, many=True).data) return self.success(ContestAnnouncementSerializer(data, many=True).data)
class ContestAPI(APIView):
def get(self, request):
id = request.GET.get("id")
if not id:
return self.error("Invalid parameter, id is required")
try:
contest = Contest.objects.get(id=id)
except Contest.DoesNotExist:
return self.error("Contest does not exist")
data = ContestSerializer(contest).data
data["now"] = datetime2str(now())
return self.success(data)
class ContestListAPI(APIView):
def get(self, request):
contests = Contest.objects.select_related("created_by").filter(visible=True)
keyword = request.GET.get("keyword")
rule_type = request.GET.get("rule_type")
status = request.GET.get("status")
if keyword:
contests = contests.filter(title__contains=keyword)
if rule_type:
contests = contests.filter(rule_type=rule_type)
if status:
cur = now()
if status == ContestStatus.CONTEST_NOT_START:
contests = contests.filter(start_time__gt=cur)
elif status == ContestStatus.CONTEST_ENDED:
contests = contests.filter(end_time__lt=cur)
else:
contests = contests.filter(start_time__lte=cur, end_time__gte=cur)
data = self.paginate_data(request, contests, ContestSerializer)
return self.success(data)
class ContestPasswordVerifyAPI(APIView):
@validate_serializer(ContestPasswordVerifySerializer)
@login_required
def post(self, request):
data = request.data
try:
contest = Contest.objects.get(id=data["contest_id"], visible=True, password__isnull=False)
except Contest.DoesNotExist:
return self.error("Contest does not exist")
if contest.password != data["password"]:
return self.error("Wrong password")
# password verify OK.
if "accessible_contests" not in request.session:
request.session["accessible_contests"] = []
request.session["accessible_contests"].append(contest.id)
# https://docs.djangoproject.com/en/dev/topics/http/sessions/#when-sessions-are-saved
request.session.modified = True
return self.success(True)
class ContestAccessAPI(APIView):
@login_required
def get(self, request):
contest_id = request.GET.get("contest_id")
if not contest_id:
return self.error()
return self.success({"access": int(contest_id) in request.session.get("accessible_contests", [])})
class ContestRankAPI(APIView):
def get_rank(self):
if self.contest.rule_type == ContestRuleType.ACM:
return ACMContestRank.objects.filter(contest=self.contest). \
select_related("user").order_by("-accepted_number", "total_time")
else:
return OIContestRank.objects.filter(contest=self.contest). \
select_related("user").order_by("-total_score")
@check_contest_permission(check_type="ranks")
def get(self, request):
if self.contest.rule_type == ContestRuleType.OI:
serializer = OIContestRankSerializer
else:
serializer = ACMContestRankSerializer
cache_key = f"{CacheKey.contest_rank_cache}:{self.contest.id}"
qs = cache.get(cache_key)
if not qs:
qs = self.get_rank()
cache.set(cache_key, qs)
return self.success(self.paginate_data(request, qs, serializer))

0
data/log/.gitkeep Normal file
View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 16 KiB

View File

0
data/ssl/.gitkeep Normal file
View File

0
data/test_case/.gitkeep Normal file
View File

View File

@ -1,5 +0,0 @@
FROM python:3.5
ADD requirements.txt /tmp
RUN pip install -r /tmp/requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
WORKDIR /app
CMD python manage.py runserver 0.0.0.0:8085

20
deploy/nginx/common.conf Normal file
View File

@ -0,0 +1,20 @@
location /public {
root /data;
}
location /api {
proxy_pass http://backend;
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header Host $http_host;
client_max_body_size 200M;
}
location /admin {
root /app/dist/admin;
try_files $uri $uri/ /index.html =404;
}
location / {
root /app/dist;
try_files $uri $uri/ /index.html =404;
}

57
deploy/nginx/nginx.conf Normal file
View File

@ -0,0 +1,57 @@
user nobody;
daemon off;
pid /tmp/nginx.pid;
worker_processes auto;
pcre_jit on;
error_log /data/log/nginx_error.log warn;
events {
worker_connections 1024;
}
http {
include /etc/nginx/mime.types;
default_type application/octet-stream;
server_tokens off;
keepalive_timeout 65;
sendfile on;
tcp_nodelay on;
gzip on;
gzip_vary on;
gzip_types application/javascript text/css;
client_body_temp_path /tmp 1 2;
log_format main '$remote_addr - $remote_user [$time_local] "$request" '
'$status $body_bytes_sent "$http_referer" '
'"$http_user_agent" "$http_x_forwarded_for"';
access_log /data/log/nginx_access.log main;
upstream backend {
server 127.0.0.1:8080;
keepalive 32;
}
server {
listen 8000 default_server;
server_name _;
include common.conf;
}
server {
listen 1443 ssl http2 default_server;
server_name _;
ssl_certificate /data/ssl/server.crt;
ssl_certificate_key /data/ssl/server.key;
ssl_protocols TLSv1 TLSv1.1 TLSv1.2;
ssl_ciphers "EECDH+AESGCM:EDH+AESGCM:ECDHE-RSA-AES128-GCM-SHA256:AES256+EECDH:DHE-RSA-AES128-GCM-SHA256:AES256+EDH:ECDHE-RSA-AES256-GCM-SHA384:DHE-RSA-AES256-GCM-SHA384:ECDHE-RSA-AES256-SHA384:ECDHE-RSA-AES128-SHA256:ECDHE-RSA-AES256-SHA:ECDHE-RSA-AES128-SHA:DHE-RSA-AES256-SHA256:DHE-RSA-AES128-SHA256:DHE-RSA-AES256-SHA:DHE-RSA-AES128-SHA:ECDHE-RSA-DES-CBC3-SHA:EDH-RSA-DES-CBC3-SHA:AES256-GCM-SHA384:AES128-GCM-SHA256:AES256-SHA256:AES128-SHA256:AES256-SHA:AES128-SHA:DES-CBC3-SHA:HIGH:!aNULL:!eNULL:!EXPORT:!DES:!MD5:!PSK:!RC4";
ssl_prefer_server_ciphers on;
ssl_session_cache shared:SSL:10m;
include common.conf;
}
}

View File

@ -1,7 +1,6 @@
django==1.9.6 django==1.11.4
djangorestframework==3.4.0 djangorestframework==3.4.0
pillow pillow
jsonfield
otpauth otpauth
flake8-quotes flake8-quotes
pytz pytz
@ -11,3 +10,9 @@ celery
Envelopes Envelopes
qrcode qrcode
flake8-coding flake8-coding
requests
django-redis
psycopg2
gunicorn
jsonfield
XlsxWriter

33
deploy/run.sh Normal file
View File

@ -0,0 +1,33 @@
#!/bin/bash
APP=/app
DATA=/data
if [ ! -f "$APP/oj/custom_settings.py" ]; then
echo SECRET_KEY=\"$(cat /dev/urandom | head -1 | md5sum | head -c 32)\" >> $APP/oj/custom_settings.py
fi
mkdir -p $DATA/log $DATA/ssl $DATA/test_case $DATA/public/upload $DATA/public/avatar
SSL="$DATA/ssl"
if [ ! -f "$SSL/server.key" ]; then
openssl req -x509 -newkey rsa:2048 -keyout "$SSL/server.key" -out "$SSL/server.crt" -days 1000 \
-subj "/C=CN/ST=Beijing/L=Beijing/O=Beijing OnlineJudge Technology Co., Ltd./OU=Service Infrastructure Department/CN=`hostname`" -nodes
fi
cd $APP
n=0
while [ $n -lt 5 ]
do
python manage.py migrate --no-input &&
python manage.py inituser --username=root --password=rootroot --action=create_super_admin &&
break
n=$(($n+1))
echo "Failed to migrate, going to retry..."
sleep 8
done
cp data/public/avatar/default.png /data/public/avatar
chown -R nobody:nogroup $DATA $APP/dist
exec supervisord -c /app/deploy/supervisord.conf

52
deploy/supervisord.conf Normal file
View File

@ -0,0 +1,52 @@
[supervisord]
logfile=/data/log/supervisord.log
logfile_maxbytes=10MB
logfile_backups=10
loglevel=info
pidfile=/tmp/supervisord.pid
nodaemon=true
childlogdir=/data/log/
[inet_http_server]
port=127.0.0.1:9005
[rpcinterface:supervisor]
supervisor.rpcinterface_factory=supervisor.rpcinterface:make_main_rpcinterface
[supervisorctl]
serverurl=http://127.0.0.1:9005
[program:nginx]
command=nginx -c /app/deploy/nginx/nginx.conf
directory=/app/
stdout_logfile=/data/log/nginx.log
stderr_logfile=/data/log/nginx.log
autostart=true
autorestart=true
startsecs=5
stopwaitsecs = 5
killasgroup=true
[program:gunicorn]
command=sh -c "gunicorn oj.wsgi --user nobody -b 127.0.0.1:8080 --reload -w `grep -c ^processor /proc/cpuinfo`"
directory=/app/
user=nobody
stdout_logfile=/data/log/gunicorn.log
stderr_logfile=/data/log/gunicorn.log
autostart=true
autorestart=true
startsecs=5
stopwaitsecs = 5
killasgroup=true
[program:celery]
command=celery -A oj worker -l warning
directory=/app/
user=nobody
stdout_logfile=/data/log/celery.log
stderr_logfile=/data/log/celery.log
autostart=true
autorestart=true
startsecs=5
stopwaitsecs = 5
killasgroup=true

352
judge/dispatcher.py Normal file
View File

@ -0,0 +1,352 @@
import hashlib
import json
import logging
from urllib.parse import urljoin
import requests
from django.db import transaction
from django.db.models import F
from django.conf import settings
from account.models import User
from conf.models import JudgeServer
from contest.models import ContestRuleType, ACMContestRank, OIContestRank, ContestStatus
from judge.languages import languages, spj_languages
from options.options import SysOptions
from problem.models import Problem, ProblemRuleType
from problem.utils import parse_problem_template
from submission.models import JudgeStatus, Submission
from utils.cache import cache
from utils.constants import CacheKey
logger = logging.getLogger(__name__)
# 继续处理在队列中的问题
def process_pending_task():
if cache.llen(CacheKey.waiting_queue):
# 防止循环引入
from judge.tasks import judge_task
data = json.loads(cache.rpop(CacheKey.waiting_queue).decode("utf-8"))
judge_task.delay(**data)
class DispatcherBase(object):
def __init__(self):
self.token = hashlib.sha256(SysOptions.judge_server_token.encode("utf-8")).hexdigest()
def _request(self, url, data=None):
kwargs = {"headers": {"X-Judge-Server-Token": self.token}}
if data:
kwargs["json"] = data
try:
return requests.post(url, **kwargs).json()
except Exception as e:
logger.exception(e)
@staticmethod
def choose_judge_server():
with transaction.atomic():
servers = JudgeServer.objects.select_for_update().all().order_by("task_number")
servers = [s for s in servers if s.status == "normal"]
if servers:
server = servers[0]
server.used_instance_number = F("task_number") + 1
server.save()
return server
@staticmethod
def release_judge_server(judge_server_id):
with transaction.atomic():
# 使用原子操作, 同时因为use和release中间间隔了判题过程,需要重新查询一下
server = JudgeServer.objects.get(id=judge_server_id)
server.used_instance_number = F("task_number") - 1
server.save()
class SPJCompiler(DispatcherBase):
def __init__(self, spj_code, spj_version, spj_language):
super().__init__()
spj_compile_config = list(filter(lambda config: spj_language == config["name"], spj_languages))[0]["spj"][
"compile"]
self.data = {
"src": spj_code,
"spj_version": spj_version,
"spj_compile_config": spj_compile_config
}
def compile_spj(self):
server = self.choose_judge_server()
if not server:
return "No available judge_server"
result = self._request(urljoin(server.service_url, "compile_spj"), data=self.data)
self.release_judge_server(server.id)
if result["err"]:
return result["data"]
class JudgeDispatcher(DispatcherBase):
def __init__(self, submission_id, problem_id):
super().__init__()
self.submission = Submission.objects.get(id=submission_id)
self.contest_id = self.submission.contest_id
if self.contest_id:
self.problem = Problem.objects.select_related("contest").get(id=problem_id, contest_id=self.contest_id)
self.contest = self.problem.contest
else:
self.problem = Problem.objects.get(id=problem_id)
def _compute_statistic_info(self, resp_data):
# 用时和内存占用保存为多个测试点中最长的那个
self.submission.statistic_info["time_cost"] = max([x["cpu_time"] for x in resp_data])
self.submission.statistic_info["memory_cost"] = max([x["memory"] for x in resp_data])
# sum up the score in OI mode
if self.problem.rule_type == ProblemRuleType.OI:
score = 0
try:
for i in range(len(resp_data)):
if resp_data[i]["result"] == JudgeStatus.ACCEPTED:
resp_data[i]["score"] = self.problem.test_case_score[i]["score"]
score += resp_data[i]["score"]
else:
resp_data[i]["score"] = 0
except IndexError:
logger.error(f"Index Error raised when summing up the score in problem {self.problem.id}")
self.submission.statistic_info["score"] = 0
return
self.submission.statistic_info["score"] = score
def judge(self, output=True):
server = self.choose_judge_server()
if not server:
data = {"submission_id": self.submission.id, "problem_id": self.problem.id}
cache.lpush(CacheKey.waiting_queue, json.dumps(data))
return
language = self.submission.language
sub_config = list(filter(lambda item: language == item["name"], languages))[0]
spj_config = {}
if self.problem.spj_code:
for lang in spj_languages:
if lang["name"] == self.problem.spj_language:
spj_config = lang["spj"]
break
if language in self.problem.template:
template = parse_problem_template(self.problem.template[language])
code = f"{template['prepend']}\n{self.submission.code}\n{template['append']}"
else:
code = self.submission.code
data = {
"language_config": sub_config["config"],
"src": code,
"max_cpu_time": self.problem.time_limit,
"max_memory": 1024 * 1024 * self.problem.memory_limit,
"test_case_id": self.problem.test_case_id,
"output": output,
"spj_version": self.problem.spj_version,
"spj_config": spj_config.get("config"),
"spj_compile_config": spj_config.get("compile"),
"spj_src": self.problem.spj_code
}
Submission.objects.filter(id=self.submission.id).update(result=JudgeStatus.JUDGING)
service_url = server.service_url
# not set service_url, it should be a linked container
if not service_url:
service_url = settings.DEFAULT_JUDGE_SERVER_SERVICE_URL
resp = self._request(urljoin(service_url, "/judge"), data=data)
if resp["err"]:
self.submission.result = JudgeStatus.COMPILE_ERROR
self.submission.statistic_info["err_info"] = resp["data"]
self.submission.statistic_info["score"] = 0
else:
resp["data"].sort(key=lambda x: int(x["test_case"]))
self.submission.info = resp
self._compute_statistic_info(resp["data"])
error_test_case = list(filter(lambda case: case["result"] != 0, resp["data"]))
# ACM模式下,多个测试点全部正确则AC否则取第一个错误的测试点的状态
# OI模式下, 若多个测试点全部正确则AC 若全部错误则取第一个错误测试点状态,否则为部分正确
if not error_test_case:
self.submission.result = JudgeStatus.ACCEPTED
elif self.problem.rule_type == ProblemRuleType.ACM or len(error_test_case) == len(resp["data"]):
self.submission.result = error_test_case[0]["result"]
else:
self.submission.result = JudgeStatus.PARTIALLY_ACCEPTED
self.submission.save()
self.release_judge_server(server.id)
if self.contest_id:
self.update_contest_problem_status()
self.update_contest_rank()
else:
self.update_problem_status()
# 至此判题结束,尝试处理任务队列中剩余的任务
process_pending_task()
def update_problem_status(self):
result = str(self.submission.result)
problem_id = str(self.problem.id)
with transaction.atomic():
# update problem status
problem = Problem.objects.select_for_update().get(contest_id=self.contest_id, id=self.problem.id)
problem.submission_number += 1
if self.submission.result == JudgeStatus.ACCEPTED:
problem.accepted_number += 1
problem_info = problem.statistic_info
problem_info[result] = problem_info.get(result, 0) + 1
problem.save(update_fields=["accepted_number", "submission_number", "statistic_info"])
# update_userprofile
user = User.objects.select_for_update().get(id=self.submission.user_id)
user_profile = user.userprofile
user_profile.submission_number += 1
if problem.rule_type == ProblemRuleType.ACM:
acm_problems_status = user_profile.acm_problems_status.get("problems", {})
if problem_id not in acm_problems_status:
acm_problems_status[problem_id] = {"status": self.submission.result, "_id": self.problem._id}
if self.submission.result == JudgeStatus.ACCEPTED:
user_profile.accepted_number += 1
elif acm_problems_status[problem_id]["status"] != JudgeStatus.ACCEPTED:
acm_problems_status[problem_id]["status"] = self.submission.result
if self.submission.result == JudgeStatus.ACCEPTED:
user_profile.accepted_number += 1
user_profile.acm_problems_status["problems"] = acm_problems_status
user_profile.save(update_fields=["submission_number", "accepted_number", "acm_problems_status"])
else:
oi_problems_status = user_profile.oi_problems_status.get("problems", {})
score = self.submission.statistic_info["score"]
if problem_id not in oi_problems_status:
user_profile.add_score(score)
oi_problems_status[problem_id] = {"status": self.submission.result,
"_id": self.problem._id,
"score": score}
if self.submission.result == JudgeStatus.ACCEPTED:
user_profile.accepted_number += 1
else:
if oi_problems_status[problem_id]["status"] == JudgeStatus.ACCEPTED and \
self.submission.result != JudgeStatus.ACCEPTED:
user_profile.accepted_number -= 1
elif oi_problems_status[problem_id]["status"] != JudgeStatus.ACCEPTED and \
self.submission.result == JudgeStatus:
user_profile.accepted_number += 1
# minus last time score, add this time score
user_profile.add_score(this_time_score=score,
last_time_score=oi_problems_status[problem_id]["score"])
oi_problems_status[problem_id]["score"] = score
oi_problems_status[problem_id]["status"] = self.submission.result
user_profile.oi_problems_status["problems"] = oi_problems_status
user_profile.save(update_fields=["submission_number", "accepted_number", "oi_problems_status"])
def update_contest_problem_status(self):
if self.contest_id and self.contest.status != ContestStatus.CONTEST_UNDERWAY:
logger.info("Contest debug mode, id: " + str(self.contest_id) + ", submission id: " + self.submission.id)
return
with transaction.atomic():
user = User.objects.select_for_update().get(id=self.submission.user_id)
user_profile = user.userprofile
problem_id = str(self.problem.id)
if self.contest.rule_type == ContestRuleType.ACM:
contest_problems_status = user_profile.acm_problems_status.get("contest_problems", {})
if problem_id not in contest_problems_status:
contest_problems_status[problem_id] = {"status": self.submission.result, "_id": self.problem._id}
elif contest_problems_status[problem_id]["status"] != JudgeStatus.ACCEPTED:
contest_problems_status[problem_id]["status"] = self.submission.result
else:
# 如果已AC 直接跳过 不计入任何计数器
return
user_profile.acm_problems_status["contest_problems"] = contest_problems_status
user_profile.save(update_fields=["acm_problems_status"])
elif self.contest.rule_type == ContestRuleType.OI:
contest_problems_status = user_profile.oi_problems_status.get("contest_problems", {})
score = self.submission.statistic_info["score"]
if problem_id not in contest_problems_status:
contest_problems_status[problem_id] = {"status": self.submission.result,
"_id": self.problem._id,
"score": score}
else:
contest_problems_status[problem_id]["score"] = score
contest_problems_status[problem_id]["status"] = self.submission.result
user_profile.oi_problems_status["contest_problems"] = contest_problems_status
user_profile.save(update_fields=["oi_problems_status"])
problem = Problem.objects.select_for_update().get(contest_id=self.contest_id, id=self.problem.id)
result = str(self.submission.result)
problem_info = problem.statistic_info
problem_info[result] = problem_info.get(result, 0) + 1
problem.submission_number += 1
if self.submission.result == JudgeStatus.ACCEPTED:
problem.accepted_number += 1
problem.save(update_fields=["submission_number", "accepted_number", "statistic_info"])
def update_contest_rank(self):
if self.contest_id and self.contest.status != ContestStatus.CONTEST_UNDERWAY:
return
if self.contest.rule_type == ContestRuleType.OI or self.contest.real_time_rank:
cache.delete(f"{CacheKey.contest_rank_cache}:{self.contest.id}")
with transaction.atomic():
if self.contest.rule_type == ContestRuleType.ACM:
acm_rank, _ = ACMContestRank.objects.select_for_update(). \
get_or_create(user_id=self.submission.user_id, contest=self.contest)
self._update_acm_contest_rank(acm_rank)
else:
oi_rank, _ = OIContestRank.objects.select_for_update(). \
get_or_create(user_id=self.submission.user_id, contest=self.contest)
self._update_oi_contest_rank(oi_rank)
def _update_acm_contest_rank(self, rank):
info = rank.submission_info.get(str(self.submission.problem_id))
# 因前面更改过,这里需要重新获取
problem = Problem.objects.get(contest_id=self.contest_id, id=self.problem.id)
# 此题提交过
if info:
if info["is_ac"]:
return
rank.submission_number += 1
if self.submission.result == JudgeStatus.ACCEPTED:
rank.accepted_number += 1
info["is_ac"] = True
info["ac_time"] = (self.submission.create_time - self.contest.start_time).total_seconds()
rank.total_time += info["ac_time"] + info["error_number"] * 20 * 60
if problem.accepted_number == 1:
info["is_first_ac"] = True
else:
info["error_number"] += 1
# 第一次提交
else:
rank.submission_number += 1
info = {"is_ac": False, "ac_time": 0, "error_number": 0, "is_first_ac": False}
if self.submission.result == JudgeStatus.ACCEPTED:
rank.accepted_number += 1
info["is_ac"] = True
info["ac_time"] = (self.submission.create_time - self.contest.start_time).total_seconds()
rank.total_time += info["ac_time"]
if problem.accepted_number == 1:
info["is_first_ac"] = True
else:
info["error_number"] = 1
rank.submission_info[str(self.submission.problem_id)] = info
rank.save()
def _update_oi_contest_rank(self, rank):
problem_id = str(self.submission.problem_id)
current_score = self.submission.statistic_info["score"]
last_score = rank.submission_info.get(problem_id)
if last_score:
rank.total_score = rank.total_score - last_score + current_score
else:
rank.total_score = rank.total_score + current_score
rank.submission_info[problem_id] = current_score
rank.save()

View File

@ -1,7 +1,7 @@
_c_lang_config = { _c_lang_config = {
"template": """//PREPEND START "template": """//PREPEND BEGIN
#include <stdio.h> #include <stdio.h>
//PREPEND END //PREPEND END
@ -12,7 +12,7 @@ int add(int a, int b) {
} }
//TEMPLATE END //TEMPLATE END
//APPEND START //APPEND BEGIN
int main() { int main() {
printf("%d", add(1, 2)); printf("%d", add(1, 2));
return 0; return 0;
@ -23,7 +23,7 @@ int main() {
"exe_name": "main", "exe_name": "main",
"max_cpu_time": 3000, "max_cpu_time": 3000,
"max_real_time": 5000, "max_real_time": 5000,
"max_memory": 128 * 1024 * 1024, "max_memory": 256 * 1024 * 1024,
"compile_command": "/usr/bin/gcc -DONLINE_JUDGE -O2 -w -fmax-errors=3 -std=c99 {src_path} -lm -o {exe_path}", "compile_command": "/usr/bin/gcc -DONLINE_JUDGE -O2 -w -fmax-errors=3 -std=c99 {src_path} -lm -o {exe_path}",
}, },
"run": { "run": {
@ -48,18 +48,29 @@ _c_lang_spj_config = {
} }
_cpp_lang_config = { _cpp_lang_config = {
"template": """/*--PREPEND START--*/ "template": """//PREPEND BEGIN
/*--PREPEND END--*/ #include <iostream>
/*--TEMPLATE BEGIN--*/ //PREPEND END
/*--TEMPLATE END--*/
/*--APPEND START--*/ //TEMPLATE BEGIN
/*--APPEND END--*/""", int add(int a, int b) {
// Please fill this blank
return ___________;
}
//TEMPLATE END
//APPEND BEGIN
int main() {
std::cout << add(1, 2);
return 0;
}
//APPEND END""",
"compile": { "compile": {
"src_name": "main.cpp", "src_name": "main.cpp",
"exe_name": "main", "exe_name": "main",
"max_cpu_time": 3000, "max_cpu_time": 3000,
"max_real_time": 5000, "max_real_time": 5000,
"max_memory": 128 * 1024 * 1024, "max_memory": 512 * 1024 * 1024,
"compile_command": "/usr/bin/g++ -DONLINE_JUDGE -O2 -w -fmax-errors=3 -std=c++11 {src_path} -lm -o {exe_path}", "compile_command": "/usr/bin/g++ -DONLINE_JUDGE -O2 -w -fmax-errors=3 -std=c++11 {src_path} -lm -o {exe_path}",
}, },
"run": { "run": {
@ -99,8 +110,8 @@ _java_lang_config = {
"compile_command": "/usr/bin/javac {src_path} -d {exe_dir} -encoding UTF8" "compile_command": "/usr/bin/javac {src_path} -d {exe_dir} -encoding UTF8"
}, },
"run": { "run": {
"command": "/usr/bin/java -cp {exe_dir} -Xss1M -XX:MaxPermSize=16M -XX:PermSize=8M -Xms16M -Xmx{max_memory}k " "command": "/usr/bin/java -cp {exe_dir} -Xss1M -Xms16M -Xmx{max_memory}k "
"-Djava.security.manager -Djava.security.policy==/etc/java_policy -Djava.awt.headless=true Main", "-Djava.security.manager -Djava.security.policy=/etc/java_policy -Djava.awt.headless=true Main",
"seccomp_rule": None, "seccomp_rule": None,
"env": ["MALLOC_ARENA_MAX=1"] "env": ["MALLOC_ARENA_MAX=1"]
} }

8
judge/tasks.py Normal file
View File

@ -0,0 +1,8 @@
from __future__ import absolute_import, unicode_literals
from celery import shared_task
from judge.dispatcher import JudgeDispatcher
@shared_task
def judge_task(submission_id, problem_id):
JudgeDispatcher(submission_id, problem_id).judge()

View File

@ -0,0 +1,6 @@
from __future__ import absolute_import, unicode_literals
# Django starts so that shared_task will use this app.
from .celery import app as celery_app
__all__ = ["celery_app"]

18
oj/celery.py Normal file
View File

@ -0,0 +1,18 @@
from __future__ import absolute_import, unicode_literals
import os
from celery import Celery
from django.conf import settings
# set the default Django settings module for the "celery" program.
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "oj.settings")
app = Celery("oj")
# Using a string here means the worker will not have to
# pickle the object when using Windows.
app.config_from_object("django.conf:settings")
# load task modules from all registered Django app configs.
app.autodiscover_tasks(lambda: settings.INSTALLED_APPS)
# app.autodiscover_tasks()

View File

@ -1,19 +0,0 @@
class DBRouter(object):
def db_for_read(self, model, **hints):
if model._meta.app_label == "submission":
return "submission"
return "default"
def db_for_write(self, model, **hints):
if model._meta.app_label == "submission":
return "submission"
return "default"
def allow_relation(self, obj1, obj2, **hints):
return True
def allow_migrate(self, db, app_label, model=None, **hints):
if app_label == "submission":
return db == app_label
else:
return db == "default"

27
oj/dev_settings.py Normal file
View File

@ -0,0 +1,27 @@
# coding=utf-8
import os
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
DATABASES = {
'default': {
'ENGINE': 'django.db.backends.postgresql_psycopg2',
'HOST': '127.0.0.1',
'PORT': 5433,
'NAME': "onlinejudge",
'USER': "onlinejudge",
'PASSWORD': 'onlinejudge'
}
}
REDIS_CONF = {
"host": "127.0.0.1",
"port": "6379"
}
DEBUG = True
ALLOWED_HOSTS = ["*"]
DATA_DIR = f"{BASE_DIR}/data"

View File

@ -1,31 +0,0 @@
# coding=utf-8
import os
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
DATABASES = {
'default': {
'ENGINE': 'django.db.backends.sqlite3',
'NAME': os.path.join(BASE_DIR, 'db.sqlite3'),
}
}
REDIS_CACHE = {
"host": "127.0.0.1",
"port": 6379,
"db": 1
}
REDIS_QUEUE = {
"host": "127.0.0.1",
"port": 6379,
"db": 2
}
DEBUG = True
ALLOWED_HOSTS = ["*"]
TEST_CASE_DIR = "/tmp"
LOG_PATH = "log/"

28
oj/production_settings.py Normal file
View File

@ -0,0 +1,28 @@
import os
def get_env(name, default=""):
return os.environ.get(name, default)
DATABASES = {
'default': {
'ENGINE': 'django.db.backends.postgresql_psycopg2',
'HOST': get_env("POSTGRES_HOST", "oj-postgres"),
'PORT': get_env("POSTGRES_PORT", "5432"),
'NAME': get_env("POSTGRES_DB"),
'USER': get_env("POSTGRES_USER"),
'PASSWORD': get_env("POSTGRES_PASSWORD")
}
}
REDIS_CONF = {
"host": get_env("REDIS_HOST", "oj-redis"),
"port": get_env("REDIS_PORT", "6379")
}
DEBUG = False
ALLOWED_HOSTS = ['*']
DATA_DIR = "/data"

View File

@ -1,37 +0,0 @@
# coding=utf-8
import os
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
DATABASES = {
'default': {
'ENGINE': 'django.db.backends.mysql',
'NAME': "oj",
'CONN_MAX_AGE': 0.1,
'HOST': os.environ["MYSQL_PORT_3306_TCP_ADDR"],
'PORT': 3306,
'USER': os.environ["MYSQL_ENV_MYSQL_USER"],
'PASSWORD': os.environ["MYSQL_ENV_MYSQL_ROOT_PASSWORD"]
}
}
REDIS_CACHE = {
"host": os.environ["REDIS_PORT_6379_TCP_ADDR"],
"port": 6379,
"db": 1
}
REDIS_QUEUE = {
"host": os.environ["REDIS_PORT_6379_TCP_ADDR"],
"port": 6379,
"db": 2
}
DEBUG = False
ALLOWED_HOSTS = ['*']
TEST_CASE_DIR = "/test_case"
LOG_PATH = "log/"

View File

@ -1,8 +1,7 @@
# coding=utf-8
""" """
Django settings for oj project. Django settings for oj project.
Generated by 'django-admin startproject' using Django 1.8. Generated by 'django-admin startproject' using Django 1.11.
For more information on this file, see For more information on this file, see
https://docs.djangoproject.com/en/1.8/topics/settings/ https://docs.djangoproject.com/en/1.8/topics/settings/
@ -10,59 +9,54 @@ https://docs.djangoproject.com/en/1.8/topics/settings/
For the full list of settings and their values, see For the full list of settings and their values, see
https://docs.djangoproject.com/en/1.8/ref/settings/ https://docs.djangoproject.com/en/1.8/ref/settings/
""" """
# Build paths inside the project like this: os.path.join(BASE_DIR, ...)
import os import os
from copy import deepcopy
if os.environ.get("OJ_ENV") == "production":
from .production_settings import *
else:
from .dev_settings import *
from .custom_settings import * from .custom_settings import *
# 判断运行环境
ENV = os.environ.get("oj_env", "local")
if ENV == "local":
from .local_settings import *
elif ENV == "server":
from .server_settings import *
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
# Applications
# Quick-start development settings - unsuitable for production VENDOR_APPS = (
# See https://docs.djangoproject.com/en/1.8/howto/deployment/checklist/
# Application definition
INSTALLED_APPS = (
'django.contrib.auth', 'django.contrib.auth',
'django.contrib.contenttypes',
'django.contrib.sessions', 'django.contrib.sessions',
'django.contrib.contenttypes',
'django.contrib.messages', 'django.contrib.messages',
'django.contrib.staticfiles', 'django.contrib.staticfiles',
'rest_framework',
)
LOCAL_APPS = (
'account', 'account',
'announcement', 'announcement',
'conf', 'conf',
'problem', 'problem',
'contest', 'contest',
'utils', 'utils',
'submission',
'rest_framework', 'options',
'judge',
) )
INSTALLED_APPS = VENDOR_APPS + LOCAL_APPS
MIDDLEWARE_CLASSES = ( MIDDLEWARE_CLASSES = (
'django.contrib.sessions.middleware.SessionMiddleware', 'django.contrib.sessions.middleware.SessionMiddleware',
'django.middleware.common.CommonMiddleware', 'django.middleware.common.CommonMiddleware',
'django.middleware.csrf.CsrfViewMiddleware', 'django.middleware.csrf.CsrfViewMiddleware',
'django.contrib.auth.middleware.AuthenticationMiddleware', 'django.contrib.auth.middleware.AuthenticationMiddleware',
'django.contrib.auth.middleware.SessionAuthenticationMiddleware', 'account.middleware.APITokenAuthMiddleware',
'django.contrib.messages.middleware.MessageMiddleware', 'django.contrib.messages.middleware.MessageMiddleware',
'django.middleware.clickjacking.XFrameOptionsMiddleware', 'django.middleware.clickjacking.XFrameOptionsMiddleware',
'django.middleware.security.SecurityMiddleware', 'django.middleware.security.SecurityMiddleware',
'account.middleware.AdminRoleRequiredMiddleware', 'account.middleware.AdminRoleRequiredMiddleware',
'account.middleware.SessionSecurityMiddleware', 'account.middleware.SessionRecordMiddleware',
'account.middleware.TimezoneMiddleware' # 'account.middleware.LogSqlMiddleware',
) )
ROOT_URLCONF = 'oj.urls' ROOT_URLCONF = 'oj.urls'
TEMPLATES = [ TEMPLATES = [
@ -80,9 +74,26 @@ TEMPLATES = [
}, },
}, },
] ]
WSGI_APPLICATION = 'oj.wsgi.application' WSGI_APPLICATION = 'oj.wsgi.application'
# Password validation
# https://docs.djangoproject.com/en/1.9/ref/settings/#auth-password-validators
AUTH_PASSWORD_VALIDATORS = [
{
'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator',
},
{
'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator',
},
{
'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator',
},
{
'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator',
},
]
# Internationalization # Internationalization
# https://docs.djangoproject.com/en/1.8/topics/i18n/ # https://docs.djangoproject.com/en/1.8/topics/i18n/
@ -96,35 +107,34 @@ USE_L10N = True
USE_TZ = True USE_TZ = True
# Static files (CSS, JavaScript, Images) # Static files (CSS, JavaScript, Images)
# https://docs.djangoproject.com/en/1.8/howto/static-files/ # https://docs.djangoproject.com/en/1.8/howto/static-files/
STATIC_URL = '/static/' STATIC_URL = '/public/'
AUTH_USER_MODEL = 'account.User' AUTH_USER_MODEL = 'account.User'
TEST_CASE_DIR = os.path.join(DATA_DIR, "test_case")
LOG_PATH = os.path.join(DATA_DIR, "log")
AVATAR_URI_PREFIX = "/public/avatar"
AVATAR_UPLOAD_DIR = f"{DATA_DIR}{AVATAR_URI_PREFIX}"
UPLOAD_PREFIX = "/public/upload"
UPLOAD_DIR = f"{DATA_DIR}{UPLOAD_PREFIX}"
STATICFILES_DIRS = [os.path.join(DATA_DIR, "public")]
LOGGING = { LOGGING = {
'version': 1, 'version': 1,
'disable_existing_loggers': True, 'disable_existing_loggers': False,
'formatters': { 'formatters': {
'standard': { 'standard': {
'format': '%(asctime)s [%(threadName)s:%(thread)d] [%(name)s:%(lineno)d] [%(module)s:%(funcName)s] [%(levelname)s]- %(message)s'} 'format': '[%(asctime)s] - [%(levelname)s] - [%(name)s:%(lineno)d] - %(message)s',
# 日志格式 'datefmt': '%Y-%m-%d %H:%M:%S'
}
}, },
'handlers': { 'handlers': {
'django_error': {
'level': 'DEBUG',
'class': 'logging.handlers.RotatingFileHandler',
'filename': os.path.join(LOG_PATH, 'django.log'),
'formatter': 'standard'
},
'app_info': {
'level': 'DEBUG',
'class': 'logging.handlers.RotatingFileHandler',
'filename': os.path.join(LOG_PATH, 'app_info.log'),
'formatter': 'standard'
},
'console': { 'console': {
'level': 'DEBUG', 'level': 'DEBUG',
'class': 'logging.StreamHandler', 'class': 'logging.StreamHandler',
@ -132,25 +142,24 @@ LOGGING = {
} }
}, },
'loggers': { 'loggers': {
'app_info': {
'handlers': ['app_info', "console"],
'level': 'DEBUG',
'propagate': True
},
'django.request': { 'django.request': {
'handlers': ['django_error', 'console'], 'handlers': ['console'],
'level': 'DEBUG', 'level': 'ERROR',
'propagate': True, 'propagate': True,
}, },
'django.db.backends': { 'django.db.backends': {
'handlers': ['console'], 'handlers': ['console'],
'level': 'ERROR', 'level': 'ERROR',
'propagate': True, 'propagate': True,
},
'': {
'handlers': ['console'],
'level': 'WARNING',
'propagate': True,
} }
}, },
} }
REST_FRAMEWORK = { REST_FRAMEWORK = {
'TEST_REQUEST_DEFAULT_FORMAT': 'json', 'TEST_REQUEST_DEFAULT_FORMAT': 'json',
'DEFAULT_RENDERER_CLASSES': ( 'DEFAULT_RENDERER_CLASSES': (
@ -158,17 +167,37 @@ REST_FRAMEWORK = {
) )
} }
# for celery REDIS_URL = "redis://%s:%s" % (REDIS_CONF["host"], REDIS_CONF["port"])
BROKER_URL = 'redis://%s:%s/%s' % (REDIS_QUEUE["host"], str(REDIS_QUEUE["port"]), str(REDIS_QUEUE["db"]))
def redis_config(db):
def make_key(key, key_prefix, version):
return key
return {
"BACKEND": "utils.cache.MyRedisCache",
"LOCATION": f"{REDIS_URL}/{db}",
"TIMEOUT": None,
"KEY_PREFIX": "",
"KEY_FUNCTION": make_key
}
CACHES = {
"default": redis_config(db=1)
}
SESSION_ENGINE = "django.contrib.sessions.backends.cache"
SESSION_CACHE_ALIAS = "default"
CELERY_RESULT_BACKEND = f"{REDIS_URL}/2"
BROKER_URL = f"{REDIS_URL}/3"
CELERY_TASK_SOFT_TIME_LIMIT = CELERY_TASK_TIME_LIMIT = 180
CELERY_ACCEPT_CONTENT = ["json"] CELERY_ACCEPT_CONTENT = ["json"]
CELERY_TASK_SERIALIZER = "json" CELERY_TASK_SERIALIZER = "json"
DATABASE_ROUTERS = ['oj.db_router.DBRouter']
IMAGE_UPLOAD_DIR = os.path.join(BASE_DIR, 'upload/')
# 用于限制用户恶意提交大量代码 # 用于限制用户恶意提交大量代码
TOKEN_BUCKET_DEFAULT_CAPACITY = 50 TOKEN_BUCKET_DEFAULT_CAPACITY = 20
# 单位:每分钟 # 单位:每分钟
TOKEN_BUCKET_FILL_RATE = 2 TOKEN_BUCKET_FILL_RATE = 2

View File

@ -3,12 +3,15 @@ from django.conf.urls import include, url
urlpatterns = [ urlpatterns = [
url(r"^api/", include("account.urls.oj")), url(r"^api/", include("account.urls.oj")),
url(r"^api/admin/", include("account.urls.admin")), url(r"^api/admin/", include("account.urls.admin")),
url(r"^api/account/", include("account.urls.user")), url(r"^api/", include("announcement.urls.oj")),
url(r"^api/admin/", include("announcement.urls.admin")), url(r"^api/admin/", include("announcement.urls.admin")),
url(r"^api/", include("conf.urls.oj")), url(r"^api/", include("conf.urls.oj")),
url(r"^api/admin/", include("conf.urls.admin")), url(r"^api/admin/", include("conf.urls.admin")),
url(r"^api/", include("problem.urls.oj")), url(r"^api/", include("problem.urls.oj")),
url(r"^api/admin/", include("problem.urls.admin")), url(r"^api/admin/", include("problem.urls.admin")),
url(r"^api/", include("contest.urls.oj")),
url(r"^api/admin/", include("contest.urls.admin")), url(r"^api/admin/", include("contest.urls.admin")),
url(r"^api/", include("contest.urls.oj")) url(r"^api/", include("submission.urls.oj")),
url(r"^api/admin/", include("submission.urls.admin")),
url(r"^api/admin/", include("utils.urls")),
] ]

0
options/__init__.py Normal file
View File

View File

@ -0,0 +1,25 @@
# -*- coding: utf-8 -*-
# Generated by Django 1.11.4 on 2017-10-23 08:11
from __future__ import unicode_literals
import django.contrib.postgres.fields.jsonb
from django.db import migrations, models
class Migration(migrations.Migration):
initial = True
dependencies = [
]
operations = [
migrations.CreateModel(
name='SysOptions',
fields=[
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('key', models.CharField(db_index=True, max_length=128, unique=True)),
('value', django.contrib.postgres.fields.jsonb.JSONField()),
],
),
]

View File

7
options/models.py Normal file
View File

@ -0,0 +1,7 @@
from django.db import models
from utils.models import JSONField
class SysOptions(models.Model):
key = models.CharField(max_length=128, unique=True, db_index=True)
value = JSONField()

185
options/options.py Normal file
View File

@ -0,0 +1,185 @@
import os
from django.core.cache import cache
from django.db import transaction, IntegrityError
from utils.constants import CacheKey
from utils.shortcuts import rand_str
from .models import SysOptions as SysOptionsModel
def default_token():
token = os.environ.get("JUDGE_SERVER_TOKEN")
return token if token else rand_str()
class OptionKeys:
website_base_url = "website_base_url"
website_name = "website_name"
website_name_shortcut = "website_name_shortcut"
website_footer = "website_footer"
allow_register = "allow_register"
submission_list_show_all = "submission_list_show_all"
smtp_config = "smtp_config"
judge_server_token = "judge_server_token"
class OptionDefaultValue:
website_base_url = "http://127.0.0.1"
website_name = "Online Judge"
website_name_shortcut = "oj"
website_footer = "Online Judge Footer"
allow_register = True
submission_list_show_all = True
smtp_config = {}
judge_server_token = default_token
class _SysOptionsMeta(type):
@classmethod
def _set_cache(mcs, option_key, option_value):
cache.set(f"{CacheKey.option}:{option_key}", option_value, timeout=60)
@classmethod
def _del_cache(mcs, option_key):
cache.delete(f"{CacheKey.option}:{option_key}")
@classmethod
def _get_keys(cls):
return [key for key in OptionKeys.__dict__ if not key.startswith("__")]
def rebuild_cache(cls):
for key in cls._get_keys():
# get option 的时候会写 cache 的
cls._get_option(key, use_cache=False)
@classmethod
def _init_option(mcs):
for item in mcs._get_keys():
if not SysOptionsModel.objects.filter(key=item).exists():
default_value = getattr(OptionDefaultValue, item)
if callable(default_value):
default_value = default_value()
try:
SysOptionsModel.objects.create(key=item, value=default_value)
except IntegrityError:
pass
@classmethod
def _get_option(mcs, option_key, use_cache=True):
try:
if use_cache:
option = cache.get(f"{CacheKey.option}:{option_key}")
if option:
return option
option = SysOptionsModel.objects.get(key=option_key)
value = option.value
mcs._set_cache(option_key, value)
return value
except SysOptionsModel.DoesNotExist:
mcs._init_option()
return mcs._get_option(option_key, use_cache=use_cache)
@classmethod
def _set_option(mcs, option_key: str, option_value):
try:
with transaction.atomic():
option = SysOptionsModel.objects.select_for_update().get(key=option_key)
option.value = option_value
option.save()
mcs._del_cache(option_key)
except SysOptionsModel.DoesNotExist:
mcs._init_option()
mcs._set_option(option_key, option_value)
@classmethod
def _increment(mcs, option_key):
try:
with transaction.atomic():
option = SysOptionsModel.objects.select_for_update().get(key=option_key)
value = option.value + 1
option.value = value
option.save()
mcs._del_cache(option_key)
except SysOptionsModel.DoesNotExist:
mcs._init_option()
return mcs._increment(option_key)
@classmethod
def set_options(mcs, options):
for key, value in options:
mcs._set_option(key, value)
@classmethod
def get_options(mcs, keys):
result = {}
for key in keys:
result[key] = mcs._get_option(key)
return result
@property
def website_base_url(cls):
return cls._get_option(OptionKeys.website_base_url)
@website_base_url.setter
def website_base_url(cls, value):
cls._set_option(OptionKeys.website_base_url, value)
@property
def website_name(cls):
return cls._get_option(OptionKeys.website_name)
@website_name.setter
def website_name(cls, value):
cls._set_option(OptionKeys.website_name, value)
@property
def website_name_shortcut(cls):
return cls._get_option(OptionKeys.website_name_shortcut)
@website_name_shortcut.setter
def website_name_shortcut(cls, value):
cls._set_option(OptionKeys.website_name_shortcut, value)
@property
def website_footer(cls):
return cls._get_option(OptionKeys.website_footer)
@website_footer.setter
def website_footer(cls, value):
cls._set_option(OptionKeys.website_footer, value)
@property
def allow_register(cls):
return cls._get_option(OptionKeys.allow_register)
@allow_register.setter
def allow_register(cls, value):
cls._set_option(OptionKeys.allow_register, value)
@property
def submission_list_show_all(cls):
return cls._get_option(OptionKeys.submission_list_show_all)
@submission_list_show_all.setter
def submission_list_show_all(cls, value):
cls._set_option(OptionKeys.submission_list_show_all, value)
@property
def smtp_config(cls):
return cls._get_option(OptionKeys.smtp_config)
@smtp_config.setter
def smtp_config(cls, value):
cls._set_option(OptionKeys.smtp_config, value)
@property
def judge_server_token(cls):
return cls._get_option(OptionKeys.judge_server_token)
@judge_server_token.setter
def judge_server_token(cls, value):
cls._set_option(OptionKeys.judge_server_token, value)
class SysOptions(metaclass=_SysOptionsMeta):
pass

1
options/tests.py Normal file
View File

@ -0,0 +1 @@
# Create your tests here.

1
options/views.py Normal file
View File

@ -0,0 +1 @@
# Create your views here.

View File

@ -0,0 +1,35 @@
# -*- coding: utf-8 -*-
# Generated by Django 1.9.6 on 2017-05-01 06:37
from __future__ import unicode_literals
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('problem', '0003_auto_20170217_0820'),
]
operations = [
migrations.AlterField(
model_name='contestproblem',
name='total_accepted_number',
field=models.BigIntegerField(default=0),
),
migrations.AlterField(
model_name='contestproblem',
name='total_submit_number',
field=models.BigIntegerField(default=0),
),
migrations.AlterField(
model_name='problem',
name='total_accepted_number',
field=models.BigIntegerField(default=0),
),
migrations.AlterField(
model_name='problem',
name='total_submit_number',
field=models.BigIntegerField(default=0),
),
]

View File

@ -0,0 +1,26 @@
# -*- coding: utf-8 -*-
# Generated by Django 1.9.6 on 2017-08-15 12:58
from __future__ import unicode_literals
from django.db import migrations
import jsonfield.fields
class Migration(migrations.Migration):
dependencies = [
('problem', '0004_auto_20170501_0637'),
]
operations = [
migrations.AddField(
model_name='contestproblem',
name='statistic_info',
field=jsonfield.fields.JSONField(default={}),
),
migrations.AddField(
model_name='problem',
name='statistic_info',
field=jsonfield.fields.JSONField(default={}),
),
]

View File

@ -0,0 +1,35 @@
# -*- coding: utf-8 -*-
# Generated by Django 1.9.6 on 2017-08-23 09:18
from __future__ import unicode_literals
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
('problem', '0005_auto_20170815_1258'),
]
operations = [
migrations.RenameField(
model_name='contestproblem',
old_name='total_accepted_number',
new_name='accepted_number',
),
migrations.RenameField(
model_name='contestproblem',
old_name='total_submit_number',
new_name='submission_number',
),
migrations.RenameField(
model_name='problem',
old_name='total_accepted_number',
new_name='accepted_number',
),
migrations.RenameField(
model_name='problem',
old_name='total_submit_number',
new_name='submission_number',
),
]

View File

@ -0,0 +1,66 @@
# -*- coding: utf-8 -*-
# Generated by Django 1.11.4 on 2017-09-23 13:18
from __future__ import unicode_literals
from django.db import migrations, models
import django.db.models.deletion
class Migration(migrations.Migration):
dependencies = [
('contest', '0005_auto_20170823_0918'),
('problem', '0006_auto_20170823_0918'),
]
operations = [
migrations.AddField(
model_name='contestproblem',
name='total_score',
field=models.IntegerField(blank=True, default=0),
),
migrations.AddField(
model_name='problem',
name='total_score',
field=models.IntegerField(blank=True, default=0),
),
migrations.AlterUniqueTogether(
name='contestproblem',
unique_together=set([]),
),
migrations.RemoveField(
model_name='contestproblem',
name='contest',
),
migrations.RemoveField(
model_name='contestproblem',
name='created_by',
),
migrations.RemoveField(
model_name='contestproblem',
name='tags',
),
migrations.AddField(
model_name='problem',
name='contest',
field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, to='contest.Contest'),
preserve_default=False,
),
migrations.AddField(
model_name='problem',
name='is_public',
field=models.BooleanField(default=False),
),
migrations.AlterField(
model_name='problem',
name='_id',
field=models.CharField(db_index=True, max_length=24),
),
migrations.AlterUniqueTogether(
name='problem',
unique_together=set([('_id', 'contest')]),
),
migrations.DeleteModel(
name='ContestProblem',
),
]

View File

@ -0,0 +1,45 @@
# -*- coding: utf-8 -*-
# Generated by Django 1.11.4 on 2017-10-11 12:14
from __future__ import unicode_literals
import django.contrib.postgres.fields.jsonb
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
('problem', '0008_auto_20170923_1318'),
]
operations = [
migrations.AlterField(
model_name='problem',
name='languages',
field=django.contrib.postgres.fields.jsonb.JSONField(),
),
migrations.AlterField(
model_name='problem',
name='samples',
field=django.contrib.postgres.fields.jsonb.JSONField(),
),
migrations.AlterField(
model_name='problem',
name='statistic_info',
field=django.contrib.postgres.fields.jsonb.JSONField(default=dict),
),
migrations.AlterField(
model_name='problem',
name='template',
field=django.contrib.postgres.fields.jsonb.JSONField(),
),
migrations.AlterField(
model_name='problem',
name='test_case_score',
field=django.contrib.postgres.fields.jsonb.JSONField(),
),
migrations.AlterModelOptions(
name='problem',
options={'ordering': ('create_time',)},
),
]

View File

@ -0,0 +1,20 @@
# -*- coding: utf-8 -*-
# Generated by Django 1.11.4 on 2017-11-16 12:42
from __future__ import unicode_literals
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('problem', '0009_auto_20171011_1214'),
]
operations = [
migrations.AddField(
model_name='problem',
name='spj_compile_ok',
field=models.BooleanField(default=False),
),
]

View File

@ -1,5 +1,5 @@
from django.db import models from django.db import models
from jsonfield import JSONField from utils.models import JSONField
from account.models import User from account.models import User
from contest.models import Contest from contest.models import Contest
@ -18,7 +18,18 @@ class ProblemRuleType(object):
OI = "OI" OI = "OI"
class AbstractProblem(models.Model): class ProblemDifficulty(object):
High = "High"
Mid = "Mid"
Low = "Low"
class Problem(models.Model):
# display ID
_id = models.CharField(max_length=24, db_index=True)
contest = models.ForeignKey(Contest, null=True, blank=True)
# for contest problem
is_public = models.BooleanField(default=False)
title = models.CharField(max_length=128) title = models.CharField(max_length=128)
# HTML # HTML
description = RichTextField() description = RichTextField()
@ -27,6 +38,7 @@ class AbstractProblem(models.Model):
# [{input: "test", output: "123"}, {input: "test123", output: "456"}] # [{input: "test", output: "123"}, {input: "test123", output: "456"}]
samples = JSONField() samples = JSONField()
test_case_id = models.CharField(max_length=32) test_case_id = models.CharField(max_length=32)
# [{"input_name": "1.in", "output_name": "1.out", "score": 0}]
test_case_score = JSONField() test_case_score = JSONField()
hint = RichTextField(blank=True, null=True) hint = RichTextField(blank=True, null=True)
languages = JSONField() languages = JSONField()
@ -44,37 +56,28 @@ class AbstractProblem(models.Model):
spj_language = models.CharField(max_length=32, blank=True, null=True) spj_language = models.CharField(max_length=32, blank=True, null=True)
spj_code = models.TextField(blank=True, null=True) spj_code = models.TextField(blank=True, null=True)
spj_version = models.CharField(max_length=32, blank=True, null=True) spj_version = models.CharField(max_length=32, blank=True, null=True)
spj_compile_ok = models.BooleanField(default=False)
rule_type = models.CharField(max_length=32) rule_type = models.CharField(max_length=32)
visible = models.BooleanField(default=True) visible = models.BooleanField(default=True)
difficulty = models.CharField(max_length=32) difficulty = models.CharField(max_length=32)
tags = models.ManyToManyField(ProblemTag) tags = models.ManyToManyField(ProblemTag)
source = models.CharField(max_length=200, blank=True, null=True) source = models.CharField(max_length=200, blank=True, null=True)
total_submit_number = models.IntegerField(default=0) # for OI mode
total_accepted_number = models.IntegerField(default=0) total_score = models.IntegerField(default=0, blank=True)
submission_number = models.BigIntegerField(default=0)
accepted_number = models.BigIntegerField(default=0)
# {JudgeStatus.ACCEPTED: 3, JudgeStaus.WRONG_ANSWER: 11}, the number means count
statistic_info = JSONField(default=dict)
class Meta: class Meta:
db_table = "problem" db_table = "problem"
abstract = True unique_together = (("_id", "contest"),)
ordering = ("create_time",)
def add_submission_number(self): def add_submission_number(self):
self.accepted_problem_number = models.F("total_submit_number") + 1 self.submission_number = models.F("submission_number") + 1
self.save() self.save(update_fields=["submission_number"])
def add_ac_number(self): def add_ac_number(self):
self.accepted_problem_number = models.F("total_accepted_number") + 1 self.accepted_number = models.F("accepted_number") + 1
self.save() self.save(update_fields=["accepted_number"])
class Problem(AbstractProblem):
_id = models.CharField(max_length=24, unique=True, db_index=True)
class ContestProblem(AbstractProblem):
_id = models.CharField(max_length=24, db_index=True)
contest = models.ForeignKey(Contest)
# 是否已经公开了题目,防止重复公开
is_public = models.BooleanField(default=False)
class Meta:
db_table = "contest_problem"
unique_together = (("_id", "contest"), )

View File

@ -4,6 +4,7 @@ from judge.languages import language_names, spj_language_names
from utils.api import DateTimeTZField, UsernameSerializer, serializers from utils.api import DateTimeTZField, UsernameSerializer, serializers
from .models import Problem, ProblemRuleType, ProblemTag from .models import Problem, ProblemRuleType, ProblemTag
from .utils import parse_problem_template
class TestCaseUploadForm(forms.Form): class TestCaseUploadForm(forms.Form):
@ -12,8 +13,8 @@ class TestCaseUploadForm(forms.Form):
class CreateSampleSerializer(serializers.Serializer): class CreateSampleSerializer(serializers.Serializer):
input = serializers.CharField() input = serializers.CharField(trim_whitespace=False)
output = serializers.CharField() output = serializers.CharField(trim_whitespace=False)
class CreateTestCaseScoreSerializer(serializers.Serializer): class CreateTestCaseScoreSerializer(serializers.Serializer):
@ -39,7 +40,7 @@ class CreateOrEditProblemSerializer(serializers.Serializer):
input_description = serializers.CharField() input_description = serializers.CharField()
output_description = serializers.CharField() output_description = serializers.CharField()
samples = serializers.ListField(child=CreateSampleSerializer(), allow_empty=False) samples = serializers.ListField(child=CreateSampleSerializer(), allow_empty=False)
test_case_id = serializers.CharField(min_length=32, max_length=32) test_case_id = serializers.CharField(max_length=32)
test_case_score = serializers.ListField(child=CreateTestCaseScoreSerializer(), allow_empty=False) test_case_score = serializers.ListField(child=CreateTestCaseScoreSerializer(), allow_empty=False)
time_limit = serializers.IntegerField(min_value=1, max_value=1000 * 60) time_limit = serializers.IntegerField(min_value=1, max_value=1000 * 60)
memory_limit = serializers.IntegerField(min_value=1, max_value=1024) memory_limit = serializers.IntegerField(min_value=1, max_value=1024)
@ -49,6 +50,7 @@ class CreateOrEditProblemSerializer(serializers.Serializer):
spj = serializers.BooleanField() spj = serializers.BooleanField()
spj_language = serializers.ChoiceField(choices=spj_language_names, allow_blank=True, allow_null=True) spj_language = serializers.ChoiceField(choices=spj_language_names, allow_blank=True, allow_null=True)
spj_code = serializers.CharField(allow_blank=True, allow_null=True) spj_code = serializers.CharField(allow_blank=True, allow_null=True)
spj_compile_ok = serializers.BooleanField(default=False)
visible = serializers.BooleanField() visible = serializers.BooleanField()
difficulty = serializers.ChoiceField(choices=[Difficulty.LOW, Difficulty.MID, Difficulty.HIGH]) difficulty = serializers.ChoiceField(choices=[Difficulty.LOW, Difficulty.MID, Difficulty.HIGH])
tags = serializers.ListField(child=serializers.CharField(max_length=32), allow_empty=False) tags = serializers.ListField(child=serializers.CharField(max_length=32), allow_empty=False)
@ -68,12 +70,23 @@ class CreateContestProblemSerializer(CreateOrEditProblemSerializer):
contest_id = serializers.IntegerField() contest_id = serializers.IntegerField()
class EditContestProblemSerializer(CreateOrEditProblemSerializer):
id = serializers.IntegerField()
contest_id = serializers.IntegerField()
class TagSerializer(serializers.ModelSerializer): class TagSerializer(serializers.ModelSerializer):
class Meta: class Meta:
model = ProblemTag model = ProblemTag
fields = "__all__"
class ProblemSerializer(serializers.ModelSerializer): class CompileSPJSerializer(serializers.Serializer):
spj_language = serializers.ChoiceField(choices=spj_language_names)
spj_code = serializers.CharField()
class BaseProblemSerializer(serializers.ModelSerializer):
samples = serializers.JSONField() samples = serializers.JSONField()
test_case_score = serializers.JSONField() test_case_score = serializers.JSONField()
languages = serializers.JSONField() languages = serializers.JSONField()
@ -82,6 +95,100 @@ class ProblemSerializer(serializers.ModelSerializer):
create_time = DateTimeTZField() create_time = DateTimeTZField()
last_update_time = DateTimeTZField() last_update_time = DateTimeTZField()
created_by = UsernameSerializer() created_by = UsernameSerializer()
statistic_info = serializers.JSONField()
class ProblemAdminSerializer(BaseProblemSerializer):
class Meta:
model = Problem
fields = "__all__"
class ContestProblemAdminSerializer(BaseProblemSerializer):
class Meta:
model = Problem
fields = "__all__"
class ProblemSerializer(BaseProblemSerializer):
template = serializers.SerializerMethodField()
def get_template(self, obj):
ret = {}
for lang, code in obj.template.items():
ret[lang] = parse_problem_template(code)["template"]
return ret
class Meta: class Meta:
model = Problem model = Problem
exclude = ("contest", "test_case_score", "test_case_id", "visible", "is_public",
"template", "spj_code", "spj_version", "spj_compile_ok")
class ContestProblemSerializer(BaseProblemSerializer):
class Meta:
model = Problem
exclude = ("test_case_score", "test_case_id", "visible", "is_public", "difficulty")
class ContestProblemSafeSerializer(BaseProblemSerializer):
class Meta:
model = Problem
exclude = ("test_case_score", "test_case_id", "visible", "is_public", "difficulty",
"submission_number", "accepted_number", "statistic_info")
class ContestProblemMakePublicSerializer(serializers.Serializer):
id = serializers.IntegerField()
display_id = serializers.CharField(max_length=32)
class ExportProblemSerializer(serializers.ModelSerializer):
description = serializers.SerializerMethodField()
input_description = serializers.SerializerMethodField()
output_description = serializers.SerializerMethodField()
test_case_score = serializers.SerializerMethodField()
hint = serializers.SerializerMethodField()
time_limit = serializers.SerializerMethodField()
memory_limit = serializers.SerializerMethodField()
spj = serializers.SerializerMethodField()
template = serializers.SerializerMethodField()
def get_description(self, obj):
return {"format": "html", "value": obj.description}
def get_input_description(self, obj):
return {"format": "html", "value": obj.input_description}
def get_output_description(self, obj):
return {"format": "html", "value": obj.output_description}
def get_hint(self, obj):
return {"format": "html", "value": obj.hint}
def get_test_case_score(self, obj):
return obj.test_case_score if obj.rule_type == ProblemRuleType.OI else []
def get_time_limit(self, obj):
return {"unit": "ms", "value": obj.time_limit}
def get_memory_limit(self, obj):
return {"unit": "MB", "value": obj.memory_limit}
def get_spj(self, obj):
return {"enabled": obj.spj,
"code": obj.spj_code if obj.spj else None,
"language": obj.spj_language if obj.spj else None}
def get_template(self, obj):
ret = {}
for k, v in obj.template.items():
ret[k] = parse_problem_template(v)
return ret
class Meta:
model = Problem
fields = ("_id", "title", "description",
"input_description", "output_description",
"test_case_score", "hint", "time_limit", "memory_limit", "samples",
"template", "spj", "rule_type", "source", "template")

View File

@ -1,6 +1,8 @@
import copy import copy
import hashlib
import os import os
import shutil import shutil
from datetime import timedelta
from zipfile import ZipFile from zipfile import ZipFile
from django.conf import settings from django.conf import settings
@ -8,7 +10,59 @@ from django.conf import settings
from utils.api.tests import APITestCase from utils.api.tests import APITestCase
from .models import ProblemTag from .models import ProblemTag
from .views.admin import TestCaseUploadAPI from .models import Problem, ProblemRuleType
from contest.models import Contest
from contest.tests import DEFAULT_CONTEST_DATA
from .views.admin import TestCaseAPI
from .utils import parse_problem_template
DEFAULT_PROBLEM_DATA = {"_id": "A-110", "title": "test", "description": "<p>test</p>", "input_description": "test",
"output_description": "test", "time_limit": 1000, "memory_limit": 256, "difficulty": "Low",
"visible": True, "tags": ["test"], "languages": ["C", "C++", "Java", "Python2"], "template": {},
"samples": [{"input": "test", "output": "test"}], "spj": False, "spj_language": "C",
"spj_code": "", "spj_compile_ok": True, "test_case_id": "499b26290cc7994e0b497212e842ea85",
"test_case_score": [{"output_name": "1.out", "input_name": "1.in", "output_size": 0,
"stripped_output_md5": "d41d8cd98f00b204e9800998ecf8427e",
"input_size": 0, "score": 0}],
"rule_type": "ACM", "hint": "<p>test</p>", "source": "test"}
class ProblemCreateTestBase(APITestCase):
@staticmethod
def add_problem(problem_data, created_by):
data = copy.deepcopy(problem_data)
if data["spj"]:
if not data["spj_language"] or not data["spj_code"]:
raise ValueError("Invalid spj")
data["spj_version"] = hashlib.md5(
(data["spj_language"] + ":" + data["spj_code"]).encode("utf-8")).hexdigest()
else:
data["spj_language"] = None
data["spj_code"] = None
if data["rule_type"] == ProblemRuleType.OI:
total_score = 0
for item in data["test_case_score"]:
if item["score"] <= 0:
raise ValueError("invalid score")
else:
total_score += item["score"]
data["total_score"] = total_score
data["created_by"] = created_by
tags = data.pop("tags")
data["languages"] = list(data["languages"])
problem = Problem.objects.create(**data)
for item in tags:
try:
tag = ProblemTag.objects.get(name=item)
except ProblemTag.DoesNotExist:
tag = ProblemTag.objects.create(name=item)
problem.tags.add(tag)
return problem
class ProblemTagListAPITest(APITestCase): class ProblemTagListAPITest(APITestCase):
@ -17,17 +71,20 @@ class ProblemTagListAPITest(APITestCase):
ProblemTag.objects.create(name="name2") ProblemTag.objects.create(name="name2")
resp = self.client.get(self.reverse("problem_tag_list_api")) resp = self.client.get(self.reverse("problem_tag_list_api"))
self.assertSuccess(resp) self.assertSuccess(resp)
self.assertEqual(resp.data["data"], ["name1", "name2"]) resp_data = resp.data["data"]
self.assertEqual(resp_data[0]["name"], "name1")
self.assertEqual(resp_data[1]["name"], "name2")
class TestCaseUploadAPITest(APITestCase): class TestCaseUploadAPITest(APITestCase):
def setUp(self): def setUp(self):
self.api = TestCaseUploadAPI() self.api = TestCaseAPI()
self.url = self.reverse("test_case_upload_api") self.url = self.reverse("test_case_api")
self.create_super_admin() self.create_super_admin()
def test_filter_file_name(self): def test_filter_file_name(self):
self.assertEqual(self.api.filter_name_list(["1.in", "1.out", "2.in", ".DS_Store"], spj=False), ["1.in", "1.out"]) self.assertEqual(self.api.filter_name_list(["1.in", "1.out", "2.in", ".DS_Store"], spj=False),
["1.in", "1.out"])
self.assertEqual(self.api.filter_name_list(["2.in", "2.out"], spj=False), []) self.assertEqual(self.api.filter_name_list(["2.in", "2.out"], spj=False), [])
self.assertEqual(self.api.filter_name_list(["1.in", "1.out", "2.in"], spj=True), ["1.in", "2.in"]) self.assertEqual(self.api.filter_name_list(["1.in", "1.out", "2.in"], spj=True), ["1.in", "2.in"])
@ -76,19 +133,11 @@ class TestCaseUploadAPITest(APITestCase):
self.assertEqual(f.read(), name + "\n" + name + "\n" + "end") self.assertEqual(f.read(), name + "\n" + name + "\n" + "end")
class ProblemAPITest(APITestCase): class ProblemAdminAPITest(APITestCase):
def setUp(self): def setUp(self):
self.url = self.reverse("problem_api") self.url = self.reverse("problem_admin_api")
self.create_super_admin() self.create_super_admin()
self.data = {"_id": "A-110", "title": "test", "description": "<p>test</p>", "input_description": "test", self.data = copy.deepcopy(DEFAULT_PROBLEM_DATA)
"output_description": "test", "time_limit": 1000, "memory_limit": 256, "difficulty": "Low",
"visible": True, "tags": ["test"], "languages": ["C", "C++", "Java", "Python2"], "template": {},
"samples": [{"input": "test", "output": "test"}], "spj": False, "spj_language": "C",
"spj_code": "", "test_case_id": "499b26290cc7994e0b497212e842ea85",
"test_case_score": [{"output_name": "1.out", "input_name": "1.in", "output_size": 0,
"stripped_output_md5": "d41d8cd98f00b204e9800998ecf8427e",
"input_size": 0, "score": 0}],
"rule_type": "ACM", "hint": "<p>test</p>", "source": "test"}
def test_create_problem(self): def test_create_problem(self):
resp = self.client.post(self.url, data=self.data) resp = self.client.post(self.url, data=self.data)
@ -128,3 +177,127 @@ class ProblemAPITest(APITestCase):
data["id"] = problem_id data["id"] = problem_id
resp = self.client.put(self.url, data=data) resp = self.client.put(self.url, data=data)
self.assertSuccess(resp) self.assertSuccess(resp)
class ProblemAPITest(ProblemCreateTestBase):
def setUp(self):
self.url = self.reverse("problem_api")
admin = self.create_admin(login=False)
self.problem = self.add_problem(DEFAULT_PROBLEM_DATA, admin)
self.create_user("test", "test123")
def test_get_problem_list(self):
resp = self.client.get(f"{self.url}?limit=10")
self.assertSuccess(resp)
def get_one_problem(self):
resp = self.client.get(self.url + "?id=" + self.problem._id)
self.assertSuccess(resp)
class ContestProblemAdminTest(APITestCase):
def setUp(self):
self.url = self.reverse("contest_problem_admin_api")
self.create_admin()
self.contest = self.client.post(self.reverse("contest_admin_api"), data=DEFAULT_CONTEST_DATA).data["data"]
def test_create_contest_problem(self):
data = copy.deepcopy(DEFAULT_PROBLEM_DATA)
data["contest_id"] = self.contest["id"]
resp = self.client.post(self.url, data=data)
self.assertSuccess(resp)
return resp.data["data"]
def test_get_contest_problem(self):
self.test_create_contest_problem()
contest_id = self.contest["id"]
resp = self.client.get(self.url + "?contest_id=" + str(contest_id))
self.assertSuccess(resp)
self.assertEqual(len(resp.data["data"]["results"]), 1)
def test_get_one_contest_problem(self):
contest_problem = self.test_create_contest_problem()
contest_id = self.contest["id"]
problem_id = contest_problem["id"]
resp = self.client.get(f"{self.url}?contest_id={contest_id}&id={problem_id}")
self.assertSuccess(resp)
class ContestProblemTest(ProblemCreateTestBase):
def setUp(self):
admin = self.create_admin()
url = self.reverse("contest_admin_api")
contest_data = copy.deepcopy(DEFAULT_CONTEST_DATA)
contest_data["password"] = ""
contest_data["start_time"] = contest_data["start_time"] + timedelta(hours=1)
self.contest = self.client.post(url, data=contest_data).data["data"]
self.problem = self.add_problem(DEFAULT_PROBLEM_DATA, admin)
self.problem.contest_id = self.contest["id"]
self.problem.save()
self.url = self.reverse("contest_problem_api")
def test_admin_get_contest_problem_list(self):
contest_id = self.contest["id"]
resp = self.client.get(self.url + "?contest_id=" + str(contest_id))
self.assertSuccess(resp)
self.assertEqual(len(resp.data["data"]), 1)
def test_admin_get_one_contest_problem(self):
contest_id = self.contest["id"]
problem_id = self.problem._id
resp = self.client.get("{}?contest_id={}&problem_id={}".format(self.url, contest_id, problem_id))
self.assertSuccess(resp)
def test_regular_user_get_not_started_contest_problem(self):
self.create_user("test", "test123")
resp = self.client.get(self.url + "?contest_id=" + str(self.contest["id"]))
self.assertDictEqual(resp.data, {"error": "error", "data": "Contest has not started yet."})
def test_reguar_user_get_started_contest_problem(self):
self.create_user("test", "test123")
contest = Contest.objects.first()
contest.start_time = contest.start_time - timedelta(hours=1)
contest.save()
resp = self.client.get(self.url + "?contest_id=" + str(self.contest["id"]))
self.assertSuccess(resp)
class ParseProblemTemplateTest(APITestCase):
def test_parse(self):
template_str = """
//PREPEND BEGIN
aaa
//PREPEND END
//TEMPLATE BEGIN
bbb
//TEMPLATE END
//APPEND BEGIN
ccc
//APPEND END
"""
ret = parse_problem_template(template_str)
self.assertEqual(ret["prepend"], "aaa\n")
self.assertEqual(ret["template"], "bbb\n")
self.assertEqual(ret["append"], "ccc\n")
def test_parse1(self):
template_str = """
//PREPEND BEGIN
aaa
//PREPEND END
//APPEND BEGIN
ccc
//APPEND END
//APPEND BEGIN
ddd
//APPEND END
"""
ret = parse_problem_template(template_str)
self.assertEqual(ret["prepend"], "aaa\n")
self.assertEqual(ret["template"], "")
self.assertEqual(ret["append"], "ccc\n")

View File

@ -1,9 +1,12 @@
from django.conf.urls import url from django.conf.urls import url
from ..views.admin import ContestProblemAPI, ProblemAPI, TestCaseUploadAPI from ..views.admin import ContestProblemAPI, ProblemAPI, TestCaseAPI, MakeContestProblemPublicAPIView
from ..views.admin import CompileSPJAPI
urlpatterns = [ urlpatterns = [
url(r"^test_case/upload/?$", TestCaseUploadAPI.as_view(), name="test_case_upload_api"), url(r"^test_case/?$", TestCaseAPI.as_view(), name="test_case_api"),
url(r"^problem/?$", ProblemAPI.as_view(), name="problem_api"), url(r"^compile_spj/?$", CompileSPJAPI.as_view(), name="compile_spj"),
url(r"^contest/problem/?$", ContestProblemAPI.as_view(), name="contest_problem_api") url(r"^problem/?$", ProblemAPI.as_view(), name="problem_admin_api"),
url(r"^contest/problem/?$", ContestProblemAPI.as_view(), name="contest_problem_admin_api"),
url(r"^contest_problem/make_public/?$", MakeContestProblemPublicAPIView.as_view(), name="make_public_api"),
] ]

View File

@ -1,7 +1,10 @@
from django.conf.urls import url from django.conf.urls import url
from ..views.oj import ProblemTagAPI from ..views.oj import ProblemTagAPI, ProblemAPI, ContestProblemAPI, PickOneAPI
urlpatterns = [ urlpatterns = [
url(r"^problem/tags/?$", ProblemTagAPI.as_view(), name="problem_tag_list_api") url(r"^problem/tags/?$", ProblemTagAPI.as_view(), name="problem_tag_list_api"),
url(r"^problem/?$", ProblemAPI.as_view(), name="problem_api"),
url(r"^pickone/?$", PickOneAPI.as_view(), name="pick_one_api"),
url(r"^contest/problem/?$", ContestProblemAPI.as_view(), name="contest_problem_api"),
] ]

10
problem/utils.py Normal file
View File

@ -0,0 +1,10 @@
import re
def parse_problem_template(template_str):
prepend = re.findall("//PREPEND BEGIN\n([\s\S]+?)//PREPEND END", template_str)
template = re.findall("//TEMPLATE BEGIN\n([\s\S]+?)//TEMPLATE END", template_str)
append = re.findall("//APPEND BEGIN\n([\s\S]+?)//APPEND END", template_str)
return {"prepend": prepend[0] if prepend else "",
"template": template[0] if template else "",
"append": append[0] if append else ""}

View File

@ -1,22 +1,27 @@
import hashlib import hashlib
import json import json
import os import os
import shutil
import zipfile import zipfile
from wsgiref.util import FileWrapper
from django.conf import settings from django.conf import settings
from django.http import StreamingHttpResponse
from account.decorators import problem_permission_required from account.decorators import problem_permission_required
from judge.dispatcher import SPJCompiler
from contest.models import Contest from contest.models import Contest
from submission.models import Submission
from utils.api import APIView, CSRFExemptAPIView, validate_serializer from utils.api import APIView, CSRFExemptAPIView, validate_serializer
from utils.shortcuts import rand_str from utils.shortcuts import rand_str, natural_sort_key
from ..models import ContestProblem, Problem, ProblemRuleType, ProblemTag from ..models import Problem, ProblemRuleType, ProblemTag
from ..serializers import (CreateContestProblemSerializer, from ..serializers import (CreateContestProblemSerializer, ContestProblemAdminSerializer, CompileSPJSerializer,
CreateProblemSerializer, EditProblemSerializer, CreateProblemSerializer, EditProblemSerializer, EditContestProblemSerializer,
ProblemSerializer, TestCaseUploadForm) ProblemAdminSerializer, TestCaseUploadForm, ContestProblemMakePublicSerializer)
class TestCaseUploadAPI(CSRFExemptAPIView): class TestCaseAPI(CSRFExemptAPIView):
request_parsers = () request_parsers = ()
def filter_name_list(self, name_list, spj): def filter_name_list(self, name_list, spj):
@ -30,7 +35,7 @@ class TestCaseUploadAPI(CSRFExemptAPIView):
prefix += 1 prefix += 1
continue continue
else: else:
return sorted(ret) return sorted(ret, key=natural_sort_key)
else: else:
while True: while True:
in_name = str(prefix) + ".in" in_name = str(prefix) + ".in"
@ -41,7 +46,30 @@ class TestCaseUploadAPI(CSRFExemptAPIView):
prefix += 1 prefix += 1
continue continue
else: else:
return sorted(ret) return sorted(ret, key=natural_sort_key)
@problem_permission_required
def get(self, request):
problem_id = request.GET.get("problem_id")
if not problem_id:
return self.error("Parameter error, problem_id is required")
try:
problem = Problem.objects.get(id=problem_id)
except Problem.DoesNotExist:
return self.error("Problem does not exists")
test_case_dir = os.path.join(settings.TEST_CASE_DIR, problem.test_case_id)
if not os.path.isdir(test_case_dir):
return self.error("Test case does not exists")
name_list = self.filter_name_list(os.listdir(test_case_dir), problem.spj)
name_list.append("info")
file_name = os.path.join(test_case_dir, problem.test_case_id + ".zip")
with zipfile.ZipFile(file_name, "w") as file:
for test_case in name_list:
file.write(f"{test_case_dir}/{test_case}", test_case)
response = StreamingHttpResponse(FileWrapper(open(file_name, "rb")), content_type="application/zip")
response["Content-Disposition"] = f"attachment; filename=problem_{problem.id}_test_cases.zip"
return response
@problem_permission_required @problem_permission_required
def post(self, request): def post(self, request):
@ -76,7 +104,7 @@ class TestCaseUploadAPI(CSRFExemptAPIView):
content = zip_file.read(item).replace(b"\r\n", b"\n") content = zip_file.read(item).replace(b"\r\n", b"\n")
size_cache[item] = len(content) size_cache[item] = len(content)
if item.endswith(".out"): if item.endswith(".out"):
md5_cache[item] = hashlib.md5(content).hexdigest() md5_cache[item] = hashlib.md5(content.rstrip()).hexdigest()
f.write(content) f.write(content)
test_case_info = {"spj": spj, "test_cases": {}} test_case_info = {"spj": spj, "test_cases": {}}
@ -109,44 +137,80 @@ class TestCaseUploadAPI(CSRFExemptAPIView):
return self.success({"id": test_case_id, "info": ret, "hint": hint, "spj": spj}) return self.success({"id": test_case_id, "info": ret, "hint": hint, "spj": spj})
class ProblemAPI(APIView): class CompileSPJAPI(APIView):
@validate_serializer(CompileSPJSerializer)
@problem_permission_required
def post(self, request):
data = request.data
spj_version = rand_str(8)
error = SPJCompiler(data["spj_code"], spj_version, data["spj_language"]).compile_spj()
if error:
return self.error(error)
else:
return self.success()
class ProblemBase(APIView):
def common_checks(self, request):
data = request.data
if data["spj"]:
if not data["spj_language"] or not data["spj_code"]:
return "Invalid spj"
if not data["spj_compile_ok"]:
return "SPJ code must be compiled successfully"
data["spj_version"] = hashlib.md5(
(data["spj_language"] + ":" + data["spj_code"]).encode("utf-8")).hexdigest()
else:
data["spj_language"] = None
data["spj_code"] = None
if data["rule_type"] == ProblemRuleType.OI:
total_score = 0
for item in data["test_case_score"]:
if item["score"] <= 0:
return "Invalid score"
else:
total_score += item["score"]
data["total_score"] = total_score
data["created_by"] = request.user
data["languages"] = list(data["languages"])
@problem_permission_required
def delete(self, request):
id = request.GET.get("id")
if not id:
return self.error("Invalid parameter, id is requred")
try:
problem = Problem.objects.get(id=id)
except Problem.DoesNotExist:
return self.error("Problem does not exists")
if Submission.objects.filter(problem=problem).exists():
return self.error("Can't delete the problem as it has submissions")
d = os.path.join(settings.TEST_CASE_DIR, problem.test_case_id)
if os.path.isdir(d):
shutil.rmtree(d, ignore_errors=True)
problem.delete()
return self.success()
class ProblemAPI(ProblemBase):
@validate_serializer(CreateProblemSerializer) @validate_serializer(CreateProblemSerializer)
@problem_permission_required @problem_permission_required
def post(self, request): def post(self, request):
data = request.data data = request.data
_id = data["_id"] _id = data["_id"]
if _id:
try:
Problem.objects.get(_id=_id)
return self.error("Display ID already exists")
except Problem.DoesNotExist:
pass
else:
data["_id"] = rand_str(8)
if data["spj"]:
if not data["spj_language"] or not data["spj_code"]:
return self.error("Invalid spj")
data["spj_version"] = hashlib.md5((data["spj_language"] + ":" + data["spj_code"]).encode("utf-8")).hexdigest()
else:
data["spj_language"] = None
data["spj_code"] = None
if data["rule_type"] == ProblemRuleType.OI:
for item in data["test_case_score"]:
if item["score"] <= 0:
return self.error("Invalid score")
# todo check filename and score info
data["created_by"] = request.user
tags = data.pop("tags")
data["languages"] = list(data["languages"])
problem = Problem.objects.create(**data)
if not _id: if not _id:
problem._id = str(problem.id) return self.error("Display ID is required")
problem.save() if Problem.objects.filter(_id=_id, contest_id__isnull=True).exists():
return self.error("Display ID already exists")
error_info = self.common_checks(request)
if error_info:
return self.error(error_info)
# todo check filename and score info
tags = data.pop("tags")
problem = Problem.objects.create(**data)
for item in tags: for item in tags:
try: try:
@ -154,7 +218,7 @@ class ProblemAPI(APIView):
except ProblemTag.DoesNotExist: except ProblemTag.DoesNotExist:
tag = ProblemTag.objects.create(name=item) tag = ProblemTag.objects.create(name=item)
problem.tags.add(tag) problem.tags.add(tag)
return self.success(ProblemSerializer(problem).data) return self.success(ProblemAdminSerializer(problem).data)
@problem_permission_required @problem_permission_required
def get(self, request): def get(self, request):
@ -165,17 +229,17 @@ class ProblemAPI(APIView):
problem = Problem.objects.get(id=problem_id) problem = Problem.objects.get(id=problem_id)
if not user.can_mgmt_all_problem() and problem.created_by != user: if not user.can_mgmt_all_problem() and problem.created_by != user:
return self.error("Problem does not exist") return self.error("Problem does not exist")
return self.success(ProblemSerializer(problem).data) return self.success(ProblemAdminSerializer(problem).data)
except Problem.DoesNotExist: except Problem.DoesNotExist:
return self.error("Problem does not exist") return self.error("Problem does not exist")
problems = Problem.objects.all().order_by("-create_time") problems = Problem.objects.filter(contest_id__isnull=True).order_by("-create_time")
if not user.can_mgmt_all_problem(): if not user.can_mgmt_all_problem():
problems = problems.filter(created_by=user) problems = problems.filter(created_by=user)
keyword = request.GET.get("keyword") keyword = request.GET.get("keyword")
if keyword: if keyword:
problems = problems.filter(title__contains=keyword) problems = problems.filter(title__contains=keyword)
return self.success(self.paginate_data(request, problems, ProblemSerializer)) return self.success(self.paginate_data(request, problems, ProblemAdminSerializer))
@validate_serializer(EditProblemSerializer) @validate_serializer(EditProblemSerializer)
@problem_permission_required @problem_permission_required
@ -192,29 +256,17 @@ class ProblemAPI(APIView):
return self.error("Problem does not exist") return self.error("Problem does not exist")
_id = data["_id"] _id = data["_id"]
if _id: if not _id:
try: return self.error("Display ID is required")
Problem.objects.exclude(id=problem_id).get(_id=_id) if Problem.objects.exclude(id=problem_id).filter(_id=_id, contest_id__isnull=True).exists():
return self.error("Display ID already exists") return self.error("Display ID already exists")
except Problem.DoesNotExist:
pass
else:
data["_id"] = str(problem_id)
if data["spj"]: error_info = self.common_checks(request)
if not data["spj_language"] or not data["spj_code"]: if error_info:
return self.error("Invalid spj") return self.error(error_info)
data["spj_version"] = hashlib.md5((data["spj_language"] + ":" + data["spj_code"]).encode("utf-8")).hexdigest()
else:
data["spj_language"] = None
data["spj_code"] = None
if data["rule_type"] == ProblemRuleType.OI:
for item in data["test_case_score"]:
if item["score"] <= 0:
return self.error("Invalid score")
# todo check filename and score info # todo check filename and score info
tags = data.pop("tags") tags = data.pop("tags")
data["languages"] = list(data["languages"])
for k, v in data.items(): for k, v in data.items():
setattr(problem, k, v) setattr(problem, k, v)
@ -231,11 +283,11 @@ class ProblemAPI(APIView):
return self.success() return self.success()
class ContestProblemAPI(APIView): class ContestProblemAPI(ProblemBase):
@validate_serializer(CreateContestProblemSerializer) @validate_serializer(CreateContestProblemSerializer)
@problem_permission_required
def post(self, request): def post(self, request):
data = request.data data = request.data
try: try:
contest = Contest.objects.get(id=data.pop("contest_id")) contest = Contest.objects.get(id=data.pop("contest_id"))
if request.user.is_admin() and contest.created_by != request.user: if request.user.is_admin() and contest.created_by != request.user:
@ -248,33 +300,19 @@ class ContestProblemAPI(APIView):
_id = data["_id"] _id = data["_id"]
if not _id: if not _id:
return self.error("Display id is required for contest problem") return self.error("Display ID is required")
try:
ContestProblem.objects.get(_id=_id, contest=contest) if Problem.objects.filter(_id=_id, contest=contest).exists():
return self.error("Duplicate Display id") return self.error("Duplicate Display id")
except ContestProblem.DoesNotExist:
pass
if data["spj"]: error_info = self.common_checks(request)
if not data["spj_language"] or not data["spj_code"]: if error_info:
return self.error("Invalid spj") return self.error(error_info)
data["spj_version"] = hashlib.md5((data["spj_language"] + ":" + data["spj_code"]).encode("utf-8")).hexdigest()
else:
data["spj_language"] = None
data["spj_code"] = None
if data["rule_type"] == ProblemRuleType.OI:
for item in data["test_case_score"]:
if item["score"] <= 0:
return self.error("Invalid score")
# todo check filename and score info # todo check filename and score info
data["created_by"] = request.user
data["contest"] = contest data["contest"] = contest
tags = data.pop("tags") tags = data.pop("tags")
data["languages"] = list(data["languages"]) problem = Problem.objects.create(**data)
problem = ContestProblem.objects.create(**data)
for item in tags: for item in tags:
try: try:
@ -282,28 +320,109 @@ class ContestProblemAPI(APIView):
except ProblemTag.DoesNotExist: except ProblemTag.DoesNotExist:
tag = ProblemTag.objects.create(name=item) tag = ProblemTag.objects.create(name=item)
problem.tags.add(tag) problem.tags.add(tag)
return self.success(ProblemSerializer(problem).data) return self.success(ContestProblemAdminSerializer(problem).data)
@problem_permission_required
def get(self, request): def get(self, request):
problem_id = request.GET.get("id") problem_id = request.GET.get("id")
contest_id = request.GET.get("contest_id") contest_id = request.GET.get("contest_id")
user = request.user user = request.user
if problem_id: if problem_id:
try: try:
problem = ContestProblem.objects.get(id=problem_id) problem = Problem.objects.get(id=problem_id)
if user.is_admin() and problem.contest.created_by != user: if user.is_admin() and problem.contest.created_by != user:
return self.error("Problem does not exist") return self.error("Problem does not exist")
except ContestProblem.DoesNotExist: except Problem.DoesNotExist:
return self.error("Problem does not exist") return self.error("Problem does not exist")
return self.success(ProblemSerializer(problem).data) return self.success(ProblemAdminSerializer(problem).data)
if not contest_id: if not contest_id:
return self.error("Contest id is required") return self.error("Contest id is required")
problems = ContestProblem.objects.filter(contest_id=contest_id).order_by("-create_time") problems = Problem.objects.filter(contest_id=contest_id).order_by("-create_time")
if user.is_admin(): if user.is_admin():
problems = problems.filter(contest__created_by=user) problems = problems.filter(contest__created_by=user)
keyword = request.GET.get("keyword") keyword = request.GET.get("keyword")
if keyword: if keyword:
problems = problems.filter(title__contains=keyword) problems = problems.filter(title__contains=keyword)
return self.success(self.paginate_data(request, problems, ProblemSerializer)) return self.success(self.paginate_data(request, problems, ContestProblemAdminSerializer))
@validate_serializer(EditContestProblemSerializer)
@problem_permission_required
def put(self, request):
data = request.data
try:
contest = Contest.objects.get(id=data.pop("contest_id"))
if request.user.is_admin() and contest.created_by != request.user:
return self.error("Contest does not exist")
except Contest.DoesNotExist:
return self.error("Contest does not exist")
if data["rule_type"] != contest.rule_type:
return self.error("Invalid rule type")
problem_id = data.pop("id")
user = request.user
try:
problem = Problem.objects.get(id=problem_id)
if not user.can_mgmt_all_problem() and problem.created_by != user:
return self.error("Problem does not exist")
except Problem.DoesNotExist:
return self.error("Problem does not exist")
_id = data["_id"]
if not _id:
return self.error("Display ID is required")
if Problem.objects.exclude(id=problem_id).filter(_id=_id, contest=contest).exists():
return self.error("Display ID already exists")
error_info = self.common_checks(request)
if error_info:
return self.error(error_info)
# todo check filename and score info
tags = data.pop("tags")
data["languages"] = list(data["languages"])
for k, v in data.items():
setattr(problem, k, v)
problem.save()
problem.tags.remove(*problem.tags.all())
for tag in tags:
try:
tag = ProblemTag.objects.get(name=tag)
except ProblemTag.DoesNotExist:
tag = ProblemTag.objects.create(name=tag)
problem.tags.add(tag)
return self.success()
class MakeContestProblemPublicAPIView(APIView):
@validate_serializer(ContestProblemMakePublicSerializer)
@problem_permission_required
def post(self, request):
data = request.data
display_id = data.get("display_id")
if Problem.objects.filter(_id=display_id, contest_id__isnull=True).exists():
return self.error("Duplicate display ID")
try:
problem = Problem.objects.get(id=data["id"])
except Problem.DoesNotExist:
return self.error("Problem does not exist")
if not problem.contest or problem.is_public:
return self.error("Alreay be a public problem")
problem.is_public = True
problem.save()
# https://docs.djangoproject.com/en/1.11/topics/db/queries/#copying-model-instances
tags = problem.tags.all()
problem.pk = None
problem.contest = None
problem._id = display_id
problem.submission_number = problem.accepted_number = 0
problem.statistic_info = {}
problem.save()
problem.tags.set(tags)
return self.success()

View File

@ -1,8 +1,116 @@
import random
from django.db.models import Q
from utils.api import APIView from utils.api import APIView
from account.decorators import check_contest_permission
from ..models import ProblemTag from ..models import ProblemTag, Problem, ProblemRuleType
from ..serializers import ProblemSerializer, TagSerializer
from ..serializers import ContestProblemSerializer, ContestProblemSafeSerializer
from contest.models import ContestRuleType
class ProblemTagAPI(APIView): class ProblemTagAPI(APIView):
def get(self, request): def get(self, request):
return self.success([item.name for item in ProblemTag.objects.all().order_by("id")]) return self.success(TagSerializer(ProblemTag.objects.all(), many=True).data)
class PickOneAPI(APIView):
def get(self, request):
problems = Problem.objects.filter(contest_id__isnull=True, visible=True)
count = problems.count()
if count == 0:
return self.error("No problem to pick")
return self.success(problems[random.randint(0, count - 1)]._id)
class ProblemAPI(APIView):
@staticmethod
def _add_problem_status(request, queryset_values):
if request.user.is_authenticated():
profile = request.user.userprofile
acm_problems_status = profile.acm_problems_status.get("problems", {})
oi_problems_status = profile.oi_problems_status.get("problems", {})
# paginate data
results = queryset_values.get("results")
if results is not None:
problems = results
else:
problems = [queryset_values, ]
for problem in problems:
if problem["rule_type"] == ProblemRuleType.ACM:
problem["my_status"] = acm_problems_status.get(str(problem["id"]), {}).get("status")
else:
problem["my_status"] = oi_problems_status.get(str(problem["id"]), {}).get("status")
def get(self, request):
# 问题详情页
problem_id = request.GET.get("problem_id")
if problem_id:
try:
problem = Problem.objects.select_related("created_by") \
.get(_id=problem_id, contest_id__isnull=True, visible=True)
problem_data = ProblemSerializer(problem).data
self._add_problem_status(request, problem_data)
return self.success(problem_data)
except Problem.DoesNotExist:
return self.error("Problem does not exist")
limit = request.GET.get("limit")
if not limit:
return self.error("Limit is needed")
problems = Problem.objects.select_related("created_by").filter(contest_id__isnull=True, visible=True)
# 按照标签筛选
tag_text = request.GET.get("tag")
if tag_text:
problems = problems.filter(tags__name=tag_text)
# 搜索的情况
keyword = request.GET.get("keyword", "").strip()
if keyword:
problems = problems.filter(Q(title__icontains=keyword) | Q(_id__icontains=keyword))
# 难度筛选
difficulty = request.GET.get("difficulty")
if difficulty:
problems = problems.filter(difficulty=difficulty)
# 根据profile 为做过的题目添加标记
data = self.paginate_data(request, problems, ProblemSerializer)
self._add_problem_status(request, data)
return self.success(data)
class ContestProblemAPI(APIView):
def _add_problem_status(self, request, queryset_values):
if request.user.is_authenticated():
profile = request.user.userprofile
if self.contest.rule_type == ContestRuleType.ACM:
problems_status = profile.acm_problems_status.get("contest_problems", {})
else:
problems_status = profile.oi_problems_status.get("contest_problems", {})
for problem in queryset_values:
problem["my_status"] = problems_status.get(str(problem["id"]), {}).get("status")
@check_contest_permission(check_type="problems")
def get(self, request):
problem_id = request.GET.get("problem_id")
if problem_id:
try:
problem = Problem.objects.select_related("created_by").get(_id=problem_id,
contest=self.contest,
visible=True)
except Problem.DoesNotExist:
return self.error("Problem does not exist.")
if self.contest.problem_details_permission(request.user):
problem_data = ContestProblemSerializer(problem).data
self._add_problem_status(request, [problem_data, ])
else:
problem_data = ContestProblemSafeSerializer(problem).data
return self.success(problem_data)
contest_problems = Problem.objects.select_related("created_by").filter(contest=self.contest, visible=True)
if self.contest.problem_details_permission(request.user):
data = ContestProblemSerializer(contest_problems, many=True).data
self._add_problem_status(request, data)
else:
data = ContestProblemSafeSerializer(contest_problems, many=True).data
return self.success(data)

View File

@ -1,10 +0,0 @@
django==1.9.6
djangorestframework==3.4.0
otpauth
pillow
python-dateutil
celery
Envelopes
pytz
jsonfield
qrcode

View File

@ -21,7 +21,7 @@ print("running flake8...")
if os.system("flake8 --statistics ."): if os.system("flake8 --statistics ."):
exit() exit()
ret = os.system("coverage run ./manage.py test {module} --settings={setting}".format(module=test_module, setting=setting)) ret = os.system("coverage run --include=\"$PWD/*\" manage.py test {module} --settings={setting}".format(module=test_module, setting=setting))
if not ret and is_coverage: if not ret and is_coverage:
os.system("coverage html && open htmlcov/index.html") os.system("coverage html && open htmlcov/index.html")

0
submission/__init__.py Normal file
View File

View File

@ -0,0 +1,39 @@
# -*- coding: utf-8 -*-
# Generated by Django 1.9.6 on 2017-05-09 06:41
from __future__ import unicode_literals
from django.db import migrations, models
import jsonfield.fields
import utils.models
import utils.shortcuts
class Migration(migrations.Migration):
initial = True
dependencies = [
]
operations = [
migrations.CreateModel(
name='Submission',
fields=[
('id', models.CharField(db_index=True, default=utils.shortcuts.rand_str, max_length=32, primary_key=True, serialize=False)),
('contest_id', models.IntegerField(db_index=True, null=True)),
('problem_id', models.IntegerField(db_index=True)),
('created_time', models.DateTimeField(auto_now_add=True)),
('user_id', models.IntegerField(db_index=True)),
('code', utils.models.RichTextField()),
('result', models.IntegerField(default=6)),
('info', jsonfield.fields.JSONField(default={})),
('language', models.CharField(max_length=20)),
('shared', models.BooleanField(default=False)),
('accepted_time', models.IntegerField(blank=True, null=True)),
('accepted_info', jsonfield.fields.JSONField(default={})),
],
options={
'db_table': 'submission',
},
),
]

View File

@ -0,0 +1,38 @@
# -*- coding: utf-8 -*-
# Generated by Django 1.9.6 on 2017-05-09 12:03
from __future__ import unicode_literals
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('submission', '0001_initial'),
]
operations = [
migrations.AlterField(
model_name='submission',
name='code',
field=models.TextField(),
),
migrations.RenameField(
model_name='submission',
old_name='accepted_info',
new_name='statistic_info',
),
migrations.RemoveField(
model_name='submission',
name='accepted_time',
),
migrations.RenameField(
model_name='submission',
old_name='created_time',
new_name='create_time',
),
migrations.AlterModelOptions(
name='submission',
options={'ordering': ('-create_time',)},
)
]

View File

@ -0,0 +1,21 @@
# -*- coding: utf-8 -*-
# Generated by Django 1.11.4 on 2017-08-26 03:47
from __future__ import unicode_literals
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('submission', '0002_auto_20170509_1203'),
]
operations = [
migrations.AddField(
model_name='submission',
name='username',
field=models.CharField(default="", max_length=30),
preserve_default=False,
),
]

Some files were not shown because too many files have changed in this diff Show More