Merge pull request #113 from QingdaoU/fix1

fix some issues
This commit is contained in:
李扬 2017-12-26 13:43:07 +08:00 committed by GitHub
commit d94a512edb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 287 additions and 258 deletions

View File

@ -15,4 +15,4 @@ RUN curl -L $(curl -s https://api.github.com/repos/QingdaoU/OnlineJudgeFE/rele
unzip dist.zip && \
rm dist.zip
CMD sh /app/deploy/run.sh
ENTRYPOINT /app/deploy/entrypoint.sh

View File

@ -1,33 +1,23 @@
import logging
from celery import shared_task
from envelopes import Envelope
from options.options import SysOptions
from utils.shortcuts import send_email
logger = logging.getLogger(__name__)
def send_email(from_name, to_email, to_name, subject, content):
smtp = SysOptions.smtp_config
if not smtp:
return
envlope = Envelope(from_addr=(smtp["email"], from_name),
to_addr=(to_email, to_name),
subject=subject,
html_body=content)
try:
envlope.send(smtp["server"],
login=smtp["email"],
password=smtp["password"],
port=smtp["port"],
tls=smtp["tls"])
return True
except Exception as e:
logger.exception(e)
return False
@shared_task
def send_email_async(from_name, to_email, to_name, subject, content):
send_email(from_name, to_email, to_name, subject, content)
if not SysOptions.smtp_config:
return
try:
send_email(smtp_config=SysOptions.smtp_config,
from_name=from_name,
to_email=to_email,
to_name=to_name,
subject=subject,
content=content)
except Exception as e:
logger.exception(e)

View File

@ -1,77 +1,31 @@
<table cellpadding="0" cellspacing="0" align="center" style="text-align:left;font-family:'微软雅黑','黑体',arial;"
width="742">
<div>
<table cellpadding="0" align="center"
style="overflow:hidden;background:#fff;margin:0 auto;text-align:left;position:relative;font-size:14px; font-family:'lucida Grande',Verdana;line-height:1.5;box-shadow:0 0 3px #ccc;border:1px solid #ccc;border-radius:5px;border-collapse:collapse;">
<tbody>
<tr>
<th valign="middle"
style="height:38px;color:#fff; font-size:14px;line-height:38px; font-weight:bold;text-align:left;padding:10px 24px 6px; border-bottom:1px solid #467ec3;background:#518bcb;border-radius:5px 5px 0 0;">
{{ website_name }}</th>
</tr>
<tr>
<td>
<table cellpadding="0" cellspacing="0"
style="text-align:left;border:1px solid #50a5e6;color:#fff;font-size:18px;" width="740">
<tbody>
<tr height="39" style="background-color:#50a5e6;">
<td style="padding-left:15px;font-family:'微软雅黑','黑体',arial;">
{{ website_name }}
</td>
</tr>
</tbody>
</table>
<table cellpadding="0" cellspacing="0"
style="text-align:left;border:1px solid #f0f0f0;border-top:none;color:#585858;background-color:#fafafa;"
width="740">
<tbody>
<tr height="25">
<td></td>
</tr>
<tr height="40">
<td style="padding-left:25px;padding-right:25px;font-size:18px;font-family:'微软雅黑','黑体',arial;">
Hello, {{ username }}:
<div style="padding:20px 35px 40px;">
<h2 style="font-weight:bold;margin-bottom:5px;font-size:14px;">Hello, {{ username }}:</h2>
<p style="margin-top:20px">
Please click <a href="{{ link }}">{{ link }}</a> to reset your password in 20 minutes.
</p>
<p style="margin-top:20px">
To protect your account, please do not use simple passwords.
</p>
<p style="margin-top:20px">
If you still have any questions, please contract system administrator.
</p>
<p style="margin-left:2em;"></p>
<p style="text-indent:0;text-align:right;">{{ website_name }}</p>
</div>
</td>
</tr>
<tr height="15">
<td></td>
</tr>
<tr height="30">
<td style="padding-left:55px;padding-right:55px;font-family:'微软雅黑','黑体',arial;font-size:14px;">
We received a request to reset your password for {{ website_name }}.
</td>
</tr>
<tr height="30">
<td style="padding-left:55px;padding-right:55px;font-family:'微软雅黑','黑体',arial;font-size:14px;">
You can use the following link to reset your password in <span style="color:rgb(255,0,0)">20 minutes.</span>
</td>
</tr>
<tr height="60">
<td style="padding-left:55px;padding-right:55px;font-family:'微软雅黑','黑体',arial;font-size:14px;">
<a href="{{ link }}" target="_blank"
style="color: rgb(255,255,255);text-decoration: none;display: block;min-height: 39px;width: 158px;line-height: 39px;background-color:rgb(80,165,230);font-size:20px;text-align:center;">Reset Password</a>
</td>
</tr>
<tr height="10">
<td></td>
</tr>
<tr height="20">
<td style="padding-left:55px;padding-right:55px;font-family:'微软雅黑','黑体',arial;font-size:12px;">
If the button above doesn't work, please copy the following link to your browser and press enter.
</td>
</tr>
<tr height="30">
<td style="padding-left:55px;padding-right:65px;font-family:'微软雅黑','黑体',arial;">
<a href="{{ link }}" target="_blank" style="color:#0c94de;font-size:12px;">
{{ link }}
</a>
</td>
</tr>
<tr height="20">
<td style="padding-left:55px;padding-right:55px;font-family:'微软雅黑','黑体',arial;font-size:12px;">
If you did not ask that, please ignore this email. It will expire and become useless in 20 minutes.
</td>
</tr>
<tr height="20">
<td></td>
</tr>
</tbody>
</table>
</td>
</tr>
</tbody>
</table>
</div>

View File

@ -302,11 +302,11 @@ class ApplyResetPasswordAPI(APIView):
"link": f"{SysOptions.website_base_url}/reset-password/{user.reset_password_token}"
}
email_html = render_to_string("reset_password_email.html", render_data)
send_email_async.delay(SysOptions.website_name,
user.email,
user.username,
f"{SysOptions.website_name} 登录信息找回邮件",
email_html)
send_email_async.delay(from_name=SysOptions.website_name_shortcut,
to_email=user.email,
to_username=user.username,
subject=f"Reset your password",
content=email_html)
return self.success("Succeeded")

