diff --git a/protoscribe/pmmx/utils/seqio_utils.py b/protoscribe/pmmx/utils/seqio_utils.py index b532f36..2914ad3 100644 --- a/protoscribe/pmmx/utils/seqio_utils.py +++ b/protoscribe/pmmx/utils/seqio_utils.py @@ -13,6 +13,7 @@ # limitations under the License. """General utility functions for t5x.""" + import time from typing import Any, Callable, Optional, Type @@ -87,6 +88,11 @@ def get_dataset( cfg.mixture_or_task_name not in mixtures): define_task_fn(cfg.mixture_or_task_name) + if isinstance(cfg.mixture_or_task_name, seqio.DatasetProviderBase): + mixture_or_task = cfg.mixture_or_task_name + else: + mixture_or_task = seqio.get_mixture_or_task(cfg.mixture_or_task_name) + shard_info = seqio.ShardInfo(index=shard_id, num_shards=num_shards) if cfg.seed is None: