This repository contains a PyTorch-based deep learning model that generates images from textual descriptions. The model leverages a combination of LSTM-based text encoding, multi-head self-attention, and convolutional upsampling to generate images from input text sequences.
Combining LSTM with self-attention allows the model to leverage the strengths of both architectures while compensating for their weaknesses.
- LSTMs process text sequentially, which means they are good at capturing local dependencies and order-sensitive information.
- They handle long-term dependencies better than vanilla RNNs, but they still struggle with very long sequences due to memory constraints.
- Self-attention (like in Transformers) allows the model to look at the entire sequence at once instead of processing it step by step.
- This helps capture long-range dependencies and relationships between words that might be far apart.
- Unlike LSTMs, self-attention does not suffer from vanishing gradients over long sequences.
By combining an LSTM with self-attention, we get:
✅ LSTM for local context & sequential understanding
✅ Self-attention for capturing global dependencies
- The LSTM processes the input sequentially, learning relationships between neighboring tokens.
- Then, the self-attention mechanism (via
MultiHeadSelfAttention
) helps refine these representations by looking at all words at once and adjusting their importance. - This makes the text encoding more informative before it is mapped to image features.
- Text Normalization & Tokenization: Preprocesses text inputs by normalizing and tokenizing.
- Custom Dataset Class: Handles loading image-text pairs.
- Data Augmentation: Includes color jittering and Gaussian noise.
- Residual Blocks & Self-Attention: Enhances image generation quality.
- Training & Validation Pipeline: Supports training with loss tracking.
To set up the required dependencies, use the provided Anaconda environment:
Save the following .yaml
file as environment.yaml
and install it using Anaconda.
name: text_to_image_env
channels:
- conda-forge
- defaults
dependencies:
- python=3.8
- pytorch
- torchvision
- torchaudio
- tqdm
- pillow
- pip
- pip:
- numpy
- matplotlib
Install the environment using:
conda env create -f environment.yaml
conda activate text_to_image_env
Alternatively, install dependencies using pip:
pip install torch torchvision torchaudio tqdm pillow numpy matplotlib
Store images in a specified directory and provide corresponding textual descriptions.
Run the training script:
python train.py
After training, generate images from new text inputs by running:
python generate.py --input "your text here"
The model follows a Text-to-Image Pipeline:
- Text Embedding: Tokenized text is passed through an LSTM.
- Multi-Head Self-Attention: Captures textual dependencies.
- Fully Connected Layer: Maps text features to image features.
- Convolutional Decoder: Upsamples features to reconstruct an image.
The trained model outputs generated images based on input text. Example results can be found in the results_short/
directory.
To save the trained model:
python save_model.py
To load and use the trained model:
import torch
from model import TextToImageModel
model = TextToImageModel()
model.load_state_dict(torch.load("results_short/text_to_image_model.pth"))
model.eval()
This project is open-source under the MIT License.
Feel free to open an issue or a pull request for improvements!