-
Notifications
You must be signed in to change notification settings - Fork 35
/
Copy pathinterview_modal.py
executable file
·64 lines (51 loc) · 2.31 KB
/
interview_modal.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
#!/usr/bin/env python3
import fire
import json
import sys
from jinja2 import Template
import subprocess
# Format: <gpu>-<memory>x<count>
# Examples: T4, A10Gx2, A100-40x4
def parse_gpu_string(gstr):
count = 1
memory = None
size_split = gstr.split('x')
if len(size_split) > 1: count = size_split[1]
mem_split = size_split[0].split('-')
if len(mem_split) > 1: memory = mem_split[1]
return f"modal.gpu.{mem_split[0]}(count={count}" + (f", size='{memory}')" if memory else ")")
def main(model: str, runtime: str, gpu: str = "A10G", input: str = "", interview: str = "senior", prompt:str="", params: str = "", templateout: str = "", revision: str = "", info: str = "{}", quant: str = "fp16", context : int = 2048):
model_info = json.loads(info) if isinstance(info, str) else info
model_args = { 'info': model_info }
model_args['info']['quant'] = quant
model_args['info']['context_size'] = context
if revision: model_args['revision'] = revision
if isinstance(revision, int): raise Exception("Please escape --revision with \\' to avoid Fire parsing issues.")
model_clean = model.replace('/','-').replace('_','-').replace('.','-')
model_clean_py = model_clean.replace('-','_')
if input == "" and interview == "": raise Exception("Please provide either --input or --interview")
input_template = "interview_modal.tpl.py"
tpl = Template(open(input_template).read())
modal_params = {
'MODELSLUG': model_clean_py,
'MODELARGS': str(model_args),
'MODELNAME': model,
'RUNTIME': runtime,
'GPUREQUEST': parse_gpu_string(gpu)
}
output = tpl.render(modal_params)
output_script = f"modal_run_{model_clean_py}_{runtime}_{gpu}.py"
with open(output_script,'w') as f:
f.write(output)
args = []
if input: args += ["--input",input]
if interview: args += ["--interview",interview]
if params: args += ["--params", params]
if templateout: args += ["--templateout",templateout]
if prompt: args += ["--prompt",prompt]
if runtime == "vllm": args += ["--batch"]
print(f"Rendered {output_script} with {modal_params}, executing via modal run with {args}")
subprocess.run(["modal", "run", output_script]+args)
if __name__ == "__main__":
import fire
fire.Fire(main)