-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathapp.py
106 lines (81 loc) · 4.05 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import streamlit as st
import torch
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
from torchvision import models
import torch.nn as nn
import torch.nn.functional as F
import cv2
from Models import MultiClassMobileNetV2, MultiClassMobileNetV3Small
from CAM import get_cam
from Preprocess import apply_clahe
import operator
from Box import find_largest_similar_rectangle, overlay_rectangles
from Augment import data_transforms
# Initialize Models
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_v2 = MultiClassMobileNetV2().to(device)
model_v3s = MultiClassMobileNetV3Small().to(device)
model_v2.load_state_dict(torch.load('bucket/MobileNetV2_4.pth', map_location=device)['model_state_dict'])
model_v3s.load_state_dict(torch.load('bucket/MobileNetV3Small_1.pth', map_location=device)['model_state_dict'])
model_v2.eval()
model_v3s.eval()
# Streamlit App
st.title("LungInsight for X-ray Classification")
st.write("Upload an X-ray image and get the prediction with confidence levels.")
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
image = Image.open(uploaded_file).convert('L')
st.write("")
st.write("Classifying...")
image_clahe = apply_clahe(image)
image_tensor = data_transforms(image_clahe).unsqueeze(0).to(device)
class_labels = ['Normal lungs', 'TBC', 'Bacterial Pneumonia', 'Viral Pneumonia', 'COVID Pneumonia']
predictions_v2 = []
predictions_v3s = []
for _ in range(20):
with torch.no_grad():
outputs_v2 = model_v2(image_tensor)
outputs_v3s = model_v3s(image_tensor)
prob_v2 = torch.softmax(outputs_v2, dim=1).cpu().numpy().flatten()
prob_v3s = torch.softmax(outputs_v3s, dim=1).cpu().numpy().flatten()
predictions_v2.append(prob_v2)
predictions_v3s.append(prob_v3s)
# Averaging the predictions
avg_prob_v2 = np.mean(predictions_v2, axis=0)
avg_prob_v3s = np.mean(predictions_v3s, axis=0)
stacked_prob = (avg_prob_v2 + avg_prob_v3s) / 2
# Determine the label with the highest confidence
pred_label = class_labels[np.argmax(stacked_prob)]
confidence = np.max(stacked_prob)
st.write(f"Prediction: **{pred_label}**")
st.write(f"Confidence: **{confidence:.4f}**")
# Display the confidence levels for classes with more than 1% probability, excluding the main prediction
st.write("Class-wise confidence levels (only showing > 1%):")
for label, prob in zip(class_labels, stacked_prob):
if prob > 0.01 and label != pred_label:
st.write(f"{label}: {prob:.4f}")
if pred_label == 'Normal lungs':
# Show only the original image without heatmaps or bounding boxes
st.image(image, caption='Original Image (Normal lungs)', use_column_width=True)
else:
# Generate CAMs
cam_v2 = get_cam(model_v2, image_tensor, target_layer_name='base_model.features.18.2')
cam_v3s = get_cam(model_v3s, image_tensor, target_layer_name='base_model.features.12')
combined_cam = (cam_v2 + cam_v3s) / 2
# Upscale CAM to original image size
cam_upscaled = cv2.resize(combined_cam, (image.size[0], image.size[1]))
# Generate the heatmap
heatmap = cv2.applyColorMap(np.uint8(255 * (1 - cam_upscaled)), cv2.COLORMAP_JET)
heatmap = np.float32(heatmap) / 255
# Overlay rectangles on the original image using heatmap analysis
image_with_rectangles = overlay_rectangles(image, combined_cam)
st.image(image_with_rectangles, caption='Image with highlighted regions.', use_column_width=True)
# Create a heatmap-overlayed image
image_np = np.array(image) # Use the original image size
image_np = np.float32(image_np) / 255
cam_overlay = heatmap + np.expand_dims(image_np, axis=2)
cam_overlay = cam_overlay / np.max(cam_overlay)
cam_overlay_image = Image.fromarray(np.uint8(255 * cam_overlay))
st.image(cam_overlay_image, caption='Stacked CAM overlay.', use_column_width=True)