Skip to content

Commit

Permalink
working on device bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
dkimpara committed Oct 23, 2024
1 parent 4f1c74b commit 7ba2972
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 5 deletions.
2 changes: 2 additions & 0 deletions credit/models/crossformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,10 +438,12 @@ def __init__(
apply_spectral_norm(self)

if freeze_base_model_weights:
logger.warning("freezing all base model weights due to skebs config")
for param in self.parameters():
param.requires_grad = False

if self.use_post_block:
logger.info("using postblock")
self.postblock = PostBlock(post_conf)

def forward(self, x):
Expand Down
7 changes: 5 additions & 2 deletions credit/postblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,16 +660,19 @@ def initialize_pattern(self, y_pred):
persistent=False)

# initialize pattern todo: how many iters?
for _ in range(10):
for i in range(10):
logger.debug(f'cycle {i}')
self.spec_coef = self.cycle_pattern(self.spec_coef)

def cycle_pattern(self, spec_coef):
Gamma = torch.sum(self.lrange * (self.lrange + 1.0) * (self.lrange + 2.0) * self.lrange ** (2.0 * self.p)) # scalar
b = torch.sqrt((4.0 * PI * RAD_EARTH**2.0) / (self.variance * Gamma) * self.alpha * self.dE) # scalar
g_n = b * self.lrange ** self.p # (lmax, 1)
logger.debug(f"g_n: {g_n.shape}")
noise = self.variance * torch.randn(spec_coef.shape) # (b, 1, 1, lmax, mmax) std normal noise diff for all n?

device = spec_coef.device
logger.info(f"spec_coef device: {device}")
noise = self.variance * torch.randn(self.spec_coef.shape, device=device) # (b, 1, 1, lmax, mmax) std normal noise diff for all n?
new_coef = (1.0 - self.alpha) * spec_coef + g_n * torch.sqrt(self.alpha) * noise # (lmax, mmax)
return new_coef

Expand Down
23 changes: 20 additions & 3 deletions tests/test_postblock.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import pytest
import yaml
import os
import logging

import torch
from credit.models.crossformer import CrossFormer
from credit.postblock import PostBlock, Backscatter_FCNN
from credit.postblock import SKEBS, TracerFixer, GlobalMassFixer, GlobalEnergyFixer
from credit.parser import CREDIT_main_parser


TEST_FILE_DIR = "/".join(os.path.abspath(__file__).split("/")[:-1])
CONFIG_FILE_DIR = os.path.join("/".join(os.path.abspath(__file__).split("/")[:-2]),
"config")
Expand All @@ -18,6 +20,7 @@ def test_SKEBS_integration():
integration testing to make sure everything goes on GPU, is loaded properly etc
requires loading weights
'''
logging.info("integration testing SKEBS")
config = os.path.join(CONFIG_FILE_DIR, "example_skebs.yml")
with open(config) as cf:
conf = yaml.load(cf, Loader=yaml.FullLoader)
Expand All @@ -42,9 +45,12 @@ def test_SKEBS_integration():
y_pred[:, sp_index] = torch.ones_like(y_pred[:, sp_index]) * 1013

model = CrossFormer(**conf["model"])
model.to("cpu")
device = torch.device(f"cuda:{1 % torch.cuda.device_count()}") if torch.cuda.is_available() else torch.device("cpu")
model = model.load_model(conf)
pred = model(x)
model.to(device)
logging.info(f"model: {device}")

pred = model(x.to(device))

assert pred.shape == y_pred.shape

Expand Down Expand Up @@ -244,5 +250,16 @@ def test_SKEBS_era5():
pass

if __name__ == "__main__":
# test_SKEBS_integration()
# Set up logger to print stuff
root = logging.getLogger()
root.setLevel(logging.DEBUG)
formatter = logging.Formatter("%(levelname)s:%(name)s:%(message)s")

# Stream output to stdout
ch = logging.StreamHandler()
# ch.setLevel(logging.INFO)
ch.setFormatter(formatter)
root.addHandler(ch)

test_SKEBS_integration()
test_SKEBS_rand()

0 comments on commit 7ba2972

Please sign in to comment.