-
Bug
-
Resolution: Unresolved
-
Normal
-
None
-
None
-
None
-
-
- 🐛 Describe the bug
-
Repro:
```python
import torch
import torch.nn as nn
torch.manual_seed(123)
def _record_scalar_tensor(x: torch.Tensor, prefix: str) -> None:
print(f"
=
{x}")
@torch.library.custom_op("mylib::record_scalar_tensor", mutates_args=())
def record_scalar_tensor(x: torch.Tensor, prefix: str) -> None:
_record_scalar_tensor(x.clone(), prefix)
@record_scalar_tensor.register_fake
def _(x: torch.Tensor, prefix: str) -> None:
return
torch._higher_order_ops.effects._register_effectful_op(
torch.ops.mylib.record_scalar_tensor.default,
torch._higher_order_ops.effects._EffectType.ORDERED,
)
@torch.compile
class MyModule(nn.Module):
def _init_(self):
super()._init_()
self.linear = nn.Linear(5, 5)
def forward(self, x, y):
def true_fn:
torch.ops.mylib.record_scalar_tensor(x.mean(), "True : x.mean ")
return x.clone()
def false_fn:
return x.clone()
x = self.linear
x = torch.relu
x = torch.cond(pred=y, true_fn=true_fn, false_fn=false_fn, operands=(x,))
return x
x = torch.randn(5, 5).to("cuda")
y = torch.tensor([True]).to("cuda")
mod = MyModule().to("cuda")
out = mod(x, y)
print(out[0])
y = torch.tensor([False]).to("cuda")
out = mod(x, y)
print(out[0])
```
Error stack:
```
/usr/local/lib/python3.11/dist-packages/torch/autograd/graph.py:824: UserWarning: Error detected in CondAutogradOpBackward. Traceback of forward call that caused the error:
File "/mlx_devbox/users/yanbo.liang/playground/debug/debug3.py", line 37, in forward
x = torch.cond(pred=y, true_fn=true_fn, false_fn=false_fn, operands=(x,))
File "/usr/local/lib/python3.11/dist-packages/torch/_higher_order_ops/cond.py", line 137, in cond
return cond_op(pred, true_fn, false_fn, operands)
(Triggered internally at /pytorch/torch/csrc/autograd/python_anomaly_mode.cpp:122.)
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
I0917 04:48:39.934000 6191 torch/_dynamo/convert_frame.py:1121] [0/0] run_gc_after_compile: running gc
V0917 04:48:39.936000 6191 torch/_dynamo/convert_frame.py:1395] skipping: remove_dynamo_frames (reason: in skipfiles, file: /usr/local/lib/python3.11/dist-packages/torch/_dynamo/exc.py)
V0917 04:48:39.936000 6191 torch/dynamo/convert_frame.py:1395] skipping: __getattr_ (reason: in skipfiles, file: /usr/local/lib/python3.11/dist-packages/torch/utils/_config_module.py)
V0917 04:48:39.937000 6191 torch/_dynamo/convert_frame.py:1395] skipping: _get_alias_val (reason: in skipfiles, file: /usr/local/lib/python3.11/dist-packages/torch/utils/_config_module.py)
V0917 04:48:39.937000 6191 torch/_dynamo/convert_frame.py:1395] skipping: _get_alias_module_and_name (reason: in skipfiles, file: /usr/local/lib/python3.11/dist-packages/torch/utils/_config_module.py)
Traceback (most recent call last):
File "/mlx_devbox/users/yanbo.liang/playground/debug/debug3.py", line 44, in <module>
out = mod(x, y)
^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/eval_frame.py", line 663, in _fn
raise e.remove_dynamo_frames() from None # see TORCHDYNAMO_VERBOSE=1
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/output_graph.py", line 1544, in _call_user_compiler
raise BackendCompilerFailed(
File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/output_graph.py", line 1519, in _call_user_compiler
compiled_fn = compiler_fn(gm, self.example_inputs())
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/dynamo/repro/after_dynamo.py", line 150, in __call_
compiled_gm = compiler_fn(gm, example_inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/_init.py", line 2347, in __call_
return compile_fx(model_, inputs_, config_patches=self.config)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/compile_fx.py", line 2088, in compile_fx
return aot_autograd(
^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/dynamo/backends/common.py", line 101, in __call_
cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/_functorch/aot_autograd.py", line 1168, in aot_module_simplified
compiled_fn = AOTAutogradCache.load(
^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/_functorch/_aot_autograd/autograd_cache.py", line 775, in load
compiled_fn = dispatch_and_compile()
^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/_functorch/aot_autograd.py", line 1153, in dispatch_and_compile
compiled_fn, _ = create_aot_dispatcher_function(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/_functorch/aot_autograd.py", line 570, in create_aot_dispatcher_function
return _create_aot_dispatcher_function(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/_functorch/aot_autograd.py", line 820, in _create_aot_dispatcher_function
compiled_fn, fw_metadata = compiler_fn(
^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 783, in aot_dispatch_autograd
fx_g, joint_inputs, maybe_subclass_meta = aot_dispatch_autograd_graph(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py", line 318, in aot_dispatch_autograd_graph
fx_g = _create_graph(joint_fn_to_trace, updated_joint_inputs, aot_config=aot_config)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py", line 55, in _create_graph
fx_g = make_fx(
^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2240, in wrapped
return make_fx_tracer.trace(f, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2178, in trace
return self._trace_inner(f, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2149, in _trace_inner
t = dispatch_trace(
^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/_compile.py", line 51, in inner
return disable_fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/eval_frame.py", line 838, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1174, in dispatch_trace
graph = tracer.trace(root, concrete_args) # type: ignore[arg-type]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/eval_frame.py", line 838, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/fx/_symbolic_trace.py", line 837, in trace
(self.create_arg(fn(*args)),),
^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/fx/_symbolic_trace.py", line 691, in flatten_fn
tree_out = root_fn(*tree_args)
^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1229, in wrapped
out = f(*tensors) # type:ignore[call-arg]
^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 717, in inner_fn
outs = fn(*args)
^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 668, in joint_helper
return _functionalized_f_helper(primals, tangents)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 416, in _functionalized_f_helper
f_outs = fn(*f_args)
^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 283, in inner_fn_with_anomaly
return inner_fn(*args)
^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 268, in inner_fn
backward_out = torch.autograd.grad(
^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/autograd/_init_.py", line 451, in grad
return handle_torch_function(
^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/overrides.py", line 1721, in handle_torch_function
result = mode._torch_function_(public_api, types, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1277, in _torch_function_
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/autograd/_init_.py", line 502, in grad
result = _engine_run_backward(
^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/autograd/graph.py", line 824, in _engine_run_backward
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/autograd/function.py", line 307, in apply
return user_fn(self, *args)
^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/_higher_order_ops/cond.py", line 328, in backward
grads = cond_op(
^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/higher_order_ops/cond.py", line 61, in __call_
return super()._call_(pred, true_fn, false_fn, operands)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/ops.py", line 471, in __call_
return wrapper()
^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/_ops.py", line 467, in wrapper
return self.dispatch(
^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/_ops.py", line 327, in dispatch
return kernel(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/_higher_order_ops/cond.py", line 355, in cond_autograd
flat_out = CondAutogradOp.apply(
^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/autograd/function.py", line 575, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/_higher_order_ops/cond.py", line 322, in forward
return cond_op(pred, fw_true_graph, fw_false_graph, operands)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/higher_order_ops/cond.py", line 61, in __call_
return super()._call_(pred, true_fn, false_fn, operands)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/ops.py", line 471, in __call_
return wrapper()
^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/_ops.py", line 467, in wrapper
return self.dispatch(
^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/_ops.py", line 363, in dispatch
result = handler(mode, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/_ops.py", line 179, in functionalize_dispatch_mode_fn
return fn(PythonFunctionalizeAPI(mode), *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/_higher_order_ops/cond.py", line 652, in cond_func
cond_return = cond_op(
^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/higher_order_ops/cond.py", line 61, in __call_
return super()._call_(pred, true_fn, false_fn, operands)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/ops.py", line 471, in __call_
return wrapper()
^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/_ops.py", line 467, in wrapper
return self.dispatch(
^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/_ops.py", line 363, in dispatch
result = handler(mode, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/_higher_order_ops/cond.py", line 368, in inner
return trace_cond(mode, cond_op, pred, true_fn, false_fn, operands)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/_higher_order_ops/cond.py", line 287, in trace_cond
out = func_overload(pred, true_graph, false_graph, operands)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/higher_order_ops/cond.py", line 61, in __call_
return super()._call_(pred, true_fn, false_fn, operands)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/ops.py", line 471, in __call_
return wrapper()
^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/_ops.py", line 467, in wrapper
return self.dispatch(
^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/_ops.py", line 363, in dispatch
result = handler(mode, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/_higher_order_ops/cond.py", line 382, in cond_fake_tensor_mode
flat_true_outs, true_out_spec = pytree.tree_flatten(true_fn(*operands))
^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/fx/_symbolic_trace.py", line 812, in module_call_wrapper
return self.call_module(mod, forward, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1047, in call_module
return forward(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/fx/_symbolic_trace.py", line 805, in forward
return _orig_module_call(mod, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/fx/_lazy_graph_module.py", line 126, in _lazy_forward
return self(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/fx/graph_module.py", line 830, in call_wrapped
return self._wrapped_call(self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/fx/graph_module.py", line 406, in _call_
raise e
File "/usr/local/lib/python3.11/dist-packages/torch/fx/graph_module.py", line 393, in _call_
return super(self.cls, obj)._call_(*args, **kwargs) # type: ignore[misc]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/fx/_symbolic_trace.py", line 812, in module_call_wrapper
return self.call_module(mod, forward, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1047, in call_module
return forward(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/fx/_symbolic_trace.py", line 805, in forward
return _orig_module_call(mod, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
TypeError: forward() missing 1 required positional argument: 'tangents_token'
Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"
```
This is sort of a reasonable use case where we just want to log metrics during the forward rather than the AC recomputation, but it seems we can't register effectful op and put them into sub graph of `torch.cond`.
-
-
- Versions
-
torch 2.7.1
cc @chauhang @penguinwu @zou3519 @ydwu4 @bdhirsh