diff --git a/modules/devices.py b/modules/devices.py new file mode 100644 index 00000000..25008a04 --- /dev/null +++ b/modules/devices.py @@ -0,0 +1,12 @@ +import torch + + +# has_mps is only available in nightly pytorch (for now), `getattr` for compatibility +has_mps = getattr(torch, 'has_mps', False) + +def get_optimal_device(): + if torch.cuda.is_available(): + return torch.device("cuda") + if has_mps: + return torch.device("mps") + return torch.device("cpu") diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py index e86ad775..7f3baf31 100644 --- a/modules/esrgan_model.py +++ b/modules/esrgan_model.py @@ -9,12 +9,13 @@ from PIL import Image import modules.esrgam_model_arch as arch from modules import shared from modules.shared import opts +from modules.devices import has_mps import modules.images def load_model(filename): # this code is adapted from https://github.com/xinntao/ESRGAN - pretrained_net = torch.load(filename, map_location='cpu' if torch.has_mps else None) + pretrained_net = torch.load(filename, map_location='cpu' if has_mps else None) crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32) if 'conv_first.weight' in pretrained_net: diff --git a/modules/lowvram.py b/modules/lowvram.py index bd117491..079386c3 100644 --- a/modules/lowvram.py +++ b/modules/lowvram.py @@ -1,13 +1,9 @@ import torch +from modules.devices import get_optimal_device module_in_gpu = None cpu = torch.device("cpu") -if torch.has_cuda: - device = gpu = torch.device("cuda") -elif torch.has_mps: - device = gpu = torch.device("mps") -else: - device = gpu = torch.device("cpu") +device = gpu = get_optimal_device() def setup_for_low_vram(sd_model, use_medvram): parents = {} diff --git a/modules/shared.py b/modules/shared.py index 6ca9106c..74b0ad89 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -9,6 +9,7 @@ import tqdm import modules.artists from modules.paths import script_path, sd_path +from modules.devices import get_optimal_device import modules.styles config_filename = "config.json" @@ -43,12 +44,8 @@ parser.add_argument("--ui-config-file", type=str, help="filename to use for ui c cmd_opts = parser.parse_args() -if torch.has_cuda: - device = torch.device("cuda") -elif torch.has_mps: - device = torch.device("mps") -else: - device = torch.device("cpu") +device = get_optimal_device() + batch_cond_uncond = cmd_opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram) parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram