-
Bug
-
Resolution: Unresolved
-
Undefined
-
None
-
None
-
None
-
False
-
-
False
-
-
-
-
- 🐛 Describe the bug
-
If the optimizer is provided parameters which are not used as part of the loss function calculation, they do not get an internal optimization state (e.g. a state used by in Adam). This confuses the `torch.distributed.checkpoint.state_dict.set_state_dict` function and it errors. The error happens even if `StateDictOptions(strict=False)` and `DefaultLoadPlanner(allow_partial_load=True)` is used.
```python
import pathlib
import shutil
import torch
import torch.optim as optim
import torch.distributed.checkpoint as dcp
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict, StateDictOptions
from torch.distributed.checkpoint.stateful import Stateful
from torch.distributed.checkpoint import DefaultLoadPlanner
class AppState(Stateful):
def _init_(self, model, optimizer):
self.model = model
self.optimizer = optimizer
def state_dict(self):
- this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT
model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer)
return { "model": model_state_dict, "optim": optimizer_state_dict }
def load_state_dict(self, state_dict):
- sets our state dicts on the model and optimizer, now that we've loaded
set_state_dict(
self.model,
self.optimizer,
model_state_dict=state_dict["model"],
optim_state_dict=state_dict["optim"],
options=StateDictOptions(strict=False)
)
tmp_dir = pathlib.Path("test/tmp/test_dataset_datamodule_resume")
shutil.rmtree(tmp_dir, ignore_errors=True)
- Model + optimizer
model = torch.nn.Linear(4, 2, bias=True)
optimizer = optim.AdamW(model.parameters(), lr=0.01)
- One training step
loss = torch.mean(model.weight)
loss.backward()
optimizer.step()
- Save checkpoint
dcp.save( {'app': AppState(model, optimizer)}, checkpoint_id=tmp_dir)
- Recreate model + optimizer and load
model2 = torch.nn.Linear(4, 2, bias=True)
optimizer2 = optim.AdamW(model2.parameters(), lr=0.01)
- RuntimeError: Missing key in checkpoint state_dict: optimizer.state.bias.step.
dcp.load( {'app': AppState(model, optimizer)}, checkpoint_id=tmp_dir,
planner=DefaultLoadPlanner(allow_partial_load=True))
```
```
W0930 10:27:20.827000 79300 site-packages/torch/distributed/elastic/multiprocessing/redirects.py:29] NOTE: Redirects are currently not supported in Windows or MacOs.
/opt/homebrew/Caskroom/miniforge/base/envs/py313/lib/python3.13/site-packages/torch/distributed/checkpoint/state_dict_saver.py:167: UserWarning: torch.distributed is disabled, unavailable or uninitialized, assuming the intent is to save in a single process.
warnings.warn(
/opt/homebrew/Caskroom/miniforge/base/envs/py313/lib/python3.13/site-packages/torch/distributed/checkpoint/filesystem.py:131: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
if tensor.storage().size() != tensor.numel():
/opt/homebrew/Caskroom/miniforge/base/envs/py313/lib/python3.13/site-packages/torch/distributed/checkpoint/state_dict_loader.py:153: UserWarning: torch.distributed is disabled, unavailable or uninitialized, assuming the intent is to load in a single process.
warnings.warn(
Traceback (most recent call last):
File "test3.py", line 54, in <module>
dcp.load(
, checkpoint_id=tmp_dir,
~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
planner=DefaultLoadPlanner(allow_partial_load=True))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/Caskroom/miniforge/base/envs/py313/lib/python3.13/site-packages/torch/distributed/checkpoint/logger.py", line 87, in wrapper
result = func(*args, **kwargs)
File "/opt/homebrew/Caskroom/miniforge/base/envs/py313/lib/python3.13/site-packages/torch/distributed/checkpoint/utils.py", line 475, in inner_func
return func(*args, **kwargs)
File "/opt/homebrew/Caskroom/miniforge/base/envs/py313/lib/python3.13/site-packages/torch/distributed/checkpoint/state_dict_loader.py", line 191, in load
elem.load_state_dict(statetful_sd[key])
~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^
File "test3.py", line 26, in load_state_dict
set_state_dict(
~~~~~~~~~~~~~~^
self.model,
^^^^^^^^^^^
...<3 lines>...
options=StateDictOptions(strict=False)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
)
^
File "/opt/homebrew/Caskroom/miniforge/base/envs/py313/lib/python3.13/site-packages/torch/distributed/checkpoint/state_dict.py", line 1377, in set_state_dict
_load_optim_state_dict(model, optimizers, optim_state_dict, info)
~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/Caskroom/miniforge/base/envs/py313/lib/python3.13/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
return func(*args, **kwargs)
File "/opt/homebrew/Caskroom/miniforge/base/envs/py313/lib/python3.13/site-packages/torch/distributed/checkpoint/state_dict.py", line 946, in _load_optim_state_dict
optim_state_dict = _split_optim_state_dict(
model, optim, state_dict, info
)
File "/opt/homebrew/Caskroom/miniforge/base/envs/py313/lib/python3.13/site-packages/torch/distributed/checkpoint/state_dict.py", line 890, in _split_optim_state_dict
state[fqn] = cast(DictValueType, optim_state_dict[_STATE])[fqn]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^
KeyError: 'bias'
```
-
-
- Versions
-
```
Collecting environment information...
PyTorch version: 2.8.0
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: macOS 26.0 (arm64)
GCC version: Could not collect
Clang version: 17.0.0 (clang-1700.3.19.1)
CMake version: version 4.1.1
Libc version: N/A
Python version: 3.13.1 | packaged by conda-forge | (main, Jan 13 2025, 09:45:31) [Clang 18.1.8 ] (64-bit runtime)
Python platform: macOS-26.0-arm64-arm-64bit-Mach-O
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Apple M3
Versions of relevant libraries:
[pip3] mypy==1.15.0
[pip3] mypy-extensions==1.0.0
[pip3] numpy==2.3.3
[pip3] pytorch-lightning==2.5.0.post0
[pip3] torch==2.8.0
[pip3] torchdata==0.11.0
[pip3] torchmetrics==1.8.2
[pip3] torchvision==0.22.0
[conda] numpy 2.3.3 pypi_0 pypi
[conda] pytorch-lightning 2.5.0.post0 pypi_0 pypi
[conda] torch 2.8.0 pypi_0 pypi
[conda] torchdata 0.11.0 pypi_0 pypi
[conda] torchmetrics 1.8.2 pypi_0 pypi
[conda] torchvision 0.22.0 pypi_0 pypi
```
cc @LucasLLC @pradeepfn