-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathLSTM Text Generation.py
82 lines (66 loc) · 3.91 KB
/
LSTM Text Generation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import numpy as np
import tensorflow as tf
def process_text(file_path):
text = open(file_path, 'rb').read().decode(encoding='utf-8') # Read, then decode for py2 compat.
vocab = sorted(set(text)) # The unique characters in the file
# Creating a mapping from unique characters to indices and vice versa
char2idx = {u: i for i, u in enumerate(vocab)}
idx2char = np.array(vocab)
text_as_int = np.array([char2idx[c] for c in text])
return text_as_int, vocab, char2idx, idx2char
def split_input_target(chunk):
input_text, target_text = chunk[:-1], chunk[1:]
return input_text, target_text
def create_dataset(text_as_int, seq_length=100, batch_size=64, buffer_size=10000):
char_dataset = tf.data.Dataset.from_tensor_slices(text_as_int)
dataset = char_dataset.batch(seq_length + 1, drop_remainder=True).map(split_input_target)
dataset = dataset.shuffle(buffer_size).batch(batch_size, drop_remainder=True)
return dataset
def build_model(vocab_size, embedding_dim=256, rnn_units=1024, batch_size=64):
model = tf.keras.Sequential([
tf.keras.layers.Embedding(vocab_size, embedding_dim, batch_input_shape=[batch_size, None]),
tf.keras.layers.LSTM(rnn_units, return_sequences=True, stateful=True, recurrent_initializer='glorot_uniform'),
tf.keras.layers.Dropout(0.1),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.LSTM(rnn_units, return_sequences=True, stateful=True, recurrent_initializer='glorot_uniform'),
tf.keras.layers.Dropout(0.1),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Dense(vocab_size)
])
return model
def loss(labels, logits):
return tf.keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True)
def generate_text(model, char2idx, idx2char, start_string, generate_char_num=1000, temperature=1.0):
# Evaluation step (generating text using the learned model)
# Low temperatures results in more predictable text, higher temperatures results in more surprising text.
# Converting our start string to numbers (vectorizing)
input_eval = [char2idx[s] for s in start_string]
input_eval = tf.expand_dims(input_eval, 0)
text_generated = [] # Empty string to store our results
model.reset_states()
for i in range(generate_char_num):
predictions = model(input_eval)
predictions = tf.squeeze(predictions, 0) # remove the batch dimension
predictions /= temperature
# using a categorical distribution to predict the character returned by the model
predicted_id = tf.random.categorical(predictions, num_samples=1)[-1, 0].numpy()
# We pass the predicted character as the next input to the model along with the previous hidden state
input_eval = tf.expand_dims([predicted_id], axis=0)
text_generated.append(idx2char[predicted_id])
return start_string + ''.join(text_generated)
# path_to_file = tf.keras.utils.get_file('nietzsche.txt', 'https://s3.amazonaws.com/text-datasets/nietzsche.txt')
path_to_file = tf.keras.utils.get_file('shakespeare.txt', 'https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt')
text_as_int, vocab, char2idx, idx2char = process_text(path_to_file)
dataset = create_dataset(text_as_int)
model = build_model(vocab_size=len(vocab))
model.compile(optimizer='adam', loss=loss)
model.summary()
history = model.fit(dataset, epochs=50)
model.save_weights("gen_text_weights.h5", save_format='h5')
# To keep this prediction step simple, use a batch size of 1
model = build_model(vocab_size=len(vocab), batch_size=1)
model.load_weights("gen_text_weights.h5")
model.summary()
user_input = input("Write the beginning of the text, the program will complete it. Your input is: ")
generated_text = generate_text(model, char2idx, idx2char, start_string=user_input, generate_char_num=2000)
print(generated_text)