Fixed safe.py for pytorch 1.13 ckpt files

This commit is contained in:
SmirkingFace 2022-12-02 11:12:13 +01:00
parent 4b3c5bc24b
commit e461477869

View File

@ -62,14 +62,12 @@ class RestrictedUnpickler(pickle.Unpickler):
raise Exception(f"global '{module}/{name}' is forbidden") raise Exception(f"global '{module}/{name}' is forbidden")
allowed_zip_names = ["archive/data.pkl", "archive/version"] # Regular expression that accepts 'dirname/version', 'dirname/data.pkl', and 'dirname/data/<number>'
allowed_zip_names_re = re.compile(r"^archive/data/\d+$") allowed_zip_names_re = re.compile(r"^([^/]+)/((data/\d+)|version|(data\.pkl))$")
data_pkl_re = re.compile(r"^([^/]+)/data\.pkl$")
def check_zip_filenames(filename, names): def check_zip_filenames(filename, names):
for name in names: for name in names:
if name in allowed_zip_names:
continue
if allowed_zip_names_re.match(name): if allowed_zip_names_re.match(name):
continue continue
@ -82,8 +80,14 @@ def check_pt(filename, extra_handler):
# new pytorch format is a zip file # new pytorch format is a zip file
with zipfile.ZipFile(filename) as z: with zipfile.ZipFile(filename) as z:
check_zip_filenames(filename, z.namelist()) check_zip_filenames(filename, z.namelist())
with z.open('archive/data.pkl') as file: # find filename of data.pkl in zip file: '<directory name>/data.pkl'
data_pkl_filenames = [f for f in z.namelist() if data_pkl_re.match(f)]
if len(data_pkl_filenames) == 0:
raise Exception(f"data.pkl not found in {filename}")
if len(data_pkl_filenames) > 1:
raise Exception(f"Multiple data.pkl found in {filename}")
with z.open(data_pkl_filenames[0]) as file:
unpickler = RestrictedUnpickler(file) unpickler = RestrictedUnpickler(file)
unpickler.extra_handler = extra_handler unpickler.extra_handler = extra_handler
unpickler.load() unpickler.load()