Skip to content

Commit

Permalink
moved stratified to abstask
Browse files Browse the repository at this point in the history
  • Loading branch information
dokato committed May 24, 2024
1 parent 9451e38 commit e89f14b
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 53 deletions.
41 changes: 22 additions & 19 deletions mteb/abstasks/AbsTask.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
import datasets
import numpy as np
import torch
from datasets import DatasetDict
from sklearn.model_selection import train_test_split
from datasets import Dataset, DatasetDict
from sklearn.preprocessing import MultiLabelBinarizer

from mteb.abstasks.stratification import _iterative_train_test_split
from mteb.abstasks.TaskMetadata import TaskMetadata
from mteb.encoder_interface import Encoder, EncoderWithQueryCorpusEncode
from mteb.languages import LanguageScripts
Expand All @@ -22,28 +22,31 @@


def _multilabel_subsampling(
dataset_dict: datasets.DatasetDict,
dataset_dict: DatasetDict,
seed: int,
splits: list[str] = ["test"],
label: str = "label",
n_samples: int = 2048,
) -> datasets.DatasetDict:
"""Startified subsampling for multilabel problems."""
) -> DatasetDict:
"""Multilabel subsampling the dataset with stratification by the supplied label.
Returns a DatasetDict object.
Args:
dataset_dict: the DatasetDict object.
seed: the random seed.
splits: the splits of the dataset.
label: the label with which the stratified sampling is based on.
n_samples: Optional, number of samples to subsample. Default is max_n_samples.
"""
for split in splits:
labels = dataset_dict[split][label]
encoded_labels = MultiLabelBinarizer().fit_transform(labels)
idxs = np.arange(len(labels))
try:
idxs, *_ = train_test_split(
idxs,
encoded_labels,
stratify=encoded_labels,
random_state=seed,
train_size=n_samples,
)
except ValueError:
logger.warn("Couldn't subsample, continuing with full split.")
dataset_dict.update({split: dataset_dict[split].select(idxs)})
n_split = len(dataset_dict[split])
X_np = np.arange(n_split).reshape((-1, 1))
binarizer = MultiLabelBinarizer()
labels_np = binarizer.fit_transform(dataset_dict[split][label])
_, test_idx = _iterative_train_test_split(
X_np, labels_np, test_size=n_samples / n_split, random_state=seed
)
dataset_dict.update({split: Dataset.from_dict(dataset_dict[split][test_idx])})
return dataset_dict


Expand Down
34 changes: 0 additions & 34 deletions mteb/abstasks/AbsTaskMultilabelClassification.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,12 @@
from typing import Any

import numpy as np
from datasets import Dataset, DatasetDict
from sklearn.base import ClassifierMixin, clone
from sklearn.metrics import f1_score, label_ranking_average_precision_score
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.preprocessing import MultiLabelBinarizer

from mteb.abstasks.stratification import _iterative_train_test_split

from ..MTEBResults import ScoresDict
from .AbsTask import AbsTask

Expand Down Expand Up @@ -149,34 +146,3 @@ def _undersample_data_indices(self, y, samples_per_label, idxs=None):
for label in y[i]:
label_counter[label] += 1
return sample_indices, idxs

def stratified_multilabel_subsampling(
self,
dataset_dict: DatasetDict,
seed: int,
splits: list[str] = ["test"],
label: str = "label",
n_samples: int = 2048,
) -> DatasetDict:
"""Multilabel subsampling the dataset with stratification by the supplied label.
Returns a DatasetDict object.
Args:
dataset_dict: the DatasetDict object.
seed: the random seed.
splits: the splits of the dataset.
label: the label with which the stratified sampling is based on.
n_samples: Optional, number of samples to subsample. Default is max_n_samples.
"""
for split in splits:
n_split = len(dataset_dict[split])
X_np = np.arange(n_split).reshape((-1, 1))
binarizer = MultiLabelBinarizer()
labels_np = binarizer.fit_transform(dataset_dict[split][label])
_, test_idx = _iterative_train_test_split(
X_np, labels_np, test_size=n_samples / n_split, random_state=seed
)
dataset_dict.update(
{split: Dataset.from_dict(dataset_dict[split][test_idx])}
)
return dataset_dict

0 comments on commit e89f14b

Please sign in to comment.