Allow bf16 in safe unpickler

This commit is contained in:
catboxanon 2023-05-13 11:04:26 -04:00 committed by GitHub
parent 2b3fc246b0
commit cb5f61281a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -40,7 +40,7 @@ class RestrictedUnpickler(pickle.Unpickler):
return getattr(collections, name) return getattr(collections, name)
if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter', '_rebuild_device_tensor_from_numpy']: if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter', '_rebuild_device_tensor_from_numpy']:
return getattr(torch._utils, name) return getattr(torch._utils, name)
if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage', 'ByteStorage', 'float32']: if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage', 'ByteStorage', 'float32', 'BFloat16Storage']:
return getattr(torch, name) return getattr(torch, name)
if module == 'torch.nn.modules.container' and name in ['ParameterDict']: if module == 'torch.nn.modules.container' and name in ['ParameterDict']:
return getattr(torch.nn.modules.container, name) return getattr(torch.nn.modules.container, name)