-
Bug
-
Resolution: Unresolved
-
Undefined
-
None
-
None
-
None
-
-
- 🐛 Describe the bug
-
The create_read_items_for_chunk_list function has an O(L×C) time complexity that causes performance degradation when resharding checkpoints with many chunks. For large-scale distributed training (e.g., LLaMA-70B with 80+ layers, resharding from 64 to 128 GPUs), this validation adds significant overhead during checkpoint loading.
The root cause is the nested loop in create_read_items_for_chunk_list which checks all pairs of local and saved chunks:
```
- this is a naive quadratic algo that can be optimized later
for idx, shard in enumerate(local_chunks):
for storage_idx, storage_md in enumerate(checkpoint_md.chunks):
if not _check_shard_metadata_pair_overlap(shard, storage_md):
continue - Process overlap...
```
This results in L×C comparisons where L = local chunks and C = saved chunks (e.g., 80×80 = 6,400 comparisons per tensor).
*Reproduction*
```
import time
import torch
from torch.distributed.checkpoint.metadata import (
ChunkStorageMetadata,
TensorProperties,
TensorStorageMetadata,
)
from torch.distributed.checkpoint.planner_helpers import create_read_items_for_chunk_list
def create_chunks_2d(grid_size: int, tensor_shape=(10000, 10000)):
"""Create 2D grid-sharded chunks."""
chunks = []
chunk_height = tensor_shape[0] // grid_size
chunk_width = tensor_shape[1] // grid_size
for row in range(grid_size):
for col in range(grid_size):
chunk = ChunkStorageMetadata(
offsets=torch.Size([row * chunk_height, col * chunk_width]),
sizes=torch.Size([chunk_height, chunk_width]),
)
chunks.append(chunk)
return chunks
def benchmark_resharding(saved_grid: int, local_grid: int, num_tensors: int = 100):
"""Benchmark checkpoint resharding scenario."""
print(f"Scenario: Resharding from
to
{local_grid}x{local_grid}")
print(f"Saved chunks:
, Local chunks:
{local_grid**2}")
print(f"Comparisons per tensor:
")
print()
saved_chunks = create_chunks_2d(saved_grid)
local_chunks = create_chunks_2d(local_grid)
tensor_md = TensorStorageMetadata(
properties=TensorProperties.create_from_tensor(torch.empty([10000, 10000])),
size=torch.Size([10000, 10000]),
chunks=saved_chunks,
)
start = time.time()
for i in range(num_tensors):
read_items = create_read_items_for_chunk_list("tensor", tensor_md, local_chunks)
total = time.time() - starttotal = time.time() - start
print(f"Total time:
s for
{num_tensors} tensors")
print(f"Average:
ms per tensor")
print("=" * 60)
if _name_ == "_main_":
- Simulate LLaMA-70B scale checkpoint resharding
benchmark_resharding(saved_grid=8, local_grid=8, num_tensors=100)
benchmark_resharding(saved_grid=8, local_grid=32, num_tensors=1000)
```
Run with: python script.py
-
-
- Versions
-
Collecting environment information...
PyTorch version: N/A
Is debug build: N/A
CUDA used to build PyTorch: N/A
ROCM used to build PyTorch: N/A
OS: Fedora Linux 43 (Workstation Edition) (x86_64)
GCC version: (GCC) 15.2.1 20251111 (Red Hat 15.2.1-4)
Clang version: Could not collect
CMake version: version 4.1.2
Libc version: glibc-2.42
Python version: 3.11.14 (main, Oct 10 2025, 00:00:00) [GCC 15.2.1 20250924 (Red Hat 15.2.1-2)] (64-bit runtime)
Python platform: Linux-6.17.7-300.fc43.x86_64-x86_64-with-glibc2.42
Is CUDA available: N/A
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: Could not collect
Nvidia driver version: Could not collect
cuDNN version: Could not collect
Is XPU available: N/A
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: N/A
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 48 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 16
On-line CPU(s) list: 0-15
Vendor ID: AuthenticAMD
Model name: AMD Ryzen 7 PRO 7840HS w/ Radeon 780M Graphics
CPU family: 25
Model: 116
Thread(s) per core: 2
Core(s) per socket: 8
Socket(s): 1
Stepping: 1
Frequency boost: enabled
CPU(s) scaling MHz: 75%
CPU max MHz: 5137.9038
CPU min MHz: 419.4210
BogoMIPS: 7585.52
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good amd_lbr_v2 nopl xtopology nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpuid_fault cpb cat_l3 cdp_l3 hw_pstate ssbd mba perfmon_v2 ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local user_shstk avx512_bf16 clzero irperf xsaveerptr rdpru wbnoinvd cppc arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold vgif x2avic v_spec_ctrl vnmi avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq rdpid overflow_recov succor smca fsrm flush_l1d amd_lbr_pmc_freeze
Virtualization: AMD-V
L1d cache: 256 KiB (8 instances)
L1i cache: 256 KiB (8 instances)
L2 cache: 8 MiB (8 instances)
L3 cache: 16 MiB (1 instance)
NUMA node(s): 1
NUMA node0 CPU(s): 0-15
Vulnerability Gather data sampling: Not affected
Vulnerability Ghostwrite: Not affected
Vulnerability Indirect target selection: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Old microcode: Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Mitigation; Safe RET
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS; IBPB conditional; STIBP always-on; PBRSB-eIBRS Not affected; BHI Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsa: Mitigation; Clear CPU buffers
Vulnerability Tsx async abort: Not affected
Vulnerability Vmscape: Mitigation; IBPB before exit to userspace
Versions of relevant libraries:
[pip3] flake8==7.3.0
[pip3] flake8-bugbear==24.12.12
[pip3] flake8-comprehensions==3.16.0
[pip3] flake8-executable==2.1.3
[pip3] flake8-logging-format==2024.24.12
[pip3] flake8-pyi==25.5.0
[pip3] flake8_simplify==0.22.0
[pip3] mypy==1.16.0
[pip3] mypy_extensions==1.1.0
[pip3] numpy==1.26.4
[pip3] optree==0.17.0
[pip3] torch==2.9.0+cpu