Skip to content

Commit

Permalink
fixed merged conflicts: stratified multilabel classification
Browse files Browse the repository at this point in the history
  • Loading branch information
dokato committed May 31, 2024
1 parent 5fa2aee commit 053cbaa
Show file tree
Hide file tree
Showing 12 changed files with 763 additions and 45 deletions.
3 changes: 3 additions & 0 deletions docs/mmteb/points/760.jsonl
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{"GitHub": "dokato", "Bug fixes": 2}
{"GitHub": "dokato", "New dataset": 6}
{"GitHub": "x-tabdeveloping", "Review PR": 2}
41 changes: 22 additions & 19 deletions mteb/abstasks/AbsTask.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
import datasets
import numpy as np
import torch
from datasets import DatasetDict
from sklearn.model_selection import train_test_split
from datasets import Dataset, DatasetDict
from sklearn.preprocessing import MultiLabelBinarizer

from mteb.abstasks.stratification import _iterative_train_test_split
from mteb.abstasks.TaskMetadata import TaskMetadata
from mteb.encoder_interface import Encoder, EncoderWithQueryCorpusEncode
from mteb.languages import LanguageScripts
Expand All @@ -22,28 +22,31 @@


def _multilabel_subsampling(
dataset_dict: datasets.DatasetDict,
dataset_dict: DatasetDict,
seed: int,
splits: list[str] = ["test"],
label: str = "label",
n_samples: int = 2048,
) -> datasets.DatasetDict:
"""Startified subsampling for multilabel problems."""
) -> DatasetDict:
"""Multilabel subsampling the dataset with stratification by the supplied label.
Returns a DatasetDict object.
Args:
dataset_dict: the DatasetDict object.
seed: the random seed.
splits: the splits of the dataset.
label: the label with which the stratified sampling is based on.
n_samples: Optional, number of samples to subsample. Default is max_n_samples.
"""
for split in splits:
labels = dataset_dict[split][label]
encoded_labels = MultiLabelBinarizer().fit_transform(labels)
idxs = np.arange(len(labels))
try:
idxs, *_ = train_test_split(
idxs,
encoded_labels,
stratify=encoded_labels,
random_state=seed,
train_size=n_samples,
)
except ValueError:
logger.warn("Couldn't subsample, continuing with full split.")
dataset_dict.update({split: dataset_dict[split].select(idxs)})
n_split = len(dataset_dict[split])
X_np = np.arange(n_split).reshape((-1, 1))
binarizer = MultiLabelBinarizer()
labels_np = binarizer.fit_transform(dataset_dict[split][label])
_, test_idx = _iterative_train_test_split(
X_np, labels_np, test_size=n_samples / n_split, random_state=seed
)
dataset_dict.update({split: Dataset.from_dict(dataset_dict[split][test_idx])})
return dataset_dict


Expand Down
4 changes: 2 additions & 2 deletions mteb/abstasks/AbsTaskMultilabelClassification.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,9 @@ def _evaluate_subset(
y_test = binarizer.fit_transform(eval_split["label"])
# Stratified subsampling of test set to 2000 examples.
try:
if len(test_text) > 2000:
if len(test_text) > 2048:
test_text, _, y_test, _ = train_test_split(
test_text, y_test, stratify=y_test, train_size=2000
test_text, y_test, stratify=y_test, train_size=2048
)
except ValueError:
logger.warn("Couldn't subsample, continuing with the entire test set.")
Expand Down
Loading

0 comments on commit 053cbaa

Please sign in to comment.