-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathrun.py
127 lines (113 loc) · 4.01 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from __future__ import annotations
import argparse
from coco_froc_analysis.count import generate_bootstrap_count_curves
from coco_froc_analysis.count import generate_count_curve
from coco_froc_analysis.froc import generate_bootstrap_froc_curves
from coco_froc_analysis.froc import generate_froc_curve
from coco_froc_analysis.utils import bounds
from coco_froc_analysis.utils import test_point
def run():
parser = argparse.ArgumentParser()
parser.add_argument(
'--bootstrap',
type=int,
default=1,
help='Whether to do a single or bootstrap runs.',
)
parser.add_argument('--gt_ann', type=str, required=True)
parser.add_argument('--pr_ann', type=str, required=True)
parser.add_argument(
'--use_iou',
default=False,
action='store_true',
help='Use IoU score to decide based on `proximity`',
)
parser.add_argument(
'--iou_thres',
default=0.5,
type=float,
required=False,
help='If IoU score is used the default threshold is set to .5',
)
parser.add_argument(
'--n_sample_points',
type=int,
default=50,
help='Number of points to evaluate the FROC curve at.',
)
parser.add_argument('--plot_title', type=str)
parser.add_argument('--plot_output_path', type=str)
parser.add_argument(
'--test_ann',
action='append',
help='Extra ground-truth like annotations with annotator name/ID.',
type=test_point,
dest='test_ann',
required=False,
)
parser.add_argument(
'--counts',
default=False,
action='store_true',
)
parser.add_argument(
'--weighted',
default=False,
action='store_true',
)
parser.add_argument(
'--bounds',
type=bounds,
default=None,
required=False,
)
args = parser.parse_args()
if args.counts:
if args.bootstrap > 1:
generate_bootstrap_count_curves(
gt_ann=args.gt_ann,
pr_ann=args.pr_ann,
n_bootstrap_samples=args.bootstrap,
n_sample_points=args.n_sample_points,
plot_title='Counts PR (bootstrap)' if args.plot_title is None else args.plot_title,
plot_output_path='counts_bootstrap.png' if args.plot_output_path is None else args.plot_output_path,
weighted=args.weighted,
test_ann=args.test_ann,
)
else:
generate_count_curve(
gt_ann=args.gt_ann,
pr_ann=args.pr_ann,
weighted=args.weighted,
plot_title='Counts PR' if args.plot_title is None else args.plot_title,
plot_output_path='counts.png' if args.plot_output_path is None else args.plot_output_path,
test_ann=args.test_ann,
)
exit(-1)
if args.bootstrap > 1:
print('Generating bootstrap curves... (this may take a while)')
generate_bootstrap_froc_curves(
gt_ann=args.gt_ann,
pr_ann=args.pr_ann,
n_bootstrap_samples=args.bootstrap,
use_iou=args.use_iou,
iou_thres=args.iou_thres,
n_sample_points=args.n_sample_points,
plot_title='FROC (bootstrap)' if args.plot_title is None else args.plot_title,
plot_output_path='froc_bootstrap.png' if args.plot_output_path is None else args.plot_output_path,
test_ann=args.test_ann,
)
else:
print('Generating single FROC curve...')
generate_froc_curve(
gt_ann=args.gt_ann,
pr_ann=args.pr_ann,
use_iou=args.use_iou,
iou_thres=args.iou_thres,
n_sample_points=args.n_sample_points,
plot_title='FROC' if args.plot_title is None else args.plot_title,
plot_output_path='froc.png' if args.plot_output_path is None else args.plot_output_path,
test_ann=args.test_ann,
)
if __name__ == '__main__':
run()