diff --git a/lmms_eval/api/samplers.py b/lmms_eval/api/samplers.py index c9084e9b3..9c7b3bf7a 100644 --- a/lmms_eval/api/samplers.py +++ b/lmms_eval/api/samplers.py @@ -19,6 +19,10 @@ def get_dataset(self) -> datasets.Dataset: if self.fewshot_indices: self.dataset = self.dataset.select(self.fewshot_indices) return self.dataset + + def sample(self, n, rnd): + indices = rnd.sample(range(len(self.get_dataset())), n) + return self.get_dataset().select(indices) def __getitem__(self, item): return self.get_dataset()[item] @@ -78,7 +82,7 @@ def sample(self, n): Draw `n` samples from our fewshot docs. This method should be overridden by subclasses. """ - return self.rnd.sample(self.docs.get_dataset(), n) + return self.docs.sample(n) class FirstNSampler(ContextSampler): diff --git a/lmms_eval/api/task.py b/lmms_eval/api/task.py index 2fff7252a..abd36470f 100644 --- a/lmms_eval/api/task.py +++ b/lmms_eval/api/task.py @@ -553,7 +553,7 @@ def __init__(self, model_name) -> None: # TODO no super() call here self._filters = [build_filter_ensemble("none", [["take_first", None]])] ########################################## # TODO: for test, will delete later - if self.config.task == "flickr30k_test": + if self.config.task == "textvqa_test": pass else: pass