Multi-Node Training
On This Page
Overview
This tutorial demonstrates how to train a PyTorch Lightning model across multiple GPU nodes using the Slurm workload manager and the micromamba
package manager for environment management.
Goals
- Launch distributed training using PyTorch Lightning’s built-in DDP support.
- Use
srun
to run jobs across multiple nodes. - Run everything from a clean
micromamba
environment without requiring shell activation.
Components
1. PyTorch Lightning Script
A standard Lightning training script using
Trainer(..., strategy="ddp", num_nodes=N)
2. Micromamba Environment
Create a lightweight environment for PyTorch Lightning:
# Load the mamba module to get access to micromamba
module load MAMBA
# Create the environment and install packages
micromamba create -n lightning-env python=3.10 uv -c conda-forge
uv pip install torch torchvision --index-url https://download.pytorch.org/whl/cu128
uv pip install pytorch-lightning
# Turn off lock files
micromamba config set use_lockfiles False
The file locks will cause issues when we try to initialize the same environment from multiple nodes at the same time. We turn them off to enable this use case.
3. Slurm Batch Script
A job script is used to setup the environment and run. You should refer to the SLURM Tutorial for a more in depth explainer on setting up a batch script. Our focus here is on requesting multiple nodes and GPUs across 2 nodes – be sure to request in proportional increments (8 GPUs for every node).
#SBATCH --nodes=2
#SBATCH --gpus=16
#SBATCH --ntasks-per-node=8
#SBATCH --mem-per-gpu=128G
#SBATCH --cpus-per-gpu=16
The main environment variables are related to the NCCL configuration and are used to ensure that there are no conflicts when searching for routes between GPUs. You can use the NCCL_DEBUG to show you information on the connectivity if you run into issues.
export OMP_NUM_THREADS=64
# export NCCL_DEBUG=INFO
export NCCL_NVLS_ENABLE=1
export NCCL_IB_ADAPTIVE_ROUTING=1
export NCCL_IB_SL=1
export NCCL_IB_QPS_PER_CONNECTION=2
export NCCL_IB_SPLIT_DATA_ON_QPS=0
export NCCL_IB_HCA=mlx5_15,mlx5_10,mlx5_14,mlx5_13,mlx5_8,mlx5_7,mlx5_9,mlx5_4
export NCCL_SOCKET_IFNAME=bond0
export NCCL_ALGO=RING
export UCX_TLS=rc
Finally, we use srun
to launch one training process per node utilizing the micromamba run command:
srun micromamba run -n lightning-env python train.py
This avoids shell activation issues in batch scripts and PyTorch Lightning takes care of setting up distributed training internally using Slurm environment variables like SLURM_NTASKS
, SLURM_NODEID
, etc.
Files
These completed files contain the core notes from above. You will still need to setup the environment as mentioned before. With the environment setup as shown above you simply:
sbatch train.sbatch
train.py
import torch
torch.set_float32_matmul_precision('medium')
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.strategies import DDPStrategy
import os
class LitModel(pl.LightningModule):
def __init__(self):
super().__init__()
self.layer = nn.Sequential(
nn.Flatten(),
nn.Linear(28 * 28, 256),
nn.ReLU(),
nn.Linear(256, 10)
)
def forward(self, x):
return self.layer(x)
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
self.log('train_loss', loss)
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=1e-3)
def train_dataloader(self):
return DataLoader(
MNIST('./data', train=True, download=True, transform=transforms.ToTensor()),
batch_size=64,
num_workers=4
)
if __name__ == "__main__":
trainer = Trainer(
accelerator='gpu',
devices=torch.cuda.device_count(),
strategy=DDPStrategy(find_unused_parameters=False),
num_nodes=int(os.environ.get("SLURM_NNODES", 1)),
max_epochs=5,
)
model = LitModel()
trainer.fit(model)
train.sbatch
#!/bin/bash
#SBATCH --job-name=lightning-ddp
#SBATCH --nodes=2
#SBATCH --gpus=16
#SBATCH --ntasks-per-node=8
#SBATCH --mem-per-gpu=128G
#SBATCH --cpus-per-gpu=16
#SBATCH --time=02:00:00
#SBATCH --output=slurm-%j.out
# Load your environment
module load MAMBA
# export NCCL_DEBUG=INFO
export OMP_NUM_THREADS=64
export NCCL_NVLS_ENABLE=1
export NCCL_IB_ADAPTIVE_ROUTING=1
export NCCL_IB_SL=1
export NCCL_IB_QPS_PER_CONNECTION=2
export NCCL_IB_SPLIT_DATA_ON_QPS=0
export NCCL_IB_HCA=mlx5_15,mlx5_10,mlx5_14,mlx5_13,mlx5_8,mlx5_7,mlx5_9,mlx5_4
export NCCL_SOCKET_IFNAME=bond0
export NCCL_ALGO=RING
export UCX_TLS=rc
# Run with torch.distributed
srun micromamba run -n lightning-env python train.py