Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update README.md #1

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added .DS_Store
Binary file not shown.
8 changes: 7 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1 +1,7 @@
# zooms_classifier
# ZooMS Classifier

## Model Training and Explainability

**Note:** Access to Google Drive is required to utilize this Colab notebook.

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1Ljos45IErs819W3ynY5-u9ach85y9J9G?usp=sharing)
sethaxen marked this conversation as resolved.
Show resolved Hide resolved
12,648 changes: 12,648 additions & 0 deletions ZooMS_1DCNN_model_explainability.ipynb

Large diffs are not rendered by default.

Binary file added windows_app/.DS_Store
sethaxen marked this conversation as resolved.
Show resolved Hide resolved
Binary file not shown.
15 changes: 15 additions & 0 deletions windows_app/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# ZooMS Windows Classifer App


Please use pyinstaller to compile the app on a Windows machine.


MyApp/
|-- model/
| |-- weights.pth # 1DCNN weights (PyTorch)
|-- src/
| |-- main.py
| |-- model.py
| |-- file_ops.py
| |-- gui.py
|-- requirements.txt
Binary file added windows_app/model/model.pth
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file is quite large, and It's probably better for us to distribute this in a different manner than bundled with the code.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Up to you guys. In general, this size is considered very small.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, but it's also about hygiene. We prefer the weights be stored separately from the code. The idea is that at runtime, the user can provide the path to the weights. If they don't, then the code should check if the weights are present in a predefined location in the package, and if they're not, then they are automatically downloaded, using a method like this:

import requests
import hashlib
import logging
import os

def download_artifact(url: str, download_path: str, expected_hash: str=None, hash_algorithm: str='sha256', chunk_size: int=2**20):
    """
    Downloads an artifact from a given URL and optionally checks its hash.
    
    Parameters
    ----------
    url : str
        The URL of the artifact.
    download_path : str
        The path where the artifact will be downloaded.
    expected_hash : str, optional
        The expected hash of the artifact. (default: `None`)
    hash_algorithm : str, optional
        The hash algorithm to use. (default: 'sha256')
    chunk_size : int, optional
        The size of the chunks to use when downloading the artifact. (default: 2**20)
    
    Returns
    -------
    None
    
    Raises:
    - ValueError: If the computed hash does not match the expected hash.
    """
    response = requests.get(url, stream=True)
    response.raise_for_status()  # Check for any errors and raise an exception if found.
    
    hasher = hashlib.new(hash_algorithm)
    with open(download_path, 'wb') as file:
        for chunk in response.iter_content(chunk_size=chunk_size):
            file.write(chunk)
            hasher.update(chunk)
    
    computed_hash = hasher.hexdigest()
    if expected_hash:
        if computed_hash != expected_hash:
            os.remove(download_path)  # Delete the downloaded file if the hash does not match.
            raise ValueError(f'Hash mismatch: expected {expected_hash}, but got {computed_hash}.')
    else:
        logging.debug(f'Computed hash: {computed_hash}')

As for where to place the weights, a simple approach is to make a release in this repo, and add the weights as an artifact to the release. Then the URL of the artifact and the hash can be specified in the code in the next commit.

In terms of how to do this, I suggest you send us the weights now through another means (e-mail, Google Drive, etc), and I can take care of setting up the artifact. For code, I only ask for now that you add placeholder code to automatically download the weights to the right place once we have the URL.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We prefer the weights be stored separately from the code.

This is already the case; the weights are not part of the package. A user has to select weights from the drive, so we do not need to recompile the app every time. This way, I can just send Sam new updated weights, and she can still use the app.

The idea of downloading the weights is cool, but what if a user has no internet connection?

Anyway, currently, I do not have access to a Windows machine, so I won't be able to test the implementation.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We prefer the weights be stored separately from the code.

This is already the case; the weights are not part of the package. A user has to select weights from the drive, so we do not need to recompile the app every time. This way, I can just send Sam new updated weights, and she can still use the app.

Can you tell me a little more about this workflow then? I see there's a windows_app/model/weights.path documented in the readme, but no such file in the repo, but there is a model.path. If the weights are not part of the package, then what is this model.path file, and why is it included? When the user launches the app, must they have already downloaded weights from the drive, or is there then an interface to do that? I also don't have a windows machine to test it for myself.

RE the windows app, how keyed to windows is it? Is there a way it could be run on other platforms?

The idea of downloading the weights is cool, but what if a user has no internet connection?

Then if weights are unavailable, an error would be thrown instructing the user how to get the weights. How would this be less of an issue if the weights are stored in Google Drive though?

Binary file not shown.
3 changes: 3 additions & 0 deletions windows_app/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
pandas >= 1.4.4
torch >= 1.10.2
pyinstaller >= 5.13.2
8 changes: 8 additions & 0 deletions windows_app/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from cx_Freeze import setup, Executable

