From 0ab0a50f9ae14bd7ce7ec518323ebd31c7971155 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 12 Nov 2022 10:00:49 +0300 Subject: [PATCH] change formatting to match the main program in devices.py --- modules/devices.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/modules/devices.py b/modules/devices.py index bd3e4ffb..67165bf6 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -3,23 +3,27 @@ import contextlib import torch from modules import errors + # has_mps is only available in nightly pytorch (for now) and MasOS 12.3+. # check `getattr` and try it for compatibility def has_mps() -> bool: - if not getattr(torch, 'has_mps', False): return False + if not getattr(torch, 'has_mps', False): + return False try: torch.zeros(1).to(torch.device("mps")) return True except Exception: return False -cpu = torch.device("cpu") def extract_device_id(args, name): for x in range(len(args)): - if name in args[x]: return args[x+1] + if name in args[x]: + return args[x + 1] + return None + def get_optimal_device(): if torch.cuda.is_available(): from modules import shared @@ -52,10 +56,12 @@ def enable_tf32(): errors.run(enable_tf32, "Enabling TF32") +cpu = torch.device("cpu") device = device_interrogate = device_gfpgan = device_swinir = device_esrgan = device_scunet = device_codeformer = None dtype = torch.float16 dtype_vae = torch.float16 + def randn(seed, shape): # Pytorch currently doesn't handle setting randomness correctly when the metal backend is used. if device.type == 'mps': @@ -89,6 +95,11 @@ def autocast(disable=False): return torch.autocast("cuda") + # MPS workaround for https://github.com/pytorch/pytorch/issues/79383 -def mps_contiguous(input_tensor, device): return input_tensor.contiguous() if device.type == 'mps' else input_tensor -def mps_contiguous_to(input_tensor, device): return mps_contiguous(input_tensor, device).to(device) +def mps_contiguous(input_tensor, device): + return input_tensor.contiguous() if device.type == 'mps' else input_tensor + + +def mps_contiguous_to(input_tensor, device): + return mps_contiguous(input_tensor, device).to(device)