Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updates to grad clipping, seed, and checkpointing #152

Merged
merged 5 commits into from
Feb 3, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 27 additions & 15 deletions applications/train_universal.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,16 @@
from credit.scheduler import load_scheduler
from credit.trainers import load_trainer
from credit.parser import credit_main_parser, training_data_check
from credit.datasets.load_dataset_and_dataloader import (
load_dataset,
load_dataloader
)
from credit.datasets.load_dataset_and_dataloader import load_dataset, load_dataloader

from credit.metrics import LatWeightedMetrics
from credit.pbs import launch_script, launch_script_mpi
from credit.models import load_model
from credit.models.checkpoint import FSDPOptimizerWrapper, TorchFSDPCheckpointIO, load_state_dict_error_handler
from credit.models.checkpoint import (
FSDPOptimizerWrapper,
TorchFSDPCheckpointIO,
load_state_dict_error_handler,
)


warnings.filterwarnings("ignore")
Expand All @@ -46,6 +47,9 @@
os.environ["MKL_NUM_THREADS"] = "1"
# https://stackoverflow.com/questions/59129812/how-to-avoid-cuda-out-of-memory-in-pytorch
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
# os.environ["NCCL_P2P_DISABLE"] = "1"
# os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "1"


def load_model_states_and_optimizer(conf, model, device):
"""
Expand Down Expand Up @@ -141,15 +145,17 @@ def load_model_states_and_optimizer(conf, model, device):
logging.info(
f"Loading DDP model, optimizer, grad scaler, and learning rate scheduler states from {save_loc}"
)
load_msg = model.module.load_state_dict(checkpoint["model_state_dict"],
strict=False)
load_msg = model.module.load_state_dict(
checkpoint["model_state_dict"], strict=False
)
load_state_dict_error_handler(load_msg)
else:
logging.info(
f"Loading model, optimizer, grad scaler, and learning rate scheduler states from {save_loc}"
)
load_msg = model.load_state_dict(checkpoint["model_state_dict"],
strict=False)
load_msg = model.load_state_dict(
checkpoint["model_state_dict"], strict=False
)
load_state_dict_error_handler(load_msg)

# Load the learning rate scheduler and mixed precision grad scaler
Expand Down Expand Up @@ -195,15 +201,17 @@ def load_model_states_and_optimizer(conf, model, device):
logging.info(
f"Loading DDP model, optimizer, grad scaler, and learning rate scheduler states from {save_loc}"
)
load_msg = model.module.load_state_dict(checkpoint["model_state_dict"],
strict=False)
load_msg = model.module.load_state_dict(
checkpoint["model_state_dict"], strict=False
)
load_state_dict_error_handler(load_msg)
else:
logging.info(
f"Loading model, optimizer, grad scaler, and learning rate scheduler states from {save_loc}"
)
load_msg = model.load_state_dict(checkpoint["model_state_dict"],
strict=False)
load_msg = model.load_state_dict(
checkpoint["model_state_dict"], strict=False
)
load_state_dict_error_handler(load_msg)

optimizer = torch.optim.AdamW(
Expand Down Expand Up @@ -287,8 +295,12 @@ def main(rank, world_size, conf, backend, trial=False):
valid_dataset = load_dataset(conf, rank=rank, world_size=world_size, is_train=False)

# Load the dataloader
train_loader = load_dataloader(conf, train_dataset, rank=rank, world_size=world_size, is_train=True)
valid_loader = load_dataloader(conf, valid_dataset, rank=rank, world_size=world_size, is_train=False)
train_loader = load_dataloader(
conf, train_dataset, rank=rank, world_size=world_size, is_train=True
)
valid_loader = load_dataloader(
conf, valid_dataset, rank=rank, world_size=world_size, is_train=False
)

# model
m = load_model(conf)
Expand Down
6 changes: 3 additions & 3 deletions config/example-v2025.2.0.yml
Original file line number Diff line number Diff line change
Expand Up @@ -239,11 +239,11 @@ trainer:
reduce_dtype: "float32"
buffer_dtype: "float32"

# rescale loss as loss = loss / grad_accum_every
# rescale loss as loss = loss / grad_accum_every -- currently is not being used
grad_accum_every: 1

# gradient clipping. Set to 0 to ignore
grad_max_norm: 1.0
# gradient clipping. Set to 'dynamic' to compute global norm (default). Set to 0 to ignore entirely.
grad_max_norm: 'dynamic'

# number of CPU workers used in datasets/dataloaders
thread_workers: 4
Expand Down
7 changes: 6 additions & 1 deletion credit/seed.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,14 @@
def seed_everything(seed=1234):
random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" # Add this line
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.backends.cudnn.benchmark = True
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.use_deterministic_algorithms(True)
torch.backends.cudnn.allow_tf32 = False # Disable TensorFloat32 for exact FP32 math
torch.backends.cuda.matmul.allow_tf32 = False
51 changes: 40 additions & 11 deletions credit/trainers/trainerERA5_multistep_grad_accum.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,14 @@ def train_one_epoch(
"""

batches_per_epoch = conf["trainer"]["batches_per_epoch"]
grad_max_norm = conf["trainer"]["grad_max_norm"]
grad_max_norm = conf["trainer"].get("grad_max_norm", 0.0)
amp = conf["trainer"]["amp"]
distributed = True if conf["trainer"]["mode"] in ["fsdp", "ddp"] else False
forecast_length = conf["data"]["forecast_len"]
ensemble_size = conf["trainer"].get("ensemble_size", 1)
if ensemble_size > 1:
logger.info(f"ensemble training with ensemble_size {ensemble_size}")
logger.info(f"Using grad-max-norm value: {grad_max_norm}")

# number of diagnostic variables
varnum_diag = len(conf["data"]["diagnostic_variables"])
Expand Down Expand Up @@ -190,10 +191,10 @@ def train_one_epoch(
else:
# no x_surf
x = reshape_only(batch["x"]).to(self.device) # .float()

# --------------------------------------------- #
# ensemble x and x_surf on initialization
# copies each sample in the batch ensemble_size number of times.
# copies each sample in the batch ensemble_size number of times.
# if samples in the batch are ordered (x,y,z) then the result tensor is (x, x, ..., y, y, ..., z,z ...)
# WARNING: needs to be used with a loss that can handle x with b * ensemble_size samples and y with b samples
if ensemble_size > 1:
Expand All @@ -207,8 +208,10 @@ def train_one_epoch(
) # .float()
# ---------------- ensemble ----------------- #
# ensemble x_forcing_batch for concat. see above for explanation of code
if ensemble_size > 1:
x_forcing_batch = torch.repeat_interleave(x_forcing_batch, ensemble_size, 0)
if ensemble_size > 1:
x_forcing_batch = torch.repeat_interleave(
x_forcing_batch, ensemble_size, 0
)
# --------------------------------------------- #

# concat on var dimension
Expand Down Expand Up @@ -320,10 +323,34 @@ def train_one_epoch(
if distributed:
torch.distributed.barrier()

# Grad norm clipping
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(
self.model.parameters(), max_norm=grad_max_norm
)
if grad_max_norm == "dynamic":
# Compute local L2 norm
local_norm = torch.norm(
torch.stack(
[
p.grad.detach().norm(2)
for p in self.model.parameters()
if p.grad is not None
]
)
)

# All-reduce to get global norm across ranks
dist.all_reduce(local_norm, op=dist.ReduceOp.SUM)
global_norm = local_norm.sqrt() # Compute total global norm

# Clip gradients using the global norm
torch.nn.utils.clip_grad_norm_(
self.model.parameters(), max_norm=global_norm
)
elif grad_max_norm > 0.0:
torch.nn.utils.clip_grad_norm_(
self.model.parameters(), max_norm=grad_max_norm
)

# Step optimizer
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
Expand Down Expand Up @@ -507,7 +534,7 @@ def validate(self, epoch, conf, valid_loader, criterion, metrics):
x = reshape_only(batch["x"]).to(self.device) # .float()
# --------------------------------------------- #
# ensemble x and x_surf on initialization
# copies each sample in the batch ensemble_size number of times.
# copies each sample in the batch ensemble_size number of times.
# if samples in the batch are ordered (x,y,z) then the result tensor is (x, x, ..., y, y, ..., z,z ...)
# WARNING: needs to be used with a loss that can handle x with b * ensemble_size samples and y with b samples
if ensemble_size > 1:
Expand All @@ -523,8 +550,10 @@ def validate(self, epoch, conf, valid_loader, criterion, metrics):
) # .float()
# ---------------- ensemble ----------------- #
# ensemble x_forcing_batch for concat. see above for explanation of code
if ensemble_size > 1:
x_forcing_batch = torch.repeat_interleave(x_forcing_batch, ensemble_size, 0)
if ensemble_size > 1:
x_forcing_batch = torch.repeat_interleave(
x_forcing_batch, ensemble_size, 0
)
# --------------------------------------------- #

# concat on var dimension
Expand Down