Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add net_messages benchtest #46

Open
wants to merge 11 commits into
base: dev
Choose a base branch
from
24 changes: 20 additions & 4 deletions src/lifeblood/net_messages/impl/message_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,21 @@ async def connection_callback(self, reader: asyncio.StreamReader, writer: asynci
potentially_pending_tasks = []
# first what's sent is return address
try:
other_stream_source = await self.read_string(reader)
try:
init_message_waiter = asyncio.create_task(self.read_string(reader))
done, _ = await asyncio.wait([init_message_waiter, stop_waiter], return_when=asyncio.FIRST_COMPLETED)
if stop_waiter in done:
self.__logger.debug('explicitly asked to stop')
if init_message_waiter.done():
init_message_waiter.result() # this will re-raise exceptions if any
init_message_waiter.cancel()
stop_waiter = None
return
# otherwise it must be init_message_waiter that is done
assert init_message_waiter in done
other_stream_source = await init_message_waiter
except EOFError as e:
raise MessageReceivingError("failed to initialize message read", wrapped_exception=e) from None

message_stream = MessageReceiveStream(reader, writer,
this_address=DirectAddress(':'.join(str(x) for x in self.__listening_address)),
Expand Down Expand Up @@ -100,13 +114,15 @@ async def connection_callback(self, reader: asyncio.StreamReader, writer: asynci
await message_stream.acknowledge_received_message(success)
except MessageTransferError as mre:
e = mre.wrapped_exception()
msg = mre.args[0] if mre.args else None
msg_suffix = (f': ({msg})' if msg else '')
if isinstance(e, asyncio.exceptions.IncompleteReadError):
if len(e.partial) == 0:
self.__logger.debug('read 0 bytes, connection closed')
self.__logger.debug('read 0 bytes, connection closed %s', msg_suffix)
else:
self.__logger.error(f'read incomplete {len(e.partial)} bytes')
self.__logger.error(f'read incomplete {len(e.partial)} bytes %s', msg_suffix)
elif isinstance(e, asyncio.exceptions.TimeoutError):
self.__logger.warning(f'connection timeout happened')
self.__logger.warning('connection timeout happened')
elif isinstance(e, ConnectionResetError):
self.__logger.error('connection was reset. disconnected %s', e)
elif isinstance(e, ConnectionError):
Expand Down
40 changes: 30 additions & 10 deletions src/lifeblood/net_messages/impl/tcp_message_stream_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
from lifeblood.logging import get_logger
from datetime import datetime
from dataclasses import dataclass
from dataclasses import dataclass, field
from ..exceptions import MessageTransferError, MessageTransferTimeoutError
from ..interfaces import MessageStreamFactory
from ..stream_wrappers import MessageSendStream, MessageSendStreamBase
Expand All @@ -30,9 +30,18 @@ class ConnectionPoolEntry:
writer: asyncio.StreamWriter
last_used: datetime
users_count: int
last_ping_time: datetime = field(default_factory=lambda: datetime.now())
close_when_user_count_zero: bool = False
bad: bool = False

def check_bad(self) -> bool:
"""
check and update "bad" status of the entry, return updated bad value
"""
if not self.bad and (self.reader.at_eof() or self.writer.is_closing()):
self.bad = True
return self.bad


class ReusableMessageSendStream(MessageSendStream):
"""
Expand All @@ -59,7 +68,8 @@ def __init__(self,
async def send_raw_message(self, message: Message, *, message_delivery_timeout_override: Optional[float] = ...):
try:
return await super().send_raw_message(message, message_delivery_timeout_override=message_delivery_timeout_override)
except MessageTransferTimeoutError:
except MessageTransferError:
# even for MessageTransferTimeoutError:
# we cannot be sure some crap won't arrive after timeout,
# in that case future uses of this connection will be at risk of getting it,
# so it's safer to mark it for closure
Expand Down Expand Up @@ -91,15 +101,21 @@ class TcpMessageStreamPooledFactory(MessageStreamFactory):
_logger: Optional[logging.Logger] = None

def __init__(self,
pooled_connection_life: int = 0,
pooled_connection_life: float = 0,
connection_open_function: Optional[Callable[[DirectAddress, DirectAddress], Awaitable[Tuple[asyncio.StreamReader, asyncio.StreamWriter]]]] = None,
timeout: float = default_stream_timeout):
timeout: float = default_stream_timeout,
minimal_reping_interval: Optional[float] = None):
self.__pooled_connection_life = pooled_connection_life
self.__pool: Dict[Tuple[str, int], List[ConnectionPoolEntry]] = {}
self.__connection_open_func: Callable[[DirectAddress, DirectAddress], Awaitable[Tuple[asyncio.StreamReader, asyncio.StreamWriter]]] = connection_open_function or _initialize_connection
self.__open_connection_calls_count = 0
self.__pool_closed = asyncio.Event()
self.__timeout = timeout
# below is some arbitrary heuristics
if minimal_reping_interval is None:
self.__minimal_reping_interval = max(1, int(timeout/3))
else:
self.__minimal_reping_interval = minimal_reping_interval
if self._logger is None:
TcpMessageStreamPooledFactory._logger = get_logger('TcpMessageStreamPooledFactory')

Expand All @@ -115,7 +131,7 @@ async def close_unused_connections_older_than(self, older_than_this_seconds: flo
for key, entry_list in self.__pool.items():
new_entries = []
for entry in entry_list:
if not entry.bad \
if not entry.check_bad() \
and (entry.users_count > 0
or ((now - entry.last_used).total_seconds() < older_than_this_seconds
and not entry.close_when_user_count_zero)
Expand Down Expand Up @@ -151,9 +167,7 @@ def _get_cached_entry(self, host: str, port: int) -> Optional[ConnectionPoolEntr
for entry in entry_list:
if entry.users_count > 0 or entry.close_when_user_count_zero:
continue
if entry.bad or entry.reader.at_eof() or entry.writer.is_closing():
# to_remove.append(entry)
entry.bad = True
if entry.check_bad():
continue
selected = entry
# for entry in to_remove: # not the most optimal way......
Expand All @@ -173,7 +187,12 @@ async def open_sending_stream(self, destination: DirectAddress, source: DirectAd
stream_timeout=self.__timeout,
confirmation_timeout=self.__timeout)
try:
await stream.send_ping()
# this is a heuristics base on connection "freshness"
# "fresh" connections will most likely work, so extra ping will only slow things down
ping_now = datetime.now()
if (ping_now - entry.last_ping_time).total_seconds() >= self.__minimal_reping_interval:
await stream.send_ping()
entry.last_ping_time = ping_now
except MessageTransferError as e:
self._logger.debug('ping failed due to %s', e)
entry.bad = True
Expand All @@ -187,7 +206,8 @@ async def open_sending_stream(self, destination: DirectAddress, source: DirectAd
entry = ConnectionPoolEntry(reader,
writer,
datetime.now(),
0)
0,
datetime.now())
self.__pool.setdefault(key, []).append(entry)
assert entry is not None

Expand Down
178 changes: 178 additions & 0 deletions tests/tests_net_messages/test_benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
import os
import asyncio
import random
import string
import time
import threading
import multiprocessing
from unittest import IsolatedAsyncioTestCase, skip
from lifeblood.logging import get_logger, set_default_loglevel
from lifeblood.nethelpers import get_localhost
from lifeblood.net_messages.address import AddressChain, DirectAddress
from lifeblood.net_messages.messages import Message
from lifeblood.net_messages.client import MessageClient
from lifeblood.net_messages.exceptions import MessageSendingError, MessageTransferTimeoutError

from lifeblood.net_messages.impl.tcp_message_processor import TcpMessageProcessor, TcpMessageProxyProcessor

from typing import Callable, List, Type, Awaitable

set_default_loglevel('DEBUG')
logger = get_logger('message_test')


class NoopMessageServer(TcpMessageProcessor):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.test_messages_count = 0

# async def new_message_received(self, message: Message) -> bool:
# self.test_messages_count += 1
# return True

async def process_message(self, message: Message, client: MessageClient):
self.test_messages_count += 1


class FooRunner:
def start(self):
raise NotImplementedError()

def stop(self):
raise NotImplementedError()

def join(self):
raise NotImplementedError()

def get_message_count(self) -> int:
raise NotImplementedError()


class ThreadedFoo(threading.Thread, FooRunner):
def __init__(self, server: NoopMessageServer):
super().__init__()
self.__stop = threading.Event()
self.__ready = threading.Event()
self.__server = server

def run(self):
try:
asyncio.run(self.async_run())
except:
logger.exception("runner had an exception:")
raise

def start(self):
super().start()
self.__ready.wait()

async def async_run(self):
await self.__server.start()
self.__ready.set()
while True:
await asyncio.sleep(1)
if self.__stop.is_set():
break

self.__server.stop()
await self.__server.wait_till_stops()

def stop(self):
# crude crude crude
self.__stop.set()

def get_message_count(self) -> int:
return self.__server.test_messages_count


class ProcessedFoo(FooRunner):
def __init__(self, server: NoopMessageServer):
super().__init__()
self.__server = server
ctx = multiprocessing.get_context('spawn')
self.__stop = ctx.Event()
self.__value = ctx.Value('i', -1)
self.__ready = ctx.Event()
self.__proc = ctx.Process(target=self.body)

def start(self):
self.__proc.start()
self.__ready.wait()

def body(self):
asyncio.run(self.async_run())

async def async_run(self):
print('another process started')
await self.__server.start()
self.__ready.set()
print('another process server started')
while True:
await asyncio.sleep(1)
if self.__stop.is_set():
break

self.__server.stop()
await self.__server.wait_till_stops()
self.__value.value = self.__server.test_messages_count

def stop(self):
self.__stop.set()

def join(self):
self.__proc.join()

def get_message_count(self) -> int:
return self.__value.value


class TestBenchmarkSendReceive(IsolatedAsyncioTestCase):
@skip("no reason to benchmark on slow machines")
async def test_threaded(self):
await self.helper_test(ThreadedFoo)

@skip("no reason to benchmark on slow machines")
async def test_proc(self):
await self.helper_test(ProcessedFoo)

async def helper_test(self, foo_factory: Callable[[NoopMessageServer], FooRunner]):
"""
runs 2 servers
starts X clients asyncio coroutines on one of the servers, each sends Y messages.
average per-message time is then calculated
"""
data = ''.join(random.choice(string.ascii_letters) for _ in range(16000)).encode('latin1')
server1 = NoopMessageServer((get_localhost(), 28385))
server2 = NoopMessageServer((get_localhost(), 28386))
server1_runner = foo_factory(server1) # TODO: server1 is created in one loop, passed to another, this may cause problems with events
server1_runner.start()
await server2.start()
pure_send_time = 0.0

messages_per_client = 10
total_clients = 100

async def test_foo():
nonlocal pure_send_time
with server2.message_client(AddressChain(f'{get_localhost()}:28385')) as client: # type: MessageClient
beforesend = time.perf_counter()
for _ in range(messages_per_client):
await client.send_message(data)
pure_send_time += time.perf_counter() - beforesend

tasks = []
for _ in range(total_clients):
tasks.append(asyncio.create_task(test_foo()))

timestamp = time.perf_counter()
await asyncio.wait(tasks, return_when=asyncio.ALL_COMPLETED)
total_time = time.perf_counter() - timestamp
pure_send_time /= total_clients

server2.stop()
server1_runner.stop()
await server2.wait_till_stops()
server1_runner.join()
s1_message_count = server1_runner.get_message_count()
print(f'threaded total go {s1_message_count} in {total_time}s (pure send: {pure_send_time}s, avg {s1_message_count/total_time} (pure: {s1_message_count/pure_send_time}) msg/s')
self.assertEqual(total_clients * messages_per_client, s1_message_count)
14 changes: 11 additions & 3 deletions tests/tests_net_messages/test_connection_pool_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,15 +120,22 @@ def _create_protocol():
while srv is None:
await asyncio.sleep(0.1)

stream_stash = []
timeout = 2
pooled_factory = TcpMessageStreamPooledFactory(
timeout,
_fake_conn_opener_factory(stream_stash, _initialize_connection),
timeout=5,
minimal_reping_interval=0,
)

try:
stream_stash = []
timeout = 2
pooled_factory = TcpMessageStreamPooledFactory(timeout, _fake_conn_opener_factory(stream_stash, _initialize_connection), timeout=5)

addr = DirectAddress(f'{get_localhost()}:29361'), DirectAddress(f'{get_localhost()}:29360')

print('attempting connection')
for _ in range(3):
print('sending message')
stream0 = await pooled_factory.open_sending_stream(*addr)
stream0.close()
await stream0.wait_closed()
Expand All @@ -141,6 +148,7 @@ def _create_protocol():
await prt.last_writer.wait_closed()

for _ in range(3):
print('sending message')
stream0 = await pooled_factory.open_sending_stream(*addr)
await stream0.send_data_message(b'foo', addr[0], session=uuid.uuid4())
stream0.close()
Expand Down
4 changes: 2 additions & 2 deletions tests/tests_net_messages/test_messageprocessor_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@


class TestReceiver(TcpMessageProcessor):
def __init__(self, listening_host: str, listening_port: int, *, backlog=None, artificial_delay: float = 0, stream_timeout=None, default_client_retry_attempts=None):
def __init__(self, listening_host: str, listening_port: int, *, backlog=None, artificial_delay: float = 0, stream_timeout=120, default_client_retry_attempts=None):
super().__init__((listening_host, listening_port), backlog=backlog, stream_timeout=stream_timeout, default_client_retry_attempts=default_client_retry_attempts)
self.messages_received: List[Message] = []
self.__artificial_delay: float = artificial_delay
Expand Down Expand Up @@ -54,7 +54,7 @@ async def process_message(self, message: Message, client: MessageClient):
class DummyReceiverWithReply(TestReceiver):
_counter = 0

def __init__(self, listening_host: str, listening_port: int, *, backlog=None, artificial_delay: float = 0, stream_timeout=None, default_client_retry_attempts=None):
def __init__(self, listening_host: str, listening_port: int, *, backlog=None, artificial_delay: float = 0, stream_timeout=120, default_client_retry_attempts=None):
super().__init__(listening_host=listening_host,
listening_port=listening_port,
backlog=backlog,
Expand Down
Loading