From e394671ea59b722bfc525a5a21d7318f14ba9170 Mon Sep 17 00:00:00 2001 From: John Schreck Date: Mon, 9 Dec 2024 09:09:16 -0700 Subject: [PATCH] Addition of new multi-step dataset that allows batch size > 1, refactor of data config loading --- credit/datasets/__init__.py | 296 ++++++++++++++ credit/datasets/era5_multistep.py | 452 +++------------------- credit/datasets/era5_multistep_batcher.py | 425 ++++++++++++++++++++ 3 files changed, 785 insertions(+), 388 deletions(-) create mode 100644 credit/datasets/era5_multistep_batcher.py diff --git a/credit/datasets/__init__.py b/credit/datasets/__init__.py index e69de29b..38c0c693 100644 --- a/credit/datasets/__init__.py +++ b/credit/datasets/__init__.py @@ -0,0 +1,296 @@ +import os +import sys +import glob +import logging + + +logger = logging.getLogger(__name__) + + +def set_globals(data_config, namespace=None): + """ + Sets global variables from the provided configuration dictionary in the specified namespace. + + This method updates the global variables in either the given `namespace` or the + caller's namespace (if `namespace` is not provided). If the `namespace` is not specified, + it uses the global namespace of the caller (using `sys._getframe(1).f_globals`). + + Parameters: + - data_config (dict): A dictionary where the keys are the global variable names + and the values are the corresponding values to set. + - namespace (dict, optional): The namespace (or dictionary) where the global variables + should be set. If not provided, the caller's global namespace is used. + + The method logs each global variable being created and its name. + """ + + target = namespace or sys._getframe(1).f_globals + target.update(data_config) + + # Identify if this is the __main__ namespace + name = target.get('__name__') + + for key in data_config: + logger.info(f"Creating global variable in {name}: {key}") + + +def setup_data_loading(conf): + """ + Sets up the data loading configuration by reading and processing data paths, + surface, dynamic forcing, and diagnostic files based on the given configuration. + + The function processes the configuration dictionary (`conf`) and performs the following: + - Globs and filters data files (ERA5, surface, dynamic forcing, diagnostic). + - Determines the training and validation file sets based on specified years. + - Sets up variables like historical data length, forecast length, and additional metadata. + - Returns a dictionary containing all the paths and configuration details for further use. + + Parameters: + - conf (dict): A dictionary containing configuration details, including data paths, + variable names, forecast details, and other settings. + + Returns: + - data_config (dict): A dictionary containing paths to various datasets and other + configuration values used in data loading, such as: + - all_ERA_files: All ERA5 dataset files. + - train_files: Filtered training dataset files. + - valid_files: Filtered validation dataset files. + - surface_files: Surface data files, if available. + - dyn_forcing_files: Dynamic forcing files, if available. + - diagnostic_files: Diagnostic files, if available. + - varname_upper_air, varname_surface, varname_dyn_forcing, etc.: Variable names for + each data type. + - history_len: Length of the history data for training. + - forecast_len: Number of steps ahead to forecast. + - Other configuration values related to skipping periods, one-shot learning, etc. + """ + + all_ERA_files = sorted(glob.glob(conf["data"]["save_loc"])) + + # <------------------------------------------ std_new + if conf["data"]["scaler_type"] == "std_new": + # check and glob surface files + if ("surface_variables" in conf["data"]) and ( + len(conf["data"]["surface_variables"]) > 0 + ): + surface_files = sorted(glob.glob(conf["data"]["save_loc_surface"])) + + else: + surface_files = None + + # check and glob dyn forcing files + if ("dynamic_forcing_variables" in conf["data"]) and ( + len(conf["data"]["dynamic_forcing_variables"]) > 0 + ): + dyn_forcing_files = sorted(glob.glob(conf["data"]["save_loc_dynamic_forcing"])) + + else: + dyn_forcing_files = None + + # check and glob diagnostic files + if ("diagnostic_variables" in conf["data"]) and ( + len(conf["data"]["diagnostic_variables"]) > 0 + ): + diagnostic_files = sorted(glob.glob(conf["data"]["save_loc_diagnostic"])) + + else: + diagnostic_files = None + + # -------------------------------------------------- # + # import training / validation years from conf + + if "train_years" in conf["data"]: + train_years_range = conf["data"]["train_years"] + else: + train_years_range = [1979, 2014] + + if "valid_years" in conf["data"]: + valid_years_range = conf["data"]["valid_years"] + else: + valid_years_range = [2014, 2018] + + # convert year info to str for file name search + train_years = [str(year) for year in range(train_years_range[0], train_years_range[1])] + valid_years = [str(year) for year in range(valid_years_range[0], valid_years_range[1])] + + # Filter the files for training / validation + train_files = [ + file for file in all_ERA_files if any(year in file for year in train_years) + ] + valid_files = [ + file for file in all_ERA_files if any(year in file for year in valid_years) + ] + + # <----------------------------------- std_new + if conf["data"]["scaler_type"] == "std_new": + if surface_files is not None: + train_surface_files = [ + file for file in surface_files if any(year in file for year in train_years) + ] + valid_surface_files = [ + file for file in surface_files if any(year in file for year in valid_years) + ] + + # ---------------------------- # + # check total number of files + assert ( + len(train_surface_files) == len(train_files) + ), "Mismatch between the total number of training set [surface files] and [upper-air files]" + assert ( + len(valid_surface_files) == len(valid_files) + ), "Mismatch between the total number of validation set [surface files] and [upper-air files]" + + else: + train_surface_files = None + valid_surface_files = None + + if dyn_forcing_files is not None: + train_dyn_forcing_files = [ + file + for file in dyn_forcing_files + if any(year in file for year in train_years) + ] + valid_dyn_forcing_files = [ + file + for file in dyn_forcing_files + if any(year in file for year in valid_years) + ] + + # ---------------------------- # + # check total number of files + assert ( + len(train_dyn_forcing_files) == len(train_files) + ), "Mismatch between the total number of training set [dynamic forcing files] and [upper-air files]" + assert ( + len(valid_dyn_forcing_files) == len(valid_files) + ), "Mismatch between the total number of validation set [dynamic forcing files] and [upper-air files]" + + else: + train_dyn_forcing_files = None + valid_dyn_forcing_files = None + + if diagnostic_files is not None: + train_diagnostic_files = [ + file + for file in diagnostic_files + if any(year in file for year in train_years) + ] + valid_diagnostic_files = [ + file + for file in diagnostic_files + if any(year in file for year in valid_years) + ] + + # ---------------------------- # + # check total number of files + assert ( + len(train_diagnostic_files) == len(train_files) + ), "Mismatch between the total number of training set [diagnostic files] and [upper-air files]" + assert ( + len(valid_diagnostic_files) == len(valid_files) + ), "Mismatch between the total number of validation set [diagnostic files] and [upper-air files]" + + else: + train_diagnostic_files = None + valid_diagnostic_files = None + + # convert $USER to the actual user name + conf["save_loc"] = os.path.expandvars(conf["save_loc"]) + + # ======================================================== # + # parse inputs + + # upper air variables + varname_upper_air = conf["data"]["variables"] + + if ("forcing_variables" in conf["data"]) and ( + len(conf["data"]["forcing_variables"]) > 0 + ): + forcing_files = conf["data"]["save_loc_forcing"] + varname_forcing = conf["data"]["forcing_variables"] + else: + forcing_files = None + varname_forcing = None + + if ("static_variables" in conf["data"]) and (len(conf["data"]["static_variables"]) > 0): + static_files = conf["data"]["save_loc_static"] + varname_static = conf["data"]["static_variables"] + else: + static_files = None + varname_static = None + + # get surface variable names + if surface_files is not None: + varname_surface = conf["data"]["surface_variables"] + else: + varname_surface = None + + # get dynamic forcing variable names + if dyn_forcing_files is not None: + varname_dyn_forcing = conf["data"]["dynamic_forcing_variables"] + else: + varname_dyn_forcing = None + + # get diagnostic variable names + if diagnostic_files is not None: + varname_diagnostic = conf["data"]["diagnostic_variables"] + else: + varname_diagnostic = None + + # number of previous lead time inputs + history_len = conf["data"]["history_len"] + valid_history_len = conf["data"]["valid_history_len"] + + # number of lead times to forecast + forecast_len = conf["data"]["forecast_len"] + valid_forecast_len = conf["data"]["valid_forecast_len"] + + # max_forecast_len + if "max_forecast_len" not in conf["data"]: + max_forecast_len = None + else: + max_forecast_len = conf["data"]["max_forecast_len"] + + # skip_periods + if "skip_periods" not in conf["data"]: + skip_periods = None + else: + skip_periods = conf["data"]["skip_periods"] + + # one_shot + if "one_shot" not in conf["data"]: + one_shot = None + else: + one_shot = conf["data"]["one_shot"] + + data_config = { + 'all_ERA_files': all_ERA_files, + 'train_files': train_files, + 'valid_files': valid_files, + 'surface_files': surface_files, + 'dyn_forcing_files': dyn_forcing_files, + 'diagnostic_files': diagnostic_files, + 'forcing_files': forcing_files, + 'static_files': static_files, + 'train_surface_files': train_surface_files, + 'valid_surface_files': valid_surface_files, + 'train_dyn_forcing_files': train_dyn_forcing_files, + 'valid_dyn_forcing_files': valid_dyn_forcing_files, + 'train_diagnostic_files': train_diagnostic_files, + 'valid_diagnostic_files': valid_diagnostic_files, + 'varname_upper_air': varname_upper_air, + 'varname_surface': varname_surface, + 'varname_dyn_forcing': varname_dyn_forcing, + 'varname_forcing': varname_forcing, + 'varname_static': varname_static, + 'varname_diagnostic': varname_diagnostic, + 'history_len': history_len, + 'valid_history_len': valid_history_len, + 'forecast_len': forecast_len, + 'valid_forecast_len': valid_forecast_len, + 'max_forecast_len': max_forecast_len, + 'skip_periods': skip_periods, + 'one_shot': one_shot + } + + return data_config diff --git a/credit/datasets/era5_multistep.py b/credit/datasets/era5_multistep.py index 6cf7c5ca..91a990e2 100644 --- a/credit/datasets/era5_multistep.py +++ b/credit/datasets/era5_multistep.py @@ -321,7 +321,7 @@ def __init__( # ------------------------------------------------------------------ # # blocks that can handle no-sharing (each group has it own file) - ## surface + # surface if filename_surface is not None: surface_files = [] filename_surface = sorted(filename_surface) @@ -339,7 +339,7 @@ def __init__( else: self.surface_files = False - ## dynamic forcing + # dynamic forcing if filename_dyn_forcing is not None: dyn_forcing_files = [] filename_dyn_forcing = sorted(filename_dyn_forcing) @@ -357,7 +357,7 @@ def __init__( else: self.dyn_forcing_files = False - ## diagnostics + # diagnostics if filename_diagnostic is not None: diagnostic_files = [] filename_diagnostic = sorted(filename_diagnostic) @@ -511,388 +511,64 @@ def __getitem__(self, index): return sample -# class ERA5_and_Forcing_MultiStep(torch.utils.data.Dataset): -# ''' -# A Pytorch Dataset class that works on: -# - upper-air variables (time, level, lat, lon) -# - surface variables (time, lat, lon) -# - dynamic forcing variables (time, lat, lon) -# - foring variables (time, lat, lon) -# - diagnostic variables (time, lat, lon) -# - static variables (lat, lon) -# ''' - -# def __init__( -# self, -# varname_upper_air, -# varname_surface, -# varname_dyn_forcing, -# varname_forcing, -# varname_static, -# varname_diagnostic, -# filenames, -# filename_surface=None, -# filename_dyn_forcing=None, -# filename_forcing=None, -# filename_static=None, -# filename_diagnostic=None, -# history_len=2, -# forecast_len=0, -# transform=None, -# seed=42, -# skip_periods=None, -# one_shot=None, -# max_forecast_len=None, -# rank=0, -# world_size=1 -# ): - -# ''' -# Initialize the ERA5_and_Forcing_Dataset - -# Parameters: -# - varname_upper_air (list): List of upper air variable names. -# - varname_surface (list): List of surface variable names. -# - varname_dyn_forcing (list): List of dynamic forcing variable names. -# - varname_forcing (list): List of forcing variable names. -# - varname_static (list): List of static variable names. -# - varname_diagnostic (list): List of diagnostic variable names. -# - filenames (list): List of filenames for upper air data. -# - filename_surface (list, optional): List of filenames for surface data. -# - filename_dyn_forcing (list, optional): List of filenames for dynamic forcing data. -# - filename_forcing (str, optional): Filename for forcing data. -# - filename_static (str, optional): Filename for static data. -# - filename_diagnostic (list, optional): List of filenames for diagnostic data. -# - history_len (int, optional): Length of the history sequence. Default is 2. -# - forecast_len (int, optional): Length of the forecast sequence. Default is 0. -# - transform (callable, optional): Transformation function to apply to the data. -# - seed (int, optional): Random seed for reproducibility. Default is 42. -# - skip_periods (int, optional): Number of periods to skip between samples. -# - one_shot(bool, optional): Whether to return all states or just -# the final state of the training target. Default is None -# - max_forecast_len (int, optional): Maximum length of the forecast sequence. -# - shuffle (bool, optional): Whether to shuffle the data. Default is True. - -# Returns: -# - sample (dict): A dictionary containing historical_ERA5_images, -# target_ERA5_images, -# datetime index, and additional information. -# ''' - -# self.history_len = history_len -# self.forecast_len = forecast_len -# self.transform = transform - -# # skip periods -# self.skip_periods = skip_periods -# if self.skip_periods is None: -# self.skip_periods = 1 - -# # one shot option -# self.one_shot = one_shot - -# # total number of needed forecast lead times -# self.total_seq_len = self.history_len + self.forecast_len - -# # set random seed -# self.rng = np.random.default_rng(seed=seed) - -# # max possible forecast len -# self.max_forecast_len = max_forecast_len - -# # ======================================================== # -# # upper-air files - -# all_files = [] -# filenames = sorted(filenames) - -# for fn in filenames: -# # drop variables if they are not in the config -# xarray_dataset = get_forward_data(filename=fn) -# xarray_dataset = drop_var_from_dataset(xarray_dataset, varname_upper_air) - -# # collect yearly datasets within a list -# all_files.append(xarray_dataset) - -# self.all_files = all_files - -# # get sample indices from ERA5 upper-air files: -# ind_start = 0 -# self.ERA5_indices = {} # <------ change -# for ind_file, ERA5_xarray in enumerate(self.all_files): -# # [number of samples, ind_start, ind_end] -# self.ERA5_indices[str(ind_file)] = [len(ERA5_xarray['time']), -# ind_start, -# ind_start + len(ERA5_xarray['time'])] -# ind_start += len(ERA5_xarray['time']) + 1 - -# # ======================================================== # -# # surface files -# if filename_surface is not None: - -# surface_files = [] -# filename_surface = sorted(filename_surface) - -# for fn in filename_surface: - -# # drop variables if they are not in the config -# xarray_dataset = get_forward_data(filename=fn) -# xarray_dataset = drop_var_from_dataset(xarray_dataset, varname_surface) - -# surface_files.append(xarray_dataset) - -# self.surface_files = surface_files - -# else: -# self.surface_files = False - - -# # ======================================================== # -# # dynamic forcing files -# if filename_dyn_forcing is not None: - -# dyn_forcing_files = [] -# filename_dyn_forcing = sorted(filename_dyn_forcing) - -# for fn in filename_dyn_forcing: - -# # drop variables if they are not in the config -# xarray_dataset = get_forward_data(filename=fn) -# xarray_dataset = drop_var_from_dataset(xarray_dataset, varname_dyn_forcing) - -# dyn_forcing_files.append(xarray_dataset) - -# self.dyn_forcing_files = dyn_forcing_files - -# else: -# self.dyn_forcing_files = False - -# # ======================================================== # -# # diagnostic file -# self.filename_diagnostic = filename_diagnostic - -# if self.filename_diagnostic is not None: - -# diagnostic_files = [] -# filename_diagnostic = sorted(filename_diagnostic) - -# for fn in filename_diagnostic: - -# # drop variables if they are not in the config -# xarray_dataset = get_forward_data(filename=fn) -# xarray_dataset = drop_var_from_dataset(xarray_dataset, varname_diagnostic) - -# diagnostic_files.append(xarray_dataset) - -# self.diagnostic_files = diagnostic_files - -# else: -# self.diagnostic_files = False - -# # ======================================================== # -# # forcing file -# self.filename_forcing = filename_forcing - -# if self.filename_forcing is not None: -# # drop variables if they are not in the config -# xarray_dataset = get_forward_data(filename_forcing) -# xarray_dataset = drop_var_from_dataset(xarray_dataset, varname_forcing) - -# self.xarray_forcing = xarray_dataset -# else: -# self.xarray_forcing = False - -# # ======================================================== # -# # static file -# self.filename_static = filename_static - -# if self.filename_static is not None: -# # drop variables if they are not in the config -# xarray_dataset = get_forward_data(filename_static) -# xarray_dataset = drop_var_from_dataset(xarray_dataset, varname_static) - -# self.xarray_static = xarray_dataset -# else: -# self.xarray_static = False - -# self.start_index = self._get_random_start_index() -# self.forecast_step = 0 -# self.total_length = len(self.ERA5_indices) - -# def _get_random_start_index(self): -# """Generate a random start index based on the length of the dataset.""" -# dataset_length = len(self) -# return 0 #random.randint(0, dataset_length - 1) - -# def __post_init__(self): -# # Total sequence length of each sample. -# self.total_seq_len = self.history_len + self.forecast_len - -# def __len__(self): -# # compute the total number of length -# total_len = 0 -# for ERA5_xarray in self.all_files: -# total_len += len(ERA5_xarray['time']) - self.total_seq_len + 1 -# return total_len - -# def set_epoch(self, epoch): -# self.current_epoch = epoch -# self.forecast_step_count = 0 -# self.current_index = None -# self.initial_index = None - -# def _get_new_start_index(self, worker_id=0, num_workers=1): -# # Divide the data among workers such that there's no overlap -# total_steps = len(self.ERA5_indices) // num_workers -# worker_offset = worker_id * total_steps -# return worker_offset + (self.start_index % total_steps) - -# def __getitem__(self, index): - -# if (self.forecast_step_count == self.forecast_len + 1) or (self.current_index is None): -# # We've completed the last forecast or we're starting for the first time -# # Start a new forecast using the sampler index -# self.current_index = index # self._get_random_start_index() -# self.forecast_step_count = 0 -# index = self.current_index -# self.initial_index = self.current_index -# else: -# # Ignore the sampler index and continue the forecast -# self.current_index += 1 -# index = self.current_index - -# # select the ind_file based on the iter index -# ind_file = find_key_for_number(index, self.ERA5_indices) - -# # get the ind within the current file -# ind_start = self.ERA5_indices[ind_file][1] -# ind_start_in_file = index - ind_start - -# # handle out-of-bounds -# ind_largest = len(self.all_files[int(ind_file)]['time'])-(self.history_len+self.forecast_len+1) -# if ind_start_in_file > ind_largest: -# ind_start_in_file = ind_largest - -# # ========================================================================== # -# # subset xarray on time dimension - -# ind_end_in_file = ind_start_in_file+self.history_len - -# ## ERA5_subset: a xarray dataset that contains training input and target (for the current batch) -# ERA5_subset = self.all_files[int(ind_file)].isel( -# time=slice(ind_start_in_file, ind_end_in_file+1)) #.load() NOT load into memory - -# # ========================================================================== # -# # merge surface into the dataset - -# if self.surface_files: -# ## subset surface variables -# surface_subset = self.surface_files[int(ind_file)].isel( -# time=slice(ind_start_in_file, ind_end_in_file+1)) #.load() NOT load into memory - -# ## merge upper-air and surface here: -# ERA5_subset = ERA5_subset.merge(surface_subset) # <-- lazy merge, ERA5 and surface both not loaded - -# # ==================================================== # -# # split ERA5_subset into training inputs and targets -# # + merge with dynamic forcing, forcing, and static - -# # the ind_end of the ERA5_subset -# ind_end_time = len(ERA5_subset['time']) - -# # datetiem information as int number (used in some normalization methods) -# datetime_as_number = ERA5_subset.time.values.astype('datetime64[s]').astype(int) - -# # ==================================================== # -# # xarray dataset as input -# ## historical_ERA5_images: the final input - -# historical_ERA5_images = ERA5_subset.isel( -# time=slice(0, self.history_len, self.skip_periods)).load() # <-- load into memory - -# # ========================================================================== # -# # merge dynamic forcing inputs -# if self.dyn_forcing_files: -# dyn_forcing_subset = self.dyn_forcing_files[int(ind_file)].isel( -# time=slice(ind_start_in_file, ind_end_in_file+1)) -# dyn_forcing_subset = dyn_forcing_subset.isel( -# time=slice(0, self.history_len, self.skip_periods)).load() # <-- load into memory - -# historical_ERA5_images = historical_ERA5_images.merge(dyn_forcing_subset) - -# # ========================================================================== # -# # merge forcing inputs -# if self.xarray_forcing: -# # ------------------------------------------------------------------------------- # -# # matching month, day, hour between forcing and upper air [time] -# # this approach handles leap year forcing file and non-leap-year upper air file -# month_day_forcing = extract_month_day_hour(np.array(self.xarray_forcing['time'])) -# month_day_inputs = extract_month_day_hour(np.array(historical_ERA5_images['time'])) # <-- upper air -# # indices to subset -# ind_forcing, _ = find_common_indices(month_day_forcing, month_day_inputs) -# forcing_subset_input = self.xarray_forcing.isel(time=ind_forcing).load() # <-- load into memory -# # forcing and upper air have different years but the same mon/day/hour -# # safely replace forcing time with upper air time -# forcing_subset_input['time'] = historical_ERA5_images['time'] -# # ------------------------------------------------------------------------------- # - -# # merge -# historical_ERA5_images = historical_ERA5_images.merge(forcing_subset_input) - -# # ========================================================================== # -# # merge static inputs -# if self.xarray_static: -# # expand static var on time dim -# N_time_dims = len(ERA5_subset['time']) -# static_subset_input = self.xarray_static.expand_dims(dim={"time": N_time_dims}) -# # assign coords 'time' -# static_subset_input = static_subset_input.assign_coords({'time': ERA5_subset['time']}) - -# # slice + load to the GPU -# static_subset_input = static_subset_input.isel( -# time=slice(0, self.history_len, self.skip_periods)).load() # <-- load into memory - -# # update -# static_subset_input['time'] = historical_ERA5_images['time'] - -# # merge -# historical_ERA5_images = historical_ERA5_images.merge(static_subset_input) - -# # ==================================================== # -# # xarray dataset as target -# ## target_ERA5_images: the final target - -# target_ERA5_images = ERA5_subset.isel(time=slice(-1, None)).load() # <-- load into memory - -# ## merge diagnoisc input here: -# if self.diagnostic_files: -# diagnostic_subset = self.diagnostic_files[int(ind_file)].isel( -# time=slice(ind_start_in_file, ind_end_in_file+1)) - -# diagnostic_subset = diagnostic_subset.isel( -# time=slice(-1, None)).load() # <-- load into memory - -# target_ERA5_images = target_ERA5_images.merge(diagnostic_subset) - -# # pipe xarray datasets to the sampler -# sample = Sample( -# historical_ERA5_images=historical_ERA5_images, -# target_ERA5_images=target_ERA5_images, -# datetime_index=datetime_as_number -# ) - -# # ==================================== # -# # data normalization -# if self.transform: -# sample = self.transform(sample) - -# # assign sample index -# sample["datetime"] = datetime_as_number -# sample["forecast_step"] = self.forecast_step + 1 -# sample["index"] = index -# sample["stop_forecast"] = (self.forecast_step == self.forecast_len) - -# # update the step count -# self.forecast_step += 1 - -# return sample +if __name__ == "__main__": + + import torch + import yaml + from torch.utils.data import DataLoader + from credit.transforms import load_transforms + from credit.parser import credit_main_parser, training_data_check + from credit.datasets import setup_data_loading, set_globals + + with open( + "/glade/derecho/scratch/schreck/repos/miles-credit/production/multistep/wxformer_6h/model.yml" + ) as cf: + conf = yaml.load(cf, Loader=yaml.FullLoader) + + conf = credit_main_parser( + conf, parse_training=True, parse_predict=False, print_summary=False + ) + training_data_check(conf, print_summary=False) + + data_config = setup_data_loading(conf) + + data_config["forecast_len"] = 6 + batch_size = 2 + + set_globals(data_config, namespace=globals()) + + dataset_multi = ERA5_and_Forcing_MultiStep( + varname_upper_air=data_config['varname_upper_air'], + varname_surface=data_config['varname_surface'], + varname_dyn_forcing=data_config['varname_dyn_forcing'], + varname_forcing=data_config['varname_forcing'], + varname_static=data_config['varname_static'], + varname_diagnostic=data_config['varname_diagnostic'], + filenames=data_config['all_ERA_files'], + filename_surface=data_config['surface_files'], + filename_dyn_forcing=data_config['dyn_forcing_files'], + filename_forcing=data_config['forcing_files'], + filename_static=data_config['static_files'], + filename_diagnostic=data_config['diagnostic_files'], + history_len=data_config['history_len'], + forecast_len=data_config['forecast_len'], + skip_periods=data_config['skip_periods'], + one_shot=False, + max_forecast_len=data_config['max_forecast_len'], + transform=load_transforms(conf) + ) + + dataloader = DataLoader( + dataset_multi, + batch_size=1, # Adjust the batch size as needed + shuffle=True, # Shuffle the dataset if needed + num_workers=1, # Number of subprocesses to use for data loading (adjust as needed) + drop_last=True, # Drop the last incomplete batch if not divisible by batch_size, + prefetch_factor=4 + ) + + dataloader.dataset.set_epoch(0) + for (k, sample) in enumerate(dataloader): + print(k, sample['index'], sample['datetime'], sample['forecast_step'], sample['stop_forecast']) + if k == 20: + break diff --git a/credit/datasets/era5_multistep_batcher.py b/credit/datasets/era5_multistep_batcher.py new file mode 100644 index 00000000..9b7c4a4c --- /dev/null +++ b/credit/datasets/era5_multistep_batcher.py @@ -0,0 +1,425 @@ +import torch +import logging +import numpy as np +from functools import partial +from credit.data import ( + drop_var_from_dataset, + get_forward_data +) +from credit.datasets.era5_multistep import worker + + +logger = logging.getLogger(__name__) + + +class ERA5_MultiStep_Batcher(torch.utils.data.Dataset): + """ + A Pytorch Dataset class that works on: + - upper-air variables (time, level, lat, lon) + - surface variables (time, lat, lon) + - dynamic forcing variables (time, lat, lon) + - foring variables (time, lat, lon) + - diagnostic variables (time, lat, lon) + - static variables (lat, lon) + """ + + def __init__( + self, + varname_upper_air, + varname_surface, + varname_dyn_forcing, + varname_forcing, + varname_static, + varname_diagnostic, + filenames, + filename_surface=None, + filename_dyn_forcing=None, + filename_forcing=None, + filename_static=None, + filename_diagnostic=None, + history_len=2, + forecast_len=0, + transform=None, + seed=42, + rank=0, + world_size=1, + skip_periods=None, + max_forecast_len=None, + batch_size=1 + ): + """ + Initialize the ERA5_and_Forcing_Dataset + + Parameters: + - varname_upper_air (list): List of upper air variable names. + - varname_surface (list): List of surface variable names. + - varname_dyn_forcing (list): List of dynamic forcing variable names. + - varname_forcing (list): List of forcing variable names. + - varname_static (list): List of static variable names. + - varname_diagnostic (list): List of diagnostic variable names. + - filenames (list): List of filenames for upper air data. + - filename_surface (list, optional): List of filenames for surface data. + - filename_dyn_forcing (list, optional): List of filenames for dynamic forcing data. + - filename_forcing (str, optional): Filename for forcing data. + - filename_static (str, optional): Filename for static data. + - filename_diagnostic (list, optional): List of filenames for diagnostic data. + - history_len (int, optional): Length of the history sequence. Default is 2. + - forecast_len (int, optional): Length of the forecast sequence. Default is 0. + - transform (callable, optional): Transformation function to apply to the data. + - seed (int, optional): Random seed for reproducibility. Default is 42. + - skip_periods (int, optional): Number of periods to skip between samples. + - max_forecast_len (int, optional): Maximum length of the forecast sequence. + - shuffle (bool, optional): Whether to shuffle the data. Default is True. + + Returns: + - sample (dict): A dictionary containing historical_ERA5_images, + target_ERA5_images, + datetime index, and additional information. + """ + + self.history_len = history_len + self.forecast_len = forecast_len + self.transform = transform + self.seed = seed + self.rank = rank + self.world_size = world_size + + # skip periods + self.skip_periods = skip_periods + if self.skip_periods is None: + self.skip_periods = 1 + + # total number of needed forecast lead times + self.total_seq_len = self.history_len + self.forecast_len + + # set random seed + self.rng = np.random.default_rng(seed=seed) + + # max possible forecast len + self.max_forecast_len = max_forecast_len + + # =================================================================== # + # flags to determin if any of the [surface, dyn_forcing, diagnostics] + # variable groups share the same file as upper air variables + flag_share_surf = False + flag_share_dyn = False + flag_share_diag = False + + all_files = [] + filenames = sorted(filenames) + + # ------------------------------------------------------------------ # + # blocks that can handle no-sharing (each group has it own file) + # surface + if filename_surface is not None: + surface_files = [] + filename_surface = sorted(filename_surface) + + if filenames == filename_surface: + flag_share_surf = True + else: + for fn in filename_surface: + # drop variables if they are not in the config + ds = get_forward_data(filename=fn) + ds_surf = drop_var_from_dataset(ds, varname_surface) + surface_files.append(ds_surf) + + self.surface_files = surface_files + else: + self.surface_files = False + + # dynamic forcing + if filename_dyn_forcing is not None: + dyn_forcing_files = [] + filename_dyn_forcing = sorted(filename_dyn_forcing) + + if filenames == filename_dyn_forcing: + flag_share_dyn = True + else: + for fn in filename_dyn_forcing: + # drop variables if they are not in the config + ds = get_forward_data(filename=fn) + ds_dyn = drop_var_from_dataset(ds, varname_dyn_forcing) + dyn_forcing_files.append(ds_dyn) + + self.dyn_forcing_files = dyn_forcing_files + else: + self.dyn_forcing_files = False + + # diagnostics + if filename_diagnostic is not None: + diagnostic_files = [] + filename_diagnostic = sorted(filename_diagnostic) + + if filenames == filename_diagnostic: + flag_share_diag = True + else: + for fn in filename_diagnostic: + # drop variables if they are not in the config + ds = get_forward_data(filename=fn) + ds_diag = drop_var_from_dataset(ds, varname_diagnostic) + diagnostic_files.append(ds_diag) + + self.diagnostic_files = diagnostic_files + else: + self.diagnostic_files = False + + # ------------------------------------------------------------------ # + # blocks that can handle file sharing (share with upper air file) + for fn in filenames: + # drop variables if they are not in the config + ds = get_forward_data(filename=fn) + ds_upper = drop_var_from_dataset(ds, varname_upper_air) + + if flag_share_surf: + ds_surf = drop_var_from_dataset(ds, varname_surface) + surface_files.append(ds_surf) + + if flag_share_dyn: + ds_dyn = drop_var_from_dataset(ds, varname_dyn_forcing) + dyn_forcing_files.append(ds_dyn) + + if flag_share_diag: + ds_diag = drop_var_from_dataset(ds, varname_diagnostic) + diagnostic_files.append(ds_diag) + + all_files.append(ds_upper) + + self.all_files = all_files + + if flag_share_surf: + self.surface_files = surface_files + if flag_share_dyn: + self.dyn_forcing_files = dyn_forcing_files + if flag_share_diag: + self.diagnostic_files = diagnostic_files + + # -------------------------------------------------------------------------- # + # get sample indices from ERA5 upper-air files: + ind_start = 0 + self.ERA5_indices = {} # <------ change + for ind_file, ERA5_xarray in enumerate(self.all_files): + # [number of samples, ind_start, ind_end] + self.ERA5_indices[str(ind_file)] = [ + len(ERA5_xarray["time"]), + ind_start, + ind_start + len(ERA5_xarray["time"]), + ] + ind_start += len(ERA5_xarray["time"]) + 1 + + # ======================================================== # + # forcing file + self.filename_forcing = filename_forcing + + if self.filename_forcing is not None: + # drop variables if they are not in the config + xarray_dataset = get_forward_data(filename_forcing) + xarray_dataset = drop_var_from_dataset(xarray_dataset, varname_forcing) + + self.xarray_forcing = xarray_dataset + else: + self.xarray_forcing = False + + # ======================================================== # + # static file + self.filename_static = filename_static + + if self.filename_static is not None: + # drop variables if they are not in the config + xarray_dataset = get_forward_data(filename_static) + xarray_dataset = drop_var_from_dataset(xarray_dataset, varname_static) + + self.xarray_static = xarray_dataset + else: + self.xarray_static = False + + self.worker = partial( + worker, + ERA5_indices=self.ERA5_indices, + all_files=self.all_files, + surface_files=self.surface_files, + dyn_forcing_files=self.dyn_forcing_files, + diagnostic_files=self.diagnostic_files, + xarray_forcing=self.xarray_forcing, + xarray_static=self.xarray_static, + history_len=self.history_len, + forecast_len=self.forecast_len, + skip_periods=self.skip_periods, + transform=self.transform, + ) + + self.total_length = len(self.ERA5_indices) + self.current_epoch = None + self.current_index = None + self.initial_index = None + + # Initialize state variables for batch management + self.batch_size = batch_size + self.batch_indices = None # To track initial indices for each batch item + self.time_steps = None # Tracks time steps for each batch index + self.forecast_step_counts = None # Track forecast step counts for each batch item + self.initial_indices = None # Tracks the initial index for each forecast item in the batch + + # Initialize batch once when the dataset is created + self.initialize_batch() + + def initialize_batch(self): + """ + Initializes random starting indices for each batch item and resets their time steps and forecast counts. + This must be called before accessing the dataset or when resetting the batch. + """ + # Randomly sample indices for the batch + self.batch_indices = np.random.choice( + range(self.__len__() - self.forecast_len), + size=self.batch_size, + replace=False, + ) + self.time_steps = [0 for idx in self.batch_indices] # Initialize time to 0 for each item + self.forecast_step_counts = [0 for idx in self.batch_indices] # Initialize forecast step counts + self.initial_indices = list(self.batch_indices) # Track initial indices for each batch item + + def __post_init__(self): + # Total sequence length of each sample. + self.total_seq_len = self.history_len + self.forecast_len + + def __len__(self): + # compute the total number of length + total_len = 0 + for ERA5_xarray in self.all_files: + total_len += len(ERA5_xarray["time"]) - self.total_seq_len + 1 + return total_len + + def set_epoch(self, epoch): + self.current_epoch = epoch + self.current_index = None + self.initial_index = None + + def __getitem__(self, _): + """ + Fetches the current forecast step data for each item in the batch. + Resets items when their forecast length is exceeded. + """ + batch = {} + + # If the forecast_step_count exceeds forecast_len, reset the item + # If one exceeds, they all exceed and all neet reset + if self.forecast_step_counts[0] == self.forecast_len + 1: + # Get a new starting index for this item (randomly selected) + self.initialize_batch() + + for k, idx in enumerate(self.batch_indices): + # Get the current time step for this batch item + current_t = self.time_steps[k] + initial_idx = self.initial_indices[k] # Correctly find the initial index + index_pair = (initial_idx, idx + current_t) # Correctly construct the index_pair + + # Fetch the current sample for this batch item + sample = self.worker(index_pair) + + # Add index to the sample + sample["index"] = idx + + # Concatenate data by common keys in sample + for key, value in sample.items(): + if isinstance(value, np.ndarray): # If it's a numpy array + value = torch.tensor(value) + elif isinstance(value, np.int64): # If it's a numpy scalar (int64) + value = torch.tensor(value, dtype=torch.int64) + elif isinstance(value, (int, float)): # If it's a native Python scalar (int/float) + value = torch.tensor(value, dtype=torch.float32) # Ensure tensor is float for scalar + elif not isinstance(value, torch.Tensor): # If it's not already a tensor + value = torch.tensor(value) # Ensure conversion to tensor + + # Convert zero-dimensional tensor (scalar) to 1D tensor + if value.ndimension() == 0: + value = value.unsqueeze(0) # Unsqueeze to make it a 1D tensor + + if key not in batch: + batch[key] = value # Initialize the key in the batch dictionary + else: + batch[key] = torch.cat((batch[key], value), dim=0) # Concatenate values along the batch dimension + + # Increment time steps and forecast step counts for this batch item + self.time_steps[k] += 1 + self.forecast_step_counts[k] += 1 + + batch["forecast_step"] = self.forecast_step_counts[0] + batch["stop_forecast"] = batch["forecast_step"] == self.forecast_len + 1 + batch["datetime"] = batch["datetime"].view(-1, self.batch_size) # reshape + + return batch + + +if __name__ == "__main__": + + import logging + import torch + import yaml + from torch.utils.data import DataLoader + from credit.transforms import load_transforms + from credit.parser import credit_main_parser, training_data_check + from credit.datasets import setup_data_loading, set_globals + + # Set up the logger + logging.basicConfig( + level=logging.INFO, # Set the logging level + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' + ) + logger = logging.getLogger(__name__) # Create a logger with the module name + + with open( + "/glade/derecho/scratch/schreck/repos/miles-credit/production/multistep/wxformer_6h/model.yml" + ) as cf: + conf = yaml.load(cf, Loader=yaml.FullLoader) + + conf = credit_main_parser( + conf, parse_training=True, parse_predict=False, print_summary=False + ) + training_data_check(conf, print_summary=False) + data_config = setup_data_loading(conf) + + batch_size = 2 + data_config["forecast_len"] = 6 + + set_globals(data_config, namespace=globals()) + + # globals().update(data_config) + # for key, value in data_config.items(): + # globals()[key] = value + # logger.info(f"Creating global variable in the namespace: {key}") + + dataset_multi = ERA5_MultiStep_Batcher( + varname_upper_air=data_config['varname_upper_air'], + varname_surface=data_config['varname_surface'], + varname_dyn_forcing=data_config['varname_dyn_forcing'], + varname_forcing=data_config['varname_forcing'], + varname_static=data_config['varname_static'], + varname_diagnostic=data_config['varname_diagnostic'], + filenames=data_config['all_ERA_files'], + filename_surface=data_config['surface_files'], + filename_dyn_forcing=data_config['dyn_forcing_files'], + filename_forcing=data_config['forcing_files'], + filename_static=data_config['static_files'], + filename_diagnostic=data_config['diagnostic_files'], + history_len=data_config['history_len'], + forecast_len=data_config['forecast_len'], + skip_periods=data_config['skip_periods'], + max_forecast_len=data_config['max_forecast_len'], + transform=load_transforms(conf), + batch_size=batch_size + ) + + dataloader = DataLoader( + dataset_multi, + batch_size=1, # Adjust the batch size as needed + shuffle=False, # Shuffle the dataset if needed + num_workers=1, # Number of subprocesses to use for data loading (adjust as needed) + drop_last=True, # Drop the last incomplete batch if not divisible by batch_size, + prefetch_factor=4 + ) + + dataloader.dataset.set_epoch(0) + for (k, sample) in enumerate(dataloader): + print(k, sample['index'], sample['datetime'], sample['forecast_step'], sample['stop_forecast']) + if k == 20: + break