-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtest_baselines.py
68 lines (58 loc) · 2.41 KB
/
test_baselines.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
"""
Script to test several algorithmic baselines for TSP.
Usage:
test_baselines.py (--algorithm A) (--test-set TS) [options]
Options:
-h --help Show this screen.
--algorithm A TSP algorithm to run.
--test-set (test|test_100|test_200|test_1000) Which test set to use.
"""
import schema
from docopt import docopt
from datasets.constants import _DATASET_ROOTS, _DATASET_CLASSES
from datasets._configs import CONFIGS
from baselines import random_baseline, greedy_baseline, christofides_baseline, LKH_baseline, beam_search_baseline, optimal_baseline
from statistics import stdev as std, mean
if __name__ == '__main__':
import os
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
schema = schema.Schema({
'--help': bool,
'--algorithm': schema.Use(str),
'--test-set': schema.Use(str)
})
args = docopt(__doc__)
args = schema.validate(args)
ts_set = args['--test-set']
data = _DATASET_CLASSES['tsp_large'](
root=_DATASET_ROOTS['tsp_large'],
split=ts_set,
num_nodes=CONFIGS['tsp_large'][ts_set]['num_nodes'],
num_samples=CONFIGS['tsp_large'][ts_set]['num_samples']
)
alg = args['--algorithm']
data = data[:2]
data.num_samples = 2
if alg == 'random':
means, stds = random_baseline(data, return_ratio=True)
print(f'avg. tour len {round(mean(means), 4)} ± {round(std(means), 4)}')
print(f'avg. std tour len {round(mean(stds), 4)} ± {round(std(stds), 4)}')
elif alg == 'greedy':
mean, std_dev = greedy_baseline(data, return_ratio=True)
print(f'avg. tour len {round(mean, 4)} ± {round(std_dev, 4)}')
elif alg == 'beam_search':
import time
st = time.time()
mean, std_dev = beam_search_baseline(data, return_ratio=True)
print(f'avg. tour len {round(mean, 4)} ± {round(std_dev, 4)} which took {time.time()-st}')
elif alg == 'christofides':
mean, std_dev = christofides_baseline(data, return_ratio=True)
print(f'avg. tour len {round(mean, 4)} ± {round(std_dev, 4)}')
elif alg == 'LKH':
mean, std_dev = LKH_baseline(data)
print(f'avg. tour len {round(mean, 4)} ± {round(std_dev, 4)}')
elif alg == 'optimal':
mean, std_dev = optimal_baseline(data)
print(f'avg. tour len {round(mean, 4)} ± {round(std_dev, 4)}')
else:
assert False, f"{alg} baseline not implemented."