diff --git a/README.md b/README.md index 9a3c8fb..bbc5f74 100644 --- a/README.md +++ b/README.md @@ -43,7 +43,9 @@ metnet3 = MetNet3( omo_wind_component_y = 256, omo_wind_direction = 180 ), - hrrr_loss_weight = 10 + hrrr_loss_weight = 10, + hrrr_norm_strategy = 'sync_batchnorm', # this would use a sync batchnorm to normalize the input hrrr and target, without having to precalculate the mean and variance of the hrrr dataset per channel + hrrr_norm_statistics = None # you can also also set `hrrr_norm_strategy = "precalculated"` and pass in the mean and variance as shape `(2, 617)` through this keyword argument ) # inputs @@ -107,8 +109,8 @@ surface_preds, hrrr_pred, precipitation_preds = metnet3( - [x] auto-handle normalization across all the channels of the HRRR by tracking a running mean and variance of targets during training (using sync batchnorm as hack) - [x] allow researcher to pass in their own normalization variables for HRRR - [x] build all the inputs to spec, also make sure hrrr input is normalized, offer option to unnormalize hrrr predictions +- [x] make sure model can be easily saved and loaded, with different ways of handling hrrr norm -- [ ] make sure model can be easily saved and loaded, with different ways of handling hrrr norm - [ ] figure out the topological embedding, consult a neural weather researcher ## Citations diff --git a/metnet3_pytorch/metnet3_pytorch.py b/metnet3_pytorch/metnet3_pytorch.py index e0f5c92..7a0af9c 100644 --- a/metnet3_pytorch/metnet3_pytorch.py +++ b/metnet3_pytorch/metnet3_pytorch.py @@ -1,3 +1,4 @@ +from pathlib import Path from contextlib import contextmanager from functools import partial from collections import namedtuple @@ -13,7 +14,9 @@ from einops.layers.torch import Rearrange, Reduce from beartype import beartype -from beartype.typing import Tuple, Union, List, Optional, Dict +from beartype.typing import Tuple, Union, List, Optional, Dict, Literal + +import pickle # helpers @@ -609,6 +612,11 @@ def __init__( omo_wind_component_y = 256, omo_wind_direction = 180 ), + hrrr_norm_strategy: Union[ + Literal['none'], + Literal['precalculated'], + Literal['sync_batchnorm'] + ] = 'none', hrrr_channels = 617, hrrr_norm_statistics: Optional[Tensor] = None, hrrr_loss_weight = 10, @@ -616,6 +624,15 @@ def __init__( resnet_block_depth = 2, ): super().__init__() + + # for autosaving the config + + _locals = locals() + _locals.pop('self', None) + _locals.pop('__class__', None) + _locals.pop('hrrr_norm_statistics', None) + self._configs = pickle.dumps(_locals) + self.hrrr_input_2496_shape = (hrrr_channels, input_spatial_size, input_spatial_size) self.input_2496_shape = (input_2496_channels, input_spatial_size, input_spatial_size) self.input_4996_shape = (input_4996_channels, input_spatial_size, input_spatial_size) @@ -732,15 +749,60 @@ def __init__( self.hrrr_loss_weight = hrrr_loss_weight / hrrr_channels - self.has_hrrr_norm_statistics = exists(hrrr_norm_statistics) + self.mse_loss_scaler = LossScaler() + + # norm statistics - if self.has_hrrr_norm_statistics: - assert hrrr_norm_statistics.shape == (2, hrrr_channels), f'normalization statistics must be of shape (2, {normed_hrrr_target}), containing mean and variance of each target calculated from the dataset' + default_hrrr_statistics = torch.empty((2, hrrr_channels), dtype = torch.float32) + + if hrrr_norm_strategy == 'none': + self.register_buffer('hrrr_norm_statistics', default_hrrr_statistics, persistent = False) + + elif hrrr_norm_strategy == 'precalculated': + assert exists(hrrr_norm_statistics), 'hrrr_norm_statistics must be passed in, if normalizing input hrrr as well as target with precalculated dataset mean and variance' + assert hrrr_norm_statistics.shape == (2, hrrr_channels), f'normalization statistics must be of shape (2, {hrrr_channels}), containing mean and variance of each target calculated from the dataset' self.register_buffer('hrrr_norm_statistics', hrrr_norm_statistics) - else: + + elif hrrr_norm_strategy == 'sync_batchnorm': + self.register_buffer('hrrr_norm_statistics', default_hrrr_statistics, persistent = False) self.batchnorm_hrrr = MaybeSyncBatchnorm2d()(hrrr_channels, affine = False) - self.mse_loss_scaler = LossScaler() + self.hrrr_norm_strategy = hrrr_norm_strategy + + @classmethod + def init_and_load_from(cls, path, strict = True): + path = Path(path) + assert path.exists() + pkg = torch.load(str(path), map_location = 'cpu') + + assert 'config' in pkg, 'model configs were not found in this saved checkpoint' + + config = pickle.loads(pkg['config']) + tokenizer = cls(**config) + tokenizer.load(path, strict = strict) + return tokenizer + + def save(self, path, overwrite = True): + path = Path(path) + assert overwrite or not path.exists(), f'{str(path)} already exists' + + pkg = dict( + model_state_dict = self.state_dict(), + config = self._configs + ) + + torch.save(pkg, str(path)) + + def load(self, path, strict = True): + path = Path(path) + assert path.exists() + + pkg = torch.load(str(path)) + state_dict = pkg.get('model_state_dict') + + assert exists(state_dict) + + self.load_state_dict(state_dict, strict = strict) @beartype def forward( @@ -763,9 +825,9 @@ def forward( assert input_2496.shape[1:] == self.input_2496_shape assert input_4996.shape[1:] == self.input_4996_shape - # normalize HRRR input and target as needed + # normalize HRRR input and target, if needed - if self.has_hrrr_norm_statistics: + if self.hrrr_norm_strategy == 'precalculated': mean, variance = self.hrrr_norm_statistics mean = rearrange(mean, 'c -> c 1 1') variance = rearrange(variance, 'c -> c 1 1') @@ -776,7 +838,7 @@ def forward( if exists(hrrr_target): normed_hrrr_target = (hrrr_target - mean) * inv_std - else: + elif self.hrrr_norm_strategy == 'sync_batchnorm': # use a batchnorm to normalize each channel to mean zero and unit variance with freeze_batchnorm(self.batchnorm_hrrr) as frozen_batchnorm: @@ -785,6 +847,12 @@ def forward( if exists(hrrr_target): normed_hrrr_target = frozen_batchnorm(hrrr_target) + elif self.hrrr_norm_strategy == 'none': + normed_hrrr_input = hrrr_input_2496 + + if exists(hrrr_target): + normed_hrrr_target = hrrr_target + # main network cond = self.lead_time_embedding(lead_times) @@ -899,7 +967,7 @@ def forward( # update hrrr normalization statistics, if using batchnorm way - if not self.has_hrrr_norm_statistics and self.training: + if self.training and self.hrrr_norm_strategy == 'sync_batchnorm': _ = self.batchnorm_hrrr(hrrr_target) # total loss diff --git a/setup.py b/setup.py index 0c0a22e..4f6e21b 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'metnet3-pytorch', packages = find_packages(exclude=[]), - version = '0.0.9', + version = '0.0.11', license='MIT', description = 'MetNet 3 - Pytorch', author = 'Phil Wang',