a way to add an exception to unpickler without explicitly calling load_with_extra

This commit is contained in:
AUTOMATIC 2022-12-25 09:03:56 +03:00
parent c5bdba2089
commit 8eef9d8e78

View File

@ -103,7 +103,7 @@ def check_pt(filename, extra_handler):
def load(filename, *args, **kwargs):
return load_with_extra(filename, *args, **kwargs)
return load_with_extra(filename, extra_handler=global_extra_handler, *args, **kwargs)
def load_with_extra(filename, extra_handler=None, *args, **kwargs):
@ -151,5 +151,42 @@ def load_with_extra(filename, extra_handler=None, *args, **kwargs):
return unsafe_torch_load(filename, *args, **kwargs)
class Extra:
"""
A class for temporarily setting the global handler for when you can't explicitly call load_with_extra
(because it's not your code making the torch.load call). The intended use is like this:
```
import torch
from modules import safe
def handler(module, name):
if module == 'torch' and name in ['float64', 'float16']:
return getattr(torch, name)
return None
with safe.Extra(handler):
x = torch.load('model.pt')
```
"""
def __init__(self, handler):
self.handler = handler
def __enter__(self):
global global_extra_handler
assert global_extra_handler is None, 'already inside an Extra() block'
global_extra_handler = self.handler
def __exit__(self, exc_type, exc_val, exc_tb):
global global_extra_handler
global_extra_handler = None
unsafe_torch_load = torch.load
torch.load = load
global_extra_handler = None