Skip to content

Commit 90de5c6

Browse files
committed
seen greatest name column
1 parent 4589b88 commit 90de5c6

File tree

1 file changed

+18
-0
lines changed

1 file changed

+18
-0
lines changed

spells/extension.py

+18
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Callable
2+
13
import math
24

35
import polars as pl
@@ -13,6 +15,19 @@ def print_ext(ext: dict[str, ColSpec]) -> None:
1315
print("\t" + key)
1416

1517

18+
def seen_greatest_name_fn(attr: str) -> Callable:
19+
def inner(names: list[str]) -> pl.Expr:
20+
expr = pl.lit(None)
21+
for name in names:
22+
expr = (
23+
pl.when(pl.col(f"seen_{attr}_is_greatest_{name}"))
24+
.then(pl.lit(name))
25+
.otherwise(expr)
26+
)
27+
return expr
28+
return inner
29+
30+
1631
def context_cols(attr, silent: bool = False) -> dict[str, ColSpec]:
1732
ext = {
1833
f"seen_{attr}": ColSpec(
@@ -41,6 +56,9 @@ def context_cols(attr, silent: bool = False) -> dict[str, ColSpec]:
4156
expr=lambda name: pl.col(f"seen_{attr}_{name}")
4257
== pl.col(f"greatest_{attr}_seen"),
4358
),
59+
f"seen_greatest_{attr}_name": ColSpec(
60+
col_type=ColType.GROUP_BY, expr=seen_greatest_name_fn(attr)
61+
),
4462
f"seen_{attr}_greater": ColSpec(
4563
col_type=ColType.NAME_SUM,
4664
expr=lambda name: pl.col(f"seen_{attr}_{name}")

0 commit comments

Comments
 (0)