-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathtest_classifier.py
103 lines (86 loc) · 4.49 KB
/
test_classifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
# Test content classifiers to #
# make sure they are working well #
# enough to be used in NPI training #
# (Accuracy should be very high: #
# ideally 99% or more) #
# #
# Fulda, Brown, Wingate, Robinson #
# DRAGN #
# NPI Project #
# 2020 #
import argparse
import pickle as pkl
import numpy as np
import torch
from .train_classifier import Classifier, extract_needed_layers
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model-dir-path",
default="classifiers/layers_5_11/",
help="path to directory containing classifiers")
parser.add_argument("--data-path",
default="data/sentence_arrays",
help="path to data (standard file name witout pkl suffix, full or relative file path)")
parser.add_argument("--test-pkls",
type=str,
default="53,54,55,56", # See NOTE in arg-parsing section of train_classifier.py (line 378)
help="pkl numbers for data designated for testing: string of numbers separated by commas")
parser.add_argument("--test-epochs",
type=str,
default="20,30,40,50,60,70",
help="epoch nums for class'n models we want to test: string of numbers separated by commas")
parser.add_argument("--perturbation-indices",
type=str,
default="5,11",
help="indices for layers to extract from language model activations: string of numbers separated by commas")
args = parser.parse_args()
EPOCH_NUM_LIST = [int(pi) for pi in args.test_epochs.split(',')]
FILE_PATH_LIST = [args.model_dir_path] * len(EPOCH_NUM_LIST)
for classifier_num in range(len(EPOCH_NUM_LIST)):
epoch_num = EPOCH_NUM_LIST[classifier_num]
test_nums = [int(pi) for pi in args.test_pkls.split(',')] # these pickles are designated for testing!!
file_path = FILE_PATH_LIST[classifier_num]
data_path = args.data_path
PRED_INDS = [int(pi) for pi in args.perturbation_indices.split(',')]
print("NEW FILE", file_path, "epoch num", epoch_num, flush=True)
# Load classifier
classifier = torch.load(file_path + "Classifier_classification_network_epoch{}.bin".format(epoch_num),
map_location=torch.device('cpu')).cuda()
# We load the model from the CPU just in case it was trained on a different GPU than the one we are using
collected_accs = []
# collected_alt_accs = []
for test_num in test_nums:
with open(data_path + ".pkl_{}".format(test_num), 'rb') as f:
money = pkl.load(f)
score = 0
# alt_score = 0
for i in range(len(money)):
arr = extract_needed_layers(money[i][0], PRED_INDS)
arr = torch.Tensor(arr).cuda()
sent = money[i][-1]
truth = money[i][1][1]
yhat = classifier(arr).squeeze().cpu().item()
if truth == 1 and yhat >= .5:
score += 1
elif truth == 0 and yhat < .5:
score += 1
# if truth == 1 and yhat >= .7:
# alt_score += 1
# elif truth == 0 and yhat < .7:
# alt_score += 1
if i % 100 == 99:
print(sent.replace('\n', '\\n'))
print("truth", truth)
print("yhat", yhat)
score = score / len(money)
# alt_score = alt_score/len(money)
print("ACCURACY FOR TEST {}: {}".format(test_num, score)) # HACK
# print("ALT-ACCURACY FOR TEST {}: {}".format(test_num,alt_score))
collected_accs.append(score)
# collected_alt_accs.append(alt_score)
avg_acc = np.mean(collected_accs)
# avg_alt_acc = np.mean(collected_alt_accs)
print('done')
print("TOTAL ACCURACY OVERALL:", avg_acc)
# print("TOTAL ALT-ACCURACY OVERALL:",avg_alt_acc)
print("\n================================================\n", flush=True)