Skip to content

Commit

Permalink
Initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
zhaojw1998 committed Oct 23, 2023
0 parents commit 9814119
Show file tree
Hide file tree
Showing 125 changed files with 12,394 additions and 0 deletions.
2 changes: 2 additions & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Auto detect text files and perform LF normalization
* text=auto
21 changes: 21 additions & 0 deletions LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
MIT License

Copyright (c) 2023 Zhao Jingwei

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# AccoMontage-3

192 changes: 192 additions & 0 deletions arrangement_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
import os
import pretty_midi as pyd
import numpy as np
import torch
from torch.utils.data import DataLoader
from scipy.interpolate import interp1d
from tqdm import tqdm

from piano_arranger.acc_utils import split_phrases
import piano_arranger.format_converter as cvt
from piano_arranger.models import DisentangleVAE
from piano_arranger.AccoMontage import find_by_length, dp_search, re_harmonization, get_texture_filter, ref_spotlight

from orchestrator import Slakh_Dataset, collate_fn, compute_pr_feat, EMBED_PROGRAM_MAPPING, Prior
from orchestrator.dataset import SLAKH_CLASS_PROGRAMS
from orchestrator.utils import grid2pr, pr2grid, matrix2midi, midi2matrix


SLAKH_CLASS_MAPPING = {v: k for k, v in EMBED_PROGRAM_MAPPING.items()}
TOTAL_LEN_BIN = np.array([4, 7, 12, 15, 20, 23, 28, 31, 36, 39, 44, 47, 52, 55, 60, 63, 68, 71, 76, 79, 84, 87, 92, 95, 100, 103, 108, 111, 116, 119, 124, 127, 132])
ABS_POS_BIN = np.arange(129)
REL_POS_BIN = np.arange(128)


def load_premise(DATA_FILE_ROOT, DEVICE):
"""Load AccoMontage Search Space"""
print('Loading AccoMontage piano texture search space. This may take 1 or 2 minutes ...')
data = np.load(os.path.join(DATA_FILE_ROOT, 'phrase_data.npz'), allow_pickle=True)
melody = data['melody']
acc = data['acc']
chord = data['chord']
vel = data['velocity']
cc = data['cc']
acc_pool = {}
for LEN in tqdm(range(2, 11)):
(mel, acc_, chord_, vel_, cc_, song_reference) = find_by_length(melody, acc, chord, vel, cc, LEN)
acc_pool[LEN] = (mel, acc_, chord_, vel_, cc_, song_reference)
texture_filter = get_texture_filter(acc_pool)
edge_weights=np.load(os.path.join(DATA_FILE_ROOT, 'edge_weights.npz'), allow_pickle=True)

"""Load Q&A Prompt Search Space"""
print('loading orchestration prompt search space ...')
slakh_dir = os.path.join(DATA_FILE_ROOT, 'Slakh2100_inference_set')
dataset = Slakh_Dataset(slakh_dir, debug_mode=False, split='test', mode='train')
loader = DataLoader(dataset, batch_size=1, shuffle=True, collate_fn=lambda b:collate_fn(b, DEVICE, get_pr_gt=True))
REF_PR = []
REF_P = []
REF_T = []
REF_PROG = []
FLTN = []
for (pr, _, prog, func_pitch, func_time, _, _, _, _, _) in tqdm(loader):
pr = pr[0]
prog = prog[0, :]
func_pitch = func_pitch[0, :]
func_time = func_time[0, :]
fltn = torch.cat([torch.sum(func_pitch, dim=-2), torch.sum(func_time, dim=-2)], dim=-1)
REF_PR.append(pr)
REF_P.append(func_pitch[0])
REF_T.append(func_time[0])
REF_PROG.append(prog)
FLTN.append(fltn[0])
FLTN = torch.stack(FLTN, dim=0)

"""Initialize orchestration model (Prior + Q&A)"""
print('Initialize model ...')
prior_model_path = os.path.join(DATA_FILE_ROOT, 'params_prior.pt')
QaA_model_path = os.path.join(DATA_FILE_ROOT, 'params_qa.pt')
orchestrator = Prior.init_inference_model(prior_model_path, QaA_model_path, DEVICE=DEVICE)
orchestrator.to(DEVICE)
orchestrator.eval()
piano_arranger = DisentangleVAE.init_model(torch.device('cuda')).cuda()
piano_arranger.load_state_dict(torch.load(os.path.join(DATA_FILE_ROOT, 'params_reharmonizer.pt')))
print('Finished.')
return piano_arranger, orchestrator, (acc_pool, edge_weights, texture_filter), (REF_PR, REF_P, REF_T, REF_PROG, FLTN)


def read_lead_sheet(DEMO_ROOT, SONG_NAME, SEGMENTATION, NOTE_SHIFT, melody_track_ID=0):
melody_roll, chord_roll = cvt.leadsheet2matrix(os.path.join(DEMO_ROOT, SONG_NAME, 'lead sheet.mid'), melody_track_ID)
assert(len(melody_roll == len(chord_roll)))
if NOTE_SHIFT != 0:
melody_roll = melody_roll[int(NOTE_SHIFT*4):, :]
chord_roll = chord_roll[int(NOTE_SHIFT*4):, :]
if len(melody_roll) % 16 != 0:
pad_len = (len(melody_roll)//16+1)*16-len(melody_roll)
melody_roll = np.pad(melody_roll, ((0, pad_len), (0, 0)))
melody_roll[-pad_len:, -1] = 1
chord_roll = np.pad(chord_roll, ((0, pad_len), (0, 0)))
chord_roll[-pad_len:, 0] = -1
chord_roll[-pad_len:, -1] = -1

CHORD_TABLE = np.stack([cvt.expand_chord(chord) for chord in chord_roll[::4]], axis=0)
LEADSHEET = np.concatenate((melody_roll, chord_roll[:, 1: -1]), axis=-1) #T*142, quantized at 16th
query_phrases = split_phrases(SEGMENTATION) #[('A', 8, 0), ('A', 8, 8), ('B', 8, 16), ('B', 8, 24)]

midi_len = len(LEADSHEET)//16
anno_len = sum([item[1] for item in query_phrases])
if midi_len > anno_len:
LEADSHEET = LEADSHEET[: anno_len*16]
CHORD_TABLE = CHORD_TABLE[: anno_len*4]
print(f'Mismatch warning: Detect {midi_len} bars in the lead sheet (MIDI) and {anno_len} bars in the provided phrase annotation. The lead sheet is truncated to {anno_len} bars.')
elif midi_len < anno_len:
pad_len = (anno_len - midi_len)*16
LEADSHEET = np.pad(LEADSHEET, ((0, pad_len), (0, 0)))
LEADSHEET[-pad_len:, 129] = 1
CHORD_TABLE = np.pad(CHORD_TABLE, ((0, pad_len//4), (0, 0)))
CHORD_TABLE[-pad_len//4:, 11] = -1
CHORD_TABLE[-pad_len//4:, -1] = -1
print(f'Mismatch warning: Detect {midi_len} bars in the lead sheet (MIDI) and {anno_len} bars in the provided phrase annotation. The lead sheet is padded to {anno_len} bars.')


melody_queries = []
for item in query_phrases:
start_bar = item[-1]
length = item[-2]
segment = LEADSHEET[start_bar*16: (start_bar+length)*16]
melody_queries.append(segment) #melody queries: list of T16*142, segmented by phrases

return (LEADSHEET, CHORD_TABLE, melody_queries, query_phrases)


def piano_arrangement(pianoRoll, chord_table, melody_queries, query_phrases, acc_pool, edge_weights, texture_filter, piano_arranger, PREFILTER):
print('Phrasal Unit selection begins:\n\t', f'{len(query_phrases)} phrases in the lead sheet;\n\t', f'set note density filter: {PREFILTER}.')
phrase_indice, chord_shift = dp_search( melody_queries,
query_phrases,
acc_pool,
edge_weights,
texture_filter,
filter_id=PREFILTER)
path = phrase_indice[0]
shift = chord_shift[0]
print('Re-harmonization begins ...')
midi_recon, acc = re_harmonization(pianoRoll, chord_table, query_phrases, path, shift, acc_pool, model=piano_arranger, get_est=True, tempo=100)
acc = np.array([grid2pr(matrix) for matrix in acc])
print('Piano accompaiment generated!')
return midi_recon, acc


def prompt_sampling(acc_piano, REF_PR, REF_P, REF_T, REF_PROG, FLTN, DEVICE='cuda:0'):
fltn = torch.from_numpy(np.concatenate(compute_pr_feat(acc_piano[0:1]), axis=-1)).to(DEVICE)
sim_func = torch.nn.CosineSimilarity(dim=-1)
distance = sim_func(fltn, FLTN)
distance = distance + torch.normal(mean=torch.zeros(distance.shape), std=0.2*torch.ones(distance.shape)).to(distance.device)
sim_values, anchor_points = torch.sort(distance, descending=True)
IDX = 0
sim_value = sim_values[IDX]
anchor_point = anchor_points[IDX]
func_pitch = REF_P[anchor_point]
func_time = REF_T[anchor_point]
prog = REF_PROG[anchor_point]
prog_class = [SLAKH_CLASS_MAPPING[item.item()] for item in prog.cpu().detach().numpy()]
program_name = [SLAKH_CLASS_PROGRAMS[item] for item in prog_class]
refr = REF_PR[anchor_point]
midi_ref = matrix2midi(
pr_matrices=refr.detach().cpu().numpy(),
programs=prog_class,
init_tempo=100)
print(f'Prior model initialized with {len(program_name)} tracks:\n\t{program_name}')
return midi_ref, (prog, func_pitch, func_time)

def orchestration(acc_piano, chord_track, prog, func_pitch, func_time, orchestrator, DEVICE='cuda:0', blur=.5, p=.1, t=4):
print('Orchestration begins ...')
if chord_track is not None:
if len(acc_piano) > len(chord_track):
chord_track = np.pad(chord_track, ((0, 0), (len(acc_piano)-len(chord_track))))
else:
chord_track = chord_track[:len(acc_piano)]
acc_piano = np.max(np.stack([acc_piano, chord_track], axis=0), axis=0)

mix = torch.from_numpy(np.array([pr2grid(matrix, max_note_count=32) for matrix in acc_piano])).to(DEVICE)
r_pos = np.round(np.arange(0, len(mix), 1) / (len(mix)-1) * len(REL_POS_BIN))
total_len = np.argmin(np.abs(TOTAL_LEN_BIN - len(mix))).repeat(len(mix))
a_pos = np.append(ABS_POS_BIN[0: min(ABS_POS_BIN[-1],len(mix))], [ABS_POS_BIN[-1]] * (len(mix)-ABS_POS_BIN[-1]))
r_pos = torch.from_numpy(r_pos).long().to(DEVICE)
a_pos = torch.from_numpy(a_pos).long().to(DEVICE)
total_len = torch.from_numpy(total_len).long().to(DEVICE)

if func_pitch is not None:
func_pitch = func_pitch.unsqueeze(0).unsqueeze(0)
if func_time is not None:
func_time = func_time.unsqueeze(0).unsqueeze(0)
recon_pitch, recon_dur = orchestrator.run_autoregressive_nucleus(mix.unsqueeze(0), prog.unsqueeze(0), func_pitch, func_time, total_len.unsqueeze(0), a_pos.unsqueeze(0), r_pos.unsqueeze(0), blur, p, t) #func_pitch.unsqueeze(0).unsqueeze(0), func_time.unsqueeze(0).unsqueeze(0)

grid_recon = torch.cat([recon_pitch.max(-1)[-1].unsqueeze(-1), recon_dur.max(-1)[-1]], dim=-1)
bat_ch, track, _, max_simu_note, grid_dim = grid_recon.shape
grid_recon = grid_recon.permute(1, 0, 2, 3, 4)
grid_recon = grid_recon.reshape(track, -1, max_simu_note, grid_dim)

pr_recon_ = np.array([grid2pr(matrix) for matrix in grid_recon.detach().cpu().numpy()])
pr_recon = matrix2midi(pr_recon_, [SLAKH_CLASS_MAPPING[item.item()] for item in prog.cpu().detach().numpy()], 100)
print('Full-band accompaiment generated!')
return pr_recon

5 changes: 5 additions & 0 deletions data_file_dir.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
If you wish to run our code,
please download our pre-trained checkpoints
and processed data at the following link:

https://we.tl/t-cc2nOC8dAA
Binary file added demo/Castles in the Air/arrangement_band.mid
Binary file not shown.
Binary file added demo/Castles in the Air/arrangement_piano.mid
Binary file not shown.
Binary file added demo/Castles in the Air/lead sheet.mid
Binary file not shown.
Binary file added demo/Jingle Bells/arrangement_band.mid
Binary file not shown.
Binary file added demo/Jingle Bells/arrangement_piano.mid
Binary file not shown.
Binary file added demo/Jingle Bells/lead sheet.mid
Binary file not shown.
Binary file added demo/Sally Garden/arrangement_band.mid
Binary file not shown.
Binary file added demo/Sally Garden/arrangement_piano.mid
Binary file not shown.
Binary file added demo/Sally Garden/lead sheet.mid
Binary file not shown.
170 changes: 170 additions & 0 deletions inference.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"os.environ['CUDA_VISIBLE_DEVICES']= '0'\n",
"os.environ['CUDA_LAUNCH_BLOCKING'] = '1'\n",
"os.environ[\"KMP_DUPLICATE_LIB_OK\"]=\"TRUE\"\n",
"import numpy as np\n",
"import pretty_midi as pyd\n",
"from arrangement_utils import *\n",
"import warnings\n",
"warnings.filterwarnings(\"ignore\")\n",
"\n",
"\"\"\"Download checkpoints and data at https://we.tl/t-cc2nOC8dAA (379MB) and decompress at the directory\"\"\"\n",
"DATA_FILE_ROOT = './data_file_dir/'\n",
"DEVICE = 'cuda:0'\n",
"piano_arranger, orchestrator, piano_texture, band_prompt = load_premise(DATA_FILE_ROOT, DEVICE)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Initialize input and preference\n",
"We provide three sample lead sheets for a quick inference. You should be able to directly run the code blocks after downloading the pre-trained checkpoints.\n",
"\n",
"If you wish to test our model on your own lead sheet file, please initialize a sub-folder with its `SONG_NAME` in the `./demo` folder and put the file in, and name the file \"lead sheet.mid\". \n",
"\n",
"Please also specify `SEGMENTATION` (phrase structure) and `NOTE_SHIFT` (the duration of the pick-up measure if any)."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"\"\"\"Set input lead sheet\"\"\"\n",
"SONG_NAME, SEGMENTATION, NOTE_SHIFT = 'Castles in the Air', 'A8A8B8B8', 1 #1 beat in the pick-up measure\n",
"#SONG_NAME, SEGMENTATION, NOTE_SHIFT = 'Jingle Bells', 'A8B8A8', 0\n",
"#SONG_NAME, SEGMENTATION, NOTE_SHIFT = 'Sally Garden', 'A4A4B4A4', 0\n",
"\n",
"\"\"\"Set texture pre-filtering for piano arrangement (default random)\"\"\"\n",
"RHTHM_DENSITY = np.random.randint(3, 5)\n",
"VOICE_NUMBER = np.random.randint(3, 5)\n",
"PREFILTER = (RHTHM_DENSITY, VOICE_NUMBER)\n",
"\n",
"\"\"\"Set if use a 2-bar prompt for full-band arrangement (default False)\"\"\" \n",
"USE_PROMPT = False\n",
"\n",
"lead_sheet = read_lead_sheet('./demo', SONG_NAME, SEGMENTATION, NOTE_SHIFT)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Piano Accompaniment Arrangement"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Phrasal Unit selection begins:\n",
"\t 4 phrases in the lead sheet;\n",
"\t set note density filter: (4, 4).\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 3/3 [00:17<00:00, 5.82s/it]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Re-harmonization begins ...\n",
"Piano accompaiment generated!\n"
]
}
],
"source": [
"midi_piano, acc_piano = piano_arrangement(*lead_sheet, *piano_texture, piano_arranger, PREFILTER)\n",
"midi_piano.write(f'./demo/{SONG_NAME}/arrangement_piano.mid')"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Orchestration"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Prior model initialized with 7 tracks:\n",
"\t['Acoustic Bass', 'Chromatic Percussion', 'Synth Pad', 'Acoustic Piano', 'Electric Piano', 'Acoustic Bass', 'Acoustic Guitar']\n",
"Orchestration begins ...\n",
"Full-band accompaiment generated!\n"
]
}
],
"source": [
"midi_prompt, func_prompt = prompt_sampling(acc_piano, *band_prompt, DEVICE)\n",
"if USE_PROMPT:\n",
" midi_band = orchestration(acc_piano, None, *func_prompt, orchestrator, DEVICE, blur=.5, p=.1, t=4)\n",
"else:\n",
" instruments, pitch_prompt, time_promt = func_prompt\n",
" midi_band = orchestration(acc_piano, None, instruments, None, None, orchestrator, DEVICE, blur=.5, p=.1, t=4)\n",
"mel_track = pyd.Instrument(program=72, is_drum=False, name='melody')\n",
"mel_track.notes = midi_piano.instruments[0].notes\n",
"midi_band.instruments.append(mel_track)\n",
"midi_band.write(f'./demo/{SONG_NAME}/arrangement_band.mid')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "torch1.10_conda11.3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
Loading

0 comments on commit 9814119

Please sign in to comment.