Skip to content

Commit

Permalink
Add zstd gateway compression to speed profile
Browse files Browse the repository at this point in the history
  • Loading branch information
LostLuma committed Sep 25, 2024
1 parent 59f877f commit d5db927
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 29 deletions.
23 changes: 11 additions & 12 deletions discord/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""

from __future__ import annotations

import asyncio
Expand All @@ -32,7 +33,6 @@
import time
import threading
import traceback
import zlib

from typing import Any, Callable, Coroutine, Deque, Dict, List, TYPE_CHECKING, NamedTuple, Optional, TypeVar, Tuple

Expand Down Expand Up @@ -325,8 +325,7 @@ def __init__(self, socket: aiohttp.ClientWebSocketResponse, *, loop: asyncio.Abs
# ws related stuff
self.session_id: Optional[str] = None
self.sequence: Optional[int] = None
self._zlib: zlib._Decompress = zlib.decompressobj()
self._buffer: bytearray = bytearray()
self._decompressor = utils._DecompressionContext()
self._close_code: Optional[int] = None
self._rate_limiter: GatewayRatelimiter = GatewayRatelimiter()

Expand Down Expand Up @@ -355,7 +354,7 @@ async def from_client(
sequence: Optional[int] = None,
resume: bool = False,
encoding: str = 'json',
zlib: bool = True,
compress: bool = True,
) -> Self:
"""Creates a main websocket for Discord from a :class:`Client`.
Expand All @@ -366,10 +365,12 @@ async def from_client(

gateway = gateway or cls.DEFAULT_GATEWAY

if zlib:
url = gateway.with_query(v=INTERNAL_API_VERSION, encoding=encoding, compress='zlib-stream')
else:
if not compress:
url = gateway.with_query(v=INTERNAL_API_VERSION, encoding=encoding)
else:
url = gateway.with_query(
v=INTERNAL_API_VERSION, encoding=encoding, compress=utils._DecompressionContext.STREAM_TYPE
)

socket = await client.http.ws_connect(str(url))
ws = cls(socket, loop=client.loop)
Expand Down Expand Up @@ -488,13 +489,11 @@ async def resume(self) -> None:

async def received_message(self, msg: Any, /) -> None:
if type(msg) is bytes:
self._buffer.extend(msg)
msg = self._decompressor.decompress(msg)

if len(msg) < 4 or msg[-4:] != b'\x00\x00\xff\xff':
# zlib and didn't end in Z_SYNC_FLUSH
if msg is None:
return
msg = self._zlib.decompress(self._buffer)
msg = msg.decode('utf-8')
self._buffer = bytearray()

self.log_receive(msg)
msg = utils._from_json(msg)
Expand Down
19 changes: 2 additions & 17 deletions discord/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -2628,28 +2628,13 @@ def end_poll(self, channel_id: Snowflake, message_id: Snowflake) -> Response[mes

# Misc

async def get_gateway(self, *, encoding: str = 'json', zlib: bool = True) -> str:
try:
data = await self.request(Route('GET', '/gateway'))
except HTTPException as exc:
raise GatewayNotFound() from exc
if zlib:
value = '{0}?encoding={1}&v={2}&compress=zlib-stream'
else:
value = '{0}?encoding={1}&v={2}'
return value.format(data['url'], encoding, INTERNAL_API_VERSION)

async def get_bot_gateway(self, *, encoding: str = 'json', zlib: bool = True) -> Tuple[int, str]:
async def get_bot_gateway(self) -> Tuple[int, str]:
try:
data = await self.request(Route('GET', '/gateway/bot'))
except HTTPException as exc:
raise GatewayNotFound() from exc

if zlib:
value = '{0}?encoding={1}&v={2}&compress=zlib-stream'
else:
value = '{0}?encoding={1}&v={2}'
return data['shards'], value.format(data['url'], encoding, INTERNAL_API_VERSION)
return data['shards'], data['url']

def get_user(self, user_id: Snowflake) -> Response[user.User]:
return self.request(Route('GET', '/users/{user_id}', user_id=user_id))
48 changes: 48 additions & 0 deletions discord/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""

from __future__ import annotations

import array
Expand Down Expand Up @@ -71,6 +72,7 @@
import typing
import warnings
import logging
import zlib

import yarl

Expand All @@ -81,6 +83,12 @@
else:
HAS_ORJSON = True

try:
import zstandard # type: ignore
except ImportError:
HAS_ZSTD = False
else:
HAS_ZSTD = True

__all__ = (
'oauth_url',
Expand Down Expand Up @@ -1406,3 +1414,43 @@ def _human_join(seq: Sequence[str], /, *, delimiter: str = ', ', final: str = 'o
return f'{seq[0]} {final} {seq[1]}'

return delimiter.join(seq[:-1]) + f' {final} {seq[-1]}'


if HAS_ZSTD:

class _ZstdDecompressionContext:
__slots__ = ('context',)

STREAM_TYPE = 'zstd-stream'

def __init__(self) -> None:
decompressor = zstandard.ZstdDecompressor()
self.context = decompressor.decompressobj()

def decompress(self, data: bytes, /) -> str | None:
return self.context.decompress(data).decode('utf-8')

_DecompressionContext = _ZstdDecompressionContext
else:

class _ZlibDecompressionContext:
__slots__ = ('context', 'buffer')

STREAM_TYPE = 'zlib-stream'

def __init__(self) -> None:
self.buffer: bytearray = bytearray()
self.context = zlib.decompressobj()

def decompress(self, data: bytes, /) -> str | None:
self.buffer.extend(data)

if len(data) < 4 or data[-4:] != b'\x00\x00\xff\xff':
return

msg = self.context.decompress(self.buffer)
msg = msg.decode('utf-8')
self.buffer = bytearray()
return msg

_DecompressionContext = _ZlibDecompressionContext
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ speed = [
"aiodns>=1.1; sys_platform != 'win32'",
"Brotli",
"cchardet==2.1.7; python_version < '3.10'",
"zstandard>=0.23.0"
]
test = [
"coverage[toml]",
Expand Down

0 comments on commit d5db927

Please sign in to comment.