diff --git a/modules/lowvram.py b/modules/lowvram.py index 4b78deab..bd117491 100644 --- a/modules/lowvram.py +++ b/modules/lowvram.py @@ -2,9 +2,12 @@ import torch module_in_gpu = None cpu = torch.device("cpu") -gpu = torch.device("cuda") -device = gpu if torch.cuda.is_available() else cpu - +if torch.has_cuda: + device = gpu = torch.device("cuda") +elif torch.has_mps: + device = gpu = torch.device("mps") +else: + device = gpu = torch.device("cpu") def setup_for_low_vram(sd_model, use_medvram): parents = {}