Merge pull request #8518 from brkirch/remove-bool-test
Fix image generation on macOS 13.3 betas
This commit is contained in:
commit
bbc4b0478a
@ -23,7 +23,7 @@ def cumsum_fix(input, cumsum_func, *args, **kwargs):
|
|||||||
output_dtype = kwargs.get('dtype', input.dtype)
|
output_dtype = kwargs.get('dtype', input.dtype)
|
||||||
if output_dtype == torch.int64:
|
if output_dtype == torch.int64:
|
||||||
return cumsum_func(input.cpu(), *args, **kwargs).to(input.device)
|
return cumsum_func(input.cpu(), *args, **kwargs).to(input.device)
|
||||||
elif cumsum_needs_bool_fix and output_dtype == torch.bool or cumsum_needs_int_fix and (output_dtype == torch.int8 or output_dtype == torch.int16):
|
elif output_dtype == torch.bool or cumsum_needs_int_fix and (output_dtype == torch.int8 or output_dtype == torch.int16):
|
||||||
return cumsum_func(input.to(torch.int32), *args, **kwargs).to(torch.int64)
|
return cumsum_func(input.to(torch.int32), *args, **kwargs).to(torch.int64)
|
||||||
return cumsum_func(input, *args, **kwargs)
|
return cumsum_func(input, *args, **kwargs)
|
||||||
|
|
||||||
@ -45,7 +45,6 @@ if has_mps:
|
|||||||
CondFunc('torch.Tensor.numpy', lambda orig_func, self, *args, **kwargs: orig_func(self.detach(), *args, **kwargs), lambda _, self, *args, **kwargs: self.requires_grad)
|
CondFunc('torch.Tensor.numpy', lambda orig_func, self, *args, **kwargs: orig_func(self.detach(), *args, **kwargs), lambda _, self, *args, **kwargs: self.requires_grad)
|
||||||
elif version.parse(torch.__version__) > version.parse("1.13.1"):
|
elif version.parse(torch.__version__) > version.parse("1.13.1"):
|
||||||
cumsum_needs_int_fix = not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.ShortTensor([1,1]).to(torch.device("mps")).cumsum(0))
|
cumsum_needs_int_fix = not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.ShortTensor([1,1]).to(torch.device("mps")).cumsum(0))
|
||||||
cumsum_needs_bool_fix = not torch.BoolTensor([True,True]).to(device=torch.device("mps"), dtype=torch.int64).equal(torch.BoolTensor([True,False]).to(torch.device("mps")).cumsum(0))
|
|
||||||
cumsum_fix_func = lambda orig_func, input, *args, **kwargs: cumsum_fix(input, orig_func, *args, **kwargs)
|
cumsum_fix_func = lambda orig_func, input, *args, **kwargs: cumsum_fix(input, orig_func, *args, **kwargs)
|
||||||
CondFunc('torch.cumsum', cumsum_fix_func, None)
|
CondFunc('torch.cumsum', cumsum_fix_func, None)
|
||||||
CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None)
|
CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None)
|
||||||
|
Loading…
Reference in New Issue
Block a user