-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
205 lines (162 loc) · 6.68 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
import os
import asyncio
import traceback
from typing import Optional, Union
from amiyabot import Message,Chain
from core import log,Requirement
from core import bot as main_bot
from .src.core.trpg_storage import AmiyaBotChatGPTParamHistory,AmiyaBotChatGPTTRPGSpeechLog,AmiyaBotChatGPTExecutionLog
from .src.core.chatgpt_plugin_instance import ChatGPTPluginInstance
from .src.core.developer_types import BLMAdapter
from .src.deep_cosplay import DeepCosplay
from .src.assistant_amiya import AssistantAmiya
from .src.trpg import TRPGMode
from .src.online_troll import OnlineTrollMode
from .src.ask_amiya import AskAmiya
from .src.server.trpg_server import TRPGAPI # 导入Server类从而启动服务器
from .src.util.complex_math import frequency_controller
curr_dir = os.path.dirname(__file__)
bot : ChatGPTPluginInstance = None
def dynamic_get_global_config_schema_data():
if bot:
return bot.generate_global_schema()
else:
return f'{curr_dir}/global_config_default.json'
def dynamic_get_channel_config_schema_data():
if bot:
return bot.generate_channel_schema()
else:
return f'{curr_dir}/global_config_default.json'
bot = ChatGPTPluginInstance(
name='AI智能回复',
version='4.2.2',
plugin_id='amiyabot-hsyhhssyy-chatgpt',
plugin_type='',
description='调用"大语言模型库"插件智能回复普通对话',
document=f'{curr_dir}/README.md',
requirements=[
Requirement("amiyabot-blm-library")
],
channel_config_default=f'{curr_dir}/accessories/channel_config_default.json',
channel_config_schema=dynamic_get_channel_config_schema_data,
global_config_default=f'{curr_dir}/accessories/global_config_default.json',
global_config_schema=dynamic_get_global_config_schema_data,
)
def load():
AmiyaBotChatGPTParamHistory.create_table(safe=True)
AmiyaBotChatGPTTRPGSpeechLog.create_table(safe=True)
AmiyaBotChatGPTExecutionLog.create_table(safe=True)
bot.load = load
del load
channel_hander_context = {}
async def check_talk(data: Message):
enabled = bot.get_config('enable_in_this_channel',data.channel_id)
bot.debug_log(f'[{data.channel_id:<10}]在本频道启用: {enabled}')
if enabled != True:
return False, 0
# 临时排除纯阿拉伯数字的消息,等待兔妈修复
# 已修复,但是就先不移除了,以防万一
if data.text.isdigit():
return False,0
# 黑名单
black_list = bot.get_config('black_list',data.channel_id)
if black_list:
if str(data.user_id) in black_list:
bot.debug_log(f'[{data.channel_id:<10}]用户被黑名单屏蔽: {data.user_id}')
return False,0
if 'chat' in data.text.lower():
return True, 10
if data.text.upper().startswith("CHATGPT请问"):
if next(frequency_controller):
return True, 10
return True, -99999
@bot.on_message(verify=check_talk,check_prefix=False,allow_direct=True)
async def _(data: Message):
blm_lib : BLMAdapter = main_bot.plugins['amiyabot-blm-library']
if blm_lib is None:
bot.debug_log("未加载blm库,无法使用ChatGPT")
return
# bot.debug_log(f"触发进入ChatGPT插件 {not data.text}")
if not data.text and not data.image:
return
channel = data.channel_id
if channel is None:
channel = f'User:{data.user_id}'
try:
mode = bot.get_config('mode',channel)
except Exception as e:
bot.debug_log(
f'Unknown Error {e} \n {traceback.format_exc()}')
if data.text_original.upper().startswith("CHATGPT请问") or data.text_original.upper().startswith("文心一言请问"):
mode = "请问模式"
bot.debug_log(f'[{channel:<10}] 模式:{mode} 消息:{data.text_original}')
if mode == "请问模式":
model = bot.get_model_in_config('high_cost_model_name',channel)
content_to_send =[{ "type": "text", "text": data.text }]
vision = bot.get_config('vision_enabled',channel)
if vision == True:
if data.image and len(data.image) > 0:
content_to_send = content_to_send + [{"type":"image_url","url":imgPath} for imgPath in data.image]
bot.debug_log(content_to_send)
model = bot.get_model_in_config('vision_model_name',channel)
raw_answer = await blm_lib.chat_flow(
prompt=content_to_send,
channel_id = channel,
model=model
)
return Chain(data).text(raw_answer)
elif mode == "角色扮演":
try:
context = channel_hander_context.get(channel)
if context is None or not isinstance(context, DeepCosplay):
context = DeepCosplay(bot,blm_lib,channel,data.instance)
channel_hander_context[channel] = context
except Exception as e:
log.error(e)
return
await context.on_message(data)
elif mode == "助手模式":
try:
context = channel_hander_context.get(channel)
if context is None or not isinstance(context, AssistantAmiya):
context = AssistantAmiya(bot,blm_lib,channel)
channel_hander_context[channel] = context
except Exception as e:
log.error(e)
return
await context.on_message(data)
elif mode == "典孝急模式":
try:
context = channel_hander_context.get(channel)
if context is None or not isinstance(context, OnlineTrollMode):
context = OnlineTrollMode(bot,blm_lib,channel,data.instance)
channel_hander_context[channel] = context
except Exception as e:
log.error(e)
return
await context.on_message(data)
elif mode.startswith("跑团模式"):
try:
context = channel_hander_context.get(channel)
if context is None or not isinstance(context, TRPGMode):
context = TRPGMode(bot,blm_lib,channel,data.instance)
channel_hander_context[channel] = context
except Exception as e:
log.error(e)
return
await context.on_message(data)
else:
# 经典模式
channel = data.channel_id
if channel is None:
channel = f'User:{data.user_id}'
try:
context = channel_hander_context.get(channel)
if context is None or not isinstance(context, AskAmiya):
context = AskAmiya(bot,blm_lib,channel)
channel_hander_context[channel] = context
except Exception as e:
log.error(e)
return
await context.on_message(data)
return