Fix warning of 'has_mps' is deprecated from PyTorch
This commit is contained in:
parent
fab73f2e7d
commit
daf41a2734
@ -4,16 +4,21 @@ from modules.sd_hijack_utils import CondFunc
|
|||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
|
|
||||||
# has_mps is only available in nightly pytorch (for now) and macOS 12.3+.
|
# before torch version 1.13, has_mps is only available in nightly pytorch and macOS 12.3+,
|
||||||
# check `getattr` and try it for compatibility
|
# use check `getattr` and try it for compatibility.
|
||||||
|
# in torch version 1.13, backends.mps.is_available() and backends.mps.is_built() are introduced in to check mps availabilty,
|
||||||
|
# since torch 2.0.1+ nightly build, getattr(torch, 'has_mps', False) was deprecated, see https://github.com/pytorch/pytorch/pull/103279
|
||||||
def check_for_mps() -> bool:
|
def check_for_mps() -> bool:
|
||||||
if not getattr(torch, 'has_mps', False):
|
if version.parse(torch.__version__) <= version.parse("2.0.1"):
|
||||||
return False
|
if not getattr(torch, 'has_mps', False):
|
||||||
try:
|
return False
|
||||||
torch.zeros(1).to(torch.device("mps"))
|
try:
|
||||||
return True
|
torch.zeros(1).to(torch.device("mps"))
|
||||||
except Exception:
|
return True
|
||||||
return False
|
except Exception:
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
return torch.backends.mps.is_available() and torch.backends.mps.is_built()
|
||||||
has_mps = check_for_mps()
|
has_mps = check_for_mps()
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user