View File

@ -0,0 +1,20 @@
# -*- coding: utf-8 -*-
# Generated by Django 1.11.4 on 2017-12-24 03:44
from __future__ import unicode_literals
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('conf', '0002_auto_20171011_1214'),
]
operations = [
migrations.AddField(
model_name='judgeserver',
name='is_disabled',
field=models.BooleanField(default=False),
),
]

View File

@ -13,6 +13,7 @@ class JudgeServer(models.Model):
create_time = models.DateTimeField(auto_now_add=True)
task_number = models.IntegerField(default=0)
service_url = models.CharField(max_length=256, blank=True, null=True)
is_disabled = models.BooleanField(default=False)
@property
def status(self):

View File

@ -43,4 +43,9 @@ class JudgeServerHeartbeatSerializer(serializers.Serializer):
memory = serializers.FloatField(min_value=0, max_value=100)
cpu = serializers.FloatField(min_value=0, max_value=100)
action = serializers.ChoiceField(choices=("heartbeat", ))
service_url = serializers.CharField(max_length=256, required=False)
service_url = serializers.CharField(max_length=256)
class EditJudgeServerSerializer(serializers.Serializer):
id = serializers.IntegerField()
is_disabled = serializers.BooleanField()

View File

@ -43,6 +43,14 @@ class SMTPConfigTest(APITestCase):
resp = self.client.put(self.url, data=data)
self.assertSuccess(resp)
@mock.patch("conf.views.send_email")
def test_test_smtp(self, mocked_send_email):
url = self.reverse("smtp_test_api")
self.test_create_smtp_config()
resp = self.client.post(url, data={"email": "test@test.com"})
self.assertSuccess(resp)
mocked_send_email.assert_called_once()
class WebsiteConfigAPITest(APITestCase):
def test_create_website_config(self):
@ -58,10 +66,11 @@ class WebsiteConfigAPITest(APITestCase):
self.create_super_admin()
url = self.reverse("website_config_api")
data = {"website_base_url": "http://test.com", "website_name": "test name",
"website_name_shortcut": "test oj", "website_footer": "<a>test</a>",
"website_name_shortcut": "test oj", "website_footer": "<img onerror=alert(1) src=#>",
"allow_register": True, "submission_list_show_all": False}
resp = self.client.post(url, data=data)
self.assertSuccess(resp)
self.assertEqual(SysOptions.website_footer, "<img src=\"#\" />")
def test_get_website_config(self):
# do not need to login
@ -74,7 +83,7 @@ class JudgeServerHeartbeatTest(APITestCase):
def setUp(self):
self.url = self.reverse("judge_server_heartbeat_api")
self.data = {"hostname": "testhostname", "judger_version": "1.0.4", "cpu_core": 4,
"cpu": 90.5, "memory": 80.3, "action": "heartbeat"}
"cpu": 90.5, "memory": 80.3, "action": "heartbeat", "service_url": "http://127.0.0.1"}
self.token = "test"
self.hashed_token = hashlib.sha256(self.token.encode("utf-8")).hexdigest()
SysOptions.judge_server_token = self.token
@ -85,16 +94,6 @@ class JudgeServerHeartbeatTest(APITestCase):
self.assertSuccess(resp)
server = JudgeServer.objects.first()
self.assertEqual(server.ip, "127.0.0.1")
self.assertEqual(server.service_url, None)
def test_new_heartbeat_service_url(self):
service_url = "http://1.2.3.4:8000/api/judge"
data = self.data
data["service_url"] = service_url
resp = self.client.post(self.url, data=self.data, **self.headers)
self.assertSuccess(resp)
server = JudgeServer.objects.first()
self.assertEqual(server.service_url, service_url)
def test_update_heartbeat(self):
self.test_new_heartbeat()
@ -107,7 +106,7 @@ class JudgeServerHeartbeatTest(APITestCase):
class JudgeServerAPITest(APITestCase):
def setUp(self):
JudgeServer.objects.create(**{"hostname": "testhostname", "judger_version": "1.0.4",
self.server = JudgeServer.objects.create(**{"hostname": "testhostname", "judger_version": "1.0.4",
"cpu_core": 4, "cpu_usage": 90.5, "memory_usage": 80.3,
"last_heartbeat": timezone.now()})
self.url = self.reverse("judge_server_api")
@ -123,6 +122,11 @@ class JudgeServerAPITest(APITestCase):
self.assertSuccess(resp)
self.assertFalse(JudgeServer.objects.filter(hostname="testhostname").exists())
def test_disabled_judge_server(self):
resp = self.client.put(self.url, data={"is_disabled": True, "id": self.server.id})
self.assertSuccess(resp)
self.assertTrue(JudgeServer.objects.get(id=self.server.id).is_disabled)
class LanguageListAPITest(APITestCase):
def test_get_languages(self):

View File

@ -1,9 +1,10 @@
from django.conf.urls import url
from ..views import SMTPAPI, JudgeServerAPI, WebsiteConfigAPI, TestCasePruneAPI
from ..views import SMTPAPI, JudgeServerAPI, WebsiteConfigAPI, TestCasePruneAPI, SMTPTestAPI
urlpatterns = [
url(r"^smtp/?$", SMTPAPI.as_view(), name="smtp_admin_api"),
url(r"^smtp_test/?$", SMTPTestAPI.as_view(), name="smtp_test_api"),
url(r"^website/?$", WebsiteConfigAPI.as_view(), name="website_config_api"),
url(r"^judge_server/?$", JudgeServerAPI.as_view(), name="judge_server_api"),
url(r"^prune_test_case/?$", TestCasePruneAPI.as_view(), name="prune_test_case_api"),

View File

@ -12,11 +12,13 @@ from judge.dispatcher import process_pending_task
from judge.languages import languages, spj_languages
from options.options import SysOptions
from utils.api import APIView, CSRFExemptAPIView, validate_serializer
from utils.shortcuts import send_email
from utils.xss_filter import XSSHtml
from .models import JudgeServer
from .serializers import (CreateEditWebsiteConfigSerializer,
CreateSMTPConfigSerializer, EditSMTPConfigSerializer,
JudgeServerHeartbeatSerializer,
JudgeServerSerializer, TestSMTPConfigSerializer)
JudgeServerSerializer, TestSMTPConfigSerializer, EditJudgeServerSerializer)
class SMTPAPI(APIView):
@ -51,7 +53,25 @@ class SMTPTestAPI(APIView):
@super_admin_required
@validate_serializer(TestSMTPConfigSerializer)
def post(self, request):
return self.success({"result": True})
if not SysOptions.smtp_config:
return self.error("Please setup SMTP config at first")
try:
send_email(smtp_config=SysOptions.smtp_config,
from_name=SysOptions.website_name_shortcut,
to_name=request.user.username,
to_email=request.data["email"],
subject="You have successfully configured SMTP",
content="You have successfully configured SMTP")
except Exception as e:
# guess error message encoding
msg = e.smtp_error
try:
# qq mail
msg = msg.decode("gbk")
except Exception:
msg = msg.decode("utf-8", "ignore")
return self.error(msg)
return self.success()
class WebsiteConfigAPI(APIView):
@ -65,6 +85,9 @@ class WebsiteConfigAPI(APIView):
@super_admin_required
def post(self, request):
for k, v in request.data.items():
if k == "website_footer":
with XSSHtml() as parser:
v = parser.clean(v)
setattr(SysOptions, k, v)
return self.success()
@ -83,6 +106,12 @@ class JudgeServerAPI(APIView):
JudgeServer.objects.filter(hostname=hostname).delete()
return self.success()
@validate_serializer(EditJudgeServerSerializer)
@super_admin_required
def put(self, request):
JudgeServer.objects.filter(id=request.data["id"]).update(is_disabled=request.data["is_disabled"])
return self.success()
class JudgeServerHeartbeatAPI(CSRFExemptAPIView):
@validate_serializer(JudgeServerHeartbeatSerializer)
@ -91,7 +120,6 @@ class JudgeServerHeartbeatAPI(CSRFExemptAPIView):
client_token = request.META.get("HTTP_X_JUDGE_SERVER_TOKEN")
if hashlib.sha256(SysOptions.judge_server_token.encode("utf-8")).hexdigest() != client_token:
return self.error("Invalid token")
service_url = data.get("service_url")
try:
server = JudgeServer.objects.get(hostname=data["hostname"])
@ -99,7 +127,7 @@ class JudgeServerHeartbeatAPI(CSRFExemptAPIView):
server.cpu_core = data["cpu_core"]
server.memory_usage = data["memory"]
server.cpu_usage = data["cpu"]
server.service_url = service_url
server.service_url = data["service_url"]
server.ip = request.META["HTTP_X_REAL_IP"]
server.last_heartbeat = timezone.now()
server.save()
@ -110,7 +138,7 @@ class JudgeServerHeartbeatAPI(CSRFExemptAPIView):
memory_usage=data["memory"],
cpu_usage=data["cpu"],
ip=request.META["REMOTE_ADDR"],
service_url=service_url,
service_url=data["service_url"],
last_heartbeat=timezone.now(),
)
# 新server上线 处理队列中的防止没有新的提交而导致一直waiting

View File

@ -16,3 +16,4 @@ psycopg2
gunicorn
jsonfield
XlsxWriter
raven

3
docs/README.md Normal file
View File

@ -0,0 +1,3 @@
# DOCUMENT
[Here](http://docs.onlinejudge.me/)

16
docs/data.json Normal file
View File

@ -0,0 +1,16 @@
{
"update": [
{
"version": "2017-12-25",
"level": 1,
"title": "Update at 2017-12-25",
"details": [
"Fix some issues under IE/Edge",
"Add backend error reporter",
"New email template",
"A more flexible throttling function",
"Other bugs and enhancements"
]
}
]
}

View File

@ -6,7 +6,6 @@ from urllib.parse import urljoin
import requests
from django.db import transaction
from django.db.models import F
from django.conf import settings
from account.models import User
from conf.models import JudgeServer
@ -47,7 +46,7 @@ class DispatcherBase(object):
@staticmethod
def choose_judge_server():
with transaction.atomic():
servers = JudgeServer.objects.select_for_update().all().order_by("task_number")
servers = JudgeServer.objects.select_for_update().filter(is_disabled=False).order_by("task_number")
servers = [s for s in servers if s.status == "normal"]
if servers:
server = servers[0]
@ -154,11 +153,7 @@ class JudgeDispatcher(DispatcherBase):
Submission.objects.filter(id=self.submission.id).update(result=JudgeStatus.JUDGING)
service_url = server.service_url
# not set service_url, it should be a linked container
if not service_url:
service_url = settings.DEFAULT_JUDGE_SERVER_SERVICE_URL
resp = self._request(urljoin(service_url, "/judge"), data=data)
resp = self._request(urljoin(server.service_url, "/judge"), data=data)
if resp["err"]:
self.submission.result = JudgeStatus.COMPILE_ERROR
self.submission.statistic_info["err_info"] = resp["data"]

View File

@ -10,6 +10,7 @@ For the full list of settings and their values, see
https://docs.djangoproject.com/en/1.8/ref/settings/
"""
import os
import raven
from copy import deepcopy
if os.environ.get("OJ_ENV") == "production":
@ -29,6 +30,7 @@ VENDOR_APPS = (
'django.contrib.messages',
'django.contrib.staticfiles',
'rest_framework',
'raven.contrib.django.raven_compat'
)
LOCAL_APPS = (
'account',
@ -125,6 +127,8 @@ UPLOAD_DIR = f"{DATA_DIR}{UPLOAD_PREFIX}"
STATICFILES_DIRS = [os.path.join(DATA_DIR, "public")]
LOGGING_HANDLERS = ['console'] if DEBUG else ['console', 'sentry']
LOGGING = {
'version': 1,
'disable_existing_loggers': False,
@ -139,21 +143,26 @@ LOGGING = {
'level': 'DEBUG',
'class': 'logging.StreamHandler',
'formatter': 'standard'
},
'sentry': {
'level': 'ERROR',
'class': 'raven.contrib.django.raven_compat.handlers.SentryHandler',
'formatter': 'standard'
}
},
'loggers': {
'django.request': {
'handlers': ['console'],
'handlers': LOGGING_HANDLERS,
'level': 'ERROR',
'propagate': True,
},
'django.db.backends': {
'handlers': ['console'],
'handlers': LOGGING_HANDLERS,
'level': 'ERROR',
'propagate': True,
},
'': {
'handlers': ['console'],
'handlers': LOGGING_HANDLERS,
'level': 'WARNING',
'propagate': True,
}
@ -201,3 +210,7 @@ TOKEN_BUCKET_DEFAULT_CAPACITY = 10
# 单位:每分钟
TOKEN_BUCKET_FILL_RATE = 2
RAVEN_CONFIG = {
'dsn': 'https://b200023b8aed4d708fb593c5e0a6ad3d:1fddaba168f84fcf97e0d549faaeaff0@sentry.io/263057'
}

View File

@ -21,6 +21,7 @@ class OptionKeys:
submission_list_show_all = "submission_list_show_all"
smtp_config = "smtp_config"
judge_server_token = "judge_server_token"
throttling = "throttling"
class OptionDefaultValue:
@ -32,6 +33,8 @@ class OptionDefaultValue:
submission_list_show_all = True
smtp_config = {}
judge_server_token = default_token
throttling = {"ip": {"capacity": 100, "fill_rate": 0.1, "default_capacity": 50},
"user": {"capacity": 20, "fill_rate": 0.03, "default_capacity": 10}}
class _SysOptionsMeta(type):
@ -180,6 +183,14 @@ class _SysOptionsMeta(type):
def judge_server_token(cls, value):
cls._set_option(OptionKeys.judge_server_token, value)
@property
def throttling(cls):
return cls._get_option(OptionKeys.throttling)
@throttling.setter
def throttling(cls, value):
cls._set_option(OptionKeys.throttling, value)
class SysOptions(metaclass=_SysOptionsMeta):
pass

View File

@ -1,6 +1,5 @@
import ipaddress
from django.conf import settings
from account.decorators import login_required, check_contest_permission
from judge.tasks import judge_task
# from judge.dispatcher import JudgeDispatcher
@ -8,7 +7,7 @@ from problem.models import Problem, ProblemRuleType
from contest.models import Contest, ContestStatus, ContestRuleType
from options.options import SysOptions
from utils.api import APIView, validate_serializer
from utils.throttling import TokenBucket, BucketController
from utils.throttling import TokenBucket
from utils.captcha import Captcha
from utils.cache import cache
from ..models import Submission
@ -19,29 +18,16 @@ from ..serializers import SubmissionSafeModelSerializer, SubmissionListSerialize
class SubmissionAPI(APIView):
def throttling(self, request):
user_controller = BucketController(factor=request.user.id,
redis_conn=cache,
default_capacity=settings.TOKEN_BUCKET_DEFAULT_CAPACITY)
user_bucket = TokenBucket(fill_rate=settings.TOKEN_BUCKET_FILL_RATE,
capacity=settings.TOKEN_BUCKET_DEFAULT_CAPACITY,
last_capacity=user_controller.last_capacity,
last_timestamp=user_controller.last_timestamp)
if user_bucket.consume():
user_controller.last_capacity -= 1
else:
return "Please wait %d seconds" % int(user_bucket.expected_time() + 1)
user_bucket = TokenBucket(key=str(request.user.id),
redis_conn=cache, **SysOptions.throttling["user"])
can_consume, wait = user_bucket.consume()
if not can_consume:
return "Please wait %d seconds" % (int(wait))
ip_controller = BucketController(factor=request.session["ip"],
redis_conn=cache,
default_capacity=settings.TOKEN_BUCKET_DEFAULT_CAPACITY * 3)
ip_bucket = TokenBucket(fill_rate=settings.TOKEN_BUCKET_FILL_RATE * 3,
capacity=settings.TOKEN_BUCKET_DEFAULT_CAPACITY * 3,
last_capacity=ip_controller.last_capacity,
last_timestamp=ip_controller.last_timestamp)
if ip_bucket.consume():
ip_controller.last_capacity -= 1
else:
ip_bucket = TokenBucket(key=request.session["ip"],
redis_conn=cache, **SysOptions.throttling["ip"])
can_consume, wait = ip_bucket.consume()
if not can_consume:
return "Captcha is required"
@validate_serializer(CreateSubmissionSerializer)

View File

@ -1,14 +1,10 @@
from django.contrib.postgres.fields import JSONField # NOQA
from django.db import models
from utils.xss_filter import XssHtml
from utils.xss_filter import XSSHtml
class RichTextField(models.TextField):
def get_prep_value(self, value):
if not value:
value = ""
parser = XssHtml()
parser.feed(value)
parser.close()
return parser.getHtml()
with XSSHtml() as parser:
return parser.clean(value or "")

View File

@ -5,6 +5,7 @@ from base64 import b64encode
from io import BytesIO
from django.utils.crypto import get_random_string
from envelopes import Envelope
def rand_str(length=32, type="lower_hex"):
@ -63,3 +64,15 @@ def timestamp2utcstr(value):
def natural_sort_key(s, _nsre=re.compile(r"(\d+)")):
return [int(text) if text.isdigit() else text.lower()
for text in re.split(_nsre, s)]
def send_email(smtp_config, from_name, to_email, to_name, subject, content):
envelope = Envelope(from_addr=(smtp_config["email"], from_name),
to_addr=(to_email, to_name),
subject=subject,
html_body=content)
return envelope.send(smtp_config["server"],
login=smtp_config["email"],
password=smtp_config["password"],
port=smtp_config["port"],
tls=smtp_config["tls"])

View File

@ -1,90 +1,72 @@
from __future__ import print_function
import time
class TokenBucket:
def __init__(self, fill_rate, capacity, last_capacity, last_timestamp):
self.capacity = float(capacity)
self._left_tokens = last_capacity
self.fill_rate = float(fill_rate)
self.timestamp = last_timestamp
"""
注意对于单个key的操作不是线程安全的
"""
def __init__(self, key, capacity, fill_rate, default_capacity, redis_conn):
"""
:param capacity: 最大容量
:param fill_rate: 填充速度/每秒
:param default_capacity: 初始容量
:param redis_conn: redis connection
"""
self._key = key
self._capacity = capacity
self._fill_rate = fill_rate
self._default_capacity = default_capacity
self._redis_conn = redis_conn
def consume(self, tokens=1):
if tokens <= self.tokens:
self._left_tokens -= tokens
return True
return False
self._last_capacity_key = "last_capacity"
self._last_timestamp_key = "last_timestamp"
def expected_time(self, tokens=1):
_tokens = self.tokens
tokens = max(tokens, _tokens)
return (tokens - _tokens) / self.fill_rate * 60
@property
def tokens(self):
if self._left_tokens < self.capacity:
def _init_key(self):
self._last_capacity = self._default_capacity
now = time.time()
delta = self.fill_rate * ((now - self.timestamp) / 60)
self._left_tokens = min(self.capacity, self._left_tokens + delta)
self.timestamp = now
return self._left_tokens
class BucketController:
def __init__(self, factor, redis_conn, default_capacity):
self.default_capacity = default_capacity
self.redis = redis_conn
self.key = "bucket_" + str(factor)
self._last_timestamp = now
return self._default_capacity, now
@property
def last_capacity(self):
value = self.redis.hget(self.key, "last_capacity")
if value is None:
self.last_capacity = self.default_capacity
return self.default_capacity
return int(value)
@last_capacity.setter
def last_capacity(self, value):
self.redis.hset(self.key, "last_capacity", value)
@property
def last_timestamp(self):
value = self.redis.hget(self.key, "last_timestamp")
if value is None:
timestamp = int(time.time())
self.last_timestamp = timestamp
return timestamp
return int(value)
@last_timestamp.setter
def last_timestamp(self, value):
self.redis.hset(self.key, "last_timestamp", value)
"""
# # Token bucket, to limit submission rate
# # Demo
success = failure = 0
current_user_id = 1
token_bucket_default_capacity = 50
token_bucket_fill_rate = 10
for i in range(5000):
controller = BucketController(user_id=current_user_id,
redis_conn=redis.Redis(),
default_capacity=token_bucket_default_capacity)
bucket = TokenBucket(fill_rate=token_bucket_fill_rate,
capacity=token_bucket_default_capacity,
last_capacity=controller.last_capacity,
last_timestamp=controller.last_timestamp)
time.sleep(0.05)
if bucket.consume():
success += 1
print(i, ": Accepted")
controller.last_capacity -= 1
def _last_capacity(self):
last_capacity = self._redis_conn.hget(self._key, self._last_capacity_key)
if last_capacity is None:
return self._init_key()[0]
else:
failure += 1
print(i, "Dropped, time left ", bucket.expected_time())
print(success, failure)
return float(last_capacity)
@_last_capacity.setter
def _last_capacity(self, value):
self._redis_conn.hset(self._key, self._last_capacity_key, value)
@property
def _last_timestamp(self):
return float(self._redis_conn.hget(self._key, self._last_timestamp_key))
@_last_timestamp.setter
def _last_timestamp(self, value):
self._redis_conn.hset(self._key, self._last_timestamp_key, value)
def _try_to_fill(self, now):
delta = self._fill_rate * (now - self._last_timestamp)
return min(self._last_capacity + delta, self._capacity)
def consume(self, num=1):
"""
消耗 num token返回是否成功
:param num:
:return: result: bool, wait_time: float
"""
# print("capacity ", self.fill(time.time()))
if self._last_capacity >= num:
self._last_capacity -= num
return True, 0
else:
now = time.time()
cur_num = self._try_to_fill(now)
if cur_num >= num:
self._last_capacity = cur_num - num
self._last_timestamp = now
return True, 0
else:
return False, (num - cur_num) / self._fill_rate

View File

@ -30,7 +30,7 @@ import copy
from html.parser import HTMLParser
class XssHtml(HTMLParser):
class XSSHtml(HTMLParser):
allow_tags = ['a', 'img', 'br', 'strong', 'b', 'code', 'pre',
'p', 'div', 'em', 'span', 'h1', 'h2', 'h3', 'h4',
'h5', 'h6', 'blockquote', 'ul', 'ol', 'tr', 'th', 'td',
@ -53,7 +53,17 @@ class XssHtml(HTMLParser):
self.start = []
self.data = []
def getHtml(self):
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
super().close()
def clean(self, content):
self.feed(content)
return self.get_html()
def get_html(self):
"""
Get the safe html code
"""
@ -188,11 +198,11 @@ class XssHtml(HTMLParser):
if "__main__" == __name__:
parser = XssHtml()
parser.feed("""<p><img src=1 onerror=alert(/xss/)></p><div class="left">
with XSSHtml() as parser:
ret = parser.clean("""<p><img src=1 onerror=alert(/xss/)></p><div class="left">
<a href='javascript:prompt(1)'><br />hehe</a></div>
<p id="test" onmouseover="alert(1)">&gt;M<svg>
<a href="https://www.baidu.com" target="self">MM</a></p>
<embed src='javascript:alert(/hehe/)' allowscriptaccess=always />""")
parser.close()
print(parser.getHtml())
<embed src='javascript:alert(/hehe/)' allowscriptaccess=always />
<img onerror=alert(1) src=#>""")
print(ret)