Skip to content

Commit

Permalink
Update app_onnx.py
Browse files Browse the repository at this point in the history
  • Loading branch information
SkyTNT committed Oct 9, 2024
1 parent ad1241a commit c043187
Showing 1 changed file with 38 additions and 17 deletions.
55 changes: 38 additions & 17 deletions app_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,10 +222,12 @@ def run(model_name, tab, mid_seq, continuation_state, continuation_select, instr
if current_model != model_name:
gr.Info("Loading model...")
model_info = models_info[model_name]
model_config = model_info[0]
model_config, model_config_url = model_info[0]
model_base_path, model_base_url = model_info[1]
model_token_path, model_token_url = model_info[2]
try:
if model_config.endswith(".json"):
download_if_not_exit(model_config_url, model_config)
download_if_not_exit(model_base_url, model_base_path)
download_if_not_exit(model_token_url, model_token_path)
except Exception as e:
Expand Down Expand Up @@ -406,6 +408,7 @@ def undo_continuation(mid_seq, continuation_state):


def download(url, output_file):
print(f"Downloading {output_file} from {url}")
response = requests.get(url, stream=True)
file_size = int(response.headers.get("Content-Length", 0))
with tqdm.tqdm(total=file_size, unit="B", unit_scale=True, unit_divisor=1024,
Expand Down Expand Up @@ -448,16 +451,22 @@ def template_response(*args, **kwargs):

gr.routes.templates.TemplateResponse = template_response

def get_tokenizer(config_name):
tv, size = config_name.split("-")
tv = tv[1:]
if tv[-1] == "o":
o = True
tv = tv[:-1]
def get_tokenizer(config_name_or_path):
if config_name_or_path.endswith(".json"):
with open(config_name_or_path, "r") as f:
config = json.load(f)
tv = config["tokenizer"]["version"]
o = config["tokenizer"]["optimise_midi"]
else:
o = False
if tv not in ["v1", "v2"]:
raise ValueError(f"Unknown tokenizer version {tv}")
tv, size = config_name_or_path.split("-")
tv = tv[1:]
if tv[-1] == "o":
o = True
tv = tv[:-1]
else:
o = False
if tv not in ["v1", "v2"]:
raise ValueError(f"Unknown tokenizer version {tv}")
tokenizer = MIDITokenizer(tv)
tokenizer.set_optimise_midi(o)
return tokenizer
Expand Down Expand Up @@ -492,14 +501,19 @@ def check_update(current_ver):
parser.add_argument("--batch", type=int, default=8, help="batch size")
parser.add_argument("--max-gen", type=int, default=4096, help="max")
parser.add_argument("--soundfont-path", type=str, default="soundfont.sf2", help="soundfont")
parser.add_argument("--model-config", type=str, default="tv2o-medium", help="model config name")
parser.add_argument("--model-config", type=str,
default="models/default/config.json",
help="model config name or path")
parser.add_argument("--model-base-path", type=str,
default="models/default/model_base.onnx", help="model path")
parser.add_argument("--model-token-path", type=str,
default="models/default/model_token.onnx", help="model path")
parser.add_argument("--soundfont-url", type=str,
default="https://huggingface.co/skytnt/midi-model/resolve/main/soundfont.sf2",
help="download soundfont to soundfont-path if file not exist")
parser.add_argument("--model-config-url", type=str,
default="https://huggingface.co/skytnt/midi-model-tv2o-medium/resolve/main/config.json",
help="download config.json to model-config if file not exist")
parser.add_argument("--model-base-url", type=str,
default="https://huggingface.co/skytnt/midi-model-tv2o-medium/resolve/main/onnx/model_base.onnx",
help="download model-base to model-base-path if file not exist")
Expand All @@ -511,40 +525,45 @@ def check_update(current_ver):
OUTPUT_BATCH_SIZE = opt.batch
models_info = {
"generic pretrain model (tv2o-medium) by skytnt (default)": [
opt.model_config,
[opt.model_config, opt.model_config_url],
[opt.model_base_path, opt.model_base_url],
[opt.model_token_path, opt.model_token_url]
],
"generic pretrain model (tv2o-medium) by skytnt with jpop lora": [
"tv2o-medium",
["models/tv2om_skytnt_jpop_lora/config.json",
"https://huggingface.co/skytnt/midi-model-tv2o-medium/resolve/main/config.json"],
["models/tv2om_skytnt_jpop_lora/model_base.onnx",
"https://huggingface.co/skytnt/midi-model-tv2om-jpop-lora/resolve/main/onnx/model_base.onnx"],
["models/tv2om_skytnt_jpop_lora/model_token.onnx",
"https://huggingface.co/skytnt/midi-model-tv2om-jpop-lora/resolve/main/onnx/model_token.onnx"]
],
"generic pretrain model (tv2o-medium) by skytnt with touhou lora": [
"tv2o-medium",
["models/tv2om_skytnt_touhou_lora/config.json",
"https://huggingface.co/skytnt/midi-model-tv2o-medium/resolve/main/config.json"],
["models/tv2om_skytnt_touhou_lora/model_base.onnx",
"https://huggingface.co/skytnt/midi-model-tv2om-touhou-lora/resolve/main/onnx/model_base.onnx"],
["models/tv2om_skytnt_touhou_lora/model_token.onnx",
"https://huggingface.co/skytnt/midi-model-tv2om-touhou-lora/resolve/main/onnx/model_token.onnx"]
],
"generic pretrain model (tv2o-large) by asigalov61": [
"tv2o-large",
["models/tv2ol_asigalov61/config.json",
"https://huggingface.co/asigalov61/Music-Llama/resolve/main/config.json"],
["models/tv2ol_asigalov61/model_base.onnx",
"https://huggingface.co/asigalov61/Music-Llama/resolve/main/onnx/model_base.onnx"],
["models/tv2ol_asigalov61/model_token.onnx",
"https://huggingface.co/asigalov61/Music-Llama/resolve/main/onnx/model_token.onnx"]
],
"generic pretrain model (tv2o-medium) by asigalov61": [
"tv2o-medium",
["models/tv2om_asigalov61/config.json",
"https://huggingface.co/asigalov61/Music-Llama-Medium/resolve/main/config.json"],
["models/tv2om_asigalov61/model_base.onnx",
"https://huggingface.co/asigalov61/Music-Llama-Medium/resolve/main/onnx/model_base.onnx"],
["models/tv2om_asigalov61/model_token.onnx",
"https://huggingface.co/asigalov61/Music-Llama-Medium/resolve/main/onnx/model_token.onnx"]
],
"generic pretrain model (tv1-medium) by skytnt": [
"tv1-medium",
["models/tv1m_skytnt/config.json",
"https://huggingface.co/skytnt/midi-model/resolve/main/config.json"],
["models/tv1m_skytnt/model_base.onnx",
"https://huggingface.co/skytnt/midi-model/resolve/main/onnx/model_base.onnx"],
["models/tv1m_skytnt/model_token.onnx",
Expand All @@ -554,6 +573,8 @@ def check_update(current_ver):
current_model = list(models_info.keys())[0]
try:
download_if_not_exit(opt.soundfont_url, opt.soundfont_path)
if opt.model_config.endswith(".json"):
download_if_not_exit(opt.model_config_url, opt.model_config)
download_if_not_exit(opt.model_base_url, opt.model_base_path)
download_if_not_exit(opt.model_token_url, opt.model_token_path)
except Exception as e:
Expand Down

0 comments on commit c043187

Please sign in to comment.