Skip to content

Commit

Permalink
Add a slice expression to polars IR (#18050)
Browse files Browse the repository at this point in the history
Closes #18051

Authors:
  - Matthew Murray (https://github.com/Matt711)

Approvers:
  - Matthew Roeschke (https://github.com/mroeschke)

URL: #18050
  • Loading branch information
Matt711 authored Feb 24, 2025
1 parent 2b6dcb0 commit bcff1f7
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 2 deletions.
4 changes: 3 additions & 1 deletion python/cudf_polars/cudf_polars/containers/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ def filter(self, mask: Column) -> Self:
table = plc.stream_compaction.apply_boolean_mask(self.table, mask.obj)
return type(self).from_table(table, self.column_names).sorted_like(self)

def slice(self, zlice: tuple[int, int] | None) -> Self:
def slice(self, zlice: tuple[int, int | None] | None) -> Self:
"""
Slice a dataframe.
Expand All @@ -312,6 +312,8 @@ def slice(self, zlice: tuple[int, int] | None) -> Self:
if zlice is None:
return self
start, length = zlice
if length is None:
length = self.num_rows
if start < 0:
start += self.num_rows
# Polars implementation wraps negative start by num_rows, then
Expand Down
4 changes: 3 additions & 1 deletion python/cudf_polars/cudf_polars/dsl/expr.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0
# TODO: remove need for this
# ruff: noqa: D101
Expand Down Expand Up @@ -30,6 +30,7 @@
from cudf_polars.dsl.expressions.literal import Literal, LiteralColumn
from cudf_polars.dsl.expressions.rolling import GroupedRollingWindow, RollingWindow
from cudf_polars.dsl.expressions.selection import Filter, Gather
from cudf_polars.dsl.expressions.slicing import Slice
from cudf_polars.dsl.expressions.sorting import Sort, SortBy
from cudf_polars.dsl.expressions.string import StringFunction
from cudf_polars.dsl.expressions.ternary import Ternary
Expand All @@ -53,6 +54,7 @@
"LiteralColumn",
"NamedExpr",
"RollingWindow",
"Slice",
"Sort",
"SortBy",
"StringFunction",
Expand Down
51 changes: 51 additions & 0 deletions python/cudf_polars/cudf_polars/dsl/expressions/slicing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0
# TODO: remove need for this
# ruff: noqa: D101
"""Slicing DSL nodes."""

from __future__ import annotations

from typing import TYPE_CHECKING

from cudf_polars.dsl.expressions.base import (
ExecutionContext,
Expr,
)

if TYPE_CHECKING:
from collections.abc import Mapping

import pylibcudf as plc

from cudf_polars.containers import Column, DataFrame


__all__ = ["Slice"]


class Slice(Expr):
__slots__ = ("length", "offset")
_non_child = ("dtype", "offset", "length")

def __init__(
self,
dtype: plc.DataType,
offset: int,
length: int,
column: Expr,
) -> None:
self.dtype = dtype
self.offset = offset
self.length = length
self.children = (column,)

def do_evaluate(
self,
df: DataFrame,
*,
context: ExecutionContext = ExecutionContext.FRAME,
mapping: Mapping[Expr, Column] | None = None,
) -> Column:
"""Evaluate this expression given a dataframe for context."""
return df.slice((self.offset, self.length)).columns[0]
14 changes: 14 additions & 0 deletions python/cudf_polars/cudf_polars/dsl/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,6 +690,20 @@ def _(node: pl_expr.SortBy, translator: Translator, dtype: plc.DataType) -> expr
)


@_translate_expr.register
def _(node: pl_expr.Slice, translator: Translator, dtype: plc.DataType) -> expr.Expr:
offset = translator.translate_expr(n=node.offset)
length = translator.translate_expr(n=node.length)
assert isinstance(offset, expr.Literal)
assert isinstance(length, expr.Literal)
return expr.Slice(
dtype,
offset.value.as_py(),
length.value.as_py(),
translator.translate_expr(n=node.input),
)


@_translate_expr.register
def _(node: pl_expr.Gather, translator: Translator, dtype: plc.DataType) -> expr.Expr:
return expr.Gather(
Expand Down
24 changes: 24 additions & 0 deletions python/cudf_polars/tests/expressions/test_slice.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations

import pytest

import polars as pl

from cudf_polars.testing.asserts import assert_gpu_result_equal


@pytest.mark.parametrize(
"zlice",
[
(1,),
(1, 3),
(-1,),
],
)
def test_slice(zlice):
df = pl.LazyFrame({"a": [0, 1, 2, 3], "b": [1, 2, 3, 4]})
q = df.select(pl.col("a").slice(*zlice))

assert_gpu_result_equal(q)

0 comments on commit bcff1f7

Please sign in to comment.