-
Story
-
Resolution: Unresolved
-
Undefined
-
None
-
None
-
-
- 🐛 Describe the bug
-
OrderedSet example:
```
import torch
import functools
from torch.utils.checkpoint import checkpoint, create_selective_checkpoint_contexts, CheckpointPolicy
from torch.utils._ordered_set import OrderedSet
def policy(compute_heavy_ops, ctx, func, *args, **kwargs):
if func in compute_heavy_ops:
return CheckpointPolicy.MUST_SAVE
return CheckpointPolicy.PREFER_RECOMPUTE
def g
:
return torch.mm(x, x).sin().exp()
@torch.compile(fullgraph=True, backend="eager")
def f(x, policy):
return checkpoint(g, x, use_reentrant=False, context_fn=policy)
x = torch.randn(4, 4, requires_grad=True)
f(x, functools.partial(create_selective_checkpoint_contexts, functools.partial(policy, OrderedSet([torch.ops.aten.mm.default]))))
f(x, functools.partial(create_selective_checkpoint_contexts, functools.partial(policy, OrderedSet([torch.ops.aten.mm.default]))))
```
Recompiles:
```
$ TORCH_LOGS=recompiles python sett.py
W1119 11:44:41.438000 73793 torch/_logging/_internal.py:1200]
W1119 11:44:41.438000 73793 torch/_logging/_internal.py:1200] Detected that context_fn is passed to torch.utils.checkpoint under torch.compile.
W1119 11:44:41.438000 73793 torch/logging/_internal.py:1200] Please make sure the checkpointed region does not contain in-place ops (e.g. torch.relu).
W1119 11:44:41.438000 73793 torch/_logging/_internal.py:1200]
V1119 11:44:41.452000 73793 torch/_dynamo/guards.py:4418] [0/1] [__recompiles] Recompiling function f in /Users/ezyang/Dev/pytorch-z1/sett.py:14
V1119 11:44:41.452000 73793 torch/_dynamo/guards.py:4418] [0/1] [__recompiles] triggered by the following guard failure(s):
V1119 11:44:41.452000 73793 torch/_dynamo/guards.py:4418] [0/1] [__recompiles] - 0/0: ___check_obj_id(policy.args[0].args[0], 4321535568) # return checkpoint(g, x, use_reentrant=False, context_fn=policy) # sett.py:16 in f
```
-
-
- Versions
-
main
cc @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @amjames @Lucaskabela