Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type annotations in aiokafka/codec.py #984

Merged
merged 1 commit into from
Feb 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ SCALA_VERSION?=2.13
KAFKA_VERSION?=2.8.1
DOCKER_IMAGE=aiolibs/kafka:$(SCALA_VERSION)_$(KAFKA_VERSION)
DIFF_BRANCH=origin/master
FORMATTED_AREAS=aiokafka/util.py aiokafka/structs.py
FORMATTED_AREAS=aiokafka/util.py aiokafka/structs.py aiokafka/codec.py tests/test_codec.py

.PHONY: setup
setup:
Expand Down
44 changes: 24 additions & 20 deletions aiokafka/codec.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from __future__ import annotations

import gzip
import io
import struct

from typing_extensions import Buffer

_XERIAL_V1_HEADER = (-126, b"S", b"N", b"A", b"P", b"P", b"Y", 0, 1, 1)
_XERIAL_V1_FORMAT = "bccccccBii"
ZSTD_MAX_OUTPUT_SIZE = 1024 * 1024
Expand All @@ -12,23 +16,23 @@
cramjam = None


def has_gzip():
def has_gzip() -> bool:
return True


def has_snappy():
def has_snappy() -> bool:
return cramjam is not None


def has_zstd():
def has_zstd() -> bool:
return cramjam is not None


def has_lz4():
def has_lz4() -> bool:
return cramjam is not None


def gzip_encode(payload, compresslevel=None):
def gzip_encode(payload: Buffer, compresslevel: int | None = None) -> bytes:
if not compresslevel:
compresslevel = 9

Expand All @@ -45,7 +49,7 @@ def gzip_encode(payload, compresslevel=None):
return buf.getvalue()


def gzip_decode(payload):
def gzip_decode(payload: Buffer) -> bytes:
buf = io.BytesIO(payload)

# Gzip context manager introduced in python 2.7
Expand All @@ -57,7 +61,9 @@ def gzip_decode(payload):
gzipper.close()


def snappy_encode(payload, xerial_compatible=True, xerial_blocksize=32 * 1024):
def snappy_encode(
payload: Buffer, xerial_compatible: bool = True, xerial_blocksize: int = 32 * 1024
) -> bytes:
"""Encodes the given data with snappy compression.

If xerial_compatible is set then the stream is encoded in a fashion
Expand Down Expand Up @@ -93,12 +99,9 @@ def snappy_encode(payload, xerial_compatible=True, xerial_blocksize=32 * 1024):
for fmt, dat in zip(_XERIAL_V1_FORMAT, _XERIAL_V1_HEADER):
out.write(struct.pack("!" + fmt, dat))

# Chunk through buffers to avoid creating intermediate slice copies
def chunker(payload, i, size):
return memoryview(payload)[i : size + i]

payload = memoryview(payload)
for chunk in (
chunker(payload, i, xerial_blocksize)
payload[i : i + xerial_blocksize]
for i in range(0, len(payload), xerial_blocksize)
):
block = cramjam.snappy.compress_raw(chunk)
Expand All @@ -109,7 +112,7 @@ def chunker(payload, i, size):
return out.getvalue()


def _detect_xerial_stream(payload):
def _detect_xerial_stream(payload: Buffer) -> bool:
"""Detects if the data given might have been encoded with the blocking mode
of the xerial snappy library.

