Skip to content

Commit

Permalink
bugfix on crossformer and parser
Browse files Browse the repository at this point in the history
  • Loading branch information
yingkaisha committed Oct 16, 2024
1 parent cfe2bce commit 08bebe9
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 68 deletions.
79 changes: 11 additions & 68 deletions credit/models/crossformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,7 @@ def __init__(
attn_dropout=0.,
ff_dropout=0.,
use_spectral_norm=True,
interp=True,
padding_conf={'activate': False},
post_conf={"activate": False},
**kwargs
Expand All @@ -338,7 +339,10 @@ def __init__(
self.surface_channels = surface_channels
self.levels = levels
self.use_spectral_norm = use_spectral_norm

self.use_interp = interp
self.use_padding = padding_conf['activate']
self.use_post_block = post_conf['activate']

# input channels
input_channels = channels * levels + surface_channels + input_only_channels

Expand All @@ -360,7 +364,6 @@ def __init__(
assert len(cross_embed_strides) == 4

# dimensions

last_dim = dim[-1]
first_dim = input_channels if (patch_height == 1 and patch_width == 1) else dim[0]
dims = [first_dim, *dim]
Expand Down Expand Up @@ -404,57 +407,19 @@ def __init__(
transformer_layer
])
)

# =================================================================================== #
# this block handles boundary padding
self.use_padding = padding_conf['activate']

if self.use_padding:
self.padding_opt = TensorPadding(padding_conf)

# =================================================================================== #
# This block handles I/O sizes that cannot be divcided by cross_embed_strides

# total downsampling factors
total_dsample_factor_H = patch_height
total_dsample_factor_W = patch_width

for s in cross_embed_strides:
total_dsample_factor_H *= s
total_dsample_factor_W *= s

# compute I/O sizes that can be accepted
self.image_height_adjust = (image_height // total_dsample_factor_H) * total_dsample_factor_H
self.image_width_adjust = (image_width // total_dsample_factor_W) * total_dsample_factor_W

# acceptable sizes are at least the size of the total factor
if self.image_height_adjust == 0:
self.image_height_adjust = total_dsample_factor_H

if self.image_width_adjust == 0:
self.image_width_adjust = total_dsample_factor_W

if (
self.image_height != self.image_height_adjust or
self.image_width != self.image_width_adjust
):
logger.infos(
'Configured input sizes before padding: ({} {}); acceptable input sizes before padding: ({} {}). '
'Bilinear interpolation will be used to handle the size differences'.format(
self.image_height, self.image_width,
self.image_height_adjust, self.image_width_adjust
)
)


# define embedding layer using adjusted sizes
# if the original sizes were good, adjusted sizes should == original sizes
self.cube_embedding = CubeEmbedding(
(frames, self.image_height_adjust, self.image_width_adjust),
(frames, image_height, image_width),
(frames, patch_height, patch_width),
input_channels,
dim[0]
)

# =================================================================================== #

self.up_block1 = UpBlock(1 * last_dim, last_dim // 2, dim[0])
Expand All @@ -465,39 +430,15 @@ def __init__(
if self.use_spectral_norm:
logger.info("Adding spectral norm to all conv and linear layers")
apply_spectral_norm(self)


self.use_post_block = post_conf['activate']

if self.use_post_block:
self.postblock = PostBlock(post_conf)

def forward(self, x):

if self.use_post_block: # copy tensor to feed into postBlock later
x_copy = x.clone().detach()

# ===================================================================== #
# this block does the input interpolation, if adjusted sizes are needed

# get the current size
B, C, T, H_in, W_in = x.shape

# if current size needs to be adjusted
if H_in != self.image_height_adjust or W_in != self.image_width_adjust:

# merge batch and time becuase F.interpolate works for 4D tensor
x = x.permute(0, 2, 1, 3, 4).reshape(B * T, C, H_in, W_in)

# F.interpolate to the adjusted sizes
x = F.interpolate(x,
size=(self.image_height_adjust, self.image_width_adjust),
mode='bilinear',
align_corners=False)

# reshape back
x = x.reshape(B, T, C, self.image_height_adjust, self.image_width_adjust).permute(0, 2, 1, 3, 4)
# ===================================================================== #

if self.use_padding:
x = self.padding_opt.pad(x)

Expand All @@ -524,8 +465,10 @@ def forward(self, x):

if self.use_padding:
x = self.padding_opt.unpad(x)

if self.use_interp:
x = F.interpolate(x, size=(self.image_height, self.image_width), mode="bilinear")

x = F.interpolate(x, size=(self.image_height, self.image_width), mode="bilinear")
x = x.unsqueeze(2)

if self.use_post_block:
Expand Down
5 changes: 5 additions & 0 deletions credit/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,11 @@ def CREDIT_main_parser(conf, parse_training=True, parse_predict=True, print_summ

# --------------------------------------------------------- #
# conf['model'] section

# use interpolation
if 'interp' not in conf['model']:
conf['model']['interp'] = True

# ======================================================== #
# padding opts
if 'padding_conf' not in conf['model']:
Expand Down

0 comments on commit 08bebe9

Please sign in to comment.