-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathMain.py
102 lines (87 loc) · 2.91 KB
/
Main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import pickle as pk
import sys
import Constants
import DailyData
import Execute
from Model import LSTM
# #####################################################################
# PROCESS COMMAND LINE ARGUMENTS
# #####################################################################
instructions = """
Incorrect options:
-dtype [or] -d ['daily', 'intraday']
-generate [or] -g ['y', 'n']
-restore [or] -r ['y', 'n']
-train [or] -t ['y', 'n']
-simulate [or] -s ['y', 'n']
(All must be specified)
"""
if len(sys.argv) != 9:
print(instructions)
sys.exit()
opts = {}
for i, arg in enumerate(sys.argv):
if arg[0] != '-':
continue
choice = sys.argv[i + 1][0].lower()
if arg in ['-g', '-generate'] and choice in ['y', 'n']:
opts['g'] = choice
continue
if arg in ['-r', '-restore'] and choice in ['y', 'n']:
opts['r'] = choice
continue
if arg in ['-t', '-train'] and choice in ['y', 'n']:
opts['t'] = choice
continue
if arg in ['-s', '-simulate'] and choice in ['y', 'n']:
opts['s'] = choice
continue
print(instructions)
sys.exit()
# #####################################################################
# RETRIEVE DATA
# #####################################################################
if opts['g'] == 'y':
DailyData.generateDataSet()
prices = pk.load(open(Constants.dataDir + 'dailyPrices.p', 'rb'))
data = pk.load(open(Constants.dataDir + 'dailyData.p', 'rb'))
offlineData = data[:-Constants.onlineLength]
onlineData = data[-Constants.onlineLength - Constants.sequenceLength + 1:]
numLabels = Constants.numLabels
numFeatures = data.shape[1] - numLabels
# #####################################################################
# GENERATE LSTM MODEL
# #####################################################################
lstm = LSTM(
numFeatures=numFeatures,
numOutputs=numLabels,
sequenceLength=100,
unitsPerLayer=[250, 100],
regularise=True
)
bestLoss = 1.0
bestEpoch = 0
if opts['r'] == 'y':
try:
lstm.restore()
bestLoss = pk.load(open(Constants.modelDir + 'bestLoss.p', 'rb'))
bestEpoch = pk.load(open(Constants.modelDir + 'bestEpoch.p', 'rb'))
print("\nMODEL LOADED (Loss: {})".format(bestLoss))
except Exception:
print("""
ERROR:
Unable to restore model.
Does a stored model exist?
Have you changed the LSTM architecture?
""")
sys.exit()
# #####################################################################
# TRAIN MODEL
# #####################################################################
if opts['t'] == 'y':
Execute.train(lstm, offlineData, bestEpoch, bestLoss)
# #####################################################################
# SIMULATE PREDICTIONS
# #####################################################################
if opts['s'] == 'y':
Execute.simulate(lstm, onlineData, prices, Constants.ticker)