Skip to content

Commit

Permalink
Merge pull request #152 from NCAR/distributed
Browse files Browse the repository at this point in the history
Updates to grad clipping, seed, and checkpointing
  • Loading branch information
djgagne authored Feb 3, 2025
2 parents 814e1fc + 8128397 commit b27e22f
Show file tree
Hide file tree
Showing 6 changed files with 212 additions and 98 deletions.
42 changes: 27 additions & 15 deletions applications/train_universal.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,16 @@
from credit.scheduler import load_scheduler
from credit.trainers import load_trainer
from credit.parser import credit_main_parser, training_data_check
from credit.datasets.load_dataset_and_dataloader import (
load_dataset,
load_dataloader
)
from credit.datasets.load_dataset_and_dataloader import load_dataset, load_dataloader

from credit.metrics import LatWeightedMetrics
from credit.pbs import launch_script, launch_script_mpi
from credit.models import load_model
from credit.models.checkpoint import FSDPOptimizerWrapper, TorchFSDPCheckpointIO, load_state_dict_error_handler
from credit.models.checkpoint import (
FSDPOptimizerWrapper,
TorchFSDPCheckpointIO,
load_state_dict_error_handler,
)


warnings.filterwarnings("ignore")
Expand All @@ -46,6 +47,9 @@
os.environ["MKL_NUM_THREADS"] = "1"
# https://stackoverflow.com/questions/59129812/how-to-avoid-cuda-out-of-memory-in-pytorch
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
# os.environ["NCCL_P2P_DISABLE"] = "1"
# os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "1"


