Skip to content

Commit

Permalink
update unet init and call in models/__init__.py
Browse files Browse the repository at this point in the history
  • Loading branch information
dkimpara committed Oct 17, 2024
1 parent 7357df0 commit 998c1e3
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 82 deletions.
106 changes: 61 additions & 45 deletions config/unet_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,48 +4,53 @@
# The model is trained on hourly model-level ERA5 data with top solar irradiance, geopotential, and land-sea mask
# Output variables: model level [U, V, T, Q], single level [SP, t2m], and 500 hPa [U, V, T, Z, Q]
# --------------------------------------------------------------------------------------------------------------------- #
save_loc: '/glade/work/$USER/CREDIT_runs/post_test/'
save_loc: '/glade/work/$USER/CREDIT_runs/test_1dg_unet/'
seed: 1000

data:
# upper-air variables
variables: ['U','V','T','Q']
save_loc: '/glade/derecho/scratch/wchapman/y_TOTAL*'
save_loc: '/glade/derecho/scratch/dkimpara/y_ONEdeg*.zarr'

# surface variables
surface_variables: ['SP','t2m', 'V500','U500','T500','Z500','Q500']
save_loc_surface: '/glade/derecho/scratch/wchapman/y_TOTAL*'
save_loc_surface: '/glade/derecho/scratch/dkimpara/y_ONEdeg*.zarr'

# dynamic forcing variables
dynamic_forcing_variables: [] #['tsi']
save_loc_dynamic_forcing: '/glade/derecho/scratch/dgagne/credit_solar_1h_0.25deg/*.nc'
dynamic_forcing_variables: ['tsi']
save_loc_dynamic_forcing: '/glade/derecho/scratch/dkimpara/credit_solar_1h_1deg/*.nc'

# static variables
static_variables: [] #['Z_GDS4_SFC','LSM']
save_loc_static: '/glade/derecho/scratch/ksha/CREDIT_data/static_norm_old.nc'
static_variables: ['Z_GDS4_SFC','LSM']
save_loc_static: '/glade/derecho/scratch/dkimpara/LSM_static_variables_ERA5_zhght_onedeg.nc'

# mean / std path
mean_path: '/glade/derecho/scratch/ksha/CREDIT_data/mean_1h_1979_2018_16lev_0.25deg.nc'
# regular z-score version
std_path: '/glade/derecho/scratch/ksha/CREDIT_data/std_1h_1979_2018_16lev_0.25deg.nc'
# std_path: '/glade/derecho/scratch/ksha/CREDIT_data/std_1h_1979_2018_16lev_0.25deg.nc'
# residual norm version
# std_path: '/glade/derecho/scratch/ksha/CREDIT_data/std_residual_1h_1979_2018_16lev_0.25deg.nc'
std_path: '/glade/derecho/scratch/ksha/CREDIT_data/std_residual_1h_1979_2018_16lev_0.25deg.nc'

# train / validation split
train_years: [1979, 2018]
valid_years: [2018, 2019]

# data workflow
scaler_type: 'std_new'

# state-in-state-out
history_len: 2
valid_history_len: 2

# single step

forecast_len: 0
valid_forecast_len: 0

one_shot: True

# 1 for hourly model
lead_time_periods: 1

# do not use skip_period
skip_periods: null

# compatible with the old 'std'
Expand All @@ -59,14 +64,14 @@ trainer:

mode: none
cpu_offload: False
activation_checkpoint: False
activation_checkpoint: True

load_weights: False
load_optimizer: False
load_scaler: False
load_sheduler: False
skip_validation: True

skip_validation: False
update_learning_rate: False

save_backup_weights: True
Expand All @@ -78,41 +83,52 @@ trainer:
train_batch_size: 1
valid_batch_size: 1

batches_per_epoch: 1 # use 50% of full epoch; full epoch = 10681(batches) * 32(samples)
valid_batches_per_epoch: 0
batches_per_epoch: 1 # Total number of samples = 341,880
valid_batches_per_epoch: 1
stopping_patience: 999

start_epoch: 0
# num_epoch: 6 # optional
num_epoch: 1
reload_epoch: True
epochs: &epochs 1
epochs: &epochs 70

