Skip to content

Commit

Permalink
Add tests serializing RMM headers
Browse files Browse the repository at this point in the history
  • Loading branch information
pentschev committed Feb 28, 2025
1 parent 26f784a commit 5649f73
Showing 1 changed file with 21 additions and 1 deletion.
22 changes: 21 additions & 1 deletion python/cudf_polars/tests/experimental/test_dask_serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from polars.testing.asserts import assert_frame_equal

import pylibcudf as plc
import rmm

from cudf_polars.containers import DataFrame
from cudf_polars.experimental.dask_serialize import register
Expand All @@ -18,6 +19,16 @@
register()


def convert_to_rmm(frame):
"""Convert frame to RMM to simulate Dask UCX transfers."""
if hasattr(frame, "__cuda_array_interface__"):
buf = rmm.DeviceBuffer(size=frame.nbytes)
buf.copy_from_device(frame)
return buf
else:
return frame


@pytest.mark.parametrize(
"arrow_tbl",
[
Expand All @@ -29,12 +40,18 @@
pa.table({"a": [1, 2, None], "b": [None, 3, 4]}),
],
)
@pytest.mark.parametrize("protocol", ["cuda", "dask"])
@pytest.mark.parametrize("protocol", ["cuda", "cuda_rmm", "dask"])
def test_dask_serialization_roundtrip(arrow_tbl, protocol):
plc_tbl = plc.interop.from_arrow(arrow_tbl)
df = DataFrame.from_table(plc_tbl, names=arrow_tbl.column_names)

cuda_rmm = protocol == "cuda_rmm"
protocol = "cuda" if protocol == "cuda_rmm" else protocol

header, frames = serialize(df, on_error="raise", serializers=[protocol])
if cuda_rmm:
# Simulate Dask UCX transfers
frames = [convert_to_rmm(f) for f in frames]
res = deserialize(header, frames, deserializers=[protocol])

assert_frame_equal(df.to_polars(), res.to_polars())
Expand All @@ -44,6 +61,9 @@ def test_dask_serialization_roundtrip(arrow_tbl, protocol):
expect = DataFrame([column])

header, frames = serialize(column, on_error="raise", serializers=[protocol])
if cuda_rmm:
# Simulate Dask UCX transfers
frames = [convert_to_rmm(f) for f in frames]
res = deserialize(header, frames, deserializers=[protocol])

assert_frame_equal(expect.to_polars(), DataFrame([res]).to_polars())

0 comments on commit 5649f73

Please sign in to comment.