Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updates to grad clipping, seed, and checkpointing #152

Merged
merged 5 commits into from
Feb 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 27 additions & 15 deletions applications/train_universal.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,16 @@
from credit.scheduler import load_scheduler
from credit.trainers import load_trainer
from credit.parser import credit_main_parser, training_data_check
from credit.datasets.load_dataset_and_dataloader import (
load_dataset,
load_dataloader
)
from credit.datasets.load_dataset_and_dataloader import load_dataset, load_dataloader

from credit.metrics import LatWeightedMetrics
from credit.pbs import launch_script, launch_script_mpi
from credit.models import load_model
from credit.models.checkpoint import FSDPOptimizerWrapper, TorchFSDPCheckpointIO, load_state_dict_error_handler
from credit.models.checkpoint import (
FSDPOptimizerWrapper,
TorchFSDPCheckpointIO,
load_state_dict_error_handler,
)


warnings.filterwarnings("ignore")
Expand All @@ -46,6 +47,9 @@
os.environ["MKL_NUM_THREADS"] = "1"
# https://stackoverflow.com/questions/59129812/how-to-avoid-cuda-out-of-memory-in-pytorch
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
# os.environ["NCCL_P2P_DISABLE"] = "1"
# os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "1"


def load_model_states_and_optimizer(conf, model, device):
"""
Expand Down Expand Up @@ -141,15 +145,17 @@ def load_model_states_and_optimizer(conf, model, device):
logging.info(
f"Loading DDP model, optimizer, grad scaler, and learning rate scheduler states from {save_loc}"
)
load_msg = model.module.load_state_dict(checkpoint["model_state_dict"],
strict=False)
load_msg = model.module.load_state_dict(
checkpoint["model_state_dict"], strict=False
)
load_state_dict_error_handler(load_msg)
else:
logging.info(
f"Loading model, optimizer, grad scaler, and learning rate scheduler states from {save_loc}"
)
load_msg = model.load_state_dict(checkpoint["model_state_dict"],
strict=False)
load_msg = model.load_state_dict(
checkpoint["model_state_dict"], strict=False
)
load_state_dict_error_handler(load_msg)

# Load the learning rate scheduler and mixed precision grad scaler
Expand Down Expand Up @@ -195,15 +201,17 @@ def load_model_states_and_optimizer(conf, model, device):
logging.info(
f"Loading DDP model, optimizer, grad scaler, and learning rate scheduler states from {save_loc}"
)
load_msg = model.module.load_state_dict(checkpoint["model_state_dict"],
strict=False)
load_msg = model.module.load_state_dict(
checkpoint["model_state_dict"], strict=False
)
load_state_dict_error_handler(load_msg)
else:
logging.info(
f"Loading model, optimizer, grad scaler, and learning rate scheduler states from {save_loc}"
)
load_msg = model.load_state_dict(checkpoint["model_state_dict"],
strict=False)
load_msg = model.load_state_dict(
checkpoint["model_state_dict"], strict=False
)
load_state_dict_error_handler(load_msg)

optimizer = torch.optim.AdamW(
Expand Down Expand Up @@ -287,8 +295,12 @@ def main(rank, world_size, conf, backend, trial=False):
valid_dataset = load_dataset(conf, rank=rank, world_size=world_size, is_train=False)

# Load the dataloader
train_loader = load_dataloader(conf, train_dataset, rank=rank, world_size=world_size, is_train=True)
valid_loader = load_dataloader(conf, valid_dataset, rank=rank, world_size=world_size, is_train=False)
train_loader = load_dataloader(
conf, train_dataset, rank=rank, world_size=world_size, is_train=True
)
valid_loader = load_dataloader(
conf, valid_dataset, rank=rank, world_size=world_size, is_train=False
)

# model
m = load_model(conf)
Expand Down
8 changes: 5 additions & 3 deletions config/example-v2025.2.0.yml
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,8 @@ trainer:
cpu_offload: False
# save forward pass activation to checkpoints and free GPU memory
activation_checkpoint: True
# Set to True for all layers with activations otherwise use custom (see credit/distributed.py)
checkpoint_all_layers: False

# (optional) use torch.compile: False. May not be compatible with custom models
compile: False
Expand Down Expand Up @@ -239,11 +241,11 @@ trainer:
reduce_dtype: "float32"
buffer_dtype: "float32"

# rescale loss as loss = loss / grad_accum_every
# rescale loss as loss = loss / grad_accum_every -- currently is not being used
grad_accum_every: 1

# gradient clipping. Set to 0 to ignore
grad_max_norm: 1.0
# gradient clipping. Set to 'dynamic' to compute global norm (default). Set to 0 to ignore entirely.
grad_max_norm: 'dynamic'

# number of CPU workers used in datasets/dataloaders
thread_workers: 4
Expand Down
151 changes: 84 additions & 67 deletions credit/distributed.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch.distributed as dist
import torch.nn as nn
import numpy as np
import socket
import torch
Expand All @@ -19,6 +20,7 @@
apply_activation_checkpointing,
)
from credit.models.checkpoint import TorchFSDPModel
from credit.models import load_fsdp_or_checkpoint_policy
from torch.nn.parallel import DistributedDataParallel as DDP
from credit.mixed_precision import parse_dtype
import functools
Expand Down Expand Up @@ -79,7 +81,10 @@ def get_rank_info(trainer_mode):
WORLD_SIZE = int(os.environ["PMI_SIZE"])
WORLD_RANK = int(os.environ["PMI_RANK"])
else:
sys.exit("Can't find the environment variables for local rank")
sys.exit(
"Can't find the environment variables for local rank. "
"If you are on casper you'll want to use torchrun for now."
)

# Set MASTER_ADDR and MASTER_PORT if not already set
if "MASTER_ADDR" not in os.environ:
Expand All @@ -94,6 +99,42 @@ def get_rank_info(trainer_mode):
return LOCAL_RANK, WORLD_RANK, WORLD_SIZE


def should_not_checkpoint(module):
exclude_types = (
# Regularization & Normalization
nn.Dropout,
nn.BatchNorm1d,
nn.BatchNorm2d,
nn.BatchNorm3d,
nn.LayerNorm,
# Activations (stateless, cheap to recompute)
nn.ReLU,
nn.GELU,
nn.SiLU,
nn.Sigmoid,
nn.Tanh,
# Pooling (lightweight, no significant memory savings)
nn.MaxPool1d,
nn.MaxPool2d,
nn.MaxPool3d,
nn.AvgPool1d,
nn.AvgPool2d,
nn.AvgPool3d,
nn.AdaptiveMaxPool1d,
nn.AdaptiveMaxPool2d,
nn.AdaptiveMaxPool3d,
nn.AdaptiveAvgPool1d,
nn.AdaptiveAvgPool2d,
nn.AdaptiveAvgPool3d,
# Embeddings (usually large but don’t require recomputation)
nn.Embedding,
# Identity & Reshaping (no computation)
nn.Identity,
nn.Flatten,
)
return isinstance(module, exclude_types)


def distributed_model_wrapper(conf, neural_network, device):
"""Wraps the neural network model for distributed training.

