import hashlib import json import os import shutil import zipfile from wsgiref.util import FileWrapper from django.conf import settings from django.http import StreamingHttpResponse from account.decorators import problem_permission_required from judge.dispatcher import SPJCompiler from contest.models import Contest from submission.models import Submission from utils.api import APIView, CSRFExemptAPIView, validate_serializer from utils.shortcuts import rand_str, natural_sort_key from ..models import Problem, ProblemRuleType, ProblemTag from ..serializers import (CreateContestProblemSerializer, ContestProblemAdminSerializer, CompileSPJSerializer, CreateProblemSerializer, EditProblemSerializer, EditContestProblemSerializer, ProblemAdminSerializer, TestCaseUploadForm, ContestProblemMakePublicSerializer) class TestCaseAPI(CSRFExemptAPIView): request_parsers = () def filter_name_list(self, name_list, spj): ret = [] prefix = 1 if spj: while True: in_name = str(prefix) + ".in" if in_name in name_list: ret.append(in_name) prefix += 1 continue else: return sorted(ret, key=natural_sort_key) else: while True: in_name = str(prefix) + ".in" out_name = str(prefix) + ".out" if in_name in name_list and out_name in name_list: ret.append(in_name) ret.append(out_name) prefix += 1 continue else: return sorted(ret, key=natural_sort_key) @problem_permission_required def get(self, request): problem_id = request.GET.get("problem_id") if not problem_id: return self.error("Parameter error, problem_id is required") try: problem = Problem.objects.get(id=problem_id) except Problem.DoesNotExist: return self.error("Problem does not exists") test_case_dir = os.path.join(settings.TEST_CASE_DIR, problem.test_case_id) if not os.path.isdir(test_case_dir): return self.error("Test case does not exists") name_list = self.filter_name_list(os.listdir(test_case_dir), problem.spj) name_list.append("info") file_name = os.path.join(test_case_dir, problem.test_case_id + ".zip") with zipfile.ZipFile(file_name, "w") as file: for test_case in name_list: file.write(f"{test_case_dir}/{test_case}", test_case) response = StreamingHttpResponse(FileWrapper(open(file_name, "rb")), content_type="application/zip") response["Content-Disposition"] = f"attachment; filename=problem_{problem.id}_test_cases.zip" return response @problem_permission_required def post(self, request): form = TestCaseUploadForm(request.POST, request.FILES) if form.is_valid(): spj = form.cleaned_data["spj"] == "true" file = form.cleaned_data["file"] else: return self.error("Upload failed") tmp_file = os.path.join("/tmp", rand_str() + ".zip") with open(tmp_file, "wb") as f: for chunk in file: f.write(chunk) try: zip_file = zipfile.ZipFile(tmp_file) except zipfile.BadZipFile: return self.error("Bad zip file") name_list = zip_file.namelist() test_case_list = self.filter_name_list(name_list, spj=spj) if not test_case_list: return self.error("Empty file") test_case_id = rand_str() test_case_dir = os.path.join(settings.TEST_CASE_DIR, test_case_id) os.mkdir(test_case_dir) size_cache = {} md5_cache = {} for item in test_case_list: with open(os.path.join(test_case_dir, item), "wb") as f: content = zip_file.read(item).replace(b"\r\n", b"\n") size_cache[item] = len(content) if item.endswith(".out"): md5_cache[item] = hashlib.md5(content.rstrip()).hexdigest() f.write(content) test_case_info = {"spj": spj, "test_cases": {}} hint = None diff = set(name_list).difference(set(test_case_list)) if diff: hint = ", ".join(diff) + " are ignored" ret = [] if spj: for index, item in enumerate(test_case_list): data = {"input_name": item, "input_size": size_cache[item]} ret.append(data) test_case_info["test_cases"][str(index + 1)] = data else: # ["1.in", "1.out", "2.in", "2.out"] => [("1.in", "1.out"), ("2.in", "2.out")] test_case_list = zip(*[test_case_list[i::2] for i in range(2)]) for index, item in enumerate(test_case_list): data = {"stripped_output_md5": md5_cache[item[1]], "input_size": size_cache[item[0]], "output_size": size_cache[item[1]], "input_name": item[0], "output_name": item[1]} ret.append(data) test_case_info["test_cases"][str(index + 1)] = data with open(os.path.join(test_case_dir, "info"), "w", encoding="utf-8") as f: f.write(json.dumps(test_case_info, indent=4)) return self.success({"id": test_case_id, "info": ret, "hint": hint, "spj": spj}) class CompileSPJAPI(APIView): @validate_serializer(CompileSPJSerializer) @problem_permission_required def post(self, request): data = request.data spj_version = rand_str(8) error = SPJCompiler(data["spj_code"], spj_version, data["spj_language"]).compile_spj() if error: return self.error(error) else: return self.success() class ProblemBase(APIView): def common_checks(self, request): data = request.data if data["spj"]: if not data["spj_language"] or not data["spj_code"]: return "Invalid spj" if not data["spj_compile_ok"]: return "SPJ code must be compiled successfully" data["spj_version"] = hashlib.md5( (data["spj_language"] + ":" + data["spj_code"]).encode("utf-8")).hexdigest() else: data["spj_language"] = None data["spj_code"] = None if data["rule_type"] == ProblemRuleType.OI: total_score = 0 for item in data["test_case_score"]: if item["score"] <= 0: return "Invalid score" else: total_score += item["score"] data["total_score"] = total_score data["created_by"] = request.user data["languages"] = list(data["languages"]) @problem_permission_required def delete(self, request): id = request.GET.get("id") if not id: return self.error("Invalid parameter, id is requred") try: problem = Problem.objects.get(id=id) except Problem.DoesNotExist: return self.error("Problem does not exists") if Submission.objects.filter(problem=problem).exists(): return self.error("Can't delete the problem as it has submissions") d = os.path.join(settings.TEST_CASE_DIR, problem.test_case_id) if os.path.isdir(d): shutil.rmtree(d, ignore_errors=True) problem.delete() return self.success() class ProblemAPI(ProblemBase): @validate_serializer(CreateProblemSerializer) @problem_permission_required def post(self, request): data = request.data _id = data["_id"] if not _id: return self.error("Display ID is required") if Problem.objects.filter(_id=_id, contest_id__isnull=True).exists(): return self.error("Display ID already exists") error_info = self.common_checks(request) if error_info: return self.error(error_info) # todo check filename and score info tags = data.pop("tags") problem = Problem.objects.create(**data) for item in tags: try: tag = ProblemTag.objects.get(name=item) except ProblemTag.DoesNotExist: tag = ProblemTag.objects.create(name=item) problem.tags.add(tag) return self.success(ProblemAdminSerializer(problem).data) @problem_permission_required def get(self, request): problem_id = request.GET.get("id") user = request.user if problem_id: try: problem = Problem.objects.get(id=problem_id) if not user.can_mgmt_all_problem() and problem.created_by != user: return self.error("Problem does not exist") return self.success(ProblemAdminSerializer(problem).data) except Problem.DoesNotExist: return self.error("Problem does not exist") problems = Problem.objects.filter(contest_id__isnull=True).order_by("-create_time") if not user.can_mgmt_all_problem(): problems = problems.filter(created_by=user) keyword = request.GET.get("keyword") if keyword: problems = problems.filter(title__contains=keyword) return self.success(self.paginate_data(request, problems, ProblemAdminSerializer)) @validate_serializer(EditProblemSerializer) @problem_permission_required def put(self, request): data = request.data problem_id = data.pop("id") user = request.user try: problem = Problem.objects.get(id=problem_id) if not user.can_mgmt_all_problem() and problem.created_by != user: return self.error("Problem does not exist") except Problem.DoesNotExist: return self.error("Problem does not exist") _id = data["_id"] if not _id: return self.error("Display ID is required") if Problem.objects.exclude(id=problem_id).filter(_id=_id, contest_id__isnull=True).exists(): return self.error("Display ID already exists") error_info = self.common_checks(request) if error_info: return self.error(error_info) # todo check filename and score info tags = data.pop("tags") data["languages"] = list(data["languages"]) for k, v in data.items(): setattr(problem, k, v) problem.save() problem.tags.remove(*problem.tags.all()) for tag in tags: try: tag = ProblemTag.objects.get(name=tag) except ProblemTag.DoesNotExist: tag = ProblemTag.objects.create(name=tag) problem.tags.add(tag) return self.success() class ContestProblemAPI(ProblemBase): @validate_serializer(CreateContestProblemSerializer) @problem_permission_required def post(self, request): data = request.data try: contest = Contest.objects.get(id=data.pop("contest_id")) if request.user.is_admin() and contest.created_by != request.user: return self.error("Contest does not exist") except Contest.DoesNotExist: return self.error("Contest does not exist") if data["rule_type"] != contest.rule_type: return self.error("Invalid rule type") _id = data["_id"] if not _id: return self.error("Display ID is required") if Problem.objects.filter(_id=_id, contest=contest).exists(): return self.error("Duplicate Display id") error_info = self.common_checks(request) if error_info: return self.error(error_info) # todo check filename and score info data["contest"] = contest tags = data.pop("tags") problem = Problem.objects.create(**data) for item in tags: try: tag = ProblemTag.objects.get(name=item) except ProblemTag.DoesNotExist: tag = ProblemTag.objects.create(name=item) problem.tags.add(tag) return self.success(ContestProblemAdminSerializer(problem).data) @problem_permission_required def get(self, request): problem_id = request.GET.get("id") contest_id = request.GET.get("contest_id") user = request.user if problem_id: try: problem = Problem.objects.get(id=problem_id) if user.is_admin() and problem.contest.created_by != user: return self.error("Problem does not exist") except Problem.DoesNotExist: return self.error("Problem does not exist") return self.success(ProblemAdminSerializer(problem).data) if not contest_id: return self.error("Contest id is required") problems = Problem.objects.filter(contest_id=contest_id).order_by("-create_time") if user.is_admin(): problems = problems.filter(contest__created_by=user) keyword = request.GET.get("keyword") if keyword: problems = problems.filter(title__contains=keyword) return self.success(self.paginate_data(request, problems, ContestProblemAdminSerializer)) @validate_serializer(EditContestProblemSerializer) @problem_permission_required def put(self, request): data = request.data try: contest = Contest.objects.get(id=data.pop("contest_id")) if request.user.is_admin() and contest.created_by != request.user: return self.error("Contest does not exist") except Contest.DoesNotExist: return self.error("Contest does not exist") if data["rule_type"] != contest.rule_type: return self.error("Invalid rule type") problem_id = data.pop("id") user = request.user try: problem = Problem.objects.get(id=problem_id) if not user.can_mgmt_all_problem() and problem.created_by != user: return self.error("Problem does not exist") except Problem.DoesNotExist: return self.error("Problem does not exist") _id = data["_id"] if not _id: return self.error("Display ID is required") if Problem.objects.exclude(id=problem_id).filter(_id=_id, contest=contest).exists(): return self.error("Display ID already exists") error_info = self.common_checks(request) if error_info: return self.error(error_info) # todo check filename and score info tags = data.pop("tags") data["languages"] = list(data["languages"]) for k, v in data.items(): setattr(problem, k, v) problem.save() problem.tags.remove(*problem.tags.all()) for tag in tags: try: tag = ProblemTag.objects.get(name=tag) except ProblemTag.DoesNotExist: tag = ProblemTag.objects.create(name=tag) problem.tags.add(tag) return self.success() class MakeContestProblemPublicAPIView(APIView): @validate_serializer(ContestProblemMakePublicSerializer) @problem_permission_required def post(self, request): data = request.data display_id = data.get("display_id") if Problem.objects.filter(_id=display_id, contest_id__isnull=True).exists(): return self.error("Duplicate display ID") try: problem = Problem.objects.get(id=data["id"]) except Problem.DoesNotExist: return self.error("Problem does not exist") if not problem.contest or problem.is_public: return self.error("Alreay be a public problem") problem.is_public = True problem.save() # https://docs.djangoproject.com/en/1.11/topics/db/queries/#copying-model-instances tags = problem.tags.all() problem.pk = None problem.contest = None problem._id = display_id problem.submission_number = problem.accepted_number = 0 problem.statistic_info = {} problem.save() problem.tags.set(tags) return self.success()