-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy patheval.py
84 lines (74 loc) · 3.24 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import os
import json
import torch
from torch import nn
from model.grud_model import GRUD
from model.custom_cross_entropy import CustomCrossEntropyLoss
from data.dataset import RegionDataset
from torch.utils.data import DataLoader
from utils.argument_parser import get_argument
from utils.imputation import evaluation, get_device, write_gen, write_output_Oxford_format, oxford_2_vcf
# torch.manual_seed(42)
def run(dataloader, dataset, imp_site_info_list, model_config, args, region):
device = get_device(args.gpu)
model_dir = args.model_dir
result_gen_dir = args.result_gen_dir
chromosome = args.chromosome
model_config['device'] = device
type_model = args.type_model
if type_model in ['Lower', 'Higher']:
gamma = args.gamma if type_model == 'Higher' else -args.gamma
#Init Model
model = GRUD(model_config, device).to(device)
loss_fn = CustomCrossEntropyLoss()
loss_fct = nn.BCEWithLogitsLoss()
loss = {
'CustomCrossEntropy': loss_fn,
'BCEWithLogitsLoss': loss_fct
}
if args.best_model:
loaded_model = torch.load(os.path.join(model_dir, f'Best_grud_region_{region}.pt'),map_location=torch.device(device))
else:
loaded_model = torch.load(os.path.join(model_dir, f'grud_region_{region}.pt'),map_location=torch.device(device))
model.load_state_dict(loaded_model)
print(f"Loaded grud_{region} model")
test_loss, _r2_score, (predictions, labels, dosage) = evaluation(dataloader, model, device, loss)
print(f"[Evaluate] Loss: {test_loss} \t R2 Score: {_r2_score}")
write_output_Oxford_format(dosage, imp_site_info_list, chromosome, region, result_gen_dir)
write_gen(predictions, imp_site_info_list, chromosome, region, result_gen_dir)
def main():
args = get_argument()
root_dir = args.root_dir
model_config_dir = args.model_config_dir
batch_size = args.batch_size
chromosome = args.chromosome
regions = args.regions.split("-")
index_region = args.regions + "_GRUD"
with open(os.path.join(root_dir, f'{index_region}.txt'), 'w+') as index_file:
index_file.write("0")
for region in range(int(regions[0]), int(regions[-1])+1):
print(f"----------Testing Region {region}----------")
with open(os.path.join(model_config_dir, f'region_{region}_config.json'), "r") as json_config:
model_config = json.load(json_config)
model_config['region'] = region
model_config['type_model'] = args.type_model
dataset = RegionDataset(root_dir, index_region, region, chromosome, dataset=args.dataset)
testloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
imp_site_info_list = [
site_info
for site_info in dataset.site_info_list if not site_info.array_marker_flag
]
run(
testloader,
dataset,
imp_site_info_list,
model_config,
args,
region,
)
print("----------Imputation Done----------")
print(f"Writing to gen_{chromosome}.vcf.gz file----------")
oxford_2_vcf(os.path.join(args.result_gen_dir, 'gen'), args.result_gen_dir, args.sample, chromosome)
print("----------Writing to VCF Done----------")
if __name__ == "__main__":
main()