forked from pmuens/alphago
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathalphago_policy_sl.py
37 lines (29 loc) · 1.31 KB
/
alphago_policy_sl.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import h5py
from keras.callbacks import ModelCheckpoint
from dlgo.data.parallel_processor import GoDataProcessor
from dlgo.encoders.alphago import AlphaGoEncoder
from dlgo.agent.predict import DeepLearningAgent
from dlgo.networks.alphago import alphago_model
rows, cols = 19, 19
num_classes = rows * cols
num_games = 10000
encoder = AlphaGoEncoder()
processor = GoDataProcessor(encoder=encoder.name())
generator = processor.load_go_data('train', num_games, use_generator=True)
test_generator = processor.load_go_data('test', num_games, use_generator=True)
input_shape = (encoder.num_planes, rows, cols)
alphago_sl_policy = alphago_model(input_shape, is_policy_net=True)
alphago_sl_policy.compile('sgd', 'categorical_crossentropy', metrics=['accuracy'])
epochs = 200
batch_size = 128
alphago_sl_policy = fit_generator(
generator=generator.generate(batch_size, num_classes),
epochs=epochs,
steps_per_epoch=generator.get_num_samples() / batch_size,
validation_data=test_generator.generate(batch_size, num_classes),
validation_steps=test_generator.get_num_samples() / batch_size,
callbacks=[ModelCheckpoint('alphago_sl_policy_{epoch}.h5')]
)
alphago_sl_agent = DeepLearningAgent(alphago_sl_policy, encoder)
with h5py.File('alphago_sl_policy.h5', 'w') as sl_agent_out:
alphago_sl_agent.serialize(sl_agent_out)