-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdecode.py
66 lines (45 loc) · 1.76 KB
/
decode.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
#! /usr/bin/python3
# Script to decode encoded ensembles
# Usage: python decode.py <encoded_str>
# Example: python decode.py 010000000000010100000
import sys
model_list = ["ConvNet", "DeconvNet", "MobileNet", "ResNet18", "ResNet50", "VGG11", "VGG16"]
def is_valid_encoding(encoded_arr, total_models, ens_size, num_div_ops):
return (len(encoded_arr) == total_models * num_div_ops and
sum(encoded_arr) == ens_size)
def decode_str(encoded_str):
return [int(char) for char in list(encoded_str)]
def decode(encoded_arr, ens_size):
N = int(len(encoded_arr)/ens_size)
pred_filenames = []
for idx, config in enumerate(encoded_arr):
if idx < N:
if config != 0:
model_name = model_list[idx]
pred_filenames.append(model_name)
elif idx < 2*N:
if config != 0:
model_name = model_list[idx-N]
for j in range(config):
identifier = str(j)
pred_filenames.append("Data Div Op on " + model_name)
elif idx < 3*N:
if config != 0:
model_name = model_list[idx-2*N]
for j in range(config):
identifier = str(j)
pred_filenames.append("Snapshot Div Op on " + model_name)
return pred_filenames
def main():
encoded_str = sys.argv[1]
encoded_arr = decode_str(encoded_str)
total_models = len(model_list)
ens_size = 3
num_div_ops = 3
if is_valid_encoding(encoded_arr, total_models, ens_size, num_div_ops):
decoded_arr = decode(encoded_arr, ens_size)
print("Ensemble combination decoded: ", decoded_arr)
else:
print("Invalid ensemble encoding supplied!")
if __name__ == "__main__":
main()