From 7ba2972ac68b88c84605b807ca5ab69923f462fe Mon Sep 17 00:00:00 2001 From: dkimpara Date: Wed, 23 Oct 2024 16:27:08 -0600 Subject: [PATCH] working on device bugs --- credit/models/crossformer.py | 2 ++ credit/postblock.py | 7 +++++-- tests/test_postblock.py | 23 ++++++++++++++++++++--- 3 files changed, 27 insertions(+), 5 deletions(-) diff --git a/credit/models/crossformer.py b/credit/models/crossformer.py index 31f523f..525a5f0 100644 --- a/credit/models/crossformer.py +++ b/credit/models/crossformer.py @@ -438,10 +438,12 @@ def __init__( apply_spectral_norm(self) if freeze_base_model_weights: + logger.warning("freezing all base model weights due to skebs config") for param in self.parameters(): param.requires_grad = False if self.use_post_block: + logger.info("using postblock") self.postblock = PostBlock(post_conf) def forward(self, x): diff --git a/credit/postblock.py b/credit/postblock.py index 984589d..a4c8425 100644 --- a/credit/postblock.py +++ b/credit/postblock.py @@ -660,7 +660,8 @@ def initialize_pattern(self, y_pred): persistent=False) # initialize pattern todo: how many iters? - for _ in range(10): + for i in range(10): + logger.debug(f'cycle {i}') self.spec_coef = self.cycle_pattern(self.spec_coef) def cycle_pattern(self, spec_coef): @@ -668,8 +669,10 @@ def cycle_pattern(self, spec_coef): b = torch.sqrt((4.0 * PI * RAD_EARTH**2.0) / (self.variance * Gamma) * self.alpha * self.dE) # scalar g_n = b * self.lrange ** self.p # (lmax, 1) logger.debug(f"g_n: {g_n.shape}") - noise = self.variance * torch.randn(spec_coef.shape) # (b, 1, 1, lmax, mmax) std normal noise diff for all n? + device = spec_coef.device + logger.info(f"spec_coef device: {device}") + noise = self.variance * torch.randn(self.spec_coef.shape, device=device) # (b, 1, 1, lmax, mmax) std normal noise diff for all n? new_coef = (1.0 - self.alpha) * spec_coef + g_n * torch.sqrt(self.alpha) * noise # (lmax, mmax) return new_coef diff --git a/tests/test_postblock.py b/tests/test_postblock.py index 67b91a3..c4c40ae 100644 --- a/tests/test_postblock.py +++ b/tests/test_postblock.py @@ -1,6 +1,7 @@ import pytest import yaml import os +import logging import torch from credit.models.crossformer import CrossFormer @@ -8,6 +9,7 @@ from credit.postblock import SKEBS, TracerFixer, GlobalMassFixer, GlobalEnergyFixer from credit.parser import CREDIT_main_parser + TEST_FILE_DIR = "/".join(os.path.abspath(__file__).split("/")[:-1]) CONFIG_FILE_DIR = os.path.join("/".join(os.path.abspath(__file__).split("/")[:-2]), "config") @@ -18,6 +20,7 @@ def test_SKEBS_integration(): integration testing to make sure everything goes on GPU, is loaded properly etc requires loading weights ''' + logging.info("integration testing SKEBS") config = os.path.join(CONFIG_FILE_DIR, "example_skebs.yml") with open(config) as cf: conf = yaml.load(cf, Loader=yaml.FullLoader) @@ -42,9 +45,12 @@ def test_SKEBS_integration(): y_pred[:, sp_index] = torch.ones_like(y_pred[:, sp_index]) * 1013 model = CrossFormer(**conf["model"]) - model.to("cpu") + device = torch.device(f"cuda:{1 % torch.cuda.device_count()}") if torch.cuda.is_available() else torch.device("cpu") model = model.load_model(conf) - pred = model(x) + model.to(device) + logging.info(f"model: {device}") + + pred = model(x.to(device)) assert pred.shape == y_pred.shape @@ -244,5 +250,16 @@ def test_SKEBS_era5(): pass if __name__ == "__main__": - # test_SKEBS_integration() + # Set up logger to print stuff + root = logging.getLogger() + root.setLevel(logging.DEBUG) + formatter = logging.Formatter("%(levelname)s:%(name)s:%(message)s") + + # Stream output to stdout + ch = logging.StreamHandler() + # ch.setLevel(logging.INFO) + ch.setFormatter(formatter) + root.addHandler(ch) + + test_SKEBS_integration() test_SKEBS_rand() \ No newline at end of file