From 76ab31e18898d4c2aacb9725cfbe25b230bff974 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?= <41315874+fumiama@users.noreply.github.com> Date: Sat, 12 Nov 2022 11:02:40 +0800 Subject: [PATCH] Fix wrong mps selection below MasOS 12.3 --- modules/devices.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/modules/devices.py b/modules/devices.py index 7511e1dc..9a3d29d7 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -3,8 +3,15 @@ import contextlib import torch from modules import errors -# has_mps is only available in nightly pytorch (for now), `getattr` for compatibility -has_mps = getattr(torch, 'has_mps', False) +# has_mps is only available in nightly pytorch (for now) and MasOS 12.3+. +# check `getattr` and try it for compatibility +def has_mps() -> bool: + if getattr(torch, 'has_mps', False): return False + try: + torch.zeros(1).to(torch.device("mps")) + return True + except Exception: + return False cpu = torch.device("cpu") @@ -25,7 +32,7 @@ def get_optimal_device(): else: return torch.device("cuda") - if has_mps: + if has_mps(): return torch.device("mps") return cpu