This repo contains a collection of scripts for builiding a text generator by training a recurrent neural network on a large text dataset.
For this project, the New York Times Comments dataset was used to train the model. The dataset consist of 2+ million comments made to New York Times arcticle sections, collected in various time windows. The data should be downloaded from here and stored in the data
folder as shown in the repository structure
paragraph further down this readme.
The text generator is created by training a recurrent neural network model on the comments. The model is created in the TensorFlow
framework, and is a sequential model with an input embedding layer, a LSTM (Long Short-Term Memory) hidden layer and an output layer.
The code pipeline consists of a training script, train_rnn_model.py
, and a text generation script, generate_text.py
.
Furthermore, a series of helper functions are used in various steps of the pipeline. These functions are located in the helper_functions.py
script located in the utils
folder.
The train_rnn_model.py
script follows these steps:
- Import dependencies
- Load data
- Preprocess data
- Tokenize data
- Initialize and train RNN model
- Save tokenizer and model to
models
folder
The generate_text.py
script follows these steps:
- Load tokenizer and model
- Generate text based on input prompt
The code is tested on Python 3.11.2. Futhermore, if your OS is not UNIX-based, a bash-compatible terminal is required for running shell scripts (such as Git for Windows).
The repo was setup to work with Windows (the WIN_ files), MacOS and Linux (the MACL_ files).
git clone https://github.com/alekswael/text_generation_with_RNNs
cd text_generation_with_RNNs
NOTE: Depending on your OS, run either WIN_setup.sh
or MACL_setup.sh
.
The setup script does the following:
- Creates a virtual environment for the project
- Activates the virtual environment
- Installs the correct versions of the packages required
- Deactivates the virtual environment
bash WIN_setup.sh
NOTE: Depending on your OS, run either WIN_run.sh
or MACL_run.sh
.
-
Run the
*train_model.sh
script.The script does the following:
- Activates the virtual environment
- Runs
train_rnn_model.py
located in thesrc
folder - Deactivates the virtual environment
bash WIN_train_model.sh
-
Run the
*generate_text.sh
script.The script does the following:
- Activates the virtual environment
- Runs
generate_text.py
located in thesrc
folder - Deactivates the virtual environment
bash WIN_generate_text.sh
Generating text through the bash script will default to using "Hi, how are you?" as the prompt. However, if you wish to generate text with a different prompt, this can be specified as an argument when running the generate_text.py
script.
Further model parameters can be set through the argparse
module. However, this requires running the Python script seperately OR altering the run*.sh
file to include the arguments. The Python script is located in the src
folder. Make sure to activate the environment before running the Python script.
# Activate venv
source ./rnn_for_text_classification_venv/bin/activate # MacOS and Linux
source ./rnn_for_text_classification_venv/Scripts/activate # Windows
# Run the script with a different prompt
python3 ./src/generate_text.py -p Do you think Hercules is ripped? # MacOS and Linux
python ./src/generate_text.py -p Do you think Hercules is ripped? # Windows
generate_text.py [-h] [-p PROMPT] [-w NUM_WORDS]
options:
-h, --help show this help message and exit
-p PROMPT, --prompt PROMPT
The prompt to start the text generation from. (default: Hi, how are you?)
-w NUM_WORDS, --num_words NUM_WORDS
Batch size to use when training the model. (default: 5)
train_rnn_model.py [-h] [--epochs EPOCHS] [--batch_size BATCH_SIZE]
options:
-h, --help show this help message and exit
--epochs EPOCHS Amount of epochs to train the model for. (default: 100)
--batch_size BATCH_SIZE
Batch size to use when training the model. (default: 128)
This repository has the following structure:
│ .gitignore
│ MACL_generate_text.sh
│ MACL_setup.sh
│ MACL_train_model.sh
│ README.md
│ requirements.txt
│ WIN_generate_text.sh
│ WIN_setup.sh
│ WIN_train_model.sh
│
├───.github
│ .keep
│
├───data
│ .keep
│ ArticlesApril2017.csv
│ ...
│
├───models
│ .keep
│
├───src
│ generate_text.py
│ train_rnn_model.py
│
└───utils
helper_functions.py
__init__.py
NOTE: Trained for 1 epoch - this is just an example.
PROMPT: Hi, how are you?
1/1 [==============================] - 0s 382ms/step
1/1 [==============================] - 0s 29ms/step
1/1 [==============================] - 0s 27ms/step
1/1 [==============================] - 0s 24ms/step
1/1 [==============================] - 0s 26ms/step
GENERATED TEXT: Hi, How Are You? The The The The The