Uploaded image for project: 'AI Platform Core Components'
  1. AI Platform Core Components
  2. AIPCC-7497

Bottleneck in create_read_items_for_chunk_list when resharding checkpoints at scale

XMLWordPrintable

    • Icon: Bug Bug
    • Resolution: Unresolved
    • Icon: Undefined Undefined
    • None
    • None
    • PyTorch
    • None
    • PyTorch Sprint 20

          1. 🐛 Describe the bug

      The create_read_items_for_chunk_list function has an O(L×C) time complexity that causes performance degradation when resharding checkpoints with many chunks. For large-scale distributed training (e.g., LLaMA-70B with 80+ layers, resharding from 64 to 128 GPUs), this validation adds significant overhead during checkpoint loading.
      The root cause is the nested loop in create_read_items_for_chunk_list which checks all pairs of local and saved chunks:

      ```

      1. this is a naive quadratic algo that can be optimized later
        for idx, shard in enumerate(local_chunks):
        for storage_idx, storage_md in enumerate(checkpoint_md.chunks):
        if not _check_shard_metadata_pair_overlap(shard, storage_md):
        continue
      2. Process overlap...
        ```
        This results in L×C comparisons where L = local chunks and C = saved chunks (e.g., 80×80 = 6,400 comparisons per tensor).

      *Reproduction*
      ```
      import time
      import torch
      from torch.distributed.checkpoint.metadata import (
      ChunkStorageMetadata,
      TensorProperties,
      TensorStorageMetadata,
      )
      from torch.distributed.checkpoint.planner_helpers import create_read_items_for_chunk_list

      def create_chunks_2d(grid_size: int, tensor_shape=(10000, 10000)):
      """Create 2D grid-sharded chunks."""
      chunks = []
      chunk_height = tensor_shape[0] // grid_size
      chunk_width = tensor_shape[1] // grid_size

      for row in range(grid_size):
      for col in range(grid_size):
      chunk = ChunkStorageMetadata(
      offsets=torch.Size([row * chunk_height, col * chunk_width]),
      sizes=torch.Size([chunk_height, chunk_width]),
      )
      chunks.append(chunk)
      return chunks

      def benchmark_resharding(saved_grid: int, local_grid: int, num_tensors: int = 100):
      """Benchmark checkpoint resharding scenario."""
      print(f"Scenario: Resharding from

      {saved_grid}x{saved_grid}

      to

      {local_grid}x{local_grid}

      ")
      print(f"Saved chunks:

      {saved_grid**2}

      , Local chunks:

      {local_grid**2}

      ")
      print(f"Comparisons per tensor:

      {saved_grid**2 * local_grid**2:,}

      ")
      print()

      saved_chunks = create_chunks_2d(saved_grid)
      local_chunks = create_chunks_2d(local_grid)

      tensor_md = TensorStorageMetadata(
      properties=TensorProperties.create_from_tensor(torch.empty([10000, 10000])),
      size=torch.Size([10000, 10000]),
      chunks=saved_chunks,
      )

      start = time.time()
      for i in range(num_tensors):
      read_items = create_read_items_for_chunk_list("tensor", tensor_md, local_chunks)
      total = time.time() - starttotal = time.time() - start
      print(f"Total time:

      {total:.2f}

      s for

      {num_tensors}

      tensors")
      print(f"Average:

      {total/num_tensors*1000:.2f}

      ms per tensor")
      print("=" * 60)

      if _name_ == "_main_":

      1. Simulate LLaMA-70B scale checkpoint resharding
        benchmark_resharding(saved_grid=8, local_grid=8, num_tensors=100)
        benchmark_resharding(saved_grid=8, local_grid=32, num_tensors=1000)
        ```
        Run with: python script.py
          1. Versions

      Collecting environment information...
      PyTorch version: N/A
      Is debug build: N/A
      CUDA used to build PyTorch: N/A
      ROCM used to build PyTorch: N/A

      OS: Fedora Linux 43 (Workstation Edition) (x86_64)
      GCC version: (GCC) 15.2.1 20251111 (Red Hat 15.2.1-4)
      Clang version: Could not collect
      CMake version: version 4.1.2
      Libc version: glibc-2.42

      Python version: 3.11.14 (main, Oct 10 2025, 00:00:00) [GCC 15.2.1 20250924 (Red Hat 15.2.1-2)] (64-bit runtime)
      Python platform: Linux-6.17.7-300.fc43.x86_64-x86_64-with-glibc2.42
      Is CUDA available: N/A
      CUDA runtime version: Could not collect
      CUDA_MODULE_LOADING set to: N/A
      GPU models and configuration: Could not collect
      Nvidia driver version: Could not collect
      cuDNN version: Could not collect
      Is XPU available: N/A
      HIP runtime version: N/A
      MIOpen runtime version: N/A
      Is XNNPACK available: N/A

      CPU:
      Architecture: x86_64
      CPU op-mode(s): 32-bit, 64-bit
      Address sizes: 48 bits physical, 48 bits virtual
      Byte Order: Little Endian
      CPU(s): 16
      On-line CPU(s) list: 0-15
      Vendor ID: AuthenticAMD
      Model name: AMD Ryzen 7 PRO 7840HS w/ Radeon 780M Graphics
      CPU family: 25
      Model: 116
      Thread(s) per core: 2
      Core(s) per socket: 8
      Socket(s): 1
      Stepping: 1
      Frequency boost: enabled
      CPU(s) scaling MHz: 75%
      CPU max MHz: 5137.9038
      CPU min MHz: 419.4210
      BogoMIPS: 7585.52
      Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good amd_lbr_v2 nopl xtopology nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpuid_fault cpb cat_l3 cdp_l3 hw_pstate ssbd mba perfmon_v2 ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local user_shstk avx512_bf16 clzero irperf xsaveerptr rdpru wbnoinvd cppc arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold vgif x2avic v_spec_ctrl vnmi avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq rdpid overflow_recov succor smca fsrm flush_l1d amd_lbr_pmc_freeze
      Virtualization: AMD-V
      L1d cache: 256 KiB (8 instances)
      L1i cache: 256 KiB (8 instances)
      L2 cache: 8 MiB (8 instances)
      L3 cache: 16 MiB (1 instance)
      NUMA node(s): 1
      NUMA node0 CPU(s): 0-15
      Vulnerability Gather data sampling: Not affected
      Vulnerability Ghostwrite: Not affected
      Vulnerability Indirect target selection: Not affected
      Vulnerability Itlb multihit: Not affected
      Vulnerability L1tf: Not affected
      Vulnerability Mds: Not affected
      Vulnerability Meltdown: Not affected
      Vulnerability Mmio stale data: Not affected
      Vulnerability Old microcode: Not affected
      Vulnerability Reg file data sampling: Not affected
      Vulnerability Retbleed: Not affected
      Vulnerability Spec rstack overflow: Mitigation; Safe RET
      Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
      Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
      Vulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS; IBPB conditional; STIBP always-on; PBRSB-eIBRS Not affected; BHI Not affected
      Vulnerability Srbds: Not affected
      Vulnerability Tsa: Mitigation; Clear CPU buffers
      Vulnerability Tsx async abort: Not affected
      Vulnerability Vmscape: Mitigation; IBPB before exit to userspace

      Versions of relevant libraries:
      [pip3] flake8==7.3.0
      [pip3] flake8-bugbear==24.12.12
      [pip3] flake8-comprehensions==3.16.0
      [pip3] flake8-executable==2.1.3
      [pip3] flake8-logging-format==2024.24.12
      [pip3] flake8-pyi==25.5.0
      [pip3] flake8_simplify==0.22.0
      [pip3] mypy==1.16.0
      [pip3] mypy_extensions==1.1.0
      [pip3] numpy==1.26.4
      [pip3] optree==0.17.0
      [pip3] torch==2.9.0+cpu

              rh-ee-managarw Mansi Agarwal
              rh-ee-managarw Mansi Agarwal
              PyTorch Distributed
              Votes:
              0 Vote for this issue
              Watchers:
              2 Start watching this issue

                Created:
                Updated: