Skip to content

Commit

Permalink
Prepare PyPI release 0.1.0 (#11)
Browse files Browse the repository at this point in the history
* Update PyGaggle version and clean up

* Update docs
  • Loading branch information
edwinzhng authored Mar 8, 2021
1 parent a6b1180 commit 514d752
Show file tree
Hide file tree
Showing 8 changed files with 49 additions and 52 deletions.
52 changes: 26 additions & 26 deletions chatty_goose/pipeline/retrieval_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,63 +15,63 @@ class RetrievalPipeline:
Parameters:
searcher (SimpleSearcher): Pyserini searcher for Lucene index
retrievers (List[CQR]): List of CQR retrievers to use for first-stage retrieval
reformulators (List[CQR]): List of CQR methods to use for first-stage retrieval
searcher_num_hits (int): number of hits returned by searcher - default 10
early_fusion (bool): flag to perform fusion before second-stage retrieval - default True
reranker (Reranker): optional reranker for second-stage retrieval
reranker_query_index (int): retriever index to use for reranking query - defaults to last retriever
reranker_query_retriever (CQR): retriever for generating reranker query,
overrides reranker_query_index if provided
reranker_query_reformulator (CQR): CQR method for generating reranker query,
overrides reranker_query_index if provided
"""

def __init__(
self,
searcher: SimpleSearcher,
retrievers: List[CQR],
reformulators: List[CQR],
searcher_num_hits: int = 10,
early_fusion: bool = True,
reranker: Reranker = None,
reranker_query_index: int = -1,
reranker_query_retriever: CQR = None,
reranker_query_reformulator: CQR = None,
):
self.searcher = searcher
self.retrievers = retrievers
self.reformulators = reformulators
self.searcher_num_hits = int(searcher_num_hits)
self.early_fusion = early_fusion
self.reranker = reranker
self.reranker_query_index = reranker_query_index
self.reranker_query_retriever = reranker_query_retriever
self.reranker_query_reformulator = reranker_query_reformulator

def retrieve(self, query) -> List[JSimpleSearcherResult]:
retriever_hits = []
retriever_queries = []
for retriever in self.retrievers:
new_query = retriever.rewrite(query)
cqr_hits = []
cqr_queries = []
for cqr in self.reformulators:
new_query = cqr.rewrite(query)
hits = self.searcher.search(new_query, k=self.searcher_num_hits)
retriever_hits.append(hits)
retriever_queries.append(new_query)
cqr_hits.append(hits)
cqr_queries.append(new_query)

# Merge results from multiple retrievers if required
# Merge results from multiple CQR methods if required
if self.early_fusion or self.reranker is None:
retriever_hits = reciprocal_rank_fusion(retriever_hits)
cqr_hits = reciprocal_rank_fusion(cqr_hits)

# Return results if no reranker
if self.reranker is None:
return retriever_hits
return cqr_hits

# Get query for reranker
if self.reranker_query_retriever is None:
rerank_query = retriever_queries[self.reranker_query_index]
if self.reranker_query_reformulator is None:
rerank_query = cqr_queries[self.reranker_query_index]
else:
rerank_query = self.reranker_query_retriever.rewrite(query)
rerank_query = self.reranker_query_reformulator.rewrite(query)

# Rerank results
if self.early_fusion:
results = self.rerank(rerank_query, retriever_hits[:self.searcher_num_hits])
results = self.rerank(rerank_query, cqr_hits[:self.searcher_num_hits])
else:
# Rerank all retriever results and fuse together
# Rerank all CQR results and fuse together
results = []
for hits in retriever_hits:
for hits in cqr_hits:
results = self.rerank(rerank_query, hits)
results = reciprocal_rank_fusion(results)
return results
Expand All @@ -91,8 +91,8 @@ def rerank(self, query, hits):
return reranked_hits

def reset_history(self):
for retriever in self.retrievers:
retriever.reset_history()
for cqr in self.reformulators:
cqr.reset_history()

if self.reranker_query_retriever:
self.reranker_query_retriever.reset_history()
if self.reranker_query_reformulator:
self.reranker_query_reformulator.reset_history()
2 changes: 1 addition & 1 deletion chatty_goose/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
class SearcherSettings(BaseSettings):
"""Settings for Anserini searcher"""

index_path: str # Lucene index path
index_path: str # Pre-built index name or path to Lucene index
k1: float = 0.82 # BM25 k parameter
b: float = 0.68 # BM25 b parameter
rm3: bool = False # use RM3
Expand Down
6 changes: 5 additions & 1 deletion chatty_goose/util.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from os import path
from typing import Dict, List, Tuple

from pygaggle.rerank.transformer import MonoBERT
Expand Down Expand Up @@ -49,7 +50,10 @@ def build_bert_reranker(


def build_searcher(settings: SearcherSettings) -> SimpleSearcher:
searcher = SimpleSearcher(settings.index_path)
if path.isdir(settings.index_path):
searcher = SimpleSearcher(settings.index_path)
else:
searcher = SimpleSearcher.from_prebuilt_index(settings.index_path)
searcher.set_bm25(float(settings.k1), float(settings.b))
logging.info(
"Initializing BM25, setting k1={} and b={}".format(settings.k1, settings.b)
Expand Down
17 changes: 5 additions & 12 deletions docs/cqr_experiments.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,31 +2,24 @@

## Data Preparation

1. Download the pre-built CAsT 2019 index using Pyserini. This will download the entire index to `~/.cache/pyserini`.
1. Download either the [training](https://github.com/daltonj/treccastweb/blob/master/2019/data/training/train_topics_v1.0.json) and [evaluation](https://github.com/daltonj/treccastweb/blob/master/2019/data/evaluation/evaluation_topics_v1.0.json) input query JSON files. These files can be found under `data/treccastweb/2019/data` if you cloned the submodules for this repo.

```
from pyserini.search import SimpleSearcher
SimpleSearcher.from_prebuilt_index('cast2019')
```

2. Download either the [training](https://github.com/daltonj/treccastweb/blob/master/2019/data/training/train_topics_v1.0.json) and [evaluation](https://github.com/daltonj/treccastweb/blob/master/2019/data/evaluation/evaluation_topics_v1.0.json) input query JSON files. These files can be found under `data/treccastweb/2019/data` if you cloned the submodules for this repo.

3. Download the evaluation answer files for [training](https://github.com/daltonj/treccastweb/blob/master/2019/data/training/train_topics_mod.qrel) or [evaluation](https://trec.nist.gov/data/cast/2019qrels.txt). The training answer file is found under `data/treccastweb/2019/data`.
2. Download the evaluation answer files for [training](https://github.com/daltonj/treccastweb/blob/master/2019/data/training/train_topics_mod.qrel) or [evaluation](https://trec.nist.gov/data/cast/2019qrels.txt). The training answer file is found under `data/treccastweb/2019/data`.

## Run CQR retrieval

The following command is for HQE, but you can also run other CQR methods using `t5` or `fusion` instead of `hqe` as the input to the `--experiment` flag.
The following command is for HQE, but you can also run other CQR methods using `t5` or `fusion` instead of `hqe` as the input to the `--experiment` flag. Running the command for the first time will download the CAsT 2019 index (or whatever index is specified for the `--index` flag). It is also possible to supply a path to a local directory containing the index.

```shell=bash
python -m experiments.run_retrieval \
--experiment hqe \
--hits 1000 \
--index $anserini_index_path \
--index cast2019 \
--qid_queries $input_query_json \
--output ./output/hqe_bm25 \
```

Running the experiment will output the retrieval results at the specified location in TSV format. By default, this will perform retrieval using only BM25, but you can add the `--rerank` flag to further rerank these results using BERT. For other command line arguments, see [run_retrieval.py](experiments/run_retrieval.py).
The experiment will output the retrieval results at the specified location in TSV format. By default, this will perform retrieval using only BM25, but you can add the `--rerank` flag to further rerank these results using BERT. For other command line arguments, see [run_retrieval.py](experiments/run_retrieval.py).

## Evaluate CQR results

Expand Down
2 changes: 1 addition & 1 deletion examples/messenger/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ This guide is based on ParlAI's [chat service tutorial](https://parl.ai/docs/tut
3. Run the webhook server and Chatty Goose agent using our provided configuration. This assumes you have the ParlAI Python package installed and are inside the `chatty-goose` root repository folder.

```
python3.7 -m parlai.chat_service.services.messenger.run --config-path examples/messenger/config.yml
python -m parlai.chat_service.services.messenger.run --config-path examples/messenger/config.yml
```

4. Add the webhook URL outputted from the above command as a callback URL for the Messenger App settings, and set the verify token to `Messenger4ParlAI`. For Heroku, this URL should look like `https://firstname-parlai-messenger-chatbot.herokuapp.com/webhook`.
Expand Down
16 changes: 8 additions & 8 deletions experiments/run_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,9 @@ def run_experiment(rp: RetrievalPipeline):
)
searcher = build_searcher(searcher_settings)

# Initialize retrievers and reranker
retrievers = []
reranker_query_retriever = None
# Initialize CQR and reranker
reformulators = []
reranker_query_reformulator = None
reranker = build_bert_reranker(device=args.reranker_device) if args.rerank else None

if experiment == CQRType.HQE or experiment == CQRType.FUSION:
Expand All @@ -131,7 +131,7 @@ def run_experiment(rp: RetrievalPipeline):
verbose=args.verbose,
)
hqe_bm25 = HQE(searcher, hqe_bm25_settings)
retrievers.append(hqe_bm25)
reformulators.append(hqe_bm25)

if experiment == CQRType.T5 or experiment == CQRType.FUSION:
# Initialize T5
Expand All @@ -143,7 +143,7 @@ def run_experiment(rp: RetrievalPipeline):
verbose=args.verbose,
)
t5 = T5_NTR(t5_settings, device=args.t5_device)
retrievers.append(t5)
reformulators.append(t5)

if experiment == CQRType.HQE:
hqe_bert_settings = HQESettings(
Expand All @@ -153,14 +153,14 @@ def run_experiment(rp: RetrievalPipeline):
R_sub=args.R1_sub,
filter=PosFilter(args.filter),
)
reranker_query_retriever = HQE(searcher, hqe_bert_settings)
reranker_query_reformulator = HQE(searcher, hqe_bert_settings)

rp = RetrievalPipeline(
searcher,
retrievers,
reformulators,
searcher_num_hits=args.hits,
early_fusion=not args.late_fusion,
reranker=reranker,
reranker_query_retriever=reranker_query_retriever,
reranker_query_reformulator=reranker_query_reformulator,
)
run_experiment(rp)
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
coloredlogs
parlai==1.1.0
pydantic>=1.5
pygaggle==0.0.2
pygaggle==0.0.3.1
spacy>=2.2.4
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
with open("requirements.txt") as f:
requirements = f.read().splitlines()

excluded = ["data*", "experiments*"]
excluded = ["data*", "examples*", "experiments*"]


setuptools.setup(
Expand All @@ -17,7 +17,7 @@
description="A conversational passage retrieval toolkit",
long_description=long_description,
long_description_content_type="text/markdown",
url="https://github.com/jacklin64/chatty-goose",
url="https://github.com/castorini/chatty-goose",
install_requires=requirements,
packages=setuptools.find_packages(exclude=excluded),
classifiers=[
Expand Down

0 comments on commit 514d752

Please sign in to comment.