Use other MPS optimization for large q.shape[0] * q.shape[1]

Check if q.shape[0] * q.shape[1] is 2**18 or larger and use the lower memory usage MPS optimization if it is. This should prevent most crashes that were occurring at certain resolutions (e.g. 1024x1024, 2048x512, 512x2048).

Also included is a change to check slice_size and prevent it from being divisible by 4096 which also results in a crash. Otherwise a crash can occur at 1024x512 or 512x1024 resolution.
This commit is contained in:
brkirch 2022-12-19 17:25:14 -05:00
parent 685f9631b5
commit 35b1775b32

View File

@ -127,7 +127,7 @@ def check_for_psutil():
invokeAI_mps_available = check_for_psutil() invokeAI_mps_available = check_for_psutil()
# -- Taken from https://github.com/invoke-ai/InvokeAI -- # -- Taken from https://github.com/invoke-ai/InvokeAI and modified --
if invokeAI_mps_available: if invokeAI_mps_available:
import psutil import psutil
mem_total_gb = psutil.virtual_memory().total // (1 << 30) mem_total_gb = psutil.virtual_memory().total // (1 << 30)
@ -152,14 +152,16 @@ def einsum_op_slice_1(q, k, v, slice_size):
return r return r
def einsum_op_mps_v1(q, k, v): def einsum_op_mps_v1(q, k, v):
if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096 if q.shape[0] * q.shape[1] <= 2**16: # (512x512) max q.shape[1]: 4096
return einsum_op_compvis(q, k, v) return einsum_op_compvis(q, k, v)
else: else:
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1])) slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
if slice_size % 4096 == 0:
slice_size -= 1
return einsum_op_slice_1(q, k, v, slice_size) return einsum_op_slice_1(q, k, v, slice_size)
def einsum_op_mps_v2(q, k, v): def einsum_op_mps_v2(q, k, v):
if mem_total_gb > 8 and q.shape[1] <= 4096: if mem_total_gb > 8 and q.shape[0] * q.shape[1] <= 2**16:
return einsum_op_compvis(q, k, v) return einsum_op_compvis(q, k, v)
else: else:
return einsum_op_slice_0(q, k, v, 1) return einsum_op_slice_0(q, k, v, 1)
@ -188,7 +190,7 @@ def einsum_op(q, k, v):
return einsum_op_cuda(q, k, v) return einsum_op_cuda(q, k, v)
if q.device.type == 'mps': if q.device.type == 'mps':
if mem_total_gb >= 32: if mem_total_gb >= 32 and q.shape[0] % 32 != 0 and q.shape[0] * q.shape[1] < 2**18:
return einsum_op_mps_v1(q, k, v) return einsum_op_mps_v1(q, k, v)
return einsum_op_mps_v2(q, k, v) return einsum_op_mps_v2(q, k, v)