diff --git a/Dockerfile b/Dockerfile index a1ca9f99..41cabf83 100644 --- a/Dockerfile +++ b/Dockerfile @@ -15,4 +15,4 @@ RUN curl -L $(curl -s https://api.github.com/repos/QingdaoU/OnlineJudgeFE/rele unzip dist.zip && \ rm dist.zip -CMD sh /app/deploy/run.sh +ENTRYPOINT /app/deploy/entrypoint.sh diff --git a/account/tasks.py b/account/tasks.py index 3e7c1d2f..12c1587d 100644 --- a/account/tasks.py +++ b/account/tasks.py @@ -1,33 +1,23 @@ import logging from celery import shared_task -from envelopes import Envelope from options.options import SysOptions +from utils.shortcuts import send_email logger = logging.getLogger(__name__) -def send_email(from_name, to_email, to_name, subject, content): - smtp = SysOptions.smtp_config - if not smtp: - return - envlope = Envelope(from_addr=(smtp["email"], from_name), - to_addr=(to_email, to_name), - subject=subject, - html_body=content) - try: - envlope.send(smtp["server"], - login=smtp["email"], - password=smtp["password"], - port=smtp["port"], - tls=smtp["tls"]) - return True - except Exception as e: - logger.exception(e) - return False - - @shared_task def send_email_async(from_name, to_email, to_name, subject, content): - send_email(from_name, to_email, to_name, subject, content) + if not SysOptions.smtp_config: + return + try: + send_email(smtp_config=SysOptions.smtp_config, + from_name=from_name, + to_email=to_email, + to_name=to_name, + subject=subject, + content=content) + except Exception as e: + logger.exception(e) diff --git a/account/templates/reset_password_email.html b/account/templates/reset_password_email.html index b2f76f88..54dcc55a 100644 --- a/account/templates/reset_password_email.html +++ b/account/templates/reset_password_email.html @@ -1,77 +1,31 @@ - - - - - - -
- - - - - - -
- {{ website_name }} -
- - - - - +
+
+ + + + + + + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
+ {{ website_name }}
+
+

Hello, {{ username }}:

+

+ Please click {{ link }} to reset your password in 20 minutes. +

+

+ To protect your account, please do not use simple passwords. +

+

+ If you still have any questions, please contract system administrator. +

+

+

{{ website_name }}

+
+
- Hello, {{ username }}: -
- We received a request to reset your password for {{ website_name }}. -
- You can use the following link to reset your password in 20 minutes. -
- Reset Password -
- If the button above doesn't work, please copy the following link to your browser and press enter. -
- - {{ link }} - -
- If you did not ask that, please ignore this email. It will expire and become useless in 20 minutes. -
-
\ No newline at end of file + + + \ No newline at end of file diff --git a/account/views/oj.py b/account/views/oj.py index 4aaec7cb..79b5622e 100644 --- a/account/views/oj.py +++ b/account/views/oj.py @@ -302,11 +302,11 @@ class ApplyResetPasswordAPI(APIView): "link": f"{SysOptions.website_base_url}/reset-password/{user.reset_password_token}" } email_html = render_to_string("reset_password_email.html", render_data) - send_email_async.delay(SysOptions.website_name, - user.email, - user.username, - f"{SysOptions.website_name} 登录信息找回邮件", - email_html) + send_email_async.delay(from_name=SysOptions.website_name_shortcut, + to_email=user.email, + to_username=user.username, + subject=f"Reset your password", + content=email_html) return self.success("Succeeded") diff --git a/conf/migrations/0003_judgeserver_is_disabled.py b/conf/migrations/0003_judgeserver_is_disabled.py new file mode 100644 index 00000000..6a571c01 --- /dev/null +++ b/conf/migrations/0003_judgeserver_is_disabled.py @@ -0,0 +1,20 @@ +# -*- coding: utf-8 -*- +# Generated by Django 1.11.4 on 2017-12-24 03:44 +from __future__ import unicode_literals + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('conf', '0002_auto_20171011_1214'), + ] + + operations = [ + migrations.AddField( + model_name='judgeserver', + name='is_disabled', + field=models.BooleanField(default=False), + ), + ] diff --git a/conf/models.py b/conf/models.py index 4c6348d5..2317f81c 100644 --- a/conf/models.py +++ b/conf/models.py @@ -13,6 +13,7 @@ class JudgeServer(models.Model): create_time = models.DateTimeField(auto_now_add=True) task_number = models.IntegerField(default=0) service_url = models.CharField(max_length=256, blank=True, null=True) + is_disabled = models.BooleanField(default=False) @property def status(self): diff --git a/conf/serializers.py b/conf/serializers.py index 3c741b27..50f259c1 100644 --- a/conf/serializers.py +++ b/conf/serializers.py @@ -43,4 +43,9 @@ class JudgeServerHeartbeatSerializer(serializers.Serializer): memory = serializers.FloatField(min_value=0, max_value=100) cpu = serializers.FloatField(min_value=0, max_value=100) action = serializers.ChoiceField(choices=("heartbeat", )) - service_url = serializers.CharField(max_length=256, required=False) + service_url = serializers.CharField(max_length=256) + + +class EditJudgeServerSerializer(serializers.Serializer): + id = serializers.IntegerField() + is_disabled = serializers.BooleanField() diff --git a/conf/tests.py b/conf/tests.py index c0b22f38..55f4853d 100644 --- a/conf/tests.py +++ b/conf/tests.py @@ -43,6 +43,14 @@ class SMTPConfigTest(APITestCase): resp = self.client.put(self.url, data=data) self.assertSuccess(resp) + @mock.patch("conf.views.send_email") + def test_test_smtp(self, mocked_send_email): + url = self.reverse("smtp_test_api") + self.test_create_smtp_config() + resp = self.client.post(url, data={"email": "test@test.com"}) + self.assertSuccess(resp) + mocked_send_email.assert_called_once() + class WebsiteConfigAPITest(APITestCase): def test_create_website_config(self): @@ -58,10 +66,11 @@ class WebsiteConfigAPITest(APITestCase): self.create_super_admin() url = self.reverse("website_config_api") data = {"website_base_url": "http://test.com", "website_name": "test name", - "website_name_shortcut": "test oj", "website_footer": "test", + "website_name_shortcut": "test oj", "website_footer": "", "allow_register": True, "submission_list_show_all": False} resp = self.client.post(url, data=data) self.assertSuccess(resp) + self.assertEqual(SysOptions.website_footer, "") def test_get_website_config(self): # do not need to login @@ -74,7 +83,7 @@ class JudgeServerHeartbeatTest(APITestCase): def setUp(self): self.url = self.reverse("judge_server_heartbeat_api") self.data = {"hostname": "testhostname", "judger_version": "1.0.4", "cpu_core": 4, - "cpu": 90.5, "memory": 80.3, "action": "heartbeat"} + "cpu": 90.5, "memory": 80.3, "action": "heartbeat", "service_url": "http://127.0.0.1"} self.token = "test" self.hashed_token = hashlib.sha256(self.token.encode("utf-8")).hexdigest() SysOptions.judge_server_token = self.token @@ -85,16 +94,6 @@ class JudgeServerHeartbeatTest(APITestCase): self.assertSuccess(resp) server = JudgeServer.objects.first() self.assertEqual(server.ip, "127.0.0.1") - self.assertEqual(server.service_url, None) - - def test_new_heartbeat_service_url(self): - service_url = "http://1.2.3.4:8000/api/judge" - data = self.data - data["service_url"] = service_url - resp = self.client.post(self.url, data=self.data, **self.headers) - self.assertSuccess(resp) - server = JudgeServer.objects.first() - self.assertEqual(server.service_url, service_url) def test_update_heartbeat(self): self.test_new_heartbeat() @@ -107,9 +106,9 @@ class JudgeServerHeartbeatTest(APITestCase): class JudgeServerAPITest(APITestCase): def setUp(self): - JudgeServer.objects.create(**{"hostname": "testhostname", "judger_version": "1.0.4", - "cpu_core": 4, "cpu_usage": 90.5, "memory_usage": 80.3, - "last_heartbeat": timezone.now()}) + self.server = JudgeServer.objects.create(**{"hostname": "testhostname", "judger_version": "1.0.4", + "cpu_core": 4, "cpu_usage": 90.5, "memory_usage": 80.3, + "last_heartbeat": timezone.now()}) self.url = self.reverse("judge_server_api") self.create_super_admin() @@ -123,6 +122,11 @@ class JudgeServerAPITest(APITestCase): self.assertSuccess(resp) self.assertFalse(JudgeServer.objects.filter(hostname="testhostname").exists()) + def test_disabled_judge_server(self): + resp = self.client.put(self.url, data={"is_disabled": True, "id": self.server.id}) + self.assertSuccess(resp) + self.assertTrue(JudgeServer.objects.get(id=self.server.id).is_disabled) + class LanguageListAPITest(APITestCase): def test_get_languages(self): diff --git a/conf/urls/admin.py b/conf/urls/admin.py index 469ecc22..5b194b84 100644 --- a/conf/urls/admin.py +++ b/conf/urls/admin.py @@ -1,9 +1,10 @@ from django.conf.urls import url -from ..views import SMTPAPI, JudgeServerAPI, WebsiteConfigAPI, TestCasePruneAPI +from ..views import SMTPAPI, JudgeServerAPI, WebsiteConfigAPI, TestCasePruneAPI, SMTPTestAPI urlpatterns = [ url(r"^smtp/?$", SMTPAPI.as_view(), name="smtp_admin_api"), + url(r"^smtp_test/?$", SMTPTestAPI.as_view(), name="smtp_test_api"), url(r"^website/?$", WebsiteConfigAPI.as_view(), name="website_config_api"), url(r"^judge_server/?$", JudgeServerAPI.as_view(), name="judge_server_api"), url(r"^prune_test_case/?$", TestCasePruneAPI.as_view(), name="prune_test_case_api"), diff --git a/conf/views.py b/conf/views.py index 5fb22414..b4e47bb9 100644 --- a/conf/views.py +++ b/conf/views.py @@ -12,11 +12,13 @@ from judge.dispatcher import process_pending_task from judge.languages import languages, spj_languages from options.options import SysOptions from utils.api import APIView, CSRFExemptAPIView, validate_serializer +from utils.shortcuts import send_email +from utils.xss_filter import XSSHtml from .models import JudgeServer from .serializers import (CreateEditWebsiteConfigSerializer, CreateSMTPConfigSerializer, EditSMTPConfigSerializer, JudgeServerHeartbeatSerializer, - JudgeServerSerializer, TestSMTPConfigSerializer) + JudgeServerSerializer, TestSMTPConfigSerializer, EditJudgeServerSerializer) class SMTPAPI(APIView): @@ -51,7 +53,25 @@ class SMTPTestAPI(APIView): @super_admin_required @validate_serializer(TestSMTPConfigSerializer) def post(self, request): - return self.success({"result": True}) + if not SysOptions.smtp_config: + return self.error("Please setup SMTP config at first") + try: + send_email(smtp_config=SysOptions.smtp_config, + from_name=SysOptions.website_name_shortcut, + to_name=request.user.username, + to_email=request.data["email"], + subject="You have successfully configured SMTP", + content="You have successfully configured SMTP") + except Exception as e: + # guess error message encoding + msg = e.smtp_error + try: + # qq mail + msg = msg.decode("gbk") + except Exception: + msg = msg.decode("utf-8", "ignore") + return self.error(msg) + return self.success() class WebsiteConfigAPI(APIView): @@ -65,6 +85,9 @@ class WebsiteConfigAPI(APIView): @super_admin_required def post(self, request): for k, v in request.data.items(): + if k == "website_footer": + with XSSHtml() as parser: + v = parser.clean(v) setattr(SysOptions, k, v) return self.success() @@ -83,6 +106,12 @@ class JudgeServerAPI(APIView): JudgeServer.objects.filter(hostname=hostname).delete() return self.success() + @validate_serializer(EditJudgeServerSerializer) + @super_admin_required + def put(self, request): + JudgeServer.objects.filter(id=request.data["id"]).update(is_disabled=request.data["is_disabled"]) + return self.success() + class JudgeServerHeartbeatAPI(CSRFExemptAPIView): @validate_serializer(JudgeServerHeartbeatSerializer) @@ -91,7 +120,6 @@ class JudgeServerHeartbeatAPI(CSRFExemptAPIView): client_token = request.META.get("HTTP_X_JUDGE_SERVER_TOKEN") if hashlib.sha256(SysOptions.judge_server_token.encode("utf-8")).hexdigest() != client_token: return self.error("Invalid token") - service_url = data.get("service_url") try: server = JudgeServer.objects.get(hostname=data["hostname"]) @@ -99,7 +127,7 @@ class JudgeServerHeartbeatAPI(CSRFExemptAPIView): server.cpu_core = data["cpu_core"] server.memory_usage = data["memory"] server.cpu_usage = data["cpu"] - server.service_url = service_url + server.service_url = data["service_url"] server.ip = request.META["HTTP_X_REAL_IP"] server.last_heartbeat = timezone.now() server.save() @@ -110,7 +138,7 @@ class JudgeServerHeartbeatAPI(CSRFExemptAPIView): memory_usage=data["memory"], cpu_usage=data["cpu"], ip=request.META["REMOTE_ADDR"], - service_url=service_url, + service_url=data["service_url"], last_heartbeat=timezone.now(), ) # 新server上线 处理队列中的,防止没有新的提交而导致一直waiting diff --git a/deploy/run.sh b/deploy/entrypoint.sh similarity index 100% rename from deploy/run.sh rename to deploy/entrypoint.sh diff --git a/deploy/requirements.txt b/deploy/requirements.txt index 3d83ebfd..554250a8 100644 --- a/deploy/requirements.txt +++ b/deploy/requirements.txt @@ -16,3 +16,4 @@ psycopg2 gunicorn jsonfield XlsxWriter +raven \ No newline at end of file diff --git a/docs/README.md b/docs/README.md new file mode 100644 index 00000000..5953a03a --- /dev/null +++ b/docs/README.md @@ -0,0 +1,3 @@ +# DOCUMENT + +[Here](http://docs.onlinejudge.me/) \ No newline at end of file diff --git a/docs/data.json b/docs/data.json new file mode 100644 index 00000000..53d164a9 --- /dev/null +++ b/docs/data.json @@ -0,0 +1,16 @@ +{ + "update": [ + { + "version": "2017-12-25", + "level": 1, + "title": "Update at 2017-12-25", + "details": [ + "Fix some issues under IE/Edge", + "Add backend error reporter", + "New email template", + "A more flexible throttling function", + "Other bugs and enhancements" + ] + } + ] +} \ No newline at end of file diff --git a/judge/dispatcher.py b/judge/dispatcher.py index 289e50a1..ecb5dbf6 100644 --- a/judge/dispatcher.py +++ b/judge/dispatcher.py @@ -6,7 +6,6 @@ from urllib.parse import urljoin import requests from django.db import transaction from django.db.models import F -from django.conf import settings from account.models import User from conf.models import JudgeServer @@ -47,7 +46,7 @@ class DispatcherBase(object): @staticmethod def choose_judge_server(): with transaction.atomic(): - servers = JudgeServer.objects.select_for_update().all().order_by("task_number") + servers = JudgeServer.objects.select_for_update().filter(is_disabled=False).order_by("task_number") servers = [s for s in servers if s.status == "normal"] if servers: server = servers[0] @@ -154,11 +153,7 @@ class JudgeDispatcher(DispatcherBase): Submission.objects.filter(id=self.submission.id).update(result=JudgeStatus.JUDGING) - service_url = server.service_url - # not set service_url, it should be a linked container - if not service_url: - service_url = settings.DEFAULT_JUDGE_SERVER_SERVICE_URL - resp = self._request(urljoin(service_url, "/judge"), data=data) + resp = self._request(urljoin(server.service_url, "/judge"), data=data) if resp["err"]: self.submission.result = JudgeStatus.COMPILE_ERROR self.submission.statistic_info["err_info"] = resp["data"] diff --git a/oj/settings.py b/oj/settings.py index 91683a65..2cb64ad2 100644 --- a/oj/settings.py +++ b/oj/settings.py @@ -10,6 +10,7 @@ For the full list of settings and their values, see https://docs.djangoproject.com/en/1.8/ref/settings/ """ import os +import raven from copy import deepcopy if os.environ.get("OJ_ENV") == "production": @@ -29,6 +30,7 @@ VENDOR_APPS = ( 'django.contrib.messages', 'django.contrib.staticfiles', 'rest_framework', + 'raven.contrib.django.raven_compat' ) LOCAL_APPS = ( 'account', @@ -125,6 +127,8 @@ UPLOAD_DIR = f"{DATA_DIR}{UPLOAD_PREFIX}" STATICFILES_DIRS = [os.path.join(DATA_DIR, "public")] + +LOGGING_HANDLERS = ['console'] if DEBUG else ['console', 'sentry'] LOGGING = { 'version': 1, 'disable_existing_loggers': False, @@ -139,21 +143,26 @@ LOGGING = { 'level': 'DEBUG', 'class': 'logging.StreamHandler', 'formatter': 'standard' + }, + 'sentry': { + 'level': 'ERROR', + 'class': 'raven.contrib.django.raven_compat.handlers.SentryHandler', + 'formatter': 'standard' } }, 'loggers': { 'django.request': { - 'handlers': ['console'], + 'handlers': LOGGING_HANDLERS, 'level': 'ERROR', 'propagate': True, }, 'django.db.backends': { - 'handlers': ['console'], + 'handlers': LOGGING_HANDLERS, 'level': 'ERROR', 'propagate': True, }, '': { - 'handlers': ['console'], + 'handlers': LOGGING_HANDLERS, 'level': 'WARNING', 'propagate': True, } @@ -201,3 +210,7 @@ TOKEN_BUCKET_DEFAULT_CAPACITY = 10 # 单位:每分钟 TOKEN_BUCKET_FILL_RATE = 2 + +RAVEN_CONFIG = { + 'dsn': 'https://b200023b8aed4d708fb593c5e0a6ad3d:1fddaba168f84fcf97e0d549faaeaff0@sentry.io/263057' +} \ No newline at end of file diff --git a/options/options.py b/options/options.py index 7d8b9a9e..57169bb7 100644 --- a/options/options.py +++ b/options/options.py @@ -21,6 +21,7 @@ class OptionKeys: submission_list_show_all = "submission_list_show_all" smtp_config = "smtp_config" judge_server_token = "judge_server_token" + throttling = "throttling" class OptionDefaultValue: @@ -32,6 +33,8 @@ class OptionDefaultValue: submission_list_show_all = True smtp_config = {} judge_server_token = default_token + throttling = {"ip": {"capacity": 100, "fill_rate": 0.1, "default_capacity": 50}, + "user": {"capacity": 20, "fill_rate": 0.03, "default_capacity": 10}} class _SysOptionsMeta(type): @@ -180,6 +183,14 @@ class _SysOptionsMeta(type): def judge_server_token(cls, value): cls._set_option(OptionKeys.judge_server_token, value) + @property + def throttling(cls): + return cls._get_option(OptionKeys.throttling) + + @throttling.setter + def throttling(cls, value): + cls._set_option(OptionKeys.throttling, value) + class SysOptions(metaclass=_SysOptionsMeta): pass diff --git a/submission/views/oj.py b/submission/views/oj.py index 5d6351b0..4fb077b6 100644 --- a/submission/views/oj.py +++ b/submission/views/oj.py @@ -1,6 +1,5 @@ import ipaddress -from django.conf import settings from account.decorators import login_required, check_contest_permission from judge.tasks import judge_task # from judge.dispatcher import JudgeDispatcher @@ -8,7 +7,7 @@ from problem.models import Problem, ProblemRuleType from contest.models import Contest, ContestStatus, ContestRuleType from options.options import SysOptions from utils.api import APIView, validate_serializer -from utils.throttling import TokenBucket, BucketController +from utils.throttling import TokenBucket from utils.captcha import Captcha from utils.cache import cache from ..models import Submission @@ -19,29 +18,16 @@ from ..serializers import SubmissionSafeModelSerializer, SubmissionListSerialize class SubmissionAPI(APIView): def throttling(self, request): - user_controller = BucketController(factor=request.user.id, - redis_conn=cache, - default_capacity=settings.TOKEN_BUCKET_DEFAULT_CAPACITY) - user_bucket = TokenBucket(fill_rate=settings.TOKEN_BUCKET_FILL_RATE, - capacity=settings.TOKEN_BUCKET_DEFAULT_CAPACITY, - last_capacity=user_controller.last_capacity, - last_timestamp=user_controller.last_timestamp) - if user_bucket.consume(): - user_controller.last_capacity -= 1 - else: - return "Please wait %d seconds" % int(user_bucket.expected_time() + 1) + user_bucket = TokenBucket(key=str(request.user.id), + redis_conn=cache, **SysOptions.throttling["user"]) + can_consume, wait = user_bucket.consume() + if not can_consume: + return "Please wait %d seconds" % (int(wait)) - ip_controller = BucketController(factor=request.session["ip"], - redis_conn=cache, - default_capacity=settings.TOKEN_BUCKET_DEFAULT_CAPACITY * 3) - - ip_bucket = TokenBucket(fill_rate=settings.TOKEN_BUCKET_FILL_RATE * 3, - capacity=settings.TOKEN_BUCKET_DEFAULT_CAPACITY * 3, - last_capacity=ip_controller.last_capacity, - last_timestamp=ip_controller.last_timestamp) - if ip_bucket.consume(): - ip_controller.last_capacity -= 1 - else: + ip_bucket = TokenBucket(key=request.session["ip"], + redis_conn=cache, **SysOptions.throttling["ip"]) + can_consume, wait = ip_bucket.consume() + if not can_consume: return "Captcha is required" @validate_serializer(CreateSubmissionSerializer) diff --git a/utils/models.py b/utils/models.py index 3c114522..9a1cd0ba 100644 --- a/utils/models.py +++ b/utils/models.py @@ -1,14 +1,10 @@ from django.contrib.postgres.fields import JSONField # NOQA from django.db import models -from utils.xss_filter import XssHtml +from utils.xss_filter import XSSHtml class RichTextField(models.TextField): def get_prep_value(self, value): - if not value: - value = "" - parser = XssHtml() - parser.feed(value) - parser.close() - return parser.getHtml() + with XSSHtml() as parser: + return parser.clean(value or "") diff --git a/utils/shortcuts.py b/utils/shortcuts.py index 9340564e..111c31b8 100644 --- a/utils/shortcuts.py +++ b/utils/shortcuts.py @@ -5,6 +5,7 @@ from base64 import b64encode from io import BytesIO from django.utils.crypto import get_random_string +from envelopes import Envelope def rand_str(length=32, type="lower_hex"): @@ -63,3 +64,15 @@ def timestamp2utcstr(value): def natural_sort_key(s, _nsre=re.compile(r"(\d+)")): return [int(text) if text.isdigit() else text.lower() for text in re.split(_nsre, s)] + + +def send_email(smtp_config, from_name, to_email, to_name, subject, content): + envelope = Envelope(from_addr=(smtp_config["email"], from_name), + to_addr=(to_email, to_name), + subject=subject, + html_body=content) + return envelope.send(smtp_config["server"], + login=smtp_config["email"], + password=smtp_config["password"], + port=smtp_config["port"], + tls=smtp_config["tls"]) diff --git a/utils/throttling.py b/utils/throttling.py index 7c5f54a9..bab1184a 100644 --- a/utils/throttling.py +++ b/utils/throttling.py @@ -1,90 +1,72 @@ -from __future__ import print_function import time class TokenBucket: - def __init__(self, fill_rate, capacity, last_capacity, last_timestamp): - self.capacity = float(capacity) - self._left_tokens = last_capacity - self.fill_rate = float(fill_rate) - self.timestamp = last_timestamp + """ + 注意:对于单个key的操作不是线程安全的 + """ + def __init__(self, key, capacity, fill_rate, default_capacity, redis_conn): + """ + :param capacity: 最大容量 + :param fill_rate: 填充速度/每秒 + :param default_capacity: 初始容量 + :param redis_conn: redis connection + """ + self._key = key + self._capacity = capacity + self._fill_rate = fill_rate + self._default_capacity = default_capacity + self._redis_conn = redis_conn - def consume(self, tokens=1): - if tokens <= self.tokens: - self._left_tokens -= tokens - return True - return False + self._last_capacity_key = "last_capacity" + self._last_timestamp_key = "last_timestamp" - def expected_time(self, tokens=1): - _tokens = self.tokens - tokens = max(tokens, _tokens) - return (tokens - _tokens) / self.fill_rate * 60 + def _init_key(self): + self._last_capacity = self._default_capacity + now = time.time() + self._last_timestamp = now + return self._default_capacity, now @property - def tokens(self): - if self._left_tokens < self.capacity: + def _last_capacity(self): + last_capacity = self._redis_conn.hget(self._key, self._last_capacity_key) + if last_capacity is None: + return self._init_key()[0] + else: + return float(last_capacity) + + @_last_capacity.setter + def _last_capacity(self, value): + self._redis_conn.hset(self._key, self._last_capacity_key, value) + + @property + def _last_timestamp(self): + return float(self._redis_conn.hget(self._key, self._last_timestamp_key)) + + @_last_timestamp.setter + def _last_timestamp(self, value): + self._redis_conn.hset(self._key, self._last_timestamp_key, value) + + def _try_to_fill(self, now): + delta = self._fill_rate * (now - self._last_timestamp) + return min(self._last_capacity + delta, self._capacity) + + def consume(self, num=1): + """ + 消耗 num 个 token,返回是否成功 + :param num: + :return: result: bool, wait_time: float + """ + # print("capacity ", self.fill(time.time())) + if self._last_capacity >= num: + self._last_capacity -= num + return True, 0 + else: now = time.time() - delta = self.fill_rate * ((now - self.timestamp) / 60) - self._left_tokens = min(self.capacity, self._left_tokens + delta) - self.timestamp = now - return self._left_tokens - - -class BucketController: - def __init__(self, factor, redis_conn, default_capacity): - self.default_capacity = default_capacity - self.redis = redis_conn - self.key = "bucket_" + str(factor) - - @property - def last_capacity(self): - value = self.redis.hget(self.key, "last_capacity") - if value is None: - self.last_capacity = self.default_capacity - return self.default_capacity - return int(value) - - @last_capacity.setter - def last_capacity(self, value): - self.redis.hset(self.key, "last_capacity", value) - - @property - def last_timestamp(self): - value = self.redis.hget(self.key, "last_timestamp") - if value is None: - timestamp = int(time.time()) - self.last_timestamp = timestamp - return timestamp - return int(value) - - @last_timestamp.setter - def last_timestamp(self, value): - self.redis.hset(self.key, "last_timestamp", value) - - -""" -# # Token bucket, to limit submission rate -# # Demo - -success = failure = 0 -current_user_id = 1 -token_bucket_default_capacity = 50 -token_bucket_fill_rate = 10 -for i in range(5000): - controller = BucketController(user_id=current_user_id, - redis_conn=redis.Redis(), - default_capacity=token_bucket_default_capacity) - bucket = TokenBucket(fill_rate=token_bucket_fill_rate, - capacity=token_bucket_default_capacity, - last_capacity=controller.last_capacity, - last_timestamp=controller.last_timestamp) - time.sleep(0.05) - if bucket.consume(): - success += 1 - print(i, ": Accepted") - controller.last_capacity -= 1 - else: - failure += 1 - print(i, "Dropped, time left ", bucket.expected_time()) -print(success, failure) -""" + cur_num = self._try_to_fill(now) + if cur_num >= num: + self._last_capacity = cur_num - num + self._last_timestamp = now + return True, 0 + else: + return False, (num - cur_num) / self._fill_rate diff --git a/utils/xss_filter.py b/utils/xss_filter.py index 34d65a8b..1b45d89c 100644 --- a/utils/xss_filter.py +++ b/utils/xss_filter.py @@ -30,7 +30,7 @@ import copy from html.parser import HTMLParser -class XssHtml(HTMLParser): +class XSSHtml(HTMLParser): allow_tags = ['a', 'img', 'br', 'strong', 'b', 'code', 'pre', 'p', 'div', 'em', 'span', 'h1', 'h2', 'h3', 'h4', 'h5', 'h6', 'blockquote', 'ul', 'ol', 'tr', 'th', 'td', @@ -53,7 +53,17 @@ class XssHtml(HTMLParser): self.start = [] self.data = [] - def getHtml(self): + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + super().close() + + def clean(self, content): + self.feed(content) + return self.get_html() + + def get_html(self): """ Get the safe html code """ @@ -188,11 +198,11 @@ class XssHtml(HTMLParser): if "__main__" == __name__: - parser = XssHtml() - parser.feed("""

-
hehe
-

>M - MM

- """) - parser.close() - print(parser.getHtml()) + with XSSHtml() as parser: + ret = parser.clean("""

+
hehe
+

>M + MM

+ + """) + print(ret)