使用 TokenBucket 机制限制用户恶意提交代码

This commit is contained in:
virusdefender 2016-01-17 14:51:14 +08:00
parent 17ed05cb4c
commit 2097698560
3 changed files with 117 additions and 1 deletions

View File

@ -185,3 +185,8 @@ TEST_CASE_DIR = os.path.join(BASE_DIR, 'test_case/')
IMAGE_UPLOAD_DIR = os.path.join(BASE_DIR, 'upload/') IMAGE_UPLOAD_DIR = os.path.join(BASE_DIR, 'upload/')
# 用于限制用户恶意提交大量代码
TOKEN_BUCKET_DEFAULT_CAPACITY = 50
# 单位:每分钟
TOKEN_BUCKET_FILL_RATE = 2

View File

@ -3,8 +3,10 @@ import json
import logging import logging
import redis import redis
from django.shortcuts import render from django.shortcuts import render
from django.core.paginator import Paginator from django.core.paginator import Paginator
from django.conf import settings
from rest_framework.views import APIView from rest_framework.views import APIView
from account.decorators import login_required, super_admin_required from account.decorators import login_required, super_admin_required
@ -13,6 +15,7 @@ from problem.models import Problem
from contest.models import ContestProblem, Contest from contest.models import ContestProblem, Contest
from contest.decorators import check_user_contest_permission from contest.decorators import check_user_contest_permission
from utils.shortcuts import serializer_invalid_response, error_response, success_response, error_page, paginate from utils.shortcuts import serializer_invalid_response, error_response, success_response, error_page, paginate
from utils.throttling import TokenBucket, BucketController
from .tasks import _judge from .tasks import _judge
from .models import Submission from .models import Submission
from .serializers import (CreateSubmissionSerializer, SubmissionSerializer, from .serializers import (CreateSubmissionSerializer, SubmissionSerializer,
@ -30,6 +33,20 @@ class SubmissionAPIView(APIView):
--- ---
request_serializer: CreateSubmissionSerializer request_serializer: CreateSubmissionSerializer
""" """
controller = BucketController(user_id=request.user.id,
redis_conn=redis.Redis(host=settings.REDIS_CACHE["host"],
port=settings.REDIS_CACHE["port"],
db=settings.REDIS_CACHE["db"]),
default_capacity=settings.TOKEN_BUCKET_DEFAULT_CAPACITY)
bucket = TokenBucket(fill_rate=settings.TOKEN_BUCKET_FILL_RATE,
capacity=settings.TOKEN_BUCKET_DEFAULT_CAPACITY,
last_capacity=controller.last_capacity,
last_timestamp=controller.last_timestamp)
if bucket.consume():
controller.last_capacity -= 1
else:
return error_response(u"您提交的频率过快, 请等待%d" % int(bucket.expected_time() + 1))
serializer = CreateSubmissionSerializer(data=request.data) serializer = CreateSubmissionSerializer(data=request.data)
if serializer.is_valid(): if serializer.is_valid():
data = serializer.data data = serializer.data
@ -107,7 +124,7 @@ def problem_my_submissions_list_page(request, problem_id):
except Problem.DoesNotExist: except Problem.DoesNotExist:
return error_page(request, u"问题不存在") return error_page(request, u"问题不存在")
submissions = Submission.objects.filter(user_id=request.user.id, problem_id=problem.id,contest_id__isnull=True).\ submissions = Submission.objects.filter(user_id=request.user.id, problem_id=problem.id, contest_id__isnull=True). \
order_by("-create_time"). \ order_by("-create_time"). \
values("id", "result", "create_time", "accepted_answer_time", "language") values("id", "result", "create_time", "accepted_answer_time", "language")

94
utils/throttling.py Normal file
View File

@ -0,0 +1,94 @@
# coding=utf-8
import time
import redis
class TokenBucket(object):
def __init__(self, fill_rate, capacity, last_capacity, last_timestamp):
self.capacity = float(capacity)
self._left_tokens = last_capacity
self.fill_rate = float(fill_rate)
self.timestamp = last_timestamp
def consume(self, tokens=1):
if tokens <= self.tokens:
self._left_tokens -= tokens
return True
return False
def expected_time(self, tokens=1):
_tokens = self.tokens
tokens = max(tokens, _tokens)
return (tokens - _tokens) / self.fill_rate * 60
@property
def tokens(self):
if self._left_tokens < self.capacity:
now = time.time()
delta = self.fill_rate * ((now - self.timestamp) / 60)
self._left_tokens = min(self.capacity, self._left_tokens + delta)
self.timestamp = now
return self._left_tokens
class BucketController(object):
def __init__(self, user_id, redis_conn, default_capacity):
self.user_id = user_id
self.default_capacity = default_capacity
self.redis = redis_conn
self.key = "bucket_" + str(self.user_id)
@property
def last_capacity(self):
value = self.redis.hget(self.key, "last_capacity")
if value is None:
self.last_capacity = self.default_capacity
return self.default_capacity
return int(value)
@last_capacity.setter
def last_capacity(self, value):
self.redis.hset(self.key, "last_capacity", value)
@property
def last_timestamp(self):
value = self.redis.hget(self.key, "last_timestamp")
if value is None:
timestamp = int(time.time())
self.last_timestamp = timestamp
return timestamp
return int(value)
@last_timestamp.setter
def last_timestamp(self, value):
self.redis.hset(self.key, "last_timestamp", value)
"""
# token bucket 机制限制用户提交大量代码
# demo
success = failure = 0
current_user_id = 1
token_bucket_default_capacity = 50
token_bucket_fill_rate = 10
for i in range(5000):
controller = BucketController(user_id=current_user_id,
redis_conn=redis.Redis(),
default_capacity=token_bucket_default_capacity)
bucket = TokenBucket(fill_rate=token_bucket_fill_rate,
capacity=token_bucket_default_capacity,
last_capacity=controller.last_capacity,
last_timestamp=controller.last_timestamp)
time.sleep(0.05)
if bucket.consume():
success += 1
print i, ": Accepted"
controller.last_capacity -= 1
else:
failure += 1
print i, "Dropped, time left ", bucket.expected_time()
print success, failure
"""