-
Bug
-
Resolution: Unresolved
-
Undefined
-
None
-
None
-
None
-
-
- 🐛 Describe the bug
-
-
- Summary
When using `torch.func.jvp` with a custom `autograd.Function` that returns an `AsyncCollectiveTensor` (via `maybe_wrap_tensor`), PyTorch raises a `NotImplementedError: Cannot access storage of TensorWrapper`. This occurs because `AsyncCollectiveTensor.torch_dispatch_` attempts to unwrap functorch's `TensorWrapper` during JVP execution, but `wait_tensor` cannot access the storage of wrapped tensors.
Similar to https://github.com/pytorch/pytorch/issues/161943 https://github.com/pytorch/pytorch/issues/138422
-
- Code to Reproduce
```python
import torch
from torch.distributed._functional_collectives import _maybe_wrap_tensor
inp = torch.randn(1)
class F(torch.autograd.Function):
@staticmethod
def forward(input):
return input
@staticmethod
def backward(ctx, grad_output):
return grad_output
@staticmethod
def setup_context(ctx, inputs, output):
pass
@staticmethod
def jvp(ctx, input_tangent):
return _maybe_wrap_tensor(input_tangent)
torch.func.jvp(F.apply, (inp,), (inp,))
```
-
- Error Traceback
```
NotImplementedError Traceback (most recent call last)
Cell In[14], line 20
16 @staticmethod
17 def jvp(ctx, input_tangent):
18 return _maybe_wrap_tensor(input_tangent)
---> 20 torch.func.jvp(F.apply, (inp,), (inp,))
File /opt/python3.12/lib/python3.12/site-packages/torch/_functorch/eager_transforms.py:1042, in jvp(func, primals, tangents, strict, has_aux)
983 @exposed_in("torch.func")
984 def jvp(
985 func: Callable,
(...) 990 has_aux: bool = False,
991 ):
992 """
993 Standing for the Jacobian-vector product, returns a tuple containing
994 the output of `func(*primals)` and the "Jacobian of ``func`` evaluated at
(...) 1039
1040 """
-> 1042 return _jvp_with_argnums(
1043 func, primals, tangents, argnums=None, strict=strict, has_aux=has_aux
1044 )
File /opt/python3.12/lib/python3.12/site-packages/torch/_functorch/eager_transforms.py:1101, in _jvp_with_argnums(func, primals, tangents, argnums, strict, has_aux)
1099 primals = _wrap_all_tensors(primals, level)
1100 duals = _replace_args(primals, duals, argnums)
-> 1101 result_duals = func(*duals)
1102 if has_aux:
1103 if not (isinstance(result_duals, tuple) and len(result_duals) == 2):
File /opt/python3.12/lib/python3.12/site-packages/torch/autograd/function.py:591, in Function.apply(cls, *args, **kwargs)
583 if not is_setup_ctx_defined:
584 raise RuntimeError(
585 "In order to use an autograd.Function with functorch transforms "
586 "(vmap, grad, jvp, jacrev, ...), it must override the setup_context "
587 "staticmethod. For more details, please see "
588 "https://pytorch.org/docs/main/notes/extending.func.html"
589 )
--> 591 return custom_function_call(cls, *args, **kwargs)
File /opt/python3.12/lib/python3.12/site-packages/torch/functorch/autograd_function.py:49, in CustomFunctionHigherOrderOperator.call_(self, autograd_function, *args, **kwargs)
36 def _call_(self, autograd_function, *args, **kwargs):
37 # When custom_function_call is done dispatching through functorch,
38 # it should just invoke the autograd.Function. This is consistent
(...) 46 # (because autograd.Function happens before the Python dispatch key)
47 # and only traces the forward pass.
48 if torch._C._are_functorch_transforms_active():
---> 49 return super()._call_(autograd_function, *args, **kwargs)
50 return autograd_function.apply(*args, **kwargs)
File /opt/python3.12/lib/python3.12/site-packages/torch/ops.py:536, in HigherOrderOperator.call_(self, *args, **kwargs)
531 dispatch_key_set = _compute_keyset(args, kwargs, self.non_fallthrough_keys)
532 return self.dispatch(
533 dispatch_key_set.highestPriorityTypeId(), *args, **kwargs
534 )
--> 536 return wrapper()
File /opt/python3.12/lib/python3.12/site-packages/torch/ops.py:532, in HigherOrderOperator.call_.<locals>.wrapper()
527 return torch.overrides.handle_torch_function(
528 self, flat_args, *args, **kwargs
529 )
531 dispatch_key_set = _compute_keyset(args, kwargs, self.non_fallthrough_keys)
--> 532 return self.dispatch(
533 dispatch_key_set.highestPriorityTypeId(), *args, **kwargs
534 )
File /opt/python3.12/lib/python3.12/site-packages/torch/_ops.py:384, in HigherOrderOperator.dispatch(self, dispatch_key, *args, **kwargs)
381 return kernel(*args, **kwargs)
383 if dispatch_key == DispatchKey.FuncTorchDynamicLayerFrontMode:
--> 384 return dispatch_functorch(self, args, kwargs)
386 if dispatch_key == DispatchKey.Python:
387 # Keep the following 1:1 with handle_torch_function_no_python_arg_parser
388 # in torch/csrc/utils/python_arg_parser.cpp
390 overloaded_args_list = []
File /opt/python3.12/lib/python3.12/site-packages/torch/_functorch/pyfunctorch.py:312, in dispatch_functorch(op, args, kwargs)
304 # In traditional PyTorch operators, DispatchKey::FuncTorchTensorWrapper's
305 # unwrap_dead_tensors fallback handles unwrapping dead tensor wrappers.
306 # PyDispatcher sidesteps the PyTorch dispatcher when dealing with functorch
307 # transforms, so we manually unwrap the dead tensors here.
308 # This logic won't need to exist when we have mode-only functorch.
309 args, kwargs = pytree.tree_map_only(
310 torch.Tensor, torch._C._functorch.unwrap_if_dead, (args, kwargs)
311 )
--> 312 return interpreter.process(op, args, kwargs)
File /opt/python3.12/lib/python3.12/site-packages/torch/_functorch/pyfunctorch.py:222, in JvpInterpreter.process(self, op, args, kwargs)
220 kernel = op.functorch_table[TransformType.Jvp]
221 args, kwargs = self.lift(args, kwargs)
--> 222 return kernel(self, *args, **kwargs)
File /opt/python3.12/lib/python3.12/site-packages/torch/_functorch/autograd_function.py:93, in custom_function_call_grad(interpreter, autograd_function, *operands)
91 Generated = generate_single_level_function(interpreter, autograd_function)
92 with enable_single_level_autograd_function():
---> 93 flat_out = Generated.apply(*operands)
94 return flat_out
File /opt/python3.12/lib/python3.12/site-packages/torch/distributed/functional_collectives.py:654, in AsyncCollectiveTensor.torch_dispatch_(cls, func, types, args, kwargs)
651 res = AsyncCollectiveTensor(e)
652 return res
--> 654 unwrapped_args = tree_map_only(AsyncCollectiveTensor, unwrap, args)
655 unwrapped_kwargs = tree_map_only(AsyncCollectiveTensor, unwrap, kwargs)
657 # we don't wrap the result as it doesn't need to be waited on.
File /opt/python3.12/lib/python3.12/site-packages/torch/utils/_cxx_pytree.py:716, in tree_map_only(type_or_types_or_pred, func, tree, is_leaf)
709 def tree_map_only(
710 type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]],
711 /,
(...) 714 is_leaf: Optional[Callable[[PyTree], bool]] = None,
715 ) -> PyTree:
--> 716 return tree_map(map_only(type_or_types_or_pred)(func), tree, is_leaf=is_leaf)
File /opt/python3.12/lib/python3.12/site-packages/torch/utils/_cxx_pytree.py:523, in tree_map(func, tree, is_leaf, *rests)
484 def tree_map(
485 func: Callable[..., Any],
486 tree: PyTree,
487 *rests: PyTree,
488 is_leaf: Optional[Callable[[PyTree], bool]] = None,
489 ) -> PyTree:
490 """Map a multi-input function over pytree args to produce a new pytree.
491
492 See also :func:`tree_map_`.
(...) 521 is the tuple of values at corresponding nodes in ``rests``.
522 """
--> 523 return optree.tree_map(
524 func,
525 tree,
526 *rests,
527 is_leaf=is_leaf,
528 none_is_leaf=True,
529 namespace="torch",
530 )
File /opt/python3.12/lib/python3.12/site-packages/optree/ops.py:766, in tree_map(func, tree, is_leaf, none_is_leaf, namespace, *rests)
764 leaves, treespec = _C.flatten(tree, is_leaf, none_is_leaf, namespace)
765 flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests]
--> 766 return treespec.unflatten(map(func, *flat_args))
File /opt/python3.12/lib/python3.12/site-packages/torch/utils/_cxx_pytree.py:651, in map_only.<locals>.wrapper.<locals>.wrapped![]()
648 @functools.wraps(func)
649 def wrapped(x: T) -> Any:
650 if pred
:
--> 651 return func![]()
652 return x
File /opt/python3.12/lib/python3.12/site-packages/torch/distributed/functional_collectives.py:645, in AsyncCollectiveTensor.torch_dispatch_.<locals>.unwrap(e)
642 def unwrap(e: AsyncCollectiveTensor):
643 # wait_tensor is idepotent and will do stream sync only once
644 if not is_view_op:
--> 645 return e.trigger_wait()
646 return e.elem
File /opt/python3.12/lib/python3.12/site-packages/torch/distributed/_functional_collectives.py:618, in AsyncCollectiveTensor.trigger_wait(self)
616 def trigger_wait(self):
617 if not self.completed:
--> 618 out = wait_tensor(self.elem)
619 self.completed = True
620 return out
File /opt/python3.12/lib/python3.12/site-packages/torch/distributed/_functional_collectives.py:135, in wait_tensor(tensor)
129 def wait_tensor(tensor):
130 """
131 Wait on a tensor returned by the collectives ops.
132
133 Waiting follows device semantics, which means blocking on CPU and synchronizing streams on CUDA.
134 """
--> 135 return torch.ops._c10d_functional.wait_tensor(tensor)
File /opt/python3.12/lib/python3.12/site-packages/torch/ops.py:1255, in OpOverloadPacket.call_(self, *args, **kwargs)
1253 if self._has_torchbind_op_overload and _must_dispatch_in_python(args, kwargs):
1254 return _call_overload_packet_from_python(self, *args, **kwargs)
-> 1255 return self._op(*args, **kwargs)
NotImplementedError: Cannot access storage of TensorWrapper
```
-
-
- Versions
-
PyTorch 2.9.1
cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @msaroufim @dcci @Chillee @samdow @kshitij12345 @guilhermeleobas