Skip to content

Commit

Permalink
Bugfix #825 - Update paraphrase_mining
Browse files Browse the repository at this point in the history
  • Loading branch information
nreimers committed Mar 22, 2021
1 parent 82c0987 commit 6353eb9
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 6 deletions.
2 changes: 1 addition & 1 deletion sentence_transformers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "1.0.2"
__version__ = "1.0.3"
__DOWNLOAD_SERVER__ = 'http://sbert.net/models/'
from .datasets import SentencesDataset, ParallelSentencesDataset
from .LoggingHandler import LoggingHandler
Expand Down
3 changes: 2 additions & 1 deletion sentence_transformers/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def paraphrase_mining(model,
sentences: List[str],
show_progress_bar: bool = False,
batch_size:int = 32,
*args,
**kwargs):
"""
Given a list of sentences / texts, this function performs paraphrase mining. It compares all sentences against all
Expand All @@ -93,7 +94,7 @@ def paraphrase_mining(model,
# Compute embedding for the sentences
embeddings = model.encode(sentences, show_progress_bar=show_progress_bar, batch_size=batch_size, convert_to_tensor=True)

return paraphrase_mining_embeddings(embeddings, **kwargs)
return paraphrase_mining_embeddings(embeddings, *args, **kwargs)


def paraphrase_mining_embeddings(embeddings: Tensor,
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@

setup(
name="sentence-transformers",
version="1.0.2",
version="1.0.3",
author="Nils Reimers",
author_email="info@nils-reimers.de",
description="Sentence Embeddings using BERT / RoBERTa / XLM-R",
long_description=readme,
long_description_content_type="text/markdown",
license="Apache License 2.0",
url="https://github.com/UKPLab/sentence-transformers",
download_url="https://github.com/UKPLab/sentence-transformers/archive/v1.0.2.zip",
download_url="https://github.com/UKPLab/sentence-transformers/archive/v1.0.3.zip",
packages=find_packages(),
install_requires=[
'transformers>=3.1.0,<5.0.0',
Expand Down
11 changes: 9 additions & 2 deletions tests/test_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@ def test_BinaryClassificationEvaluator_find_best_accuracy_and_threshold(self):
assert np.abs(max_acc - sklearn_acc) < 1e-6

def test_LabelAccuracyEvaluator(self):


"""Tests that the LabelAccuracyEvaluator can be loaded correctly"""
model = SentenceTransformer('paraphrase-distilroberta-base-v1')

nli_dataset_path = 'datasets/AllNLI.tsv.gz'
Expand All @@ -59,3 +58,11 @@ def test_LabelAccuracyEvaluator(self):
evaluator = evaluation.LabelAccuracyEvaluator(dev_dataloader, softmax_model=train_loss)
acc = evaluator(model)
assert acc > 0.2

def test_ParaphraseMiningEvaluator(self):
"""Tests that the ParaphraseMiningEvaluator can be loaded"""
model = SentenceTransformer('paraphrase-distilroberta-base-v1')
sentences = {0: "Hello World", 1: "Hello World!", 2: "The cat is on the table", 3: "On the table the cat is"}
data_eval = evaluation.ParaphraseMiningEvaluator(sentences, [(0,1), (2,3)])
score = data_eval(model)
assert score > 0.99

0 comments on commit 6353eb9

Please sign in to comment.