diff --git a/oj/settings.py b/oj/settings.py index c8a07ec2..6e01c103 100644 --- a/oj/settings.py +++ b/oj/settings.py @@ -50,6 +50,7 @@ INSTALLED_APPS = ( 'announcement', 'utils', 'group', + 'problem', 'rest_framework', 'rest_framework_swagger', diff --git a/oj/urls.py b/oj/urls.py index b2eecb9e..9e4f08d5 100644 --- a/oj/urls.py +++ b/oj/urls.py @@ -9,6 +9,7 @@ from account.views import (UserLoginAPIView, UsernameCheckAPIView, UserRegisterA from announcement.views import AnnouncementAPIView, AnnouncementAdminAPIView from group.views import GroupAdminAPIView from admin.views import AdminTemplateView +from problem.views import ProblemAdminAPIView urlpatterns = [ url("^$", TemplateView.as_view(template_name="oj/index.html"), name="index_page"), @@ -33,4 +34,6 @@ urlpatterns = [ url(r'^problems/$', TemplateView.as_view(template_name="oj/problem/problem_list.html"), name="problem_list_page"), url(r'^admin/template/(?P\w+)/(?P\w+).html', AdminTemplateView.as_view(), name="admin_template"), url(r'^api/admin/group/$', GroupAdminAPIView.as_view(), name="group_admin_api"), + + url(r'api/admin/problem/$', ProblemAdminAPIView.as_view(), name="problem_admin_api"), ] diff --git a/problem/models.py b/problem/models.py index e8255cde..01320722 100644 --- a/problem/models.py +++ b/problem/models.py @@ -18,21 +18,19 @@ class AbstractProblem(models.Model): # 问题描述 HTML 格式 description = models.TextField() # 样例输入 可能会存储 json 格式的数据 - sample_input = models.TextField(blank=True) - # 样例输出 同上 - sample_output = models.TextField(blank=True) + sample = models.TextField(blank=True) # 测试用例id 这个id 可以用来拼接得到测试用例的文件存储位置 test_case_id = models.CharField(max_length=40) # 提示 - hint = models.TextField(blank=True) + hint = models.TextField(blank=True, null=True) # 创建时间 - create_time = models.DateTimeField(auth_now_add=True) + create_time = models.DateTimeField(auto_now_add=True) # 最后更新时间 last_update_time = models.DateTimeField(auto_now=True) # 这个题是谁创建的 created_by = models.ForeignKey(User) # 来源 - source = models.CharField(max_length=30, blank=True) + source = models.CharField(max_length=30, blank=True, null=True) # 时间限制 单位是毫秒 time_limit = models.IntegerField() # 内存限制 单位是MB diff --git a/problem/serizalizers.py b/problem/serizalizers.py new file mode 100644 index 00000000..441c9c55 --- /dev/null +++ b/problem/serizalizers.py @@ -0,0 +1,62 @@ +# coding=utf-8 +import json + +from rest_framework import serializers + +from account.models import User +from .models import Problem + + +class ProblemSampleSerializer(serializers.ListField): + input = serializers.CharField(max_length=3000) + output = serializers.CharField(max_length=3000) + + +class JSONField(serializers.Field): + def to_representation(self, value): + print value, type(value) + return json.loads(value) + + +class CreateProblemSerializer(serializers.Serializer): + title = serializers.CharField(max_length=50) + description = serializers.CharField(max_length=10000) + # [{"input": "1 1", "output": "2"}] + sample = ProblemSampleSerializer() + test_case_id = serializers.CharField(max_length=40) + source = serializers.CharField(max_length=30, required=False, default=None) + time_limit = serializers.IntegerField() + memory_limit = serializers.IntegerField() + difficulty = serializers.IntegerField() + tags = serializers.ListField(child=serializers.IntegerField()) + hint = serializers.CharField(max_length=3000, required=False, default=None) + + +class ProblemSerializer(serializers.ModelSerializer): + sample = JSONField() + + class UserSerializer(serializers.ModelSerializer): + class Meta: + model = User + fields = ["username"] + + created_by = UserSerializer() + + class Meta: + model = Problem + + +class EditProblemSerializer(serializers.Serializer): + id = serializers.IntegerField() + title = serializers.CharField(max_length=50) + description = serializers.CharField(max_length=10000) + test_case_id = serializers.CharField(max_length=40) + source = serializers.CharField(max_length=30) + time_limit = serializers.IntegerField() + memory_limit = serializers.IntegerField() + difficulty = serializers.IntegerField() + tags = serializers.ListField(child=serializers.IntegerField()) + sample = ProblemSampleSerializer() + hint = serializers.CharField(max_length=10000) + visible = serializers.BooleanField() + diff --git a/problem/tests.py b/problem/tests.py index b66a4940..9fb04063 100644 --- a/problem/tests.py +++ b/problem/tests.py @@ -1,6 +1,83 @@ # coding=utf-8 from django.test import TestCase +from django.core.urlresolvers import reverse +from rest_framework.test import APITestCase, APIClient + +from account.models import User, SUPER_ADMIN +from problem.models import Problem, ProblemTag + class ProblemPageTest(TestCase): pass + + +class ProblemAdminTest(APITestCase): + def setUp(self): + self.client = APIClient() + self.url = reverse("problem_admin_api") + user = User.objects.create(username="test", admin_type=SUPER_ADMIN) + user.set_password("testaa") + user.save() + + # 以下是发布题目的测试 + def test_invalid_format(self): + self.client.login(username="test", password="testaa") + data = {"title": "test1"} + response = self.client.post(self.url, data=data) + self.assertEqual(response.data["code"], 1) + + def test_success_problem(self): + self.client.login(username="test", password="testaa") + ProblemTag.objects.create(name="tag1", description="destag1") + data = {"title": "title1", "description": "des1", "test_case_id": "1", "source": "source1", + "sample": [{"input": "1 1", "output": "2"}], "time_limit": "100", "memory_limit": "1000", + "difficulty": "1", "hint": "hint1", "tags": [1]} + response = self.client.post(self.url, data=data) + self.assertEqual(response.data["code"], 0) + + # 以下是编辑题目的测试 + def test_put_invalid_data(self): + self.client.login(username="test", password="testaa") + data = {"title": "test0"} + response = self.client.put(self.url, data=data) + self.assertEqual(response.data["code"], 1) + + def test_problem_does_not_exist(self): + self.client.login(username="test", password="testaa") + ProblemTag.objects.create(name="tag1", description="destag1") + tags = ProblemTag.objects.filter(id__in=[1]) + problem = Problem.objects.create(title="title1", description="des1", + test_case_id="1", source="source1", + sample=[{"input": "1 1", "output": "2"}], + time_limit=100, memory_limit=1000, + difficulty=1, hint="hint1", + created_by=User.objects.get(username="test")) + problem.tags.add(*tags) + data = {"id": 2, "title": "title1", "description": "des1", "test_case_id": "1", "source": "source1", + "sample": [{"input": "1 1", "output": "2"}], "time_limit": "100", "memory_limit": "1000", + "difficulty": "1", "hint": "hint1", "tags": [1]} + response = self.client.put(self.url, data=data) + self.assertEqual(response.data, {"code": 1, "data": u"该题目不存在!"}) + + def test_success_edit_problem(self): + self.client.login(username="test", password="testaa") + self.client.login(username="test", password="testaa") + ProblemTag.objects.create(name="tag1", description="destag1") + ProblemTag.objects.create(name="tag2", description="destag2") + tags = ProblemTag.objects.filter(id__in=[1]) + problem0 = Problem.objects.create(title="title1", description="des1", + test_case_id="1", source="source1", + sample=[{"input": "1 1", "output": "2"}], + time_limit=100, memory_limit=1000, + difficulty=1, hint="hint1", + created_by=User.objects.get(username="test")) + problem0.tags.add(*tags) + data = {"id": 1, "title": "title1", "description": "des1", "test_case_id": "1", "source": "source1", + "sample": [{"input": "1 1", "output": "2"}], "time_limit": "100", "memory_limit": "1000", + "difficulty": "1", "hint": "hint1", "visible": True, "tags": [1, 2]} + problem = Problem.objects.get(id=data["id"]) + problem.tags.remove(*problem.tags.all()) + problem.tags.add(*ProblemTag.objects.filter(id__in=data["tags"])) + response = self.client.put(self.url, data=data) + self.assertEqual(response.data["code"], 0) diff --git a/problem/views.py b/problem/views.py index 58fb9e6d..cf45d938 100644 --- a/problem/views.py +++ b/problem/views.py @@ -1,7 +1,99 @@ # coding=utf-8 +import json from django.shortcuts import render +from rest_framework.views import APIView + +from django.db.models import Q + +from serizalizers import CreateProblemSerializer, EditProblemSerializer, ProblemSerializer +from .models import Problem, ProblemTag +from utils.shortcuts import serializer_invalid_response, error_response, success_response, paginate + def problem_page(request, problem_id): # todo return render(request, "oj/problem/problem.html") + + +class ProblemAdminAPIView(APIView): + def post(self, request): + """ + 题目发布json api接口 + --- + request_serializer: CreateProblemSerializer + response_serializer: ProblemSerializer + """ + serializer = CreateProblemSerializer(data=request.data) + if serializer.is_valid(): + data = serializer.data + problem = Problem.objects.create(title=data["title"], + description=data["description"], + test_case_id=data["test_case_id"], + source=data["source"], + sample=json.dumps(data["sample"]), + time_limit=data["time_limit"], + memory_limit=data["memory_limit"], + difficulty=data["difficulty"], + created_by=request.user, + hint=data["hint"]) + + tags = ProblemTag.objects.filter(id__in=data["tags"]) + problem.tags.add(*tags) + return success_response(ProblemSerializer(problem).data) + else: + return serializer_invalid_response(serializer) + + def put(self, request): + """ + 题目编辑json api接口 + --- + request_serializer: EditProblemSerializer + response_serializer: ProblemSerializer + """ + serializer = EditProblemSerializer(data=request.data) + if serializer.is_valid(): + data = serializer.data + print request.data + try: + problem = Problem.objects.get(id=data["id"]) + except Problem.DoesNotExist: + return error_response(u"该题目不存在!") + + problem.title = data["title"] + problem.description = data["description"] + problem.test_case_id = data["test_case_id"] + problem.source = data["source"] + problem.time_limit = data["time_limit"] + problem.memory_limit = data["memory_limit"] + problem.difficulty = data["difficulty"] + problem.sample = json.dumps(data["sample"]) + problem.hint = data["hint"] + problem.visible = data["visible"] + + # 删除原有的标签的对应关系 + problem.tags.remove(*problem.tags.all()) + # 重新添加所有的标签 + problem.tags.add(*ProblemTag.objects.filter(id__in=data["tags"])) + problem.save() + return success_response(ProblemSerializer(problem).data) + else: + return serializer_invalid_response(serializer) + + +class ProblemAPIView(APIView): + def get(self, request): + """ + 题目分页json api接口 + --- + response_serializer: ProblemSerializer + """ + problem = Problem.objects.all().order_by("-last_update_time") + visible = request.GET.get("visible", None) + if visible: + problem = problem.filter(visible=(visible == "true")) + keyword = request.GET.get("keyword", None) + if keyword: + problem = problem.filter(Q(difficulty__contains=keyword)) + + return paginate(request, problem, ProblemSerializer)