Skip to content

Commit 5c0a110

Browse files
committed
pool and pick sum extensions
1 parent d4b9723 commit 5c0a110

File tree

1 file changed

+22
-2
lines changed

1 file changed

+22
-2
lines changed

spells/extension.py

+22-2
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def inner(names: list[str]) -> pl.Expr:
2525
.otherwise(expr)
2626
)
2727
return expr
28+
2829
return inner
2930

3031

@@ -49,7 +50,26 @@ def context_cols(attr, silent: bool = True) -> dict[str, ColSpec]:
4950
else card_context[name][attr],
5051
),
5152
f"pick_{attr}": ColSpec(
52-
col_type=ColType.AGG, expr=pl.col(f"pick_{attr}_sum") / pl.col("num_taken")
53+
col_type=ColType.GROUP_BY, expr=pl.col(f"pick_{attr}_sum")
54+
),
55+
f"pool_{attr}": ColSpec(
56+
col_type=ColType.NAME_SUM,
57+
expr=(
58+
lambda name, card_context: pl.lit(None)
59+
if card_context[name].get(attr) is None
60+
or math.isnan(card_context[name][attr])
61+
else card_context[name][attr] * pl.col(f"pool_{name}")
62+
),
63+
),
64+
f"pool_{attr}_sum": ColSpec(
65+
col_type=ColType.PICK_SUM,
66+
expr=lambda names: pl.sum_horizontal(
67+
[pl.col(f"pool_{attr}_{name}") for name in names]
68+
),
69+
),
70+
f"pool_pick_{attr}_sum": ColSpec(
71+
col_type=ColType.PICK_SUM,
72+
expr=pl.col(f"pick_{attr}_sum") + pl.col(f"pool_{attr}_sum"),
5373
),
5474
f"seen_{attr}_is_greatest": ColSpec(
5575
col_type=ColType.NAME_SUM,
@@ -79,7 +99,7 @@ def context_cols(attr, silent: bool = True) -> dict[str, ColSpec]:
7999
col_type=ColType.PICK_SUM,
80100
expr=lambda names: pl.sum_horizontal(
81101
[pl.col(f"seen_{attr}_{name}") for name in names]
82-
)
102+
),
83103
),
84104
f"least_{attr}_seen": ColSpec(
85105
col_type=ColType.PICK_SUM,

0 commit comments

Comments
 (0)