Skip to content

Commit 8c74db4

Browse files
committed
more tests
1 parent 9e70f4a commit 8c74db4

File tree

5 files changed

+239
-22
lines changed

5 files changed

+239
-22
lines changed

skrub/_expressions/_choosing.py

+4
Original file line numberDiff line numberDiff line change
@@ -476,6 +476,8 @@ class BaseChoice:
476476
recognizable with ``isinstance`` checks.
477477
"""
478478

479+
__hash__ = None
480+
479481
def as_expr(self):
480482
"""Wrap the choice in an expression.
481483
@@ -731,6 +733,8 @@ class Match:
731733
choice: Choice
732734
outcome_mapping: dict
733735

736+
__hash__ = None
737+
734738
def match(self, outcome_mapping, default=Constants.NO_VALUE):
735739
"""Select a value depending on the result of this match.
736740

skrub/_expressions/_evaluation.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,15 @@ def handle_seq(self, seq):
169169
def handle_mapping(self, mapping):
170170
new_mapping = {}
171171
for k, v in mapping.items():
172-
new_mapping[(yield k)] = yield v
172+
# note evaluating the keys is not needed because expressions,
173+
# choices and matches are not hashable so we do not need to
174+
# (yield k).
175+
#
176+
# In theory, because scikit-learn estimators are (unfortunately)
177+
# hashable an estimator containing a choice could be a key but that
178+
# wouldn't make sense and evaluating couldn't help in any case
179+
# because estimators are hashed and compared by id.
180+
new_mapping[k] = yield v
173181
return type(mapping)(new_mapping)
174182

175183
def handle_slice(self, s):

skrub/_expressions/_expressions.py

+3-21
Original file line numberDiff line numberDiff line change
@@ -252,23 +252,6 @@ def _find_dataframe(expr, func_name):
252252
return None
253253

254254

255-
def _bad_match_key(expr):
256-
from ._evaluation import needs_eval
257-
258-
impl = expr._skrub_impl
259-
if not isinstance(impl, Match):
260-
return None
261-
contains_expr, found = needs_eval(list(impl.targets.keys()), return_node=True)
262-
if not contains_expr:
263-
return None
264-
return {
265-
"message": (
266-
"`.skb.match()` keys must be actual values, not expressions nor choices. "
267-
f"Found: {short_repr(found)}"
268-
)
269-
}
270-
271-
272255
def check_expr(f):
273256
"""Check an expression and evaluate the preview.
274257
@@ -296,9 +279,6 @@ def _checked_call(*args, **kwargs):
296279
if (found_df := _find_dataframe(expr, func_name)) is not None:
297280
raise TypeError(found_df["message"])
298281

299-
if (bad_key := _bad_match_key(expr)) is not None:
300-
raise TypeError(bad_key["message"])
301-
302282
# Note: if checking pickling for every step is expensive we could also
303283
# do it in `get_estimator()` only, ie before any cross-val or
304284
# grid-search. or we could have some more elaborate check (possibly
@@ -354,6 +334,8 @@ def _check_call_return_value(*args, **kwargs):
354334

355335

356336
class Expr:
337+
__hash__ = None
338+
357339
def __init__(self, impl):
358340
self._skrub_impl = impl
359341

@@ -1092,7 +1074,7 @@ def get_func_name(self):
10921074
if isinstance(impl, GetAttr):
10931075
name = impl.attr_name
10941076
elif isinstance(impl, GetItem):
1095-
name = impl.key
1077+
name = f"{{ ... }}[{short_repr(impl.key)}]"
10961078
elif isinstance(impl, Var):
10971079
name = impl.name
10981080
else:

skrub/_expressions/tests/test_errors.py

+87
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import pickle
2+
13
import numpy as np
24
import pandas as pd
35
import pytest
@@ -72,6 +74,38 @@ def test_preview_failure():
7274
skrub.X(1) / 0
7375

7476

77+
#
78+
# pickling
79+
#
80+
81+
82+
class NoPickle:
83+
def __deepcopy__(self, mem):
84+
return self
85+
86+
def __getstate__(self):
87+
raise pickle.PicklingError("cannot pickle NoPickle")
88+
89+
90+
def test_pickling_preview_failure():
91+
with pytest.raises(
92+
pickle.PicklingError,
93+
match="The check to verify that the pipeline can be serialized failed",
94+
):
95+
skrub.X([]) + [NoPickle()]
96+
97+
98+
def test_pickling_estimator_failure():
99+
a = []
100+
e = skrub.X([]) + a
101+
a.append(NoPickle())
102+
with pytest.raises(
103+
pickle.PicklingError,
104+
match="The check to verify that the pipeline can be serialized failed",
105+
):
106+
e.skb.get_estimator()
107+
108+
75109
#
76110
# Misc errors
77111
#
@@ -111,6 +145,14 @@ def test_apply_instead_of_skb_apply():
111145
a.apply(PassThrough())
112146

