From 98ca437edfbf71dd956d67d37f2136b12d13be0d Mon Sep 17 00:00:00 2001 From: brkirch Date: Sat, 12 Nov 2022 02:17:55 -0500 Subject: [PATCH 1/4] Refactor and instead check if mps is being used, not availability --- modules/sd_hijack.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index b824b5bf..ce583950 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -182,11 +182,7 @@ def register_buffer(self, name, attr): if type(attr) == torch.Tensor: if attr.device != devices.device: - - if devices.has_mps(): - attr = attr.to(device="mps", dtype=torch.float32) - else: - attr = attr.to(devices.device) + attr = attr.to(device=devices.device, dtype=(torch.float32 if devices.device.type == 'mps' else None)) setattr(self, name, attr) From 21effd629d0fdfdbbff2b20a9f4a3767e7e8bd33 Mon Sep 17 00:00:00 2001 From: brkirch Date: Mon, 28 Nov 2022 21:24:06 -0500 Subject: [PATCH 2/4] Add workaround for using MPS with torchsde --- modules/sd_samplers.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 5fefb227..8b11f569 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -6,6 +6,7 @@ import tqdm from PIL import Image import inspect import k_diffusion.sampling +import torchsde._brownian.brownian_interval import ldm.models.diffusion.ddim import ldm.models.diffusion.plms from modules import prompt_parser, devices, processing, images @@ -367,6 +368,19 @@ class TorchHijack: return torch.randn_like(x) +# MPS fix for randn in torchsde +def torchsde_randn(size, dtype, device, seed): + if device.type == 'mps': + generator = torch.Generator(devices.cpu).manual_seed(int(seed)) + return torch.randn(size, dtype=dtype, device=devices.cpu, generator=generator).to(device) + else: + generator = torch.Generator(device).manual_seed(int(seed)) + return torch.randn(size, dtype=dtype, device=device, generator=generator) + + +torchsde._brownian.brownian_interval._randn = torchsde_randn + + class KDiffusionSampler: def __init__(self, funcname, sd_model): denoiser = k_diffusion.external.CompVisVDenoiser if sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser From 4d5f1691dda971ec7b461dd880426300fd54ccee Mon Sep 17 00:00:00 2001 From: brkirch Date: Mon, 28 Nov 2022 21:36:35 -0500 Subject: [PATCH 3/4] Use devices.autocast instead of torch.autocast --- modules/hypernetworks/hypernetwork.py | 2 +- modules/interrogate.py | 3 +-- modules/swinir_model.py | 6 +----- modules/textual_inversion/dataset.py | 4 ++-- modules/textual_inversion/textual_inversion.py | 2 +- 5 files changed, 6 insertions(+), 11 deletions(-) diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 8466887f..eb5ae372 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -495,7 +495,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, if shared.state.interrupted: break - with torch.autocast("cuda"): + with devices.autocast(): x = batch.latent_sample.to(devices.device, non_blocking=pin_memory) if tag_drop_out != 0 or shuffle_tags: shared.sd_model.cond_stage_model.to(devices.device) diff --git a/modules/interrogate.py b/modules/interrogate.py index 9769aa34..40c6b082 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -148,8 +148,7 @@ class InterrogateModels: clip_image = self.clip_preprocess(pil_image).unsqueeze(0).type(self.dtype).to(devices.device_interrogate) - precision_scope = torch.autocast if shared.cmd_opts.precision == "autocast" else contextlib.nullcontext - with torch.no_grad(), precision_scope("cuda"): + with torch.no_grad(), devices.autocast(): image_features = self.clip_model.encode_image(clip_image).type(self.dtype) image_features /= image_features.norm(dim=-1, keepdim=True) diff --git a/modules/swinir_model.py b/modules/swinir_model.py index facd262d..483eabd4 100644 --- a/modules/swinir_model.py +++ b/modules/swinir_model.py @@ -13,10 +13,6 @@ from modules.swinir_model_arch import SwinIR as net from modules.swinir_model_arch_v2 import Swin2SR as net2 from modules.upscaler import Upscaler, UpscalerData -precision_scope = ( - torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext -) - class UpscalerSwinIR(Upscaler): def __init__(self, dirname): @@ -112,7 +108,7 @@ def upscale( img = np.moveaxis(img, 2, 0) / 255 img = torch.from_numpy(img).float() img = img.unsqueeze(0).to(devices.device_swinir) - with torch.no_grad(), precision_scope("cuda"): + with torch.no_grad(), devices.autocast(): _, _, h_old, w_old = img.size() h_pad = (h_old // window_size + 1) * window_size - h_old w_pad = (w_old // window_size + 1) * window_size - w_old diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index e5725f33..2dc64c3c 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -82,7 +82,7 @@ class PersonalizedBase(Dataset): torchdata = torch.from_numpy(npimage).permute(2, 0, 1).to(device=device, dtype=torch.float32) latent_sample = None - with torch.autocast("cuda"): + with devices.autocast(): latent_dist = model.encode_first_stage(torchdata.unsqueeze(dim=0)) if latent_sampling_method == "once" or (latent_sampling_method == "deterministic" and not isinstance(latent_dist, DiagonalGaussianDistribution)): @@ -101,7 +101,7 @@ class PersonalizedBase(Dataset): entry.cond_text = self.create_text(filename_text) if include_cond and not (self.tag_drop_out != 0 or self.shuffle_tags): - with torch.autocast("cuda"): + with devices.autocast(): entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0) self.dataset.append(entry) diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 4eb75cb5..daf8d1b8 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -316,7 +316,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ if shared.state.interrupted: break - with torch.autocast("cuda"): + with devices.autocast(): # c = stack_conds(batch.cond).to(devices.device) # mask = torch.tensor(batch.emb_index).to(devices.device, non_blocking=pin_memory) # print(mask) From 0fddb4a1c06a6e2122add7eee3b001a6d473baee Mon Sep 17 00:00:00 2001 From: brkirch Date: Wed, 30 Nov 2022 08:02:39 -0500 Subject: [PATCH 4/4] Rework MPS randn fix, add randn_like fix torch.manual_seed() already sets a CPU generator, so there is no reason to create a CPU generator manually. torch.randn_like also needs a MPS fix for k-diffusion, but a torch hijack with randn_like already exists so it can also be used for that. --- modules/devices.py | 15 +++------------ modules/sd_samplers.py | 8 +++++--- 2 files changed, 8 insertions(+), 15 deletions(-) diff --git a/modules/devices.py b/modules/devices.py index f00079c6..046460fa 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -66,24 +66,15 @@ dtype_vae = torch.float16 def randn(seed, shape): - # Pytorch currently doesn't handle setting randomness correctly when the metal backend is used. - if device.type == 'mps': - generator = torch.Generator(device=cpu) - generator.manual_seed(seed) - noise = torch.randn(shape, generator=generator, device=cpu).to(device) - return noise - torch.manual_seed(seed) + if device.type == 'mps': + return torch.randn(shape, device=cpu).to(device) return torch.randn(shape, device=device) def randn_without_seed(shape): - # Pytorch currently doesn't handle setting randomness correctly when the metal backend is used. if device.type == 'mps': - generator = torch.Generator(device=cpu) - noise = torch.randn(shape, generator=generator, device=cpu).to(device) - return noise - + return torch.randn(shape, device=cpu).to(device) return torch.randn(shape, device=device) diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 8b11f569..4c123d3b 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -365,7 +365,10 @@ class TorchHijack: if noise.shape == x.shape: return noise - return torch.randn_like(x) + if x.device.type == 'mps': + return torch.randn_like(x, device=devices.cpu).to(x.device) + else: + return torch.randn_like(x) # MPS fix for randn in torchsde @@ -429,8 +432,7 @@ class KDiffusionSampler: self.model_wrap.step = 0 self.eta = p.eta or opts.eta_ancestral - if self.sampler_noises is not None: - k_diffusion.sampling.torch = TorchHijack(self.sampler_noises) + k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else []) extra_params_kwargs = {} for param_name in self.extra_params: