Skip to content

Commit

Permalink
Merge branch 'main' into add_memory_usage
Browse files Browse the repository at this point in the history
  • Loading branch information
Samoed authored Feb 5, 2025
2 parents 8919302 + fc6696f commit 609b7e9
Show file tree
Hide file tree
Showing 24 changed files with 146 additions and 53 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:

- uses: actions/setup-python@v4
with:
python-version: "3.9"
python-version: "3.10"
cache: "pip"

- name: Install dependencies
Expand Down
9 changes: 8 additions & 1 deletion mteb/abstasks/AbsTask.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json
import logging
import random
import warnings
from abc import ABC, abstractmethod
from collections.abc import Sequence
from copy import copy
Expand Down Expand Up @@ -68,7 +69,9 @@ class AbsTask(ABC):
def __init__(self, seed: int = 42, **kwargs: Any):
self.save_suffix = kwargs.get("save_suffix", "")
if self.save_suffix:
logger.warning("`save_suffix` will be removed in v2.0.0.")
warnings.warn(
"`save_suffix` will be removed in v2.0.0.", DeprecationWarning
)

self.seed = seed
random.seed(self.seed)
Expand Down Expand Up @@ -252,6 +255,10 @@ def _calculate_metrics_from_split(

@property
def metadata_dict(self) -> dict[str, Any]:
warnings.warn(
"`metadata_dict` will be removed in v2.0. Use task.metadata instead.",
DeprecationWarning,
)
return dict(self.metadata)

@property
Expand Down
25 changes: 24 additions & 1 deletion mteb/abstasks/AbsTaskClassification.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import logging
import warnings
from collections import Counter, defaultdict
from typing import Any

Expand Down Expand Up @@ -75,15 +76,31 @@ def __init__(
**kwargs,
):
super().__init__(**kwargs)
if method != "logReg":
warnings.warn(
"Passing `method` to AbsTaskClassification is deprecated and will be removed in v2.0.0.",
DeprecationWarning,
)
self.method = method

if n_experiments:
warnings.warn(
"Passing `n_experiments` to AbsTaskClassification is deprecated and will be removed in v2.0.0.",
DeprecationWarning,
)

# Bootstrap parameters
self.n_experiments: int = ( # type: ignore
n_experiments
if n_experiments is not None
else self.metadata_dict.get("n_experiments", 10)
)

if k != 3:
warnings.warn(
"Passing `k` to AbsTaskClassification is deprecated and will be removed in v2.0.0.",
DeprecationWarning,
)
# kNN parameters
self.k = k

Expand All @@ -103,6 +120,12 @@ def evaluate(
if not self.data_loaded:
self.load_data()

if train_split != "train":
warnings.warn(
"Passing `train_split` to AbsTaskClassification.evaluate is deprecated and will be removed in v2.0.0.",
DeprecationWarning,
)

scores = {}
hf_subsets = list(self.dataset) if self.is_multilingual else ["default"]
if subsets_to_run is not None:
Expand Down Expand Up @@ -150,7 +173,7 @@ def _evaluate_subset(
) # we store idxs to make the shuffling reproducible
for i in range(self.n_experiments):
logger.info(
"=" * 10 + f" Experiment {i+1}/{self.n_experiments} " + "=" * 10
"=" * 10 + f" Experiment {i + 1}/{self.n_experiments} " + "=" * 10
)
# Bootstrap `self.samples_per_label` samples per label for each split
X_sampled, y_sampled, idxs = self._undersample_data(
Expand Down
5 changes: 5 additions & 0 deletions mteb/abstasks/AbsTaskInstructionRetrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json
import logging
import os
import warnings
from collections import defaultdict
from time import time
from typing import Any
Expand Down Expand Up @@ -319,6 +320,10 @@ def __init__(
self.do_length_ablation = kwargs.get("do_length_ablation", False)
if self.do_length_ablation:
logger.info("Running length ablation also...")
warnings.warn(
"`AbsTaskInstructionRetrieval` will be merged with Retrieval in v2.0.0.",
DeprecationWarning,
)

def load_data(self, **kwargs):
if self.data_loaded:
Expand Down
14 changes: 13 additions & 1 deletion mteb/abstasks/AbsTaskMultilabelClassification.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import itertools
import logging
import warnings
from collections import Counter, defaultdict
from typing import Any

Expand Down Expand Up @@ -103,6 +104,11 @@ def __init__(
super().__init__(**kwargs)
self.batch_size = batch_size

if n_experiments:
warnings.warn(
"Passing `n_experiments` to AbsTaskMultilabelClassification is deprecated and will be removed in v2.0.0.",
DeprecationWarning,
)
# Bootstrap parameters
self.n_experiments = n_experiments or getattr(self, "n_experiments", 10)

Expand All @@ -129,6 +135,12 @@ def evaluate(
if not self.data_loaded:
self.load_data()

if train_split != "train":
warnings.warn(
"Passing `train_split` to AbsTaskClassification.evaluate is deprecated and will be removed in v2.0.0.",
DeprecationWarning,
)

scores = {}
hf_subsets = list(self.dataset) if self.is_multilingual else ["default"]
# If subsets_to_run is specified, filter the hf_subsets accordingly
Expand Down Expand Up @@ -215,7 +227,7 @@ def _evaluate_subset(
for i_experiment, sample_indices in enumerate(train_samples):
logger.info(
"=" * 10
+ f" Experiment {i_experiment+1}/{self.n_experiments} "
+ f" Experiment {i_experiment + 1}/{self.n_experiments} "
+ "=" * 10
)
X_train = np.stack([unique_train_embeddings[idx] for idx in sample_indices])
Expand Down
5 changes: 5 additions & 0 deletions mteb/abstasks/AbsTaskReranking.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import warnings
from typing import Any

from datasets import Dataset
Expand Down Expand Up @@ -70,6 +71,10 @@ class AbsTaskReranking(AbsTask):
abstask_prompt = "Retrieve text based on user query."

def __init__(self, **kwargs):
warnings.warn(
"`AbsTaskReranking` will be merged with AbsTaskRetrieval in v2.0.0.",
DeprecationWarning,
)
super().__init__(**kwargs)

def _evaluate_subset(
Expand Down
5 changes: 5 additions & 0 deletions mteb/abstasks/AbsTaskRetrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json
import logging
import os
import warnings
from collections import defaultdict
from pathlib import Path
from time import time
Expand Down Expand Up @@ -44,6 +45,10 @@ def __init__(
# By default fetch qrels from same repo not a second repo with "-qrels" like in original
self.hf_repo_qrels = hf_repo_qrels if hf_repo_qrels else hf_repo
else:
warnings.warn(
"Loading from local files will be removed in v2.0.0.",
DeprecationWarning,
)
# data folder would contain these files:
# (1) fiqa/corpus.jsonl (format: jsonlines)
# (2) fiqa/queries.jsonl (format: jsonlines)
Expand Down
2 changes: 1 addition & 1 deletion mteb/abstasks/AbsTaskSpeedTask.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def get_system_info(self) -> dict[str, str]:
list_gpus.append(
{
"gpu_name": gpu.name,
"gpu_total_memory": f"{gpu.memoryTotal/1024.0} GB",
"gpu_total_memory": f"{gpu.memoryTotal / 1024.0} GB",
}
)
info["gpu_info"] = list_gpus
Expand Down
2 changes: 1 addition & 1 deletion mteb/abstasks/Image/AbsTaskImageClassification.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def _evaluate_subset(
) # we store idxs to make the shuffling reproducible
for i in range(self.n_experiments):
logger.info(
"=" * 10 + f" Experiment {i+1}/{self.n_experiments} " + "=" * 10
"=" * 10 + f" Experiment {i + 1}/{self.n_experiments} " + "=" * 10
)
# Bootstrap `self.samples_per_label` samples per label for each split
undersampled_train, idxs = self._undersample_data(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def _evaluate_subset(
for i_experiment, sample_indices in enumerate(train_samples):
logger.info(
"=" * 10
+ f" Experiment {i_experiment+1}/{self.n_experiments} "
+ f" Experiment {i_experiment + 1}/{self.n_experiments} "
+ "=" * 10
)
X_train = np.stack([unique_train_embeddings[idx] for idx in sample_indices])
Expand Down
5 changes: 5 additions & 0 deletions mteb/abstasks/MultiSubsetLoader.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import warnings

import datasets


Expand All @@ -8,6 +10,9 @@ def load_data(self, **kwargs):
"""Load dataset containing multiple subsets from HuggingFace hub"""
if self.data_loaded:
return
warnings.warn(
"`MultiSubsetLoader` will be removed in v2.0.0.", DeprecationWarning
)

if hasattr(self, "fast_loading") and self.fast_loading:
self.fast_load()
Expand Down
7 changes: 7 additions & 0 deletions mteb/abstasks/MultilingualTask.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
from __future__ import annotations

import warnings

from .AbsTask import AbsTask
from .MultiSubsetLoader import MultiSubsetLoader


class MultilingualTask(MultiSubsetLoader, AbsTask):
def __init__(self, hf_subsets: list[str] | None = None, **kwargs):
super().__init__(**kwargs)
warnings.warn(
"`MultilingualTask` will be removed in v2.0. In the future, checking whether a task is multilingual"
" will be based solely on `metadata.eval_langs`, which should be a dictionary for multilingual tasks.",
DeprecationWarning,
)
if isinstance(hf_subsets, list):
hf_subsets = [
lang for lang in hf_subsets if lang in self.metadata.eval_langs
Expand Down
3 changes: 2 additions & 1 deletion mteb/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
import argparse
import json
import logging
import warnings
from pathlib import Path

import torch
Expand Down Expand Up @@ -368,7 +369,7 @@ def main():

# If no subcommand is provided, default to run with a deprecation warning
if not hasattr(args, "func"):
logger.warning(
warnings.warn(
"Using `mteb` without a subcommand is deprecated. Use `mteb run` instead.",
DeprecationWarning,
)
Expand Down
21 changes: 11 additions & 10 deletions mteb/evaluation/MTEB.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
import os
import traceback
import warnings
from collections.abc import Iterable, Sequence
from copy import copy, deepcopy
from datetime import datetime
Expand Down Expand Up @@ -96,29 +97,29 @@ def deprecation_warning(
self, task_types, task_categories, task_langs, tasks, version
):
if task_types is not None:
logger.warning(
"The `task_types` argument is deprecated and will be removed in the next release. "
warnings.warn(
"The `task_types` argument is deprecated and will be removed in the 2.0 release. "
+ "Please use `tasks = mteb.get_tasks(... task_types = [...])` to filter tasks instead."
)
if task_categories is not None:
logger.warning(
"The `task_categories` argument is deprecated and will be removed in the next release. "
warnings.warn(
"The `task_categories` argument is deprecated and will be removed in the 2.0 release. "
+ "Please use `tasks = mteb.get_tasks(... categories = [...])` to filter tasks instead."
)
if task_langs is not None:
logger.warning(
"The `task_langs` argument is deprecated and will be removed in the next release. "
warnings.warn(
"The `task_langs` argument is deprecated and will be removed in the 2.0 release. "
+ "Please use `tasks = mteb.get_tasks(... languages = [...])` to filter tasks instead. "
+ "Note that this uses 3 letter language codes (ISO 639-3)."
)
if version is not None:
logger.warning(
"The `version` argument is deprecated and will be removed in the next release."
warnings.warn(
"The `version` argument is deprecated and will be removed in the 2.0 release."
)
task_contains_strings = any(isinstance(x, str) for x in tasks or [])
if task_contains_strings:
logger.warning(
"Passing task names as strings is deprecated and will be removed in the next release. "
warnings.warn(
"Passing task names as strings is deprecated and will be removed in 2.0 release. "
+ "Please use `tasks = mteb.get_tasks(tasks=[...])` method to get tasks instead."
)

Expand Down
15 changes: 15 additions & 0 deletions mteb/evaluation/evaluators/ClassificationEvaluator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import logging
import warnings
from typing import Any

import numpy as np
Expand Down Expand Up @@ -40,6 +41,10 @@ def __init__(
):
super().__init__(**kwargs)
if limit is not None:
warnings.warn(
"Limiting the number of samples with `limit` for evaluation will be removed in v2.0.0.",
DeprecationWarning,
)
sentences_train = sentences_train[:limit]
y_train = y_train[:limit]
sentences_test = sentences_test[:limit]
Expand Down Expand Up @@ -113,6 +118,11 @@ def __init__(
):
super().__init__(**kwargs)
if limit is not None:
warnings.warn(
"Limiting the number of samples with `limit` for evaluation will be removed in v2.0.0.",
DeprecationWarning,
)

sentences_train = sentences_train[:limit]
y_train = y_train[:limit]
sentences_test = sentences_test[:limit]
Expand Down Expand Up @@ -268,6 +278,11 @@ def __init__(
self.encode_kwargs["batch_size"] = 32

if limit is not None:
warnings.warn(
"Limiting the number of samples with `limit` for evaluation will be removed in v2.0.0.",
DeprecationWarning,
)

sentences_train = sentences_train[:limit]
y_train = y_train[:limit]
sentences_test = sentences_test[:limit]
Expand Down
Loading

0 comments on commit 609b7e9

Please sign in to comment.