Remove test, use bool tensor fix by default
The test isn't working correctly on macOS 13.3 and the bool tensor fix for cumsum is currently always needed anyway, so enable the fix by default.
This commit is contained in:
parent
27e319dc4f
commit
a4cb96d4ae
@ -23,7 +23,7 @@ def cumsum_fix(input, cumsum_func, *args, **kwargs):
|
|||||||
output_dtype = kwargs.get('dtype', input.dtype)
|
output_dtype = kwargs.get('dtype', input.dtype)
|
||||||
if output_dtype == torch.int64:
|
if output_dtype == torch.int64:
|
||||||
return cumsum_func(input.cpu(), *args, **kwargs).to(input.device)
|
return cumsum_func(input.cpu(), *args, **kwargs).to(input.device)
|
||||||
elif cumsum_needs_bool_fix and output_dtype == torch.bool or cumsum_needs_int_fix and (output_dtype == torch.int8 or output_dtype == torch.int16):
|
elif output_dtype == torch.bool or cumsum_needs_int_fix and (output_dtype == torch.int8 or output_dtype == torch.int16):
|
||||||
return cumsum_func(input.to(torch.int32), *args, **kwargs).to(torch.int64)
|
return cumsum_func(input.to(torch.int32), *args, **kwargs).to(torch.int64)
|
||||||
return cumsum_func(input, *args, **kwargs)
|
return cumsum_func(input, *args, **kwargs)
|
||||||
|
|
||||||
@ -45,7 +45,6 @@ if has_mps:
|
|||||||
CondFunc('torch.Tensor.numpy', lambda orig_func, self, *args, **kwargs: orig_func(self.detach(), *args, **kwargs), lambda _, self, *args, **kwargs: self.requires_grad)
|
CondFunc('torch.Tensor.numpy', lambda orig_func, self, *args, **kwargs: orig_func(self.detach(), *args, **kwargs), lambda _, self, *args, **kwargs: self.requires_grad)
|
||||||
elif version.parse(torch.__version__) > version.parse("1.13.1"):
|
elif version.parse(torch.__version__) > version.parse("1.13.1"):
|
||||||
cumsum_needs_int_fix = not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.ShortTensor([1,1]).to(torch.device("mps")).cumsum(0))
|
cumsum_needs_int_fix = not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.ShortTensor([1,1]).to(torch.device("mps")).cumsum(0))
|
||||||
cumsum_needs_bool_fix = not torch.BoolTensor([True,True]).to(device=torch.device("mps"), dtype=torch.int64).equal(torch.BoolTensor([True,False]).to(torch.device("mps")).cumsum(0))
|
|
||||||
cumsum_fix_func = lambda orig_func, input, *args, **kwargs: cumsum_fix(input, orig_func, *args, **kwargs)
|
cumsum_fix_func = lambda orig_func, input, *args, **kwargs: cumsum_fix(input, orig_func, *args, **kwargs)
|
||||||
CondFunc('torch.cumsum', cumsum_fix_func, None)
|
CondFunc('torch.cumsum', cumsum_fix_func, None)
|
||||||
CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None)
|
CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None)
|
||||||
|
Loading…
Reference in New Issue
Block a user