Skip to content

Commit

Permalink
Merge pull request #94 from akrherz/gh93_atproto_thread_safety
Browse files Browse the repository at this point in the history
🐛 Take swing at thread safety for AT client
  • Loading branch information
akrherz authored Nov 15, 2024
2 parents 0418565 + a9101cb commit 0e2da09
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 67 deletions.
104 changes: 104 additions & 0 deletions src/iembot/atworker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
"""The ATmosphere Worker.
Need to do some hackery to workaround threadsafety issues with the
atproto module and my general lack of understanding of how to do
threadsafety in Python. So this is on me.
"""

import threading
from queue import Queue

import httpx
from atproto import Client
from atproto_client.utils import TextBuilder
from twisted.python import log


class ATWorkerThead(threading.Thread):
"""The Worker."""

def __init__(self, queue: Queue, at_handle: str, at_password: str):
"""Constructor."""
threading.Thread.__init__(self)
self.queue = queue
self.at_handle = at_handle
self.at_password = at_password
self.client = Client()
self.daemon = True # Don't block on shutdown

def run(self):
"""Listen for messages from the queue."""
while True:
# Grab the message from the queue
message = self.queue.get()
if message is None:
break
# Process the message
try:
self.process_message(message)
except Exception as exp:
print(message)
print(exp)
# Invalidate session
if hasattr(self.client, "session"):
delattr(self.client, "session")
self.queue.task_done()

def process_message(self, msgdict: dict):
"""Process the message."""
media = msgdict.get("twitter_media")
img = None
if media is not None:
try:
resp = httpx.get(media, timeout=30)
resp.raise_for_status()
img = resp.content
# AT has a size limit of 976.56KB
if len(img) > 1_000_000:
log.msg(f"{media} is too large({len(img)}) for AT")
img = None
except Exception as exp:
log.err(exp)

# Do we need to login?
if not hasattr(self.client, "session"):
log.msg(f"Logging in as {self.at_handle}...")
self.client.login(self.at_handle, self.at_password)

msg = msgdict["msg"]
if msg.find("http") > -1:
parts = msg.split("http")
msg = TextBuilder().text(parts[0]).link("Link", f"http{parts[1]}")

if img:
res = self.client.send_image(
msg, image=img, image_alt="IEMBot Image TBD"
)
else:
res = self.client.send_post(msg)
# for now
log.msg(repr(res))


class ATManager:
"""Ensure the creation of clients and submission of tasks is threadsafe."""

def __init__(self):
"""Constructor."""
self.at_clients = {}
self.lock = threading.Lock()

def add_client(self, at_handle: str, at_password: str):
"""Add a new client, if necessary."""
if at_handle in self.at_clients:
return
with self.lock:
self.at_clients[at_handle] = ATWorkerThead(
Queue(), at_handle, at_password
)
self.at_clients[at_handle].start()

def submit(self, at_handle: str, message: dict):
"""Submit a message to the client."""
self.at_clients[at_handle].queue.put(message)
17 changes: 4 additions & 13 deletions src/iembot/basicbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from twisted.words.xish import domish, xpath

import iembot.util as botutil
from iembot.atworker import ATManager

DATADIR = os.sep.join([os.path.dirname(__file__), "data"])
ROOM_LOG_ENTRY = namedtuple(
Expand Down Expand Up @@ -64,7 +65,7 @@ def __init__(
self.chatlog = {}
self.seqnum = 0
self.routingtable = {}
self.at_clients = {} # handle -> Client
self.at_manager = ATManager()
self.tw_users = {} # Storage by user_id => {screen_name: ..., oauth:}
self.tw_routingtable = {} # Storage by channel => [user_id, ]
# Storage by user_id => {access_token: ..., api_base_url: ...}
Expand Down Expand Up @@ -364,18 +365,8 @@ def tweet(self, user_id, twttxt, **kwargs):
Tweet a message
"""
twttxt = botutil.safe_twitter_text(twttxt)
adf = threads.deferToThread(
botutil.at_send_message,
self,
user_id,
twttxt,
**kwargs,
)
adf.addErrback(
botutil.email_error,
self,
f"User: {user_id}, Text: {twttxt} Hit double exception",
)
botutil.at_send_message(self, user_id, twttxt, **kwargs)

df = threads.deferToThread(
botutil.tweet,
self,
Expand Down
60 changes: 6 additions & 54 deletions src/iembot/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
from io import BytesIO
from zoneinfo import ZoneInfo

import atproto
import httpx
import mastodon
import requests
import twitter
Expand Down Expand Up @@ -46,57 +44,10 @@ def at_send_message(bot, user_id, msg: str, **kwargs):
"""Send a message to the ATmosphere."""
at_handle = bot.tw_users.get(user_id, {}).get("at_handle")
if at_handle is None:
return None
media = kwargs.get("twitter_media")
img = None
if media is not None:
try:
resp = httpx.get(media, timeout=30)
resp.raise_for_status()
img = resp.content
# AT has a size limit of 976.56KB
if len(img) > 1_000_000:
img = None
except Exception as exp:
log.err(exp)

if at_handle not in bot.at_clients:
# This is racey, so we need to not add the client until we are
# sure it is logged in
client = atproto.Client()
client.login(
at_handle,
bot.tw_users[user_id]["at_app_pass"],
)
# Again, racey, so we need to check again
if at_handle not in bot.at_clients:
bot.at_clients[at_handle] = client

if msg.find("http") > -1:
parts = msg.split("http")
msg = (
atproto.client_utils.TextBuilder()
.text(parts[0])
.link("link", f"http{parts[1]}")
)

for attempt in range(1, 4):
try:
if img:
res = bot.at_clients[at_handle].send_image(
msg, image=img, image_alt="IEMBot Image TBD"
)
else:
res = bot.at_clients[at_handle].send_post(msg)
break
except Exception as exp:
log.err(exp)
time.sleep(attempt * 5)
if attempt == 3:
raise exp
# for now
log.msg(repr(res))
return res
return
message = {"msg": msg}
message.update(kwargs)
bot.at_manager.submit(at_handle, message)


def tweet(bot, user_id, twttxt, **kwargs):
Expand Down Expand Up @@ -733,8 +684,9 @@ def load_twitter_from_db(txn, bot):
"access_token_secret": row["access_token_secret"],
"iem_owned": row["iem_owned"],
"at_handle": row["at_handle"],
"at_app_pass": row["at_app_pass"],
}
if row["at_handle"]:
bot.at_manager.add_client(row["at_handle"], row["at_app_pass"])
bot.tw_users = twusers
log.msg(f"load_twitter_from_db(): {txn.rowcount} oauth tokens found")

Expand Down

0 comments on commit 0e2da09

Please sign in to comment.