Skip to content

Commit 52a2309

Browse files
committed
wavg with tests
1 parent 4138f87 commit 52a2309

File tree

2 files changed

+188
-5
lines changed

2 files changed

+188
-5
lines changed

spells/utils.py

+14-5
Original file line numberDiff line numberDiff line change
@@ -30,17 +30,26 @@ def wavg(
3030
name_list = list(new_names)
3131

3232
assert len(name_list) == len(col_list), f"{len(name_list)} names provided for {len(col_list)} columns"
33-
assert len(weight_list) == len(col_list), f"{len(weight_list)} weights provided for {len(col_list)} columns"
33+
assert len(name_list) == len(set(name_list)), "Output names must be unique"
34+
assert len(weight_list) == len(col_list) or len(weight_list) == 1, f"{len(weight_list)} weights provided for {len(col_list)} columns"
35+
36+
enum_wl = weight_list * int(len(col_list) / len(weight_list))
37+
wl_names = [w.meta.output_name() for w in weight_list]
38+
assert len(wl_names) == len(set(wl_names)), "Weights must have unique names. Send one weight column or n uniquely named ones"
3439

3540
to_group = df.select(gbs + weight_list + [
36-
(c * weight_list[i]) for i, c in enumerate(col_list)
41+
(c * enum_wl[i]).alias(name_list[i]) for i, c in enumerate(col_list)
3742
])
3843

3944
grouped = to_group if not gbs else to_group.group_by(gbs)
4045

41-
return grouped.sum().select(
46+
ret_df = grouped.sum().select(
4247
gbs +
43-
[pl.col(c.meta.output_name()).alias(name_list[i]) for i, c in enumerate(col_list)]
48+
wl_names +
49+
[(pl.col(name) / pl.col(enum_wl[i].meta.output_name())) for i, name in enumerate(name_list)]
4450
)
51+
52+
if gbs:
53+
ret_df = ret_df.sort(by=gbs)
4554

46-
55+
return ret_df

tests/utils_test.py

+174
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
2+
"""
3+
Test behavior of wavg utility for Polars DataFrames
4+
"""
5+
6+
import pytest
7+
import polars as pl
8+
9+
import spells.utils as utils
10+
11+
def format_test_string(test_string: str) -> str:
12+
"""
13+
strip whitespace from each line to test pasted dataframe outputs
14+
"""
15+
return "\n".join(
16+
[line.strip() for line in test_string.splitlines() if line.strip()]
17+
)
18+
19+
test_df = pl.DataFrame({
20+
'cat': ['a', 'a', 'b', 'b', 'b', 'c' ],
21+
'va1': [1.0, -1.0, 0.2, 0.4, 0.0, 10.0 ],
22+
'va2': [4.0, 3.0, 1.0, -2.0, 2.0, 1.0 ],
23+
'wt1': [1, 2, 0, 2, 3, 1 ],
24+
'wt2': [2, 4, 1, 1, 1, 2, ],
25+
})
26+
27+
28+
# test wavg with default args
29+
@pytest.mark.parametrize(
30+
"cols, weights, expected",
31+
[
32+
(
33+
'va1',
34+
'wt1',
35+
"""
36+
shape: (1, 2)
37+
┌─────┬──────────┐
38+
│ wt1 ┆ va1 │
39+
│ --- ┆ --- │
40+
│ i64 ┆ f64 │
41+
╞═════╪══════════╡
42+
│ 9 ┆ 1.088889 │
43+
└─────┴──────────┘
44+
"""
45+
),
46+
(
47+
['va1', 'va2'],
48+
'wt1',
49+
"""
50+
shape: (1, 3)
51+
┌─────┬──────────┬──────────┐
52+
│ wt1 ┆ va1 ┆ va2 │
53+
│ --- ┆ --- ┆ --- │
54+
│ i64 ┆ f64 ┆ f64 │
55+
╞═════╪══════════╪══════════╡
56+
│ 9 ┆ 1.088889 ┆ 1.444444 │
57+
└─────┴──────────┴──────────┘
58+
"""
59+
),
60+
(
61+
['va1', 'va2'],
62+
['wt1', 'wt2'],
63+
"""
64+
shape: (1, 4)
65+
┌─────┬─────┬──────────┬──────────┐
66+
│ wt1 ┆ wt2 ┆ va1 ┆ va2 │
67+
│ --- ┆ --- ┆ --- ┆ --- │
68+
│ i64 ┆ i64 ┆ f64 ┆ f64 │
69+
╞═════╪═════╪══════════╪══════════╡
70+
│ 9 ┆ 11 ┆ 1.088889 ┆ 2.090909 │
71+
└─────┴─────┴──────────┴──────────┘
72+
"""
73+
),
74+
(
75+
[pl.col('va1') + 1, 'va2'],
76+
['wt1', pl.col('wt2') + 1],
77+
"""
78+
shape: (1, 4)
79+
┌─────┬─────┬──────────┬──────────┐
80+
│ wt1 ┆ wt2 ┆ va1 ┆ va2 │
81+
│ --- ┆ --- ┆ --- ┆ --- │
82+
│ i64 ┆ i64 ┆ f64 ┆ f64 │
83+
╞═════╪═════╪══════════╪══════════╡
84+
│ 9 ┆ 17 ┆ 2.088889 ┆ 1.882353 │
85+
└─────┴─────┴──────────┴──────────┘
86+
"""
87+
),
88+
]
89+
)
90+
def test_wavg_defaults(cols: str | pl.Expr | list[str | pl.Expr], weights: str | pl.Expr | list[str | pl.Expr], expected: str):
91+
result = utils.wavg(test_df, cols, weights)
92+
93+
test_str = str(result)
94+
print(test_str)
95+
assert test_str == format_test_string(expected)
96+
97+
98+
# test wavg with named args
99+
@pytest.mark.parametrize(
100+
"cols, weights, group_by, new_names, expected",
101+
[
102+
(
103+
"va1",
104+
"wt1",
105+
[],
106+
"v1",
107+
"""
108+
shape: (1, 2)
109+
┌─────┬──────────┐
110+
│ wt1 ┆ v1 │
111+
│ --- ┆ --- │
112+
│ i64 ┆ f64 │
113+
╞═════╪══════════╡
114+
│ 9 ┆ 1.088889 │
115+
└─────┴──────────┘
116+
"""
117+
),
118+
(
119+
"va1",
120+
"wt1",
121+
"cat",
122+
"va1",
123+
"""
124+
shape: (3, 3)
125+
┌─────┬─────┬───────────┐
126+
│ cat ┆ wt1 ┆ va1 │
127+
│ --- ┆ --- ┆ --- │
128+
│ str ┆ i64 ┆ f64 │
129+
╞═════╪═════╪═══════════╡
130+
│ a ┆ 3 ┆ -0.333333 │
131+
│ b ┆ 5 ┆ 0.16 │
132+
│ c ┆ 1 ┆ 10.0 │
133+
└─────┴─────┴───────────┘
134+
"""
135+
),
136+
(
137+
["va1", "va1"],
138+
["wt1", "wt2"],
139+
["cat"],
140+
["v@1", "v@2"],
141+
"""
142+
shape: (3, 5)
143+
┌─────┬─────┬─────┬───────────┬───────────┐
144+
│ cat ┆ wt1 ┆ wt2 ┆ v@1 ┆ v@2 │
145+
│ --- ┆ --- ┆ --- ┆ --- ┆ --- │
146+
│ str ┆ i64 ┆ i64 ┆ f64 ┆ f64 │
147+
╞═════╪═════╪═════╪═══════════╪═══════════╡
148+
│ a ┆ 3 ┆ 6 ┆ -0.333333 ┆ -0.333333 │
149+
│ b ┆ 5 ┆ 3 ┆ 0.16 ┆ 0.2 │
150+
│ c ┆ 1 ┆ 2 ┆ 10.0 ┆ 10.0 │
151+
└─────┴─────┴─────┴───────────┴───────────┘
152+
"""
153+
)
154+
]
155+
)
156+
def test_wavg(
157+
cols: str | pl.Expr | list[str | pl.Expr],
158+
weights: str | pl.Expr | list[str | pl.Expr],
159+
group_by: str | pl.Expr | list[str | pl.Expr],
160+
new_names: str | list[str],
161+
expected: str,
162+
):
163+
result = utils.wavg(
164+
test_df,
165+
cols,
166+
weights,
167+
group_by=group_by,
168+
new_names=new_names,
169+
)
170+
171+
test_str = str(result)
172+
print(test_str)
173+
assert test_str == format_test_string(expected)
174+

0 commit comments

Comments
 (0)