Expand All @@ -109,52 +150,34 @@ def distributed_model_wrapper(conf, neural_network, device):
# convert $USER to the actual user name
conf["save_loc"] = os.path.expandvars(conf["save_loc"])

# FSDP polices
if conf["trainer"]["mode"] == "fsdp":
# Define the sharding policies
# crossformer
if "crossformer" in conf["model"]["type"]:
from credit.models.crossformer import (
Attention,
DynamicPositionBias,
FeedForward,
CrossEmbedLayer,
)
mode = conf["trainer"]["mode"]

transformer_layers_cls = {
Attention,
DynamicPositionBias,
FeedForward,
CrossEmbedLayer,
}

# FuXi
# FuXi supports "spectral_nrom = True" only
elif "fuxi" in conf["model"]["type"]:
from timm.models.swin_transformer_v2 import SwinTransformerV2Stage

transformer_layers_cls = {SwinTransformerV2Stage}

# Swin by itself
elif "swin" in conf["model"]["type"]:
from credit.models.swin import (
SwinTransformerV2CrBlock,
WindowMultiHeadAttentionNoPos,
WindowMultiHeadAttention,
activation_checkpoint = (
conf["trainer"]["activation_checkpoint"]
if "activation_checkpoint" in conf["trainer"]
else False
)
checkpoint_all_layers = conf["trainer"].get("checkpoint_all_layers", False)

# Configure FSDP layers for paralle policies AND/OR activation checkpointing
# in either DDP or FSDP
if mode == "fsdp" or activation_checkpoint:
transformer_layers_cls = load_fsdp_or_checkpoint_policy(conf)

# logger announcement
if activation_checkpoint:
logging.info(f"Activation checkpointing on {mode}: {activation_checkpoint}")
if checkpoint_all_layers:
logging.info("Checkpointing all available layers in your model")
logging.warning(
"This may cause performance degredation -- consider supplying a list to checkpoint"
)

transformer_layers_cls = {
SwinTransformerV2CrBlock,
WindowMultiHeadAttentionNoPos,
WindowMultiHeadAttention,
}

# other models not supported
else:
raise OSError(
"You asked for FSDP but only crossformer and fuxi are currently supported."
)
logging.info(f"Checkpointing custom layers {transformer_layers_cls}")

# FSDP polices
if conf["trainer"]["mode"] == "fsdp":
# Define the sharding policies
auto_wrap_policy1 = functools.partial(
transformer_auto_wrap_policy, transformer_layer_cls=transformer_layers_cls
)
Expand Down Expand Up @@ -208,38 +231,32 @@ def combined_auto_wrap_policy(module, recurse, nonwrapped_numel):
cpu_offload=CPUOffload(offload_params=cpu_offload),
)