113147

148+
def test_method_call_failure():
149+
with pytest.raises(
150+
Exception,
151+
match=r"(?s)Evaluation of '.upper\(\)' failed.*takes no arguments \(1 given\)",
152+
):
153+
skrub.var("a", "hello").upper(0)
154+
155+
114156
def test_duplicate_choice_name():
115157
with pytest.raises(
116158
ValueError, match=r".*The name 'b' was used for 2 different objects"
@@ -161,6 +203,8 @@ def test_bad_names():
161203
):
162204
# less likely to happen but for the sake of completeness
163205
(skrub.var("a") + 2).skb.set_name(0)
206+
with pytest.raises(ValueError, match="names starting with '_skrub_'"):
207+
skrub.var("_skrub_X")
164208

165209

166210
def test_pass_df_instead_of_expr():
@@ -202,6 +246,49 @@ def test_get_grid_search_with_continuous_ranges():
202246
).skb.get_grid_search()
203247

204248

249+
def test_expr_with_circular_ref():
250+
# expressions are not allowed to contain circular references as it would
251+
# complicate the implementation and there is probably no use case. We want
252+
# to get an understandable error and not an infinite loop or memory error.
253+
e = {}
254+
e["a"] = [0, {"b": e}]
255+
with pytest.raises(
256+
ValueError, match="expressions cannot contain circular references"
257+
):
258+
skrub.as_expr(e).skb.eval()
259+
260+
261+
@pytest.mark.parametrize(
262+
"attribute", ["__copy__", "__float__", "__int__", "__reversed__"]
263+
)
264+
def test_bad_attr(attribute):
265+
with pytest.raises(AttributeError):
266+
getattr(skrub.X(), attribute)
267+
268+
269+
def test_unhashable():
270+
with pytest.raises(TypeError, match="unhashable type"):
271+
{skrub.X()}
272+
with pytest.raises(TypeError, match="unhashable type"):
273+
{skrub.choose_bool(name="b")}
274+
with pytest.raises(TypeError, match="unhashable type"):
275+
{skrub.choose_bool(name="b").if_else(0, 1)}
276+
277+
278+
#
279+
# Bad arguments passed to eval()
280+
#
281+
282+
283+
def test_missing_var():
284+
e = skrub.var("a", 0) + skrub.var("b", 1)
285+
# we must provide either bindings for all vars or none
286+
assert e.skb.eval() == 1
287+
assert e.skb.eval({}) == 1
288+
with pytest.raises(KeyError, match="No value has been provided for 'b'"):
289+
e.skb.eval({"a": 10})
290+
291+
205292
#
206293
# warnings
207294
#

skrub/_expressions/tests/test_interactive_features.py

+136
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,33 @@ def test_doc(a):
2424
assert "Encode the string using the codec" in a.encode.__doc__
2525

2626

27+
class _A:
28+
pass
29+
30+
31+
def test_missing_doc():
32+
with pytest.raises(AttributeError):
33+
skrub.X().__doc__
34+
35+
with pytest.raises(AttributeError):
36+
skrub.X(_A()).__doc__
37+
38+
2739
@pytest.mark.parametrize("a", example_strings())
2840
def test_signature(a):
2941
assert "encoding" in inspect.signature(a.encode).parameters
3042

3143

44+
def test_missing_signature():
45+
with pytest.raises(AttributeError):
46+
skrub.X(0).__signature__
47+
48+
3249
def test_key_completions():
3350
a = skrub.var("a", {"one": 1}) | skrub.var("b", {"two": 2})
3451
assert a._ipython_key_completions_() == ["one", "two"]
52+
assert skrub.X()._ipython_key_completions_() == []
53+
assert skrub.X(0)._ipython_key_completions_() == []
3554

3655

3756
def test_repr_html():
@@ -41,6 +60,9 @@ def test_repr_html():
4160
a = skrub.var("thename", skrub.toy_orders().orders)
4261
r = a._repr_html_()
4362
assert "thename" in r and "table-report" in r
63+
assert "thename" in skrub.var("thename")._repr_html_()
64+
# example without a name
65+
assert "add" in (skrub.var("thename", 0) + 2)._repr_html_()
4466

