-
Story
-
Resolution: Done
-
Undefined
-
None
-
None
-
-
- 🐛 Describe the bug
-
```python
import torch
torch.set_default_device("cuda")
def test_dynamo_trace_vmap_indexing():
data = torch.arange(20).reshape(2, 10) # [B, N]
def vmap_index_fn(data_in, b_indices, n_indices):
def index_fn(b, n):
return data_in[b, n]
return torch.vmap(index_fn, in_dims=(None, 0))(
b_indices, n_indices
)
b_indices = torch.arange(2)
n_indices = torch.arange(10)
torch.compile(vmap_index_fn, fullgraph=True)(data, b_indices, n_indices)
test_dynamo_trace_vmap_indexing()
```
Errors with:
```
TorchRuntimeError: Dynamo failed to run FX node with fake tensors: call_function <built-in function getitem>(*(FakeTensor(..., device='cuda:0', size=(2, 10), dtype=torch.int64), (BatchedTensor(lvl=1, bdim=0, value=
FakeTensor(..., device='cuda:0', size=(2,), dtype=torch.int64)
), BatchedTensor(lvl=2, bdim=0, value=
FakeTensor(..., device='cuda:0', size=(10,), dtype=torch.int64)
))), **{}): got RuntimeError("vmap: It looks like you're calling .item() on a Tensor. We don't support vmap over calling .item() on a Tensor, please try to rewrite what you're doing with other operations. If error is occurring somewhere inside PyTorch internals, please file a bug report.")
from user code:
File "/tmp/ipykernel_402368/1714913198.py", line 22, in vmap_index_fn
return torch.vmap(torch.vmap(index_fn, in_dims=(None, 0)), in_dims=(0, None))(
File "/data/users/tmanlaibaatar/.bento/kernels/bento_kernel_pytorch/2742/bento_kernel_pytorch_binary-inplace#link-tree/torch/_functorch/apis.py", line 208, in wrapped
return vmap_impl(
File "/data/users/tmanlaibaatar/.bento/kernels/bento_kernel_pytorch/2742/bento_kernel_pytorch_binary-inplace#link-tree/torch/_functorch/vmap.py", line 283, in vmap_impl
return _flat_vmap(
File "/data/users/tmanlaibaatar/.bento/kernels/bento_kernel_pytorch/2742/bento_kernel_pytorch_binary-inplace#link-tree/torch/_functorch/vmap.py", line 433, in _flat_vmap
batched_outputs = func(*batched_inputs, **kwargs)
File "/data/users/tmanlaibaatar/.bento/kernels/bento_kernel_pytorch/2742/bento_kernel_pytorch_binary-inplace#link-tree/torch/_functorch/apis.py", line 208, in wrapped
return vmap_impl(
File "/data/users/tmanlaibaatar/.bento/kernels/bento_kernel_pytorch/2742/bento_kernel_pytorch_binary-inplace#link-tree/torch/_functorch/vmap.py", line 283, in vmap_impl
return _flat_vmap(
File "/data/users/tmanlaibaatar/.bento/kernels/bento_kernel_pytorch/2742/bento_kernel_pytorch_binary-inplace#link-tree/torch/_functorch/vmap.py", line 433, in _flat_vmap
batched_outputs = func(*batched_inputs, **kwargs)
File "/tmp/ipykernel_402368/72028094.py", line 20, in index_fn
return data_in[b, n]
File "/data/users/tmanlaibaatar/.bento/kernels/bento_kernel_pytorch/2742/bento_kernel_pytorch_binary-inplace#link-tree/torch/utils/device.py", line 109, in __torch_function_
return func(*args, **kwargs)
Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"
```
-
-
- Versions
-
main
cc @chauhang @penguinwu @Chillee @samdow @kshitij12345 @ezyang @bobrenjc93