From 981411916f94d7552146309149d2a30d2899fb07 Mon Sep 17 00:00:00 2001 From: Zhao Jingwei Date: Mon, 23 Oct 2023 20:34:23 +0800 Subject: [PATCH] Initial commit --- .gitattributes | 2 + LICENSE | 21 + README.md | 2 + arrangement_utils.py | 192 ++++++ data_file_dir.txt | 5 + demo/Castles in the Air/arrangement_band.mid | Bin 0 -> 9570 bytes demo/Castles in the Air/arrangement_piano.mid | Bin 0 -> 5695 bytes demo/Castles in the Air/lead sheet.mid | Bin 0 -> 6537 bytes demo/Jingle Bells/arrangement_band.mid | Bin 0 -> 5322 bytes demo/Jingle Bells/arrangement_piano.mid | Bin 0 -> 3757 bytes demo/Jingle Bells/lead sheet.mid | Bin 0 -> 2279 bytes demo/Sally Garden/arrangement_band.mid | Bin 0 -> 5174 bytes demo/Sally Garden/arrangement_piano.mid | Bin 0 -> 2577 bytes demo/Sally Garden/lead sheet.mid | Bin 0 -> 1300 bytes inference.ipynb | 170 +++++ orchestrator/Prior.py | 587 +++++++++++++++++ orchestrator/QandA.py | 438 +++++++++++++ orchestrator/TransformerEncoderLayer.py | 114 ++++ orchestrator/__init__.py | 2 + orchestrator/dataset.py | 286 +++++++++ orchestrator/dl_modules/__init__.py | 5 + orchestrator/dl_modules/feat_decoder.py | 160 +++++ orchestrator/dl_modules/pianotree_dec.py | 364 +++++++++++ orchestrator/dl_modules/pianotree_enc.py | 97 +++ orchestrator/dl_modules/pr_mat_txt_enc.py | 48 ++ orchestrator/dl_modules/vqvae.py | 201 ++++++ orchestrator/scheduler.py | 120 ++++ .../scripts/data_preprocessing/converter.py | 167 +++++ .../lmd_midi_quantization.py | 153 +++++ .../pop909_process_4bin_data.py | 174 +++++ .../data_preprocessing/quantization_utils.py | 223 +++++++ .../data_preprocessing/slakh_quantization.py | 223 +++++++ .../objective_evaluation_arrangement.ipynb | 322 ++++++++++ .../objective_evaluation_orchestration.ipynb | 180 ++++++ orchestrator/train_Prior_DDP.py | 237 +++++++ orchestrator/train_QandA.py | 262 ++++++++ orchestrator/utils.py | 383 +++++++++++ orchestrator/vq_dataset.py | 149 +++++ piano_arranger/AccoMontage.py | 389 ++++++++++++ piano_arranger/__init__.py | 2 + piano_arranger/acc_utils.py | 186 ++++++ piano_arranger/chord_recognition/.gitignore | 6 + piano_arranger/chord_recognition/README.TXT | 2 + piano_arranger/chord_recognition/__init__.py | 5 + .../chord_recognition/air_structure.py | 375 +++++++++++ .../chord_recognition/chord_class.py | 129 ++++ .../chord_recognition/complex_chord.py | 320 ++++++++++ .../extractors/midi_utilities.py | 180 ++++++ .../extractors/rule_based_channel_reweight.py | 37 ++ .../chord_recognition/io_new/air_io.py | 18 + .../chord_recognition/io_new/beat_align_io.py | 77 +++ .../chord_recognition/io_new/beatlab_io.py | 42 ++ .../chord_recognition/io_new/chordlab_io.py | 44 ++ .../io_new/complex_chord_io.py | 56 ++ .../chord_recognition/io_new/downbeat_io.py | 45 ++ .../chord_recognition/io_new/jams_io.py | 17 + .../chord_recognition/io_new/jointbeat_io.py | 44 ++ .../chord_recognition/io_new/key_io.py | 42 ++ .../chord_recognition/io_new/list_io.py | 12 + .../chord_recognition/io_new/lyric_io.py | 44 ++ .../chord_recognition/io_new/madmom_io.py | 34 + .../chord_recognition/io_new/midilab_io.py | 50 ++ .../chord_recognition/io_new/osu_io.py | 54 ++ .../chord_recognition/io_new/salami_io.py | 43 ++ .../chord_recognition/io_new/tag_io.py | 44 ++ piano_arranger/chord_recognition/main.py | 76 +++ .../chord_recognition/midi_chord.py | 151 +++++ .../inspectionProfiles/Project_Default.xml | 14 + .../chord_recognition/mir/.idea/mir.iml | 12 + .../chord_recognition/mir/.idea/misc.xml | 4 + .../chord_recognition/mir/.idea/modules.xml | 8 + .../chord_recognition/mir/.idea/vcs.xml | 6 + .../chord_recognition/mir/.idea/workspace.xml | 129 ++++ .../chord_recognition/mir/README.MD | 296 +++++++++ .../chord_recognition/mir/__init__.py | 5 + piano_arranger/chord_recognition/mir/cache.py | 52 ++ .../chord_recognition/mir/common.py | 7 + .../chord_recognition/mir/data/bothchroma.n3 | 34 + .../chord_recognition/mir/data/chordino.n3 | 46 ++ .../chord_recognition/mir/data/chroma.n3 | 42 ++ .../mir/data/curve_template.svl | 13 + .../mir/data/midi_template.svl | 13 + .../mir/data/pitch_template.svl | 18 + .../mir/data/sparse_tag_template.svl | 13 + .../mir/data/spectrogram_template.svl | 13 + .../mir/data/tunedlogfreqspec.n3 | 42 ++ .../chord_recognition/mir/data/tuning.n3 | 14 + .../chord_recognition/mir/data_file.py | 512 +++++++++++++++ .../mir/extractors/__init__.py | 3 + .../mir/extractors/extractor_base.py | 107 ++++ .../mir/extractors/librosa_extractor.py | 59 ++ .../chord_recognition/mir/extractors/misc.py | 56 ++ .../mir/extractors/vamp_extractor.py | 136 ++++ .../chord_recognition/mir/io/__init__.py | 12 + .../mir/io/feature_io_base.py | 96 +++ .../mir/io/implement/__init__.py | 0 .../mir/io/implement/chroma_io.py | 48 ++ .../mir/io/implement/midi_io.py | 16 + .../mir/io/implement/music_io.py | 18 + .../io/implement/regional_spectrogram_io.py | 80 +++ .../mir/io/implement/scalar_io.py | 32 + .../mir/io/implement/spectrogram_io.py | 58 ++ .../mir/io/implement/unknown_io.py | 12 + .../chord_recognition/mir/music_base.py | 21 + .../chord_recognition/mir/requirements.txt | 8 + .../chord_recognition/mir/settings.py | 3 + .../chord_recognition/requirements.txt | 6 + piano_arranger/format_converter.py | 246 +++++++ piano_arranger/models/EC2VAE.py | 138 ++++ piano_arranger/models/Poly_Dis.py | 270 ++++++++ piano_arranger/models/__init__.py | 5 + piano_arranger/models/amc_dl/__init__.py | 0 piano_arranger/models/amc_dl/demo_maker.py | 38 ++ .../models/amc_dl/torch_plus/__init__.py | 7 + .../models/amc_dl/torch_plus/example.py | 13 + .../models/amc_dl/torch_plus/manager.py | 137 ++++ .../models/amc_dl/torch_plus/module.py | 220 +++++++ .../models/amc_dl/torch_plus/scheduler.py | 104 +++ .../models/amc_dl/torch_plus/train_utils.py | 49 ++ piano_arranger/models/ptvae.py | 598 ++++++++++++++++++ piano_arranger/models/transition_model.py | 31 + piano_arranger/scripts/build_phrase_data.py | 191 ++++++ .../scripts/edge_weights_inference.py | 141 +++++ .../scripts/transition_model_data_loader.py | 166 +++++ .../transition_model_train_contrastive.py | 165 +++++ 125 files changed, 12394 insertions(+) create mode 100644 .gitattributes create mode 100644 LICENSE create mode 100644 README.md create mode 100644 arrangement_utils.py create mode 100644 data_file_dir.txt create mode 100644 demo/Castles in the Air/arrangement_band.mid create mode 100644 demo/Castles in the Air/arrangement_piano.mid create mode 100644 demo/Castles in the Air/lead sheet.mid create mode 100644 demo/Jingle Bells/arrangement_band.mid create mode 100644 demo/Jingle Bells/arrangement_piano.mid create mode 100644 demo/Jingle Bells/lead sheet.mid create mode 100644 demo/Sally Garden/arrangement_band.mid create mode 100644 demo/Sally Garden/arrangement_piano.mid create mode 100644 demo/Sally Garden/lead sheet.mid create mode 100644 inference.ipynb create mode 100644 orchestrator/Prior.py create mode 100644 orchestrator/QandA.py create mode 100644 orchestrator/TransformerEncoderLayer.py create mode 100644 orchestrator/__init__.py create mode 100644 orchestrator/dataset.py create mode 100644 orchestrator/dl_modules/__init__.py create mode 100644 orchestrator/dl_modules/feat_decoder.py create mode 100644 orchestrator/dl_modules/pianotree_dec.py create mode 100644 orchestrator/dl_modules/pianotree_enc.py create mode 100644 orchestrator/dl_modules/pr_mat_txt_enc.py create mode 100644 orchestrator/dl_modules/vqvae.py create mode 100644 orchestrator/scheduler.py create mode 100644 orchestrator/scripts/data_preprocessing/converter.py create mode 100644 orchestrator/scripts/data_preprocessing/lmd_midi_quantization.py create mode 100644 orchestrator/scripts/data_preprocessing/pop909_process_4bin_data.py create mode 100644 orchestrator/scripts/data_preprocessing/quantization_utils.py create mode 100644 orchestrator/scripts/data_preprocessing/slakh_quantization.py create mode 100644 orchestrator/scripts/objective_evaluation_arrangement.ipynb create mode 100644 orchestrator/scripts/objective_evaluation_orchestration.ipynb create mode 100644 orchestrator/train_Prior_DDP.py create mode 100644 orchestrator/train_QandA.py create mode 100644 orchestrator/utils.py create mode 100644 orchestrator/vq_dataset.py create mode 100644 piano_arranger/AccoMontage.py create mode 100644 piano_arranger/__init__.py create mode 100644 piano_arranger/acc_utils.py create mode 100644 piano_arranger/chord_recognition/.gitignore create mode 100644 piano_arranger/chord_recognition/README.TXT create mode 100644 piano_arranger/chord_recognition/__init__.py create mode 100644 piano_arranger/chord_recognition/air_structure.py create mode 100644 piano_arranger/chord_recognition/chord_class.py create mode 100644 piano_arranger/chord_recognition/complex_chord.py create mode 100644 piano_arranger/chord_recognition/extractors/midi_utilities.py create mode 100644 piano_arranger/chord_recognition/extractors/rule_based_channel_reweight.py create mode 100644 piano_arranger/chord_recognition/io_new/air_io.py create mode 100644 piano_arranger/chord_recognition/io_new/beat_align_io.py create mode 100644 piano_arranger/chord_recognition/io_new/beatlab_io.py create mode 100644 piano_arranger/chord_recognition/io_new/chordlab_io.py create mode 100644 piano_arranger/chord_recognition/io_new/complex_chord_io.py create mode 100644 piano_arranger/chord_recognition/io_new/downbeat_io.py create mode 100644 piano_arranger/chord_recognition/io_new/jams_io.py create mode 100644 piano_arranger/chord_recognition/io_new/jointbeat_io.py create mode 100644 piano_arranger/chord_recognition/io_new/key_io.py create mode 100644 piano_arranger/chord_recognition/io_new/list_io.py create mode 100644 piano_arranger/chord_recognition/io_new/lyric_io.py create mode 100644 piano_arranger/chord_recognition/io_new/madmom_io.py create mode 100644 piano_arranger/chord_recognition/io_new/midilab_io.py create mode 100644 piano_arranger/chord_recognition/io_new/osu_io.py create mode 100644 piano_arranger/chord_recognition/io_new/salami_io.py create mode 100644 piano_arranger/chord_recognition/io_new/tag_io.py create mode 100644 piano_arranger/chord_recognition/main.py create mode 100644 piano_arranger/chord_recognition/midi_chord.py create mode 100644 piano_arranger/chord_recognition/mir/.idea/inspectionProfiles/Project_Default.xml create mode 100644 piano_arranger/chord_recognition/mir/.idea/mir.iml create mode 100644 piano_arranger/chord_recognition/mir/.idea/misc.xml create mode 100644 piano_arranger/chord_recognition/mir/.idea/modules.xml create mode 100644 piano_arranger/chord_recognition/mir/.idea/vcs.xml create mode 100644 piano_arranger/chord_recognition/mir/.idea/workspace.xml create mode 100644 piano_arranger/chord_recognition/mir/README.MD create mode 100644 piano_arranger/chord_recognition/mir/__init__.py create mode 100644 piano_arranger/chord_recognition/mir/cache.py create mode 100644 piano_arranger/chord_recognition/mir/common.py create mode 100644 piano_arranger/chord_recognition/mir/data/bothchroma.n3 create mode 100644 piano_arranger/chord_recognition/mir/data/chordino.n3 create mode 100644 piano_arranger/chord_recognition/mir/data/chroma.n3 create mode 100644 piano_arranger/chord_recognition/mir/data/curve_template.svl create mode 100644 piano_arranger/chord_recognition/mir/data/midi_template.svl create mode 100644 piano_arranger/chord_recognition/mir/data/pitch_template.svl create mode 100644 piano_arranger/chord_recognition/mir/data/sparse_tag_template.svl create mode 100644 piano_arranger/chord_recognition/mir/data/spectrogram_template.svl create mode 100644 piano_arranger/chord_recognition/mir/data/tunedlogfreqspec.n3 create mode 100644 piano_arranger/chord_recognition/mir/data/tuning.n3 create mode 100644 piano_arranger/chord_recognition/mir/data_file.py create mode 100644 piano_arranger/chord_recognition/mir/extractors/__init__.py create mode 100644 piano_arranger/chord_recognition/mir/extractors/extractor_base.py create mode 100644 piano_arranger/chord_recognition/mir/extractors/librosa_extractor.py create mode 100644 piano_arranger/chord_recognition/mir/extractors/misc.py create mode 100644 piano_arranger/chord_recognition/mir/extractors/vamp_extractor.py create mode 100644 piano_arranger/chord_recognition/mir/io/__init__.py create mode 100644 piano_arranger/chord_recognition/mir/io/feature_io_base.py create mode 100644 piano_arranger/chord_recognition/mir/io/implement/__init__.py create mode 100644 piano_arranger/chord_recognition/mir/io/implement/chroma_io.py create mode 100644 piano_arranger/chord_recognition/mir/io/implement/midi_io.py create mode 100644 piano_arranger/chord_recognition/mir/io/implement/music_io.py create mode 100644 piano_arranger/chord_recognition/mir/io/implement/regional_spectrogram_io.py create mode 100644 piano_arranger/chord_recognition/mir/io/implement/scalar_io.py create mode 100644 piano_arranger/chord_recognition/mir/io/implement/spectrogram_io.py create mode 100644 piano_arranger/chord_recognition/mir/io/implement/unknown_io.py create mode 100644 piano_arranger/chord_recognition/mir/music_base.py create mode 100644 piano_arranger/chord_recognition/mir/requirements.txt create mode 100644 piano_arranger/chord_recognition/mir/settings.py create mode 100644 piano_arranger/chord_recognition/requirements.txt create mode 100644 piano_arranger/format_converter.py create mode 100644 piano_arranger/models/EC2VAE.py create mode 100644 piano_arranger/models/Poly_Dis.py create mode 100644 piano_arranger/models/__init__.py create mode 100644 piano_arranger/models/amc_dl/__init__.py create mode 100644 piano_arranger/models/amc_dl/demo_maker.py create mode 100644 piano_arranger/models/amc_dl/torch_plus/__init__.py create mode 100644 piano_arranger/models/amc_dl/torch_plus/example.py create mode 100644 piano_arranger/models/amc_dl/torch_plus/manager.py create mode 100644 piano_arranger/models/amc_dl/torch_plus/module.py create mode 100644 piano_arranger/models/amc_dl/torch_plus/scheduler.py create mode 100644 piano_arranger/models/amc_dl/torch_plus/train_utils.py create mode 100644 piano_arranger/models/ptvae.py create mode 100644 piano_arranger/models/transition_model.py create mode 100644 piano_arranger/scripts/build_phrase_data.py create mode 100644 piano_arranger/scripts/edge_weights_inference.py create mode 100644 piano_arranger/scripts/transition_model_data_loader.py create mode 100644 piano_arranger/scripts/transition_model_train_contrastive.py diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..dfe0770 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,2 @@ +# Auto detect text files and perform LF normalization +* text=auto diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..3feef5e --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 Zhao Jingwei + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..f93a554 --- /dev/null +++ b/README.md @@ -0,0 +1,2 @@ +# AccoMontage-3 + diff --git a/arrangement_utils.py b/arrangement_utils.py new file mode 100644 index 0000000..1747b3f --- /dev/null +++ b/arrangement_utils.py @@ -0,0 +1,192 @@ +import os +import pretty_midi as pyd +import numpy as np +import torch +from torch.utils.data import DataLoader +from scipy.interpolate import interp1d +from tqdm import tqdm + +from piano_arranger.acc_utils import split_phrases +import piano_arranger.format_converter as cvt +from piano_arranger.models import DisentangleVAE +from piano_arranger.AccoMontage import find_by_length, dp_search, re_harmonization, get_texture_filter, ref_spotlight + +from orchestrator import Slakh_Dataset, collate_fn, compute_pr_feat, EMBED_PROGRAM_MAPPING, Prior +from orchestrator.dataset import SLAKH_CLASS_PROGRAMS +from orchestrator.utils import grid2pr, pr2grid, matrix2midi, midi2matrix + + +SLAKH_CLASS_MAPPING = {v: k for k, v in EMBED_PROGRAM_MAPPING.items()} +TOTAL_LEN_BIN = np.array([4, 7, 12, 15, 20, 23, 28, 31, 36, 39, 44, 47, 52, 55, 60, 63, 68, 71, 76, 79, 84, 87, 92, 95, 100, 103, 108, 111, 116, 119, 124, 127, 132]) +ABS_POS_BIN = np.arange(129) +REL_POS_BIN = np.arange(128) + + +def load_premise(DATA_FILE_ROOT, DEVICE): + """Load AccoMontage Search Space""" + print('Loading AccoMontage piano texture search space. This may take 1 or 2 minutes ...') + data = np.load(os.path.join(DATA_FILE_ROOT, 'phrase_data.npz'), allow_pickle=True) + melody = data['melody'] + acc = data['acc'] + chord = data['chord'] + vel = data['velocity'] + cc = data['cc'] + acc_pool = {} + for LEN in tqdm(range(2, 11)): + (mel, acc_, chord_, vel_, cc_, song_reference) = find_by_length(melody, acc, chord, vel, cc, LEN) + acc_pool[LEN] = (mel, acc_, chord_, vel_, cc_, song_reference) + texture_filter = get_texture_filter(acc_pool) + edge_weights=np.load(os.path.join(DATA_FILE_ROOT, 'edge_weights.npz'), allow_pickle=True) + + """Load Q&A Prompt Search Space""" + print('loading orchestration prompt search space ...') + slakh_dir = os.path.join(DATA_FILE_ROOT, 'Slakh2100_inference_set') + dataset = Slakh_Dataset(slakh_dir, debug_mode=False, split='test', mode='train') + loader = DataLoader(dataset, batch_size=1, shuffle=True, collate_fn=lambda b:collate_fn(b, DEVICE, get_pr_gt=True)) + REF_PR = [] + REF_P = [] + REF_T = [] + REF_PROG = [] + FLTN = [] + for (pr, _, prog, func_pitch, func_time, _, _, _, _, _) in tqdm(loader): + pr = pr[0] + prog = prog[0, :] + func_pitch = func_pitch[0, :] + func_time = func_time[0, :] + fltn = torch.cat([torch.sum(func_pitch, dim=-2), torch.sum(func_time, dim=-2)], dim=-1) + REF_PR.append(pr) + REF_P.append(func_pitch[0]) + REF_T.append(func_time[0]) + REF_PROG.append(prog) + FLTN.append(fltn[0]) + FLTN = torch.stack(FLTN, dim=0) + + """Initialize orchestration model (Prior + Q&A)""" + print('Initialize model ...') + prior_model_path = os.path.join(DATA_FILE_ROOT, 'params_prior.pt') + QaA_model_path = os.path.join(DATA_FILE_ROOT, 'params_qa.pt') + orchestrator = Prior.init_inference_model(prior_model_path, QaA_model_path, DEVICE=DEVICE) + orchestrator.to(DEVICE) + orchestrator.eval() + piano_arranger = DisentangleVAE.init_model(torch.device('cuda')).cuda() + piano_arranger.load_state_dict(torch.load(os.path.join(DATA_FILE_ROOT, 'params_reharmonizer.pt'))) + print('Finished.') + return piano_arranger, orchestrator, (acc_pool, edge_weights, texture_filter), (REF_PR, REF_P, REF_T, REF_PROG, FLTN) + + +def read_lead_sheet(DEMO_ROOT, SONG_NAME, SEGMENTATION, NOTE_SHIFT, melody_track_ID=0): + melody_roll, chord_roll = cvt.leadsheet2matrix(os.path.join(DEMO_ROOT, SONG_NAME, 'lead sheet.mid'), melody_track_ID) + assert(len(melody_roll == len(chord_roll))) + if NOTE_SHIFT != 0: + melody_roll = melody_roll[int(NOTE_SHIFT*4):, :] + chord_roll = chord_roll[int(NOTE_SHIFT*4):, :] + if len(melody_roll) % 16 != 0: + pad_len = (len(melody_roll)//16+1)*16-len(melody_roll) + melody_roll = np.pad(melody_roll, ((0, pad_len), (0, 0))) + melody_roll[-pad_len:, -1] = 1 + chord_roll = np.pad(chord_roll, ((0, pad_len), (0, 0))) + chord_roll[-pad_len:, 0] = -1 + chord_roll[-pad_len:, -1] = -1 + + CHORD_TABLE = np.stack([cvt.expand_chord(chord) for chord in chord_roll[::4]], axis=0) + LEADSHEET = np.concatenate((melody_roll, chord_roll[:, 1: -1]), axis=-1) #T*142, quantized at 16th + query_phrases = split_phrases(SEGMENTATION) #[('A', 8, 0), ('A', 8, 8), ('B', 8, 16), ('B', 8, 24)] + + midi_len = len(LEADSHEET)//16 + anno_len = sum([item[1] for item in query_phrases]) + if midi_len > anno_len: + LEADSHEET = LEADSHEET[: anno_len*16] + CHORD_TABLE = CHORD_TABLE[: anno_len*4] + print(f'Mismatch warning: Detect {midi_len} bars in the lead sheet (MIDI) and {anno_len} bars in the provided phrase annotation. The lead sheet is truncated to {anno_len} bars.') + elif midi_len < anno_len: + pad_len = (anno_len - midi_len)*16 + LEADSHEET = np.pad(LEADSHEET, ((0, pad_len), (0, 0))) + LEADSHEET[-pad_len:, 129] = 1 + CHORD_TABLE = np.pad(CHORD_TABLE, ((0, pad_len//4), (0, 0))) + CHORD_TABLE[-pad_len//4:, 11] = -1 + CHORD_TABLE[-pad_len//4:, -1] = -1 + print(f'Mismatch warning: Detect {midi_len} bars in the lead sheet (MIDI) and {anno_len} bars in the provided phrase annotation. The lead sheet is padded to {anno_len} bars.') + + + melody_queries = [] + for item in query_phrases: + start_bar = item[-1] + length = item[-2] + segment = LEADSHEET[start_bar*16: (start_bar+length)*16] + melody_queries.append(segment) #melody queries: list of T16*142, segmented by phrases + + return (LEADSHEET, CHORD_TABLE, melody_queries, query_phrases) + + +def piano_arrangement(pianoRoll, chord_table, melody_queries, query_phrases, acc_pool, edge_weights, texture_filter, piano_arranger, PREFILTER): + print('Phrasal Unit selection begins:\n\t', f'{len(query_phrases)} phrases in the lead sheet;\n\t', f'set note density filter: {PREFILTER}.') + phrase_indice, chord_shift = dp_search( melody_queries, + query_phrases, + acc_pool, + edge_weights, + texture_filter, + filter_id=PREFILTER) + path = phrase_indice[0] + shift = chord_shift[0] + print('Re-harmonization begins ...') + midi_recon, acc = re_harmonization(pianoRoll, chord_table, query_phrases, path, shift, acc_pool, model=piano_arranger, get_est=True, tempo=100) + acc = np.array([grid2pr(matrix) for matrix in acc]) + print('Piano accompaiment generated!') + return midi_recon, acc + + +def prompt_sampling(acc_piano, REF_PR, REF_P, REF_T, REF_PROG, FLTN, DEVICE='cuda:0'): + fltn = torch.from_numpy(np.concatenate(compute_pr_feat(acc_piano[0:1]), axis=-1)).to(DEVICE) + sim_func = torch.nn.CosineSimilarity(dim=-1) + distance = sim_func(fltn, FLTN) + distance = distance + torch.normal(mean=torch.zeros(distance.shape), std=0.2*torch.ones(distance.shape)).to(distance.device) + sim_values, anchor_points = torch.sort(distance, descending=True) + IDX = 0 + sim_value = sim_values[IDX] + anchor_point = anchor_points[IDX] + func_pitch = REF_P[anchor_point] + func_time = REF_T[anchor_point] + prog = REF_PROG[anchor_point] + prog_class = [SLAKH_CLASS_MAPPING[item.item()] for item in prog.cpu().detach().numpy()] + program_name = [SLAKH_CLASS_PROGRAMS[item] for item in prog_class] + refr = REF_PR[anchor_point] + midi_ref = matrix2midi( + pr_matrices=refr.detach().cpu().numpy(), + programs=prog_class, + init_tempo=100) + print(f'Prior model initialized with {len(program_name)} tracks:\n\t{program_name}') + return midi_ref, (prog, func_pitch, func_time) + +def orchestration(acc_piano, chord_track, prog, func_pitch, func_time, orchestrator, DEVICE='cuda:0', blur=.5, p=.1, t=4): + print('Orchestration begins ...') + if chord_track is not None: + if len(acc_piano) > len(chord_track): + chord_track = np.pad(chord_track, ((0, 0), (len(acc_piano)-len(chord_track)))) + else: + chord_track = chord_track[:len(acc_piano)] + acc_piano = np.max(np.stack([acc_piano, chord_track], axis=0), axis=0) + + mix = torch.from_numpy(np.array([pr2grid(matrix, max_note_count=32) for matrix in acc_piano])).to(DEVICE) + r_pos = np.round(np.arange(0, len(mix), 1) / (len(mix)-1) * len(REL_POS_BIN)) + total_len = np.argmin(np.abs(TOTAL_LEN_BIN - len(mix))).repeat(len(mix)) + a_pos = np.append(ABS_POS_BIN[0: min(ABS_POS_BIN[-1],len(mix))], [ABS_POS_BIN[-1]] * (len(mix)-ABS_POS_BIN[-1])) + r_pos = torch.from_numpy(r_pos).long().to(DEVICE) + a_pos = torch.from_numpy(a_pos).long().to(DEVICE) + total_len = torch.from_numpy(total_len).long().to(DEVICE) + + if func_pitch is not None: + func_pitch = func_pitch.unsqueeze(0).unsqueeze(0) + if func_time is not None: + func_time = func_time.unsqueeze(0).unsqueeze(0) + recon_pitch, recon_dur = orchestrator.run_autoregressive_nucleus(mix.unsqueeze(0), prog.unsqueeze(0), func_pitch, func_time, total_len.unsqueeze(0), a_pos.unsqueeze(0), r_pos.unsqueeze(0), blur, p, t) #func_pitch.unsqueeze(0).unsqueeze(0), func_time.unsqueeze(0).unsqueeze(0) + + grid_recon = torch.cat([recon_pitch.max(-1)[-1].unsqueeze(-1), recon_dur.max(-1)[-1]], dim=-1) + bat_ch, track, _, max_simu_note, grid_dim = grid_recon.shape + grid_recon = grid_recon.permute(1, 0, 2, 3, 4) + grid_recon = grid_recon.reshape(track, -1, max_simu_note, grid_dim) + + pr_recon_ = np.array([grid2pr(matrix) for matrix in grid_recon.detach().cpu().numpy()]) + pr_recon = matrix2midi(pr_recon_, [SLAKH_CLASS_MAPPING[item.item()] for item in prog.cpu().detach().numpy()], 100) + print('Full-band accompaiment generated!') + return pr_recon + \ No newline at end of file diff --git a/data_file_dir.txt b/data_file_dir.txt new file mode 100644 index 0000000..09c93b6 --- /dev/null +++ b/data_file_dir.txt @@ -0,0 +1,5 @@ +If you wish to run our code, +please download our pre-trained checkpoints +and processed data at the following link: + +https://we.tl/t-cc2nOC8dAA \ No newline at end of file diff --git a/demo/Castles in the Air/arrangement_band.mid b/demo/Castles in the Air/arrangement_band.mid new file mode 100644 index 0000000000000000000000000000000000000000..da114cbde87d7fc5fd7d7c70449138bdb2fdfe1c GIT binary patch literal 9570 zcmeI2%Z?pa6^6GH17Uzbw1JGkgR}u#R-2aNa>tg{X{+Vd!Iwm!$b$}&m9iirR7bWW z1OudV+BP62+<6P0fjck3lrfKBzVBaEr%yXfn1Fm#??B%X4e9*6y_V`fh8t`|ftGH~V(GbvIi#>*wpU zwphz=-_00dLqFf2wI`=;^Lx83X4%G4_9J<|&Q?|-<4zlHDsMos(I2qd$zDI~IXTSQu-!Lw?6GRQwT!af%>!<_u^%@3 zcCFXtNSv>4;y)9(iG;qphQ+PxB4=E|RvlIW0=jN=9%G(vF)@Zxrw%VKjk~3Ncr@H< z!!57+{AL^WqBudx2AtTxqh_OoRTw=karb=xX5TSH_=v=vzT0ixXnlp55L|1F>J4#a z$2tpHl3*QP-oo1?nEA42OHl*KFwyQvbPX-zi?s;|iACW=(+9@yAb66EOUOki?v1LB zk`!ORmU!V*!Y(F~-a(!0F_bm<3&FsXuTF7i3veT|B(r`VpYiud zEe}TXCcgoJ<`KA?y$jVyh6@IXsL@7Kka>7{H_`*Bqs?L!VS&<3?qb*oI)Qz=J$*m2 zp1yCO`zNQi5Oa#KscSAiyu9PKZig8*dB=0h zbM`LQ7~2J0u|pQPp$QbENM(si{>A;$7NmjmRb$vh#I_lZD+ok7wP+1##oJ=- zU@hteu4fnr$`rbG1f4TkP zZu|LY68=xrwTw*Rdo&zNIGc15Td_e4F+(G%5f*WU4Ov5MYD9B=n63tXH3AZD$;-9K zSG%E2`FN{8JlL{e60{1`>UFE{n67KaWo$yPgMXD`E!HQ* zvrS1=IVEgVQR>1A^L4f&w8O=d#td0y@WQbQrjCzVosWGA7CNdf}hDyf1iUCHmAt_Z}odZnKFMx53rr)rC zn~d-@Vsxph%qVf$BfqkvsGM$v)$?^6E^oIHbwMsHFF{m_ljm57!IHI)ZY-+HY7byX zWCn_pGpro2W;i?IfHG|?69ByDeb1()?|9KuUWzncVPHg1C%!}^O@!4)Uu_IYl^1HV z#tM&gC;$cJpm~g2^L-MMnm4e?T5Z!-gqa|WMC$qM7>TPg4yVvwo=?7*ri}O6z+=6w z#(@C`j7EG&WR3!$_d#dk5iWt{O3&A_%Q(f62j2k}S|Ca4iL% z8_}48^@U}bRzrrDX=PCD%|^D@>CtV}&e0APBs*y^@MzPc5D!cq6O(3}m$5*ZV^J&| z%`+twg@uTw)CJR)4${5zv?>Fa6i~_G)nD6CHecU@cYHr9`HJbMyn;>JPhNXGa{v^^ zQNZNdz=m!~e-1fwL#pKFoggSw%*&cw7^NN(!Zc{8)4}FcQXoN=Xd8%m6pigud3MD$ zbmZC1=`H9GiD6Z&dBF%HXuEkBJ(~a&7&I}2DOgUp(+MS1({e#wO>j#LmNEt>hedIa584qC<1O;E| zl5rAqAX&&mNW=uH1@TyVF(;1qt-Gk3)r!s6&r5;ixOHa?EpZ*r#F@ZKe_~%9S$Z`! ziymJ~Kz>0$TG`-=E!D40fuItV*nOc~vYnOwyVod!Lxaxg9ybV=%DI8I+GZmStP?^W)vXg~)7a5ei?^C`tS1%P5 zOTpGF@Y>;nA@fryJwC(cppI*JfSDDpn%+YS&XsZlSJ>mK*1Qv3Yd`wTE}v7T9;v`6 z@uv?O3E0md45dx-JAw1ZuScEr_$p~5C)HCVIomvgavsuoJvkM^moAFnJ(<#6`_5BE z8Sm;0vh-mFK2Kb(vQ3m%{rxao)2dp}RyB|yI@3vF*kHGbC_Ud2Vk4fKv9$NGA&EBK zd0RqOkx+$E3L}7IJo!@rxw^^;4HQ5LRbvz9*PQTu(Um+} zdUPqD8wseYa#}4p`8p2Y#C}7V{s&9t0%2q8LPBMuDd4h^ar6b;F1@g2#?T@I=j4;1 ztuADNDRw8Qbp}&zoFF-H5N$w}jya3jIsTsxcxzL)lU7a1c`q&~gB1iUTMo@4Y{j|w zfB`WH2qI?b=ESV}*wNEZl&uO*y9COJP!<+DQD&%0*GMe0J-)#JyU?p*`7dr&0CAI( zPVXw$+xcKAER1)y6H7=k91RqJl8MERGiD!#3?(L@ABHXr8H6s@&`UZd)r7aCh&UKB z(j7s%YbF{=(lI?i4S{%om+7%WiwIz@IaRY`2@^9j-=DaPy97W2Utz`R9;IvbloV7yk}wKX>$=`@y@9?mghU+xOl(7{1*8@^bsNQB5APUfHR0*n7Tz zw)d6weBLe!tQaub8rurTa-bNs;sn-I!?py$E)7QIiqL`u0ZWOGKe}1H!+C&q9#D_^ zG~}%uBZTZNbX59>4RAXmA}|bMgvUzAVo-P{T}KK+-jw}oS$$?pn`PTaVlXO+0f|In zu_l{IU(}kfSGj_}R0Al=n?pfrMX^w)JaL=KPhhH=mWhw`bY|#!{PxlytOE&bnpHSpgk;f6MbIm&i8c{)Tak_4OE?Cogds*@(*$8uI7DSEC6Fq; zQ3g|#V97&DLPZMF-B(13oZ5^ZQKT%p;sQg;-A40hpyJm!3`&9lX7v>)p{y)C*gN>m z2F=tT5mFGLs2{t$R8*bBPhZHXVy*)_$@9mrjU11QRNz>zwO#=D0)Q$wWidO8u(M+r zS4toU3!*FsT%ph@ad2ZoRUt*N#*RB{JYdFmUs(>9AZge!Mom}4Uns9s#5^G_w;rBI z>4OTxI+Zg@V(oE-6Cm)|%7;=`iv1(sr`SLGo%`<{^xtcL-D|(Gk4v}XpDXeS#7`7y zV-hD?Z5wX-svR2<`U4x4>47cPqnVb8ye`gi-AK*=v5esWO( zAB3y%ML?QmI^Sb6e)rIhbPt@&&)M-GHrj7A8hzxiUtavK5_iXs$@MxAqwR$M+eROY zXH5T}i!@&O*Teh)o9-I7=i_GUASNtEmRAkuT{~>)I QD9@AFzr_d8hM!OV2`z{!y8r+H literal 0 HcmV?d00001 diff --git a/demo/Castles in the Air/arrangement_piano.mid b/demo/Castles in the Air/arrangement_piano.mid new file mode 100644 index 0000000000000000000000000000000000000000..05bda2d56121cc7950b8c31267067b0d53734e71 GIT binary patch literal 5695 zcmeHLU2h~;5j~p(1aC<2fP|t%%Gz0M77vs4lsnzi?vLr|o}Q1M@!HO6Vbg~@PMx~#<@N8j0KSh0coTnL zUVrg>e)}wAB|H^5)Y?e4Aw`0I$j>aC?))Y#pn zK)UVuGuv7Q6k98TRk1~h!xuhMZLQLSUtQEdLpSvn4Q&6V0Tx?pw^+b;T3F==9F<#C zFn6#!au46$$#f3m*doI}iQ~TSK~FM~(eeVproJl}gvW8_4GE7z8;_&3UxodYq0yCW`f&ZQrs{E$n#qxd+4Z3z#V zpq!wdFiRWouk8Ie2yG6TCjj(5(MX@{=p`HnYKz;}90yB>;d~9NnlrImtF_Ld8t%b zUz_alITJ}=k^$|)hWjl6;3Psjv4|fanQCP*#r4#x0&F~9Lr;g1OjoYlI=zz^s znW;L?C^4Tt=Zbve8IKY~i5?_o^xrHLfTufRT&zLoPjJ6O*U4FYt`IuT^$VrKT$y57 zqAt;vmj3cp4Ma84d1dhd7$zDh&w_;q=^-spD-}_*VncY+-NMpW#ny2wPh1`H@a|ex zemWc4b6V2IJB+E&xXhRbJAaTP^EN4SG`VuKd%F-d7&jX}$%T(hd_VaWw|wPidjpc0 zadXk#olQZc%PHz<@D)D_g%lX+*}nEKdz(3>zX*^RE3FhKem5NJY1|**BDqC%>uKNG zu@{1AB(PB->P2%)=gNom5TIj~Neq2%BqR-DrvV8-K&1-GLKS+Nt6DmLs1k05c~`^Px+BjNu;lkUV(13 zNoVqb+pAteJYx7FKIujcNfO~DM$_`Qid2tW=;z`$cJiG0c`VPpIiCmYPbkJoI4N4{ z8NyOI+gq}o9l@{ev|@YGQ7hh14!N#d&`N0NkG<@jNqxz8t&wl>rBOy103X$J%zJ}+ zgKmS5H;yg^n?Y)q_V^h1cwyxY>Kn8-TD-mCqA*h_R(oT5yQbU#&sJ7xa1kxAWuF}g z)p;v)ijubMGhx>4(@c~mVG2q@W|tFzAG$~(@-tLS%}@s?bFE0N=WKJ)SbtOL!8e=$ z;$VhshH{2_=9@DtW?=7l0Zsj9$!qPAMcxc-fCgoQrm^Jx&^(r%%fZ6so`qEx@xxg& zMl-fE7JL*eQI1fLXl(ZUwPtyY`mqY$<|M1AA+iQBAi#wai?kMT_NZ@0T-*bkO43Eg zO{!YaPc0Hi*F2TSthd$VLifvMgn-W(A(f{UEtFv;@6)|F+xUf$U8jf^)p zHk=`vA)TQhb|l_V1Bb^36(tX&3K{PW^jz7n)NW9iRVf?p9&Q!taoPtP)x4w{+mJKf zZy_wPhkv$qD>lyI*2C*S`G(U{WhY*C*y;mzS*rDjlbi-7m^euo2uEx(K@o=Qgk0#k z$Z9cIwMz`rzKE55#Iu;zT3f{$)tVhkxl}jyJ6kzdouAIN6>eC;J~RaZ z_lur<(nEJGj>I)pOHOp$9J%*}rbw<8-j(=>GU)_pz7KhjB1w^@+(+s3YE!@TuqU#^ z=x?Mg>8P|bMgstsO_DWyZXW$n@kMdV!^>E8i#!C2C_QNE4#?cE)DwVxVtygNE= zp)_=NkuT0cpA?H+WeOPDR7puBr#(AQ$~Y3$IC|Mfoljg|=F-x?mI~AEvj&?+A@Nal z8!wL@KBxSqwL?Qjd7mbIq2Wo}p26}|N=n<1di%>$7~S*~^r@ZogHlCxWctK!8BIyi z(5R)J=*$MSbgAckm*>+`L!F-Is3n3u>%@adCU;`(B2`vRZW^q;7Wym-KR<%)t!Q2 zkaS$!^1nu*)~9<>4wV_!8k%IN4b&e|+Q*sFAX4JKXlvDku~GH<{S_xyhRQ#;Cz<|1 z#jtryD7k3cnmrLlzOK9v*A!j)c#3d8!1U^kV`;pPzrIB}JR$Bc-*b4q9R|0q=7)K} zx1hEPxzQ^td$Q2|F5EBeemLzFzAqU z$U4T_4=8P-U7u(=sU`W9F!7IcgKFbW*iWOLU+ayuh`>GMxug>Mx_H@YC$8Uo=@Wno zo9!{Igd!swX?XQ~`$%(z7+#;bGFj;kQAekj1dgQ#{b3=NmejSyJNhBF% zJ%<}*AZZn=3s8SN_fw@9{n2{V>`eD9-W)X40R15IAMFM7)%~Q0^GoJH&)?lZLz=Hjwf(UxI8jL4y1U)Ix2FwS;K6(8dc*h6vG&T8Q_^3yhE8 z8}JDF9zKY7b@i-JI{VBRjm8+Agd4uu-DjV3R`*_=wT-O+KtFQGAsQ+f94|4hEiqe}g52 zz-!>m8N3CWc|^tV#X*eqk^!agejN89Dgi5v>v24YXay*Tdr#y11fsIz<8gk{@oHlw zj)&WYpbDF2#1-S!IGSAJSaOZ$4r?^EoF{Lf3{y)XCH#Uby<$thjHm)s!~MDVmTn-b z0;}PkS@kA)1FLq`iInh3M5{pEu6i3$EkqqQ`;5z8v)N}^q&r_z%pCXBK*fd^wNIXQck*A(*H?PWi3A}GNKvYvd|Yxc57oTMpVBS`UnweTJY-rl^5v`27*of^_YXLN$WSI;1B^jgzcHHnR?yb;$psLyoy=n{(~D)@A?1IaO

cFk2`j$vc;#5soJMo*Y_ zkv6^3<%-1Hevziu9zlBy$EUW!Ifj$6vl$wrr|_~fS6xO=_?|UU!WCqbKee zJz?MIiF-y**f)COp3xKbjh?t?^n`t*Cy8Zk49%>KHM2IB+-kEn*8EG`Sh?%{U*98q z=1PvjUmx8G^{K0yEX{qdF+jpFzgr7NC cPhN@I`u!RIWwmB!`ePmId#Yl`T6CTD8?|~!kpKVy literal 0 HcmV?d00001 diff --git a/demo/Jingle Bells/arrangement_band.mid b/demo/Jingle Bells/arrangement_band.mid new file mode 100644 index 0000000000000000000000000000000000000000..6e80f14692714add3cab5f341a373009bc9c33b6 GIT binary patch literal 5322 zcmdUz%Whj$6oyyYLd)1(D#V~jb!ZetjjULwRw66GiV~Sd1*t*|JRo!jXelRE+78;8 zof37xfa{QVDYv)a4S0z1{cG)QpEjWq6A**g=j_Y6{OeyAXXpCo$;@81IXi2A>|8(i zg144E{owQ~?>w=mADuaK>ioI6r;6@*+ho{0nJ*mmHd4J|UmvU|KhjTipOb_1Xnw$V zKTTSako{X((JsGdtdPT9;xAj`}V?CbBI_EjeNf-3u zXOt!z76p>atLq4rThdgmB}40$t64LXR$uR<(yCsvjWiws&!}g{y>0o-z1pA0BVIOj z-gY>~hTCV=ihj(FHYl_8~lbee?F-je7?- zZ(P{AfAik%!G%k=b{{@mu%F+wZ=G&{TH~01)>Xo{I#bPGCc{|sv*UIb;&iQENr(F@ z@FpMOvcn@_LXdxok^F%dq^cCYaeGa1s)aTA zszOn%N>bIsicYFKX5CbGx#?e%VK;sC90TJ-wHg?b@Dm~8t(_~9mcBB5gjmFshai!1n!H;RFm1zvHn2yBQ|o+p%dE635~qU`S^?jT&J{crg}b+Zi)zUHbpQG34%C1uiRlnhx*1V z`3Cpu#bh|ZwbiI07)1;g%`hBo7n7}I-xc|1HIGLZ%~o70BHb7@FC2)JmQ5WuWfW-_ zDZvQ}gBl*v(eKJ@M4=UvSMAY!gtUALCXfH&9FTiF!D-UEma1#s$JF|~t${6_tw92{ zg2o;7gd$8^MOLiq)mw!j*}H`yKiKXOuOJB466?(*W^|3Vuhvr*)gN%d7m+z`1D;zh zWgzAJm9k5zx@6Uoql@Sz$B}MHbd(oTAXx~a!4f-iWD~#^lR_h!&{{NXfCD{d(!%Ai z{AzA};1Jyp=7_ri5wx5_RL(P1=#2ZDub>&V$Sk%4Z|wXxzpAUZtvF%_Un23ei^GM9fc855<%kl3g_IRV@I{2vcuZPLY79YoF(WqC8Ac+t<1DdP`OGyW-zL{-yW~`d zg6OnR(~dCCyaU)179nWlN;$s*4wq=HT&6Idn&`~b9IFLbSZ1C^A`@u+NozEZ!~0#p1Bhg?!Q zHa(*-tT>?}+~CJuNZ{E{yVv`8)F}2UIru`^E$&bvs^X&t4V4gz75>AJcyQrylSSMt z2r*xDkY9?ioTR>ZnwC%yfsCw@h5RiW&04D}e&PlQC~hT6>MMHove8(lpiKklq$u5) zWFdcp@p370fKti2BN6~M^hqo1E*wv+YeI;DNAzp<^lh?t;vvL+J$k>(wYvbpAYJf_=6Q4wmeKS-FeCYBl=3NE(4Ez zRf?1np*B2ead|~N;8QP3i>yT%v78RC`1zwde%tMs4#e=vcwG!Cy26y3t?Q6Qj% zRg%gqXhaE=zqk;vMj#efC4+^esaE z@|SzJ_tOLWecOKY@lUl5hUmwMf5pqOB@Mh9bv0)z(Wq?eI4N%F|02^*{o6nUB9PJRnO(s`)p!F6$gG_#s?;lJKR}0Uf%#{LZQ3jPy?e=s+ zN?oQQDiga|rks+QzU$YY_JAj_h?>XYYOEORks7S)u6{R?pzMl?K>DWkJ)_iYXSQDI z>Y671-99v+DdQg_Nk3;#p4vCFFNYa=7$kC9XS>K6BIbcsTv)17=Ts9OLMb!aFv|&h z4Nn*sd^NRdn})X4F!YnV8ki)L!!DJfw;~?5M9VRGlYr+5(|frW)5+7Z@`T$9#=yCa zvd=@U)}!NRpr7`$on(6@M3@e)=9XUZ*GbZeC09`SOy7xh%5dUkc5iVNB@2Nxv9v!- z$z7$F0}N-Djs`&$k*vC)EVQFgP|zv=TiVGsQs;bTy}+_#Hj+;&WrAC=WZn=UwSRfB_wM**Zs2;3_n4(J)mR=k~F6=CdL(1Q{0)fq=Cwau(VZqucT zr^0iyOa5{i_p@rRGs5(mpz0-kLV~Jn!_c}&$Ge0 zPCw*5Xb?J8SEYt%&YL71Sb8pSl(|<5>{P@+wmb82kboUKY##M2h@*RimXPh zv?)U)6hTXy)5#JVjvq>*b_cc=*J>948|J{(fJHkX)s#hUj>bqJg8%@A&LrP#0~V2C qBqcFm+|?SKANMbv8YcGLzjQD|D@4rD8*%^A#q#`LUb?0EZ{r`#$F{`) literal 0 HcmV?d00001 diff --git a/demo/Jingle Bells/lead sheet.mid b/demo/Jingle Bells/lead sheet.mid new file mode 100644 index 0000000000000000000000000000000000000000..7ec0b54710954d0f2b2f5ab097a8bedb5e30b49d GIT binary patch literal 2279 zcmeH{&uSA<6vn@#!H6)-!j0Lrdm&9}dr32ncf6fU(h|Kw1w~v)x=;jdT+}*U$9Jq#VHNUPldAL7 zQ%>C1P$mv vd($=WEkS+f2YM4@B4_whv#o> ztZ$B=UAb{fewhAt?30-|=+B6>XwM9Do>xj(8 zL1R_||N67K0{PAqa#Ob3X-lZ@?L?W6VV>O*(x8$Q^;yPqG0aIoS03xG`na3-c4fG2 zFR~cJq6i0c`KzXV%xsZoU6J%+(G!_7gZ8LiieX9h6zExsg^k_lFc;YvdV+me>0^|* zT+YRCPQ=veDf6*9w<9}KEfp1**s7s&;ORZ6{5UP2veNc@z*ryJqY_~F|0KcAD%}ej zBb@xE5n|wlrz5KqQDMSdtm#4s>}tuN-<9Ap=37#*1Uj^1D4&km5<7^lfi^sX$GxaA z>_~A=2=eZj2a#ocU&f4uKLtJ`KU^pg2W1Fg8r6inV$jd!L{TL!4pEO1JuePY6l3>p?F*1DM`YRb;G#laM-rs2Xo>od^0fb$`SEhz08 zG~DivbhjZZ%2Sy*Ba>r}Gp>%Cx6v|_+9e^Sf9Y@|{-w_6b~hTC>VI?qK5E}DtOjFMap z@g1owdPdYspN)s*KI$oD)GW2_LGz{y7rIeN<>;ITs7%UDOOWGdVctCXB6n zF@d@0<*|JtUEad|4-BPT;|>Z@z2KRrSfsa}FqM|;KU^{vhI>*I^J z*SA(S$0y!g+1@@OKR^HRvCnge>I}h1F1!SVMsjDZ8ACJE&^eGL!)<6oBw+{mfO>{0 zBd6RnWl91qh_1Q3P4#34qUH!!G>~V)0e(fF1e;Ts@N*wxKvD}*=Ww`yxks5`kNLu!)5%9r1M zhI!Qu-bJ;cTx?ac5?uz1sM!GOn~H zO_*mst_si-ZMgE>+3Op(Z*3`8s6%o5@;cQ>ei8ZNWYk-4w1By8U>}ilAdY}~u0?xi zk)O0R8rM{%wQhQtabhxZ=jv2h?D3Cal=rXS%PipL^6{iWM1f|n)e@PMFcD>1^=n|J zHJuU&@L?J%R=1^|u0pj+`D!cmib6hNQ^>RXdO2oDbq!{_r!4lV9FVlq5`d$VLR%B2 z^64~OBw_9fG#mVME*MtArOX$7WtNpwCWnkUM1w}HwNq4}hbO4TsOr;n%2F9PnHp`_ z+>P=MSL+?cq)Ic}cWI1_ZyE@9{v`Y~TuE|-0PJ!SzJ3(wQS3@%|hE?Msp+yPwm z0nwpa_(=#AfEunChfW;O>QKo^K9zg6IY>sO#2eP5(`-nw_Q(=493v?H$}%LFQA*UR zY#3>mL!8T_Zzvf@L_NSBU4NEIVJBg#y7vSEcRu!(yP)AO(>Nam)ULyaLG16 zg8Vq8w~8}MZ_!h#SO@69@l!FJ$_<;1rG{~pOc>%=vMXn{ISMUJkzQUoBMNAhD@G~^ zbr)FgsiclrI{OD;hBSRzf&|g0GV; z(T~IPh-^k*d`k~&*zd~b*~zyr-@tm!`7w{{%+lOJAA2vx;$mg>Z~?2`k&C#%OzDaF fIO_ese9X=lrIstPxS|?&rY=2pckjMVY~k8Jh3X?! literal 0 HcmV?d00001 diff --git a/demo/Sally Garden/arrangement_piano.mid b/demo/Sally Garden/arrangement_piano.mid new file mode 100644 index 0000000000000000000000000000000000000000..cdafd0c802e235d8023d824ccea3211ab325161b GIT binary patch literal 2577 zcmcgu%}Sg>5GlWkcb{4B94oTEW0c#AFGStA%{KAArKfs-r!TnIe8=V3OVOp zOjY;He6t}xmmq|&+ubu$)m7c=_1*gjU>951#*gc}PapX06`pRlU%q+7)4QFW7q9oW zp5DS!Ti@{rd|AYPJ&#x*%_8RfADAlyEMi&!T}&OYj1T8s#1j3Q-&Tx8tl$}jlu3(- z4$1Di@9JeUo}BU6Ec(^SZhoFOI}$Ot{rbl7M%-l7_ubNnSfXDqmB~!{wNq)>Z*puB z$!}Q!w=u0EZt-xwdUm>>Kh7{HySIyPz}KUQ5OEnnEfz6cQc;8eiy#M8z|WSdl>y-h zcGl`Wdl+$Mw@qo2m`?Z=;RF`6JEwqBvX5tVGlO$0p^JVQRH{Th3mEB1kY#3QT{CU8F*aNB=?O8bof;>BalqL) zLT(YR82v6$z{YF(@!=`2v6Yu`TD$x$-+K8h;~j#p3<_JM(|mEBogMpD@PleoQm0@W zA#sthfRI@1YtcQ~9mtRmZhA>o`&czU#2(HsjtFU&PjbowY4#!(lepb9I3lDSw(}TAC&0Fp3JBo34>3 zjcm;+QV45c3suuNZ)=*2nCgVmTP{f{9`<_a6t%R)S1{E$nKMO*1Kf6Isq0?{mecA> zy~6XoUW`Z;Hwl+E{ByOZ3W)9e)6}|^tm_-U18uy$$13S%TgSG@t#LASu)hX=0()2C AxBvhE literal 0 HcmV?d00001 diff --git a/demo/Sally Garden/lead sheet.mid b/demo/Sally Garden/lead sheet.mid new file mode 100644 index 0000000000000000000000000000000000000000..5f6eedaa5d8ac32e263971f3116709f77691bda1 GIT binary patch literal 1300 zcmdT?%SyvQ6g^2xD?%vloSiPZ>OxXVOKB)V2m>KaCQzs%E~H&3f`U7jC6+>0`U(Am z_78O7fA}ZyW-?Pkbs;V+u7-Q=hr{`;i0Z(y43FC>0m=43ICjavwp%Ou#i-Q1r@VK zfVz?vlF$$=phrrRvbQbyM)t<8={SI;hQ&k_XVJVC?b3wB$c2%fKpoC|@o?n9tm#tT zfc2M4d1(rC1bU>P33PxS?6yExpifX6=mLG%Em`H{XiFA4G9i0qt?Vt<&TG2RCK5SR zlE|SMw9TubgzHt9ZiGYsMQFZ|gDkP#zQeV1Y9}{Q@2E`6tj&!eYqgIWa_~cMe%+{z O`VXcqM*8EkWBvqSJ_#)V literal 0 HcmV?d00001 diff --git a/inference.ipynb b/inference.ipynb new file mode 100644 index 0000000..261b082 --- /dev/null +++ b/inference.ipynb @@ -0,0 +1,170 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "os.environ['CUDA_VISIBLE_DEVICES']= '0'\n", + "os.environ['CUDA_LAUNCH_BLOCKING'] = '1'\n", + "os.environ[\"KMP_DUPLICATE_LIB_OK\"]=\"TRUE\"\n", + "import numpy as np\n", + "import pretty_midi as pyd\n", + "from arrangement_utils import *\n", + "import warnings\n", + "warnings.filterwarnings(\"ignore\")\n", + "\n", + "\"\"\"Download checkpoints and data at https://we.tl/t-cc2nOC8dAA (379MB) and decompress at the directory\"\"\"\n", + "DATA_FILE_ROOT = './data_file_dir/'\n", + "DEVICE = 'cuda:0'\n", + "piano_arranger, orchestrator, piano_texture, band_prompt = load_premise(DATA_FILE_ROOT, DEVICE)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Initialize input and preference\n", + "We provide three sample lead sheets for a quick inference. You should be able to directly run the code blocks after downloading the pre-trained checkpoints.\n", + "\n", + "If you wish to test our model on your own lead sheet file, please initialize a sub-folder with its `SONG_NAME` in the `./demo` folder and put the file in, and name the file \"lead sheet.mid\". \n", + "\n", + "Please also specify `SEGMENTATION` (phrase structure) and `NOTE_SHIFT` (the duration of the pick-up measure if any)." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "\"\"\"Set input lead sheet\"\"\"\n", + "SONG_NAME, SEGMENTATION, NOTE_SHIFT = 'Castles in the Air', 'A8A8B8B8', 1 #1 beat in the pick-up measure\n", + "#SONG_NAME, SEGMENTATION, NOTE_SHIFT = 'Jingle Bells', 'A8B8A8', 0\n", + "#SONG_NAME, SEGMENTATION, NOTE_SHIFT = 'Sally Garden', 'A4A4B4A4', 0\n", + "\n", + "\"\"\"Set texture pre-filtering for piano arrangement (default random)\"\"\"\n", + "RHTHM_DENSITY = np.random.randint(3, 5)\n", + "VOICE_NUMBER = np.random.randint(3, 5)\n", + "PREFILTER = (RHTHM_DENSITY, VOICE_NUMBER)\n", + "\n", + "\"\"\"Set if use a 2-bar prompt for full-band arrangement (default False)\"\"\" \n", + "USE_PROMPT = False\n", + "\n", + "lead_sheet = read_lead_sheet('./demo', SONG_NAME, SEGMENTATION, NOTE_SHIFT)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Piano Accompaniment Arrangement" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Phrasal Unit selection begins:\n", + "\t 4 phrases in the lead sheet;\n", + "\t set note density filter: (4, 4).\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 3/3 [00:17<00:00, 5.82s/it]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Re-harmonization begins ...\n", + "Piano accompaiment generated!\n" + ] + } + ], + "source": [ + "midi_piano, acc_piano = piano_arrangement(*lead_sheet, *piano_texture, piano_arranger, PREFILTER)\n", + "midi_piano.write(f'./demo/{SONG_NAME}/arrangement_piano.mid')" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Orchestration" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Prior model initialized with 7 tracks:\n", + "\t['Acoustic Bass', 'Chromatic Percussion', 'Synth Pad', 'Acoustic Piano', 'Electric Piano', 'Acoustic Bass', 'Acoustic Guitar']\n", + "Orchestration begins ...\n", + "Full-band accompaiment generated!\n" + ] + } + ], + "source": [ + "midi_prompt, func_prompt = prompt_sampling(acc_piano, *band_prompt, DEVICE)\n", + "if USE_PROMPT:\n", + " midi_band = orchestration(acc_piano, None, *func_prompt, orchestrator, DEVICE, blur=.5, p=.1, t=4)\n", + "else:\n", + " instruments, pitch_prompt, time_promt = func_prompt\n", + " midi_band = orchestration(acc_piano, None, instruments, None, None, orchestrator, DEVICE, blur=.5, p=.1, t=4)\n", + "mel_track = pyd.Instrument(program=72, is_drum=False, name='melody')\n", + "mel_track.notes = midi_piano.instruments[0].notes\n", + "midi_band.instruments.append(mel_track)\n", + "midi_band.write(f'./demo/{SONG_NAME}/arrangement_band.mid')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "torch1.10_conda11.3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.13" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/orchestrator/Prior.py b/orchestrator/Prior.py new file mode 100644 index 0000000..f585deb --- /dev/null +++ b/orchestrator/Prior.py @@ -0,0 +1,587 @@ +import math +import torch +from torch import nn +import torch.nn.functional as F +from .QandA import QandA +from .TransformerEncoderLayer import TransformerEncoderLayer as TransformerEncoderLayerRPE +import numpy as np + +NUM_INSTR_CLASS = 34 +NUM_PITCH_CODE = 64 +NUM_TIME_CODE = 128 +TOTAL_LEN_BIN = np.array([4, 7, 12, 15, 20, 23, 28, 31, 36, 39, 44, 47, 52, 55, 60, 63, 68, 71, 76, 79, 84, 87, 92, 95, 100, 103, 108, 111, 116, 119, 124, 127, 132]) +ABS_POS_BIN = np.arange(129) +REL_POS_BIN = np.arange(128) + +class Prior(nn.Module): + def __init__(self, mixture_encoder=None, + pitch_function_encoder=None, + time_function_encoder=None, + context_enc_layer=2, + function_dec_layer=4, + d_model=256, + nhead=8, + dim_feedforward=1024, + dropout=.1, + ft_resolution=8, + inference=False, + QaA_model=None, + DEVICE='cuda:0'): + super(Prior, self).__init__() + + # embeddings + self.fp_embedding = nn.Embedding(num_embeddings=NUM_PITCH_CODE+1, embedding_dim=d_model, padding_idx=NUM_PITCH_CODE) + self.ft_embedding = nn.Embedding(num_embeddings=NUM_TIME_CODE+1, embedding_dim=d_model, padding_idx=NUM_TIME_CODE) + self.prog_embedding = nn.Embedding(num_embeddings=NUM_INSTR_CLASS+1, embedding_dim=d_model, padding_idx=NUM_INSTR_CLASS) + self.total_len_embedding = nn.Embedding(num_embeddings=len(TOTAL_LEN_BIN)+1, embedding_dim=d_model, padding_idx=len(TOTAL_LEN_BIN)) + self.abs_pos_embedding = nn.Embedding(num_embeddings=len(ABS_POS_BIN)+1, embedding_dim=d_model, padding_idx=len(ABS_POS_BIN)) + self.rel_pos_embedding = nn.Embedding(num_embeddings=len(REL_POS_BIN)+1, embedding_dim=d_model, padding_idx=len(REL_POS_BIN)) + + self.start_embedding = nn.Parameter(torch.empty(NUM_INSTR_CLASS+1, 9, d_model)) + nn.init.normal_(self.start_embedding) + with torch.no_grad(): + self.start_embedding[NUM_INSTR_CLASS].fill_(0) + + #pre-trained encoders + if not inference: + self.mixture_encoder = mixture_encoder + for param in self.mixture_encoder.parameters(): + param.requires_grad = False + self.pitch_function_encoder = pitch_function_encoder + for param in self.pitch_function_encoder.parameters(): + param.requires_grad = False + self.time_function_encoder = time_function_encoder + for param in self.time_function_encoder.parameters(): + param.requires_grad = False + else: + self.QaA_model = QaA_model + self.mixture_encoder = self.QaA_model.prmat_enc_fltn + self.pitch_function_encoder = self.QaA_model.func_pitch_enc + self.time_function_encoder = self.QaA_model.func_time_enc + + #multi-stream Transformer + self.context_enc = nn.TransformerEncoder( + nn.TransformerEncoderLayer(d_model=d_model, + nhead=nhead, + dim_feedforward=dim_feedforward, + dropout=dropout, + activation=F.gelu, + batch_first=True, + norm_first=True, + device=DEVICE), + num_layers=context_enc_layer) + self.ms_trf = nn.ModuleDict({}) + for layer in range(function_dec_layer): + """self.ms_trf[f'track_layer_{layer}'] = nn.TransformerEncoderLayer(d_model=d_model, + nhead=nhead, + dim_feedforward=dim_feedforward, + dropout=dropout, + activation=F.gelu, + batch_first=True, + norm_first=True, + device=DEVICE)""" + self.ms_trf[f'track_layer_{layer}'] = TransformerEncoderLayerRPE(d_model=d_model, + nhead=nhead, + dim_feedforward=dim_feedforward, + dropout=dropout, + norm_first=True, + max_len=24).to(DEVICE) + self.ms_trf[f'time_layer_{layer}'] = nn.TransformerDecoderLayer(d_model=d_model, + nhead=nhead, + dim_feedforward=dim_feedforward, + dropout=dropout, + activation=F.gelu, + batch_first=True, + norm_first=True, + device=DEVICE) + + #positional encoding + self.max_len = 1000 + position = torch.arange(self.max_len).unsqueeze(1) + div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) + pe = torch.zeros(1, self.max_len, d_model) + pe[0, :, 0::2] = torch.sin(position * div_term) + pe[0, :, 1::2] = torch.cos(position * div_term) + #pe = torch.flip(pe, dims=[1]) + pe = pe.to(DEVICE) + self.register_buffer('pe', pe) + + #decoder output module + self.fp_out_linear = nn.Linear(d_model, NUM_PITCH_CODE) + self.ft_out_linear = nn.Linear(d_model, NUM_TIME_CODE) + + #constants + self.d_model = d_model + self.function_dec_layer = function_dec_layer + self.ft_resolution = ft_resolution + + #loss function + self.criterion = nn.CrossEntropyLoss(reduction='mean') + + + def generate_square_subsequent_mask(self, sz=15): + return torch.triu(torch.ones(sz, sz), diagonal=1).repeat_interleave(9,dim=0).repeat_interleave(9,dim=1).bool() + + + def run(self, mix, prog, fp, ft, tm_mask, tk_mask, total_len, abs_pos, rel_pos, inference=False): + #mix: (batch, max_time, 256) + #prog: (batch, max_track) + #fp: (batch, max_time, max_track) + #ft: (batch, max_time, max_track, 8) + #tm_mask: (batch, max_time) + #tk_mask: (batch, max_track) + #total_len: (batch, max_time) + #abs_pos: (batch, max_time) + #rel_pos: (batch, max_time) + batch, max_time, _ = mix.shape + _, max_track = prog.shape + + #with torch.no_grad(): + #mix = mix.reshape(-1, time, max_simu_note, 6) + #mix = self.mixture_encoder(mix)[0].mean.reshape(batch, num_2bar, -1) #(batch, num_2bar, 256) + #fp = fp.reshape(-1, 128) + #fp = self.pitch_function_encoder.get_code_indices(fp).reshape(batch, num_2bar, max_track) + #ft = ft.reshape(-1, 32) + #ft = self.time_function_encoder.get_code_indices(ft).reshape(batch, num_2bar, max_track, self.ft_resolution) + + mix = mix + self.pe[:, :mix.shape[1], :] + #mix = mix + self.total_len_embedding(total_len) + #mix = mix + self.abs_pos_embedding(abs_pos) + #mix = mix + self.rel_pos_embedding(rel_pos) + mix = mix.unsqueeze(1) + self.prog_embedding(prog).unsqueeze(2) #(batch, max_track, max_time, 256) + mix = self.context_enc(mix.reshape(-1, max_time, self.d_model)) #(batch*max_track, max_time, 256) + + func = torch.cat([self.fp_embedding(fp[:, :-1].unsqueeze(-1)), + self.ft_embedding(ft[:, :-1])], + dim=-2) #batch, max_time-1, max_track, 9, d_model + + func = torch.cat([ + self.start_embedding[prog].unsqueeze(1), #(batch, 1, max_track, 9, d_model) + func], + dim=1) #batch, max_time, max_track, 9, d_model + + func = func.permute(0, 1, 3, 2, 4).reshape(batch, -1, max_track, self.d_model) #(batch, max_time*9, max_track, d_model) + + func = func + self.prog_embedding(prog).unsqueeze(1) + func = func + self.pe[:, :func.shape[1], :].unsqueeze(2) + func = func + self.total_len_embedding(total_len).repeat_interleave(9, dim=1).unsqueeze(2) + func = func + self.abs_pos_embedding(abs_pos).repeat_interleave(9, dim=1).unsqueeze(2) + func = func + self.rel_pos_embedding(rel_pos).repeat_interleave(9, dim=1).unsqueeze(2) + + for layer in range(self.function_dec_layer): + func = func.reshape(-1, max_track, self.d_model) + func = self.ms_trf[f'track_layer_{layer}'](src=func, + src_key_padding_mask=tk_mask.unsqueeze(1).repeat(1, max_time*9, 1).reshape(-1, max_track)) + func = func.reshape(batch, -1, max_track, self.d_model).permute(0, 2, 1, 3).reshape(-1, max_time*9, self.d_model) + func = self.ms_trf[f'time_layer_{layer}'](tgt=func, + tgt_mask=self.generate_square_subsequent_mask(max_time).to(func.device), + tgt_key_padding_mask=tm_mask.unsqueeze(1).repeat(1, max_track, 1).reshape(-1, max_time).repeat_interleave(9, dim=-1), + memory=mix) + func = func.reshape(batch, max_track, -1, self.d_model).permute(0, 2, 1, 3) #(batch, max_time*9, max_track, d_model) + + func = func.reshape(batch, max_time, 9, max_track, self.d_model) + fp_recon = self.fp_out_linear(func[:, :, 0]) + ft_recon = self.ft_out_linear(func[:, :, 1:].permute(0, 1, 3, 2, 4)) + + return fp_recon, ft_recon + + + def loss_function(self, fp_recon, ft_recon, fp_gt, ft_gt, tm_mask, tk_mask): + mask = torch.logical_or(tm_mask.unsqueeze(-1), tk_mask.unsqueeze(1)) + unmask = torch.logical_not(mask) + + fp_loss = self.criterion(fp_recon[unmask], + fp_gt[unmask]) + ft_loss = self.criterion(ft_recon[unmask].reshape(-1, NUM_TIME_CODE), + ft_gt[unmask].reshape(-1)) + + loss = 0.11*fp_loss + 0.89*ft_loss + return loss, fp_loss, ft_loss + + + def loss(self, mix, prog, fp, ft, tm_mask, tk_mask, total_len, abs_pos, rel_pos): + output = self.run(mix, prog, fp, ft, tm_mask, tk_mask, total_len, abs_pos, rel_pos, inference=False) + return self.loss_function(*output, fp, ft, tm_mask, tk_mask) + + + def forward(self, mode, *input, **kwargs): + if mode in ["run", 0]: + return self.run(*input, **kwargs) + elif mode in ['loss', 'train', 1]: + return self.loss(*input, **kwargs) + elif mode in ['inference', 'eval', 'val', 2]: + return self.inference(*input, **kwargs) + else: + raise NotImplementedError + + + def run_autoregressive_greedy(self, mix, prog, fp, ft, total_len, abs_pos, rel_pos, blur=.5): + #mix: (batch, num2bar, bar_resolution, max_simu_note, 6) + #prog: (batch, max_track) + #fp: (batch, 1, max_track, 128) + #ft: (batch, 1, max_track, 32) + #total_len: (batch, num2bar) + #abs_pos: (batch, num2bar) + #rel_pos: (batch, num2bar) + batch, num_2bar, time, max_simu_note, _ = mix.shape + _, max_track = prog.shape + + mix = mix.reshape(-1, time, max_simu_note, 6) + mix = self.mixture_encoder(mix)[0].mean.reshape(batch, num_2bar, -1) #(batch, num_2bar, 256) + mix_ = (1-blur)*mix.clone() + blur*torch.empty(mix.shape, device=mix.device).normal_(mean=0, std=1) + self.pe[:, :mix.shape[1], :] + #mix_ = mix_ + self.total_len_embedding(total_len) + #mix_ = mix_ + self.abs_pos_embedding(abs_pos) + #mix_ = mix_ + self.rel_pos_embedding(rel_pos) + mix_ = mix_.unsqueeze(1) + self.prog_embedding(prog).unsqueeze(2) #(batch, max_track, num2bar, 256) + mix_ = self.context_enc(mix_.reshape(-1, num_2bar, self.d_model)) + + func = self.start_embedding[prog].unsqueeze(1) #(batch, 1, max_track, 9, d_model) + for idx in range(num_2bar): + if idx == 0: + if (fp is not None) and (ft is not None): + fp = fp.reshape(-1, 128) + fp = self.pitch_function_encoder.get_code_indices(fp).reshape(batch, 1, max_track) + ft = ft.reshape(-1, 32) + ft = self.time_function_encoder.get_code_indices(ft).reshape(batch, 1, max_track, self.ft_resolution) + continue + else: + fp = torch.empty((batch, 0, max_track)).long().to(mix.device) + ft = torch.empty((batch, 0, max_track, self.ft_resolution)).long().to(mix.device) + elif idx > 0: + func = torch.cat([ + func, + torch.cat([self.fp_embedding(fp[:, idx-1: idx].unsqueeze(-1)), + self.ft_embedding(ft[:, idx-1: idx])], + dim=-2) #*batch, 1, max_track, 9, d_model + ], dim=1) #*batch, idx+1, max_track, 9, d_model + + func = func.permute(0, 1, 3, 2, 4).reshape(batch, -1, max_track, self.d_model) + + func = func + self.prog_embedding(prog).unsqueeze(1) + func = func + self.pe[:, :func.shape[1], :].unsqueeze(2) + func = func + self.total_len_embedding(total_len[:, : 1+idx]).repeat_interleave(9, dim=1).unsqueeze(2) + func = func + self.abs_pos_embedding(abs_pos[:, : 1+idx]).repeat_interleave(9, dim=1).unsqueeze(2) + func = func + self.rel_pos_embedding(rel_pos[:, : 1+idx]).repeat_interleave(9, dim=1).unsqueeze(2) + + for layer in range(self.function_dec_layer): + + func = func.reshape(-1, max_track, self.d_model) + func = self.ms_trf[f'track_layer_{layer}'](src=func) + func = func.reshape(batch, -1, max_track, self.d_model).permute(0, 2, 1, 3).reshape(-1, (1+idx)*9, self.d_model) + func = self.ms_trf[f'time_layer_{layer}'](tgt=func, + tgt_mask=self.generate_square_subsequent_mask(sz=1+idx).to(func.device), + memory=mix_) + func = func.reshape(batch, max_track, -1, self.d_model).permute(0, 2, 1, 3) #(batch, num2bar-1, max_track, d_model) + #print('func output', func.shape) + + func = func.reshape(batch, 1+idx, 9, max_track, self.d_model).permute(0, 1, 3, 2, 4) + fp_pred = self.fp_out_linear(func[:, -1, :, 0]).unsqueeze(1).max(-1)[1] + ft_pred = self.ft_out_linear(func[:, -1, :, 1:]).unsqueeze(1).max(-1)[1] + + fp = torch.cat([fp, fp_pred], dim=1) + ft = torch.cat([ft, ft_pred], dim=1) + if fp.shape[1] == num_2bar: + break + + z_fp = self.pitch_function_encoder.infer_by_codes(fp) + z_ft = self.time_function_encoder.infer_by_codes(ft) + return self.QaA_model.infer_with_function_codes(mix[0], prog[0].repeat(num_2bar, 1), z_fp[0], z_ft[0]) + + + def run_autoregressive_nucleus(self, mix, prog, fp, ft, total_len, abs_pos, rel_pos, blur=.5, p=.1, t=1): + #mix: (batch, num2bar, bar_resolution, max_simu_note, 6) + #prog: (batch, max_track) + #fp: (batch, 1, max_track, 128) + #ft: (batch, 1, max_track, 32) + #total_len: (batch, num2bar) + #abs_pos: (batch, num2bar) + #rel_pos: (batch, num2bar) + batch, num_2bar, time, max_simu_note, _ = mix.shape + _, max_track = prog.shape + + mix = mix.reshape(-1, time, max_simu_note, 6) + mix = self.mixture_encoder(mix)[0].mean.reshape(batch, num_2bar, -1) #(batch, num_2bar, 256) + mix_ = (1-blur)*mix.clone() + blur*torch.empty(mix.shape, device=mix.device).normal_(mean=0, std=1) + self.pe[:, :mix.shape[1], :] + #mix_ = mix_ + self.total_len_embedding(total_len) + #mix_ = mix_ + self.abs_pos_embedding(abs_pos) + #mix_ = mix_ + self.rel_pos_embedding(rel_pos) + mix_ = mix_.unsqueeze(1) + self.prog_embedding(prog).unsqueeze(2) #(batch, max_track, num2bar, 256) + mix_ = self.context_enc(mix_.reshape(-1, num_2bar, self.d_model)) + + func = self.start_embedding[prog].unsqueeze(1) #(batch, 1, max_track, 9, d_model) + for idx in range(num_2bar): + if idx == 0: + if (fp is not None) and (ft is not None): + fp = fp.reshape(-1, 128) + fp = self.pitch_function_encoder.get_code_indices(fp).reshape(batch, 1, max_track) + ft = ft.reshape(-1, 32) + ft = self.time_function_encoder.get_code_indices(ft).reshape(batch, 1, max_track, self.ft_resolution) + continue + else: + fp = torch.empty((batch, 0, max_track)).long().to(mix.device) + ft = torch.empty((batch, 0, max_track, self.ft_resolution)).long().to(mix.device) + elif idx > 0: + func = torch.cat([ + func, + torch.cat([self.fp_embedding(fp[:, idx-1: idx].unsqueeze(-1)), + self.ft_embedding(ft[:, idx-1: idx])], + dim=-2) #*batch, 1, max_track, 9, d_model + ], dim=1) #*batch, idx+1, max_track, 9, d_model + + func = func.permute(0, 1, 3, 2, 4).reshape(batch, -1, max_track, self.d_model) + + func = func + self.prog_embedding(prog).unsqueeze(1) + func = func + self.pe[:, :func.shape[1], :].unsqueeze(2) + func = func + self.total_len_embedding(total_len[:, : 1+idx]).repeat_interleave(9, dim=1).unsqueeze(2) + func = func + self.abs_pos_embedding(abs_pos[:, : 1+idx]).repeat_interleave(9, dim=1).unsqueeze(2) + func = func + self.rel_pos_embedding(rel_pos[:, : 1+idx]).repeat_interleave(9, dim=1).unsqueeze(2) + + for layer in range(self.function_dec_layer): + + func = func.reshape(-1, max_track, self.d_model) + func = self.ms_trf[f'track_layer_{layer}'](src=func) + func = func.reshape(batch, -1, max_track, self.d_model).permute(0, 2, 1, 3).reshape(-1, (1+idx)*9, self.d_model) + func = self.ms_trf[f'time_layer_{layer}'](tgt=func, + tgt_mask=self.generate_square_subsequent_mask(sz=1+idx).to(func.device), + memory=mix_) + func = func.reshape(batch, max_track, -1, self.d_model).permute(0, 2, 1, 3) #(batch, num2bar-1, max_track, d_model) + #print('func output', func.shape) + + func = func.reshape(batch, 1+idx, 9, max_track, self.d_model).permute(0, 1, 3, 2, 4) + + + fp_logits = self.fp_out_linear(func[:, -1, :, 0]).unsqueeze(1) / t + if idx == 0: + filtered_fp_logits = self.nucleus_filter(fp_logits/2, 2*p) + else: + filtered_fp_logits = self.nucleus_filter(fp_logits, p) + fp_probability = F.softmax(filtered_fp_logits, dim=-1) + #print('fp_probability', fp_probability.shape) + fp_pred = torch.multinomial(fp_probability.reshape(-1, NUM_PITCH_CODE), 1).reshape(fp_probability.shape[:-1]) + + ft_logits = self.ft_out_linear(func[:, -1, :, 1:]).unsqueeze(1) / t + if idx == 0: + filtered_ft_logits = self.nucleus_filter(ft_logits/2, 2*p) + else: + filtered_ft_logits = self.nucleus_filter(ft_logits, p) + ft_probability = F.softmax(filtered_ft_logits, dim=-1) + ft_pred = torch.multinomial(ft_probability.reshape(-1, NUM_TIME_CODE), 1).reshape(ft_probability.shape[:-1]) + + fp = torch.cat([fp, fp_pred], dim=1) + ft = torch.cat([ft, ft_pred], dim=1) + if fp.shape[1] == num_2bar: + break + + z_fp = self.pitch_function_encoder.infer_by_codes(fp) + z_ft = self.time_function_encoder.infer_by_codes(ft) + return self.QaA_model.infer_with_function_codes(mix[0], prog[0].repeat(num_2bar, 1), z_fp[0], z_ft[0]) + + def nucleus_filter(self, logits, p): + #sorted_logits, sorted_indices = torch.sort(logits, descending=True) + sorted_logits, sorted_indices = torch.sort(logits, dim=-1, descending=True) + #cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) + cum_sum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) + # Remove tokens with cumulative probability above the threshold + nucleus = cum_sum_probs < p + # Shift the indices to the right to keep also the first token above the threshold + nucleus = torch.cat([nucleus.new_ones(nucleus.shape[:-1] + (1,)), nucleus[..., :-1]], dim=-1) + nucleus = nucleus.gather(-1, sorted_indices.argsort(-1)) + logits[~nucleus] = float('-inf') + return logits + + + def run_autoregressive_nucleus_long_sample(self, mix, prog, fp, ft, total_len, abs_pos, rel_pos, blur=.5, p=.1, t=1): + #mix: (batch, num2bar, bar_resolution, max_simu_note, 6) + #prog: (batch, max_track) + #fp: (batch, 1, max_track, 128) + #ft: (batch, 1, max_track, 32) + #total_len: (batch, num2bar) + #abs_pos: (batch, num2bar) + #rel_pos: (batch, num2bar) + batch, num_2bar, time, max_simu_note, _ = mix.shape + _, max_track = prog.shape + + MAX_LEN = 16 + HOP_LEN = 4 + START = 0 + + mix = mix.reshape(-1, time, max_simu_note, 6) + mix = self.mixture_encoder(mix)[0].mean.reshape(batch, num_2bar, -1) #(batch, num_2bar, 256) + + func = self.start_embedding[prog].unsqueeze(1) #(batch, 1, max_track, 9, d_model) + for START in range(0, num_2bar - MAX_LEN+1, HOP_LEN): + mix_ = (1-blur)*mix[:, START: START+MAX_LEN].clone() + mix_ = mix_ + blur*torch.empty(mix_.shape, device=mix_.device).normal_(mean=0, std=1) + self.pe[:, :MAX_LEN, :] + #mix_ = mix_ + self.total_len_embedding(total_len) + #mix_ = mix_ + self.abs_pos_embedding(abs_pos) + #mix_ = mix_ + self.rel_pos_embedding(rel_pos) + mix_ = mix_.unsqueeze(1) + self.prog_embedding(prog).unsqueeze(2) #(batch, max_track, num2bar, 256) + mix_ = self.context_enc(mix_.reshape(-1, MAX_LEN, self.d_model)) + + if START == 0: + init = 0 + else: + init = MAX_LEN-HOP_LEN + for idx in range(init, MAX_LEN): + if START == 0: + if idx == 0: + if (fp is not None) and (ft is not None): + fp = fp.reshape(-1, 128) + fp = self.pitch_function_encoder.get_code_indices(fp).reshape(batch, 1, max_track) + ft = ft.reshape(-1, 32) + ft = self.time_function_encoder.get_code_indices(ft).reshape(batch, 1, max_track, self.ft_resolution) + continue + else: + fp = torch.empty((batch, 0, max_track)).long().to(mix.device) + ft = torch.empty((batch, 0, max_track, self.ft_resolution)).long().to(mix.device) + elif idx > 0: + func = torch.cat([ + func, + torch.cat([self.fp_embedding(fp[:, idx-1: idx].unsqueeze(-1)), + self.ft_embedding(ft[:, idx-1: idx])], + dim=-2) #*batch, 1, max_track, 9, d_model + ], dim=1) #*batch, idx+1, max_track, 9, d_model + else: + func = torch.cat([ + func, + torch.cat([self.fp_embedding(fp[:, idx-1: idx].unsqueeze(-1)), + self.ft_embedding(ft[:, idx-1: idx])], + dim=-2) #*batch, 1, max_track, 9, d_model + ], dim=1) #*batch, idx+1, max_track, 9, d_model + if idx == init: + func = func[:, HOP_LEN:] + + + func = func.permute(0, 1, 3, 2, 4).reshape(batch, -1, max_track, self.d_model) + + func = func + self.prog_embedding(prog).unsqueeze(1) + func = func + self.pe[:, 9:9+func.shape[1], :].unsqueeze(2) + print(START) + print('func', func.shape), print('total_len', self.total_len_embedding(total_len[:, START: START+1+idx]).repeat_interleave(9, dim=1).unsqueeze(2).shape) + func = func + self.total_len_embedding(total_len[:, START: START+1+idx]).repeat_interleave(9, dim=1).unsqueeze(2) + func = func + self.abs_pos_embedding(abs_pos[:, START: START+1+idx]).repeat_interleave(9, dim=1).unsqueeze(2) + func = func + self.rel_pos_embedding(rel_pos[:, START: START+1+idx]).repeat_interleave(9, dim=1).unsqueeze(2) + + for layer in range(self.function_dec_layer): + + func = func.reshape(-1, max_track, self.d_model) + func = self.ms_trf[f'track_layer_{layer}'](src=func) + func = func.reshape(batch, -1, max_track, self.d_model).permute(0, 2, 1, 3).reshape(-1, (1+idx)*9, self.d_model) + func = self.ms_trf[f'time_layer_{layer}'](tgt=func, + tgt_mask=self.generate_square_subsequent_mask(sz=1+idx).to(func.device), + memory=mix_) + func = func.reshape(batch, max_track, -1, self.d_model).permute(0, 2, 1, 3) #(batch, num2bar-1, max_track, d_model) + #print('func output', func.shape) + func = func.reshape(batch, 1+idx, 9, max_track, self.d_model).permute(0, 1, 3, 2, 4) + + fp_logits = self.fp_out_linear(func[:, -1, :, 0]).unsqueeze(1) / t + if idx == 0: + filtered_fp_logits = self.nucleus_filter(fp_logits/2, 2*p) + else: + filtered_fp_logits = self.nucleus_filter(fp_logits, p) + fp_probability = F.softmax(filtered_fp_logits, dim=-1) + #print('fp_probability', fp_probability.shape) + fp_pred = torch.multinomial(fp_probability.reshape(-1, NUM_PITCH_CODE), 1).reshape(fp_probability.shape[:-1]) + + ft_logits = self.ft_out_linear(func[:, -1, :, 1:]).unsqueeze(1) / t + if idx == 0: + filtered_ft_logits = self.nucleus_filter(ft_logits/2, 2*p) + else: + filtered_ft_logits = self.nucleus_filter(ft_logits, p) + ft_probability = F.softmax(filtered_ft_logits, dim=-1) + ft_pred = torch.multinomial(ft_probability.reshape(-1, NUM_TIME_CODE), 1).reshape(ft_probability.shape[:-1]) + + fp = torch.cat([fp, fp_pred], dim=1) + ft = torch.cat([ft, ft_pred], dim=1) + if fp.shape[1] == num_2bar: + print('precise') + break + + + if START + MAX_LEN < num_2bar: + rest = num_2bar - (START + MAX_LEN) + START = num_2bar - MAX_LEN + mix_ = (1-blur)*mix[:, START:].clone() + mix_ = mix_ + blur*torch.empty(mix_.shape, device=mix_.device).normal_(mean=0, std=1) + self.pe[:, :MAX_LEN, :] + mix_ = mix_.unsqueeze(1) + self.prog_embedding(prog).unsqueeze(2) #(batch, max_track, num2bar, 256) + mix_ = self.context_enc(mix_.reshape(-1, MAX_LEN, self.d_model)) + + for idx in range(MAX_LEN - rest, MAX_LEN): + func = torch.cat([ + func, + torch.cat([self.fp_embedding(fp[:, idx-1: idx].unsqueeze(-1)), + self.ft_embedding(ft[:, idx-1: idx])], + dim=-2) #*batch, 1, max_track, 9, d_model + ], dim=1) #*batch, idx+1, max_track, 9, d_model + if idx == MAX_LEN - rest: + func = func[:, rest:] + + func = func.permute(0, 1, 3, 2, 4).reshape(batch, -1, max_track, self.d_model) + + func = func + self.prog_embedding(prog).unsqueeze(1) + func = func + self.pe[:, 9:9+func.shape[1], :].unsqueeze(2) + print(START) + print('func', func.shape), print('total_len', self.total_len_embedding(total_len[:, START: START+1+idx]).repeat_interleave(9, dim=1).unsqueeze(2).shape) + func = func + self.total_len_embedding(total_len[:, START: START+1+idx]).repeat_interleave(9, dim=1).unsqueeze(2) + func = func + self.abs_pos_embedding(abs_pos[:, START: START+1+idx]).repeat_interleave(9, dim=1).unsqueeze(2) + func = func + self.rel_pos_embedding(rel_pos[:, START: START+1+idx]).repeat_interleave(9, dim=1).unsqueeze(2) + + for layer in range(self.function_dec_layer): + + func = func.reshape(-1, max_track, self.d_model) + func = self.ms_trf[f'track_layer_{layer}'](src=func) + func = func.reshape(batch, -1, max_track, self.d_model).permute(0, 2, 1, 3).reshape(-1, (1+idx)*9, self.d_model) + func = self.ms_trf[f'time_layer_{layer}'](tgt=func, + tgt_mask=self.generate_square_subsequent_mask(sz=1+idx).to(func.device), + memory=mix_) + func = func.reshape(batch, max_track, -1, self.d_model).permute(0, 2, 1, 3) #(batch, num2bar-1, max_track, d_model) + #print('func output', func.shape) + func = func.reshape(batch, 1+idx, 9, max_track, self.d_model).permute(0, 1, 3, 2, 4) + + fp_logits = self.fp_out_linear(func[:, -1, :, 0]).unsqueeze(1) / t + if idx == 0: + filtered_fp_logits = self.nucleus_filter(fp_logits/2, 2*p) + else: + filtered_fp_logits = self.nucleus_filter(fp_logits, p) + fp_probability = F.softmax(filtered_fp_logits, dim=-1) + #print('fp_probability', fp_probability.shape) + fp_pred = torch.multinomial(fp_probability.reshape(-1, NUM_PITCH_CODE), 1).reshape(fp_probability.shape[:-1]) + + ft_logits = self.ft_out_linear(func[:, -1, :, 1:]).unsqueeze(1) / t + if idx == 0: + filtered_ft_logits = self.nucleus_filter(ft_logits/2, 2*p) + else: + filtered_ft_logits = self.nucleus_filter(ft_logits, p) + ft_probability = F.softmax(filtered_ft_logits, dim=-1) + ft_pred = torch.multinomial(ft_probability.reshape(-1, NUM_TIME_CODE), 1).reshape(ft_probability.shape[:-1]) + + fp = torch.cat([fp, fp_pred], dim=1) + ft = torch.cat([ft, ft_pred], dim=1) + if fp.shape[1] == num_2bar: + break + + z_fp = self.pitch_function_encoder.infer_by_codes(fp) + z_ft = self.time_function_encoder.infer_by_codes(ft) + return self.QaA_model.infer_with_function_codes(mix[0], prog[0].repeat(num_2bar, 1), z_fp[0], z_ft[0]) + + + @classmethod + def init_model(cls, pretrain_model_path=None, DEVICE='cuda:0'): + """Fast model initialization.""" + vqQaA = QandA(name='pretrain', trf_layers=2, device=DEVICE) + if pretrain_model_path is not None: + vqQaA.load_state_dict(torch.load(pretrain_model_path, map_location=torch.device('cpu'))) + vqQaA.eval() + model = cls(vqQaA.prmat_enc_fltn, vqQaA.func_pitch_enc, vqQaA.func_time_enc, DEVICE=DEVICE).to(DEVICE) + return model + + @classmethod + def init_inference_model(cls, prior_model_path, QaA_model_path, DEVICE='cuda:0'): + """Fast model initialization.""" + vqQaA = QandA(name='pretrain', trf_layers=2, device=DEVICE) + vqQaA.load_state_dict(torch.load(QaA_model_path, map_location=torch.device('cpu'))) + vqQaA.eval() + model = cls(inference=True, QaA_model=vqQaA, DEVICE=DEVICE).to(DEVICE) + model.load_state_dict(torch.load(prior_model_path), strict=False) + return model + diff --git a/orchestrator/QandA.py b/orchestrator/QandA.py new file mode 100644 index 0000000..f7186d5 --- /dev/null +++ b/orchestrator/QandA.py @@ -0,0 +1,438 @@ +import os +from torch import nn +from .utils import kl_with_normal +import torch +from .dl_modules import PtvaeEncoder, PianoTreeDecoder, TextureEncoder, AdaptFeatDecoder, VectorQuantizerEMA, VectorQuantizer + +from torch.nn import TransformerEncoderLayer +import torch.nn.functional as F +from torch.distributions import Normal + + +class FuncPitchEncoder(nn.Module): + def __init__(self, emb_size=256, z_dim=128, num_channel=10): + super(FuncPitchEncoder, self).__init__() + self.cnn = nn.Sequential(nn.Conv1d(1, num_channel, kernel_size=12, + stride=1, padding=0), + nn.ReLU(), + nn.MaxPool1d(kernel_size=4, stride=4)) + self.fc = nn.Linear(num_channel * 29, emb_size) + self.linear_mu = nn.Linear(emb_size, z_dim) + #self.linear_var = nn.Linear(emb_size, z_dim) + self.emb_size = emb_size + self.z_dim = z_dim + self.z2hid = nn.Linear(z_dim, emb_size) + self.hid2out = nn.Linear(emb_size, 128) + self.mse_func = nn.MSELoss() + self.vq_quantizer = VectorQuantizerEMA(embedding_dim=z_dim, num_embeddings=64, commitment_cost=.25, decay=.9, usage_threshold=1e-9, random_start=True) + #self.vq_quantizer = VectorQuantizer(embedding_dim=z_dim, num_embeddings=256, commitment_cost=.25, usage_threshold=1e-9, random_start=True) + self.batch_z = None + + def forward(self, pr, track_pad_mask): + # pr: (bs, 128) + bs = pr.size(0) + pr = pr.unsqueeze(1) + pr = self.cnn(pr).reshape(bs, -1) + pr = self.fc(pr) # (bs, emb_size) + mu = self.linear_mu(pr) + self.batch_z = mu.clone() + z, cmt_loss, perplexity = self.vq_quantizer(mu, track_pad_mask) + return z, cmt_loss, perplexity + + def get_code_indices(self, pr): + bs = pr.size(0) + pr = pr.unsqueeze(1) + pr = self.cnn(pr).reshape(bs, -1) + pr = self.fc(pr) # (bs, emb_size) + pr = self.linear_mu(pr) + pr = self.vq_quantizer.get_code_indices(pr) + return pr + + def infer_by_codes(self, encoding_indices): + z = self.vq_quantizer.infer_code(encoding_indices) + return z + + def decoder(self, z): + return self.hid2out(torch.relu(self.z2hid(z))) + + def recon_loss(self, pred, func_gt): + return self.mse_func(pred, func_gt) + + +class FuncTimeEncoder(nn.Module): + def __init__(self, emb_size=256, z_dim=128, num_channel=10): + super(FuncTimeEncoder, self).__init__() + self.cnn = nn.Sequential(nn.Conv1d(1, num_channel, kernel_size=4, + stride=4, padding=0), + nn.ReLU()) + self.fc = nn.Linear(num_channel * 8, emb_size) + + self.linear_mu = nn.Linear(emb_size , z_dim) + self.emb_size = emb_size + self.z_dim = z_dim + self.num_channel = num_channel + self.z2hid = nn.Linear(z_dim, emb_size) + self.hid2out = nn.Linear(emb_size, 32) + self.mse_func = nn.MSELoss() + self.vq_quantizer = VectorQuantizerEMA(embedding_dim=(self.num_channel*8)//8, num_embeddings=128, commitment_cost=.25, decay=.9, usage_threshold=1e-9, random_start=True) + #self.vq_quantizer = VectorQuantizer(embedding_dim=(self.num_channel*8)//4, num_embeddings=256, commitment_cost=.25, usage_threshold=1e-9, random_start=True) + self.batch_z = None + + def forward(self, pr, track_pad_mask): + # pr: (bs, 32) + bs = pr.size(0) + pr = pr.unsqueeze(1) + pr = self.cnn(pr)#.reshape(bs, -1) #(bs, channel, 8) + pr = pr.permute(0, 2, 1).reshape(bs, -1) + z = pr.reshape(bs, 8, (self.num_channel*8)//8) + self.batch_z = z.clone() + z, cmt_loss, perplexity = self.vq_quantizer(z, track_pad_mask.unsqueeze(1).repeat(1, 8, 1)) + z = z.reshape(bs, 8, self.num_channel).permute(0, 2, 1).reshape(bs, -1) + + z = self.fc(z) # (bs, emb_size) + z = self.linear_mu(z) + return z, cmt_loss, perplexity + + def get_code_indices(self, pr): + bs = pr.size(0) + pr = pr.unsqueeze(1) + pr = self.cnn(pr) + pr = pr.permute(0, 2, 1).reshape(bs, -1) + pr = pr.reshape(bs, 8, (self.num_channel*8)//8) + pr = self.vq_quantizer.get_code_indices(pr) + return pr.reshape(bs, 8) + + def infer_by_codes(self, encoding_indices): + #print('encoding_indices', encoding_indices.shape) + input_shape = encoding_indices.shape + encoding_indices = encoding_indices.reshape(-1, 8) + bs = encoding_indices.shape[0] + z = self.vq_quantizer.infer_code(encoding_indices) + #print('z', z.shape) + z = z.reshape(bs, 8, self.num_channel).permute(0, 2, 1).reshape(bs, -1) + z = self.fc(z) # (bs, emb_size) + z = self.linear_mu(z) + z = z.reshape(*list(input_shape[:-1]), z.shape[-1]) + #print(z.shape) + return z + + def decoder(self, z): + return self.hid2out(torch.relu(self.z2hid(z))) + + def recon_loss(self, pred, func_gt): + return self.mse_func(pred, func_gt) + + + +class QandA(nn.Module): + def __init__(self, name, device, + trf_layers=1, + stage=0): + super(QandA, self).__init__() + + self.name = name + self.device = device + + # symbolic encoder + self.prmat_enc_fltn = PtvaeEncoder(max_simu_note=32, device=self.device, z_size=256) + + # track function encoder + self.func_pitch_enc = FuncPitchEncoder(256, 128, 16) + self.func_time_enc = FuncTimeEncoder(256, 128, 16) + + # feat_dec + pianotree_dec = symbolic decoder + self.feat_dec = AdaptFeatDecoder(z_dim=256) # for symbolic feature recon + self.feat_emb_layer = nn.Linear(3, 64) + self.pianotree_dec = PianoTreeDecoder(z_size=256, feat_emb_dim=64, device=device) + + self.Transformer_layers = nn.ModuleDict({}) + self.trf_layers = trf_layers + for idx in range(self.trf_layers): + self.Transformer_layers[f'layer_{idx}'] = TransformerEncoderLayer(d_model=256, nhead=8, dim_feedforward=1024, dropout=.1, activation=F.gelu, batch_first=True) + + self.prog_embedding = nn.Embedding(num_embeddings=35, embedding_dim=256, padding_idx=34) + + self.eq_feat_head = nn.Linear(256, 4) + + self.trf_mu = nn.Linear(256, 256) + self.trf_var = nn.Linear(256, 256) + + self.stage = stage + + @property + def z_chd_dim(self): + return self.chord_enc.z_dim + + @property + def z_aud_dim(self): + return self.frame_enc.z_dim + + @property + def z_sym_dim(self): + return self.prmat_enc.z_dim + + + def run(self, pno_tree, pno_tree_fltn, prog, feat, track_pad_mask, func_pitch, func_time, tfr1=0, tfr2=0, tfr3=0, inference=False, sample_melody=False): + """ + Forward path of the model in training (w/o computing loss). + """ + #pno_tree: (batch, max_track, time, max_simu_note, 6) + #chd: (batch, time', 36) + #pr_fltn: (batch, max_track, time, 128) + #prog: (batch, 5, max_track) + #track_pad_mask: (batch, max_track) + #feat: (batch, max_track, time, 3) + #func_pitch: (batch, max_track, 128) + #func_time: (batch, max_track, 32) + + + if inference: + batch, track = track_pad_mask.shape + _, time, _, _ = pno_tree_fltn.shape + max_simu_note = 16 + else: + batch, track, time, max_simu_note, _ = pno_tree.shape + #print('pno_tree', pno_tree.shape) + + dist_sym, _, _ = self.prmat_enc_fltn(pno_tree_fltn) # + if inference: + z_sym = dist_sym.mean + else: + z_sym = dist_sym.rsample() + + + # compute symbolic-texture representation + if self.stage in [0, 1, 3]: + #print('pr_mat', pr_mat.shape) + func_pitch = func_pitch.reshape(-1, 128) + z_fp, cmt_loss_p, plty_p = self.func_pitch_enc(func_pitch, track_pad_mask) + + func_time = func_time.reshape(-1, 32) + z_ft, cmt_loss_t, plty_t = self.func_time_enc(func_time, track_pad_mask) + + fp_recon = self.func_pitch_enc.decoder(z_fp).reshape(batch, track, -1) + ft_recon = self.func_time_enc.decoder(z_ft).reshape(batch, track, -1) + + z_func = torch.cat([ + z_fp.reshape(batch, track, -1), + z_ft.reshape(batch, track, -1) + ], + dim=-1) #(batch, track, 256), + else: # self.stage == 2 (fine-tuning stage), dist_sym abandoned. + #TODO + pass + + #print('prog', prog.shape) + #print('prog embedding', self.prog_embedding(prog[:, 0]).shape) + + z = torch.cat([ + z_sym.unsqueeze(1), #(batch, 1, 256) + z_func + self.prog_embedding(prog)], + dim=1) #z: (batch, track+1, 256)""" + + + trf_mask = torch.cat([torch.zeros(batch, 1, device=z.device).bool(), track_pad_mask], dim=-1) #(batch, track+1) + for idx in range(self.trf_layers): + z = self.Transformer_layers[f'layer_{idx}'](src=z, src_key_padding_mask=trf_mask) + + # reconstruct symbolic feature using audio-texture repr. + z = z[:, 1:].reshape(-1, 256) + + #eq_feat = self.eq_feat_head(z).reshape(batch, track, 4) + + mu = self.trf_mu(z) + var = self.trf_var(z).exp_() + + #eq_feat = self.eq_feat_head(mu).reshape(batch, track, 4) + + dist_trf = Normal(mu, var) + if inference and (not sample_melody): + z = dist_trf.mean + elif inference and sample_melody: + z1 = dist_trf.mean.reshape(batch, track, 256) + z2 = dist_trf.rsample().reshape(batch, track, 256) + z = torch.cat([z2[:, 0: 1], z1[:, 1:]], dim=1).reshape(-1, 256) + else: + z = dist_trf.rsample() + #z = z.reshape(batch, track, 256) + + if not inference: + feat = feat.reshape(-1, time, 3) + recon_feat = self.feat_dec(z, inference, tfr1, feat) #(batch*track, time, 3) + # embed the reconstructed feature (without applying argmax) + feat_emb = self.feat_emb_layer(recon_feat) + + # prepare the teacher-forcing data for pianotree decoder + if inference: + embedded_pno_tree = None + pno_tree_lgths = None + else: + embedded_pno_tree, pno_tree_lgths = self.pianotree_dec.emb_x(pno_tree.reshape(-1, time, max_simu_note, 6)) + + # pianotree decoder + recon_pitch, recon_dur = \ + self.pianotree_dec(z, inference, embedded_pno_tree, pno_tree_lgths, tfr1, tfr2, feat_emb) + + recon_pitch = recon_pitch.reshape(batch, track, time, max_simu_note-1, 130) + recon_dur = recon_dur.reshape(batch, track, time, max_simu_note-1, 5, 2) + recon_feat = recon_feat.reshape(batch, track, time, 3) + + return recon_pitch, recon_dur, recon_feat, \ + fp_recon, ft_recon, \ + dist_sym, dist_trf, \ + cmt_loss_p, plty_p, \ + cmt_loss_t, plty_t + + + def loss_function(self, pno_tree, feat, func_pitch, func_time, recon_pitch, recon_dur, recon_feat, + fp_recon, ft_recon, + dist_sym, dist_trf, cmt_loss_p, plty_p, cmt_loss_t, plty_t, track_pad_mask, + beta, weights): + """ Compute the loss from ground truth and the output of self.run()""" + mask = torch.logical_not(track_pad_mask) + # pianotree recon loss + pno_tree_l, pitch_l, dur_l = \ + self.pianotree_dec.recon_loss(pno_tree[mask], + recon_pitch[mask], + recon_dur[mask], + weights, False) + + # feature prediction loss + feat_l, onset_feat_l, int_feat_l, center_feat_l = \ + self.feat_dec.recon_loss(feat[mask], recon_feat[mask]) + + fp_l = self.func_pitch_enc.recon_loss(fp_recon[mask], func_pitch[mask]) + ft_l = self.func_time_enc.recon_loss(ft_recon[mask], func_time[mask]) + func_l = (fp_l + cmt_loss_p) + (ft_l + cmt_loss_t) + + # kl losses + kl_sym = kl_with_normal(dist_sym) + kl_trf = kl_with_normal(dist_trf) + + if self.stage == 0: + # beta keeps annealing from 0 - 0.01 + kl_l = beta * (kl_sym + kl_trf) + else: # self.stage == 3 + # autoregressive fine-tuning + pass + + loss = pno_tree_l + feat_l + kl_l + func_l + + return loss, pno_tree_l, pitch_l, dur_l, \ + kl_l, kl_sym, kl_trf, \ + feat_l, onset_feat_l, int_feat_l, center_feat_l, \ + func_l, fp_l, ft_l, cmt_loss_p, cmt_loss_t, plty_p, plty_t + + + def loss(self, pno_tree, pno_tree_fltn, prog, feat, track_pad_mask, func_pitch_batch, func_time_batch, tfr1, tfr2, tfr3, + beta=0.01, weights=(1, 0.5)): + """ + Forward path during training with loss computation. + :param pno_tree: (B, track, 32, 16, 6) ground truth for teacher forcing + :param chd: (B, 8, 36) chord input + :param spec: (B, 229, 153) audio input. Log mel-spectrogram. (n_mels=229) + :param pr_mat: (B, track, 32, 128) (with proper corruption) symbolic input. + :param prog: (B, 5, track), track program and feature for embedding + :param feat: (B, track, 32, 3) ground truth for teacher forcing + :param track_pad_mask: (B, track), pad mask for Transformer. BoolTensor, with True indicating mask + :param tfr1: teacher forcing ratio 1 (1st-hierarchy RNNs except chord) + :param tfr2: teacher forcing ratio 2 (2nd-hierarchy RNNs except chord) + :param tfr3: teacher forcing ratio 3 (for chord decoder) + :param beta: kl annealing parameter + :param weights: weighting parameter for pitch and dur in PianoTree. + :return: losses (first argument is the total loss.) + """ + + output = self.run(pno_tree, pno_tree_fltn, prog, feat, track_pad_mask, func_pitch_batch, func_time_batch, tfr1, tfr2, tfr3) + + return self.loss_function(pno_tree, feat, func_pitch_batch, func_time_batch, *output, track_pad_mask, beta, weights) + + + def forward(self, mode, *input, **kwargs): + if mode in ["run", 0]: + return self.run(*input, **kwargs) + elif mode in ['loss', 'train', 1]: + return self.loss(*input, **kwargs) + elif mode in ['inference', 'eval', 'val', 2]: + return self.inference(*input, **kwargs) + else: + raise NotImplementedError + + def load_model(self, model_path, map_location=None): + if map_location is None: + map_location = self.device + dic = torch.load(model_path, map_location=map_location) + for name in list(dic.keys()): + dic[name.replace('module.', '')] = dic.pop(name) + self.load_state_dict(dic) + self.to(self.device) + + def infer_with_function_codes(self, z_sym, prog, z_fp, z_ft): + #z_sym: (batch, 256) + #prog: (batch, track) + #z_fp: (batch, track, 128) + #z_fp: (batch, track, 128) + + z_func = torch.cat([z_fp, z_ft], dim=-1) + z = torch.cat([ z_sym.unsqueeze(1), #(batch, 1, 256) + z_func + self.prog_embedding(prog)], + dim=1) #z: (batch, track+1, 256)""" + + for idx in range(self.trf_layers): + z = self.Transformer_layers[f'layer_{idx}'](src=z) + + z = z[:, 1:].reshape(-1, 256) + + mu = self.trf_mu(z) + var = self.trf_var(z).exp_() + dist_trf = Normal(mu, var) + z = dist_trf.mean + + recon_feat = self.feat_dec(z, True, 0, None) + feat_emb = self.feat_emb_layer(recon_feat) + + # prepare the teacher-forcing data for pianotree decoder + embedded_pno_tree = None + pno_tree_lgths = None + + # pianotree decoder + recon_pitch, recon_dur = \ + self.pianotree_dec(z, True, embedded_pno_tree, pno_tree_lgths, 0, 0, feat_emb) + + recon_pitch = recon_pitch.reshape(*list(prog.shape), 32, 15, 130) + recon_dur = recon_dur.reshape(*list(prog.shape), 32, 15, 5, 2) + return recon_pitch, recon_dur + + + def inference(self, audio, chord, sym_prompt=None): + """ + Forward path during inference. By default, symbolic source is not used. + :param audio: (B, 229, 153) audio input. + Log mel-spectrogram. (n_mels=229) + :param chord: (B, 8, 36) chord input + :param sym_prompt: (B, 32, 128) symbolic prompt. + By default, None. + :return: pianotree prediction (B, 32, 15, 6) numpy array. + """ + + self.eval() + with torch.no_grad(): + z_chd = self.chord_enc(chord).mean + z_aud = self.audio_enc(audio).mean + + z_sym = \ + torch.zeros(z_aud.size(0), self.z_sym_dim, + dtype=z_aud.dtype, device=z_aud.device) \ + if sym_prompt is None else self.prmat_enc(sym_prompt).mean + + z = torch.cat([z_chd, z_aud, z_sym], -1) + + recon_feat = self.feat_dec(z_aud, True, 0., None) + feat_emb = self.feat_emb_layer(recon_feat) + recon_pitch, recon_dur = \ + self.pianotree_dec(z, True, None, None, 0., 0., feat_emb) + + # convert to (argmax) pianotree format, numpy array. + pred = self.pianotree_dec.output_to_numpy(recon_pitch.cpu(), + recon_dur.cpu())[0] + return pred diff --git a/orchestrator/TransformerEncoderLayer.py b/orchestrator/TransformerEncoderLayer.py new file mode 100644 index 0000000..5803143 --- /dev/null +++ b/orchestrator/TransformerEncoderLayer.py @@ -0,0 +1,114 @@ +import math +import torch +import torch.nn.functional as F +from torch import nn +from torch.nn.modules.normalization import LayerNorm + + +class MultiheadSelfAttentionwithRelativePositionalEmbedding(nn.Module): + def __init__(self, dmodel, num_heads, dropout=0, max_len=1024): + super(MultiheadSelfAttentionwithRelativePositionalEmbedding, self).__init__() + self.max_len = max_len + self.num_heads = num_heads + self.head_dim = dmodel // num_heads + assert self.head_dim * num_heads == dmodel, "embed_dim must be divisible by num_heads" + + self.key = nn.Linear(dmodel, dmodel) + self.value = nn.Linear(dmodel, dmodel) + self.query = nn.Linear(dmodel, dmodel) + self.dropout = nn.Dropout(dropout) + self.Er = nn.Parameter(torch.randn(num_heads, (2*max_len-1) + 2, self.head_dim)) + + def forward(self, query, key, value, attn_mask=None, key_padding_mask=None): + #x: (batch, len, dmodel) + #Srel: (num_head, src_len, src_len) + #key_padding_mask: (batch, src_len), bool tensor + #attn_mask: (batch, num_head, src_len, src_len): float tensor + bs, src_len, d_model = query.shape + #_, src_len, _ = key.shape + + q = self.query(query).reshape(bs, src_len, self.num_heads, self.head_dim).transpose(1, 2) #(batch, num_head, src_len, head_dim) + k = self.key(key).reshape(bs, src_len, self.num_heads, self.head_dim).permute(0, 2, 3, 1) #(batch, num_head, head_dim, src_len) + v = self.value(value).reshape(bs, src_len, self.num_heads, self.head_dim).transpose(1, 2) #(batch, num_head, src_len, head_dim) + + Er_t = self.Er[:, max(0, self.max_len-src_len): min(2*self.max_len-1, self.max_len+src_len-1), :] + if src_len > self.max_len: + Er_t = torch.cat([ + self.Er[:, -2, :].unsqueeze(1).repeat(1, src_len-self.max_len, 1), + Er_t, + self.Er[:, -1, :].unsqueeze(1).repeat(1, src_len-self.max_len, 1) + ], dim=1) + Er_t = Er_t.transpose(-2, -1) #(num_head, head_dim, 2*src_len-1) + + QEr = torch.matmul(q, Er_t) #(batch, num_head, src_len, 2*src_len-1) + Srel = self.skew(QEr) #(batch, num_head, src_len, src_len) + + if key_padding_mask is not None: + if attn_mask is not None: + attn_mask = attn_mask.masked_fill(key_padding_mask.reshape(bs, 1, 1, src_len), float("-inf")) + else: + attn_mask = torch.zeros(bs, 1, 1, src_len, dtype=torch.float).to(key_padding_mask.device) + attn_mask = attn_mask.masked_fill(key_padding_mask.reshape(bs, 1, 1, src_len), float("-inf")) + + attn = (torch.matmul(q, k) + Srel) / math.sqrt(self.head_dim) #(batch, num_head, tgt_len, src_len) + + if attn_mask is not None: + attn += attn_mask + attn = F.softmax(attn, dim=-1) + + out = torch.matmul(attn, v) #(batch, num_head, tgt_len, head_dim) + out = out.transpose(1, 2).reshape(bs, src_len, d_model) #(batch, tgt_len, d_model) + return self.dropout(out), attn + + + def skew(self, QEr): + #QEr: (batch, num_heads, src_len, 2*src_len-1) + bs, num_heads, src_len, L = QEr.shape + QEr = F.pad(QEr, (0, 1)) #(batch, num_heads, src_len, L+1) + QEr = QEr.reshape(bs, num_heads, -1) #(batch, num_heads, src_len*(L+1)) + QEr = F.pad(QEr, (0, L-src_len)) #(batch, num_heads, (src_len+1)*L) + QEr = QEr.reshape(bs, num_heads, src_len+1, L) + QEr = QEr[:, :, :src_len, -src_len:] #(batch, num_heads, src_len, src_len) + return QEr + + + +class TransformerEncoderLayer(nn.Module): + def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, layer_norm_eps=1e-5, norm_first=False, max_len=1024): + super(TransformerEncoderLayer, self).__init__() + self.self_attn = MultiheadSelfAttentionwithRelativePositionalEmbedding(d_model, nhead, dropout, max_len) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm_first = norm_first + self.norm1 = LayerNorm(d_model, eps=layer_norm_eps) + self.norm2 = LayerNorm(d_model, eps=layer_norm_eps) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + self.activation = F.gelu + + def forward(self, src, src_mask=None, src_key_padding_mask=None): + #src: (batch, len, dmodel) + #key_padding_mask: (batch, src_len), bool tensor + #attn_mask: (batch, num_head, src_len, src_len): float tensor + x = src + if self.norm_first: + x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask) + x = x + self._ff_block(self.norm2(x)) + else: + x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask)) + x = self.norm2(x + self._ff_block(x)) + return x + + # self-attention block + def _sa_block(self, x, attn_mask=None, key_padding_mask=None): + x = self.self_attn(x, x, x, attn_mask=attn_mask, key_padding_mask=key_padding_mask)[0] + return self.dropout1(x) + + # feed forward block + def _ff_block(self, x): + x = self.linear2(self.dropout(self.activation(self.linear1(x)))) + return self.dropout2(x) diff --git a/orchestrator/__init__.py b/orchestrator/__init__.py new file mode 100644 index 0000000..7f0c180 --- /dev/null +++ b/orchestrator/__init__.py @@ -0,0 +1,2 @@ +from .dataset import Slakh_Dataset, collate_fn, compute_pr_feat, EMBED_PROGRAM_MAPPING +from .Prior import Prior \ No newline at end of file diff --git a/orchestrator/dataset.py b/orchestrator/dataset.py new file mode 100644 index 0000000..90b1e71 --- /dev/null +++ b/orchestrator/dataset.py @@ -0,0 +1,286 @@ +import os +import numpy as np +import pretty_midi as pyd +from torch.utils.data import Dataset +from tqdm import tqdm +import torch +import pandas as pd +from .utils import retrieve_control + + +ACC = 16 +SAMPLE_LEN = 2 * 16 +BAR_HOP_LEN = 1 +AUG_P = np.array([2, 2, 5, 5, 3, 7, 7, 5, 7, 3, 5, 1]) +NUM_INSTR_CLASS = 34 +NUM_PITCH_CODE = 64 +NUM_TIME_CODE = 128 +TOTAL_LEN_BIN = np.array([4, 7, 12, 15, 20, 23, 28, 31, 36, 39, 44, 47, 52, 55, 60, 63, 68, 71, 76, 79, 84, 87, 92, 95, 100, 103, 108, 111, 116, 119, 124, 127, 132]) +ABS_POS_BIN = np.arange(129) +REL_POS_BIN = np.arange(128) + +SLAKH_CLASS_PROGRAMS = dict({ + 0: 'Acoustic Piano', #0 + 4: 'Electric Piano', #1 + 8: 'Chromatic Percussion',#2 + 16: 'Organ', #3 + 24: 'Acoustic Guitar', #4 + 26: 'Clean Electric Guitar', #5 + 29: 'Distorted Electric Guitar', #6 + 32: 'Acoustic Bass', #7 + 33: 'Electric Bass', #8 + 40: 'Violin', #9 + 41: 'Viola', #10 + 42: 'Cello', #11 + 43: 'Contrabass', #12 + 46: 'Orchestral Harp', #13 + 47: 'Timpani', #14 + 48: 'String Ensemble', #15 + 50: 'Synth Strings', #16 + 52: 'Choir and Voice', #17 + 55: 'Orchestral Hit', #18 + 56: 'Trumpet', #19 + 57: 'Trombone', #20 + 58: 'Tuba', #21 + 60: 'French Horn', #22 + 61: 'Brass Section', #23 + 64: 'Soprano/Alto Sax', #24 + 66: 'Tenor Sax', #25 + 67: 'Baritone Sax', #26 + 68: 'Oboe', #27 + 69: 'English Horn', #28 + 70: 'Bassoon', #29 + 71: 'Clarinet', #30 + 72: 'Pipe', #31 + 80: 'Synth Lead', #32 + 88: 'Synth Pad' #33 +}) + +SLAKH_PROGRAM_MAPPING = dict({0: 0, 1: 0, 2: 0, 3: 0, 4: 4, 5: 4, 6: 4, 7: 4,\ + 8: 8, 9: 8, 10: 8, 11: 8, 12: 8, 13: 8, 14: 8, 15: 8,\ + 16: 16, 17: 16, 18: 16, 19: 16, 20: 16, 21: 16, 22: 16, 23: 16,\ + 24: 24, 25: 24, 26: 26, 27: 26, 28: 26, 29: 29, 30: 29, 31: 29,\ + 32: 32, 33: 33, 34: 33, 35: 33, 36: 33, 37: 33, 38: 33, 39: 33,\ + 40: 40, 41: 41, 42: 42, 43: 43, 44: 43, 45: 43, 46: 46, 47: 47,\ + 48: 48, 49: 48, 50: 50, 51: 50, 52: 52, 53: 52, 54: 52, 55: 55,\ + 56: 56, 57: 57, 58: 58, 59: 58, 60: 60, 61: 61, 62: 61, 63: 61,\ + 64: 64, 65: 64, 66: 66, 67: 67, 68: 68, 69: 69, 70: 70, 71: 71,\ + 72: 72, 73: 72, 74: 72, 75: 72, 76: 72, 77: 72, 78: 72, 79: 72,\ + 80: 80, 81: 80, 82: 80, 83: 80, 84: 80, 85: 80, 86: 80, 87: 80,\ + 88: 88, 89: 88, 90: 88, 91: 88, 92: 88, 93: 88, 94: 88, 95: 88}) + +EMBED_PROGRAM_MAPPING = dict({ + 0: 0, 4: 1, 8: 2, 16: 3, 24: 4, 26: 5, 29: 6, 32: 7,\ + 33: 8, 40: 9, 41: 10, 42: 11, 43: 12, 46: 13, 47: 14, 48: 15,\ + 50: 16, 52: 17, 55: 18, 56: 19, 57: 20, 58: 21, 60: 22, 61: 23, + 64: 24, 66: 25, 67: 26, 68: 27, 69: 28, 70: 29, 71: 30, 72: 31,\ + 80: 32, 88: 33}) + + +class Slakh_Dataset(Dataset): + def __init__(self, slakh_dir, debug_mode=False, split='train', mode='train'): + super(Slakh_Dataset, self).__init__() + self.slakh_dir = slakh_dir + self.split = split + self.mode = mode + self.debug_mode = debug_mode + self.pr_list = [] + self.program_list = [] + self.anchor_list = [] + + self.load_slakh() + + def __len__(self): + return len(self.anchor_list) + + def __getitem__(self, idx): + song_id, start, total_len = self.anchor_list[idx] + pr = self.pr_list[song_id][:, start: min(total_len, start+SAMPLE_LEN)] + prog = self.program_list[song_id] + return pr, prog, (start, total_len) + + + def slakh_program_mapping(self, programs): + return np.array([SLAKH_PROGRAM_MAPPING[program] for program in programs]) + + + def load_slakh(self): + if self.split == 'inference': + slakh_list = [] + slakh_list += os.listdir(os.path.join(self.slakh_dir, 'validation')) + slakh_list += os.listdir(os.path.join(self.slakh_dir, 'test')) + else: + slakh_list = os.listdir(os.path.join(self.slakh_dir, self.split)) + if self.debug_mode: + slakh_list = slakh_list[: 10] + for song in slakh_list: + if self.split == 'inference': + if song in os.listdir(os.path.join(self.slakh_dir, 'validation')): + slakh_data = np.load(os.path.join(self.slakh_dir, 'validation', song)) + elif song in os.listdir(os.path.join(self.slakh_dir, 'test')): + slakh_data = np.load(os.path.join(self.slakh_dir, 'test', song)) + else: + slakh_data = np.load(os.path.join(self.slakh_dir, self.split, song)) + tracks = slakh_data['tracks'] #(n_track, time, 128) + programs = slakh_data['programs'] #(n_track, ) + db_indicator = slakh_data['db_indicator'] #(time, ) + """padding""" + num_bars = int(np.ceil(tracks.shape[1] / 16)) + if ((num_bars) % 2) == 1: #pad zero so that each piece has a even number of bars (four beats per bar) + pad_len = (num_bars + 1) * 16 - tracks.shape[1] + else: + pad_len = num_bars * 16 - tracks.shape[1] + if pad_len != 0: + tracks = np.pad(tracks, ((0, 0), (0, pad_len), (0, 0)), mode='constant', constant_values=(0,)) + + center_pitch = compute_center_pitch(tracks) + pitch_sort = np.argsort(center_pitch)[::-1] + tracks = tracks[pitch_sort] + programs = programs[pitch_sort] + + """clipping""" + db_indices = np.nonzero(db_indicator)[0] + if self.split == 'train': + for i in range(0, len(db_indices), BAR_HOP_LEN): + if db_indices[i] + SAMPLE_LEN >= tracks.shape[1]: + break + self.anchor_list.append((len(self.pr_list), db_indices[i], tracks.shape[1])) #(song_id, start, total_length) + else: + for i in range(0, tracks.shape[1], SAMPLE_LEN): + if i + SAMPLE_LEN >= tracks.shape[1]: + break + self.anchor_list.append((len(self.pr_list), i, tracks.shape[1])) #(song_id, start, total_length) + self.anchor_list.append((len(self.pr_list), max(0, (tracks.shape[1]-SAMPLE_LEN)), tracks.shape[1])) + + program_classes = self.slakh_program_mapping(programs) + prog_sample = np.array([EMBED_PROGRAM_MAPPING[prog] for prog in program_classes]) + self.program_list.append(prog_sample) + self.pr_list.append(tracks) + + +def collate_fn(batch, device, pitch_shift=True, get_pr_gt=False): + max_dur = max([item[0].shape[1]//32 for item in batch]) + max_tracks = max([len(item[1]) for item in batch]) + + grid_flatten_batch = [] + prog_batch = [] + time_mask = [] + track_mask = [] + func_pitch_batch = [] + func_time_batch = [] + total_length = [] + abs_pos = [] + rel_pos = [] + pr_batch = [] + + if pitch_shift: + aug_p = AUG_P / AUG_P.sum() + aug_shift = np.random.choice(np.arange(-6, 6), 1, p=aug_p)[0] + else: + aug_shift = 0 + + for pr, prog, (start, total_len) in batch: + time_mask.append([0]*(pr.shape[1]//32) + [1]*(max_dur-pr.shape[1]//32)) + track_mask.append([0]*len(prog) + [1]*(max_tracks-len(prog))) + + r_pos = np.round(np.arange(start//32, (start+pr.shape[1])//32, 1) / (total_len//32-1) * len(REL_POS_BIN)) + total_len = np.argmin(np.abs(TOTAL_LEN_BIN - total_len//32)).repeat(pr.shape[1]//32) + if start//32 <= ABS_POS_BIN[-2]: + a_pos = np.append(ABS_POS_BIN[start//32: min(ABS_POS_BIN[-1], (start+pr.shape[1])//32)], [ABS_POS_BIN[-1]] * ((start+pr.shape[1])//32-ABS_POS_BIN[-1])) + else: + a_pos = np.array([ABS_POS_BIN[-1]] * (pr.shape[1]//32)) + + + pr = pr_mat_pitch_shift(pr, aug_shift) + func_pitch, func_time = compute_pr_feat(pr.reshape(pr.shape[0], -1, 32, pr.shape[-1])) + func_pitch = func_pitch.transpose(1, 0, 2) + func_time = func_time.transpose(1, 0, 2) + if len(prog) < max_tracks: + pr = np.pad(pr, ((0, max_tracks-len(prog)), (0, 0), (0, 0)), mode='constant', constant_values=(0,)) + prog = np.pad(prog, ((0, max_tracks-len(prog))), mode='constant', constant_values=(NUM_INSTR_CLASS,)) + func_pitch = np.pad(func_pitch, ((0, 0), (0, max_tracks-func_pitch.shape[1]), (0, 0)), mode='constant', constant_values=(0,)) + func_time = np.pad(func_time, ((0, 0), (0, max_tracks-func_time.shape[1]), (0, 0)), mode='constant', constant_values=(0,)) + + if pr.shape[1]//32 < max_dur: + pr = np.pad(pr, ((0, 0), (0, max_dur*32-pr.shape[1]), (0, 0)), mode='constant', constant_values=(0,)) + total_len = np.pad(total_len, (0, max_dur-pr.shape[1]//32), mode='constant', constant_values=(len(TOTAL_LEN_BIN),)) + a_pos = np.pad(a_pos, (0, max_dur-len(a_pos)), mode='constant', constant_values=(len(ABS_POS_BIN),)) + r_pos = np.pad(r_pos, (0, max_dur-len(r_pos)), mode='constant', constant_values=(len(REL_POS_BIN),)) + func_pitch = np.pad(func_pitch, ((0, max_dur-len(func_pitch)), (0, 0)), mode='constant', constant_values=(0,)) + func_time = np.pad(func_time, ((0, max_dur-len(func_time)), (0, 0), (0, 0)), mode='constant', constant_values=(0,)) + + #print('pr', pr.shape, 'prog', prog.shape, 'fp', func_pitch.shape, 'ft', func_time.shape) + grid_flatten = pr2grid(np.max(pr, axis=0), max_note_count=32).reshape(-1, 32, 32, 6) + grid_flatten_batch.append(grid_flatten) + prog_batch.append(prog) + func_pitch_batch.append(func_pitch) + func_time_batch.append(func_time) + total_length.append(total_len) + abs_pos.append(a_pos) + rel_pos.append(r_pos) + pr_batch.append(pr) + + if get_pr_gt: + return torch.from_numpy(np.array(pr_batch)).long().to(device), \ + torch.from_numpy(np.array(grid_flatten_batch)).long().to(device), \ + torch.from_numpy(np.array(prog_batch)).to(device), \ + torch.from_numpy(np.array(func_pitch_batch)).float().to(device), \ + torch.from_numpy(np.array(func_time_batch)).float().to(device), \ + torch.BoolTensor(time_mask).to(device), \ + torch.BoolTensor(track_mask).to(device), \ + torch.from_numpy(np.array(total_length)).long().to(device),\ + torch.from_numpy(np.array(abs_pos)).long().to(device),\ + torch.from_numpy(np.array(rel_pos)).long().to(device) + else: + return torch.from_numpy(np.array(grid_flatten_batch)).long().to(device), \ + torch.from_numpy(np.array(prog_batch)).to(device), \ + torch.from_numpy(np.array(func_pitch_batch)).float().to(device), \ + torch.from_numpy(np.array(func_time_batch)).float().to(device), \ + torch.BoolTensor(time_mask).to(device), \ + torch.BoolTensor(track_mask).to(device), \ + torch.from_numpy(np.array(total_length)).long().to(device),\ + torch.from_numpy(np.array(abs_pos)).long().to(device),\ + torch.from_numpy(np.array(rel_pos)).long().to(device) + + + +def pr_mat_pitch_shift(pr_mat, shift): + pr_mat = pr_mat.copy() + pr_mat = np.roll(pr_mat, shift, -1) + return pr_mat + + +def pr2grid(pr_mat, max_note_count=16, max_pitch=127, min_pitch=0, + pitch_pad_ind=130, dur_pad_ind=2, + pitch_sos_ind=128, pitch_eos_ind=129): + grid = np.ones((SAMPLE_LEN, max_note_count, 6), dtype=int) * dur_pad_ind + grid[:, :, 0] = pitch_pad_ind + grid[:, 0, 0] = pitch_sos_ind + cur_idx = np.ones(SAMPLE_LEN, dtype=int) + for t, p in zip(*np.where(pr_mat != 0)): + if cur_idx[t] == max_note_count - 1: + continue + grid[t, cur_idx[t], 0] = p - min_pitch + binary = np.binary_repr(min(int(pr_mat[t, p]), 32) - 1, width=5) + grid[t, cur_idx[t], 1: 6] = \ + np.fromstring(' '.join(list(binary)), dtype=int, sep=' ') + cur_idx[t] += 1 + grid[np.arange(0, SAMPLE_LEN), cur_idx, 0] = pitch_eos_ind + return grid + + +def compute_pr_feat(pr): + #pr: (track, time, 128) + onset = (np.sum(pr, axis=-1) > 0) * 1. #(track, time) + func_time = np.clip(np.sum((pr > 0) * 1., axis=-1) / 14, a_min=None, a_max=1) #(track, time) + func_pitch = np.sum((pr > 0) * 1., axis=-2) / 32 + + return func_pitch, func_time + +def compute_center_pitch(pr): + #pr: (track, time, 128) + #pr[pr > 0] = 1 + weight = np.sum(pr, axis=(-2, -1)) + weight[weight == 0] = 1 + pitch_center = np.sum(np.arange(0, 128)[np.newaxis, np.newaxis, :] * pr, axis=(-2, -1)) / weight + return pitch_center #(track, ) \ No newline at end of file diff --git a/orchestrator/dl_modules/__init__.py b/orchestrator/dl_modules/__init__.py new file mode 100644 index 0000000..0b6a819 --- /dev/null +++ b/orchestrator/dl_modules/__init__.py @@ -0,0 +1,5 @@ +from .pianotree_dec import PianoTreeDecoder +from .pr_mat_txt_enc import TextureEncoder +from .feat_decoder import FeatDecoder, AdaptFeatDecoder +from .pianotree_enc import PtvaeEncoder +from .vqvae import VectorQuantizerEMA, VectorQuantizer diff --git a/orchestrator/dl_modules/feat_decoder.py b/orchestrator/dl_modules/feat_decoder.py new file mode 100644 index 0000000..07e35e3 --- /dev/null +++ b/orchestrator/dl_modules/feat_decoder.py @@ -0,0 +1,160 @@ +import torch +from torch import nn +import random + + +class FeatDecoder(nn.Module): + + def __init__(self, z_input_dim=128, + hidden_dim=1024, z_dim=512, n_step=32, output_dim=3): + super(FeatDecoder, self).__init__() + self.z2dec_hid = nn.Linear(z_dim, hidden_dim) + self.z2dec_in = nn.Linear(z_dim, z_input_dim) + self.gru = nn.GRU(output_dim + z_input_dim, hidden_dim, + batch_first=True, + bidirectional=False) + self.init_input = nn.Parameter(torch.rand(output_dim)) + self.hidden_dim = hidden_dim + self.z_dim = z_dim + + self.out = nn.Linear(hidden_dim, output_dim) + + self.sigmoid = nn.Sigmoid() + self.output_dim = output_dim + self.n_step = n_step + self.bce_func = nn.BCELoss() + self.mse_func = nn.MSELoss() + + def forward(self, z, inference, tfr, gt_feat=None): + + bs = z.size(0) + + z_hid = self.z2dec_hid(z).unsqueeze(0) + + z_in = self.z2dec_in(z).unsqueeze(1) + + if inference: + tfr = 0. + + token = self.init_input.repeat(bs, 1).unsqueeze(1) + + out_feats = [] + + for t in range(self.n_step): + y_t, z_hid = \ + self.gru(torch.cat([token, z_in], dim=-1), z_hid) + + out_feat = self.out(y_t) # (bs, 1, 3) + + bass_pred = self.sigmoid(out_feat[:, :, 0]) + rhy_pred = self.sigmoid(out_feat[:, :, 2]) + rhy_int = out_feat[:, :, 1] + + out_feats.append(torch.stack([bass_pred, rhy_int, rhy_pred], -1)) + + # prepare the input to the next step + if t == self.n_step - 1: + break + teacher_force = random.random() < tfr + if teacher_force and not inference: + token = gt_feat[:, t].unsqueeze(1) + else: + t_bass = bass_pred + t_rhy = rhy_pred > 0.5 + t_int = rhy_int + token = torch.stack([t_bass, t_int, t_rhy], -1) + + recon = torch.cat(out_feats, dim=1) + return recon + + def recon_loss(self, gt_feat, recon_feat): + recon_bass = recon_feat[:, :, 0] + recon_int = recon_feat[:, :, 1] + recon_rhy = recon_feat[:, :, 2] + + bass_loss = self.bce_func(recon_bass, gt_feat[:, :, 0]) + int_loss = self.mse_func(recon_int, gt_feat[:, :, 1]) + rhy_loss = self.bce_func(recon_rhy, gt_feat[:, :, 2]) + + loss = bass_loss + int_loss + rhy_loss + + return loss, bass_loss, int_loss, rhy_loss + + + +class AdaptFeatDecoder(nn.Module): + + def __init__(self, z_input_dim=128, + hidden_dim=1024, z_dim=512, n_step=32, output_dim=3): + super(AdaptFeatDecoder, self).__init__() + self.z2dec_hid = nn.Linear(z_dim, hidden_dim) + self.z2dec_in = nn.Linear(z_dim, z_input_dim) + self.gru = nn.GRU(output_dim + z_input_dim, hidden_dim, + batch_first=True, + bidirectional=False) + self.init_input = nn.Parameter(torch.rand(output_dim)) + self.hidden_dim = hidden_dim + self.z_dim = z_dim + + self.out = nn.Linear(hidden_dim, output_dim) + + self.sigmoid = nn.Sigmoid() + self.output_dim = output_dim + self.n_step = n_step + self.bce_func = nn.BCELoss() + self.mse_func = nn.MSELoss() + + def forward(self, z, inference, tfr, gt_feat=None): + + bs = z.size(0) + + z_hid = self.z2dec_hid(z).unsqueeze(0) + + z_in = self.z2dec_in(z).unsqueeze(1) + + if inference: + tfr = 0. + + token = self.init_input.repeat(bs, 1).unsqueeze(1) + + out_feats = [] + + for t in range(self.n_step): + y_t, z_hid = \ + self.gru(torch.cat([token, z_in], dim=-1), z_hid) + + out_feat = self.out(y_t) # (bs, 1, 3) + + onset_pred = self.sigmoid(out_feat[:, :, 0]) + rhy_int = out_feat[:, :, 1] + pitch_center = out_feat[:, :, 2] + + out_feats.append(torch.stack([onset_pred, rhy_int, pitch_center], -1)) + + # prepare the input to the next step + if t == self.n_step - 1: + break + teacher_force = random.random() < tfr + if teacher_force and not inference: + token = gt_feat[:, t].unsqueeze(1) + else: + t_onset = onset_pred > 0.5 + t_int = rhy_int + t_pitch = pitch_center + token = torch.stack([t_onset, t_int, t_pitch], -1) + + recon = torch.cat(out_feats, dim=1) + return recon + + def recon_loss(self, gt_feat, recon_feat): + recon_onset = recon_feat[:, :, 0] + recon_int = recon_feat[:, :, 1] + recon_pitch = recon_feat[:, :, 2] + + onset_loss = self.bce_func(recon_onset, gt_feat[:, :, 0]) + int_loss = self.mse_func(recon_int, gt_feat[:, :, 1]) + pitch_loss = self.mse_func(recon_pitch, gt_feat[:, :, 2]) + + loss = onset_loss + int_loss + pitch_loss + + return loss, onset_loss, int_loss, pitch_loss diff --git a/orchestrator/dl_modules/pianotree_dec.py b/orchestrator/dl_modules/pianotree_dec.py new file mode 100644 index 0000000..9b35902 --- /dev/null +++ b/orchestrator/dl_modules/pianotree_dec.py @@ -0,0 +1,364 @@ +from torch import nn +import torch +import random +from torch.nn.utils.rnn import pack_padded_sequence +import pretty_midi +import numpy as np + + +class PianoTreeDecoder(nn.Module): + + def __init__(self, device=None, note_embedding=None, + max_simu_note=16, max_pitch=127, min_pitch=0, + pitch_sos=128, pitch_eos=129, pitch_pad=130, + dur_pad=2, dur_width=5, num_step=32, + note_emb_size=128, z_size=512, + dec_emb_hid_size=128, + dec_time_hid_size=1024, dec_notes_hid_size=512, + dec_z_in_size=256, dec_dur_hid_size=16, feat_emb_dim=0): + """ + feat_emb_dim: additional dimension for symbolic features. + """ + super(PianoTreeDecoder, self).__init__() + + # Parameters + # note and time + self.max_pitch = max_pitch # the highest pitch in train/val set. + self.min_pitch = min_pitch # the lowest pitch in train/val set. + self.pitch_sos = pitch_sos + self.pitch_eos = pitch_eos + self.pitch_pad = pitch_pad + self.pitch_range = max_pitch - min_pitch + 3 # 88, not including pad. + self.dur_pad = dur_pad + self.dur_width = dur_width + self.note_size = self.pitch_range + dur_width + self.max_simu_note = max_simu_note # the max # of notes at each ts. + self.num_step = num_step # 32 + + # device + self.device = device#\ + # torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + self.note_emb_size = note_emb_size + self.z_size = z_size + + # decoder + self.dec_z_in_size = dec_z_in_size + self.dec_emb_hid_size = dec_emb_hid_size + self.dec_time_hid_size = dec_time_hid_size + self.dec_init_input = \ + nn.Parameter(torch.rand(2 * self.dec_emb_hid_size)) + self.dec_notes_hid_size = dec_notes_hid_size + self.dur_sos_token = nn.Parameter(torch.rand(self.dur_width)) + self.dec_dur_hid_size = dec_dur_hid_size + + # Modules + # For both encoder and decoder + if note_embedding is None: + self.note_embedding = nn.Linear(self.note_size, note_emb_size) + else: + self.note_embedding = note_embedding + self.z2dec_hid_linear = nn.Linear(self.z_size, dec_time_hid_size) + self.z2dec_in_linear = nn.Linear(self.z_size, dec_z_in_size) + self.dec_notes_emb_gru = nn.GRU(note_emb_size, dec_emb_hid_size, + num_layers=1, batch_first=True, + bidirectional=True) + self.dec_time_gru = \ + nn.GRU(dec_z_in_size + 2 * dec_emb_hid_size + feat_emb_dim, + dec_time_hid_size, + num_layers=1, batch_first=True, + bidirectional=False) + self.dec_time_to_notes_hid = nn.Linear(dec_time_hid_size, + dec_notes_hid_size) + self.dec_notes_gru = nn.GRU(dec_time_hid_size + note_emb_size, + dec_notes_hid_size, + num_layers=1, batch_first=True, + bidirectional=False) + self.pitch_out_linear = nn.Linear(dec_notes_hid_size, self.pitch_range) + self.dec_dur_gru = nn.GRU(dur_width, dec_dur_hid_size, + num_layers=1, batch_first=True, + bidirectional=False) + self.dur_hid_linear = nn.Linear(self.pitch_range + dec_notes_hid_size, + dec_dur_hid_size) + self.dur_out_linear = nn.Linear(dec_notes_hid_size, self.dur_width * 2) + + def get_len_index_tensor(self, ind_x): + """Calculate the lengths ((B, 32), torch.LongTensor) of pgrid.""" + with torch.no_grad(): + lengths = self.max_simu_note - \ + (ind_x[:, :, :, 0] - self.pitch_pad == 0).sum(dim=-1) + return lengths + + def index_tensor_to_multihot_tensor(self, ind_x): + """Transfer piano_grid to multi-hot piano_grid.""" + # ind_x: (B, 32, max_simu_note, 1 + dur_width) + with torch.no_grad(): + dur_part = ind_x[:, :, :, 1:].float() + out = torch.zeros( + [ind_x.size(0) * self.num_step * self.max_simu_note, + self.pitch_range + 1], + dtype=torch.float).to(ind_x.device) + + out[range(0, out.size(0)), ind_x[:, :, :, 0].view(-1)] = 1. + out = out.view(-1, 32, self.max_simu_note, self.pitch_range + 1) + out = torch.cat([out[:, :, :, 0: self.pitch_range], dur_part], + dim=-1) + return out + + def get_sos_token(self): + sos = torch.zeros(self.note_size) + sos[self.pitch_sos] = 1. + sos[self.pitch_range:] = 2. + sos = sos.to(self.device) + return sos + + def dur_ind_to_dur_token(self, inds, batch_size): + token = torch.zeros(batch_size, self.dur_width) + token[range(0, batch_size), inds] = 1. + token = token.to(inds.device) + return token + + def pitch_dur_ind_to_note_token(self, pitch_inds, dur_inds, batch_size): + token = torch.zeros(batch_size, self.note_size) + token[range(0, batch_size), pitch_inds] = 1. + token[:, self.pitch_range:] = dur_inds + token = token.to(pitch_inds.device) + token = self.note_embedding(token) + return token + + def decode_note(self, note_summary, batch_size): + # note_summary: (B, 1, dec_notes_hid_size) + # This function estimate pitch, and dur for a single pitch based on + # note_summary. + # Returns: est_pitch (B, 1, pitch_range), est_durs (B, 1, dur_width, 2) + + # The estimated pitch is calculated by a linear layer. + est_pitch = self.pitch_out_linear(note_summary).squeeze(1) + # est_pitch: (B, pitch_range) + + # Unlike the original PianoTree implementation, the duration is + # computed simply by a linear layer. + est_durs = self.dur_out_linear(note_summary).reshape(batch_size, + self.dur_width, 2) + + return est_pitch, est_durs + + def decode_notes(self, notes_summary, batch_size, notes, inference, + teacher_forcing_ratio=0.5): + # notes_summary: (B, 1, dec_time_hid_size) + # notes: (B, max_simu_note, note_emb_size), ground_truth + notes_summary_hid = \ + self.dec_time_to_notes_hid(notes_summary.transpose(0, 1)) + if inference: + assert teacher_forcing_ratio == 0 + assert notes is None + sos = self.get_sos_token().to(notes_summary_hid.device) # (note_size,) + token = self.note_embedding(sos).repeat(batch_size, 1).unsqueeze(1) + # hid: (B, 1, note_emb_size) + else: + token = notes[:, 0].unsqueeze(1) + + predicted_notes = torch.zeros(batch_size, self.max_simu_note, + self.note_emb_size) + predicted_notes[:, :, self.pitch_range:] = 2. + predicted_notes[:, 0] = token.squeeze(1) # fill sos index + lengths = torch.zeros(batch_size) + predicted_notes = predicted_notes.to(notes_summary_hid.device) + + lengths = lengths.to(notes_summary_hid.device) + + pitch_outs = [] + dur_outs = [] + + for t in range(1, self.max_simu_note): + note_summary, notes_summary_hid = \ + self.dec_notes_gru(torch.cat([notes_summary, token], dim=-1), + notes_summary_hid) + # note_summary: (B, 1, dec_notes_hid_size) + # notes_summary_hid: (1, B, dec_time_hid_size) + + est_pitch, est_durs = self.decode_note(note_summary, batch_size) + # est_pitch: (B, pitch_range) + # est_durs: (B, dur_width, 2) + + pitch_outs.append(est_pitch.unsqueeze(1)) + dur_outs.append(est_durs.unsqueeze(1)) + pitch_inds = est_pitch.max(1)[1] + dur_inds = est_durs.max(2)[1] + predicted = self.pitch_dur_ind_to_note_token(pitch_inds, dur_inds, + batch_size) + # predicted: (B, note_size) + + predicted_notes[:, t] = predicted + eos_samp_inds = (pitch_inds == self.pitch_eos) + lengths[eos_samp_inds & (lengths == 0)] = t + + if t == self.max_simu_note - 1: + break + teacher_force = random.random() < teacher_forcing_ratio + if inference or not teacher_force: + token = predicted.unsqueeze(1) + else: + token = notes[:, t].unsqueeze(1) + lengths[lengths == 0] = t + pitch_outs = torch.cat(pitch_outs, dim=1) + dur_outs = torch.cat(dur_outs, dim=1) + return pitch_outs, dur_outs, predicted_notes, lengths + + def decoder(self, z, inference, x, lengths, teacher_forcing_ratio1, + teacher_forcing_ratio2, feat=None): + # z: (B, z_size) + # x: (B, num_step, max_simu_note, note_emb_size) + batch_size = z.size(0) + z_hid = self.z2dec_hid_linear(z).unsqueeze(0) + # z_hid: (1, B, dec_time_hid_size) + z_in = self.z2dec_in_linear(z).unsqueeze(1) + # z_in: (B, dec_z_in_size) + + if inference: + assert x is None + assert lengths is None + assert teacher_forcing_ratio1 == 0 + assert teacher_forcing_ratio2 == 0 + else: + x_summarized = x.view(-1, self.max_simu_note, self.note_emb_size) + x_summarized = pack_padded_sequence(x_summarized, + lengths.view(-1).cpu(), + batch_first=True, + enforce_sorted=False) + x_summarized = self.dec_notes_emb_gru(x_summarized)[-1].\ + transpose(0, 1).contiguous() + x_summarized = x_summarized.view(-1, self.num_step, + 2 * self.dec_emb_hid_size) + + pitch_outs = [] + dur_outs = [] + token = self.dec_init_input.repeat(batch_size, 1).unsqueeze(1) + # (B, 2 * dec_emb_hid_size) + + for t in range(self.num_step): + if feat is not None: + notes_summary, z_hid = \ + self.dec_time_gru( + torch.cat([token, z_in, feat[:, t].unsqueeze(1)], dim=-1), z_hid) + else: + notes_summary, z_hid = \ + self.dec_time_gru( + torch.cat([token, z_in], dim=-1), z_hid) + if inference: + pitch_out, dur_out, predicted_notes, predicted_lengths = \ + self.decode_notes(notes_summary, batch_size, None, + inference, teacher_forcing_ratio2) + else: + pitch_out, dur_out, predicted_notes, predicted_lengths = \ + self.decode_notes(notes_summary, batch_size, x[:, t], + inference, teacher_forcing_ratio2) + pitch_outs.append(pitch_out.unsqueeze(1)) + dur_outs.append(dur_out.unsqueeze(1)) + if t == self.num_step - 1: + break + + teacher_force = random.random() < teacher_forcing_ratio1 + if teacher_force and not inference: + token = x_summarized[:, t].unsqueeze(1) + else: + token = pack_padded_sequence(predicted_notes, + predicted_lengths.cpu(), + batch_first=True, + enforce_sorted=False) + token = self.dec_notes_emb_gru(token)[-1].\ + transpose(0, 1).contiguous() + token = token.view(-1, 2 * self.dec_emb_hid_size).unsqueeze(1) + pitch_outs = torch.cat(pitch_outs, dim=1) + dur_outs = torch.cat(dur_outs, dim=1) + # print(pitch_outs.size()) + # print(dur_outs.size()) + return pitch_outs, dur_outs + + def forward(self, z, inference, x, lengths, teacher_forcing_ratio1, + teacher_forcing_ratio2, feat=None): + return self.decoder(z, inference, x, lengths, teacher_forcing_ratio1, + teacher_forcing_ratio2, feat) + + def recon_loss(self, x, recon_pitch, recon_dur, weights=(1, 0.5), + weighted_dur=False, reduction='mean'): + bs = x.size(0) + pitch_loss_func = \ + nn.CrossEntropyLoss(ignore_index=self.pitch_pad, reduction=reduction) + recon_pitch = recon_pitch.view(-1, recon_pitch.size(-1)) + gt_pitch = x[:, :, 1:, 0].contiguous().view(-1) + pitch_loss = pitch_loss_func(recon_pitch, gt_pitch) + + dur_loss_func = \ + nn.CrossEntropyLoss(ignore_index=self.dur_pad, reduction=reduction) + if not weighted_dur: + recon_dur = recon_dur.view(-1, 2) + gt_dur = x[:, :, 1:, 1:].contiguous().view(-1) + dur_loss = dur_loss_func(recon_dur, gt_dur) + else: + recon_dur = recon_dur.view(-1, self.dur_width, 2) + gt_dur = x[:, :, 1:, 1:].contiguous().view(-1, self.dur_width) + dur0 = dur_loss_func(recon_dur[:, 0, :], gt_dur[:, 0]) + dur1 = dur_loss_func(recon_dur[:, 1, :], gt_dur[:, 1]) + dur2 = dur_loss_func(recon_dur[:, 2, :], gt_dur[:, 2]) + dur3 = dur_loss_func(recon_dur[:, 3, :], gt_dur[:, 3]) + dur4 = dur_loss_func(recon_dur[:, 4, :], gt_dur[:, 4]) + w = torch.tensor([1, 0.6, 0.4, 0.3, 0.3], + device=recon_dur.device).float() + dur_loss = \ + w[0] * dur0 + \ + w[1] * dur1 + \ + w[2] * dur2 + \ + w[3] * dur3 + \ + w[4] * dur4 + + loss = weights[0] * pitch_loss + weights[1] * dur_loss + + return loss, pitch_loss, dur_loss + + def emb_x(self, x): + lengths = self.get_len_index_tensor(x) + x = self.index_tensor_to_multihot_tensor(x) + embedded = self.note_embedding(x) + return embedded, lengths + + def output_to_numpy(self, recon_pitch, recon_dur): + est_pitch = recon_pitch.max(-1)[1].unsqueeze(-1) # (B, 32, 11, 1) + est_dur = recon_dur.max(-1)[1] # (B, 32, 11, 5) + est_x = torch.cat([est_pitch, est_dur], dim=-1) # (B, 32, 11, 6) + est_x = est_x.cpu().numpy() + recon_pitch = recon_pitch.cpu().numpy() + recon_dur = recon_dur.cpu().numpy() + return est_x, recon_pitch, recon_dur + + def pr_to_notes(self, pr, bpm=80, start=0., one_hot=False): + pr_matrix = self.pr_to_pr_matrix(pr, one_hot) + alpha = 0.25 * 60 / bpm + notes = [] + for t in range(32): + for p in range(128): + if pr_matrix[t, p] >= 1: + s = alpha * t + start + e = alpha * (t + pr_matrix[t, p]) + start + notes.append(pretty_midi.Note(100, int(p), s, e)) + return notes + + def grid_to_pr_and_notes(self, grid, bpm=60., start=0., + truncate_dur=False): + if grid.shape[1] == self.max_simu_note: + grid = grid[:, 1:] + pr = np.zeros((32, 128), dtype=int) + alpha = 0.25 * 60 / bpm + notes = [] + for t in range(32): + for n in range(10): + note = grid[t, n] + if note[0] == self.pitch_eos: + break + pitch = note[0] + self.min_pitch + dur = int(''.join([str(_) for _ in note[1:]]), 2) + 1 + pr[t, pitch] = min(dur, 32 - t) if truncate_dur else dur + notes.append( + pretty_midi.Note(100, int(pitch), start + t * alpha, + start + (t + dur) * alpha)) + return pr, notes diff --git a/orchestrator/dl_modules/pianotree_enc.py b/orchestrator/dl_modules/pianotree_enc.py new file mode 100644 index 0000000..23dc066 --- /dev/null +++ b/orchestrator/dl_modules/pianotree_enc.py @@ -0,0 +1,97 @@ +import torch +from torch import nn +from torch.nn.utils.rnn import pack_padded_sequence +from torch.distributions import Normal + + +class PtvaeEncoder(nn.Module): + + def __init__(self, device, max_simu_note=16, max_pitch=127, min_pitch=0, + pitch_sos=128, pitch_eos=129, pitch_pad=130, + dur_pad=2, dur_width=5, num_step=32, + note_emb_size=128, + enc_notes_hid_size=256, + enc_time_hid_size=512, z_size=512): + super(PtvaeEncoder, self).__init__() + + # Parameters + # note and time + self.max_pitch = max_pitch # the highest pitch in train/val set. + self.min_pitch = min_pitch # the lowest pitch in train/val set. + self.pitch_sos = pitch_sos + self.pitch_eos = pitch_eos + self.pitch_pad = pitch_pad + self.pitch_range = max_pitch - min_pitch + 3 # not including pad. + self.dur_pad = dur_pad + self.dur_width = dur_width + self.note_size = self.pitch_range + dur_width + self.max_simu_note = max_simu_note # the max # of notes at each ts. + self.num_step = num_step # 32 + if device is None: + self.device = 'cuda' if torch.cuda.is_available() else 'cpu' + else: + self.device = device + self.note_emb_size = note_emb_size + self.z_size = z_size + self.enc_notes_hid_size = enc_notes_hid_size + self.enc_time_hid_size = enc_time_hid_size + + self.note_embedding = nn.Linear(self.note_size, note_emb_size) + self.enc_notes_gru = nn.GRU(note_emb_size, enc_notes_hid_size, + num_layers=1, batch_first=True, + bidirectional=True) + self.enc_time_gru = nn.GRU(2 * enc_notes_hid_size, enc_time_hid_size, + num_layers=1, batch_first=True, + bidirectional=True) + self.linear_mu = nn.Linear(2 * enc_time_hid_size, z_size) + self.linear_std = nn.Linear(2 * enc_time_hid_size, z_size) + + def get_len_index_tensor(self, ind_x): + """Calculate the lengths ((B, 32), torch.LongTensor) of pgrid.""" + with torch.no_grad(): + lengths = self.max_simu_note - \ + (ind_x[:, :, :, 0] - self.pitch_pad == 0).sum(dim=-1) + return lengths + + def index_tensor_to_multihot_tensor(self, ind_x): + """Transfer piano_grid to multi-hot piano_grid.""" + # ind_x: (B, 32, max_simu_note, 1 + dur_width) + with torch.no_grad(): + dur_part = ind_x[:, :, :, 1:].float() + out = torch.zeros( + [ind_x.size(0) * self.num_step * self.max_simu_note, + self.pitch_range + 1], + dtype=torch.float).to(ind_x.device) + + out[range(0, out.size(0)), ind_x[:, :, :, 0].view(-1)] = 1. + out = out.view(-1, 32, self.max_simu_note, self.pitch_range + 1) + out = torch.cat([out[:, :, :, 0: self.pitch_range], dur_part], + dim=-1) + return out + + def encoder(self, x, lengths): + embedded = self.note_embedding(x) + # x: (B, num_step, max_simu_note, note_emb_size) + # now x are notes + x = embedded.view(-1, self.max_simu_note, self.note_emb_size) + x = pack_padded_sequence(x, lengths.view(-1), batch_first=True, + enforce_sorted=False) + x = self.enc_notes_gru(x)[-1].transpose(0, 1).contiguous() + x = x.view(-1, self.num_step, 2 * self.enc_notes_hid_size) + # now, x is simu_notes. + x = self.enc_time_gru(x)[-1].transpose(0, 1).contiguous() + # x: (B, 2, enc_time_hid_size) + x = x.view(x.size(0), -1) + mu = self.linear_mu(x) # (B, z_size) + std = self.linear_std(x).exp_() # (B, z_size) + dist = Normal(mu, std) + return dist, embedded + + def forward(self, x, return_iterators=False): + lengths = self.get_len_index_tensor(x).cpu() + x = self.index_tensor_to_multihot_tensor(x) + dist, embedded_x = self.encoder(x, lengths) + if return_iterators: + return dist.mean, dist.scale, embedded_x + else: + return dist, embedded_x, lengths \ No newline at end of file diff --git a/orchestrator/dl_modules/pr_mat_txt_enc.py b/orchestrator/dl_modules/pr_mat_txt_enc.py new file mode 100644 index 0000000..8820a2f --- /dev/null +++ b/orchestrator/dl_modules/pr_mat_txt_enc.py @@ -0,0 +1,48 @@ +from torch import nn +from torch.distributions import Normal +import torch + + +class TextureEncoder(nn.Module): + + def __init__(self, emb_size=256, hidden_dim=1024, z_dim=256, + num_channel=10, return_h=False): + super(TextureEncoder, self).__init__() + self.cnn = nn.Sequential(nn.Conv2d(1, num_channel, kernel_size=(4, 12), + stride=(4, 1), padding=0), + nn.ReLU(), + nn.MaxPool2d(kernel_size=(1, 4), + stride=(1, 4))) + self.fc1 = nn.Linear(num_channel * 29, 1000) + self.fc2 = nn.Linear(1000, emb_size) + self.gru = nn.GRU(emb_size, hidden_dim, batch_first=True, + bidirectional=True) + self.linear_mu = nn.Linear(hidden_dim * 2, z_dim) + self.linear_var = nn.Linear(hidden_dim * 2, z_dim) + self.emb_size = emb_size + self.hidden_dim = hidden_dim + self.z_dim = z_dim + self.return_h = return_h + + def forward(self, pr): + # pr: (bs, 32, 128) + bs = pr.size(0) + pr = pr.unsqueeze(1) + pr = self.cnn(pr).permute(0, 2, 1, 3).reshape(bs, 8, -1) + pr_feat = self.fc2(self.fc1(pr)) # (bs, 8, emb_size) + + # hs, pr = self.gru(pr) + pr = self.gru(pr_feat)[-1] + + pr = pr.transpose_(0, 1).contiguous() + pr = pr.view(pr.size(0), -1) + + mu = self.linear_mu(pr) + var = self.linear_var(pr).exp_() + + dist = Normal(mu, var) + + if self.return_h: + return dist, pr_feat + else: + return dist diff --git a/orchestrator/dl_modules/vqvae.py b/orchestrator/dl_modules/vqvae.py new file mode 100644 index 0000000..6e37705 --- /dev/null +++ b/orchestrator/dl_modules/vqvae.py @@ -0,0 +1,201 @@ +import torch +from torch import nn +import torch.nn.functional as F + +class VectorQuantizerEMA(nn.Module): + """ + Discretization bottleneck of VQ-VAE using EMA with random restart. + After certain iterations, run: + random_restart() + reset_usage() + """ + def __init__(self, embedding_dim, num_embeddings, commitment_cost, decay, usage_threshold, epsilon=1e-5, random_start=False): + super(VectorQuantizerEMA, self).__init__() + self.embedding_dim = embedding_dim + self.num_embeddings = num_embeddings + self.decay = decay + self.commitment_cost = commitment_cost + self.usage_threshold = usage_threshold + self.epsilon = epsilon + self.random_start = random_start + + with torch.no_grad(): + self.embedding = nn.Embedding(self.num_embeddings, self.embedding_dim) + #self.embedding.weight.data.normal_() + self.embedding.weight.data.uniform_(-1.0 / self.num_embeddings, 1.0 / self.num_embeddings) + self.register_buffer('usage', torch.ones(self.num_embeddings), persistent=False) + self.register_buffer('ema_cluster_size', torch.zeros(self.num_embeddings), persistent=False) + self.register_buffer('ema_w', self.embedding.weight.data.clone(), persistent=False) + + self.perplexity = None + self.loss = None + + def update_usage(self, min_enc): + with torch.no_grad(): + self.usage[min_enc] = self.usage[min_enc] + 1 # if code is used add 1 to usage + self.usage /= 2 # decay all codes usage + + def reset_usage(self): + with torch.no_grad(): + self.usage.zero_() # reset usage between certain numbers of iterations + + def random_restart(self, batch_z=None): + # randomly restart all dead codes below threshold with random code from the codebook + with torch.no_grad(): + mean_usage = torch.mean(self.usage[self.usage >= self.usage_threshold]) + dead_codes = torch.nonzero(self.usage < self.usage_threshold).squeeze(1) + if self.random_start: + if batch_z is None: + rand_codes = torch.randperm(self.num_embeddings)[0:len(dead_codes)] + self.embedding.weight[dead_codes] = self.embedding.weight[rand_codes] + self.ema_w[dead_codes] = self.embedding.weight[rand_codes] + else: + LEN = min(len(dead_codes), len(batch_z)) + rand_codes = torch.randperm(len(batch_z))[0:LEN] + self.embedding.weight[dead_codes[0:LEN]] = batch_z[rand_codes] + self.ema_w[dead_codes[0:LEN]] = batch_z[rand_codes] + return mean_usage, len(dead_codes) + + def forward(self, z, track_pad_mask=None): + #z shape: (batch*max_track, embedding_dim) + #track_pad_mask: (batch, max_track) + assert(z.shape[-1] == self.embedding_dim) + input_shape = z.shape + z = z.reshape(-1, z.shape[-1]) + track_pad_mask = track_pad_mask.reshape(-1) + + distance = torch.sum(z ** 2, dim=1, keepdim=True) \ + + torch.sum(self.embedding.weight ** 2, dim=1) \ + - 2 * torch.matmul(z, self.embedding.weight.t()) #(batch*max_track, num_embeddings) + + min_encoding_indices = torch.argmin(distance, dim=1) #(batch*max_track,) + #print(min_encoding_indices) + min_encodings = torch.zeros(len(min_encoding_indices), self.num_embeddings, device=z.device) + min_encodings.scatter_(1, min_encoding_indices.unsqueeze(1), 1) #(batch*max_track, num_embeddings) + + z_q = torch.matmul(min_encodings, self.embedding.weight) #(batch*max_track, embedding_dim) + + self.update_usage(min_encoding_indices[torch.logical_not(track_pad_mask)]) + + if self.training: + with torch.no_grad(): + self.ema_cluster_size -= (1 - self.decay) * (self.ema_cluster_size - torch.sum(min_encodings[torch.logical_not(track_pad_mask)], dim=0)) + #laplacian smoothing + n = torch.sum(self.ema_cluster_size.data) + self.ema_cluster_size = (self.ema_cluster_size + self.epsilon) * n \ + / (n + self.num_embeddings * self.epsilon) + + dw = torch.matmul(min_encodings[torch.logical_not(track_pad_mask)].t(), z[torch.logical_not(track_pad_mask)]) #(num_embeddings, embed_dim) + self.ema_w -= (1-self.decay) * (self.ema_w - dw) + self.embedding.weight.data = self.ema_w / self.ema_cluster_size.unsqueeze(-1) + + e_latent_loss = F.mse_loss(z_q.detach(), z) + loss = self.commitment_cost * e_latent_loss + + quantized = (z + (z_q - z).detach()).reshape(input_shape) + avg_probs = torch.mean(min_encodings[torch.logical_not(track_pad_mask)], dim=0) + perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) + + return quantized, loss, perplexity + + def get_code_indices(self, z): + assert(z.shape[-1] == self.embedding_dim) + input_shape = z.shape + z = z.reshape(-1, z.shape[-1]) + + distance = torch.sum(z ** 2, dim=1, keepdim=True) \ + + torch.sum(self.embedding.weight ** 2, dim=1) \ + - 2 * torch.matmul(z, self.embedding.weight.t()) #(batch*max_track, num_embeddings) + + min_encoding_indices = torch.argmin(distance, dim=1) #(batch*max_track,) + return min_encoding_indices.reshape(input_shape[:-1]) + + def infer_code(self, encoding_indices): + input_shape = encoding_indices.shape + encoding_indices = encoding_indices.reshape(-1) + encodings = torch.zeros(len(encoding_indices), self.num_embeddings, device=encoding_indices.device) + encodings.scatter_(1, encoding_indices.unsqueeze(1), 1) #(batch*max_track, num_embeddings) + z_q = torch.matmul(encodings, self.embedding.weight) + return z_q.reshape(*list(input_shape), self.embedding_dim) + + + + +class VectorQuantizer(nn.Module): + """ + Discretization bottleneck of VQ-VAE with random restart. + After certain iterations, run: + random_restart() + reset_usage() + """ + def __init__(self, embedding_dim, num_embeddings, commitment_cost, usage_threshold, random_start=False): + super(VectorQuantizer, self).__init__() + self.embedding_dim = embedding_dim + self.num_embeddings = num_embeddings + self.commitment_cost = commitment_cost + self.usage_threshold = usage_threshold + self.random_start = random_start + + with torch.no_grad(): + self.embedding = nn.Embedding(self.num_embeddings, self.embedding_dim) + #self.embedding.weight.data.normal_() + self.embedding.weight.data.uniform_(-1.0 / self.num_embeddings, 1.0 / self.num_embeddings) + self.register_buffer('usage', torch.ones(self.num_embeddings), persistent=False) + + self.perplexity = None + self.loss = None + + def update_usage(self, min_enc): + with torch.no_grad(): + self.usage[min_enc] = self.usage[min_enc] + 1 # if code is used add 1 to usage + self.usage /= 2 # decay all codes usage + + def reset_usage(self): + with torch.no_grad(): + self.usage.zero_() # reset usage between certain numbers of iterations + + def random_restart(self): + # randomly restart all dead codes below threshold with random code from the codebook + with torch.no_grad(): + dead_codes = torch.nonzero(self.usage < self.usage_threshold).squeeze(1) + if self.random_start: + rand_codes = torch.randperm(self.num_embeddings)[0:len(dead_codes)] + self.embedding.weight[dead_codes] = self.embedding.weight[rand_codes] + return len(dead_codes) + + def forward(self, z, track_pad_mask=None): + #z shape: (batch, max_track, embedding_dim) + #track_pad_mask: (batch, max_track) + assert(z.shape[-1] == self.embedding_dim) + input_shape = z.shape + z = z.reshape(-1, z.shape[-1]) + track_pad_mask = track_pad_mask.reshape(-1) + + distance = torch.sum(z ** 2, dim=1, keepdim=True) \ + + torch.sum(self.embedding.weight ** 2, dim=1) \ + - 2 * torch.matmul(z, self.embedding.weight.t()) #(batch*max_track, num_embeddings) + + min_encoding_indices = torch.argmin(distance, dim=1) #(batch*max_track,) + #print(min_encoding_indices) + min_encodings = torch.zeros(len(min_encoding_indices), self.num_embeddings, device=z.device) + min_encodings.scatter_(1, min_encoding_indices.unsqueeze(1), 1) #(batch*max_track, num_embeddings) + + z_q = torch.matmul(min_encodings, self.embedding.weight) #(batch*max_track, embedding_dim) + + self.update_usage(min_encoding_indices[torch.logical_not(track_pad_mask)]) + + e_latent_loss = F.mse_loss(z_q.detach(), z) + q_latent_loss = F.mse_loss(z_q, z.detach()) + loss = q_latent_loss + self.commitment_cost * e_latent_loss + + quantized = (z + (z_q - z).detach()).reshape(input_shape) + avg_probs = torch.mean(min_encodings[torch.logical_not(track_pad_mask)], dim=0) + perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) + + return quantized, loss, perplexity + + + + + + \ No newline at end of file diff --git a/orchestrator/scheduler.py b/orchestrator/scheduler.py new file mode 100644 index 0000000..d5c6522 --- /dev/null +++ b/orchestrator/scheduler.py @@ -0,0 +1,120 @@ +import numpy as np +from utils import scheduled_sampling +from torch.optim.lr_scheduler import ExponentialLR + + +class MinExponentialLR(ExponentialLR): + def __init__(self, optimizer, gamma, minimum, last_epoch=-1): + self.min = minimum + super(MinExponentialLR, self).__init__(optimizer, gamma, last_epoch=-1) + + def get_lr(self): + return [ + max(base_lr * self.gamma ** self.last_epoch, self.min) + for base_lr in self.base_lrs + ] + + +class _Scheduler: + + def __init__(self, step=0, mode='train'): + self._step = step + self._mode = mode + + def _update_step(self): + if self._mode == 'train': + self._step += 1 + elif self._mode == 'val': + pass + else: + raise NotImplementedError + + def step(self): + raise NotImplementedError + + def train(self): + self._mode = 'train' + + def eval(self): + self._mode = 'val' + + +class ConstantScheduler(_Scheduler): + + def __init__(self, param, step=0.): + super(ConstantScheduler, self).__init__(step) + self.param = param + + def step(self, scaler=None): + self._update_step() + return self.param + + +class TeacherForcingScheduler(_Scheduler): + + def __init__(self, high, low, scaler, f=scheduled_sampling, step=0): + super(TeacherForcingScheduler, self).__init__(step) + self.high = high + self.low = low + self._step = step + self.scaler = scaler + self.schedule_f = f + + def get_tfr(self): + return self.schedule_f(self._step/self.scaler, self.high, self.low) + + def step(self): + tfr = self.get_tfr() + self._update_step() + return tfr + + +class OptimizerScheduler(_Scheduler): + + def __init__(self, optimizer, scheduler, clip, step=0): + # optimizer and scheduler are pytorch class + super(OptimizerScheduler, self).__init__(step) + self.optimizer = optimizer + self.scheduler = scheduler + self.clip = clip + + def optimizer_zero_grad(self): + self.optimizer.zero_grad() + + def step(self, require_zero_grad=False): + self.optimizer.step() + if self.scheduler is not None: + self.scheduler.step() + if require_zero_grad: + self.optimizer_zero_grad() + self._update_step() + + +class ParameterScheduler(_Scheduler): + + def __init__(self, step=0, mode='train', **schedulers): + # optimizer and scheduler are pytorch class + super(ParameterScheduler, self).__init__(step) + self.schedulers = schedulers + self.mode = mode + + def train(self): + self.mode = 'train' + for scheduler in self.schedulers.values(): + scheduler.train() + + def eval(self): + self.mode = 'val' + for scheduler in self.schedulers.values(): + scheduler.eval() + + def step(self, require_zero_grad=False): + params_dic = {} + for key, scheduler in self.schedulers.items(): + params_dic[key] = scheduler.step() + return params_dic + + + + + diff --git a/orchestrator/scripts/data_preprocessing/converter.py b/orchestrator/scripts/data_preprocessing/converter.py new file mode 100644 index 0000000..58bf3a8 --- /dev/null +++ b/orchestrator/scripts/data_preprocessing/converter.py @@ -0,0 +1,167 @@ +import numpy as np +import pretty_midi as pm + + +def bpm_to_rate(bpm): + return 60 / bpm + + +def ext_nmat_to_nmat(ext_nmat): + nmat = np.zeros((ext_nmat.shape[0], 4)) + nmat[:, 0] = ext_nmat[:, 0] + ext_nmat[:, 1] / ext_nmat[:, 2] + nmat[:, 1] = ext_nmat[:, 3] + ext_nmat[:, 4] / ext_nmat[:, 5] + nmat[:, 2] = ext_nmat[:, 6] + nmat[:, 3] = ext_nmat[:, 7] + return nmat + + +# def nmat_to_pr(nmat, num_step=32): +# pr = np.zeros((num_step, 128)) +# for s, e, p, v in pr: +# pr[s, p] + +def nmat_to_notes(nmat, start, bpm): + notes = [] + for s, e, p, v in nmat: + assert s < e + assert 0 <= p < 128 + assert 0 <= v < 128 + s = start + s * bpm_to_rate(bpm) + e = start + e * bpm_to_rate(bpm) + notes.append(pm.Note(int(v), int(p), s, e)) + return notes + + +def ext_nmat_to_pr(ext_nmat, num_step=32): + # [start measure, no, deno, .., .., .., pitch, vel] + # This is not RIGHT in general. Only works for 2-bar 4/4 music for now. + pr = np.zeros((32, 128)) + if ext_nmat is not None: + for (sb, sq, sde, eb, eq, ede, p, v) in ext_nmat: + s_ind = int(sb * sde + sq) + e_ind = int(eb * ede + eq) + p = int(p) + pr[s_ind, p] = 2 + pr[s_ind + 1: e_ind, p] = 1 # note not including the last ind + return pr + + +def ext_nmat_to_mel_pr(ext_nmat, num_step=32): + # [start measure, no, deno, .., .., .., pitch, vel] + # This is not RIGHT in general. Only works for 2-bar 4/4 music for now. + pr = np.zeros((32, 130)) + pr[:, 129] = 1 + if ext_nmat is not None: + for (sb, sq, sde, eb, eq, ede, p, v) in ext_nmat: + s_ind = int(sb * sde + sq) + e_ind = int(eb * ede + eq) + p = int(p) + pr[s_ind, p] = 1 + pr[s_ind: e_ind, 129] = 0 + pr[s_ind + 1: e_ind, 128] = 1 # note not including the last ind + return pr + + +def augment_pr(pr, shift=0): + # it assures to work on single pr + # for an array of pr, should double-check + return np.roll(pr, shift, axis=-1) + + +def augment_mel_pr(pr, shift=0): + # it only works on single mel_pr. Not on array of it. + pitch_part = np.roll(pr[:, 0: 128], shift, axis=-1) + control_part = pr[:, 128:] + augmented_pr = np.concatenate([pitch_part, control_part], axis=-1) + return augmented_pr + +def pr_to_onehot_pr(pr): + onset_data = pr[:, :] == 2 + sustain_data = pr[:, :] == 1 + silence_data = pr[:, :] == 0 + pr = np.stack([onset_data, sustain_data, silence_data], + axis=-1).astype(np.int64) + return pr + + +def piano_roll_to_target(pr): + # pr: (32, 128, 3), dtype=bool + + # Assume that "not (first_layer or second layer) = third_layer" + pr[:, :, 1] = np.logical_not(np.logical_or(pr[:, :, 0], pr[:, :, 2])) + # To int dtype can make addition work + pr = pr.astype(int) + # Initialize a matrix to store the duration of a note on the (32, 128) grid + pr_matrix = np.zeros((32, 128)) + + for i in range(31, -1, -1): + # At each iteration + # 1. Assure that the second layer accumulates the note duration + # 2. collect the onset notes in time step i, and mark it on the matrix. + + # collect + onset_idx = np.where(pr[i, :, 0] == 1)[0] + pr_matrix[i, onset_idx] = pr[i, onset_idx, 1] + 1 + if i == 0: + break + # Accumulate + # pr[i - 1, :, 1] += pr[i, :, 1] + # pr[i - 1, onset_idx, 1] = 0 # the onset note should be set 0. + + pr[i, onset_idx, 1] = 0 # the onset note should be set 0. + pr[i - 1, :, 1] += pr[i, :, 1] + return pr_matrix + + +def target_to_3dtarget(pr_mat, max_note_count=11, max_pitch=107, min_pitch=22, + pitch_pad_ind=88, dur_pad_ind=2, + pitch_sos_ind=86, pitch_eos_ind=87): + """ + :param pr_mat: (32, 128) matrix. pr_mat[t, p] indicates a note of pitch p, + started at time step t, has a duration of pr_mat[t, p] time steps. + :param max_note_count: the maximum number of notes in a time step, + including and tokens. + :param max_pitch: the highest pitch in the dataset. + :param min_pitch: the lowest pitch in the dataset. + :param pitch_pad_ind: see return value. + :param dur_pad_ind: see return value. + :param pitch_sos_ind: sos token. + :param pitch_eos_ind: eos token. + :return: pr_mat3d is a (32, max_note_count, 6) matrix. In the last dim, + the 0th column is for pitch, 1: 6 is for duration in binary repr. Output is + padded with and tokens in the pitch column, but with pad token + for dur columns. + """ + pitch_range = max_pitch - min_pitch + 1 # including pad + pr_mat3d = np.ones((32, max_note_count, 6), dtype=int) * dur_pad_ind + pr_mat3d[:, :, 0] = pitch_pad_ind + pr_mat3d[:, 0, 0] = pitch_sos_ind + cur_idx = np.ones(32, dtype=int) + for t, p in zip(*np.where(pr_mat != 0)): + pr_mat3d[t, cur_idx[t], 0] = p - min_pitch + binary = np.binary_repr(min(int(pr_mat[t, p]), 32) - 1, width=5) + pr_mat3d[t, cur_idx[t], 1: 6] = \ + np.fromstring(' '.join(list(binary)), dtype=int, sep=' ') + if cur_idx[t] == max_note_count-1: + continue + cur_idx[t] += 1 + #print(cur_idx) + pr_mat3d[np.arange(0, 32), cur_idx, 0] = pitch_eos_ind + return pr_mat3d + + +def expand_chord(chord, shift, relative=False): + # chord = np.copy(chord) + root = (chord[0] + shift) % 12 + chroma = np.roll(chord[1: 13], shift) + bass = (chord[13] + shift) % 12 + root_onehot = np.zeros(12) + root_onehot[int(root)] = 1 + bass_onehot = np.zeros(12) + bass_onehot[int(bass)] = 1 + if not relative: + pass + # chroma = np.roll(chroma, int(root)) + # print(chroma) + # print('----------') + return np.concatenate([root_onehot, chroma, bass_onehot]) diff --git a/orchestrator/scripts/data_preprocessing/lmd_midi_quantization.py b/orchestrator/scripts/data_preprocessing/lmd_midi_quantization.py new file mode 100644 index 0000000..38083fb --- /dev/null +++ b/orchestrator/scripts/data_preprocessing/lmd_midi_quantization.py @@ -0,0 +1,153 @@ +import os +import numpy as np +import pretty_midi as pyd +import soundfile as sf +from tqdm import tqdm +from math import ceil +from scipy.interpolate import interp1d +from quantization_utils import Converter +import librosa +import yaml + +import sys +sys.path.append('../exported_midi_chord_recognition/') +from mir import DataEntry +from mir import io +from extractors.midi_utilities import get_valid_channel_count, is_percussive_channel, MidiBeatExtractor +from main import process_chord +import mir_eval + +from time_stretch import time_stretch + + +def midi2matrix(midi, quaver): + """ + Convert multi-track midi to a 3D matrix of shape (Track, Time, 128). + Each cell is a integer number representing quantized duration. + """ + pr_matrices = [] + programs = [] + quantization_error = [] + for track in midi.instruments: + if track.is_drum: + continue + qt_error = [] # record quantization error + pr_matrix = np.zeros((len(quaver), 128)) + for note in track.notes: + note_start = np.argmin(np.abs(quaver - note.start)) + note_end = np.argmin(np.abs(quaver - note.end)) + if note_end == note_start: + note_end = min(note_start + 1, len(quaver) - 1) # guitar/bass plunk typically results in a very short note duration. These note should be quantized to 1 instead of 0. + pr_matrix[note_start, note.pitch] = note_end - note_start + + #compute quantization error. A song with very high error (e.g., triple-quaver songs) will be discriminated and therefore discarded. + if note_end == note_start: + qt_error.append(np.abs(quaver[note_start] - note.start) / (quaver[note_start] - quaver[note_start-1])) + else: + qt_error.append(np.abs(quaver[note_start] - note.start) / (quaver[note_end] - quaver[note_start])) + + pr_matrices.append(pr_matrix) + programs.append(track.program) + quantization_error.append(np.mean(qt_error)) + + return np.array(pr_matrices), np.array(programs), quantization_error + + +def extrac_chord_matrix(midi_path, quaver, extra_division=1): + ''' + Perform chord recognition on a midi + ''' + entry = DataEntry() + entry.append_file(midi_path, io.MidiIO, 'midi') + entry.append_extractor(MidiBeatExtractor, 'beat') + result = process_chord(entry, extra_division) + + beat_quaver = quaver[::4] + chord_matrix = np.zeros((len(beat_quaver), 14)) + chord_matrix[:, 0] = -1 + chord_matrix[:, -1] = -1 + + for chord in result: + chord_start = np.argmin(np.abs(beat_quaver - chord[0])) + chord_end = np.argmin(np.abs(beat_quaver - chord[1])) + root, bitmap, bass_rel = mir_eval.chord.encode(chord[2]) + chroma = mir_eval.chord.rotate_bitmap_to_root(bitmap, root) + chord = np.concatenate(([root], chroma, [bass_rel]), axis=-1) + chord_matrix[chord_start: chord_end, :] = chord + return chord_matrix + + +ACC = 4 + +slakh_root = '../Q&A/slakh2100_flac_redux/' +slakh_ids = [] +for split in ['train', 'validation', 'test', 'omitted']: + slakh_split = os.path.join(slakh_root, split) + for song in tqdm(os.listdir(slakh_split)): + track_id = yaml.safe_load(open(os.path.join(slakh_split, song, 'metadata.yaml'), 'r'))['UUID'] + slakh_ids.append(track_id) +print(len(slakh_ids)) + +lmd_root = '../LMD/lmd_full/' +lmd_midi = {} +slakh_midi = {} +for folder in os.listdir(lmd_root): + sub_folder = os.path.join(lmd_root, folder) + for piece in os.listdir(sub_folder): + midi_id = piece.split('.')[0] + if midi_id in slakh_ids: + slakh_midi[midi_id] = os.path.join(sub_folder, piece) + else: + lmd_midi[midi_id] = os.path.join(sub_folder, piece) +print(len(slakh_midi), len(lmd_midi)) + + +save_root = "../LMD/4_bin_quantization_chord/" +print(f'processing LMD ...') +for song_id in tqdm(lmd_midi): + break_flag = 0 + + try: + all_src_midi = pyd.PrettyMIDI(lmd_midi[song_id]) + except: + continue + for ts in all_src_midi.time_signature_changes: + if not (((ts.numerator == 2) or (ts.numerator == 4)) and (ts.denominator == 4)): + break_flag = 1 + break + if break_flag: + continue # process only 2/4 and 4/4 songs + + beats = all_src_midi.get_beats() + if len(beats) < 32: + continue #skip pieces shorter than 8 bars + downbeats = all_src_midi.get_downbeats() + + beats = np.append(beats, beats[-1] + (beats[-1] - beats[-2])) + quantize = interp1d(np.array(range(0, len(beats))) * ACC, beats, kind='linear') + quaver = quantize(np.array(range(0, (len(beats) - 1) * ACC))) + + #pr_matrices = [] + #programs = [] + + #break_flag = 0 + + #pr_matrices, programs, track_qt = midi2matrix(all_src_midi, quaver) + #for item in track_qt: + # if item > .2: + # break_flag = 1 + #if break_flag: + # continue #skip the pieces with very large quantization error. This pieces are possibly triple-quaver songs + + + try: + chord_matrix = extrac_chord_matrix(lmd_midi[song_id], quaver) + except: + continue + + #db_indicator = np.array([int(t in downbeats) for t in quaver]) + + + np.save(os.path.join(save_root, f'{song_id}.npy'), chord_matrix) + + \ No newline at end of file diff --git a/orchestrator/scripts/data_preprocessing/pop909_process_4bin_data.py b/orchestrator/scripts/data_preprocessing/pop909_process_4bin_data.py new file mode 100644 index 0000000..20c1fc9 --- /dev/null +++ b/orchestrator/scripts/data_preprocessing/pop909_process_4bin_data.py @@ -0,0 +1,174 @@ +import os +import numpy as np +from tqdm import tqdm +import pandas as pd +from scipy.interpolate import interp1d +import pretty_midi as pyd +from scipy import stats as st + + +def convert_pop909(melody, bridge, piano, beat): + pr = np.zeros((3, len(beat)*4, 128, 2)) + for (sb, sq, sde, eb, eq, ede, p, v) in melody: + assert sde==4 + assert ede==4 + s_ind = int(sb * sde + sq) + e_ind = int(eb * ede + eq) + p = int(p) + pr[0, s_ind, p, 0] = e_ind - s_ind + pr[0, s_ind, p, 1] = v + for (sb, sq, sde, eb, eq, ede, p, v) in bridge: + assert sde==4 + assert ede==4 + s_ind = int(sb * sde + sq) + e_ind = int(eb * ede + eq) + p = int(p) + pr[1, s_ind, p, 0] = e_ind - s_ind + pr[1, s_ind, p, 1] = v + for (sb, sq, sde, eb, eq, ede, p, v) in piano: + assert sde==4 + assert ede==4 + s_ind = int(sb * sde + sq) + e_ind = int(eb * ede + eq) + p = int(p) + pr[2, s_ind, p, 0] = e_ind - s_ind + pr[2, s_ind, p, 1] = v + return pr + + +def midi2matrix(midi, quaver): + pr_matrices = [] + programs = [] + quantization_error = [] + for track in midi.instruments: + qt_error = [] # record quantization error + pr_matrix = np.zeros((len(quaver), 128, 2)) + for note in track.notes: + note_start = np.argmin(np.abs(quaver - note.start)) + note_end = np.argmin(np.abs(quaver - note.end)) + if note_end == note_start: + note_end = min(note_start + 1, len(quaver) - 1) # guitar/bass plunk typically results in a very short note duration. These note should be quantized to 1 instead of 0. + pr_matrix[note_start, note.pitch, 0] = note_end - note_start + pr_matrix[note_start, note.pitch, 1] = note.velocity + + #compute quantization error. A song with very high error (e.g., triple-quaver songs) will be discriminated and therefore discarded. + if note_end == note_start: + qt_error.append(np.abs(quaver[note_start] - note.start) / (quaver[note_start] - quaver[note_start-1])) + else: + qt_error.append(np.abs(quaver[note_start] - note.start) / (quaver[note_end] - quaver[note_start])) + + control_matrix = np.ones((len(quaver), 128, 1)) * -1 + for control in track.control_changes: + #if control.time < time_end: + # if len(quaver) == 0: + # continue + control_time = np.argmin(np.abs(quaver - control.time)) + control_matrix[control_time, control.number, 0] = control.value + + pr_matrix = np.concatenate((pr_matrix, control_matrix), axis=-1) + pr_matrices.append(pr_matrix) + programs.append(track.program) + quantization_error.append(np.mean(qt_error)) + + return np.array(pr_matrices), programs, quantization_error + + +def retrieve_control(pop909_midi_dir, song, tracks): + src_dir = os.path.join(pop909_midi_dir, song.split('.')[0], song.replace('.npz', '.mid')) + src_midi = pyd.PrettyMIDI(src_dir) + beats = src_midi.get_beats() + beats = np.append(beats, beats[-1] + (beats[-1] - beats[-2])) + ACC = 4 + quantize = interp1d(np.array(range(0, len(beats))) * ACC, beats, kind='linear') + quaver = quantize(np.array(range(0, (len(beats) - 1) * ACC))) + + pr_matrices, programs, _ = midi2matrix(src_midi, quaver) + + from_4_bin = np.nonzero(tracks[0, :, :, 0]) + from_midi = np.nonzero(pr_matrices[0, :, :, 0]) + + mid_length = min(from_midi[1].shape[0], from_4_bin[1].shape[0]) + diff = from_midi[1][: mid_length] - from_4_bin[1][: mid_length] + diff_avg = np.mean(diff) + diff_std = np.std(diff) + + if diff_std > 0: + diff_record = [] + for roll_idx in range(-32, 32): + roll_pitches = np.roll(from_midi[1], shift=roll_idx, axis=0) + diff = roll_pitches[abs(roll_idx): mid_length-abs(roll_idx)] - from_4_bin[1][abs(roll_idx): mid_length-abs(roll_idx)] + diff_avg = np.mean(diff) + diff_std = np.std(diff) + diff_record.append((roll_idx, diff_avg, diff_std)) + diff_record = sorted(diff_record, key=lambda x: x[2]) + + roll_idx_min = diff_record[0][0] + roll_times = np.roll(from_midi[0], shift=roll_idx_min, axis=0) + diff = roll_times[abs(roll_idx_min): mid_length-abs(roll_idx_min)] - from_4_bin[0][abs(roll_idx_min): mid_length-abs(roll_idx_min)] + else: + diff = from_midi[0][: mid_length] - from_4_bin[0][: mid_length] + return pr_matrices[:, :, :, 2: 3], st.mode(diff).mode[0] + + +pop909_4bin_dir = '../Q&A/POP909-Dataset/quantization/POP09-PIANOROLL-4-bin-quantization/' +pop909_midi_dir = '../Q&A/POP909-Dataset/POP909/' +meta_info = pd.read_excel(os.path.join(pop909_midi_dir, 'index.xlsx')) +save_root = '../Q&A/POP909-Dataset/quantization/4_bin_midi_quantization_with_dynamics_and_chord/' + +for split in ['train', 'validation', 'test']: + save_split = os.path.join(save_root, split) + if not os.path.exists(save_split): + os.makedirs(save_split) + print(f'processing {split} set ...') + + pop909_list = os.listdir(pop909_4bin_dir) + if split == 'train': + pop909_list = pop909_list[: int(len(pop909_list)*.8)] + elif split == 'validation': + pop909_list = pop909_list[int(len(pop909_list)*.8): int(len(pop909_list)*.9)] + elif split == 'test': + pop909_list = pop909_list[int(len(pop909_list)*.9): ] + + for song in tqdm(pop909_list): + song_meta = meta_info[meta_info.song_id == int(song.replace('.npz', ''))] + num_beats = song_meta.num_beats_per_measure.values[0] + num_quavers = song_meta.num_quavers_per_beat.values[0] + if int(num_beats) == 3 or int(num_quavers) == 3: + continue #neglect pieces with triplet meters + pop909_data = np.load(os.path.join(pop909_4bin_dir, song)) + beats = pop909_data['beat'] + melody= pop909_data['melody'] + bridge= pop909_data['bridge'] + piano= pop909_data['piano'] + + tracks = convert_pop909(melody, bridge, piano, beats) + + + + track_control = np.ones((tracks.shape[0], tracks.shape[1], 128, 1)) * -1 + cc, shift = retrieve_control(pop909_midi_dir, song, tracks) + if shift >= 0: + track_control[:, : min(cc.shape[1] - shift, track_control.shape[1])] = cc[:, shift: min(cc.shape[1], track_control.shape[1] + shift)] + else: + track_control[:, -shift: min(cc.shape[1] - shift, track_control.shape[1]) ] = cc[:, :min(cc.shape[1], track_control.shape[1] + shift)] + + pr_matrices = tracks[..., 0] + dynamic_matrices = np.concatenate([tracks[..., 1:], track_control], axis=-1) + chord_matrices = pop909_data['chord'] + downbeat_indicator = np.zeros(len(beats)*4) + for idx, beat in enumerate(beats): + if beat[3] == 0: + downbeat_indicator[idx*4] = 1 + + #print(pr_matrices.shape) + #print(dynamic_matrices.shape) + #print(chord_matrices.shape) + #print(downbeat_indicator.shape) + + np.savez(os.path.join(save_split, song),\ + tracks = pr_matrices,\ + db_indicator = downbeat_indicator,\ + dynamics = dynamic_matrices, \ + chord = chord_matrices) + + #break \ No newline at end of file diff --git a/orchestrator/scripts/data_preprocessing/quantization_utils.py b/orchestrator/scripts/data_preprocessing/quantization_utils.py new file mode 100644 index 0000000..0c8bad2 --- /dev/null +++ b/orchestrator/scripts/data_preprocessing/quantization_utils.py @@ -0,0 +1,223 @@ +import os +import numpy as np +import pretty_midi as pyd +from scipy.interpolate import interp1d + +import sys +sys.path.append('../exported_midi_chord_recognition/') +from mir import DataEntry +from mir import io +from extractors.midi_utilities import get_valid_channel_count,is_percussive_channel,MidiBeatExtractor +from main import process_chord +import mir_eval + +import librosa +from tqdm import tqdm + +from joblib import Parallel, delayed +from multiprocessing import Manager + + +class Converter: + def __init__(self, ACC): + self.ACC = ACC + self.max_note_count=11 + self.max_pitch=107 + self.min_pitch=22, + self.pitch_sos_ind=86 + self.pitch_eos_ind=87 + self.pitch_pad_ind=88 + self.dur_pad_ind=2 + + + def midi2matrix(self, midi, quaver): + """ + Convert multi-track midi to a 3D matrix of shape (Track, Time, 128). + Each cell is a integer number representing quantized duration. + """ + tracks = [] + unique_programs = [] + for track in midi.instruments: + """ + merge tracks with the same program number into one single track + """ + if not (track.program in unique_programs): + unique_programs.append(track.program) + tracks.append(track) + else: + idx = unique_programs.index(track.program) + notes = tracks[idx].notes + track.notes + notes = sorted(notes, key=lambda x: x.start, reverse=False) + tracks[idx].notes = notes + + pr_matrices = [] + for track in tracks: + pr_matrix = np.zeros((len(quaver), 128)) + for note in track.notes: + note_start = np.argmin(np.abs(quaver - note.start)) + note_end = np.argmin(np.abs(quaver - note.end)) + pr_matrix[note_start, note.pitch] = note_end - note_start + pr_matrices.append(pr_matrix) + return np.array(pr_matrices, dtype=np.uint8), np.array(unique_programs, dtype=np.uint8) + + def matrix_compress(self, pr_mat): + #pr_mat: (time, 128) + T = pr_mat.shape[0] + pr_mat3d = np.zeros((T, self.max_note_count, 2), dtype=np.uint8) + pr_mat3d[:, :, 0] = self.pitch_pad_ind + pr_mat3d[:, 0, 0] = self.pitch_sos_ind + cur_idx = np.ones(T, dtype=int) + for t, p in zip(*np.where(pr_mat != 0)): + if p < self.min_pitch or p > self.max_pitch: + continue + pr_mat3d[t, cur_idx[t], 0] = p - self.min_pitch + pr_mat3d[t, cur_idx[t], 1] = min(int(pr_mat[t, p]), 128) + if cur_idx[t] == self.max_note_count-1: + continue + cur_idx[t] += 1 + #print(cur_idx) + pr_mat3d[np.arange(0, T), cur_idx, 0] = self.pitch_eos_ind + return pr_mat3d + + def matrix_decompress(self, mat_compress): + #mat_compress: (time, max_simu_note, 2) + pr_mat = np.zeros((mat_compress.shape[0], 128)) + + for t, p in zip(*np.where(mat_compress[:, 1:, 0] < self.pitch_eos_ind)): + pitch_rel = mat_compress[t, p+1, 0] + pr_mat[t, pitch_rel + self.min_pitch] = mat_compress[t, p+1, 1] + + return pr_mat + + + def matrix2midi(self, pr_matrices, programs, init_tempo=120, time_start=0): + """ + Reconstruct a multi-track midi from a 3D matrix of shape (Track. Time, 128). + """ + tracks = [] + for program in programs: + track_recon = pyd.Instrument(program=int(program), is_drum=False, name=pyd.program_to_instrument_name(int(program))) + tracks.append(track_recon) + + indices_track, indices_onset, indices_pitch = np.nonzero(pr_matrices) + alpha = 1 / (self.ACC // 4) * 60 / init_tempo #timetep between each quntization bin + for idx in range(len(indices_track)): + track_id = indices_track[idx] + onset = indices_onset[idx] + pitch = indices_pitch[idx] + + start = onset * alpha + duration = pr_matrices[track_id, onset, pitch] * alpha + velocity = 100 + + note_recon = pyd.Note(velocity=int(velocity), pitch=int(pitch), start=time_start + start, end=time_start + start + duration) + tracks[track_id].notes.append(note_recon) + + midi_recon = pyd.PrettyMIDI(initial_tempo=init_tempo) + midi_recon.instruments = tracks + return midi_recon + + + def extrac_chord_matrix(self, midi_path, quaver, extra_division=1): + ''' + Perform chord recognition on a midi + ''' + entry = DataEntry() + entry.append_file(midi_path,io.MidiIO, 'midi') + entry.append_extractor(MidiBeatExtractor, 'beat') + result = process_chord(entry, extra_division) + + beat_quaver = quaver[::4] + chord_matrix = np.zeros((len(beat_quaver), 14), dtype=np.uint8) + chord_matrix[:, 0] = -1 + chord_matrix[:, -1] = -1 + + for chord in result: + chord_start = np.argmin(np.abs(beat_quaver - chord[0])) + chord_end = np.argmin(np.abs(beat_quaver - chord[1])) + root, bitmap, bass_rel = mir_eval.chord.encode(chord[2]) + chroma = mir_eval.chord.rotate_bitmap_to_root(bitmap, root) + chord = np.concatenate(([root], chroma, [bass_rel]), axis=-1) + chord_matrix[chord_start: chord_end, :] = chord + return chord_matrix + + + def chord_matrix2midi(self, chord_matrix, init_tempo=120, time_start=0): + alpha = 1 / (self.ACC // 4) * 60 / init_tempo * 4 #timetep between each quntization bin + onset_or_rest = [i for i in range(1, len(chord_matrix)) if (chord_matrix[i] != chord_matrix[i-1]).any()] + onset_or_rest = [0] + onset_or_rest + onset_or_rest.append(len(chord_matrix)) + + chordTrack = pyd.Instrument(program=0, is_drum=False, name='Chord') + for idx, onset in enumerate(onset_or_rest[:-1]): + chordset = [int(i) for i in chord_matrix[onset]] + start = onset * alpha + end = onset_or_rest[idx+1] * alpha + root = chordset[0] + chroma = chordset[1: 13] + bass_rel = chordset[13] + bass_bitmap = np.roll(chroma, shift=-root-bass_rel) + bass = root + bass_rel + + if np.argmax(bass_bitmap) + bass >= 12: + register = 3 + else: + register = 4 + + for entry in np.nonzero(bass_bitmap)[0]: + pitch = register * 12 + bass + entry + note = pyd.Note(velocity=100, pitch=int(pitch), start=time_start + start, end=time_start + end) + chordTrack.notes.append(note) + return chordTrack + + + def pr2gird(self, compress_pr): + #compress_pr: (time, max_simu_note, 2) + pr_mat3d = np.ones((compress_pr.shape[0], self.max_note_count, 6), dtype=int) * self.dur_pad_ind + pr_mat3d[:, :, 0] = self.pitch_pad_ind + pr_mat3d[:, 0, 0] = self.pitch_sos_ind + cur_idx = np.ones(compress_pr.shape[0], dtype=int) + for t, p in zip(*np.where(compress_pr[:, 1:, 0] < self.pitch_eos_ind)): + pitch_rel = compress_pr[t, p+1, 0] + duration = compress_pr[t, p+1, 1] + pr_mat3d[t, cur_idx[t], 0] = pitch_rel + binary = np.binary_repr(min(duration, 32 - t) - 1, width=5) + pr_mat3d[t, cur_idx[t], 1: 6] = np.fromstring(' '.join(list(binary)), dtype=int, sep=' ') + if cur_idx[t] == self.max_note_count-1: + continue + cur_idx[t] += 1 + pr_mat3d[np.arange(0, compress_pr.shape[0]), cur_idx, 0] = self.pitch_eos_ind + return pr_mat3d + + def grid2pr(self, grid): + #grid: (time, max_simu_note, 6) + if grid.shape[1] == self.max_note_count: + grid = grid[:, 1:] + pr = np.zeros((grid.shape[0], 128), dtype=int) + for t in range(grid.shape[0]): + for n in range(grid.shape[1]): + note = grid[t, n] + if note[0] == self.pitch_eos_ind: + break + pitch = note[0] + self.min_pitch + dur = int(''.join([str(_) for _ in note[1:]]), 2) + 1 + pr[t, pitch] = dur + return pr + + def expand_chord(self, chord, shift, relative=False): + # chord = np.copy(chord) + root = (chord[0] + shift) % 12 + chroma = np.roll(chord[1: 13], shift) + bass = (chord[13] + shift) % 12 + root_onehot = np.zeros(12) + root_onehot[int(root)] = 1 + bass_onehot = np.zeros(12) + bass_onehot[int(bass)] = 1 + if not relative: + pass + # chroma = np.roll(chroma, int(root)) + # print(chroma) + # print('----------') + return np.concatenate([root_onehot, chroma, bass_onehot]) + + \ No newline at end of file diff --git a/orchestrator/scripts/data_preprocessing/slakh_quantization.py b/orchestrator/scripts/data_preprocessing/slakh_quantization.py new file mode 100644 index 0000000..21713b9 --- /dev/null +++ b/orchestrator/scripts/data_preprocessing/slakh_quantization.py @@ -0,0 +1,223 @@ +import os +import numpy as np +import pretty_midi as pyd +import soundfile as sf +from tqdm import tqdm +from math import ceil +from scipy.interpolate import interp1d +from quantization_utils import Converter +import librosa +import yaml + +import sys +sys.path.append('../exported_midi_chord_recognition') +from mir import DataEntry +from mir import io +from extractors.midi_utilities import get_valid_channel_count,is_percussive_channel,MidiBeatExtractor +from main import process_chord +import mir_eval + +from time_stretch import time_stretch + + +def midi2matrix(midi, quaver): + """ + Convert multi-track midi to a 3D matrix of shape (Track, Time, 128). + Each cell is a integer number representing quantized duration. + """ + pr_matrices = [] + programs = [] + quantization_error = [] + for track in midi.instruments: + qt_error = [] # record quantization error + pr_matrix = np.zeros((len(quaver), 128)) + for note in track.notes: + note_start = np.argmin(np.abs(quaver - note.start)) + note_end = np.argmin(np.abs(quaver - note.end)) + if note_end == note_start: + note_end = min(note_start + 1, len(quaver) - 1) # guitar/bass plunk typically results in a very short note duration. These note should be quantized to 1 instead of 0. + pr_matrix[note_start, note.pitch] = note_end - note_start + + #compute quantization error. A song with very high error (e.g., triple-quaver songs) will be discriminated and therefore discarded. + if note_end == note_start: + qt_error.append(np.abs(quaver[note_start] - note.start) / (quaver[note_start] - quaver[note_start-1])) + else: + qt_error.append(np.abs(quaver[note_start] - note.start) / (quaver[note_end] - quaver[note_start])) + + pr_matrices.append(pr_matrix) + programs.append(track.program) + quantization_error.append(np.mean(qt_error)) + + return np.array(pr_matrices, dtype=np.uint8), programs, quantization_error + + +def extrac_chord_matrix(midi_path, quaver, extra_division=1): + ''' + Perform chord recognition on a midi + ''' + entry = DataEntry() + entry.append_file(midi_path, io.MidiIO, 'midi') + entry.append_extractor(MidiBeatExtractor, 'beat') + result = process_chord(entry, extra_division) + + beat_quaver = quaver[::4] + chord_matrix = np.zeros((len(beat_quaver), 14), dtype=np.uint8) + chord_matrix[:, 0] = -1 + chord_matrix[:, -1] = -1 + + for chord in result: + chord_start = np.argmin(np.abs(beat_quaver - chord[0])) + chord_end = np.argmin(np.abs(beat_quaver - chord[1])) + root, bitmap, bass_rel = mir_eval.chord.encode(chord[2]) + chroma = mir_eval.chord.rotate_bitmap_to_root(bitmap, root) + chord = np.concatenate(([root], chroma, [bass_rel]), axis=-1) + chord_matrix[chord_start: chord_end, :] = chord + return chord_matrix + + +TGT_SR = 22050 +HOP_LGTH = 512 +STRETCH_BPM = 100 +AUD_DIR = '../slakh2100_flac_redux/' + +def pad_audio_npy(audio, beat_secs, exceed_frames=1000): + """ + This operation generates a copy of the wav and ensures + len(copy) >= frame of beat_secs[-1] * TGT_SR + exceed_frames + """ + last_beat_frame = beat_secs[-1] * TGT_SR + last_audio_frame = len(audio) - 1 + if last_audio_frame < last_beat_frame + exceed_frames: + pad_data = np.zeros(ceil(last_beat_frame + exceed_frames), + dtype=np.float32) + pad_data[0: len(audio)] = audio + else: + pad_data = audio.copy() + return pad_data + + +def stretch_a_song(beat_secs, audio, tgt_bpm=100, exceed_frames=1000): + """Stretch the audio to constant bpm=tgt_bpm.""" + data = pad_audio_npy(audio, beat_secs, exceed_frames=exceed_frames) + pad_start = 0 + if beat_secs[0] > HOP_LGTH / TGT_SR: + critical_beats = np.insert(beat_secs, 0, 0) + beat_dict = dict(zip(beat_secs, + np.arange(0, len(beat_secs)) + 1)) + pad_start = 1 + else: + critical_beats = beat_secs + beat_dict = dict(zip(beat_secs, + np.arange(0, len(beat_secs)))) + + critical_frames = critical_beats * TGT_SR + critical_frames = np.append(critical_frames, len(data)) + + frame_intervals = np.diff(critical_frames) + tgt_interval = (60 / tgt_bpm) * TGT_SR + rates = frame_intervals / tgt_interval + + steps = [np.arange(critical_frames[i] / HOP_LGTH, + critical_frames[i + 1] / HOP_LGTH, + rates[i]) + for i in range(len(frame_intervals))] + + time_steps = np.concatenate(steps, dtype=float) + + fpb = np.ceil((tgt_interval / HOP_LGTH)) * HOP_LGTH + len_stretch = int(fpb * len(steps)) + + stretched_song = time_stretch(data, time_steps, len_stretch, + center=False) + beat_steps = [int(i * fpb) for i in range(len(steps))] + if pad_start: + beat_steps = beat_steps[1:] + return stretched_song, beat_steps, int(fpb), rates + + + +ACC = 4 + +slakh_root = '../slakh2100_flac_redux' +save_root = '../slakh2100_flac_redux/4_bin_quantization/' +for split in ['train', 'validation', 'test']: + slakh_split = os.path.join(slakh_root, split) + save_split = os.path.join(save_root, split) + if not os.path.exists(save_split): + os.mkdir(save_split) + print(f'processing {split} set ...') + for song in tqdm(os.listdir(slakh_split)): + break_flag = 0 + + all_src_midi = pyd.PrettyMIDI(os.path.join(slakh_split, song, 'all_src.mid')) + for ts in all_src_midi.time_signature_changes: + if not (((ts.numerator == 2) or (ts.numerator == 4)) and (ts.denominator == 4)): + break_flag = 1 + break + if break_flag: + continue # process only 2/4 and 4/4 songs + + tracks = os.path.join(slakh_split, song, 'MIDI') + track_names = os.listdir(tracks) + track_midi = [pyd.PrettyMIDI(os.path.join(tracks, track)) for track in track_names] + track_meta = yaml.safe_load(open(os.path.join(slakh_split, song, 'metadata.yaml'), 'r'))['stems'] + + if len(all_src_midi.get_beats()) >= max([len(midi.get_beats()) for midi in track_midi]): + beats = all_src_midi.get_beats() + downbeats = all_src_midi.get_downbeats() + else: + beats = track_midi[np.argmax([len(midi.get_beats()) for midi in track_midi])].get_beats() + downbeats = track_midi[np.argmax([len(midi.get_beats()) for midi in track_midi])].get_downbeats() + + beats = np.append(beats, beats[-1] + (beats[-1] - beats[-2])) + quantize = interp1d(np.array(range(0, len(beats))) * ACC, beats, kind='linear') + quaver = quantize(np.array(range(0, (len(beats) - 1) * ACC))) + + pr_matrices = [] + programs = [] + + break_flag = 0 + for idx, midi in enumerate(track_midi): + meta = track_meta[track_names[idx].replace('.mid', '')] + if meta['is_drum']: + continue #let's skip drum for now + pr_matrix, _, track_qt = midi2matrix(midi, quaver) + if track_qt[0] > .2: + break_flag = 1 + break + pr_matrices.append(pr_matrix) + programs.append(meta['program_num']) + if break_flag: + continue #skip the pieces with very large quantization error. This pieces are possibly triple-quaver songs + + pr_matrices = np.concatenate(pr_matrices, axis=0, dtype=np.uint8) + programs = np.array(programs, dtype=np.uint8) + + chord_matrix = extrac_chord_matrix(os.path.join(slakh_split, song, 'all_src.mid'), quaver) + + audio, _ = librosa.load(os.path.join(slakh_split, song, 'drum_detach', 'drum_detach_22050.wav'), sr=TGT_SR) + audio_strech, beat_steps, fpb, _ = stretch_a_song(beats[:-1], audio, tgt_bpm=STRETCH_BPM) + sf.write(os.path.join(slakh_split, song, 'drum_detach', f'drum_detach_22050_{STRETCH_BPM}.wav'), audio_strech, TGT_SR, 'PCM_16') + + + db_indicator = np.array([int(t in downbeats) for t in quaver], dtype=np.uint8) + db_frame = [beat_steps[i] for i in range(len(beats[:-1])) if beats[i] in downbeats] + + assert(len(np.nonzero(db_indicator)[0]) == len(db_frame)) + + #print(db_frame) + #print(len(db_frame), len(np.nonzero(db_indicator)[0])) + + #print(beat_steps) + #print(len(beat_steps), len(downbeats)) + #print(fpb) + + np.savez(os.path.join(save_split, f'{song}.npz'),\ + tracks = pr_matrices,\ + programs = programs,\ + chord = chord_matrix,\ + db_indicator = db_indicator,\ + db_frame = db_frame,\ + fpb = fpb) + + \ No newline at end of file diff --git a/orchestrator/scripts/objective_evaluation_arrangement.ipynb b/orchestrator/scripts/objective_evaluation_arrangement.ipynb new file mode 100644 index 0000000..291ff89 --- /dev/null +++ b/orchestrator/scripts/objective_evaluation_arrangement.ipynb @@ -0,0 +1,322 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 66/66 [01:47<00:00, 1.63s/it]\n", + "100%|██████████| 66/66 [01:48<00:00, 1.64s/it]\n", + "100%|██████████| 66/66 [01:50<00:00, 1.67s/it]\n" + ] + } + ], + "source": [ + "import os\n", + "import numpy as np\n", + "import scipy\n", + "from utils import midi2matrix\n", + "import pretty_midi as pyd\n", + "from tqdm import tqdm\n", + "import sys\n", + "sys.path.append('../exported_midi_chord_recognition')\n", + "from main import transcribe_cb1000_midi\n", + "import scipy.stats as st\n", + "ACC = 4\n", + "\n", + "def load_mixture_acc(path, melody_id=None):\n", + " multi_track = pyd.PrettyMIDI(path)\n", + " beats = multi_track.get_beats()\n", + " beats = np.append(beats, beats[-1] + (beats[-1] - beats[-2]))\n", + " quantize = scipy.interpolate.interp1d(np.array(range(0, len(beats))) * ACC, beats, kind='linear')\n", + " quaver = quantize(np.array(range(0, (len(beats) - 1) * ACC)))\n", + " multi_track, _ = midi2matrix(multi_track, quaver)\n", + " multi_track = multi_track[:, :multi_track.shape[1]//16*16]\n", + " if melody_id is not None:\n", + " multi_track = np.delete(multi_track, melody_id, axis=0)\n", + " mixture = np.max(multi_track, axis=0)\n", + "\n", + " pitch_hist = np.sum(mixture[:, :120].reshape(-1, 10, 12), axis=-2)\n", + " pitch_hist = np.sum(pitch_hist.reshape(-1, 16, 12), axis=-2)\n", + " grooves = np.sum(mixture[:, :] > 0, axis=-1).reshape(-1, 16)\n", + " grooves[grooves>0] = 1\n", + " return pitch_hist, grooves\n", + "\n", + "def load_multi_track_acc(path, melody_id=None):\n", + " orchestration = pyd.PrettyMIDI(path)\n", + " beats = orchestration.get_beats()\n", + " beats = np.append(beats, beats[-1] + (beats[-1] - beats[-2]))\n", + " quantize = scipy.interpolate.interp1d(np.array(range(0, len(beats))) * ACC, beats, kind='linear')\n", + " quaver = quantize(np.array(range(0, (len(beats) - 1) * ACC)))\n", + " orchestration, _ = midi2matrix(orchestration, quaver)\n", + " orchestration = orchestration[:, :orchestration.shape[1]//16*16]\n", + " if melody_id is not None:\n", + " orchestration = np.delete(orchestration, melody_id, axis=0)\n", + " \n", + " pitch_hist = np.sum(orchestration[:, :, :120].reshape(len(orchestration), -1, 10, 12), axis=-2)\n", + " pitch_hist = np.sum(pitch_hist.reshape(len(orchestration), -1, 16, 12), axis=-2)\n", + " grooves = np.sum(orchestration > 0, axis=-1).reshape(len(orchestration), -1, 16)\n", + " grooves[grooves>0] = 1\n", + " return pitch_hist, grooves\n", + "\n", + "\n", + "def pitch_historgam_entropy(histo_mix):\n", + " #histo_mix: num_bar x 12\n", + " empty = np.nonzero(np.sum(histo_mix, axis=-1) == 0)[0]\n", + " if len(empty) > 0:\n", + " histo_mix = np.delete(histo_mix, empty, axis=0)\n", + " return np.mean([scipy.stats.entropy(bar) for bar in histo_mix])\n", + "\n", + "def groove_consistency(grooves_mix):\n", + " #grooves_mix: num_bar x 16\n", + " empty = np.nonzero(np.sum(grooves_mix, axis=-1) == 0)[0]\n", + " if len(empty) > 0:\n", + " grooves_mix = np.delete(grooves_mix, empty, axis=0)\n", + " results = []\n", + " for i in range(len(grooves_mix)):\n", + " for j in range(len(grooves_mix)):\n", + " results.append(1 - np.sum((grooves_mix[i] * grooves_mix[j]) == 0) / 16)\n", + " return np.mean(results)\n", + "\n", + "\n", + "def structure_dynamics(grooves_mix, phrase_seg=[8, 8, 8, 8]):\n", + " #grooves_mix: num_bar x 16\n", + " #print(len(grooves_mix), np.sum(phrase_seg))\n", + " if len(grooves_mix) > np.sum(phrase_seg):\n", + " grooves_mix = grooves_mix[:np.sum(phrase_seg)]\n", + " assert(len(grooves_mix) == np.sum(phrase_seg))\n", + " results = []\n", + " start = 0\n", + " for p_len in phrase_seg:\n", + " in_phrase_result = []\n", + " for i in range(start, start+p_len):\n", + " for j in range(start, start+p_len):\n", + " in_phrase_result.append(1 - np.sum((grooves_mix[i] * grooves_mix[j]) == 0) / 16)\n", + " out_phrase_result = []\n", + " for i in range(start, start+p_len):\n", + " for j in range(len(grooves_mix)):\n", + " if (j < start) or (j >= start+p_len):\n", + " out_phrase_result.append(1 - np.sum((grooves_mix[i] * grooves_mix[j]) == 0) / 16)\n", + " start += p_len\n", + " results.append(np.mean(in_phrase_result) / np.mean(out_phrase_result))\n", + " return np.mean(results)\n", + "\n", + "\n", + "def track_wise_entropy(multi_track):\n", + " #multi_track: n_track x num_bar x 12\n", + " return np.mean([pitch_historgam_entropy(track) for track in multi_track])\n", + "def track_wise_consistency(multi_track):\n", + " #multi_track: n_track x num_bar x 12\n", + " return np.mean([groove_consistency(track) for track in multi_track])\n", + "\n", + "def chord_comparator(path, name):\n", + " chord_1 = []\n", + " chord = transcribe_cb1000_midi(os.path.join(path, 'lead_sheet.mid'), output_path=None)\n", + " INCRE = 60 / pyd.PrettyMIDI(os.path.join(path, 'lead_sheet.mid')).get_tempo_changes()[1][0]\n", + " for item in chord:\n", + " chord_1 += [item[-1].split('/')[0]]*int(round((item[1]-item[0]) / INCRE))\n", + " chord_2 = []\n", + " chord = transcribe_cb1000_midi(os.path.join(path, f'{name}.mid'), output_path=None)\n", + " INCRE = 60 / pyd.PrettyMIDI(os.path.join(path, f'{name}.mid')).get_tempo_changes()[1][0]\n", + " for item in chord:\n", + " chord_2 += [item[-1].split('/')[0]]*int(round((item[1]-item[0]) / INCRE))\n", + " if not (len(chord_1) == len(chord_2)):\n", + " #print('chord', len(chord_1), len(chord_2))\n", + " #print(chord_1)\n", + " #print(chord_2)\n", + " chord_len = min(len(chord_1), len(chord_2))\n", + " chord_1 = chord_1[:chord_len]\n", + " chord_2 = chord_2[:chord_len]\n", + " #assert(len(chord_1) == len(chord_2))\n", + " result = 0\n", + " for i in range(len(chord_1)):\n", + " if chord_1[i] == chord_2[i]:\n", + " result += 1\n", + " #else:\n", + " # print(chord_1[i], chord_2[i])\n", + " result = result / len(chord_1)\n", + " return result\n", + "\n", + "with open(\"../nottingham_database/phrase_cleaned.txt\", 'r') as f:\n", + " phrases = f.readlines()\n", + "phrase_dict = {}\n", + "for item in phrases:\n", + " p_len =[]\n", + " for i in item.split('\\t')[1][1::2]:\n", + " p_len.append(int(i))\n", + " phrase_dict[item.split('\\t')[0]] = p_len\n", + "#print(phrase_dict)\n", + "\n", + "results = {'AccoMontage3': {'p_etr': [], 'g_cst': [], 'dyn': [], 'trk_p_etr': [], 'trk_g_etr': [], 'chd_acc': []}, \\\n", + " 'Jianianhua': {'p_etr': [], 'g_cst': [], 'dyn': [], 'trk_p_etr': [], 'trk_g_etr': [], 'chd_acc': []}, \\\n", + " 'PopMAG': {'p_etr': [], 'g_cst': [], 'dyn': [], 'trk_p_etr': [], 'trk_g_etr': [], 'chd_acc': []}\\\n", + " }\n", + "melody_id_dict = {'AccoMontage3':-1, 'Jianianhua':0, 'PopMAG':0}\n", + "for demo in [1, 2, 3]:\n", + " demo_root = f\"arrangement/demo_{demo}\"\n", + " for song in tqdm(os.listdir(demo_root)):\n", + " phrase = phrase_dict[song]\n", + " for model in ['AccoMontage3', 'Jianianhua', 'PopMAG']:\n", + " histo_mix, grooves_mix = load_mixture_acc(os.path.join(demo_root, song, f'{model}.mid'), melody_id=melody_id_dict[model])\n", + " histo_track, grooves_track = load_multi_track_acc(os.path.join(demo_root, song, f'{model}.mid'), melody_id=melody_id_dict[model])\n", + " results[model]['p_etr'].append(pitch_historgam_entropy(histo_mix))\n", + " results[model]['g_cst'].append(groove_consistency(grooves_mix))\n", + " results[model]['dyn'].append(structure_dynamics(grooves_mix, phrase_seg=phrase))\n", + " results[model]['trk_p_etr'].append(track_wise_entropy(histo_track))\n", + " results[model]['trk_g_etr'].append(track_wise_consistency(grooves_track))\n", + " results[model]['chd_acc'].append(chord_comparator(os.path.join(demo_root, song), model))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "AccoMontage3\t p-Etr=1.29535 + 0.01522\t g-Cst=0.52693 + 0.01979\t s-dyn=1.08076 + 0.00778\t trk-p=0.76001 + 0.01542\t trk-g=0.14174 + 0.00485\t chd=0.72168 + 0.02024\n", + "Jianianhua\t p-Etr=1.32391 + 0.02046\t g-Cst=0.55623 + 0.03252\t s-dyn=1.05569 + 0.01150\t trk-p=0.91413 + 0.02130\t trk-g=0.26791 + 0.01198\t chd=0.70607 + 0.02170\n", + "PopMAG\t p-Etr=1.32200 + 0.01927\t g-Cst=0.46709 + 0.02251\t s-dyn=1.05669 + 0.00581\t trk-p=0.76200 + 0.02252\t trk-g=0.10504 + 0.00455\t chd=0.60882 + 0.02045\n" + ] + } + ], + "source": [ + "for key in results:\n", + " print(f\"{key}\\t p-Etr={np.mean(results[key]['p_etr']):.5f} + {st.sem(results[key]['p_etr']) * scipy.stats.t.ppf((1 + 0.95) / 2., len(results[key]['p_etr'])-1):.5f}\\t\\\n", + " g-Cst={np.mean(results[key]['g_cst']):.5f} + {st.sem(results[key]['g_cst']) * scipy.stats.t.ppf((1 + 0.95) / 2., len(results[key]['g_cst'])-1):.5f}\\t\\\n", + " s-dyn={np.mean(results[key]['dyn']):.5f} + {st.sem(results[key]['dyn']) * scipy.stats.t.ppf((1 + 0.95) / 2., len(results[key]['dyn'])-1):.5f}\\t\\\n", + " trk-p={np.mean(results[key]['trk_p_etr']):.5f} + {st.sem(results[key]['trk_p_etr']) * scipy.stats.t.ppf((1 + 0.95) / 2., len(results[key]['trk_p_etr'])-1):.5f}\\t\\\n", + " trk-g={np.mean(results[key]['trk_g_etr']):.5f} + {st.sem(results[key]['trk_g_etr']) * scipy.stats.t.ppf((1 + 0.95) / 2., len(results[key]['trk_g_etr'])-1):.5f}\\t\\\n", + " chd={np.mean(results[key]['chd_acc']):.5f} + {st.sem(results[key]['chd_acc']) * scipy.stats.t.ppf((1 + 0.95) / 2., len(results[key]['chd_acc'])-1):.5f}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "demo_root = \"../lmd\"\n", + "results = {'Real': {'p_etr': [], 'g_cst': [], 'trk_p_etr': [], 'trk_g_etr': []}}\n", + "melody_id_dict = {'AccoMontage3':-1, 'Jianianhua':0, 'PopMAG':0}\n", + "for song in tqdm(os.listdir(demo_root)):\n", + " try:\n", + " histo_mix, grooves_mix = load_mixture_acc(os.path.join(demo_root, song))\n", + " histo_track, grooves_track = load_multi_track_acc(os.path.join(demo_root, song))\n", + " results['Real']['p_etr'].append(pitch_historgam_entropy(histo_mix))\n", + " results['Real']['g_cst'].append(groove_consistency(grooves_mix))\n", + " results['Real']['trk_p_etr'].append(track_wise_entropy(histo_track))\n", + " results['Real']['trk_g_etr'].append(track_wise_consistency(grooves_track))\n", + " except:\n", + " continue\n" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Real\t p-Etr=1.55543 + 0.01577\t g-Cst=0.52725 + 0.01340\t trk-p=nan + nan\t trk-g=nan + nan\t\n" + ] + } + ], + "source": [ + "for key in results:\n", + " print(f\"{key}\\t p-Etr={np.mean(results[key]['p_etr']):.5f} + {st.sem(results[key]['p_etr']) * scipy.stats.t.ppf((1 + 0.95) / 2., len(results[key]['p_etr'])-1):.5f}\\t\\\n", + " g-Cst={np.mean(results[key]['g_cst']):.5f} + {st.sem(results[key]['g_cst']) * scipy.stats.t.ppf((1 + 0.95) / 2., len(results[key]['g_cst'])-1):.5f}\\t\\\n", + " trk-p={np.mean(results[key]['trk_p_etr']):.5f} + {st.sem(results[key]['trk_p_etr']) * scipy.stats.t.ppf((1 + 0.95) / 2., len(results[key]['trk_p_etr'])-1):.5f}\\t\\\n", + " trk-g={np.mean(results[key]['trk_g_etr']):.5f} + {st.sem(results[key]['trk_g_etr']) * scipy.stats.t.ppf((1 + 0.95) / 2., len(results[key]['trk_g_etr'])-1):.5f}\\t\")" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.90548 + 0.00016\n" + ] + } + ], + "source": [ + "new_list = []\n", + "for item in results[key]['trk_p_etr']:\n", + " if np.isnan(item):\n", + " continue\n", + " new_list.append(item)\n", + "print(f\"{np.mean(new_list):.5f} + {st.sem(new_list) * scipy.stats.t.ppf((1 + 0.95) / 2., len(new_list)-1):.5f}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.23359 + 0.00747\n" + ] + } + ], + "source": [ + "new_list = []\n", + "for item in results[key]['trk_g_etr']:\n", + " if np.isnan(item):\n", + " continue\n", + " new_list.append(item)\n", + "print(f\"{np.mean(new_list):.5f} + {st.sem(new_list) * scipy.stats.t.ppf((1 + 0.95) / 2., len(new_list)-1):.5f}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "torch1.9", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.12" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/orchestrator/scripts/objective_evaluation_orchestration.ipynb b/orchestrator/scripts/objective_evaluation_orchestration.ipynb new file mode 100644 index 0000000..cf4abaf --- /dev/null +++ b/orchestrator/scripts/objective_evaluation_orchestration.ipynb @@ -0,0 +1,180 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 57, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 55/55 [00:18<00:00, 3.04it/s]\n", + "100%|██████████| 55/55 [00:17<00:00, 3.12it/s]\n", + "100%|██████████| 55/55 [00:17<00:00, 3.17it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Q&A-XL_blur_0\t p-Sim=0.92118 + 0.00809\t p-Etr=1.63749 + 0.04753\t g-Sim=0.86462 + 0.00964\t g-Etr=1.62582 + 0.04782\n", + "Q&A-XL_blur_0.5\t p-Sim=0.91967 + 0.00775\t p-Etr=1.90828 + 0.03603\t g-Sim=0.80587 + 0.01162\t g-Etr=1.88433 + 0.03582\n", + "Q&A-XL_blur_1\t p-Sim=0.91661 + 0.00792\t p-Etr=2.02799 + 0.03684\t g-Sim=0.78441 + 0.01286\t g-Etr=2.00000 + 0.03614\n", + "Q&A\t p-Sim=0.88574 + 0.00988\t p-Etr=1.63581 + 0.03543\t g-Sim=0.73230 + 0.01491\t g-Etr=1.60909 + 0.03615\n", + "Arranger-2\t p-Sim=0.98862 + 0.00635\t p-Etr=0.34147 + 0.04460\t g-Sim=0.98900 + 0.00637\t g-Etr=0.33764 + 0.04355\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "import os\n", + "import numpy as np\n", + "import scipy\n", + "from utils import midi2matrix\n", + "import pretty_midi as pyd\n", + "from tqdm import tqdm\n", + "import scipy.stats as st\n", + "ACC = 4\n", + "\n", + "def piano_mixture_similarity(piano, mixture):\n", + " num_bar = min(len(piano), len(mixture))\n", + " piano = piano[:num_bar]\n", + " mixture = mixture[:num_bar]\n", + " sim = np.array([np.dot(piano[i], mixture[i]) for i in range(piano.shape[0])]) / (np.linalg.norm(piano, axis=-1) * np.linalg.norm(mixture, axis=-1) + 1e-5)\n", + " return np.mean(sim, axis=0) #scalar\n", + "\n", + "\n", + "def degree_of_orchestration(piano, multi_track):\n", + " #multi_track: num_bar x 12\n", + " #multi_track: n_track x num_bar x 12\n", + " num_bar = min(len(piano), multi_track.shape[1])\n", + " piano = piano[:num_bar]\n", + " multi_track = multi_track[:, :num_bar]\n", + " histogram = np.array([(np.dot(multi_track[:, idx], piano[idx]) + 1e-10) for idx in range(len(piano))]) / (np.linalg.norm(piano, axis=-1)[:, np.newaxis] * np.linalg.norm(multi_track.transpose(1, 0, 2), axis=-1) + 1e-5) #(num_bar, n_track)\n", + " #print(np.sum(histogram, axis=-1))\n", + " return np.mean([scipy.stats.entropy(bar) for bar in histogram])\n", + "\n", + "def load_piano_acc(path):\n", + " piano = pyd.PrettyMIDI(path)\n", + " beats = piano.get_beats()\n", + " beats = np.append(beats, beats[-1] + (beats[-1] - beats[-2]))\n", + " quantize = scipy.interpolate.interp1d(np.array(range(0, len(beats))) * ACC, beats, kind='linear')\n", + " quaver = quantize(np.array(range(0, (len(beats) - 1) * ACC)))\n", + " piano, prog = midi2matrix(piano, quaver)\n", + " piano = piano[:, :piano.shape[1]//16*16]\n", + "\n", + " pitch_hist = np.sum(piano[1:, :, :120].reshape(len(prog)-1, -1, 10, 12), axis=(0, -2))\n", + " pitch_hist = np.sum(pitch_hist.reshape(-1, 16, 12), axis=-2)\n", + " grooves = np.sum(piano[1:, :, :] > 0, axis=(0, -1)).reshape(-1, 16)\n", + " return pitch_hist, grooves\n", + "\n", + "def load_multi_track_acc(path, melody_id):\n", + " orchestration = pyd.PrettyMIDI(path)\n", + " beats = orchestration.get_beats()\n", + " beats = np.append(beats, beats[-1] + (beats[-1] - beats[-2]))\n", + " quantize = scipy.interpolate.interp1d(np.array(range(0, len(beats))) * ACC, beats, kind='linear')\n", + " quaver = quantize(np.array(range(0, (len(beats) - 1) * ACC)))\n", + " orchestration, prog = midi2matrix(orchestration, quaver)\n", + " orchestration = orchestration[:, :orchestration.shape[1]//16*16]\n", + " orchestration = np.delete(orchestration, melody_id, axis=0)\n", + " \n", + " pitch_hist = np.sum(orchestration[:, :, :120].reshape(len(prog)-1, -1, 10, 12), axis=-2)\n", + " pitch_hist = np.sum(pitch_hist.reshape(len(prog)-1, -1, 16, 12), axis=-2)\n", + " grooves = np.sum(orchestration > 0, axis=-1).reshape(len(prog)-1, -1, 16)\n", + " return pitch_hist, grooves\n", + "\n", + "results = {'Q&A-XL_blur_0': {'p_sim': [], 'p_entro': [], 'g_sim': [], 'g_entro': []}, \\\n", + " 'Q&A-XL_blur_0.5': {'p_sim': [], 'p_entro': [], 'g_sim': [], 'g_entro': []}, \\\n", + " 'Q&A-XL_blur_1': {'p_sim': [], 'p_entro': [], 'g_sim': [], 'g_entro': []}, \\\n", + " 'Q&A': {'p_sim': [], 'p_entro': [], 'g_sim': [], 'g_entro': []}, \\\n", + " 'Arranger-2': {'p_sim': [], 'p_entro': [], 'g_sim': [], 'g_entro': []}\\\n", + " }\n", + "melody_id_dict = {'Q&A-XL_blur_0':-1, 'Q&A-XL_blur_0.5':-1, 'Q&A-XL_blur_1':-1, 'Q&A':-1, 'Arranger-2':0}\n", + "count = 0\n", + "for demo in [1, 2, 3]:\n", + " demo_root = f'../orchestration_with_ablation/demo_{demo}'\n", + " for song in tqdm(os.listdir(demo_root)):\n", + " histo_pno, grooves_pno = load_piano_acc(os.path.join(demo_root, song, 'piano_recon.mid'))\n", + " for model in ['Q&A-XL_blur_0', 'Q&A-XL_blur_0.5', 'Q&A-XL_blur_1', 'Q&A', 'Arranger-2']:\n", + " \n", + " histo_orch, grooves_orch = load_multi_track_acc(os.path.join(demo_root, song, f'{model}.mid'), melody_id=melody_id_dict[model])\n", + "\n", + " p_sim = piano_mixture_similarity(histo_pno, np.sum(histo_orch, axis=0))\n", + " p_entro = degree_of_orchestration(histo_pno, histo_orch)\n", + " g_sim = piano_mixture_similarity(grooves_pno, np.sum(grooves_orch, axis=0))\n", + " g_entro = degree_of_orchestration(grooves_pno, grooves_orch)\n", + " results[model]['p_sim'].append(p_sim)\n", + " results[model]['p_entro'].append(p_entro)\n", + " results[model]['g_sim'].append(g_sim)\n", + " results[model]['g_entro'].append(g_entro)\n", + "\n", + "for key in results:\n", + " print(f\"{key}\\t p-Sim={np.mean(results[key]['p_sim']):.5f} + {st.sem(results[key]['p_sim']) * scipy.stats.t.ppf((1 + 0.95) / 2., len(results[key]['p_sim'])-1):.5f}\\t\\\n", + " p-Etr={np.mean(results[key]['p_entro']):.5f} + {st.sem(results[key]['p_entro']) * scipy.stats.t.ppf((1 + 0.95) / 2., len(results[key]['p_entro'])-1):.5f}\\t\\\n", + " g-Sim={np.mean(results[key]['g_sim']):.5f} + {st.sem(results[key]['g_sim']) * scipy.stats.t.ppf((1 + 0.95) / 2., len(results[key]['g_sim'])-1):.5f}\\t\\\n", + " g-Etr={np.mean(results[key]['g_entro']):.5f} + {st.sem(results[key]['g_entro']) * scipy.stats.t.ppf((1 + 0.95) / 2., len(results[key]['g_entro'])-1):.5f}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 60, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Q&A-XL_blur_0\t p-Sim=0.92118 + 0.00809\t p-Etr=1.63749 + 0.04753\t g-Sim=0.86462 + 0.00964\t g-Etr=1.62582 + 0.04782\n", + "Q&A-XL_blur_0.5\t p-Sim=0.91967 + 0.00775\t p-Etr=1.90828 + 0.03603\t g-Sim=0.80587 + 0.01162\t g-Etr=1.88433 + 0.03582\n", + "Q&A-XL_blur_1\t p-Sim=0.91661 + 0.00792\t p-Etr=2.02799 + 0.03684\t g-Sim=0.78441 + 0.01286\t g-Etr=2.00000 + 0.03614\n", + "Q&A\t p-Sim=0.88574 + 0.00988\t p-Etr=1.63581 + 0.03543\t g-Sim=0.73230 + 0.01491\t g-Etr=1.60909 + 0.03615\n", + "Arranger-2\t p-Sim=0.98862 + 0.00635\t p-Etr=0.34147 + 0.04460\t g-Sim=0.98900 + 0.00637\t g-Etr=0.33764 + 0.04355\n" + ] + } + ], + "source": [ + "for key in results:\n", + " print(f\"{key}\\t p-Sim={np.mean(results[key]['p_sim']):.5f} + {st.sem(results[key]['p_sim']) * scipy.stats.t.ppf((1 + 0.95) / 2., len(results[key]['p_sim'])-1):.5f}\\t\\\n", + " p-Etr={np.mean(results[key]['p_entro']):.5f} + {st.sem(results[key]['p_entro']) * scipy.stats.t.ppf((1 + 0.95) / 2., len(results[key]['p_entro'])-1):.5f}\\t\\\n", + " g-Sim={np.mean(results[key]['g_sim']):.5f} + {st.sem(results[key]['g_sim']) * scipy.stats.t.ppf((1 + 0.95) / 2., len(results[key]['g_sim'])-1):.5f}\\t\\\n", + " g-Etr={np.mean(results[key]['g_entro']):.5f} + {st.sem(results[key]['g_entro']) * scipy.stats.t.ppf((1 + 0.95) / 2., len(results[key]['g_entro'])-1):.5f}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "torch1.9", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.12" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/orchestrator/train_Prior_DDP.py b/orchestrator/train_Prior_DDP.py new file mode 100644 index 0000000..2a22bf0 --- /dev/null +++ b/orchestrator/train_Prior_DDP.py @@ -0,0 +1,237 @@ +import os +import time +import torch +from torch import optim +from Prior import Prior +from vq_dataset import VQ_LMD_Dataset, collate_fn +from torch.utils.data import DataLoader +from scheduler import MinExponentialLR, OptimizerScheduler, TeacherForcingScheduler, ConstantScheduler, ParameterScheduler +from utils import SummaryWriters, LogPathManager, epoch_time +from tqdm import tqdm + +import torch.multiprocessing as mp +from torch.utils.data.distributed import DistributedSampler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.distributed import init_process_group, destroy_process_group + + +def ddp_setup(rank, world_size): + """ + Args: + rank: Unique identifier of each process + world_size: Total number of processes + """ + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "12355" + init_process_group(backend="nccl", rank=rank, world_size=world_size) + + +def main(rank, world_size, log_path_mng, VERBOSE, MODEL_NAME): + #print('rank:', rank) + ddp_setup(rank, world_size) + + PRETRAIN_PATH = "data_file_dir/params_qa.pt" + BATCH_SIZE = 32 + N_EPOCH = 30 + CLIP = 3 + LR = 1e-3 + + if VERBOSE: + N_EPOCH=10 + + model = Prior.init_model(pretrain_model_path=PRETRAIN_PATH, DEVICE=rank) + model = DDP(model, device_ids=[rank], find_unused_parameters=False) + + lmd_dir = "/data1/LMD/vector_quantization_029/" + train_set = VQ_LMD_Dataset(lmd_dir, debug_mode=VERBOSE, split='train', mode='train') + train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=False, collate_fn=lambda b: collate_fn(b, rank), sampler=DistributedSampler(train_set)) + val_set = VQ_LMD_Dataset(lmd_dir, debug_mode=VERBOSE, split='validation', mode='train') + val_loader = DataLoader(val_set, batch_size=BATCH_SIZE, shuffle=False, collate_fn=lambda b: collate_fn(b, rank), sampler=DistributedSampler(val_set)) + print(f'Dataset loaded. {len(train_loader)} samples for train and {len(val_loader)} samples for validation.') + + + optimizer = optim.Adam(model.parameters(), lr=LR) + scheduler = MinExponentialLR(optimizer, gamma=0.99996, minimum=1e-5) + #scheduler = None + optimizer_scheduler = OptimizerScheduler(optimizer, scheduler, CLIP) + #tfr_scheduler = TeacherForcingScheduler(*TFR, scaler=N_EPOCH*len(train_loader)) + #params_dic = dict(tfr=tfr_scheduler) + param_scheduler = None #ParameterScheduler(**params_dic) + + writer_names = ['loss', 'fp_l', 'ft_l'] + scheduler_writer_names = ['lr'] + + if rank == 0: + tags = {'loss': None} + loss_writers = SummaryWriters(writer_names, tags, log_path_mng.writer_path) + tags = {'scheduler': None} + scheduler_writers = SummaryWriters(scheduler_writer_names, tags, log_path_mng.writer_path) + else: + loss_writers = None + scheduler_writers = None + + + #best_valid_loss = float('inf') + for n_epoch in range(N_EPOCH): + start_time = time.time() + train_loader.sampler.set_epoch(n_epoch) + print(f'Training epoch {n_epoch}') + train_loss = train(model, train_loader, param_scheduler, optimizer_scheduler, writer_names, loss_writers, scheduler_writers, n_epoch=n_epoch, VERBOSE=VERBOSE)['loss'] + print(f'Validating epoch {n_epoch}') + val_loss = val(model, val_loader, param_scheduler, writer_names, loss_writers, n_epoch=n_epoch, VERBOSE=VERBOSE)['loss'] + end_time = time.time() + + if rank == 0: + torch.save(model.module.state_dict(), log_path_mng.epoch_model_path(f'{MODEL_NAME}_{str(n_epoch).zfill(3)}')) + + #if val_loss < best_valid_loss: + # best_valid_loss = val_loss + # if rank == 0: + # torch.save(model.module.state_dict(), log_path_mng.valid_model_path(MODEL_NAME)) + + epoch_report(start_time, end_time, train_loss, val_loss, n_epoch) + #if rank == 0: + # torch.save(model.module.state_dict(), log_path_mng.final_model_path(MODEL_NAME)) + + destroy_process_group() + + + +def accumulate_loss_dic(writer_names, loss_dic, loss_items): + assert len(writer_names) == len(loss_items) + for key, val in zip(writer_names, loss_items): + loss_dic[key] += val.item() + return loss_dic + +def write_loss_to_dic(writer_names, loss_items): + loss_dic = {} + assert len(writer_names) == len(loss_items) + for key, val in zip(writer_names, loss_items): + loss_dic[key] = val.item() + return loss_dic + +def init_loss_dic(writer_names): + loss_dic = {} + for key in writer_names: + loss_dic[key] = 0. + return loss_dic + +def average_epoch_loss(epoch_loss_dict, num_batch): + for key in epoch_loss_dict: + epoch_loss_dict[key] /= num_batch + return epoch_loss_dict + + +def batch_report(loss, n_epoch, idx, num_batch, mode='training', verbose=False): + if verbose: + print(f'------------{mode}------------') + print('Epoch: [{0}][{1}/{2}]'.format(n_epoch, idx, num_batch)) + print(f"\t Total loss: {loss['loss']}") + print(f"\t pitch func loss: {loss['fp_l']:.3f}") + print(f"\t time func loss: {loss['ft_l']:.3f}") + + +def scheduler_show(param_scheduler, optimizer_scheduler, verbose=False): + schedule_params = {} + #schedule_params['tfr'] = param_scheduler.schedulers['tfr'].get_tfr() + schedule_params['lr'] = optimizer_scheduler.optimizer.param_groups[0]['lr'] + if verbose: + print(schedule_params) + return schedule_params + + +def train(model, dataloader, param_scheduler, optimizer_scheduler, writer_names, loss_writers, scheduler_writers, n_epoch, VERBOSE): + model.train() + #param_scheduler.train() + epoch_loss_dic = init_loss_dic(writer_names) + + for idx, batch in tqdm(enumerate(dataloader), total=len(dataloader)): + try: + optimizer_scheduler.optimizer_zero_grad() + + #input_params = param_scheduler.step() + outputs = model('loss', *batch)#, **input_params) + loss = outputs[0] + loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), optimizer_scheduler.clip) + optimizer_scheduler.step() + + epoch_loss_dic = accumulate_loss_dic(writer_names, epoch_loss_dic, outputs) + batch_loss_dic = write_loss_to_dic(writer_names, outputs) + train_step = n_epoch * len(dataloader) + idx + if loss_writers is not None: + loss_writers.write_task('train', batch_loss_dic, train_step) + batch_report(batch_loss_dic, n_epoch, idx, len(dataloader), mode='train', verbose=VERBOSE) + + scheduler_dic = scheduler_show(param_scheduler, optimizer_scheduler, verbose=VERBOSE) + if scheduler_writers is not None: + scheduler_writers.write_task('train', scheduler_dic, train_step) + except Exception as exc: + print(exc) + print(batch[0].shape, batch[1].shape) + continue + + scheduler_show(param_scheduler, optimizer_scheduler, verbose=True) + epoch_loss_dic = average_epoch_loss(epoch_loss_dic, len(dataloader)) + return epoch_loss_dic + + +def val(model, dataloader, param_scheduler, writer_names, summary_writers, n_epoch, VERBOSE): + model.eval() + #param_scheduler.eval() + epoch_loss_dic = init_loss_dic(writer_names) + + for idx, batch in tqdm(enumerate(dataloader), total=len(dataloader)): + try: + #input_params = param_scheduler.step() + with torch.no_grad(): + outputs = model('loss', *batch)#, **input_params) + epoch_loss_dic = accumulate_loss_dic(writer_names, epoch_loss_dic, outputs) + batch_loss_dic = write_loss_to_dic(writer_names, outputs) + if summary_writers is not None: + batch_report(batch_loss_dic, n_epoch, idx, len(dataloader), mode='validation', verbose=VERBOSE) + #val_step = n_epoch * len(dataloader) + idx + #summary_writers.write_task('val', batch_loss_dic, val_step) + except Exception as exc: + print(exc) + print(batch[0].shape, batch[1].shape) + continue + epoch_loss_dic = average_epoch_loss(epoch_loss_dic, len(dataloader)) + if summary_writers is not None: + summary_writers.write_task('val', epoch_loss_dic, n_epoch) + return epoch_loss_dic + + +def epoch_report(start_time, end_time, train_loss, valid_loss, n_epoch): + epoch_mins, epoch_secs = epoch_time(start_time, end_time) + print(f'Epoch: {n_epoch + 1:02} | ' + f'Time: {epoch_mins}m {epoch_secs}s', + flush=True) + print(f'\tTrain Loss: {train_loss:.3f}', flush=True) + print(f'\t Valid. Loss: {valid_loss:.3f}', flush=True) + + + + + +if __name__ == '__main__': + os.environ['CUDA_VISIBLE_DEVICES']= '2, 3' + os.environ['CUDA_LAUNCH_BLOCKING'] = '1' + + MODEL_NAME = 'Prior Model' + DEBUG = 0 + + if DEBUG: + save_root = './save' + log_path_name = 'debug' + else: + save_root = '/data1/AccoMontage3/' + log_path_name = MODEL_NAME + + + readme_fn = 'orchestrator/train_Prior_DDP.py' + log_path_mng = LogPathManager(readme_fn, save_root=save_root, log_path_name=log_path_name) + + world_size = torch.cuda.device_count() + #print(world_size) + mp.spawn(main, args=(world_size, log_path_mng, DEBUG, MODEL_NAME), nprocs=world_size) diff --git a/orchestrator/train_QandA.py b/orchestrator/train_QandA.py new file mode 100644 index 0000000..69edfce --- /dev/null +++ b/orchestrator/train_QandA.py @@ -0,0 +1,262 @@ +import os +os.environ['CUDA_VISIBLE_DEVICES']= '3' +os.environ['CUDA_LAUNCH_BLOCKING'] = '1' +import time +import torch +from torch import optim +from QandA import QandA +from dataset import Slakh_Pop909_Dataset, collate_fn +from torch.utils.data import DataLoader +from scheduler import MinExponentialLR, OptimizerScheduler, TeacherForcingScheduler, ConstantScheduler, ParameterScheduler +from utils import kl_anealing, SummaryWriters, LogPathManager, epoch_time +from tqdm import tqdm + + +DEVICE = 'cuda:0' +PARALLEL = False +BATCH_SIZE = 128 +TRF_LAYERS = 2 +N_EPOCH = 30 +CLIP = 3 +WEIGHTS = [1, 1] +BETA = 1e-2 +TFR = [(0.6, 0), (0.5, 0), (0.5, 0)] +LR = 1e-3 + +MODEL_NAME = 'VQ-Q&A' +DEBUG = 0 + + +model = QandA(name=MODEL_NAME, trf_layers=TRF_LAYERS, device=DEVICE) +if PARALLEL: + model = torch.nn.DataParallel(model, device_ids=[0, 1]) +model.to(DEVICE) + + +if DEBUG: + save_root = './save' + log_path_name = 'debug' + VERBOSE = True +else: + save_root = "/data1/AccoMontage3/" + log_path_name = MODEL_NAME + VERBOSE = False + + +slakh_dir = '/data1/Q&A/slakh2100_flac_redux/4_bin_quantization/' +pop909_dir = '/data1/Q&A/POP909-Dataset/quantization/POP09-PIANOROLL-4-bin-quantization/' +train_set = Slakh_Pop909_Dataset(slakh_dir, pop909_dir, hop_len=1, debug_mode=DEBUG, split='train', mode='train') +train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, collate_fn=lambda b: collate_fn(b, DEVICE)) +val_set = Slakh_Pop909_Dataset(slakh_dir, pop909_dir, hop_len=2, debug_mode=DEBUG, split='validation', mode='train') +val_loader = DataLoader(val_set, batch_size=BATCH_SIZE, shuffle=False, collate_fn=lambda b: collate_fn(b, DEVICE, pitch_shift=False)) + +print(f'Dataset loaded. {len(train_loader)} samples for train and {len(val_loader)} samples for validation.') + +optimizer = optim.Adam(model.parameters(), lr=LR) +scheduler = MinExponentialLR(optimizer, gamma=0.9999, minimum=1e-5) +#scheduler = None +optimizer_scheduler = OptimizerScheduler(optimizer, scheduler, CLIP) +tfr1_scheduler = TeacherForcingScheduler(*TFR[0], scaler=N_EPOCH*len(train_loader)) +tfr2_scheduler = TeacherForcingScheduler(*TFR[1], scaler=N_EPOCH*len(train_loader)) +tfr3_scheduler = TeacherForcingScheduler(*TFR[2], scaler=N_EPOCH*len(train_loader)) +weights_scheduler = ConstantScheduler(WEIGHTS) +beta_scheduler = TeacherForcingScheduler(BETA, 0, scaler=N_EPOCH*len(train_loader), f=kl_anealing) + +params_dic = dict(tfr1=tfr1_scheduler, tfr2=tfr2_scheduler, + tfr3=tfr3_scheduler, + beta=beta_scheduler, weights=weights_scheduler) +param_scheduler = ParameterScheduler(**params_dic) + + +readme_fn = 'orchestrator/train_QandA.py' +log_path_mng = LogPathManager(readme_fn, save_root=save_root, log_path_name=log_path_name) + + +writer_names = ['loss', 'pno_tree_l', 'pl', 'dl', \ + 'kl_l', 'kl_sym', 'kl_trf', \ + 'feat_l', 'onset_l', 'intensity_l', 'center_l', \ + 'func_l', 'fp_l', 'ft_l', 'cmt_p', 'cmt_t', 'plty_p', 'plty_t'] + +tags = {'loss': None} +loss_writers = SummaryWriters(writer_names, tags, log_path_mng.writer_path) + +scheduler_writer_names = ['tfr1', 'tfr2', 'tfr3', 'beta', 'lr'] +tags = {'scheduler': None} +scheduler_writers = SummaryWriters(scheduler_writer_names, tags, log_path_mng.writer_path) +deadcode_writer_names = ['fp_usage', 'ft_usage', 'fp_dead', 'ft_dead'] +tags = {'deadcode': None} +deadcode_writers = SummaryWriters(deadcode_writer_names, tags, log_path_mng.writer_path) + + +def accumulate_loss_dic(writer_names, loss_dic, loss_items): + assert len(writer_names) == len(loss_items) + for key, val in zip(writer_names, loss_items): + loss_dic[key] += val.item() + return loss_dic + +def write_loss_to_dic(writer_names, loss_items): + loss_dic = {} + assert len(writer_names) == len(loss_items) + for key, val in zip(writer_names, loss_items): + loss_dic[key] = val.item() + return loss_dic + +def init_loss_dic(writer_names): + loss_dic = {} + for key in writer_names: + loss_dic[key] = 0. + return loss_dic + +def average_epoch_loss(epoch_loss_dict, num_batch): + for key in epoch_loss_dict: + epoch_loss_dict[key] /= num_batch + return epoch_loss_dict + + +def batch_report(loss, n_epoch, idx, num_batch, mode='training', verbose=False): + if verbose: + print(f'------------{mode}------------') + print('Epoch: [{0}][{1}/{2}]'.format(n_epoch, idx, num_batch)) + print(f"\t Total loss: {loss['loss']}") + print(f"\t Pitch loss: {loss['pl']:.3f}") + print(f"\t Duration loss: {loss['dl']:.3f}") + print(f"\t Feature loss [onset/intensity/center]: {loss['onset_l']:.3f}/{loss['intensity_l']:.3f}/{loss['center_l']:.3f}") + print(f"\t KL loss [sym/trf]: {loss['kl_sym']:.3f}/{loss['kl_trf']:.3f}") + print(f"\t Function loss [pitch/time]: {loss['fp_l']:.6f}/{loss['ft_l']:.6f}") + print(f"\t Commitment loss [pitch/time]: {loss['cmt_p']:.6f}/{loss['cmt_t']:.6f}") + print(f"\t Perplexity [pitch/time]: {loss['plty_p']:.6f}/{loss['plty_t']:.6f}") + + +def scheduler_show(param_scheduler, optimizer_scheduler, verbose=False): + schedule_params = {} + schedule_params['tfr1'] = param_scheduler.schedulers['tfr1'].get_tfr() + schedule_params['tfr2'] = param_scheduler.schedulers['tfr2'].get_tfr() + schedule_params['tfr3'] = param_scheduler.schedulers['tfr3'].get_tfr() + schedule_params['beta'] = param_scheduler.schedulers['beta'].get_tfr() + schedule_params['lr'] = optimizer_scheduler.optimizer.param_groups[0]['lr'] + if verbose: + print(schedule_params) + return schedule_params + + +def train(model, dataloader, param_scheduler, device, optimizer_scheduler, writer_names, loss_writers, scheduler_writers, n_epoch): + model.train() + param_scheduler.train() + epoch_loss_dic = init_loss_dic(writer_names) + + for idx, batch in tqdm(enumerate(dataloader), total=len(dataloader)): + #try: + optimizer_scheduler.optimizer_zero_grad() + + input_params = param_scheduler.step() + outputs = model('loss', *batch, **input_params) + if PARALLEL: + outputs = tuple([x.mean() for x in outputs]) + loss = outputs[0] + loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), optimizer_scheduler.clip) + optimizer_scheduler.step() + + epoch_loss_dic = accumulate_loss_dic(writer_names, epoch_loss_dic, outputs) + batch_loss_dic = batch_loss_dic = write_loss_to_dic(writer_names, outputs) + train_step = n_epoch * len(dataloader) + idx + loss_writers.write_task('train', batch_loss_dic, train_step) + + scheduler_dic = scheduler_show(param_scheduler, optimizer_scheduler, verbose=VERBOSE) + scheduler_writers.write_task('train', scheduler_dic, train_step) + + batch_report(batch_loss_dic, n_epoch, idx, len(dataloader), mode='train', verbose=VERBOSE) + + + if (idx!=0) and (idx % (len(dataloader) // 10)) == 0: + batch_z = model.func_time_enc.batch_z[torch.logical_not(batch[4].reshape(-1))] + batch_z = batch_z.reshape(-1, batch_z.shape[-1]) + ft_usage, ft_deadcode = model.func_time_enc.vq_quantizer.random_restart(batch_z) + model.func_time_enc.vq_quantizer.reset_usage() + #if (idx!=0) and (idx % (len(dataloader) // 10)) == 0: + batch_z = model.func_pitch_enc.batch_z[torch.logical_not(batch[4].reshape(-1))] + fp_usage, fp_deadcode = model.func_pitch_enc.vq_quantizer.random_restart(batch_z) + model.func_pitch_enc.vq_quantizer.reset_usage() + deadcode_writers.write_task('val', dict({'fp_usage': fp_usage, 'ft_usage': ft_usage, 'fp_dead': fp_deadcode, 'ft_dead': ft_deadcode}), train_step) + if VERBOSE: + print(f'\t code usage [fp/ft]: {fp_usage}/{ft_usage}', flush=True) + print(f'\t dead code [fp/ft]: {fp_deadcode}/{ft_deadcode}', flush=True) + + + #except Exception as exc: + # print(exc) + # continue + + scheduler_show(param_scheduler, optimizer_scheduler, verbose=True) + epoch_loss_dic = average_epoch_loss(epoch_loss_dic, len(dataloader)) + + + return epoch_loss_dic + + +def val(model, dataloader, param_scheduler, device, writer_names, summary_writers, deadcode_writers, n_epoch): + model.eval() + param_scheduler.eval() + epoch_loss_dic = init_loss_dic(writer_names) + + for idx, batch in tqdm(enumerate(dataloader), total=len(dataloader)): + try: + input_params = param_scheduler.step() + with torch.no_grad(): + outputs = model('loss', *batch, **input_params) + if PARALLEL: + outputs = tuple([x.mean() for x in outputs]) + #loss = outputs[0] + + epoch_loss_dic = accumulate_loss_dic(writer_names, epoch_loss_dic, outputs) + batch_loss_dic = write_loss_to_dic(writer_names, outputs) + batch_report(batch_loss_dic, n_epoch, idx, len(dataloader), mode='validation', verbose=VERBOSE) + #val_step = n_epoch * len(dataloader) + idx + #summary_writers.write_task('val', batch_loss_dic, val_step) + except Exception as exc: + print(exc) + continue + + epoch_loss_dic = average_epoch_loss(epoch_loss_dic, len(dataloader)) + summary_writers.write_task('val', epoch_loss_dic, n_epoch) + + return epoch_loss_dic + + +def epoch_report(start_time, end_time, train_loss, valid_loss, n_epoch): + epoch_mins, epoch_secs = epoch_time(start_time, end_time) + print(f'Epoch: {n_epoch + 1:02} | ' + f'Time: {epoch_mins}m {epoch_secs}s', + flush=True) + print(f'\tTrain Loss: {train_loss:.3f}', flush=True) + print(f'\t Valid. Loss: {valid_loss:.3f}', flush=True) + + +best_valid_loss = float('inf') +for n_epoch in range(N_EPOCH): + start_time = time.time() + print(f'Training epoch {n_epoch}') + train_loss = train(model, train_loader, param_scheduler, DEVICE, optimizer_scheduler, writer_names, loss_writers, scheduler_writers, n_epoch)['loss'] + print(f'Validating epoch {n_epoch}') + val_loss = val(model, val_loader, param_scheduler, DEVICE, writer_names, loss_writers, deadcode_writers, n_epoch)['loss'] + end_time = time.time() + + if PARALLEL: + torch.save(model.module.state_dict(), log_path_mng.epoch_model_path(f'{MODEL_NAME}_{str(n_epoch).zfill(3)}')) + else: + torch.save(model.state_dict(), log_path_mng.epoch_model_path(f'{MODEL_NAME}_{str(n_epoch).zfill(3)}')) + + if val_loss < best_valid_loss: + best_valid_loss = val_loss + if PARALLEL: + torch.save(model.module.state_dict(), log_path_mng.valid_model_path(MODEL_NAME)) + else: + torch.save(model.state_dict(), log_path_mng.valid_model_path(MODEL_NAME)) + + epoch_report(start_time, end_time, train_loss, val_loss, n_epoch) +if PARALLEL: + torch.save(model.module.state_dict(), log_path_mng.final_model_path(MODEL_NAME)) +else: + torch.save(model.state_dict(), log_path_mng.final_model_path(MODEL_NAME)) + + diff --git a/orchestrator/utils.py b/orchestrator/utils.py new file mode 100644 index 0000000..56d05e4 --- /dev/null +++ b/orchestrator/utils.py @@ -0,0 +1,383 @@ +import os +import datetime +import shutil +import torch +import numpy as np +from torch.distributions import Normal, kl_divergence +from torch.utils.tensorboard import SummaryWriter +import pretty_midi as pyd +from scipy.interpolate import interp1d +from scipy import stats as st + + + + +def get_zs_from_dists(dists, sample=False): + return [dist.rsample() if sample else dist.mean for dist in dists] + + +def scheduled_sampling(i, high=0.7, low=0.05, scaler=1e5): + x = 10 * (i - 0.5) + z = 1 / (1 + np.exp(x)) + y = (high - low) * z + low + return y + + +def kl_anealing(i, high=0.1, low=0., scaler=None): + hh = 1 - low + ll = 1 - high + x = 10 * (i - 0.5) + z = 1 / (1 + np.exp(x)) + y = (hh - ll) * z + ll + return 1 - y + +def standard_normal(shape, device): + N = Normal(torch.zeros(shape), torch.ones(shape)) + #if torch.cuda.is_available(): + N.loc = N.loc.to(device) + N.scale = N.scale.to(device) + return N + +def kl_with_normal(dist): + shape = dist.mean.size(-1) + normal = standard_normal(shape, dist.mean.device) + kl = kl_divergence(dist, normal).mean() + return kl + + +class SummaryWriters: + + def __init__(self, writer_names, tags, log_path, tasks=('train', 'val')): + # writer_names example: ['loss', 'kl_loss', 'recon_loss'] + # tags example: {'name1': None, 'name2': (0, 1)} + self.log_path = log_path + #assert 'loss' == writer_names[0] + self.writer_names = writer_names + self.tags = tags + self._regularize_tags() + + writer_dic = {} + for name in writer_names: + writer_dic[name] = SummaryWriter(os.path.join(log_path, name)) + self.writers = writer_dic + + all_tags = {} + for task in tasks: + task_dic = {} + for key, val in self.tags.items(): + task_dic['_'.join([task, key])] = val + all_tags[task] = task_dic + self.all_tags = all_tags + + def _init_summary_writer(self): + tags = {'batch_train': (0, 1, 2, 3, 4)} + self.summary_writers = SummaryWriters(self.writer_names, tags, + self.writer_path) + + def _regularize_tags(self): + for key, val in self.tags.items(): + if val is None: + self.tags[key] = tuple(range(len(self.writer_names))) + + def single_write(self, name, tag, val, step): + self.writers[name].add_scalar(tag, val, step) + + def write_tag(self, task, tag, vals, step): + assert len(vals) == len(self.all_tags[task][tag]) + for name_id, val in zip(self.all_tags[task][tag], vals): + name = self.writer_names[name_id] + self.single_write(name, tag, val, step) + + def write_task(self, task, vals_dic, step): + for tag, name_ids in self.all_tags[task].items(): + vals = [vals_dic[self.writer_names[i]] for i in name_ids] + self.write_tag(task, tag, vals, step) + + +def join_fn(*items, ext='pt'): + return '.'.join(['_'.join(items), ext]) + + +class LogPathManager: + + def __init__(self, readme_fn=None, save_root='.', log_path_name='result', + with_date=True, with_time=True, + writer_folder='writers', model_folder='models'): + date = str(datetime.date.today()) if with_date else '' + ctime = datetime.datetime.now().time().strftime("%H%M%S") \ + if with_time else '' + log_folder = '_'.join([date, ctime, log_path_name]) + log_path = os.path.join(save_root, log_folder) + writer_path = os.path.join(log_path, writer_folder) + model_path = os.path.join(log_path, model_folder) + self.log_path = log_path + self.writer_path = writer_path + self.model_path = model_path + LogPathManager.create_path(log_path) + LogPathManager.create_path(writer_path) + LogPathManager.create_path(model_path) + if readme_fn is not None: + shutil.copyfile(readme_fn, os.path.join(log_path, 'readme.txt')) + + @staticmethod + def create_path(path): + if not os.path.exists(path): + os.makedirs(path) + + def epoch_model_path(self, model_name): + model_fn = join_fn(model_name, 'epoch', ext='pt') + return os.path.join(self.model_path, model_fn) + + def valid_model_path(self, model_name): + model_fn = join_fn(model_name, 'valid', ext='pt') + return os.path.join(self.model_path, model_fn) + + def final_model_path(self, model_name): + model_fn = join_fn(model_name, 'final', ext='pt') + return os.path.join(self.model_path, model_fn) + +def epoch_time(start_time, end_time): + elapsed_time = end_time - start_time + elapsed_mins = int(elapsed_time / 60) + elapsed_secs = int(elapsed_time - (elapsed_mins * 60)) + return elapsed_mins, elapsed_secs + + +def matrix2midi(pr_matrices, programs, init_tempo=120, time_start=0): + """ + Reconstruct a multi-track midi from a 3D matrix of shape (Track. Time, 128). + """ + ACC = 16 + tracks = [] + for program in programs: + track_recon = pyd.Instrument(program=int(program), is_drum=False, name=pyd.program_to_instrument_name(int(program))) + tracks.append(track_recon) + + indices_track, indices_onset, indices_pitch = np.nonzero(pr_matrices) + alpha = 1 / (ACC // 4) * 60 / init_tempo #timetep between each quntization bin + for idx in range(len(indices_track)): + track_id = indices_track[idx] + onset = indices_onset[idx] + pitch = indices_pitch[idx] + + start = onset * alpha + duration = pr_matrices[track_id, onset, pitch] * alpha + velocity = 100 + + note_recon = pyd.Note(velocity=int(velocity), pitch=int(pitch), start=time_start + start, end=time_start + start + duration) + tracks[track_id].notes.append(note_recon) + + midi_recon = pyd.PrettyMIDI(initial_tempo=init_tempo) + midi_recon.instruments = tracks + return midi_recon + +def grid2pr(grid, max_note_count=16, min_pitch=0, pitch_eos_ind=129): + #grid: (time, max_simu_note, 6) + if grid.shape[1] == max_note_count: + grid = grid[:, 1:] + pr = np.zeros((grid.shape[0], 128), dtype=int) + for t in range(grid.shape[0]): + for n in range(grid.shape[1]): + note = grid[t, n] + if note[0] == pitch_eos_ind: + break + pitch = note[0] + min_pitch + dur = int(''.join([str(_) for _ in note[1:]]), 2) + 1 + pr[t, pitch] = dur + return pr + +def midi2matrix(midi, quaver): + pr_matrices = [] + programs = [] + for track in midi.instruments: + programs.append(track.program) + pr_matrix = np.zeros((len(quaver), 128)) + for note in track.notes: + note_start = np.argmin(np.abs(quaver - note.start)) + note_end = np.argmin(np.abs(quaver - note.end)) + if note_end == note_start: + note_end = min(note_start + 1, len(quaver) - 1) + pr_matrix[note_start, note.pitch] = note_end - note_start + pr_matrices.append(pr_matrix) + return np.array(pr_matrices), np.array(programs) + + +def midi2matrix_with_dynamics(midi, quaver): + """ + Convert multi-track midi to a 3D matrix of shape (Track, Time, 128). + Each cell is a integer number representing quantized duration. + """ + pr_matrices = [] + programs = [] + quantization_error = [] + for track in midi.instruments: + qt_error = [] # record quantization error + pr_matrix = np.zeros((len(quaver), 128, 2)) + for note in track.notes: + note_start = np.argmin(np.abs(quaver - note.start)) + note_end = np.argmin(np.abs(quaver - note.end)) + if note_end == note_start: + note_end = min(note_start + 1, len(quaver) - 1) # guitar/bass plunk typically results in a very short note duration. These note should be quantized to 1 instead of 0. + pr_matrix[note_start, note.pitch, 0] = note_end - note_start + pr_matrix[note_start, note.pitch, 1] = note.velocity + + #compute quantization error. A song with very high error (e.g., triple-quaver songs) will be discriminated and therefore discarded. + if note_end == note_start: + qt_error.append(np.abs(quaver[note_start] - note.start) / (quaver[note_start] - quaver[note_start-1])) + else: + qt_error.append(np.abs(quaver[note_start] - note.start) / (quaver[note_end] - quaver[note_start])) + + control_matrix = np.ones((len(quaver), 128, 1)) * -1 + for control in track.control_changes: + #if control.time < time_end: + # if len(quaver) == 0: + # continue + control_time = np.argmin(np.abs(quaver - control.time)) + control_matrix[control_time, control.number, 0] = control.value + + pr_matrix = np.concatenate((pr_matrix, control_matrix), axis=-1) + pr_matrices.append(pr_matrix) + programs.append(track.program) + quantization_error.append(np.mean(qt_error)) + + return np.array(pr_matrices), programs, quantization_error + + +def pr2grid(pr_mat, max_note_count=16, max_pitch=127, min_pitch=0, + pitch_pad_ind=130, dur_pad_ind=2, + pitch_sos_ind=128, pitch_eos_ind=129): + pr_mat3d = np.ones((len(pr_mat), max_note_count, 6), dtype=int) * dur_pad_ind + pr_mat3d[:, :, 0] = pitch_pad_ind + pr_mat3d[:, 0, 0] = pitch_sos_ind + cur_idx = np.ones(len(pr_mat), dtype=int) + for t, p in zip(*np.where(pr_mat != 0)): + pr_mat3d[t, cur_idx[t], 0] = p - min_pitch + binary = np.binary_repr(min(int(pr_mat[t, p]), 32) - 1, width=5) + pr_mat3d[t, cur_idx[t], 1: 6] = \ + np.fromstring(' '.join(list(binary)), dtype=int, sep=' ') + if cur_idx[t] == max_note_count-1: + continue + cur_idx[t] += 1 + #print(cur_idx) + pr_mat3d[np.arange(0, len(pr_mat)), cur_idx, 0] = pitch_eos_ind + return pr_mat3d + + + + +def matrix2midi_with_dynamics(pr_matrices, programs, init_tempo=120, time_start=0, ACC=16): + """ + Reconstruct a multi-track midi from a 3D matrix of shape (Track. Time, 128, 3). + """ + tracks = [] + for program in programs: + track_recon = pyd.Instrument(program=int(program), is_drum=False, name=pyd.program_to_instrument_name(int(program))) + tracks.append(track_recon) + + indices_track, indices_onset, indices_pitch = np.nonzero(pr_matrices[:, :, :, 0]) + alpha = 1 / (ACC // 4) * 60 / init_tempo #timetep between each quntization bin + for idx in range(len(indices_track)): + track_id = indices_track[idx] + onset = indices_onset[idx] + pitch = indices_pitch[idx] + + start = onset * alpha + duration = pr_matrices[track_id, onset, pitch, 0] * alpha + velocity = pr_matrices[track_id, onset, pitch, 1] + #if (int(velocity) > 127) or (int(velocity) < 0): + # print(velocity) + #if int(pitch) > 127 or (int(pitch) < 0): + # print(pitch) + + note_recon = pyd.Note(velocity=int(velocity), pitch=int(pitch), start=time_start + start, end=time_start + start + duration) + tracks[track_id].notes.append(note_recon) + + for idx in range(len(pr_matrices)): + cc = [] + control_matrix = pr_matrices[idx, :, :, 2] + for t, n in zip(*np.nonzero(control_matrix >= 0)): + start = alpha * t + #if int(n) > 127 or (int(n) < 0): + # print(n) + #if int(control_matrix[t, n]) > 127 or (int(control_matrix[t, n]) < 0): + # print(int(control_matrix[t, n])) + cc.append(pyd.ControlChange(int(n), int(control_matrix[t, n]), start)) + tracks[idx].control_changes = cc + + midi_recon = pyd.PrettyMIDI(initial_tempo=init_tempo) + midi_recon.instruments = tracks + return midi_recon + + +def matrix2midi_drum(pr_matrices, programs, init_tempo=120, time_start=0, ACC=64): + """ + Reconstruct a multi-track midi from a 3D matrix of shape (Track. Time, 128, 3). + """ + tracks = [] + for pogram in range(len(programs)): + track_recon = pyd.Instrument(program=int(pogram), is_drum=True, name='drums') + tracks.append(track_recon) + + indices_track, indices_onset, indices_pitch = np.nonzero(pr_matrices[:, :, :, 0]) + alpha = 1 / (ACC // 4) * 60 / init_tempo #timetep between each quntization bin + for idx in range(len(indices_track)): + track_id = indices_track[idx] + onset = indices_onset[idx] + pitch = indices_pitch[idx] + + start = onset * alpha + duration = pr_matrices[track_id, onset, pitch, 0] * alpha + velocity = pr_matrices[track_id, onset, pitch, 1] + + note_recon = pyd.Note(velocity=int(velocity), pitch=int(pitch), start=time_start + start, end=time_start + start + duration) + tracks[track_id].notes.append(note_recon) + + for idx in range(len(pr_matrices)): + cc = [] + control_matrix = pr_matrices[idx, :, :, 2] + for t, n in zip(*np.nonzero(control_matrix > -1)): + start = alpha * t + cc.append(pyd.ControlChange(int(n), int(control_matrix[t, n]), start)) + tracks[idx].control_changes = cc + + #midi_recon = pyd.PrettyMIDI(initial_tempo=init_tempo) + #midi_recon.instruments = tracks + return tracks + +def retrieve_control(pop909_dir, song, tracks): + src_dir = os.path.join(pop909_dir.replace('quantization/POP09-PIANOROLL-4-bin-quantization/', 'POP909/'), song.split('.')[0], song.replace('.npz', '.mid')) + src_midi = pyd.PrettyMIDI(src_dir) + beats = src_midi.get_beats() + beats = np.append(beats, beats[-1] + (beats[-1] - beats[-2])) + ACC = 4 + quantize = interp1d(np.array(range(0, len(beats))) * ACC, beats, kind='linear') + quaver = quantize(np.array(range(0, (len(beats) - 1) * ACC))) + + pr_matrices, programs, _ = midi2matrix_with_dynamics(src_midi, quaver) + + from_4_bin = np.nonzero(tracks[0, :, :, 0]) + from_midi = np.nonzero(pr_matrices[0, :, :, 0]) + + mid_length = min(from_midi[1].shape[0], from_4_bin[1].shape[0]) + diff = from_midi[1][: mid_length] - from_4_bin[1][: mid_length] + diff_avg = np.mean(diff) + diff_std = np.std(diff) + + if diff_std > 0: + diff_record = [] + for roll_idx in range(-32, 32): + roll_pitches = np.roll(from_midi[1], shift=roll_idx, axis=0) + diff = roll_pitches[abs(roll_idx): mid_length-abs(roll_idx)] - from_4_bin[1][abs(roll_idx): mid_length-abs(roll_idx)] + diff_avg = np.mean(diff) + diff_std = np.std(diff) + diff_record.append((roll_idx, diff_avg, diff_std)) + diff_record = sorted(diff_record, key=lambda x: x[2]) + + roll_idx_min = diff_record[0][0] + roll_times = np.roll(from_midi[0], shift=roll_idx_min, axis=0) + diff = roll_times[abs(roll_idx_min): mid_length-abs(roll_idx_min)] - from_4_bin[0][abs(roll_idx_min): mid_length-abs(roll_idx_min)] + #print(st.mode(diff).mode[0], st.mode(diff).count[0]/diff.shape[0]) + else: + diff = from_midi[0][: mid_length] - from_4_bin[0][: mid_length] + #print(st.mode(diff).mode[0], st.mode(diff).count[0]/diff.shape[0]) + return pr_matrices[:, :, :, 2: 3], st.mode(diff).mode[0] \ No newline at end of file diff --git a/orchestrator/vq_dataset.py b/orchestrator/vq_dataset.py new file mode 100644 index 0000000..b6c3d55 --- /dev/null +++ b/orchestrator/vq_dataset.py @@ -0,0 +1,149 @@ +import os +import numpy as np +import pretty_midi as pyd +from torch.utils.data import Dataset +from tqdm import tqdm +import torch + +SAMPLE_LEN = 16 #16 codes per sample sequence, where each codes represents 2-bar info +HOP_LEN = 4 +NUM_INSTR_CLASS = 34 +NUM_PITCH_CODE = 64 +NUM_TIME_CODE = 128 +TOTAL_LEN_BIN = np.array([4, 7, 12, 15, 20, 23, 28, 31, 36, 39, 44, 47, 52, 55, 60, 63, 68, 71, 76, 79, 84, 87, 92, 95, 100, 103, 108, 111, 116, 119, 124, 127, 132]) +ABS_POS_BIN = np.arange(129) +REL_POS_BIN = np.arange(128) + + +class VQ_LMD_Dataset(Dataset): + def __init__(self, lmd_dir, debug_mode=False, split='train', mode='train'): + super(VQ_LMD_Dataset, self).__init__() + self.lmd_dir = lmd_dir + self.split = split + self.mode = mode + self.debug_mode = debug_mode + self.mixture_list = [] + self.program_list = [] + self.func_pitch_list = [] + self.func_time_list = [] + self.anchor_list = [] + + print('loading LMD Dataset ...') + self.load_lmd() + + def __len__(self): + return len(self.anchor_list) + + def __getitem__(self, idx): + song_id, start, total_len = self.anchor_list[idx] + mix = self.mixture_list[song_id][start: min(total_len, start+SAMPLE_LEN)] + prog = self.program_list[song_id] + fp = self.func_pitch_list[song_id][start: min(total_len, start+SAMPLE_LEN)] + ft = self.func_time_list[song_id][start: min(total_len, start+SAMPLE_LEN)] + return mix, prog, fp, ft, (start, total_len) + + + + def load_lmd(self): + lmd_list = os.listdir(self.lmd_dir) + if self.split == 'train': + lmd_list = lmd_list[: int(len(lmd_list)*.95)] + elif self.split == 'validation': + lmd_list = lmd_list[int(len(lmd_list)*.95): ] + if self.debug_mode: + lmd_list = lmd_list[: 1000] + for song in tqdm(lmd_list): + lmd_data = np.load(os.path.join(self.lmd_dir, song)) + mix = lmd_data['mixture'] #(num2bar, 256) 3 for duration, velocity, and control + prog = lmd_data['programs'] #(track) + if len(prog) > 25: + continue #for sake of computing memory + fp = lmd_data['func_pitch'] #(num2bar, track) + ft = lmd_data['func_time'] #(num2bar, track, 8) + + if self.split == 'train': + for i in range(0, len(mix), HOP_LEN): + if i + SAMPLE_LEN >= len(mix): + break + self.anchor_list.append((len(self.mixture_list), i, len(mix))) #(song_id, start, total_length) + else: + for i in range(0, len(mix), SAMPLE_LEN): + if i + SAMPLE_LEN >= len(mix): + break + self.anchor_list.append((len(self.mixture_list), i, len(mix))) #(song_id, start, total_length) + self.anchor_list.append((len(self.mixture_list), max(0, len(mix)-SAMPLE_LEN), len(mix))) + + self.mixture_list.append(mix) + self.program_list.append(prog) + self.func_pitch_list.append(fp) + self.func_time_list.append(ft) + + + +def collate_fn(batch, device): + max_dur = max([len(item[0]) for item in batch]) + max_tracks = max([len(item[1]) for item in batch]) + + mixture = [] + programs = [] + func_pitch = [] + func_time = [] + time_mask = [] + track_mask = [] + total_length = [] + abs_pos = [] + rel_pos = [] + + for mix, prog, fp, ft, (start, total_len) in batch: + #print(mix.shape, prog.shape, fp.shape, ft.shape, start, total_len) + time_mask.append([0]*len(mix) + [1]*(max_dur-len(mix))) + track_mask.append([0]*len(prog) + [1]*(max_tracks-len(prog))) + + r_pos = np.round(np.arange(start, start+len(mix), 1) / (total_len-1) * len(REL_POS_BIN)) + total_len = np.argmin(np.abs(TOTAL_LEN_BIN - total_len)).repeat(len(mix)) + if start <= ABS_POS_BIN[-2]: + a_pos = np.append(ABS_POS_BIN[start: min(ABS_POS_BIN[-1], start+len(mix))], [ABS_POS_BIN[-1]] * (start+len(mix)-ABS_POS_BIN[-1])) + else: + a_pos = np.array([ABS_POS_BIN[-1]] * len(mix)) + + a = np.random.rand() + if a < 0.3: + blur_ratio = 0 + elif a < 0.7: + blur_ratio = (np.random.rand() * 2 + 1) / 4 #range in [.25, .75) + else: + blur_ratio = 1 + mix = (1 - blur_ratio) * mix + blur_ratio * np.random.normal(loc=0, scale=1, size=mix.shape) + + #print('prog', len(prog), 'max_track', max_tracks) + if len(prog) < max_tracks: + prog = np.pad(prog, ((0, max_tracks-len(prog))), mode='constant', constant_values=(NUM_INSTR_CLASS,)) + fp = np.pad(fp, ((0, 0), (0, max_tracks-fp.shape[1])), mode='constant', constant_values=(NUM_PITCH_CODE,)) + ft = np.pad(ft, ((0, 0), (0, max_tracks-ft.shape[1]), (0, 0)), mode='constant', constant_values=(NUM_TIME_CODE,)) + #print('fp pad', fp.shape) + if len(mix) < max_dur: + mix = np.pad(mix, ((0, max_dur-len(mix)), (0, 0)), mode='constant', constant_values=(0,)) + total_len = np.pad(total_len, (0, max_dur-len(total_len)), mode='constant', constant_values=(len(TOTAL_LEN_BIN),)) + a_pos = np.pad(a_pos, (0, max_dur-len(a_pos)), mode='constant', constant_values=(len(ABS_POS_BIN),)) + r_pos = np.pad(r_pos, (0, max_dur-len(r_pos)), mode='constant', constant_values=(len(REL_POS_BIN),)) + fp = np.pad(fp, ((0, max_dur-len(fp)), (0, 0)), mode='constant', constant_values=(NUM_PITCH_CODE,)) + ft = np.pad(ft, ((0, max_dur-len(ft)), (0, 0), (0, 0)), mode='constant', constant_values=(NUM_TIME_CODE,)) + + #print('fp', fp.shape, max_tracks) + mixture.append(mix) + programs.append(prog) + func_pitch.append(fp) + func_time.append(ft) + total_length.append(total_len) + abs_pos.append(a_pos) + rel_pos.append(r_pos) + + return torch.from_numpy(np.array(mixture)).float().to(device), \ + torch.from_numpy(np.array(programs)).long().to(device), \ + torch.from_numpy(np.array(func_pitch)).long().to(device), \ + torch.from_numpy(np.array(func_time)).long().to(device), \ + torch.BoolTensor(time_mask).to(device), \ + torch.BoolTensor(track_mask).to(device), \ + torch.from_numpy(np.array(total_length)).long().to(device), \ + torch.from_numpy(np.array(abs_pos)).long().to(device), \ + torch.from_numpy(np.array(rel_pos)).long().to(device) \ No newline at end of file diff --git a/piano_arranger/AccoMontage.py b/piano_arranger/AccoMontage.py new file mode 100644 index 0000000..a6852ea --- /dev/null +++ b/piano_arranger/AccoMontage.py @@ -0,0 +1,389 @@ +import os +os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" +import numpy as np +import pandas as pd +import torch +from .acc_utils import split_phrases, computeTIV, chord_shift, cosine, cosine_rhy +from .models import DisentangleVAE +from . import format_converter as cvt +from scipy.interpolate import interp1d +from tqdm import tqdm +import gc + + +def set_premises(phrase_data_dir, edge_weights_dir, checkpoint_dir, reference_meta_dir, phrase_len=range(1, 17)): + #load POP909 phrase data + data = np.load(phrase_data_dir, allow_pickle=True) + MELODY = data['melody'] + ACC = data['acc'] + CHORD = data['chord'] + VELOCITY = data['velocity'] + CC = data['cc'] + acc_pool = {} + for length in tqdm(phrase_len): + (_mel, _acc, _chord, _vel, _cc, _song_reference) = find_by_length(MELODY, ACC, CHORD, VELOCITY, CC, length) + acc_pool[length] = (_mel, _acc, _chord, _vel, _cc, _song_reference) + del data, MELODY, ACC, CHORD, VELOCITY, CC + gc.collect() #for optimizing RAM usage + texture_filter = get_texture_filter(acc_pool) + #load pre-computed transition score + edge_weights=np.load(edge_weights_dir, allow_pickle=True) + #load re-harmonization model + model = DisentangleVAE.init_model(torch.device('cuda')).cuda() + checkpoint = torch.load(checkpoint_dir) + model.load_state_dict(checkpoint) + model.eval() + #load pop909 meta + reference_check = pd.read_excel(reference_meta_dir) + return model, acc_pool, reference_check, (edge_weights, texture_filter) + + +def load_lead_sheet(DEMO_ROOT, SONG_NAME, SEGMENTATION, NOTE_SHIFT, melody_track_ID=0): + melody_roll, chord_roll = cvt.leadsheet2matrix(os.path.join(DEMO_ROOT, SONG_NAME, 'lead sheet.mid'), melody_track_ID) + assert(len(melody_roll == len(chord_roll))) + if NOTE_SHIFT != 0: + melody_roll = melody_roll[int(NOTE_SHIFT*4):, :] + chord_roll = chord_roll[int(NOTE_SHIFT*4):, :] + if len(melody_roll) % 16 != 0: + pad_len = (len(melody_roll)//16+1)*16-len(melody_roll) + melody_roll = np.pad(melody_roll, ((0, pad_len), (0, 0))) + melody_roll[-pad_len:, -1] = 1 + chord_roll = np.pad(chord_roll, ((0, pad_len), (0, 0))) + chord_roll[-pad_len:, 0] = -1 + chord_roll[-pad_len:, -1] = -1 + + CHORD_TABLE = np.stack([cvt.expand_chord(chord) for chord in chord_roll[::4]], axis=0) + LEADSHEET = np.concatenate((melody_roll, chord_roll[:, 1: -1]), axis=-1) #T*142, quantized at 16th + query_phrases = split_phrases(SEGMENTATION) #[('A', 8, 0), ('A', 8, 8), ('B', 8, 16), ('B', 8, 24)] + + midi_len = len(LEADSHEET)//16 + anno_len = sum([item[1] for item in query_phrases]) + if midi_len > anno_len: + LEADSHEET = LEADSHEET[: anno_len*16] + CHORD_TABLE = CHORD_TABLE[: anno_len*4] + print(f'Mismatch warning: Detect {midi_len} bars in the lead sheet (MIDI) and {anno_len} bars in the provided phrase annotation. The lead sheet is truncated to {anno_len} bars.') + elif midi_len < anno_len: + pad_len = (anno_len - midi_len)*16 + LEADSHEET = np.pad(LEADSHEET, ((0, pad_len), (0, 0))) + LEADSHEET[-pad_len:, 129] = 1 + CHORD_TABLE = np.pad(CHORD_TABLE, ((0, pad_len//4), (0, 0))) + CHORD_TABLE[-pad_len//4:, 11] = -1 + CHORD_TABLE[-pad_len//4:, -1] = -1 + print(f'Mismatch warning: Detect {midi_len} bars in the lead sheet (MIDI) and {anno_len} bars in the provided phrase annotation. The lead sheet is padded to {anno_len} bars.') + + return LEADSHEET, CHORD_TABLE, query_phrases + + +def phrase_selection(LEADSHEET, query_phrases, reference_check, acc_pool, edge_weights, texture_filter=None, PREFILTER=None, SPOTLIGHT=None, randomness=0): + melody_queries = [] + for item in query_phrases: + start_bar = item[-1] + length = item[-2] + segment = LEADSHEET[start_bar*16: (start_bar+length)*16] + melody_queries.append(segment) #melody queries: list of T16*142, segmented by phrases + print(f'Phrase selection begins: {len(query_phrases)} phrases in total. \n\t Set note density filter: {PREFILTER}.') + if SPOTLIGHT is not None: + print(f'\t Refer to {SPOTLIGHT} as much as possible') + phrase_indice, chord_shift = dp_search(melody_queries, + query_phrases, + acc_pool, + edge_weights, + texture_filter, + filter_id = PREFILTER, + spotlights = ref_spotlight(SPOTLIGHT, reference_check), + randomness = randomness) + path = phrase_indice[0] + shift = chord_shift[0] + reference_set = [] + for idx_phrase, phrase in enumerate(query_phrases): + phrase_len = phrase[1] + song_ref = acc_pool[phrase_len][-1] + idx_song = song_ref[path[idx_phrase][0]][0] + pop909_idx = reference_check.iloc[idx_song][0] + song_name = reference_check.iloc[idx_song][1] + reference_set.append(f'{idx_phrase}: {str(pop909_idx).zfill(3)}_{song_name}') + print('Reference pieces:', reference_set) + return (path, shift) + + +def find_by_length(melody_data, acc_data, chord_data, velocity_data, cc_data, length): + """Search from POP909 phrase data for a certain phrase length.""" + melody_record = [] + acc_record = [] + chord_record = [] + velocity_record = [] + cc_record = [] + song_reference = [] + for song_idx in range(acc_data.shape[0]): + for phrase_idx in range(len(acc_data[song_idx])): + melody = melody_data[song_idx][phrase_idx] + if not melody.shape[0] == length * 16: + continue + if np.sum(melody[:, :128]) <= 2: + continue + melody_record.append(melody) + acc = acc_data[song_idx][phrase_idx] + acc_record.append(acc) + chord = chord_data[song_idx][phrase_idx] + chord_record.append(chord) + velocity = velocity_data[song_idx][phrase_idx] + velocity_record.append(velocity) + cc = cc_data[song_idx][phrase_idx] + cc_record.append(cc) + song_reference.append((song_idx, phrase_idx)) + return np.array(melody_record), np.array(acc_record), np.array(chord_record), np.array(velocity_record), np.array(cc_record), song_reference + + +def dp_search(query_phrases, seg_query, acc_pool, edge_weights, texture_filter=None, filter_id=None, spotlights=None, randomness=0): + """Search for texture donors based on dynamic programming. + * query_phrases: lead sheet in segmented phrases. Shape of each phrase: (T, 142), quantized at 1/4-beat level. This format is defined in R. Yang et al., "Deep music analogy via latent representation disentanglement," ISMIR 2019. + * seg_query: phrase annotation for the lead sheet. Format of each phrase: (label, length, start). For example, seg_query=[('A', 8, 0), ('A', 8, 8), ('B', 4, 16)]. + * acc_pool: search space for piano texture donors. + * edge_weights: pre-computed transition scores for texture donor i to i+1. + * texture_filter: filter on voice number (VN) and rhythmic density (RD). + * filter_id: specified VN abd RD to filter for the first phrase. + * spotlights: specified a preference for certain songs and/or artists for the search process. + * randomness: degree of randomness tobe introduced to the search process. + """ + seg_query = [item[0] + str(item[1]) for item in seg_query] #['A8', 'A8', 'B8', 'B8'] + #Searching for phrase 1 + query_length = [query_phrases[i].shape[0]//16 for i in range(len(query_phrases))] + mel, acc, chord, _, _, song_ref = acc_pool[query_length[0]] + mel_set = mel + rhy_set = np.concatenate((np.sum(mel_set[:, :, :128], axis=-1, keepdims=True), mel_set[:, :, 128: 130]), axis=-1) + query_rhy = np.concatenate((np.sum(query_phrases[0][:, : 128], axis=-1, keepdims=True), query_phrases[0][:, 128: 130]), axis=-1)[np.newaxis, :, :] + rhythm_result = cosine_rhy(query_rhy+1e-5, rhy_set+1e-5) + + chord_set = chord + chord_set, num_total, shift_const = chord_shift(chord_set) + chord_set_TIV = computeTIV(chord_set) + query_chord = query_phrases[0][:, 130:][::4] + query_chord_TIV = computeTIV(query_chord)[np.newaxis, :, :] + chord_score, arg_chord = cosine(query_chord_TIV, chord_set_TIV) + + score = .5*rhythm_result + .5*chord_score + score += randomness * np.random.normal(0, 1, size=len(score)) #to introduce some randomness + if spotlights is not None: + for spot_idx in spotlights: + for ref_idx, ref_item in enumerate(song_ref): + if ref_item[0] == spot_idx: + score[ref_idx] += 1 + if filter_id is not None: + mask = texture_filter[query_length[0]][0][filter_id[0]] * texture_filter[query_length[0]][1][filter_id[1]] - 1 + score += mask + + path = [[(i, score[i])] for i in range(acc.shape[0])] + shift = [[shift_const[i]] for i in arg_chord] + melody_record = np.argmax(mel_set, axis=-1) + record = [] + + #Searching for phrase 2, 3, ... + for i in tqdm(range(1, len(query_length))): + mel, acc, chord, _, _, song_ref = acc_pool[query_length[i]] + weight_key = f"l_{str(query_length[i-1]).zfill(2)}_{str(query_length[i]).zfill(2)}" + contras_result = edge_weights[weight_key] + if query_length[i-1] == query_length[i]: + for j in range(contras_result.shape[0]): + contras_result[j, j] = -1 #the ith phrase does not transition to itself at i+1 + for k in range(j-1, -1, -1): + if song_ref[k][0] != song_ref[j][0]: + break + contras_result[j, k] = -1 #ith phrase does not transition to its ancestors in the same song. + if i > 1: + contras_result = contras_result[[item[-1][1] for item in record]] + if spotlights is not None: + for spot_idx in spotlights: + for ref_idx, ref_item in enumerate(song_ref): + if ref_item[0] == spot_idx: + contras_result[:, ref_idx] += 1 + mel_set = mel + rhy_set = np.concatenate((np.sum(mel_set[:, :, :128], axis=-1, keepdims=True), mel_set[:, :, 128: 130]), axis=-1) + query_rhy = np.concatenate((np.sum(query_phrases[i][:, : 128], axis=-1, keepdims=True), query_phrases[i][:, 128: 130]), axis=-1)[np.newaxis, :, :] + rhythm_result = cosine_rhy(query_rhy, rhy_set) + chord_set = chord + chord_set, num_total, shift_const = chord_shift(chord_set) + chord_set_TIV = computeTIV(chord_set) + query_chord = query_phrases[i][:, 130:][::4] + query_chord_TIV = computeTIV(query_chord)[np.newaxis, :, :] + chord_score, arg_chord = cosine(query_chord_TIV, chord_set_TIV) + sim_this_layer = .5*rhythm_result + .5*chord_score + sim_this_layer += randomness * np.random.normal(0, 1, size=len(sim_this_layer)) + if spotlights is not None: + for spot_idx in spotlights: + for ref_idx, ref_item in enumerate(song_ref): + if ref_item[0] == spot_idx: + sim_this_layer[ref_idx] += 1 + score_this_layer = .7*contras_result + .3*np.tile(sim_this_layer[np.newaxis, :], (contras_result.shape[0], 1)) + np.tile(score[:, np.newaxis], (1, contras_result.shape[1])) + melody_flat = np.argmax(mel_set, axis=-1) + if seg_query[i] == seg_query[i-1]: + melody_pre = melody_record + matrix = np.matmul(melody_pre, np.transpose(melody_flat, (1, 0))) / (np.linalg.norm(melody_pre, axis=-1)[:, np.newaxis]*(np.linalg.norm(np.transpose(melody_flat, (1, 0)), axis=0))[np.newaxis, :]) + if i == 1: + for k in range(matrix.shape[1]): + matrix[k, :k] = -1 + else: + for k in range(len(record)): + matrix[k, :record[k][-1][1]] = -1 + matrix = (matrix > 0.99) * 1. + score_this_layer += matrix + topk = 1 + args = np.argsort(score_this_layer, axis=0)[::-1, :][:topk, :] + record = [] + for j in range(args.shape[-1]): + for k in range(args.shape[0]): + record.append((score_this_layer[args[k, j], j], (args[k, j], j))) + shift_this_layer = [[shift_const[k]] for k in arg_chord] + new_path = [path[item[-1][0]] + [(item[-1][1], sim_this_layer[item[-1][1]])] for item in record] + new_shift = [shift[item[-1][0]] + shift_this_layer[item[-1][1]] for item in record] + melody_record = melody_flat[[item[-1][1] for item in record]] + path = new_path + shift = new_shift + score = np.array([item[0] for item in record]) + + arg = score.argsort()[::-1] + return [path[arg[i]] for i in range(topk)], [shift[arg[i]] for i in range(topk)] + + +def re_harmonization(lead_sheet, chord_table, query_phrases, indices, shifts, acc_pool, model, get_est=True, tempo=120): + """Re-harmonize the accompaniment texture donors and save in MIDI. + * lead_sheet: the conditional lead sheet. Its melody track will be taken. Shape: (T, 142), quantized at 1-beat level. This format is defined in R. Yang et al., "Deep music analogy via latent representation disentanglement," ISMIR 2019. + * chord_table: the conditional chord progression from the lead sheet. Shape: (T', 36), quantized at 1-beat level. This format is defined in Z. Wang et al., "Learning interpretable representation for controllable polyphonic music generation," ISMIR 2020. + * seg_query: phrase annotation for the lead sheet. Format of each phrase: (label, length, start). For example, seg_query=[('A', 8, 0), ('A', 8, 8), ('B', 4, 16)]. + * indices: the indices of selected texture donor phrases in the acc_pool. + * shifts: pitch transposition of each selected phrase. + * acc_pool: search space for piano texture donors. + * tempo: the tempo to render the piece. + """ + acc_roll = np.empty((0, 128)) + vel_roll = [] + phrase_mean_vel = [] + cc_roll = np.empty((0, 128)) + #retrive texture donor data of the corrresponding indices from the acc_pool + for i, idx in enumerate(indices): + length = query_phrases[i][-2] + shift = shifts[i] + # notes + acc_matrix = np.roll(acc_pool[length][1][idx[0]], shift, axis=-1) + acc_roll = np.concatenate((acc_roll, acc_matrix), axis=0) + #MIDI velocity + vel_matrix = np.roll(acc_pool[length][3][idx[0]], shift, axis=-1) + phrase_mean_vel.append(np.mean(np.ma.masked_equal(vel_matrix, value=0))) + vel_roll.append(vel_matrix) + #MIDI control messages (mainly for pedals) + cc_matrix = acc_pool[length][4][idx[0]] + cc_roll = np.concatenate((cc_roll, cc_matrix), axis=0) + # normalize the scale of velocity across different retrieved phrases + global_mean_vel = np.mean(np.ma.masked_equal(np.concatenate(vel_roll, axis=0), value=0)) + for i in range(len(vel_roll)): + vel_roll[i][vel_roll[i] > 0] += (global_mean_vel - phrase_mean_vel[i]) + vel_roll = np.concatenate(vel_roll, axis=0) + #re-harmonization + if len(acc_roll) % 32 != 0: + pad_len = (len(acc_roll)//32+1)*32 - len(acc_roll) + acc_roll = np.pad(acc_roll, ((0, pad_len), (0, 0))) + vel_roll = np.pad(vel_roll, ((0, pad_len), (0, 0))) + cc_roll = np.pad(cc_roll, ((0, pad_len), (0, 0)), mode='constant', constant_values=-1) + chord_table = np.pad(chord_table, ((0, pad_len//4), (0, 0))) + chord_table[-pad_len:, 0] = -1 + chord_table[-pad_len:, -1] = -1 + acc_roll = acc_roll.reshape(-1, 32, 128) + chord_table = chord_table.reshape(-1, 8, 36) + acc_roll = torch.from_numpy(acc_roll).float().cuda() + acc_roll = torch.clip(acc_roll, min=0, max=31) + gt_chord = torch.from_numpy(chord_table).float().cuda() + est_x = model.inference(acc_roll, gt_chord, sample=False) + acc_roll = cvt.grid2pr(est_x.reshape(-1, 15, 6)) + #interpolate MIDI velocity + adapt_vel_roll = np.zeros(vel_roll.shape) + masked_dyn_matrix = np.ma.masked_equal(vel_roll, value=0) + mean = np.mean(masked_dyn_matrix, axis=-1) + onsets = np.nonzero(mean.data) + dynamic = mean.data[onsets] + onsets = onsets[0].tolist() + dynamic = dynamic.tolist() + if not 0 in onsets: + onsets = [0] + onsets + dynamic = [dynamic[0]] + dynamic + if not len(vel_roll)-1 in onsets: + onsets = onsets + [len(vel_roll)-1] + dynamic = dynamic + [dynamic[-1]] + dyn_curve = interp1d(onsets, dynamic) + for t, p in zip(*np.nonzero(acc_roll)): + adapt_vel_roll[t, p] = dyn_curve(t) + adapt_vel_roll = np.clip(adapt_vel_roll, a_min=0, a_max=127) + #reconstruct MIDI + accompaniment = np.stack([acc_roll, adapt_vel_roll, cc_roll], axis=-1)[np.newaxis, :, :, :] + midi_recon = cvt.matrix2midi_with_dynamics(accompaniment, programs=[0], init_tempo=tempo) + melody_track = cvt.melody_matrix2data(melody_matrix=lead_sheet[:, :130], tempo=tempo) + midi_recon.instruments = [melody_track] + midi_recon.instruments + if get_est: + return midi_recon, est_x + else: + return midi_recon + +def ref_spotlight(ref_name_list, reference_check): + """convert spotlight song/artist names into the indices of corresponding pieces in the dataset.""" + if ref_name_list is None: + return None + check_idx = [] + #POP909 song_id + for name in ref_name_list: + line = reference_check[reference_check.song_id == name] + if not line.empty: + check_idx.append(line.index)#read by pd, neglect first row, index starts from 0. + #song name + for name in ref_name_list: + line = reference_check[reference_check.name == name] + if not line.empty: + check_idx.append(line.index)#read by pd, neglect first row, index starts from 0. + #artist name + for name in ref_name_list: + line = reference_check[reference_check.artist == name] + if not line.empty: + check_idx += list(line.index)#read by pd, neglect first row, index starts from 0 + return check_idx + + +def get_texture_filter(acc_pool): + """Divide accompaniment texture donors into fifths in terms of voice number (VN) and rhythmic density (RD).""" + texture_filter = {} + for key in acc_pool: + acc_track = acc_pool[key][1] + # CALCULATE HORIZONTAL DENSITY (rhythmic density) + onset_positions = (np.sum(acc_track, axis=-1) > 0) * 1. + HD = np.sum(onset_positions, axis=-1) / acc_track.shape[1] #(N) + # CALCULATE VERTICAL DENSITY (voice number) + beat_positions = acc_track[:, ::4, :] + downbeat_positions = acc_track[:, ::16, :] + upbeat_positions = acc_track[:, 2::4, :] + + simu_notes_on_beats = np.sum((beat_positions > 0) * 1., axis=-1) #N*T + simu_notes_on_downbeats = np.sum((downbeat_positions > 0) * 1., axis=-1) + simu_notes_on_upbeats = np.sum((upbeat_positions > 0) * 1., axis=-1) + + VD_beat = np.sum(simu_notes_on_beats, axis=-1) / (np.sum((simu_notes_on_beats > 0) * 1., axis=-1) + 1e-10) + VD_upbeat = np.sum(simu_notes_on_upbeats, axis=-1) / (np.sum((simu_notes_on_upbeats > 0) * 1., axis=-1) + 1e-10) + + VD = np.max(np.stack((VD_beat, VD_upbeat), axis=-1), axis=-1) + #get five-equal-divident-points of HD + dst = np.sort(HD) + HD_anchors = [dst[len(dst) // 5], dst[len(dst) // 5 * 2], dst[len(dst) // 5 * 3], dst[len(dst) // 5 * 4]] + HD_Bins = [ + HD < HD_anchors[0], + (HD >= HD_anchors[0]) * (HD < HD_anchors[1]), + (HD >= HD_anchors[1]) * (HD < HD_anchors[2]), + (HD >= HD_anchors[2]) * (HD < HD_anchors[3]), + HD >= HD_anchors[3] + ] + #get five-equal-divident-points of VD + dst = np.sort(VD) + VD_anchors = [dst[len(dst) // 5], dst[len(dst) // 5 * 2], dst[len(dst) // 5 * 3], dst[len(dst) // 5 * 4]] + VD_Bins = [ + VD < VD_anchors[0], + (VD >= VD_anchors[0]) * (VD < VD_anchors[1]), + (VD >= VD_anchors[1]) * (VD < VD_anchors[2]), + (VD >= VD_anchors[2]) * (VD < VD_anchors[3]), + VD >= VD_anchors[3] + ] + texture_filter[key] = (HD_Bins, VD_Bins) #((5, N), (5, N)) + return texture_filter \ No newline at end of file diff --git a/piano_arranger/__init__.py b/piano_arranger/__init__.py new file mode 100644 index 0000000..4450f02 --- /dev/null +++ b/piano_arranger/__init__.py @@ -0,0 +1,2 @@ +from .AccoMontage import set_premises, load_lead_sheet, phrase_selection, re_harmonization +from .format_converter import matrix2leadsheet \ No newline at end of file diff --git a/piano_arranger/acc_utils.py b/piano_arranger/acc_utils.py new file mode 100644 index 0000000..0690634 --- /dev/null +++ b/piano_arranger/acc_utils.py @@ -0,0 +1,186 @@ +import sys +import numpy as np + + +def melodySplit(matrix, WINDOWSIZE=32, HOPSIZE=16, VECTORSIZE=142): + """Clip a (melody) sequence into short WINDOWSIZE-step snippets under a hop size of HOPSIZE. This function is fitted for 1/4-beat quantized sequence.""" + start_downbeat = 0 + end_downbeat = matrix.shape[0]//16 + assert(end_downbeat - start_downbeat >= 2) + splittedMatrix = np.empty((0, WINDOWSIZE, VECTORSIZE)) + for idx_T in range(start_downbeat*16, (end_downbeat-(WINDOWSIZE//16 -1))*16, HOPSIZE): + if idx_T > matrix.shape[0]-32: + break + sample = matrix[idx_T:idx_T+WINDOWSIZE, :VECTORSIZE][np.newaxis, :, :] + #print(sample.shape) + splittedMatrix = np.concatenate((splittedMatrix, sample), axis=0) + return splittedMatrix + + +def chordSplit(chord, WINDOWSIZE=8, HOPSIZE=8): + """Clip a chord sequence into short WINDOWSIZE-step snippets under a hop size of HOPSIZE. This function is fitted for 1-beat quantized sequence.""" + start_downbeat = 0 + end_downbeat = chord.shape[0]//4 + splittedChord = np.empty((0, WINDOWSIZE, 36)) + for idx_T in range(start_downbeat*4, (end_downbeat-(WINDOWSIZE//4 -1))*4, HOPSIZE): + if idx_T > chord.shape[0]-8: + break + sample = chord[idx_T:idx_T+WINDOWSIZE, :][np.newaxis, :, :] + splittedChord = np.concatenate((splittedChord, sample), axis=0) + return splittedChord + + +def split_phrases(segmentation): + """Split a phrase label string into individual phrase meta info""" + if '\n' not in segmentation: + segmentation += '\n' + phrases = [] + lengths = [] + current = 0 + while segmentation[current] != '\n': + if segmentation[current].isalpha(): + j = 1 + while not (segmentation[current + j].isalpha() or segmentation[current + j] == '\n'): + j += 1 + phrases.append(segmentation[current]) + lengths.append(int(segmentation[current+1: current+j])) + current += j + return [(phrases[i], lengths[i], sum(lengths[:i])) for i in range(len(phrases))] + + +def chord_shift(prChordSet): + """Transpose a chord sequence to all 12keys (batch processing).""" + #prChordSet: (batch, time, feature_dim) + if prChordSet.shape[-1] == 14: + prChordSet = prChordSet[:, :, 1: -1] + elif prChordSet.shape[-1] == 12: + pass + else: + print('Chord Dimention Error') + sys.exit() + num_total = prChordSet.shape[0] + shift_const = [-6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5] + shifted_ensemble = [] + for i in shift_const: + shifted_term = np.roll(prChordSet, i, axis=-1) + shifted_ensemble.append(shifted_term) + shifted_ensemble = np.array(shifted_ensemble) #num_pitches * num_pieces * duration * size #.reshape((-1, prChordSet.shape[1], prChordSet.shape[2])) + return shifted_ensemble, num_total, shift_const + + +def computeTIV(chroma): + """Compute TIV for a sequence of chords. TIV credit to G. Bernardes et al., "A multi-level tonal interval space for modelling pitch relatedness and musical consonance," JNMR, 2016.""" + #inpute size: Time*12 + if (len(chroma.shape)) == 4: + num_pitch = chroma.shape[0] + num_pieces = chroma.shape[1] + chroma = chroma.reshape((-1, 12)) + chroma = chroma / (np.sum(chroma, axis=-1)[:, np.newaxis] + 1e-10) #Time * 12 + TIV = np.fft.fft(chroma, axis=-1)[:, 1: 7] #Time * (6*2) + TIV = np.concatenate((np.abs(TIV), np.angle(TIV)), axis=-1) #Time * 12 + TIV = TIV.reshape((num_pitch, num_pieces, -1, 12)) + else: + chroma = chroma / (np.sum(chroma, axis=-1)[:, np.newaxis] + 1e-10) #Time * 12 + TIV = np.fft.fft(chroma, axis=-1)[:, 1: 7] #Time * (6*2) + TIV = np.concatenate((np.abs(TIV), np.angle(TIV)), axis=-1) #Time * 12 + return TIV #Time * 12 + + +def cosine(query, instance_space): + """Calculate cosine similarity cos(query, reference). This function is fitted for the chord format.""" + #query: batch_Q * T * 12 + #instance_space: 12 * batch_R * T * 12 + batch_Q, _, _ = query.shape + shift, batch_R, time, chroma = instance_space.shape + query = query.reshape((batch_Q, -1))[np.newaxis, :, :] + instance_space = instance_space.reshape((shift, batch_R, -1)) + #result: 12 * Batch_Q * Batch_R + result = np.matmul(query, np.transpose(instance_space, (0, 2, 1))) / (np.linalg.norm(query, axis=-1, keepdims=True) * np.transpose(np.linalg.norm(instance_space, axis=-1, keepdims=True), (0, 2, 1)) + 1e-10) + #result: Batch_Q * Batch_R + chord_result = np.max(result, axis=0) + arg_result = np.argmax(result, axis=0) + return chord_result[0], arg_result[0] + + +def cosine_rhy(query, instance_space): + """Calculate cosine similarity cos(query, reference). This function is fitted for the rhythm format.""" + #query: 1 * T * 3 + #instance_space: batch * T * 3 + batch_Q, _, _ = query.shape + batch_R, _, _ = instance_space.shape + query = query.reshape((batch_Q, -1)) + instance_space = instance_space.reshape((batch_R, -1)) + #result: 12 * Batch_Q * Batch_R + result = np.matmul(query, np.transpose(instance_space, (1, 0))) / (np.linalg.norm(query, axis=-1, keepdims=True) * np.transpose(np.linalg.norm(instance_space, axis=-1, keepdims=True), (1, 0)) + 1e-10) + return result[0] + + +def cosine_mel(query, instance_space): + """Calculate cosine similarity cos(query, reference). This function is fitted for the melody format.""" + #query: 1 * m + #instance_space: batch * m + #result: 12 * Batch_Q * Batch_R + result = np.matmul(query, instance_space) / (np.linalg.norm(query, axis=-1, keepdims=True) * np.linalg.norm(instance_space, axis=-1, keepdims=True) + 1e-10) + return result[0] + + +def cosine_1d(query, instance_space, segmentation): + """Calculate cosine similarity cos(query, reference). This function for general 1-D sequence in batch processing.""" + #query: T + #instance space: Batch * T + #instance_space: batch * vectorLength + final_result = np.ones((instance_space.shape[0])) + recorder = [] + start = 0 + for i in segmentation: + if i.isdigit(): + end = start + int(i) * 16 + result = np.abs(np.dot(instance_space[:, start: end], query[start: end])/(np.linalg.norm(instance_space[:, start: end], axis=-1) * np.linalg.norm(query[start: end]) + 1e-10)) + recorder.append(result) + final_result = np.multiply(final_result, result) #element-wise product + start = end + candidates = final_result.argsort()[::-1] + scores = final_result[candidates] + return candidates, scores, recorder + + +def cosine_2d(query, instance_space, segmentation, record_chord=None): + """Calculate cosine similarity cos(query, reference). This function for general 2-D sequence in batch processing.""" + final_result = np.ones((instance_space.shape[0])) + recorder = [] + start = 0 + for i in segmentation: + if i.isdigit(): + end = start + int(i) * 4 + result = np.dot(np.transpose(instance_space[:, start: end, :], (0, 2, 1)), query[start: end, :])/(np.linalg.norm(np.transpose(instance_space[:, start: end, :], (0, 2, 1)), axis=-1, keepdims=True) * np.linalg.norm(query[start: end, :], axis=0, keepdims=True) + 1e-10) + result = np.trace(result, axis1=-2, axis2=-1) /2 + recorder.append(result) + + final_result = np.multiply(final_result, result) + start = end + if not record_chord == None: + record_chord = np.array(record_chord) + recorder = np.array(recorder) + assert np.shape(record_chord) == np.shape(recorder) + final_result = np.array([(np.product(recorder[:, i]) * np.product(record_chord[:, i])) * (2 *recorder.shape[0]) for i in range(recorder.shape[1])]) + + candidates = final_result.argsort()[::-1] + scores = final_result[candidates] + return candidates, scores, recorder + + +def piano_roll_shift(prpiano_rollSet): + """Transpose a piano_roll format described in R. Yang et al., "Deep music analogy via latent representation disentanglement," ISMIR 2019.""" + num_total, timeRes, piano_shape = prpiano_rollSet.shape + shift_const = [-6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5] + shifted_ensemble = [] + for i in shift_const: + piano = prpiano_rollSet[:, :, :128] + rhythm = prpiano_rollSet[:, :, 128: 130] + chord = prpiano_rollSet[:, :, 130:] + shifted_piano = np.roll(piano, i, axis=-1) + shifted_chord = np.roll(chord, i, axis=-1) + shifted_piano_roll_set = np.concatenate((shifted_piano, rhythm, shifted_chord), axis=-1) + shifted_ensemble.append(shifted_piano_roll_set) + shifted_ensemble = np.array(shifted_ensemble).reshape((-1, timeRes, piano_shape)) + return shifted_ensemble, num_total, shift_const diff --git a/piano_arranger/chord_recognition/.gitignore b/piano_arranger/chord_recognition/.gitignore new file mode 100644 index 0000000..f063320 --- /dev/null +++ b/piano_arranger/chord_recognition/.gitignore @@ -0,0 +1,6 @@ +/cache_data +/temp +/.idea +/__pycache__ +*.pyc +/output diff --git a/piano_arranger/chord_recognition/README.TXT b/piano_arranger/chord_recognition/README.TXT new file mode 100644 index 0000000..9be3814 --- /dev/null +++ b/piano_arranger/chord_recognition/README.TXT @@ -0,0 +1,2 @@ +Credit to J. Jiang et al., "Large-vocabulary chord transcription via chord structure decomposition," ISMIR 2019. +https://github.com/music-x-lab/ISMIR2019-Large-Vocabulary-Chord-Recognition \ No newline at end of file diff --git a/piano_arranger/chord_recognition/__init__.py b/piano_arranger/chord_recognition/__init__.py new file mode 100644 index 0000000..73ca104 --- /dev/null +++ b/piano_arranger/chord_recognition/__init__.py @@ -0,0 +1,5 @@ +from . import mir +import extractors +import io_new +import chord_class +from .main import transcribe_cb1000_midi \ No newline at end of file diff --git a/piano_arranger/chord_recognition/air_structure.py b/piano_arranger/chord_recognition/air_structure.py new file mode 100644 index 0000000..5169c91 --- /dev/null +++ b/piano_arranger/chord_recognition/air_structure.py @@ -0,0 +1,375 @@ +import numpy as np +import librosa +import mir_eval.chord +from mir.music_base import get_scale_and_suffix + + +# for cb dataset, please use +# melody_beat_offset=+0.033 +# lyric_beat_offset=+0.033 +OVERFLOW_WIDTH_LIMIT=32 +TABS_WIDTH=1 + + +class AIRStructure: + ''' + AIR structure is a class to store, import and export symbolic music data by + jjy. + + AIR is (1) the abbr. of Audio-Informed Representation; (2) the abbr. of the + name AIRA (in memory of the heroine in the anime <>). + ''' + def __init__(self,beat,num_beat_division,verbose_level=1,disallow_error=False): + self.disallow_error=disallow_error + self.error_log='' + self.verbose_level=verbose_level + self.length=(beat.shape[0]-1)*num_beat_division+1 + self.input_beat=beat + self.num_beat_division=num_beat_division + self.timing,self.offset,self.is_downbeat,self.is_beat=self.__init_timing_and_offset() + self.divider=(self.timing[1:]+self.timing[:-1])/2 + self.lyric=np.full(self.length,'',dtype='=1): + print(message) + self.error_log+=message+'\n' + + def __overwrite_warning(self,pos,error_item,time): + if(pos==0): + self.log_error('Warning: Too early %s event happened at time %.2f'%(error_item,time)) + return True # allow overwrite + elif(pos==self.length-1): + self.log_error('Warning: Too late %s event happened at time %.2f'%(error_item,time)) + return False # disallow overwrite + else: + self.log_error('Warning: Overwrite %s event at time %.2f. Music is probably too fast.'%(error_item,time)) + return False # disallow overwrite + + def append_lyric(self,lyric,lyric_beat_offset=0.0): + if(len(lyric)==0 or len(lyric[0])!=4): + raise Exception('Unsupported lyric format') + if(self.phrase_count>0): + raise Exception('Lyric already appended') + self.phrase_count=0 + self.input_lyric_sentence=[] + current_sentence=[] + for token in lyric: + start_pos=self.__locate(token[0]-lyric_beat_offset) + end_pos=self.__locate(token[1]-lyric_beat_offset) + if(end_pos==start_pos): + end_pos+=1 + if(token[3]>0): + self.phrase_count+=1 + supress_warning=False + allow_write=True + for pos in range(start_pos,end_pos): + if(self.lyric[pos]!='' and not supress_warning): + supress_warning=True + allow_write&=self.__overwrite_warning(pos,'lyric',token[0]) + if(allow_write): + self.lyric[pos]='-' + self.phrase[pos]=self.phrase_count-1 + if(allow_write): + self.lyric[start_pos]=token[2] + p=-1 + p_start=0 + p_end=0 + for pos in range(self.length): + if(self.phrase[pos]>p): + self.phrase[p_start:p_end]=p + p=self.phrase[pos] + p_start=pos + elif(self.phrase[pos]==p): + p_end=pos + self.phrase[p_start:p_end]=p + + def append_melody(self,midilab,melody_beat_offset=0.0): + if(len(midilab)==0 or len(midilab[0])!=3): + raise Exception('Unsupported melody format') + for token in midilab: + start_pos=self.__locate(token[0]-melody_beat_offset) + end_pos=self.__locate(token[1]-melody_beat_offset) + if(end_pos==start_pos): + end_pos+=1 + supress_warning=False + allow_write=True + for pos in range(start_pos,end_pos): + if(self.melody[pos]!=-1 and not supress_warning): + supress_warning=True + allow_write&=self.__overwrite_warning(pos,'melody',token[0]) + if(allow_write): + self.melody[pos]=int(np.round(token[2])) + if(allow_write): + self.melody_onset[start_pos]=True + + def append_chord(self,chordlab): + if(len(chordlab)==0 or len(chordlab[0])!=3): + raise Exception('Unsupported chord format') + for token in chordlab: + start_pos=self.__locate(token[0]) + end_pos=self.__locate(token[1]) + if(end_pos==start_pos): + end_pos+=1 + if(len(token[2])>31): + raise Exception('Too long chord: %s'%token[2]) + # self.chord[start_pos]=token[2] + for pos in range(start_pos,end_pos): + self.chord[pos]=token[2] + # previous_chord_pos=-1 + # for i in range(self.length): + # if(self.chord[i]!=''): + # previous_chord_pos=i + # self.previous_chord_pos[i]=previous_chord_pos + + def append_key(self,keylab,format='mode7'): + if(format!='mode7'): + raise NotImplementedError() + if(len(keylab)==0 or len(keylab[0])!=3): + raise Exception('Unsupported key format') + MODE7_TO_STR=['X','major','dorian','phrygian','lydian','mixolydian','minor','locrian'] + + for token in keylab: + start_pos=self.__locate(token[0]) + end_pos=self.__locate(token[1]) + if(end_pos==start_pos): + end_pos+=1 + key_id,mode_str=get_scale_and_suffix(token[2]) + mode_str=mode_str[1:] + if(mode_str=='maj'): + mode_str='major' + elif(mode_str=='min'): + mode_str='minor' + mode7=MODE7_TO_STR.index(mode_str) + # self.chord[start_pos]=token[2] + for pos in range(start_pos,end_pos): + self.key[pos]=mode7*12+key_id + for i in range(1,self.length): + if(self.key[i]==-1): + self.key[i]=self.key[i-1] + for i in range(self.length-2,-1,-1): + if(self.key[i]==-1): + self.key[i]=self.key[i+1] + + def export_to_array(self,export_all=False): + # convert chord labels + original_length=self.length + valid=np.zeros(original_length,dtype=np.bool) + bars=[] + if(self.phrase_count>0 and not export_all): + for phrase_id in range(self.phrase_count): + # todo: dangerous!! + bars+=PhraseRenderer(self,phrase_id).bars + for bar in bars: + valid[bar[0]:bar[1]]=True + else: + valid[:]=True + valid[-1]=False + export_length=int(np.sum(valid)) + export_melody=self.melody[valid].reshape((export_length,1)) + export_melody_onset=self.melody_onset[valid].reshape((export_length,1)) + #todo: chord N+ + #todo: blank chord? + chord=self.chord[valid] + if('' in chord): + #print('Warning: blank chord encountered, regard as N') + chord=[text if text!='' else 'N' for text in chord] + export_chord_root,export_chord_chroma,export_chord_bass=mir_eval.chord.encode_many(chord) + export_chord_chroma=mir_eval.chord.rotate_bitmaps_to_roots(export_chord_chroma,export_chord_root) + export_chord_root=export_chord_root.reshape((export_length,1)) + export_chord_bass=export_chord_bass.reshape((export_length,1)) + export_downbeat_pos=(self.offset[valid]-1).reshape((export_length,1)) + export_beat_pos=self.is_beat[valid].astype(np.int32) + # todo: incomplete downbeat + last_beat_pos=0 + for i in range(export_length): + if(export_beat_pos[i]): + last_beat_pos=0 + else: + last_beat_pos+=1 + export_beat_pos[i]=last_beat_pos + export_beat_pos=export_beat_pos.reshape((export_length,1)) + export_start=self.timing[valid].reshape((export_length,1)) + export_end=self.timing[1:][valid[:-1]].reshape((export_length,1)) + export_phrase=self.phrase[valid].reshape((export_length,1)) + # ensure compatibility + export_key=self.key[valid].reshape((export_length,1)) if 'key' in self.__dict__ else np.ones((export_length,1),dtype=np.int32)*-1 + return np.hstack((export_start,export_end)),\ + np.hstack((export_melody,export_melody_onset,export_chord_root,export_chord_chroma,export_chord_bass,export_downbeat_pos,export_beat_pos,export_phrase,export_key)) + + def __init_timing_and_offset(self): + timing=np.zeros((self.length)) + offset=np.zeros((self.length),dtype=np.int32) + is_downbeat=np.zeros((self.length),dtype=np.bool) + is_beat=np.zeros((self.length),dtype=np.bool) + for i in range(self.input_beat.shape[0]): + if(np.round(self.input_beat[i,1])==1.0): + is_downbeat[i*self.num_beat_division]=True + is_beat[i*self.num_beat_division]=True + for i in range(self.input_beat.shape[0]-1): + cur_time=self.input_beat[i,0] + next_time=self.input_beat[i+1,0] + timing[i*self.num_beat_division:(i+1)*self.num_beat_division+1]=np.linspace(cur_time,next_time,self.num_beat_division+1) + offset[i*self.num_beat_division:(i+1)*self.num_beat_division]=np.round(self.input_beat[i,1]) + offset[-1]=np.round(self.input_beat[-1,1]) + for i in range(self.length-1): + assert(timing[i]=0 and air.phrase[index]!=phrase_id): + break + if(air.is_downbeat[index]): + prev_bar_pos=index + break + for index in valid_index: + if(air.is_downbeat[index]): + if(prev_bar_pos>=0): + bars.append((prev_bar_pos,index,True)) + prev_bar_pos=index + if(prev_bar_pos!=-1): # no valid complete bar? + if(not air.is_downbeat[valid_index[-1]]): + for index in range(valid_index[-1],air.length): + if(air.is_downbeat[index]): + bars.append((prev_bar_pos,index,True)) + break + # self.min_bar_pos=np.min([b[0] for b in bars]) + # self.max_bar_pos=np.max([b[0] for b in bars]) + self.bars=bars + + + def render(self,unit_width,bar_group): + + data=np.full((len(self.bars),bar_group,OVERFLOW_WIDTH_LIMIT,3),'',dtype='0][i>0][k>0]+'═─'[k>0]*(unit_width//TABS_WIDTH) + str+='╗╢'[k>0]+'\n' + for i in range(len(data)): + budget=0 + for j in range(len(data[i])): + if(budget==0): + str+='║│'[j>0] + budget+=unit_width + else: # extra budget for the missing line + budget+=unit_width+TABS_WIDTH + skip=0 + for c in data[i][j][:,k]: + if(c!='' and ord(c)>256): + if(budget>=2): + skip+=1 + str+=c + budget-=2 + continue + c='?' + if(skip==0 and (c!='' or budget>0)): + str+=c if c!='' else ' ' + budget-=1 + else: + skip-=1 + if(budget>0): + str+=' '*budget + budget=0 + str+='║\n' + for i in range(len(data)): + for j in range(len(data[i])): + str+=['╚╩','╧╧'][j>0][i>0]+'═'*(unit_width//TABS_WIDTH) + str+='╝\n' + return str + + +class AIRStructureRenderer(): + + def to_string(self,air:AIRStructure,unit_width,bar_group): + result='' + if(air.phrase_count==0): + result+=PhraseRenderer(air,-1).to_string(unit_width,bar_group)+'\n' + else: + for phrase_id in range(air.phrase_count): + result+=PhraseRenderer(air,phrase_id).to_string(unit_width,bar_group)+'\n' + result=air.error_log+'\n'+result + return result diff --git a/piano_arranger/chord_recognition/chord_class.py b/piano_arranger/chord_recognition/chord_class.py new file mode 100644 index 0000000..7c2ba22 --- /dev/null +++ b/piano_arranger/chord_recognition/chord_class.py @@ -0,0 +1,129 @@ +import numpy as np +# print('WARNING: DECODING CHORD MAJ MIN ONLY') +# Copied from mir_eval.chord +QUALITIES = { + # 1 2 3 4 5 6 7 + 'maj': [1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0], + 'min': [1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0], + 'aug': [1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0], + 'dim': [1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0], + 'sus4': [1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0], + 'sus4(b7)':[1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0], + 'sus4(b7,9)':[1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0], + 'sus2': [1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0], + '7': [1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0], + 'maj7': [1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1], + 'min7': [1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0], + 'minmaj7': [1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1], + 'maj6': [1, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0], + 'min6': [1, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0], + '9': [1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0], + 'maj9': [1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1], + 'min9': [1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0], + '7(#9)': [1, 0, 0, 1, 1, 0, 0, 1, 0, 0, 1, 0], + 'maj6(9)': [1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0], + 'min6(9)': [1, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 0], + 'maj(9)': [1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0], + 'min(9)': [1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0], + 'maj(11)': [1, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 1], + 'min(11)': [1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1], + '11': [1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0], + 'maj9(11)':[1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 1], + 'min11': [1, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1, 0], + '13': [1, 0, 1, 0, 1, 1, 0, 1, 0, 1, 1, 0], + 'maj13': [1, 0, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1], + 'min13': [1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 1, 0], + 'dim7': [1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0], + 'hdim7': [1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0], + #'5': [1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0] + } + +INVERSIONS={ + 'maj':['3','5'], + 'min':['b3','5'], + '7':['3','5','b7'], + 'maj7':['3','5','7'], + 'min7':['5','b7'], + #'maj(9)':['2'], + #'maj(11)':['4'], + #'min(9)':['2'], + #'min(11)':['4'], +} + +NUM_TO_ABS_SCALE=['C','C#','D','Eb','E','F','F#','G','Ab','A','Bb','B'] +NUM_TO_INVERSION=['1','b2','2','b3','3','4','b5','5','#5','6','b7','7'] + +class ChordClass: + def __init__(self): + BASS_TEMPLATE=np.array([1,0,0,0,0,0,0,0,0,0,0,0]) + EMPTY_TEMPLATE=np.array([0,0,0,0,0,0,0,0,0,0,0,0]) + self.chord_list=['N'] + self.chroma_templates=[EMPTY_TEMPLATE] + self.bass_templates=[EMPTY_TEMPLATE] + for i in range(12): + for q in QUALITIES: + original_template=np.array(QUALITIES[q]) + name='%s:%s'%(NUM_TO_ABS_SCALE[i],q) + self.chord_list.append(name) + self.chroma_templates.append(np.roll(original_template,i)) + self.bass_templates.append(np.roll(BASS_TEMPLATE,i)) + if(q in INVERSIONS): + for inv in INVERSIONS[q]: + delta_scale=NUM_TO_INVERSION.index(inv) + name='%s:%s/%s'%(NUM_TO_ABS_SCALE[i],q,inv) + self.chord_list.append(name) + self.chroma_templates.append(np.roll(original_template,i)) + self.bass_templates.append(np.roll(BASS_TEMPLATE,i+delta_scale)) + self.chroma_templates=np.array([list(entry) for entry in self.chroma_templates]) + self.bass_templates=np.array([list(entry) for entry in self.bass_templates]) + + def get_length(self): + return len(self.chord_list) + + def score(self,chroma,basschroma): + ''' + Scoring a midi segment based on the chroma & basschroma feature + :param chroma: treble chroma + :param basschroma: bass chroma + :return: A score with range (-inf, +inf) + ''' + result=np.zeros((self.get_length()),dtype=np.float64) + for i,c in enumerate(self.chord_list): + if(c=='N'): + result[i]=0.2 + else: + ref_chroma=self.chroma_templates[i] + ref_bass_chroma=self.bass_templates[i] + score=(chroma[ref_chroma>0].sum()-chroma[ref_chroma==0].sum())/(ref_chroma>0).sum()\ + +0.5*basschroma[ref_bass_chroma>0].sum()-(ref_chroma>0).sum()*0.1-('/' in c)*0.05 + result[i]=score + return result + + def batch_score(self,chromas,basschromas): + ''' + Scoring a midi segment based on the chroma & basschroma feature + :param chroma: treble chroma + :param basschroma: bass chroma + :return: A score with range (-inf, +inf) + ''' + n_batch=chromas.shape[0] + result=np.zeros((n_batch,self.get_length()),dtype=np.float64) + for i,c in enumerate(self.chord_list): + if(c=='N'): + result[:,i]=0.2 + else: + ref_chroma=self.chroma_templates[i] + ref_bass_chroma=self.bass_templates[i] + score=(chromas[:,ref_chroma>0].sum(axis=1)-chromas[:,ref_chroma==0].sum(axis=1))/(ref_chroma>0).sum()\ + +0.5*basschromas[:,ref_bass_chroma>0].sum(axis=1)-(ref_chroma>0).sum()*0.1-('/' in c)*0.05 + result[:,i]=score + return result +if __name__ == '__main__': + # perform some sanity checks + chord_class=ChordClass() + #for i,c in enumerate(chord_class.chord_list): + # print(c,chord_class.chroma_templates[i],chord_class.bass_templates[i]) + print(list(zip(chord_class.chord_list,chord_class.score( + np.array([1,0,0,0,1,0,0,1,0,0,0,0]), + np.array([0,0,0,0,1,0,0,0,0,0,0,0]), + )))) diff --git a/piano_arranger/chord_recognition/complex_chord.py b/piano_arranger/chord_recognition/complex_chord.py new file mode 100644 index 0000000..bc43d5a --- /dev/null +++ b/piano_arranger/chord_recognition/complex_chord.py @@ -0,0 +1,320 @@ +from __future__ import print_function +import numpy as np + +def get_scale_and_suffix(name): + result="C*D*EF*G*A*B".index(name[0]) + prefix_length=1 + if (len(name) > 1): + if (name[1] == 'b'): + result = result - 1 + if (result<0): + result+=12 + prefix_length=2 + if (name[1] == '#'): + result = result + 1 + if (result>=12): + result-=12 + prefix_length=2 + return result,name[prefix_length:] + +def scale_name_to_value(name): + result="1*2*34*5*6*78*9".index(name[-1]) # 8 and 9 are for weird tagging in some mirex chords + return (result-name.count('b')+name.count('#')+12)%12 + +NUM_TO_ABS_SCALE=['C','C#','D','Eb','E','F','F#','G','Ab','A','Bb','B'] + +def enum_to_list(cls,valid_only): + items=[item for item in cls.__dict__.items() if not item[0].startswith('_') and (not valid_only or item[1]>=0)] + return sorted(items, key=lambda items:items[1]) + +def enum_to_dict(cls): + return {item[1]:item[0] for item in cls.__dict__.items() if not item[0].startswith('_')} + +class TriadTypes: + x=-2 + none=0 + maj=1 + min=2 + sus4=3 + sus2=4 + dim=5 + aug=6 + power=7 + one=8 + # warning: before adding new chord, consider change the data type to + # int16 instead of int8 +class SeventhTypes: + unknown=-2 + not_care=-1 + none=0 + add_7=1 + add_b7=2 + add_bb7=3 + +class NinthTypes: + unknown=-2 + not_care=-1 + none=0 + add_9=1 + add_s9=2 + add_b9=3 + +class EleventhTypes: + unknown=-2 + not_care=-1 + none=0 + add_11=1 + add_s11=2 + +class ThirteenthTypes: + unknown=-2 + not_care=-1 + none=0 + add_13=1 + add_b13=2 + add_bb13=3 + +class SuffixDecoder: + BASIC_TYPES=['.','maj','min','sus4','sus2','dim','aug','5','1'] + EXTENDED_TYPES={ + 'maj6':[TriadTypes.maj,0,0,0,ThirteenthTypes.add_13], + 'min6':[TriadTypes.min,0,0,0,ThirteenthTypes.add_13], + '7':[TriadTypes.maj,SeventhTypes.add_b7,0,0,0], + 'maj7':[TriadTypes.maj,SeventhTypes.add_7,0,0,0], + 'min7':[TriadTypes.min,SeventhTypes.add_b7,0,0,0], + 'minmaj7':[TriadTypes.min,SeventhTypes.add_7,0,0,0], + 'dim7':[TriadTypes.dim,SeventhTypes.add_bb7,0,0,0], + 'hdim7':[TriadTypes.dim,SeventhTypes.add_b7,0,0,0], + '9':[TriadTypes.maj,SeventhTypes.add_b7,NinthTypes.add_9,0,0], + '#9':[TriadTypes.maj,SeventhTypes.add_b7,NinthTypes.add_s9,0,0], + 'maj9':[TriadTypes.maj,SeventhTypes.add_7,NinthTypes.add_9,0,0], + 'min9':[TriadTypes.min,SeventhTypes.add_b7,NinthTypes.add_9,0,0], + '11':[TriadTypes.maj,SeventhTypes.add_b7,NinthTypes.add_9,EleventhTypes.add_11,0], + 'min11':[TriadTypes.min,SeventhTypes.add_b7,NinthTypes.add_9,EleventhTypes.add_11,0], + '13':[TriadTypes.maj,SeventhTypes.add_b7,NinthTypes.add_9,EleventhTypes.add_11,ThirteenthTypes.add_13], + 'maj13':[TriadTypes.maj,SeventhTypes.add_7,NinthTypes.add_9,EleventhTypes.add_11,ThirteenthTypes.add_13], + 'min13':[TriadTypes.min,SeventhTypes.add_b7,NinthTypes.add_9,EleventhTypes.add_11,ThirteenthTypes.add_13], + '':[TriadTypes.one,0,0,0,0], + 'N':[TriadTypes.none,-2,-2,-2,-2], + 'X':[-2,-2,-2,-2,-2] + } + + ADD_NOTES={ + '7':[7,SeventhTypes.add_7], + 'b7':[7,SeventhTypes.add_b7], + 'bb7':[7,SeventhTypes.add_bb7], + '2':[9,NinthTypes.add_9], + '9':[9,NinthTypes.add_9], + '#9':[9,NinthTypes.add_s9], + 'b9':[9,NinthTypes.add_b9], + '4':[11,EleventhTypes.add_11], + '11':[11,EleventhTypes.add_11], + '#11':[11,EleventhTypes.add_s11], + '13':[13,ThirteenthTypes.add_13], + 'b13':[13,ThirteenthTypes.add_b13], + '6':[6,ThirteenthTypes.add_13], + 'b6':[6,ThirteenthTypes.add_b13], + 'bb6':[6,ThirteenthTypes.add_bb13], + '#4':[5,TriadTypes.x], + 'b5':[5,TriadTypes.x], + '5':[5,TriadTypes.x], + '#5':[5,TriadTypes.x], + 'b3':[3,TriadTypes.x], + 'b2':[3,TriadTypes.x], + '3':[3,TriadTypes.x] + } + @staticmethod + def parse_chord_type(str): + if(str in __class__.BASIC_TYPES): + return [__class__.BASIC_TYPES.index(str), + SeventhTypes.none, + NinthTypes.none, + EleventhTypes.none, + ThirteenthTypes.none] + elif(str in __class__.EXTENDED_TYPES): + return __class__.EXTENDED_TYPES[str].copy() + else: + raise Exception("Unknown chord type "+str) + + @staticmethod + def decode(str): + if('(' in str): + assert(str[-1]==')') + bracket_pos=str.index('(') + chord_type_str=str[:bracket_pos] + add_omit_notes=str[bracket_pos+1:-1].split(',') + omit_notes=[str[1:] for str in add_omit_notes if str.startswith('*')] + add_notes=[str for str in add_omit_notes if not str.startswith('*')] + else: + chord_type_str=str + add_notes=[] + omit_notes=[] + result=__class__.parse_chord_type(chord_type_str) + + if(len(omit_notes)>0): + valid_omit_types=['1','b3','3','b5','5','b7','7'] + omit_found=[False]*len(valid_omit_types) + for omit_note in omit_notes: + if(omit_note not in valid_omit_types): + raise Exception('Invalid omit type %s in %s'%(omit_note,str)) + omit_found[valid_omit_types.index(omit_note)]=True + if(result[0]==TriadTypes.maj and omit_found[2]): + result[0]=TriadTypes.power + omit_found[2]=False + elif(result[0]==TriadTypes.min and omit_found[1]): + result[0]=TriadTypes.power + omit_found[1]=False + if(result[0]==TriadTypes.power and omit_found[4]): + result[0]=TriadTypes.one + omit_found[4]=False + if(omit_found[0] or omit_found[1] or omit_found[2] or omit_found[3] or omit_found[4]): + result[0]=TriadTypes.x + if(result[1]==SeventhTypes.add_b7 and omit_found[5]): + result[1]=SeventhTypes.none + omit_found[5]=False + elif(result[1]==SeventhTypes.add_7 and omit_found[6]): + result[1]=SeventhTypes.none + omit_found[6]=False + if(omit_found[5] or omit_found[6]): + result[1]=SeventhTypes.unknown + + for note in add_notes: + if(note=='1'): + continue + elif(note=='5' and result[0]==TriadTypes.one): + result[0]=TriadTypes.power + elif(note in __class__.ADD_NOTES): + [dec_class,dec_type]=__class__.ADD_NOTES[note] + dec_index=[-1,-1,-1,0,-1,0,4,1,-1,2,-1,3,-1,4][dec_class] + if(result[dec_index]>0 or result[dec_index]==-2): + result[dec_index]=-2 + result[dec_index]=dec_type + else: + raise Exception('Unknown decoration '+note+' @ '+str) + return result + +class ChordTypeLimit: + def __init__(self,triad_limit,seventh_limit,ninth_limit,eleventh_limit,thirteenth_limit): + self.triad_limit=triad_limit + self.seventh_limit=seventh_limit + self.ninth_limit=ninth_limit + self.eleventh_limit=eleventh_limit + self.thirteenth_limit=thirteenth_limit + + self.bass_slice_begin=self.triad_limit*12+1 + self.seventh_slice_begin=self.bass_slice_begin+12 + self.ninth_slice_begin=self.seventh_slice_begin+12*(self.seventh_limit+1) + self.eleventh_slice_begin=self.ninth_slice_begin+12*(self.ninth_limit+1) + self.thirteenth_slice_begin=self.eleventh_slice_begin+12*(self.eleventh_limit+1) + self.output_dim=self.thirteenth_slice_begin+12*(self.thirteenth_limit+1) + + def to_string(self): + return '[%d %d %d %d %d]'%\ + (self.triad_limit,self.seventh_limit,self.ninth_limit,self.eleventh_limit,self.thirteenth_limit) + + +class Chord: + def __init__(self,name): + if(':' in name): + self.root,suffix=get_scale_and_suffix(name) + assert(suffix[0]==':') + suffix=suffix[1:] + self.bass=self.root + if('/' in suffix): + slash_pos=suffix.index('/') + bass_str=suffix[slash_pos+1:] + self.bass=(scale_name_to_value(bass_str)+self.root)%12 + suffix=suffix[:slash_pos] + [self.triad,self.seventh,self.ninth,self.eleventh,self.thirteenth]=SuffixDecoder.decode(suffix) + elif(name=='N'): + self.root=-1 + self.bass=-1 + [self.triad,self.seventh,self.ninth,self.eleventh,self.thirteenth]=SuffixDecoder.decode('N') + elif(name=='X'): + self.root=-2 + self.bass=-2 + [self.triad,self.seventh,self.ninth,self.eleventh,self.thirteenth]=SuffixDecoder.decode('X') + else: + raise Exception("Unknown chord name "+name) + # print(name,self.root,self.bass,[self.triad,self.seventh,self.ninth,self.eleventh,self.thirteenth]) + + def to_numpy(self): + if(self.triad<=0): + triad=self.triad + else: + triad=(self.triad-1)*12+1+self.root + return np.array([triad,self.bass,self.seventh,self.ninth,self.eleventh,self.thirteenth],dtype=np.int8) + +def complex_chord_chop(id,limit): + new_id=id.copy() + if(new_id[0]>=limit.triad_limit*12+1): + new_id[0]=-2 + if(new_id[2]>limit.seventh_limit): + new_id[2]=-2 + if(new_id[3]>limit.ninth_limit): + new_id[3]=-2 + if(new_id[4]>limit.eleventh_limit): + new_id[4]=-2 + if(new_id[5]>limit.thirteenth_limit): + new_id[5]=-2 + return new_id + +def complex_chord_chop_list(ids,limit): + new_ids=ids.copy() + new_ids[new_ids[:,0]>=limit.triad_limit*12+1]=-2 + new_ids[new_ids[:,2]>limit.seventh_limit]=-2 + new_ids[new_ids[:,3]>limit.ninth_limit]=-2 + new_ids[new_ids[:,4]>limit.eleventh_limit]=-2 + new_ids[new_ids[:,5]>limit.thirteenth_limit]=-2 + return new_ids + +def shift_complex_chord_array(array,shift): + new_array=array.copy() + if(new_array[0]>0): + base=(new_array[0]-1)//12 + root=((new_array[0]-1+shift)%12+12)%12 + new_array[0]=base*12+root+1 + if(new_array[1]>=0): + new_array[1]=((new_array[1]+shift)%12+12)%12 + return new_array + +def shift_complex_chord_array_list(array,shift): + new_array=np.array(array).copy() + root_shift_indices=new_array[:,0]>0 + new_bases=(new_array[root_shift_indices,0]-1)//12 + new_roots=((new_array[root_shift_indices,0]-1+shift)%12+12)%12 + new_array[root_shift_indices,0]=new_bases*12+new_roots+1 + + bass_valid_indices=new_array[:,1]>=0 + new_array[bass_valid_indices,1]=((new_array[bass_valid_indices,1]+shift)%12+12)%12 + return new_array + +def create_tag_list(chord_limit): + result=['N'] + triad_dict=enum_to_dict(TriadTypes) + for i in range(1,chord_limit.triad_limit+1): + result+=['%s:%s'%(NUM_TO_ABS_SCALE[j],triad_dict[i]) for j in range(12)] + result+=['bass %s'%(NUM_TO_ABS_SCALE[j]) for j in range(12)] + seventh_dict=enum_to_dict(SeventhTypes) + result+=['%s'%(seventh_dict[i]) for i in range(chord_limit.seventh_limit+1)] + ninth_dict=enum_to_dict(NinthTypes) + result+=['%s'%(ninth_dict[i]) for i in range(chord_limit.ninth_limit+1)] + eleventh_dict=enum_to_dict(EleventhTypes) + result+=['%s'%(eleventh_dict[i]) for i in range(chord_limit.eleventh_limit+1)] + thirteenth_dict=enum_to_dict(ThirteenthTypes) + result+=['%s'%(thirteenth_dict[i]) for i in range(chord_limit.thirteenth_limit+1)] + return result + +if __name__ == '__main__': + # perform some tests + x=[[0,1,0,0,0,0],[12,0,0,0,0,0],[13,-1,2,0,0,0],[1,11,0,1,0,1]] + print(shift_complex_chord_array_list(x,2)) + f=open('data/full_chord_list.txt','r') + test_chord_names=f.readlines() + limit=ChordTypeLimit(triad_limit=3,seventh_limit=2,ninth_limit=1,eleventh_limit=1,thirteenth_limit=1) + for chord_name in test_chord_names: + chord_name=chord_name.strip() + if(chord_name!=''): + c=Chord(chord_name) + print(chord_name,c.to_numpy()) + f.close() \ No newline at end of file diff --git a/piano_arranger/chord_recognition/extractors/midi_utilities.py b/piano_arranger/chord_recognition/extractors/midi_utilities.py new file mode 100644 index 0000000..c80add0 --- /dev/null +++ b/piano_arranger/chord_recognition/extractors/midi_utilities.py @@ -0,0 +1,180 @@ +from mir.extractors import ExtractorBase +from io_new.downbeat_io import DownbeatIO +from mir import io +import numpy as np +from pretty_midi import PitchBend,pitch_bend_to_semitones,PrettyMIDI + +class MidiBeatExtractor(ExtractorBase): + + def get_feature_class(self): + return DownbeatIO + + def extract(self,entry,**kwargs): + extra_division=kwargs['div'] if 'div' in kwargs else 1 + midi=entry.midi + beats=midi.get_beats() + if(extra_division>1): + beat_interp=np.linspace(beats[:-1],beats[1:],extra_division+1).T + last_beat=beat_interp[-1,-1] + beats=np.append(beat_interp[:,:-1].reshape((-1)),last_beat) + downbeats=midi.get_downbeats() + j=0 + beat_pos=-2 + result=[] + for i in range(len(beats)): + if(j0) + result.append([beats[i],beat_pos]) + assert(j==len(downbeats)) + return np.array(result) + +def get_pretty_midi_energy_roll(midi, fs=100, times=None): + """Compute a piano roll matrix of the MIDI data. + + Parameters + ---------- + fs : int + Sampling frequency of the columns, i.e. each column is spaced apart + by ``1./fs`` seconds. + times : np.ndarray + Times of the start of each column in the piano roll. + Default ``None`` which is ``np.arange(0, get_end_time(), 1./fs)``. + + Returns + ------- + piano_roll : np.ndarray, shape=(128,times.shape[0]) + Piano roll of MIDI data, flattened across instruments. + + """ + + # If there are no instruments, return an empty array + if len(midi.instruments) == 0: + return np.zeros((128, 0)) + + # Get piano rolls for each instrument + piano_rolls = [get_energy_roll(i, fs=fs, times=times) + for i in midi.instruments] + # Allocate piano roll, + # number of columns is max of # of columns in all piano rolls + piano_roll = np.zeros((128, np.max([p.shape[1] for p in piano_rolls]))) + # Sum each piano roll into the aggregate piano roll + for roll in piano_rolls: + piano_roll[:, :roll.shape[1]] += roll ** 2 + return np.sqrt(piano_roll) + +def get_energy_roll(self, fs=100, times=None): + """Compute a piano roll matrix of this instrument. + + Parameters + ---------- + fs : int + Sampling frequency of the columns, i.e. each column is spaced apart + by ``1./fs`` seconds. + times : np.ndarray + Times of the start of each column in the piano roll. + Default ``None`` which is ``np.arange(0, get_end_time(), 1./fs)``. + + Returns + ------- + piano_roll : np.ndarray, shape=(128,times.shape[0]) + Piano roll of this instrument. + + """ + # If there are no notes, return an empty matrix + if self.notes == []: + return np.array([[]]*128) + # Get the end time of the last event + end_time = self.get_end_time() + # Extend end time if one was provided + if times is not None and times[-1] > end_time: + end_time = times[-1] + # Allocate a matrix of zeros - we will add in as we go + piano_roll = np.zeros((128, int(fs*end_time))) + # Drum tracks don't have pitch, so return a matrix of zeros + if is_percussive_channel(self): + if times is None: + return piano_roll + else: + return np.zeros((128, times.shape[0])) + # Add up piano roll matrix, note-by-note + for note in self.notes: + # Should interpolate + piano_roll[note.pitch, + int(note.start*fs):int(note.end*fs)] += (note.velocity/100.0)**2 + piano_roll=np.sqrt(piano_roll) + # Process pitch changes + # Need to sort the pitch bend list for the following to work + ordered_bends = sorted(self.pitch_bends, key=lambda bend: bend.time) + # Add in a bend of 0 at the end of time + end_bend = PitchBend(0, end_time) + for start_bend, end_bend in zip(ordered_bends, + ordered_bends[1:] + [end_bend]): + # Piano roll is already generated with everything bend = 0 + if np.abs(start_bend.pitch) < 1: + continue + # Get integer and decimal part of bend amount + start_pitch = pitch_bend_to_semitones(start_bend.pitch) + bend_int = int(np.sign(start_pitch)*np.floor(np.abs(start_pitch))) + bend_decimal = np.abs(start_pitch - bend_int) + # Column indices effected by the bend + bend_range = np.r_[int(start_bend.time*fs):int(end_bend.time*fs)] + # Construct the bent part of the piano roll + bent_roll = np.zeros(piano_roll[:, bend_range].shape) + # Easiest to process differently depending on bend sign + if start_bend.pitch >= 0: + # First, pitch shift by the int amount + if bend_int is not 0: + bent_roll[bend_int:] = piano_roll[:-bend_int, bend_range] + else: + bent_roll = piano_roll[:, bend_range] + # Now, linear interpolate by the decimal place + bent_roll[1:] = ((1 - bend_decimal)*bent_roll[1:] + + bend_decimal*bent_roll[:-1]) + else: + # Same procedure as for positive bends + if bend_int is not 0: + bent_roll[:bend_int] = piano_roll[-bend_int:, bend_range] + else: + bent_roll = piano_roll[:, bend_range] + bent_roll[:-1] = ((1 - bend_decimal)*bent_roll[:-1] + + bend_decimal*bent_roll[1:]) + # Store bent portion back in piano roll + piano_roll[:, bend_range] = bent_roll + + if times is None: + return piano_roll + piano_roll_integrated = np.zeros((128, times.shape[0])) + # Convert to column indices + times = np.array(times*fs, dtype=np.int) + for n, (start, end) in enumerate(zip(times[:-1], times[1:])): + # Each column is the mean of the columns in piano_roll + piano_roll_integrated[:, n] = np.mean(piano_roll[:, start:end], + axis=1) + return piano_roll_integrated + + + +class EnergyPianoRoll(ExtractorBase): + + def get_feature_class(self): + return io.SpectrogramIO + + def extract(self,entry,**kwargs): + dt=entry.prop.sr/entry.prop.hop_length + piano_roll=get_pretty_midi_energy_roll(entry.midi,dt) + return piano_roll.T + + + + + + +def is_percussive_channel(instrument): + return instrument.is_drum or instrument.program>112 # todo: are >112 instruments all percussive? + +def get_valid_channel_count(midi): + return len([ins for ins in midi.instruments if not is_percussive_channel(ins)]) diff --git a/piano_arranger/chord_recognition/extractors/rule_based_channel_reweight.py b/piano_arranger/chord_recognition/extractors/rule_based_channel_reweight.py new file mode 100644 index 0000000..2b008bc --- /dev/null +++ b/piano_arranger/chord_recognition/extractors/rule_based_channel_reweight.py @@ -0,0 +1,37 @@ +import pretty_midi +from extractors.midi_utilities import is_percussive_channel +import numpy as np + +def get_channel_thickness(piano_roll): + chroma_matrix = np.zeros((piano_roll.shape[0],12)) + for note in range(12): + chroma_matrix[:, note] = np.sum(piano_roll[:,note::12], axis=1) + thickness_array=(chroma_matrix>0).sum(axis=1) + if(thickness_array.sum()==0): + return 0 + return thickness_array[thickness_array>0].mean() + +def get_channel_bass_property(piano_roll): + result=np.argwhere(piano_roll>0)[:,1] + if(len(result)==0): + return 0.0,1.0 + return result.mean(),min(1.,len(result)/len(piano_roll)) + +def midi_to_thickness_weights(midi): + thickness=np.array([get_channel_thickness(ins.get_piano_roll().T) for ins in midi.instruments if not is_percussive_channel(ins)]) + result=1-np.exp(-(thickness-0.95)) + result/=result.max() + return result + +def midi_to_thickness_and_bass_weights(midi): + rolls=[ins.get_piano_roll().T for ins in midi.instruments if not is_percussive_channel(ins)] + thickness=np.array([get_channel_thickness(roll) for roll in rolls]) + bass=np.array([get_channel_bass_property(roll) for roll in rolls]) + bass[bass[:,1]<0.2,0]=128 + result=1-np.exp(-(thickness-0.95)) + result/=result.max() + result[np.argmin(bass[:,0])]=1.0 + + return result + + diff --git a/piano_arranger/chord_recognition/io_new/air_io.py b/piano_arranger/chord_recognition/io_new/air_io.py new file mode 100644 index 0000000..50a65f0 --- /dev/null +++ b/piano_arranger/chord_recognition/io_new/air_io.py @@ -0,0 +1,18 @@ +from mir.io.feature_io_base import * +import numpy as np +import librosa + +class AirIO(FeatureIO): + def read(self, filename, entry): + return pickle_read(self, filename) + + def write(self, data, filename, entry): + return pickle_write(self, data, filename) + + def visualize(self, data, filename, entry, override_sr): + arr=data.export_to_array() + from mir.io.implement.regional_spectrogram_io import RegionalSpectrogramIO + return RegionalSpectrogramIO().visualize(arr,filename,entry,override_sr=override_sr) + + def get_visualize_extention_name(self): + return "svl" \ No newline at end of file diff --git a/piano_arranger/chord_recognition/io_new/beat_align_io.py b/piano_arranger/chord_recognition/io_new/beat_align_io.py new file mode 100644 index 0000000..68a3767 --- /dev/null +++ b/piano_arranger/chord_recognition/io_new/beat_align_io.py @@ -0,0 +1,77 @@ +from mir.io.feature_io_base import * +from mir import PACKAGE_PATH +import numpy as np + +class BeatAlignCQTIO(FeatureIO): + + def read(self, filename, entry): + return pickle_read(self, filename) + + def write(self, data, filename, entry): + pickle_write(self, data, filename) + + def visualize(self, data, filename, entry, override_sr): + sr=entry.prop.sr + win_shift=entry.prop.hop_length + beat=entry.beat + assert(len(beat)-1==data.shape[0]) + n_frame=int(beat[-1]*sr/win_shift)+data.shape[1]+1 + new_data=np.ones((n_frame,data.shape[2]))*-1 + for i in range(len(beat)-1): + time=int(np.round(beat[i]*sr/win_shift)) + for j in range(data.shape[1]): + time_j=time+j + if(time_j>=0 and time_j'%(int(np.round(item[0]*sr)),item[2])) + content=content.replace('[__DATA__]','\n'.join(results)) + f=open(filename,'w') + f.write(content) + f.close() + + def get_visualize_extention_name(self): + return "svl" \ No newline at end of file diff --git a/piano_arranger/chord_recognition/io_new/complex_chord_io.py b/piano_arranger/chord_recognition/io_new/complex_chord_io.py new file mode 100644 index 0000000..b9f9c2a --- /dev/null +++ b/piano_arranger/chord_recognition/io_new/complex_chord_io.py @@ -0,0 +1,56 @@ +from mir.io.feature_io_base import * +import complex_chord +from mir.music_base import NUM_TO_ABS_SCALE +from mir.common import PACKAGE_PATH +import numpy as np + +class ComplexChordIO(FeatureIO): + + def read(self, filename, entry): + n_frame=entry.n_frame + f = open(filename, 'r') + line_list = f.readlines() + tags = np.ones((n_frame,6))*-2 + for line in line_list: + line=line.strip() + if(line==''): + continue + if ('\t' in line): + tokens = line.split('\t') + else: + tokens = line.split(' ') + sr=entry.prop.sr + win_shift=entry.prop.hop_length + begin=int(round(float(tokens[0])*sr/win_shift)) + end = int(round(float(tokens[1])*sr/win_shift)) + if (end > n_frame): + end = n_frame + if(begin<0): + begin=0 + tags[begin:end,:]=complex_chord.Chord(tokens[2]).to_numpy().reshape((1,6)) + f.close() + return tags + + def write(self, data, filename, entry): + raise NotImplementedError() + + + def visualize(self, data, filename, entry, override_sr): + f = open(os.path.join(PACKAGE_PATH,'data/spectrogram_template.svl'), 'r') + sr=entry.prop.sr + win_shift=entry.prop.hop_length + content = f.read() + f.close() + content = content.replace('[__SR__]', str(sr)) + content = content.replace('[__WIN_SHIFT__]', str(win_shift)) + content = content.replace('[__SHAPE_1__]', str(data.shape[1])) + content = content.replace('[__COLOR__]', str(1)) + labels = [str(i) for i in range(data.shape[1])] + content = content.replace('[__DATA__]',create_svl_3d_data(labels,data)) + f=open(filename,'w') + f.write(content) + f.close() + + + def get_visualize_extention_name(self): + return "svl" \ No newline at end of file diff --git a/piano_arranger/chord_recognition/io_new/downbeat_io.py b/piano_arranger/chord_recognition/io_new/downbeat_io.py new file mode 100644 index 0000000..60116ae --- /dev/null +++ b/piano_arranger/chord_recognition/io_new/downbeat_io.py @@ -0,0 +1,45 @@ +from mir.io.feature_io_base import * +from mir.common import PACKAGE_PATH +import numpy as np +import librosa + +class DownbeatIO(FeatureIO): + def read(self, filename, entry): + f = open(filename, 'r') + lines=f.readlines() + lines=[line.strip('\n\r') for line in lines] + lines=[line for line in lines if line!=''] + f.close() + result=np.zeros((len(lines),2)) + for i in range(len(lines)): + line=lines[i] + tokens=line.split('\t') + assert(len(tokens)==2) + result[i,0]=float(tokens[0]) + result[i,1]=float(tokens[1]) + return result + + def write(self, data, filename, entry): + f = open(filename, 'w') + for i in range(0, len(data)): + f.write('\t'.join([str(item) for item in data[i,:]])) + f.write('\n') + f.close() + + def visualize(self, data, filename, entry, override_sr): + f = open(os.path.join(PACKAGE_PATH, 'data/sparse_tag_template.svl'), 'r') + sr = override_sr + content = f.read() + f.close() + content = content.replace('[__SR__]', str(sr)) + content = content.replace('[__STYLE__]', str(1)) + output_text='' + for beat_info in data: + output_text+='\n'%(int(beat_info[0]*sr),int(beat_info[1])) + content = content.replace('[__DATA__]', output_text) + f = open(filename, 'w') + f.write(content) + f.close() + + def get_visualize_extention_name(self): + return "svl" \ No newline at end of file diff --git a/piano_arranger/chord_recognition/io_new/jams_io.py b/piano_arranger/chord_recognition/io_new/jams_io.py new file mode 100644 index 0000000..ee51e75 --- /dev/null +++ b/piano_arranger/chord_recognition/io_new/jams_io.py @@ -0,0 +1,17 @@ +from mir.io import FeatureIO +import jams + + +class JamsIO(FeatureIO): + def read(self, filename, entry): + return jams.load(filename) + + def write(self, data : jams.JAMS, filename, entry): + data.save(filename) + + def visualize(self, data : jams.JAMS, filename, entry, override_sr): + f=open(filename,'w') + for annotation in data.annotations: + for obs in annotation.data: + f.write('%f\t%f\t%s\n'%(obs.time,obs.time+obs.duration,str(obs.value))) + f.close() diff --git a/piano_arranger/chord_recognition/io_new/jointbeat_io.py b/piano_arranger/chord_recognition/io_new/jointbeat_io.py new file mode 100644 index 0000000..737901a --- /dev/null +++ b/piano_arranger/chord_recognition/io_new/jointbeat_io.py @@ -0,0 +1,44 @@ +from mir.io.feature_io_base import * +from mir.common import PACKAGE_PATH +import numpy as np + +class JointBeatIO(FeatureIO): + def read(self, filename, entry): + f = open(filename, 'r') + content = f.read() + lines=content.split('\n') + f.close() + result=[] + for i in range(len(lines)): + line=lines[i].strip() + if(line==''): + continue + tokens=line.split('\t') + assert(len(tokens)==3) + result.append([float(tokens[0]),int(tokens[1]),int(tokens[2])]) + return np.array(result) + + def write(self, data, filename, entry): + f = open(filename, 'w') + for i in range(0, len(data)): + f.write('\t'.join([str(item) for item in data[i]])) + f.write('\n') + f.close() + + def visualize(self, data, filename, entry, override_sr): + sr = override_sr + f = open(os.path.join(PACKAGE_PATH,'data/sparse_tag_template.svl'), 'r') + content = f.read() + f.close() + content = content.replace('[__SR__]', str(sr)) + content = content.replace('[__STYLE__]', str(1)) + results=[] + for item in data: + results.append(''%(int(np.round(item[0]*sr)),int(item[1]),int(item[2]))) + content=content.replace('[__DATA__]','\n'.join(results)) + f=open(filename,'w') + f.write(content) + f.close() + + def get_visualize_extention_name(self): + return "svl" \ No newline at end of file diff --git a/piano_arranger/chord_recognition/io_new/key_io.py b/piano_arranger/chord_recognition/io_new/key_io.py new file mode 100644 index 0000000..a36be68 --- /dev/null +++ b/piano_arranger/chord_recognition/io_new/key_io.py @@ -0,0 +1,42 @@ +from mir.io import FeatureIO + +class KeyIO(FeatureIO): + def read(self, filename, entry): + f = open(filename, 'r') + content = f.read() + lines=content.split('\n') + f.close() + result=[] + for i in range(len(lines)): + line=lines[i].strip() + if(line==''): + continue + tokens=line.split('\t') + assert(len(tokens)==3) + result.append([float(tokens[0]),float(tokens[1]),tokens[2]]) + return result + + def write(self, data, filename, entry): + f = open(filename, 'w') + for i in range(0, len(data)): + f.write('\t'.join([str(item).replace('\t',' ') for item in data[i]])) + f.write('\n') + f.close() + + def visualize(self, data, filename, entry, override_sr): + sr = override_sr + f = open(os.path.join(PACKAGE_PATH,'data/sparse_tag_template.svl'), 'r') + content = f.read() + f.close() + content = content.replace('[__SR__]', str(sr)) + content = content.replace('[__STYLE__]', str(1)) + results=[] + for item in data: + results.append(''%(int(np.round(item[0]*sr)),item[2])) + content=content.replace('[__DATA__]','\n'.join(results)) + f=open(filename,'w') + f.write(content) + f.close() + + def get_visualize_extention_name(self): + return "svl" \ No newline at end of file diff --git a/piano_arranger/chord_recognition/io_new/list_io.py b/piano_arranger/chord_recognition/io_new/list_io.py new file mode 100644 index 0000000..11b23db --- /dev/null +++ b/piano_arranger/chord_recognition/io_new/list_io.py @@ -0,0 +1,12 @@ +from mir.io.feature_io_base import * + +class ListIO(FeatureIO): + + def read(self, filename, entry): + return pickle_read(self, filename) + + def write(self, data, filename, entry): + pickle_write(self, data, filename) + + def visualize(self, data, filename, entry, override_sr): + return NotImplementedError() diff --git a/piano_arranger/chord_recognition/io_new/lyric_io.py b/piano_arranger/chord_recognition/io_new/lyric_io.py new file mode 100644 index 0000000..468bb66 --- /dev/null +++ b/piano_arranger/chord_recognition/io_new/lyric_io.py @@ -0,0 +1,44 @@ +from mir.io.feature_io_base import * +import numpy as np +import librosa +import codecs + +class LyricIO(FeatureIO): + def read(self, filename, entry): + f = open(filename, 'r', encoding='utf-16-le') + content = f.read() + if(content.startswith('\ufeff')): + content=content[1:] + lines=content.split('\n') + f.close() + result=[] + for i in range(len(lines)): + line=lines[i].strip() + if(line==''): + continue + tokens=line.split('\t') + if(len(tokens)==3): + result.append([float(tokens[0]),float(tokens[1]),tokens[2]]) + elif(len(tokens)==4): # Contains sentence information + result.append([float(tokens[0]),float(tokens[1]),tokens[2],int(tokens[3])]) + else: + raise Exception('Not supported format') + return result + + def write(self, data, filename, entry): + f = open(filename, 'wb') + f.write(codecs.BOM_UTF16_LE) + for i in range(0, len(data)): + f.write('\t'.join([str(item) for item in data[i]]).encode('utf-16-le')) + f.write('\n'.encode('utf-16-le')) + f.close() + + def visualize(self, data, filename, entry, override_sr): + f = open(filename, 'w') + for i in range(0, len(data)): + f.write('\t'.join([str(item) for item in data[i]])) + f.write('\n') + f.close() + + def get_visualize_extention_name(self): + return "txt" \ No newline at end of file diff --git a/piano_arranger/chord_recognition/io_new/madmom_io.py b/piano_arranger/chord_recognition/io_new/madmom_io.py new file mode 100644 index 0000000..bfaa839 --- /dev/null +++ b/piano_arranger/chord_recognition/io_new/madmom_io.py @@ -0,0 +1,34 @@ +from mir.io.feature_io_base import * +from mir.common import PACKAGE_PATH +import numpy as np + + +class MadmomBeatProbIO(FeatureIO): + def read(self, filename, entry): + return pickle_read(self, filename) + + def write(self, data, filename, entry): + pickle_write(self, data, filename) + + def visualize(self, data, filename, entry, override_sr): + f = open(os.path.join(PACKAGE_PATH,'data/spectrogram_template.svl'), 'r') + content = f.read() + f.close() + content = content.replace('[__SR__]', str(100)) + content = content.replace('[__WIN_SHIFT__]', str(1)) + content = content.replace('[__SHAPE_1__]', str(data.shape[1])) + content = content.replace('[__COLOR__]', str(1)) + labels = [str(i) for i in range(data.shape[1])] + content = content.replace('[__DATA__]',create_svl_3d_data(labels,data)) + f=open(filename,'w') + f.write(content) + f.close() + + def pre_assign(self, entry, proxy): + entry.prop.set('n_frame', LoadingPlaceholder(proxy, entry)) + + def post_load(self, data, entry): + entry.prop.set('n_frame', data.shape[0]) + + def get_visualize_extention_name(self): + return "svl" \ No newline at end of file diff --git a/piano_arranger/chord_recognition/io_new/midilab_io.py b/piano_arranger/chord_recognition/io_new/midilab_io.py new file mode 100644 index 0000000..e1a3dfa --- /dev/null +++ b/piano_arranger/chord_recognition/io_new/midilab_io.py @@ -0,0 +1,50 @@ +from mir.io.feature_io_base import * +from mir.common import PACKAGE_PATH +import numpy as np +import librosa + +class MidiLabIO(FeatureIO): + def read(self, filename, entry): + f = open(filename, 'r') + lines=f.readlines() + lines=[line.strip('\n\r') for line in lines] + lines=[line for line in lines if line!=''] + f.close() + result=np.zeros((len(lines),3)) + for i in range(len(lines)): + line=lines[i] + tokens=line.split('\t') + assert(len(tokens)==3) + result[i,0]=float(tokens[0]) + result[i,1]=float(tokens[1]) + result[i,2]=float(tokens[2]) + return result + + def write(self, data, filename, entry): + f = open(filename, 'w') + for i in range(0, len(data)): + f.write('\t'.join([str(item) for item in data[i]])) + f.write('\n') + f.close() + + def visualize(self, data, filename, entry, override_sr): + f = open(os.path.join(PACKAGE_PATH, 'data/midi_template.svl'), 'r') + sr = override_sr + content = f.read() + f.close() + content = content.replace('[__SR__]', str(sr)) + content = content.replace('[__WIN_SHIFT__]', '1') + output_text='' + for note_info in data: + output_text+=self.__get_midi_note_text(note_info[0]*sr,note_info[1]*sr-1,note_info[2]) + content = content.replace('[__DATA__]', output_text) + f = open(filename, 'w') + f.write(content) + f.close() + + def __get_midi_note_text(self,start_frame,end_frame,midi_height,level=0.78125): + return '\n'\ + %(int(round(start_frame)),midi_height,int(round(end_frame-start_frame)),level) + + def get_visualize_extention_name(self): + return "svl" \ No newline at end of file diff --git a/piano_arranger/chord_recognition/io_new/osu_io.py b/piano_arranger/chord_recognition/io_new/osu_io.py new file mode 100644 index 0000000..7e6a3c9 --- /dev/null +++ b/piano_arranger/chord_recognition/io_new/osu_io.py @@ -0,0 +1,54 @@ +from mir.io.feature_io_base import * + +class MetaDict: + def __init__(self): + self.dict={} + + def set(self,key,value): + self.dict[key]=value + + def __getattr__(self, item): + return self.dict[item.lower()] + +class OsuMapIO(FeatureIO): + def read(self, filename, entry): + f=open(filename,'r',encoding='UTF-8') + result=MetaDict() + lines=f.readlines() + current_state=0 + current_dict=None + for line in lines: + line=line.strip() + if(line==''): + continue + if(line.startswith('[')): + assert(line.endswith(']')) + namespace=line[1:-1].lower() + if(namespace in ['general','editor','metadata','difficulty','colours']): + current_state=1 + current_dict=MetaDict() + result.set(namespace,current_dict) + elif(namespace in ['hitobjects','events','timingpoints']): + current_state=2 + current_dict=[] + result.set(namespace,current_dict) + else: + raise Exception('Unknown namespace %s in %s'%(namespace,filename)) + else: + if(current_state==1): + split_index=line.index(':') + key=line[:split_index].strip() + value=line[split_index+1:].strip() + current_dict.set(key.lower(),value) + elif(current_state==2): + current_dict.append(line) + return result + + def write(self, data, filename, entry): + raise NotImplementedError() + + def visualize(self, data, filename, entry, override_sr): + raise NotImplementedError() + + def get_visualize_extention_name(self): + raise NotImplementedError() \ No newline at end of file diff --git a/piano_arranger/chord_recognition/io_new/salami_io.py b/piano_arranger/chord_recognition/io_new/salami_io.py new file mode 100644 index 0000000..78df5fe --- /dev/null +++ b/piano_arranger/chord_recognition/io_new/salami_io.py @@ -0,0 +1,43 @@ +from mir.io.feature_io_base import * +from mir.music_base import get_scale_and_suffix + +class SalamiIO(FeatureIO): + def read(self, filename, entry): + f=open(filename,'r') + data=f.read() + lines=data.split('\n') + result=[] + metre_up=-1 + metre_down=-1 + tonic=-1 + for line in lines: + if(line==''): + continue + if(line.startswith('#')): + if(':' in line): + seperator_index=line.index(':') + keyword=line[1:seperator_index].strip() + if(keyword=='metre'): + slash_index=line.index('/') + metre_up=int(line[seperator_index+1:slash_index].strip()) + metre_down=int(line[slash_index+1:].strip()) + # print('metre changed to %d/%d'%(metre_up,metre_down)) + if(keyword=='tonic'): + tonic=int(get_scale_and_suffix(line[seperator_index+1:].strip())[0]) + + else: + tokens=line.split('\t') + assert(len(tokens)==2) + start_time=float(tokens[0]) + result.append((start_time,tokens[1],metre_up,metre_down,tonic)) + f.close() + return result + + def write(self, data, filename, entry): + raise NotImplementedError() + + def visualize(self, data, filename, entry, override_sr): + f=open(filename,'w') + for (time,token,_,_,_) in data: + f.write('%f\t%s\n'%(time,token)) + f.close() \ No newline at end of file diff --git a/piano_arranger/chord_recognition/io_new/tag_io.py b/piano_arranger/chord_recognition/io_new/tag_io.py new file mode 100644 index 0000000..1640483 --- /dev/null +++ b/piano_arranger/chord_recognition/io_new/tag_io.py @@ -0,0 +1,44 @@ +from mir.io.feature_io_base import * +from mir.common import PACKAGE_PATH +import numpy as np + +class TimedTagIO(FeatureIO): + def read(self, filename, entry): + f = open(filename, 'r') + content = f.read() + lines=content.split('\n') + f.close() + result=[] + for i in range(len(lines)): + line=lines[i].strip() + if(line==''): + continue + tokens=line.split('\t') + assert(len(tokens)==2) + result.append([float(tokens[0]),tokens[1]]) + return result + + def write(self, data, filename, entry): + f = open(filename, 'w') + for i in range(0, len(data)): + f.write('\t'.join([str(item) for item in data[i]])) + f.write('\n') + f.close() + + def visualize(self, data, filename, entry, override_sr): + sr = override_sr + f = open(os.path.join(PACKAGE_PATH,'data/sparse_tag_template.svl'), 'r') + content = f.read() + f.close() + content = content.replace('[__SR__]', str(sr)) + content = content.replace('[__STYLE__]', str(1)) + results=[] + for item in data: + results.append(''%(int(np.round(item[0]*sr)),item[1])) + content=content.replace('[__DATA__]','\n'.join(results)) + f=open(filename,'w') + f.write(content) + f.close() + + def get_visualize_extention_name(self): + return "svl" \ No newline at end of file diff --git a/piano_arranger/chord_recognition/main.py b/piano_arranger/chord_recognition/main.py new file mode 100644 index 0000000..ff902ee --- /dev/null +++ b/piano_arranger/chord_recognition/main.py @@ -0,0 +1,76 @@ +from mir import DataEntry +from mir import io +from extractors.midi_utilities import get_valid_channel_count,is_percussive_channel,MidiBeatExtractor +from extractors.rule_based_channel_reweight import midi_to_thickness_and_bass_weights +from midi_chord import ChordRecognition +from chord_class import ChordClass +import numpy as np +from io_new.chordlab_io import ChordLabIO +from io_new.downbeat_io import DownbeatIO + +def process_chord(entry, extra_division): + ''' + + Parameters + ---------- + entry: the song to be processed. Properties required: + entry.midi: the pretry midi object + entry.beat: extracted beat and downbeat + extra_division: extra divisions to each beat. + For chord recognition on beat-level, use extra_division=1 + For chord recognition on half-beat-level, use extra_division=2 + + Returns + ------- + Extracted chord sequence + ''' + + midi=entry.midi + beats=midi.get_beats() + if(extra_division>1): + beat_interp=np.linspace(beats[:-1],beats[1:],extra_division+1).T + last_beat=beat_interp[-1,-1] + beats=np.append(beat_interp[:,:-1].reshape((-1)),last_beat) + downbeats=midi.get_downbeats() + j=0 + beat_pos=-2 + beat=[] + for i in range(len(beats)): + if(j0) + beat.append([beats[i],beat_pos]) + rec=ChordRecognition(entry,ChordClass()) + weights=midi_to_thickness_and_bass_weights(entry.midi) + rec.process_feature(weights) + chord=rec.decode() + return chord + +def transcribe_cb1000_midi(midi_path,output_path=None): + ''' + Perform chord recognition on a midi + :param midi_path: the path to the midi file + :param output_path: the path to the output file + ''' + entry=DataEntry() + entry.append_file(midi_path,io.MidiIO,'midi') + entry.append_extractor(MidiBeatExtractor,'beat') + result=process_chord(entry,extra_division=1) + if output_path is not None: + entry.append_data(result,ChordLabIO,'pred') + entry.save('pred',output_path) + else: + return result + + +if __name__ == '__main__': + import sys + if(len(sys.argv)!=2): + print('Usage: main.py midi_path') + exit(0) + output_path = "{}/chord_midi.txt".format( + "/".join(sys.argv[1].split("/")[:-1])) + transcribe_cb1000_midi(sys.argv[1],output_path) diff --git a/piano_arranger/chord_recognition/midi_chord.py b/piano_arranger/chord_recognition/midi_chord.py new file mode 100644 index 0000000..bb73ece --- /dev/null +++ b/piano_arranger/chord_recognition/midi_chord.py @@ -0,0 +1,151 @@ +import numpy as np +from mir import io +from chord_class import ChordClass +from extractors.midi_utilities import is_percussive_channel + +class ChordRecognition: + + def __init__(self,entry,decode_chord_class:ChordClass,half_beat_switch=True): + ''' + Initialize a chord recognizer for an entry + :param entry: an instance of DataEntry with these proxies + - midi (IO type: MidiIO) + - beat (IO type: DownbeatIO): the corresponding downbeats & beats of the midi. + :param decode_chord_class: An instance of ChordClass + ''' + self.entry=entry + self.chord_class=decode_chord_class + self.half_beat_switch=half_beat_switch + + def process_feature(self,channel_weights): + ''' + First step of chord recognition + :param channel_weights: weights for each channel. If uniform, input [1, ..., 1]. + :return: Nothing. Calculated features are stored in the class. + ''' + SUBBEAT_COUNT=8 + + entry=self.entry + midi=entry.midi + beat=np.array(entry.beat) + n_frame=len(beat) + qt_beat_onset=np.zeros(n_frame) + qt_beat_offset=np.zeros(n_frame) + qt_beat_length=np.zeros(n_frame) + beat_chroma=np.zeros((n_frame,12)) + beat_bass=np.zeros((n_frame,12)) + min_subbeat_bass=np.full((n_frame*SUBBEAT_COUNT,),259,dtype=np.int) + notes=[] + for i in range(n_frame): + qt_beat_onset[i]=beat[i,0] + qt_beat_offset[i]=beat[i,0]+(beat[i,0]-beat[i-1,0]) if i==n_frame-1 else beat[i+1,0] + qt_beat_length[i]=beat[i+1,0]-beat[i,0] if i=qt_beat_offset[-1]): + return n_frame+0.0 + beat_id=np.searchsorted(qt_beat_onset,time,side='right')-1 + return beat_id+(time-qt_beat_onset[beat_id])/qt_beat_length[beat_id] + i=0 + def clamp(qstart,qend,bstart,bend): + return min(bend,qend)-max(qstart,bstart) + for instrument in midi.instruments: + if(is_percussive_channel(instrument)): + continue + raw_notes=instrument.notes + for note in raw_notes: + beat_start=quantize(note.start) + beat_end=quantize(note.end) + left_beat=int(np.floor(beat_start+0.2)) + right_beat=int(np.ceil(beat_end-0.2)) + left_subbeat=int(np.floor(beat_start*SUBBEAT_COUNT+0.2)) + right_subbeat=int(np.floor(beat_end*SUBBEAT_COUNT+0.2)) + if(right_beat0 and self.is_downbeat[i-j+1]): + break + current_i=n_frame-1 + result=[] + while(current_i>=0): + prev_i=prei[current_i] + prev_c=prec[current_i] + start=prev_i+1 if self.half_beat_switch or self.is_even_beat[prev_i+1] else prev_i+2 + end=current_i if self.half_beat_switch or current_i==n_frame-1 or self.is_even_beat[current_i+1] else current_i+1 + result.append([self.qt_beat_onset[start],self.qt_beat_offset[end],self.chord_class.chord_list[prev_c]]) + current_i=prev_i + result=result[::-1] + #print(dp) + #entry.visualize(['chord_hmm','beat','chroma','bass']) + return result diff --git a/piano_arranger/chord_recognition/mir/.idea/inspectionProfiles/Project_Default.xml b/piano_arranger/chord_recognition/mir/.idea/inspectionProfiles/Project_Default.xml new file mode 100644 index 0000000..25bde2c --- /dev/null +++ b/piano_arranger/chord_recognition/mir/.idea/inspectionProfiles/Project_Default.xml @@ -0,0 +1,14 @@ + + + + \ No newline at end of file diff --git a/piano_arranger/chord_recognition/mir/.idea/mir.iml b/piano_arranger/chord_recognition/mir/.idea/mir.iml new file mode 100644 index 0000000..84264bf --- /dev/null +++ b/piano_arranger/chord_recognition/mir/.idea/mir.iml @@ -0,0 +1,12 @@ + + + + + + + + + + \ No newline at end of file diff --git a/piano_arranger/chord_recognition/mir/.idea/misc.xml b/piano_arranger/chord_recognition/mir/.idea/misc.xml new file mode 100644 index 0000000..65531ca --- /dev/null +++ b/piano_arranger/chord_recognition/mir/.idea/misc.xml @@ -0,0 +1,4 @@ + + + + \ No newline at end of file diff --git a/piano_arranger/chord_recognition/mir/.idea/modules.xml b/piano_arranger/chord_recognition/mir/.idea/modules.xml new file mode 100644 index 0000000..d1c765b --- /dev/null +++ b/piano_arranger/chord_recognition/mir/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/piano_arranger/chord_recognition/mir/.idea/vcs.xml b/piano_arranger/chord_recognition/mir/.idea/vcs.xml new file mode 100644 index 0000000..94a25f7 --- /dev/null +++ b/piano_arranger/chord_recognition/mir/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/piano_arranger/chord_recognition/mir/.idea/workspace.xml b/piano_arranger/chord_recognition/mir/.idea/workspace.xml new file mode 100644 index 0000000..2f1163b --- /dev/null +++ b/piano_arranger/chord_recognition/mir/.idea/workspace.xml @@ -0,0 +1,129 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + 1520852567261 + + + 1584832575286 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/piano_arranger/chord_recognition/mir/README.MD b/piano_arranger/chord_recognition/mir/README.MD new file mode 100644 index 0000000..287dc25 --- /dev/null +++ b/piano_arranger/chord_recognition/mir/README.MD @@ -0,0 +1,296 @@ + + +# MIR Helper + +## Description + +MIR Helper is a framework for Music Information Retrieval (MIR) data manipulating, processing and visualization. Specifically, MIR Helper provides fercilities including: + +1. Structural management of feature extractors, including caching and data parallel +2. Data visualization with sonic visualizer +3. Type-explicit data IO management with custom extension supported + +## Installation + +1. Install python 3.5+, Copy the package to python lib folder +2. Install dependencies shown in requirements.txt +3. (Optional, recommended) Install sonic visualizer, and fill the installation path to settings.py if you want to use data visualization functions. +4. (Optional, recommended) Change DEFAULT_DATA_STORAGE_PATH in settings.py to a folder in which you want to store training data, if you want to use features in nn.data_storage +5. (Optional) Install sonic annotator and corresponding vamp plugins, and fill the installation path to settings.py if you want to use pre-defined vamp extractors + +## Introduction + +In the following context, we will use 'mir' to denote the package, and use 'io' to denote its IO manager. +``` +import mir +import mir.io as io +``` + +### Data Management + +#### Feature IO + +We define everything representable related to a certain music piece in the computer as a feature. +The wave form, the spectrogram and the symbolic data can all be a kind of feature. + +A feature has its own type denoted by an IO class. The IO class controls (1) what the feature format is and how to read it from files or write it to files and (2) how the feature can be visualized with Sonic Visualizer. +All IO classes must inherit the base class FeatureIO. + +Some examples: +1. io.MusicIO is the feature type for wave data. It uses mono-channel numpy array for feature representation. +It can read .wav, .mp3, .ogg, etc. files and write .wav files; +2. io.MidiIO is the feature type for midi data. It uses PrettyMidi package for feature representation. It can read or write .mid files; +3. io.SpectrogramIO is the feature type for 2d spectrogram. It uses python pickle for fast IO, and can be visualized to sonic visualizer easily. + +All other features with no pre-defined type in an IO class should be assigned type 'UnknownIO'. By doing so, the feature cannot be read or written or visualized. To avoid that, you are suggested to write your own IO class to enable these features in your custom way. + +#### DataEntry + +DataEntry is a container specifically for storage of wave data of a song, along with all features linked with the song. +``` +entry=mir.DataEntry() + +# Set properties for IO +entry.prop.set('sr',22050) +entry.prop.set('hop_length',512) + +# Append the raw audio from the file 'wave.mp3' +entry.append_file('wave.mp3',io.MusicIO,'music') + +# Append the data whose name is 'cqt' with pre-defined type 'SpectrogramIO' +entry.append_data(cqt_data,io.SpectrogramIO,'cqt') + +# Append the data whose name is 'chord' with your custom type 'CustomChordLabIO' +entry.append_data(chord_estimated,CustomChordLabIO,'chord') + +# Visualize all the above features in a single window in sonic visualizer +entry.visualize(['music','cqt','chord']) + +# Directly call some features by entry.feature_name to get its content +# it is equivalent to print(entry.dict['cqt'].get(entry)) +print(entry.cqt.shape) +``` + +#### DataProxy + +DataProxy is a container to hold a single feature for a song. It also works as a proxy to handle things that improve the space and the time efficiency for the program: + +(1) Automatic lazy loading: load the feature only when it is needed. +(2) Automatic cache management: cache it to prevent repetitive computation in the future + +The DataProxy has three subclasses which share the same interface. They are + +(1) FileProxy: the feature is directly loaded from a file. Lazy loading is enabled for this kind of features. +(2) ExtractorProxy: the feature is extracted by some program (e.g., a neural network, a preprocessor, etc). Lazy loading and cache management are enabled for this kind of features. +(3) DataProxy: the feature is loaded from memory. No lazy loading or cache management is enabled for this kind of features. + +You can use entry.append_file, entry.append_extractor, entry.append_data to append different kinds of DataProxy to the data entry. + +#### DataPool + +DataPool is a container to contain multiple songs to form a data-set. Often, songs in one data-set share same properties or require same operations (e.g., same preprocessing methods). Helper functions in DataPool helps do these things fast and easily by parallel computing. + +``` +dataset=mir.DataPool(name='my_dataset') + +# Set common properties for all songs in the dataset +dataset.prop.set('sr',22050) +dataset.prop.set('hop_length',512) + +# Append all .mp3 files in a folder and create data entries for all songs +dataset.append_folder('dataset/mp3_file_folder/','.mp3',io.MusicIO,'music') + +# Append all .lab files in a folder and append them to the data entries sharing the same file names +dataset.append_folder('dataset/chord_annotation_folder/','.lab',CustomChordLabIO,'chord') + +# Append the same extractor to all entries in the DataPool +entry.append_extractor(CQTExtractor,'cqt',cache_enabled=True,cqt_dim=256) + +# Perform feature extraction in parallel +entry.activate_proxy('cqt',thread_number=8) + +# Visualize some songs in the data-set +for entry in dataset.entries: + entry.visualize(['music','cqt','chord']) +``` + +#### ExtractorProxy + +You can write your own ExtractorProxy if you want to use automatic caching for your extractor. To do this, you need create a subclass of the base class ExtractorBase and implement two functions: + +1. get_feature_class(self): returns what type of features it is extraction +2. extract(self,entry,**kwargs): how to perform the extraction algorithm. kwargs may contain additional parameters to the extractor. + +Here is an example of a CQT extractor, which is a wrapper of the package librosa: + +``` +class CQTExtractor(mir.extractors.ExtractorBase): + def get_feature_class(self): + return io.SpectrogramIO + + def extract(self,entry,**kwargs): + n_bins=kwargs['n_bins'] + hop_length=entry.prop.hop_length + logspec=librosa.core.cqt(entry.music,hop_length=hop_length, bins_per_octave=36, n_bins=n_bins, + filter_scale=1.5).T + return np.abs(logspec) +``` + +### Integration with PyTorch + +Another function of the package is to perform PyTorch-related work easier by a network trainer/inference framework. + +``` +import mir.nn +``` + +#### DataStorage + +The DataStorage class and its subclasses provide a way to store different kinds of data. We know that different data types/the scales of the data lead to different solution to store the data. Thus, in typical MIR tasks, two solutions are provided: + +(1) FramedRAMDataStorage is a data storage for framed data that fits in your RAM; +(2) FramedH5DataStorage is a data storage for framed data that is too large to fit in your RAM. We use H5FS to read it from disks in real-time instead. + +However, they share the same interface that can be recognized with a DataProvider and a NetworkInterface. + + +To save to a data storage, it is encouraged to use the data management methods in this package: + +``` +dataset.append_extractor(mir.extractors.misc.FrameCount,'n_frame',source='cqt') +storage_cqt=mir.nn.FramedH5DataStorage('/my/storage/path/500songs_cqt',dtype=np.float16) +if(not storage_cqt.created): + storage_cqt.create_and_cache(dataset.entries,'cqt') +``` + +If framed data is written, be sure to pre-calculate (or pre-cache) the 'n_frame' feature (the frame count of the data) in every data entry in the data-set. Otherwise, it would be very slow and memory-consuming. + +To load from a data storage, you can do something like: + +``` +storage_cqt=mir.nn.FramedRAMDataStorage('/my/storage/path/500songs_cqt') +storage_cqt.load_meta() +print('We have %d songs!'%storage_cqt.get_length()) +``` + +#### DataProvider + +The data provider is a way to combine multiple DataStorage instances into one with the 'link' function. When feeding into the network, data pieces will be sampled from the same position for all DataStorage instances. These dat pieces will form into a tuple. + +It will also decide whether data augmentation is performed by how you create it. + +**Notice: this class is a subclass of torch.Dataset. This means you can use it in other pure PyTorch programs.** + +``` +train_provider=mir.nn.FramedDataProvider(train_sample_length=LSTM_TRAIN_LENGTH, # how many frames per sample + shift_low=0, # what is the lower bound of pitch shift, inclusive + shift_high=11, # what is the upper bound of pitch shift, inclusive + num_workers=4) # how many extra threads are use to fetch data in parallel + +# Link the feature storage to the data provider +train_provider.link(storage_x,CQTPitchShifter(SPEC_DIM,SHIFT_LOW,SHIFT_HIGH),subrange=train_indices) + +# Link the label storage to the data provider +train_provider.link(storage_y,ChordPitchShifter(),subrange=train_indices) + +# Get the total samples +length=train_provider.get_length() + +# It will produce a random sample pair (x,y) where x is from storage_x and y is from storage_y +print(train_provider.get_sample(np.random.randint(length)) +``` + +#### DataDecorator + +data decorators helps you to perform final data preprocessing before training, and/or data augmentation. To write your own data decorator, create a subclass of AbstractPitchShifter and implement your own pitch_shift function. Then, pass it to the DataProvider when you link some DataStorage to it. + +An example is the CQT pitch shifter where we shift pitch by performing scrolling on the frequency axis: + +``` +class CQTPitchShifter(AbstractPitchShifter): + + def __init__(self,spec_dim,shift_low,shift_high,shift_step=3): + self.shift_low=shift_low + self.shift_high=shift_high + self.spec_dim=spec_dim + self.shift_step=shift_step + self.min_input_dim=(-self.shift_low+self.shift_high)*self.shift_step+self.spec_dim + + def pitch_shift(self,data,shift): + if(data.shape[1]= %d, got %d'% + (self.min_input_dim,data.shape[1])) + start_dim=(-shift+self.shift_high)*self.shift_step + return data[:,start_dim:start_dim+self.spec_dim] +``` + +#### NetworkInterface + +NetworkInterface provides a interface for training and testing PyTorch models. To use the class, you first need to define a model structure in a subclass of mir.nn.NetworkBehavior. + +You need to complete these functions in your subclass: + +(1) \_\_init\_\_(self): initialize what you want to initialize. +(2) forward(self, x): the same as torch.nn.Module.forward. PyTorch hooks work here. +(3) loss(self, *args): how to calculate the loss. args are the tuples from the data provider with each element converted to PyTorch format +(4) inference(self, x): how do you plan to do the inference. + +After that, you can use a NetworkInterface to wrap the class you defined, and train the network. + +Here is an example of training a network with cross validation: + +``` +import sys +import numpy as np +TOTAL_FOLD_COUNT=5 +slice_id=int(sys.argv[1]) +if(slice_id>=5 or slice_id<0): + raise Exception('Invalid input') +print('Train on slice %d'%slice_id) +storage_x=mir.nn.FramedH5DataStorage('/path/to/storage/cqt') +storage_y=mir.nn.FramedH5DataStorage('/path/to/storage/chord') +storage_x.load_meta() +song_count=storage_x.get_length() +is_training=np.ones(song_count,dtype=np.bool) +is_validation=np.zeros(song_count,dtype=np.bool) +is_testing=np.zeros(song_count,dtype=np.bool) +for i in range(song_count): + if(i%TOTAL_FOLD_COUNT==slice_id): + is_training[i]=False + is_testing[i]=True + if((i+1)%TOTAL_FOLD_COUNT==slice_id): + is_training[i]=False + is_validation[i]=True +train_indices=np.arange(song_count)[is_training] +val_indices=np.arange(song_count)[is_validation] +print('Using %d samples to train'%len(train_indices)) +print('Using %d samples to validate'%len(val_indices)) +train_provider=mir.nn.FramedDataProvider(train_sample_length=LSTM_TRAIN_LENGTH,shift_low=SHIFT_LOW,shift_high=SHIFT_HIGH,num_workers=4) +train_provider.link(storage_x,CQTPitchShifter(SPEC_DIM,SHIFT_LOW,SHIFT_HIGH),subrange=train_indices) +train_provider.link(storage_y,ChordPitchShifter(),subrange=train_indices) + +val_provider=mir.nn.FramedDataProvider(train_sample_length=LSTM_TRAIN_LENGTH,shift_low=SHIFT_LOW,shift_high=SHIFT_HIGH,num_workers=4) +val_provider.link(storage_x,CQTPitchShifter(SPEC_DIM,SHIFT_LOW,SHIFT_HIGH),subrange=val_indices) +val_provider.link(storage_y,ChordPitchShifter(),subrange=val_indices) + +# Create an instance of NetworkInterface +trainer=mir.nn.NetworkInterface( + MyNetworkModel(), # this is your model + 'model_fold=%d(p)'%slice_id, # model cache name + load_checkpoint=True # load model state from checkpoint? +) +# Train the network +trainer.train_supervised( + train_provider, # training set data provider + val_provider, # validation set data provider + batch_size=96, # batch size + learning_rates_dict={1e-3:6,1e-4:3,1e-5:3}, # learning rate change after certain epochs (decay function is currently not supported) + round_per_print=10, + round_per_val=50, + round_per_save=500 +) +``` + +The marked '\(p\)' in the model indicates that it will perform parallel training. Otherwise, it will only use 1 gpu/cpu. + +After training, you can call NetworkInterface.inference(input) to calculate the model output given the input. \ No newline at end of file diff --git a/piano_arranger/chord_recognition/mir/__init__.py b/piano_arranger/chord_recognition/mir/__init__.py new file mode 100644 index 0000000..37c95dd --- /dev/null +++ b/piano_arranger/chord_recognition/mir/__init__.py @@ -0,0 +1,5 @@ +from mir.common import WORKING_PATH, PACKAGE_PATH +from mir.data_file import TextureBuilder, DataEntry, DataPool + + +__all__ = ['TextureBuilder','DataEntry','WORKING_PATH','PACKAGE_PATH','DataPool','io'] \ No newline at end of file diff --git a/piano_arranger/chord_recognition/mir/cache.py b/piano_arranger/chord_recognition/mir/cache.py new file mode 100644 index 0000000..9792945 --- /dev/null +++ b/piano_arranger/chord_recognition/mir/cache.py @@ -0,0 +1,52 @@ +import pickle +import os +from mir.common import WORKING_PATH +import hashlib + +__all__=['load','save'] + +def mkdir_for_file(path): + folder_path=os.path.dirname(path) + if(not os.path.isdir(folder_path)): + os.makedirs(folder_path) + return path + +def dumptofile(obj,filename,protocol): + f=open(filename, 'wb') + # If you are awared of the compatibility issues + # Well, you use cache only on your own computer, right? + pickle.dump(obj,f,protocol=protocol) + f.close() + +def loadfromfile(filename): + if(os.path.isfile(filename)): + f=open(filename,'rb') + obj=pickle.load(f) + f.close() + return obj + else: + raise Exception('No cache of %s'%filename) + + +def load(*names): + if(len(names)==1): + return loadfromfile(os.path.join(WORKING_PATH,'cache_data/%s.cache'%names[0])) + result=[None]*len(names) + for i in range(len(names)): + result[i]=loadfromfile(os.path.join(WORKING_PATH,'cache_data/%s.cache'%names[i])) + return result + +def save(obj,name,protocol=pickle.HIGHEST_PROTOCOL): + path=os.path.join(WORKING_PATH,'cache_data/%s.cache'%name) + mkdir_for_file(path) + dumptofile(obj,path,protocol) + +def hasher(obj): + if(isinstance(obj,list)): + m=hashlib.md5() + for item in obj: + m.update(item) + return m.hexdigest() + if(isinstance(obj,str)): + return hashlib.md5(obj.encode("utf8")).hexdigest() + return hashlib.md5(obj).hexdigest() \ No newline at end of file diff --git a/piano_arranger/chord_recognition/mir/common.py b/piano_arranger/chord_recognition/mir/common.py new file mode 100644 index 0000000..af001fb --- /dev/null +++ b/piano_arranger/chord_recognition/mir/common.py @@ -0,0 +1,7 @@ +import os +from mir.settings import * + +WORKING_PATH=os.getcwd() +PACKAGE_PATH=os.path.dirname(os.path.abspath(__file__)) + +DEFAULT_DATA_STORAGE_PATH=DEFAULT_DATA_STORAGE_PATH.replace('$project_name$',os.path.basename(os.getcwd())) diff --git a/piano_arranger/chord_recognition/mir/data/bothchroma.n3 b/piano_arranger/chord_recognition/mir/data/bothchroma.n3 new file mode 100644 index 0000000..7c83791 --- /dev/null +++ b/piano_arranger/chord_recognition/mir/data/bothchroma.n3 @@ -0,0 +1,34 @@ +@prefix xsd: . +@prefix vamp: . +@prefix : <#> . + +:transform a vamp:Transform ; + vamp:plugin ; + vamp:step_size "[__WIN_SHIFT__]"^^xsd:int ; + vamp:block_size "[__WIN_SIZE__]"^^xsd:int ; + vamp:plugin_version """5""" ; + vamp:parameter_binding [ + vamp:parameter [ vamp:identifier "chromanormalize" ] ; + vamp:value "0"^^xsd:float ; + ] ; + vamp:parameter_binding [ + vamp:parameter [ vamp:identifier "rollon" ] ; + vamp:value "0"^^xsd:float ; + ] ; + vamp:parameter_binding [ + vamp:parameter [ vamp:identifier "s" ] ; + vamp:value "0.7"^^xsd:float ; + ] ; + vamp:parameter_binding [ + vamp:parameter [ vamp:identifier "tuningmode" ] ; + vamp:value "0"^^xsd:float ; + ] ; + vamp:parameter_binding [ + vamp:parameter [ vamp:identifier "useNNLS" ] ; + vamp:value "1"^^xsd:float ; + ] ; + vamp:parameter_binding [ + vamp:parameter [ vamp:identifier "whitening" ] ; + vamp:value "1"^^xsd:float ; + ] ; + vamp:output [ vamp:identifier "bothchroma" ] . diff --git a/piano_arranger/chord_recognition/mir/data/chordino.n3 b/piano_arranger/chord_recognition/mir/data/chordino.n3 new file mode 100644 index 0000000..8ef0e0c --- /dev/null +++ b/piano_arranger/chord_recognition/mir/data/chordino.n3 @@ -0,0 +1,46 @@ +@prefix xsd: . +@prefix vamp: . +@prefix : <#> . + +:transform_plugin a vamp:Plugin ; + vamp:identifier "chordino" . + +:transform_library a vamp:PluginLibrary ; + vamp:identifier "nnls-chroma" ; + vamp:available_plugin :transform_plugin . + +:transform a vamp:Transform ; + vamp:plugin ; + vamp:step_size "[__WIN_SHIFT__]"^^xsd:int ; + vamp:block_size "[__WIN_SIZE__]"^^xsd:int ; + vamp:plugin_version """5""" ; + vamp:sample_rate "[__SR__]"^^xsd:int ; + vamp:parameter_binding [ + vamp:parameter [ vamp:identifier "boostn" ] ; + vamp:value "0.1"^^xsd:float ; + ] ; + vamp:parameter_binding [ + vamp:parameter [ vamp:identifier "rollon" ] ; + vamp:value "0"^^xsd:float ; + ] ; + vamp:parameter_binding [ + vamp:parameter [ vamp:identifier "s" ] ; + vamp:value "0.7"^^xsd:float ; + ] ; + vamp:parameter_binding [ + vamp:parameter [ vamp:identifier "tuningmode" ] ; + vamp:value "0"^^xsd:float ; + ] ; + vamp:parameter_binding [ + vamp:parameter [ vamp:identifier "useHMM" ] ; + vamp:value "1"^^xsd:float ; + ] ; + vamp:parameter_binding [ + vamp:parameter [ vamp:identifier "useNNLS" ] ; + vamp:value "1"^^xsd:float ; + ] ; + vamp:parameter_binding [ + vamp:parameter [ vamp:identifier "whitening" ] ; + vamp:value "1"^^xsd:float ; + ] ; + vamp:output [ vamp:identifier "simplechord" ] . diff --git a/piano_arranger/chord_recognition/mir/data/chroma.n3 b/piano_arranger/chord_recognition/mir/data/chroma.n3 new file mode 100644 index 0000000..83eeb9e --- /dev/null +++ b/piano_arranger/chord_recognition/mir/data/chroma.n3 @@ -0,0 +1,42 @@ +@prefix xsd: . +@prefix vamp: . +@prefix : <#> . + +:transform_plugin a vamp:Plugin ; + vamp:identifier "nnls-chroma" . + +:transform_library a vamp:PluginLibrary ; + vamp:identifier "nnls-chroma" ; + vamp:available_plugin :transform_plugin . + +:transform a vamp:Transform ; + vamp:plugin :transform_plugin ; + vamp:step_size "[__WIN_SHIFT__]"^^xsd:int ; + vamp:block_size "[__WIN_SIZE__]"^^xsd:int ; + vamp:plugin_version """3""" ; + vamp:sample_rate "[__SR__]"^^xsd:int ; + vamp:parameter_binding [ + vamp:parameter [ vamp:identifier "chromanormalize" ] ; + vamp:value "0"^^xsd:float ; + ] ; + vamp:parameter_binding [ + vamp:parameter [ vamp:identifier "rollon" ] ; + vamp:value "0"^^xsd:float ; + ] ; + vamp:parameter_binding [ + vamp:parameter [ vamp:identifier "s" ] ; + vamp:value "0.7"^^xsd:float ; + ] ; + vamp:parameter_binding [ + vamp:parameter [ vamp:identifier "tuningmode" ] ; + vamp:value "0"^^xsd:float ; + ] ; + vamp:parameter_binding [ + vamp:parameter [ vamp:identifier "useNNLS" ] ; + vamp:value "1"^^xsd:float ; + ] ; + vamp:parameter_binding [ + vamp:parameter [ vamp:identifier "whitening" ] ; + vamp:value "1"^^xsd:float ; + ] ; + vamp:output [ vamp:identifier "chroma" ] . diff --git a/piano_arranger/chord_recognition/mir/data/curve_template.svl b/piano_arranger/chord_recognition/mir/data/curve_template.svl new file mode 100644 index 0000000..0e17971 --- /dev/null +++ b/piano_arranger/chord_recognition/mir/data/curve_template.svl @@ -0,0 +1,13 @@ + + + + + + + [__DATA__] + + + + + + diff --git a/piano_arranger/chord_recognition/mir/data/midi_template.svl b/piano_arranger/chord_recognition/mir/data/midi_template.svl new file mode 100644 index 0000000..6a855df --- /dev/null +++ b/piano_arranger/chord_recognition/mir/data/midi_template.svl @@ -0,0 +1,13 @@ + + + + + + + [__DATA__] + + + + + + diff --git a/piano_arranger/chord_recognition/mir/data/pitch_template.svl b/piano_arranger/chord_recognition/mir/data/pitch_template.svl new file mode 100644 index 0000000..1ac5610 --- /dev/null +++ b/piano_arranger/chord_recognition/mir/data/pitch_template.svl @@ -0,0 +1,18 @@ + + + + + + + + [__DATA_FREQ__] + + + [__DATA_ENERGY__] + + + + + + + diff --git a/piano_arranger/chord_recognition/mir/data/sparse_tag_template.svl b/piano_arranger/chord_recognition/mir/data/sparse_tag_template.svl new file mode 100644 index 0000000..cf93030 --- /dev/null +++ b/piano_arranger/chord_recognition/mir/data/sparse_tag_template.svl @@ -0,0 +1,13 @@ + + + + + + + [__DATA__] + + + + + + diff --git a/piano_arranger/chord_recognition/mir/data/spectrogram_template.svl b/piano_arranger/chord_recognition/mir/data/spectrogram_template.svl new file mode 100644 index 0000000..99217d7 --- /dev/null +++ b/piano_arranger/chord_recognition/mir/data/spectrogram_template.svl @@ -0,0 +1,13 @@ + + + + + + + [__DATA__] + + + + + + diff --git a/piano_arranger/chord_recognition/mir/data/tunedlogfreqspec.n3 b/piano_arranger/chord_recognition/mir/data/tunedlogfreqspec.n3 new file mode 100644 index 0000000..ff2afe3 --- /dev/null +++ b/piano_arranger/chord_recognition/mir/data/tunedlogfreqspec.n3 @@ -0,0 +1,42 @@ +@prefix xsd: . +@prefix vamp: . +@prefix : <#> . + +:transform_plugin a vamp:Plugin ; + vamp:identifier "nnls-chroma" . + +:transform_library a vamp:PluginLibrary ; + vamp:identifier "nnls-chroma" ; + vamp:available_plugin :transform_plugin . + +:transform a vamp:Transform ; + vamp:plugin ; + vamp:step_size "[__WIN_SHIFT__]"^^xsd:int ; + vamp:block_size "[__WIN_SIZE__]"^^xsd:int ; + vamp:plugin_version """5""" ; + vamp:sample_rate "[__SR__]"^^xsd:int ; + vamp:parameter_binding [ + vamp:parameter [ vamp:identifier "chromanormalize" ] ; + vamp:value "0"^^xsd:float ; + ] ; + vamp:parameter_binding [ + vamp:parameter [ vamp:identifier "rollon" ] ; + vamp:value "0"^^xsd:float ; + ] ; + vamp:parameter_binding [ + vamp:parameter [ vamp:identifier "s" ] ; + vamp:value "0.7"^^xsd:float ; + ] ; + vamp:parameter_binding [ + vamp:parameter [ vamp:identifier "tuningmode" ] ; + vamp:value "0"^^xsd:float ; + ] ; + vamp:parameter_binding [ + vamp:parameter [ vamp:identifier "useNNLS" ] ; + vamp:value "1"^^xsd:float ; + ] ; + vamp:parameter_binding [ + vamp:parameter [ vamp:identifier "whitening" ] ; + vamp:value "1"^^xsd:float ; + ] ; + vamp:output [ vamp:identifier "tunedlogfreqspec" ] . diff --git a/piano_arranger/chord_recognition/mir/data/tuning.n3 b/piano_arranger/chord_recognition/mir/data/tuning.n3 new file mode 100644 index 0000000..ae5d95f --- /dev/null +++ b/piano_arranger/chord_recognition/mir/data/tuning.n3 @@ -0,0 +1,14 @@ +@prefix xsd: . +@prefix vamp: . +@prefix : <#> . + +:transform a vamp:Transform ; + vamp:plugin ; + vamp:step_size "[__WIN_SHIFT__]"^^xsd:int ; + vamp:block_size "[__WIN_SIZE__]"^^xsd:int ; + vamp:plugin_version """5""" ; + vamp:parameter_binding [ + vamp:parameter [ vamp:identifier "rollon" ] ; + vamp:value "0"^^xsd:float ; + ] ; + vamp:output . diff --git a/piano_arranger/chord_recognition/mir/data_file.py b/piano_arranger/chord_recognition/mir/data_file.py new file mode 100644 index 0000000..edf1359 --- /dev/null +++ b/piano_arranger/chord_recognition/mir/data_file.py @@ -0,0 +1,512 @@ +from abc import ABC,abstractmethod +from mir.common import SONIC_VISUALIZER_PATH,WORKING_PATH +import subprocess +import os +import gc +import mir.io +from joblib import Parallel,delayed +import random +from pydub.utils import mediainfo +import time +import datetime +import mir.cache + + + +class ProxyBase(ABC): + def __init__(self,feature_class): + self.loaded=False + self.loaded_data=None + self.feature_class=feature_class + + def pre_assign(self, entry): + io = self.feature_class() + io.pre_assign(entry, self) + + def get(self,requester): + if(not self.loaded): + self.loaded_data=self.load(requester) + self.post_load(requester) + self.loaded=True + return self.loaded_data + + @abstractmethod + def load(self, requester): + pass + + def post_load(self, requester): + io = self.feature_class() + io.post_load(self.loaded_data, requester) + + def unload(self,gc_collect=True): + if(self.loaded): + self.loaded=False + del self.loaded_data + if(gc_collect): + gc.collect() + + def save_visualize_temp_file(self,requestor,savepath,auralize,override_sr,beats=None): + io = self.feature_class() + if(auralize): + if(beats is not None): + io.auralize_with_beat(self.get(requestor),savepath,requestor,beats) + else: + io.auralize(self.get(requestor), savepath, requestor) + else: + io.visualize(self.get(requestor),savepath,requestor,override_sr) + + +class FileProxy(ProxyBase): + def __init__(self,file_path,feature_class,file_exist_check=True): + super().__init__(feature_class) + # if file_path is absolute, the statement will have no effect + file_path=os.path.join(WORKING_PATH,file_path) + if (file_exist_check and (not os.path.isfile(file_path))): + raise Exception("File not found: %s"%file_path) + self.filepath=file_path + + def load(self,requester): + io = self.feature_class() + return io.safe_read(self.filepath, requester) + + +class ExtractorProxy(ProxyBase): + def __init__(self,extractor_class,cache_enabled=True,io_override=None,**kwargs): + self.extractor=extractor_class() + if(io_override is not None): + feature_class=io_override + else: + feature_class=self.extractor.get_feature_class() + super().__init__(feature_class) + self.kwargs=kwargs + self.cache_enabled=cache_enabled + + def load(self, requester): + return self.extractor.extract_and_cache(requester,self.cache_enabled,**self.kwargs) + + +class DataProxy(ProxyBase): + def __init__(self,data,feature_class): + super().__init__(feature_class) + self.loaded_data=data + self.loaded=True + + def load(self,requester): + raise Exception('Shouldn\'t be here!') + + def unload(self,gc_collect=True): + pass + +class ProxyArray(): + def __init__(self, name, entry): + self.name=name + self.entry=entry + + def __getitem__(self, item): + return self.entry.__getattr__(self.name+'['+str(item)+']') + + +class TextureBuilder(): + def __init__(self, texture_class, chords_item, beats_item): + self.texture_class=texture_class + self.chords_item=chords_item + self.beats_item=beats_item + + +class DataEntryProperties(): + def __init__(self): + self.dict={} + self.recorded_set_stack=[] + self.recording=True + + def __getattr__(self, item): + return self.get(item) + + def remove(self, item): + del self.dict[item] + + def set(self, item, value): + if('dict' not in self.__dict__): + raise AttributeError('you are not initialized!') + if(isinstance(value,mir.io.LoadingPlaceholder)): + # Set a place holder + if (item not in self.dict): + self.dict[item] = value + elif(item in self.dict and not isinstance(self.dict[item],mir.io.LoadingPlaceholder)): + # Old value found + if(self.dict[item]!=value): + print('Warning: Inconsistant property in %s: old value'%item,self.dict[item],'new value',value) + else: + # No old value, set a new value + self.dict[item]=value + + def get(self,item): + if('dict' not in self.__dict__): + raise AttributeError('you are not initialized!') + if(item in self.dict): + if(len(self.recorded_set_stack)>0): # Recording + self.recorded_set_stack[-1].add(item) + obj=self.dict[item] + if(isinstance(obj,mir.io.LoadingPlaceholder)): + obj.fire() + obj=self.dict[item] + assert(not isinstance(obj,mir.io.LoadingPlaceholder)) + return obj + else: + raise AttributeError("Property %s not appended!"%item) + + def get_unrecorded(self,item): + if('dict' not in self.__dict__): + raise AttributeError('you are not initialized!') + if(item in self.dict): + obj=self.dict[item] + if(isinstance(obj,mir.io.LoadingPlaceholder)): + obj.fire() + obj=self.dict[item] + assert(not isinstance(obj,mir.io.LoadingPlaceholder)) + return obj + + def start_record_reading(self): + self.recorded_set_stack.append(set()) + + def end_record_reading(self): + result=self.recorded_set_stack.pop() + return list(result) + + +class DataEntry(): + # Warning: use empty name will disable all extractors' cache + def __init__(self,name=''): + self.dict={} + self.name=name + self.prop=DataEntryProperties() + self.proxy_array=set() + + def __getattr__(self, item): + if('dict' not in self.__dict__): + raise AttributeError('you are not initialized!') + # I wonder, why someone would intentionally create a DataEntry + # whose __dict__ is {} to run something multi-threadly + if(item in self.dict): + return self.dict[item].get(self) + elif(item in self.proxy_array): + return ProxyArray(item,entry=self) + else: + raise AttributeError("Datatype %s not appended!"%item) + + def has(self,item): + return item in self.dict + + def rename(self,old_name,new_name): + if(old_name!=new_name): + self.dict[new_name]=self.dict[old_name] + del self.dict[old_name] + + def swap(self,item1,item2): + temp_item=self.dict[item1] + self.dict[item1]=self.dict[item2] + self.dict[item2]=temp_item + + def remove(self,item): + del self.dict[item] + + def free(self,item='',gc_collect=True): + if(item==''): + for del_item in self.dict: + self.dict[del_item].unload(gc_collect) + else: + self.dict[item].unload(gc_collect) + + def append_file(self,filename,feature_class,output_name,file_exist_check=True): + file_proxy=FileProxy(filename,feature_class,file_exist_check=file_exist_check) + file_proxy.pre_assign(self) + self.dict[output_name]=file_proxy + + def apply_extractor(self,extractor_class,cache_enabled=True,io_override=None,**kwargs): + extractor_proxy=ExtractorProxy(extractor_class,cache_enabled,io_override,**kwargs) + extractor_proxy.pre_assign(self) + return extractor_proxy.get(self) + + def append_extractor(self,extractor_class,output_name,cache_enabled=True,io_override=None,**kwargs): + extractor_proxy=ExtractorProxy(extractor_class,cache_enabled,io_override,**kwargs) + extractor_proxy.pre_assign(self) + self.dict[output_name]=extractor_proxy + + def append_data(self,data,feature_class,output_name): + data_proxy=DataProxy(data,feature_class) + data_proxy.pre_assign(self) + self.dict[output_name]=data_proxy + + def declare_proxy_array(self,name): + if(name in self.proxy_array): + return + self.proxy_array.add(name) + + def activate_proxy(self, item, free=False, verbose_id=0, verbose_all=0, start_time=None): + if(verbose_all>0): + if(verbose_id > 0 and start_time is not None): + current_time=time.time() + print('[%d/%d]Activating %s, passed:'%(verbose_id,verbose_all,item), + str(datetime.timedelta(seconds=current_time-start_time)),'remaining:', + str(datetime.timedelta(seconds=(current_time-start_time)/verbose_id*(verbose_all-verbose_id))),flush=True) + else: + print('[%d/%d]Activating %s'%(verbose_id,verbose_all,item),flush=True) + self.dict[item].get(self) + if(free): + self.free('') + + def save(self, item, filename, create_dir=False): + if(create_dir): + mir.cache.mkdir_for_file(filename) + self.dict[item].feature_class().write(self.dict[item].get(self),filename,self) + + def visualize(self,items=None,use_raw_music_file=True,music='music',midi_texture_builder: TextureBuilder=None): + if(items==None): + items=[] + elif(not isinstance(items,list)): + items=[items] + temp_path=os.path.join(WORKING_PATH,'temp') + if(not os.path.isdir(temp_path)): + os.makedirs(temp_path) + result_string='"'+SONIC_VISUALIZER_PATH+'" ' + if(use_raw_music_file): + if(isinstance(music,list)): + music_list=music + else: + music_list=[music] + for music_item in music_list: + if (not isinstance(self.dict[music_item],FileProxy)): + # Well, there is no raw music file at all + items.insert(0,music_item) + override_sr=self.prop.get_unrecorded('sr') + else: + filepath=os.path.join(WORKING_PATH, self.dict[music_item].filepath) + result_string += '"' + filepath + '" ' + info = mediainfo(filepath) + override_sr=int(info['sample_rate']) + else: + override_sr=self.prop.get_unrecorded('sr') + temp_file_list=[] + for item in items: + if(self.has(item)): + abbr='_visualize.' + self.dict[item].feature_class().get_visualize_extention_name() + temp_file_name=os.path.join(temp_path,item+abbr) + temp_file_list.append(temp_file_name) + result_string+='"'+temp_file_name+'" ' + self.dict[item].save_visualize_temp_file(self,temp_file_name,auralize=False,override_sr=override_sr) + else: + raise Exception('No such feature to visualize: %s'%item) + + if(midi_texture_builder is not None): + chords_item=midi_texture_builder.chords_item + beats_item=midi_texture_builder.beats_item + if(self.has(chords_item)): + abbr='_auralize.svl' + temp_file_name=os.path.join(temp_path,chords_item+abbr) + temp_file_list.append(temp_file_name) + result_string+='"'+temp_file_name+'" ' + if(beats_item!=None): + beats=self.dict[beats_item].get(self) + else: + beats=None + generator=midi_texture_builder.texture_class() + + sr = self.prop.get_unrecorded('sr') + win_shift = self.prop.get_unrecorded('hop_length') + generator.auralize(filename=temp_file_name,chords=self.dict[chords_item].get(self),beats=beats, + sr=sr,win_shift=win_shift) + else: + raise Exception('No such feature to auralize: %s'%chords_item) + + return_code=subprocess.call(result_string.replace('\\','/')) + + # Delete temp files + for path in temp_file_list: + try: + os.unlink(path) + except: + print('[Warning] Temp file delete failed:',path) + return return_code + + +class DataPool: + def __init__(self,name,**default_properties): + self.entries=[] + self.dict={} # collections.OrderedDict() + self.name=name + self.antidict=[] + self.default_prop={} + for (k,v) in default_properties: + self.default_prop[k]=v + + def __getitem__(self, key): + if(isinstance(key,slice)): + sub_indices=range(len(self.entries))[key] + sub_pool=DataPool(self.name) + for i in sub_indices: + sub_pool.__append_entry(self.entries[i],self.antidict[i]) + return sub_pool + else: + if(isinstance(key,int)): + raise Exception('Use dataset.entries to iterate over its entries') + raise Exception('Unsupported slicing type:',key) + + def __append_entry(self,entry,entry_name): + lower_entry_name=entry_name.lower() + if(lower_entry_name in self.dict): + print('Warning: entry `%s` overriding %s'%(entry.name,entry_name)) + self.dict[lower_entry_name]=entry + self.antidict.append(lower_entry_name) + self.entries.append(entry) + + def remove_entry(self,entry): + # todo: more situtations + if('/' in entry.name): + entry_name=entry.name[entry.name.index('/')+1:] + else: + entry_name=entry.name + lower_entry_name=entry_name.lower() + del self.dict[lower_entry_name] + self.antidict.remove(lower_entry_name) + self.entries.remove(entry) + + def add_entry(self,entry): + filename=entry.name.split('/')[-1] + if(filename==''): + raise Exception('Cannot add entry whose name is empty') + if('&' in self.name): + print('Warning: You are adding an entry to a joint dataset. Don\'t do that!') + elif(entry.name.split('/')[0]!=self.name): + print('Warning: Inconsistent dataset name, %s expected, %s found'%(self.name,entry.name.split('/')[0])) + self.__append_entry(entry,filename) + + def set_property(self,key,value): + self.default_prop[key]=value + + def new_entry(self,filename): + if('&' in self.name): + print('Warning: You are creating an entry in a joint dataset. Don\'t do that!') + entry = DataEntry(self.name+'/'+filename) + lower_filename=filename.lower() + if(lower_filename in self.dict): + print('Warning: Entry name overwrite: %s'%filename) + for k in self.default_prop: + entry.prop.set(k, self.default_prop[k]) + self.__append_entry(entry,filename) + return entry + + def append_folder(self,folder_path,suffix,typename,output_name,recursive=False): + if(recursive): + files=[os.path.join(dp, f).replace('\\','/') for dp, dn, fn in os.walk(folder_path) for f in fn] + else: + files=[os.path.join(folder_path, f) for f in os.listdir(folder_path)] + # if it's run for the first time, create the dict + if(len(self.dict)==0): + for file in files: + if file.endswith(suffix): + filename = os.path.basename(file) + filename = filename[:len(filename) - len(suffix)] + entry = DataEntry(self.name+'/'+filename) + entry.append_file(file, typename, output_name=output_name, file_exist_check=False) + for k in self.default_prop: + entry.prop.set(k, self.default_prop[k]) + self.__append_entry(entry,filename) + # sorted order + # for (k,entry) in self.dict.items(): + # self.entries.append(entry) + if(len(self.dict)==0): + print('Warning: No data entry was created in "%s"'%folder_path) + else: # check the dict + mark={} + for file in files: + if file.endswith(suffix): + filename=os.path.basename(file) + filename=filename[:len(filename)-len(suffix)] + lower_filename=filename.lower() + if(lower_filename in self.dict): + entry=self.dict[lower_filename] + entry.append_file(file,typename,output_name=output_name,file_exist_check=False) + mark[lower_filename]=True + delta=len(self.dict)-len(mark) + if(delta!=0): + print('Warning: %d entries not appended in "%s"'%(delta,folder_path)) + if(delta>10): + print('Some of them are:') + delta=10 + else: + print('They are:') + for (k,v) in self.dict.items(): + if(k not in mark): + print(k) + delta-=1 + if(delta==0): + break + + def append_extractor(self,extractor_class,output_name,cache_enabled=True,io_override=None,**kwargs): + for entry in self.entries: + entry.append_extractor(extractor_class,output_name,cache_enabled=cache_enabled,io_override=io_override,**kwargs) + + def activate_proxy(self,item,thread_number=1,timing=True,free=False): + entries_needs = [entry for entry in self.entries if not entry.dict[item].loaded] + total=len(entries_needs) + print('Total %s: %d entries to activate' % (item,total)) + start_time=time.time() if timing else None + if(thread_number!=1): + random.shuffle(entries_needs) + Parallel(n_jobs=thread_number)(delayed(DataEntry.activate_proxy)(entries_needs[i],item,free,i,total,start_time) for i in range(len(entries_needs))) + else: + for i in range(len(entries_needs)): + entries_needs[i].activate_proxy(item,free,i,total,start_time) + + def free(self, item='', gc_collect=True): + if(item==''): + for e in self.entries: + for del_item in e.dict: + e.dict[del_item].unload(gc_collect=False) + else: + for e in self.entries: + e.dict[item].unload(gc_collect=False) + if(gc_collect): + gc.collect() + + def subrange(self,*args): + subpool=DataPool(self.name) + for i in range(*args): + subpool.__append_entry(self.entries[i],self.antidict[i]) + return subpool + + def sublist(self,arg): + subpool=DataPool(self.name) + for i in arg: + subpool.__append_entry(self.entries[i],self.antidict[i]) + return subpool + + + def find(self,name): + for entry in self.entries: + if(name.lower() in entry.name.lower()): + return entry + raise Exception('Cannot find %s in %s'%(name,self.name)) + + def where(self,name): + subpool=DataPool(self.name) + for i in range(len(self.entries)): + entry=self.entries[i] + if(name.lower() in entry.name.lower()): + subpool.__append_entry(entry,self.antidict[i]) + return subpool + + def random_choice(self,count=1): + import random + return self.sublist(random.sample(range(len(self.entries)),count)) + + def join(*args): + result_pool=DataPool(' & '.join([dataset.name for dataset in args])) + for dataset in args: + for i in range(len(dataset.entries)): + result_pool.__append_entry(dataset.entries[i],dataset.entries[i].name) + # For joint dataset, the keys in the look-up dictionary self.dict + # will be formatted as `original_dataset/entry_file_name` instead of + # `entry_file_name` + return result_pool + diff --git a/piano_arranger/chord_recognition/mir/extractors/__init__.py b/piano_arranger/chord_recognition/mir/extractors/__init__.py new file mode 100644 index 0000000..00ecc06 --- /dev/null +++ b/piano_arranger/chord_recognition/mir/extractors/__init__.py @@ -0,0 +1,3 @@ +from mir.extractors.extractor_base import ExtractorBase + +__all__ =['ExtractorBase'] \ No newline at end of file diff --git a/piano_arranger/chord_recognition/mir/extractors/extractor_base.py b/piano_arranger/chord_recognition/mir/extractors/extractor_base.py new file mode 100644 index 0000000..5eb1725 --- /dev/null +++ b/piano_arranger/chord_recognition/mir/extractors/extractor_base.py @@ -0,0 +1,107 @@ +from abc import ABC,abstractmethod +from mir.common import WORKING_PATH +from mir import io +import os +import pickle + +def pickle_read(filename): + f = open(filename, 'rb') + obj = pickle.load(f) + f.close() + return obj + +def pickle_write(data, filename): + f = open(filename, 'wb') + pickle.dump(data, f) + f.close() + +def try_mkdir(filename): + folder=os.path.dirname(filename) + if(not os.path.isdir(folder)): + os.makedirs(folder) + +class ExtractorBase(ABC): + + def require(self, *args): + pass + + def get_feature_class(self): + return io.UnknownIO + + @abstractmethod + def extract(self,entry,**kwargs): + pass + + def __create_cache_path(self,entry,cached_prop_record,input_kwargs): + + items={} + items_entry={} + for k in input_kwargs: + items[k]=input_kwargs[k] + for prop_name in cached_prop_record: + if(prop_name not in items): + items_entry[prop_name]=entry.prop.get_unrecorded(prop_name) + + if(len(items)==0): + folder_name=self.__class__.__name__ + else: + folder_name=self.__class__.__name__+'/'+','.join([k+'='+str(items[k]) for k in sorted(items.keys())]) + + if(len(items_entry)==0): + entry_name=entry.name+'.cache' + else: + entry_name=entry.name+'.'+','.join([k+'='+str(items_entry[k]) for k in sorted(items_entry.keys())])+'.cache' + + return os.path.join(WORKING_PATH, 'cache_data', folder_name, entry_name) + + + def extract_and_cache(self,entry,cache_enabled=True,**kwargs): + folder_name=os.path.join(WORKING_PATH, 'cache_data',self.__class__.__name__) + prop_cache_filename=os.path.join(folder_name,'_prop_records.cache') + if('cached_prop_record' in self.__dict__): + cached_prop_record=self.__dict__['cached_prop_record'] + else: + if(os.path.exists(prop_cache_filename)): + cached_prop_record=pickle_read(prop_cache_filename) + else: + cached_prop_record=None + + if(cache_enabled and entry.name!='' and self.get_feature_class()!=io.UnknownIO): + # Need cache + need_io_create=False + if(cached_prop_record is None): + entry.prop.start_record_reading() + feature=self.extract(entry,**kwargs) + cached_prop_record=sorted(entry.prop.end_record_reading()) + try_mkdir(prop_cache_filename) + pickle_write(cached_prop_record,prop_cache_filename) + cache_file_name=self.__create_cache_path(entry,cached_prop_record,kwargs) + need_io_create=True + else: + cache_file_name=self.__create_cache_path(entry,cached_prop_record,kwargs) + if(not os.path.isfile(cache_file_name)): + entry.prop.start_record_reading() + feature=self.extract(entry,**kwargs) + new_prop_record=sorted(entry.prop.end_record_reading()) + if(cached_prop_record!=new_prop_record): + print('[Warning] Inconsistent cached properity requirement in %s, overrode:'%self.__class__.__name__) + print('Old:',cached_prop_record) + print('New:',new_prop_record) + cached_prop_record=new_prop_record + pickle_write(cached_prop_record,prop_cache_filename) + cache_file_name = self.__create_cache_path(entry, cached_prop_record, kwargs) + need_io_create=True + else: + io_obj=self.get_feature_class()() + feature=io_obj.safe_read(cache_file_name,entry) + if(need_io_create): + io_obj=self.get_feature_class()() + io_obj.create(feature,cache_file_name,entry) + else: + feature = self.extract(entry, **kwargs) + return feature + + + + + diff --git a/piano_arranger/chord_recognition/mir/extractors/librosa_extractor.py b/piano_arranger/chord_recognition/mir/extractors/librosa_extractor.py new file mode 100644 index 0000000..fb269cc --- /dev/null +++ b/piano_arranger/chord_recognition/mir/extractors/librosa_extractor.py @@ -0,0 +1,59 @@ +from mir.extractors.extractor_base import * +import librosa +import numpy as np + +class HPSS(ExtractorBase): + + def get_feature_class(self): + return io.MusicIO + + def extract(self,entry,**kwargs): + if('source' in kwargs): + y=entry.dict[kwargs['source']].get(entry) + else: + y=entry.music + y_h=librosa.effects.harmonic(y,margin=kwargs['margin']) + #y_h, y_p = librosa.effects.hpss(y, margin=(1.0, 5.0)) + return y_h + +class CQT(ExtractorBase): + def get_feature_class(self): + return io.SpectrogramIO + + # Warning this spectrum has a 1/3 half note stepping + def extract(self,entry,**kwargs): + n_bins = 262 + y = entry.music + logspec = librosa.core.cqt(y, sr=kwargs['sr'], hop_length=kwargs['hop_length'], bins_per_octave=36, n_bins=n_bins, + filter_scale=1.5).T + logspec = np.abs(logspec) + return logspec + +class STFT(ExtractorBase): + def get_feature_class(self): + return io.SpectrogramIO + + # Warning this spectrum has a 1/3 half note stepping + def extract(self,entry,**kwargs): + y = entry.music + logspec = librosa.core.stft(y, win_length=kwargs['win_size'], hop_length=kwargs['hop_length']).T + logspec = np.abs(logspec) + return logspec + + +class Onset(ExtractorBase): + def get_feature_class(self): + return io.SpectrogramIO + + def extract(self,entry,**kwargs): + onset=librosa.onset.onset_strength(entry.music,sr=kwargs['sr'], hop_length=kwargs['hop_length']).reshape((-1,1)) + return onset + +class Energy(ExtractorBase): + def get_feature_class(self): + return io.SpectrogramIO + + def extract(self,entry,**kwargs): + energy=librosa.feature.rmse(y=entry.dict[kwargs['source']].get(entry), hop_length=kwargs['hop_length'], + frame_length=kwargs['win_size'],center=True).reshape((-1,1)) + return energy diff --git a/piano_arranger/chord_recognition/mir/extractors/misc.py b/piano_arranger/chord_recognition/mir/extractors/misc.py new file mode 100644 index 0000000..3a37ee9 --- /dev/null +++ b/piano_arranger/chord_recognition/mir/extractors/misc.py @@ -0,0 +1,56 @@ +from mir.extractors.extractor_base import * +import librosa +import numpy as np + +class BlankMusic(ExtractorBase): + def get_feature_class(self): + return io.MusicIO + + def extract(self,entry,**kwargs): + time=60.0 # seconds + if('time' in kwargs): + time=kwargs['time'] + return np.zeros((int(np.ceil(time*entry.prop.sr)))) + + +class FrameCount(ExtractorBase): + def get_feature_class(self): + return io.IntegerIO + + def extract(self,entry,**kwargs): + # self.require(entry.prop.hop_length) + return entry.dict[kwargs['source']].get(entry).shape[0] + +class Evaluate(): + + def __init__(self, io): + self.__io=io + + def __call__(self, *args, **kwargs): + inner_instance=Evaluate.InnerEvaluate() + inner_instance.io=self.__io + return inner_instance + + class InnerEvaluate(ExtractorBase): + def __init__(self): + self.io=None + + def get_feature_class(self): + return self.io + + class __ProxyReflector(): + + def __init__(self,entry): + self.__entry=entry + + def __getattr__(self, item): + if(item in self.__entry.dict): + print('Getting %s'%item) + return self.__entry.dict[item].get(self.__entry) + else: + raise AttributeError('No key \'%s\' found in entry %s'%(item,self.__entry.name)) + + def extract(self,entry,**kwargs): + eval_proxy_ref__=__class__.__ProxyReflector(entry) + expr=kwargs['expr'].replace('$','eval_proxy_ref__.') + return eval(expr) diff --git a/piano_arranger/chord_recognition/mir/extractors/vamp_extractor.py b/piano_arranger/chord_recognition/mir/extractors/vamp_extractor.py new file mode 100644 index 0000000..99015a8 --- /dev/null +++ b/piano_arranger/chord_recognition/mir/extractors/vamp_extractor.py @@ -0,0 +1,136 @@ +from mir.extractors.extractor_base import * +from mir.common import WORKING_PATH,SONIC_ANNOTATOR_PATH,PACKAGE_PATH +from mir.cache import hasher +import numpy as np +import subprocess + +def rewrite_extract_n3(entry,inputfilename,outputfilename): + f=open(inputfilename,'r') + content=f.read() + f.close() + content=content.replace('[__SR__]',str(entry.prop.sr)) + content=content.replace('[__WIN_SHIFT__]',str(entry.prop.hop_length)) + content=content.replace('[__WIN_SIZE__]',str(entry.prop.win_size)) + if(not os.path.isdir(os.path.dirname(outputfilename))): + os.makedirs(os.path.dirname(outputfilename)) + f=open(outputfilename,'w') + f.write(content) + f.close() + + +class NNLSChroma(ExtractorBase): + + def get_feature_class(self): + return io.ChromaIO + + def extract(self,entry,**kwargs): + print('NNLSChroma working on entry '+entry.name) + if('margin' in kwargs): + if(kwargs['margin']>0): + music=entry.music_h + else: + raise Exception('Error margin') + + else: + music=entry.music + music_io=io.MusicIO() + temp_path=os.path.join(WORKING_PATH,'temp/nnlschroma_extractor_%s.wav'%hasher(entry.name)) + temp_n3_path=temp_path+'.n3' + rewrite_extract_n3(entry,os.path.join(PACKAGE_PATH,'data/bothchroma.n3'),temp_n3_path) + music_io.write(music,temp_path,entry) + proc=subprocess.Popen([SONIC_ANNOTATOR_PATH, + '-t',temp_n3_path, + temp_path, + '-w','lab','--lab-stdout' + ],stdout=subprocess.PIPE,stderr=subprocess.DEVNULL) + # print('Begin processing') + result=np.zeros((0,24)) + for line in proc.stdout: + # the real code does filtering here + line=bytes.decode(line) + if(line.endswith('\r\n')): + line=line[:len(line)-2] + if (line.endswith('\r')): + line=line[:len(line)-1] + arr=np.array(list(map(float,line.split('\t')))[1:]) + arr=arr.reshape((2,12))[::-1].T + arr=np.roll(arr,-3,axis=0).reshape((1,24)) + result=np.append(result,arr,axis=0) + try: + os.unlink(temp_path) + os.unlink(temp_n3_path) + except: + pass + if(result.shape[0]==0): + raise Exception('Empty response') + return result + + +class TunedLogSpectrogram(ExtractorBase): + + def get_feature_class(self): + return io.SpectrogramIO + + def extract(self,entry,**kwargs): + print('TunedLogSpectrogram working on entry '+entry.name) + music_io = io.MusicIO() + temp_path=os.path.join(WORKING_PATH,'temp/tunedlogspectrogram_extractor_%s.wav'%hasher(entry.name)) + temp_n3_path=temp_path+'.n3' + rewrite_extract_n3(entry,os.path.join(PACKAGE_PATH,'data/tunedlogfreqspec.n3'),temp_n3_path) + music_io.write(entry.music,temp_path,entry) + proc=subprocess.Popen([SONIC_ANNOTATOR_PATH, + '-t',temp_n3_path, + temp_path, + '-w','lab','--lab-stdout' + ],stdout=subprocess.PIPE,stderr=subprocess.DEVNULL) + # print('Begin processing') + result=np.zeros((0,256)) + for line in proc.stdout: + # the real code does filtering here + line=bytes.decode(line) + if(line.endswith('\r\n')): + line=line[:len(line)-2] + if (line.endswith('\r')): + line=line[:len(line)-1] + arr=np.array(list(map(float,line.split('\t')))[1:]) + arr=arr.reshape((1,-1)) + result=np.append(result,arr,axis=0) + try: + os.unlink(temp_path) + os.unlink(temp_n3_path) + except: + pass + if(result.shape[0]==0): + raise Exception('Empty response') + return result + +class GlobalTuning(ExtractorBase): + + def get_feature_class(self): + return io.FloatIO + + def extract(self,entry,**kwargs): + music_io = io.MusicIO() + temp_path=os.path.join(WORKING_PATH,'temp/tuning_%s.wav'%hasher(entry.name)) + temp_n3_path=temp_path+'.n3' + rewrite_extract_n3(entry,os.path.join(PACKAGE_PATH,'data/tuning.n3'),temp_n3_path) + if('source' in kwargs): + music=entry.dict[kwargs['source']].get(entry) + else: + music=entry.music + music_io.write(music,temp_path,entry) + proc=subprocess.Popen([SONIC_ANNOTATOR_PATH, + '-t',temp_n3_path, + temp_path, + '-w','lab','--lab-stdout' + ],stdout=subprocess.PIPE,stderr=subprocess.DEVNULL) + # print('Begin processing') + output=proc.stdout.readlines() + result=(np.log2(np.float64(output[0].decode().split('\t')[2]))-np.log2(440))*12 + try: + os.unlink(temp_path) + os.unlink(temp_n3_path) + except: + pass + return result + diff --git a/piano_arranger/chord_recognition/mir/io/__init__.py b/piano_arranger/chord_recognition/mir/io/__init__.py new file mode 100644 index 0000000..f490fbf --- /dev/null +++ b/piano_arranger/chord_recognition/mir/io/__init__.py @@ -0,0 +1,12 @@ +from .feature_io_base import FeatureIO,LoadingPlaceholder +from .implement.chroma_io import ChromaIO +from .implement.midi_io import MidiIO +from .implement.music_io import MusicIO +from .implement.spectrogram_io import SpectrogramIO +from .implement.scalar_io import IntegerIO,FloatIO +from .implement.unknown_io import UnknownIO +from .implement.regional_spectrogram_io import RegionalSpectrogramIO + +__all__ =['FeatureIO','LoadingPlaceholder', + 'ChromaIO','MidiIO','MusicIO','SpectrogramIO','IntegerIO','FloatIO', + 'RegionalSpectrogramIO','UnknownIO'] \ No newline at end of file diff --git a/piano_arranger/chord_recognition/mir/io/feature_io_base.py b/piano_arranger/chord_recognition/mir/io/feature_io_base.py new file mode 100644 index 0000000..bc938bc --- /dev/null +++ b/piano_arranger/chord_recognition/mir/io/feature_io_base.py @@ -0,0 +1,96 @@ +from abc import ABC,abstractmethod +import pickle +import os + +class LoadingPlaceholder(): + def __init__(self,proxy,entry): + self.proxy=proxy + self.entry=entry + pass + + def fire(self): + self.proxy.get(self.entry) + +class FeatureIO(ABC): + + @abstractmethod + def read(self, filename, entry): + pass + + def safe_read(self, filename, entry): + entry.prop.start_record_reading() + try: + result=self.read(filename,entry) + except Exception: + entry.prop.end_record_reading() + raise + entry.prop.end_record_reading() + return result + + def try_mkdir(self, filename): + folder=os.path.dirname(filename) + if(not os.path.isdir(folder)): + os.makedirs(folder) + + def create(self, data, filename, entry): + self.try_mkdir(filename) + self.write(data,filename,entry) + + @abstractmethod + def write(self, data, filename, entry): + pass + + # override iif writing and visualizing use different methods + # (i.e. compressed vs uncompressed) + def visualize(self, data, filename, entry, override_sr): + self.write(data, filename, entry) + + # override iff entry properties will be updated upon loading + def pre_assign(self, entry, proxy): + pass + + # override iff entry properties need updated upon loading + def post_load(self, data, entry): + pass + + # override iif it will save as other formats (e.g. wav) + def get_visualize_extention_name(self): + return "txt" + + def file_to_evaluation_format(self, filename, entry): + raise Exception('Not supported by the io class') + + def data_to_evaluation_format(self, data, entry): + raise Exception('Not supported by the io class') + + +def pickle_read(self, filename): + f = open(filename, 'rb') + obj = pickle.load(f) + f.close() + return obj + +def pickle_write(self, data, filename): + f = open(filename, 'wb') + pickle.dump(data, f) + f.close() + + +def create_svl_3d_data(labels, data): + assert (len(labels) == data.shape[1]) + results_part1 = ['' % (i, str(labels[i])) for i in range(len(labels))] + results_part2 = ['%s' % (i, ' '.join([ + str(s) for s in data[i] + ])) for i in range(data.shape[0])] + return '\n'.join(results_part1) + '\n' + '\n'.join(results_part2) + + +def framed_2d_feature_visualizer(entry,features, filename): + f = open(filename, 'w') + for i in range(0, features.shape[0]): + time = entry.prop.hop_length * i / entry.prop.sr + f.write(str(time)) + for j in range(0, features.shape[1]): + f.write('\t' + str(features[i][j])) + f.write('\n') + f.close() \ No newline at end of file diff --git a/piano_arranger/chord_recognition/mir/io/implement/__init__.py b/piano_arranger/chord_recognition/mir/io/implement/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/piano_arranger/chord_recognition/mir/io/implement/chroma_io.py b/piano_arranger/chord_recognition/mir/io/implement/chroma_io.py new file mode 100644 index 0000000..7de745e --- /dev/null +++ b/piano_arranger/chord_recognition/mir/io/implement/chroma_io.py @@ -0,0 +1,48 @@ +from mir.io.feature_io_base import * +import numpy as np + +class ChromaIO(FeatureIO): + def read(self, filename, entry): + if(filename.endswith('.csv')): + f=open(filename,'r') + lines=f.readlines() + result=[] + for line in lines: + line=line.strip() + if(line==''): + continue + arr=np.array(list(map(float,line.split(',')[2:]))) + arr=arr.reshape((2,12))[::-1].T + arr=np.roll(arr,-3,axis=0).reshape((24)) + result.append(arr) + data=np.array(result) + else: + data=pickle_read(self, filename) + return data + + def write(self, data, filename, entry): + pickle_write(self, data, filename) + + def visualize(self, data, filename, entry, override_sr): + sr=entry.prop.sr + win_shift=entry.prop.hop_length + feature_tuple_size=entry.prop.chroma_tuple_size + # if(FEATURETUPLESIZE==2): + features=data + f = open(filename, 'w') + for i in range(0, features.shape[0]): + time = win_shift * i / sr + f.write(str(time)) + for j in range(0,feature_tuple_size): + if(j>0): + f.write('\t0') + for k in range(0, 12): + f.write('\t' + str(features[i][k*feature_tuple_size+j])) + f.write('\n') + f.close() + + def pre_assign(self, entry, proxy): + entry.prop.set('n_frame', LoadingPlaceholder(proxy, entry)) + + def post_load(self, data, entry): + entry.prop.set('n_frame', data.shape[0]) \ No newline at end of file diff --git a/piano_arranger/chord_recognition/mir/io/implement/midi_io.py b/piano_arranger/chord_recognition/mir/io/implement/midi_io.py new file mode 100644 index 0000000..8cb0c54 --- /dev/null +++ b/piano_arranger/chord_recognition/mir/io/implement/midi_io.py @@ -0,0 +1,16 @@ +from mir.io.feature_io_base import * +import pretty_midi + +class MidiIO(FeatureIO): + def read(self, filename, entry): + midi_data = pretty_midi.PrettyMIDI(filename) + return midi_data + + def write(self, data, filename, entry): + data.write(filename) + + def visualize(self, data, filename, entry, override_sr): + data.write(filename) + + def get_visualize_extention_name(self): + return "mid" diff --git a/piano_arranger/chord_recognition/mir/io/implement/music_io.py b/piano_arranger/chord_recognition/mir/io/implement/music_io.py new file mode 100644 index 0000000..469888f --- /dev/null +++ b/piano_arranger/chord_recognition/mir/io/implement/music_io.py @@ -0,0 +1,18 @@ +from mir.io.feature_io_base import * +import librosa + +class MusicIO(FeatureIO): + def read(self, filename, entry): + y, sr = librosa.load(filename, sr=entry.prop.sr, mono=True) + return y #(y-np.mean(y))/np.std(y) + + def write(self, data, filename, entry): + sr=entry.prop.sr + librosa.output.write_wav(filename, y=data, sr=sr, norm=False) + + def visualize(self, data, filename, entry, override_sr): + sr=entry.prop.sr + librosa.output.write_wav(filename, y=data, sr=sr, norm=True) # otherwise I would be deaf + + def get_visualize_extention_name(self): + return "wav" \ No newline at end of file diff --git a/piano_arranger/chord_recognition/mir/io/implement/regional_spectrogram_io.py b/piano_arranger/chord_recognition/mir/io/implement/regional_spectrogram_io.py new file mode 100644 index 0000000..97d5d81 --- /dev/null +++ b/piano_arranger/chord_recognition/mir/io/implement/regional_spectrogram_io.py @@ -0,0 +1,80 @@ +from mir.io.feature_io_base import * +from mir.common import PACKAGE_PATH +import numpy as np + +class RegionalSpectrogramIO(FeatureIO): + def read(self, filename, entry): + data=pickle_read(self, filename) + assert(len(data)==3 or len(data)==2) + return data + + def write(self, data, filename, entry): + assert(len(data)==3 or len(data)==2) + pickle_write(self, data, filename) + + def visualize(self, data, filename, entry, override_sr): + if(len(data)==2): + timing,data=data + labels=None + elif(len(data)==3): + labels,timing,data=data + else: + raise Exception("Format error") + data=np.array(data) + if(len(data.shape)==1): + data=data.reshape((-1,1)) + sr = entry.prop.sr + win_shift=entry.prop.hop_length + timing=np.array(timing).reshape((len(timing),-1)) + n_frame=max(1,int(np.round(np.max(timing*sr/win_shift)))) + data_indices=(-1)*np.ones(n_frame,dtype=np.int32) + timing_start=timing[:len(data),0] + if(timing.shape[1]==1): + assert(len(timing)==len(data) or len(timing)==len(data)+1) + if(len(timing)==len(data)+1): + timing_end=timing[1:,0] + else: + timing_end=np.append(timing[1:,0],timing[-1,0]*2-timing[-2,0] if(len(timing)>1) else 1.0) + else: + timing_end=timing[:,1] + for i in range(len(data)): + frame_start=max(0,int(np.round(timing_start[i]*sr/win_shift))) + frame_end=max(0,int(np.round(timing_end[i]*sr/win_shift))) + data_indices[frame_start:frame_end]=i + if(data.shape[1]>=1): + f = open(os.path.join(PACKAGE_PATH,'data/spectrogram_template.svl'), 'r') + content = f.read() + f.close() + content = content.replace('[__SR__]', str(sr)) + content = content.replace('[__WIN_SHIFT__]', str(win_shift)) + content = content.replace('[__SHAPE_1__]', str(data.shape[1])) + content = content.replace('[__COLOR__]', str(1)) + if(labels is None): + labels = [str(i) for i in range(data.shape[1])] + assert(len(labels)==len(data[0])) + result='\n'.join(['' % (i, str(labels[i])) for i in range(len(labels))])+'\n' + for i in range(n_frame): + if(data_indices[i]>=0): + result +='%s\n' % (i, ' '.join([ + str(s) for s in data[data_indices[i]] + ])) + content = content.replace('[__DATA__]',result) + else: + f = open(os.path.join(PACKAGE_PATH,'data/curve_template.svl'), 'r') + content = f.read() + f.close() + content = content.replace('[__SR__]', str(sr)) + content = content.replace('[__STYLE__]', str(1)) + results=[] + raise NotImplementedError() + # for i in range(0, len(data)): + # results.append(''%(int(override_sr/sr*i*win_shift),data[i,0])) + # content = content.replace('[__DATA__]','\n'.join(results)) + # content = content.replace('[__NAME__]', 'curve') + + f=open(filename,'w') + f.write(content) + f.close() + + def get_visualize_extention_name(self): + return "svl" \ No newline at end of file diff --git a/piano_arranger/chord_recognition/mir/io/implement/scalar_io.py b/piano_arranger/chord_recognition/mir/io/implement/scalar_io.py new file mode 100644 index 0000000..0ce53b7 --- /dev/null +++ b/piano_arranger/chord_recognition/mir/io/implement/scalar_io.py @@ -0,0 +1,32 @@ +from mir.io.feature_io_base import * + + +class FloatIO(FeatureIO): + def read(self, filename, entry): + f=open(filename,'r') + result=float(f.readline().strip()) + f.close() + return result + + def write(self, data, filename, entry): + f=open(filename,'w') + f.write(str(float(data))) + f.close() + + def visualize(self, data, filename, entry, override_sr): + raise Exception('Cannot visualize a scalar') + +class IntegerIO(FeatureIO): + def read(self, filename, entry): + f=open(filename,'r') + result=int(f.readline().strip()) + f.close() + return result + + def write(self, data, filename, entry): + f=open(filename,'w') + f.write(str(int(data))) + f.close() + + def visualize(self, data, filename, entry, override_sr): + raise Exception('Cannot visualize a scalar') diff --git a/piano_arranger/chord_recognition/mir/io/implement/spectrogram_io.py b/piano_arranger/chord_recognition/mir/io/implement/spectrogram_io.py new file mode 100644 index 0000000..f2cad27 --- /dev/null +++ b/piano_arranger/chord_recognition/mir/io/implement/spectrogram_io.py @@ -0,0 +1,58 @@ +from mir.io.feature_io_base import * +from mir.common import PACKAGE_PATH +import numpy as np + +class SpectrogramIO(FeatureIO): + def read(self, filename, entry): + return pickle_read(self, filename) + + def write(self, data, filename, entry): + pickle_write(self, data, filename) + + def visualize(self, data, filename, entry, override_sr): + if(type(data) is tuple): + labels=data[0] + data=data[1] + else: + labels=None + if(len(data.shape)==1): + data=data.reshape((-1,1)) + if(data.shape[1]>1): + f = open(os.path.join(PACKAGE_PATH,'data/spectrogram_template.svl'), 'r') + sr=entry.prop.sr + win_shift=entry.prop.hop_length + content = f.read() + f.close() + content = content.replace('[__SR__]', str(sr)) + content = content.replace('[__WIN_SHIFT__]', str(win_shift)) + content = content.replace('[__SHAPE_1__]', str(data.shape[1])) + content = content.replace('[__COLOR__]', str(1)) + if(labels is None): + labels = [str(i) for i in range(data.shape[1])] + content = content.replace('[__DATA__]',create_svl_3d_data(labels,data)) + else: + f = open(os.path.join(PACKAGE_PATH,'data/curve_template.svl'), 'r') + sr = entry.prop.sr + win_shift=entry.prop.hop_length + content = f.read() + f.close() + content = content.replace('[__SR__]', str(sr)) + content = content.replace('[__STYLE__]', str(1)) + results=[] + for i in range(0, len(data)): + results.append(''%(int(override_sr/sr*i*win_shift),data[i,0])) + content = content.replace('[__DATA__]','\n'.join(results)) + content = content.replace('[__NAME__]', 'curve') + + f=open(filename,'w') + f.write(content) + f.close() + + def pre_assign(self, entry, proxy): + entry.prop.set('n_frame', LoadingPlaceholder(proxy, entry)) + + def post_load(self, data, entry): + entry.prop.set('n_frame', data.shape[0]) + + def get_visualize_extention_name(self): + return "svl" \ No newline at end of file diff --git a/piano_arranger/chord_recognition/mir/io/implement/unknown_io.py b/piano_arranger/chord_recognition/mir/io/implement/unknown_io.py new file mode 100644 index 0000000..f8bbb24 --- /dev/null +++ b/piano_arranger/chord_recognition/mir/io/implement/unknown_io.py @@ -0,0 +1,12 @@ +from mir.io.feature_io_base import * + +class UnknownIO(FeatureIO): + + def read(self, filename, entry): + raise Exception('Unknown type cannot be read') + + def write(self, data, filename, entry): + raise Exception('Unknown type cannot be written') + + def visualize(self, data, filename, entry, override_sr): + raise Exception('Unknown type cannot be visualized') \ No newline at end of file diff --git a/piano_arranger/chord_recognition/mir/music_base.py b/piano_arranger/chord_recognition/mir/music_base.py new file mode 100644 index 0000000..8538d3f --- /dev/null +++ b/piano_arranger/chord_recognition/mir/music_base.py @@ -0,0 +1,21 @@ +NUM_TO_ABS_SCALE=['C','C#','D','Eb','E','F','F#','G','Ab','A','Bb','B'] + +def get_scale_and_suffix(name): + result="C*D*EF*G*A*B".index(name[0]) + prefix_length=1 + if (len(name) > 1): + if (name[1] == 'b'): + result = result - 1 + if (result<0): + result+=12 + prefix_length=2 + if (name[1] == '#'): + result = result + 1 + if (result>=12): + result-=12 + prefix_length=2 + return result,name[prefix_length:] + +def scale_name_to_value(name): + result="1*2*34*5*6*78*9".index(name[-1]) # 8 and 9 are for weird tagging in some mirex chords + return (result-name.count('b')+name.count('#')+12)%12 diff --git a/piano_arranger/chord_recognition/mir/requirements.txt b/piano_arranger/chord_recognition/mir/requirements.txt new file mode 100644 index 0000000..4d56fe5 --- /dev/null +++ b/piano_arranger/chord_recognition/mir/requirements.txt @@ -0,0 +1,8 @@ +librosa>=0.6.1 +joblib>=0.12.2 +pydub>=0.22.1 +numpy>=1.15.4 +h5py>=2.8.0 +torch>=0.4.1 +pretty_midi>=0.2.8 + diff --git a/piano_arranger/chord_recognition/mir/settings.py b/piano_arranger/chord_recognition/mir/settings.py new file mode 100644 index 0000000..394afe2 --- /dev/null +++ b/piano_arranger/chord_recognition/mir/settings.py @@ -0,0 +1,3 @@ +SONIC_VISUALIZER_PATH="C:/Program Files (x86)/Sonic Visualiser/Sonic Visualiser.exe" +SONIC_ANNOTATOR_PATH="C:/Program Files (x86)/Sonic Visualiser/annotator/sonic-annotator.exe" +DEFAULT_DATA_STORAGE_PATH="E:/dataset/" \ No newline at end of file diff --git a/piano_arranger/chord_recognition/requirements.txt b/piano_arranger/chord_recognition/requirements.txt new file mode 100644 index 0000000..39eb9cb --- /dev/null +++ b/piano_arranger/chord_recognition/requirements.txt @@ -0,0 +1,6 @@ +pydub>=0.23.1 +pretty_midi>=0.2.9 +joblib>=0.13.2 +librosa>=0.7.2 +mir_eval>=0.5 +numpy>=1.16 \ No newline at end of file diff --git a/piano_arranger/format_converter.py b/piano_arranger/format_converter.py new file mode 100644 index 0000000..a402ae6 --- /dev/null +++ b/piano_arranger/format_converter.py @@ -0,0 +1,246 @@ +import pretty_midi as pyd +import numpy as np +import sys +sys.path.append('piano_arranger/chord_recognition') +from main import transcribe_cb1000_midi +from scipy.interpolate import interp1d +import mir_eval + + +def expand_chord(chord, shift=0, relative=False): + """ + expand 14-D chord feature to 36-D + For detail, see Z. Wang et al., "Learning interpretable representation for controllable polyphonic music generation," ISMIR 2020. + """ + # chord = np.copy(chord) + root = (chord[0] + shift) % 12 + chroma = np.roll(chord[1: 13], shift) + bass = (chord[13] + shift) % 12 + root_onehot = np.zeros(12) + root_onehot[int(root)] = 1 + bass_onehot = np.zeros(12) + bass_onehot[int(bass)] = 1 + return np.concatenate([root_onehot, chroma, bass_onehot]) + + +def midi2matrix(track, quaver): + """ + quantize a PrettyMIDI track based on specified quavers. + The quantized result is a (T, 128) format defined in defined in Z. Wang et al., "Learning interpretable representation for controllable polyphonic music generation," ISMIR 2020. + """ + #program = track.program + pr_matrix = np.zeros((len(quaver), 128)) + for note in track.notes: + note_start = np.argmin(np.abs(quaver - note.start)) + note_end = np.argmin(np.abs(quaver - note.end)) + if note_end == note_start: + note_end = min(note_start + 1, len(quaver) - 1) + pr_matrix[note_start, note.pitch] = note_end - note_start + return pr_matrix + +def ec2vae_mel_format(pr_matrix): + """ + convert (T, 128) melody format to (T, 130) format. + (T, 128) format defined in Z. Wang et al., "Learning interpretable representation for controllable polyphonic music generation," ISMIR 2020. + (T, 130) format defined in R. Yang et al., "Deep music analogy via latent representation disentanglement," ISMIR 2019. + """ + hold_pitch = 128 + rest_pitch = 129 + melody_roll = np.zeros((len(pr_matrix), 130)) + for t, p in zip(*np.nonzero(pr_matrix)): + dur = int(pr_matrix[t, p]) + melody_roll[t, p] = 1 + melody_roll[t+1:t+dur, hold_pitch] = 1 + melody_roll[np.nonzero(1 - np.sum(melody_roll, axis=1))[0], rest_pitch] = 1 + return melody_roll + + +def leadsheet2matrix(path, melody_track_ID=0): + """ + Tokenize and quantize a lead sheet (a melody track with a chord track). + The input can also be an arbiturary MIDI file with multiple accompaniment tracks. + The first track is by default taken as melody. Otherwise, specify melody_track_ID (counting from zero) + """ + ACC = 4 #quantize at 1/16 beat + midi = pyd.PrettyMIDI(path) + beats = midi.get_beats() + beats = np.append(beats, beats[-1] + (beats[-1] - beats[-2])) + quantize = interp1d(np.array(range(0, len(beats))) * ACC, beats, kind='linear') + quaver = quantize(np.array(range(0, (len(beats) - 1) * ACC))) + melody_roll = ec2vae_mel_format(midi2matrix(midi.instruments[melody_track_ID], quaver)) + + chord_detection = transcribe_cb1000_midi(path) + chord_roll = np.zeros((len(melody_roll), 14)) + for chord in chord_detection: + chord_start, chord_end, chord_symbol = chord + chord_start = np.argmin(np.abs(quaver - chord_start)) + chord_end = np.argmin(np.abs(quaver - chord_end)) + chord_root, bit_map, bass = mir_eval.chord.encode(chord_symbol) + chord = np.concatenate([np.array([chord_root]), np.roll(bit_map, shift=int(chord_root)), np.array([bass])]) + chord_roll[chord_start: chord_end] = chord + chord_roll[np.sum(chord_roll, axis=1)==0, 0]=-1 + chord_roll[np.sum(chord_roll, axis=1)==0, -1]=-1 + + return melody_roll, chord_roll + + +def melody_matrix2data(melody_matrix, tempo=120, start_time=0.0): + """reconstruct melody from matrix to MIDI""" + ROLL_SIZE =130 + HOLD_PITCH = 128 + REST_PITCH = 129 + melodyMatrix = melody_matrix[:, :ROLL_SIZE] + melodySequence = [np.argmax(melodyMatrix[i]) for i in range(melodyMatrix.shape[0])] + + melody_notes = [] + minStep = 60 / tempo / 4 + onset_or_rest = [i for i in range(len(melodySequence)) if not melodySequence[i]==HOLD_PITCH] + onset_or_rest.append(len(melodySequence)) + for idx, onset in enumerate(onset_or_rest[:-1]): + if melodySequence[onset] == REST_PITCH: + continue + else: + pitch = melodySequence[onset] + start = onset * minStep + end = onset_or_rest[idx+1] * minStep + noteRecon = pyd.Note(velocity=100, pitch=pitch, start=start_time+start, end=start_time+end) + melody_notes.append(noteRecon) + melody = pyd.Instrument(program=pyd.instrument_name_to_program('Acoustic Grand Piano')) + melody.notes = melody_notes + return melody + + +def chord_matrix2data(chordMatrix, tempo=120, start_time=0.0, get_list=False): + """reconstruct chord from matrix to MIDI""" + chordSequence = [] + for i in range(chordMatrix.shape[0]): + chordSequence.append(''.join([str(int(j)) for j in chordMatrix[i]])) + minStep = 60 / tempo / 4 #16th quantization + chord_notes = [] + onset_or_rest = [0] + onset_or_rest_ = [i for i in range(1, len(chordSequence)) if chordSequence[i] != chordSequence[i-1] ] + onset_or_rest = onset_or_rest + onset_or_rest_ + onset_or_rest.append(len(chordSequence)) + for idx, onset in enumerate(onset_or_rest[:-1]): + chordset = [int(i) for i in chordSequence[onset]] + start = onset * minStep + end = onset_or_rest[idx+1] * minStep + for note, value in enumerate(chordset): + if value == 1: + noteRecon = pyd.Note(velocity=100, pitch=note+4*12, start=start_time+start, end=start_time+end) + chord_notes.append(noteRecon) + chord = pyd.Instrument(program=pyd.instrument_name_to_program('Acoustic Grand Piano')) + chord.notes = chord_notes + return chord + + +def matrix2leadsheet(leadsheet, tempo=120, start_time=0.0): + """reconstruct lead sheet from matrix to MIDI""" + #leadsheet: (T, 142) + midi = pyd.PrettyMIDI(initial_tempo=tempo) + midi.instruments.append(melody_matrix2data(leadsheet[:, :130], tempo, start_time)) + midi.instruments.append(chord_matrix2data(leadsheet[:, 130:], tempo, start_time)) + return midi + + +def accompany_data2matrix(accompany_track, downbeats): + """ + quantize a PrettyMIDI track into a (T, 128) format as defined in Wang et al., "Learning interpretable representation for controllable polyphonic music generation," ISMIR 2020. + This function has the same purpose as midi2matrix(). + """ + time_stamp_sixteenth_reso = [] + delta_set = [] + downbeats = list(downbeats) + downbeats.append(downbeats[-1] + (downbeats[-1] - downbeats[-2])) + for i in range(len(downbeats)-1): + s_curr = round(downbeats[i] * 16) / 16 + s_next = round(downbeats[i+1] * 16) / 16 + delta = (s_next - s_curr) / 16 + for i in range(16): + time_stamp_sixteenth_reso.append(s_curr + delta * i) + delta_set.append(delta) + time_stamp_sixteenth_reso = np.array(time_stamp_sixteenth_reso) + + pr_matrix = np.zeros((time_stamp_sixteenth_reso.shape[0], 128)) + for note in accompany_track.notes: + onset = note.start + t = np.argmin(np.abs(time_stamp_sixteenth_reso - onset)) + p = note.pitch + duration = int(round((note.end - onset) / delta_set[t])) + pr_matrix[t, p] = duration + return pr_matrix + +def accompany_matrix2data(pr_matrix, tempo=120, start_time=0.0, get_list=False): + """reconstruct a (T, 128) polyphony from magtrix to MIDI.""" + alpha = 0.25 * 60 / tempo + notes = [] + for t in range(pr_matrix.shape[0]): + for p in range(128): + if pr_matrix[t, p] >= 1: + s = alpha * t + start_time + e = alpha * (t + pr_matrix[t, p]) + start_time + notes.append(pyd.Note(100, int(p), s, e)) + if get_list: + return notes + else: + acc = pyd.Instrument(program=pyd.instrument_name_to_program('Acoustic Grand Piano')) + acc.notes = notes + return acc + + +def grid2pr(grid, max_note_count=16, min_pitch=0, pitch_eos_ind=129): + """ + convert a (T, max_simu_note, 6) format grid into (T, 128 polyphony). + The (T, max_simu_note, 6) format is defined in Wang et al., "PIANOTREE VAE: Structured Representation Learning for Polyphonic Music," ISMIR 2020. + The (T, 128 polyphony) format is defined in Wang et al., "Learning interpretable representation for controllable polyphonic music generation," ISMIR 2020. + """ + #grid: (time, max_simu_note, 6) + if grid.shape[1] == max_note_count: + grid = grid[:, 1:] + pr = np.zeros((grid.shape[0], 128), dtype=int) + for t in range(grid.shape[0]): + for n in range(grid.shape[1]): + note = grid[t, n] + if note[0] == pitch_eos_ind: + break + pitch = note[0] + min_pitch + dur = int(''.join([str(_) for _ in note[1:]]), 2) + 1 + pr[t, pitch] = dur + return pr + + +def matrix2midi_with_dynamics(pr_matrices, programs, init_tempo=120, time_start=0, ACC=16): + """ + Reconstruct a multi-track midi from a 3D matrix of shape (Track. Time, 128, 3). + The last dimension each encoders MIDI pitch, velocity, and control message. + """ + tracks = [] + for program in programs: + track_recon = pyd.Instrument(program=int(program), is_drum=False, name=pyd.program_to_instrument_name(int(program))) + tracks.append(track_recon) + + indices_track, indices_onset, indices_pitch = np.nonzero(pr_matrices[:, :, :, 0]) + alpha = 1 / (ACC // 4) * 60 / init_tempo #timetep between each quntization bin + for idx in range(len(indices_track)): + track_id = indices_track[idx] + onset = indices_onset[idx] + pitch = indices_pitch[idx] + + start = onset * alpha + duration = pr_matrices[track_id, onset, pitch, 0] * alpha + velocity = pr_matrices[track_id, onset, pitch, 1] + + note_recon = pyd.Note(velocity=int(velocity), pitch=int(pitch), start=time_start + start, end=time_start + start + duration) + tracks[track_id].notes.append(note_recon) + + for idx in range(len(pr_matrices)): + cc = [] + control_matrix = pr_matrices[idx, :, :, 2] + for t, n in zip(*np.nonzero(control_matrix >= 0)): + start = alpha * t + cc.append(pyd.ControlChange(int(n), int(control_matrix[t, n]), start)) + tracks[idx].control_changes = cc + + midi_recon = pyd.PrettyMIDI(initial_tempo=init_tempo) + midi_recon.instruments = tracks + return midi_recon diff --git a/piano_arranger/models/EC2VAE.py b/piano_arranger/models/EC2VAE.py new file mode 100644 index 0000000..5548045 --- /dev/null +++ b/piano_arranger/models/EC2VAE.py @@ -0,0 +1,138 @@ +import torch +from torch import nn +from torch.nn import functional as F +from torch.distributions import Normal + +""" + Credit to R. Yang et al., "Deep music analogy via latent representation disentanglement," ISMIR 2019 + https://github.com/buggyyang/Deep-Music-Analogy-Demos +""" + +class VAE(nn.Module): + def __init__(self, + roll_dims, + hidden_dims, + rhythm_dims, + condition_dims, + z1_dims, + z2_dims, + n_step, + k=1000): + super(VAE, self).__init__() + self.gru_0 = nn.GRU( + roll_dims + condition_dims, + hidden_dims, + batch_first=True, + bidirectional=True) + self.linear_mu = nn.Linear(hidden_dims * 2, z1_dims + z2_dims) + self.linear_var = nn.Linear(hidden_dims * 2, z1_dims + z2_dims) + self.grucell_0 = nn.GRUCell(z2_dims + rhythm_dims, + hidden_dims) + self.grucell_1 = nn.GRUCell( + z1_dims + roll_dims + rhythm_dims + condition_dims, hidden_dims) + self.grucell_2 = nn.GRUCell(hidden_dims, hidden_dims) + self.linear_init_0 = nn.Linear(z2_dims, hidden_dims) + self.linear_out_0 = nn.Linear(hidden_dims, rhythm_dims) + self.linear_init_1 = nn.Linear(z1_dims, hidden_dims) + self.linear_out_1 = nn.Linear(hidden_dims, roll_dims) + self.n_step = n_step + self.roll_dims = roll_dims + self.hidden_dims = hidden_dims + self.eps = 1 + self.rhythm_dims = rhythm_dims + self.sample = None + self.rhythm_sample = None + self.iteration = 0 + self.z1_dims = z1_dims + self.z2_dims = z2_dims + self.k = torch.FloatTensor([k]) + + def _sampling(self, x): + idx = x.max(1)[1] + x = torch.zeros_like(x) + arange = torch.arange(x.size(0)).long() + if torch.cuda.is_available(): + arange = arange.cuda() + x[arange, idx] = 1 + return x #a batched one-hot vector + + def encoder(self, x, condition): + # self.gru_0.flatten_parameters() + x = torch.cat((x, condition), -1) + x = self.gru_0(x)[-1] #(numLayer*numDirection)* batch* hidden_size + x = x.transpose_(0, 1).contiguous() #batch* (numLayer*numDirection)* hidden_size + x = x.view(x.size(0), -1) #batch* (numLayer*numDirection*hidden_size), where numLayer=1, numDirection=2 + mu = self.linear_mu(x) #batch* (z1_dims + z2_dims) + var = self.linear_var(x).exp_() #batch* (z1_dims + z2_dims) + distribution_1 = Normal(mu[:, :self.z1_dims], var[:, :self.z1_dims]) #distribution for pitch + distribution_2 = Normal(mu[:, self.z1_dims:], var[:, self.z1_dims:]) #distribution for rhythm + return distribution_1, distribution_2 + + def rhythm_decoder(self, z): + out = torch.zeros((z.size(0), self.rhythm_dims)) #batch* rhythm_dims + out[:, -1] = 1. + x = [] + t = torch.tanh(self.linear_init_0(z)) #batch* hidden_dims + hx = t + if torch.cuda.is_available(): + out = out.cuda() + for i in range(self.n_step): + out = torch.cat([out, z], 1) #batch* (rhythm_dims+z2_dims) + hx = self.grucell_0(out, hx) #batch* hidden_dims + out = F.log_softmax(self.linear_out_0(hx), 1) #batch* rhythm_dims + x.append(out) + if self.training: + p = torch.rand(1).item() + if p < self.eps: + out = self.rhythm_sample[:, i, :] + else: + out = self._sampling(out) + else: + out = self._sampling(out) + return torch.stack(x, 1) #batch* n_step* rhythm_dims + + def final_decoder(self, z, rhythm, condition): + out = torch.zeros((z.size(0), self.roll_dims)) #batch* roll_dims + out[:, -1] = 1. + x, hx = [], [None, None] + t = torch.tanh(self.linear_init_1(z)) #batch* hidden_dims + hx[0] = t + if torch.cuda.is_available(): + out = out.cuda() + for i in range(self.n_step): + out = torch.cat([out, rhythm[:, i, :], z, condition[:, i, :]], 1) #batch* roll_dims+rhythm_dims+z1_dims+condition_dims + hx[0] = self.grucell_1(out, hx[0]) #batch* hidden_dims + if i == 0: + hx[1] = hx[0] + hx[1] = self.grucell_2(hx[0], hx[1]) #batch* hidden_dims + out = F.log_softmax(self.linear_out_1(hx[1]), 1) #batch* roll_dims + x.append(out) + if self.training: + p = torch.rand(1).item() + if p < self.eps: + out = self.sample[:, i, :] + else: + out = self._sampling(out) + self.eps = self.k / (self.k + torch.exp(self.iteration / self.k)) + else: + out = self._sampling(out) + return torch.stack(x, 1) #batch* n_step* roll_dims + + def decoder(self, z1, z2, condition=None): + rhythm = self.rhythm_decoder(z2) + return self.final_decoder(z1, rhythm, condition) + + def forward(self, x, condition): + if self.training: + self.sample = x + self.rhythm_sample = x[:, :, :-2].sum(-1).unsqueeze(-1) #batch* n_step* 1 + self.rhythm_sample = torch.cat((self.rhythm_sample, x[:, :, -2:]), -1) #batch* n_step* 3 + self.iteration += 1 + dis1, dis2 = self.encoder(x, condition) + z1 = dis1.rsample() + z2 = dis2.rsample() + recon_rhythm = self.rhythm_decoder(z2) + recon = self.final_decoder(z1, recon_rhythm, condition) + output = (recon, recon_rhythm, dis1.mean, dis1.stddev, dis2.mean, + dis2.stddev) + return output diff --git a/piano_arranger/models/Poly_Dis.py b/piano_arranger/models/Poly_Dis.py new file mode 100644 index 0000000..9346d59 --- /dev/null +++ b/piano_arranger/models/Poly_Dis.py @@ -0,0 +1,270 @@ +from .amc_dl.torch_plus import PytorchModel +from .amc_dl.torch_plus.train_utils import get_zs_from_dists, kl_with_normal +import torch +from torch import nn +from torch.distributions import Normal +import numpy as np +from .ptvae import RnnEncoder, RnnDecoder, PtvaeDecoder, TextureEncoder + +""" + Credit to Z. Wang et al., "Learning interpretable representation for controllable polyphonic music generation," ISMIR 2020. + https://github.com/ZZWaang/polyphonic-chord-texture-disentanglement +""" + +class DisentangleVAE(PytorchModel): + + def __init__(self, name, device, chd_encoder, rhy_encoder, decoder, + chd_decoder): + super(DisentangleVAE, self).__init__(name, device) + self.chd_encoder = chd_encoder + self.rhy_encoder = rhy_encoder + self.decoder = decoder + self.num_step = self.decoder.num_step + self.chd_decoder = chd_decoder + + def confuse_prmat(self, pr_mat): + non_zero_ent = torch.nonzero(pr_mat.long()) + eps = torch.randint(0, 2, (non_zero_ent.size(0),)) + eps = ((2 * eps) - 1).long() + confuse_ent = torch.clamp(non_zero_ent[:, 2] + eps, min=0, max=127) + pr_mat[non_zero_ent[:, 0], non_zero_ent[:, 1], confuse_ent] = \ + pr_mat[non_zero_ent[:, 0], non_zero_ent[:, 1], non_zero_ent[:, 2]] + return pr_mat + + def get_chroma(self, pr_mat): + bs = pr_mat.size(0) + pad = torch.zeros(bs, 32, 4).to(self.device) + pr_mat = torch.cat([pr_mat, pad], dim=-1) + c = pr_mat.view(bs, 32, -1, 12).contiguous() + c = c.sum(dim=-2) # (bs, 32, 12) + c = c.view(bs, 8, 4, 12) + c = c.sum(dim=-2).float() + c = torch.log(c + 1) + return c.to(self.device) + + def run(self, x, c, pr_mat, tfr1, tfr2, tfr3, confuse=True): + embedded_x, lengths = self.decoder.emb_x(x) + # cc = self.get_chroma(pr_mat) + dist_chd = self.chd_encoder(c) + # pr_mat = self.confuse_prmat(pr_mat) + dist_rhy = self.rhy_encoder(pr_mat) + z_chd, z_rhy = get_zs_from_dists([dist_chd, dist_rhy], True) + dec_z = torch.cat([z_chd, z_rhy], dim=-1) + pitch_outs, dur_outs = self.decoder(dec_z, False, embedded_x, + lengths, tfr1, tfr2) + recon_root, recon_chroma, recon_bass = self.chd_decoder(z_chd, False, + tfr3, c) + return pitch_outs, dur_outs, dist_chd, dist_rhy, recon_root, \ + recon_chroma, recon_bass + + def loss_function(self, x, c, recon_pitch, recon_dur, dist_chd, + dist_rhy, recon_root, recon_chroma, recon_bass, + beta, weights, weighted_dur=False): + recon_loss, pl, dl = self.decoder.recon_loss(x, recon_pitch, recon_dur, + weights, weighted_dur) + kl_loss, kl_chd, kl_rhy = self.kl_loss(dist_chd, dist_rhy) + chord_loss, root, chroma, bass = self.chord_loss(c, recon_root, + recon_chroma, + recon_bass) + loss = recon_loss + beta * kl_loss + chord_loss + return loss, recon_loss, pl, dl, kl_loss, kl_chd, kl_rhy, chord_loss, \ + root, chroma, bass + + def chord_loss(self, c, recon_root, recon_chroma, recon_bass): + loss_fun = nn.CrossEntropyLoss() + root = c[:, :, 0: 12].max(-1)[-1].view(-1).contiguous() + chroma = c[:, :, 12: 24].long().view(-1).contiguous() + bass = c[:, :, 24:].max(-1)[-1].view(-1).contiguous() + + recon_root = recon_root.view(-1, 12).contiguous() + recon_chroma = recon_chroma.view(-1, 2).contiguous() + recon_bass = recon_bass.view(-1, 12).contiguous() + root_loss = loss_fun(recon_root, root) + chroma_loss = loss_fun(recon_chroma, chroma) + bass_loss = loss_fun(recon_bass, bass) + chord_loss = root_loss + chroma_loss + bass_loss + return chord_loss, root_loss, chroma_loss, bass_loss + + def kl_loss(self, *dists): + # kl = kl_with_normal(dists[0]) + kl_chd = kl_with_normal(dists[0]) + kl_rhy = kl_with_normal(dists[1]) + kl_loss = kl_chd + kl_rhy + return kl_loss, kl_chd, kl_rhy + + def loss(self, x, c, pr_mat, dt_x, tfr1=0., tfr2=0., tfr3=0., beta=0.1, weights=(1, 0.5)): + #print(pr_mat.shape, dt_x.shape) + outputs = self.run(x, c, pr_mat, tfr1, tfr2, tfr3) + loss = self.loss_function(x, c, *outputs, beta, weights) + return loss + + # def inference(self, c, pr_mat): + # self.eval() + # with torch.no_grad(): + # dist_chd = self.chd_encoder(c) + # # pr_mat = self.confuse_prmat(pr_mat) + # dist_rhy = self.rhy_encoder(pr_mat) + # z_chd, z_rhy = get_zs_from_dists([dist_chd, dist_rhy], True) + # dec_z = torch.cat([z_chd, z_rhy], dim=-1) + # pitch_outs, dur_outs = self.decoder(dec_z, True, None, + # None, 0., 0.) + # est_x, _, _ = self.decoder.output_to_numpy(pitch_outs, dur_outs) + # return est_x + # + # def swap(self, c1, c2, pr_mat1, pr_mat2, fix_rhy, fix_chd): + # pr_mat = pr_mat1 if fix_rhy else pr_mat2 + # c = c1 if fix_chd else c2 + # est_x = self.inference(c, pr_mat) + # return est_x + + def inference_encode(self, pr_mat, c): + self.eval() + with torch.no_grad(): + dist_chd = self.chd_encoder(c) + dist_rhy = self.rhy_encoder(pr_mat) + return dist_chd, dist_rhy + + def inference_decode(self, z_chd, z_rhy): + self.eval() + with torch.no_grad(): + dec_z = torch.cat([z_chd, z_rhy], dim=-1) + pitch_outs, dur_outs = self.decoder(dec_z, True, None, + None, 0., 0.) + est_x, _, _ = self.decoder.output_to_numpy(pitch_outs, dur_outs) + return est_x + + def inference(self, pr_mat, c, sample): + self.eval() + with torch.no_grad(): + dist_chd = self.chd_encoder(c) + dist_rhy = self.rhy_encoder(pr_mat) + z_chd, z_rhy = get_zs_from_dists([dist_chd, dist_rhy], sample) + dec_z = torch.cat([z_chd, z_rhy], dim=-1) + pitch_outs, dur_outs = self.decoder(dec_z, True, None, + None, 0., 0.) + est_x, _, _ = self.decoder.output_to_numpy(pitch_outs, dur_outs) + return est_x + + def swap(self, pr_mat1, pr_mat2, c1, c2, fix_rhy, fix_chd): + pr_mat = pr_mat1 if fix_rhy else pr_mat2 + c = c1 if fix_chd else c2 + est_x = self.inference(pr_mat, c, sample=False) + return est_x + + def posterior_sample(self, pr_mat, c, scale=None, sample_chd=True, + sample_txt=True): + if scale is None and sample_chd and sample_txt: + est_x = self.inference(pr_mat, c, sample=True) + else: + dist_chd, dist_rhy = self.inference_encode(pr_mat, c) + if scale is not None: + mean_chd = dist_chd.mean + mean_rhy = dist_rhy.mean + # std_chd = torch.ones_like(dist_chd.mean) * scale + # std_rhy = torch.ones_like(dist_rhy.mean) * scale + std_chd = dist_chd.scale * scale + std_rhy = dist_rhy.scale * scale + dist_rhy = Normal(mean_rhy, std_rhy) + dist_chd = Normal(mean_chd, std_chd) + z_chd, z_rhy = get_zs_from_dists([dist_chd, dist_rhy], True) + if not sample_chd: + z_chd = dist_chd.mean + if not sample_txt: + z_rhy = dist_rhy.mean + est_x = self.inference_decode(z_chd, z_rhy) + return est_x + + def prior_sample(self, x, c, sample_chd=False, sample_rhy=False, + scale=1.): + dist_chd, dist_rhy = self.inference_encode(x, c) + mean = torch.zeros_like(dist_rhy.mean) + loc = torch.ones_like(dist_rhy.mean) * scale + if sample_chd: + dist_chd = Normal(mean, loc) + if sample_rhy: + dist_rhy = Normal(mean, loc) + z_chd, z_rhy = get_zs_from_dists([dist_chd, dist_rhy], True) + return self.inference_decode(z_chd, z_rhy) + + def gt_sample(self, x): + out = x[:, :, 1:].numpy() + return out + + def interp(self, pr_mat1, c1, pr_mat2, c2, interp_chd=False, + interp_rhy=False, int_count=10): + dist_chd1, dist_rhy1 = self.inference_encode(pr_mat1, c1) + dist_chd2, dist_rhy2 = self.inference_encode(pr_mat2, c2) + [z_chd1, z_rhy1, z_chd2, z_rhy2] = \ + get_zs_from_dists([dist_chd1, dist_rhy1, dist_chd2, dist_rhy2], + False) + if interp_chd: + z_chds = self.interp_z(z_chd1, z_chd2, int_count) + else: + z_chds = z_chd1.unsqueeze(1).repeat(1, int_count, 1) + if interp_rhy: + z_rhys = self.interp_z(z_rhy1, z_rhy2, int_count) + else: + z_rhys = z_rhy1.unsqueeze(1).repeat(1, int_count, 1) + bs = z_chds.size(0) + z_chds = z_chds.view(bs * int_count, -1).contiguous() + z_rhys = z_rhys.view(bs * int_count, -1).contiguous() + estxs = self.inference_decode(z_chds, z_rhys) + return estxs.reshape((bs, int_count, 32, 15, -1)) + + def interp_z(self, z1, z2, int_count=10): + z1 = z1.numpy() + z2 = z2.numpy() + zs = torch.stack([self.interp_path(zz1, zz2, int_count) + for zz1, zz2 in zip(z1, z2)], dim=0) + return zs + + def interp_path(self, z1, z2, interpolation_count=10): + result_shape = z1.shape + z1 = z1.reshape(-1) + z2 = z2.reshape(-1) + + def slerp2(p0, p1, t): + omega = np.arccos( + np.dot(p0 / np.linalg.norm(p0), p1 / np.linalg.norm(p1))) + so = np.sin(omega) + return np.sin((1.0 - t) * omega)[:, None] / so * p0[ + None] + np.sin( + t * omega)[:, None] / so * p1[None] + + percentages = np.linspace(0.0, 1.0, interpolation_count) + + normalized_z1 = z1 / np.linalg.norm(z1) + normalized_z2 = z2 / np.linalg.norm(z2) + dirs = slerp2(normalized_z1, normalized_z2, percentages) + length = np.linspace(np.log(np.linalg.norm(z1)), + np.log(np.linalg.norm(z2)), + interpolation_count) + out = (dirs * np.exp(length[:, None])).reshape( + [interpolation_count] + list(result_shape)) + # out = np.array([(1 - t) * z1 + t * z2 for t in percentages]) + return torch.from_numpy(out).to(self.device).float() + + @staticmethod + def init_model(device=None, chd_size=256, txt_size=256, num_channel=10): + name = 'disvae' + if device is None: + device = torch.device('cuda' if torch.cuda.is_available() + else 'cpu') + # chd_encoder = RnnEncoder(36, 1024, 256) + chd_encoder = RnnEncoder(36, 1024, chd_size) + # rhy_encoder = TextureEncoder(256, 1024, 256) + rhy_encoder = TextureEncoder(256, 1024, txt_size, num_channel) + # pt_encoder = PtvaeEncoder(device=device, z_size=152) + # chd_decoder = RnnDecoder(z_dim=256) + chd_decoder = RnnDecoder(z_dim=chd_size) + # pt_decoder = PtvaeDecoder(note_embedding=None, + # dec_dur_hid_size=64, z_size=512) + pt_decoder = PtvaeDecoder(note_embedding=None, + dec_dur_hid_size=64, + z_size=chd_size + txt_size) + + model = DisentangleVAE(name, device, chd_encoder, + rhy_encoder, pt_decoder, chd_decoder) + return model + + diff --git a/piano_arranger/models/__init__.py b/piano_arranger/models/__init__.py new file mode 100644 index 0000000..2c3bdc2 --- /dev/null +++ b/piano_arranger/models/__init__.py @@ -0,0 +1,5 @@ +from .Poly_Dis import DisentangleVAE +from .ptvae import PtvaeDecoder +from .EC2VAE import VAE +from .ptvae import TextureEncoder +from .transition_model import contrastive_model \ No newline at end of file diff --git a/piano_arranger/models/amc_dl/__init__.py b/piano_arranger/models/amc_dl/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/piano_arranger/models/amc_dl/demo_maker.py b/piano_arranger/models/amc_dl/demo_maker.py new file mode 100644 index 0000000..053f02e --- /dev/null +++ b/piano_arranger/models/amc_dl/demo_maker.py @@ -0,0 +1,38 @@ +import pretty_midi + + +def demo_format_convert(data, f, *inputs): + return [[f(x, *inputs) for x in track] for track in data] + + +def bpm_to_alpha(bpm): + return 60 / bpm + + +def add_notes(track, shift_second, alpha): + notes = [] + ss = 0 + for i, seg in enumerate(track): + notes += [pretty_midi.Note(n.velocity, n.pitch, + n.start + ss, n.end + ss) + for n in seg] + ss += shift_second + return notes + + +def demo_to_midi(data, names, bpm=90., shift_second=None, shift_beat=None): + alpha = bpm_to_alpha(bpm) + if shift_second is None: + shift_second = alpha * shift_beat + midi = pretty_midi.PrettyMIDI(initial_tempo=bpm) + for track, name in zip(data, names): + ins = pretty_midi.Instrument(0, name=name) + ins.notes = add_notes(track, shift_second, alpha) + midi.instruments.append(ins) + return midi + + +def write_demo(fn, data, names, bpm=90., shift_second=None, shift_beat=None): + midi = demo_to_midi(data, names, bpm, shift_second, shift_beat) + midi.write(fn) + diff --git a/piano_arranger/models/amc_dl/torch_plus/__init__.py b/piano_arranger/models/amc_dl/torch_plus/__init__.py new file mode 100644 index 0000000..bbc980b --- /dev/null +++ b/piano_arranger/models/amc_dl/torch_plus/__init__.py @@ -0,0 +1,7 @@ +from .module import PytorchModel, TrainingInterface +from .scheduler import ConstantScheduler, TeacherForcingScheduler, \ + OptimizerScheduler, ParameterScheduler +from .manager import LogPathManager, DataLoaders, SummaryWriters +from .example import MinExponentialLR + + diff --git a/piano_arranger/models/amc_dl/torch_plus/example.py b/piano_arranger/models/amc_dl/torch_plus/example.py new file mode 100644 index 0000000..e459825 --- /dev/null +++ b/piano_arranger/models/amc_dl/torch_plus/example.py @@ -0,0 +1,13 @@ +from torch.optim.lr_scheduler import ExponentialLR + + +class MinExponentialLR(ExponentialLR): + def __init__(self, optimizer, gamma, minimum, last_epoch=-1): + self.min = minimum + super(MinExponentialLR, self).__init__(optimizer, gamma, last_epoch=-1) + + def get_lr(self): + return [ + max(base_lr * self.gamma ** self.last_epoch, self.min) + for base_lr in self.base_lrs + ] \ No newline at end of file diff --git a/piano_arranger/models/amc_dl/torch_plus/manager.py b/piano_arranger/models/amc_dl/torch_plus/manager.py new file mode 100644 index 0000000..50ca4fe --- /dev/null +++ b/piano_arranger/models/amc_dl/torch_plus/manager.py @@ -0,0 +1,137 @@ +import datetime +import os +import shutil +#from tensorboardX import SummaryWriter +from .train_utils import join_fn +import torch +from torch.utils.tensorboard import SummaryWriter + + + +#todo copy every import file as readme + + +class LogPathManager: + + def __init__(self, readme_fn=None, log_path_name='result', + with_date=True, with_time=True, + writer_folder='writers', model_folder='models'): + date = str(datetime.date.today()) if with_date else '' + ctime = datetime.datetime.now().time().strftime("%H%M%S") \ + if with_time else '' + log_folder = '_'.join([log_path_name, date, ctime]) + log_path = os.path.join('.', log_folder) + writer_path = os.path.join(log_path, writer_folder) + model_path = os.path.join(log_path, model_folder) + self.log_path = log_path + self.writer_path = writer_path + self.model_path = model_path + LogPathManager.create_path(log_path) + LogPathManager.create_path(writer_path) + LogPathManager.create_path(model_path) + if readme_fn is not None: + shutil.copyfile(readme_fn, os.path.join(log_path, 'readme.txt')) + + @staticmethod + def create_path(path): + if not os.path.exists(path): + os.mkdir(path) + + def epoch_model_path(self, model_name): + model_fn = join_fn(model_name, 'epoch', ext='pt') + return os.path.join(self.model_path, model_fn) + + def valid_model_path(self, model_name): + model_fn = join_fn(model_name, 'valid', ext='pt') + return os.path.join(self.model_path, model_fn) + + def final_model_path(self, model_name): + model_fn = join_fn(model_name, 'final', ext='pt') + return os.path.join(self.model_path, model_fn) + + +class DataLoaders: + + def __init__(self, train_loader, val_loader, bs_train, bs_val, device=None): + self.train_loader = train_loader + self.val_loader = val_loader + self.num_train_batch = len(train_loader) + self.num_val_batch = len(val_loader) + self.bs_train = bs_train + self.bs_val = bs_val + if device is None: + device = torch.device('cuda' if torch.cuda.is_available() + else 'cpu') + self.device = device + + @staticmethod + def get_loaders(seed, bs_train, bs_val, + portion=8, shift_low=-6, shift_high=5, num_bar=2, + contain_chord=True): + raise NotImplementedError + + def batch_to_inputs(self, *input): + raise NotImplementedError + + @staticmethod + def _get_ith_batch(i, loader): + assert 0 <= 0 < len(loader) + for ind, batch in enumerate(loader): + if i == ind: + break + return batch + + def get_ith_train_batch(self, i): + return DataLoaders._get_ith_batch(i, self.train_loader) + + def get_ith_val_batch(self, i): + return DataLoaders._get_ith_batch(i, self.val_loader) + + +class SummaryWriters: + + def __init__(self, writer_names, tags, log_path, tasks=('train', 'val')): + # writer_names example: ['loss', 'kl_loss', 'recon_loss'] + # tags example: {'name1': None, 'name2': (0, 1)} + self.log_path = log_path + assert 'loss' == writer_names[0] + self.writer_names = writer_names + self.tags = tags + self._regularize_tags() + + writer_dic = {} + for name in writer_names: + writer_dic[name] = SummaryWriter(os.path.join(log_path, name)) + self.writers = writer_dic + + all_tags = {} + for task in tasks: + task_dic = {} + for key, val in self.tags.items(): + task_dic['_'.join([task, key])] = val + all_tags[task] = task_dic + self.all_tags = all_tags + + def _init_summary_writer(self): + tags = {'batch_train': (0, 1, 2, 3, 4)} + self.summary_writers = SummaryWriters(self.writer_names, tags, + self.writer_path) + + def _regularize_tags(self): + for key, val in self.tags.items(): + if val is None: + self.tags[key] = tuple(range(len(self.writer_names))) + + def single_write(self, name, tag, val, step): + self.writers[name].add_scalar(tag, val, step) + + def write_tag(self, task, tag, vals, step): + assert len(vals) == len(self.all_tags[task][tag]) + for name_id, val in zip(self.all_tags[task][tag], vals): + name = self.writer_names[name_id] + self.single_write(name, tag, val, step) + + def write_task(self, task, vals_dic, step): + for tag, name_ids in self.all_tags[task].items(): + vals = [vals_dic[self.writer_names[i]] for i in name_ids] + self.write_tag(task, tag, vals, step) \ No newline at end of file diff --git a/piano_arranger/models/amc_dl/torch_plus/module.py b/piano_arranger/models/amc_dl/torch_plus/module.py new file mode 100644 index 0000000..a74a199 --- /dev/null +++ b/piano_arranger/models/amc_dl/torch_plus/module.py @@ -0,0 +1,220 @@ +import time +import os +import torch +from torch import nn +from .train_utils import epoch_time + + +class PytorchModel(nn.Module): + + def __init__(self, name, device): + self.name = name + super(PytorchModel, self).__init__() + if device is None: + device = torch.device('cuda' if torch.cuda.is_available() + else 'cpu') + self.device = device + + def run(self, *input): + """A general way to run the model. + Usually tensor input -> tensor output""" + raise NotImplementedError + + def loss(self, *input, **kwargs): + """Call it during training. The output is loss and possibly others to + display on tensorboard.""" + raise NotImplementedError + + def inference(self, *input): + """Call it during inference. + The output is usually numpy after argmax.""" + raise NotImplementedError + + def loss_function(self, *input): + raise NotImplementedError + + def forward(self, mode, *input, **kwargs): + if mode in ["run", 0]: + return self.run(*input, **kwargs) + elif mode in ['loss', 'train', 1]: + return self.loss(*input, **kwargs) + elif mode in ['inference', 'eval', 'val', 2]: + return self.inference(*input, **kwargs) + else: + raise NotImplementedError + + def load_model(self, model_path, map_location=None): + if map_location is None: + map_location = self.device + dic = torch.load(model_path, map_location=map_location) + for name in list(dic.keys()): + dic[name.replace('module.', '')] = dic.pop(name) + self.load_state_dict(dic) + self.to(self.device) + + @staticmethod + def init_model(*inputs): + raise NotImplementedError + + +class TrainingInterface: + + def __init__(self, device, model, parallel, log_path_mng, data_loaders, + summary_writers, + opt_scheduler, param_scheduler, n_epoch, **kwargs): + self.model = model + self.model.device = device + if parallel: + self.model = nn.DataParallel(self.model) + self.model.to(device) + self.path_mng = log_path_mng + self.summary_writers = summary_writers + self.data_loaders = data_loaders + self.opt_scheduler = opt_scheduler + self.param_scheduler = param_scheduler + self.device = device + self.n_epoch = n_epoch + self.epoch = 0 + self.train_step = 0 + self.val_step = 0 + self.parallel = parallel + for key, val in kwargs.items(): + setattr(self, key, val) + + @property + def name(self): + if self.parallel: + return self.model.module.name + else: + return self.model.name + + @property + def log_path(self): + return self.path_mng.log_path + + @property + def model_path(self): + return self.path_mng.model_path + + @property + def writer_path(self): + return self.path_mng.writer_path + + @property + def writer_names(self): + return self.summary_writers.writer_names + + def _init_loss_dic(self): + loss_dic = {} + for key in self.writer_names: + loss_dic[key] = 0. + return loss_dic + + def _accumulate_loss_dic(self, loss_dic, loss_items): + assert len(self.writer_names) == len(loss_items) + for key, val in zip(self.writer_names, loss_items): + loss_dic[key] += val.item() + return loss_dic + + def _write_loss_to_dic(self, loss_items): + loss_dic = {} + assert len(self.writer_names) == len(loss_items) + for key, val in zip(self.writer_names, loss_items): + loss_dic[key] = val.item() + return loss_dic + + def _batch_to_inputs(self, batch): + raise NotImplementedError + + def train(self, **kwargs): + self.model.train() + self.param_scheduler.train() + epoch_loss_dic = self._init_loss_dic() + + for i, batch in enumerate(self.data_loaders.train_loader): + #print(len(batch)) + inputs = self._batch_to_inputs(batch) + #print(type(input)) + self.opt_scheduler.optimizer_zero_grad() + input_params = self.param_scheduler.step() + #print(input_params.keys()) + outputs = self.model('train', *inputs, **input_params) + outputs = self._sum_parallel_loss(outputs) + loss = outputs[0] + loss.backward() + torch.nn.utils.clip_grad_norm_(self.model.parameters(), + self.opt_scheduler.clip) + self.opt_scheduler.step() + self._accumulate_loss_dic(epoch_loss_dic, outputs) + batch_loss_dic = self._write_loss_to_dic(outputs) + self.summary_writers.write_task('train', batch_loss_dic, + self.train_step) + self.train_step += 1 + return epoch_loss_dic + + def _sum_parallel_loss(self, loss): + if self.parallel: + if isinstance(loss, tuple): + return tuple([x.mean() for x in loss]) + else: + return loss.mean() + else: + return loss + + def eval(self): + self.model.eval() + self.param_scheduler.eval() + epoch_loss_dic = self._init_loss_dic() + + for i, batch in enumerate(self.data_loaders.val_loader): + inputs = self._batch_to_inputs(batch) + input_params = self.param_scheduler.step() + with torch.no_grad(): + outputs = self.model('train', *inputs, **input_params) + outputs = self._sum_parallel_loss(outputs) + self._accumulate_loss_dic(epoch_loss_dic, outputs) + batch_loss_dic = self._write_loss_to_dic(outputs) + self.summary_writers.write_task('val', batch_loss_dic, + self.val_step) + self.val_step += 1 + return epoch_loss_dic + + def save_model(self, fn): + if self.parallel: + torch.save(self.model.module.state_dict(), fn) + else: + torch.save(self.model.state_dict(), fn) + + def epoch_report(self, start_time, end_time, train_loss, valid_loss): + epoch_mins, epoch_secs = epoch_time(start_time, end_time) + print(f'Epoch: {self.epoch + 1:02} | ' + f'Time: {epoch_mins}m {epoch_secs}s', + flush=True) + print( + f'\tTrain Loss: {train_loss:.3f}', flush=True) + print( + f'\t Valid. Loss: {valid_loss:.3f}', flush=True) + + def run(self, start_epoch=0, start_train_step=0, start_val_step=0): + self.epoch = start_epoch + self.train_step = start_train_step + self.val_step = start_val_step + best_valid_loss = float('inf') + + for i in range(self.n_epoch): + start_time = time.time() + train_loss = self.train()['loss'] + val_loss = self.eval()['loss'] + end_time = time.time() + self.save_model(self.path_mng.epoch_model_path(self.name)) + if val_loss < best_valid_loss: + best_valid_loss = val_loss + self.save_model(self.path_mng.valid_model_path(self.name)) + self.epoch_report(start_time, end_time, train_loss, val_loss) + self.epoch += 1 + self.save_model(self.path_mng.final_model_path(self.name)) + print('Model saved.') + + + + diff --git a/piano_arranger/models/amc_dl/torch_plus/scheduler.py b/piano_arranger/models/amc_dl/torch_plus/scheduler.py new file mode 100644 index 0000000..27e156b --- /dev/null +++ b/piano_arranger/models/amc_dl/torch_plus/scheduler.py @@ -0,0 +1,104 @@ +import numpy as np +from .train_utils import scheduled_sampling + +class _Scheduler: + + def __init__(self, step=0, mode='train'): + self._step = step + self._mode = mode + + def _update_step(self): + if self._mode == 'train': + self._step += 1 + elif self._mode == 'val': + pass + else: + raise NotImplementedError + + def step(self): + raise NotImplementedError + + def train(self): + self._mode = 'train' + + def eval(self): + self._mode = 'val' + + +class ConstantScheduler(_Scheduler): + + def __init__(self, param, step=0.): + super(ConstantScheduler, self).__init__(step) + self.param = param + + def step(self): + self._update_step() + return self.param + + +class TeacherForcingScheduler(_Scheduler): + + def __init__(self, high, low, f=scheduled_sampling, step=0): + super(TeacherForcingScheduler, self).__init__(step) + self.high = high + self.low = low + self._step = step + self.schedule_f = f + + def get_tfr(self): + return self.schedule_f(self._step, self.high, self.low) + + def step(self): + tfr = self.get_tfr() + self._update_step() + return tfr + + +class OptimizerScheduler(_Scheduler): + + def __init__(self, optimizer, scheduler, clip, step=0): + # optimizer and scheduler are pytorch class + super(OptimizerScheduler, self).__init__(step) + self.optimizer = optimizer + self.scheduler = scheduler + self.clip = clip + + def optimizer_zero_grad(self): + self.optimizer.zero_grad() + + def step(self, require_zero_grad=False): + self.optimizer.step() + self.scheduler.step() + if require_zero_grad: + self.optimizer_zero_grad() + self._update_step() + + +class ParameterScheduler(_Scheduler): + + def __init__(self, step=0, mode='train', **schedulers): + # optimizer and scheduler are pytorch class + super(ParameterScheduler, self).__init__(step) + self.schedulers = schedulers + self.mode = mode + + def train(self): + self.mode = 'train' + for scheduler in self.schedulers.values(): + scheduler.train() + + def eval(self): + self.mode = 'val' + for scheduler in self.schedulers.values(): + scheduler.eval() + + def step(self, require_zero_grad=False): + params_dic = {} + for key, scheduler in self.schedulers.items(): + params_dic[key] = scheduler.step() + return params_dic + + + + + diff --git a/piano_arranger/models/amc_dl/torch_plus/train_utils.py b/piano_arranger/models/amc_dl/torch_plus/train_utils.py new file mode 100644 index 0000000..1849442 --- /dev/null +++ b/piano_arranger/models/amc_dl/torch_plus/train_utils.py @@ -0,0 +1,49 @@ +import numpy as np +from torch.distributions import Normal, kl_divergence +import torch + + +def epoch_time(start_time, end_time): + elapsed_time = end_time - start_time + elapsed_mins = int(elapsed_time / 60) + elapsed_secs = int(elapsed_time - (elapsed_mins * 60)) + return elapsed_mins, elapsed_secs + + +def join_fn(*items, ext='pt'): + return '.'.join(['_'.join(items), ext]) + + +def scheduled_sampling(i, high=0.7, low=0.05): + x = 10 * (i - 0.5) + z = 1 / (1 + np.exp(x)) + y = (high - low) * z + low + return y + + +def kl_anealing(i, high=0.1, low=0.): + hh = 1 - low + ll = 1 - high + x = 10 * (i - 0.5) + z = 1 / (1 + np.exp(x)) + y = (hh - ll) * z + ll + return 1 - y + + +def get_zs_from_dists(dists, sample=False): + return [dist.rsample() if sample else dist.mean for dist in dists] + + +def standard_normal(shape): + N = Normal(torch.zeros(shape), torch.ones(shape)) + if torch.cuda.is_available(): + N.loc = N.loc.cuda() + N.scale = N.scale.cuda() + return N + + +def kl_with_normal(dist): + shape = dist.mean.size(-1) + normal = standard_normal(shape) + kl = kl_divergence(dist, normal).mean() + return kl diff --git a/piano_arranger/models/ptvae.py b/piano_arranger/models/ptvae.py new file mode 100644 index 0000000..ed07d4d --- /dev/null +++ b/piano_arranger/models/ptvae.py @@ -0,0 +1,598 @@ +from .amc_dl.torch_plus import PytorchModel +import torch +from torch import nn +from torch.nn.utils.rnn import pack_padded_sequence +from torch.distributions import Normal +import random +import pretty_midi +import numpy as np + +""" + Credit to Wang et al., "PIANOTREE VAE: Structured Representation Learning for Polyphonic Music," ISMIR 2020. + https://github.com/ZZWaang/PianoTree-VAE +""" + +class RnnEncoder(nn.Module): + def __init__(self, input_dim, hidden_dim, z_dim): + super(RnnEncoder, self).__init__() + self.gru = nn.GRU(input_dim, hidden_dim, batch_first=True, + bidirectional=True) + self.linear_mu = nn.Linear(hidden_dim * 2, z_dim) + self.linear_var = nn.Linear(hidden_dim * 2, z_dim) + self.input_dim = input_dim + self.hidden_dim = hidden_dim + self.z_dim = z_dim + + def forward(self, x): + x = self.gru(x)[-1] + x = x.transpose_(0, 1).contiguous() + x = x.view(x.size(0), -1) + mu = self.linear_mu(x) + var = self.linear_var(x).exp_() + dist = Normal(mu, var) + return dist + + +class RnnDecoder(nn.Module): + + def __init__(self, input_dim=36, z_input_dim=256, + hidden_dim=512, z_dim=256, num_step=32): + super(RnnDecoder, self).__init__() + self.z2dec_hid = nn.Linear(z_dim, hidden_dim) + self.z2dec_in = nn.Linear(z_dim, z_input_dim) + self.gru = nn.GRU(input_dim + z_input_dim, hidden_dim, + batch_first=True, + bidirectional=False) + self.init_input = nn.Parameter(torch.rand(36)) + self.input_dim = input_dim + self.hidden_dim = hidden_dim + self.z_dim = z_dim + self.root_out = nn.Linear(hidden_dim, 12) + self.chroma_out = nn.Linear(hidden_dim, 24) + self.bass_out = nn.Linear(hidden_dim, 12) + self.num_step = num_step + + def forward(self, z_chd, inference, tfr, c=None): + # z_chd: (B, z_chd_size) + bs = z_chd.size(0) + z_chd_hid = self.z2dec_hid(z_chd).unsqueeze(0) + z_chd_in = self.z2dec_in(z_chd).unsqueeze(1) + if inference: + tfr = 0. + token = self.init_input.repeat(bs, 1).unsqueeze(1) + recon_root = [] + recon_chroma = [] + recon_bass = [] + + for t in range(int(self.num_step / 4)): + chd, z_chd_hid = \ + self.gru(torch.cat([token, z_chd_in], dim=-1), z_chd_hid) + r_root = self.root_out(chd) # (bs, 1, 12) + r_chroma = self.chroma_out(chd).view(bs, 1, 12, 2).contiguous() + r_bass = self.bass_out(chd) # (bs, 1, 12) + recon_root.append(r_root) + recon_chroma.append(r_chroma) + recon_bass.append(r_bass) + + t_root = torch.zeros(bs, 1, 12).to(z_chd.device).float() + t_root[torch.arange(0, bs), 0, r_root.max(-1)[-1]] = 1. + t_chroma = r_chroma.max(-1)[-1].float() + t_bass = torch.zeros(bs, 1, 12).to(z_chd.device).float() + t_bass[torch.arange(0, bs), 0, r_bass.max(-1)[-1]] = 1. + token = torch.cat([t_root, t_chroma, t_bass], dim=-1) + if t == self.num_step - 1: + break + teacher_force = random.random() < tfr + if teacher_force and not inference: + token = c[:, t].unsqueeze(1) + recon_root = torch.cat(recon_root, dim=1) + recon_chroma = torch.cat(recon_chroma, dim=1) + recon_bass = torch.cat(recon_bass, dim=1) + return recon_root, recon_chroma, recon_bass + + +class TextureEncoder(nn.Module): + + def __init__(self, emb_size, hidden_dim, z_dim, num_channel=10, for_contrastive=False): + '''input must be piano_mat: (B, 32, 128)''' + super(TextureEncoder, self).__init__() + self.cnn = nn.Sequential(nn.Conv2d(1, num_channel, kernel_size=(4, 12), + stride=(4, 1), padding=0), + nn.ReLU(), + nn.MaxPool2d(kernel_size=(1, 4), + stride=(1, 4))) + self.fc1 = nn.Linear(num_channel * 29, 1000) + self.fc2 = nn.Linear(1000, emb_size) + self.gru = nn.GRU(emb_size, hidden_dim, batch_first=True, + bidirectional=True) + self.linear_mu = nn.Linear(hidden_dim * 2, z_dim) + self.linear_var = nn.Linear(hidden_dim * 2, z_dim) + self.emb_size = emb_size + self.hidden_dim = hidden_dim + self.z_dim = z_dim + self.for_contrastive = for_contrastive + + def forward(self, pr): + # pr: (bs, 32, 128) + bs = pr.size(0) + pr = pr.unsqueeze(1) + pr = self.cnn(pr).view(bs, 8, -1) + pr = self.fc2(self.fc1(pr)) # (bs, 8, emb_size) + pr = self.gru(pr)[-1] + pr = pr.transpose_(0, 1).contiguous() + pr = pr.view(pr.size(0), -1) + mu = self.linear_mu(pr) + var = self.linear_var(pr).exp_() + dist = Normal(mu, var) + if self.for_contrastive: + return mu, pr + else: + return dist + + +class PtvaeEncoder(nn.Module): + + def __init__(self, device, max_simu_note=16, max_pitch=127, min_pitch=0, + pitch_sos=128, pitch_eos=129, pitch_pad=130, + dur_pad=2, dur_width=5, num_step=32, + note_emb_size=128, + enc_notes_hid_size=256, + enc_time_hid_size=512, z_size=512): + super(PtvaeEncoder, self).__init__() + + # Parameters + # note and time + self.max_pitch = max_pitch # the highest pitch in train/val set. + self.min_pitch = min_pitch # the lowest pitch in train/val set. + self.pitch_sos = pitch_sos + self.pitch_eos = pitch_eos + self.pitch_pad = pitch_pad + self.pitch_range = max_pitch - min_pitch + 3 # not including pad. + self.dur_pad = dur_pad + self.dur_width = dur_width + self.note_size = self.pitch_range + dur_width + self.max_simu_note = max_simu_note # the max # of notes at each ts. + self.num_step = num_step # 32 + if device is None: + self.device = 'cuda' if torch.cuda.is_available() else 'cpu' + else: + self.device = device + self.note_emb_size = note_emb_size + self.z_size = z_size + self.enc_notes_hid_size = enc_notes_hid_size + self.enc_time_hid_size = enc_time_hid_size + + self.note_embedding = nn.Linear(self.note_size, note_emb_size) + self.enc_notes_gru = nn.GRU(note_emb_size, enc_notes_hid_size, + num_layers=1, batch_first=True, + bidirectional=True) + self.enc_time_gru = nn.GRU(2 * enc_notes_hid_size, enc_time_hid_size, + num_layers=1, batch_first=True, + bidirectional=True) + self.linear_mu = nn.Linear(2 * enc_time_hid_size, z_size) + self.linear_std = nn.Linear(2 * enc_time_hid_size, z_size) + + def get_len_index_tensor(self, ind_x): + """Calculate the lengths ((B, 32), torch.LongTensor) of pgrid.""" + with torch.no_grad(): + lengths = self.max_simu_note - \ + (ind_x[:, :, :, 0] - self.pitch_pad == 0).sum(dim=-1) + return lengths + + def index_tensor_to_multihot_tensor(self, ind_x): + """Transfer piano_grid to multi-hot piano_grid.""" + # ind_x: (B, 32, max_simu_note, 1 + dur_width) + with torch.no_grad(): + dur_part = ind_x[:, :, :, 1:].float() + out = torch.zeros( + [ind_x.size(0) * self.num_step * self.max_simu_note, + self.pitch_range + 1], + dtype=torch.float).to(self.device) + + out[range(0, out.size(0)), ind_x[:, :, :, 0].view(-1)] = 1. + out = out.view(-1, 32, self.max_simu_note, self.pitch_range + 1) + out = torch.cat([out[:, :, :, 0: self.pitch_range], dur_part], + dim=-1) + return out + + def encoder(self, x, lengths): + embedded = self.note_embedding(x) + # x: (B, num_step, max_simu_note, note_emb_size) + # now x are notes + x = embedded.view(-1, self.max_simu_note, self.note_emb_size) + x = pack_padded_sequence(x, lengths.view(-1), batch_first=True, + enforce_sorted=False) + x = self.enc_notes_gru(x)[-1].transpose(0, 1).contiguous() + x = x.view(-1, self.num_step, 2 * self.enc_notes_hid_size) + # now, x is simu_notes. + x = self.enc_time_gru(x)[-1].transpose(0, 1).contiguous() + # x: (B, 2, enc_time_hid_size) + x = x.view(x.size(0), -1) + mu = self.linear_mu(x) # (B, z_size) + std = self.linear_std(x).exp_() # (B, z_size) + dist = Normal(mu, std) + return dist, embedded + + def forward(self, x, return_iterators=False): + lengths = self.get_len_index_tensor(x) + x = self.index_tensor_to_multihot_tensor(x) + dist, embedded_x = self.encoder(x, lengths) + if return_iterators: + return dist.mean, dist.scale, embedded_x + else: + return dist, embedded_x, lengths + + +class PtvaeDecoder(nn.Module): + + def __init__(self, device=None, note_embedding=None, + max_simu_note=16, max_pitch=127, min_pitch=0, + pitch_sos=128, pitch_eos=129, pitch_pad=130, + dur_pad=2, dur_width=5, num_step=32, + note_emb_size=128, z_size=512, + dec_emb_hid_size=128, + dec_time_hid_size=1024, dec_notes_hid_size=512, + dec_z_in_size=256, dec_dur_hid_size=16): + super(PtvaeDecoder, self).__init__() + # Parameters + # note and time + self.max_pitch = max_pitch # the highest pitch in train/val set. + self.min_pitch = min_pitch # the lowest pitch in train/val set. + self.pitch_sos = pitch_sos + self.pitch_eos = pitch_eos + self.pitch_pad = pitch_pad + self.pitch_range = max_pitch - min_pitch + 3 # 88, not including pad. + self.dur_pad = dur_pad + self.dur_width = dur_width + self.note_size = self.pitch_range + dur_width + self.max_simu_note = max_simu_note # the max # of notes at each ts. + self.num_step = num_step # 32 + + # device + if device is None: + self.device = 'cuda' if torch.cuda.is_available() else 'cpu' + else: + self.device = device + + self.note_emb_size = note_emb_size + self.z_size = z_size + + # decoder + self.dec_z_in_size = dec_z_in_size + self.dec_emb_hid_size = dec_emb_hid_size + self.dec_time_hid_size = dec_time_hid_size + self.dec_init_input = \ + nn.Parameter(torch.rand(2 * self.dec_emb_hid_size)) + self.dec_notes_hid_size = dec_notes_hid_size + self.dur_sos_token = nn.Parameter(torch.rand(self.dur_width)) + self.dec_dur_hid_size = dec_dur_hid_size + + # Modules + # For both encoder and decoder + if note_embedding is None: + self.note_embedding = nn.Linear(self.note_size, note_emb_size) + else: + self.note_embedding = note_embedding + self.z2dec_hid_linear = nn.Linear(self.z_size, dec_time_hid_size) + self.z2dec_in_linear = nn.Linear(self.z_size, dec_z_in_size) + self.dec_notes_emb_gru = nn.GRU(note_emb_size, dec_emb_hid_size, + num_layers=1, batch_first=True, + bidirectional=True) + self.dec_time_gru = \ + nn.GRU(dec_z_in_size + 2 * dec_emb_hid_size, + dec_time_hid_size, + num_layers=1, batch_first=True, + bidirectional=False) + self.dec_time_to_notes_hid = nn.Linear(dec_time_hid_size, + dec_notes_hid_size) + self.dec_notes_gru = nn.GRU(dec_time_hid_size + note_emb_size, + dec_notes_hid_size, + num_layers=1, batch_first=True, + bidirectional=False) + self.pitch_out_linear = nn.Linear(dec_notes_hid_size, self.pitch_range) + self.dec_dur_gru = nn.GRU(dur_width, dec_dur_hid_size, + num_layers=1, batch_first=True, + bidirectional=False) + self.dur_hid_linear = nn.Linear(self.pitch_range + dec_notes_hid_size, + dec_dur_hid_size) + self.dur_out_linear = nn.Linear(dec_dur_hid_size, 2) + + def get_len_index_tensor(self, ind_x): + """Calculate the lengths ((B, 32), torch.LongTensor) of pgrid.""" + with torch.no_grad(): + lengths = self.max_simu_note - \ + (ind_x[:, :, :, 0] - self.pitch_pad == 0).sum(dim=-1) + return lengths + + def index_tensor_to_multihot_tensor(self, ind_x): + """Transfer piano_grid to multi-hot piano_grid.""" + # ind_x: (B, 32, max_simu_note, 1 + dur_width) + with torch.no_grad(): + dur_part = ind_x[:, :, :, 1:].float() + out = torch.zeros( + [ind_x.size(0) * self.num_step * self.max_simu_note, + self.pitch_range + 1], + dtype=torch.float).to(self.device) + + out[range(0, out.size(0)), ind_x[:, :, :, 0].view(-1)] = 1. + out = out.view(-1, 32, self.max_simu_note, self.pitch_range + 1) + out = torch.cat([out[:, :, :, 0: self.pitch_range], dur_part], + dim=-1) + return out + + def get_sos_token(self): + sos = torch.zeros(self.note_size) + sos[self.pitch_sos] = 1. + sos[self.pitch_range:] = 2. + sos = sos.to(self.device) + return sos + + def dur_ind_to_dur_token(self, inds, batch_size): + token = torch.zeros(batch_size, self.dur_width) + token[range(0, batch_size), inds] = 1. + token = token.to(self.device) + return token + + def pitch_dur_ind_to_note_token(self, pitch_inds, dur_inds, batch_size): + token = torch.zeros(batch_size, self.note_size) + token[range(0, batch_size), pitch_inds] = 1. + token[:, self.pitch_range:] = dur_inds + token = token.to(self.device) + token = self.note_embedding(token) + return token + + def decode_note(self, note_summary, batch_size): + # note_summary: (B, 1, dec_notes_hid_size) + # This function estimate pitch, and dur for a single pitch based on + # note_summary. + # Returns: est_pitch (B, 1, pitch_range), est_durs (B, 1, dur_width, 2) + + # The estimated pitch is calculated by a linear layer. + est_pitch = self.pitch_out_linear(note_summary).squeeze(1) + # est_pitch: (B, pitch_range) + + # The estimated dur is calculated by a 5-step gru. + dur_hid = note_summary.transpose(0, 1) + # dur_hid: (1, B, dec_notes_hid_size) + dur_hid = \ + self.dur_hid_linear(torch.cat([dur_hid, + est_pitch.unsqueeze(0)], + dim=-1)) + token = self.dur_sos_token.repeat(batch_size, 1).unsqueeze(1) + # token: (B, 1, dur_width) + + est_durs = torch.zeros(batch_size, self.dur_width, 2) + est_durs = est_durs.to(self.device) + + for t in range(self.dur_width): + token, dur_hid = self.dec_dur_gru(token, dur_hid) + est_dur = self.dur_out_linear(token).squeeze(1) + est_durs[:, t] = est_dur + if t == self.dur_width - 1: + break + token_inds = est_dur.max(1)[1] + token = self.dur_ind_to_dur_token(token_inds, + batch_size).unsqueeze(1) + return est_pitch, est_durs + + def decode_notes(self, notes_summary, batch_size, notes, inference, + teacher_forcing_ratio=0.5): + # notes_summary: (B, 1, dec_time_hid_size) + # notes: (B, max_simu_note, note_emb_size), ground_truth + notes_summary_hid = \ + self.dec_time_to_notes_hid(notes_summary.transpose(0, 1)) + if inference: + assert teacher_forcing_ratio == 0 + assert notes is None + sos = self.get_sos_token() # (note_size,) + token = self.note_embedding(sos).repeat(batch_size, 1).unsqueeze(1) + # hid: (B, 1, note_emb_size) + else: + token = notes[:, 0].unsqueeze(1) + + predicted_notes = torch.zeros(batch_size, self.max_simu_note, + self.note_emb_size) + predicted_notes[:, :, self.pitch_range:] = 2. + predicted_notes[:, 0] = token.squeeze(1) # fill sos index + lengths = torch.zeros(batch_size) + predicted_notes = predicted_notes.to(self.device) + lengths = lengths.to(self.device) + pitch_outs = [] + dur_outs = [] + + for t in range(1, self.max_simu_note): + note_summary, notes_summary_hid = \ + self.dec_notes_gru(torch.cat([notes_summary, token], dim=-1), + notes_summary_hid) + # note_summary: (B, 1, dec_notes_hid_size) + # notes_summary_hid: (1, B, dec_time_hid_size) + + est_pitch, est_durs = self.decode_note(note_summary, batch_size) + # est_pitch: (B, pitch_range) + # est_durs: (B, dur_width, 2) + + pitch_outs.append(est_pitch.unsqueeze(1)) + dur_outs.append(est_durs.unsqueeze(1)) + pitch_inds = est_pitch.max(1)[1] + dur_inds = est_durs.max(2)[1] + predicted = self.pitch_dur_ind_to_note_token(pitch_inds, dur_inds, + batch_size) + # predicted: (B, note_size) + + predicted_notes[:, t] = predicted + eos_samp_inds = (pitch_inds == self.pitch_eos) + lengths[eos_samp_inds & (lengths == 0)] = t + + if t == self.max_simu_note - 1: + break + teacher_force = random.random() < teacher_forcing_ratio + if inference or not teacher_force: + token = predicted.unsqueeze(1) + else: + token = notes[:, t].unsqueeze(1) + lengths[lengths == 0] = t + pitch_outs = torch.cat(pitch_outs, dim=1) + dur_outs = torch.cat(dur_outs, dim=1) + return pitch_outs, dur_outs, predicted_notes, lengths + + def decoder(self, z, inference, x, lengths, teacher_forcing_ratio1, + teacher_forcing_ratio2): + # z: (B, z_size) + # x: (B, num_step, max_simu_note, note_emb_size) + batch_size = z.size(0) + z_hid = self.z2dec_hid_linear(z).unsqueeze(0) + # z_hid: (1, B, dec_time_hid_size) + z_in = self.z2dec_in_linear(z).unsqueeze(1) + # z_in: (B, dec_z_in_size) + + if inference: + assert x is None + assert lengths is None + assert teacher_forcing_ratio1 == 0 + assert teacher_forcing_ratio2 == 0 + else: + x_summarized = x.view(-1, self.max_simu_note, self.note_emb_size) + x_summarized = pack_padded_sequence(x_summarized, lengths.view(-1), + batch_first=True, + enforce_sorted=False) + x_summarized = self.dec_notes_emb_gru(x_summarized)[-1].\ + transpose(0, 1).contiguous() + x_summarized = x_summarized.view(-1, self.num_step, + 2 * self.dec_emb_hid_size) + + pitch_outs = [] + dur_outs = [] + token = self.dec_init_input.repeat(batch_size, 1).unsqueeze(1) + # (B, 2 * dec_emb_hid_size) + + for t in range(self.num_step): + notes_summary, z_hid = \ + self.dec_time_gru(torch.cat([token, z_in], dim=-1), z_hid) + if inference: + pitch_out, dur_out, predicted_notes, predicted_lengths = \ + self.decode_notes(notes_summary, batch_size, None, + inference, teacher_forcing_ratio2) + else: + pitch_out, dur_out, predicted_notes, predicted_lengths = \ + self.decode_notes(notes_summary, batch_size, x[:, t], + inference, teacher_forcing_ratio2) + pitch_outs.append(pitch_out.unsqueeze(1)) + dur_outs.append(dur_out.unsqueeze(1)) + if t == self.num_step - 1: + break + + teacher_force = random.random() < teacher_forcing_ratio1 + if teacher_force and not inference: + token = x_summarized[:, t].unsqueeze(1) + else: + token = pack_padded_sequence(predicted_notes, + predicted_lengths.cpu(), + batch_first=True, + enforce_sorted=False) + token = self.dec_notes_emb_gru(token)[-1].\ + transpose(0, 1).contiguous() + token = token.view(-1, 2 * self.dec_emb_hid_size).unsqueeze(1) + pitch_outs = torch.cat(pitch_outs, dim=1) + dur_outs = torch.cat(dur_outs, dim=1) + # print(pitch_outs.size()) + # print(dur_outs.size()) + return pitch_outs, dur_outs + + def forward(self, z, inference, x, lengths, teacher_forcing_ratio1, + teacher_forcing_ratio2): + return self.decoder(z, inference, x, lengths, teacher_forcing_ratio1, + teacher_forcing_ratio2) + + def recon_loss(self, x, recon_pitch, recon_dur, weights=(1, 0.5), + weighted_dur=False): + pitch_loss_func = \ + nn.CrossEntropyLoss(ignore_index=self.pitch_pad) + recon_pitch = recon_pitch.view(-1, recon_pitch.size(-1)) + #print(recon_pitch.shape) + + gt_pitch = x[:, :, 1:, 0].contiguous().view(-1) + #print(gt_pitch.shape) + pitch_loss = pitch_loss_func(recon_pitch, gt_pitch) + + dur_loss_func = \ + nn.CrossEntropyLoss(ignore_index=self.dur_pad) + if not weighted_dur: + recon_dur = recon_dur.view(-1, 2) + gt_dur = x[:, :, 1:, 1:].contiguous().view(-1) + dur_loss = dur_loss_func(recon_dur, gt_dur) + else: + recon_dur = recon_dur.view(-1, self.dur_width, 2) + gt_dur = x[:, :, 1:, 1:].contiguous().view(-1, self.dur_width) + dur0 = dur_loss_func(recon_dur[:, 0, :], gt_dur[:, 0]) + dur1 = dur_loss_func(recon_dur[:, 1, :], gt_dur[:, 1]) + dur2 = dur_loss_func(recon_dur[:, 2, :], gt_dur[:, 2]) + dur3 = dur_loss_func(recon_dur[:, 3, :], gt_dur[:, 3]) + dur4 = dur_loss_func(recon_dur[:, 4, :], gt_dur[:, 4]) + w = torch.tensor([1, 0.6, 0.4, 0.3, 0.3], + device=recon_dur.device).float() + dur_loss = \ + w[0] * dur0 + \ + w[1] * dur1 + \ + w[2] * dur2 + \ + w[3] * dur3 + \ + w[4] * dur4 + loss = weights[0] * pitch_loss + weights[1] * dur_loss + return loss, pitch_loss, dur_loss + + def emb_x(self, x): + lengths = self.get_len_index_tensor(x) + x = self.index_tensor_to_multihot_tensor(x) + embedded = self.note_embedding(x) + return embedded, lengths + + def output_to_numpy(self, recon_pitch, recon_dur): + est_pitch = recon_pitch.max(-1)[1].unsqueeze(-1) # (B, 32, 11, 1) + est_dur = recon_dur.max(-1)[1] # (B, 32, 11, 5) + est_x = torch.cat([est_pitch, est_dur], dim=-1) # (B, 32, 11, 6) + est_x = est_x.cpu().numpy() + recon_pitch = recon_pitch.cpu().numpy() + recon_dur = recon_dur.cpu().numpy() + return est_x, recon_pitch, recon_dur + + def pr_to_notes(self, pr, bpm=80, start=0., one_hot=False): + pr_matrix = self.pr_to_pr_matrix(pr, one_hot) + alpha = 0.25 * 60 / bpm + notes = [] + for t in range(32): + for p in range(128): + if pr_matrix[t, p] >= 1: + s = alpha * t + start + e = alpha * (t + pr_matrix[t, p]) + start + notes.append(pretty_midi.Note(100, int(p), s, e)) + return notes + + def pr_matrix_to_note(self, pr_matrix, bpm=120, start=0): + alpha = 0.25 * 60 / bpm + notes = [] + for t in range(32): + for p in range(128): + if pr_matrix[t, p] >= 1: + s = alpha * t + start + e = alpha * (t + pr_matrix[t, p]) + start + notes.append(pretty_midi.Note(100, int(p), s, e)) + return notes + + def grid_to_pr_and_notes(self, grid, bpm=60., start=0.): + if grid.shape[1] == self.max_simu_note: + grid = grid[:, 1:] + pr = np.zeros((32, 128), dtype=int) + alpha = 0.25 * 60 / bpm + notes = [] + for t in range(32): + for n in range(10): + note = grid[t, n] + if note[0] == self.pitch_eos: + break + pitch = note[0] + self.min_pitch + dur = int(''.join([str(_) for _ in note[1:]]), 2) + 1 + pr[t, pitch] = min(dur, 32 - t) + notes.append( + pretty_midi.Note(100, int(pitch), start + t * alpha, + start + (t + dur) * alpha)) + return pr, notes + diff --git a/piano_arranger/models/transition_model.py b/piano_arranger/models/transition_model.py new file mode 100644 index 0000000..88106fa --- /dev/null +++ b/piano_arranger/models/transition_model.py @@ -0,0 +1,31 @@ +import torch +from torch import nn + +class contrastive_model(nn.Module): + def __init__(self, emb_size=256, hidden_dim=1024): + """input: ((batch * 6) * (1024*2))""" + super(contrastive_model, self).__init__() + #self.in_linear = nn.Linear(1024*2, emb_size) + self.out_linear_left = nn.Linear(hidden_dim * 2, emb_size) + self.out_linear_right = nn.Linear(hidden_dim * 2, emb_size) + + self.emb_size = emb_size + self.hidden_dim = hidden_dim + + self.dropout = nn.Dropout(p=0) + + self.cosine = nn.CosineSimilarity(dim=-1) + self.loss = nn.Softmax(dim=-1) + + def contrastive_loss(self, similarity): + return 1 - torch.mean(self.loss(similarity)[:, 0]) + + def forward(self, batch): + """input: (batch * 6 * (1024*2))""" + batch_size, pos_neg, feature_dim = batch.shape + #batch = self.in_linear(batch) #(batch_size * pos_neg_size) * phrase_length * emb_size + left = self.dropout(self.out_linear_left(batch[:, 0: 1, :])) #batch * 1 * emb_size + right = self.dropout(self.out_linear_right(batch[:, 1:, :])) #batch * 5 * emb_size + similarity = self.cosine(left.expand(right.shape), right) #batch * 5 + return similarity + diff --git a/piano_arranger/scripts/build_phrase_data.py b/piano_arranger/scripts/build_phrase_data.py new file mode 100644 index 0000000..1a90d15 --- /dev/null +++ b/piano_arranger/scripts/build_phrase_data.py @@ -0,0 +1,191 @@ +import os +import numpy as np +import pretty_midi as pyd +from dtw import * +from scipy import interpolate +from tqdm import tqdm +import mir_eval +from scipy import stats +import pandas as pd + +def split_phrases(segmentation): + phrases = [] + lengths = [] + current = 0 + while segmentation[current] != '\n': + if segmentation[current].isalpha(): + j = 1 + while not (segmentation[current + j].isalpha() or segmentation[current + j] == '\n'): + j += 1 + phrases.append(segmentation[current]) + lengths.append(int(segmentation[current+1: current+j])) + current += j + return [(phrases[i], lengths[i], sum(lengths[:i])) for i in range(len(phrases))] + + +def matrix2midi_with_dynamics(pr_matrices, programs, init_tempo=120, time_start=0, ACC=16): + """ + Reconstruct a multi-track midi from a 3D matrix of shape (Track. Time, 128, 3). + """ + tracks = [] + for program in programs: + track_recon = pyd.Instrument(program=int(program), is_drum=False, name=pyd.program_to_instrument_name(int(program))) + tracks.append(track_recon) + + indices_track, indices_onset, indices_pitch = np.nonzero(pr_matrices[:, :, :, 0]) + alpha = 1 / (ACC // 4) * 60 / init_tempo #timetep between each quntization bin + for idx in range(len(indices_track)): + track_id = indices_track[idx] + onset = indices_onset[idx] + pitch = indices_pitch[idx] + + start = onset * alpha + duration = pr_matrices[track_id, onset, pitch, 0] * alpha + velocity = pr_matrices[track_id, onset, pitch, 1] + + note_recon = pyd.Note(velocity=int(velocity), pitch=int(pitch), start=time_start + start, end=time_start + start + duration) + tracks[track_id].notes.append(note_recon) + for idx in range(len(pr_matrices)): + cc = [] + control_matrix = pr_matrices[idx, :, :, 2] + for t, n in zip(*np.nonzero(control_matrix >= 0)): + start = alpha * t + cc.append(pyd.ControlChange(int(n), int(control_matrix[t, n]), start)) + tracks[idx].control_changes = cc + + midi_recon = pyd.PrettyMIDI(initial_tempo=init_tempo) + midi_recon.instruments = tracks + return midi_recon + +def get_ec2_melody(pr_matrix): + hold_pitch = 128 + rest_pitch = 129 + + piano_roll = np.zeros((len(pr_matrix), 130)) + for t, p in zip(*np.nonzero(pr_matrix)): + dur = int(pr_matrix[t, p]) + piano_roll[t, p] = 1 + piano_roll[t+1:t+dur, hold_pitch] = 1 + piano_roll[np.nonzero(1 - np.sum(piano_roll, axis=1))[0], rest_pitch] = 1 + return piano_roll + + + +SEGMENTATION_ROOT = 'hierarchical-structure-analysis/POP909' # https://github.com/Dsqvival/hierarchical-structure-analysis +POP909_MIDI_ROOT = '../POP909 Dataset (MIDI)' # https://github.com/music-x-lab/POP909-Dataset + +df = pd.read_excel(f"{POP909_MIDI_ROOT}/index.xlsx") +phrase_melody = [] +phrase_acc = [] +phrase_chord = [] +phrase_velocity = [] +phrase_cc = [] +for song in os.listdir(SEGMENTATION_ROOT): + meta_data = df[df.song_id == int(song)] + num_beats = meta_data.num_beats_per_measure.values[0] + num_quavers = meta_data.num_quavers_per_beat.values[0] + if int(num_beats) == 3 or int(num_quavers) == 3: + continue + try: + melody = np.loadtxt(os.path.join(SEGMENTATION_ROOT, song, 'melody.txt')) + except OSError: + continue + melody[:, 1] = np.cumsum(melody[:, 1]) + melody[1:, 1] = melody[:-1, 1] + melody = melody[1:] + #print(melody[:, 1]) + melody_notes = [] + for note in melody: + if note[0] > 0: + melody_notes.append(note) + + midi = pyd.PrettyMIDI(os.path.join(POP909_MIDI_ROOT, song, f'{song}.mid')) + time_record = [] + midi_notes = [] + for note in midi.instruments[0].notes: + if not note.start in time_record: + midi_notes.append(note) + time_record.append(note.start) + + alignment = dtw([int(note[0]) for note in melody_notes], [note.pitch for note in midi_notes], keep_internals=True) + + melody_note_indices = alignment.index1 + midi_note_indices = alignment.index2 + quaver = [] + time = [] + for idx in range(1, len(melody_note_indices)-1): + if (melody_note_indices[idx] == melody_note_indices[idx-1]) \ + or (melody_note_indices[idx] == melody_note_indices[idx+1]) \ + or (midi_note_indices[idx] == midi_note_indices[idx-1]) \ + or (midi_note_indices[idx] == midi_note_indices[idx+1]): + continue + quaver.append(melody_notes[melody_note_indices[idx]][1]) + time.append(midi_notes[midi_note_indices[idx]].start) + + f = interpolate.interp1d(time, quaver, bounds_error=False, fill_value='extrapolate') + + #import matplotlib.pyplot as plt + #plt.plot(time, quaver, 'o', time, quaver, '-') + #plt.show() + + with open(os.path.join(SEGMENTATION_ROOT, song, 'human_label1.txt'), 'r') as file: + segmentation = file.readlines()[0] + print(song, segmentation) + if not '\n' in segmentation: + segmentation += '\n' + segmentation = split_phrases(segmentation) + + tracks = np.concatenate([np.zeros((3, (segmentation[-1][-1] + segmentation[-1][-2]) * 16, 128, 2)), \ + -1 * np.ones((3, (segmentation[-1][-1] + segmentation[-1][-2]) * 16, 128, 2))], \ + axis=-1 \ + ) + for idx, track in enumerate(midi.instruments): + for note in track.notes: + start = int(np.round(f(note.start))) + if start >= tracks.shape[1]: + break + end = int(np.round(f(note.end))) + tracks[idx, start, note.pitch, 0] = max(end - start, 1) + tracks[idx, start, note.pitch, 1] = note.velocity + for control in track.control_changes: + start = int(np.round(f(control.time))) + if start >= tracks.shape[1]: + break + tracks[idx, start, control.number, 2] = control.value + + #midi_recon = matrix2midi_with_dynamics(tracks, [0, 0, 0], init_tempo=90) + #midi_recon.write(f"seg_recon/{song}.mid") + + with open(os.path.join(POP909_MIDI_ROOT, song, f'chord_midi.txt'), 'r') as file: + chord_annotation = file.readlines() + + chord_matrix = np.zeros(((segmentation[-1][-1] + segmentation[-1][-2]) * 16, 14)) + for chord in chord_annotation: + start, end, chord = chord.replace('\n', '').split('\t') + start = int(np.round(f(start))) + end = int(np.round(f(end))) + chord_root, bit_map, bass = mir_eval.chord.encode(chord) + chord = np.concatenate([np.array([chord_root]), np.roll(bit_map, shift=int(chord_root)), np.array([bass])]) + chord_matrix[start: end] = chord + chord_matrix = chord_matrix[::4] + + song_melody = [] + song_acc = [] + song_chord = [] + song_velocity = [] + song_cc = [] + + for (_, length, start) in segmentation: + song_melody.append(get_ec2_melody(tracks[0, start*16: (start+length)*16, :, 0])) + song_acc.append(np.max(tracks[1:, start*16: (start+length)*16, :, 0], axis=0)) + song_chord.append(chord_matrix[start*4: (start+length)*4]) + song_velocity.append(np.max(tracks[1:, start*16: (start+length)*16, :, 1], axis=0)) + song_cc.append(tracks[2, start*16: (start+length)*16, :, 2]) + + phrase_melody.append(song_melody) + phrase_acc.append(song_acc) + phrase_chord.append(song_chord) + phrase_velocity.append(song_velocity) + phrase_cc.append(song_cc) + +np.savez_compressed('./phrase_data.npz', melody=phrase_melody, acc=phrase_acc, chord=phrase_chord, velocity=phrase_velocity, cc=phrase_cc) \ No newline at end of file diff --git a/piano_arranger/scripts/edge_weights_inference.py b/piano_arranger/scripts/edge_weights_inference.py new file mode 100644 index 0000000..9a7063d --- /dev/null +++ b/piano_arranger/scripts/edge_weights_inference.py @@ -0,0 +1,141 @@ +import numpy as np +import os +from tqdm import tqdm +import torch +torch.cuda.current_device() + +import sys +sys.path.append('AccoMontage') +from models import TextureEncoder, contrastive_model +import warnings +warnings.filterwarnings("ignore") + + +def find_by_length(melody_data, acc_data, chord_data, length): + melody_record = [] + acc_record = [] + chord_record = [] + for song_idx in tqdm(range(acc_data.shape[0])): + for phrase_idx in range(len(acc_data[song_idx])): + melody = melody_data[song_idx][phrase_idx] + if not melody.shape[0] == length * 16: + continue + if np.sum(melody[:, :128]) <= 2: + continue + melody_record.append(melody) + acc = acc_data[song_idx][phrase_idx] + acc_record.append(acc) + chord = chord_data[song_idx][phrase_idx] + chord_record.append(chord) + return np.array(melody_record), np.array(acc_record), np.array(chord_record) + + +def contrastive_match(left, rights, texture_model, contras_model, num_candidates): + #left: 1 * time * 128 + #rights: batch * time * 128 + NEG = 6 + batch_size, time, roll_size = rights.shape + #print(batch_size) + count = (batch_size // NEG) * NEG + rights_ = rights[:count].reshape((batch_size // NEG, NEG, time, roll_size)) + left = left[np.newaxis, :, :, :] + batch_input = np.concatenate((np.repeat(left, rights_.shape[0], axis=0), rights_), axis=1) + + texture_model.eval() + contras_model.eval() + consequence = []#np.empty((0, NEG)) + mini_batch = 2 + for i in range(0, batch_input.shape[0] - mini_batch, mini_batch): + batch = batch_input[i: (i+mini_batch)] + #lengths = contras_model.get_len_index_tensor(batch) #8 * 6 + batch = torch.from_numpy(batch).float().cuda() + #lengths = torch.from_numpy(lengths) + bs, pos_neg, time, roll = batch.shape + _, batch = texture_model(batch.view(-1, time, roll)) + batch = batch.view(bs, pos_neg, -1) + similarity = contras_model(batch) + consequence.append(similarity.cpu().detach().numpy()) + #print(consequence.shape) + consequence = np.array(consequence).reshape(-1) + #print(consequence.shape) + + if (i+mini_batch) < batch_input.shape[0]: + batch = batch_input[(i + mini_batch): ] + #lengths = contras_model.get_len_index_tensor(batch) #8 * 6 + batch = torch.from_numpy(batch).float().cuda() + #lengths = torch.from_numpy(lengths) + bs, pos_neg, time, roll = batch.shape + _, batch = texture_model(batch.view(-1, time, roll)) + batch = batch.view(bs, pos_neg, -1) + similarity = contras_model(batch).cpu().detach().numpy().reshape(-1) + consequence = np.concatenate((consequence, similarity)) + + if count < batch_size: + rest = rights[count:].reshape((1, -1, time, roll_size)) + batch = np.concatenate((np.repeat(left, rest.shape[0], axis=0), rest), axis=1) + #lengths = contras_model.get_len_index_tensor(batch) #8 * 6 + batch = torch.from_numpy(batch).float().cuda() + #lengths = torch.from_numpy(lengths) + bs, pos_neg, time, roll = batch.shape + _, batch = texture_model(batch.view(-1, time, roll)) + batch = batch.view(bs, pos_neg, -1) + similarity = contras_model(batch).cpu().detach().numpy().reshape(-1) + consequence = np.concatenate((consequence, similarity)) + #print(batch_size, consequence.shape) + if num_candidates == -1: + #argmax = np.argsort(consequence)[::-1] + return consequence#, argmax + else: + argmax = np.argsort(consequence)[::-1] + #result = [consequence[i] for i in argmax] + #print(result, argmax[:num_candidates]) + return consequence, argmax[:num_candidates] + + +def inference_edge_weights(contras_model, texture_model, length, last_length, melody_data, acc_data, chord_data, acc_pool): + if not length in acc_pool: + (mel, acc, chord) = find_by_length(melody_data, acc_data, chord_data, length) + acc_pool[length] = (mel, acc, chord) + if not last_length in acc_pool: + (mel, acc, chord) = find_by_length(melody_data, acc_data, chord_data, last_length) + acc_pool[last_length] = (mel, acc, chord) + + # melody_set = acc_pool[length][0] + acc_set = acc_pool[length][1] + #chord_set = acc_pool[length][2] + + edge_dict = [] + last_acc_set = acc_pool[last_length][1] + for item in tqdm(last_acc_set): + if len(item) < 32: + item = np.pad(item, ((32-len(item), 0), (0, 0))) + if acc_set.shape[1] < 32: + acc_set = np.pad(acc_set, ((0, 0), (0, 32-acc_set.shape[1]), (0, 0))) + contras_values = contrastive_match(item[np.newaxis, -32:, :], acc_set[:, :32, :], texture_model, contras_model, -1) + edge_dict.append(contras_values) + return np.array(edge_dict) + + +data = np.load('checkpoints/phrase_data.npz', allow_pickle=True) +melody = data['melody'] +acc = data['acc'] +chord = data['chord'] + +texture_model = TextureEncoder(emb_size=256, hidden_dim=1024, z_dim=256, num_channel=10, for_contrastive=True) +checkpoint = torch.load("checkpoints/texture_model_params049.pt") +texture_model.load_state_dict(checkpoint) +texture_model.cuda() + +contras_model = contrastive_model(emb_size=256, hidden_dim=1024) +contras_model.load_state_dict(torch.load('checkpoints/contrastive_model_params049.pt')) +contras_model.cuda() + +for l1 in range(1, 17): + for l2 in range(1, 17): + length = l2 + last_length = l1 + edge_weights = inference_edge_weights(contras_model, texture_model, length, last_length, melody, acc, chord, {}) + + if not os.path.exists('./tmp'): + os.makedirs('./tmp') + np.savez_compressed('./tmp/edge_weights' + '_' + str(last_length) + '_' + str(length) + '.npz', edge_weights) diff --git a/piano_arranger/scripts/transition_model_data_loader.py b/piano_arranger/scripts/transition_model_data_loader.py new file mode 100644 index 0000000..da17170 --- /dev/null +++ b/piano_arranger/scripts/transition_model_data_loader.py @@ -0,0 +1,166 @@ +import numpy as np +from tqdm import tqdm +import random +import platform + +class dataset(object): + def __init__(self, data_path='./', batch_size=8, time_res=32): + song_data = np.load(data_path, allow_pickle=True)['acc'] + self.batch_size = batch_size + self.time_res = time_res + + self.train_pairs, self.val_pairs, self.snippet_pool = self.find_all_pairs_and_snippet_pool(song_data) + + self.train_batch = None + self.val_batch = None + self.train_batch_anchor = None + self.val_batch_anchor = None + self.num_epoch = -1 + + def song_split(self, matrix): + """matrix must be quantizded in sixteenth note""" + window_size = self.time_res #two bars + hop_size = self.time_res #one bar + vector_size = matrix.shape[1] + start_downbeat = 0 + end_downbeat = matrix.shape[0]//16 + assert(end_downbeat - start_downbeat >= 2) + splittedMatrix = [] + for idx_T in range(start_downbeat*16, (end_downbeat-1)*16, hop_size): + sample = matrix[idx_T:idx_T+window_size, :] + if np.sum(sample) == 0: + continue #skip possible blank regions at the head and tail of each song + splittedMatrix.append(sample) + return np.array(splittedMatrix) + + def find_all_pairs_and_snippet_pool(self, song_data): + np.random.seed(0) + np.random.shuffle(song_data) + train_data = song_data[: int(len(song_data)*0.95)] + val_data = song_data[int(len(song_data)*0.95): ] + train_pairs = [] + val_pairs = [] + snippet_pool = [] + for song in tqdm(train_data): + splittedMatrix = self.song_split(song) + for i in range(splittedMatrix.shape[0] - 1): + train_pairs.append([splittedMatrix[i], splittedMatrix[i+1]]) + snippet_pool.append(splittedMatrix[i]) + snippet_pool.append(splittedMatrix[-1]) + for song in tqdm(val_data): + splittedMatrix = self.song_split(song) + for i in range(splittedMatrix.shape[0] - 1): + val_pairs.append([splittedMatrix[i], splittedMatrix[i+1]]) + snippet_pool.append(splittedMatrix[i]) + snippet_pool.append(splittedMatrix[-1]) + return train_pairs, val_pairs, snippet_pool + + def make_batch(self, batch_size): + print('shuffle dataset') + random.shuffle(self.train_pairs) + random.shuffle(self.snippet_pool) + + self.train_batch = [] + self.val_batch = [] + self.train_batch_anchor = 0 + self.val_batch_anchor = 0 + self.num_epoch += 1 + + for i in tqdm(range(0, len(self.train_pairs)-batch_size, batch_size)): + batch_pair = np.array(self.train_pairs[i: i+batch_size]) + random_items = np.array(random.sample(self.snippet_pool, batch_size*4)).reshape((batch_size, 4, 32, 128)) + one_batch = np.concatenate((batch_pair, random_items), axis=1) + #one_batch: batch_size * 6 * 32 * 128 + self.train_batch.append(one_batch) + if i + batch_size < len(self.train_pairs): + rest = len(self.train_pairs) - (i + batch_size) + batch_pair = np.array(self.train_pairs[-rest:]) + random_items = np.array(random.sample(self.snippet_pool, rest*4)).reshape((rest, 4, 32, 128)) + one_batch = np.concatenate((batch_pair, random_items), axis=1) + self.train_batch.append(one_batch) + + for i in tqdm(range(0, len(self.val_pairs)-batch_size, batch_size)): + batch_pair = np.array(self.val_pairs[i: i+batch_size]) + random_items = np.array(random.sample(self.snippet_pool, batch_size*4)).reshape((batch_size, 4, 32, 128)) + one_batch = np.concatenate((batch_pair, random_items), axis=1) + self.val_batch.append(one_batch) + if i + batch_size < len(self.val_pairs): + rest = len(self.val_pairs) - (i + batch_size) + batch_pair = np.array(self.val_pairs[-rest:]) + random_items = np.array(random.sample(self.snippet_pool, rest*4)).reshape((rest, 4, 32, 128)) + one_batch = np.concatenate((batch_pair, random_items), axis=1) + self.val_batch.append(one_batch) + print('num_epoch:', self.num_epoch) + print('shuffle finished') + print('size of train_batch:', len(self.train_batch)) + print('size of val_batch:', len(self.val_batch)) + + def get_batch(self, stage='train'): + if stage == 'train': + idx = self.train_batch_anchor + self.train_batch_anchor += 1 + if self.train_batch_anchor == len(self.train_batch): + self.make_batch(self.batch_size) + return self.train_batch[idx] + elif stage == 'val': + idx = self.val_batch_anchor + self.val_batch_anchor += 1 + if self.val_batch_anchor == len(self.val_batch): + self.val_batch_anchor = 0 + return self.val_batch[idx] + + def get_batch_volumn(self, stage='train'): + if stage == 'train': + return len(self.train_batch) + elif stage == 'val': + return len(self.val_batch) + + def get_epoch(self): + return self.num_epoch + + +if __name__ == '__main__': + import torch + torch.cuda.current_device() + import sys + sys.path.append('AccoMontage') + from models import contrastive_model, TextureEncoder + + data_Set = dataset('checkpoints/song_data.npz', 1, 32) + data_Set.make_batch(1) + init_epoch = 0 + + texture_model = TextureEncoder(emb_size=256, hidden_dim=1024, z_dim=256, num_channel=10, for_contrastive=True) + checkpoint = torch.load("checkpoints/texture_model_params049.pt") + texture_model.load_state_dict(checkpoint) + texture_model.cuda() + texture_model.eval() + + contras_model = contrastive_model(emb_size=256, hidden_dim=1024) + contras_model.load_state_dict(torch.load('checkpoints/contrastive_model_params049.pt')) + contras_model.cuda() + contras_model.eval() + + """ while data_Set.get_epoch() <= 3: + print(data_Set.get_epoch()) + batch = data_Set.get_batch('train') + print('train', batch.shape) + if data_Set.train_batch_anchor == len(data_Set.train_batch): + #validate + for i in range(data_Set.get_batch_volumn('val')): + batch = data_Set.get_batch('val') + print('validating', batch.shape)""" + #print(data_Set.get_epoch()) + record = [] + for i in range(data_Set.get_batch_volumn('val')): + batch = data_Set.get_batch('val') + #print(batch.shape) + batch = torch.from_numpy(batch).cuda().float() + bs, pos_neg, time, roll = batch.shape + _, batch = texture_model(batch.view(-1, time, roll)) + batch = batch.view(bs, pos_neg, -1) + similarity = contras_model(batch) + model_loss = contras_model.contrastive_loss(similarity).cpu().detach().numpy() + record.append(model_loss) + record.sort() + print(record[-100:]) \ No newline at end of file diff --git a/piano_arranger/scripts/transition_model_train_contrastive.py b/piano_arranger/scripts/transition_model_train_contrastive.py new file mode 100644 index 0000000..9cba34e --- /dev/null +++ b/piano_arranger/scripts/transition_model_train_contrastive.py @@ -0,0 +1,165 @@ + +import os +import time +import torch +from torch.utils.tensorboard import SummaryWriter +from torch.optim.lr_scheduler import ExponentialLR +from torch import optim + +from transition_model_data_loader import dataset + +import sys +sys.path.append('AccoMontage') +from models import contrastive_model, TextureEncoder + +args={ + "batch_size": 8, + "data_path": "checkpoints/song_data.npz", + 'weight_path': "checkpoints/model_master_final.pt", + "embed_size": 256, + "hidden_dim": 1024, + "time_step": 32, + "n_epochs": 100, + "lr": 1e-3, + "decay": 0.99991, + "log_save": "demo/demo_generate/log", +} +# contrastive optimizer stabalizes at around 10 epochs + +class MinExponentialLR(ExponentialLR): + def __init__(self, optimizer, gamma, minimum, last_epoch=-1): + self.min = minimum + super(MinExponentialLR, self).__init__(optimizer, gamma, last_epoch=last_epoch) + + def get_lr(self): + return [ + max(base_lr * self.gamma**self.last_epoch, self.min) + for base_lr in self.base_lrs + ] + +class AverageMeter(object): + """Computes and stores the average and current value""" + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + +def train(contrastive_model, texture_model, dataset, optimizer, scheduler, loss_recorder, writer): + batch = dataset.get_batch('train') #8 * 6 * 8 * 32 * 128 + batch = torch.from_numpy(batch).float().cuda() + bs, pos_neg, time, roll = batch.shape + optimizer_2.zero_grad() + _, batch = texture_model(batch.view(-1, time, roll)) + batch = batch.view(bs, pos_neg, -1) + + + optimizer.zero_grad() + similarity = contrastive_model(batch) + model_loss = contrastive_model.contrastive_loss(similarity) + model_loss.backward() + torch.nn.utils.clip_grad_norm_(contrastive_model.parameters(), 1) + optimizer_2.step() + optimizer.step() + loss_recorder.update(model_loss.item()) + scheduler_2.step() + scheduler.step() + + n_epoch = dataset.get_epoch() + total_batch = dataset.get_batch_volumn('train') + current_batch = dataset.train_batch_anchor + step = current_batch + n_epoch * total_batch + + print('---------------------------Training VAE----------------------------') + for param in optimizer.param_groups: + print('lr1: ', param['lr']) + print('Epoch: [{0}][{1}/{2}]'.format(n_epoch, current_batch, total_batch)) + print('loss: {loss:.5f}'.format(loss=model_loss.item())) + writer.add_scalar('train_vae/1-loss_total-epoch', loss_recorder.avg, step) + writer.add_scalar('train_vae/5-learning-rate', param['lr'], step) + +def val(contrastive_model, texture_model, dataset, writer, val_loss_recoder): + loss = val_loss_recoder + step = 1 + for i in range(dataset.get_batch_volumn('val')): + batch = dataset.get_batch('val') + batch = torch.from_numpy(batch).float().cuda() + bs, pos_neg, time, roll = batch.shape + _, batch = texture_model(batch.view(-1, time, roll)) + batch = batch.view(bs, pos_neg, -1) + with torch.no_grad(): + similarity = contrastive_model(batch) + model_loss = contrastive_model.contrastive_loss(similarity) + loss.update(model_loss.item()) + n_epoch = dataset.get_epoch() + total_batch = dataset.get_batch_volumn('val') + print('----validation----') + print('Epoch: [{0}][{1}/{2}]'.format(n_epoch, step, total_batch)) + print('loss: {loss:.5f}'.format(loss=model_loss.item())) + step += 1 + writer.add_scalar('val/loss_total-epoch', loss.avg, n_epoch) + +embed_size = args["embed_size"] +hidden_dim = args["hidden_dim"] +weight_path = args["weight_path"] + +contrastive_model = contrastive_model(emb_size=embed_size, hidden_dim=hidden_dim).cuda() + +texture_model = TextureEncoder(emb_size=256, hidden_dim=1024, z_dim=256, num_channel=10, for_contrastive=True) +checkpoint = torch.load(weight_path) +from collections import OrderedDict +rhy_checkpoint = OrderedDict() +for k, v in checkpoint.items(): + part = k.split('.')[0] + name = '.'.join(k.split('.')[1:]) + if part == 'rhy_encoder': + rhy_checkpoint[name] = v +texture_model.load_state_dict(rhy_checkpoint) +texture_model.cuda() + +run_time = time.asctime(time.localtime(time.time())).replace(':', '-') +logdir = 'log/' + run_time[4:] +save_dir = 'params/' + run_time[4:] +logdir = os.path.join(args["log_save"], logdir) +save_dir = os.path.join(args["log_save"], save_dir) +batch_size = args['batch_size'] +if not os.path.exists(logdir): + os.makedirs(logdir) +if not os.path.exists(save_dir): + os.makedirs(save_dir) + +writer = SummaryWriter(logdir) +training_loss_recoder = AverageMeter() +val_loss_recoder = AverageMeter() +dataset = dataset(args['data_path'], batch_size, args['time_step']) +dataset.make_batch(batch_size) + +optimizer = optim.Adam(contrastive_model.parameters(), lr=args['lr']) +optimizer_2 = optim.Adam(texture_model.parameters(), lr=1e-4) +scheduler = MinExponentialLR(optimizer, gamma=args['decay'], minimum=1e-5,) +scheduler_2 = MinExponentialLR(optimizer_2, gamma=0.999995, minimum=5e-6,) + +while dataset.get_epoch() < args['n_epochs']: + if dataset.train_batch_anchor == 0: + contrastive_model.eval() + val(contrastive_model, texture_model, dataset, writer, val_loss_recoder) + if (dataset.get_epoch()) % 1 == 0: + checkpoint = save_dir + '/contrastive_model_params' + str(dataset.get_epoch()).zfill(3) + '.pt' + torch.save(contrastive_model.cpu().state_dict(), checkpoint) + contrastive_model.cuda() + checkpoint = save_dir + '/texture_model_params' + str(dataset.get_epoch()).zfill(3) + '.pt' + torch.save(texture_model.cpu().state_dict(), checkpoint) + texture_model.cuda() + print('Model saved!') + contrastive_model.train() + train(contrastive_model, texture_model, dataset, optimizer, scheduler, training_loss_recoder, writer) + \ No newline at end of file