-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathevaluate_cloud_cover.py
81 lines (69 loc) · 2.35 KB
/
evaluate_cloud_cover.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
"""evaluate_cloud_cover.py"""
"""
Author: Yimin Yang
Last revision date: Jan 18, 2022
Function: Run this to evaluate the trained model
Ref: https://github.com/HansBambel/SmaAt-UNet
"""
import torch
from torch import nn
import numpy as np
import os
import pickle
from tqdm import tqdm
def compute_loss(model, test_dl, loss="mse"):
model.eval()
model.to("cuda")
if loss.lower() == "mse":
loss_func = nn.functional.mse_loss
elif loss.lower() == "mae":
loss_func = nn.functional.l1_loss
elif loss.lower() == "bcewl":
loss_func = nn.functional.binary_cross_entropy_with_logits
with torch.no_grad():
loss_model = 0.0
for x, y_true in tqdm(test_dl, leave=False):
x = x.to("cuda")
y_true = y_true.to("cuda")
y_pred = model(x)
loss_model += loss_func(y_pred.squeeze(), y_true)
loss_model /= len(test_dl)
return np.array(loss_model.cpu())
def evaluate(data_file, model_folder, loss):
test_losses = dict()
dataset = cloud_maps(
folder=data_file,
input_imgs=4,
output_imgs=6, train=False)
test_dl = torch.utils.data.DataLoader(
dataset,
batch_size=1,
shuffle=False,
num_workers=2,
pin_memory=True
)
# load the models
model_name = 'AA_TransUNet'
model = AA_TransUnet.load_from_checkpoint(
'/AA_TransUNet/results/Model_Saved/T21_CBAM_end_100.ckpt')
model_loss = get_model_loss(model, test_dl, loss)
test_losses[model_name] = model_loss
print(
f"Model Name: {model_name}, Loss(MSE): {model_loss}")
return test_losses
if __name__ == '__main__':
loss = "mse"
denormalize = True
model_folder = '/AA_TransUNet/dataset/Data_cloud_cover_preprocessed'
data_file = "AA_TransUNet/dataset/Data_cloud_cover_preprocessed"
load = False
if load:
with open(model_folder + f"/results/Metrics_Saved/model_losses_{loss.upper()}_denormalized_11_26_50_TransUnet.pkl", "rb") as f:
test_losses = pickle.load(f)
print(test_losses)
else:
test_losses = get_model_losses(model_folder, data_file, loss)
with open(
model_folder + f"/results/Metrics_Saved/model_losses_{loss.upper()}_{f'de' if denormalize else ''}_normalized_1.pkl",
"wb") as f:
pickle.dump(test_losses, f)