A deep learning framework for generating chest X-ray images from electronic health record (EHR) data using transformer-based architectures and CLIP models.
This project implements multiple deep learning models for predicting chest X-ray images from patient EHR data using CXR-TFT : a transformer-based framework that fuses EHR and image data.
├── src/
│ ├── models/ # Model implementations
│ │ ├── transformer.py # Transformer model
│ │ ├── transformernn.py # Transformer with neural network components
│ │ ├── clip.py # CLIP model implementation
│ │ └── mlp.py # MLP baseline model
│ ├── data/ # Data processing
│ ├── training/ # Training logic
│ ├── configs/ # Configuration files
│ │ ├── config_tft.py # Transformer configuration
│ │ └── config_clip.py # CLIP configuration
│ └── utils.py # Utility functions
├── slurmjobs/ # HPC job scripts
├── tests/ # Test files
├── docs/ # Documentation
└── requirements.txt # Dependencies
- Clone the repository:
git clone https://github.com/MehakArora/cxrgen.git
cd cxrgen
- Create and activate a virtual environment:
python -m venv venv
source venv/bin/activate # On Windows: venv\Scripts\activate
- Install dependencies:
pip install -r requirements.txt
- Set up environment variables:
DATA_DIR
: Directory containing chest X-ray imagesCHECKPOINT_DIR
: Directory for saving model checkpointsWANDB_API_KEY
: Your Weights & Biases API keyWANDB_LOCAL_SAVE
: Local directory for W&B filesMIMIC_CLASSIFIER
: Path to MIMIC classifier modelINTERMEDIATE_DIR
: Directory for intermediate files
To train a model, use the following command:
python src/train.py
The training script supports different model types:
- MLP:
model_type='mlp'
- Transformer (EHR and CXR embeddings are added at the input):
model_type='transformer'
- Transformer with concatenation (EHR CXR embeddings are concatenated at the input):
model_type='transformer_concat'
Model configurations can be modified in the respective config files:
src/configs/config_tft.py
for transformer modelssrc/configs/config_clip.py
for CLIP models
Key configuration parameters include:
- Model architecture parameters
- Training hyperparameters
- Data processing settings
- Logging and monitoring settings
The project expects the following data structure:
- Chest X-ray images in the specified
DATA_DIR
- EHR matrices in a sub-folder called
longitudinal_data/ehr_matrices
- Image embeddings in a sub-folder called
longitudinal_data/image_embeddings
- Data is split into train/validation/test sets
- Models are trained with configurable parameters
- Training progress is tracked using Weights & Biases
- Model checkpoints are saved periodically
- Best model is selected based on validation performance
- PyTorch
- NumPy
- Pillow
- Weights & Biases
- Transformers
See requirements.txt
for specific versions.