diff --git a/.github/workflows/CD.yml b/.github/workflows/CD.yml
new file mode 100644
index 0000000..21e5239
--- /dev/null
+++ b/.github/workflows/CD.yml
@@ -0,0 +1,62 @@
+name: CD
+
+on:
+ push:
+ branches:
+ - main
+ - clavrat/proto-crepe
+ tags:
+ - '*.*.*' # Adjust this pattern to match your tag format
+
+jobs:
+ build_and_release:
+ name: Build and Upload Release
+ runs-on: ubuntu-latest
+
+ steps:
+ - name: Checkout code
+ uses: actions/checkout@v3
+
+ - name: Set up Python
+ uses: actions/setup-python@v4
+ with:
+ python-version: '3.12'
+
+ - name: Install Poetry
+ run: |
+ pip install poetry
+
+ - name: Build the package
+ run: |
+ poetry build
+
+ - name: Create Release
+ id: create_release
+ uses: actions/create-release@v1.0.0
+ env:
+ GITHUB_TOKEN: ${{ secrets.PAT_TOKEN }}
+ with:
+ tag_name: ${{ github.ref }}
+ release_name: Release ${{ github.ref }}
+ draft: false
+ prerelease: false
+
+ - name: Get Name of Artifact
+ run: |
+ ARTIFACT_PATHNAME=$(ls dist/*.whl | head -n 1)
+ ARTIFACT_NAME=$(basename $ARTIFACT_PATHNAME)
+ echo "ARTIFACT_PATHNAME=${ARTIFACT_PATHNAME}" >> $GITHUB_ENV
+ echo "ARTIFACT_NAME=${ARTIFACT_NAME}" >> $GITHUB_ENV
+
+ - name: Upload Whl to Release Assets
+ id: upload-release-asset
+ uses: actions/upload-release-asset@v1.0.2
+ env:
+ GITHUB_TOKEN: ${{ secrets.PAT_TOKEN }}
+ with:
+ upload_url: ${{ steps.create_release.outputs.upload_url }}
+ asset_path: ${{ env.ARTIFACT_PATHNAME }}
+ asset_name: ${{ env.ARTIFACT_NAME }}
+ asset_content_type: application/x-wheel+zip
+
+
\ No newline at end of file
diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml
new file mode 100644
index 0000000..bb4aa2a
--- /dev/null
+++ b/.github/workflows/CI.yml
@@ -0,0 +1,28 @@
+name: CI
+
+on: [push, pull_request]
+
+jobs:
+ linter:
+ runs-on: ubuntu-latest
+ steps:
+ - name: Checkout code
+ uses: actions/checkout@v3
+
+ - name: Lint with flake8
+ run: |
+ pip install flake8
+ flake8 ./crepe --count --select=E9,F63,F7,F82 --show-source --statistics
+# flake8 ./crepe --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
+
+
+ doc_coverage:
+ runs-on: ubuntu-latest
+ steps:
+ - name: Checkout code
+ uses: actions/checkout@v3
+
+ - name: Check documentation coverage with pydocstyle
+ run: |
+ pip install pydocstyle
+ pydocstyle ./crepe
diff --git a/.gitignore b/.gitignore
index 82f9275..eaeb781 100644
--- a/.gitignore
+++ b/.gitignore
@@ -160,3 +160,6 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
+dataset/*
+wandb/*
+.prompts/*
\ No newline at end of file
diff --git a/LICENSE b/LICENSE
deleted file mode 100644
index 261eeb9..0000000
--- a/LICENSE
+++ /dev/null
@@ -1,201 +0,0 @@
- Apache License
- Version 2.0, January 2004
- http://www.apache.org/licenses/
-
- TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
-
- 1. Definitions.
-
- "License" shall mean the terms and conditions for use, reproduction,
- and distribution as defined by Sections 1 through 9 of this document.
-
- "Licensor" shall mean the copyright owner or entity authorized by
- the copyright owner that is granting the License.
-
- "Legal Entity" shall mean the union of the acting entity and all
- other entities that control, are controlled by, or are under common
- control with that entity. For the purposes of this definition,
- "control" means (i) the power, direct or indirect, to cause the
- direction or management of such entity, whether by contract or
- otherwise, or (ii) ownership of fifty percent (50%) or more of the
- outstanding shares, or (iii) beneficial ownership of such entity.
-
- "You" (or "Your") shall mean an individual or Legal Entity
- exercising permissions granted by this License.
-
- "Source" form shall mean the preferred form for making modifications,
- including but not limited to software source code, documentation
- source, and configuration files.
-
- "Object" form shall mean any form resulting from mechanical
- transformation or translation of a Source form, including but
- not limited to compiled object code, generated documentation,
- and conversions to other media types.
-
- "Work" shall mean the work of authorship, whether in Source or
- Object form, made available under the License, as indicated by a
- copyright notice that is included in or attached to the work
- (an example is provided in the Appendix below).
-
- "Derivative Works" shall mean any work, whether in Source or Object
- form, that is based on (or derived from) the Work and for which the
- editorial revisions, annotations, elaborations, or other modifications
- represent, as a whole, an original work of authorship. For the purposes
- of this License, Derivative Works shall not include works that remain
- separable from, or merely link (or bind by name) to the interfaces of,
- the Work and Derivative Works thereof.
-
- "Contribution" shall mean any work of authorship, including
- the original version of the Work and any modifications or additions
- to that Work or Derivative Works thereof, that is intentionally
- submitted to Licensor for inclusion in the Work by the copyright owner
- or by an individual or Legal Entity authorized to submit on behalf of
- the copyright owner. For the purposes of this definition, "submitted"
- means any form of electronic, verbal, or written communication sent
- to the Licensor or its representatives, including but not limited to
- communication on electronic mailing lists, source code control systems,
- and issue tracking systems that are managed by, or on behalf of, the
- Licensor for the purpose of discussing and improving the Work, but
- excluding communication that is conspicuously marked or otherwise
- designated in writing by the copyright owner as "Not a Contribution."
-
- "Contributor" shall mean Licensor and any individual or Legal Entity
- on behalf of whom a Contribution has been received by Licensor and
- subsequently incorporated within the Work.
-
- 2. Grant of Copyright License. Subject to the terms and conditions of
- this License, each Contributor hereby grants to You a perpetual,
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
- copyright license to reproduce, prepare Derivative Works of,
- publicly display, publicly perform, sublicense, and distribute the
- Work and such Derivative Works in Source or Object form.
-
- 3. Grant of Patent License. Subject to the terms and conditions of
- this License, each Contributor hereby grants to You a perpetual,
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
- (except as stated in this section) patent license to make, have made,
- use, offer to sell, sell, import, and otherwise transfer the Work,
- where such license applies only to those patent claims licensable
- by such Contributor that are necessarily infringed by their
- Contribution(s) alone or by combination of their Contribution(s)
- with the Work to which such Contribution(s) was submitted. If You
- institute patent litigation against any entity (including a
- cross-claim or counterclaim in a lawsuit) alleging that the Work
- or a Contribution incorporated within the Work constitutes direct
- or contributory patent infringement, then any patent licenses
- granted to You under this License for that Work shall terminate
- as of the date such litigation is filed.
-
- 4. Redistribution. You may reproduce and distribute copies of the
- Work or Derivative Works thereof in any medium, with or without
- modifications, and in Source or Object form, provided that You
- meet the following conditions:
-
- (a) You must give any other recipients of the Work or
- Derivative Works a copy of this License; and
-
- (b) You must cause any modified files to carry prominent notices
- stating that You changed the files; and
-
- (c) You must retain, in the Source form of any Derivative Works
- that You distribute, all copyright, patent, trademark, and
- attribution notices from the Source form of the Work,
- excluding those notices that do not pertain to any part of
- the Derivative Works; and
-
- (d) If the Work includes a "NOTICE" text file as part of its
- distribution, then any Derivative Works that You distribute must
- include a readable copy of the attribution notices contained
- within such NOTICE file, excluding those notices that do not
- pertain to any part of the Derivative Works, in at least one
- of the following places: within a NOTICE text file distributed
- as part of the Derivative Works; within the Source form or
- documentation, if provided along with the Derivative Works; or,
- within a display generated by the Derivative Works, if and
- wherever such third-party notices normally appear. The contents
- of the NOTICE file are for informational purposes only and
- do not modify the License. You may add Your own attribution
- notices within Derivative Works that You distribute, alongside
- or as an addendum to the NOTICE text from the Work, provided
- that such additional attribution notices cannot be construed
- as modifying the License.
-
- You may add Your own copyright statement to Your modifications and
- may provide additional or different license terms and conditions
- for use, reproduction, or distribution of Your modifications, or
- for any such Derivative Works as a whole, provided Your use,
- reproduction, and distribution of the Work otherwise complies with
- the conditions stated in this License.
-
- 5. Submission of Contributions. Unless You explicitly state otherwise,
- any Contribution intentionally submitted for inclusion in the Work
- by You to the Licensor shall be under the terms and conditions of
- this License, without any additional terms or conditions.
- Notwithstanding the above, nothing herein shall supersede or modify
- the terms of any separate license agreement you may have executed
- with Licensor regarding such Contributions.
-
- 6. Trademarks. This License does not grant permission to use the trade
- names, trademarks, service marks, or product names of the Licensor,
- except as required for reasonable and customary use in describing the
- origin of the Work and reproducing the content of the NOTICE file.
-
- 7. Disclaimer of Warranty. Unless required by applicable law or
- agreed to in writing, Licensor provides the Work (and each
- Contributor provides its Contributions) on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
- implied, including, without limitation, any warranties or conditions
- of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
- PARTICULAR PURPOSE. You are solely responsible for determining the
- appropriateness of using or redistributing the Work and assume any
- risks associated with Your exercise of permissions under this License.
-
- 8. Limitation of Liability. In no event and under no legal theory,
- whether in tort (including negligence), contract, or otherwise,
- unless required by applicable law (such as deliberate and grossly
- negligent acts) or agreed to in writing, shall any Contributor be
- liable to You for damages, including any direct, indirect, special,
- incidental, or consequential damages of any character arising as a
- result of this License or out of the use or inability to use the
- Work (including but not limited to damages for loss of goodwill,
- work stoppage, computer failure or malfunction, or any and all
- other commercial damages or losses), even if such Contributor
- has been advised of the possibility of such damages.
-
- 9. Accepting Warranty or Additional Liability. While redistributing
- the Work or Derivative Works thereof, You may choose to offer,
- and charge a fee for, acceptance of support, warranty, indemnity,
- or other liability obligations and/or rights consistent with this
- License. However, in accepting such obligations, You may act only
- on Your own behalf and on Your sole responsibility, not on behalf
- of any other Contributor, and only if You agree to indemnify,
- defend, and hold each Contributor harmless for any liability
- incurred by, or claims asserted against, such Contributor by reason
- of your accepting any such warranty or additional liability.
-
- END OF TERMS AND CONDITIONS
-
- APPENDIX: How to apply the Apache License to your work.
-
- To apply the Apache License to your work, attach the following
- boilerplate notice, with the fields enclosed by brackets "[]"
- replaced with your own identifying information. (Don't include
- the brackets!) The text should be enclosed in the appropriate
- comment syntax for the file format. We also recommend that a
- file or class name and description of purpose be included on the
- same "printed page" as the copyright notice for easier
- identification within third-party archives.
-
- Copyright [yyyy] [name of copyright owner]
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
diff --git a/README.md b/README.md
index 2dc6323..3fa40cf 100644
--- a/README.md
+++ b/README.md
@@ -1,2 +1,76 @@
-# TorchCrepe
-Implementation of the crepe pitch extractor in PyTorch
+
The Torch-CREPE project re-develop the CREPE pitch estimation model in PyTorch, empowering its optimization and adaptation for real-time voice pitch detection tasks. By re-developping this deep learning-based system, we unlock new research possibilities for music signal processing and audio analysis applications.
+
+## How it Works
+
+The **PyTorch CREPE** implementation utilizes the **Torch** and **Torchaudio** library to process and analyze audio signals. The project's core functionality is based on the CREPE model, which estimates fundamental frequencies from audio data.
+
+The way this model achieve this is by doing a classification of 20ms audio chunks on 350 classes representing the audio range in cents of the observed fundamental frequency.
+
+## Features
+
+- **Real-time pitch detection:** Processing done in realtime using the given script.
+- **Optimized for instrument and voices:** Trained on instruments and voices for maximum usescases focuses.
+- **Deep learning-based**: system with full PyTorch implementation
+- **Fast Integration** with Torchaudio library
+- **Trainable on Consumer GPU** (complete train done on an RTX-3080)
+
+## Run app locally
+
+To run the PyTorch CREPE demo locally, you can use the following Python code:
+
+```py
+import torchaudio
+from crepe.model import crepe
+from crepe.utils import load_test_file
+
+crepe = crepe(model_capacity="tiny", device='cpu')
+
+audio, sr = load_test_file()
+
+time, frequency, confidence, activation = crepe.predict(
+ audio=audio,
+ sr = sr
+)
+```
+
+## Python API
+
+For a detailed documentation of the PyTorch CREPE implementation, including the API and usage guidelines, please refer to [this link].
+
+## Train
+
+The model is still in my training queue so only the 'tiny' version of **Crepe** has been trained yet.
+
+## Datasets
+
+[MIR-1K](http://mirlab.org/dataset/public/MIR-1K.zip)
+
+## Contributing
+
+This project is an open-source project, and contributions are always welcome. If you would like to contribute to the project, you can do so by submitting a pull request or by creating an issue on the project's GitHub page.
+
+## License
+
+This project is licensed under the MIT License. See the [LICENSE](LICENSE) file for details.
diff --git a/TorchCREPE_banner.png b/TorchCREPE_banner.png
new file mode 100644
index 0000000..0a1b107
Binary files /dev/null and b/TorchCREPE_banner.png differ
diff --git a/crepe/crepe-tiny.pth b/crepe/crepe-tiny.pth
new file mode 100644
index 0000000..98a6a39
Binary files /dev/null and b/crepe/crepe-tiny.pth differ
diff --git a/crepe/dataset.py b/crepe/dataset.py
new file mode 100644
index 0000000..1855a71
--- /dev/null
+++ b/crepe/dataset.py
@@ -0,0 +1,332 @@
+"""This file contains various dataloader for training and processing audio data and labels."""
+
+from torch.utils.data import Dataset
+import os
+import torch
+import torchaudio
+import glob
+import json
+from torch.utils.data import Dataset, ConcatDataset
+from concurrent.futures import ThreadPoolExecutor, as_completed
+
+
+class MIR1KDataset(Dataset):
+ """
+ MIR-1K Dataset.
+
+ Args:
+ Dataset (Dataset): from torch.utils.data import Dataset
+ """
+
+ def __init__(self, root_dir):
+ """
+ Create an instance of the MIR-1K dataset.
+
+ This class loads and prepares the data from the MIR-1K dataset,
+ which consists of audio files with corresponding pitch labels.
+
+ Attributes:
+ root_dir (str): The root directory containing the MIR-1K dataset.
+
+ Methods:
+ __init__: Initializes the dataset by loading the audio files and pitch labels.
+ __len__: Returns the total number of samples in the dataset.
+ __getitem__: Loads a single sample from the dataset, consisting of an audio file
+ and its corresponding pitch label.
+ """
+ self.root_dir = root_dir
+ # self.labels = [f.replace('.wav', '.txt') for f in self.files]
+ self.files = sorted(glob.glob(os.path.join(
+ self.root_dir+"/Wavfile", f"*.wav")))
+ self.labels = sorted(glob.glob(os.path.join(
+ self.root_dir+"/PitchLabel", f"*.pv")))
+
+ def __len__(self):
+ """
+ Retrieve the total number of samples in the dataset.
+
+ Returns:
+ int: The total number of samples in the dataset.
+ """
+ return len(self.files)
+
+ def __getitem__(self, idx):
+ """
+ Retrieve a single sample from the dataset.
+
+ Args:
+ idx (int): The index of the sample to retrieve.
+
+ Returns:
+ tuple: A tuple containing the audio file and its corresponding pitch label,
+ where the audio is a tensor representing the audio waveform, and the
+ pitch label is a tensor containing the pitch values for each frame.
+
+ Note:
+ The returned audio tensor has shape (1, num_frames), where num_frames is the
+ number of frames in the original audio file. The returned pitch label tensor has
+ shape (num_frames,), containing one pitch value per frame.
+ """
+ audio_path = os.path.abspath(self.files[idx])
+ label_path = self.labels[idx]
+
+ audio, sr = torchaudio.load(audio_path)
+
+ with open(label_path, 'r') as f:
+ labels = [float(line.strip()) for line in f.readlines()]
+
+ labels = torch.tensor(labels)
+
+ return audio[1, :], labels
+
+
+class Back10Dataset(Dataset):
+ """
+ Bach10 Dataset.
+
+ Args:
+ Dataset (Dataset): from torch.utils.data import Dataset
+ """
+
+ def __init__(self, root_dir):
+ """
+ Create an instance of the Bach10 dataset.
+
+ This class loads and prepares the data from the Bach10 dataset,
+ which consists of audio files with corresponding pitch labels.
+
+ Attributes:
+ root_dir (str): The root directory containing the Bach10 dataset.
+
+ Methods:
+ __init__: Initializes the dataset by loading the audio files and pitch labels.
+ __len__: Returns the total number of samples in the dataset.
+ __getitem__: Loads a single sample from the dataset, consisting of an audio file
+ and its corresponding pitch label.
+ """
+ self.root_dir = root_dir
+ self.files_violin = sorted(glob.glob(os.path.join(
+ self.root_dir, f"*/*violin.wav"), recursive=True))
+ self.files_clarinet = sorted(glob.glob(os.path.join(
+ self.root_dir, f"*/*clarinet.wav"), recursive=True))
+ self.files_saxophone = sorted(glob.glob(os.path.join(
+ self.root_dir, f"*/*saxophone.wav"), recursive=True))
+ self.files_bassoon = sorted(glob.glob(os.path.join(
+ self.root_dir, f"*/*bassoon.wav"), recursive=True))
+ self.dataset_orga = {}
+
+ idx = 0
+ for path in self.files_violin:
+ self.dataset_orga[idx] = {}
+ self.dataset_orga[idx]["type"] = "violin"
+ self.dataset_orga[idx]["number"] = 1
+ self.dataset_orga[idx]["audio_path"] = path
+ self.dataset_orga[idx]["label_path"] = path.replace(
+ '-violin.wav', '.txt')
+ idx += 1
+
+ for path in self.files_clarinet:
+ self.dataset_orga[idx] = {}
+ self.dataset_orga[idx]["type"] = "clarinet"
+ self.dataset_orga[idx]["number"] = 2
+ self.dataset_orga[idx]["audio_path"] = path
+ self.dataset_orga[idx]["label_path"] = path.replace(
+ '-clarinet.wav', '.txt')
+ idx += 1
+
+ for path in self.files_saxophone:
+ self.dataset_orga[idx] = {}
+ self.dataset_orga[idx]["type"] = "saxophone"
+ self.dataset_orga[idx]["number"] = 3
+ self.dataset_orga[idx]["audio_path"] = path
+ self.dataset_orga[idx]["label_path"] = path.replace(
+ '-saxophone.wav', '.txt')
+ idx += 1
+
+ for path in self.files_bassoon:
+ self.dataset_orga[idx] = {}
+ self.dataset_orga[idx]["type"] = "bassoon"
+ self.dataset_orga[idx]["number"] = 4
+ self.dataset_orga[idx]["audio_path"] = path
+ self.dataset_orga[idx]["label_path"] = path.replace(
+ '-bassoon.wav', '.txt')
+ idx += 1
+
+ self.len = idx
+
+ self.labels = sorted(glob.glob(os.path.join(
+ self.root_dir, f"*/*.txt"), recursive=True))
+
+ def _load_data(self, file_path, instrument_number):
+ data = []
+ with open(file_path, 'r') as file:
+ for line in file:
+ if line.strip(): # Ignore empty lines
+ # Parse the line
+ parts = line.strip().split()
+ # time_audio = int(parts[0])
+ # time_midi = int(parts[1])
+ midi_pitch = int(parts[2])
+ channel = int(parts[3])
+
+ if channel == instrument_number:
+ # Convert MIDI pitch to fundamental frequency
+ frequency = 440 * 2**((midi_pitch - 69) / 12)
+ data.append(frequency)
+ return data
+
+ def __len__(self):
+ """
+ Retrieve the total number of samples in the dataset.
+
+ Returns:
+ int: The total number of samples in the dataset.
+ """
+ return self.len
+
+ def __getitem__(self, idx):
+ """
+ Retrieve a single sample from the dataset.
+
+ Args:
+ idx (int): The index of the sample to retrieve.
+
+ Returns:
+ tuple: A tuple containing the audio file and its corresponding pitch label,
+ where the audio is a tensor representing the audio waveform, and the
+ pitch label is a tensor containing the pitch values for each frame.
+ (audio, label)
+
+ Note:
+ The returned audio tensor has shape (1, num_frames), where num_frames is the
+ number of frames in the original audio file. The returned pitch label tensor has
+ shape (num_frames,), containing one pitch value per frame.
+ """
+ instr_type = self.dataset_orga[idx]['number']
+ audio_path = self.dataset_orga[idx]['audio_path']
+ label_path = self.dataset_orga[idx]['label_path']
+
+ audio, sr = torchaudio.load(audio_path)
+ label = self._load_data(label_path, instr_type)
+ label = torch.tensor(label)
+ audio = torch.mean(audio, dim=0)
+
+ return audio, label
+
+
+class NSynthDataset(Dataset):
+ """
+ Nsynth Dataset.
+
+ Args:
+ Dataset (Dataset): from torch.utils.data import Dataset
+ """
+
+ def __init__(self, root_dir, n_samples=1):
+ """
+ Create an instance of the NSynth dataset.
+
+ This class loads and prepares the data from the NSynth dataset,
+ which consists of audio files with corresponding pitch labels.
+
+ Attributes:
+ root_dir (str): The root directory containing the NSynth dataset.
+
+ Methods:
+ __init__: Initializes the dataset by loading the audio files and pitch labels.
+ __len__: Returns the total number of samples in the dataset.
+ __getitem__: Loads a single sample from the dataset, consisting of an audio file
+ and its corresponding pitch label.
+ """
+ self.root_dir = root_dir
+ self.n_samples = n_samples
+
+ # Load file paths
+ self.files = sorted(glob.glob(os.path.join(
+ root_dir, "*/*/*.wav"), recursive=True))
+ self.infos = sorted(glob.glob(os.path.join(
+ root_dir, "*/*.json"), recursive=True))
+
+ # Precompute a mapping from filenames to their paths
+ self.filename_to_path = {os.path.splitext(os.path.basename(f))[
+ 0]: f for f in self.files}
+
+ # Load all metadata
+ self.data = self._load_metadata()
+
+ # Create a mapping from filenames to their data for fast access
+ self.file_to_data = {filename: self.data.get(
+ filename, {}) for filename in self.data}
+
+ def _load_metadata(self):
+ data = {}
+ for json_file in self.infos:
+ with open(json_file, 'r') as f:
+ data.update(json.load(f))
+ return data
+
+ def __len__(self):
+ """
+ Retrieve the total number of samples in the dataset.
+
+ Returns:
+ int: The total number of samples in the dataset.
+ """
+ return len(self.files) // self.n_samples
+
+ def _load_audio_and_pitch(self, filename):
+ audio_path = self.filename_to_path[filename]
+ audio, sr = torchaudio.load(audio_path)
+ audio = torch.mean(audio, dim=0) # Convert to mono
+
+ # Retrieve MIDI pitch and calculate frequency
+ midi_pitch = self.file_to_data.get(filename, {}).get(
+ "pitch", 69) # Default to 69 if not found
+ pitch = 440 * 2 ** ((midi_pitch - 69) / 12)
+ pitch = torch.ones([audio.shape[0] // 40]) * pitch
+
+ return audio, pitch
+
+ def __getitem__(self, idx):
+ """
+ Retrieve a single sample from the dataset.
+
+ Args:
+ idx (int): The index of the sample to retrieve.
+
+ Returns:
+ tuple: A tuple containing the audio file and its corresponding pitch label,
+ where the audio is a tensor representing the audio waveform, and the
+ pitch label is a tensor containing the pitch values for each frame.
+ (audio, label)
+
+ Note:
+ The returned audio tensor has shape (1, num_frames), where num_frames is the
+ number of frames in the original audio file. The returned pitch label tensor has
+ shape (num_frames,), containing one pitch value per frame.
+ """
+ start_idx = idx * self.n_samples
+ end_idx = start_idx + self.n_samples
+
+ # Ensure we don't go out of bounds
+ if end_idx > len(self.files):
+ raise IndexError(
+ "Index out of bounds for the requested group of samples.")
+
+ audio_list = []
+ pitch_list = []
+
+ filenames = list(self.data.keys())[start_idx:end_idx]
+
+ with ThreadPoolExecutor() as executor:
+ future_to_filename = {executor.submit(
+ self._load_audio_and_pitch, filename): filename for filename in filenames}
+ for future in as_completed(future_to_filename):
+ audio, pitch = future.result()
+ audio_list.append(audio)
+ pitch_list.append(pitch)
+
+ audio = torch.cat(audio_list)
+ pitch = torch.cat(pitch_list)
+
+ return audio, pitch
diff --git a/crepe/model.py b/crepe/model.py
new file mode 100644
index 0000000..bd8328e
--- /dev/null
+++ b/crepe/model.py
@@ -0,0 +1,242 @@
+"""This file contains the Crepe model and its block."""
+
+import torch
+import os
+import torchaudio
+import torch.nn as nn
+
+from crepe.utils import get_frame, activation_to_frequency
+
+
+class ConvBlock(nn.Module):
+ """
+ Convolutional block model.
+
+ Args:
+ nn.Module (nn.Module): import torch.nn as nn
+ """
+
+ def __init__(self, out_channels, kernel_width, stride, in_channels):
+ """
+ Convolutional block with one or more convolutional layers.
+
+ Args:
+ out_channels (int): The number of output channels.
+ kernel_width (int): The width of the convolutional kernel.
+ stride (tuple, int): The stride for each dimension.
+ in_channels (int): The number of input channels.
+
+ Attributes:
+ layer (nn.Sequential): A sequential container holding the block's layers.
+
+ Methods:
+ forward(x): Passes an input tensor `x` through the block.
+ """
+ super(ConvBlock, self).__init__()
+
+ # Calculate padding for the height dimension (kernel width)
+ pad_top = (kernel_width - 1) // 2
+ pad_bottom = (kernel_width - 1) - pad_top
+
+ # Define the block using nn.Sequential
+ self.layer = nn.Sequential(
+ # Add padding to the input
+ nn.ZeroPad2d((0, 0, pad_top, pad_bottom)),
+ nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=(kernel_width, 1),
+ stride=stride
+ ), # Apply 2D convolution
+ nn.ReLU(), # Apply ReLU activation
+ nn.BatchNorm2d(out_channels), # Apply batch normalization
+ nn.MaxPool2d(kernel_size=(2, 1)), # Apply max pooling
+ nn.Dropout(p=0.25) # Apply dropout for regularization
+ )
+
+ def forward(self, x):
+ """
+ Pass an input tensor `x` through the network.
+
+ Args:
+ x (torch.Tensor): The input tensor to pass through the network.
+
+ Returns:
+ torch.Tensor: The output of the network.
+ """
+ return self.layer(x)
+
+
+class Crepe(nn.Module):
+ """
+ Crepe model.
+
+ Args:
+ nn.Module (nn.Module): import torch.nn as nn
+ """
+
+ def __init__(self, model_capacity="full", device='cpu'):
+ """
+ CREPE model for pitch estimation.
+
+ Args:
+ model_capacity (str): The capacity of the network ('tiny', 'small', 'medium', 'large', or 'full').
+ device (str): The device to run the model on ('cpu', 'gpu', 'mps').
+
+ Attributes:
+ model_capacity (str): The capacity of the network.
+ convolutional_blocks (nn.Sequential): A sequential container holding the convolutional blocks.
+ linear (nn.Linear): A linear layer for final mapping.
+
+ Methods:
+ forward(x): Passes an input tensor `x` through the network.
+ get_activation(audio, sr, center=True, step_size=10, batch_size=128):
+ Computes the activation stack for a given audio signal and sampling rate.
+ predict(audio, sr, center=True, step_size=10, batch_size=128):
+ Predicts pitch class labels from an input audio signal.
+
+ Note:
+ The model's capacity determines its size and complexity.
+ """
+ super(Crepe, self).__init__()
+
+ # Define a multiplier for the network's capacity based on the selected model size
+ self.model_capacity = model_capacity
+ capacity_multiplier = {
+ 'tiny': 4, 'small': 8, 'medium': 16, 'large': 24, 'full': 32
+ }[model_capacity]
+
+ # Define the number of filters for each layer, scaled by the capacity multiplier
+ filters = [n * capacity_multiplier for n in [32, 4, 4, 4, 8, 16]]
+ # Include the input channel size as the first element
+ filters = [1] + filters
+
+ # Define the kernel widths and strides for each layer
+ widths = [512, 64, 64, 64, 64, 64]
+ strides = [(4, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1)]
+
+ # Create a list of layers for the Sequential container
+ layers = []
+ for i in range(len(filters) - 1):
+ layers.append(ConvBlock(
+ out_channels=filters[i + 1],
+ kernel_width=widths[i],
+ stride=strides[i],
+ in_channels=filters[i]
+ ))
+
+ # Add all layers to the Sequential container
+ self.convolutional_blocks = nn.Sequential(*layers)
+
+ # Define the final linear layer to map the output to 360 classes (e.g., pitch classes)
+ self.linear = nn.Linear(64 * capacity_multiplier, 360)
+
+ # load model
+ self.device = device
+ self.load_weight(model_capacity)
+
+ # Set the model to evaluation mode by default
+ self.eval()
+
+ def load_weight(self, model_capacity):
+ """
+ Load the weights for a given model capacity.
+
+ Args:
+ model_capacity (str): The capacity of the network ('tiny', 'small', 'medium', 'large', or 'full').
+
+ Note:
+ The model's capacity determines its size and complexity.
+ """
+ package_dir = os.path.dirname(os.path.realpath(__file__))
+ filename = "crepe-{}.pth".format(model_capacity)
+ try:
+ self.load_state_dict(torch.load(os.path.join(
+ package_dir, filename), map_location=torch.device(self.device), weights_only=True))
+ except:
+ print(f"{filename} Not found.")
+
+ def forward(self, x):
+ """
+ Pass an input tensor `x` through the network.
+
+ Args:
+ x (torch.Tensor): The input tensor to pass through the network.
+
+ Returns:
+ torch.Tensor: The output of the network.
+ """
+ x = x.view(x.shape[0], 1, -1, 1)
+
+ # Pass the input through each convolutional block sequentially
+ x = self.convolutional_blocks(x)
+
+ # Reorder dimensions and flatten before passing to the linear layer
+ x = x.permute(0, 3, 2, 1)
+ x = x.reshape(x.shape[0], -1)
+
+ # Apply the final linear layer and sigmoid activation
+ x = self.linear(x)
+ x = torch.sigmoid(x)
+
+ return x
+
+ def get_activation(self, audio, sr, center=True, step_size=10, batch_size=128):
+ """
+ Compute the activation stack for a given audio signal and sampling rate.
+
+ Args:
+ audio (torch.Tensor): The input audio tensor.
+ sr (int): The sampling rate of the audio signal.
+ center (bool): Whether to center the frames around each other. Defaults to True.
+ step_size (int): The number of samples per frame. Defaults to 10.
+ batch_size (int): The batch size for computing activations.
+
+ Returns:
+ torch.Tensor: The activation stack.
+ """
+ # resample to 16kHz if needed
+ if sr != 16000:
+ rs = torchaudio.transforms.Resample(sr, 16000)
+ audio = rs(audio)
+
+ # make mono if needed
+ if len(audio.shape) == 2:
+ if audio.shape[0] == 1:
+ audio = audio[0]
+ else:
+ audio = audio.mean(dim=0)
+
+ frames = get_frame(audio, step_size, center)
+ activation_stack = []
+ device = self.linear.weight.device
+
+ for i in range(0, len(frames), batch_size):
+ f = frames[i:min(i+batch_size, len(frames))]
+ f = f.to(device)
+ act = self.forward(f)
+ activation_stack.append(act.cpu())
+ activation = torch.cat(activation_stack, dim=0)
+
+ return activation
+
+ def predict(self, audio, sr, center=True, step_size=10, batch_size=128):
+ """
+ Predict pitch class labels from an input audio signal.
+
+ Args:
+ audio (torch.Tensor): The input audio tensor.
+ sr (int): The sampling rate of the audio signal.
+ center (bool): Whether to center the frames around each other. Defaults to True.
+ step_size (int): The number of samples per frame. Defaults to 10.
+ batch_size (int): The batch size for computing activations.
+
+ Returns:
+ tuple: A tuple containing the time, frequency, confidence, and activation stack.
+ """
+ activation = self.get_activation(
+ audio, sr, batch_size=batch_size, step_size=step_size)
+ frequency = activation_to_frequency(activation)
+ confidence = activation.max(dim=1)[0]
+ time = torch.arange(confidence.shape[0]) * step_size / 1000.0
+ return time, frequency, confidence, activation
diff --git a/crepe/utils.py b/crepe/utils.py
new file mode 100644
index 0000000..240d1d2
--- /dev/null
+++ b/crepe/utils.py
@@ -0,0 +1,165 @@
+"""This file contains various functions for processing and converting audio data and labels."""
+
+import torch
+import torch.nn as nn
+import subprocess
+import torchaudio
+import os
+
+
+def get_frame(audio, step_size, center):
+ """
+ Extract audio frames from a given audio signal.
+
+ Args:
+ audio (Tensor): The input audio signal.
+ step_size (float): The time step size in milliseconds. Audio will be divided into
+ 1024-sample frames with this hop length.
+ center (bool): If True, pads the audio to have equal number of samples on both sides.
+
+ Returns:
+ Tensor: A tensor containing the extracted audio frames, standardized to have zero mean and unit standard deviation.
+ """
+ if center:
+ audio = nn.functional.pad(audio, pad=(512, 512))
+ # make 1024-sample frames of the audio with hop length of 10 milliseconds
+ hop_length = int(16000 * step_size / 1000)
+ n_frames = 1 + (len(audio) - 1024) // hop_length
+ frames = torch.as_strided(audio, size=(
+ 1024, n_frames), stride=(1, hop_length))
+ frames = frames.transpose(0, 1).clone()
+
+ mean = torch.mean(frames, dim=1, keepdim=True)
+ # Adding epsilon to prevent division by zero
+ std = torch.std(frames, dim=1, keepdim=True) + 1e-8
+
+ frames -= mean
+ frames /= std
+ return frames
+
+
+def to_local_average_cents(salience, center=None):
+ """
+ Compute the weighted average cents near the argmax bin of a salience vector.
+
+ Args:
+ salience (Tensor): A 1D or 2D tensor representing the salience values.
+ center (int, optional): The index around which to compute the weighted average. Defaults to None.
+
+ Returns:
+ Tensor: The weighted average cents near the argmax bin.
+
+ Notes:
+ This function assumes that the input salience values are normalized such that their sum equals 1.
+ """
+ if not hasattr(to_local_average_cents, 'cents_mapping'):
+ # The bin number-to-cents mapping
+ to_local_average_cents.cents_mapping = (
+ torch.linspace(0,
+ 1200 * torch.log2(torch.tensor(3951.066/10)),
+ 360, dtype=salience.dtype,
+ device=salience.device) + 1200 * torch.log2(torch.tensor(32.70/10)))
+
+ if salience.ndim == 1:
+ if center is None:
+ center = int(torch.argmax(salience))
+ start = max(0, center - 4)
+ end = min(len(salience), center + 5)
+ salience_segment = salience[start:end]
+ mapping_segment = to_local_average_cents.cents_mapping[start:end]
+ product_sum = torch.sum(salience_segment * mapping_segment)
+ weight_sum = torch.sum(salience_segment)
+ return product_sum / weight_sum
+ elif salience.ndim == 2:
+ return torch.stack([to_local_average_cents(salience[i, :]) for i in range(salience.shape[0])])
+
+
+def activation_to_frequency(activations):
+ """
+ Convert activations to a corresponding frequency value.
+
+ Args:
+ activations (tensor): The input activations to convert.
+
+ Returns:
+ tensor: A tensor representing the frequency values.
+ """
+ cents = to_local_average_cents(activations)
+ frequency = 10 * 2 ** (cents / 1200)
+ frequency[torch.isnan(frequency)] = 0
+ frequency = torch.where(frequency < 32.71, torch.tensor(
+ 1e-7, device=frequency.device), frequency)
+ return frequency
+
+
+def frequency_to_activation(frequencies, num_bins=360):
+ """
+ Convert a tensor of frequencies to a binary activation map.
+
+ Args:
+ frequencies (torch.Tensor): The input frequencies.
+ num_bins (int, optional): The number of bins in the activation map. Defaults to 360.
+
+ Returns:
+ torch.Tensor: A binary activation map where each row corresponds to the frequency in the corresponding
+ row of `frequencies`.
+ """
+ # Convert frequency to cents
+ cents = 1200 * torch.log2(frequencies / 10)
+
+ # Create the cents-to-bin mapping if it doesn't already exist
+ if not hasattr(frequency_to_activation, 'cents_mapping'):
+ frequency_to_activation.cents_mapping = (
+ torch.linspace(0,
+ 1200 * torch.log2(torch.tensor(3951.066/10)),
+ num_bins, dtype=frequencies.dtype,
+ device=frequencies.device) + 1200 * torch.log2(torch.tensor(32.70/10)))
+
+ # Initialize activation map with zeros; expects batch input for frequencies
+ activations = torch.zeros(
+ frequencies.shape[0], num_bins, dtype=frequencies.dtype, device=frequencies.device)
+
+ # Find the closest bin to the calculated cents value for each frequency in the batch
+ for i in range(frequencies.shape[0]):
+ closest_bin = torch.argmin(
+ torch.abs(frequency_to_activation.cents_mapping - cents[i]))
+ activations[i, closest_bin] = 1.0
+
+ return activations
+
+
+def load_test_file(filename: str, mono: bool = False, normalize: bool = False):
+ """
+ Load a test audio file from a remote server into a PyTorch Audio tensor.
+
+ Args:
+ filename (str): The name of the audio file to download.
+ mono (bool, optional): If True, load the audio in monaural format. Defaults to False.
+ normalize (bool, optional): If True, normalize the audio signal. Defaults to False.
+
+ Returns:
+ Tuple[Tensor, int]: A tuple containing the loaded audio tensor and its sample rate.
+ If an error occurs during download or loading, returns (None, None).
+ """
+ # Construct the URL for the file on your server
+ url = f"https://openfileserver.chloelavrat.com/testfiles/audio/{filename}"
+
+ # Create a temporary directory to store the downloaded files
+ temp_dir = "/tmp" # You can change this to any other directory you like
+
+ # Construct the full path where the downloaded file will be saved
+ filepath = os.path.join(temp_dir, filename)
+
+ try:
+ # Use subprocess to run wget and download the file
+ subprocess.check_call(
+ ["wget", "-P", temp_dir, url, "--no-check-certificate", "-q"])
+
+ # Load the audio file into a PyTorch Audio tensor
+ audio, sr = torchaudio.load(filepath, normalize=normalize)
+
+ return audio, sr
+
+ except subprocess.CalledProcessError as e:
+ print(f"Failed to download {filename}: {e}")
+ return None
diff --git a/notebooks/dataset.ipynb b/notebooks/dataset.ipynb
new file mode 100644
index 0000000..caa4a10
--- /dev/null
+++ b/notebooks/dataset.ipynb
@@ -0,0 +1,261 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Datasets"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# fix relative import\n",
+ "import os, sys\n",
+ "dir2 = os.path.abspath('')\n",
+ "dir1 = os.path.dirname(dir2)\n",
+ "if not dir1 in sys.path: sys.path.append(dir1)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "audio backend list: ['soundfile']\n"
+ ]
+ }
+ ],
+ "source": [
+ "# import libs + list audio backend\n",
+ "from tqdm import tqdm\n",
+ "import torch\n",
+ "import torchaudio\n",
+ "from concurrent.futures import ThreadPoolExecutor, as_completed\n",
+ "print(f\"audio backend list: {str(torchaudio.list_audio_backends())}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# dataset scan function with multithreading\n",
+ "def dataset_scan(data, desc, max_workers=4):\n",
+ " error = False\n",
+ " \n",
+ " def process_item(i):\n",
+ " try:\n",
+ " audio = data[i][0].shape\n",
+ " label = data[i][1].shape\n",
+ " except Exception as e:\n",
+ " return (i, data.files[i], str(e))\n",
+ " return None\n",
+ " \n",
+ " with ThreadPoolExecutor(max_workers=max_workers) as executor:\n",
+ " futures = {executor.submit(process_item, i): i for i in range(len(data))}\n",
+ " \n",
+ " for future in tqdm(as_completed(futures), total=len(data), desc=desc):\n",
+ " result = future.result()\n",
+ " if result is not None:\n",
+ " error = True\n",
+ " i, file, exception = result\n",
+ " print(f\"Error id: {i} file: {file} - Exception: {exception}\")\n",
+ " \n",
+ " return error"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "** mir_1k **\n",
+ "lenght of dataset: 1000\n",
+ "random sample id: 850\n",
+ "id_850 audio size: torch.Size([147969])\n",
+ "id_850 label size: torch.Size([461])\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "test mir_1k: 100%|██████████| 1000/1000 [00:01<00:00, 854.41it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Dataset analyzed and ready to be used\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Import: MIR-1K\n",
+ "from crepe.dataset import MIR1KDataset\n",
+ "mir_1k = MIR1KDataset(root_dir=os.path.join(dir1, \"dataset/MIR-1K\"))\n",
+ "\n",
+ "# Test: MIR-1K\n",
+ "print(\"** mir_1k **\")\n",
+ "print(f\"lenght of dataset: {len(mir_1k)}\")\n",
+ "idx = int(torch.randint(0, int(len(mir_1k)), (1,)))\n",
+ "print(f\"random sample id: {int(idx)}\")\n",
+ "print(f\"id_{idx} audio size: {mir_1k[idx][0].shape}\")\n",
+ "print(f\"id_{idx} label size: {mir_1k[idx][1].shape}\")\n",
+ "\n",
+ "error = dataset_scan(mir_1k, \"test mir_1k\")\n",
+ "\n",
+ "if not error:\n",
+ " print(\"Dataset analyzed and ready to be used\")\n",
+ "if error:\n",
+ " print(\"it seams that your dataset is not well formated.\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "** back10 **\n",
+ "lenght of dataset: 30\n",
+ "random sample id: 23\n",
+ "id_23 audio size: torch.Size([1837433])\n",
+ "id_23 label size: torch.Size([65])\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "test back10: 100%|██████████| 30/30 [00:00<00:00, 347.92it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Dataset analyzed and ready to be used\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Import: Bach-10\n",
+ "from crepe.dataset import Back10Dataset\n",
+ "back10 = Back10Dataset(root_dir=os.path.join(dir1, \"dataset/Bach10\"))\n",
+ "\n",
+ "# Test: MIR-1K\n",
+ "print(\"** back10 **\")\n",
+ "print(f\"lenght of dataset: {len(back10)}\")\n",
+ "idx = int(torch.randint(0, int(len(back10)), (1,)))\n",
+ "print(f\"random sample id: {int(idx)}\")\n",
+ "print(f\"id_{idx} audio size: {back10[idx][0].shape}\")\n",
+ "print(f\"id_{idx} label size: {back10[idx][1].shape}\")\n",
+ "\n",
+ "error = dataset_scan(back10, \"test back10\")\n",
+ "\n",
+ "if not error:\n",
+ " print(\"Dataset analyzed and ready to be used\")\n",
+ "if error:\n",
+ " print(\"it seams that your dataset is not well formated.\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "** nsynth **\n",
+ "lenght of dataset: 10062\n",
+ "random sample id: 5275\n",
+ "id_5275 audio size: torch.Size([1920000])\n",
+ "id_5275 label size: torch.Size([48000])\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "test nsynth: 0%| | 16/10062 [00:00<03:22, 49.70it/s] \n"
+ ]
+ }
+ ],
+ "source": [
+ "# Import: Nsynth\n",
+ "from crepe.dataset import NSynthDataset\n",
+ "nsynth = NSynthDataset(root_dir=os.path.join(dir1, \"dataset/Nsynth-mixed\"), n_samples=30)\n",
+ "\n",
+ "# Test: MIR-1K\n",
+ "print(\"** nsynth **\")\n",
+ "print(f\"lenght of dataset: {len(nsynth)}\")\n",
+ "idx = int(torch.randint(0, int(len(nsynth)), (1,)))\n",
+ "print(f\"random sample id: {int(idx)}\")\n",
+ "print(f\"id_{idx} audio size: {nsynth[idx][0].shape}\")\n",
+ "print(f\"id_{idx} label size: {nsynth[idx][1].shape}\")\n",
+ "\n",
+ "error = dataset_scan(nsynth, \"test nsynth\")\n",
+ "\n",
+ "if not error:\n",
+ " print(\"Dataset analyzed and ready to be used\")\n",
+ "if error:\n",
+ " print(\"it seams that your dataset is not well formated.\")"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": ".venv",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.12.4"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/notebooks/model.ipynb b/notebooks/model.ipynb
new file mode 100644
index 0000000..bfd683e
--- /dev/null
+++ b/notebooks/model.ipynb
@@ -0,0 +1,303 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# fix relative import\n",
+ "import os, sys\n",
+ "dir2 = os.path.abspath('')\n",
+ "dir1 = os.path.dirname(dir2)\n",
+ "if not dir1 in sys.path: sys.path.append(dir1)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import torch"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "** convBlock **\n",
+ "layers: \n",
+ "ConvBlock(\n",
+ " (layer): Sequential(\n",
+ " (0): ZeroPad2d((0, 0, 1, 2))\n",
+ " (1): Conv2d(3, 6, kernel_size=(4, 1), stride=(2, 2))\n",
+ " (2): ReLU()\n",
+ " (3): BatchNorm2d(6, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (4): MaxPool2d(kernel_size=(2, 1), stride=(2, 1), padding=0, dilation=1, ceil_mode=False)\n",
+ " (5): Dropout(p=0.25, inplace=False)\n",
+ " )\n",
+ ")\n",
+ "verification: True\n"
+ ]
+ }
+ ],
+ "source": [
+ "# ConvBlock\n",
+ "from crepe.model import ConvBlock\n",
+ "convBlock = ConvBlock(\n",
+ " in_channels=3, # Number of input channels\n",
+ " out_channels=6, # Number of output channels\n",
+ " kernel_width=4, # Width of the convolution kernel\n",
+ " stride=2 # Stride of the convolution\n",
+ ")\n",
+ "expected = \"\"\"ConvBlock(\n",
+ " (layer): Sequential(\n",
+ " (0): ZeroPad2d((0, 0, 1, 2))\n",
+ " (1): Conv2d(3, 6, kernel_size=(4, 1), stride=(2, 2))\n",
+ " (2): ReLU()\n",
+ " (3): BatchNorm2d(6, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (4): MaxPool2d(kernel_size=(2, 1), stride=(2, 1), padding=0, dilation=1, ceil_mode=False)\n",
+ " (5): Dropout(p=0.25, inplace=False)\n",
+ " )\n",
+ ")\"\"\"\n",
+ "print(\"** convBlock **\")\n",
+ "print(\"layers: \")\n",
+ "print(str(convBlock))\n",
+ "print(f\"verification: {(str(convBlock)==expected)}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "crepe device: mps\n",
+ "crepe model_capacity: tiny\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Import crepe\n",
+ "from crepe.model import Crepe\n",
+ "\n",
+ "device = torch.device('cuda' if torch.cuda.is_available(\n",
+ ") else 'mps' if torch.backends.mps.is_available() else 'cpu')\n",
+ "\n",
+ "crepe = Crepe(model_capacity='tiny', device=device)\n",
+ "\n",
+ "print(f\"crepe device: \", crepe.device)\n",
+ "print(f\"crepe model_capacity: \", crepe.model_capacity)\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "/Users/chloelavrat/Documents/GitHub/TorchCrepe/dataset/MIR-1K/Wavfile/amy_10_04.wav\n",
+ "audio test\n",
+ "load audio MIR 1K id_136\n",
+ "sr: 16000 Hz\n",
+ "audio shape: torch.Size([101889])\n",
+ "\n",
+ "Process..\n",
+ "activation Shape: torch.Size([637, 360])\n",
+ "confidence Shape: torch.Size([637])\n",
+ "frequency Shape: torch.Size([637])\n",
+ "time Shape: torch.Size([637])\n"
+ ]
+ }
+ ],
+ "source": [
+ "# load fake audio to crepe (randn)\n",
+ "from crepe.dataset import MIR1KDataset\n",
+ "mir_1k = MIR1KDataset(root_dir=os.path.join(dir1, \"dataset/MIR-1K\"))\n",
+ "\n",
+ "sr = 16000\n",
+ "idx = 136#int(torch.randint(0, int(len(mir_1k)), (1,)))\n",
+ "audio = mir_1k[idx][0]\n",
+ "labels = mir_1k[idx][1]\n",
+ "\n",
+ "print(mir_1k.files[136])\n",
+ "\n",
+ "audio = (audio - torch.mean(audio) ) / (torch.max(audio))\n",
+ "print(\"audio test\")\n",
+ "print(f\"load audio MIR 1K id_{idx}\")\n",
+ "print(f\"sr: {sr} Hz\")\n",
+ "print(f\"audio shape: {audio.shape}\")\n",
+ "\n",
+ "time, frequency, confidence, activation = crepe.predict(\n",
+ " audio=audio,\n",
+ " sr = sr\n",
+ ")\n",
+ "print(\"\\nProcess..\")\n",
+ "print(f\"activation Shape: {activation.shape}\")\n",
+ "print(f\"confidence Shape: {frequency.shape}\")\n",
+ "print(f\"frequency Shape: {frequency.shape}\")\n",
+ "print(f\"time Shape: {frequency.shape}\")\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/png": "",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "import matplotlib.pyplot as plt\n",
+ "%matplotlib inline\n",
+ "\n",
+ "salience = activation.flip(1)\n",
+ "salience_transposed = salience.transpose(0, 1) # Transpose the axes\n",
+ "plt.figure(figsize=(10, 6)) # Adjust the figure size\n",
+ "plt.imshow(salience_transposed.detach().numpy(), cmap='inferno', aspect='auto')\n",
+ "plt.colorbar(label='Activation') # Add a color bar for reference\n",
+ "plt.title('Salience Map')\n",
+ "plt.xlabel('Sample Index') # Adjusted based on transposition\n",
+ "plt.ylabel('Feature Dimension') # Adjusted based on transposition\n",
+ "plt.ylim(350, 300) # Set the y-axis range from 350 to 250\n",
+ "plt.show()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/png": "",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "import torch\n",
+ "import matplotlib.pyplot as plt\n",
+ "\n",
+ "# Assuming frequency and labels are PyTorch tensors\n",
+ "frequency_np = frequency.detach().numpy()\n",
+ "\n",
+ "# Interpolate labels to match the size of frequency\n",
+ "labels_resized = torch.nn.functional.interpolate(\n",
+ " labels.unsqueeze(0).unsqueeze(0), # Add batch and channel dimensions\n",
+ " size=frequency.shape[0], # Target size\n",
+ " mode='linear', align_corners=True # Linear interpolation\n",
+ ").squeeze(0).squeeze(0) # Remove the added dimensions\n",
+ "\n",
+ "labels_np = labels_resized.detach().numpy()\n",
+ "\n",
+ "# Create the plot with two y-axes\n",
+ "fig, ax1 = plt.subplots(figsize=(10, 4))\n",
+ "\n",
+ "# Plot frequency on the primary y-axis\n",
+ "ax1.plot(frequency_np, color='b', label='Frequency')\n",
+ "ax1.set_xlabel('Time')\n",
+ "ax1.set_ylabel('Frequency (Hz)', color='b')\n",
+ "ax1.tick_params(axis='y', labelcolor='b')\n",
+ "\n",
+ "# Create a secondary y-axis\n",
+ "ax1.plot(labels_np, color='r', label='Labels (Interpolated)')\n",
+ "ax1.set_ylabel('Labels', color='r')\n",
+ "ax1.tick_params(axis='y', labelcolor='r')\n",
+ "\n",
+ "# Add legends for clarity\n",
+ "ax1.legend(loc='upper left')\n",
+ "ax1.legend(loc='upper right')\n",
+ "\n",
+ "# Show the plot\n",
+ "plt.show()\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [
+ {
+ "ename": "TypeError",
+ "evalue": " is not a generic class",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
+ "Cell \u001b[0;32mIn[8], line 5\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[38;5;66;03m# Download and load a built-in dataset\u001b[39;00m\n\u001b[1;32m 4\u001b[0m dataset \u001b[38;5;241m=\u001b[39m torchaudio\u001b[38;5;241m.\u001b[39mdatasets\u001b[38;5;241m.\u001b[39mSPEECHCOMMANDS\n\u001b[0;32m----> 5\u001b[0m data, sampling_rate \u001b[38;5;241m=\u001b[39m \u001b[43mdataset\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m]\u001b[49m\n\u001b[1;32m 6\u001b[0m \u001b[38;5;28mprint\u001b[39m(data\u001b[38;5;241m.\u001b[39mshape)\n\u001b[1;32m 8\u001b[0m time, frequency, confidence, activation \u001b[38;5;241m=\u001b[39m crepe\u001b[38;5;241m.\u001b[39mpredict(\n\u001b[1;32m 9\u001b[0m audio\u001b[38;5;241m=\u001b[39maudio,\n\u001b[1;32m 10\u001b[0m sr \u001b[38;5;241m=\u001b[39m sr\n\u001b[1;32m 11\u001b[0m )\n",
+ "File \u001b[0;32m/opt/homebrew/Cellar/python@3.12/3.12.4/Frameworks/Python.framework/Versions/3.12/lib/python3.12/typing.py:398\u001b[0m, in \u001b[0;36m_tp_cache..decorator..inner\u001b[0;34m(*args, **kwds)\u001b[0m\n\u001b[1;32m 396\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m:\n\u001b[1;32m 397\u001b[0m \u001b[38;5;28;01mpass\u001b[39;00m \u001b[38;5;66;03m# All real errors (not unhashable args) are raised below.\u001b[39;00m\n\u001b[0;32m--> 398\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwds\u001b[49m\u001b[43m)\u001b[49m\n",
+ "File \u001b[0;32m/opt/homebrew/Cellar/python@3.12/3.12.4/Frameworks/Python.framework/Versions/3.12/lib/python3.12/typing.py:1101\u001b[0m, in \u001b[0;36m_generic_class_getitem\u001b[0;34m(cls, params)\u001b[0m\n\u001b[1;32m 1099\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m prepare \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 1100\u001b[0m params \u001b[38;5;241m=\u001b[39m prepare(\u001b[38;5;28mcls\u001b[39m, params)\n\u001b[0;32m-> 1101\u001b[0m \u001b[43m_check_generic\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mcls\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mparams\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mlen\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mcls\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m__parameters__\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1103\u001b[0m new_args \u001b[38;5;241m=\u001b[39m []\n\u001b[1;32m 1104\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m param, new_arg \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mzip\u001b[39m(\u001b[38;5;28mcls\u001b[39m\u001b[38;5;241m.\u001b[39m__parameters__, params):\n",
+ "File \u001b[0;32m~/Documents/GitHub/TorchCrepe/.venv/lib/python3.12/site-packages/typing_extensions.py:2922\u001b[0m, in \u001b[0;36m_check_generic\u001b[0;34m(cls, parameters, elen)\u001b[0m\n\u001b[1;32m 2917\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Check correct count for parameters of a generic cls (internal helper).\u001b[39;00m\n\u001b[1;32m 2918\u001b[0m \n\u001b[1;32m 2919\u001b[0m \u001b[38;5;124;03mThis gives a nice error message in case of count mismatch.\u001b[39;00m\n\u001b[1;32m 2920\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 2921\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m elen:\n\u001b[0;32m-> 2922\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mcls\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m is not a generic class\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 2923\u001b[0m alen \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mlen\u001b[39m(parameters)\n\u001b[1;32m 2924\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m alen \u001b[38;5;241m!=\u001b[39m elen:\n",
+ "\u001b[0;31mTypeError\u001b[0m: is not a generic class"
+ ]
+ }
+ ],
+ "source": [
+ "import torchaudio\n",
+ "\n",
+ "# Download and load a built-in dataset\n",
+ "dataset = torchaudio.datasets.SPEECHCOMMANDS\n",
+ "data, sampling_rate = dataset[0]\n",
+ "print(data.shape)\n",
+ "\n",
+ "time, frequency, confidence, activation = crepe.predict(\n",
+ " audio=audio,\n",
+ " sr = sr\n",
+ ")"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": ".venv",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.12.4"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/notebooks/train.ipynb b/notebooks/train.ipynb
new file mode 100644
index 0000000..40c1400
--- /dev/null
+++ b/notebooks/train.ipynb
@@ -0,0 +1,145 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Train"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# fix relative import\n",
+ "import os, sys, torch\n",
+ "dir2 = os.path.abspath('')\n",
+ "dir1 = os.path.dirname(dir2)\n",
+ "if not dir1 in sys.path: sys.path.append(dir1)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "crepe device: mps\n",
+ "crepe model_capacity: tiny\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Import crepe\n",
+ "from crepe.model import Crepe\n",
+ "\n",
+ "device = torch.device('cuda' if torch.cuda.is_available(\n",
+ ") else 'mps' if torch.backends.mps.is_available() else 'cpu')\n",
+ "\n",
+ "crepe = Crepe(model_capacity='tiny', device=device)\n",
+ "\n",
+ "print(f\"crepe device: \", crepe.device)\n",
+ "print(f\"crepe model_capacity: \", crepe.model_capacity)\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "random sample id: 359\n",
+ "id_359 audio size: torch.Size([176641])\n",
+ "id_359 label size: torch.Size([551])\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Import: MIR-1K (selection of one sample)\n",
+ "from crepe.dataset import MIR1KDataset\n",
+ "mir_1k = MIR1KDataset(root_dir=os.path.join(dir1, \"dataset/MIR-1K\"))\n",
+ "\n",
+ "sr = 16000\n",
+ "idx = int(torch.randint(0, int(len(mir_1k)), (1,)))\n",
+ "print(f\"random sample id: {int(idx)}\")\n",
+ "print(f\"id_{idx} audio size: {mir_1k[idx][0].shape}\")\n",
+ "print(f\"id_{idx} label size: {mir_1k[idx][1].shape}\")\n",
+ "audio = mir_1k[idx][0]\n",
+ "labels = mir_1k[idx][1]\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "labels_activations shape: torch.Size([1105, 360])\n",
+ "model_activations shape: torch.Size([1105, 360])\n"
+ ]
+ }
+ ],
+ "source": [
+ "# train step test\n",
+ "from train import epoch_step\n",
+ "labels_activations, model_activations = epoch_step(crepe, audio, labels, sr, device)\n",
+ "\n",
+ "print(f\"labels_activations shape: {labels_activations.shape}\")\n",
+ "print(f\"model_activations shape: {model_activations.shape}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "loss is 0.694\n"
+ ]
+ }
+ ],
+ "source": [
+ "# test loss\n",
+ "import torch.nn as nn\n",
+ "criteron = nn.BCELoss()\n",
+ "loss = criteron(model_activations, labels_activations)\n",
+ "print(f\"loss is {float(loss):.3f}\")"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": ".venv",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.12.4"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/notebooks/usecrepe.ipynb b/notebooks/usecrepe.ipynb
new file mode 100644
index 0000000..2cc72cc
--- /dev/null
+++ b/notebooks/usecrepe.ipynb
@@ -0,0 +1,141 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# CREPE: How to use"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# fix relative import\n",
+ "import os, sys\n",
+ "dir2 = os.path.abspath('')\n",
+ "dir1 = os.path.dirname(dir2)\n",
+ "if not dir1 in sys.path: sys.path.append(dir1)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "crepe device: mps\n",
+ "crepe model_capacity: tiny\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Import crepe\n",
+ "import torch\n",
+ "from crepe.model import Crepe\n",
+ "\n",
+ "device = torch.device('cuda' if torch.cuda.is_available(\n",
+ ") else 'mps' if torch.backends.mps.is_available() else 'cpu')\n",
+ "\n",
+ "crepe = Crepe(model_capacity='tiny', device=device)\n",
+ "\n",
+ "print(f\"crepe device: \", crepe.device)\n",
+ "print(f\"crepe model_capacity: \", crepe.model_capacity)\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "Process..\n",
+ "activation Shape: torch.Size([637, 360])\n",
+ "confidence Shape: torch.Size([637])\n",
+ "frequency Shape: torch.Size([637])\n",
+ "time Shape: torch.Size([637])\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Import test file\n",
+ "from crepe.utils import load_test_file\n",
+ "\n",
+ "audio, sr = load_test_file(\"amy_10_04.wav\", mono=True, normalize=True)\n",
+ "\n",
+ "time, frequency, confidence, activation = crepe.predict(\n",
+ " audio=audio,\n",
+ " sr = sr\n",
+ ")\n",
+ "print(\"\\nProcess..\")\n",
+ "print(f\"activation Shape: {activation.shape}\")\n",
+ "print(f\"confidence Shape: {frequency.shape}\")\n",
+ "print(f\"frequency Shape: {frequency.shape}\")\n",
+ "print(f\"time Shape: {frequency.shape}\")\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/png": "",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "# plot results\n",
+ "import matplotlib.pyplot as plt\n",
+ "%matplotlib inline\n",
+ "\n",
+ "salience = activation.flip(1)\n",
+ "salience_transposed = salience.transpose(0, 1) # Transpose the axes\n",
+ "plt.figure(figsize=(10, 6)) # Adjust the figure size\n",
+ "plt.imshow(salience_transposed.detach().numpy(), cmap='inferno', aspect='auto')\n",
+ "plt.colorbar(label='Activation') # Add a color bar for reference\n",
+ "plt.title('Salience Map')\n",
+ "plt.xlabel('Sample Index') # Adjusted based on transposition\n",
+ "plt.ylabel('Feature Dimension') # Adjusted based on transposition\n",
+ "plt.ylim(350, 300) # Set the y-axis range from 350 to 250\n",
+ "plt.show()"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": ".venv",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.12.4"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/notebooks/utils.ipynb b/notebooks/utils.ipynb
new file mode 100644
index 0000000..4ab74a4
--- /dev/null
+++ b/notebooks/utils.ipynb
@@ -0,0 +1,135 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Utils"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# fix relative import\n",
+ "import os, sys\n",
+ "dir2 = os.path.abspath('')\n",
+ "dir1 = os.path.dirname(dir2)\n",
+ "if not dir1 in sys.path: sys.path.append(dir1)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import torch\n",
+ "import torch.nn as nn\n",
+ "from crepe.utils import activation_to_frequency, frequency_to_activation"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([1, 8000])"
+ ]
+ },
+ "execution_count": 3,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "sr = 16000\n",
+ "test_freq = torch.arange(0., float(sr//2), 1).unsqueeze(0)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# \n",
+ "activation = frequency_to_activation(test_freq[0,:])\n",
+ "freq_out = activation_to_frequency(activation)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/png": "",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "# plot \n",
+ "import matplotlib.pyplot as plt\n",
+ "\n",
+ "fig, (ax1, ax2) = plt.subplots(2, 1) # Create two subplots, one above the other\n",
+ "\n",
+ "%matplotlib inline\n",
+ "plt.figure(figsize=(10, 4)) # Adjust the figure size\n",
+ "\n",
+ "# First plot on the first set of axes\n",
+ "ax1.plot(test_freq[0, :], color='b')\n",
+ "ax1.set_ylabel('Test Frequency')\n",
+ "ax1.set_xlabel('X-Axis for Test Frequency')\n",
+ "\n",
+ "# Second plot on the second set of axes\n",
+ "ax2.plot(freq_out, color='r')\n",
+ "ax2.set_ylabel('Frequency Out')\n",
+ "ax2.set_xlabel('X-Axis for Frequency Out')\n",
+ "\n",
+ "plt.show()\n"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": ".venv",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.12.4"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000..55da211
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,28 @@
+[tool.poetry]
+name = "CrepeTorch"
+version = "1.0.0"
+description = "**CrepeTorch** is a fundamental frequency detector on audio."
+authors = ["Chloé Lavrat "]
+license = "Copyright (c) Chloé Lavrat 2024"
+readme = "README.md"
+packages = [{include = "crepe"}]
+
+[tool.poetry.dependencies]
+python = "^3.12"
+pytest = "^6.2"
+torch = "^2.4"
+torchaudio = "^2.4"
+PySoundFile = "^0.9"
+tqdm = "^4.66.4"
+
+
+[tool.poetry.dev-dependencies]
+pytest = "^6.2.5" # Example development dependency
+
+[tool.poetry.urls]
+homepage = "chloelavrat.com"
+repository = "https://github.com/chloelavrat/TorchCrepe"
+
+[build-system]
+requires = ["poetry-core"]
+build-backend = "poetry.core.masonry.api"
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..3a8b8fe
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,108 @@
+aiofiles==24.1.0
+aiohappyeyeballs==2.4.0
+aiohttp==3.10.5
+aiosignal==1.3.1
+altair==5.3.0
+appnope==0.1.4
+asttokens==2.4.1
+attrs==23.2.0
+audioread==3.0.1
+blinker==1.8.2
+cachetools==5.4.0
+certifi==2024.7.4
+cffi==1.16.0
+charset-normalizer==3.3.2
+click==8.1.7
+comm==0.2.2
+contourpy==1.2.1
+cycler==0.12.1
+debugpy==1.8.2
+decorator==5.1.1
+docker-pycreds==0.4.0
+executing==2.0.1
+filelock==3.15.4
+fonttools==4.53.1
+frozenlist==1.4.1
+fsspec==2024.6.1
+gitdb==4.0.11
+GitPython==3.1.43
+hmmlearn==0.3.2
+idna==3.7
+ipykernel==6.29.5
+ipython==8.26.0
+jedi==0.19.1
+Jinja2==3.1.4
+joblib==1.4.2
+jsonschema==4.23.0
+jsonschema-specifications==2023.12.1
+jupyter_client==8.6.2
+jupyter_core==5.7.2
+kiwisolver==1.4.5
+lazy_loader==0.4
+librosa==0.10.2.post1
+llvmlite==0.43.0
+markdown-it-py==3.0.0
+MarkupSafe==2.1.5
+matplotlib==3.9.1
+matplotlib-inline==0.1.7
+mdurl==0.1.2
+mpmath==1.3.0
+msgpack==1.0.8
+multidict==6.0.5
+nest-asyncio==1.6.0
+networkx==3.3
+numba==0.60.0
+numpy==2.0.1
+packaging==24.1
+pandas==2.2.2
+parso==0.8.4
+pexpect==4.9.0
+pillow==10.4.0
+platformdirs==4.2.2
+pooch==1.8.2
+prompt_toolkit==3.0.47
+protobuf==5.27.2
+psutil==6.0.0
+ptyprocess==0.7.0
+pure_eval==0.2.3
+pyarrow==17.0.0
+pycparser==2.22
+pydeck==0.9.1
+Pygments==2.18.0
+pyparsing==3.1.2
+PySoundFile==0.9.0.post1
+python-dateutil==2.9.0.post0
+pytz==2024.1
+PyYAML==6.0.1
+pyzmq==26.0.3
+referencing==0.35.1
+requests==2.32.3
+rich==13.7.1
+rpds-py==0.19.1
+scikit-learn==1.5.1
+scipy==1.14.0
+sentry-sdk==2.11.0
+setproctitle==1.3.3
+setuptools==71.1.0
+six==1.16.0
+smmap==5.0.1
+soundfile==0.12.1
+soxr==0.4.0
+stack-data==0.6.3
+streamlit==1.37.0
+sympy==1.13.1
+tenacity==8.5.0
+threadpoolctl==3.5.0
+toml==0.10.2
+toolz==0.12.1
+torch==2.4.0
+torchaudio==2.4.0
+tornado==6.4.1
+tqdm==4.66.4
+traitlets==5.14.3
+typing_extensions==4.12.2
+tzdata==2024.1
+urllib3==2.2.2
+wandb==0.17.5
+wcwidth==0.2.13
+yarl==1.9.4
diff --git a/scripts/package_change_version.sh b/scripts/package_change_version.sh
new file mode 100644
index 0000000..94ce90a
--- /dev/null
+++ b/scripts/package_change_version.sh
@@ -0,0 +1,22 @@
+#!/bin/bash
+
+# Define the new version number as an argument
+NEW_VERSION="$1"
+
+# Check if a version number was provided
+if [ -z "$NEW_VERSION" ]; then
+ echo "Usage: $0 "
+ exit 1
+fi
+
+# Path to the pyproject.toml file
+PYPROJECT_FILE="pyproject.toml"
+
+# Use sed to replace the version number in the pyproject.toml file
+sed -i.bak "s/^version = \".*\"/version = \"$NEW_VERSION\"/" "$PYPROJECT_FILE"
+
+# Print a success message
+echo "Version updated to $NEW_VERSION in $PYPROJECT_FILE"
+
+# Optional: Remove backup file created by sed
+rm "${PYPROJECT_FILE}.bak"
diff --git a/scripts/package_generate_doc.sh b/scripts/package_generate_doc.sh
new file mode 100644
index 0000000..b1783a0
--- /dev/null
+++ b/scripts/package_generate_doc.sh
@@ -0,0 +1,53 @@
+#!/bin/bash
+
+SUCCESS='\033[0;32m'
+RESET='\033[0m'
+
+# Function to print messages in green
+print_success() {
+ echo -e "${SUCCESS}$1${RESET}"
+}
+
+# Function to activate the Python virtual environment
+activate_venv() {
+ print_success "Activating Python virtual environment..."
+ source .venv/bin/activate > /dev/null
+}
+
+# Function to check if a Python package is installed
+check_package() {
+ if ! python -c "import $1" 2>/dev/null; then
+ echo "Error: $1 package is missing."
+ exit 1
+ fi
+}
+
+# Function to upgrade pip
+upgrade_pip() {
+ print_success "Upgrading pip to the latest version..."
+ pip install --upgrade pip > /dev/null
+}
+
+# Function to install a Python package
+install_package() {
+ local package=$1
+ print_success "Installing ${package} package..."
+ pip install "$package" > /dev/null
+}
+
+# Function to generate API documentation using lazydocs
+generate_docs() {
+ print_success "Creating API documentation with lazydocs..."
+ lazydocs \
+ --output-path="./docs/api-docs" \
+ --overview-file="README.md" \
+ --src-base-url="https://github.com/" \
+ APOLLO_LIBRARY > /dev/null
+}
+
+# Main script execution
+activate_venv
+check_package "pip"
+upgrade_pip
+install_package "lazydocs"
+generate_docs
diff --git a/scripts/virtualenv_activate.sh b/scripts/virtualenv_activate.sh
new file mode 100644
index 0000000..28ee5d4
--- /dev/null
+++ b/scripts/virtualenv_activate.sh
@@ -0,0 +1,2 @@
+#!/bin/bash
+source .venv/bin/activate
diff --git a/scripts/virtualenv_create.sh b/scripts/virtualenv_create.sh
new file mode 100644
index 0000000..8c6d3de
--- /dev/null
+++ b/scripts/virtualenv_create.sh
@@ -0,0 +1,41 @@
+#!/bin/bash
+
+# Define color codes
+SUCCESS='\033[0;32m'
+RESET='\033[0m' # Reset color
+
+# Function to create a virtual environment
+create_virtual_env() {
+ echo -e "${SUCCESS}Setting up virtual environment...${RESET}"
+ python3 -m venv .venv > /dev/null
+}
+
+# Function to activate the virtual environment
+activate_virtual_env() {
+ echo -e "${SUCCESS}Activating virtual environment...${RESET}"
+ source .venv/bin/activate > /dev/null
+}
+
+# Function to upgrade pip
+upgrade_pip() {
+ echo -e "${SUCCESS}Upgrading pip...${RESET}"
+ pip install --upgrade pip > /dev/null
+}
+
+# Function to install dependencies
+install_dependencies() {
+ echo -e "${SUCCESS}Installing dependencies from requirements file...${RESET}"
+ pip install -r requirements.txt
+}
+
+# Function to display completion message
+display_completion() {
+ echo -e "${SUCCESS}Setup complete!${RESET}"
+}
+
+# Main script execution
+create_virtual_env
+activate_virtual_env
+upgrade_pip
+install_dependencies
+display_completion
\ No newline at end of file
diff --git a/train.py b/train.py
new file mode 100644
index 0000000..e92e0bb
--- /dev/null
+++ b/train.py
@@ -0,0 +1,210 @@
+"""This file the code necessary to train the Crepe model."""
+
+import torch
+import torch.nn as nn
+from tqdm import tqdm
+import torch.nn.functional as F
+from crepe.utils import frequency_to_activation
+from crepe.model import Crepe
+
+
+def epoch_step(model, audio, labels, sr, device):
+ """
+ Compute model and label activations for a given batch.
+
+ Args:
+ model: The Crepe model to use.
+ audio: A tensor of audio data.
+ labels: A tensor of label data.
+ sr: The sample rate of the audio data.
+ device: The device (CPU or GPU) to move the tensors to.
+
+ Returns:
+ A tuple containing the label activations and the model activations.
+ """
+ # send tensor to device
+ audio = audio.to(device)
+ labels = labels.to(device)
+
+ # compute label activations
+ labels_activations = frequency_to_activation(labels[0, :])
+
+ # compute model activations
+ model_activations = model.get_activation(audio, sr).to(device)
+
+ labels_activations = F.interpolate(labels_activations.unsqueeze(0).unsqueeze(0), size=(
+ model_activations.shape[0], labels_activations.shape[1]), mode='bilinear', align_corners=False).squeeze(0).squeeze(0)
+
+ # return activations
+ return labels_activations, model_activations
+
+
+def train_epoch(model, optimizer, dataloader, sr, device):
+ """
+ Train the Crepe model for one epoch.
+
+ Args:
+ model: The Crepe model to train.
+ optimizer: The optimizer to use.
+ dataloader: A DataLoader containing the training data.
+ sr: The sample rate of the audio data.
+ device: The device (CPU or GPU) to move the tensors to.
+
+ Returns:
+ The average loss over all batches in this epoch.
+ """
+ model.train() # Set the model to training mode
+ running_loss = 0.0
+
+ criterion = nn.BCELoss() # Initialize the loss function
+
+ for i, (audio, labels) in enumerate(tqdm(dataloader)):
+ audio, labels = audio.to(device), labels.to(
+ device) # Move data to the correct device
+
+ optimizer.zero_grad() # Zero the gradients
+
+ # Forward pass
+ labels_activations, model_activations = epoch_step(
+ model=model,
+ audio=audio,
+ labels=labels,
+ sr=sr,
+ device=device
+ )
+
+ # Compute the loss
+ loss = criterion(model_activations, labels_activations)
+ running_loss += loss.item() * audio.size(0)
+
+ # Backward pass and optimization
+ loss.backward()
+ optimizer.step()
+
+ return running_loss / len(dataloader.dataset)
+
+
+def validate_epoch(model, dataloader, sr, device):
+ """
+ Validate the Crepe model for one epoch.
+
+ Args:
+ model: The Crepe model to validate.
+ dataloader: A DataLoader containing the validation data.
+ sr: The sample rate of the audio data.
+ device: The device (CPU or GPU) to move the tensors to.
+
+ Returns:
+ The average loss over all batches in this epoch.
+ """
+ model.eval() # Set the model to evaluation mode
+ running_loss = 0.0
+ criterion = nn.BCELoss() # Initialize the loss function
+
+ with torch.no_grad(): # Disable gradient calculation
+ for i, (audio, labels) in enumerate(tqdm(dataloader)):
+ audio, labels = audio.to(device), labels.to(
+ device) # Move data to the correct device
+
+ # Forward pass
+ labels_activations, model_activations = epoch_step(
+ model=model,
+ audio=audio,
+ labels=labels,
+ sr=sr,
+ device=device
+ )
+
+ # Compute the loss
+ loss = criterion(model_activations, labels_activations)
+ running_loss += loss.item() * audio.size(0)
+
+ return running_loss / len(dataloader.dataset)
+
+
+if __name__ == "__main__":
+ from crepe.dataset import MIR1KDataset, Back10Dataset, NSynthDataset
+ from torch.utils.data import DataLoader, random_split, ConcatDataset
+
+ model_capacity = 'small'
+ learning_rate = 0.0002
+ num_epoch = 50000
+ num_batches_per_epoch = 8
+ sr = 16000
+ max_epochs_without_improvement = 32
+
+ device = torch.device('cuda' if torch.cuda.is_available(
+ ) else 'mps' if torch.backends.mps.is_available() else 'cpu')
+
+ model = Crepe(model_capacity=model_capacity).to(device)
+
+ # dataset
+ mir_1k = MIR1KDataset(root_dir="./dataset/MIR-1K")
+ back10 = Back10Dataset(root_dir="./dataset/Bach10")
+# nsynth = NSynthDataset(root_dir="./dataset/Nsynth-mixed", n_samples=30)
+ dataset = ConcatDataset([back10, mir_1k])
+
+ # set train, validation dataset sizes
+ train_size = int(0.8 * len(dataset))
+ val_size = len(dataset) - train_size
+ train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
+
+ train_loader = DataLoader(
+ train_dataset,
+ shuffle=True,
+ num_workers=4,
+ pin_memory=True
+ )
+ val_loader = DataLoader(
+ val_dataset,
+ shuffle=False,
+ num_workers=4,
+ pin_memory=True
+ )
+
+ # set optimizer
+ optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
+
+ best_val_loss = float('inf')
+ epochs_without_improvement = 0
+ accumulation_steps = 4
+
+ for epoch in range(1, num_epoch + 1):
+ train_loss = 0.0
+ for _ in range(num_batches_per_epoch):
+ tmp_loss = train_epoch(
+ model,
+ optimizer,
+ train_loader,
+ sr,
+ device
+ )
+ train_loss += tmp_loss
+ # compute train loss
+ train_loss /= num_batches_per_epoch
+
+ # validation step
+ val_loss = validate_epoch(
+ model,
+ val_loader,
+ sr,
+ device
+ )
+
+ print(f'Epoch {epoch}, Train Loss: {
+ train_loss:.4f}, Val Loss: {val_loss:.4f}')
+
+ if val_loss < best_val_loss:
+ best_val_loss = val_loss
+ torch.save(model.state_dict(),
+ f'crepe/crepe_{model_capacity}_best.pth')
+ epochs_without_improvement = 0
+ else:
+ epochs_without_improvement += 1
+
+ if epochs_without_improvement >= max_epochs_without_improvement:
+ print("Stopping early due to no improvement in validation loss.")
+ break
+
+ # save model
+ torch.save(model.state_dict(), f'crepe/crepe_{model_capacity}_final.pth')