generated from ashleve/lightning-hydra-template
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
62 lines (47 loc) · 1.79 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import os
import yaml
import argparse
from easydict import EasyDict
from tqdm import tqdm
import numpy as np
import pandas as pd
import torch
from torch_geometric.data import DataLoader
from src.utils import seed_everything
from src.dataset import GEM1Dataset
from src.model import GEM1Model
@torch.no_grad()
def test(model, loader, device):
model.eval()
preds = []
for batch in tqdm(loader, desc='Predicting'):
batch = batch.to(device)
pred = model(batch)
preds.append(pred.detach().cpu())
return np.concatenate(preds)
def main(args):
with open(args.config) as f:
cfg = EasyDict(yaml.safe_load(f))
seed_everything(cfg.seed)
config_name = args.config.split('/')[-1].split('.')[0]
device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu')
test_dataset = GEM1Dataset(root=cfg.dataset.root, mode='test')
test_loader = DataLoader(
test_dataset, batch_size=cfg.dataset.batch_size, shuffle=False,
num_workers=cfg.dataset.num_workers, pin_memory=cfg.dataset.pin_memory,
follow_batch=['edge_attr']
)
model = GEM1Model(**cfg.model).to(device)
ckpt_path = args.ckpt_path if args.ckpt_path else f'outputs/{config_name}/ckpts/best.pt'
model.load_state_dict(torch.load(ckpt_path))
test_preds = test(model, test_loader, device)
sub_df = pd.read_csv('data/sample_submission.csv')
sub_df[sub_df.columns[1:]] = test_preds
sub_df.to_csv(f'outputs/{config_name}/submissions/submission.csv', index=False)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('config', type=str)
parser.add_argument('--gpu', type=int, default=0)
parser.add_argument('--ckpt_path', type=str, default='')
args = parser.parse_args()
main(args)