Skip to content

Commit

Permalink
merge conflicts fixed for stratification
Browse files Browse the repository at this point in the history
  • Loading branch information
dokato committed Jun 5, 2024
1 parent 437d6df commit d9c3710
Show file tree
Hide file tree
Showing 6 changed files with 475 additions and 24 deletions.
3 changes: 3 additions & 0 deletions docs/mmteb/points/760.jsonl
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{"GitHub": "dokato", "Bug fixes": 2}
{"GitHub": "dokato", "New dataset": 6}
{"GitHub": "x-tabdeveloping", "Review PR": 2}
45 changes: 22 additions & 23 deletions mteb/abstasks/AbsTask.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
import datasets
import numpy as np
import torch
from datasets import DatasetDict
from sklearn.model_selection import train_test_split
from datasets import Dataset, DatasetDict
from sklearn.preprocessing import MultiLabelBinarizer

from mteb.abstasks.stratification import _iterative_train_test_split
from mteb.abstasks.TaskMetadata import TaskMetadata
from mteb.encoder_interface import Encoder, EncoderWithQueryCorpusEncode
from mteb.languages import LanguageScripts
Expand All @@ -22,28 +22,31 @@


def _multilabel_subsampling(
dataset_dict: datasets.DatasetDict,
dataset_dict: DatasetDict,
seed: int,
splits: list[str] = ["test"],
label: str = "label",
n_samples: int = 2048,
) -> datasets.DatasetDict:
"""Startified subsampling for multilabel problems."""
) -> DatasetDict:
"""Multilabel subsampling the dataset with stratification by the supplied label.
Returns a DatasetDict object.
Args:
dataset_dict: the DatasetDict object.
seed: the random seed.
splits: the splits of the dataset.
label: the label with which the stratified sampling is based on.
n_samples: Optional, number of samples to subsample. Default is max_n_samples.
"""
for split in splits:
labels = dataset_dict[split][label]
encoded_labels = MultiLabelBinarizer().fit_transform(labels)
idxs = np.arange(len(labels))
try:
idxs, *_ = train_test_split(
idxs,
encoded_labels,
stratify=encoded_labels,
random_state=seed,
train_size=n_samples,
)
except ValueError:
logger.warning("Couldn't subsample, continuing with full split.")
dataset_dict.update({split: dataset_dict[split].select(idxs)})
n_split = len(dataset_dict[split])
X_np = np.arange(n_split).reshape((-1, 1))
binarizer = MultiLabelBinarizer()
labels_np = binarizer.fit_transform(dataset_dict[split][label])
_, test_idx = _iterative_train_test_split(
X_np, labels_np, test_size=n_samples / n_split, random_state=seed
)
dataset_dict.update({split: Dataset.from_dict(dataset_dict[split][test_idx])})
return dataset_dict


Expand Down Expand Up @@ -152,10 +155,6 @@ def stratified_subsampling(
raise e

for split in splits:
if len(dataset_dict[split][label]) < n_samples:
n_samples = len(dataset_dict[split][label]) - len(
set(dataset_dict[split][label])
)
dataset_dict.update(
{
split: dataset_dict[split].train_test_split(
Expand Down
Loading

0 comments on commit d9c3710

Please sign in to comment.