diff --git a/applications/train_universal.py b/applications/train_universal.py index d7937c6..122bb1a 100644 --- a/applications/train_universal.py +++ b/applications/train_universal.py @@ -27,15 +27,16 @@ from credit.scheduler import load_scheduler from credit.trainers import load_trainer from credit.parser import credit_main_parser, training_data_check -from credit.datasets.load_dataset_and_dataloader import ( - load_dataset, - load_dataloader -) +from credit.datasets.load_dataset_and_dataloader import load_dataset, load_dataloader from credit.metrics import LatWeightedMetrics from credit.pbs import launch_script, launch_script_mpi from credit.models import load_model -from credit.models.checkpoint import FSDPOptimizerWrapper, TorchFSDPCheckpointIO, load_state_dict_error_handler +from credit.models.checkpoint import ( + FSDPOptimizerWrapper, + TorchFSDPCheckpointIO, + load_state_dict_error_handler, +) warnings.filterwarnings("ignore") @@ -46,6 +47,9 @@ os.environ["MKL_NUM_THREADS"] = "1" # https://stackoverflow.com/questions/59129812/how-to-avoid-cuda-out-of-memory-in-pytorch os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +# os.environ["NCCL_P2P_DISABLE"] = "1" +# os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "1" + def load_model_states_and_optimizer(conf, model, device): """ @@ -141,15 +145,17 @@ def load_model_states_and_optimizer(conf, model, device): logging.info( f"Loading DDP model, optimizer, grad scaler, and learning rate scheduler states from {save_loc}" ) - load_msg = model.module.load_state_dict(checkpoint["model_state_dict"], - strict=False) + load_msg = model.module.load_state_dict( + checkpoint["model_state_dict"], strict=False + ) load_state_dict_error_handler(load_msg) else: logging.info( f"Loading model, optimizer, grad scaler, and learning rate scheduler states from {save_loc}" ) - load_msg = model.load_state_dict(checkpoint["model_state_dict"], - strict=False) + load_msg = model.load_state_dict( + checkpoint["model_state_dict"], strict=False + ) load_state_dict_error_handler(load_msg) # Load the learning rate scheduler and mixed precision grad scaler @@ -195,15 +201,17 @@ def load_model_states_and_optimizer(conf, model, device): logging.info( f"Loading DDP model, optimizer, grad scaler, and learning rate scheduler states from {save_loc}" ) - load_msg = model.module.load_state_dict(checkpoint["model_state_dict"], - strict=False) + load_msg = model.module.load_state_dict( + checkpoint["model_state_dict"], strict=False + ) load_state_dict_error_handler(load_msg) else: logging.info( f"Loading model, optimizer, grad scaler, and learning rate scheduler states from {save_loc}" ) - load_msg = model.load_state_dict(checkpoint["model_state_dict"], - strict=False) + load_msg = model.load_state_dict( + checkpoint["model_state_dict"], strict=False + ) load_state_dict_error_handler(load_msg) optimizer = torch.optim.AdamW( @@ -287,8 +295,12 @@ def main(rank, world_size, conf, backend, trial=False): valid_dataset = load_dataset(conf, rank=rank, world_size=world_size, is_train=False) # Load the dataloader - train_loader = load_dataloader(conf, train_dataset, rank=rank, world_size=world_size, is_train=True) - valid_loader = load_dataloader(conf, valid_dataset, rank=rank, world_size=world_size, is_train=False) + train_loader = load_dataloader( + conf, train_dataset, rank=rank, world_size=world_size, is_train=True + ) + valid_loader = load_dataloader( + conf, valid_dataset, rank=rank, world_size=world_size, is_train=False + ) # model m = load_model(conf) diff --git a/config/example-v2025.2.0.yml b/config/example-v2025.2.0.yml index 2d63ee6..8f58013 100644 --- a/config/example-v2025.2.0.yml +++ b/config/example-v2025.2.0.yml @@ -153,6 +153,8 @@ trainer: cpu_offload: False # save forward pass activation to checkpoints and free GPU memory activation_checkpoint: True + # Set to True for all layers with activations otherwise use custom (see credit/distributed.py) + checkpoint_all_layers: False # (optional) use torch.compile: False. May not be compatible with custom models compile: False @@ -239,11 +241,11 @@ trainer: reduce_dtype: "float32" buffer_dtype: "float32" - # rescale loss as loss = loss / grad_accum_every + # rescale loss as loss = loss / grad_accum_every -- currently is not being used grad_accum_every: 1 - # gradient clipping. Set to 0 to ignore - grad_max_norm: 1.0 + # gradient clipping. Set to 'dynamic' to compute global norm (default). Set to 0 to ignore entirely. + grad_max_norm: 'dynamic' # number of CPU workers used in datasets/dataloaders thread_workers: 4 diff --git a/credit/distributed.py b/credit/distributed.py index 09e302e..ffe5b25 100644 --- a/credit/distributed.py +++ b/credit/distributed.py @@ -1,4 +1,5 @@ import torch.distributed as dist +import torch.nn as nn import numpy as np import socket import torch @@ -19,6 +20,7 @@ apply_activation_checkpointing, ) from credit.models.checkpoint import TorchFSDPModel +from credit.models import load_fsdp_or_checkpoint_policy from torch.nn.parallel import DistributedDataParallel as DDP from credit.mixed_precision import parse_dtype import functools @@ -79,7 +81,10 @@ def get_rank_info(trainer_mode): WORLD_SIZE = int(os.environ["PMI_SIZE"]) WORLD_RANK = int(os.environ["PMI_RANK"]) else: - sys.exit("Can't find the environment variables for local rank") + sys.exit( + "Can't find the environment variables for local rank. " + "If you are on casper you'll want to use torchrun for now." + ) # Set MASTER_ADDR and MASTER_PORT if not already set if "MASTER_ADDR" not in os.environ: @@ -94,6 +99,42 @@ def get_rank_info(trainer_mode): return LOCAL_RANK, WORLD_RANK, WORLD_SIZE +def should_not_checkpoint(module): + exclude_types = ( + # Regularization & Normalization + nn.Dropout, + nn.BatchNorm1d, + nn.BatchNorm2d, + nn.BatchNorm3d, + nn.LayerNorm, + # Activations (stateless, cheap to recompute) + nn.ReLU, + nn.GELU, + nn.SiLU, + nn.Sigmoid, + nn.Tanh, + # Pooling (lightweight, no significant memory savings) + nn.MaxPool1d, + nn.MaxPool2d, + nn.MaxPool3d, + nn.AvgPool1d, + nn.AvgPool2d, + nn.AvgPool3d, + nn.AdaptiveMaxPool1d, + nn.AdaptiveMaxPool2d, + nn.AdaptiveMaxPool3d, + nn.AdaptiveAvgPool1d, + nn.AdaptiveAvgPool2d, + nn.AdaptiveAvgPool3d, + # Embeddings (usually large but don’t require recomputation) + nn.Embedding, + # Identity & Reshaping (no computation) + nn.Identity, + nn.Flatten, + ) + return isinstance(module, exclude_types) + + def distributed_model_wrapper(conf, neural_network, device): """Wraps the neural network model for distributed training. @@ -109,52 +150,34 @@ def distributed_model_wrapper(conf, neural_network, device): # convert $USER to the actual user name conf["save_loc"] = os.path.expandvars(conf["save_loc"]) - # FSDP polices - if conf["trainer"]["mode"] == "fsdp": - # Define the sharding policies - # crossformer - if "crossformer" in conf["model"]["type"]: - from credit.models.crossformer import ( - Attention, - DynamicPositionBias, - FeedForward, - CrossEmbedLayer, - ) + mode = conf["trainer"]["mode"] - transformer_layers_cls = { - Attention, - DynamicPositionBias, - FeedForward, - CrossEmbedLayer, - } - - # FuXi - # FuXi supports "spectral_nrom = True" only - elif "fuxi" in conf["model"]["type"]: - from timm.models.swin_transformer_v2 import SwinTransformerV2Stage - - transformer_layers_cls = {SwinTransformerV2Stage} - - # Swin by itself - elif "swin" in conf["model"]["type"]: - from credit.models.swin import ( - SwinTransformerV2CrBlock, - WindowMultiHeadAttentionNoPos, - WindowMultiHeadAttention, + activation_checkpoint = ( + conf["trainer"]["activation_checkpoint"] + if "activation_checkpoint" in conf["trainer"] + else False + ) + checkpoint_all_layers = conf["trainer"].get("checkpoint_all_layers", False) + + # Configure FSDP layers for paralle policies AND/OR activation checkpointing + # in either DDP or FSDP + if mode == "fsdp" or activation_checkpoint: + transformer_layers_cls = load_fsdp_or_checkpoint_policy(conf) + + # logger announcement + if activation_checkpoint: + logging.info(f"Activation checkpointing on {mode}: {activation_checkpoint}") + if checkpoint_all_layers: + logging.info("Checkpointing all available layers in your model") + logging.warning( + "This may cause performance degredation -- consider supplying a list to checkpoint" ) - - transformer_layers_cls = { - SwinTransformerV2CrBlock, - WindowMultiHeadAttentionNoPos, - WindowMultiHeadAttention, - } - - # other models not supported else: - raise OSError( - "You asked for FSDP but only crossformer and fuxi are currently supported." - ) + logging.info(f"Checkpointing custom layers {transformer_layers_cls}") + # FSDP polices + if conf["trainer"]["mode"] == "fsdp": + # Define the sharding policies auto_wrap_policy1 = functools.partial( transformer_auto_wrap_policy, transformer_layer_cls=transformer_layers_cls ) @@ -208,38 +231,32 @@ def combined_auto_wrap_policy(module, recurse, nonwrapped_numel): cpu_offload=CPUOffload(offload_params=cpu_offload), ) - # activation checkpointing on the transformer blocks - - activation_checkpoint = ( - conf["trainer"]["activation_checkpoint"] - if "activation_checkpoint" in conf["trainer"] - else False - ) + elif conf["trainer"]["mode"] == "ddp": + model = DDP(neural_network, device_ids=[device]) - logging.info(f"Activation checkpointing: {activation_checkpoint}") + else: + model = neural_network - if activation_checkpoint: - # https://pytorch.org/blog/efficient-large-scale-training-with-pytorch/ + if activation_checkpoint: + # https://pytorch.org/blog/efficient-large-scale-training-with-pytorch/ - non_reentrant_wrapper = functools.partial( - checkpoint_wrapper, - checkpoint_impl=CheckpointImpl.NO_REENTRANT, - ) + non_reentrant_wrapper = functools.partial( + checkpoint_wrapper, + checkpoint_impl=CheckpointImpl.NO_REENTRANT, + ) + if checkpoint_all_layers: + check_fn = lambda submodule: not should_not_checkpoint(submodule) + else: check_fn = lambda submodule: any( isinstance(submodule, cls) for cls in transformer_layers_cls ) - apply_activation_checkpointing( - model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn - ) - - # attempting to get around the launch issue we are having - torch.distributed.barrier() + apply_activation_checkpointing( + model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn + ) - elif conf["trainer"]["mode"] == "ddp": - model = DDP(neural_network, device_ids=[device]) - else: - model = neural_network + # attempting to get around the launch issue we are having + torch.distributed.barrier() return model diff --git a/credit/models/__init__.py b/credit/models/__init__.py index 6169593..c06e810 100644 --- a/credit/models/__init__.py +++ b/credit/models/__init__.py @@ -49,10 +49,59 @@ "fuxi": (Fuxi, "Loading Fuxi model"), "swin": (SwinTransformerV2Cr, "Loading the minimal Swin model"), "graph": (GraphResTransfGRU, "Loading Graph Residual Transformer GRU model"), - "debugger": (DebuggerModel, "Loading the debugger model") + "debugger": (DebuggerModel, "Loading the debugger model"), } +# Define FSDP sharding and/or checkpointing policy +def load_fsdp_or_checkpoint_policy(conf): + # crossformer + if "crossformer" in conf["model"]["type"]: + from credit.models.crossformer import ( + Attention, + DynamicPositionBias, + FeedForward, + CrossEmbedLayer, + ) + + transformer_layers_cls = { + Attention, + DynamicPositionBias, + FeedForward, + CrossEmbedLayer, + } + + # FuXi + # FuXi supports "spectral_norm = True" only + elif "fuxi" in conf["model"]["type"]: + from timm.models.swin_transformer_v2 import SwinTransformerV2Stage + + transformer_layers_cls = {SwinTransformerV2Stage} + + # Swin by itself + elif "swin" in conf["model"]["type"]: + from credit.models.swin import ( + SwinTransformerV2CrBlock, + WindowMultiHeadAttentionNoPos, + WindowMultiHeadAttention, + ) + + transformer_layers_cls = { + SwinTransformerV2CrBlock, + WindowMultiHeadAttentionNoPos, + WindowMultiHeadAttention, + } + + # other models not supported + else: + raise OSError( + "You asked for FSDP but only crossformer, swin, and fuxi are currently supported.", + "See credit/models/__init__.py for examples on adding new models", + ) + + return transformer_layers_cls + + def load_model(conf, load_weights=False): conf = copy.deepcopy(conf) model_conf = conf["model"] diff --git a/credit/seed.py b/credit/seed.py index 4f06cc4..9c20633 100644 --- a/credit/seed.py +++ b/credit/seed.py @@ -7,9 +7,14 @@ def seed_everything(seed=1234): random.seed(seed) os.environ["PYTHONHASHSEED"] = str(seed) + os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" # Add this line np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) - torch.backends.cudnn.benchmark = True + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True + torch.use_deterministic_algorithms(True) + torch.backends.cudnn.allow_tf32 = False # Disable TensorFloat32 for exact FP32 math + torch.backends.cuda.matmul.allow_tf32 = False diff --git a/credit/trainers/trainerERA5_multistep_grad_accum.py b/credit/trainers/trainerERA5_multistep_grad_accum.py index 86cb984..9ffce7a 100644 --- a/credit/trainers/trainerERA5_multistep_grad_accum.py +++ b/credit/trainers/trainerERA5_multistep_grad_accum.py @@ -74,13 +74,14 @@ def train_one_epoch( """ batches_per_epoch = conf["trainer"]["batches_per_epoch"] - grad_max_norm = conf["trainer"]["grad_max_norm"] + grad_max_norm = conf["trainer"].get("grad_max_norm", 0.0) amp = conf["trainer"]["amp"] distributed = True if conf["trainer"]["mode"] in ["fsdp", "ddp"] else False forecast_length = conf["data"]["forecast_len"] ensemble_size = conf["trainer"].get("ensemble_size", 1) if ensemble_size > 1: logger.info(f"ensemble training with ensemble_size {ensemble_size}") + logger.info(f"Using grad-max-norm value: {grad_max_norm}") # number of diagnostic variables varnum_diag = len(conf["data"]["diagnostic_variables"]) @@ -190,10 +191,10 @@ def train_one_epoch( else: # no x_surf x = reshape_only(batch["x"]).to(self.device) # .float() - + # --------------------------------------------- # # ensemble x and x_surf on initialization - # copies each sample in the batch ensemble_size number of times. + # copies each sample in the batch ensemble_size number of times. # if samples in the batch are ordered (x,y,z) then the result tensor is (x, x, ..., y, y, ..., z,z ...) # WARNING: needs to be used with a loss that can handle x with b * ensemble_size samples and y with b samples if ensemble_size > 1: @@ -207,8 +208,10 @@ def train_one_epoch( ) # .float() # ---------------- ensemble ----------------- # # ensemble x_forcing_batch for concat. see above for explanation of code - if ensemble_size > 1: - x_forcing_batch = torch.repeat_interleave(x_forcing_batch, ensemble_size, 0) + if ensemble_size > 1: + x_forcing_batch = torch.repeat_interleave( + x_forcing_batch, ensemble_size, 0 + ) # --------------------------------------------- # # concat on var dimension @@ -320,10 +323,34 @@ def train_one_epoch( if distributed: torch.distributed.barrier() + # Grad norm clipping scaler.unscale_(optimizer) - torch.nn.utils.clip_grad_norm_( - self.model.parameters(), max_norm=grad_max_norm - ) + if grad_max_norm == "dynamic": + # Compute local L2 norm + local_norm = torch.norm( + torch.stack( + [ + p.grad.detach().norm(2) + for p in self.model.parameters() + if p.grad is not None + ] + ) + ) + + # All-reduce to get global norm across ranks + dist.all_reduce(local_norm, op=dist.ReduceOp.SUM) + global_norm = local_norm.sqrt() # Compute total global norm + + # Clip gradients using the global norm + torch.nn.utils.clip_grad_norm_( + self.model.parameters(), max_norm=global_norm + ) + elif grad_max_norm > 0.0: + torch.nn.utils.clip_grad_norm_( + self.model.parameters(), max_norm=grad_max_norm + ) + + # Step optimizer scaler.step(optimizer) scaler.update() optimizer.zero_grad() @@ -507,7 +534,7 @@ def validate(self, epoch, conf, valid_loader, criterion, metrics): x = reshape_only(batch["x"]).to(self.device) # .float() # --------------------------------------------- # # ensemble x and x_surf on initialization - # copies each sample in the batch ensemble_size number of times. + # copies each sample in the batch ensemble_size number of times. # if samples in the batch are ordered (x,y,z) then the result tensor is (x, x, ..., y, y, ..., z,z ...) # WARNING: needs to be used with a loss that can handle x with b * ensemble_size samples and y with b samples if ensemble_size > 1: @@ -523,8 +550,10 @@ def validate(self, epoch, conf, valid_loader, criterion, metrics): ) # .float() # ---------------- ensemble ----------------- # # ensemble x_forcing_batch for concat. see above for explanation of code - if ensemble_size > 1: - x_forcing_batch = torch.repeat_interleave(x_forcing_batch, ensemble_size, 0) + if ensemble_size > 1: + x_forcing_batch = torch.repeat_interleave( + x_forcing_batch, ensemble_size, 0 + ) # --------------------------------------------- # # concat on var dimension