diff --git a/README.md b/README.md index 7b477de..1b59c47 100644 --- a/README.md +++ b/README.md @@ -34,11 +34,11 @@ index_ref = indxr_pipe.index(dataset.get_corpus_iter(), batch_size=128) # Retrieval Similarly, SPLADE encodes the query into BERT WordPieces and corresponding weights. -We apply this as a query encoding transformer. It encodes the query into Terrier's matchop query language, to avoid tokenisation problems. +We apply this as a query encoding transformer. ```python -splade_retr = splade.query_encoder(matchop=True) >> pt.terrier.Retrieve('./msmarco_psg', wmodel='Tf') +splade_retr = splade.query_encoder() >> pt.terrier.Retriever('./msmarco_psg', wmodel='Tf') ``` diff --git a/pyt_splade/__init__.py b/pyt_splade/__init__.py index e20a4ce..edcdfc7 100644 --- a/pyt_splade/__init__.py +++ b/pyt_splade/__init__.py @@ -3,9 +3,9 @@ from pyt_splade._model import Splade from pyt_splade._encoder import SpladeEncoder from pyt_splade._scorer import SpladeScorer -from pyt_splade._utils import Toks2Doc, MatchOp +from pyt_splade._utils import Toks2Doc SpladeFactory = Splade # backward compatible name toks2doc = Toks2Doc # backward compatible name -__all__ = ['Splade', 'SpladeEncoder', 'SpladeScorer', 'SpladeFactory', 'Toks2Doc', 'toks2doc', 'MatchOp'] +__all__ = ['Splade', 'SpladeEncoder', 'SpladeScorer', 'SpladeFactory', 'Toks2Doc', 'toks2doc'] diff --git a/pyt_splade/_model.py b/pyt_splade/_model.py index 3a5d3e9..0d85f28 100644 --- a/pyt_splade/_model.py +++ b/pyt_splade/_model.py @@ -58,27 +58,22 @@ def doc_encoder(self, text_field='text', batch_size=100, sparse=True, verbose=Fa out_field = 'toks' if sparse else 'doc_vec' return pyt_splade.SpladeEncoder(self, text_field, out_field, 'd', sparse, batch_size, verbose, scale) - indexing = doc_encoder # backward compatible name + indexing = doc_encoder # backward compatible name - def query_encoder(self, batch_size=100, sparse=True, verbose=False, matchop=False, scale=100) -> pt.Transformer: + def query_encoder(self, batch_size=100, sparse=True, verbose=False, scale=100) -> pt.Transformer: """Returns a transformer that encodes a query field into a query representation. Args: batch_size: the batch size to use when encoding sparse: if True, the output will be a dict of term frequencies, otherwise a dense vector verbose: if True, show a progress bar - matchop: if True, convert the output to MatchOp syntax scale: the scale to apply to the term frequencies """ out_field = 'query_toks' if sparse else 'query_vec' res = pyt_splade.SpladeEncoder(self, 'query', out_field, 'q', sparse, batch_size, verbose, scale) - if matchop: - res = res >> pyt_splade.MatchOp() return res - def query(self, batch_size=100, sparse=True, verbose=False, matchop=True, scale=100) -> pt.Transformer: - # backward compatible name w/ default matchop=True - return self.query_encoder(batch_size, sparse, verbose, matchop, scale) + query = query_encoder # backward compatible name def scorer(self, text_field='text', batch_size=100, verbose=False) -> pt.Transformer: """Returns a transformer that scores documents against queries. diff --git a/pyt_splade/_utils.py b/pyt_splade/_utils.py index 635c16a..856c898 100644 --- a/pyt_splade/_utils.py +++ b/pyt_splade/_utils.py @@ -1,32 +1,7 @@ -import base64 -import string import pandas as pd -import pyterrier_alpha as pta import pyterrier as pt -class MatchOp(pt.Transformer): - """Converts a query_toks field into a query field, using the MatchOp syntax.""" - - def transform(self, df: pd.DataFrame) -> pd.DataFrame: - """Converts the query_toks field into a query field.""" - pta.validate.query_frame(df, ['query_toks']) - rtr = pt.model.push_queries(df) - rtr = rtr.assign(query=df.query_toks.apply(lambda toks: ' '.join(_matchop(k, v) for k, v in toks.items()))) - rtr = rtr.drop(columns=['query_toks']) - return rtr - - -def _matchop(t, w): - """Converts a term and its weight into MatchOp syntax.""" - if not all(a in string.ascii_letters + string.digits for a in t): - encoded = base64.b64encode(t.encode('utf-8')).decode("utf-8") - t = f'#base64({encoded})' - if w != 1: - t = f'#combine:0={w}({t})' - return t - - class Toks2Doc(pt.Transformer): """Converts a toks field into a text field, by scaling the weights by ``mult`` and repeating them.""" def __init__(self, mult: float = 100.): diff --git a/pyt_splade/pt_docs/api.rst b/pyt_splade/pt_docs/api.rst index 072fd9a..785d9a8 100644 --- a/pyt_splade/pt_docs/api.rst +++ b/pyt_splade/pt_docs/api.rst @@ -6,22 +6,12 @@ API Documentation .. autoclass:: pyt_splade.Splade :members: -Utils +Utils / Internals ------------------------------------------ -These utility transformers allow you to convert between sparse representation formats. - .. autoclass:: pyt_splade.Toks2Doc :members: -.. autoclass:: pyt_splade.MatchOp - :members: - -Internals ------------------------------------------- - -These transformers are returned by :class:`~pyt_splade.Splade` to perform encoding and scoring. - .. autoclass:: pyt_splade.SpladeEncoder :members: diff --git a/pyt_splade/pt_docs/index.rst b/pyt_splade/pt_docs/index.rst index 77cf523..63eafe0 100644 --- a/pyt_splade/pt_docs/index.rst +++ b/pyt_splade/pt_docs/index.rst @@ -44,11 +44,11 @@ Retrieval --------------------------------------------- Similarly, SPLADE encodes the query into BERT WordPieces and corresponding weights. -We apply this as a query encoding transformer. It encodes the query into Terrier's matchop query language, to avoid tokenisation problems. +We apply this as a query encoding transformer. .. code-block:: python - splade_retr = splade.query_encoder(matchop=True) >> pt.terrier.Retrieve('./msmarco_psg', wmodel='Tf') + splade_retr = splade.query_encoder() >> pt.terrier.Retriever('./msmarco_psg', wmodel='Tf') Scoring diff --git a/setup.py b/setup.py index c5dcc6e..6a748b3 100644 --- a/setup.py +++ b/setup.py @@ -19,6 +19,6 @@ license="Creative Commons Attribution-NonCommercial-ShareAlike", long_description=readme, install_requires=[ - 'splade', 'python-terrier', 'pyterrier_alpha', + 'splade', 'python-terrier>=0.11.0', 'pyterrier_alpha', ], ) \ No newline at end of file diff --git a/tests/test_matchop.py b/tests/test_matchop.py deleted file mode 100644 index 5428cb4..0000000 --- a/tests/test_matchop.py +++ /dev/null @@ -1,13 +0,0 @@ -import unittest -from pyt_splade._utils import _matchop - -class TestMatchop(unittest.TestCase): - - def test_it(self): - - self.assertEqual(_matchop('a', 1), 'a') - self.assertEqual(_matchop('a', 1.1), '#combine:0=1.1(a)') - self.assertEqual(_matchop('##a', 1.1), '#combine:0=1.1(#base64(IyNh))') - - self.assertTrue("#base64" in _matchop('"', 1)) -