-
Story
-
Resolution: Unresolved
-
Undefined
-
None
-
None
-
None
-
-
- 🐛 Describe the bug
-
When using torch.compile with the inductor backend to compile a model with `add_(xx).reshape(xx)` on a transposed tensor, the compile model outputs incorrect outputs compared to eager execution.
Here is the code to reproduce:
```python
import torch
torch._inductor.config.fallback_random = True
def fn(x, y):
return x.add_
.reshape(-1, 1, 3)
torch.manual_seed(0)
x1 = torch.ones([1, 2, 3]).transpose(1, 2)
y1 = torch.randn([1, 3, 1])
print(x1, y1)
'''
tensor([[[1., 1.],
[1., 1.],
[1., 1.]]]),
tensor([[[ 1.5410],
[-0.2934],
[-2.1788]]])
'''
x2 = x1.clone()
y2 = y1.clone()
cfunc = torch.compile(fn, backend='inductor')
out1 = fn(x1, y1)
print(out1)
'''
tensor([[[ 2.5410, 2.5410, 0.7066]],
[[ 0.7066, -1.1788, -1.1788]]])
'''
out2 = cfunc(x2, y2)
print(out2)
'''
tensor([[[ 4.0820, 4.0820, 0.4131]],
[[ 0.4131, -3.3576, -3.3576]]])
'''
torch.testing.assert_close(out1, out2, equal_nan=True)
```
Notes:
1. Issue does not occur when setting backend to `eager` or `aot_eager`.
2. If we change `x1 = torch.ones([1, 2, 3]).transpose(1, 2)` to `x1=torch.ones(1, 3, 2])`. This issue does not occur.
-
-
- Error logs
-
AssertionError: Tensor-likes are not close!
Mismatched elements: 6 / 6 (100.0%)
Greatest absolute difference: 2.1787893772125244 at index (1, 0, 1) (up to 1e-05 allowed)
Greatest relative difference: 0.7102370262145996 at index (0, 0, 2) (up to 1.3e-06 allowed)
-
-
- Versions
-
[pip3] torch==2.11.0.dev20260209+cpu
[pip3] triton==3.6.0
[conda] torch 2.11.0.dev20260209+cpu pypi_0 pypi
[conda] triton 3.6.0 pypi_0 pypi
cc @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @kadeng @muchulee8 @amjames @aakhundov @coconutruben @jataylo