-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlearn.py
80 lines (61 loc) · 2.17 KB
/
learn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import numpy as np
import tensorflow as tf
import keras
from data_generator import DataGenerator
tf.config.experimental.enable_tensor_float_32_execution(False)
'''
Load and filter data
'''
def filt(sim):
T = sim["params"]["T"]
return T > 1.0
simData = np.load("data/LJ.npy", allow_pickle=True).item()
trainingGenerator = DataGenerator(simData, batch_size=256, windowSigma=3.5, inputKeys=["rho"], paramsKeys=["T", "L_inv"], outputKeys=["c1"], filt=filt)
'''
Define model
'''
profileInputs = {"rho": keras.Input(shape=trainingGenerator.inputShape, name="rho")}
paramsInputs = {paramKey: keras.Input(shape=(1,), name=paramKey) for paramKey in ["T", "L_inv"]}
x = keras.layers.Concatenate()(list((profileInputs | paramsInputs).values()))
x = keras.layers.Dense(512, activation="softplus")(x)
x = keras.layers.Dense(512, activation="softplus")(x)
x = keras.layers.Dense(512, activation="softplus")(x)
x = keras.layers.Dense(512, activation="softplus")(x)
outputs = {"c1": keras.layers.Dense(trainingGenerator.outputShape[0], name="c1")(x)}
model = keras.Model(inputs=(profileInputs | paramsInputs), outputs=outputs)
optimizer = keras.optimizers.Adam()
loss = keras.losses.MeanSquaredError()
metrics = [keras.metrics.MeanAbsoluteError()]
model.compile(
optimizer=optimizer,
loss=loss,
metrics=metrics,
)
model.summary()
'''
Prepare data pipeline
'''
def gen():
for i in range(len(trainingGenerator)):
yield trainingGenerator[i]
train_dataset = tf.data.Dataset.from_generator(gen, output_signature=(
{
"rho": tf.TensorSpec(shape=(trainingGenerator.batch_size, trainingGenerator.inputShape[0]), dtype=tf.float32),
"T": tf.TensorSpec(shape=(trainingGenerator.batch_size, 1), dtype=tf.float32),
"L_inv": tf.TensorSpec(shape=(trainingGenerator.batch_size, 1), dtype=tf.float32),
},
{
"c1": tf.TensorSpec(shape=(trainingGenerator.batch_size, 1), dtype=tf.float32),
}
)).prefetch(tf.data.AUTOTUNE)
'''
Do training
'''
model.fit(
train_dataset,
epochs=200,
callbacks=[
keras.callbacks.LearningRateScheduler(lambda epoch, lr: lr * 0.92),
keras.callbacks.ModelCheckpoint(filepath="models/current.keras"),
]
)