diff --git a/oj/settings.py b/oj/settings.py index bc87af4a..b4173ec5 100644 --- a/oj/settings.py +++ b/oj/settings.py @@ -185,3 +185,8 @@ TEST_CASE_DIR = os.path.join(BASE_DIR, 'test_case/') IMAGE_UPLOAD_DIR = os.path.join(BASE_DIR, 'upload/') +# 用于限制用户恶意提交大量代码 +TOKEN_BUCKET_DEFAULT_CAPACITY = 50 + +# 单位:每分钟 +TOKEN_BUCKET_FILL_RATE = 2 diff --git a/submission/views.py b/submission/views.py index 41f5e3b5..8763e1e3 100644 --- a/submission/views.py +++ b/submission/views.py @@ -3,8 +3,10 @@ import json import logging import redis + from django.shortcuts import render from django.core.paginator import Paginator +from django.conf import settings from rest_framework.views import APIView from account.decorators import login_required, super_admin_required @@ -13,6 +15,7 @@ from problem.models import Problem from contest.models import ContestProblem, Contest from contest.decorators import check_user_contest_permission from utils.shortcuts import serializer_invalid_response, error_response, success_response, error_page, paginate +from utils.throttling import TokenBucket, BucketController from .tasks import _judge from .models import Submission from .serializers import (CreateSubmissionSerializer, SubmissionSerializer, @@ -30,6 +33,20 @@ class SubmissionAPIView(APIView): --- request_serializer: CreateSubmissionSerializer """ + controller = BucketController(user_id=request.user.id, + redis_conn=redis.Redis(host=settings.REDIS_CACHE["host"], + port=settings.REDIS_CACHE["port"], + db=settings.REDIS_CACHE["db"]), + default_capacity=settings.TOKEN_BUCKET_DEFAULT_CAPACITY) + bucket = TokenBucket(fill_rate=settings.TOKEN_BUCKET_FILL_RATE, + capacity=settings.TOKEN_BUCKET_DEFAULT_CAPACITY, + last_capacity=controller.last_capacity, + last_timestamp=controller.last_timestamp) + if bucket.consume(): + controller.last_capacity -= 1 + else: + return error_response(u"您提交的频率过快, 请等待%d秒" % int(bucket.expected_time() + 1)) + serializer = CreateSubmissionSerializer(data=request.data) if serializer.is_valid(): data = serializer.data @@ -107,7 +124,7 @@ def problem_my_submissions_list_page(request, problem_id): except Problem.DoesNotExist: return error_page(request, u"问题不存在") - submissions = Submission.objects.filter(user_id=request.user.id, problem_id=problem.id,contest_id__isnull=True).\ + submissions = Submission.objects.filter(user_id=request.user.id, problem_id=problem.id, contest_id__isnull=True). \ order_by("-create_time"). \ values("id", "result", "create_time", "accepted_answer_time", "language") diff --git a/utils/throttling.py b/utils/throttling.py new file mode 100644 index 00000000..7d020ddf --- /dev/null +++ b/utils/throttling.py @@ -0,0 +1,94 @@ +# coding=utf-8 +import time +import redis + + +class TokenBucket(object): + def __init__(self, fill_rate, capacity, last_capacity, last_timestamp): + self.capacity = float(capacity) + self._left_tokens = last_capacity + self.fill_rate = float(fill_rate) + self.timestamp = last_timestamp + + def consume(self, tokens=1): + if tokens <= self.tokens: + self._left_tokens -= tokens + return True + return False + + def expected_time(self, tokens=1): + _tokens = self.tokens + tokens = max(tokens, _tokens) + return (tokens - _tokens) / self.fill_rate * 60 + + @property + def tokens(self): + if self._left_tokens < self.capacity: + now = time.time() + delta = self.fill_rate * ((now - self.timestamp) / 60) + self._left_tokens = min(self.capacity, self._left_tokens + delta) + self.timestamp = now + return self._left_tokens + + +class BucketController(object): + def __init__(self, user_id, redis_conn, default_capacity): + self.user_id = user_id + self.default_capacity = default_capacity + self.redis = redis_conn + self.key = "bucket_" + str(self.user_id) + + @property + def last_capacity(self): + value = self.redis.hget(self.key, "last_capacity") + if value is None: + self.last_capacity = self.default_capacity + return self.default_capacity + return int(value) + + @last_capacity.setter + def last_capacity(self, value): + self.redis.hset(self.key, "last_capacity", value) + + @property + def last_timestamp(self): + value = self.redis.hget(self.key, "last_timestamp") + if value is None: + timestamp = int(time.time()) + self.last_timestamp = timestamp + return timestamp + return int(value) + + @last_timestamp.setter + def last_timestamp(self, value): + self.redis.hset(self.key, "last_timestamp", value) + + +""" +# token bucket 机制限制用户提交大量代码 +# demo +success = failure = 0 +current_user_id = 1 +token_bucket_default_capacity = 50 +token_bucket_fill_rate = 10 + + +for i in range(5000): + controller = BucketController(user_id=current_user_id, + redis_conn=redis.Redis(), + default_capacity=token_bucket_default_capacity) + bucket = TokenBucket(fill_rate=token_bucket_fill_rate, + capacity=token_bucket_default_capacity, + last_capacity=controller.last_capacity, + last_timestamp=controller.last_timestamp) + + time.sleep(0.05) + if bucket.consume(): + success += 1 + print i, ": Accepted" + controller.last_capacity -= 1 + else: + failure += 1 + print i, "Dropped, time left ", bucket.expected_time() +print success, failure +""" \ No newline at end of file