send tensors to the correct device when loading from safetensors file with memmap disabled for #11260
This commit is contained in:
parent
14196548c5
commit
24129368f1
@ -246,11 +246,13 @@ def read_metadata_from_safetensors(filename):
|
|||||||
def read_state_dict(checkpoint_file, print_global_state=False, map_location=None):
|
def read_state_dict(checkpoint_file, print_global_state=False, map_location=None):
|
||||||
_, extension = os.path.splitext(checkpoint_file)
|
_, extension = os.path.splitext(checkpoint_file)
|
||||||
if extension.lower() == ".safetensors":
|
if extension.lower() == ".safetensors":
|
||||||
|
device = map_location or shared.weight_load_location or devices.get_optimal_device_name()
|
||||||
|
|
||||||
if not shared.opts.disable_mmap_load_safetensors:
|
if not shared.opts.disable_mmap_load_safetensors:
|
||||||
device = map_location or shared.weight_load_location or devices.get_optimal_device_name()
|
|
||||||
pl_sd = safetensors.torch.load_file(checkpoint_file, device=device)
|
pl_sd = safetensors.torch.load_file(checkpoint_file, device=device)
|
||||||
else:
|
else:
|
||||||
pl_sd = safetensors.torch.load(open(checkpoint_file, 'rb').read())
|
pl_sd = safetensors.torch.load(open(checkpoint_file, 'rb').read())
|
||||||
|
pl_sd = {k: v.to(device) for k, v in pl_sd.items()}
|
||||||
else:
|
else:
|
||||||
pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location)
|
pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location)
|
||||||
|
|
||||||
|
@ -376,7 +376,7 @@ options_templates.update(options_section(('system', "System"), {
|
|||||||
"multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."),
|
"multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."),
|
||||||
"print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."),
|
"print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."),
|
||||||
"list_hidden_files": OptionInfo(True, "Load models/files in hidden directories").info("directory is hidden if its name starts with \".\""),
|
"list_hidden_files": OptionInfo(True, "Load models/files in hidden directories").info("directory is hidden if its name starts with \".\""),
|
||||||
"disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files (fixes very slow loading speed in some cases)."),
|
"disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files.").info("fixes very slow loading speed in some cases"),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('training', "Training"), {
|
options_templates.update(options_section(('training', "Training"), {
|
||||||
|
Loading…
Reference in New Issue
Block a user