Uploaded image for project: 'AI Platform Core Components'
  1. AI Platform Core Components
  2. AIPCC-7498

Dynamo overguards on checkpoint policy functions that are closed over by compiled region

    • Icon: Story Story
    • Resolution: Unresolved
    • Icon: Undefined Undefined
    • None
    • None
    • PyTorch
    • PyTorch Sprint 20

          1. 🐛 Describe the bug

      OrderedSet example:

      ```
      import torch
      import functools
      from torch.utils.checkpoint import checkpoint, create_selective_checkpoint_contexts, CheckpointPolicy
      from torch.utils._ordered_set import OrderedSet

      def policy(compute_heavy_ops, ctx, func, *args, **kwargs):
      if func in compute_heavy_ops:
      return CheckpointPolicy.MUST_SAVE
      return CheckpointPolicy.PREFER_RECOMPUTE

      def g:
      return torch.mm(x, x).sin().exp()

      @torch.compile(fullgraph=True, backend="eager")
      def f(x, policy):
      return checkpoint(g, x, use_reentrant=False, context_fn=policy)

      x = torch.randn(4, 4, requires_grad=True)
      f(x, functools.partial(create_selective_checkpoint_contexts, functools.partial(policy, OrderedSet([torch.ops.aten.mm.default]))))
      f(x, functools.partial(create_selective_checkpoint_contexts, functools.partial(policy, OrderedSet([torch.ops.aten.mm.default]))))

      ```

      Recompiles:

      ```
      $ TORCH_LOGS=recompiles python sett.py
      W1119 11:44:41.438000 73793 torch/_logging/_internal.py:1200]
      W1119 11:44:41.438000 73793 torch/_logging/_internal.py:1200] Detected that context_fn is passed to torch.utils.checkpoint under torch.compile.
      W1119 11:44:41.438000 73793 torch/logging/_internal.py:1200] Please make sure the checkpointed region does not contain in-place ops (e.g. torch.relu).
      W1119 11:44:41.438000 73793 torch/_logging/_internal.py:1200]
      V1119 11:44:41.452000 73793 torch/_dynamo/guards.py:4418] [0/1] [__recompiles] Recompiling function f in /Users/ezyang/Dev/pytorch-z1/sett.py:14
      V1119 11:44:41.452000 73793 torch/_dynamo/guards.py:4418] [0/1] [__recompiles] triggered by the following guard failure(s):
      V1119 11:44:41.452000 73793 torch/_dynamo/guards.py:4418] [0/1] [__recompiles] - 0/0: ___check_obj_id(policy.args[0].args[0], 4321535568) # return checkpoint(g, x, use_reentrant=False, context_fn=policy) # sett.py:16 in f
      ```

          1. Versions

      main

      cc @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @amjames @Lucaskabela

              rh-ee-krastogi Kushagra Rastogi
              rh-ee-krastogi Kushagra Rastogi
              PyTorch Compile
              Votes:
              0 Vote for this issue
              Watchers:
              2 Start watching this issue

                Created:
                Updated: