diff --git a/applications/train_universal.py b/applications/train_universal.py new file mode 100644 index 0000000..f01f696 --- /dev/null +++ b/applications/train_universal.py @@ -0,0 +1,482 @@ +""" +train.py +------------------------------------------------------- +""" + +import os +import sys +import yaml +import wandb +import optuna +import shutil +import logging +import warnings + +from pathlib import Path +from argparse import ArgumentParser +from echo.src.base_objective import BaseObjective + +import torch +from torch.cuda.amp import GradScaler +from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler +from credit.distributed import distributed_model_wrapper, setup, get_rank_info + +from credit.seed import seed_everything +from credit.loss import VariableTotalLoss2D + +from credit.scheduler import load_scheduler +from credit.trainers import load_trainer +from credit.parser import credit_main_parser, training_data_check +from credit.datasets.load_dataset_and_dataloader import ( + load_dataset, + load_dataloader +) + +from credit.metrics import LatWeightedMetrics +from credit.pbs import launch_script, launch_script_mpi +from credit.models import load_model +from credit.models.checkpoint import FSDPOptimizerWrapper, TorchFSDPCheckpointIO + + +warnings.filterwarnings("ignore") + + +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" +os.environ["OMP_NUM_THREADS"] = "1" +os.environ["MKL_NUM_THREADS"] = "1" +# https://stackoverflow.com/questions/59129812/how-to-avoid-cuda-out-of-memory-in-pytorch +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" + + +def load_model_states_and_optimizer(conf, model, device): + """ + Load the model states, optimizer, scheduler, and gradient scaler. + + Args: + conf (dict): Configuration dictionary containing training parameters. + model (torch.nn.Module): The model to be trained. + device (torch.device): The device (CPU or GPU) where the model is located. + + Returns: + tuple: A tuple containing the updated configuration, model, optimizer, scheduler, and scaler. + """ + + # convert $USER to the actual user name + conf["save_loc"] = save_loc = os.path.expandvars(conf["save_loc"]) + + # training hyperparameters + learning_rate = float(conf["trainer"]["learning_rate"]) + weight_decay = float(conf["trainer"]["weight_decay"]) + amp = conf["trainer"]["amp"] + + # load weights / states flags + load_weights = ( + False + if "load_weights" not in conf["trainer"] + else conf["trainer"]["load_weights"] + ) + load_optimizer_conf = ( + False + if "load_optimizer" not in conf["trainer"] + else conf["trainer"]["load_optimizer"] + ) + load_scaler_conf = ( + False + if "load_scaler" not in conf["trainer"] + else conf["trainer"]["load_scaler"] + ) + load_scheduler_conf = ( + False + if "load_scheduler" not in conf["trainer"] + else conf["trainer"]["load_scheduler"] + ) + + # Load an optimizer, gradient scaler, and learning rate scheduler, the optimizer must come after wrapping model using FSDP + if not load_weights: # Loaded after loading model weights when reloading + optimizer = torch.optim.AdamW( + model.parameters(), + lr=learning_rate, + weight_decay=weight_decay, + betas=(0.9, 0.95), + ) + if conf["trainer"]["mode"] == "fsdp": + optimizer = FSDPOptimizerWrapper(optimizer, model) + scheduler = load_scheduler(optimizer, conf) + scaler = ( + ShardedGradScaler(enabled=amp) + if conf["trainer"]["mode"] == "fsdp" + else GradScaler(enabled=amp) + ) + + # Multi-step training case -- when starting, only load the model weights (then after load all states) + elif load_weights and not ( + load_optimizer_conf or load_scaler_conf or load_scheduler_conf + ): + optimizer = torch.optim.AdamW( + model.parameters(), + lr=learning_rate, + weight_decay=weight_decay, + betas=(0.9, 0.95), + ) + # FSDP checkpoint settings + if conf["trainer"]["mode"] == "fsdp": + logging.info( + f"Loading FSDP model, optimizer, grad scaler, and learning rate scheduler states from {save_loc}" + ) + optimizer = torch.optim.AdamW( + model.parameters(), + lr=learning_rate, + weight_decay=weight_decay, + betas=(0.9, 0.95), + ) + optimizer = FSDPOptimizerWrapper(optimizer, model) + checkpoint_io = TorchFSDPCheckpointIO() + checkpoint_io.load_unsharded_model( + model, os.path.join(save_loc, "model_checkpoint.pt") + ) + else: + # DDP settings + ckpt = os.path.join(save_loc, "checkpoint.pt") + checkpoint = torch.load(ckpt, map_location=device) + if conf["trainer"]["mode"] == "ddp": + logging.info( + f"Loading DDP model, optimizer, grad scaler, and learning rate scheduler states from {save_loc}" + ) + model.module.load_state_dict(checkpoint["model_state_dict"]) + else: + logging.info( + f"Loading model, optimizer, grad scaler, and learning rate scheduler states from {save_loc}" + ) + model.load_state_dict(checkpoint["model_state_dict"]) + # Load the learning rate scheduler and mixed precision grad scaler + scheduler = load_scheduler(optimizer, conf) + scaler = ( + ShardedGradScaler(enabled=amp) + if conf["trainer"]["mode"] == "fsdp" + else GradScaler(enabled=amp) + ) + + # load optimizer and grad scaler states + else: + ckpt = os.path.join(save_loc, "checkpoint.pt") + checkpoint = torch.load(ckpt, map_location=device) + + # FSDP checkpoint settings + if conf["trainer"]["mode"] == "fsdp": + logging.info( + f"Loading FSDP model, optimizer, grad scaler, and learning rate scheduler states from {save_loc}" + ) + optimizer = torch.optim.AdamW( + model.parameters(), + lr=learning_rate, + weight_decay=weight_decay, + betas=(0.9, 0.95), + ) + optimizer = FSDPOptimizerWrapper(optimizer, model) + checkpoint_io = TorchFSDPCheckpointIO() + checkpoint_io.load_unsharded_model( + model, os.path.join(save_loc, "model_checkpoint.pt") + ) + if ( + "load_optimizer" in conf["trainer"] + and conf["trainer"]["load_optimizer"] + ): + checkpoint_io.load_unsharded_optimizer( + optimizer, os.path.join(save_loc, "optimizer_checkpoint.pt") + ) + + else: + # DDP settings + if conf["trainer"]["mode"] == "ddp": + logging.info( + f"Loading DDP model, optimizer, grad scaler, and learning rate scheduler states from {save_loc}" + ) + model.module.load_state_dict(checkpoint["model_state_dict"]) + else: + logging.info( + f"Loading model, optimizer, grad scaler, and learning rate scheduler states from {save_loc}" + ) + model.load_state_dict(checkpoint["model_state_dict"]) + optimizer = torch.optim.AdamW( + model.parameters(), + lr=learning_rate, + weight_decay=weight_decay, + betas=(0.9, 0.95), + ) + if ( + "load_optimizer" in conf["trainer"] + and conf["trainer"]["load_optimizer"] + ): + optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) + + scheduler = load_scheduler(optimizer, conf) + scaler = ( + ShardedGradScaler(enabled=amp) + if conf["trainer"]["mode"] == "fsdp" + else GradScaler(enabled=amp) + ) + + # Update the config file to the current epoch + if "reload_epoch" in conf["trainer"] and conf["trainer"]["reload_epoch"]: + conf["trainer"]["start_epoch"] = checkpoint["epoch"] + 1 + + if conf["trainer"]["start_epoch"] > 0: + # Only reload the scheduler state if not starting over from epoch 0 + if scheduler is not None: + scheduler.load_state_dict(checkpoint["scheduler_state_dict"]) + + # reload the AMP gradient scaler + scaler.load_state_dict(checkpoint["scaler_state_dict"]) + + # Enable updating the lr if not using a policy + if ( + conf["trainer"]["update_learning_rate"] + if "update_learning_rate" in conf["trainer"] + else False + ): + for param_group in optimizer.param_groups: + param_group["lr"] = learning_rate + + return conf, model, optimizer, scheduler, scaler + + +def main(rank, world_size, conf, backend, trial=False): + """ + Main function to set up training and validation processes. + + Args: + rank (int): Rank of the current process. + world_size (int): Number of processes participating in the job. + conf (dict): Configuration dictionary containing model, data, and training parameters. + backend (str): Backend to be used for distributed training. + trial (bool, optional): Flag for whether this is an Optuna trial. Defaults to False. + + Returns: + Any: The result of the training process. + """ + + # convert $USER to the actual user name + conf["save_loc"] = os.path.expandvars(conf["save_loc"]) + + if conf["trainer"]["mode"] in ["fsdp", "ddp"]: + setup(rank, world_size, conf["trainer"]["mode"], backend) + + # infer device id from rank + device = ( + torch.device(f"cuda:{rank % torch.cuda.device_count()}") + if torch.cuda.is_available() + else torch.device("cpu") + ) + torch.cuda.set_device(rank % torch.cuda.device_count()) + + # Config settings + seed = conf["seed"] + seed_everything(seed) + + # Load the dataset using the provided dataset_type + train_dataset = load_dataset(conf, rank=rank, world_size=world_size, is_train=True) + valid_dataset = load_dataset(conf, rank=rank, world_size=world_size, is_train=False) + + # Load the dataloader + train_loader = load_dataloader(conf, train_dataset, rank=rank, world_size=world_size, is_train=True) + valid_loader = load_dataloader(conf, valid_dataset, rank=rank, world_size=world_size, is_train=False) + + # model + m = load_model(conf) + + # have to send the module to the correct device first + m.to(device) + + # move out of eager-mode + if conf["trainer"].get("compile", False): + m = torch.compile(m) + + # Wrap in DDP or FSDP module, or none + model = distributed_model_wrapper(conf, m, device) + + # Load model weights (if any), an optimizer, scheduler, and gradient scaler + conf, model, optimizer, scheduler, scaler = load_model_states_and_optimizer( + conf, model, device + ) + + # Train and validation losses + train_criterion = VariableTotalLoss2D(conf) + valid_criterion = VariableTotalLoss2D(conf, validation=True) + + # Set up some metrics + metrics = LatWeightedMetrics(conf) + + # Initialize a trainer object + trainer_cls = load_trainer(conf) + trainer = trainer_cls(model, rank, module=(conf["trainer"]["mode"] == "ddp")) + + # Fit the model + result = trainer.fit( + conf, + train_loader=train_loader, + valid_loader=valid_loader, + optimizer=optimizer, + train_criterion=train_criterion, + valid_criterion=valid_criterion, + scaler=scaler, + scheduler=scheduler, + metrics=metrics, + trial=trial, # Optional + ) + + return result + + +class Objective(BaseObjective): + """ + Optuna objective class for hyperparameter optimization. + + Attributes: + config (dict): Configuration dictionary containing training parameters. + metric (str): Metric to optimize, defaults to "val_loss". + device (str): Device for training, defaults to "cpu". + """ + + def __init__(self, config, metric="val_loss", device="cpu"): + """ + Initialize the Objective class. + + Args: + config (dict): Configuration dictionary containing training parameters. + metric (str, optional): Metric to optimize. Defaults to "val_loss". + device (str, optional): Device for training. Defaults to "cpu". + """ + + # Initialize the base class + BaseObjective.__init__(self, config, metric, device) + + def train(self, trial, conf): + """ + Train the model using the given trial configuration. + + Args: + trial (optuna.trial.Trial): Optuna trial object. + conf (dict): Configuration dictionary for the current trial. + + Returns: + Any: The result of the training process. + """ + + conf["model"]["dim_head"] = conf["model"]["dim"] + conf["model"]["vq_codebook_dim"] = conf["model"]["dim"] + + try: + return main(0, 1, conf, trial=trial) + + except Exception as E: + if "CUDA" in str(E) or "non-singleton" in str(E): + logging.warning( + f"Pruning trial {trial.number} due to CUDA memory overflow: {str(E)}." + ) + raise optuna.TrialPruned() + elif "non-singleton" in str(E): + logging.warning( + f"Pruning trial {trial.number} due to shape mismatch: {str(E)}." + ) + raise optuna.TrialPruned() + else: + logging.warning(f"Trial {trial.number} failed due to error: {str(E)}.") + raise E + + +if __name__ == "__main__": + description = "Train a segmengation model on a hologram data set" + parser = ArgumentParser(description=description) + parser.add_argument( + "-c", + "--config", + dest="model_config", + type=str, + default=False, + help="Path to the model configuration (yml) containing your inputs.", + ) + parser.add_argument( + "-l", + dest="launch", + type=int, + default=0, + help="Submit workers to PBS.", + ) + parser.add_argument( + "-w", + "--wandb", + dest="wandb", + type=int, + default=0, + help="Use wandb. Default = False", + ) + parser.add_argument( + "--backend", + type=str, + help="Backend for distribted training.", + default="nccl", + choices=["nccl", "gloo", "mpi"], + ) + args = parser.parse_args() + args_dict = vars(args) + config = args_dict.pop("model_config") + launch = int(args_dict.pop("launch")) + backend = args_dict.pop("backend") + use_wandb = int(args_dict.pop("wandb")) + + # Set up logger to print stuff + root = logging.getLogger() + root.setLevel(logging.DEBUG) + formatter = logging.Formatter("%(levelname)s:%(name)s:%(message)s") + + # Stream output to stdout + ch = logging.StreamHandler() + ch.setLevel(logging.INFO) + ch.setFormatter(formatter) + root.addHandler(ch) + + # Load the configuration and get the relevant variables + with open(config) as cf: + conf = yaml.load(cf, Loader=yaml.FullLoader) + + # ======================================================== # + # handling config args + conf = credit_main_parser( + conf, parse_training=True, parse_predict=False, print_summary=False + ) + training_data_check(conf, print_summary=False) + # ======================================================== # + + # Create directories if they do not exist and copy yml file + save_loc = os.path.expandvars(conf["save_loc"]) + os.makedirs(save_loc, exist_ok=True) + + if not os.path.exists(os.path.join(save_loc, "model.yml")): + shutil.copy(config, os.path.join(save_loc, "model.yml")) + + # Launch PBS jobs + if launch: + # Where does this script live? + script_path = Path(__file__).absolute() + if conf["pbs"]["queue"] == "casper": + logging.info("Launching to PBS on Casper") + launch_script(config, script_path) + else: + logging.info("Launching to PBS on Derecho") + launch_script_mpi(config, script_path) + sys.exit() + + if use_wandb: # this needs updated + wandb.init( + # set the wandb project where this run will be logged + project="Derecho parallelism", + name=f"Worker {os.environ['RANK']} {os.environ['WORLD_SIZE']}", + # track hyperparameters and run metadata + config=conf, + ) + + seed = conf["seed"] + seed_everything(seed) + + local_rank, world_rank, world_size = get_rank_info(conf["trainer"]["mode"]) + main(world_rank, world_size, conf, backend) diff --git a/config/multi_step.yml b/config/multi_step.yml index c686ad3..6351545 100644 --- a/config/multi_step.yml +++ b/config/multi_step.yml @@ -15,16 +15,14 @@ data: train_years: [1979, 2014] valid_years: [2014, 2018] scaler_type: 'std_new' - history_len: 1 - valid_history_len: 1 - forecast_len: 4 - valid_forecast_len: 0 + history_len: 0 + valid_history_len: 0 one_shot: True lead_time_periods: 6 skip_periods: null static_first: True - dataset_type: MultiprocessingBatcherPrefetch # ERA5_and_Forcing_MultiStep, ERA5_MultiStep_Batcher, MultiprocessingBatcher, MultiprocessingBatcherPrefetch - + dataset_type: ERA5_and_Forcing_Singlestep # ERA5_and_Forcing_Singlestep, ERA5_and_Forcing_MultiStep, ERA5_MultiStep_Batcher, MultiprocessingBatcher, MultiprocessingBatcherPrefetch + trainer: type: multi-step # <---------- change to your type mode: fsdp @@ -36,8 +34,8 @@ trainer: train_one_epoch: True learning_rate: 1.0e-04 # <-- change to your lr weight_decay: 0 - train_batch_size: 2 - valid_batch_size: 2 + train_batch_size: 1 + valid_batch_size: 1 batches_per_epoch: 0 valid_batches_per_epoch: 50 stopping_patience: 500 diff --git a/credit/datasets/__init__.py b/credit/datasets/__init__.py index 38c0c69..b6a7d52 100644 --- a/credit/datasets/__init__.py +++ b/credit/datasets/__init__.py @@ -11,15 +11,15 @@ def set_globals(data_config, namespace=None): """ Sets global variables from the provided configuration dictionary in the specified namespace. - This method updates the global variables in either the given `namespace` or the - caller's namespace (if `namespace` is not provided). If the `namespace` is not specified, - it uses the global namespace of the caller (using `sys._getframe(1).f_globals`). + This method updates the global variables in either the given `namespace` or the + caller's namespace (if `namespace` is not provided). If the `namespace` is not specified, + it uses the global namespace of the caller (using `sys._getframe(1).f_globals`). Parameters: - - data_config (dict): A dictionary where the keys are the global variable names - and the values are the corresponding values to set. - - namespace (dict, optional): The namespace (or dictionary) where the global variables - should be set. If not provided, the caller's global namespace is used. + - data_config (dict): A dictionary where the keys are the global variable names + and the values are the corresponding values to set. + - namespace (dict, optional): The namespace (or dictionary) where the global variables + should be set. If not provided, the caller's global namespace is used. The method logs each global variable being created and its name. """ @@ -36,8 +36,8 @@ def set_globals(data_config, namespace=None): def setup_data_loading(conf): """ - Sets up the data loading configuration by reading and processing data paths, - surface, dynamic forcing, and diagnostic files based on the given configuration. + Sets up the data loading configuration by reading and processing data paths, + surface, dynamic forcing, and diagnostic files based on the given configuration. The function processes the configuration dictionary (`conf`) and performs the following: - Globs and filters data files (ERA5, surface, dynamic forcing, diagnostic). @@ -46,20 +46,20 @@ def setup_data_loading(conf): - Returns a dictionary containing all the paths and configuration details for further use. Parameters: - - conf (dict): A dictionary containing configuration details, including data paths, - variable names, forecast details, and other settings. + - conf (dict): A dictionary containing configuration details, including data paths, + variable names, forecast details, and other settings. Returns: - - data_config (dict): A dictionary containing paths to various datasets and other - configuration values used in data loading, such as: + - data_config (dict): A dictionary containing paths to various datasets and other + configuration values used in data loading, such as: - all_ERA_files: All ERA5 dataset files. - train_files: Filtered training dataset files. - valid_files: Filtered validation dataset files. - surface_files: Surface data files, if available. - dyn_forcing_files: Dynamic forcing files, if available. - diagnostic_files: Diagnostic files, if available. - - varname_upper_air, varname_surface, varname_dyn_forcing, etc.: Variable names for - each data type. + - varname_upper_air, varname_surface, varname_dyn_forcing, etc.: Variable names for + each data type. - history_len: Length of the history data for training. - forecast_len: Number of steps ahead to forecast. - Other configuration values related to skipping periods, one-shot learning, etc. @@ -263,6 +263,14 @@ def setup_data_loading(conf): else: one_shot = conf["data"]["one_shot"] + if conf["data"]["sst_forcing"]["activate"]: + sst_forcing = { + "varname_skt": conf["data"]["sst_forcing"]["varname_skt"], + "varname_ocean_mask": conf["data"]["sst_forcing"]["varname_ocean_mask"], + } + else: + sst_forcing = None + data_config = { 'all_ERA_files': all_ERA_files, 'train_files': train_files, @@ -290,7 +298,8 @@ def setup_data_loading(conf): 'valid_forecast_len': valid_forecast_len, 'max_forecast_len': max_forecast_len, 'skip_periods': skip_periods, - 'one_shot': one_shot + 'one_shot': one_shot, + 'sst_forcing': sst_forcing } return data_config diff --git a/credit/datasets/era5_multistep.py b/credit/datasets/era5_multistep.py index 82e27e3..dda9a71 100644 --- a/credit/datasets/era5_multistep.py +++ b/credit/datasets/era5_multistep.py @@ -29,6 +29,7 @@ def worker( forecast_len: int, skip_periods: int, transform: Optional[Callable], + sst_forcing: Optional[Any] = None, ) -> Dict[str, Any]: """ Processes a given index to extract and transform data for a specific time slice. @@ -45,7 +46,6 @@ def worker( - skip_periods (int): Number of periods to skip between samples. - xarray_forcing (Optional[Any]): xarray dataset containing forcing data. - xarray_static (Optional[Any]): xarray dataset containing static data. - - transform (Optional[Callable]): Transformation function to apply to the data. Returns: @@ -187,6 +187,38 @@ def worker( # merge into the target dataset target_ERA5_images = target_ERA5_images.merge(diagnostic_subset) + # sst forcing operations + if sst_forcing is not None: + # get xr.dataset keys + varname_skt = sst_forcing["varname_skt"] + varname_ocean_mask = sst_forcing["varname_ocean_mask"] + + # get xr.dataarray from the dataset + ocean_mask = historical_ERA5_images[varname_ocean_mask] + input_skt = historical_ERA5_images[varname_skt] + target_skt = target_ERA5_images[varname_skt] + + # for multi-input cases, use time=-1 ocean mask for all times + if history_len > 1: + ocean_mask[: history_len - 1] = ocean_mask.isel(time=-1) + + # get ocean mask + ocean_mask_bool = ocean_mask.isel(time=-1) == 0 + + # for multi-input cases, use time=-1 ocean SKT for all times + if history_len > 1: + input_skt[: history_len - 1] = input_skt[ + : history_len - 1 + ].where(~ocean_mask_bool, input_skt.isel(time=-1)) + + # for target skt, replace ocean values using time=-1 input SKT + target_skt = target_skt.where(~ocean_mask_bool, input_skt.isel(time=-1)) + + # Update the target_ERA5_images dataset with the modified target_skt + historical_ERA5_images[varname_ocean_mask] = ocean_mask + historical_ERA5_images[varname_skt] = input_skt + target_ERA5_images[varname_skt] = target_skt + # create a dict object with input/output tensors sample = Sample( historical_ERA5_images=historical_ERA5_images, @@ -252,6 +284,7 @@ def __init__( skip_periods=None, one_shot=None, max_forecast_len=None, + sst_forcing=None ): """ Initialize the ERA5_and_Forcing_Dataset @@ -278,7 +311,7 @@ def __init__( the final state of the training target. Default is None - max_forecast_len (int, optional): Maximum length of the forecast sequence. - shuffle (bool, optional): Whether to shuffle the data. Default is True. - + - sst_forcing (optional): Returns: - sample (dict): A dictionary containing historical_ERA5_images, target_ERA5_images, @@ -303,6 +336,9 @@ def __init__( # total number of needed forecast lead times self.total_seq_len = self.history_len + self.forecast_len + # sst forcing + self.sst_forcing = sst_forcing + # set random seed self.rng = np.random.default_rng(seed=seed) @@ -457,6 +493,7 @@ def __init__( forecast_len=self.forecast_len, skip_periods=self.skip_periods, transform=self.transform, + sst_forcing=self.sst_forcing ) self.total_length = len(self.ERA5_indices) @@ -555,6 +592,7 @@ def __getitem__(self, index): skip_periods=data_config['skip_periods'], one_shot=False, max_forecast_len=data_config['max_forecast_len'], + sst_forcing=data_config['sst_forcing'], transform=load_transforms(conf) ) diff --git a/credit/datasets/era5_multistep_batcher.py b/credit/datasets/era5_multistep_batcher.py index e938967..568acfe 100644 --- a/credit/datasets/era5_multistep_batcher.py +++ b/credit/datasets/era5_multistep_batcher.py @@ -44,6 +44,7 @@ def __init__( filename_forcing=None, filename_static=None, filename_diagnostic=None, + sst_forcing=None, history_len=2, forecast_len=0, transform=None, @@ -78,7 +79,7 @@ def __init__( - skip_periods (int, optional): Number of periods to skip between samples. - max_forecast_len (int, optional): Maximum length of the forecast sequence. - shuffle (bool, optional): Whether to shuffle the data. Default is True. - + - sst_forcing (optional): Returns: - sample (dict): A dictionary containing historical_ERA5_images, target_ERA5_images, @@ -107,6 +108,9 @@ def __init__( # max possible forecast len self.max_forecast_len = max_forecast_len + # sst forcing + self.sst_forcing = sst_forcing + # =================================================================== # # flags to determin if any of the [surface, dyn_forcing, diagnostics] # variable groups share the same file as upper air variables @@ -251,6 +255,7 @@ def __init__( diagnostic_files=self.diagnostic_files, xarray_forcing=self.xarray_forcing, xarray_static=self.xarray_static, + sst_forcing=self.sst_forcing, history_len=self.history_len, forecast_len=self.forecast_len, skip_periods=self.skip_periods, @@ -398,8 +403,6 @@ def __getitem__(self, _): batch["stop_forecast"] = batch["forecast_step"] == self.forecast_len + 1 batch["datetime"] = batch["datetime"].view(-1, self.batch_size) # reshape - # print(batch['index'], batch['datetime'], batch['forecast_step'], batch['stop_forecast'], batch['x'].shape, batch['x_surf'].shape) - return batch @@ -523,7 +526,7 @@ def __init__(self, *args, num_workers=4, prefetch_factor=4, **kwargs): # Register signal handler self.stop_event = multiprocessing.Event() - #signal.signal(signal.SIGINT, self.handle_signal) + # signal.signal(signal.SIGINT, self.handle_signal) signal.signal(signal.SIGTERM, self.handle_signal) self.prefetch_thread = None @@ -797,6 +800,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): skip_periods=data_config['skip_periods'], max_forecast_len=data_config['max_forecast_len'], transform=load_transforms(conf), + sst_forcing=data_config['sst_forcing'], batch_size=batch_size, shuffle=shuffle, rank=rank, @@ -824,6 +828,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): filename_forcing=data_config['forcing_files'], filename_static=data_config['static_files'], filename_diagnostic=data_config['diagnostic_files'], + sst_forcing=data_config['sst_forcing'], history_len=data_config['history_len'], forecast_len=data_config['forecast_len'], skip_periods=data_config['skip_periods'], @@ -885,6 +890,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): skip_periods=data_config['skip_periods'], max_forecast_len=data_config['max_forecast_len'], transform=load_transforms(conf), + sst_forcing=data_config['sst_forcing'], batch_size=batch_size, shuffle=shuffle, rank=rank, @@ -913,6 +919,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): filename_forcing=data_config['forcing_files'], filename_static=data_config['static_files'], filename_diagnostic=data_config['diagnostic_files'], + sst_forcing=data_config['sst_forcing'], history_len=data_config['history_len'], forecast_len=data_config['forecast_len'], skip_periods=data_config['skip_periods'], @@ -967,6 +974,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): filename_forcing=data_config['forcing_files'], filename_static=data_config['static_files'], filename_diagnostic=data_config['diagnostic_files'], + sst_forcing=data_config['sst_forcing'], history_len=data_config['history_len'], forecast_len=data_config['forecast_len'], skip_periods=data_config['skip_periods'], @@ -1000,6 +1008,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): filename_forcing=data_config['forcing_files'], filename_static=data_config['static_files'], filename_diagnostic=data_config['diagnostic_files'], + sst_forcing=data_config['sst_forcing'], history_len=data_config['history_len'], forecast_len=data_config['forecast_len'], skip_periods=data_config['skip_periods'], diff --git a/credit/datasets/load_dataset_and_dataloader.py b/credit/datasets/load_dataset_and_dataloader.py index 2f6ca4b..7fbbd5a 100644 --- a/credit/datasets/load_dataset_and_dataloader.py +++ b/credit/datasets/load_dataset_and_dataloader.py @@ -1,9 +1,11 @@ from credit.datasets.era5_multistep import ERA5_and_Forcing_MultiStep +from credit.datasets.era5_singlestep import ERA5_and_Forcing_SingleStep from credit.datasets.era5_multistep_batcher import ( ERA5_MultiStep_Batcher, MultiprocessingBatcher, MultiprocessingBatcherPrefetch ) +from credit.data import ERA5_and_Forcing_Dataset from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler from credit.transforms import load_transforms @@ -70,7 +72,30 @@ def load_dataset(conf, rank=0, world_size=1, is_train=True): prefetch_factor = 4 # Instantiate the dataset based on the provided class name - if dataset_type == "ERA5_and_Forcing_MultiStep": + if dataset_type == "ERA5_and_Forcing_SingleStep": # forecast-len = 0 dataset + dataset = ERA5_and_Forcing_SingleStep( + varname_upper_air=conf["data"]["variables"], + varname_surface=conf["data"]["surface_variables"], + varname_dyn_forcing=conf["data"]["dynamic_forcing_variables"], + varname_forcing=conf["data"]["forcing_variables"], + varname_static=conf["data"]["static_variables"], + varname_diagnostic=conf["data"]["diagnostic_variables"], + filenames=data_config["all_ERA_files"], + filename_surface=data_config["surface_files"], + filename_dyn_forcing=data_config["dyn_forcing_files"], + filename_forcing=conf["data"]["save_loc_forcing"], + filename_static=conf["data"]["save_loc_static"], + filename_diagnostic=data_config["diagnostic_files"], + history_len=data_config['history_len'], + forecast_len=data_config['forecast_len'], + skip_periods=conf["data"]["skip_periods"], + one_shot=conf["data"]["one_shot"], + max_forecast_len=conf["data"]["max_forecast_len"], + transform=load_transforms(conf), + sst_forcing=data_config['sst_forcing'] + ) + # All datasets from here on are multi-step examples + elif dataset_type == "ERA5_and_Forcing_MultiStep": dataset = ERA5_and_Forcing_MultiStep( varname_upper_air=conf["data"]["variables"], varname_surface=conf["data"]["surface_variables"], @@ -194,6 +219,10 @@ def load_dataloader(conf, dataset, rank=0, world_size=1, is_train=True): DataLoader: The loaded DataLoader. """ seed = conf["seed"] + training_type = "train" if is_train else "valid" + batch_size = conf["trainer"][f"{training_type}_batch_size"] + shuffle = is_train + num_workers = conf["trainer"]["thread_workers"] if is_train else conf["trainer"]["valid_thread_workers"] prefetch_factor = conf["trainer"].get("prefetch_factor") if prefetch_factor is None: logging.warning( @@ -202,7 +231,26 @@ def load_dataloader(conf, dataset, rank=0, world_size=1, is_train=True): ) prefetch_factor = 4 - if type(dataset) is ERA5_and_Forcing_MultiStep: + if type(dataset) is ERA5_and_Forcing_SingleStep: + # This is the single-step dataset, original version + sampler = DistributedSampler( + dataset, + num_replicas=world_size, + rank=rank, + seed=seed, + shuffle=shuffle, + drop_last=True, + ) + dataloader = DataLoader( + dataset, + batch_size=batch_size, + shuffle=False, + sampler=sampler, + pin_memory=True, + persistent_workers=True if num_workers > 0 else False, + num_workers=num_workers + ) + elif type(dataset) is ERA5_and_Forcing_MultiStep: # This is the deprecated dataset sampler = DistributedSampler( dataset, @@ -284,6 +332,7 @@ def load_dataloader(conf, dataset, rank=0, world_size=1, is_train=True): # options dataset_type = [ + "ERA5_and_Forcing_SingleStep", "ERA5_and_Forcing_MultiStep", "ERA5_MultiStep_Batcher", "MultiprocessingBatcher", @@ -294,13 +343,13 @@ def load_dataloader(conf, dataset, rank=0, world_size=1, is_train=True): rank = 0 world_size = 2 conf["trainer"]["start_epoch"] = epoch - conf["trainer"]["train_batch_size"] = 2 # batch_size - conf["trainer"]["valid_batch_size"] = 2 # batch_size - conf["trainer"]["thread_workers"] = 2 # num_workers - conf["trainer"]["valid_thread_workers"] = 2 # num_workers + conf["trainer"]["train_batch_size"] = 1 # batch_size + conf["trainer"]["valid_batch_size"] = 1 # batch_size + conf["trainer"]["thread_workers"] = 0 # num_workers + conf["trainer"]["valid_thread_workers"] = 0 # num_workers conf["trainer"]["prefetch_factor"] = 4 # Add prefetch_factor - conf["data"]["forecast_len"] = 6 - conf["data"]["valid_forecast_len"] = 6 + conf["data"]["forecast_len"] = 0 + conf["data"]["valid_forecast_len"] = 0 conf["data"]["dataset_type"] = dataset_type set_globals(data_config, namespace=globals())