This repository has been archived by the owner on Mar 22, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathmain.py
99 lines (84 loc) · 2.9 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
"""
Train or test a neural network.
To train a network:
'$ python3 main.py train name_of_net name_of_dataset --epochs 10'
To test a trained network:
'$ python3 main.py test name_of_net name_of_dataset'
For more information see README.md
"""
import os
from argparse import ArgumentParser, Namespace
from src.asf_cnn import test_model_masked, train_model
from src.model import load_model, path_from_model_name
from src.model.architecture.masked import create_model_masked
from src.plots import edit_predictions, plot_predictions
def train_wrapper(args: Namespace) -> None:
""" Function for training a network. """
model_name = args.model
if args.cont:
model = load_model(model_name)
history = model.__asf_model_history
else:
model_path = path_from_model_name(model_name)
if not args.overwrite and os.path.isfile(model_path):
print(f"File {model_name} already exists!")
return
model = create_model_masked(model_name)
history = {"loss": [], "accuracy": [], "val_loss": [], "val_accuracy": []}
train_model(model, history, args.dataset, args.epochs)
def test_wrapper(args: Namespace) -> None:
model_name = args.model
model = load_model(model_name)
if args.edit:
predictions, data_iter, metadata = test_model_masked(
model, args.dataset, args.edit
)
edit_predictions(
predictions, data_iter, metadata
)
else:
predictions, test_iter = test_model_masked(
model, args.dataset, args.edit
)
plot_predictions(
predictions, test_iter
)
if __name__ == '__main__':
p = ArgumentParser()
sp = p.add_subparsers()
# Arguments for train mode
train = sp.add_parser('train', help='Train a new model')
train.add_argument('model', help='Name of the model to save: example_net')
train.add_argument('dataset', nargs='?', default='dataset_calibrated')
train.add_argument(
'--overwrite',
'-o',
action='store_true',
help='Replace the file if it exists'
)
train.add_argument(
'--continue',
'-c',
action='store_true',
dest='cont',
help='Continue training from existing model'
)
train.add_argument('--epochs', '-e', type=int, default=10)
train.set_defaults(func=train_wrapper)
# Arguments for test mode
test = sp.add_parser('test', help='Test an existing model')
test.add_argument('model', help='Name of the trained model')
test.add_argument('dataset', nargs='?', default='dataset_calibrated')
test.add_argument(
'--edit',
'-e',
help="Replace mask with the networks",
action='store_true'
)
test.set_defaults(func=test_wrapper)
# Parse and execute selected function
args = p.parse_args()
if hasattr(args, 'func'):
args.func(args)
else:
p.print_help()