Uploaded image for project: 'AI Platform Core Components'
  1. AI Platform Core Components
  2. AIPCC-5724

distributed checkpoint errors with unused weights and stateful optimizer

    • Icon: Bug Bug
    • Resolution: Unresolved
    • Icon: Undefined Undefined
    • None
    • None
    • PyTorch
    • None

          1. 🐛 Describe the bug

      If the optimizer is provided parameters which are not used as part of the loss function calculation, they do not get an internal optimization state (e.g. a state used by in Adam). This confuses the `torch.distributed.checkpoint.state_dict.set_state_dict` function and it errors. The error happens even if `StateDictOptions(strict=False)` and `DefaultLoadPlanner(allow_partial_load=True)` is used.

      ```python
      import pathlib
      import shutil

      import torch
      import torch.optim as optim
      import torch.distributed.checkpoint as dcp
      from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict, StateDictOptions
      from torch.distributed.checkpoint.stateful import Stateful
      from torch.distributed.checkpoint import DefaultLoadPlanner

      class AppState(Stateful):
      def _init_(self, model, optimizer):
      self.model = model
      self.optimizer = optimizer

      def state_dict(self):

      1. this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT
        model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer)
        return { "model": model_state_dict, "optim": optimizer_state_dict }

      def load_state_dict(self, state_dict):

      1. sets our state dicts on the model and optimizer, now that we've loaded
        set_state_dict(
        self.model,
        self.optimizer,
        model_state_dict=state_dict["model"],
        optim_state_dict=state_dict["optim"],
        options=StateDictOptions(strict=False)
        )

      tmp_dir = pathlib.Path("test/tmp/test_dataset_datamodule_resume")
      shutil.rmtree(tmp_dir, ignore_errors=True)

      1. Model + optimizer
        model = torch.nn.Linear(4, 2, bias=True)
        optimizer = optim.AdamW(model.parameters(), lr=0.01)
      1. One training step
        loss = torch.mean(model.weight)
        loss.backward()
        optimizer.step()
      1. Save checkpoint
        dcp.save( {'app': AppState(model, optimizer)}

        , checkpoint_id=tmp_dir)

      1. Recreate model + optimizer and load
        model2 = torch.nn.Linear(4, 2, bias=True)
        optimizer2 = optim.AdamW(model2.parameters(), lr=0.01)
      1. RuntimeError: Missing key in checkpoint state_dict: optimizer.state.bias.step.
        dcp.load( {'app': AppState(model, optimizer)}

        , checkpoint_id=tmp_dir,
        planner=DefaultLoadPlanner(allow_partial_load=True))
        ```

      ```
      W0930 10:27:20.827000 79300 site-packages/torch/distributed/elastic/multiprocessing/redirects.py:29] NOTE: Redirects are currently not supported in Windows or MacOs.
      /opt/homebrew/Caskroom/miniforge/base/envs/py313/lib/python3.13/site-packages/torch/distributed/checkpoint/state_dict_saver.py:167: UserWarning: torch.distributed is disabled, unavailable or uninitialized, assuming the intent is to save in a single process.
      warnings.warn(
      /opt/homebrew/Caskroom/miniforge/base/envs/py313/lib/python3.13/site-packages/torch/distributed/checkpoint/filesystem.py:131: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
      if tensor.storage().size() != tensor.numel():
      /opt/homebrew/Caskroom/miniforge/base/envs/py313/lib/python3.13/site-packages/torch/distributed/checkpoint/state_dict_loader.py:153: UserWarning: torch.distributed is disabled, unavailable or uninitialized, assuming the intent is to load in a single process.
      warnings.warn(
      Traceback (most recent call last):
      File "test3.py", line 54, in <module>
      dcp.load(

      {'app': AppState(model, optimizer)}

      , checkpoint_id=tmp_dir,
      ~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      planner=DefaultLoadPlanner(allow_partial_load=True))
      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "/opt/homebrew/Caskroom/miniforge/base/envs/py313/lib/python3.13/site-packages/torch/distributed/checkpoint/logger.py", line 87, in wrapper
      result = func(*args, **kwargs)
      File "/opt/homebrew/Caskroom/miniforge/base/envs/py313/lib/python3.13/site-packages/torch/distributed/checkpoint/utils.py", line 475, in inner_func
      return func(*args, **kwargs)
      File "/opt/homebrew/Caskroom/miniforge/base/envs/py313/lib/python3.13/site-packages/torch/distributed/checkpoint/state_dict_loader.py", line 191, in load
      elem.load_state_dict(statetful_sd[key])
      ~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^
      File "test3.py", line 26, in load_state_dict
      set_state_dict(
      ~~~~~~~~~~~~~~^
      self.model,
      ^^^^^^^^^^^
      ...<3 lines>...
      options=StateDictOptions(strict=False)
      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      )
      ^
      File "/opt/homebrew/Caskroom/miniforge/base/envs/py313/lib/python3.13/site-packages/torch/distributed/checkpoint/state_dict.py", line 1377, in set_state_dict
      _load_optim_state_dict(model, optimizers, optim_state_dict, info)
      ~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "/opt/homebrew/Caskroom/miniforge/base/envs/py313/lib/python3.13/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
      return func(*args, **kwargs)
      File "/opt/homebrew/Caskroom/miniforge/base/envs/py313/lib/python3.13/site-packages/torch/distributed/checkpoint/state_dict.py", line 946, in _load_optim_state_dict
      optim_state_dict = _split_optim_state_dict(
      model, optim, state_dict, info
      )
      File "/opt/homebrew/Caskroom/miniforge/base/envs/py313/lib/python3.13/site-packages/torch/distributed/checkpoint/state_dict.py", line 890, in _split_optim_state_dict
      state[fqn] = cast(DictValueType, optim_state_dict[_STATE])[fqn]
      ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^
      KeyError: 'bias'
      ```

          1. Versions

      ```
      Collecting environment information...
      PyTorch version: 2.8.0
      Is debug build: False
      CUDA used to build PyTorch: None
      ROCM used to build PyTorch: N/A

      OS: macOS 26.0 (arm64)
      GCC version: Could not collect
      Clang version: 17.0.0 (clang-1700.3.19.1)
      CMake version: version 4.1.1
      Libc version: N/A

      Python version: 3.13.1 | packaged by conda-forge | (main, Jan 13 2025, 09:45:31) [Clang 18.1.8 ] (64-bit runtime)
      Python platform: macOS-26.0-arm64-arm-64bit-Mach-O
      Is CUDA available: False
      CUDA runtime version: No CUDA
      CUDA_MODULE_LOADING set to: N/A
      GPU models and configuration: No CUDA
      Nvidia driver version: No CUDA
      cuDNN version: No CUDA
      Is XPU available: False
      HIP runtime version: N/A
      MIOpen runtime version: N/A
      Is XNNPACK available: True

      CPU:
      Apple M3

      Versions of relevant libraries:
      [pip3] mypy==1.15.0
      [pip3] mypy-extensions==1.0.0
      [pip3] numpy==2.3.3
      [pip3] pytorch-lightning==2.5.0.post0
      [pip3] torch==2.8.0
      [pip3] torchdata==0.11.0
      [pip3] torchmetrics==1.8.2
      [pip3] torchvision==0.22.0
      [conda] numpy 2.3.3 pypi_0 pypi
      [conda] pytorch-lightning 2.5.0.post0 pypi_0 pypi
      [conda] torch 2.8.0 pypi_0 pypi
      [conda] torchdata 0.11.0 pypi_0 pypi
      [conda] torchmetrics 1.8.2 pypi_0 pypi
      [conda] torchvision 0.22.0 pypi_0 pypi
      ```

      cc @LucasLLC @pradeepfn

              rh-ee-amaitra Arkadip Maitra
              rh-ee-amaitra Arkadip Maitra
              Votes:
              0 Vote for this issue
              Watchers:
              3 Start watching this issue

                Created:
                Updated: