From c0431872726e6eb191162761288c2dea0b29f9f7 Mon Sep 17 00:00:00 2001 From: skytnt Date: Wed, 9 Oct 2024 21:58:39 +0800 Subject: [PATCH] Update app_onnx.py --- app_onnx.py | 55 ++++++++++++++++++++++++++++++++++++----------------- 1 file changed, 38 insertions(+), 17 deletions(-) diff --git a/app_onnx.py b/app_onnx.py index 681b7f0..f1230e7 100644 --- a/app_onnx.py +++ b/app_onnx.py @@ -222,10 +222,12 @@ def run(model_name, tab, mid_seq, continuation_state, continuation_select, instr if current_model != model_name: gr.Info("Loading model...") model_info = models_info[model_name] - model_config = model_info[0] + model_config, model_config_url = model_info[0] model_base_path, model_base_url = model_info[1] model_token_path, model_token_url = model_info[2] try: + if model_config.endswith(".json"): + download_if_not_exit(model_config_url, model_config) download_if_not_exit(model_base_url, model_base_path) download_if_not_exit(model_token_url, model_token_path) except Exception as e: @@ -406,6 +408,7 @@ def undo_continuation(mid_seq, continuation_state): def download(url, output_file): + print(f"Downloading {output_file} from {url}") response = requests.get(url, stream=True) file_size = int(response.headers.get("Content-Length", 0)) with tqdm.tqdm(total=file_size, unit="B", unit_scale=True, unit_divisor=1024, @@ -448,16 +451,22 @@ def template_response(*args, **kwargs): gr.routes.templates.TemplateResponse = template_response -def get_tokenizer(config_name): - tv, size = config_name.split("-") - tv = tv[1:] - if tv[-1] == "o": - o = True - tv = tv[:-1] +def get_tokenizer(config_name_or_path): + if config_name_or_path.endswith(".json"): + with open(config_name_or_path, "r") as f: + config = json.load(f) + tv = config["tokenizer"]["version"] + o = config["tokenizer"]["optimise_midi"] else: - o = False - if tv not in ["v1", "v2"]: - raise ValueError(f"Unknown tokenizer version {tv}") + tv, size = config_name_or_path.split("-") + tv = tv[1:] + if tv[-1] == "o": + o = True + tv = tv[:-1] + else: + o = False + if tv not in ["v1", "v2"]: + raise ValueError(f"Unknown tokenizer version {tv}") tokenizer = MIDITokenizer(tv) tokenizer.set_optimise_midi(o) return tokenizer @@ -492,7 +501,9 @@ def check_update(current_ver): parser.add_argument("--batch", type=int, default=8, help="batch size") parser.add_argument("--max-gen", type=int, default=4096, help="max") parser.add_argument("--soundfont-path", type=str, default="soundfont.sf2", help="soundfont") - parser.add_argument("--model-config", type=str, default="tv2o-medium", help="model config name") + parser.add_argument("--model-config", type=str, + default="models/default/config.json", + help="model config name or path") parser.add_argument("--model-base-path", type=str, default="models/default/model_base.onnx", help="model path") parser.add_argument("--model-token-path", type=str, @@ -500,6 +511,9 @@ def check_update(current_ver): parser.add_argument("--soundfont-url", type=str, default="https://huggingface.co/skytnt/midi-model/resolve/main/soundfont.sf2", help="download soundfont to soundfont-path if file not exist") + parser.add_argument("--model-config-url", type=str, + default="https://huggingface.co/skytnt/midi-model-tv2o-medium/resolve/main/config.json", + help="download config.json to model-config if file not exist") parser.add_argument("--model-base-url", type=str, default="https://huggingface.co/skytnt/midi-model-tv2o-medium/resolve/main/onnx/model_base.onnx", help="download model-base to model-base-path if file not exist") @@ -511,40 +525,45 @@ def check_update(current_ver): OUTPUT_BATCH_SIZE = opt.batch models_info = { "generic pretrain model (tv2o-medium) by skytnt (default)": [ - opt.model_config, + [opt.model_config, opt.model_config_url], [opt.model_base_path, opt.model_base_url], [opt.model_token_path, opt.model_token_url] ], "generic pretrain model (tv2o-medium) by skytnt with jpop lora": [ - "tv2o-medium", + ["models/tv2om_skytnt_jpop_lora/config.json", + "https://huggingface.co/skytnt/midi-model-tv2o-medium/resolve/main/config.json"], ["models/tv2om_skytnt_jpop_lora/model_base.onnx", "https://huggingface.co/skytnt/midi-model-tv2om-jpop-lora/resolve/main/onnx/model_base.onnx"], ["models/tv2om_skytnt_jpop_lora/model_token.onnx", "https://huggingface.co/skytnt/midi-model-tv2om-jpop-lora/resolve/main/onnx/model_token.onnx"] ], "generic pretrain model (tv2o-medium) by skytnt with touhou lora": [ - "tv2o-medium", + ["models/tv2om_skytnt_touhou_lora/config.json", + "https://huggingface.co/skytnt/midi-model-tv2o-medium/resolve/main/config.json"], ["models/tv2om_skytnt_touhou_lora/model_base.onnx", "https://huggingface.co/skytnt/midi-model-tv2om-touhou-lora/resolve/main/onnx/model_base.onnx"], ["models/tv2om_skytnt_touhou_lora/model_token.onnx", "https://huggingface.co/skytnt/midi-model-tv2om-touhou-lora/resolve/main/onnx/model_token.onnx"] ], "generic pretrain model (tv2o-large) by asigalov61": [ - "tv2o-large", + ["models/tv2ol_asigalov61/config.json", + "https://huggingface.co/asigalov61/Music-Llama/resolve/main/config.json"], ["models/tv2ol_asigalov61/model_base.onnx", "https://huggingface.co/asigalov61/Music-Llama/resolve/main/onnx/model_base.onnx"], ["models/tv2ol_asigalov61/model_token.onnx", "https://huggingface.co/asigalov61/Music-Llama/resolve/main/onnx/model_token.onnx"] ], "generic pretrain model (tv2o-medium) by asigalov61": [ - "tv2o-medium", + ["models/tv2om_asigalov61/config.json", + "https://huggingface.co/asigalov61/Music-Llama-Medium/resolve/main/config.json"], ["models/tv2om_asigalov61/model_base.onnx", "https://huggingface.co/asigalov61/Music-Llama-Medium/resolve/main/onnx/model_base.onnx"], ["models/tv2om_asigalov61/model_token.onnx", "https://huggingface.co/asigalov61/Music-Llama-Medium/resolve/main/onnx/model_token.onnx"] ], "generic pretrain model (tv1-medium) by skytnt": [ - "tv1-medium", + ["models/tv1m_skytnt/config.json", + "https://huggingface.co/skytnt/midi-model/resolve/main/config.json"], ["models/tv1m_skytnt/model_base.onnx", "https://huggingface.co/skytnt/midi-model/resolve/main/onnx/model_base.onnx"], ["models/tv1m_skytnt/model_token.onnx", @@ -554,6 +573,8 @@ def check_update(current_ver): current_model = list(models_info.keys())[0] try: download_if_not_exit(opt.soundfont_url, opt.soundfont_path) + if opt.model_config.endswith(".json"): + download_if_not_exit(opt.model_config_url, opt.model_config) download_if_not_exit(opt.model_base_url, opt.model_base_path) download_if_not_exit(opt.model_token_url, opt.model_token_path) except Exception as e: