-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathapp.py
791 lines (710 loc) · 27.7 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
import random
import time
import asyncio
import dashscope
import openai
import pandas as pd
import plotly.express as px
import streamlit as st
from chat_models.spark_model import SparkClient, SparkChatConfig, SparkMsgInfo
from collect import TokenCounter
from dialog import message
from image import conversation2png
from juicy import clickable_select
from prompt import (
PROMPTS,
get_description_by_preset_id,
get_prompt_by_preset_id,
get_suggestion_by_preset_id,
)
from share import generate_share_link, restore_from_share_link
from utils.common_resource import get_tokenizer
openai.api_key = st.secrets["OPENAI_API_KEY"]
dashscope.api_key = st.secrets["Qwen"]["DASHSCOPE_API_KEY"]
st.set_page_config(
page_title="GPT-3 Playground",
layout="wide",
initial_sidebar_state="auto",
)
st_tiltle_slot = st.empty()
st_tiltle_slot.title("GPT-3 你问我答")
st.markdown(
"""[![GitHub][github_badge]][github_link]\n\n[github_badge]: https://badgen.net/badge/icon/GitHub?icon=github&color=black&label\n[github_link]: https://github.com/switchball/streamlit-gpt3"""
)
st_desc_solt = st.empty()
st_desc_solt.text('在下方文本框输入你的对话 ✨ 支持多轮对话 😉 \n现支持星火大模型V2.0,该服务已内嵌联网搜索、日期查询、天气查询、股票查询、诗词查询、字词理解等功能\n已支持星火大模型V3.0,在数学、代码、医疗、教育等场景进行了专项优化\n已支持阿里通义千问千亿级别超大规模语言模型')
# st.success('GPT-3 非常擅长与人对话,甚至是与自己对话。只需要几行的指示,就可以让 AI 模仿客服聊天机器人的语气进行对话。\n关键在于,需要描述 AI 应该表现成什么样,并且举几个例子。', icon="✅")
# st.success('看起来很简单,但也有些需要额外注意的地方:\n1. 在开头描述意图,一句话概括 AI 的个性,通常还需要 1~2 个例子,模仿对话的内容。\n2. 给 AI 一个身份(identity),如果是个在实验室研究的科学家身份,那可能就会得到更有智慧的话。以下是一些可参考的例子', icon="✅")
st.write(
"""<style>
[data-testid="column"] {
min-width: 1rem !important;
}
</style>""",
unsafe_allow_html=True,
)
@st.cache_resource
def get_token_counter():
# if the definition of TokenCounter changes, the app need to reboot.
tc = TokenCounter(interval=900)
return tc
def wait(delay, reason=""):
if delay <= 30:
return
end = time.time() + delay
for t in range(int(delay)):
with st.spinner(text=f"{reason} 预计等待时间 {round(end - time.time())} 秒"):
time.sleep(random.uniform(2, 7))
if time.time() > end:
break
@st.cache_data(ttl=3600)
def completion(
prompt,
model="text-davinci-003",
temperature=0.9,
max_tokens=256,
top_p=1,
frequency_penalty=0,
presence_penalty=0.6,
stop=[" Human:", " AI:"],
):
"""Text completion"""
print("completion", prompt)
with st.spinner(text=random.choice(HINT_TEXTS)):
response = openai.Completion.create(
model=model,
prompt=prompt,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
stop=stop,
)
print("completion finished.")
print(response["choices"][0]["text"])
return response
def _chat_completion_gpt(
message_list,
model="gpt-3.5-turbo",
temperature=0.9,
max_tokens=256,
top_p=1,
frequency_penalty=0,
presence_penalty=0.6,
stream=False,
):
"""Chat completion"""
with st.spinner(text=random.choice(HINT_TEXTS)):
response = openai.ChatCompletion.create(
model=model,
messages=message_list,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
stream=stream,
)
if stream:
reply_msg = ""
finish_reason = ""
# streaming chat with editable slot
reply_edit_slot = st.empty()
for chunk in response:
c = chunk["choices"][0]
delta = c.get("delta", {}).get("content", "")
finish_reason = c.get("finish_reason", "")
reply_msg += delta
reply_edit_slot.markdown(reply_msg)
reply_edit_slot.markdown("")
# calculate message tokens
txt = "".join(m["content"] for m in message_list)
input_tokens = len(get_tokenizer().tokenize(txt))
completion_tokens = len(get_tokenizer().tokenize(reply_msg))
# mock response
response = {
"choices": [
{
"message": {"content": reply_msg, "role": "assistant"},
"finish_reason": finish_reason,
}
],
"usage": {"total_tokens": input_tokens + completion_tokens},
}
return response
else:
return response
# cannot pickle 'coroutine' object
# @st.cache_data(ttl=3600)
async def _chat_completion_spark(
message_list,
model="星火V2.0",
temperature=0.9,
max_tokens=256,
top_p=1,
frequency_penalty=0,
presence_penalty=0.6,
stream=True,
):
domain = (
"generalv1"
if "1.0" in model
else "generalv2"
if "2.0" in model
else "generalv3"
if "3.0" in model
else None
)
if domain is None:
st.error(f"Unknown model: {model}")
st.stop()
chat_conf = SparkChatConfig(
domain=domain,
temperature=temperature,
max_tokens=max_tokens,
# top_k: use default value
)
client = SparkClient(
app_id=st.secrets["Spark"]["APP_ID"],
api_secret=st.secrets["Spark"]["API_SECRET"],
api_key=st.secrets["Spark"]["API_KEY"],
chat_conf=chat_conf,
)
with st.spinner(text=f"[星火-{domain}]" + random.choice(HINT_TEXTS)):
slot = st.empty()
msg_info : SparkMsgInfo = None
async for msg_info in client.aiohttp_chat(message_list):
slot.markdown(client.answer_full_content)
answer = msg_info.msg_content
st.write(msg_info.usage_info)
response = {
"choices": [
{
"message": {"content": answer, "role": "assistant"},
"finish_reason": "",
}
],
"usage": {
"prompt_tokens": msg_info.usage_info["prompt_tokens"],
"completion_tokens": msg_info.usage_info["completion_tokens"],
"total_tokens": msg_info.usage_info["total_tokens"]
},
}
return response
async def _chat_completion_qwen(
message_list,
model="通义千问-max",
temperature=0.9,
max_tokens=2048,
top_p=0.8,
repetition_penalty=1.1,
):
M = dashscope.Generation.Models
model_name = M.qwen_plus if model == "通义千问-max-30k" else None
if model_name in (M.qwen_max, M.qwen_plus, M.qwen_turbo, 'qwen-max-longcontext'):
max_tokens = min(1500, max_tokens)
response = dashscope.Generation.call(
model_name,
messages=message_list,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
repetition_penalty=repetition_penalty,
result_format='message', # set the result to be "message" format.
stream=True,
incremental_output=True
)
# streaming chat with editable slot
with st.spinner(text=f"[通义千问-{model_name}]" + random.choice(HINT_TEXTS)):
answer = ""
reply_slot = st.empty()
for chunk in response:
c = chunk["output"]["choices"][0]
delta = c.get("message", {}).get("content", "")
finish_reason = c.get("finish_reason", "")
answer += delta
reply_slot.markdown(answer)
reply_slot.markdown("")
response = {
"choices": [
{
"message": {"content": answer, "role": "assistant"},
"finish_reason": finish_reason,
}
],
"usage": {
"prompt_tokens": chunk["usage"]["input_tokens"],
"completion_tokens": chunk["usage"]["output_tokens"],
"total_tokens": chunk["usage"]["total_tokens"]
},
}
return response
@st.cache_data(ttl=3600)
def chat_completion(
message_list,
model="gpt-3.5-turbo",
temperature=0.9,
max_tokens=256,
top_p=1,
frequency_penalty=0,
presence_penalty=0.6,
stream=False,
):
if model.startswith("gpt"):
return _chat_completion_gpt(
message_list,
model,
temperature,
max_tokens,
top_p,
frequency_penalty,
presence_penalty,
stream,
)
elif model.startswith("星火"):
answer = asyncio.run(
_chat_completion_spark(
message_list=message_list,
model=model,
temperature=temperature,
max_tokens=max_tokens,
)
)
return answer
elif model.startswith("通义千问"):
answer = asyncio.run(
_chat_completion_qwen(
message_list=message_list,
model=model,
temperature=temperature,
max_tokens=max_tokens
)
)
return answer
else:
st.error(f"无效的模型输入:{model}")
st.stop()
# Available Models
LANGUAGE_MODELS = ["星火V3.0", "星火V2.0", "gpt-3.5-turbo-16k", "通义千问-max-30k"]
CODEX_MODELS = ["code-davinci-002", "code-cushman-001"]
MAX_TOKEN_CONFIG = {
"星火V3.0": 8192,
"星火V2.0": 8192,
"gpt-3.5-turbo-16k": 16384,
"通义千问-max-30k": 30000
}
HINT_TEXTS = [
"正在接通电源,请稍等 ...",
"正在思考怎么回答,不要着急",
"正在努力查询字典内容 ...",
"等待对方回复中 ...",
"正在激活神经网络 ...",
"请稍等",
]
TOKEN_SAVING_HINT_THRESHOLD = 6000
# store chat as session state
DEFAULT_CHAT_TEXT = "以下是与AI助手的对话。助手乐于助人、有创意、聪明而且非常友好。\n\n"
if "input_text_state" not in st.session_state:
st.session_state.input_text_state = DEFAULT_CHAT_TEXT
if "conv_user" not in st.session_state:
st.session_state.conv_user = []
if "conv_robot" not in st.session_state:
st.session_state.conv_robot = []
if "user" not in st.session_state:
st.session_state["user"] = "new user"
get_token_counter().page_view()
if "seed" not in st.session_state:
st.session_state["seed"] = random.randint(0, 1000)
seed = st.session_state["seed"]
class ConversationCompressConfig:
def __init__(
self,
*,
enabled,
max_human_conv_reserve_count=None,
max_robot_conv_reserve_count=None,
enable_first_conv=None,
) -> None:
self.enabled = enabled
self.max_human_conv_reserve_count = max_human_conv_reserve_count
self.max_robot_conv_reserve_count = max_robot_conv_reserve_count
self.enable_first_conv = enable_first_conv
def get_message_list(self):
if self.enabled:
return self._get_compressed_message_list()
else:
return self._get_full_message_list()
@property
def message_tokens(self):
if self.enabled:
return self.compressed_message_tokens
else:
return self.full_message_tokens
@property
def full_message_tokens(self):
ms = self._get_full_message_list()
txt = "".join(m["content"] for m in ms)
tokens = get_tokenizer().tokenize(txt)
return len(tokens)
@property
def compressed_message_tokens(self):
ms = self._get_compressed_message_list()
txt = "".join(m["content"] for m in ms)
tokens = get_tokenizer().tokenize(txt)
return len(tokens)
def _get_full_message_list(self):
"""Get full message list (for Chat Completion)"""
message_list = []
# Add system prompt
if st.session_state["prompt_system"]:
message_list.append(
{"role": "system", "content": st.session_state["prompt_system"]}
)
# Add history conversations
for conv_user, conv_robot in zip(
st.session_state["conv_user"], st.session_state["conv_robot"]
):
message_list.append({"role": "user", "content": conv_user})
message_list.append({"role": "assistant", "content": conv_robot})
return message_list
def _get_compressed_message_list(self):
"""Get compressed message list (for Chat Completion)"""
message_list = []
# Add system prompt
if st.session_state["prompt_system"]:
message_list.append(
{"role": "system", "content": st.session_state["prompt_system"]}
)
# Add history conversations but compressed
turns_count = min(
len(st.session_state["conv_user"]), len(st.session_state["conv_robot"])
)
for turn_idx in range(turns_count):
should_keep_human = False # should keep human conversations at this turn
should_keep_robot = False # should keep robot conversations at this turn
if turn_idx == 0 and self.enable_first_conv:
should_keep_human, should_keep_robot = True, True
if turn_idx + self.max_human_conv_reserve_count >= turns_count:
should_keep_human = True
if turn_idx + self.max_robot_conv_reserve_count >= turns_count:
should_keep_robot = True
# Add conversations to message_list
if should_keep_human or should_keep_robot:
conv_user = (
st.session_state["conv_user"][turn_idx] if should_keep_human else ""
)
conv_robot = (
st.session_state["conv_robot"][turn_idx]
if should_keep_robot
else ""
)
message_list.append({"role": "user", "content": conv_user})
message_list.append({"role": "assistant", "content": conv_robot})
return message_list
def after_submit(
current_input,
model,
temperature,
max_tokens,
cc_config: ConversationCompressConfig,
stream=False,
):
# Append current_input to input_text_state
st.session_state.input_text_state += current_input
# Queue by prompt length and max_tokens
if model in LANGUAGE_MODELS:
token_number = cc_config.message_tokens
else:
token_number = len(get_tokenizer().tokenize(st.session_state.input_text_state))
x = token_number / MAX_TOKEN_CONFIG[model] * 3
delay = 2 * x * x - 3
delay += 2 * x * (max_tokens / 1024 - 1)
wait(delay, "前方排队中...")
# Send text and waiting for respond
if model in LANGUAGE_MODELS:
# Get system prompt + history conversations
message_list = cc_config.get_message_list()
# Add current user input
message_list.append({"role": "user", "content": current_input})
response = chat_completion(
message_list=message_list,
model=model,
temperature=temperature,
max_tokens=max_tokens,
top_p=1,
frequency_penalty=0,
presence_penalty=0.6,
stream=stream,
)
answer = response["choices"][0]["message"]["content"]
st.session_state.input_text_state += answer
else:
response = completion(
model=model,
prompt=st.session_state.input_text_state,
temperature=temperature,
max_tokens=max_tokens,
top_p=1,
frequency_penalty=0,
presence_penalty=0.6,
stop=[" Human:", " AI:"],
)
# TODO: non chat model also need stream function
answer = response["choices"][0]["text"]
# TODO: should check if answer starts with '\nAI:'
st.session_state.input_text_state += answer
st.session_state.input_text_state += "\nHuman: "
print(answer)
# Collect usage
tc = get_token_counter()
tc.collect(tokens=response["usage"]["total_tokens"])
return response, answer
def load_preset_id_from_url_link():
"""Load preset if it is provided in url link"""
query = st.query_params
preset_id = query.get("preset", [""])[0]
if preset_id:
for p in PROMPTS:
if preset_id == p["preset"]:
return preset_id
return None
def load_preset_qa(candidate=None):
"""Load default preset Q&A"""
preset = st.session_state.get("preset", candidate or "GPT-3 你问我答 (ChatBot)")
st.session_state["conv_user"].clear()
st.session_state["conv_robot"].clear()
# load prompt message into conversations
for p in PROMPTS:
if preset == p["preset"]:
st.session_state["input"] = p.get("input", "")
for message in p["message"]:
if message["role"] == "user":
st.session_state["conv_user"].append(message["content"])
elif message["role"] == "assistant":
st.session_state["conv_robot"].append(message["content"])
def append_to_input_text():
"""Restore input_text_state via chat history"""
if st.session_state.conv_robot:
st.session_state.input_text_state += "\nHuman: "
for i in range(len(st.session_state.conv_robot)):
st.session_state.input_text_state += st.session_state["conv_user"][i]
st.session_state.input_text_state += st.session_state["conv_robot"][i]
st.session_state.input_text_state += "\nHuman: "
def show_conversation_dialog(slot_list, rollback_fn, reverse_order=False):
"""Render the conversation dialogs"""
just_loaded_from_share = False
if (
"loaded_from_share" in st.session_state
and st.session_state["loaded_from_share"]
):
just_loaded_from_share = True
st.session_state["loaded_from_share"] = False
if not slot_list:
reverse_order = True
if st.session_state.conv_robot:
num = len(st.session_state.conv_robot)
# From user0, robot0, ..., user_{n-1}, robot_{n-1} in normal order
order_indexes = reversed(range(2 * num)) if reverse_order else range(2 * num)
for j in order_indexes:
slot = st.empty() if reverse_order else slot_list[j]
with slot:
is_user = j % 2 == 0
text = (
st.session_state["conv_user"][j // 2]
if is_user
else st.session_state["conv_robot"][j // 2]
)
message(
text,
is_user=is_user,
key=str(j),
seed=seed,
on_click=(rollback_fn if j == 2 * num - 1 else None),
)
if just_loaded_from_share:
time.sleep(1)
def show_edit_dialog(slot):
"""Show dialog that edits AI answer"""
with slot:
if len(st.session_state["conv_robot"]) > 0:
with st.expander("⭐ 手动编辑上一次AI回复的内容", expanded=True):
with st.form("edit_form"):
# 加载上一次AI回复的内容
st.session_state["edit_answer"] = st.session_state["conv_robot"][-1]
st.form_submit_button("📝 确认修改", on_click=edit_answer)
st.text_area("对话内容", key="edit_answer", height=800)
else:
st.warning("无法编辑!对话不存在")
def edit_answer():
# 修改上一次对话内容
if len(st.session_state["conv_robot"]) > 0:
txt = st.session_state["edit_answer"]
st.session_state["conv_robot"][-1] = txt
st.success("对话内容已修改")
def rollback():
# 移除最新的一轮对话
st.session_state["conv_robot"].pop()
user_input = st.session_state["conv_user"].pop()
# st.write('robot invoke', user_input)
st.session_state["input"] = user_input
# 恢复 / 保存
restore_from_share_link()
st.sidebar.title("✨ GPT Interface")
with st.sidebar.expander("🎈 预设身份的提示词 (Preset Prompts)", expanded=False):
preset_id_options = [p["preset"] for p in PROMPTS]
preset_id_options.append("自定义")
if "preset" not in st.session_state:
load_preset_qa(candidate=load_preset_id_from_url_link())
prompt_id = st.selectbox(
"预设身份的提示词",
options=preset_id_options,
index=0,
on_change=load_preset_qa,
key="preset",
label_visibility="collapsed",
)
# 动态更改标题和说明
st_tiltle_slot.title(prompt_id)
if get_description_by_preset_id(prompt_id) is not None:
st_desc_solt.text(get_description_by_preset_id(prompt_id))
_prompt_text = get_prompt_by_preset_id(prompt_id)
prompt_text = st.text_area(
"Enter Prompt",
value=_prompt_text,
placeholder="预设的Prompt",
label_visibility="collapsed",
key="prompt_system",
disabled=(_prompt_text != ""),
)
_suggestion = get_suggestion_by_preset_id(prompt_id)
if _suggestion:
st.warning(_suggestion, icon="⚠️")
st.session_state.input_text_state = prompt_text
append_to_input_text()
edit_answer_slot = st.empty()
# 对话保留设置
with st.sidebar.expander("⭐ 对话设置"):
enbale_conv_reserve = st.checkbox(
"开启对话压缩", value=False, help="若开启,仅会发送对话中的特定部分作为上下文\n\n若关闭,所有聊天内容都会作为上下文发送"
)
if enbale_conv_reserve:
max_robot_conv_reserve_count = st.number_input(
":hash: 仅保留最近AI回复对话数", 0, None, 3, help="设定最多保留多少次 AI 最近的回复内容"
)
max_human_conv_reserve_count = st.number_input(
":hash: 仅保留最近输入对话数", 0, None, 10, help="设定最多保留多少次最近输入的提问内容"
)
enable_first_conv = st.checkbox("必定保留第一轮对话", help="推荐在第一轮对话包含特殊设定时开启")
cc_config = ConversationCompressConfig(
enabled=True,
max_human_conv_reserve_count=max_human_conv_reserve_count,
max_robot_conv_reserve_count=max_robot_conv_reserve_count,
enable_first_conv=enable_first_conv,
)
full_tokens = cc_config.full_message_tokens
active_tokens = cc_config.compressed_message_tokens
st.caption(f"预估压缩前/后: `{active_tokens}`/ `{full_tokens}` tokens")
else:
cc_config = ConversationCompressConfig(enabled=False)
enable_reverse_order = st.checkbox(
"对话倒序显示", value=False, help="开启后,输入框在上方,最近的对话在最上方\n\n关闭后,输入框在下方,最早的对话在上方"
)
enable_stream_chat = st.checkbox("对话流式显示", value=True, help="开启后,以流式方式传输对话,无需等待")
if st.session_state["input_text_state"] and not enbale_conv_reserve:
tokens = get_tokenizer().tokenize(st.session_state["input_text_state"])
if len(tokens) > TOKEN_SAVING_HINT_THRESHOLD:
st.sidebar.info(f"👆 全文 Token 数 >= {TOKEN_SAVING_HINT_THRESHOLD},可考虑开启对话压缩功能")
model_val = clickable_select(LANGUAGE_MODELS, label="<b><i>模型选择:</i></b>", index=3)
if st.button("🗑️ 清除所有对话"):
st.session_state["input_text_state"] = ""
st.session_state["input"] = ""
st.session_state.conv_user.clear()
st.session_state.conv_robot.clear()
with st.form("my_form"):
dialog_slot_list = (
None
if enable_reverse_order
else [st.empty() for _ in range(2 + 2 * len(st.session_state["conv_user"]))]
)
col_icon, col_text, col_btn = st.columns((1, 10, 2))
col_icon.markdown(
f"""<img src="https://api.dicebear.com/5.x/{"lorelei"}/svg?seed={seed}" alt="avatar" />""",
unsafe_allow_html=True,
)
input_text = col_text.text_area(
"You: ", "", key="input", label_visibility="collapsed", height=150
)
with st.sidebar.expander("🧩 模型参数 (Model Parameters)"):
# moved to upper
# model_val = st.selectbox("Model", options=LANGUAGE_MODELS, index=0)
temperature_val = st.slider("Temperature", 0.0, 2.0, 0.8, step=0.05)
max_tokens_val = st.select_slider(
"Max Tokens", options=(256, 512, 1024, 2048), value=2048
)
# Every form must have a submit button.
submitted = col_btn.form_submit_button("💬")
if submitted:
response, answer = after_submit(
input_text,
model_val,
temperature_val,
max_tokens_val,
cc_config,
stream=enable_stream_chat,
)
st.session_state.conv_user.append(input_text)
st.session_state.conv_robot.append(answer)
finish_reason = response["choices"][0].get("finish_reason", "")
if finish_reason == "length":
st.sidebar.info("👆 上次输入因长度被截断,可考虑撤回该消息,并调大该参数后重试")
st.session_state['usage_total_tokens'] = response["usage"]["total_tokens"]
st.json(response["usage"])
show_conversation_dialog(
dialog_slot_list, rollback_fn=rollback, reverse_order=enable_reverse_order
)
# When the input_text_state is bind to widget, its content cannot be modified by session api.
with st.expander(""):
st.json(st.session_state.conv_robot, expanded=False)
st.json(st.session_state.conv_user, expanded=False)
txt = st.text_area("对话内容", key="input_text_state", height=800)
if 'usage_total_tokens' not in st.session_state:
tokens = get_tokenizer().tokenize(txt)
token_number = len(tokens)
st.session_state['usage_total_tokens'] = token_number
max_token_limit = MAX_TOKEN_CONFIG[model_val]
percent = st.session_state['usage_total_tokens'] / max_token_limit
st.progress(percent,
text="Total Tokens %: {:.0f}%".format(percent * 100))
st.write("全文的 Token 数:", st.session_state['usage_total_tokens'], f" (最大 Token 数:`{max_token_limit}`)")
if submitted:
st.json(response, expanded=False)
# st.write("temperature", temperature_val)
need_edit_answer = st.sidebar.button("🔬 编辑AI的回答(高级功能)")
if need_edit_answer:
show_edit_dialog(slot=edit_answer_slot)
# 恢复 / 保存
if st.sidebar.button("🔗 生成分享链接"):
share_link = generate_share_link()
st.sidebar.success(f"链接已生成 [右键复制]({share_link}) 有效期7天")
is_generate_image = st.sidebar.button("🖼️ 生成分享图片", key="image_button")
if is_generate_image:
image = conversation2png(
st.session_state["preset"],
st.session_state["conv_user"],
st.session_state["conv_robot"],
seed=seed,
)
st.image(image, caption="已生成图片,长按或右键保存")
"""---"""
with st.expander("访问统计"):
pv_stats, call_stats, token_stats = get_token_counter().summary()
tab1, tab2, tab3 = st.tabs(["Session View", "Request Count", "Token Count"])
df = pd.DataFrame(pv_stats.items(), columns=["time", "pv"])
fig = px.line(df, x="time", y="pv", title="Page View")
tab1.subheader(f"Total PV: {df['pv'].sum()}")
tab1.plotly_chart(fig)
tab2.subheader(f"Total Requests: {sum(call_stats.values())}")
tab2.write(call_stats)
tab3.subheader(f"Total Tokens: {sum(token_stats.values())}")
tab3.write(token_stats)