Merge pull request #117 from QingdaoU/import_export_problem

add import and export problem
This commit is contained in:
李扬 2018-04-18 00:11:32 +08:00 committed by GitHub
commit cb5b9692b3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 483 additions and 119 deletions

View File

@ -1,26 +1,33 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import base64 import base64
import copy import copy
import random import random
import string import string
import hashlib
import json
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
class FPSParser(object): class FPSParser(object):
def __init__(self, fps_path): def __init__(self, fps_path=None, string_data=None):
self.fps_path = fps_path if fps_path:
self._etree = ET.parse(fps_path).getroot()
@property elif string_data:
def _root(self): self._ertree = ET.fromstring(string_data).getroot()
root = ET.ElementTree(file=self.fps_path).getroot() else:
version = root.attrib.get("version", "No Version") raise ValueError("You must tell me the file path or directly give me the data for the file")
version = self._etree.attrib.get("version", "No Version")
if version not in ["1.1", "1.2"]: if version not in ["1.1", "1.2"]:
raise ValueError("Unsupported version '" + version + "'") raise ValueError("Unsupported version '" + version + "'")
return root
@property
def etree(self):
return self._etree
def parse(self): def parse(self):
ret = [] ret = []
for node in self._root: for node in self._etree:
if node.tag == "item": if node.tag == "item":
ret.append(self._parse_one_problem(node)) ret.append(self._parse_one_problem(node))
return ret return ret
@ -112,20 +119,50 @@ class FPSHelper(object):
_problem[item] = _problem[item].replace(img["src"], os.path.join(base_url, file_name)) _problem[item] = _problem[item].replace(img["src"], os.path.join(base_url, file_name))
return _problem return _problem
def save_test_case(self, problem, base_dir, input_preprocessor=None, output_preprocessor=None): # {
# "spj": false,
# "test_cases": {
# "1": {
# "stripped_output_md5": "84f244e41d3c8fd4bdb43ed0e1f7a067",
# "input_size": 12,
# "output_size": 7,
# "input_name": "1.in",
# "output_name": "1.out"
# }
# }
# }
def save_test_case(self, problem, base_dir):
spj = problem.get("spj", {})
test_cases = {}
for index, item in enumerate(problem["test_cases"]): for index, item in enumerate(problem["test_cases"]):
input_content = item.get("input")
output_content = item.get("output")
if input_content:
with open(os.path.join(base_dir, str(index + 1) + ".in"), "w", encoding="utf-8") as f: with open(os.path.join(base_dir, str(index + 1) + ".in"), "w", encoding="utf-8") as f:
if input_preprocessor:
input_content = input_preprocessor(item["input"])
else:
input_content = item["input"]
f.write(input_content) f.write(input_content)
if output_content:
with open(os.path.join(base_dir, str(index + 1) + ".out"), "w", encoding="utf-8") as f: with open(os.path.join(base_dir, str(index + 1) + ".out"), "w", encoding="utf-8") as f:
if output_preprocessor:
output_content = output_preprocessor(item["output"])
else:
output_content = item["output"]
f.write(output_content) f.write(output_content)
if spj:
one_info = {
"input_size": len(input_content),
"input_name": f"{index}.in"
}
else:
one_info = {
"input_size": len(input_content),
"input_name": f"{index}.in",
"output_size": len(output_content),
"output_name": f"{index}.out",
"stripped_output_md5": hashlib.md5(output_content.rstrip()).hexdigest()
}
test_cases[index] = one_info
info = {
"spj": True if spj else False,
"test_cases": test_cases
}
with open(os.path.join(base_dir, "info"), "w", encoding="utf-8") as f:
f.write(json.dumps(info, indent=4))
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,7 +1,9 @@
from django import forms from django import forms
from options.options import SysOptions
from judge.languages import language_names, spj_language_names from judge.languages import language_names, spj_language_names
from utils.api import UsernameSerializer, serializers from utils.api import UsernameSerializer, serializers
from utils.constants import Difficulty
from .models import Problem, ProblemRuleType, ProblemTag from .models import Problem, ProblemRuleType, ProblemTag
from .utils import parse_problem_template from .utils import parse_problem_template
@ -27,12 +29,6 @@ class CreateProblemCodeTemplateSerializer(serializers.Serializer):
pass pass
class Difficulty(object):
LOW = "Low"
MID = "Mid"
HIGH = "High"
class CreateOrEditProblemSerializer(serializers.Serializer): class CreateOrEditProblemSerializer(serializers.Serializer):
_id = serializers.CharField(max_length=32, allow_blank=True, allow_null=True) _id = serializers.CharField(max_length=32, allow_blank=True, allow_null=True)
title = serializers.CharField(max_length=128) title = serializers.CharField(max_length=128)
@ -41,7 +37,7 @@ class CreateOrEditProblemSerializer(serializers.Serializer):
output_description = serializers.CharField() output_description = serializers.CharField()
samples = serializers.ListField(child=CreateSampleSerializer(), allow_empty=False) samples = serializers.ListField(child=CreateSampleSerializer(), allow_empty=False)
test_case_id = serializers.CharField(max_length=32) test_case_id = serializers.CharField(max_length=32)
test_case_score = serializers.ListField(child=CreateTestCaseScoreSerializer(), allow_empty=False) test_case_score = serializers.ListField(child=CreateTestCaseScoreSerializer(), allow_empty=True)
time_limit = serializers.IntegerField(min_value=1, max_value=1000 * 60) time_limit = serializers.IntegerField(min_value=1, max_value=1000 * 60)
memory_limit = serializers.IntegerField(min_value=1, max_value=1024) memory_limit = serializers.IntegerField(min_value=1, max_value=1024)
languages = serializers.MultipleChoiceField(choices=language_names) languages = serializers.MultipleChoiceField(choices=language_names)
@ -52,7 +48,7 @@ class CreateOrEditProblemSerializer(serializers.Serializer):
spj_code = serializers.CharField(allow_blank=True, allow_null=True) spj_code = serializers.CharField(allow_blank=True, allow_null=True)
spj_compile_ok = serializers.BooleanField(default=False) spj_compile_ok = serializers.BooleanField(default=False)
visible = serializers.BooleanField() visible = serializers.BooleanField()
difficulty = serializers.ChoiceField(choices=[Difficulty.LOW, Difficulty.MID, Difficulty.HIGH]) difficulty = serializers.ChoiceField(choices=Difficulty.choices())
tags = serializers.ListField(child=serializers.CharField(max_length=32), allow_empty=False) tags = serializers.ListField(child=serializers.CharField(max_length=32), allow_empty=False)
hint = serializers.CharField(allow_blank=True, allow_null=True) hint = serializers.CharField(allow_blank=True, allow_null=True)
source = serializers.CharField(max_length=256, allow_blank=True, allow_null=True) source = serializers.CharField(max_length=256, allow_blank=True, allow_null=True)
@ -128,41 +124,42 @@ class ContestProblemMakePublicSerializer(serializers.Serializer):
class ExportProblemSerializer(serializers.ModelSerializer): class ExportProblemSerializer(serializers.ModelSerializer):
display_id = serializers.SerializerMethodField()
description = serializers.SerializerMethodField() description = serializers.SerializerMethodField()
input_description = serializers.SerializerMethodField() input_description = serializers.SerializerMethodField()
output_description = serializers.SerializerMethodField() output_description = serializers.SerializerMethodField()
test_case_score = serializers.SerializerMethodField() test_case_score = serializers.SerializerMethodField()
hint = serializers.SerializerMethodField() hint = serializers.SerializerMethodField()
time_limit = serializers.SerializerMethodField()
memory_limit = serializers.SerializerMethodField()
spj = serializers.SerializerMethodField() spj = serializers.SerializerMethodField()
template = serializers.SerializerMethodField() template = serializers.SerializerMethodField()
source = serializers.SerializerMethodField()
tags = serializers.SlugRelatedField(many=True, slug_field="name", read_only=True)
def get_display_id(self, obj):
return obj._id
def _html_format_value(self, value):
return {"format": "html", "value": value}
def get_description(self, obj): def get_description(self, obj):
return {"format": "html", "value": obj.description} return self._html_format_value(obj.description)
def get_input_description(self, obj): def get_input_description(self, obj):
return {"format": "html", "value": obj.input_description} return self._html_format_value(obj.input_description)
def get_output_description(self, obj): def get_output_description(self, obj):
return {"format": "html", "value": obj.output_description} return self._html_format_value(obj.output_description)
def get_hint(self, obj): def get_hint(self, obj):
return {"format": "html", "value": obj.hint} return self._html_format_value(obj.hint)
def get_test_case_score(self, obj): def get_test_case_score(self, obj):
return obj.test_case_score if obj.rule_type == ProblemRuleType.OI else [] return [{"score": item["score"], "input_name": item["input_name"]}
for item in obj.test_case_score] if obj.rule_type == ProblemRuleType.OI else None
def get_time_limit(self, obj):
return {"unit": "ms", "value": obj.time_limit}
def get_memory_limit(self, obj):
return {"unit": "MB", "value": obj.memory_limit}
def get_spj(self, obj): def get_spj(self, obj):
return {"enabled": obj.spj, return {"code": obj.spj_code,
"code": obj.spj_code if obj.spj else None, "language": obj.spj_language} if obj.spj else None
"language": obj.spj_language if obj.spj else None}
def get_template(self, obj): def get_template(self, obj):
ret = {} ret = {}
@ -170,9 +167,12 @@ class ExportProblemSerializer(serializers.ModelSerializer):
ret[k] = parse_problem_template(v) ret[k] = parse_problem_template(v)
return ret return ret
def get_source(self, obj):
return obj.source or f"{SysOptions.website_name} {SysOptions.website_base_url}"
class Meta: class Meta:
model = Problem model = Problem
fields = ("_id", "title", "description", fields = ("display_id", "title", "description", "tags",
"input_description", "output_description", "input_description", "output_description",
"test_case_score", "hint", "time_limit", "memory_limit", "samples", "test_case_score", "hint", "time_limit", "memory_limit", "samples",
"template", "spj", "rule_type", "source", "template") "template", "spj", "rule_type", "source", "template")
@ -182,3 +182,76 @@ class AddContestProblemSerializer(serializers.Serializer):
contest_id = serializers.IntegerField() contest_id = serializers.IntegerField()
problem_id = serializers.IntegerField() problem_id = serializers.IntegerField()
display_id = serializers.CharField() display_id = serializers.CharField()
class ExportProblemRequestSerialzier(serializers.Serializer):
problem_id = serializers.ListField(child=serializers.IntegerField(), allow_empty=False)
class UploadProblemForm(forms.Form):
file = forms.FileField()
class FormatValueSerializer(serializers.Serializer):
format = serializers.ChoiceField(choices=["html", "markdown"])
value = serializers.CharField(allow_blank=True)
class TestCaseScoreSerializer(serializers.Serializer):
score = serializers.IntegerField(min_value=1)
input_name = serializers.CharField(max_length=32)
class TemplateSerializer(serializers.Serializer):
prepend = serializers.CharField()
template = serializers.CharField()
append = serializers.CharField()
class SPJSerializer(serializers.Serializer):
code = serializers.CharField()
language = serializers.ChoiceField(choices=spj_language_names)
class AnswerSerializer(serializers.Serializer):
code = serializers.CharField()
language = serializers.ChoiceField(choices=language_names)
class ImportProblemSerializer(serializers.Serializer):
display_id = serializers.CharField(max_length=128)
title = serializers.CharField(max_length=128)
description = FormatValueSerializer()
input_description = FormatValueSerializer()
output_description = FormatValueSerializer()
hint = FormatValueSerializer()
test_case_score = serializers.ListField(child=TestCaseScoreSerializer(), allow_null=True)
time_limit = serializers.IntegerField(min_value=1, max_value=60000)
memory_limit = serializers.IntegerField(min_value=1, max_value=10240)
samples = serializers.ListField(child=CreateSampleSerializer())
template = serializers.DictField(child=TemplateSerializer())
spj = SPJSerializer(allow_null=True)
rule_type = serializers.ChoiceField(choices=ProblemRuleType.choices())
source = serializers.CharField(max_length=200, allow_blank=True, allow_null=True)
answers = serializers.ListField(child=AnswerSerializer())
tags = serializers.ListField(child=serializers.CharField())
class FPSProblemSerializer(serializers.Serializer):
class UnitSerializer(serializers.Serializer):
unit = serializers.ChoiceField(choices=["MB", "s", "ms"])
value = serializers.IntegerField(min_value=1, max_value=60000)
title = serializers.CharField(max_length=128)
description = serializers.CharField()
input = serializers.CharField()
output = serializers.CharField()
hint = serializers.CharField(allow_blank=True, allow_null=True)
time_limit = UnitSerializer()
memory_limit = UnitSerializer()
samples = serializers.ListField(child=CreateSampleSerializer())
source = serializers.CharField(max_length=200, allow_blank=True, allow_null=True)
spj = SPJSerializer(allow_null=True)
template = serializers.ListField(child=serializers.DictField(), allow_empty=True, allow_null=True)
append = serializers.ListField(child=serializers.DictField(), allow_empty=True, allow_null=True)
prepend = serializers.ListField(child=serializers.DictField(), allow_empty=True, allow_null=True)

