Skip to content

Commit

Permalink
make lint
Browse files Browse the repository at this point in the history
  • Loading branch information
isaac-chung committed Feb 5, 2025
1 parent 829272c commit cd2839d
Show file tree
Hide file tree
Showing 9 changed files with 63 additions and 63 deletions.
12 changes: 6 additions & 6 deletions mteb/abstasks/AbsTaskClusteringFast.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,9 +320,9 @@ def convert_to_fast(

# check that it is the same distribution
row_label_set = set(lab)
assert row_label_set.issubset(all_labels_set), (
"The clusters are not sampled from the same distribution as they have different labels."
)
assert row_label_set.issubset(
all_labels_set
), "The clusters are not sampled from the same distribution as they have different labels."

for l, s in zip(lab, sents):
if s not in sent_set:
Expand Down Expand Up @@ -353,6 +353,6 @@ def check_label_distribution(ds: DatasetDict) -> None:

# check that it is the same distribution
row_label_set = set(lab)
assert row_label_set.issubset(all_labels_set), (
"The clusters are not sampled from the same distribution as they have different labels."
)
assert row_label_set.issubset(
all_labels_set
), "The clusters are not sampled from the same distribution as they have different labels."
6 changes: 3 additions & 3 deletions mteb/abstasks/AbsTaskInstructionRetrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,9 +374,9 @@ def load_data(self, **kwargs):
doc["id"]: {"title": doc["title"], "text": doc["text"]}
for doc in corpus
}
assert len(top_ranked) == len(queries), (
f"Top ranked not loaded properly! Expected {len(self.queries)} but got {len(self.top_ranked)}."
)
assert (
len(top_ranked) == len(queries)
), f"Top ranked not loaded properly! Expected {len(self.queries)} but got {len(self.top_ranked)}."