def load_model_states_and_optimizer(conf, model, device):
"""
Expand Down Expand Up @@ -141,15 +145,17 @@ def load_model_states_and_optimizer(conf, model, device):
logging.info(
f"Loading DDP model, optimizer, grad scaler, and learning rate scheduler states from {save_loc}"
)
load_msg = model.module.load_state_dict(checkpoint["model_state_dict"],
strict=False)
load_msg = model.module.load_state_dict(
checkpoint["model_state_dict"], strict=False
)
load_state_dict_error_handler(load_msg)
else:
logging.info(
f"Loading model, optimizer, grad scaler, and learning rate scheduler states from {save_loc}"
)
load_msg = model.load_state_dict(checkpoint["model_state_dict"],
strict=False)
load_msg = model.load_state_dict(
checkpoint["model_state_dict"], strict=False
)
load_state_dict_error_handler(load_msg)

# Load the learning rate scheduler and mixed precision grad scaler
Expand Down Expand Up @@ -195,15 +201,17 @@ def load_model_states_and_optimizer(conf, model, device):
logging.info(
f"Loading DDP model, optimizer, grad scaler, and learning rate scheduler states from {save_loc}"
)
load_msg = model.module.load_state_dict(checkpoint["model_state_dict"],
strict=False)
load_msg = model.module.load_state_dict(
checkpoint["model_state_dict"], strict=False
)
load_state_dict_error_handler(load_msg)
else:
logging.info(
f"Loading model, optimizer, grad scaler, and learning rate scheduler states from {save_loc}"
)
load_msg = model.load_state_dict(checkpoint["model_state_dict"],
strict=False)
load_msg = model.load_state_dict(
checkpoint["model_state_dict"], strict=False
)
load_state_dict_error_handler(load_msg)

optimizer = torch.optim.AdamW(
Expand Down Expand Up @@ -287,8 +295,12 @@ def main(rank, world_size, conf, backend, trial=False):
valid_dataset = load_dataset(conf, rank=rank, world_size=world_size, is_train=False)

# Load the dataloader
train_loader = load_dataloader(conf, train_dataset, rank=rank, world_size=world_size, is_train=True)
valid_loader = load_dataloader(conf, valid_dataset, rank=rank, world_size=world_size, is_train=False)
train_loader = load_dataloader(
conf, train_dataset, rank=rank, world_size=world_size, is_train=True
)
valid_loader = load_dataloader(
conf, valid_dataset, rank=rank, world_size=world_size, is_train=False
)

# model
m = load_model(conf)
Expand Down
8 changes: 5 additions & 3 deletions config/example-v2025.2.0.yml
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,8 @@ trainer:
cpu_offload: False
# save forward pass activation to checkpoints and free GPU memory
activation_checkpoint: True
# Set to True for all layers with activations otherwise use custom (see credit/distributed.py)
checkpoint_all_layers: False

# (optional) use torch.compile: False. May not be compatible with custom models
compile: False
Expand Down Expand Up @@ -239,11 +241,11 @@ trainer:
reduce_dtype: "float32"
buffer_dtype: "float32"

# rescale loss as loss = loss / grad_accum_every
# rescale loss as loss = loss / grad_accum_every -- currently is not being used
grad_accum_every: 1

# gradient clipping. Set to 0 to ignore
grad_max_norm: 1.0
# gradient clipping. Set to 'dynamic' to compute global norm (default). Set to 0 to ignore entirely.
grad_max_norm: 'dynamic'

# number of CPU workers used in datasets/dataloaders
thread_workers: 4
Expand Down
151 changes: 84 additions & 67 deletions credit/distributed.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch.distributed as dist
import torch.nn as nn
import numpy as np
import socket
import torch
Expand All @@ -19,6 +20,7 @@
apply_activation_checkpointing,
)
from credit.models.checkpoint import TorchFSDPModel
from credit.models import load_fsdp_or_checkpoint_policy
from torch.nn.parallel import DistributedDataParallel as DDP
from credit.mixed_precision import parse_dtype
import functools
Expand Down Expand Up @@ -79,7 +81,10 @@ def get_rank_info(trainer_mode):
WORLD_SIZE = int(os.environ["PMI_SIZE"])
WORLD_RANK = int(os.environ["PMI_RANK"])
else:
sys.exit("Can't find the environment variables for local rank")
sys.exit(
"Can't find the environment variables for local rank. "
"If you are on casper you'll want to use torchrun for now."
)

# Set MASTER_ADDR and MASTER_PORT if not already set
if "MASTER_ADDR" not in os.environ:
Expand All @@ -94,6 +99,42 @@ def get_rank_info(trainer_mode):
return LOCAL_RANK, WORLD_RANK, WORLD_SIZE


def should_not_checkpoint(module):
exclude_types = (
# Regularization & Normalization
nn.Dropout,
nn.BatchNorm1d,
nn.BatchNorm2d,
nn.BatchNorm3d,
nn.LayerNorm,
# Activations (stateless, cheap to recompute)
nn.ReLU,
nn.GELU,
nn.SiLU,
nn.Sigmoid,
nn.Tanh,
# Pooling (lightweight, no significant memory savings)
nn.MaxPool1d,
nn.MaxPool2d,
nn.MaxPool3d,
nn.AvgPool1d,
nn.AvgPool2d,
nn.AvgPool3d,
nn.AdaptiveMaxPool1d,
nn.AdaptiveMaxPool2d,
nn.AdaptiveMaxPool3d,
nn.AdaptiveAvgPool1d,
nn.AdaptiveAvgPool2d,
nn.AdaptiveAvgPool3d,
# Embeddings (usually large but don’t require recomputation)
nn.Embedding,
# Identity & Reshaping (no computation)
nn.Identity,
nn.Flatten,
)
return isinstance(module, exclude_types)


def distributed_model_wrapper(conf, neural_network, device):
"""Wraps the neural network model for distributed training.
Expand All @@ -109,52 +150,34 @@ def distributed_model_wrapper(conf, neural_network, device):
# convert $USER to the actual user name
conf["save_loc"] = os.path.expandvars(conf["save_loc"])

# FSDP polices
if conf["trainer"]["mode"] == "fsdp":
# Define the sharding policies
# crossformer
if "crossformer" in conf["model"]["type"]:
from credit.models.crossformer import (
Attention,
DynamicPositionBias,
FeedForward,
CrossEmbedLayer,
)
mode = conf["trainer"]["mode"]

transformer_layers_cls = {
Attention,
DynamicPositionBias,
FeedForward,
CrossEmbedLayer,
}

# FuXi
# FuXi supports "spectral_nrom = True" only
elif "fuxi" in conf["model"]["type"]:
from timm.models.swin_transformer_v2 import SwinTransformerV2Stage

transformer_layers_cls = {SwinTransformerV2Stage}

