From 8f6b24ce5922174d96eb9776126488cb28694ff8 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Wed, 12 Jul 2023 15:16:42 +0300 Subject: [PATCH 1/2] Add correct logger name --- modules/mac_specific.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/mac_specific.py b/modules/mac_specific.py index 2c2f15ca..328b5973 100644 --- a/modules/mac_specific.py +++ b/modules/mac_specific.py @@ -5,7 +5,7 @@ import platform from modules.sd_hijack_utils import CondFunc from packaging import version -log = logging.getLogger() +log = logging.getLogger(__name__) # before torch version 1.13, has_mps is only available in nightly pytorch and macOS 12.3+, From 3d524fd3f1bdb17946bf6fa8a3cdf7b10859c495 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Wed, 12 Jul 2023 15:17:13 +0300 Subject: [PATCH 2/2] Don't do MPS GC when there's a latent that could still be sampled --- modules/mac_specific.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/modules/mac_specific.py b/modules/mac_specific.py index 328b5973..9ceb43ba 100644 --- a/modules/mac_specific.py +++ b/modules/mac_specific.py @@ -30,6 +30,10 @@ has_mps = check_for_mps() def torch_mps_gc() -> None: try: + from modules.shared import state + if state.current_latent is not None: + log.debug("`current_latent` is set, skipping MPS garbage collection") + return from torch.mps import empty_cache empty_cache() except Exception: