Skip to content

Commit

Permalink
fix sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
pufanyi committed Mar 29, 2024
1 parent 96129e0 commit 93f56d9
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 2 deletions.
6 changes: 5 additions & 1 deletion lmms_eval/api/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ def get_dataset(self) -> datasets.Dataset:
if self.fewshot_indices:
self.dataset = self.dataset.select(self.fewshot_indices)
return self.dataset

def sample(self, n, rnd):
indices = rnd.sample(range(len(self.get_dataset())), n)
return self.get_dataset().select(indices)

def __getitem__(self, item):
return self.get_dataset()[item]
Expand Down Expand Up @@ -78,7 +82,7 @@ def sample(self, n):
Draw `n` samples from our fewshot docs. This method should be overridden by subclasses.
"""

return self.rnd.sample(self.docs.get_dataset(), n)
return self.docs.sample(n)


class FirstNSampler(ContextSampler):
Expand Down
2 changes: 1 addition & 1 deletion lmms_eval/api/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,7 @@ def __init__(self, model_name) -> None: # TODO no super() call here
self._filters = [build_filter_ensemble("none", [["take_first", None]])]
##########################################
# TODO: for test, will delete later
if self.config.task == "flickr30k_test":
if self.config.task == "textvqa_test":
pass
else:
pass
Expand Down

0 comments on commit 93f56d9

Please sign in to comment.