|
| 1 | +import csv |
| 2 | +import gzip |
| 3 | +import os |
| 4 | +import time |
| 5 | + |
| 6 | +import pytest |
| 7 | +from sentence_transformers import SentenceTransformer, util |
| 8 | + |
| 9 | +from .test_examples import TIME_PERF_FACTOR |
| 10 | + |
| 11 | + |
| 12 | +if os.environ.get("GAUDI2_CI", "0") == "1": |
| 13 | + # Gaudi2 CI baselines |
| 14 | + MODELS_TO_TEST = [ |
| 15 | + ("sentence-transformers/all-mpnet-base-v2", 762.5595168883357), |
| 16 | + ("sentence-transformers/multi-qa-mpnet-base-dot-v1", 545.3360251829846), |
| 17 | + ("sentence-transformers/all-distilroberta-v1", 958.5097903298335), |
| 18 | + ("sentence-transformers/all-MiniLM-L12-v2", 3614.2610109716247), |
| 19 | + ("sentence-transformers/multi-qa-distilbert-cos-v1", 944.6166139694299), |
| 20 | + ("sentence-transformers/all-MiniLM-L6-v2", 2615.6975354038477), |
| 21 | + ("sentence-transformers/multi-qa-MiniLM-L6-cos-v1", 1208.3672807492396), |
| 22 | + ("sentence-transformers/paraphrase-multilingual-mpnet-base-v2", 2392.1654748794062), |
| 23 | + ("sentence-transformers/paraphrase-albert-small-v2", 3896.1911011860166), |
| 24 | + ("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", 3558.0778715789693), |
| 25 | + ("sentence-transformers/paraphrase-MiniLM-L3-v2", 5734.318427972881), |
| 26 | + ("sentence-transformers/distiluse-base-multilingual-cased-v1", 3487.3319366004903), |
| 27 | + ("sentence-transformers/distiluse-base-multilingual-cased-v2", 3807.2486282025716), |
| 28 | + ] |
| 29 | +else: |
| 30 | + # Gaudi1 CI baselines |
| 31 | + MODELS_TO_TEST = [ |
| 32 | + ("sentence-transformers/all-mpnet-base-v2", 164.36556936723508), |
| 33 | + ("sentence-transformers/multi-qa-mpnet-base-dot-v1", 116.82789535569364), |
| 34 | + ("sentence-transformers/all-distilroberta-v1", 226.90237421623164), |
| 35 | + ("sentence-transformers/all-MiniLM-L12-v2", 1252.6261862281467), |
| 36 | + ("sentence-transformers/multi-qa-distilbert-cos-v1", 216.47035182888888), |
| 37 | + ("sentence-transformers/all-MiniLM-L6-v2", 1109.160132821451), |
| 38 | + ("sentence-transformers/multi-qa-MiniLM-L6-cos-v1", 471.14320842607674), |
| 39 | + ("sentence-transformers/paraphrase-multilingual-mpnet-base-v2", 518.4762252952173), |
| 40 | + ("sentence-transformers/paraphrase-albert-small-v2", 1139.806075824319), |
| 41 | + ("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", 1253.06776127632), |
| 42 | + ("sentence-transformers/paraphrase-MiniLM-L3-v2", 3029.398417051629), |
| 43 | + ("sentence-transformers/distiluse-base-multilingual-cased-v1", 947.844857744754), |
| 44 | + ("sentence-transformers/distiluse-base-multilingual-cased-v2", 947.7317550605878), |
| 45 | + ] |
| 46 | + |
| 47 | + |
| 48 | +def _test_sentence_transformers( |
| 49 | + model_name: str, |
| 50 | + baseline: float, |
| 51 | +): |
| 52 | + model = SentenceTransformer(model_name) |
| 53 | + |
| 54 | + nli_dataset_path = "/tmp/datasets/AllNLI.tsv.gz" |
| 55 | + sentences = set() |
| 56 | + max_sentences = 10000 |
| 57 | + |
| 58 | + # Download datasets if needed |
| 59 | + if not os.path.exists(nli_dataset_path): |
| 60 | + util.http_get("https://sbert.net/datasets/AllNLI.tsv.gz", nli_dataset_path) |
| 61 | + |
| 62 | + with gzip.open(nli_dataset_path, "rt", encoding="utf8") as fIn: |
| 63 | + reader = csv.DictReader(fIn, delimiter="\t", quoting=csv.QUOTE_NONE) |
| 64 | + for row in reader: |
| 65 | + sentences.add(row["sentence1"]) |
| 66 | + if len(sentences) >= max_sentences: |
| 67 | + break |
| 68 | + |
| 69 | + sentences = list(sentences) |
| 70 | + |
| 71 | + for i in range(2): |
| 72 | + start_time = time.perf_counter() |
| 73 | + _ = model.encode(sentences, batch_size=32) |
| 74 | + end_time = time.perf_counter() |
| 75 | + diff_time = end_time - start_time |
| 76 | + measured_throughput = len(sentences) / diff_time |
| 77 | + # Only assert the last measured throughtput as the first iteration is used as a warmup |
| 78 | + assert measured_throughput >= (2 - TIME_PERF_FACTOR) * baseline |
| 79 | + |
| 80 | + |
| 81 | +@pytest.mark.parametrize("model_name, baseline", MODELS_TO_TEST) |
| 82 | +def test_compute_embeddings_throughput(model_name: str, baseline: float): |
| 83 | + _test_sentence_transformers(model_name, baseline) |
0 commit comments