-
Notifications
You must be signed in to change notification settings - Fork 533
/
Copy pathgradio_demo.py
121 lines (110 loc) · 4.64 KB
/
gradio_demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
# -*- coding: utf-8 -*-
"""
@author:XuMing(xuming624@qq.com)
@description:
pip install gradio>=3.50.2
"""
import argparse
from threading import Thread
import gradio as gr
import torch
from peft import PeftModel
from transformers import (
AutoModel,
AutoTokenizer,
AutoModelForCausalLM,
BloomForCausalLM,
BloomTokenizerFast,
LlamaForCausalLM,
GenerationConfig,
TextIteratorStreamer,
)
from template import get_conv_template
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--base_model', default=None, type=str, required=True)
parser.add_argument('--lora_model', default="", type=str, help="If None, perform inference on the base model")
parser.add_argument('--tokenizer_path', default=None, type=str)
parser.add_argument('--template_name', default="vicuna", type=str,
help="Prompt template name, eg: alpaca, vicuna, baichuan2, chatglm2 etc.")
parser.add_argument('--system_prompt', default="", type=str)
parser.add_argument('--only_cpu', action='store_true', help='only use CPU for inference')
parser.add_argument('--resize_emb', action='store_true', help='Whether to resize model token embeddings')
parser.add_argument('--share', action='store_true', help='Share gradio')
parser.add_argument('--port', default=8081, type=int, help='Port of gradio demo')
args = parser.parse_args()
print(args)
load_type = 'auto'
if torch.cuda.is_available() and not args.only_cpu:
device = torch.device(0)
else:
device = torch.device('cpu')
if args.tokenizer_path is None:
args.tokenizer_path = args.base_model
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path, trust_remote_code=True)
base_model = AutoModelForCausalLM.from_pretrained(
args.base_model,
torch_dtype=load_type,
low_cpu_mem_usage=True,
device_map='auto',
trust_remote_code=True,
)
try:
base_model.generation_config = GenerationConfig.from_pretrained(args.base_model, trust_remote_code=True)
except OSError:
print("Failed to load generation config, use default.")
if args.resize_emb:
model_vocab_size = base_model.get_input_embeddings().weight.size(0)
tokenzier_vocab_size = len(tokenizer)
print(f"Vocab of the base model: {model_vocab_size}")
print(f"Vocab of the tokenizer: {tokenzier_vocab_size}")
if model_vocab_size != tokenzier_vocab_size:
print("Resize model embeddings to fit tokenizer")
base_model.resize_token_embeddings(tokenzier_vocab_size)
if args.lora_model:
model = PeftModel.from_pretrained(base_model, args.lora_model, torch_dtype=load_type, device_map='auto')
print("loaded lora model")
else:
model = base_model
if device == torch.device('cpu'):
model.float()
model.eval()
prompt_template = get_conv_template(args.template_name)
system_prompt = args.system_prompt
stop_str = tokenizer.eos_token if tokenizer.eos_token else prompt_template.stop_str
def predict(message, history):
"""Generate answer from prompt with GPT and stream the output"""
history_messages = history + [[message, ""]]
prompt = prompt_template.get_prompt(messages=history_messages, system_prompt=system_prompt)
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
input_ids = tokenizer(prompt).input_ids
context_len = 2048
max_new_tokens = 512
max_src_len = context_len - max_new_tokens - 8
input_ids = input_ids[-max_src_len:]
generation_kwargs = dict(
input_ids=torch.as_tensor([input_ids]).to(device),
streamer=streamer,
max_new_tokens=max_new_tokens,
temperature=0.7,
do_sample=True,
num_beams=1,
repetition_penalty=1.0,
)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
partial_message = ""
for new_token in streamer:
if new_token != stop_str:
partial_message += new_token
yield partial_message
gr.ChatInterface(
predict,
chatbot=gr.Chatbot(),
textbox=gr.Textbox(placeholder="Ask me question", lines=4, scale=9),
title="MedicalGPT",
description="为了促进医疗行业大模型的开放研究,本项目开源了[MedicalGPT](https://github.com/shibing624/MedicalGPT)医疗大模型",
theme="soft",
).queue().launch(share=args.share, inbrowser=True, server_name='0.0.0.0', server_port=args.port)
if __name__ == '__main__':
main()