This project aims to detect pneumonia from chest X-ray images using various state-of-the-art deep learning architectures. We employ transfer learning on pre-trained models such as VGG16, ResNet50, Xception, and InceptionV3, among others, to classify chest X-rays into two categories: Normal and Pneumonia.
We use the publicly available Chest X-Ray dataset from Kaggle. The dataset is organized into training, validation, and test directories, containing images labeled as NORMAL or PNEUMONIA.
Dataset structure:
train/
: Contains training images.val/
: Contains validation images.test/
: Contains test images.
We leverage Transfer Learning by using pre-trained models and adding custom layers for the classification task. The following architectures have been implemented:
- VGG16 🧑💻
- ResNet50 🧑🔬
- InceptionV3 🧙♂️
- MobileNetV2 🦸♀️
Each architecture is loaded with ImageNet weights, and the final layers are customized to handle the binary classification task (Normal vs Pneumonia).
- Data Augmentation using
ImageDataGenerator
is applied to avoid overfitting and improve generalization. - Early Stopping is used to prevent overtraining the model.
- Adam Optimizer with a learning rate of 0.0001 is used to train the models.
For each architecture, we evaluate the model using the following metrics:
- Accuracy 🎯
- Confusion Matrix 📊
- Precision, Recall, and F1-Score 📏
- AUC-ROC Curve 🟠
fpr, tpr, _ = roc_curve(y_true, y_pred_probs)
roc_auc = auc(fpr, tpr)
plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (area = {roc_auc:.2f})')
conf_matrix = confusion_matrix(y_true, y_pred_classes)
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues')
For each architecture, the training and validation accuracy are plotted to visualize the learning process:
plt.plot(history.history['accuracy'], label='accuracy')
plt.plot(history.history['val_accuracy'], label='val_accuracy')
plt.title('Training and Validation Accuracy')
-
Clone the repository:
git clone https://github.com/yourusername/pneumonia-detection.git cd pneumonia-detection
-
Install the required dependencies:
pip install -r requirements.txt
-
Download the dataset from Kaggle and extract it to the
data/
folder. -
Train the model:
python train.py --model Xception
-
Evaluate the model:
python evaluate.py --model Xception
Model | Accuracy | Precision | Recall | F1-Score | AUC |
---|---|---|---|---|---|
VGG16 | 94.5% | 93.2% | 95.0% | 94.1% | 0.96 |
ResNet50 | 95.1% | 94.6% | 95.7% | 95.1% | 0.97 |
InceptionV3 | 95.8% | 94.9% | 96.0% | 95.4% | 0.97 |
MobileNetV2 | 94.0% | 92.8% | 94.2% | 93.5% | 0.95 |
This project is licensed under the MIT License.