4567

4668
def test_repr():
@@ -60,4 +82,118 @@ def test_repr():
6082
Result:
6183
―――――――
6284
'one two'
85+
>>> skrub.as_expr({'a': 0})
86+
<Value dict>
87+
Result:
88+
―――――――
89+
{'a': 0}
90+
>>> skrub.var('a', 1).skb.match({1: 10, 2: 20})
91+
<Match <Var 'a'>>
92+
>>> from sklearn.preprocessing import StandardScaler, RobustScaler
93+
>>> skrub.X().skb.apply(StandardScaler())
94+
<Apply StandardScaler>
95+
>>> skrub.X().skb.apply('passthrough')
96+
<Apply 'passthrough'>
97+
>>> skrub.X().skb.apply(None)
98+
<Apply passthrough>
99+
>>> skrub.X().skb.apply(skrub.optional(StandardScaler(), name='scale'))
100+
<Apply StandardScaler>
101+
>>> skrub.X().skb.apply(
102+
... skrub.choose_from([RobustScaler(), StandardScaler()], name='scale'))
103+
<Apply RobustScaler>
104+
>>> skrub.as_expr({'a': 0})['a']
105+
<GetItem 'a'>
106+
Result:
107+
―――――――
108+
0
109+
>>> skrub.as_expr({'a': 0, 'b': 1})[skrub.choose_from(['a', 'b'], name='c')]
110+
<GetItem choose_from(['a', 'b'], name='c')>
111+
Result:
112+
―――――――
113+
0
114+
>>> skrub.as_expr({'a': 0, 'b': 1})[skrub.var('key', 'b')]
115+
<GetItem <Var 'key'>>
116+
Result:
117+
―――――――
118+
1
119+
>>> skrub.as_expr('hello').upper()
120+
<CallMethod 'upper'>
121+
Result:
122+
―――――――
123+
'HELLO'
124+
>>> a = skrub.var('a', 'hello')
125+
>>> b = skrub.var('b', 1)
126+
>>> skrub.as_expr({0: a.upper, 1: a.title})[b]()
127+
<Call "{ ... }[<Var 'b'>]">
128+
Result:
129+
―――――――
130+
'Hello'
131+
>>> skrub.var('f', str.upper)('abc')
132+
<Call 'f'>
133+
Result:
134+
―――――――
135+
'ABC'
136+
137+
Weird (unnecessary) use of deferred to trigger a case where calling a
138+
method has not been translated to a CallMethod
139+
140+
>>> skrub.deferred(skrub.var('a', 'hello').upper)()
141+
<Call 'upper'>
142+
Result:
143+
―――――――
144+
'HELLO'
145+
146+
In cases that are hard to figure out we fall back on a less informative
147+
default
148+
149+
>>> skrub.choose_from([str.upper, str.title], name='f').as_expr()('abc')
150+
<Call 'Value'>
151+
Result:
152+
―――――――
153+
'ABC'
154+
>>> skrub.as_expr(str.upper)('abc')
155+
<Call 'Value'>
156+
Result:
157+
―――――――
158+
'ABC'
159+
160+
>>> a = skrub.var('a')
161+
>>> b = skrub.var('b')
162+
>>> c = skrub.var('c', 0)
163+
>>> a + b
164+
<BinOp: add>
165+
>>> - a
166+
<UnaryOp: neg>
167+
>>> 2 + a
168+
<BinOp: add>
169+
>>> c + c
170+
<BinOp: add>
171+
Result:
172+
―――――――
173+
0
174+
>>> - c
175+
<UnaryOp: neg>
176+
Result:
177+
―――――――
178+
0
179+
>>> 2 - c
180+
<BinOp: sub>
181+
Result:
182+
―――――――
183+
2
184+
185+
>>> X = skrub.X()
186+
>>> X.skb.concat_horizontal([X, X])
187+
<ConcatHorizontal: 3 dataframes>
188+
189+
When we do not know the length of the list of dataframes to concatenate
190+
191+
>>> X.skb.concat_horizontal(skrub.as_expr([X, X]))
192+
<ConcatHorizontal>
63193
"""
194+
195+
196+
def test_format():
197+
assert f"{skrub.X()}" == "<Var 'X'>"
198+
with pytest.raises(ValueError, match="Invalid format specifier"):
199+
f"{skrub.X(0.2):.2f}"

0 commit comments

Comments
 (0)