setup(
name="AI_ZooMS",
version="0.0.1",
description="Find homininis with AI",
executables=[Executable("src/main.py")]
)
45 changes: 45 additions & 0 deletions windows_app/src/file_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import pandas as pd
import os
from datetime import datetime

def read_csv_files(directory):
file_paths = []
file_names = []
for filename in os.listdir(directory):
if filename.endswith('.csv'):
filepath = os.path.join(directory, filename)
file_paths.append(filepath)
file_names.append(filename)
return file_paths, file_names



def save_results(results, output_directory, file_names, targets):
# Column names
column_names = [
'Canidae', 'Cervidae', 'CervidaeGazellaSaiga', 'Ovis', 'Equidae',
'CrocutaPanthera', 'BisonYak', 'Capra', 'Ursidae', 'Vulpes vulpes',
'Elephantidae', 'Others', 'Rhinocerotidae', 'Rangifer tarandus', 'Hominins'
]

# Updated the dataframe creation line to handle numpy arrays
concatenated_df = pd.concat([pd.DataFrame(result) for result in results], ignore_index=True)

concatenated_df.columns = column_names
concatenated_df['Most Probable Class'] = [column_names[i] for i in targets]

# Reorder columns to place 'Most Probable Class' as the second column
cols = ['Most Probable Class'] + [col for col in concatenated_df.columns if col != 'Most Probable Class']
concatenated_df = concatenated_df[cols]

# Insert file names as the first column
concatenated_df.insert(0, 'File Name', file_names)

# Get current date and time
current_datetime = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')

# Create a unique filename with current date and time
output_path = os.path.join(output_directory, f'results_{current_datetime}.csv')

# Save the concatenated dataframe to the unique output path
concatenated_df.to_csv(output_path, index=False)
92 changes: 92 additions & 0 deletions windows_app/src/gui.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from PyQt5.QtWidgets import QApplication, QWidget, QVBoxLayout, QHBoxLayout, QLabel, QLineEdit, QPushButton, QFileDialog
from PyQt5.QtGui import QIcon, QFont

def create_gui(main_func):
app = QApplication([])

window = QWidget()
window.setWindowTitle("ML-Based Mass Spectra Species Identifier")
window.setWindowIcon(QIcon('icon.png'))
window.setFixedSize(700, 350)

layoutV = QVBoxLayout()
font_label = QFont("Arial", 16, QFont.Bold)
font_button = QFont("Arial", 14)

layoutV.setContentsMargins(40, 40, 40, 40)

layoutH3 = QHBoxLayout()
model_file_path = QLineEdit(placeholderText="Select your ML model file...")
model_file_path.setFont(font_button)
browse_model_btn = QPushButton("Select Model")
browse_model_btn.setFont(font_button)
label3 = QLabel("Model File:")
label3.setFont(font_label)
layoutH3.addWidget(label3)
layoutH3.addWidget(model_file_path)
layoutH3.addWidget(browse_model_btn)
layoutV.addLayout(layoutH3)

layoutH1 = QHBoxLayout()
input_directory = QLineEdit(placeholderText="Select a directory with CSV files...")
input_directory.setFont(font_button)
browse_input_btn = QPushButton("Browse")
browse_input_btn.setFont(font_button)
label1 = QLabel("Input Directory:")
label1.setFont(font_label)
layoutH1.addWidget(label1)
layoutH1.addWidget(input_directory)
layoutH1.addWidget(browse_input_btn)
layoutV.addLayout(layoutH1)

layoutH2 = QHBoxLayout()
output_directory = QLineEdit(placeholderText="Select a directory for results...")
output_directory.setFont(font_button)
browse_output_btn = QPushButton("Browse")
browse_output_btn.setFont(font_button)
label2 = QLabel("Output Directory:")
label2.setFont(font_label)
layoutH2.addWidget(label2)
layoutH2.addWidget(output_directory)
layoutH2.addWidget(browse_output_btn)
layoutV.addLayout(layoutH2)

classify_btn = QPushButton("Classify")
classify_btn.setFont(font_button)
layoutV.addWidget(classify_btn)

window.setStyleSheet("""
QWidget {
background-color: #fafafa;
font-size: 18px;
color: #333;
}
QPushButton {
background-color: #11a611; /* Green */
color: white;
border: none;
border-radius: 10px;
padding: 14px 28px;
}
QPushButton:pressed {
background-color: #005900; /* Darker green on click */
}
QLineEdit {
background-color: #fff;
border: 1px solid #ccc;
border-radius: 10px;
padding: 14px;
}
""")

browse_model_btn.clicked.connect(lambda: model_file_path.setText(QFileDialog.getOpenFileName()[0]))
browse_input_btn.clicked.connect(lambda: input_directory.setText(QFileDialog.getExistingDirectory()))
browse_output_btn.clicked.connect(lambda: output_directory.setText(QFileDialog.getExistingDirectory()))
classify_btn.clicked.connect(lambda: main_func(model_file_path.text(), input_directory.text(), output_directory.text()))

