Remove test, use bool tensor fix by default

The test isn't working correctly on macOS 13.3 and the bool tensor fix for cumsum is currently always needed anyway, so enable the fix by default.
This commit is contained in:
brkirch 2023-03-11 17:35:17 -05:00
parent 27e319dc4f
commit a4cb96d4ae

View File

@ -23,7 +23,7 @@ def cumsum_fix(input, cumsum_func, *args, **kwargs):
output_dtype = kwargs.get('dtype', input.dtype) output_dtype = kwargs.get('dtype', input.dtype)
if output_dtype == torch.int64: if output_dtype == torch.int64:
return cumsum_func(input.cpu(), *args, **kwargs).to(input.device) return cumsum_func(input.cpu(), *args, **kwargs).to(input.device)
elif cumsum_needs_bool_fix and output_dtype == torch.bool or cumsum_needs_int_fix and (output_dtype == torch.int8 or output_dtype == torch.int16): elif output_dtype == torch.bool or cumsum_needs_int_fix and (output_dtype == torch.int8 or output_dtype == torch.int16):
return cumsum_func(input.to(torch.int32), *args, **kwargs).to(torch.int64) return cumsum_func(input.to(torch.int32), *args, **kwargs).to(torch.int64)
return cumsum_func(input, *args, **kwargs) return cumsum_func(input, *args, **kwargs)
@ -45,7 +45,6 @@ if has_mps:
CondFunc('torch.Tensor.numpy', lambda orig_func, self, *args, **kwargs: orig_func(self.detach(), *args, **kwargs), lambda _, self, *args, **kwargs: self.requires_grad) CondFunc('torch.Tensor.numpy', lambda orig_func, self, *args, **kwargs: orig_func(self.detach(), *args, **kwargs), lambda _, self, *args, **kwargs: self.requires_grad)
elif version.parse(torch.__version__) > version.parse("1.13.1"): elif version.parse(torch.__version__) > version.parse("1.13.1"):
cumsum_needs_int_fix = not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.ShortTensor([1,1]).to(torch.device("mps")).cumsum(0)) cumsum_needs_int_fix = not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.ShortTensor([1,1]).to(torch.device("mps")).cumsum(0))
cumsum_needs_bool_fix = not torch.BoolTensor([True,True]).to(device=torch.device("mps"), dtype=torch.int64).equal(torch.BoolTensor([True,False]).to(torch.device("mps")).cumsum(0))
cumsum_fix_func = lambda orig_func, input, *args, **kwargs: cumsum_fix(input, orig_func, *args, **kwargs) cumsum_fix_func = lambda orig_func, input, *args, **kwargs: cumsum_fix(input, orig_func, *args, **kwargs)
CondFunc('torch.cumsum', cumsum_fix_func, None) CondFunc('torch.cumsum', cumsum_fix_func, None)
CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None) CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None)