(
self.corpus[split],
Expand Down
6 changes: 3 additions & 3 deletions mteb/evaluation/MTEB.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@ def __init__(
if isinstance(tasks[0], Benchmark):
self.benchmarks = tasks
self._tasks = self._tasks = list(chain.from_iterable(tasks)) # type: ignore
assert task_types is None and task_categories is None, (
"Cannot specify both `tasks` and `task_types`/`task_categories`"
)
assert (
task_types is None and task_categories is None
), "Cannot specify both `tasks` and `task_types`/`task_categories`"
else:
self._task_types = task_types
self._task_categories = task_categories
Expand Down
6 changes: 3 additions & 3 deletions mteb/models/rerankers_custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,9 @@ def predict(self, input_to_rerank, **kwargs):
assert len(queries) == len(passages)
query_passage_tuples = list(zip(queries, passages))
scores = self.model.compute_score(query_passage_tuples, normalize=True)
assert len(scores) == len(queries), (
f"Expected {len(queries)} scores, got {len(scores)}"
)
assert len(scores) == len(
queries
), f"Expected {len(queries)} scores, got {len(scores)}"
return scores


Expand Down
12 changes: 6 additions & 6 deletions mteb/tasks/BitextMining/vie/VieMedEVBitextMining.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,19 +54,19 @@ def dataset_transform(self):
# Pairs are in two halves
en_sentences = all_texts[:mid_index]
vie_sentences = all_texts[mid_index:]
assert len(en_sentences) == len(vie_sentences), (
"The split does not result in equal halves."
)
assert len(en_sentences) == len(
vie_sentences
), "The split does not result in equal halves."

# Downsample
indices = list(range(len(en_sentences)))
random.shuffle(indices)
sample_indices = indices[:TEST_SAMPLES]
en_sentences = [en_sentences[i] for i in sample_indices]
vie_sentences = [vie_sentences[i] for i in sample_indices]
assert len(en_sentences) == len(vie_sentences) == TEST_SAMPLES, (
f"Exceeded {TEST_SAMPLES} samples for 'test' split."
)
assert (
len(en_sentences) == len(vie_sentences) == TEST_SAMPLES
), f"Exceeded {TEST_SAMPLES} samples for 'test' split."

# Return dataset
ds["test"] = datasets.Dataset.from_dict(
Expand Down
6 changes: 3 additions & 3 deletions scripts/running_model/check_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,9 @@ def normalize_results(results):
# [t.task_name for t in mteb_results['GritLM/GritLM-7B']["13f00a0e36500c80ce12870ea513846a066004af"] if t.task_name == "SemRel24STS"]
# it is there

assert [len(revisions.keys()) == 1 for model, revisions in mteb_results.items()], (
"Some models have more than one revision"
)
assert [
len(revisions.keys()) == 1 for model, revisions in mteb_results.items()
], "Some models have more than one revision"

results_df = results_to_dataframe(mteb_results)

Expand Down
36 changes: 18 additions & 18 deletions tests/test_TaskMetadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,17 +516,17 @@ def test_disallow_trust_remote_code_in_new_datasets():
"SwednClusteringS2S",
]

assert 135 == len(exceptions), (
"The number of exceptions has changed. Please do not add new datasets to this list."
)
assert (
135 == len(exceptions)
), "The number of exceptions has changed. Please do not add new datasets to this list."

exceptions = []

for task in get_tasks():
if task.metadata.dataset.get("trust_remote_code", False):
assert task.metadata.name not in exceptions, (
f"Dataset {task.metadata.name} should not trust remote code"
)
assert (
task.metadata.name not in exceptions
), f"Dataset {task.metadata.name} should not trust remote code"


def test_empy_descriptive_stat_in_new_datasets():
Expand Down Expand Up @@ -1088,26 +1088,26 @@ def test_empy_descriptive_stat_in_new_datasets():
"SummEvalFrSummarization.v2",
]

assert 553 == len(exceptions), (
"The number of exceptions has changed. Please do not add new datasets to this list."
)
assert (
553 == len(exceptions)
), "The number of exceptions has changed. Please do not add new datasets to this list."

exceptions = []

for task in get_tasks():
if task.metadata.descriptive_stats is None:
assert task.metadata.name not in exceptions, (
f"Dataset {task.metadata.name} should have descriptive stats"
)
assert (
task.metadata.name not in exceptions
), f"Dataset {task.metadata.name} should have descriptive stats"


@pytest.mark.parametrize("task", get_tasks())
def test_eval_langs_correctly_specified(task: AbsTask):
if task.is_multilingual:
assert isinstance(task.metadata.eval_langs, dict), (
f"{task.metadata.name} should have eval_langs as a dict"
)
assert isinstance(
task.metadata.eval_langs, dict
), f"{task.metadata.name} should have eval_langs as a dict"
else:
assert isinstance(task.metadata.eval_langs, list), (
f"{task.metadata.name} should have eval_langs as a list"
)
assert isinstance(
task.metadata.eval_langs, list
), f"{task.metadata.name} should have eval_langs as a list"
36 changes: 18 additions & 18 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,18 @@ def test_available_tasks():
command = f"{sys.executable} -m mteb available_tasks"
result = subprocess.run(command, shell=True, capture_output=True, text=True)
assert result.returncode == 0, "Command failed"
assert "Banking77Classification" in result.stdout, (
"Sample task Banking77Classification task not found in available tasks"
)
assert (
"Banking77Classification" in result.stdout
), "Sample task Banking77Classification task not found in available tasks"


def test_available_benchmarks():
command = f"{sys.executable} -m mteb available_benchmarks"
result = subprocess.run(command, shell=True, capture_output=True, text=True)
assert result.returncode == 0, "Command failed"
assert "MTEB(eng, classic)" in result.stdout, (
"Sample benchmark MTEB(eng, classic) task not found in available benchmarks"
)
assert (
"MTEB(eng, classic)" in result.stdout
), "Sample benchmark MTEB(eng, classic) task not found in available benchmarks"


run_task_fixures = [
Expand Down Expand Up @@ -75,12 +75,12 @@ def test_run_task(
f"tests/results/test_model/{model_name_as_path}/{model_revision}"
)
assert results_path.exists(), "Output folder not created"
assert "model_meta.json" in [f.name for f in list(results_path.glob("*.json"))], (
"model_meta.json not found in output folder"
)
assert f"{task_name}.json" in [f.name for f in list(results_path.glob("*.json"))], (
f"{task_name} not found in output folder"
)
assert "model_meta.json" in [
f.name for f in list(results_path.glob("*.json"))
], "model_meta.json not found in output folder"
assert f"{task_name}.json" in [
f.name for f in list(results_path.glob("*.json"))
], f"{task_name} not found in output folder"


def test_create_meta():
Expand Down Expand Up @@ -117,9 +117,9 @@ def test_create_meta():
for key in frontmatter_gold:
assert key in frontmatter, f"Key {key} not found in output"

assert frontmatter[key] == frontmatter_gold[key], (
f"Value for {key} does not match"
)
assert (
frontmatter[key] == frontmatter_gold[key]
), f"Value for {key} does not match"

# ensure that the command line interface works as well
command = f"{sys.executable} -m mteb create_meta --results_folder {results} --output_path {output_path} --overwrite"
Expand Down Expand Up @@ -178,9 +178,9 @@ def test_create_meta_from_existing(existing_readme_name: str, gold_readme_name:
for key in frontmatter_gold:
assert key in frontmatter, f"Key {key} not found in output"

assert frontmatter[key] == frontmatter_gold[key], (
f"Value for {key} does not match"
)
assert (
frontmatter[key] == frontmatter_gold[key]
), f"Value for {key} does not match"
assert readme_output == gold_readme
# ensure that the command line interface works as well
command = f"{sys.executable} -m mteb create_meta --results_folder {results} --output_path {output_path} --from_existing {existing_readme} --overwrite"
Expand Down
6 changes: 3 additions & 3 deletions tests/test_tasks/test_all_abstasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,6 @@ def test_superseded_dataset_exists():
tasks = mteb.get_tasks(exclude_superseded=False)
for task in tasks:
if task.superseded_by:
assert task.superseded_by in TASKS_REGISTRY, (
f"{task} is superseded by {task.superseded_by} but {task.superseded_by} is not in the TASKS_REGISTRY"
)
assert (
task.superseded_by in TASKS_REGISTRY
), f"{task} is superseded by {task.superseded_by} but {task.superseded_by} is not in the TASKS_REGISTRY"

0 comments on commit cd2839d

Please sign in to comment.