import hashlib import json import os import shutil import zipfile from wsgiref.util import FileWrapper from django.conf import settings from django.http import StreamingHttpResponse, HttpResponse from account.decorators import problem_permission_required, ensure_created_by from judge.dispatcher import SPJCompiler from contest.models import Contest, ContestStatus from submission.models import Submission from utils.api import APIView, CSRFExemptAPIView, validate_serializer from utils.shortcuts import rand_str, natural_sort_key from ..models import Problem, ProblemRuleType, ProblemTag from ..serializers import (CreateContestProblemSerializer, CompileSPJSerializer, CreateProblemSerializer, EditProblemSerializer, EditContestProblemSerializer, ProblemAdminSerializer, TestCaseUploadForm, ContestProblemMakePublicSerializer, AddContestProblemSerializer) class TestCaseAPI(CSRFExemptAPIView): request_parsers = () def filter_name_list(self, name_list, spj): ret = [] prefix = 1 if spj: while True: in_name = str(prefix) + ".in" if in_name in name_list: ret.append(in_name) prefix += 1 continue else: return sorted(ret, key=natural_sort_key) else: while True: in_name = str(prefix) + ".in" out_name = str(prefix) + ".out" if in_name in name_list and out_name in name_list: ret.append(in_name) ret.append(out_name) prefix += 1 continue else: return sorted(ret, key=natural_sort_key) def get(self, request): problem_id = request.GET.get("problem_id") if not problem_id: return self.error("Parameter error, problem_id is required") try: problem = Problem.objects.get(id=problem_id) except Problem.DoesNotExist: return self.error("Problem does not exists") if problem.contest: ensure_created_by(problem.contest, request.user) else: ensure_created_by(problem, request.user) test_case_dir = os.path.join(settings.TEST_CASE_DIR, problem.test_case_id) if not os.path.isdir(test_case_dir): return self.error("Test case does not exists") name_list = self.filter_name_list(os.listdir(test_case_dir), problem.spj) name_list.append("info") file_name = os.path.join(test_case_dir, problem.test_case_id + ".zip") with zipfile.ZipFile(file_name, "w") as file: for test_case in name_list: file.write(f"{test_case_dir}/{test_case}", test_case) if os.environ.get("OJ_ENV") == "production": response = HttpResponse() response["X-Accel-Redirect"] = file_name else: response = StreamingHttpResponse(FileWrapper(open(file_name, "rb")), content_type="application/octet-stream") response["Content-Disposition"] = f"attachment; filename=problem_{problem.id}_test_cases.zip" response["Content-Length"] = os.path.getsize(file_name) return response def post(self, request): form = TestCaseUploadForm(request.POST, request.FILES) if form.is_valid(): spj = form.cleaned_data["spj"] == "true" file = form.cleaned_data["file"] else: return self.error("Upload failed") tmp_file = os.path.join("/tmp", rand_str() + ".zip") with open(tmp_file, "wb") as f: for chunk in file: f.write(chunk) try: zip_file = zipfile.ZipFile(tmp_file) except zipfile.BadZipFile: return self.error("Bad zip file") name_list = zip_file.namelist() test_case_list = self.filter_name_list(name_list, spj=spj) if not test_case_list: return self.error("Empty file") test_case_id = rand_str() test_case_dir = os.path.join(settings.TEST_CASE_DIR, test_case_id) os.mkdir(test_case_dir) size_cache = {} md5_cache = {} for item in test_case_list: with open(os.path.join(test_case_dir, item), "wb") as f: content = zip_file.read(item).replace(b"\r\n", b"\n") size_cache[item] = len(content) if item.endswith(".out"): md5_cache[item] = hashlib.md5(content.rstrip()).hexdigest() f.write(content) test_case_info = {"spj": spj, "test_cases": {}} hint = None diff = set(name_list).difference(set(test_case_list)) if diff: hint = ", ".join(diff) + " are ignored" ret = [] if spj: for index, item in enumerate(test_case_list): data = {"input_name": item, "input_size": size_cache[item]} ret.append(data) test_case_info["test_cases"][str(index + 1)] = data else: # ["1.in", "1.out", "2.in", "2.out"] => [("1.in", "1.out"), ("2.in", "2.out")] test_case_list = zip(*[test_case_list[i::2] for i in range(2)]) for index, item in enumerate(test_case_list): data = {"stripped_output_md5": md5_cache[item[1]], "input_size": size_cache[item[0]], "output_size": size_cache[item[1]], "input_name": item[0], "output_name": item[1]} ret.append(data) test_case_info["test_cases"][str(index + 1)] = data with open(os.path.join(test_case_dir, "info"), "w", encoding="utf-8") as f: f.write(json.dumps(test_case_info, indent=4)) return self.success({"id": test_case_id, "info": ret, "hint": hint, "spj": spj}) class CompileSPJAPI(APIView): @validate_serializer(CompileSPJSerializer) def post(self, request): data = request.data spj_version = rand_str(8) error = SPJCompiler(data["spj_code"], spj_version, data["spj_language"]).compile_spj() if error: return self.error(error) else: return self.success() class ProblemBase(APIView): def common_checks(self, request): data = request.data if data["spj"]: if not data["spj_language"] or not data["spj_code"]: return "Invalid spj" if not data["spj_compile_ok"]: return "SPJ code must be compiled successfully" data["spj_version"] = hashlib.md5( (data["spj_language"] + ":" + data["spj_code"]).encode("utf-8")).hexdigest() else: data["spj_language"] = None data["spj_code"] = None if data["rule_type"] == ProblemRuleType.OI: total_score = 0 for item in data["test_case_score"]: if item["score"] <= 0: return "Invalid score" else: total_score += item["score"] data["total_score"] = total_score data["created_by"] = request.user data["languages"] = list(data["languages"]) @problem_permission_required def delete(self, request): id = request.GET.get("id") if not id: return self.error("Invalid parameter, id is required") try: problem = Problem.objects.get(id=id, contest_id__isnull=True) except Problem.DoesNotExist: return self.error("Problem does not exists") ensure_created_by(problem, request.user) if Submission.objects.filter(problem=problem).exists(): return self.error("Can't delete the problem as it has submissions") d = os.path.join(settings.TEST_CASE_DIR, problem.test_case_id) if os.path.isdir(d): shutil.rmtree(d, ignore_errors=True) problem.delete() return self.success() class ProblemAPI(ProblemBase): @problem_permission_required @validate_serializer(CreateProblemSerializer) def post(self, request): data = request.data _id = data["_id"] if not _id: return self.error("Display ID is required") if Problem.objects.filter(_id=_id, contest_id__isnull=True).exists(): return self.error("Display ID already exists") error_info = self.common_checks(request) if error_info: return self.error(error_info) # todo check filename and score info tags = data.pop("tags") problem = Problem.objects.create(**data) for item in tags: try: tag = ProblemTag.objects.get(name=item) except ProblemTag.DoesNotExist: tag = ProblemTag.objects.create(name=item) problem.tags.add(tag) return self.success(ProblemAdminSerializer(problem).data) @problem_permission_required def get(self, request): problem_id = request.GET.get("id") rule_type = request.GET.get("rule_type") user = request.user if problem_id: try: problem = Problem.objects.get(id=problem_id) ensure_created_by(problem, request.user) return self.success(ProblemAdminSerializer(problem).data) except Problem.DoesNotExist: return self.error("Problem does not exist") problems = Problem.objects.filter(contest_id__isnull=True).order_by("-create_time") if rule_type: if rule_type not in ProblemRuleType.choices(): return self.error("Invalid rule_type") else: problems = problems.filter(rule_type=rule_type) if not user.can_mgmt_all_problem(): problems = problems.filter(created_by=user) keyword = request.GET.get("keyword") if keyword: problems = problems.filter(title__contains=keyword) return self.success(self.paginate_data(request, problems, ProblemAdminSerializer)) @problem_permission_required @validate_serializer(EditProblemSerializer) def put(self, request): data = request.data problem_id = data.pop("id") try: problem = Problem.objects.get(id=problem_id) ensure_created_by(problem, request.user) except Problem.DoesNotExist: return self.error("Problem does not exist") _id = data["_id"] if not _id: return self.error("Display ID is required") if Problem.objects.exclude(id=problem_id).filter(_id=_id, contest_id__isnull=True).exists(): return self.error("Display ID already exists") error_info = self.common_checks(request) if error_info: return self.error(error_info) # todo check filename and score info tags = data.pop("tags") data["languages"] = list(data["languages"]) for k, v in data.items(): setattr(problem, k, v) problem.save() problem.tags.remove(*problem.tags.all()) for tag in tags: try: tag = ProblemTag.objects.get(name=tag) except ProblemTag.DoesNotExist: tag = ProblemTag.objects.create(name=tag) problem.tags.add(tag) return self.success() class ContestProblemAPI(ProblemBase): @validate_serializer(CreateContestProblemSerializer) def post(self, request): data = request.data try: contest = Contest.objects.get(id=data.pop("contest_id")) ensure_created_by(contest, request.user) except Contest.DoesNotExist: return self.error("Contest does not exist") if data["rule_type"] != contest.rule_type: return self.error("Invalid rule type") _id = data["_id"] if not _id: return self.error("Display ID is required") if Problem.objects.filter(_id=_id, contest=contest).exists(): return self.error("Duplicate Display id") error_info = self.common_checks(request) if error_info: return self.error(error_info) # todo check filename and score info data["contest"] = contest tags = data.pop("tags") problem = Problem.objects.create(**data) for item in tags: try: tag = ProblemTag.objects.get(name=item) except ProblemTag.DoesNotExist: tag = ProblemTag.objects.create(name=item) problem.tags.add(tag) return self.success(ProblemAdminSerializer(problem).data) @problem_permission_required def get(self, request): problem_id = request.GET.get("id") contest_id = request.GET.get("contest_id") user = request.user if problem_id: try: problem = Problem.objects.get(id=problem_id) ensure_created_by(problem, user) except Problem.DoesNotExist: return self.error("Problem does not exist") return self.success(ProblemAdminSerializer(problem).data) if not contest_id: return self.error("Contest id is required") problems = Problem.objects.filter(contest_id=contest_id).order_by("-create_time") if user.is_admin(): problems = problems.filter(contest__created_by=user) keyword = request.GET.get("keyword") if keyword: problems = problems.filter(title__contains=keyword) return self.success(self.paginate_data(request, problems, ProblemAdminSerializer)) @validate_serializer(EditContestProblemSerializer) @problem_permission_required def put(self, request): data = request.data user = request.user try: contest = Contest.objects.get(id=data.pop("contest_id")) ensure_created_by(contest, user) except Contest.DoesNotExist: return self.error("Contest does not exist") if data["rule_type"] != contest.rule_type: return self.error("Invalid rule type") problem_id = data.pop("id") try: problem = Problem.objects.get(id=problem_id) ensure_created_by(problem, user) except Problem.DoesNotExist: return self.error("Problem does not exist") _id = data["_id"] if not _id: return self.error("Display ID is required") if Problem.objects.exclude(id=problem_id).filter(_id=_id, contest=contest).exists(): return self.error("Display ID already exists") error_info = self.common_checks(request) if error_info: return self.error(error_info) # todo check filename and score info tags = data.pop("tags") data["languages"] = list(data["languages"]) for k, v in data.items(): setattr(problem, k, v) problem.save() problem.tags.remove(*problem.tags.all()) for tag in tags: try: tag = ProblemTag.objects.get(name=tag) except ProblemTag.DoesNotExist: tag = ProblemTag.objects.create(name=tag) problem.tags.add(tag) return self.success() class MakeContestProblemPublicAPIView(APIView): @validate_serializer(ContestProblemMakePublicSerializer) @problem_permission_required def post(self, request): data = request.data display_id = data.get("display_id") if Problem.objects.filter(_id=display_id, contest_id__isnull=True).exists(): return self.error("Duplicate display ID") try: problem = Problem.objects.get(id=data["id"]) except Problem.DoesNotExist: return self.error("Problem does not exist") if not problem.contest or problem.is_public: return self.error("Already be a public problem") problem.is_public = True problem.save() # https://docs.djangoproject.com/en/1.11/topics/db/queries/#copying-model-instances tags = problem.tags.all() problem.pk = None problem.contest = None problem._id = display_id problem.visible = False problem.submission_number = problem.accepted_number = 0 problem.statistic_info = {} problem.save() problem.tags.set(tags) return self.success() class AddContestProblemAPI(APIView): @validate_serializer(AddContestProblemSerializer) def post(self, request): data = request.data try: contest = Contest.objects.get(id=data["contest_id"]) problem = Problem.objects.get(id=data["problem_id"]) except (Contest.DoesNotExist, Problem.DoesNotExist): return self.error("Contest or Problem does not exist") if contest.status == ContestStatus.CONTEST_ENDED: return self.error("Contest has ended") if Problem.objects.filter(contest=contest, _id=data["display_id"]).exists(): return self.error("Duplicate display id in this contest") tags = problem.tags.all() problem.pk = None problem.contest = contest problem.is_public = True problem.visible = True problem._id = request.data["display_id"] problem.submission_number = problem.accepted_number = 0 problem.statistic_info = {} problem.save() problem.tags.set(tags) return self.success()