-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathbot.py
305 lines (249 loc) · 10.2 KB
/
bot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
from functools import cache
from discord import (
Activity, ActivityType, AllowedMentions, ChannelType, Message, Intents, User
)
import discord
from discord.ext import commands
from discord.ext import tasks
import sqlite3
import asyncio_atexit
import config
import os
import logging
import aiohttp
from async_lru import alru_cache
from utils import ParrotMarkov, regex, tag
from database.corpus_manager import CorpusManager
from database.avatar_manager import AvatarManager
class Parrot(commands.AutoShardedBot):
def __init__(
self, *,
prefix: str,
db_path: str,
admin_user_ids: list[int],
admin_role_ids: list[int] | None=None,
):
self.destructor_called = False
logging.info(f"discord.py v{discord.__version__}")
intents = Intents.default()
intents.message_content = True # For learning
intents.members = config.ENABLE_IMITATE_SOMEONE
super().__init__(
command_prefix=prefix,
owner_ids=admin_user_ids,
case_insensitive=True,
allowed_mentions=AllowedMentions.none(),
activity=Activity(
name=f"everyone ({prefix}help)",
type=ActivityType.listening,
),
intents=intents,
)
self.admin_role_ids = admin_role_ids or []
self.finished_initializing = False
self.con = sqlite3.connect(db_path)
self.db = self.con.cursor()
self.db.executescript(
"""
BEGIN;
CREATE TABLE IF NOT EXISTS users (
id INTEGER PRIMARY KEY,
is_registered INTEGER NOT NULL DEFAULT 0,
original_avatar_url TEXT,
modified_avatar_url TEXT,
modified_avatar_message_id INTEGER
);
CREATE TABLE IF NOT EXISTS channels (
id INTEGER PRIMARY KEY,
can_speak_here INTEGER NOT NULL DEFAULT 0,
can_learn_here INTEGER NOT NULL DEFAULT 0,
webhook_id INTEGER
);
CREATE TABLE IF NOT EXISTS messages (
id INTEGER PRIMARY KEY,
user_id INTEGER NOT NULL REFERENCES users(id),
timestamp INTEGER NOT NULL,
content TEXT NOT NULL
);
CREATE TABLE IF NOT EXISTS guilds (
id INTEGER PRIMARY KEY,
imitation_prefix TEXT NOT NULL DEFAULT "Not ",
imitation_suffix TEXT NOT NULL DEFAULT ""
);
COMMIT;
"""
)
self.update_learning_channels()
self.update_speaking_channels()
self.update_registered_users()
async def _async__del__(self) -> None:
if self.destructor_called:
return
self.destructor_called = True
logging.info("Parrot shutting down...")
self.autosave.cancel()
await self.close()
await self.autosave()
logging.info("Closing HTTP session...")
await self.http_session.close()
logging.info("HTTP session closed.")
def __del__(self):
if self.destructor_called:
return
self.loop.run_until_complete(self._async__del__())
@commands.Cog.listener()
async def on_ready(self) -> None:
# on_ready also fires when the bot regains connection.
if self.finished_initializing:
logging.info("Logged back in.")
else:
logging.info(f"Logged in as {tag(self.user)}")
self.finished_initializing = True
async def setup_hook(self) -> None:
""" Constructor Part 2: Enter Async """
self.http_session = aiohttp.ClientSession(loop=self.loop)
# Parrot has to do async stuff as part of its destructor, so it can't
# actually use __del__, which is strictly synchronous. So we have to
# reinvent a little bit of the wheel and manually set a function to run
# when Parrot is about to be destroyed -- except instead we'll do it
# when the event loop is about to be closed.
asyncio_atexit.register(self._async__del__, loop=self.loop)
self.corpora = CorpusManager(
db=self.db,
get_registered_users=self.get_registered_users,
command_prefix=self.command_prefix,
)
self.avatars = AvatarManager(
loop=self.loop,
db=self.db,
http_session=self.http_session,
fetch_channel=self.fetch_channel,
)
self.autosave.start()
await self.load_extension("jishaku")
await self.load_folder("events")
await self.load_folder("commands")
async def load_folder(self, folder_name: str) -> None:
filenames = []
for filename in os.listdir(folder_name):
abs_path = os.path.join(folder_name, filename)
if os.path.isfile(abs_path):
filename = os.path.splitext(filename)[0]
filenames.append(filename)
for module in filenames:
path = f"{folder_name}.{module}"
try:
logging.info(f"Loading {path}... ")
await self.load_extension(path)
logging.info("✅")
except Exception as error:
logging.info("❌")
logging.error(f"{error}\n")
@tasks.loop(seconds=config.AUTOSAVE_INTERVAL_SECONDS)
async def autosave(self) -> None:
logging.info("Saving database...")
self.con.commit()
logging.info("Save complete.")
@alru_cache(maxsize=int(config.MODEL_CACHE_SIZE))
async def get_model(self, user: User) -> ParrotMarkov:
""" Get a Markov model by user ID. """
corpus = self.corpora.get(user)
return await ParrotMarkov.new(corpus)
def validate_message(self, message: Message) -> bool:
"""
A message must pass all of these checks before Parrot can learn from it.
"""
return (
# Text content not empty.
len(message.content) > 0 and
# Not a Parrot command.
not message.content.startswith(self.command_prefix) and
# Only learn in text channels, not DMs.
message.channel.type == ChannelType.text and
# Most bots' commands start with non-alphanumeric characters, so if
# a message starts with one other than a known Markdown character or
# special Discord character, Parrot should just avoid it because
# it's probably a command.
(
message.content[0].isalnum() or
regex.discord_string_start.match(message.content[0]) or
regex.markdown.match(message.content[0])
) and
# Don't learn from self.
message.author.id != self.user.id and
# Don't learn from Webhooks.
not message.webhook_id and
# Parrot must be allowed to learn in this channel.
message.channel.id in self.learning_channels and
# People will often say "v" or "z" on accident while spamming,
# and it doesn't really make for good learning material.
message.content not in ("v", "z")
)
def learn_from(self, messages: Message | list[Message]) -> int:
"""
Add a Message or list of Messages to a user's corpus.
Every Message in the list must be from the same user.
"""
# Ensure that messages is a list.
# If it's not, make it a list with one value.
if not isinstance(messages, list):
messages = [messages]
user = messages[0].author
# Every message in the list must have the same author, because the
# Corpus Manager adds every message passed to it to the same user.
for message in messages:
if message.author != user:
raise ValueError(
"Too many authors; every message in a list passed to"
"learn_from() must have the same author."
)
# Only keep messages that pass all of validate_message()'s checks.
messages = list(filter(self.validate_message, messages))
# Add these messages to this user's corpus and return the number of
# messages that were added.
if len(messages) > 0:
return self.corpora.add(user, messages)
return 0
def update_learning_channels(self) -> None:
""" Fetch and cache the set of channels that Parrot can learn from. """
res = self.db.execute("SELECT id FROM channels WHERE can_learn_here = 1")
self.learning_channels = {row[0] for row in res.fetchall()}
def update_speaking_channels(self) -> None:
""" Fetch and cache the set of channels that Parrot can speak in. """
res = self.db.execute("SELECT id FROM channels WHERE can_speak_here = 1")
self.speaking_channels = {row[0] for row in res.fetchall()}
def update_registered_users(self) -> None:
""" Fetch and cache the set of users who are registered. """
res = self.db.execute("SELECT id FROM users WHERE is_registered = 1")
self.registered_users = {row[0] for row in res.fetchall()}
def get_registered_users(self) -> set[int]:
return self.registered_users
@cache
def get_guild_prefix_suffix(self, guild_id: int) -> tuple[str, str]:
res = self.db.execute(
"""
SELECT imitation_prefix, imitation_suffix
FROM guilds
WHERE id = ?
""",
(guild_id,)
)
result = res.fetchone()
if result is None:
return "Not ", ""
return result
def find_text(self, message: Message) -> str:
"""
Search for text within a message.
Return an empty string if no text is found.
"""
text = []
if (
len(message.content) > 0 and
not message.content.startswith(self.command_prefix)
):
text.append(message.content)
for embed in message.embeds:
if isinstance(embed.description, str) and len(embed.description) > 0:
text.append(embed.description)
return " ".join(text)