window.setLayout(layoutV)
window.show()
app.exec_()

if __name__ == "__main__":
create_gui(main)
74 changes: 74 additions & 0 deletions windows_app/src/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from file_ops import read_csv_files, save_results
from gui import create_gui
import torch
from PyQt5.QtWidgets import QMessageBox
import numpy as np
import pandas as pd
import os
from model import CNN1D

def mean_intensity(temp_df, bin_resolution=0.5):
bins = np.arange(899.9, 3500, bin_resolution)
temp_df['bin'] = pd.cut(temp_df['mass'], bins=bins)
return temp_df.groupby('bin')['intensity'].mean().values

def normalize(tensor):
tensor[torch.isnan(tensor)] = 0
mean = tensor.mean()
std = tensor.std()
return (tensor - mean) / (std + torch.finfo(torch.float32).eps)

def load_model(weight_path):
model = torch.load(weight_path, map_location=torch.device('cpu'))
model.eval()
return model

def main(model_file_path, input_directory, output_directory):
if not model_file_path or not input_directory or not output_directory:
show_missing_paths_message()
return

model = load_model(model_file_path)
file_paths, file_names = read_csv_files(input_directory)
results, file_names, targets = make_predictions(file_paths, model)
save_results(results, output_directory, file_names, targets)
show_done_message()

def show_missing_paths_message():
msg = QMessageBox()
msg.setIcon(QMessageBox.Critical)
msg.setWindowTitle("Missing Path")
msg.setText("Please provide all required paths (Model, Input Directory, Output Directory).")
msg.exec_()

def show_done_message():
msg = QMessageBox()
msg.setIcon(QMessageBox.Information)
msg.setWindowTitle("Process Completed")
msg.setText("The classification process is complete.")
msg.exec_()

def make_predictions(file_paths, model):
results = []
file_names = []
targets = []
for i, file_path in enumerate(file_paths):
temp_df = pd.read_csv(file_path)
file_name = os.path.basename(file_path)
file_names.append(file_name)

intensities = mean_intensity(temp_df)
tensor_data = torch.tensor(intensities, dtype=torch.float32)
tensor_data = normalize(tensor_data)

output = model(tensor_data.unsqueeze(0).unsqueeze(0))
probabilities = torch.softmax(output, dim=1).detach().numpy().round(3)
results.append(probabilities)

target = np.argmax(probabilities)
targets.append(target)

return results, file_names, targets

if __name__ == "__main__":
create_gui(main)
43 changes: 43 additions & 0 deletions windows_app/src/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import torch
import torch.nn as nn

class CNN1D(nn.Module):
def __init__(self, input_size, num_classes):
super(CNN1D, self).__init__()

self.conv1 = nn.Conv1d(1, 32, kernel_size=5, stride=1)

self.conv2 = nn.Conv1d(32, 64, kernel_size=5, stride=1)
#self.bn2 = nn.BatchNorm1d(32)

self.pool = nn.AvgPool1d(kernel_size=3)

output_size = (input_size - 5 + 1) // 3 # After conv1 and pool
output_size = (output_size - 5 + 1) // 3 # After conv2 and pool

self.fc1 = nn.Linear(64 * output_size, 128)
self.dropout1 = nn.Dropout(0.25)

self.fc2 = nn.Linear(128, num_classes)

self.relu = nn.ReLU()

def forward(self, x):
x = self.pool(self.relu(self.conv1(x)))
x = self.pool(self.relu(self.conv2(x)))

x = x.view(x.size(0), -1)

x = self.relu(self.fc1(x))
x = self.dropout1(x)

x = self.fc2(x)
return x


def load_model(weight_path):
model = torch.load(weight_path, map_location=torch.device('cpu'))
model.eval()
return model


84 changes: 84 additions & 0 deletions windows_app/test.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"id": "55bd4cb8-4bf1-48f8-bcb3-11375f947e31",
"metadata": {},
"outputs": [],
"source": [
"import tkinter as tk\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "ab248bcb-4c77-442b-a12d-9bb89965d029",
"metadata": {},
"outputs": [],
"source": [
"from cx_Freeze import setup, Executable\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "bc3e8b53-753d-4d63-9f31-ae485850acbd",
"metadata": {},
"outputs": [],
"source": [
"import torch, pandas "
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "74fb7e63-4b89-43ce-a98d-06871d43f416",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"('1.12.1+cu116', '1.5.0')"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torch.__version__, pandas.__version__"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "20a31617-7550-424c-8119-fd76d66c1cfa",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
1 change: 1 addition & 0 deletions windows_app/untitled.md
sethaxen marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@