diff --git a/conf/models.py b/conf/models.py index 9a88fdfb..9fe3cc51 100644 --- a/conf/models.py +++ b/conf/models.py @@ -41,7 +41,8 @@ class JudgeServer(models.Model): @property def status(self): - if (timezone.now() - self.last_heartbeat).total_seconds() > 5: + # 增加一秒延时,提高对网络环境的适应性 + if (timezone.now() - self.last_heartbeat).total_seconds() > 6: return "abnormal" return "normal" diff --git a/conf/views.py b/conf/views.py index 3dc99591..bc03c010 100644 --- a/conf/views.py +++ b/conf/views.py @@ -1,9 +1,11 @@ import hashlib from django.utils import timezone +from django_redis import get_redis_connection from account.decorators import super_admin_required from judge.languages import languages, spj_languages +from judge.dispatcher import process_pending_task from utils.api import APIView, CSRFExemptAPIView, validate_serializer from utils.shortcuts import rand_str @@ -126,6 +128,10 @@ class JudgeServerHeartbeatAPI(CSRFExemptAPIView): service_url=service_url, last_heartbeat=timezone.now(), ) + # 新server上线 处理队列中的,防止没有新的提交而导致一直waiting + conn = get_redis_connection("JudgeQueue") + process_pending_task(conn) + return self.success() diff --git a/judge/tasks.py b/judge/dispatcher.py similarity index 63% rename from judge/tasks.py rename to judge/dispatcher.py index 1a83ce44..48fa33e5 100644 --- a/judge/tasks.py +++ b/judge/dispatcher.py @@ -9,23 +9,32 @@ from django.db.models import F from django_redis import get_redis_connection from judge.languages import languages -from account.models import User, UserProfile +from account.models import User from conf.models import JudgeServer, JudgeServerToken from problem.models import Problem, ProblemRuleType -from submission.models import JudgeStatus +from submission.models import JudgeStatus, Submission logger = logging.getLogger(__name__) WAITING_QUEUE = "waiting_queue" +# 继续处理在队列中的问题 +def process_pending_task(redis_conn): + if redis_conn.llen(WAITING_QUEUE): + # 防止循环引入 + from submission.tasks import judge_task + data = json.loads(redis_conn.rpop(WAITING_QUEUE)) + judge_task.delay(**data) + + class JudgeDispatcher(object): - def __init__(self, submission_obj, problem_obj): + def __init__(self, submission_id, problem_id): token = JudgeServerToken.objects.first().token self.token = hashlib.sha256(token.encode("utf-8")).hexdigest() self.redis_conn = get_redis_connection("JudgeQueue") - self.submission_obj = submission_obj - self.problem_obj = problem_obj + self.submission_obj = Submission.objects.get(pk=submission_id) + self.problem_obj = Problem.objects.get(pk=problem_id) def _request(self, url, data=None): kwargs = {"headers": {"X-Judge-Server-Token": self.token, @@ -41,10 +50,10 @@ class JudgeDispatcher(object): def choose_judge_server(): with transaction.atomic(): # TODO: use more reasonable way - servers = JudgeServer.objects.select_for_update().filter( - status="normal").order_by("task_number") - if servers.exists(): - server = servers.first() + servers = JudgeServer.objects.select_for_update().all().order_by('task_number') + servers = [s for s in servers if s.status == "normal"] + if servers: + server = servers[0] server.used_instance_number = F("task_number") + 1 server.save() return server @@ -60,28 +69,31 @@ class JudgeDispatcher(object): def judge(self, output=False): server = self.choose_judge_server() if not server: - self.redis_conn.lpush(WAITING_QUEUE, self.submission_obj.id) + data = {'submission_id': self.submission_obj.id, 'problem_id': self.problem_obj.id} + self.redis_conn.lpush(WAITING_QUEUE, json.dumps(data)) return language = list(filter(lambda item: self.submission_obj.language == item['name'], languages))[0] - data = {"language_config": language['config'], - "src": self.submission_obj.code, - "max_cpu_time": self.problem_obj.time_limit, - "max_memory": self.problem_obj.memory_limit, - "test_case_id": self.problem_obj.test_case_id, - "output": output} + data = { + "language_config": language['config'], + "src": self.submission_obj.code, + "max_cpu_time": self.problem_obj.time_limit, + "max_memory": 1024 * 1024 * self.problem_obj.memory_limit, + "test_case_id": self.problem_obj.test_case_id, + "output": output + } # TODO: try catch resp = self._request(urljoin(server.service_url, "/judge"), data=data) self.submission_obj.info = resp if resp['err']: self.submission_obj.result = JudgeStatus.COMPILE_ERROR else: - error_test_case = list(filter(lambda case: case['result'] != 0, resp)) + error_test_case = list(filter(lambda case: case['result'] != 0, resp['data'])) # 多个测试点全部正确AC,否则ACM模式下取第一个测试点状态 if not error_test_case: self.submission_obj.result = JudgeStatus.ACCEPTED - elif self.problem_obj.rule_tyle == ProblemRuleType.ACM: - self.submission_obj.result = error_test_case[0].result + elif self.problem_obj.rule_type == ProblemRuleType.ACM: + self.submission_obj.result = error_test_case[0]['result'] else: self.submission_obj.result = JudgeStatus.PARTIALLY_ACCEPTED self.submission_obj.save() @@ -92,37 +104,36 @@ class JudgeDispatcher(object): pass else: self.update_problem_status() - # 取redis中等待中的提交 - if self.redis_conn.llen(WAITING_QUEUE): - pass + process_pending_task(self.redis_conn) def compile_spj(self, service_url, src, spj_version, spj_compile_config, test_case_id): data = {"src": src, "spj_version": spj_version, - "spj_compile_config": spj_compile_config, "test_case_id": test_case_id} - return self._request(service_url + "/compile_spj", data=data) + "spj_compile_config": spj_compile_config, + "test_case_id": test_case_id} + return self._request(urljoin(service_url, "compile_spj"), data=data) def update_problem_status(self): with transaction.atomic(): - problem = Problem.objects.select_for_update().get(id=self.problem_obj.problem_id) - # 更新普通题目的计数器 - problem.add_submission_number() - - # 更新用户做题状态 + problem = Problem.objects.select_for_update().get(id=self.problem_obj.id) user = User.objects.select_for_update().get(id=self.submission_obj.user_id) - problems_status = UserProfile.objects.get(user=user).problem_status + # 更新提交计数器 + problem.add_submission_number() + user_profile = user.userprofile + user_profile.add_submission_number() + if self.submission_obj.result == JudgeStatus.ACCEPTED: + problem.add_ac_number() + + problems_status = user_profile.problems_status if "problems" not in problems_status: problems_status["problems"] = {} - # 增加用户提交计数器 - user.userprofile.add_submission_number() - # 之前状态不是ac, 现在是ac了 需要更新用户ac题目数量计数器,这里需要判重 if problems_status["problems"].get(str(problem.id), JudgeStatus.WRONG_ANSWER) != JudgeStatus.ACCEPTED: if self.submission_obj.result == JudgeStatus.ACCEPTED: - user.userprofile.add_accepted_problem_number() + user_profile.add_accepted_problem_number() problems_status["problems"][str(problem.id)] = JudgeStatus.ACCEPTED else: problems_status["problems"][str(problem.id)] = JudgeStatus.WRONG_ANSWER - user.problems_status = problems_status - user.save(update_fields=["problems_status"]) + user_profile.problems_status = problems_status + user_profile.save(update_fields=["problems_status"]) diff --git a/oj/__init__.py b/oj/__init__.py index e69de29b..1ed42e9c 100644 --- a/oj/__init__.py +++ b/oj/__init__.py @@ -0,0 +1,6 @@ +from __future__ import absolute_import, unicode_literals + +# Django starts so that shared_task will use this app. +from .celery import app as celery_app + +__all__ = ['celery_app'] diff --git a/oj/celery.py b/oj/celery.py new file mode 100644 index 00000000..b5edaa38 --- /dev/null +++ b/oj/celery.py @@ -0,0 +1,18 @@ +from __future__ import absolute_import, unicode_literals +import os +from celery import Celery +from django.conf import settings + +# set the default Django settings module for the 'celery' program. +os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'oj.settings') + + +app = Celery('oj') + +# Using a string here means the worker will not have to +# pickle the object when using Windows. +app.config_from_object('django.conf:settings') + +# load task modules from all registered Django app configs. +app.autodiscover_tasks(lambda: settings.INSTALLED_APPS) +# app.autodiscover_tasks() diff --git a/oj/db_router.py b/oj/db_router.py deleted file mode 100644 index 823d7843..00000000 --- a/oj/db_router.py +++ /dev/null @@ -1,19 +0,0 @@ -class DBRouter(object): - def db_for_read(self, model, **hints): - if model._meta.app_label == "submission": - return "submission" - return "default" - - def db_for_write(self, model, **hints): - if model._meta.app_label == "submission": - return "submission" - return "default" - - def allow_relation(self, obj1, obj2, **hints): - return True - - def allow_migrate(self, db, app_label, model=None, **hints): - if app_label == "submission": - return db == app_label - else: - return db == "default" diff --git a/oj/local_settings.py b/oj/local_settings.py index f5212ff6..dfad51bc 100644 --- a/oj/local_settings.py +++ b/oj/local_settings.py @@ -34,16 +34,11 @@ CACHES = { } } -REDIS_CACHE = { - "host": "127.0.0.1", - "port": 6379, - "db": 1 -} - +# For celery REDIS_QUEUE = { "host": "127.0.0.1", "port": 6379, - "db": 2 + "db": 4 } DEBUG = True diff --git a/oj/settings.py b/oj/settings.py index c33efe61..c0d6cf71 100644 --- a/oj/settings.py +++ b/oj/settings.py @@ -164,8 +164,6 @@ BROKER_URL = 'redis://%s:%s/%s' % (REDIS_QUEUE["host"], str(REDIS_QUEUE["port"]) CELERY_ACCEPT_CONTENT = ["json"] CELERY_TASK_SERIALIZER = "json" -DATABASE_ROUTERS = ['oj.db_router.DBRouter'] - IMAGE_UPLOAD_DIR = os.path.join(BASE_DIR, 'upload/') # 用于限制用户恶意提交大量代码 diff --git a/problem/models.py b/problem/models.py index 61cc8f44..d0708ed8 100644 --- a/problem/models.py +++ b/problem/models.py @@ -64,11 +64,11 @@ class AbstractProblem(models.Model): def add_submission_number(self): self.total_submit_number = models.F("total_submit_number") + 1 - self.save() + self.save(update_fields=['total_submit_number']) def add_ac_number(self): self.total_accepted_number = models.F("total_accepted_number") + 1 - self.save() + self.save(update_fields=['total_accepted_number']) class Problem(AbstractProblem): diff --git a/submission/migrations/0002_auto_20170509_1203.py b/submission/migrations/0002_auto_20170509_1203.py new file mode 100644 index 00000000..7ca58163 --- /dev/null +++ b/submission/migrations/0002_auto_20170509_1203.py @@ -0,0 +1,20 @@ +# -*- coding: utf-8 -*- +# Generated by Django 1.9.6 on 2017-05-09 12:03 +from __future__ import unicode_literals + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('submission', '0001_initial'), + ] + + operations = [ + migrations.AlterField( + model_name='submission', + name='code', + field=models.TextField(), + ), + ] diff --git a/submission/models.py b/submission/models.py index 9c6d823e..2687004c 100644 --- a/submission/models.py +++ b/submission/models.py @@ -1,7 +1,6 @@ from django.db import models from jsonfield import JSONField -from utils.models import RichTextField from utils.shortcuts import rand_str @@ -25,7 +24,7 @@ class Submission(models.Model): problem_id = models.IntegerField(db_index=True) created_time = models.DateTimeField(auto_now_add=True) user_id = models.IntegerField(db_index=True) - code = RichTextField() + code = models.TextField() result = models.IntegerField(default=JudgeStatus.PENDING) # 判题结果的详细信息 info = JSONField(default={}) diff --git a/submission/tasks.py b/submission/tasks.py index ea9d6c8f..eda9e0f4 100644 --- a/submission/tasks.py +++ b/submission/tasks.py @@ -1,7 +1,8 @@ +from __future__ import absolute_import, unicode_literals from celery import shared_task -from judge.tasks import JudgeDispatcher +from judge.dispatcher import JudgeDispatcher @shared_task -def _judge(submission_obj, problem_obj): - return JudgeDispatcher(submission_obj, problem_obj).judge() +def judge_task(submission_id, problem_id): + JudgeDispatcher(submission_id, problem_id).judge() diff --git a/submission/views/oj.py b/submission/views/oj.py index c34ee5cd..733e0df9 100644 --- a/submission/views/oj.py +++ b/submission/views/oj.py @@ -4,53 +4,43 @@ from django_redis import get_redis_connection from account.decorators import login_required from account.models import AdminType, User from problem.models import Problem - +from submission.tasks import judge_task from utils.api import APIView, validate_serializer from utils.shortcuts import build_query_string from utils.throttling import TokenBucket, BucketController - from ..models import Submission from ..serializers import CreateSubmissionSerializer -from ..tasks import _judge - - -def _submit_code(response, user, problem_id, language, code): - controller = BucketController(user_id=user.id, - redis_conn=get_redis_connection("Throttling"), - default_capacity=30) - bucket = TokenBucket(fill_rate=10, - capacity=20, - last_capacity=controller.last_capacity, - last_timestamp=controller.last_timestamp) - if bucket.consume(): - controller.last_capacity -= 1 - else: - return response.error("Please wait %d seconds" % int(bucket.expected_time() + 1)) - - try: - problem = Problem.objects.get(id=problem_id) - except Problem.DoesNotExist: - return response.error("Problem not exist") - - submission = Submission.objects.create(user_id=user.id, - language=language, - code=code, - problem_id=problem.id) - - try: - _judge.delay(submission, problem) - except Exception: - return response.error("Failed") - - return response.success({"submission_id": submission.id}) class SubmissionAPI(APIView): @validate_serializer(CreateSubmissionSerializer) + # TODO: login # @login_required def post(self, request): + controller = BucketController(user_id=request.user.id, + redis_conn=get_redis_connection("Throttling"), + default_capacity=30) + bucket = TokenBucket(fill_rate=10, capacity=20, + last_capacity=controller.last_capacity, + last_timestamp=controller.last_timestamp) + if bucket.consume(): + controller.last_capacity -= 1 + else: + return self.error("Please wait %d seconds" % int(bucket.expected_time() + 1)) + data = request.data - return _submit_code(self, request.user, data["problem_id"], data["language"], data["code"]) + try: + problem = Problem.objects.get(id=data['problem_id']) + except Problem.DoesNotExist: + return self.error("Problem not exist") + # TODO: user_id + submission = Submission.objects.create(user_id=1, + language=data['language'], + code=data['code'], + problem_id=problem.id) + judge_task.delay(submission.id, problem.id) + # JudgeDispatcher(submission.id, problem.id).judge() + return self.success({"submission_id": submission.id}) @login_required def get(self, request):