Add safetensors support to LDSR

This commit is contained in:
wywywywy 2022-12-10 18:57:18 +00:00
parent 685f9631b5
commit 8bcdd50461
2 changed files with 14 additions and 4 deletions

View File

@ -1,3 +1,4 @@
import os
import gc import gc
import time import time
import warnings import warnings
@ -8,6 +9,7 @@ import torchvision
from PIL import Image from PIL import Image
from einops import rearrange, repeat from einops import rearrange, repeat
from omegaconf import OmegaConf from omegaconf import OmegaConf
import safetensors.torch
from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.ddim import DDIMSampler
from ldm.util import instantiate_from_config, ismap from ldm.util import instantiate_from_config, ismap
@ -28,8 +30,12 @@ class LDSR:
model: torch.nn.Module = cached_ldsr_model model: torch.nn.Module = cached_ldsr_model
else: else:
print(f"Loading model from {self.modelPath}") print(f"Loading model from {self.modelPath}")
pl_sd = torch.load(self.modelPath, map_location="cpu") _, extension = os.path.splitext(self.modelPath)
sd = pl_sd["state_dict"] if extension.lower() == ".safetensors":
pl_sd = safetensors.torch.load_file(self.modelPath, device="cpu")
else:
pl_sd = torch.load(self.modelPath, map_location="cpu")
sd = pl_sd["state_dict"] if "state_dict" in pl_sd else pl_sd
config = OmegaConf.load(self.yamlPath) config = OmegaConf.load(self.yamlPath)
config.model.target = "ldm.models.diffusion.ddpm.LatentDiffusionV1" config.model.target = "ldm.models.diffusion.ddpm.LatentDiffusionV1"
model: torch.nn.Module = instantiate_from_config(config.model) model: torch.nn.Module = instantiate_from_config(config.model)

View File

@ -25,6 +25,7 @@ class UpscalerLDSR(Upscaler):
yaml_path = os.path.join(self.model_path, "project.yaml") yaml_path = os.path.join(self.model_path, "project.yaml")
old_model_path = os.path.join(self.model_path, "model.pth") old_model_path = os.path.join(self.model_path, "model.pth")
new_model_path = os.path.join(self.model_path, "model.ckpt") new_model_path = os.path.join(self.model_path, "model.ckpt")
safetensors_model_path = os.path.join(self.model_path, "model.safetensors")
if os.path.exists(yaml_path): if os.path.exists(yaml_path):
statinfo = os.stat(yaml_path) statinfo = os.stat(yaml_path)
if statinfo.st_size >= 10485760: if statinfo.st_size >= 10485760:
@ -33,8 +34,11 @@ class UpscalerLDSR(Upscaler):
if os.path.exists(old_model_path): if os.path.exists(old_model_path):
print("Renaming model from model.pth to model.ckpt") print("Renaming model from model.pth to model.ckpt")
os.rename(old_model_path, new_model_path) os.rename(old_model_path, new_model_path)
model = load_file_from_url(url=self.model_url, model_dir=self.model_path, if os.path.exists(safetensors_model_path):
file_name="model.ckpt", progress=True) model = safetensors_model_path
else:
model = load_file_from_url(url=self.model_url, model_dir=self.model_path,
file_name="model.ckpt", progress=True)
yaml = load_file_from_url(url=self.yaml_url, model_dir=self.model_path, yaml = load_file_from_url(url=self.yaml_url, model_dir=self.model_path,
file_name="project.yaml", progress=True) file_name="project.yaml", progress=True)