View File

@ -1,7 +1,8 @@
from django.conf.urls import url from django.conf.urls import url
from ..views.admin import ContestProblemAPI, ProblemAPI, TestCaseAPI, MakeContestProblemPublicAPIView from ..views.admin import (ContestProblemAPI, ProblemAPI, TestCaseAPI, MakeContestProblemPublicAPIView,
from ..views.admin import CompileSPJAPI, AddContestProblemAPI CompileSPJAPI, AddContestProblemAPI, ExportProblemAPI, ImportProblemAPI,
FPSProblemImport)
urlpatterns = [ urlpatterns = [
url(r"^test_case/?$", TestCaseAPI.as_view(), name="test_case_api"), url(r"^test_case/?$", TestCaseAPI.as_view(), name="test_case_api"),
@ -10,4 +11,7 @@ urlpatterns = [
url(r"^contest/problem/?$", ContestProblemAPI.as_view(), name="contest_problem_admin_api"), url(r"^contest/problem/?$", ContestProblemAPI.as_view(), name="contest_problem_admin_api"),
url(r"^contest_problem/make_public/?$", MakeContestProblemPublicAPIView.as_view(), name="make_public_api"), url(r"^contest_problem/make_public/?$", MakeContestProblemPublicAPIView.as_view(), name="make_public_api"),
url(r"^contest/add_problem_from_public/?$", AddContestProblemAPI.as_view(), name="add_contest_problem_from_public_api"), url(r"^contest/add_problem_from_public/?$", AddContestProblemAPI.as_view(), name="add_contest_problem_from_public_api"),
url(r"^export_problem/?$", ExportProblemAPI.as_view(), name="export_problem_api"),
url(r"^import_problem/?$", ImportProblemAPI.as_view(), name="import_problem_api"),
url(r"^import_fps/?$", FPSProblemImport.as_view(), name="fps_problem_api"),
] ]

View File

@ -1,5 +1,17 @@
import re import re
TEMPLATE_BASE = """//PREPEND BEGIN
{}
//PREPEND END
//TEMPLATE BEGIN
{}
//TEMPLATE END
//APPEND BEGIN
{}
//APPEND END"""
def parse_problem_template(template_str): def parse_problem_template(template_str):
prepend = re.findall("//PREPEND BEGIN\n([\s\S]+?)//PREPEND END", template_str) prepend = re.findall("//PREPEND BEGIN\n([\s\S]+?)//PREPEND END", template_str)
@ -8,3 +20,7 @@ def parse_problem_template(template_str):
return {"prepend": prepend[0] if prepend else "", return {"prepend": prepend[0] if prepend else "",
"template": template[0] if template else "", "template": template[0] if template else "",
"append": append[0] if append else ""} "append": append[0] if append else ""}
def build_problem_template(prepend, template, append):
return TEMPLATE_BASE.format(prepend, template, append)

View File

@ -3,35 +3,91 @@ import json
import os import os
import shutil import shutil
import zipfile import zipfile
import tempfile
from wsgiref.util import FileWrapper from wsgiref.util import FileWrapper
from django.conf import settings from django.conf import settings
from django.http import StreamingHttpResponse, HttpResponse from django.http import StreamingHttpResponse, HttpResponse, FileResponse
from django.db import transaction
from account.decorators import problem_permission_required, ensure_created_by from account.decorators import problem_permission_required, ensure_created_by
from judge.dispatcher import SPJCompiler from judge.dispatcher import SPJCompiler
from judge.languages import language_names
from contest.models import Contest, ContestStatus from contest.models import Contest, ContestStatus
from submission.models import Submission from submission.models import Submission, JudgeStatus
from utils.api import APIView, CSRFExemptAPIView, validate_serializer from fps.parser import FPSHelper, FPSParser
from utils.api import APIView, CSRFExemptAPIView, validate_serializer, APIError
from utils.shortcuts import rand_str, natural_sort_key from utils.shortcuts import rand_str, natural_sort_key
from utils.tasks import delete_files
from utils.constants import Difficulty
from ..utils import TEMPLATE_BASE, build_problem_template
from ..models import Problem, ProblemRuleType, ProblemTag from ..models import Problem, ProblemRuleType, ProblemTag
from ..serializers import (CreateContestProblemSerializer, CompileSPJSerializer, from ..serializers import (CreateContestProblemSerializer, CompileSPJSerializer,
CreateProblemSerializer, EditProblemSerializer, EditContestProblemSerializer, CreateProblemSerializer, EditProblemSerializer, EditContestProblemSerializer,
ProblemAdminSerializer, TestCaseUploadForm, ContestProblemMakePublicSerializer, ProblemAdminSerializer, TestCaseUploadForm, ContestProblemMakePublicSerializer,
AddContestProblemSerializer) AddContestProblemSerializer, ExportProblemSerializer,
ExportProblemRequestSerialzier, UploadProblemForm, ImportProblemSerializer,
FPSProblemSerializer)
class TestCaseAPI(CSRFExemptAPIView): class TestCaseZipProcessor(object):
request_parsers = () def process_zip(self, uploaded_zip_file, spj, dir=""):
try:
zip_file = zipfile.ZipFile(uploaded_zip_file, "r")
except zipfile.BadZipFile:
raise APIError("Bad zip file")
name_list = zip_file.namelist()
test_case_list = self.filter_name_list(name_list, spj=spj, dir=dir)
if not test_case_list:
raise APIError("Empty file")
def filter_name_list(self, name_list, spj): test_case_id = rand_str()
test_case_dir = os.path.join(settings.TEST_CASE_DIR, test_case_id)
os.mkdir(test_case_dir)
size_cache = {}
md5_cache = {}
for item in test_case_list:
with open(os.path.join(test_case_dir, item), "wb") as f:
content = zip_file.read(f"{dir}{item}").replace(b"\r\n", b"\n")
size_cache[item] = len(content)
if item.endswith(".out"):
md5_cache[item] = hashlib.md5(content.rstrip()).hexdigest()
f.write(content)
test_case_info = {"spj": spj, "test_cases": {}}
info = []
if spj:
for index, item in enumerate(test_case_list):
data = {"input_name": item, "input_size": size_cache[item]}
info.append(data)
test_case_info["test_cases"][str(index + 1)] = data
else:
# ["1.in", "1.out", "2.in", "2.out"] => [("1.in", "1.out"), ("2.in", "2.out")]
test_case_list = zip(*[test_case_list[i::2] for i in range(2)])
for index, item in enumerate(test_case_list):
data = {"stripped_output_md5": md5_cache[item[1]],
"input_size": size_cache[item[0]],
"output_size": size_cache[item[1]],
"input_name": item[0],
"output_name": item[1]}
info.append(data)
test_case_info["test_cases"][str(index + 1)] = data
with open(os.path.join(test_case_dir, "info"), "w", encoding="utf-8") as f:
f.write(json.dumps(test_case_info, indent=4))
return info, test_case_id
def filter_name_list(self, name_list, spj, dir=""):
ret = [] ret = []
prefix = 1 prefix = 1
if spj: if spj:
while True: while True:
in_name = str(prefix) + ".in" in_name = f"{prefix}.in"
if in_name in name_list: if f"{dir}{in_name}" in name_list:
ret.append(in_name) ret.append(in_name)
prefix += 1 prefix += 1
continue continue
@ -39,9 +95,9 @@ class TestCaseAPI(CSRFExemptAPIView):
return sorted(ret, key=natural_sort_key) return sorted(ret, key=natural_sort_key)
else: else:
while True: while True:
in_name = str(prefix) + ".in" in_name = f"{prefix}.in"
out_name = str(prefix) + ".out" out_name = f"{prefix}.out"
if in_name in name_list and out_name in name_list: if f"{dir}{in_name}" in name_list and f"{dir}{out_name}" in name_list:
ret.append(in_name) ret.append(in_name)
ret.append(out_name) ret.append(out_name)
prefix += 1 prefix += 1
@ -49,6 +105,10 @@ class TestCaseAPI(CSRFExemptAPIView):
else: else:
return sorted(ret, key=natural_sort_key) return sorted(ret, key=natural_sort_key)
class TestCaseAPI(CSRFExemptAPIView, TestCaseZipProcessor):
request_parsers = ()
def get(self, request): def get(self, request):
problem_id = request.GET.get("problem_id") problem_id = request.GET.get("problem_id")
if not problem_id: if not problem_id:
@ -90,62 +150,13 @@ class TestCaseAPI(CSRFExemptAPIView):
file = form.cleaned_data["file"] file = form.cleaned_data["file"]
else: else:
return self.error("Upload failed") return self.error("Upload failed")
tmp_file = os.path.join("/tmp", rand_str() + ".zip") zip_file = f"/tmp/{rand_str()}.zip"
with open(tmp_file, "wb") as f: with open(zip_file, "wb") as f:
for chunk in file: for chunk in file:
f.write(chunk) f.write(chunk)
try: info, test_case_id = self.process_zip(zip_file, spj=spj)
zip_file = zipfile.ZipFile(tmp_file) os.remove(zip_file)
except zipfile.BadZipFile: return self.success({"id": test_case_id, "info": info, "spj": spj})
return self.error("Bad zip file")
name_list = zip_file.namelist()
test_case_list = self.filter_name_list(name_list, spj=spj)
if not test_case_list:
return self.error("Empty file")
test_case_id = rand_str()
test_case_dir = os.path.join(settings.TEST_CASE_DIR, test_case_id)
os.mkdir(test_case_dir)
size_cache = {}
md5_cache = {}
for item in test_case_list:
with open(os.path.join(test_case_dir, item), "wb") as f:
content = zip_file.read(item).replace(b"\r\n", b"\n")
size_cache[item] = len(content)
if item.endswith(".out"):
md5_cache[item] = hashlib.md5(content.rstrip()).hexdigest()
f.write(content)
test_case_info = {"spj": spj, "test_cases": {}}
hint = None
diff = set(name_list).difference(set(test_case_list))
if diff:
hint = ", ".join(diff) + " are ignored"
ret = []
if spj:
for index, item in enumerate(test_case_list):
data = {"input_name": item, "input_size": size_cache[item]}
ret.append(data)
test_case_info["test_cases"][str(index + 1)] = data
else:
# ["1.in", "1.out", "2.in", "2.out"] => [("1.in", "1.out"), ("2.in", "2.out")]
test_case_list = zip(*[test_case_list[i::2] for i in range(2)])
for index, item in enumerate(test_case_list):
data = {"stripped_output_md5": md5_cache[item[1]],
"input_size": size_cache[item[0]],
"output_size": size_cache[item[1]],
"input_name": item[0],
"output_name": item[1]}
ret.append(data)
test_case_info["test_cases"][str(index + 1)] = data
with open(os.path.join(test_case_dir, "info"), "w", encoding="utf-8") as f:
f.write(json.dumps(test_case_info, indent=4))
return self.success({"id": test_case_id, "info": ret, "hint": hint, "spj": spj})
class CompileSPJAPI(APIView): class CompileSPJAPI(APIView):
@ -466,3 +477,204 @@ class AddContestProblemAPI(APIView):
problem.save() problem.save()
problem.tags.set(tags) problem.tags.set(tags)
return self.success() return self.success()
class ExportProblemAPI(APIView):
def choose_answers(self, user, problem):
ret = []
for item in problem.languages:
submission = Submission.objects.filter(problem=problem,
user_id=user.id,
language=item,
result=JudgeStatus.ACCEPTED).order_by("-create_time").first()
if submission:
ret.append({"language": submission.language, "code": submission.code})
return ret
def process_one_problem(self, zip_file, user, problem, index):
info = ExportProblemSerializer(problem).data
info["answers"] = self.choose_answers(user, problem=problem)
compression = zipfile.ZIP_DEFLATED
zip_file.writestr(zinfo_or_arcname=f"{index}/problem.json",
data=json.dumps(info, indent=4),
compress_type=compression)
problem_test_case_dir = os.path.join(settings.TEST_CASE_DIR, problem.test_case_id)
with open(os.path.join(problem_test_case_dir, "info")) as f:
info = json.load(f)
for k, v in info["test_cases"].items():
zip_file.write(filename=os.path.join(problem_test_case_dir, v["input_name"]),
arcname=f"{index}/testcase/{v['input_name']}",
compress_type=compression)
if not info["spj"]:
zip_file.write(filename=os.path.join(problem_test_case_dir, v["output_name"]),
arcname=f"{index}/testcase/{v['output_name']}",
compress_type=compression)
@validate_serializer(ExportProblemRequestSerialzier)
def get(self, request):
problems = Problem.objects.filter(id__in=request.data["problem_id"])
for problem in problems:
if problem.contest:
ensure_created_by(problem.contest, request.user)
else:
ensure_created_by(problem, request.user)
path = f"/tmp/{rand_str()}.zip"
with zipfile.ZipFile(path, "w") as zip_file:
for index, problem in enumerate(problems):
self.process_one_problem(zip_file=zip_file, user=request.user, problem=problem, index=index + 1)
delete_files.apply_async((path,), countdown=300)
resp = FileResponse(open(path, "rb"))
resp["Content-Type"] = "application/zip"
resp["Content-Disposition"] = f"attachment;filename=problem-export.zip"
return resp
class ImportProblemAPI(CSRFExemptAPIView, TestCaseZipProcessor):
request_parsers = ()
def post(self, request):
form = UploadProblemForm(request.POST, request.FILES)
if form.is_valid():
file = form.cleaned_data["file"]
tmp_file = f"/tmp/{rand_str()}.zip"
with open(tmp_file, "wb") as f:
for chunk in file:
f.write(chunk)
else:
return self.error("Upload failed")
count = 0
with zipfile.ZipFile(tmp_file, "r") as zip_file:
name_list = zip_file.namelist()
for item in name_list:
if "/problem.json" in item:
count += 1
with transaction.atomic():
for i in range(1, count + 1):
with zip_file.open(f"{i}/problem.json") as f:
problem_info = json.load(f)
serializer = ImportProblemSerializer(data=problem_info)
if not serializer.is_valid():
return self.error(f"Invalid problem format, error is {serializer.errors}")
else:
problem_info = serializer.data
for item in problem_info["template"].keys():
if item not in language_names:
return self.error(f"Unsupported language {item}")
problem_info["display_id"] = problem_info["display_id"][:24]
for k, v in problem_info["template"].items():
problem_info["template"][k] = build_problem_template(v["prepend"], v["template"],
v["append"])
spj = problem_info["spj"] is not None
rule_type = problem_info["rule_type"]
test_case_score = problem_info["test_case_score"]
# process test case
_, test_case_id = self.process_zip(tmp_file, spj=spj, dir=f"{i}/testcase/")
problem_obj = Problem.objects.create(_id=problem_info["display_id"],
title=problem_info["title"],
description=problem_info["description"]["value"],
input_description=problem_info["input_description"][
"value"],
output_description=problem_info["output_description"][
"value"],
hint=problem_info["hint"]["value"],
test_case_score=test_case_score if test_case_score else [],
time_limit=problem_info["time_limit"],
memory_limit=problem_info["memory_limit"],
samples=problem_info["samples"],
template=problem_info["template"],
rule_type=problem_info["rule_type"],
source=problem_info["source"],
spj=spj,
spj_code=problem_info["spj"]["code"] if spj else None,
spj_language=problem_info["spj"][
"language"] if spj else None,
spj_version=rand_str(8) if spj else "",
languages=language_names,
created_by=request.user,
visible=False,
difficulty=Difficulty.MID,
total_score=sum(item["score"] for item in test_case_score)
if rule_type == ProblemRuleType.OI else 0,
test_case_id=test_case_id
)
for tag_name in problem_info["tags"]:
tag_obj, _ = ProblemTag.objects.get_or_create(name=tag_name)
problem_obj.tags.add(tag_obj)
return self.success({"import_count": count})
class FPSProblemImport(CSRFExemptAPIView):
request_parsers = ()
def _create_problem(self, problem_data, creator):
if problem_data["time_limit"]["unit"] == "ms":
time_limit = problem_data["time_limit"]["value"]
else:
time_limit = problem_data["time_limit"]["value"] * 1000
template = {}
prepend = {}
append = {}
for t in problem_data["prepend"]:
prepend[t["language"]] = t["code"]
for t in problem_data["append"]:
append[t["language"]] = t["code"]
for t in problem_data["template"]:
our_lang = lang = t["language"]
if lang == "Python":
our_lang = "Python3"
template[our_lang] = TEMPLATE_BASE.format(prepend.get(lang, ""), t["code"], append.get(lang, ""))
spj = problem_data["spj"] is not None
Problem.objects.create(_id=f"fps-{rand_str(4)}",
title=problem_data["title"],
description=problem_data["description"],
input_description=problem_data["input"],
output_description=problem_data["output"],
hint=problem_data["hint"],
test_case_score=[],
time_limit=time_limit,
memory_limit=problem_data["memory_limit"]["value"],
samples=problem_data["samples"],
template=template,
rule_type=ProblemRuleType.ACM,
source=problem_data.get("source", ""),
spj=spj,
spj_code=problem_data["spj"]["code"] if spj else None,
spj_language=problem_data["spj"]["language"] if spj else None,
spj_version=rand_str(8) if spj else "",
visible=False,
languages=language_names,
created_by=creator,
difficulty=Difficulty.MID,
test_case_id=problem_data["test_case_id"])
def post(self, request):
form = UploadProblemForm(request.POST, request.FILES)
if form.is_valid():
file = form.cleaned_data["file"]
with tempfile.NamedTemporaryFile("wb") as tf:
for chunk in file.chunks(4096):
tf.file.write(chunk)
problems = FPSParser(tf.name).parse()
else:
return self.error("Parse upload file error")
helper = FPSHelper()
with transaction.atomic():
for _problem in problems:
test_case_id = rand_str()
test_case_dir = os.path.join(settings.TEST_CASE_DIR, test_case_id)
os.mkdir(test_case_dir)
helper.save_test_case(_problem, test_case_dir)
problem_data = helper.save_image(_problem, settings.UPLOAD_DIR, settings.UPLOAD_PREFIX)
s = FPSProblemSerializer(data=problem_data)
if not s.is_valid():
return self.error(f"Parse FPS file error: {s.errors}")
problem_data = s.data
problem_data["test_case_id"] = test_case_id
self._create_problem(problem_data, request.user)
return self.success({"import_count": len(problems)})

View File

@ -26,3 +26,9 @@ class CacheKey:
contest_rank_cache = "contest_rank_cache" contest_rank_cache = "contest_rank_cache"
website_config = "website_config" website_config = "website_config"
option = "option" option = "option"
class Difficulty(Choices):
LOW = "Low"
MID = "Mid"
HIGH = "High"

View File

@ -5,6 +5,7 @@ import re
import json import json
import django import django
import hashlib import hashlib
from json.decoder import JSONDecodeError
sys.path.append("../") sys.path.append("../")
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "oj.settings") os.environ.setdefault("DJANGO_SETTINGS_MODULE", "oj.settings")
@ -59,8 +60,8 @@ def set_problem_display_id_prefix():
def get_stripped_output_md5(test_case_id, output_name): def get_stripped_output_md5(test_case_id, output_name):
output_path = os.path.join(settings.TEST_CASE_DIR, test_case_id, output_name) output_path = os.path.join(settings.TEST_CASE_DIR, test_case_id, output_name)
with open(output_path, "r") as f: with open(output_path, 'r') as f:
return hashlib.md5(f.read().encode("utf-8").rstrip()).hexdigest() return hashlib.md5(f.read().rstrip().encode('utf-8')).hexdigest()
def get_test_case_score(test_case_id): def get_test_case_score(test_case_id):
@ -190,8 +191,12 @@ if __name__ == "__main__":
print("Data file does not exist") print("Data file does not exist")
exit(1) exit(1)
try:
with open(data_path, "r") as data_file: with open(data_path, "r") as data_file:
old_data = json.load(data_file) old_data = json.load(data_file)
except JSONDecodeError:
print("Data file format error, ensure it's a valid json file!")
exit(1)
print("Read old data successfully.\n") print("Read old data successfully.\n")
for obj in old_data: for obj in old_data:

11
utils/tasks.py Normal file
View File

@ -0,0 +1,11 @@
import os
from celery import shared_task
@shared_task
def delete_files(*args):
for item in args:
try:
os.remove(item)
except Exception:
pass