Uploaded image for project: 'AI Platform Core Components'
  1. AI Platform Core Components
  2. AIPCC-8330

NotImplementedError when using AsyncCollectiveTensor with torch.func.jvp

    • Icon: Bug Bug
    • Resolution: Unresolved
    • Icon: Undefined Undefined
    • None
    • None
    • PyTorch
    • None
    • PyTorch Sprint 23

          1. 🐛 Describe the bug
        1. Summary

      When using `torch.func.jvp` with a custom `autograd.Function` that returns an `AsyncCollectiveTensor` (via `maybe_wrap_tensor`), PyTorch raises a `NotImplementedError: Cannot access storage of TensorWrapper`. This occurs because `AsyncCollectiveTensor.torch_dispatch_` attempts to unwrap functorch's `TensorWrapper` during JVP execution, but `wait_tensor` cannot access the storage of wrapped tensors.

      Similar to https://github.com/pytorch/pytorch/issues/161943 https://github.com/pytorch/pytorch/issues/138422

        1. Code to Reproduce

      ```python
      import torch
      from torch.distributed._functional_collectives import _maybe_wrap_tensor

      inp = torch.randn(1)

      class F(torch.autograd.Function):
      @staticmethod
      def forward(input):
      return input
      @staticmethod
      def backward(ctx, grad_output):
      return grad_output
      @staticmethod
      def setup_context(ctx, inputs, output):
      pass
      @staticmethod
      def jvp(ctx, input_tangent):
      return _maybe_wrap_tensor(input_tangent)

      torch.func.jvp(F.apply, (inp,), (inp,))
      ```

        1. Error Traceback

      ```
      NotImplementedError Traceback (most recent call last)
      Cell In[14], line 20
      16 @staticmethod
      17 def jvp(ctx, input_tangent):
      18 return _maybe_wrap_tensor(input_tangent)
      ---> 20 torch.func.jvp(F.apply, (inp,), (inp,))

      File /opt/python3.12/lib/python3.12/site-packages/torch/_functorch/eager_transforms.py:1042, in jvp(func, primals, tangents, strict, has_aux)
      983 @exposed_in("torch.func")
      984 def jvp(
      985 func: Callable,
      (...) 990 has_aux: bool = False,
      991 ):
      992 """
      993 Standing for the Jacobian-vector product, returns a tuple containing
      994 the output of `func(*primals)` and the "Jacobian of ``func`` evaluated at
      (...) 1039
      1040 """
      -> 1042 return _jvp_with_argnums(
      1043 func, primals, tangents, argnums=None, strict=strict, has_aux=has_aux
      1044 )

      File /opt/python3.12/lib/python3.12/site-packages/torch/_functorch/eager_transforms.py:1101, in _jvp_with_argnums(func, primals, tangents, argnums, strict, has_aux)
      1099 primals = _wrap_all_tensors(primals, level)
      1100 duals = _replace_args(primals, duals, argnums)
      -> 1101 result_duals = func(*duals)
      1102 if has_aux:
      1103 if not (isinstance(result_duals, tuple) and len(result_duals) == 2):

      File /opt/python3.12/lib/python3.12/site-packages/torch/autograd/function.py:591, in Function.apply(cls, *args, **kwargs)
      583 if not is_setup_ctx_defined:
      584 raise RuntimeError(
      585 "In order to use an autograd.Function with functorch transforms "
      586 "(vmap, grad, jvp, jacrev, ...), it must override the setup_context "
      587 "staticmethod. For more details, please see "
      588 "https://pytorch.org/docs/main/notes/extending.func.html"
      589 )
      --> 591 return custom_function_call(cls, *args, **kwargs)

      File /opt/python3.12/lib/python3.12/site-packages/torch/functorch/autograd_function.py:49, in CustomFunctionHigherOrderOperator.call_(self, autograd_function, *args, **kwargs)
      36 def _call_(self, autograd_function, *args, **kwargs):
      37 # When custom_function_call is done dispatching through functorch,
      38 # it should just invoke the autograd.Function. This is consistent
      (...) 46 # (because autograd.Function happens before the Python dispatch key)
      47 # and only traces the forward pass.
      48 if torch._C._are_functorch_transforms_active():
      ---> 49 return super()._call_(autograd_function, *args, **kwargs)
      50 return autograd_function.apply(*args, **kwargs)

      File /opt/python3.12/lib/python3.12/site-packages/torch/ops.py:536, in HigherOrderOperator.call_(self, *args, **kwargs)
      531 dispatch_key_set = _compute_keyset(args, kwargs, self.non_fallthrough_keys)
      532 return self.dispatch(
      533 dispatch_key_set.highestPriorityTypeId(), *args, **kwargs
      534 )
      --> 536 return wrapper()

      File /opt/python3.12/lib/python3.12/site-packages/torch/ops.py:532, in HigherOrderOperator.call_.<locals>.wrapper()
      527 return torch.overrides.handle_torch_function(
      528 self, flat_args, *args, **kwargs
      529 )
      531 dispatch_key_set = _compute_keyset(args, kwargs, self.non_fallthrough_keys)
      --> 532 return self.dispatch(
      533 dispatch_key_set.highestPriorityTypeId(), *args, **kwargs
      534 )

      File /opt/python3.12/lib/python3.12/site-packages/torch/_ops.py:384, in HigherOrderOperator.dispatch(self, dispatch_key, *args, **kwargs)
      381 return kernel(*args, **kwargs)
      383 if dispatch_key == DispatchKey.FuncTorchDynamicLayerFrontMode:
      --> 384 return dispatch_functorch(self, args, kwargs)
      386 if dispatch_key == DispatchKey.Python:
      387 # Keep the following 1:1 with handle_torch_function_no_python_arg_parser
      388 # in torch/csrc/utils/python_arg_parser.cpp
      390 overloaded_args_list = []

      File /opt/python3.12/lib/python3.12/site-packages/torch/_functorch/pyfunctorch.py:312, in dispatch_functorch(op, args, kwargs)
      304 # In traditional PyTorch operators, DispatchKey::FuncTorchTensorWrapper's
      305 # unwrap_dead_tensors fallback handles unwrapping dead tensor wrappers.
      306 # PyDispatcher sidesteps the PyTorch dispatcher when dealing with functorch
      307 # transforms, so we manually unwrap the dead tensors here.
      308 # This logic won't need to exist when we have mode-only functorch.
      309 args, kwargs = pytree.tree_map_only(
      310 torch.Tensor, torch._C._functorch.unwrap_if_dead, (args, kwargs)
      311 )
      --> 312 return interpreter.process(op, args, kwargs)

      File /opt/python3.12/lib/python3.12/site-packages/torch/_functorch/pyfunctorch.py:222, in JvpInterpreter.process(self, op, args, kwargs)
      220 kernel = op.functorch_table[TransformType.Jvp]
      221 args, kwargs = self.lift(args, kwargs)
      --> 222 return kernel(self, *args, **kwargs)

      File /opt/python3.12/lib/python3.12/site-packages/torch/_functorch/autograd_function.py:93, in custom_function_call_grad(interpreter, autograd_function, *operands)
      91 Generated = generate_single_level_function(interpreter, autograd_function)
      92 with enable_single_level_autograd_function():
      ---> 93 flat_out = Generated.apply(*operands)
      94 return flat_out

      File /opt/python3.12/lib/python3.12/site-packages/torch/distributed/functional_collectives.py:654, in AsyncCollectiveTensor.torch_dispatch_(cls, func, types, args, kwargs)
      651 res = AsyncCollectiveTensor(e)
      652 return res
      --> 654 unwrapped_args = tree_map_only(AsyncCollectiveTensor, unwrap, args)
      655 unwrapped_kwargs = tree_map_only(AsyncCollectiveTensor, unwrap, kwargs)
      657 # we don't wrap the result as it doesn't need to be waited on.

      File /opt/python3.12/lib/python3.12/site-packages/torch/utils/_cxx_pytree.py:716, in tree_map_only(type_or_types_or_pred, func, tree, is_leaf)
      709 def tree_map_only(
      710 type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]],
      711 /,
      (...) 714 is_leaf: Optional[Callable[[PyTree], bool]] = None,
      715 ) -> PyTree:
      --> 716 return tree_map(map_only(type_or_types_or_pred)(func), tree, is_leaf=is_leaf)

      File /opt/python3.12/lib/python3.12/site-packages/torch/utils/_cxx_pytree.py:523, in tree_map(func, tree, is_leaf, *rests)
      484 def tree_map(
      485 func: Callable[..., Any],
      486 tree: PyTree,
      487 *rests: PyTree,
      488 is_leaf: Optional[Callable[[PyTree], bool]] = None,
      489 ) -> PyTree:
      490 """Map a multi-input function over pytree args to produce a new pytree.
      491
      492 See also :func:`tree_map_`.
      (...) 521 is the tuple of values at corresponding nodes in ``rests``.
      522 """
      --> 523 return optree.tree_map(
      524 func,
      525 tree,
      526 *rests,
      527 is_leaf=is_leaf,
      528 none_is_leaf=True,
      529 namespace="torch",
      530 )

      File /opt/python3.12/lib/python3.12/site-packages/optree/ops.py:766, in tree_map(func, tree, is_leaf, none_is_leaf, namespace, *rests)
      764 leaves, treespec = _C.flatten(tree, is_leaf, none_is_leaf, namespace)
      765 flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests]
      --> 766 return treespec.unflatten(map(func, *flat_args))

      File /opt/python3.12/lib/python3.12/site-packages/torch/utils/_cxx_pytree.py:651, in map_only.<locals>.wrapper.<locals>.wrapped
      648 @functools.wraps(func)
      649 def wrapped(x: T) -> Any:
      650 if pred:
      --> 651 return func
      652 return x

      File /opt/python3.12/lib/python3.12/site-packages/torch/distributed/functional_collectives.py:645, in AsyncCollectiveTensor.torch_dispatch_.<locals>.unwrap(e)
      642 def unwrap(e: AsyncCollectiveTensor):
      643 # wait_tensor is idepotent and will do stream sync only once
      644 if not is_view_op:
      --> 645 return e.trigger_wait()
      646 return e.elem

      File /opt/python3.12/lib/python3.12/site-packages/torch/distributed/_functional_collectives.py:618, in AsyncCollectiveTensor.trigger_wait(self)
      616 def trigger_wait(self):
      617 if not self.completed:
      --> 618 out = wait_tensor(self.elem)
      619 self.completed = True
      620 return out

      File /opt/python3.12/lib/python3.12/site-packages/torch/distributed/_functional_collectives.py:135, in wait_tensor(tensor)
      129 def wait_tensor(tensor):
      130 """
      131 Wait on a tensor returned by the collectives ops.
      132
      133 Waiting follows device semantics, which means blocking on CPU and synchronizing streams on CUDA.
      134 """
      --> 135 return torch.ops._c10d_functional.wait_tensor(tensor)

      File /opt/python3.12/lib/python3.12/site-packages/torch/ops.py:1255, in OpOverloadPacket.call_(self, *args, **kwargs)
      1253 if self._has_torchbind_op_overload and _must_dispatch_in_python(args, kwargs):
      1254 return _call_overload_packet_from_python(self, *args, **kwargs)
      -> 1255 return self._op(*args, **kwargs)

      NotImplementedError: Cannot access storage of TensorWrapper
      ```

          1. Versions

      PyTorch 2.9.1

      cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @msaroufim @dcci @Chillee @samdow @kshitij12345 @guilhermeleobas

              rh-ee-amaitra Arkadip Maitra
              rh-ee-amaitra Arkadip Maitra
              PyTorch Distributed
              Votes:
              0 Vote for this issue
              Watchers:
              2 Start watching this issue

                Created:
                Updated: