diff --git a/discord/gateway.py b/discord/gateway.py index b8936bf5708b..a6708af97376 100644 --- a/discord/gateway.py +++ b/discord/gateway.py @@ -21,6 +21,7 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ + from __future__ import annotations import asyncio @@ -32,7 +33,6 @@ import time import threading import traceback -import zlib from typing import Any, Callable, Coroutine, Deque, Dict, List, TYPE_CHECKING, NamedTuple, Optional, TypeVar, Tuple @@ -325,8 +325,7 @@ def __init__(self, socket: aiohttp.ClientWebSocketResponse, *, loop: asyncio.Abs # ws related stuff self.session_id: Optional[str] = None self.sequence: Optional[int] = None - self._zlib: zlib._Decompress = zlib.decompressobj() - self._buffer: bytearray = bytearray() + self._decompressor = utils._DecompressionContext() self._close_code: Optional[int] = None self._rate_limiter: GatewayRatelimiter = GatewayRatelimiter() @@ -355,7 +354,7 @@ async def from_client( sequence: Optional[int] = None, resume: bool = False, encoding: str = 'json', - zlib: bool = True, + compress: bool = True, ) -> Self: """Creates a main websocket for Discord from a :class:`Client`. @@ -366,10 +365,12 @@ async def from_client( gateway = gateway or cls.DEFAULT_GATEWAY - if zlib: - url = gateway.with_query(v=INTERNAL_API_VERSION, encoding=encoding, compress='zlib-stream') - else: + if not compress: url = gateway.with_query(v=INTERNAL_API_VERSION, encoding=encoding) + else: + url = gateway.with_query( + v=INTERNAL_API_VERSION, encoding=encoding, compress=utils._DecompressionContext.STREAM_TYPE + ) socket = await client.http.ws_connect(str(url)) ws = cls(socket, loop=client.loop) @@ -488,13 +489,11 @@ async def resume(self) -> None: async def received_message(self, msg: Any, /) -> None: if type(msg) is bytes: - self._buffer.extend(msg) + msg = self._decompressor.decompress(msg) - if len(msg) < 4 or msg[-4:] != b'\x00\x00\xff\xff': + # zlib and didn't end in Z_SYNC_FLUSH + if msg is None: return - msg = self._zlib.decompress(self._buffer) - msg = msg.decode('utf-8') - self._buffer = bytearray() self.log_receive(msg) msg = utils._from_json(msg) diff --git a/discord/http.py b/discord/http.py index 6230f9b1da16..1bd29d7b1d82 100644 --- a/discord/http.py +++ b/discord/http.py @@ -2628,28 +2628,13 @@ def end_poll(self, channel_id: Snowflake, message_id: Snowflake) -> Response[mes # Misc - async def get_gateway(self, *, encoding: str = 'json', zlib: bool = True) -> str: - try: - data = await self.request(Route('GET', '/gateway')) - except HTTPException as exc: - raise GatewayNotFound() from exc - if zlib: - value = '{0}?encoding={1}&v={2}&compress=zlib-stream' - else: - value = '{0}?encoding={1}&v={2}' - return value.format(data['url'], encoding, INTERNAL_API_VERSION) - - async def get_bot_gateway(self, *, encoding: str = 'json', zlib: bool = True) -> Tuple[int, str]: + async def get_bot_gateway(self) -> Tuple[int, str]: try: data = await self.request(Route('GET', '/gateway/bot')) except HTTPException as exc: raise GatewayNotFound() from exc - if zlib: - value = '{0}?encoding={1}&v={2}&compress=zlib-stream' - else: - value = '{0}?encoding={1}&v={2}' - return data['shards'], value.format(data['url'], encoding, INTERNAL_API_VERSION) + return data['shards'], data['url'] def get_user(self, user_id: Snowflake) -> Response[user.User]: return self.request(Route('GET', '/users/{user_id}', user_id=user_id)) diff --git a/discord/utils.py b/discord/utils.py index 89cc8bdebfce..75b7d40daffb 100644 --- a/discord/utils.py +++ b/discord/utils.py @@ -21,6 +21,7 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ + from __future__ import annotations import array @@ -71,6 +72,7 @@ import typing import warnings import logging +import zlib import yarl @@ -81,6 +83,12 @@ else: HAS_ORJSON = True +try: + import zstandard # type: ignore +except ImportError: + _HAS_ZSTD = False +else: + _HAS_ZSTD = True __all__ = ( 'oauth_url', @@ -1406,3 +1414,43 @@ def _human_join(seq: Sequence[str], /, *, delimiter: str = ', ', final: str = 'o return f'{seq[0]} {final} {seq[1]}' return delimiter.join(seq[:-1]) + f' {final} {seq[-1]}' + + +if _HAS_ZSTD: + + class _ZstdDecompressionContext: + __slots__ = ('context',) + + STREAM_TYPE = 'zstd-stream' + + def __init__(self) -> None: + decompressor = zstandard.ZstdDecompressor() + self.context = decompressor.decompressobj() + + def decompress(self, data: bytes, /) -> str | None: + return self.context.decompress(data).decode('utf-8') + + _DecompressionContext = _ZstdDecompressionContext +else: + + class _ZlibDecompressionContext: + __slots__ = ('context', 'buffer') + + STREAM_TYPE = 'zlib-stream' + + def __init__(self) -> None: + self.buffer: bytearray = bytearray() + self.context = zlib.decompressobj() + + def decompress(self, data: bytes, /) -> str | None: + self.buffer.extend(data) + + if len(data) < 4 or data[-4:] != b'\x00\x00\xff\xff': + return + + msg = self.context.decompress(self.buffer) + self.buffer = bytearray() + + return msg.decode('utf-8') + + _DecompressionContext = _ZlibDecompressionContext diff --git a/pyproject.toml b/pyproject.toml index 596e6ef0874d..4ec7bc007de7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,6 +56,7 @@ speed = [ "aiodns>=1.1; sys_platform != 'win32'", "Brotli", "cchardet==2.1.7; python_version < '3.10'", + "zstandard>=0.23.0" ] test = [ "coverage[toml]",