Uploaded image for project: 'AI Platform Core Components'
  1. AI Platform Core Components
  2. AIPCC-8914

Data dependent error in torch.compile

    • Icon: Story Story
    • Resolution: Done
    • Icon: Undefined Undefined
    • None
    • None
    • PyTorch
    • PyTorch Sprint 24, PyTorch Sprint 25

          1. 🐛 Describe the bug

      ```python
      import torch

      torch.set_default_device("cuda")

      def test_dynamo_trace_vmap_indexing():
      data = torch.arange(20).reshape(2, 10) # [B, N]

      def vmap_index_fn(data_in, b_indices, n_indices):
      def index_fn(b, n):
      return data_in[b, n]

      return torch.vmap(index_fn, in_dims=(None, 0))(
      b_indices, n_indices
      )

      b_indices = torch.arange(2)
      n_indices = torch.arange(10)

      torch.compile(vmap_index_fn, fullgraph=True)(data, b_indices, n_indices)

      test_dynamo_trace_vmap_indexing()
      ```
      Errors with:
      ```
      TorchRuntimeError: Dynamo failed to run FX node with fake tensors: call_function <built-in function getitem>(*(FakeTensor(..., device='cuda:0', size=(2, 10), dtype=torch.int64), (BatchedTensor(lvl=1, bdim=0, value=
      FakeTensor(..., device='cuda:0', size=(2,), dtype=torch.int64)
      ), BatchedTensor(lvl=2, bdim=0, value=
      FakeTensor(..., device='cuda:0', size=(10,), dtype=torch.int64)
      ))), **{}): got RuntimeError("vmap: It looks like you're calling .item() on a Tensor. We don't support vmap over calling .item() on a Tensor, please try to rewrite what you're doing with other operations. If error is occurring somewhere inside PyTorch internals, please file a bug report.")

      from user code:
      File "/tmp/ipykernel_402368/1714913198.py", line 22, in vmap_index_fn
      return torch.vmap(torch.vmap(index_fn, in_dims=(None, 0)), in_dims=(0, None))(
      File "/data/users/tmanlaibaatar/.bento/kernels/bento_kernel_pytorch/2742/bento_kernel_pytorch_binary-inplace#link-tree/torch/_functorch/apis.py", line 208, in wrapped
      return vmap_impl(
      File "/data/users/tmanlaibaatar/.bento/kernels/bento_kernel_pytorch/2742/bento_kernel_pytorch_binary-inplace#link-tree/torch/_functorch/vmap.py", line 283, in vmap_impl
      return _flat_vmap(
      File "/data/users/tmanlaibaatar/.bento/kernels/bento_kernel_pytorch/2742/bento_kernel_pytorch_binary-inplace#link-tree/torch/_functorch/vmap.py", line 433, in _flat_vmap
      batched_outputs = func(*batched_inputs, **kwargs)
      File "/data/users/tmanlaibaatar/.bento/kernels/bento_kernel_pytorch/2742/bento_kernel_pytorch_binary-inplace#link-tree/torch/_functorch/apis.py", line 208, in wrapped
      return vmap_impl(
      File "/data/users/tmanlaibaatar/.bento/kernels/bento_kernel_pytorch/2742/bento_kernel_pytorch_binary-inplace#link-tree/torch/_functorch/vmap.py", line 283, in vmap_impl
      return _flat_vmap(
      File "/data/users/tmanlaibaatar/.bento/kernels/bento_kernel_pytorch/2742/bento_kernel_pytorch_binary-inplace#link-tree/torch/_functorch/vmap.py", line 433, in _flat_vmap
      batched_outputs = func(*batched_inputs, **kwargs)
      File "/tmp/ipykernel_402368/72028094.py", line 20, in index_fn
      return data_in[b, n]
      File "/data/users/tmanlaibaatar/.bento/kernels/bento_kernel_pytorch/2742/bento_kernel_pytorch_binary-inplace#link-tree/torch/utils/device.py", line 109, in __torch_function_
      return func(*args, **kwargs)

      Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"
      ```

          1. Versions

      main

      cc @chauhang @penguinwu @Chillee @samdow @kshitij12345 @ezyang @bobrenjc93

              rh-ee-parsshar Parshant Sharma
              rh-ee-parsshar Parshant Sharma
              PyTorch Compile
              Votes:
              0 Vote for this issue
              Watchers:
              2 Start watching this issue

                Created:
                Updated:
                Resolved: