Skip to content

Commit

Permalink
normalization of hrrr pred and targets should use same statistics
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 8, 2023
1 parent c533fcc commit 84c336a
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
4 changes: 2 additions & 2 deletions metnet3_pytorch/metnet3_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -799,11 +799,11 @@ def forward(

# use a batchnorm to normalize each channel to mean zero and unit variance

normed_hrrr_target = self.batchnorm_hrrr(hrrr_target)

with freeze_batchnorm(self.batchnorm_hrrr) as frozen_batchnorm:
normed_hrrr_pred = frozen_batchnorm(hrrr_pred)

normed_hrrr_target = self.batchnorm_hrrr(hrrr_target)

# proposed loss gradient rescaler from section 4.3.2

normed_hrrr_pred = self.mse_loss_scaler(normed_hrrr_pred)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'metnet3-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.3',
version = '0.0.4',
license='MIT',
description = 'MetNet 3 - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 84c336a

Please sign in to comment.