# activation checkpointing on the transformer blocks

activation_checkpoint = (
conf["trainer"]["activation_checkpoint"]
if "activation_checkpoint" in conf["trainer"]
else False
)
elif conf["trainer"]["mode"] == "ddp":
model = DDP(neural_network, device_ids=[device])

logging.info(f"Activation checkpointing: {activation_checkpoint}")
else:
model = neural_network

if activation_checkpoint:
# https://pytorch.org/blog/efficient-large-scale-training-with-pytorch/
if activation_checkpoint:
# https://pytorch.org/blog/efficient-large-scale-training-with-pytorch/

non_reentrant_wrapper = functools.partial(
checkpoint_wrapper,
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
)
non_reentrant_wrapper = functools.partial(
checkpoint_wrapper,
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
)

if checkpoint_all_layers:
check_fn = lambda submodule: not should_not_checkpoint(submodule)
else:
check_fn = lambda submodule: any(
isinstance(submodule, cls) for cls in transformer_layers_cls
)

apply_activation_checkpointing(
model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn
)

# attempting to get around the launch issue we are having
torch.distributed.barrier()
apply_activation_checkpointing(
model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn
)

elif conf["trainer"]["mode"] == "ddp":
model = DDP(neural_network, device_ids=[device])
else:
model = neural_network
# attempting to get around the launch issue we are having
torch.distributed.barrier()

return model
51 changes: 50 additions & 1 deletion credit/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,59 @@
"fuxi": (Fuxi, "Loading Fuxi model"),
"swin": (SwinTransformerV2Cr, "Loading the minimal Swin model"),
"graph": (GraphResTransfGRU, "Loading Graph Residual Transformer GRU model"),
"debugger": (DebuggerModel, "Loading the debugger model")
"debugger": (DebuggerModel, "Loading the debugger model"),
}


# Define FSDP sharding and/or checkpointing policy
def load_fsdp_or_checkpoint_policy(conf):
# crossformer
if "crossformer" in conf["model"]["type"]:
from credit.models.crossformer import (
Attention,
DynamicPositionBias,
FeedForward,
CrossEmbedLayer,
)

transformer_layers_cls = {
Attention,
DynamicPositionBias,
FeedForward,
CrossEmbedLayer,
}

# FuXi
# FuXi supports "spectral_norm = True" only
elif "fuxi" in conf["model"]["type"]:
from timm.models.swin_transformer_v2 import SwinTransformerV2Stage

transformer_layers_cls = {SwinTransformerV2Stage}

# Swin by itself
elif "swin" in conf["model"]["type"]:
from credit.models.swin import (
SwinTransformerV2CrBlock,
WindowMultiHeadAttentionNoPos,
WindowMultiHeadAttention,
)

transformer_layers_cls = {
SwinTransformerV2CrBlock,
WindowMultiHeadAttentionNoPos,
WindowMultiHeadAttention,
}

# other models not supported
else:
raise OSError(
"You asked for FSDP but only crossformer, swin, and fuxi are currently supported.",
"See credit/models/__init__.py for examples on adding new models",
)

return transformer_layers_cls


def load_model(conf, load_weights=False):
conf = copy.deepcopy(conf)
model_conf = conf["model"]
Expand Down
7 changes: 6 additions & 1 deletion credit/seed.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,14 @@
def seed_everything(seed=1234):
random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" # Add this line
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.backends.cudnn.benchmark = True
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.use_deterministic_algorithms(True)
torch.backends.cudnn.allow_tf32 = False # Disable TensorFloat32 for exact FP32 math
torch.backends.cuda.matmul.allow_tf32 = False
Loading