# Swin by itself
elif "swin" in conf["model"]["type"]:
from credit.models.swin import (
SwinTransformerV2CrBlock,
WindowMultiHeadAttentionNoPos,
WindowMultiHeadAttention,
activation_checkpoint = (
conf["trainer"]["activation_checkpoint"]
if "activation_checkpoint" in conf["trainer"]
else False
)
checkpoint_all_layers = conf["trainer"].get("checkpoint_all_layers", False)

# Configure FSDP layers for paralle policies AND/OR activation checkpointing
# in either DDP or FSDP
if mode == "fsdp" or activation_checkpoint:
transformer_layers_cls = load_fsdp_or_checkpoint_policy(conf)

# logger announcement
if activation_checkpoint:
logging.info(f"Activation checkpointing on {mode}: {activation_checkpoint}")
if checkpoint_all_layers:
logging.info("Checkpointing all available layers in your model")
logging.warning(
"This may cause performance degredation -- consider supplying a list to checkpoint"
)

transformer_layers_cls = {
SwinTransformerV2CrBlock,
WindowMultiHeadAttentionNoPos,
WindowMultiHeadAttention,
}

# other models not supported
else:
raise OSError(
"You asked for FSDP but only crossformer and fuxi are currently supported."
)
logging.info(f"Checkpointing custom layers {transformer_layers_cls}")

# FSDP polices
if conf["trainer"]["mode"] == "fsdp":
# Define the sharding policies
auto_wrap_policy1 = functools.partial(
transformer_auto_wrap_policy, transformer_layer_cls=transformer_layers_cls
)
Expand Down Expand Up @@ -208,38 +231,32 @@ def combined_auto_wrap_policy(module, recurse, nonwrapped_numel):
cpu_offload=CPUOffload(offload_params=cpu_offload),
)

# activation checkpointing on the transformer blocks

activation_checkpoint = (
conf["trainer"]["activation_checkpoint"]
if "activation_checkpoint" in conf["trainer"]
else False
)
elif conf["trainer"]["mode"] == "ddp":
model = DDP(neural_network, device_ids=[device])

logging.info(f"Activation checkpointing: {activation_checkpoint}")
else:
model = neural_network

if activation_checkpoint:
# https://pytorch.org/blog/efficient-large-scale-training-with-pytorch/
if activation_checkpoint:
# https://pytorch.org/blog/efficient-large-scale-training-with-pytorch/

non_reentrant_wrapper = functools.partial(
checkpoint_wrapper,
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
)
non_reentrant_wrapper = functools.partial(
checkpoint_wrapper,
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
)

if checkpoint_all_layers:
check_fn = lambda submodule: not should_not_checkpoint(submodule)
else:
check_fn = lambda submodule: any(
isinstance(submodule, cls) for cls in transformer_layers_cls
)

apply_activation_checkpointing(
model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn
)

# attempting to get around the launch issue we are having
torch.distributed.barrier()
apply_activation_checkpointing(
model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn
)

elif conf["trainer"]["mode"] == "ddp":
model = DDP(neural_network, device_ids=[device])
else:
model = neural_network
# attempting to get around the launch issue we are having
torch.distributed.barrier()

return model
51 changes: 50 additions & 1 deletion credit/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,59 @@
"fuxi": (Fuxi, "Loading Fuxi model"),
"swin": (SwinTransformerV2Cr, "Loading the minimal Swin model"),
"graph": (GraphResTransfGRU, "Loading Graph Residual Transformer GRU model"),
"debugger": (DebuggerModel, "Loading the debugger model")
"debugger": (DebuggerModel, "Loading the debugger model"),
}


# Define FSDP sharding and/or checkpointing policy
def load_fsdp_or_checkpoint_policy(conf):
# crossformer
if "crossformer" in conf["model"]["type"]:
from credit.models.crossformer import (
Attention,
DynamicPositionBias,
FeedForward,
CrossEmbedLayer,
)

transformer_layers_cls = {
Attention,
DynamicPositionBias,
FeedForward,
CrossEmbedLayer,
}

# FuXi
# FuXi supports "spectral_norm = True" only
elif "fuxi" in conf["model"]["type"]:
from timm.models.swin_transformer_v2 import SwinTransformerV2Stage

transformer_layers_cls = {SwinTransformerV2Stage}

# Swin by itself
elif "swin" in conf["model"]["type"]:
from credit.models.swin import (
SwinTransformerV2CrBlock,
WindowMultiHeadAttentionNoPos,
WindowMultiHeadAttention,
)

transformer_layers_cls = {
SwinTransformerV2CrBlock,
WindowMultiHeadAttentionNoPos,
WindowMultiHeadAttention,
}

# other models not supported
else:
raise OSError(
"You asked for FSDP but only crossformer, swin, and fuxi are currently supported.",
"See credit/models/__init__.py for examples on adding new models",
)

return transformer_layers_cls


def load_model(conf, load_weights=False):
conf = copy.deepcopy(conf)
model_conf = conf["model"]
Expand Down
7 changes: 6 additions & 1 deletion credit/seed.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,14 @@
def seed_everything(seed=1234):
random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" # Add this line
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.backends.cudnn.benchmark = True
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.use_deterministic_algorithms(True)
torch.backends.cudnn.allow_tf32 = False # Disable TensorFloat32 for exact FP32 math
torch.backends.cuda.matmul.allow_tf32 = False
Loading

0 comments on commit b27e22f

Please sign in to comment.