Skip to content

Commit

Permalink
a few fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
seanmacavaney committed Aug 29, 2024
1 parent 2683be3 commit a2d3ee3
Show file tree
Hide file tree
Showing 5 changed files with 4 additions and 5 deletions.
2 changes: 1 addition & 1 deletion pyt_splade/_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(

def transform(self, df: pd.DataFrame) -> pd.DataFrame:
"""Encodes the text field in the input DataFrame."""
pta.validate.columns(df, include=[self.text_field])
pta.validate.columns(df, includes=[self.text_field])
it = iter(df[self.text_field])
if self.verbose:
it = pt.tqdm(it, total=len(df), unit=self.text_field)
Expand Down
2 changes: 1 addition & 1 deletion pyt_splade/_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __init__(self, splade, text_field, batch_size=100, verbose=False):

def transform(self, df: pd.DataFrame) -> pd.DataFrame:
"""Scores (re-ranks) the documents against the queries in the input DataFrame."""
pta.validate.results_frame(df, ['query', self.text_field])
pta.validate.result_frame(df, ['query', self.text_field])
it = df.groupby('query')
if self.verbose:
it = pt.tqdm(it, unit='query')
Expand Down
2 changes: 1 addition & 1 deletion pyt_splade/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __init__(self, mult: float = 100.):

def transform(self, inp: pd.DataFrame) -> pd.DataFrame:
"""Converts the toks field into a text field."""
res = inp.assign(toks=inp['toks'].apply(self._dict_tf2text))
res = inp.assign(text=inp['toks'].apply(self._dict_tf2text))
res.drop(columns=['toks'], inplace=True)
return res

Expand Down
1 change: 0 additions & 1 deletion tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ def setUp(self):
self.factory = pyt_splade.Splade(device='cpu')

def test_transformer_indexing(self):
import pyt_splade
df = (self.factory.indexing() >> pyt_splade.toks2doc()).transform_iter([{'docno' : 'd1', 'text' : 'hello there'}])
self.assertTrue('there there' in df.iloc[0].text)
df = self.factory.indexing().transform_iter([
Expand Down
2 changes: 1 addition & 1 deletion tests/test_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def test_scorer(self):
{'qid': '1', 'query': 'hello', 'docno' : 'd1', 'text' : 'hello there'},
])
self.assertAlmostEqual(0., df['score'][0])
self.assertAlmostEqual(11.133593, df['score'][1], places=5)
self.assertAlmostEqual(11.133593, df['score'][1], places=4)
self.assertAlmostEqual(17.566324, df['score'][2], places=3)
self.assertEqual('0', df['qid'][0])
self.assertEqual('0', df['qid'][1])
Expand Down

0 comments on commit a2d3ee3

Please sign in to comment.