Expand All @@ -131,20 +134,21 @@ def _detect_xerial_stream(payload):
1.
"""

payload = memoryview(payload)
if len(payload) > 16:
header = struct.unpack("!" + _XERIAL_V1_FORMAT, memoryview(payload)[:16])
header = struct.unpack("!" + _XERIAL_V1_FORMAT, payload[:16])
return header == _XERIAL_V1_HEADER
return False


def snappy_decode(payload):
def snappy_decode(payload: Buffer) -> bytes:
if not has_snappy():
raise NotImplementedError("Snappy codec is not available")

if _detect_xerial_stream(payload):
# TODO ? Should become a fileobj ?
out = io.BytesIO()
byt = payload[16:]
byt = memoryview(payload)[16:]
length = len(byt)
cursor = 0

Expand All @@ -162,7 +166,7 @@ def snappy_decode(payload):
return bytes(cramjam.snappy.decompress_raw(payload))


def lz4_encode(payload, level=9):
def lz4_encode(payload: Buffer, level: int = 9) -> bytes:
# level=9 is used by default by broker itself
# https://cwiki.apache.org/confluence/display/KAFKA/KIP-390%3A+Support+Compression+Level
if not has_lz4():
Expand All @@ -177,14 +181,14 @@ def lz4_encode(payload, level=9):
return bytes(compressor.finish())


def lz4_decode(payload):
def lz4_decode(payload: Buffer) -> bytes:
if not has_lz4():
raise NotImplementedError("LZ4 codec is not available")

return bytes(cramjam.lz4.decompress(payload))


def zstd_encode(payload, level=None):
def zstd_encode(payload: Buffer, level: int | None = None) -> bytes:
if not has_zstd():
raise NotImplementedError("Zstd codec is not available")

Expand All @@ -196,7 +200,7 @@ def zstd_encode(payload, level=None):
return bytes(cramjam.zstd.compress(payload, level=level))


def zstd_decode(payload):
def zstd_decode(payload: Buffer) -> bytes:
if not has_zstd():
raise NotImplementedError("Zstd codec is not available")

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ dynamic = ["version"]
dependencies = [
"async-timeout",
"packaging",
"typing_extensions >=4.6.0",
]

[project.optional-dependencies]
Expand Down
16 changes: 8 additions & 8 deletions tests/test_codec.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,23 @@
from ._testutil import random_string


def test_gzip():
def test_gzip() -> None:
for i in range(1000):
b1 = random_string(100)
b2 = gzip_decode(gzip_encode(b1))
assert b1 == b2


@pytest.mark.skipif(not has_snappy(), reason="Snappy not available")
def test_snappy():
def test_snappy() -> None:
for i in range(1000):
b1 = random_string(100)
b2 = snappy_decode(snappy_encode(b1))
assert b1 == b2


@pytest.mark.skipif(not has_snappy(), reason="Snappy not available")
def test_snappy_detect_xerial():
def test_snappy_detect_xerial() -> None:
_detect_xerial_stream = codecs._detect_xerial_stream

header = b"\x82SNAPPY\x00\x00\x00\x00\x01\x00\x00\x00\x01Some extra bytes"
Expand All @@ -55,7 +55,7 @@ def test_snappy_detect_xerial():


@pytest.mark.skipif(not has_snappy(), reason="Snappy not available")
def test_snappy_decode_xerial():
def test_snappy_decode_xerial() -> None:
header = b"\x82SNAPPY\x00\x00\x00\x00\x01\x00\x00\x00\x01"
random_snappy = snappy_encode(b"SNAPPY" * 50, xerial_compatible=False)
block_len = len(random_snappy)
Expand All @@ -73,7 +73,7 @@ def test_snappy_decode_xerial():


@pytest.mark.skipif(not has_snappy(), reason="Snappy not available")
def test_snappy_encode_xerial():
def test_snappy_encode_xerial() -> None:
to_ensure = (
b"\x82SNAPPY\x00\x00\x00\x00\x01\x00\x00\x00\x01"
b"\x00\x00\x00\x18\xac\x02\x14SNAPPY\xfe\x06\x00\xfe\x06\x00\xfe\x06\x00"
Expand All @@ -88,7 +88,7 @@ def test_snappy_encode_xerial():


@pytest.mark.skipif(not has_lz4(), reason="LZ4 not available")
def test_lz4():
def test_lz4() -> None:
for i in range(1000):
b1 = random_string(100)
b2 = lz4_decode(lz4_encode(b1))
Expand All @@ -97,7 +97,7 @@ def test_lz4():


@pytest.mark.skipif(not has_lz4(), reason="LZ4 not available")
def test_lz4_incremental():
def test_lz4_incremental() -> None:
for i in range(1000):
# lz4 max single block size is 4MB
# make sure we test with multiple-blocks
Expand All @@ -108,7 +108,7 @@ def test_lz4_incremental():


@pytest.mark.skipif(not has_zstd(), reason="Zstd not available")
def test_zstd():
def test_zstd() -> None:
for _ in range(1000):
b1 = random_string(100)
b2 = zstd_decode(zstd_encode(b1))
Expand Down
Loading