Uploaded image for project: 'AI Platform Core Components'
  1. AI Platform Core Components
  2. AIPCC-5759

[torch.compile] Side-Effectful Tokens doesn't work well with torch.cond

XMLWordPrintable

    • Icon: Bug Bug
    • Resolution: Unresolved
    • Icon: Normal Normal
    • None
    • None
    • PyTorch
    • None
    • AIPCC Accelerators 17

          1. 🐛 Describe the bug

      Repro:
      ```python
      import torch
      import torch.nn as nn

      torch.manual_seed(123)

      def _record_scalar_tensor(x: torch.Tensor, prefix: str) -> None:
      print(f"

      {prefix}

      =

      {x}

      ")

      @torch.library.custom_op("mylib::record_scalar_tensor", mutates_args=())
      def record_scalar_tensor(x: torch.Tensor, prefix: str) -> None:
      _record_scalar_tensor(x.clone(), prefix)

      @record_scalar_tensor.register_fake
      def _(x: torch.Tensor, prefix: str) -> None:
      return
      torch._higher_order_ops.effects._register_effectful_op(
      torch.ops.mylib.record_scalar_tensor.default,
      torch._higher_order_ops.effects._EffectType.ORDERED,
      )

      @torch.compile
      class MyModule(nn.Module):
      def _init_(self):
      super()._init_()
      self.linear = nn.Linear(5, 5)

      def forward(self, x, y):
      def true_fn:
      torch.ops.mylib.record_scalar_tensor(x.mean(), "True : x.mean ")
      return x.clone()

      def false_fn:
      return x.clone()

      x = self.linear
      x = torch.relu
      x = torch.cond(pred=y, true_fn=true_fn, false_fn=false_fn, operands=(x,))
      return x

      x = torch.randn(5, 5).to("cuda")
      y = torch.tensor([True]).to("cuda")
      mod = MyModule().to("cuda")

      out = mod(x, y)
      print(out[0])

      y = torch.tensor([False]).to("cuda")
      out = mod(x, y)
      print(out[0])
      ```

      Error stack:
      ```
      /usr/local/lib/python3.11/dist-packages/torch/autograd/graph.py:824: UserWarning: Error detected in CondAutogradOpBackward. Traceback of forward call that caused the error:
      File "/mlx_devbox/users/yanbo.liang/playground/debug/debug3.py", line 37, in forward
      x = torch.cond(pred=y, true_fn=true_fn, false_fn=false_fn, operands=(x,))
      File "/usr/local/lib/python3.11/dist-packages/torch/_higher_order_ops/cond.py", line 137, in cond
      return cond_op(pred, true_fn, false_fn, operands)
      (Triggered internally at /pytorch/torch/csrc/autograd/python_anomaly_mode.cpp:122.)
      return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
      I0917 04:48:39.934000 6191 torch/_dynamo/convert_frame.py:1121] [0/0] run_gc_after_compile: running gc
      V0917 04:48:39.936000 6191 torch/_dynamo/convert_frame.py:1395] skipping: remove_dynamo_frames (reason: in skipfiles, file: /usr/local/lib/python3.11/dist-packages/torch/_dynamo/exc.py)
      V0917 04:48:39.936000 6191 torch/dynamo/convert_frame.py:1395] skipping: __getattr_ (reason: in skipfiles, file: /usr/local/lib/python3.11/dist-packages/torch/utils/_config_module.py)
      V0917 04:48:39.937000 6191 torch/_dynamo/convert_frame.py:1395] skipping: _get_alias_val (reason: in skipfiles, file: /usr/local/lib/python3.11/dist-packages/torch/utils/_config_module.py)
      V0917 04:48:39.937000 6191 torch/_dynamo/convert_frame.py:1395] skipping: _get_alias_module_and_name (reason: in skipfiles, file: /usr/local/lib/python3.11/dist-packages/torch/utils/_config_module.py)
      Traceback (most recent call last):
      File "/mlx_devbox/users/yanbo.liang/playground/debug/debug3.py", line 44, in <module>
      out = mod(x, y)
      ^^^^^^^^^
      File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/eval_frame.py", line 663, in _fn
      raise e.remove_dynamo_frames() from None # see TORCHDYNAMO_VERBOSE=1
      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/output_graph.py", line 1544, in _call_user_compiler
      raise BackendCompilerFailed(
      File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/output_graph.py", line 1519, in _call_user_compiler
      compiled_fn = compiler_fn(gm, self.example_inputs())
      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "/usr/local/lib/python3.11/dist-packages/torch/dynamo/repro/after_dynamo.py", line 150, in __call_
      compiled_gm = compiler_fn(gm, example_inputs)
      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "/usr/local/lib/python3.11/dist-packages/torch/_init.py", line 2347, in __call_
      return compile_fx(model_, inputs_, config_patches=self.config)
      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "/usr/local/lib/python3.11/dist-packages/torch/_inductor/compile_fx.py", line 2088, in compile_fx
      return aot_autograd(
      ^^^^^^^^^^^^^
      File "/usr/local/lib/python3.11/dist-packages/torch/dynamo/backends/common.py", line 101, in __call_
      cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "/usr/local/lib/python3.11/dist-packages/torch/_functorch/aot_autograd.py", line 1168, in aot_module_simplified
      compiled_fn = AOTAutogradCache.load(
      ^^^^^^^^^^^^^^^^^^^^^^
      File "/usr/local/lib/python3.11/dist-packages/torch/_functorch/_aot_autograd/autograd_cache.py", line 775, in load
      compiled_fn = dispatch_and_compile()
      ^^^^^^^^^^^^^^^^^^^^^^
      File "/usr/local/lib/python3.11/dist-packages/torch/_functorch/aot_autograd.py", line 1153, in dispatch_and_compile
      compiled_fn, _ = create_aot_dispatcher_function(
      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "/usr/local/lib/python3.11/dist-packages/torch/_functorch/aot_autograd.py", line 570, in create_aot_dispatcher_function
      return _create_aot_dispatcher_function(
      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "/usr/local/lib/python3.11/dist-packages/torch/_functorch/aot_autograd.py", line 820, in _create_aot_dispatcher_function
      compiled_fn, fw_metadata = compiler_fn(
      ^^^^^^^^^^^^
      File "/usr/local/lib/python3.11/dist-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 783, in aot_dispatch_autograd
      fx_g, joint_inputs, maybe_subclass_meta = aot_dispatch_autograd_graph(
      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "/usr/local/lib/python3.11/dist-packages/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py", line 318, in aot_dispatch_autograd_graph
      fx_g = _create_graph(joint_fn_to_trace, updated_joint_inputs, aot_config=aot_config)
      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "/usr/local/lib/python3.11/dist-packages/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py", line 55, in _create_graph
      fx_g = make_fx(
      ^^^^^^^^
      File "/usr/local/lib/python3.11/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2240, in wrapped
      return make_fx_tracer.trace(f, *args)
      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "/usr/local/lib/python3.11/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2178, in trace
      return self._trace_inner(f, *args)
      ^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "/usr/local/lib/python3.11/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2149, in _trace_inner
      t = dispatch_trace(
      ^^^^^^^^^^^^^^^
      File "/usr/local/lib/python3.11/dist-packages/torch/_compile.py", line 51, in inner
      return disable_fn(*args, **kwargs)
      ^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/eval_frame.py", line 838, in _fn
      return fn(*args, **kwargs)
      ^^^^^^^^^^^^^^^^^^^
      File "/usr/local/lib/python3.11/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1174, in dispatch_trace
      graph = tracer.trace(root, concrete_args) # type: ignore[arg-type]
      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/eval_frame.py", line 838, in _fn
      return fn(*args, **kwargs)
      ^^^^^^^^^^^^^^^^^^^
      File "/usr/local/lib/python3.11/dist-packages/torch/fx/_symbolic_trace.py", line 837, in trace
      (self.create_arg(fn(*args)),),
      ^^^^^^^^^
      File "/usr/local/lib/python3.11/dist-packages/torch/fx/_symbolic_trace.py", line 691, in flatten_fn
      tree_out = root_fn(*tree_args)
      ^^^^^^^^^^^^^^^^^^^
      File "/usr/local/lib/python3.11/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1229, in wrapped
      out = f(*tensors) # type:ignore[call-arg]
      ^^^^^^^^^^^
      File "/usr/local/lib/python3.11/dist-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 717, in inner_fn
      outs = fn(*args)
      ^^^^^^^^^
      File "/usr/local/lib/python3.11/dist-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 668, in joint_helper
      return _functionalized_f_helper(primals, tangents)
      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "/usr/local/lib/python3.11/dist-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 416, in _functionalized_f_helper
      f_outs = fn(*f_args)
      ^^^^^^^^^^^
      File "/usr/local/lib/python3.11/dist-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 283, in inner_fn_with_anomaly
      return inner_fn(*args)
      ^^^^^^^^^^^^^^^
      File "/usr/local/lib/python3.11/dist-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 268, in inner_fn
      backward_out = torch.autograd.grad(
      ^^^^^^^^^^^^^^^^^^^^
      File "/usr/local/lib/python3.11/dist-packages/torch/autograd/_init_.py", line 451, in grad
      return handle_torch_function(
      ^^^^^^^^^^^^^^^^^^^^^^
      File "/usr/local/lib/python3.11/dist-packages/torch/overrides.py", line 1721, in handle_torch_function
      result = mode._torch_function_(public_api, types, args, kwargs)
      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "/usr/local/lib/python3.11/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1277, in _torch_function_
      return func(*args, **kwargs)
      ^^^^^^^^^^^^^^^^^^^^^
      File "/usr/local/lib/python3.11/dist-packages/torch/autograd/_init_.py", line 502, in grad
      result = _engine_run_backward(
      ^^^^^^^^^^^^^^^^^^^^^
      File "/usr/local/lib/python3.11/dist-packages/torch/autograd/graph.py", line 824, in _engine_run_backward
      return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "/usr/local/lib/python3.11/dist-packages/torch/autograd/function.py", line 307, in apply
      return user_fn(self, *args)
      ^^^^^^^^^^^^^^^^^^^^
      File "/usr/local/lib/python3.11/dist-packages/torch/_higher_order_ops/cond.py", line 328, in backward
      grads = cond_op(
      ^^^^^^^^
      File "/usr/local/lib/python3.11/dist-packages/torch/higher_order_ops/cond.py", line 61, in __call_
      return super()._call_(pred, true_fn, false_fn, operands)
      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "/usr/local/lib/python3.11/dist-packages/torch/ops.py", line 471, in __call_
      return wrapper()
      ^^^^^^^^^
      File "/usr/local/lib/python3.11/dist-packages/torch/_ops.py", line 467, in wrapper
      return self.dispatch(
      ^^^^^^^^^^^^^^
      File "/usr/local/lib/python3.11/dist-packages/torch/_ops.py", line 327, in dispatch
      return kernel(*args, **kwargs)
      ^^^^^^^^^^^^^^^^^^^^^^^
      File "/usr/local/lib/python3.11/dist-packages/torch/_higher_order_ops/cond.py", line 355, in cond_autograd
      flat_out = CondAutogradOp.apply(
      ^^^^^^^^^^^^^^^^^^^^^
      File "/usr/local/lib/python3.11/dist-packages/torch/autograd/function.py", line 575, in apply
      return super().apply(*args, **kwargs) # type: ignore[misc]
      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "/usr/local/lib/python3.11/dist-packages/torch/_higher_order_ops/cond.py", line 322, in forward
      return cond_op(pred, fw_true_graph, fw_false_graph, operands)
      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "/usr/local/lib/python3.11/dist-packages/torch/higher_order_ops/cond.py", line 61, in __call_
      return super()._call_(pred, true_fn, false_fn, operands)
      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "/usr/local/lib/python3.11/dist-packages/torch/ops.py", line 471, in __call_
      return wrapper()
      ^^^^^^^^^
      File "/usr/local/lib/python3.11/dist-packages/torch/_ops.py", line 467, in wrapper
      return self.dispatch(
      ^^^^^^^^^^^^^^
      File "/usr/local/lib/python3.11/dist-packages/torch/_ops.py", line 363, in dispatch
      result = handler(mode, *args, **kwargs)
      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "/usr/local/lib/python3.11/dist-packages/torch/_ops.py", line 179, in functionalize_dispatch_mode_fn
      return fn(PythonFunctionalizeAPI(mode), *args, **kwargs)
      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "/usr/local/lib/python3.11/dist-packages/torch/_higher_order_ops/cond.py", line 652, in cond_func
      cond_return = cond_op(
      ^^^^^^^^
      File "/usr/local/lib/python3.11/dist-packages/torch/higher_order_ops/cond.py", line 61, in __call_
      return super()._call_(pred, true_fn, false_fn, operands)
      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "/usr/local/lib/python3.11/dist-packages/torch/ops.py", line 471, in __call_
      return wrapper()
      ^^^^^^^^^
      File "/usr/local/lib/python3.11/dist-packages/torch/_ops.py", line 467, in wrapper
      return self.dispatch(
      ^^^^^^^^^^^^^^
      File "/usr/local/lib/python3.11/dist-packages/torch/_ops.py", line 363, in dispatch
      result = handler(mode, *args, **kwargs)
      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "/usr/local/lib/python3.11/dist-packages/torch/_higher_order_ops/cond.py", line 368, in inner
      return trace_cond(mode, cond_op, pred, true_fn, false_fn, operands)
      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "/usr/local/lib/python3.11/dist-packages/torch/_higher_order_ops/cond.py", line 287, in trace_cond
      out = func_overload(pred, true_graph, false_graph, operands)
      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "/usr/local/lib/python3.11/dist-packages/torch/higher_order_ops/cond.py", line 61, in __call_
      return super()._call_(pred, true_fn, false_fn, operands)
      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "/usr/local/lib/python3.11/dist-packages/torch/ops.py", line 471, in __call_
      return wrapper()
      ^^^^^^^^^
      File "/usr/local/lib/python3.11/dist-packages/torch/_ops.py", line 467, in wrapper
      return self.dispatch(
      ^^^^^^^^^^^^^^
      File "/usr/local/lib/python3.11/dist-packages/torch/_ops.py", line 363, in dispatch
      result = handler(mode, *args, **kwargs)
      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "/usr/local/lib/python3.11/dist-packages/torch/_higher_order_ops/cond.py", line 382, in cond_fake_tensor_mode
      flat_true_outs, true_out_spec = pytree.tree_flatten(true_fn(*operands))
      ^^^^^^^^^^^^^^^^^^
      File "/usr/local/lib/python3.11/dist-packages/torch/fx/_symbolic_trace.py", line 812, in module_call_wrapper
      return self.call_module(mod, forward, args, kwargs)
      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "/usr/local/lib/python3.11/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1047, in call_module
      return forward(*args, **kwargs)
      ^^^^^^^^^^^^^^^^^^^^^^^^
      File "/usr/local/lib/python3.11/dist-packages/torch/fx/_symbolic_trace.py", line 805, in forward
      return _orig_module_call(mod, *args, **kwargs)
      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
      return self._call_impl(*args, **kwargs)
      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl
      return forward_call(*args, **kwargs)
      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "/usr/local/lib/python3.11/dist-packages/torch/fx/_lazy_graph_module.py", line 126, in _lazy_forward
      return self(*args, **kwargs)
      ^^^^^^^^^^^^^^^^^^^^^
      File "/usr/local/lib/python3.11/dist-packages/torch/fx/graph_module.py", line 830, in call_wrapped
      return self._wrapped_call(self, *args, **kwargs)
      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "/usr/local/lib/python3.11/dist-packages/torch/fx/graph_module.py", line 406, in _call_
      raise e
      File "/usr/local/lib/python3.11/dist-packages/torch/fx/graph_module.py", line 393, in _call_
      return super(self.cls, obj)._call_(*args, **kwargs) # type: ignore[misc]
      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "/usr/local/lib/python3.11/dist-packages/torch/fx/_symbolic_trace.py", line 812, in module_call_wrapper
      return self.call_module(mod, forward, args, kwargs)
      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "/usr/local/lib/python3.11/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1047, in call_module
      return forward(*args, **kwargs)
      ^^^^^^^^^^^^^^^^^^^^^^^^
      File "/usr/local/lib/python3.11/dist-packages/torch/fx/_symbolic_trace.py", line 805, in forward
      return _orig_module_call(mod, *args, **kwargs)
      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
      return self._call_impl(*args, **kwargs)
      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl
      return forward_call(*args, **kwargs)
      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
      TypeError: forward() missing 1 required positional argument: 'tangents_token'

      Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"
      ```

      This is sort of a reasonable use case where we just want to log metrics during the forward rather than the AC recomputation, but it seems we can't register effectful op and put them into sub graph of `torch.cond`.

          1. Versions

      torch 2.7.1

      cc @chauhang @penguinwu @zou3519 @ydwu4 @bdhirsh

              rh-ee-krastogi Kushagra Rastogi
              rh-ee-krastogi Kushagra Rastogi
              PyTorch Compile
              Votes:
              0 Vote for this issue
              Watchers:
              4 Start watching this issue

                Created:
                Updated: