faed465a0b
Get ESRGAN, SCUNet, and SwinIR working correctly on MPS by ensuring memory is contiguous for tensor views before sending to MPS device.
88 lines
2.5 KiB
Python
88 lines
2.5 KiB
Python
import sys, os, shlex
|
|
import contextlib
|
|
import torch
|
|
from modules import errors
|
|
|
|
# has_mps is only available in nightly pytorch (for now), `getattr` for compatibility
|
|
has_mps = getattr(torch, 'has_mps', False)
|
|
|
|
cpu = torch.device("cpu")
|
|
|
|
def extract_device_id(args, name):
|
|
for x in range(len(args)):
|
|
if name in args[x]: return args[x+1]
|
|
return None
|
|
|
|
def get_optimal_device():
|
|
if torch.cuda.is_available():
|
|
from modules import shared
|
|
|
|
device_id = shared.cmd_opts.device_id
|
|
|
|
if device_id is not None:
|
|
cuda_device = f"cuda:{device_id}"
|
|
return torch.device(cuda_device)
|
|
else:
|
|
return torch.device("cuda")
|
|
|
|
if has_mps:
|
|
return torch.device("mps")
|
|
|
|
return cpu
|
|
|
|
|
|
def torch_gc():
|
|
if torch.cuda.is_available():
|
|
torch.cuda.empty_cache()
|
|
torch.cuda.ipc_collect()
|
|
|
|
|
|
def enable_tf32():
|
|
if torch.cuda.is_available():
|
|
torch.backends.cuda.matmul.allow_tf32 = True
|
|
torch.backends.cudnn.allow_tf32 = True
|
|
|
|
|
|
errors.run(enable_tf32, "Enabling TF32")
|
|
|
|
device = device_interrogate = device_gfpgan = device_swinir = device_esrgan = device_scunet = device_codeformer = None
|
|
dtype = torch.float16
|
|
dtype_vae = torch.float16
|
|
|
|
def randn(seed, shape):
|
|
# Pytorch currently doesn't handle setting randomness correctly when the metal backend is used.
|
|
if device.type == 'mps':
|
|
generator = torch.Generator(device=cpu)
|
|
generator.manual_seed(seed)
|
|
noise = torch.randn(shape, generator=generator, device=cpu).to(device)
|
|
return noise
|
|
|
|
torch.manual_seed(seed)
|
|
return torch.randn(shape, device=device)
|
|
|
|
|
|
def randn_without_seed(shape):
|
|
# Pytorch currently doesn't handle setting randomness correctly when the metal backend is used.
|
|
if device.type == 'mps':
|
|
generator = torch.Generator(device=cpu)
|
|
noise = torch.randn(shape, generator=generator, device=cpu).to(device)
|
|
return noise
|
|
|
|
return torch.randn(shape, device=device)
|
|
|
|
|
|
def autocast(disable=False):
|
|
from modules import shared
|
|
|
|
if disable:
|
|
return contextlib.nullcontext()
|
|
|
|
if dtype == torch.float32 or shared.cmd_opts.precision == "full":
|
|
return contextlib.nullcontext()
|
|
|
|
return torch.autocast("cuda")
|
|
|
|
# MPS workaround for https://github.com/pytorch/pytorch/issues/79383
|
|
def mps_contiguous(input_tensor, device): return input_tensor.contiguous() if device.type == 'mps' else input_tensor
|
|
def mps_contiguous_to(input_tensor, device): return mps_contiguous(input_tensor, device).to(device)
|