use_scheduler: True
scheduler: {'scheduler_type': 'cosine-annealing', 'T_max': *epochs, 'last_epoch': -1}

# Automatic Mixed Precision: False
amp: False
grad_accum_every: 1

# rescale loss as loss = loss / grad_accum_every
grad_accum_every: 1
# gradient clipping
grad_max_norm: 1.0

# number of workers
thread_workers: 4
valid_thread_workers: 0

# compile
# compile: True

model:
type: "unet"
image_height: 640
image_width: 1280
image_height: 192 #640
image_width: 288 #1280
frames: 2
channels: 4
surface_channels: 7
input_only_channels: 3
output_only_channels: 0
levels: 16
rk4_integration: False
architecture:
name: "unet"
encoder_name: "resnet34"
encoder_weights: "imagenet"
post_conf:
activate: True
activate: False
skebs:
activate: True
activate: False

tracer_fixer:
activate: False
Expand All @@ -133,7 +149,7 @@ loss:

# use latitude weighting
use_latitude_weights: True
latitude_weights: "/glade/u/home/wchapman/MLWPS/DataLoader/LSM_static_variables_ERA5_zhght.nc"
latitude_weights: "/glade/derecho/scratch/dkimpara/LSM_static_variables_ERA5_zhght_onedeg.nc"

# turn-off variable weighting
use_variable_weights: False
Expand All @@ -159,25 +175,25 @@ predict:
# deprecated
# save_format: "nc"

# pbs: # casper
# conda: "/glade/work/ksha/miniconda3/envs/credit"
# job_name: 'unet'
# nodes: 1
# ncpus: 8
# ngpus: 1
# mem: '128GB'
# walltime: '24:00:00'
# gpu_type: 'a100'
# project: 'NAML0001'
# queue: 'casper'

pbs: #derecho
conda: "credit-derecho"
project: "NAML0001"
job_name: "unet_skebs"
walltime: "00:30:00"
pbs: # casper
conda: "$HOME/credit"
job_name: 'test_1dg_unet'
project: 'NAML0001'
nodes: 1
ncpus: 64
ngpus: 4
mem: '480GB'
queue: 'main'
ncpus: 8
ngpus: 1
mem: '64GB'
walltime: '00:10:00'
gpu_type: 'a100'
queue: 'casper'

# pbs: #derecho
# conda: "credit-derecho"
# project: "NAML0001"
# job_name: "unet_data"
# walltime: "12:00:00"
# nodes: 1
# ncpus: 64
# ngpus: 4
# mem: '480GB'
# queue: 'main'
3 changes: 1 addition & 2 deletions credit/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,7 @@ def load_model(conf, load_weights=False):

model_type = model_conf.pop("type")

#if model_type == 'unet':
if model_type in ('unet', 'unet404'):
if model_type in ('unet404',):
import torch
model, message = model_types[model_type]
logger.info(message)
Expand Down
80 changes: 51 additions & 29 deletions credit/models/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import copy
import os
import torch.nn.functional as F
#from credit.models.base_model import BaseModel
from torch import nn
from credit.models.base_model import BaseModel
from credit.postblock import PostBlock

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
Expand Down Expand Up @@ -40,46 +41,67 @@ def load_premade_encoder_model(model_conf):
f"Model name {name} not recognized. Please choose from {supported_models.keys()}")


class SegmentationModel(torch.nn.Module):

def __init__(self, conf):
class SegmentationModel(BaseModel):

def __init__(self,
image_height=640,
image_width=1280,
frames=2,
channels=4,
surface_channels=7,
input_only_channels=3,
output_only_channels=0,
levels=16,
rk4_integration=False,
architecture=
{
"name": "unet",
"encoder_name": "resnet34",
"encoder_weights": "imagenet",
},
post_conf={"use_skebs": False},
**kwargs
):

super(SegmentationModel, self).__init__()

self.variables = len(conf["data"]["variables"])
self.levels = conf["model"]["levels"]
self.frames = conf["model"]["frames"]
self.surface_variables = len(conf["data"]["surface_variables"])
self.static_variables = len(conf["data"]["static_variables"])
self.use_codebook = False
self.rk4_integration = conf["model"]["rk4_integration"]
self.channels = 1
self.image_height = image_height
self.image_width = image_width
self.frames = frames
self.channels = channels
self.surface_channels = surface_channels
self.levels = levels
self.rk4_integration = rk4_integration

# input channels
input_channels = channels * levels + surface_channels + input_only_channels

in_out_channels = int(self.variables*self.levels + self.surface_variables + self.static_variables)
# output channels
output_channels = channels * levels + surface_channels + output_only_channels

if conf['model']['architecture']['name'] == 'unet':
conf['model']['architecture']['decoder_attention_type'] = 'scse'
conf['model']['architecture']['in_channels'] = in_out_channels
conf['model']['architecture']['classes'] = in_out_channels
if architecture['name'] == 'unet':
architecture['decoder_attention_type'] = 'scse'
architecture['in_channels'] = input_channels
architecture['classes'] = output_channels

self.model = load_premade_encoder_model(conf['model']['architecture'])
self.model = load_premade_encoder_model(architecture)
# Additional layers for testing

self.use_post_block = conf["model"]["post_conf"]["activate"]
self.use_post_block = post_conf["activate"]
if self.use_post_block:
self.postblock = PostBlock(conf["model"]["post_conf"])
self.postblock = PostBlock(post_conf)


def concat_and_reshape(self, x1, x2):
x1 = x1.view(x1.shape[0], x1.shape[1], x1.shape[2] * x1.shape[3], x1.shape[4], x1.shape[5])
x_concat = torch.cat((x1, x2), dim=2)
return x_concat.permute(0, 2, 1, 3, 4)
# def concat_and_reshape(self, x1, x2):
# x1 = x1.view(x1.shape[0], x1.shape[1], x1.shape[2] * x1.shape[3], x1.shape[4], x1.shape[5])
# x_concat = torch.cat((x1, x2), dim=2)
# return x_concat.permute(0, 2, 1, 3, 4)

def split_and_reshape(self, tensor):
tensor1 = tensor[:, :int(self.channels * self.levels), :, :, :]
tensor2 = tensor[:, -int(self.surface_channels):, :, :, :]
tensor1 = tensor1.view(tensor1.shape[0], self.channels, self.levels, tensor1.shape[2], tensor1.shape[3], tensor1.shape[4])
return tensor1, tensor2
# def split_and_reshape(self, tensor):
# tensor1 = tensor[:, :int(self.channels * self.levels), :, :, :]
# tensor2 = tensor[:, -int(self.surface_channels):, :, :, :]
# tensor1 = tensor1.view(tensor1.shape[0], self.channels, self.levels, tensor1.shape[2], tensor1.shape[3], tensor1.shape[4])
# return tensor1, tensor2

def forward(self, x):
if self.use_post_block: # copy tensor to feed into postBlock later
Expand Down
18 changes: 12 additions & 6 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,20 @@ def test_unet():
levels = conf["model"]["levels"]
frames = conf["model"]["frames"]
surface_variables = len(conf["data"]["surface_variables"])
static_variables = len(conf["data"]["static_variables"])
input_only_variables = (len(conf["data"]["static_variables"])
+ len(conf["data"]["dynamic_forcing_variables"]))
output_only_variables = conf["model"]["output_only_channels"]

in_channels = int(variables*levels + surface_variables + input_only_variables)
out_channels = int(variables*levels + surface_variables + output_only_variables)

assert in_channels != out_channels

in_channels = int(variables*levels + surface_variables + static_variables)
input_tensor = torch.randn(1, in_channels, frames, image_height, image_width)

y_pred = model(input_tensor)

assert y_pred.shape == torch.Size([1, in_channels, 1, image_height, image_width])
assert y_pred.shape == torch.Size([1, out_channels, 1, image_height, image_width])
assert not torch.isnan(y_pred).any()

def test_crossformer():
Expand Down Expand Up @@ -65,9 +71,9 @@ def test_crossformer():
assert y_pred.shape == torch.Size([1, in_channels - input_only_channels, 1, image_height, image_width])
assert not torch.isnan(y_pred).any()

# if __name__ == "__main__":
# # test_unet()
# test_crossformer()
if __name__ == "__main__":
test_unet()
# test_crossformer()



0 comments on commit 998c1e3

Please sign in to comment.