-
Story
-
Resolution: Unresolved
-
Undefined
-
None
-
None
-
None
Repro:
cmd: `torchrun --nproc_per_node=2 --local-ranks-filter=0 nccl_pytorch_demo.py`
```python
import torch
import torch.distributed as dist
dist.init_process_group(backend="nccl")
rank = dist.get_rank()
torch.cuda.set_device(rank)
def fn
:
x = x.t()
y = torch.ops._c10d_functional.all_reduce.default(x, 'sum', '0')
z = torch.ops._c10d_functional.wait_tensor.default![]()
return y, z
eager_x = torch.randn((16384, 256), device="cuda")
eager_y, eager_z = fn(eager_x)
compiled_f = torch.compile(fn)
compiled_y, compiled_z = compiled_f(eager_x)
assert compiled_y.stride == eager_y.stride, f"Compiled y stride
{compiled_y.stride()}does not match eager y stride
{eager_y.stride()}"
- AssertionError: Compiled y stride (1, 256) does not match eager y stride (16384, 1)
assert compiled_z.stride == eager_z.stride, f"Compiled z stride
{compiled_z.stride()}does not match eager z stride
{eager_z.stride()}"
dist.destroy_process_group()
```
cc @ezyang @gchanan @kadeng @msaroufim @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @muchulee8 @amjames @aakhundov @coconutruben