From 750c2fe84bab7197101ce72333f97416e9c8094a Mon Sep 17 00:00:00 2001 From: suyunsen Date: Tue, 31 Oct 2023 17:00:07 +0800 Subject: [PATCH] =?UTF-8?q?=E9=87=8D=E6=96=B0=E5=B0=81=E8=A3=85chatGLM?= =?UTF-8?q?=E8=87=AA=E5=AE=9A=E4=B9=89=E6=A8=A1=E5=9E=8B=EF=BC=8C=E6=9B=B4?= =?UTF-8?q?=E6=96=B0=E6=94=BF=E5=8A=A1=E9=97=AE=E7=AD=94=E7=9A=84gradio?= =?UTF-8?q?=E7=95=8C=E9=9D=A2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Web_UI/goverment_UI/gradio_webui.py | 213 ++++++++++++++++++++++++++++ Web_UI/goverment_UI/options.py | 16 +++ 2 files changed, 229 insertions(+) create mode 100644 Web_UI/goverment_UI/gradio_webui.py create mode 100644 Web_UI/goverment_UI/options.py diff --git a/Web_UI/goverment_UI/gradio_webui.py b/Web_UI/goverment_UI/gradio_webui.py new file mode 100644 index 0000000..30533c5 --- /dev/null +++ b/Web_UI/goverment_UI/gradio_webui.py @@ -0,0 +1,213 @@ +""" + -*- coding: utf-8 -*- +time: 2023/10/26 16:15 +author: suyunsen +email: suyunsen2023@gmail.com +""" + +import json +import os +import time +import sys + +import gradio as gr + +from options import parser +from typing import Optional + +sys.path.append("../../") +from QA.govermentQA import GoQa +from Custom.ChatGLM import ChatGlm26b +from Custom.Custom_SparkLLM import Spark + +history = [] +readable_history = [] +cmd_opts = parser.parse_args() + +templet_prompt = """ + 假如你是一名语文专家,你需要根据一下给出的内容找出问题的正确答案。 + 答案只存在给出的内容中,你知道就回答,不要自己编造答案。 + 这是给出的内容:{context} + 问题是:{question} + 因为你是语文专家你需要仔细分析问题和给出的内容,不要给出多余的答案。 + 按给出的内容作答,你不需要自己总结。 + """ +llm = ChatGlm26b() +sparkllm:Optional[Spark] = None +qa_chain = GoQa(llm=llm,templet_prompt=templet_prompt) + +chat_h = """

ChatGLM WebUI

""" +spark_h = """

Spark WebUI

""" +head_h =chat_h + +_css = """ +#del-btn { + max-width: 2.5em; + min-width: 2.5em !important; + height: 2.5em; + margin: 1.5em 0; +} +""" + + +def prepare_model(): + global model + if cmd_opts.cpu: + model = model.float() + else: + if cmd_opts.precision == "fp16": + model = model.half().cuda() + elif cmd_opts.precision == "int4": + model = model.half().quantize(4).cuda() + elif cmd_opts.precision == "int8": + model = model.half().quantize(8).cuda() + + model = model.eval() + + +# prepare_model() + + +def parse_codeblock(text): + lines = text.split("\n") + for i, line in enumerate(lines): + if "```" in line: + if line != "```": + lines[i] = f'
'
+            else:
+                lines[i] = '
' + else: + if i > 0: + lines[i] = "
" + line.replace("<", "<").replace(">", ">") + return "".join(lines) + + +def predict(query, max_length, top_p, temperature): + global history + llm.set_llm_temperature(temperature) + output = qa_chain.ask_question(query,int(top_p)) + readable_history.append((query, parse_codeblock(output))) + # print(output) + return readable_history + + +def save_history(): + if not os.path.exists("outputs"): + os.mkdir("outputs") + + s = [{"q": i[0], "o": i[1]} for i in history] + filename = f"save-{int(time.time())}.json" + with open(os.path.join("outputs", filename), "w", encoding="utf-8") as f: + f.write(json.dumps(s, ensure_ascii=False)) + + +def load_history(file): + global history, readable_history + try: + with open(file.name, "r", encoding='utf-8') as f: + j = json.load(f) + _hist = [(i["q"], i["o"]) for i in j] + _readable_hist = [(i["q"], parse_codeblock(i["o"])) for i in j] + except Exception as e: + print(e) + return readable_history + history = _hist.copy() + readable_history = _readable_hist.copy() + return readable_history + + +def clear_history(): + history.clear() + readable_history.clear() + return gr.update(value=[]) + +def load_chatGlm(): + qa_chain.set_llm_modle(llm) + head_h = chat_h + La = gr.HTML(head_h) + return La,clear_history() + + +def load_spark(): + global sparkllm + if sparkllm is None: + sparkllm = Spark() + qa_chain.set_llm_modle(sparkllm) + head_h = spark_h + La = gr.HTML(head_h) + return La,clear_history() + + + + +def create_ui(): + with gr.Blocks(css=_css) as demo: + prompt = "输入你的内容..." + with gr.Row(): + with gr.Column(scale=3): + La = gr.HTML(head_h) + with gr.Row(): + with gr.Column(variant="panel"): + with gr.Row(): + max_length = gr.Slider(minimum=4, maximum=4096, step=4, label='Max Length', value=2048) + top_p = gr.Slider(minimum=1, maximum=15, step=1, label='检索返回Top P', value=5) + with gr.Row(): + temperature = gr.Slider(minimum=0.01, maximum=1.0, step=0.01, label='Temperature', value=0.01) + + # with gr.Row(): + # max_rounds = gr.Slider(minimum=1, maximum=50, step=1, label="最大对话轮数(调小可以显著改善爆显存,但是会丢失上下文)", value=20) + + with gr.Row(): + with gr.Column(variant="panel"): + with gr.Row(): + clear = gr.Button("清空对话(上下文)") + + with gr.Row(): + save_his_btn = gr.Button("保存对话") + load_his_btn = gr.UploadButton("读取对话", file_types=['file'], file_count='single') + + with gr.Row(): + chatGLM_load = gr.Button("加载chatGLM模型") + spark_load = gr.Button("加载星火模型") + with gr.Column(scale=7): + chatbot = gr.Chatbot(elem_id="chat-box", show_label=False).style(height=500) + with gr.Row(): + message = gr.Textbox(placeholder=prompt, show_label=False, lines=2) + clear_input = gr.Button("🗑️", elem_id="del-btn") + + with gr.Row(): + submit = gr.Button("发送") + + submit.click(predict, inputs=[ + message, + max_length, + top_p, + temperature + ], outputs=[ + chatbot + ]) + + clear.click(clear_history, outputs=[chatbot]) + clear_input.click(lambda x: "", inputs=[message], outputs=[message]) + + chatGLM_load.click(load_chatGlm,outputs=[La,chatbot]) + spark_load.click(load_spark,outputs=[La,chatbot]) + + save_his_btn.click(save_history) + + load_his_btn.upload(load_history, inputs=[ + load_his_btn, + ], outputs=[ + chatbot + ]) + + return demo + + +ui = create_ui() +ui.queue().launch( + server_name="0.0.0.0", + server_port=cmd_opts.port, + share=cmd_opts.share +) + diff --git a/Web_UI/goverment_UI/options.py b/Web_UI/goverment_UI/options.py new file mode 100644 index 0000000..3bfd343 --- /dev/null +++ b/Web_UI/goverment_UI/options.py @@ -0,0 +1,16 @@ +""" + -*- coding: utf-8 -*- +time: 2023/10/26 16:41 +author: suyunsen +email: suyunsen2023@gmail.com +""" +import argparse + +parser = argparse.ArgumentParser() + +parser.add_argument("--port", type=int, default="7800") +parser.add_argument("--model-path", type=str, default="/gemini/data-1") +parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["fp16", "int4", "int8"], default="fp16") +parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests") +parser.add_argument("--cpu", action='store_true', help="use cpu") +parser.add_argument("--share", action='store_true', help="use gradio share")