From e72f16a51bd4b64931315817cb3e88238deae315 Mon Sep 17 00:00:00 2001 From: unknown Date: Mon, 18 Dec 2023 20:56:37 +0300 Subject: [PATCH 01/15] Create empty branch From def592a4f88cba45ea8476dcf57edf39d5eadbe4 Mon Sep 17 00:00:00 2001 From: d-a-yakovlev Date: Fri, 29 Dec 2023 16:22:09 +0300 Subject: [PATCH 02/15] refac: full - run_docker, almost - dataset, inference --- run_docker.sh | 2 +- separator/config/config.py | 2 + separator/data/dataset.py | 342 ++++++++++++++++++++++++------------- separator/inference.py | 90 ++++++---- 4 files changed, 288 insertions(+), 148 deletions(-) diff --git a/run_docker.sh b/run_docker.sh index 30385bf..46e60a1 100644 --- a/run_docker.sh +++ b/run_docker.sh @@ -1,6 +1,6 @@ #!/bin/bash -app=$PWD +app=$(pwd) docker run --name pmunet -it --rm \ --net=host --ipc=host \ diff --git a/separator/config/config.py b/separator/config/config.py index c2e621b..e4a290e 100644 --- a/separator/config/config.py +++ b/separator/config/config.py @@ -79,6 +79,8 @@ class InferenceConfig: # weights weights_dir: Path = Path("/app/separator/inference/weights") + weights_LSTM_filename: str = "weight_LSTM.pt" + weights_conv_filename: str = "weight_conv.pt" gdrive_weights_LSTM: str = f"{GDRIVE_PREFIX}18jT2TYffdRD1fL7wecAiM5nJPM_OKpNB" gdrive_weights_conv: str = f"{GDRIVE_PREFIX}1VO07OYbsnCuEJYRSuA8HhjlQnx6dbWX7" diff --git a/separator/data/dataset.py b/separator/data/dataset.py index 90b23b2..0f3567c 100644 --- a/separator/data/dataset.py +++ b/separator/data/dataset.py @@ -1,4 +1,6 @@ from collections import OrderedDict +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass import hashlib import math import json @@ -10,52 +12,69 @@ import musdb import julius import torch as th -from torch import distributed import torchaudio as ta from torch.nn import functional as F +@dataclass class File: - MIXTURE = "mixture" - EXT = ".wav" + MIXTURE: str = "mixture" + EXT: str = ".wav" def get_musdb_wav_datasets( musdb="musdb18hq", - musdb_samplerate=44100, - use_musdb=True, segment=11, shift=1, train_valid=False, - full_cv=True, samplerate=44100, channels=2, normalize=True, metadata="./metadata", sources=["drums", "bass", "other", "vocals"], - backend=None, data_type="train", ): """ - Extract the musdb dataset from the XP arguments. + Prepares and retrieves the MUSDB18-HQ dataset for audio source separation. + + This function handles the dataset preparation by creating necessary metadata files and setting up the data loader configuration. It then returns a dataset instance ready for use in training or evaluation. + + Args: + musdb (str): Path to the MUSDB18-HQ dataset directory. + segment (int): Length in seconds of each audio segment for processing. + shift (int): Stride in seconds between consecutive audio segments. + train_valid (bool): Flag to determine if training or validation data should be used. + samplerate (int): Target sample rate for audio processing. + channels (int): Number of audio channels (e.g., 1 for mono, 2 for stereo). + normalize (bool): Whether to normalize audio based on the entire track. + metadata (str): Path for saving the generated metadata. + sources (list[str]): List of source names to be included in the dataset. + data_type (str): Type of data to process ('train' or 'test'). + + Returns: + Wavset: An instance of the Wavset class configured for the specified dataset. """ + + # Create a unique identifier for the dataset configuration. sig = hashlib.sha1(str(musdb).encode()).hexdigest()[:8] - metadata_file = Path(metadata) / ("musdb_" + sig + ".json") - root = Path(musdb) / data_type + + metadata_file = Path(metadata) / f"musdb_{sig}.json" + root = Path(musdb) / data_type + + # Build metadata if not already present. if not metadata_file.is_file(): - metadata_file.parent.mkdir(exist_ok=True, parents=True) - metadata = MetaData.build_metadata(root, sources) - json.dump(metadata, open(metadata_file, "w")) + metadata_file.parent.mkdir(exist_ok=True) + metadata_content = MetaData.build_metadata(root, sources) + json.dump(metadata_content, open(metadata_file, "w")) + + # Load metadata from the file. metadata = json.load(open(metadata_file)) - valid_tracks = _get_musdb_valid() - if train_valid: - metadata_train = metadata - else: - metadata_train = { - name: meta for name, meta in metadata.items() if name not in valid_tracks - } + # Filter tracks for training or validation based on the configuration. + valid_tracks = _get_musdb_valid() # Retrieve a list of valid track names. + metadata_train = metadata if train_valid else {name: meta for name, meta in metadata.items() if name not in valid_tracks} + # Configure and return the dataset instance. data_set = Wavset( root, metadata_train, @@ -93,28 +112,22 @@ def __init__( ext=File.EXT, ): """ - Waveset (or mp3 set for that matter). - Can be used to train with arbitrary sources. - Each track should be one folder inside of `path`. - The folder should contain files named `{source}.{ext}`. + A dataset class for audio source separation, compatible with WAV (or MP3) files. + This class allows training with arbitrary sources, where each audio track is represented by a separate folder within the specified root directory. Each folder should contain audio files for different sources, named as `{source}.{ext}`. Args: - root (Path or str): root folder for the dataset. - metadata (dict): output from `build_metadata`. - sources (list[str]): list of source names. - segment (None or float): segment length in seconds. - If `None`, returns entire tracks. - shift (None or float): stride in seconds bewteen samples. - normalize (bool): normalizes input audio, - **based on the metadata content**, - i.e. the entire track is normalized, not individual extracts. - samplerate (int): target sample rate. if the file sample rate - is different, it will be resampled on the fly. - channels (int): target nb of channels. if different, will be - changed onthe fly. - ext (str): extension for audio files (default is .wav). - - samplerate and channels are converted on the fly. + root (Path or str): The root directory of the dataset where audio tracks are stored. + metadata (dict): Metadata information generated by the `build_metadata` function. It contains details like track names and lengths. + sources (list[str]): A list of source names to be separated, e.g., ['drums', 'vocals']. + segment (Optional[float]): The length of each audio segment in seconds. If `None`, the entire track is used. + shift (Optional[float]): The stride in seconds between samples. Determines the overlap between consecutive audio segments. + normalize (bool): If True, normalizes the input audio based on the entire track's statistics, not just individual segments. + samplerate (int): The target sample rate. Audio files with a different sample rate will be resampled to this rate. + channels (int): The target number of audio channels. If different, audio will be converted accordingly. + ext (str): The file extension of the audio files, default is '.wav'. + + Note: + The `samplerate` and `channels` parameters are used to ensure consistency across the dataset. They allow on-the-fly conversion of audio properties to match the target specifications. """ self.root = Path(root) self.metadata = OrderedDict(metadata) @@ -131,9 +144,7 @@ def __init__( if segment is None or track_duration < segment: examples = 1 else: - examples = int( - math.ceil((track_duration - self.segment) / self.shift) + 1 - ) + examples = int(math.ceil((track_duration - self.segment) / self.shift) + 1) self.num_examples.append(examples) def __len__(self): @@ -143,82 +154,116 @@ def get_file(self, name, source): return self.root / name / f"{source}{self.ext}" def __getitem__(self, index): + """ + Get an audio example by index with applied transformations. + + Args: + index (int): The index of the audio example in the dataset. + + Returns: + Tensor: The processed audio example as a tensor. + """ + # Iterate over each audio source and adjust the index for each source for name, examples in zip(self.metadata, self.num_examples): if index >= examples: index -= examples continue + + # Access metadata for the current source meta = self.metadata[name] - num_frames = -1 - offset = 0 + + # Calculate offset and number of frames if segmenting is enabled + num_frames, offset = -1, 0 if self.segment is not None: offset = int(meta["samplerate"] * self.shift * index) num_frames = int(math.ceil(meta["samplerate"] * self.segment)) + + # Load and process audio from each source wavs = [] for source in self.sources: - file = self.get_file(name, source) - wav, _ = ta.load(str(file), frame_offset=offset, num_frames=num_frames) + file_path = self.get_file(name, source) + wav, _ = ta.load(str(file_path), frame_offset=offset, num_frames=num_frames) wav = self.__convert_audio_channels(wav, self.channels) wavs.append(wav) + # Stack, resample, and normalize the audio examples example = th.stack(wavs) example = julius.resample_frac(example, meta["samplerate"], self.samplerate) if self.normalize: example = (example - meta["mean"]) / meta["std"] + + # Pad the audio example if segmenting is used if self.segment: length = int(self.segment * self.samplerate) example = example[..., :length] example = F.pad(example, (0, length - example.shape[-1])) + return example - def __convert_audio_channels(self, wav, channels=2): - """Convert audio to the given number of channels.""" + def __convert_audio_channels(self, wav, desired_channels=2): + """ + Convert an audio waveform to the specified number of channels. + + Args: + wav (Tensor): The input waveform tensor with shape (..., channels, length). + desired_channels (int, optional): The number of channels for the output waveform. + Defaults to 2 for stereo output. + + Returns: + Tensor: The waveform with the desired number of channels. + + Raises: + ValueError: If the input audio has fewer channels than requested and is not mono. + + Description: + - If the input already has the desired number of channels, it is returned as is. + - If a mono to stereo conversion is needed, the mono channel is duplicated. + - If downmixing is needed (e.g., from 5.1 to stereo), only the first 'desired_channels' are kept. + - If upmixing is required (e.g., mono to 5.1), the single channel is replicated across all desired channels. + """ + *shape, src_channels, length = wav.shape - if src_channels == channels: - pass - elif channels == 1: - # Case 1: - # The caller asked 1-channel audio, but the stream have multiple - # channels, downmix all channels. - wav = wav.mean(dim=-2, keepdim=True) + + if src_channels == desired_channels: + # No change needed + return wav + elif src_channels > desired_channels: + # Downmix by slicing to the desired number of channels + return wav[..., :desired_channels, :] elif src_channels == 1: - # Case 2: - # The caller asked for multiple channels, but the input file have - # one single channel, replicate the audio over all channels. - wav = wav.expand(*shape, channels, length) - elif src_channels >= channels: - # Case 3: - # The caller asked for multiple channels, and the input file have - # more channels than requested. - # In that case return the first channels. - wav = wav[..., :channels, :] + # Upmix by replicating the mono channel + return wav.expand(*shape, desired_channels, length) else: - # Case 4: What is a reasonable choice here? - raise ValueError( - "The audio file has less channels than requested \ - but is not mono." - ) - return wav + # Invalid case: input has fewer channels than desired and is not mono + raise ValueError("Cannot upmix from fewer than 1 channel unless the source is mono.") class MetaData: + @staticmethod def __track_metadata(track, sources, normalize=True, ext=File.EXT): - track_length = None - track_samplerate = None - mean = 0 - std = 1 + """ + Process and return the metadata for a single track. + + Args: + track (Path): Path to the track directory. + sources (list[str]): List of sources to look for. + normalize (bool): If True, calculates normalization values. + ext (str): Extension of audio files. + + Returns: + dict: Dictionary containing the track's metadata. + + Raises: + RuntimeError: If an audio file is invalid. + ValueError: If audio files have inconsistent lengths or sample rates. + """ + track_length, track_samplerate = None, None + mean, std = 0, 1 + for source in sources + [File.MIXTURE]: source_file = track / f"{source}{ext}" if source == File.MIXTURE and not source_file.exists(): - audio = 0 - for sub_source in sources: - sub_file = track / f"{sub_source}{ext}" - sub_audio, sr = ta.load(sub_file) - audio += sub_audio - would_clip = audio.abs().max() >= 1 - if would_clip: - assert ( - ta.get_audio_backend() == "soundfile" - ), "use dset.backend=soundfile" + audio, sr = MetaData.__create_mixture(track, sources, ext) ta.save(source_file, audio, sr, encoding="PCM_F") try: @@ -226,30 +271,13 @@ def __track_metadata(track, sources, normalize=True, ext=File.EXT): except RuntimeError: logging.error(f"{source_file} is invalid") raise - length = info.num_frames + + length, sample_rate = MetaData.__validate_track(info, track_length, track_samplerate, source_file) if track_length is None: - track_length = length - track_samplerate = info.sample_rate - elif track_length != length: - raise ValueError( - f"Invalid length for file {source_file}: " - f"expecting {track_length} but got {length}." - ) - elif info.sample_rate != track_samplerate: - raise ValueError( - f"Invalid sample rate for file {source_file}: " - f"expecting {track_samplerate} but got \ - {info.sample_rate}." - ) + track_length, track_samplerate = length, sample_rate + if source == File.MIXTURE and normalize: - try: - wav, _ = ta.load(str(source_file)) - except RuntimeError: - logging.error(f"{source_file} is invalid") - raise - wav = wav.mean(0) - mean = wav.mean().item() - std = wav.std().item() + mean, std = MetaData.__calculate_normalization(source_file) return { "length": length, @@ -258,27 +286,28 @@ def __track_metadata(track, sources, normalize=True, ext=File.EXT): "samplerate": track_samplerate, } + @staticmethod def build_metadata(path, sources, normalize=True, ext=File.EXT): """ - Build the metadata for `Wavset`. + Build and return the metadata for the entire dataset. Args: - path (str or Path): path to dataset. - sources (list[str]): list of sources to look for. - normalize (bool): if True, loads full track and store normalization - values based on the mixture file. - ext (str): extension of audio files (default is .wav). - """ + path (str or Path): Path to the dataset. + sources (list[str]): List of sources to look for. + normalize (bool): If True, calculates normalization values. + ext (str): Extension of audio files. + Returns: + dict: Dictionary containing metadata for each track in the dataset. + """ meta = {} path = Path(path) pendings = [] - from concurrent.futures import ThreadPoolExecutor with ThreadPoolExecutor(8) as pool: - for root, folders, files in os.walk(path, followlinks=True): + for root, _, _ in os.walk(path, followlinks=True): root = Path(root) - if root.name.startswith(".") or folders or root == path: + if root.name.startswith(".") or root == path: continue name = str(root.relative_to(path)) pendings.append( @@ -292,4 +321,81 @@ def build_metadata(path, sources, normalize=True, ext=File.EXT): for name, pending in tqdm.tqdm(pendings, ncols=120): meta[name] = pending.result() + return meta + + @staticmethod + def __create_mixture(track, sources, ext): + """ + Create and return the audio mixture from individual sources. + + Args: + track (Path): Path to the track directory. + sources (list[str]): List of sources to look for. + ext (str): Extension of audio files. + + Returns: + Tuple[Tensor, int]: The mixture audio tensor and its sample rate. + """ + audio = 0 + for sub_source in sources: + sub_file = track / f"{sub_source}{ext}" + sub_audio, sr = ta.load(sub_file) + audio += sub_audio + + would_clip = audio.abs().max() >= 1 + if would_clip: + assert ( + ta.get_audio_backend() == "soundfile" + ), "use dset.backend=soundfile" + + return audio, sr + + @staticmethod + def __validate_track(info, track_length, track_samplerate, source_file): + """ + Validate the track's length and sample rate. + + Args: + info (AudioMetaData): Metadata of the audio file. + track_length (int): Expected length of the track. + track_samplerate (int): Expected sample rate of the track. + source_file (Path): Path to the source file. + + Returns: + Tuple[int, int]: Length and sample rate of the track. + + Raises: + ValueError: If the track's length or sample rate is inconsistent. + """ + length = info.num_frames + if track_length is not None and track_length != length: + raise ValueError( + f"Invalid length for file {source_file}: " + f"expecting {track_length} but got {length}." + ) + elif track_samplerate is not None and info.sample_rate != track_samplerate: + raise ValueError( + f"Invalid sample rate for file {source_file}: " + f"expecting {track_samplerate} but got {info.sample_rate}." + ) + return length, info.sample_rate + + @staticmethod + def __calculate_normalization(source_file): + """ + Calculate and return the mean and standard deviation for normalization. + + Args: + source_file (Path): Path to the source file. + + Returns: + Tuple[float, float]: Mean and standard deviation of the waveform. + """ + try: + wav, _ = ta.load(str(source_file)) + except RuntimeError: + logging.error(f"{source_file} is invalid") + raise + wav = wav.mean(0) + return wav.mean().item(), wav.std().item() diff --git a/separator/inference.py b/separator/inference.py index c71844c..62b2ea3 100644 --- a/separator/inference.py +++ b/separator/inference.py @@ -32,10 +32,10 @@ def __init__(self, config, model_bottlneck_lstm=True): def resolve_weigths(self): if self.model_bottlneck_lstm: - self.weights_path = self.config.weights_dir / "weight_LSTM.pt" + self.weights_path = self.config.weights_dir / self.config.weights_LSTM_filename gdrive_url = self.config.gdrive_weights_LSTM else: - self.weights_path = self.config.weights_dir / "weight_conv.pt" + self.weights_path = self.config.weights_dir / self.config.weights_conv_filename gdrive_url = self.config.gdrive_weights_conv try: @@ -62,11 +62,14 @@ def track(self, sample_mixture_path): end = sr * (offset + duration) if duration else None mixture = waveform[:, start:end] + # Normalize ref = mixture.mean(0) - mixture = (mixture - ref.mean()) / ref.std() # normalization + mixture = (mixture - ref.mean()) / ref.std() + # Do separation sources = self.separate_sources(mixture[None], sample_rate=sr) + # Denormalize sources = sources * ref.std() + ref.mean() sources_list = ["drums", "bass", "other", "vocals"] B, S, C, T = sources.shape @@ -81,43 +84,72 @@ def track(self, sample_mixture_path): return audios def separate_sources(self, mix, sample_rate): - device = self.config.device - device = torch.device(device) if device else mix.device + """ + Separates the audio mix into its constituent sources. + Args: + mix (Tensor): The input mixed audio signal tensor of shape (batch, channels, length). + sample_rate (int): The sample rate of the audio signal. + + Returns: + Tensor: The separated audio sources as a tensor. + """ + # Set the device based on the configuration or input mix + device = torch.device(self.config.device) if self.config.device else mix.device + + # Get the shape of the input mix batch, channels, length = mix.shape + # Calculate chunk length for processing and overlap frames chunk_len = int(sample_rate * self.segment * (1 + self.overlap)) - start = 0 - end = chunk_len - overlap_frames = self.overlap * sample_rate - fade = Fade( - fade_in_len=0, fade_out_len=int(overlap_frames), fade_shape="linear" - ) + overlap_frames = int(self.overlap * sample_rate) + fade = Fade(fade_in_len=0, fade_out_len=overlap_frames, fade_shape="linear") - final = torch.zeros( - batch, - len(["drums", "bass", "other", "vocals"]), - channels, - length, - device=device, - ) + # Initialize the tensor to hold the final separated sources + num_sources = 4 # ["drums", "bass", "other", "vocals"] + final = torch.zeros(batch, num_sources, channels, length, device=device) + start, end = 0, chunk_len while start < length - overlap_frames: + # Process each chunk with model and apply fade chunk = mix[:, :, start:end] with torch.no_grad(): - out = self.model.forward(chunk) - out = fade(out) - final[:, :, :, start:end] += out - if start == 0: - fade.fade_in_len = int(overlap_frames) - start += int(chunk_len - overlap_frames) - else: - start += chunk_len - end += chunk_len - if end >= length: - fade.fade_out_len = 0 + separated_sources = self.model.forward(chunk) + separated_sources = fade(separated_sources) + final[:, :, :, start:end] += separated_sources + + # Adjust the start and end for the next chunk, and update fade parameters + start, end = self.__update_chunk_indices(start, end, chunk_len, overlap_frames, length, fade) + return final + @staticmethod + def __update_chunk_indices(start, end, chunk_len, overlap_frames, length, fade): + """ + Update the chunk indices for the next iteration and adjust fade parameters. + + Args: + start (int): Current start index of the chunk. + end (int): Current end index of the chunk. + chunk_len (int): Length of each chunk. + overlap_frames (int): Number of overlapping frames. + length (int): Total length of the audio signal. + fade (Fade): The Fade object used for applying fade in/out. + + Returns: + Tuple[int, int]: The updated start and end indices for the next chunk. + """ + if start == 0: + fade.fade_in_len = overlap_frames + start += chunk_len - overlap_frames + else: + start += chunk_len + + end = min(end + chunk_len, length) + fade.fade_out_len = 0 if end >= length else overlap_frames + + return start, end + def resolve_default_sample(self): default_input_dir = self.config.default_input_dir Path(default_input_dir).mkdir(parents=True, exist_ok=True) From 2f869eb56392ac211c15ce4925f7546e04f2ecec Mon Sep 17 00:00:00 2001 From: d-a-yakovlev Date: Fri, 29 Dec 2023 17:00:02 +0300 Subject: [PATCH 03/15] refac: full - runner, stream_class -> tf_lite_stream, almost - converter --- streaming/config/config.py | 6 ++++++ streaming/converter.py | 11 +++++------ streaming/runner.py | 3 ++- streaming/{stream_class.py => tf_lite_stream.py} | 0 4 files changed, 13 insertions(+), 7 deletions(-) rename streaming/{stream_class.py => tf_lite_stream.py} (100%) diff --git a/streaming/config/config.py b/streaming/config/config.py index 6989748..789d9e1 100644 --- a/streaming/config/config.py +++ b/streaming/config/config.py @@ -5,6 +5,8 @@ @dataclass class ConverterConfig: weights_dir: Path = Path("/app/streaming/weights") + weights_LSTM_filename: str = "weight_LSTM.pt" + weights_conv_filename: str = "weight_conv.pt" gdrive_weights_LSTM_id: str = "18jT2TYffdRD1fL7wecAiM5nJPM_OKpNB" gdrive_weights_conv_id: str = "1VO07OYbsnCuEJYRSuA8HhjlQnx6dbWX7" @@ -14,9 +16,13 @@ class ConverterConfig: model_class_name: str = "Model_Unet" tflite_model_dst: str = "tflite_model" + sample_rate: int = 44100 + segment_duration: float = 1. + @dataclass class StreamConfig: + converter_script: str = "/app/streaming/converter.py" sample_rate: int = 44100 nfft: int = 4096 stft_py_module: str = "model.STFT" diff --git a/streaming/converter.py b/streaming/converter.py index 2a11e5f..3bcdd11 100644 --- a/streaming/converter.py +++ b/streaming/converter.py @@ -190,10 +190,10 @@ def main(args, config): ) if model.bottlneck_lstm: - weights_path = config.weights_dir / "weight_LSTM.pt" + weights_path = config.weights_dir / config.weights_LSTM_filename gdrive_id = config.gdrive_weights_LSTM_id else: - weights_path = config.weights_dir / "weight_conv.pt" + weights_path = config.weights_dir / config.weights_conv_filename gdrive_id = config.gdrive_weights_conv_id try: config.weights_dir.mkdir(parents=True, exist_ok=False) @@ -224,7 +224,7 @@ def stft(self, wave): def istft(self, z): return self.model.stft.istft(z, self.length_wave) - SEGMENT_WAVE = 44100 + SEGMENT_WAVE = config.sample_rate * config.segment_duration dummy_wave = torch.rand(size=(1, 2, SEGMENT_WAVE)) dummy_spectr = OuterSTFT(SEGMENT_WAVE, model).stft(dummy_wave) @@ -235,9 +235,8 @@ def istft(self, z): inputs_channel_order=ChannelOrder.PYTORCH, ) - model_path = str( - args.out_dir + f"/{args.class_name}_outer_stft_{SEGMENT_WAVE / 44100:.1f}" - ) + model_filename = f"{args.class_name}_outer_stft_{config.segment_duration:.1f}" + model_path = args.out_dir + '/' + model_filename keras_model.save(model_path + ".h5") custom_objects = {"WeightLayer": WeightLayer} diff --git a/streaming/runner.py b/streaming/runner.py index 5dfa5b6..59f1bd5 100644 --- a/streaming/runner.py +++ b/streaming/runner.py @@ -35,7 +35,8 @@ def main(args, config): if start_converter: subprocess.Popen( - ["python3", "/app/converter.py"], executable="/bin/bash", shell=True + ["python3", config.StreamConfig.converter_script], + executable="/bin/bash", shell=True ) converter_outputs = os.listdir(config.ConverterConfig.tflite_model_dst) diff --git a/streaming/stream_class.py b/streaming/tf_lite_stream.py similarity index 100% rename from streaming/stream_class.py rename to streaming/tf_lite_stream.py From 6cec29a1202b8ae355d2901f08e995a988908f1d Mon Sep 17 00:00:00 2001 From: d-a-yakovlev Date: Fri, 29 Dec 2023 17:08:05 +0300 Subject: [PATCH 04/15] refac: runner --- streaming/runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/streaming/runner.py b/streaming/runner.py index 59f1bd5..b0992cb 100644 --- a/streaming/runner.py +++ b/streaming/runner.py @@ -6,7 +6,7 @@ import subprocess from pathlib import Path -from stream_class import TFLiteTorchStream +from tf_lite_stream import TFLiteTorchStream def resolve_default_sample(config): From 1fdc0da771723377f0e3f4f075da64d3d4a1e0d9 Mon Sep 17 00:00:00 2001 From: maks00170 Date: Fri, 29 Dec 2023 21:19:20 +0300 Subject: [PATCH 05/15] rafactor_separator: modules, PM_unet, pl_model,STFT --- separator/config/config.py | 3 + separator/model/PM_Unet.py | 29 +++--- separator/model/STFT.py | 5 +- separator/model/modules.py | 155 +++++++++++++++++++------------- separator/pl_model.py | 149 ++++++++++++++++-------------- separator/train/augment.py | 179 +++++++++++++++++++++---------------- separator/train/loss.py | 14 +-- 7 files changed, 303 insertions(+), 231 deletions(-) diff --git a/separator/config/config.py b/separator/config/config.py index e4a290e..6b27417 100644 --- a/separator/config/config.py +++ b/separator/config/config.py @@ -36,7 +36,10 @@ class TrainConfig: layers: int = 2 stft_flag: bool = True # augments + proba_shift: float = 0.5 shift: int = 8192 + proba_flip_channel: float = 1 + proba_flip_sign: float = 1 pitchshift_proba: float = 0.2 vocals_min_semitones: int = -5 vocals_max_semitones: int = 5 diff --git a/separator/model/PM_Unet.py b/separator/model/PM_Unet.py index f02f43e..5a5894f 100644 --- a/separator/model/PM_Unet.py +++ b/separator/model/PM_Unet.py @@ -2,7 +2,6 @@ import torch as th import torch.nn as nn from model.STFT import STFT -from functools import partial from model.modules import Encoder, Decoder, Bottleneck_v2, Bottleneck from typing import List, Optional @@ -22,20 +21,20 @@ def __init__( stft_flag: bool = True, ): """ - depth - (int) number of layers encoder and decoder - source - (list[str]) list of source names - channel - (int) initial number of hidden channels - is_mono - (bool) mono input/output audio channel - mask_mode - (bool) mask inference - skip_mode - (concat or add) types skip connection - concat: concatenates output encoder and decoder - add: add output encoder and decoder - nfft - (int) number of fft bins - bottlneck_lstm - (bool) lstm bottlneck - True: bottlneck_lstm - bilstm bottlneck - False: bottlneck_conv - convolution bottlneck - layers - (int) number bottlneck_lstm layers - stft_flag - (bool) use stft + depth (int): Number of layers in both the encoder and decoder. + source (list[str]): List of source names. + channel (int): Initial number of hidden channels. + is_mono (bool): Indicates whether the input/output audio channel is mono. + mask_mode (bool): Enables or disables mask inference. + skip_mode (str): Type of skip connection, either 'concat' or 'add'. + concat: Concatenates the outputs of the encoder and decoder. + add: Adds the outputs of the encoder and decoder. + nfft (int): Number of FFT (Fast Fourier Transform) bins. + bottleneck_lstm (bool): Determines the type of bottleneck to use. + True: Uses a BiLSTM (bidirectional Long Short-Term Memory) bottleneck. + False: Uses a convolutional bottleneck. + layers (int): Number of bottleneck LSTM layers. + stft_flag (bool): Indicates whether to use STFT (Short-Time Fourier Transform). """ super().__init__() self.sources = source diff --git a/separator/model/STFT.py b/separator/model/STFT.py index 2f47566..666bccc 100644 --- a/separator/model/STFT.py +++ b/separator/model/STFT.py @@ -18,8 +18,9 @@ def __pad1d( mode: str = "constant", value: float = 0.0, ): - """Tiny wrapper around F.pad, just to allow for reflect padding on small input. - If this is the case, we insert extra 0 padding to the right before the reflection happen. + """ + Tiny wrapper around F.pad, designed to allow reflect padding on small inputs. + If the input is too small for reflect padding, we first add extra zero padding to the right before reflection occurs. """ x0 = x length = x.shape[-1] diff --git a/separator/model/modules.py b/separator/model/modules.py index b8d27d1..781d118 100644 --- a/separator/model/modules.py +++ b/separator/model/modules.py @@ -1,11 +1,21 @@ import torch -import torch as th import torch.nn as nn import math from torch.nn import functional as F class DownSample(nn.Module): + """ + DownSample - dimensionality reduction block that includes layer normalization, activation layer, and Conv2d layer. + Args: + input_channel (int): Number of input channels. + out_channel (int): Number of output channels. + scale (int, tuple): Kernel size. + stride (int, tuple): Stride of the convolution. + padding (int, tuple or str): Padding added to all four sides of the input. + activation (object): Activation layer. + normalization (object): Normalization layer. + """ def __init__( self, input_channel, @@ -16,13 +26,7 @@ def __init__( activation, normalization, ): - """ - DownSample - block include layer normalization, layer activation and Conv2d layer - Args: - scale - kernel size - activation - activation layer - normalization - normalization layer - """ + super().__init__() self.conv_layer = nn.Sequential( @@ -43,6 +47,17 @@ def forward(self, x): class UpSample(nn.Module): + """ + UpSample - dimensionality boosting block that includes layer normalization, activation layer, and Conv2d layer. + Args: + input_channel (int): Number of input channels. + out_channel (int): Number of output channels. + scale (int, tuple): Kernel size. + stride (int, tuple): Stride of the convolution. + padding (int, tuple or str): Padding added to all four sides of the input. + activation (object): Activation layer. + normalization (object): Normalization layer. + """ def __init__( self, input_channel, @@ -53,13 +68,7 @@ def __init__( activation, normalization, ): - """ - UpSample - block include layer normalization, layer activation and Conv2d layer - Args: - scale - kernel size - activation - activation layer - normalization - normalization layer - """ + super().__init__() self.convT_layer = nn.Sequential( @@ -80,12 +89,16 @@ def forward(self, x): class InceptionBlock(nn.Module): - def __init__(self, input_channel, out_channels, activation, normalization): - """InceptionBlock - The block includes 3 branches consisting of normalization layers, activation layers and 2d convolution with 1, 3 and 5 core sizes respectively. - Args: - activation - activation layer - normalization - normalization layer - """ + """ + InceptionBlock: This block comprises three branches, each consisting of normalization layers, activation layers, and 2D convolution layers. The convolution layers in each branch have kernel sizes of 1, 3, and 5, respectively. + Args: + input_channel (int): Number of input channels. + out_channel (int): Number of output channels. + activation (object): Activation layer. + normalization (object): Normalization layer. + """ + def __init__(self, input_channel, out_channel, activation, normalization): + super().__init__() self.conv_layer_1 = nn.Sequential( @@ -93,7 +106,7 @@ def __init__(self, input_channel, out_channels, activation, normalization): activation, nn.Conv2d( in_channels=input_channel, - out_channels=out_channels, + out_channels=out_channel, kernel_size=1, stride=1, bias=False, @@ -105,7 +118,7 @@ def __init__(self, input_channel, out_channels, activation, normalization): activation, nn.Conv2d( in_channels=input_channel, - out_channels=out_channels, + out_channels=out_channel, kernel_size=3, stride=1, padding="same", @@ -118,7 +131,7 @@ def __init__(self, input_channel, out_channels, activation, normalization): activation, nn.Conv2d( in_channels=input_channel, - out_channels=out_channels, + out_channels=out_channel, kernel_size=5, stride=1, padding="same", @@ -134,6 +147,17 @@ def forward(self, x): class Encoder(nn.Module): + """ + Encoder layer - Block included DownSample layer and InceptionBlock. + Args: + input_channel (int): Number of input channels. + out_channel (int): Number of output channels. + scale (int, tuple): The size of the kernel used in the DownSample layer. + stride (int, tuple): The stride used in the DownSample layer. + padding (int, tuple or str): Padding added to all four sides of the input. + activation (object): Activation layer. + normalization (object): Normalization layer. + """ def __init__( self, input_channel, @@ -144,15 +168,7 @@ def __init__( activation, normalization, ): - """ - Encoder layer - Block included DownSample layer and InceptionBlock. - Args: - scale - scale (kernel size) DownSample layer - stride - stride DownSample layer - padding - padding DownSample layer - activation - activation layer - normalization - normalization layer - """ + super().__init__() self.inception_layer = InceptionBlock( @@ -175,6 +191,17 @@ def forward(self, x): class Decoder(nn.Module): + """ + Decoder layer - Block included UpSample layer and InceptionBlock. + Args: + input_channel (int): Number of input channels. + out_channel (int): Number of output channels. + scale (int, tuple): The size of the kernel used in the UpSample layer. + stride (int, tuple): The stride used in the UpSample layer. + padding (int, tuple or str): Padding added to all four sides of the input. + activation (object): Activation layer. + normalization (object): Normalization layer. + """ def __init__( self, input_channel, @@ -185,15 +212,7 @@ def __init__( activation, normalization, ): - """ - Decoder layer - Block included UpSample layer and InceptionBlock. - Args: - scale - scale (kernel size) UpSample layer - stride - stride UpSample layer - padding - padding UpSample layer - activation - activation layer - normalization - normalization layer - """ + super().__init__() self.inception_layer = InceptionBlock( @@ -217,9 +236,14 @@ def forward(self, x): class BLSTM(nn.Module): """ - BiLSTM with same hidden units as input dim. - If `max_steps` is not None, input will be splitting in overlapping - chunks and the LSTM applied separately on each chunk. + A bidirectional LSTM (BiLSTM) module with the same number of hidden units as the input dimension. + This module can process inputs in overlapping chunks if `max_steps` is specified. + In this case, the input will be split into chunks, and the LSTM will be applied to each chunk separately. + Args: + dim (int): The number of dimensions in the input and the hidden state of the LSTM. + max_steps (int, optional): The maximum number of steps (length of chunks) for processing the input. Defaults to None. + skip (bool, optional): Flag to enable skip connections. Defaults to False. + layers (int): Number of recurrent layers """ def __init__(self, dim, layers=1, max_steps=None, skip=False): @@ -286,6 +310,18 @@ def forward(self, x): class Bottleneck_v2(nn.Module): + """ + Bottleneck - bi-lstm bottleneck + Args: + input_channel (int): Number of input channels. + out_channel (int): Number of output channels. + layers (int): number of recurrent layers + skip (bool): include skip conncetion bi-lstm + stride (int, tuple): The stride used in the Conv1d layer. + padding (int, tuple or str): Padding added to all four sides of the input. + activation (object): Activation layer. + normalization (object): Normalization layer. + """ def __init__( self, input_channel, @@ -296,25 +332,16 @@ def __init__( max_steps=200, skip=True, stride=1, - padding=1, + padding="same", ): - """ - Bottleneck - bi-lstm bottleneck - Args: - activation - activation layer - normalization - normalization layer - layers - number of recurrent layers - skip - include skip conncetion bi-lstm - stride - stride Conv1d - padding - stride Conv1d - """ + super().__init__() self.conv_layer = nn.Sequential( normalization(input_channel, affine=True), activation, nn.Conv1d( - input_channel, out_channel, kernel_size=3, stride=stride, padding="same" + input_channel, out_channel, kernel_size=3, stride=stride, padding=padding ), ) @@ -337,10 +364,16 @@ def forward(self, x): class Bottleneck(nn.Module): + """ + Bottleneck - convolution bottleneck + Args: + input_channel (int): Number of input channels. + out_channel (int): Number of output channels. + activation (object): Activation layer. + normalization (object): Normalization layer. + """ def __init__(self, input_channel, out_channels, normalization, activation): - """ - Bottleneck - convolution bottleneck - """ + super().__init__() self.conv_layer_1 = nn.Sequential( @@ -376,7 +409,7 @@ def __init__(self, input_channel, out_channels, normalization, activation): in_channels=out_channels, out_channels=out_channels, kernel_size=1, - stride=1, # padding='same', + stride=1, bias=False, ), ) diff --git a/separator/pl_model.py b/separator/pl_model.py index febcee1..e10328b 100644 --- a/separator/pl_model.py +++ b/separator/pl_model.py @@ -5,7 +5,6 @@ from train import augment from pathlib import Path -import os import torch import torch.nn as nn import pytorch_lightning as pl @@ -15,18 +14,6 @@ signal_distortion_ratio, ) - -def compute_uSDR( - predT: torch.Tensor, tgtT: torch.Tensor, delta: float = 1e-7 -) -> torch.Tensor: - num = torch.sum(torch.square(tgtT), dim=(1, 2)) - den = torch.sum(torch.square(tgtT - predT), dim=(1, 2)) - num += delta - den += delta - usdr = 10 * torch.log10(num / den) - return usdr.mean() - - class PM_model(pl.LightningModule): def __init__(self, config): super().__init__() @@ -45,6 +32,7 @@ def __init__(self, config): ) # loss + # Loss = (L_1 + L_{MRS} - L_{SISDR}) self.criterion_1 = nn.L1Loss() self.criterion_2 = MultiResSpecLoss( factor=config.factor, @@ -55,9 +43,9 @@ def __init__(self, config): self.criterion_3 = ScaleInvariantSignalDistortionRatio() # augment - self.augment = [augment.Shift(shift=config.shift, same=True)] + self.augment = [augment.Shift(proba=config.proba_shift, shift=config.shift, same=True)] self.augment += [ - augment.PitchShift_f( + augment.PitchShift( proba=config.pitchshift_proba, min_semitones=config.vocals_min_semitones, max_semitones=config.vocals_max_semitones, @@ -65,11 +53,11 @@ def __init__(self, config): max_semitones_other=config.other_max_semitones, flag_other=config.pitchshift_flag_other, ), - augment.TimeChange_f( + augment.TimeChange( factors_list=config.time_change_factors, proba=config.time_change_proba ), - augment.FlipChannels(), - augment.FlipSign(), + augment.FlipChannels(proba=config.proba_flip_channel), + augment.FlipSign(proba=config.proba_flip_sign), augment.Remix(proba=config.remix_proba, group_size=config.remix_group_size), augment.Scale( proba=config.scale_proba, min=config.scale_min, max=config.scale_max @@ -77,22 +65,41 @@ def __init__(self, config): augment.FadeMask(proba=config.fade_mask_proba), augment.Double(proba=config.double_proba), augment.Reverse(proba=config.reverse_proba), - augment.Remix_wave( + augment.RemixWave( proba=config.mushap_proba, group_size=config.mushap_depth ), ] self.augment = torch.nn.Sequential(*self.augment) + self.model.apply(self.__init_weights) + + def __init_weights(self, m): + if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): + torch.nn.init.xavier_uniform(m.weight) + + def __usdr(self, predT, tgtT, delta=1e-7): + """ + latex: $usdr=10\log_{10} (\dfrac{\| tgtT\|^2 + \delta}{ \| predT - tgtT\| ^{2} + \delta})$ + """ + num = torch.sum(torch.square(tgtT), dim=(1, 2)) + den = torch.sum(torch.square(tgtT - predT), dim=(1, 2)) + num += delta + den += delta + usdr = 10 * torch.log10(num / den) + return usdr.mean() + + def forward(self, x): x = self.model(x) return x - def loss(self, y_true, y_pred): # L = L_1 + L_{MRS} - L_{SISDR} + def loss(self, y_true, y_pred): + # losses are averaged loss = ( self.criterion_1(y_pred, y_true) + self.criterion_2(y_pred, y_true) - self.criterion_3(y_pred, y_true) - ) + )/3 return loss def training_step(self, batch, batch_idx): @@ -102,15 +109,20 @@ def training_step(self, batch, batch_idx): source_predict = self.model(mix) - drums_loss = self.loss(source_predict[:, 0], source[:, 0]) / 3 + drums_pred, drums_target = source_predict[:, 0], source[:, 0] + bass_pred, bass_target = source_predict[:, 1], source[:, 1] + other_pred, other_target = source_predict[:, 2], source[:, 2] + vocals_pred, vocals_target = source_predict[:, 3], source[:, 3] - bass_loss = self.loss(source_predict[:, 1], source[:, 1]) / 3 + drums_loss = self.loss(drums_pred, drums_target) - other_loss = self.loss(source_predict[:, 2], source[:, 2]) / 3 + bass_loss = self.loss(bass_pred, bass_target) - vocals_loss = self.loss(source_predict[:, 3], source[:, 3]) / 3 + other_loss = self.loss(other_pred, other_target) - loss = 0.25 * (drums_loss + bass_loss + other_loss + vocals_loss) + vocals_loss = self.loss(vocals_pred, vocals_target) + + loss = 0.25 * (drums_loss + bass_loss + other_loss + vocals_loss) # losses averaged across sources self.log_dict( { @@ -128,16 +140,16 @@ def training_step(self, batch, batch_idx): self.log_dict( { "train_drums_sdr": signal_distortion_ratio( - source_predict[:, 0], source[:, 0] + drums_pred, drums_target ).mean(), "train_bass_sdr": signal_distortion_ratio( - source_predict[:, 1], source[:, 1] + bass_pred, bass_target ).mean(), "train_other_sdr": signal_distortion_ratio( - source_predict[:, 2], source[:, 2] + other_pred, other_target ).mean(), "train_vocals_sdr": signal_distortion_ratio( - source_predict[:, 3], source[:, 3] + vocals_pred, vocals_target ).mean(), }, on_epoch=True, @@ -148,16 +160,16 @@ def training_step(self, batch, batch_idx): self.log_dict( { "train_drums_sisdr": scale_invariant_signal_distortion_ratio( - source_predict[:, 0], source[:, 0] + drums_pred, drums_target ).mean(), "train_bass_sisdr": scale_invariant_signal_distortion_ratio( - source_predict[:, 1], source[:, 1] + bass_pred, bass_target ).mean(), "train_other_sisdr": scale_invariant_signal_distortion_ratio( - source_predict[:, 2], source[:, 2] + other_pred, other_target ).mean(), "train_vocals_sisdr": scale_invariant_signal_distortion_ratio( - source_predict[:, 3], source[:, 3] + vocals_pred, vocals_target ).mean(), }, on_epoch=True, @@ -167,17 +179,17 @@ def training_step(self, batch, batch_idx): self.log_dict( { - "train_drums_usdr": compute_uSDR( - source_predict[:, 0], source[:, 0] + "train_drums_usdr": self.__usdr( + drums_pred, drums_target ).mean(), - "train_bass_usdr": compute_uSDR( - source_predict[:, 1], source[:, 1] + "train_bass_usdr": self.__usdr( + bass_pred, bass_target ).mean(), - "train_other_usdr": compute_uSDR( - source_predict[:, 2], source[:, 2] + "train_other_usdr": self.__usdr( + other_pred, other_target ).mean(), - "train_vocals_usdr": compute_uSDR( - source_predict[:, 3], source[:, 3] + "train_vocals_usdr": self.__usdr( + vocals_pred, vocals_target ).mean(), }, on_epoch=True, @@ -192,14 +204,18 @@ def validation_step(self, batch, batch_idx): mix = source.sum(dim=1) source_predict = self.model(mix) + drums_pred, drums_target = source_predict[:, 0], source[:, 0] + bass_pred, bass_target = source_predict[:, 1], source[:, 1] + other_pred, other_target = source_predict[:, 2], source[:, 2] + vocals_pred, vocals_target = source_predict[:, 3], source[:, 3] - drums_loss = self.loss(source_predict[:, 0], source[:, 0]) / 3 + drums_loss = self.loss(drums_pred, drums_target) - bass_loss = self.loss(source_predict[:, 1], source[:, 1]) / 3 + bass_loss = self.loss(bass_pred, bass_target) - other_loss = self.loss(source_predict[:, 2], source[:, 2]) / 3 + other_loss = self.loss(other_pred, other_target) - vocals_loss = self.loss(source_predict[:, 3], source[:, 3]) / 3 + vocals_loss = self.loss(vocals_pred, vocals_target) loss = 0.25 * (drums_loss + bass_loss + other_loss + vocals_loss) @@ -219,16 +235,16 @@ def validation_step(self, batch, batch_idx): self.log_dict( { "valid_drums_sdr": signal_distortion_ratio( - source_predict[:, 0], source[:, 0] + drums_pred, drums_target ).mean(), "valid_bass_sdr": signal_distortion_ratio( - source_predict[:, 1], source[:, 1] + bass_pred, bass_target ).mean(), "valid_other_sdr": signal_distortion_ratio( - source_predict[:, 2], source[:, 2] + other_pred, other_target ).mean(), "valid_vocals_sdr": signal_distortion_ratio( - source_predict[:, 3], source[:, 3] + vocals_pred, vocals_target ).mean(), }, on_epoch=True, @@ -239,16 +255,16 @@ def validation_step(self, batch, batch_idx): self.log_dict( { "valid_drums_sisdr": scale_invariant_signal_distortion_ratio( - source_predict[:, 0], source[:, 0] + drums_pred, drums_target ).mean(), "valid_bass_sisdr": scale_invariant_signal_distortion_ratio( - source_predict[:, 1], source[:, 1] + bass_pred, bass_target ).mean(), "valid_other_sisdr": scale_invariant_signal_distortion_ratio( - source_predict[:, 2], source[:, 2] + other_pred, other_target ).mean(), "valid_vocals_sisdr": scale_invariant_signal_distortion_ratio( - source_predict[:, 3], source[:, 3] + vocals_pred, vocals_target ).mean(), }, on_epoch=True, @@ -256,19 +272,20 @@ def validation_step(self, batch, batch_idx): sync_dist=True, ) + self.log_dict( { - "valid_drums_usdr": compute_uSDR( - source_predict[:, 0], source[:, 0] + "valid_drums_usdr": self.__usdr( + drums_pred, drums_target ).mean(), - "valid_bass_usdr": compute_uSDR( - source_predict[:, 1], source[:, 1] + "valid_bass_usdr": self.__usdr( + bass_pred, bass_target ).mean(), - "valid_other_usdr": compute_uSDR( - source_predict[:, 2], source[:, 2] + "valid_other_usdr": self.__usdr( + other_pred, other_target ).mean(), - "valid_vocals_usdr": compute_uSDR( - source_predict[:, 3], source[:, 3] + "valid_vocals_usdr": self.__usdr( + vocals_pred, vocals_target ).mean(), }, on_epoch=True, @@ -287,12 +304,6 @@ def configure_optimizers(self): "monitor": "valid_loss", } - -def init_weights(m): - if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): - torch.nn.init.xavier_uniform(m.weight) - - def main(config): from data.dataset import get_musdb_wav_datasets from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor @@ -336,7 +347,7 @@ def main(config): lr_monitor = LearningRateMonitor(logging_interval="step") mp_model = PM_model(config) - mp_model.model.apply(init_weights) + trainer = pl.Trainer( accelerator="gpu" if config.device == "cuda" else "cpu", devices="auto", diff --git a/separator/train/augment.py b/separator/train/augment.py index 39ea37c..a7efabe 100644 --- a/separator/train/augment.py +++ b/separator/train/augment.py @@ -2,57 +2,80 @@ import torchaudio import torch as th from torch import nn -from torch_audiomentations import PitchShift +from torch_audiomentations import PitchShift as ps class Shift(nn.Module): """ - Randomly shift audio in time by up to `shift` samples. + Shifts audio in time for data augmentation during training. Applies a random shift up to 'shift' samples. + If 'same' is True, all sources in a batch are shifted by the same amount; otherwise, each is shifted differently. + + Args: + proba (float): Probability of applying the shift. + shift (int): Maximum number of samples for the shift. Defaults to 8192. + same (bool): Apply the same shift to all sources in a batch. Defaults to False. """ - def __init__(self, shift=8192, same=False): + def __init__(self, proba=1, shift=8192, same=False): super().__init__() self.shift = shift self.same = same + self.proba = proba def forward(self, wav): + if self.shift < 1: + return wav + batch, sources, channels, time = wav.size() length = time - self.shift - if self.shift > 0: - if not self.training: - wav = wav[..., :length] - else: - srcs = 1 if self.same else sources - offsets = th.randint(self.shift, [batch, srcs, 1, 1], device=wav.device) - offsets = offsets.expand(-1, sources, channels, -1) - indexes = th.arange(length, device=wav.device) - wav = wav.gather(3, indexes + offsets) + + if random.random() < self.proba: + srcs = 1 if self.same else sources + offsets = th.randint(self.shift, [batch, srcs, 1, 1], device=wav.device) + offsets = offsets.expand(-1, sources, channels, -1) + indexes = th.arange(length, device=wav.device) + wav = wav.gather(3, indexes + offsets) return wav + class FlipChannels(nn.Module): """ Flip left-right channels. + Args: + proba (float): Probability of applying the flip left-right channels. """ - + def __init__(self, proba=1): + super().__init__() + self.proba = proba + + def forward(self, wav): batch, sources, channels, time = wav.size() - if self.training and wav.size(2) == 2: - left = th.randint(2, (batch, sources, 1, 1), device=wav.device) - left = left.expand(-1, -1, -1, time) - right = 1 - left - wav = th.cat([wav.gather(2, left), wav.gather(2, right)], dim=2) + if wav.size(2) == 2: + if random.random() < self.proba: + left = th.randint(2, (batch, sources, 1, 1), device=wav.device) + left = left.expand(-1, -1, -1, time) + right = 1 - left + wav = th.cat([wav.gather(2, left), wav.gather(2, right)], dim=2) return wav class FlipSign(nn.Module): """ Random sign flip. + Args: + proba (float): Probability of applying the sign flip. """ + def __init__(self, proba=1): + super().__init__() + self.proba = proba + + def forward(self, wav): batch, sources, channels, time = wav.size() - if self.training: + if random.random() < self.proba: signs = th.randint( 2, (batch, sources, 1, 1), device=wav.device, dtype=th.float32 ) @@ -62,15 +85,13 @@ def forward(self, wav): class Remix(nn.Module): """ - Shuffle sources to make new mixes. + Randomly shuffles sources within each batch during training to create new mixes. Shuffling is performed within groups. + Args: + proba (float): Probability of applying the shuffle. + group_size (int): Size of groups within which shuffling occurs. """ - def __init__(self, proba=1, group_size=4): - """ - Shuffle sources within one batch. - Each batch is divided into groups of size `group_size` and shuffling is done within - each group separatly. - """ + super().__init__() self.proba = proba self.group_size = group_size @@ -96,11 +117,15 @@ def forward(self, wav): class Scale(nn.Module): + """ + Scales the amplitude of the audio waveform during training. The scaling factor is chosen randomly within a specified range. + Args: + proba (float): Probability of applying the scaling. + min (float): Minimum scaling factor. + max (float): Maximum scaling factor. + """ def __init__(self, proba=1.0, min=0.25, max=1.25): - """ - Args: - time_mask_param - maximum possible length seconds of the mask - """ + super().__init__() self.proba = proba self.min = min @@ -119,15 +144,15 @@ def forward(self, wav): class FadeMask(nn.Module): """ - Apply masking to a spectrogram in the time domain. - https://pytorch.org/audio/main/generated/torchaudio.transforms.TimeMasking.html + Applies time-domain masking to the spectrogram for data augmentation. + Args: + proba (float): Probability of applying the mask. + sample_rate (int): Sample rate of the audio. + time_mask_param (int): Maximum possible length in seconds of the mask. """ def __init__(self, proba=1, sample_rate=44100, time_mask_param=2): - """ - Args: - time_mask_param - maximum possible length seconds of the mask - """ + super().__init__() self.sample_rate = sample_rate self.time_mask = torchaudio.transforms.TimeMasking( @@ -146,10 +171,17 @@ def forward(self, wav): return wav # output -> tensor -class PitchShift_f(nn.Module): # input -> tensor +class PitchShift(nn.Module): # input -> tensor """ - Pitch shift the sound up or down without changing the tempo. - https://github.com/asteroid-team/torch-audiomentations/blob/main/torch_audiomentations/augmentations/pitch_shift.py + Applies pitch shifting to audio sources. The pitch is shifted up or down without changing the tempo. + Args: + proba (float): Probability of applying the pitch shift. + min_semitones (int): Min shift for vocal source. + max_semitones (int): Max shift for vocal source. + min_semitones_other (int): Min shift for other sources. + max_semitones_other (int): Max shift for other sources. + sample_rate (int): Sample rate of audio. + flag_other (bool): Apply augmentation to other sources. """ def __init__( @@ -162,16 +194,9 @@ def __init__( sample_rate=44100, flag_other=False, ): - """ - Args: - min_semitones - vocal source - max_semitones - vocals source - min_semitones_other - drums, bass, other source - max_semitones_other - drums, bass, other source - flag_other - apply augmentation other sources - """ + super().__init__() - self.pitch_vocals = PitchShift( + self.pitch_vocals = ps( p=proba, min_transpose_semitones=min_semitones, max_transpose_semitones=max_semitones, @@ -180,7 +205,7 @@ def __init__( self.flag_other = flag_other if flag_other: - self.pitch_other = PitchShift( + self.pitch_other = ps( p=proba, min_transpose_semitones=min_semitones_other, max_transpose_semitones=max_semitones_other, @@ -198,13 +223,16 @@ def forward(self, wav): return wav -class TimeChange_f(nn.Module): +class TimeChange(nn.Module): """ - Changes the speed or duration of the signal without changing the pitch. - https://pytorch.org/audio/stable/generated/torchaudio.transforms.SpeedPerturbation.html + Changes the speed or duration of the signal without affecting the pitch. + Args: + factors_list (list): List of factors to adjust speed. + proba (float): Probability of applying the time change. + sample_rate (int): Sample rate of audio. """ - def __init__(self, factors_list, proba=1, sample_rate=44100): + super().__init__() self.sample_rate = sample_rate self.proba = proba @@ -220,12 +248,12 @@ def forward(self, wav): return wav -# new augment - class Double(nn.Module): """ - With equal probability makes both channels the same to left/right original channel. + With equal probability, makes both channels the same as either the left or right original channel. + Args: + proba (float): Probability of applying the doubling. """ def __init__(self, proba=1): @@ -233,7 +261,6 @@ def __init__(self, proba=1): self.proba = proba def forward(self, wav): - num_samples = wav.shape[-1] if random.random() < self.proba: wav = wav.clone() @@ -254,15 +281,15 @@ def forward(self, wav): class Reverse(nn.Module): """ - Reverse (invert) the vocal source along the time axis - """ + Reverses a segment of the vocal source along the time axis. + Args: + proba (float): Probability of applying the reversal. + min_band_part (float): Minimum fraction of the track to be inverted. + max_band_part (float): Maximum fraction of the track to be inverted. +""" def __init__(self, proba=1, min_band_part=0.2, max_band_part=0.4): - """ - Args: - min_band_part - minimum track share inversion - max_band_part - maximum track share inversion - """ + super().__init__() self.proba = proba self.min_band_part = min_band_part @@ -286,17 +313,16 @@ def forward(self, wav): return wav -class Remix_wave(nn.Module): +class RemixWave(nn.Module): """ - Mashup track in group + Creates a mashup track within a batch. + Args: + proba (float): Probability of applying the mashup. + group_size (int): Group size for mashup. + mix_depth (int): Number of tracks to mix. """ - def __init__(self, proba=1, group_size=4, mix_depth=2): - """ - Args: - group_size - group size - mix_depth - number mashup track - """ + super().__init__() self.proba = proba self.remix = Remix(proba=1, group_size=group_size) @@ -304,8 +330,6 @@ def __init__(self, proba=1, group_size=4, mix_depth=2): def forward(self, wav): if random.random() < self.proba: - batch, streams, channels, time = wav.size() - device = wav.device mix = wav.clone() for i in range(self.mix_depth): mix += self.remix(wav) @@ -314,9 +338,11 @@ def forward(self, wav): return wav -class Remix_channel(nn.Module): +class RemixChannel(nn.Module): """ - Shuffle sources channels within one batch + Shuffles source channels within a batch. + Args: + proba (float): Probability of applying the channel shuffle. """ def __init__(self, proba=1): @@ -326,7 +352,6 @@ def __init__(self, proba=1): def forward(self, wav): batch, streams, channels, time = wav.size() - device = wav.device if self.training and random.random() < self.proba: drums = wav[:, 0].reshape(-1, time) bass = wav[:, 1].reshape(-1, time) diff --git a/separator/train/loss.py b/separator/train/loss.py index 1dc0b1d..af78ea1 100644 --- a/separator/train/loss.py +++ b/separator/train/loss.py @@ -1,6 +1,5 @@ import warnings -from collections import defaultdict -from typing import Dict, Final, Iterable, List, Literal, Optional, Tuple, Union +from typing import Dict, Final, Iterable, List, Optional, Union import torch import torch as th @@ -59,9 +58,9 @@ def forward(self, input: Tensor): class SpectralLoss(nn.Module): """ - L1 between target magnitude and predicted magnitude - L1 between target phase and predicted phase - L1(target magnitude, predicted magnitude) + L1(target phase, predicted phase) + Calculates the L1 loss between the target and predicted magnitudes, and between the target and predicted phases. + The total loss is the sum of L1 loss for magnitude and L1 loss for phase: + L1(target magnitude, predicted magnitude) + L1(target phase, predicted phase). """ def __init__(self, n_fft=4096): @@ -87,8 +86,9 @@ def forward(self, target, predict): class MultiResSpecLoss(nn.Module): """ - Deep-FilterNet loss - https://github.com/Rikorose/DeepFilterNet/blob/main/DeepFilterNet/df/loss.py#L95 + Determines the discrepancies between the anticipated and actual spectrogram based on Short-Time Fourier Transform (STFT) + with varying windows, utilizing the Mean Square Error (MSE) loss function for calculation. + We use Deep-FilterNet loss https://github.com/Rikorose/DeepFilterNet/blob/main/DeepFilterNet/df/loss.py#L95 """ gamma: Final[float] From f46b0e0b762d2516f771a704bc80022585bac2ff Mon Sep 17 00:00:00 2001 From: maks00170 Date: Fri, 29 Dec 2023 23:44:02 +0300 Subject: [PATCH 06/15] black and refac config --- separator/config/config.py | 167 ++++++++++++++++++------------------- separator/data/dataset.py | 32 ++++--- separator/inference.py | 14 +++- separator/model/STFT.py | 2 +- separator/model/modules.py | 24 +++--- separator/pl_model.py | 50 +++++------ separator/train/augment.py | 29 +++---- separator/train/loss.py | 4 +- streaming/config/config.py | 2 +- streaming/converter.py | 2 +- streaming/runner.py | 3 +- 11 files changed, 165 insertions(+), 164 deletions(-) diff --git a/separator/config/config.py b/separator/config/config.py index 6b27417..b9df4d4 100644 --- a/separator/config/config.py +++ b/separator/config/config.py @@ -3,100 +3,99 @@ from typing import Union +from dataclasses import dataclass +from pathlib import Path +from typing import Union + + @dataclass class TrainConfig: - device: str = "cuda" - - # datasets - musdb_path: str = "musdb18hq" - metadata_train_path: str = "metadata" - metadata_test_path: str = "metadata1" - segment: int = 5 - # dataloaders - batch_size: int = 6 - shuffle_train: bool = True - shuffle_valid: bool = False - drop_last: bool = True - num_workers: int = 2 + # DATA OPTIONS + musdb_path : str = "musdb18hq" # Directory path where the MUSDB18-HQ dataset is stored. + metadata_train_path : str = "metadata" # Directory path for saving training metadata, like track names and lengths. + metadata_test_path : str = "metadata1" # Directory path for saving testing metadata. + segment : int = 5 # Length (in seconds) of each audio segment used during training. - # checkpoint_callback - metric_monitor_mode: str = "min" - save_top_k_model_weights: int = 1 + # MODEL OPTIONS + model_source : tuple = ("drums", "bass", "other", "vocals") # Sources to target in source separation. + model_depth : int = 4 # The depth of the U-Net architecture. + model_channel : int = 28 # Number of initial channels in U-Net layers. + is_mono : bool = False # Indicates whether the input audio should be treated as mono (True) or stereo (False). + mask_mode : bool = False # Whether to utilize masking within the model. + skip_mode : str = "concat" # Mode of skip connections in U-Net ('concat' for concatenation, 'add' for summation). + nfft : int = 4096 # Number of bins used in STFT. + bottlneck_lstm : bool = True # Determines whether to use LSTM layers as bottleneck in the U-Net architecture. + layers : int = 2 # Number of LSTM layers if bottleneck. + stft_flag : bool = True # A flag to decide whether to apply the STFT is required for tflite. - # PM_Unet model - model_source: tuple = ("drums", "bass", "other", "vocals") - model_depth: int = 4 - model_channel: int = 28 - is_mono: bool = False - mask_mode: bool = False - skip_mode: str = "concat" - nfft: int = 4096 - bottlneck_lstm: bool = True - layers: int = 2 - stft_flag: bool = True - # augments - proba_shift: float = 0.5 - shift: int = 8192 - proba_flip_channel: float = 1 - proba_flip_sign: float = 1 - pitchshift_proba: float = 0.2 - vocals_min_semitones: int = -5 - vocals_max_semitones: int = 5 - other_min_semitones: int = -2 - other_max_semitones: int = 2 - pitchshift_flag_other: bool = False - time_change_proba: float = 0.2 - time_change_factors: tuple = (0.8, 0.85, 0.9, 0.95, 1.05, 1.1, 1.15, 1.2, 1.25, 1.3) - remix_proba: float = 1 - remix_group_size: int = batch_size - scale_proba: float = 1 - scale_min: float = 0.25 - scale_max: float = 1.25 - fade_mask_proba: float = 0.1 - double_proba: float = 0.1 - reverse_proba: float = 0.2 - mushap_proba: float = 0.0 - mushap_depth: int = 2 + # TRAIN OPTIONS + device : str = "cuda" # The computing platform for training: 'cuda' for NVIDIA GPUs or 'cpu'. + batch_size : int = 6 # Batch size for training. + shuffle_train : bool = True # Whether to shuffle the training dataset. + shuffle_valid : bool = False # Whether to shuffle the valid dataset. + drop_last : bool = True # Whether to drop the last incomplete batch in train data. + num_workers : int = 2 # Number of worker processes used for loading data. + metric_monitor_mode : str = "min" # Strategy for monitoring metrics to save model checkpoints. + save_top_k_model_weights : int = 1 # Number of best-performing model weights to save based on the monitored metric. + + factor : int = 1 # Factors for different components of the loss function. + c_factor : int = 1 - # loss if there are artifacts while listening, then increase this params - factor: int = 1 - c_factor: int = 1 - loss_nfft: tuple = (4096,) - gamma: float = 0.3 - # lr - lr: float = 0.5 * 3e-3 - T_0: int = 40 + loss_nfft : tuple = (4096,) # Number of FFT bins for calculating loss. + gamma : float = 0.3 # Gamma parameter for adjusting the focus of the loss on certain aspects of the audio spectrum. + lr : float = 0.5 * 3e-3 # Learning rate for the optimizer. + T_0 : int = 40 # Period of the cosine annealing schedule in learning rate adjustment. + max_epochs : int = 100 # Maximum number of training epochs. + precision : str = 16 # Precision of training computations. + grad_clip : float = 0.5 # Gradient clipping value. - # lightning - max_epochs: int = 100 - precision: str = 16 # "bf16-mixed" - grad_clip: float = 0.5 + # AUGMENTATION OPTIONS + proba_shift : float = 0.5 # Probability of applying the shift. + shift : int = 8192 # Maximum number of samples for the shift. + proba_flip_channel : float = 1 # Probability of applying the flip left-right channels. + proba_flip_sign : float = 1 # Probability of applying the sign flip. + pitchshift_proba : float = 0.2 # Probability of applying pitch shift. + vocals_min_semitones : int = -5 # The lower limit of vocal semitones. + vocals_max_semitones : int = 5 # The upper limit of vocal semitones. + other_min_semitones : int = -2 # The lower limit of non-vocal semitones. + other_max_semitones : int = 2 # The upper limit of non-vocal semitones. + pitchshift_flag_other : bool = False # Flag to enable pitch shift augmentation on non-vocal sources. + time_change_proba : float = 0.2 # Probability of applying time stretching. + time_change_factors : tuple = (0.8, 0.85, 0.9, 0.95, 1.05, 1.1, 1.15, 1.2, 1.25, 1.3) # Factors for time stretching/compression, defining the range and intensity of this augmentation. + remix_proba : float = 1 # Probability of remixing audio tracks. + remix_group_size : int = batch_size # Size of groups within which shuffling occurs. + scale_proba : float = 1 # Probability of applying the scaling. + scale_min : float = 0.25 # Minimum scaling factor. + scale_max : float = 1.25 # Maximum scaling factor. + fade_mask_proba : float = 0.1 # Probability of applying a fade effect. + double_proba : float = 0.1 # Probability of doubling one channel's audio to both channels. + reverse_proba : float = 0.2 # Probability of reversing a segment of the audio track. + mushap_proba : float = 0.0 # Probability create mashups. + mushap_depth : int = 2 # Number of tracks to mix. @dataclass class InferenceConfig: - GDRIVE_PREFIX = "https://drive.google.com/uc?id=" - - device: str = "cpu" - - # weights - weights_dir: Path = Path("/app/separator/inference/weights") - weights_LSTM_filename: str = "weight_LSTM.pt" - weights_conv_filename: str = "weight_conv.pt" - gdrive_weights_LSTM: str = f"{GDRIVE_PREFIX}18jT2TYffdRD1fL7wecAiM5nJPM_OKpNB" - gdrive_weights_conv: str = f"{GDRIVE_PREFIX}1VO07OYbsnCuEJYRSuA8HhjlQnx6dbWX7" + GDRIVE_PREFIX = "https://drive.google.com/uc?id=" # Google Drive URL - # inference instance - segment: int = 7 - overlap: float = 0.2 - offset: Union[int, None] = None - duration: Union[int, None] = None + # MODEL OPTIONS + weights_dir : Path = Path("/app/separator/inference/weights") # file name where weights are saved + weights_LSTM_filename : str = "weight_LSTM.pt" # file name model with LSTM + weights_conv_filename : str = "weight_conv.pt" # file name model without LSTM + gdrive_weights_LSTM : str = f"{GDRIVE_PREFIX}1uhAVMvW3x-KL2T2-VkjKjn9K7dTJnoyo" # Google Drive URL that directs weights LSTM + gdrive_weights_conv : str = f"{GDRIVE_PREFIX}1VO07OYbsnCuEJYRSuA8HhjlQnx6dbWX7" # Google Drive URL that directs weights without_LSTM + device : str = "cpu" # The computing platform for inference - # inference - sample_rate: int = 44100 - num_channels: int = 2 - default_result_dir: str = "/app/separator/inference/output" - default_input_dir: str = "/app/separator/inference/input" - # adele - gdrive_mix: str = f"{GDRIVE_PREFIX}1zJpyW1fYxHKXDcDH9s5DiBCYiRpraDB3" + # INFERENCE OPTIONS + segment : int = 7 # Length (in seconds) of each audio segment used during inference. + overlap : float = 0.2 # overlapping segments at the beginning of the track and at the end + offset : Union[int, None] = None # start of segment to split + duration : Union[int, None] = None # end of segment to split + sample_rate : int = 44100 # sample rate track + num_channels : int = 2 # Number of channels in the audio track + default_result_dir : str = "/app/separator/inference/output" # path file output tracks + default_input_dir : str = "/app/separator/inference/input" # path file input track + + # TEST TRACK + gdrive_mix : str = f"{GDRIVE_PREFIX}1zJpyW1fYxHKXDcDH9s5DiBCYiRpraDB3" # Google Drive URL that directs test track diff --git a/separator/data/dataset.py b/separator/data/dataset.py index 0f3567c..01825fa 100644 --- a/separator/data/dataset.py +++ b/separator/data/dataset.py @@ -57,9 +57,9 @@ def get_musdb_wav_datasets( # Create a unique identifier for the dataset configuration. sig = hashlib.sha1(str(musdb).encode()).hexdigest()[:8] - + metadata_file = Path(metadata) / f"musdb_{sig}.json" - root = Path(musdb) / data_type + root = Path(musdb) / data_type # Build metadata if not already present. if not metadata_file.is_file(): @@ -72,7 +72,11 @@ def get_musdb_wav_datasets( # Filter tracks for training or validation based on the configuration. valid_tracks = _get_musdb_valid() # Retrieve a list of valid track names. - metadata_train = metadata if train_valid else {name: meta for name, meta in metadata.items() if name not in valid_tracks} + metadata_train = ( + metadata + if train_valid + else {name: meta for name, meta in metadata.items() if name not in valid_tracks} + ) # Configure and return the dataset instance. data_set = Wavset( @@ -144,7 +148,9 @@ def __init__( if segment is None or track_duration < segment: examples = 1 else: - examples = int(math.ceil((track_duration - self.segment) / self.shift) + 1) + examples = int( + math.ceil((track_duration - self.segment) / self.shift) + 1 + ) self.num_examples.append(examples) def __len__(self): @@ -171,7 +177,7 @@ def __getitem__(self, index): # Access metadata for the current source meta = self.metadata[name] - + # Calculate offset and number of frames if segmenting is enabled num_frames, offset = -1, 0 if self.segment is not None: @@ -182,7 +188,9 @@ def __getitem__(self, index): wavs = [] for source in self.sources: file_path = self.get_file(name, source) - wav, _ = ta.load(str(file_path), frame_offset=offset, num_frames=num_frames) + wav, _ = ta.load( + str(file_path), frame_offset=offset, num_frames=num_frames + ) wav = self.__convert_audio_channels(wav, self.channels) wavs.append(wav) @@ -235,7 +243,9 @@ def __convert_audio_channels(self, wav, desired_channels=2): return wav.expand(*shape, desired_channels, length) else: # Invalid case: input has fewer channels than desired and is not mono - raise ValueError("Cannot upmix from fewer than 1 channel unless the source is mono.") + raise ValueError( + "Cannot upmix from fewer than 1 channel unless the source is mono." + ) class MetaData: @@ -272,7 +282,9 @@ def __track_metadata(track, sources, normalize=True, ext=File.EXT): logging.error(f"{source_file} is invalid") raise - length, sample_rate = MetaData.__validate_track(info, track_length, track_samplerate, source_file) + length, sample_rate = MetaData.__validate_track( + info, track_length, track_samplerate, source_file + ) if track_length is None: track_length, track_samplerate = length, sample_rate @@ -345,9 +357,7 @@ def __create_mixture(track, sources, ext): would_clip = audio.abs().max() >= 1 if would_clip: - assert ( - ta.get_audio_backend() == "soundfile" - ), "use dset.backend=soundfile" + assert ta.get_audio_backend() == "soundfile", "use dset.backend=soundfile" return audio, sr diff --git a/separator/inference.py b/separator/inference.py index 62b2ea3..3c49363 100644 --- a/separator/inference.py +++ b/separator/inference.py @@ -32,10 +32,14 @@ def __init__(self, config, model_bottlneck_lstm=True): def resolve_weigths(self): if self.model_bottlneck_lstm: - self.weights_path = self.config.weights_dir / self.config.weights_LSTM_filename + self.weights_path = ( + self.config.weights_dir / self.config.weights_LSTM_filename + ) gdrive_url = self.config.gdrive_weights_LSTM else: - self.weights_path = self.config.weights_dir / self.config.weights_conv_filename + self.weights_path = ( + self.config.weights_dir / self.config.weights_conv_filename + ) gdrive_url = self.config.gdrive_weights_conv try: @@ -69,7 +73,7 @@ def track(self, sample_mixture_path): # Do separation sources = self.separate_sources(mixture[None], sample_rate=sr) - # Denormalize + # Denormalize sources = sources * ref.std() + ref.mean() sources_list = ["drums", "bass", "other", "vocals"] B, S, C, T = sources.shape @@ -119,7 +123,9 @@ def separate_sources(self, mix, sample_rate): final[:, :, :, start:end] += separated_sources # Adjust the start and end for the next chunk, and update fade parameters - start, end = self.__update_chunk_indices(start, end, chunk_len, overlap_frames, length, fade) + start, end = self.__update_chunk_indices( + start, end, chunk_len, overlap_frames, length, fade + ) return final diff --git a/separator/model/STFT.py b/separator/model/STFT.py index 666bccc..9e9f572 100644 --- a/separator/model/STFT.py +++ b/separator/model/STFT.py @@ -19,7 +19,7 @@ def __pad1d( value: float = 0.0, ): """ - Tiny wrapper around F.pad, designed to allow reflect padding on small inputs. + Tiny wrapper around F.pad, designed to allow reflect padding on small inputs. If the input is too small for reflect padding, we first add extra zero padding to the right before reflection occurs. """ x0 = x diff --git a/separator/model/modules.py b/separator/model/modules.py index 781d118..063bc10 100644 --- a/separator/model/modules.py +++ b/separator/model/modules.py @@ -16,6 +16,7 @@ class DownSample(nn.Module): activation (object): Activation layer. normalization (object): Normalization layer. """ + def __init__( self, input_channel, @@ -26,7 +27,6 @@ def __init__( activation, normalization, ): - super().__init__() self.conv_layer = nn.Sequential( @@ -58,6 +58,7 @@ class UpSample(nn.Module): activation (object): Activation layer. normalization (object): Normalization layer. """ + def __init__( self, input_channel, @@ -68,7 +69,6 @@ def __init__( activation, normalization, ): - super().__init__() self.convT_layer = nn.Sequential( @@ -97,8 +97,8 @@ class InceptionBlock(nn.Module): activation (object): Activation layer. normalization (object): Normalization layer. """ + def __init__(self, input_channel, out_channel, activation, normalization): - super().__init__() self.conv_layer_1 = nn.Sequential( @@ -158,6 +158,7 @@ class Encoder(nn.Module): activation (object): Activation layer. normalization (object): Normalization layer. """ + def __init__( self, input_channel, @@ -168,7 +169,6 @@ def __init__( activation, normalization, ): - super().__init__() self.inception_layer = InceptionBlock( @@ -202,6 +202,7 @@ class Decoder(nn.Module): activation (object): Activation layer. normalization (object): Normalization layer. """ + def __init__( self, input_channel, @@ -212,7 +213,6 @@ def __init__( activation, normalization, ): - super().__init__() self.inception_layer = InceptionBlock( @@ -236,8 +236,8 @@ def forward(self, x): class BLSTM(nn.Module): """ - A bidirectional LSTM (BiLSTM) module with the same number of hidden units as the input dimension. - This module can process inputs in overlapping chunks if `max_steps` is specified. + A bidirectional LSTM (BiLSTM) module with the same number of hidden units as the input dimension. + This module can process inputs in overlapping chunks if `max_steps` is specified. In this case, the input will be split into chunks, and the LSTM will be applied to each chunk separately. Args: dim (int): The number of dimensions in the input and the hidden state of the LSTM. @@ -322,6 +322,7 @@ class Bottleneck_v2(nn.Module): activation (object): Activation layer. normalization (object): Normalization layer. """ + def __init__( self, input_channel, @@ -334,14 +335,17 @@ def __init__( stride=1, padding="same", ): - super().__init__() self.conv_layer = nn.Sequential( normalization(input_channel, affine=True), activation, nn.Conv1d( - input_channel, out_channel, kernel_size=3, stride=stride, padding=padding + input_channel, + out_channel, + kernel_size=3, + stride=stride, + padding=padding, ), ) @@ -372,8 +376,8 @@ class Bottleneck(nn.Module): activation (object): Activation layer. normalization (object): Normalization layer. """ + def __init__(self, input_channel, out_channels, normalization, activation): - super().__init__() self.conv_layer_1 = nn.Sequential( diff --git a/separator/pl_model.py b/separator/pl_model.py index e10328b..cf6c9e9 100644 --- a/separator/pl_model.py +++ b/separator/pl_model.py @@ -14,6 +14,7 @@ signal_distortion_ratio, ) + class PM_model(pl.LightningModule): def __init__(self, config): super().__init__() @@ -32,7 +33,7 @@ def __init__(self, config): ) # loss - # Loss = (L_1 + L_{MRS} - L_{SISDR}) + # Loss = (L_1 + L_{MRS} - L_{SISDR}) self.criterion_1 = nn.L1Loss() self.criterion_2 = MultiResSpecLoss( factor=config.factor, @@ -43,7 +44,9 @@ def __init__(self, config): self.criterion_3 = ScaleInvariantSignalDistortionRatio() # augment - self.augment = [augment.Shift(proba=config.proba_shift, shift=config.shift, same=True)] + self.augment = [ + augment.Shift(proba=config.proba_shift, shift=config.shift, same=True) + ] self.augment += [ augment.PitchShift( proba=config.pitchshift_proba, @@ -76,7 +79,7 @@ def __init__(self, config): def __init_weights(self, m): if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): torch.nn.init.xavier_uniform(m.weight) - + def __usdr(self, predT, tgtT, delta=1e-7): """ latex: $usdr=10\log_{10} (\dfrac{\| tgtT\|^2 + \delta}{ \| predT - tgtT\| ^{2} + \delta})$ @@ -88,7 +91,6 @@ def __usdr(self, predT, tgtT, delta=1e-7): usdr = 10 * torch.log10(num / den) return usdr.mean() - def forward(self, x): x = self.model(x) return x @@ -99,7 +101,7 @@ def loss(self, y_true, y_pred): self.criterion_1(y_pred, y_true) + self.criterion_2(y_pred, y_true) - self.criterion_3(y_pred, y_true) - )/3 + ) / 3 return loss def training_step(self, batch, batch_idx): @@ -122,7 +124,9 @@ def training_step(self, batch, batch_idx): vocals_loss = self.loss(vocals_pred, vocals_target) - loss = 0.25 * (drums_loss + bass_loss + other_loss + vocals_loss) # losses averaged across sources + loss = 0.25 * ( + drums_loss + bass_loss + other_loss + vocals_loss + ) # losses averaged across sources self.log_dict( { @@ -179,18 +183,10 @@ def training_step(self, batch, batch_idx): self.log_dict( { - "train_drums_usdr": self.__usdr( - drums_pred, drums_target - ).mean(), - "train_bass_usdr": self.__usdr( - bass_pred, bass_target - ).mean(), - "train_other_usdr": self.__usdr( - other_pred, other_target - ).mean(), - "train_vocals_usdr": self.__usdr( - vocals_pred, vocals_target - ).mean(), + "train_drums_usdr": self.__usdr(drums_pred, drums_target).mean(), + "train_bass_usdr": self.__usdr(bass_pred, bass_target).mean(), + "train_other_usdr": self.__usdr(other_pred, other_target).mean(), + "train_vocals_usdr": self.__usdr(vocals_pred, vocals_target).mean(), }, on_epoch=True, prog_bar=False, @@ -272,21 +268,12 @@ def validation_step(self, batch, batch_idx): sync_dist=True, ) - self.log_dict( { - "valid_drums_usdr": self.__usdr( - drums_pred, drums_target - ).mean(), - "valid_bass_usdr": self.__usdr( - bass_pred, bass_target - ).mean(), - "valid_other_usdr": self.__usdr( - other_pred, other_target - ).mean(), - "valid_vocals_usdr": self.__usdr( - vocals_pred, vocals_target - ).mean(), + "valid_drums_usdr": self.__usdr(drums_pred, drums_target).mean(), + "valid_bass_usdr": self.__usdr(bass_pred, bass_target).mean(), + "valid_other_usdr": self.__usdr(other_pred, other_target).mean(), + "valid_vocals_usdr": self.__usdr(vocals_pred, vocals_target).mean(), }, on_epoch=True, prog_bar=False, @@ -304,6 +291,7 @@ def configure_optimizers(self): "monitor": "valid_loss", } + def main(config): from data.dataset import get_musdb_wav_datasets from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor diff --git a/separator/train/augment.py b/separator/train/augment.py index a7efabe..88ad06d 100644 --- a/separator/train/augment.py +++ b/separator/train/augment.py @@ -25,10 +25,10 @@ def __init__(self, proba=1, shift=8192, same=False): def forward(self, wav): if self.shift < 1: return wav - + batch, sources, channels, time = wav.size() length = time - self.shift - + if random.random() < self.proba: srcs = 1 if self.same else sources offsets = th.randint(self.shift, [batch, srcs, 1, 1], device=wav.device) @@ -38,18 +38,17 @@ def forward(self, wav): return wav - class FlipChannels(nn.Module): """ Flip left-right channels. Args: proba (float): Probability of applying the flip left-right channels. """ + def __init__(self, proba=1): super().__init__() self.proba = proba - - + def forward(self, wav): batch, sources, channels, time = wav.size() if wav.size(2) == 2: @@ -67,12 +66,12 @@ class FlipSign(nn.Module): Args: proba (float): Probability of applying the sign flip. """ + def __init__(self, proba=1): super().__init__() self.proba = proba - - + def forward(self, wav): batch, sources, channels, time = wav.size() if random.random() < self.proba: @@ -90,8 +89,8 @@ class Remix(nn.Module): proba (float): Probability of applying the shuffle. group_size (int): Size of groups within which shuffling occurs. """ + def __init__(self, proba=1, group_size=4): - super().__init__() self.proba = proba self.group_size = group_size @@ -124,8 +123,8 @@ class Scale(nn.Module): min (float): Minimum scaling factor. max (float): Maximum scaling factor. """ + def __init__(self, proba=1.0, min=0.25, max=1.25): - super().__init__() self.proba = proba self.min = min @@ -152,7 +151,6 @@ class FadeMask(nn.Module): """ def __init__(self, proba=1, sample_rate=44100, time_mask_param=2): - super().__init__() self.sample_rate = sample_rate self.time_mask = torchaudio.transforms.TimeMasking( @@ -194,7 +192,6 @@ def __init__( sample_rate=44100, flag_other=False, ): - super().__init__() self.pitch_vocals = ps( p=proba, @@ -231,8 +228,8 @@ class TimeChange(nn.Module): proba (float): Probability of applying the time change. sample_rate (int): Sample rate of audio. """ + def __init__(self, factors_list, proba=1, sample_rate=44100): - super().__init__() self.sample_rate = sample_rate self.proba = proba @@ -248,7 +245,6 @@ def forward(self, wav): return wav - class Double(nn.Module): """ With equal probability, makes both channels the same as either the left or right original channel. @@ -261,7 +257,6 @@ def __init__(self, proba=1): self.proba = proba def forward(self, wav): - if random.random() < self.proba: wav = wav.clone() @@ -285,11 +280,9 @@ class Reverse(nn.Module): Args: proba (float): Probability of applying the reversal. min_band_part (float): Minimum fraction of the track to be inverted. - max_band_part (float): Maximum fraction of the track to be inverted. -""" + max_band_part (float): Maximum fraction of the track to be inverted.""" def __init__(self, proba=1, min_band_part=0.2, max_band_part=0.4): - super().__init__() self.proba = proba self.min_band_part = min_band_part @@ -321,8 +314,8 @@ class RemixWave(nn.Module): group_size (int): Group size for mashup. mix_depth (int): Number of tracks to mix. """ + def __init__(self, proba=1, group_size=4, mix_depth=2): - super().__init__() self.proba = proba self.remix = Remix(proba=1, group_size=group_size) diff --git a/separator/train/loss.py b/separator/train/loss.py index af78ea1..91f8ee2 100644 --- a/separator/train/loss.py +++ b/separator/train/loss.py @@ -86,9 +86,9 @@ def forward(self, target, predict): class MultiResSpecLoss(nn.Module): """ - Determines the discrepancies between the anticipated and actual spectrogram based on Short-Time Fourier Transform (STFT) + Determines the discrepancies between the anticipated and actual spectrogram based on Short-Time Fourier Transform (STFT) with varying windows, utilizing the Mean Square Error (MSE) loss function for calculation. - We use Deep-FilterNet loss https://github.com/Rikorose/DeepFilterNet/blob/main/DeepFilterNet/df/loss.py#L95 + We use loss from Deep-FilterNet https://github.com/Rikorose/DeepFilterNet/blob/main/DeepFilterNet/df/loss.py#L95 """ gamma: Final[float] diff --git a/streaming/config/config.py b/streaming/config/config.py index 789d9e1..787cf96 100644 --- a/streaming/config/config.py +++ b/streaming/config/config.py @@ -17,7 +17,7 @@ class ConverterConfig: tflite_model_dst: str = "tflite_model" sample_rate: int = 44100 - segment_duration: float = 1. + segment_duration: float = 1.0 @dataclass diff --git a/streaming/converter.py b/streaming/converter.py index 3bcdd11..ecc287c 100644 --- a/streaming/converter.py +++ b/streaming/converter.py @@ -236,7 +236,7 @@ def istft(self, z): ) model_filename = f"{args.class_name}_outer_stft_{config.segment_duration:.1f}" - model_path = args.out_dir + '/' + model_filename + model_path = args.out_dir + "/" + model_filename keras_model.save(model_path + ".h5") custom_objects = {"WeightLayer": WeightLayer} diff --git a/streaming/runner.py b/streaming/runner.py index b0992cb..28e1d4d 100644 --- a/streaming/runner.py +++ b/streaming/runner.py @@ -36,7 +36,8 @@ def main(args, config): if start_converter: subprocess.Popen( ["python3", config.StreamConfig.converter_script], - executable="/bin/bash", shell=True + executable="/bin/bash", + shell=True, ) converter_outputs = os.listdir(config.ConverterConfig.tflite_model_dst) From d8ea2f8e17a2c2a41b5daffa108ee423f2464428 Mon Sep 17 00:00:00 2001 From: maks00170 Date: Fri, 29 Dec 2023 23:49:15 +0300 Subject: [PATCH 07/15] one del space config --- separator/config/config.py | 74 +++++++++++++++++++------------------- 1 file changed, 37 insertions(+), 37 deletions(-) diff --git a/separator/config/config.py b/separator/config/config.py index b9df4d4..0b5d0c7 100644 --- a/separator/config/config.py +++ b/separator/config/config.py @@ -51,28 +51,28 @@ class TrainConfig: grad_clip : float = 0.5 # Gradient clipping value. # AUGMENTATION OPTIONS - proba_shift : float = 0.5 # Probability of applying the shift. - shift : int = 8192 # Maximum number of samples for the shift. - proba_flip_channel : float = 1 # Probability of applying the flip left-right channels. - proba_flip_sign : float = 1 # Probability of applying the sign flip. - pitchshift_proba : float = 0.2 # Probability of applying pitch shift. - vocals_min_semitones : int = -5 # The lower limit of vocal semitones. - vocals_max_semitones : int = 5 # The upper limit of vocal semitones. - other_min_semitones : int = -2 # The lower limit of non-vocal semitones. - other_max_semitones : int = 2 # The upper limit of non-vocal semitones. - pitchshift_flag_other : bool = False # Flag to enable pitch shift augmentation on non-vocal sources. - time_change_proba : float = 0.2 # Probability of applying time stretching. - time_change_factors : tuple = (0.8, 0.85, 0.9, 0.95, 1.05, 1.1, 1.15, 1.2, 1.25, 1.3) # Factors for time stretching/compression, defining the range and intensity of this augmentation. - remix_proba : float = 1 # Probability of remixing audio tracks. - remix_group_size : int = batch_size # Size of groups within which shuffling occurs. - scale_proba : float = 1 # Probability of applying the scaling. - scale_min : float = 0.25 # Minimum scaling factor. - scale_max : float = 1.25 # Maximum scaling factor. - fade_mask_proba : float = 0.1 # Probability of applying a fade effect. - double_proba : float = 0.1 # Probability of doubling one channel's audio to both channels. - reverse_proba : float = 0.2 # Probability of reversing a segment of the audio track. - mushap_proba : float = 0.0 # Probability create mashups. - mushap_depth : int = 2 # Number of tracks to mix. + proba_shift : float = 0.5 # Probability of applying the shift. + shift : int = 8192 # Maximum number of samples for the shift. + proba_flip_channel : float = 1 # Probability of applying the flip left-right channels. + proba_flip_sign : float = 1 # Probability of applying the sign flip. + pitchshift_proba : float = 0.2 # Probability of applying pitch shift. + vocals_min_semitones : int = -5 # The lower limit of vocal semitones. + vocals_max_semitones : int = 5 # The upper limit of vocal semitones. + other_min_semitones : int = -2 # The lower limit of non-vocal semitones. + other_max_semitones : int = 2 # The upper limit of non-vocal semitones. + pitchshift_flag_other : bool = False # Flag to enable pitch shift augmentation on non-vocal sources. + time_change_proba : float = 0.2 # Probability of applying time stretching. + time_change_factors : tuple = (0.8, 0.85, 0.9, 0.95, 1.05, 1.1, 1.15, 1.2, 1.25, 1.3) # Factors for time stretching/compression, defining the range and intensity of this augmentation. + remix_proba : float = 1 # Probability of remixing audio tracks. + remix_group_size : int = batch_size # Size of groups within which shuffling occurs. + scale_proba : float = 1 # Probability of applying the scaling. + scale_min : float = 0.25 # Minimum scaling factor. + scale_max : float = 1.25 # Maximum scaling factor. + fade_mask_proba : float = 0.1 # Probability of applying a fade effect. + double_proba : float = 0.1 # Probability of doubling one channel's audio to both channels. + reverse_proba : float = 0.2 # Probability of reversing a segment of the audio track. + mushap_proba : float = 0.0 # Probability create mashups. + mushap_depth : int = 2 # Number of tracks to mix. @dataclass @@ -80,22 +80,22 @@ class InferenceConfig: GDRIVE_PREFIX = "https://drive.google.com/uc?id=" # Google Drive URL # MODEL OPTIONS - weights_dir : Path = Path("/app/separator/inference/weights") # file name where weights are saved - weights_LSTM_filename : str = "weight_LSTM.pt" # file name model with LSTM - weights_conv_filename : str = "weight_conv.pt" # file name model without LSTM - gdrive_weights_LSTM : str = f"{GDRIVE_PREFIX}1uhAVMvW3x-KL2T2-VkjKjn9K7dTJnoyo" # Google Drive URL that directs weights LSTM - gdrive_weights_conv : str = f"{GDRIVE_PREFIX}1VO07OYbsnCuEJYRSuA8HhjlQnx6dbWX7" # Google Drive URL that directs weights without_LSTM - device : str = "cpu" # The computing platform for inference + weights_dir : Path = Path("/app/separator/inference/weights") # file name where weights are saved + weights_LSTM_filename : str = "weight_LSTM.pt" # file name model with LSTM + weights_conv_filename : str = "weight_conv.pt" # file name model without LSTM + gdrive_weights_LSTM : str = f"{GDRIVE_PREFIX}1uhAVMvW3x-KL2T2-VkjKjn9K7dTJnoyo" # Google Drive URL that directs weights LSTM + gdrive_weights_conv : str = f"{GDRIVE_PREFIX}1VO07OYbsnCuEJYRSuA8HhjlQnx6dbWX7" # Google Drive URL that directs weights without_LSTM + device : str = "cpu" # The computing platform for inference # INFERENCE OPTIONS - segment : int = 7 # Length (in seconds) of each audio segment used during inference. - overlap : float = 0.2 # overlapping segments at the beginning of the track and at the end - offset : Union[int, None] = None # start of segment to split - duration : Union[int, None] = None # end of segment to split - sample_rate : int = 44100 # sample rate track - num_channels : int = 2 # Number of channels in the audio track - default_result_dir : str = "/app/separator/inference/output" # path file output tracks - default_input_dir : str = "/app/separator/inference/input" # path file input track + segment : int = 7 # Length (in seconds) of each audio segment used during inference. + overlap : float = 0.2 # overlapping segments at the beginning of the track and at the end + offset : Union[int, None] = None # start of segment to split + duration : Union[int, None] = None # end of segment to split + sample_rate : int = 44100 # sample rate track + num_channels : int = 2 # Number of channels in the audio track + default_result_dir : str = "/app/separator/inference/output" # path file output tracks + default_input_dir : str = "/app/separator/inference/input" # path file input track # TEST TRACK - gdrive_mix : str = f"{GDRIVE_PREFIX}1zJpyW1fYxHKXDcDH9s5DiBCYiRpraDB3" # Google Drive URL that directs test track + gdrive_mix : str = f"{GDRIVE_PREFIX}1zJpyW1fYxHKXDcDH9s5DiBCYiRpraDB3" # Google Drive URL that directs test track From d4c60de9805fa2207e791ac5c7937156a43d90b1 Mon Sep 17 00:00:00 2001 From: maks00170 Date: Sat, 30 Dec 2023 19:43:24 +0300 Subject: [PATCH 08/15] refac stream config --- streaming/config/config.py | 46 ++++++++++++++++++++------------------ 1 file changed, 24 insertions(+), 22 deletions(-) diff --git a/streaming/config/config.py b/streaming/config/config.py index 787cf96..de2fabd 100644 --- a/streaming/config/config.py +++ b/streaming/config/config.py @@ -4,29 +4,31 @@ @dataclass class ConverterConfig: - weights_dir: Path = Path("/app/streaming/weights") - weights_LSTM_filename: str = "weight_LSTM.pt" - weights_conv_filename: str = "weight_conv.pt" - gdrive_weights_LSTM_id: str = "18jT2TYffdRD1fL7wecAiM5nJPM_OKpNB" - gdrive_weights_conv_id: str = "1VO07OYbsnCuEJYRSuA8HhjlQnx6dbWX7" - - original_model_src: str = "/app/separator/model" - original_model_dst: str = "/app/streaming/model" - model_py_module: str = "model.PM_Unet" - model_class_name: str = "Model_Unet" - tflite_model_dst: str = "tflite_model" - - sample_rate: int = 44100 - segment_duration: float = 1.0 + # WEIGHTS LOAD + weights_dir : Path = Path("/app/streaming/weights") # Path to the directory where the model weight files are stored. + weights_LSTM_filename : str = "weight_LSTM.pt" # This is the filename for the LSTM weights file. + weights_conv_filename : str = "weight_conv.pt" # This is the filename for the without CNN weights file. + gdrive_weights_LSTM_id : str = "18jT2TYffdRD1fL7wecAiM5nJPM_OKpNB" # This is the Google Drive ID for the LSTM weights file. + gdrive_weights_conv_id : str = "1VO07OYbsnCuEJYRSuA8HhjlQnx6dbWX7" # This is the Google Drive ID for the CNN weights file. + + # MODEL OPTIONS + original_model_src : str = "/app/separator/model" # This parameter represents the source directory of the original model. + original_model_dst : str = "/app/streaming/model" # This parameter represents the destination directory of the original model. + model_py_module : str = "model.PM_Unet" # This is the python module where the model is defined + model_class_name : str = "Model_Unet" # The name of the model class. + tflite_model_dst : str = "tflite_model" # This is the destination directory for the TFLite model. + sample_rate : int = 44100 # Sample rate track + segment_duration : float = 1.0 # This parameter represents the duration of the audio segments that the model will process. @dataclass class StreamConfig: - converter_script: str = "/app/streaming/converter.py" - sample_rate: int = 44100 - nfft: int = 4096 - stft_py_module: str = "model.STFT" - default_input_path: str = "/app/streaming/input" - default_result_dir: str = "/app/streaming/streams" - gdrive_mix_id: str = "1zJpyW1fYxHKXDcDH9s5DiBCYiRpraDB3" - default_duration: int = 15 + # STREAM OPTIONS + converter_script : str = "/app/streaming/converter.py" # Path to the script used to convert the pytorch model to tflite. + sample_rate : int = 44100 # Sample rate track. + nfft : int = 4096 # Number of bins used in STFT. + stft_py_module : str = "model.STFT" # Path to the script STFT. + default_input_path : str = "/app/streaming/input" # Path to the directory where the input files are stored. + default_result_dir : str = "/app/streaming/streams" # Path directory in which processing results are saved. + gdrive_mix_id : str = "1zJpyW1fYxHKXDcDH9s5DiBCYiRpraDB3" # The Google Drive ID for the mix file. + default_duration : int = 15 # Length of an audio stream, in seconds. From 08fefd17e6cf04d254ccb13805d2d88f3690536a Mon Sep 17 00:00:00 2001 From: maks00170 Date: Sat, 30 Dec 2023 19:45:23 +0300 Subject: [PATCH 09/15] ref comment config --- streaming/config/config.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/streaming/config/config.py b/streaming/config/config.py index de2fabd..971b2ab 100644 --- a/streaming/config/config.py +++ b/streaming/config/config.py @@ -5,20 +5,20 @@ @dataclass class ConverterConfig: # WEIGHTS LOAD - weights_dir : Path = Path("/app/streaming/weights") # Path to the directory where the model weight files are stored. - weights_LSTM_filename : str = "weight_LSTM.pt" # This is the filename for the LSTM weights file. - weights_conv_filename : str = "weight_conv.pt" # This is the filename for the without CNN weights file. - gdrive_weights_LSTM_id : str = "18jT2TYffdRD1fL7wecAiM5nJPM_OKpNB" # This is the Google Drive ID for the LSTM weights file. - gdrive_weights_conv_id : str = "1VO07OYbsnCuEJYRSuA8HhjlQnx6dbWX7" # This is the Google Drive ID for the CNN weights file. + weights_dir : Path = Path("/app/streaming/weights") # Path to the directory where the model weight files are stored. + weights_LSTM_filename : str = "weight_LSTM.pt" # This is the filename for the LSTM weights file. + weights_conv_filename : str = "weight_conv.pt" # This is the filename for the without CNN weights file. + gdrive_weights_LSTM_id : str = "18jT2TYffdRD1fL7wecAiM5nJPM_OKpNB" # This is the Google Drive ID for the LSTM weights file. + gdrive_weights_conv_id : str = "1VO07OYbsnCuEJYRSuA8HhjlQnx6dbWX7" # This is the Google Drive ID for the CNN weights file. # MODEL OPTIONS - original_model_src : str = "/app/separator/model" # This parameter represents the source directory of the original model. - original_model_dst : str = "/app/streaming/model" # This parameter represents the destination directory of the original model. - model_py_module : str = "model.PM_Unet" # This is the python module where the model is defined - model_class_name : str = "Model_Unet" # The name of the model class. - tflite_model_dst : str = "tflite_model" # This is the destination directory for the TFLite model. - sample_rate : int = 44100 # Sample rate track - segment_duration : float = 1.0 # This parameter represents the duration of the audio segments that the model will process. + original_model_src : str = "/app/separator/model" # This parameter represents the source directory of the original model. + original_model_dst : str = "/app/streaming/model" # This parameter represents the destination directory of the original model. + model_py_module : str = "model.PM_Unet" # This is the python module where the model is defined + model_class_name : str = "Model_Unet" # The name of the model class. + tflite_model_dst : str = "tflite_model" # This is the destination directory for the TFLite model. + sample_rate : int = 44100 # Sample rate track + segment_duration : float = 1.0 # This parameter represents the duration of the audio segments that the model will process. @dataclass From ae7eeea2bf5794106aae2eff1d713dd12182519d Mon Sep 17 00:00:00 2001 From: d-a-yakovlev Date: Tue, 9 Jan 2024 07:10:39 +0300 Subject: [PATCH 10/15] refac: docker debugged 1 --- .dockerignore | 9 +++ .gitignore | 8 +++ Dockerfile | 19 ++---- requirements.txt | 124 +----------------------------------- streaming/config/config.py | 16 ++--- streaming/converter.py | 24 +++---- streaming/runner.py | 83 ++++++++++++++++++++---- streaming/tf_lite_stream.py | 16 ++--- 8 files changed, 122 insertions(+), 177 deletions(-) create mode 100644 .dockerignore create mode 100644 .gitignore diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..7332e35 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,9 @@ +.github +**/__pycache__/ +separator/inference/ +streaming/weights/ +streaming/input/ +streaming/streams/ +streaming/model/ +streaming/tflite_model/ + diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..e510821 --- /dev/null +++ b/.gitignore @@ -0,0 +1,8 @@ +.vscode +**/__pycache__/ +separator/inference/ +streaming/weights +streaming/input/ +streaming/streams/ +streaming/model/ +streaming/tflite_model/ \ No newline at end of file diff --git a/Dockerfile b/Dockerfile index 942470a..d95aad3 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,19 +1,10 @@ -FROM nvidia/cuda:11.8.0-runtime-ubuntu22.04 +FROM nvcr.io/nvidia/tensorrt:22.08-py3 -ENV NV_CUDNN_VERSION 8.6.0.163 -ENV NV_CUDNN_PACKAGE_NAME "libcudnn8" +ENV PYTHONUNBUFFERED=1 -ENV NV_CUDNN_PACKAGE "$NV_CUDNN_PACKAGE_NAME=$NV_CUDNN_VERSION-1+cuda11.8" -ENV DEBIAN_FRONTEND=noninteractive - -RUN apt-get -y update && apt-get -y upgrade && apt-get install -y --no-install-recommends ffmpeg -RUN apt-get update && apt-get install -y --no-install-recommends \ - ${NV_CUDNN_PACKAGE} \ - unzip \ - && apt-mark hold ${NV_CUDNN_PACKAGE_NAME} \ - && rm -rf /var/lib/apt/lists/* -RUN apt-get update -y \ - && apt-get install -y python3-pip +RUN apt-get -y update && apt-get -y upgrade +RUN apt-get install -y --no-install-recommends ffmpeg +RUN apt-get install -y python3-pip RUN echo 'alias python=python3' >> ~/.bashrc RUN echo 'NCCL_SOCKET_IFNAME=lo' >> ~/.bashrc diff --git a/requirements.txt b/requirements.txt index 9ab2926..96902a1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,139 +1,19 @@ -aiohttp==3.8.4 -aiosignal==1.3.1 -antlr4-python3-runtime==4.9.3 -appdirs==1.4.4 -asttokens -async-timeout==4.0.2 -attrs==23.1.0 -audioread==3.0.0 -backcall -certifi==2023.5.7 -cffi==1.15.1 -charset-normalizer==3.1.0 -cmake==3.26.4 -comm -contourpy -cycler -Cython==0.29.35 -debugpy -decorator -diffq==0.2.4 -einops==0.6.1 -executing -fast-bss-eval==0.1.4 ffmpeg-python==0.2.0 -filelock==3.12.0 -fonttools==4.25.0 -frozenlist==1.3.3 -fsspec==2023.6.0 -future==0.18.3 gdown -idna==3.4 -ipykernel -ipython -jedi -Jinja2==3.1.2 -joblib==1.3.1 -jsonschema==4.19.0 -jsonschema-specifications==2023.7.1 julius==0.2.7 -jupyter_client -jupyter_core -kiwisolver -lameenc==1.4.2 -lazy_loader==0.3 -librosa==0.10.0.post2 -lightning-utilities==0.8.0 -lit==16.0.5.post0 -llvmlite==0.40.1 lpips==0.1.4 -MarkupSafe==2.1.3 -matplotlib -matplotlib-inline -mir-eval==0.7 -mkl-fft==1.3.6 -mkl-random -mkl-service==2.4.0 -mpmath==1.3.0 -msgpack==1.0.5 -multidict==6.0.4 -munkres==1.1.4 musdb==0.4.0 -museval==0.4.1 -nest-asyncio -networkx==3.1 -numba==0.57.1 -numpy #==1.24.4 nobuco -nvidia-cublas-cu11==11.10.3.66 -nvidia-cuda-cupti-cu11==11.7.101 -nvidia-cuda-nvrtc-cu11==11.7.99 -nvidia-cuda-runtime-cu11==11.7.99 -nvidia-cudnn-cu11==8.5.0.96 -nvidia-cufft-cu11==10.9.0.58 -nvidia-curand-cu11==10.2.10.91 -nvidia-cusolver-cu11==11.4.0.1 -nvidia-cusparse-cu11==11.7.4.91 -nvidia-nccl-cu11==2.14.3 -nvidia-nvtx-cu11==11.7.91 omegaconf==2.3.0 openunmix==1.2.1 -packaging -pandas==2.1.0 -parso -pexpect -pickleshare -Pillow==9.5.0 -platformdirs -ply==3.11 -pooch==1.6.0 -primePy==1.3 -prompt-toolkit -psutil -ptyprocess -pure-eval -pyaml==23.5.9 -pycparser==2.21 -pyee==10.0.1 -Pygments -pyparsing -PyQt5-sip==12.11.0 -PySoundFile==0.9.0.post1 -python-dateutil -python-ffmpeg==2.0.4 -pytorch-lightning==2.0.3 -pytz==2023.3 -PyYAML==6.0 -pyzmq -referencing==0.30.2 -requests==2.31.0 -rpds-py==0.10.0 -scikit-learn==1.3.0 -scipy==1.10.1 -simplejson==3.19.1 -sip -six soundfile==0.12.1 sox==1.4.1 -soxr==0.3.5 -stack-data stempeg==0.2.3 sympy==1.12 -tensorflow>=2.13.0 #.* -threadpoolctl==3.1.0 -toml +tensorflow>=2.13.0 torch==2.0.1 torch-audiomentations==0.11.0 -torch-pitch-shift==1.2.4 torchaudio==2.0.2 torchmetrics==0.11.4 -torchvision==0.15.2 -tornado +pytorch-lightning==2.0.3 tqdm==4.65.0 -traitlets -triton==2.0.0 -typing_extensions>=4.6.1 -tzdata==2023.3 -urllib3==2.0.3 -wcwidth -yarl==1.9.2 diff --git a/streaming/config/config.py b/streaming/config/config.py index 971b2ab..9605f08 100644 --- a/streaming/config/config.py +++ b/streaming/config/config.py @@ -12,13 +12,13 @@ class ConverterConfig: gdrive_weights_conv_id : str = "1VO07OYbsnCuEJYRSuA8HhjlQnx6dbWX7" # This is the Google Drive ID for the CNN weights file. # MODEL OPTIONS - original_model_src : str = "/app/separator/model" # This parameter represents the source directory of the original model. - original_model_dst : str = "/app/streaming/model" # This parameter represents the destination directory of the original model. - model_py_module : str = "model.PM_Unet" # This is the python module where the model is defined - model_class_name : str = "Model_Unet" # The name of the model class. - tflite_model_dst : str = "tflite_model" # This is the destination directory for the TFLite model. - sample_rate : int = 44100 # Sample rate track - segment_duration : float = 1.0 # This parameter represents the duration of the audio segments that the model will process. + original_model_src : str = "/app/separator/model" # This parameter represents the source directory of the original model. + original_model_dst : str = "/app/streaming/model" # This parameter represents the destination directory of the original model. + model_py_module : str = "model.PM_Unet" # This is the python module where the model is defined + model_class_name : str = "Model_Unet" # The name of the model class. + tflite_model_dst : str = "/app/streaming/tflite_model" # This is the destination directory for the TFLite model. + sample_rate : int = 44100 # Sample rate track + segment_duration : float = 1.0 # This parameter represents the duration of the audio segments that the model will process. @dataclass @@ -28,7 +28,7 @@ class StreamConfig: sample_rate : int = 44100 # Sample rate track. nfft : int = 4096 # Number of bins used in STFT. stft_py_module : str = "model.STFT" # Path to the script STFT. - default_input_path : str = "/app/streaming/input" # Path to the directory where the input files are stored. + default_input_dir : str = "/app/streaming/input" # Path to the directory where the input files are stored. default_result_dir : str = "/app/streaming/streams" # Path directory in which processing results are saved. gdrive_mix_id : str = "1zJpyW1fYxHKXDcDH9s5DiBCYiRpraDB3" # The Google Drive ID for the mix file. default_duration : int = 15 # Length of an audio stream, in seconds. diff --git a/streaming/converter.py b/streaming/converter.py index ecc287c..8b5de6c 100644 --- a/streaming/converter.py +++ b/streaming/converter.py @@ -174,11 +174,9 @@ def tf_concat(tensors, dim): def main(args, config): - try: - Path(config.original_model_dst).mkdir(exist_ok=False) - except FileExistsError: - shutil.rmtree(config.original_model_dst) - shutil.copytree(config.original_model_src, config.original_model_dst) + shutil.copytree( + config.original_model_src, config.original_model_dst, dirs_exist_ok=True + ) py_module = importlib.import_module(args.model_py_module) cls_model = getattr(py_module, args.class_name) model = cls_model( @@ -200,7 +198,7 @@ def main(args, config): download_weights = True except FileExistsError: try: - Path(weights_path).touch(exist_ok=False) + weights_path.touch(exist_ok=False) download_weights = True except FileExistsError: download_weights = False @@ -224,7 +222,7 @@ def stft(self, wave): def istft(self, z): return self.model.stft.istft(z, self.length_wave) - SEGMENT_WAVE = config.sample_rate * config.segment_duration + SEGMENT_WAVE = int(config.sample_rate * config.segment_duration) dummy_wave = torch.rand(size=(1, 2, SEGMENT_WAVE)) dummy_spectr = OuterSTFT(SEGMENT_WAVE, model).stft(dummy_wave) @@ -236,13 +234,17 @@ def istft(self, z): ) model_filename = f"{args.class_name}_outer_stft_{config.segment_duration:.1f}" - model_path = args.out_dir + "/" + model_filename + model_path = f"{args.out_dir}/{model_filename}" + try: + Path(args.out_dir).mkdir(exist_ok=False) + except (OSError, FileExistsError): + pass - keras_model.save(model_path + ".h5") + keras_model.save(f"{model_path}.h5") custom_objects = {"WeightLayer": WeightLayer} converter = TFLiteConverter.from_keras_model_file( - model_path + ".h5", custom_objects=custom_objects + f"{model_path}.h5", custom_objects=custom_objects ) converter.target_ops = [ tf.lite.OpsSet.SELECT_TF_OPS, @@ -250,7 +252,7 @@ def istft(self, z): ] tflite_model = converter.convert() - with open(model_path + ".tflite", "wb") as f: + with open(f"{model_path}.tflite", "wb") as f: f.write(tflite_model) diff --git a/streaming/runner.py b/streaming/runner.py index 28e1d4d..e6c91da 100644 --- a/streaming/runner.py +++ b/streaming/runner.py @@ -3,19 +3,23 @@ import logging import os import re -import subprocess +import subprocess as sb +import sys from pathlib import Path from tf_lite_stream import TFLiteTorchStream +LOGGER = logging.getLogger(__name__) + + def resolve_default_sample(config): default_input_dir = config.StreamConfig.default_input_dir Path(default_input_dir).mkdir(parents=True, exist_ok=True) default_sample_path = f"{default_input_dir}/sample.wav" try: - Path(default_sample_path).touch() + Path(default_sample_path).touch(exist_ok=False) gdown.download(id=config.StreamConfig.gdrive_mix_id, output=default_sample_path) except FileExistsError: pass @@ -23,41 +27,69 @@ def resolve_default_sample(config): return default_sample_path -def main(args, config): +def resolve_tflite_model(config): try: Path(config.ConverterConfig.tflite_model_dst).mkdir(exist_ok=False) start_converter = True - except FileExistsError: + except (OSError, FileExistsError): if len(os.listdir(config.ConverterConfig.tflite_model_dst)) == 0: start_converter = True else: start_converter = False if start_converter: - subprocess.Popen( + with sb.Popen( ["python3", config.StreamConfig.converter_script], - executable="/bin/bash", - shell=True, + stdout=sb.PIPE, + stderr=sb.STDOUT, + ) as proc: + LOGGER.info(proc.stdout.read().decode()) + res = proc.wait() + LOGGER.info( + f"{config.StreamConfig.converter_script} finished with code : {res}" ) converter_outputs = os.listdir(config.ConverterConfig.tflite_model_dst) converter_outputs = list( filter(lambda x: re.match(r".*_outer_stft_.*\.tflite$", x), converter_outputs) ) + converter_outputs = [ + f"{config.ConverterConfig.tflite_model_dst}/{filename}" + for filename in converter_outputs + ] converter_outputs.sort(key=lambda x: os.stat(x).st_mtime, reverse=True) + tflite_model_path = converter_outputs[0] + parsed_segment = re.findall(r"_outer_stft_(.*)\.tflite$", tflite_model_path)[0] + + return tflite_model_path, parsed_segment - tflite_model_path = str( - config.ConverterConfig.tflite_model_dst + f"\{converter_outputs[0]}" + +def main(args, config): + is_tflite_model_path_default = ( + args.tflite_model_path == config.ConverterConfig.tflite_model_dst ) - parsed_segment = re.findall(r"_outer_stft_(.*)\.tflite$", tflite_model_path) + if not is_tflite_model_path_default and not args.tflite_model_segment: + raise ValueError( + "Specify segment [-s (0.5, 1, ...)] of STFT to outer tflite model" + ) + + if is_tflite_model_path_default: + tflite_model_path, parsed_segment = resolve_tflite_model(config) + else: + tflite_model_path, parsed_segment = ( + args.tflite_model_path, + args.tflite_model_segment, + ) track_path = args.mix_path if args.mix_path == config.StreamConfig.default_input_dir: track_path = resolve_default_sample(config) - stream_class = TFLiteTorchStream(tflite_model_path, segment=float(parsed_segment)) + stream_class = TFLiteTorchStream( + config, tflite_model_path, segment=float(parsed_segment) + ) out_paths = stream_class(track_path, args.out_dir, args.duration) - logging.info("Streams stored in : " + " ".join(out_paths)) + LOGGER.info("Streams stored in : " + " ".join(out_paths)) if __name__ == "__main__": @@ -85,7 +117,32 @@ def main(args, config): default=config.StreamConfig.default_duration, type=int, ) + parser.add_argument( + "-m", + dest="tflite_model_path", + help="path to tflite model", + default=config.ConverterConfig.tflite_model_dst, + type=str, + ) + parser.add_argument( + "-s", + dest="tflite_model_segment", + help="tflite model STFT window width (sample_rate * segment)", + required=False, + type=float, + ) + + stdout_handler = logging.StreamHandler(sys.stdout) + stdout_handler.setFormatter( + logging.Formatter( + "[%(levelname)s](%(filename)s).%(funcName)s(%(lineno)d) - %(message)s" + ) + ) + logging.basicConfig( + level=logging.DEBUG, + handlers=[stdout_handler], + format="%(levelname)s : %(message)s", + ) - # TODO: specify .tflite model manually args = parser.parse_args() main(args, config) diff --git a/streaming/tf_lite_stream.py b/streaming/tf_lite_stream.py index 4b2962a..73a826b 100644 --- a/streaming/tf_lite_stream.py +++ b/streaming/tf_lite_stream.py @@ -25,10 +25,13 @@ def __init__(self, config, model_filename: str, segment: float = 1): self.segment = segment try: - Path(config.original_model_dst).mkdir(exist_ok=False) + Path(config.ConverterConfig.original_model_dst).mkdir(exist_ok=False) except FileExistsError: - shutil.rmtree(config.original_model_dst) - shutil.copytree(config.original_model_src, config.original_model_dst) + shutil.rmtree(config.ConverterConfig.original_model_dst) + shutil.copytree( + config.ConverterConfig.original_model_src, + config.ConverterConfig.original_model_dst, + ) py_module = importlib.import_module(config.StreamConfig.stft_py_module) cls_stft = getattr(py_module, "STFT") self.stft = cls_stft(self.nfft) @@ -69,12 +72,7 @@ def __call__( stream_vocals.add_audio_stream(sample_rate, TFLiteTorchStream.NUM_CHANNELS) chunk_count = int(sample_rate * duration // frames_per_chunk) if duration else 0 - with ( - stream_drums.open(), - stream_bass.open(), - stream_other.open(), - stream_vocals.open(), - ): + with stream_drums.open(), stream_bass.open(), stream_other.open(), stream_vocals.open(): for i, chunk in tqdm(enumerate(stream_mix.stream())): if duration and i > chunk_count: break From fd090ebb1acb28e00042cdf8d58565c823899041 Mon Sep 17 00:00:00 2001 From: d-a-yakovlev Date: Tue, 9 Jan 2024 11:54:47 +0300 Subject: [PATCH 11/15] refac: configuring CI 5 --- .github/workflows/main.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/main.yaml b/.github/workflows/main.yaml index 2a8363a..c874dc1 100644 --- a/.github/workflows/main.yaml +++ b/.github/workflows/main.yaml @@ -15,4 +15,4 @@ jobs: run: | pip install -r requirements-dev.txt - name: "black" - run: black . --check --diff --color + run: black . --check --diff --color --exclude .*/config/ From 20d2466d7ad907938ffd6a03d253c645798effa9 Mon Sep 17 00:00:00 2001 From: d-a-yakovlev Date: Mon, 15 Jan 2024 21:30:28 +0300 Subject: [PATCH 12/15] refac: inference output 1 --- requirements.txt | 2 +- separator/inference.py | 28 +++++++++++++--------------- 2 files changed, 14 insertions(+), 16 deletions(-) diff --git a/requirements.txt b/requirements.txt index 96902a1..e2afdb6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ ffmpeg-python==0.2.0 -gdown +gdown==4.6.3 julius==0.2.7 lpips==0.1.4 musdb==0.4.0 diff --git a/separator/inference.py b/separator/inference.py index 3c49363..a1c7646 100644 --- a/separator/inference.py +++ b/separator/inference.py @@ -54,9 +54,10 @@ def resolve_weigths(self): if download_weights: gdown.download(gdrive_url, str(self.weights_path)) - def track(self, sample_mixture_path): + def track(self, sample_mixture_path, output_dir): if sample_mixture_path == self.config.default_input_dir: sample_mixture_path = self.resolve_default_sample() + output_dir = f"{output_dir}/{sample_mixture_path.split('/')[-1]}" offset = self.config.offset duration = self.config.duration @@ -75,7 +76,10 @@ def track(self, sample_mixture_path): # Denormalize sources = sources * ref.std() + ref.mean() + sources_list = ["drums", "bass", "other", "vocals"] + sources_ouputs = {s: f"{output_dir}/{s}.wav" for s in sources_list} + B, S, C, T = sources.shape sources = ( sources.view(B, S * C, T) @@ -84,7 +88,9 @@ def track(self, sample_mixture_path): sources = list(sources) audios = dict(zip(sources_list, sources[0])) - audios["original"] = waveform[:, start:end] + for k, v in audios.items(): + audios[k] = {"source": v, "path": sources_ouputs[k]} + return audios def separate_sources(self, mix, sample_rate): @@ -172,20 +178,12 @@ def resolve_default_sample(self): def main(args, config): inf_model = InferenceModel(config) - audios = inf_model.track(args.mix_path) - - out_dir = f"{args.out_dir}/{os.path.basename(args.mix_path)}/" - out_paths = ( - f"{out_dir}drums.wav", - f"{out_dir}bass.wav", - f"{out_dir}other.wav", - f"{out_dir}vocals.wav", - ) + audios = inf_model.track(args.mix_path, args.out_dir) - torchaudio.save(out_paths[0], audios["drums"], config.sample_rate) - torchaudio.save(out_paths[1], audios["bass"], config.sample_rate) - torchaudio.save(out_paths[2], audios["other"], config.sample_rate) - torchaudio.save(out_paths[3], audios["vocals"], config.sample_rate) + torchaudio.save(audios["drums"]["path"], audios["drums"]["source"], config.sample_rate) + torchaudio.save(audios["bass"]["path"], audios["bass"]["source"], config.sample_rate) + torchaudio.save(audios["other"]["path"], audios["other"]["source"], config.sample_rate) + torchaudio.save(audios["vocals"]["path"], audios["vocals"]["source"], config.sample_rate) if __name__ == "__main__": From 0d7eb3bed7597b91077a6162898c8ddfffbfda06 Mon Sep 17 00:00:00 2001 From: d-a-yakovlev Date: Mon, 15 Jan 2024 21:35:20 +0300 Subject: [PATCH 13/15] refac: inference output 2 --- separator/inference.py | 1 + 1 file changed, 1 insertion(+) diff --git a/separator/inference.py b/separator/inference.py index a1c7646..a281711 100644 --- a/separator/inference.py +++ b/separator/inference.py @@ -42,6 +42,7 @@ def resolve_weigths(self): ) gdrive_url = self.config.gdrive_weights_conv + download_weights = True try: self.config.weights_dir.mkdir(parents=True) download_weights = True From 50c93aab46b9bdf489f6cee24f5874f7505d2470 Mon Sep 17 00:00:00 2001 From: d-a-yakovlev Date: Tue, 16 Jan 2024 22:57:28 +0300 Subject: [PATCH 14/15] refac: inference output 3 --- separator/inference.py | 50 +++++++++++++++++++++++++++++------------- 1 file changed, 35 insertions(+), 15 deletions(-) diff --git a/separator/inference.py b/separator/inference.py index a281711..25255c6 100644 --- a/separator/inference.py +++ b/separator/inference.py @@ -11,10 +11,15 @@ class InferenceModel: - def __init__(self, config, model_bottlneck_lstm=True): + def __init__(self, config, model_bottlneck_lstm=True, weights_path=""): self.config = config self.model_bottlneck_lstm = model_bottlneck_lstm - self.resolve_weigths() + + weights_path = "" if weights_path is None else weights_path + if Path(weights_path).is_file(): + self.weights_path = weights_path + else: + self.resolve_weigths() self.model = Model_Unet( source=["drums", "bass", "other", "vocals"], @@ -42,13 +47,13 @@ def resolve_weigths(self): ) gdrive_url = self.config.gdrive_weights_conv - download_weights = True try: - self.config.weights_dir.mkdir(parents=True) + self.config.weights_dir.mkdir(exist_ok=False, parents=True) download_weights = True except FileExistsError: try: - Path(self.weights_path).touch() + Path(self.weights_path).touch(exist_ok=False) + download_weights = True except FileExistsError: download_weights = False @@ -58,7 +63,8 @@ def resolve_weigths(self): def track(self, sample_mixture_path, output_dir): if sample_mixture_path == self.config.default_input_dir: sample_mixture_path = self.resolve_default_sample() - output_dir = f"{output_dir}/{sample_mixture_path.split('/')[-1]}" + output_path = Path(output_dir) / Path(sample_mixture_path).stem + output_path.mkdir(exist_ok=True, parents=True) offset = self.config.offset duration = self.config.duration @@ -77,9 +83,9 @@ def track(self, sample_mixture_path, output_dir): # Denormalize sources = sources * ref.std() + ref.mean() - + sources_list = ["drums", "bass", "other", "vocals"] - sources_ouputs = {s: f"{output_dir}/{s}.wav" for s in sources_list} + sources_ouputs = {s: f"{str(output_path)}/{s}.wav" for s in sources_list} B, S, C, T = sources.shape sources = ( @@ -169,7 +175,7 @@ def resolve_default_sample(self): default_sample_path = f"{default_input_dir}/sample.wav" try: - Path(default_sample_path).touch() + Path(default_sample_path).touch(exist_ok=False) gdown.download(self.config.gdrive_mix, default_sample_path) except FileExistsError: pass @@ -178,13 +184,21 @@ def resolve_default_sample(self): def main(args, config): - inf_model = InferenceModel(config) + inf_model = InferenceModel(config, weights_path=args.weights_path) audios = inf_model.track(args.mix_path, args.out_dir) - torchaudio.save(audios["drums"]["path"], audios["drums"]["source"], config.sample_rate) - torchaudio.save(audios["bass"]["path"], audios["bass"]["source"], config.sample_rate) - torchaudio.save(audios["other"]["path"], audios["other"]["source"], config.sample_rate) - torchaudio.save(audios["vocals"]["path"], audios["vocals"]["source"], config.sample_rate) + torchaudio.save( + audios["drums"]["path"], audios["drums"]["source"], config.sample_rate + ) + torchaudio.save( + audios["bass"]["path"], audios["bass"]["source"], config.sample_rate + ) + torchaudio.save( + audios["other"]["path"], audios["other"]["source"], config.sample_rate + ) + torchaudio.save( + audios["vocals"]["path"], audios["vocals"]["source"], config.sample_rate + ) if __name__ == "__main__": @@ -207,7 +221,13 @@ def main(args, config): default=config.default_result_dir, type=str, ) - # TODO : argument for weigths + parser.add_argument( + "-w", + dest="weights_path", + help="specified path to weights", + required=False, + type=str, + ) args = parser.parse_args() main(args, config) From 26b6b2dc10c7ed26ffa487a6b1c48844aec3a682 Mon Sep 17 00:00:00 2001 From: maks00170 Date: Tue, 16 Jan 2024 23:40:30 +0300 Subject: [PATCH 15/15] refac: fix comment --- separator/config/config.py | 4 ++-- streaming/config/config.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/separator/config/config.py b/separator/config/config.py index 0b5d0c7..d9e25a1 100644 --- a/separator/config/config.py +++ b/separator/config/config.py @@ -90,8 +90,8 @@ class InferenceConfig: # INFERENCE OPTIONS segment : int = 7 # Length (in seconds) of each audio segment used during inference. overlap : float = 0.2 # overlapping segments at the beginning of the track and at the end - offset : Union[int, None] = None # start of segment to split - duration : Union[int, None] = None # end of segment to split + offset : Union[int, None] = None # start (in seconds) of segment to split + duration : Union[int, None] = None # duration (in seconds) of segment to split, use with `offset` sample_rate : int = 44100 # sample rate track num_channels : int = 2 # Number of channels in the audio track default_result_dir : str = "/app/separator/inference/output" # path file output tracks diff --git a/streaming/config/config.py b/streaming/config/config.py index 9605f08..6deead4 100644 --- a/streaming/config/config.py +++ b/streaming/config/config.py @@ -8,7 +8,7 @@ class ConverterConfig: weights_dir : Path = Path("/app/streaming/weights") # Path to the directory where the model weight files are stored. weights_LSTM_filename : str = "weight_LSTM.pt" # This is the filename for the LSTM weights file. weights_conv_filename : str = "weight_conv.pt" # This is the filename for the without CNN weights file. - gdrive_weights_LSTM_id : str = "18jT2TYffdRD1fL7wecAiM5nJPM_OKpNB" # This is the Google Drive ID for the LSTM weights file. + gdrive_weights_LSTM_id : str = "1uhAVMvW3x-KL2T2-VkjKjn9K7dTJnoyo" # This is the Google Drive ID for the LSTM weights file. gdrive_weights_conv_id : str = "1VO07OYbsnCuEJYRSuA8HhjlQnx6dbWX7" # This is the Google Drive ID for the CNN weights file. # MODEL OPTIONS