Description
🐛 Describe the bug
I was able to save a HF T5 model with FSDP SHARDED_STATE_DICT checkpointing, however during the load of the model it fails with missing some of the layers.
Repro Steps
git clone https://github.com/HamidShojanazeri/examples.git
cd examples
git checkout dist-checkpoint-repro
cd distributed/FSDP
pip install -r requirements.txt
sh download_dataset.sh
sbatch t5.slurm
Versions
Collecting environment information...
PyTorch version: 2.1.0.dev20230613+cu118
Is debug build: False
CUDA used to build PyTorch: 11.8
ROCM used to build PyTorch: N/A
OS: Ubuntu 20.04.5 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
Clang version: 10.0.0-4ubuntu1
CMake version: version 3.26.3
Libc version: glibc-2.31
Python version: 3.10.11 (main, Apr 20 2023, 19:02:41) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.15.0-1019-aws-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 11.2.152
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA A100-SXM4-40GB
GPU 1: NVIDIA A100-SXM4-40GB
GPU 2: NVIDIA A100-SXM4-40GB
GPU 3: NVIDIA A100-SXM4-40GB
Nvidia driver version: 525.85.12
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
Address sizes: 46 bits physical, 48 bits virtual
CPU(s): 96
On-line CPU(s) list: 0-95
Thread(s) per core: 2
Core(s) per socket: 24
Socket(s): 2
NUMA node(s): 2
Vendor ID: GenuineIntel
CPU family: 6
Model: 85
Model name: Intel(R) Xeon(R) Platinum 8275CL CPU @ 3.00GHz
Stepping: 7
CPU MHz: 2999.998
BogoMIPS: 5999.99
Hypervisor vendor: KVM
Virtualization type: full
L1d cache: 1.5 MiB
L1i cache: 1.5 MiB
L2 cache: 48 MiB
L3 cache: 71.5 MiB
NUMA node0 CPU(s): 0-23,48-71
NUMA node1 CPU(s): 24-47,72-95
Vulnerability Itlb multihit: KVM: Mitigation: VMX unsupported
Vulnerability L1tf: Mitigation; PTE Inversion
Vulnerability Mds: Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Vulnerability Meltdown: Mitigation; PTI
Vulnerability Mmio stale data: Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Vulnerability Retbleed: Vulnerable
Vulnerability Spec store bypass: Vulnerable
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines, STIBP disabled, RSB filling
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single pti fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves ida arat pku ospke
Versions of relevant libraries:
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.23.5
[pip3] pytorch-triton==2.1.0+440fd1bf20
[pip3] torch==2.1.0.dev20230613+cu118
[pip3] torch-model-archiver==0.8.0
[pip3] torch-tb-profiler==0.4.1
[pip3] torch-workflow-archiver==0.2.8
[pip3] torchaudio==2.1.0.dev20230613+cu118
[pip3] torchpippy==0.1.1+3edf3ab
[pip3] torchserve==0.8.0
[pip3] torchvision==0.16.0.dev20230613+cu118
[pip3] triton==2.0.0
[pip3] vit-pytorch==1.2.2
[conda] numpy 1.23.5 pypi_0 pypi
[conda] pytorch-triton 2.1.0+440fd1bf20 pypi_0 pypi
[conda] torch 2.1.0.dev20230613+cu118 pypi_0 pypi
[conda] torch-model-archiver 0.8.0 pypi_0 pypi
[conda] torch-tb-profiler 0.4.1 pypi_0 pypi
[conda] torch-workflow-archiver 0.2.8 pypi_0 pypi
[conda] torchaudio 2.1.0.dev20230613+cu118 pypi_0 pypi
[conda] torchpippy 0.1.1+3edf3ab pypi_0 pypi
[conda] torchserve 0.8.0 pypi_0 pypi
[conda] torchvision 0.16.0.dev20230613+cu118 pypi_0 pypi
[conda] triton 2.0.0 pypi_0 pypi
[conda] vit-pytorch 1.2.2 pypi_0 pypi
cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu