From d938f8f0def37e1de16c7c5a42387fb43e79fdda Mon Sep 17 00:00:00 2001 From: Mohammad Asghari Date: Fri, 19 May 2023 15:47:02 -0700 Subject: [PATCH] Changed ActiveLearner to be generic in Data PiperOrigin-RevId: 533570248 --- enn/active_learning/base.py | 12 ++++++------ enn/active_learning/prioritized.py | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/enn/active_learning/base.py b/enn/active_learning/base.py index adfd267..7fef99c 100644 --- a/enn/active_learning/base.py +++ b/enn/active_learning/base.py @@ -19,13 +19,13 @@ import typing as tp import chex +from enn import base as enn_base from enn import networks -from enn.datasets import base as ds_base import haiku as hk import typing_extensions -class ActiveLearner(abc.ABC): +class ActiveLearner(abc.ABC, tp.Generic[enn_base.Data]): """Samples a batch from a pool of data for learning. An active learner selects an "acquisition batch" with acquisition_size @@ -38,9 +38,9 @@ def sample_batch( self, params: hk.Params, state: hk.State, - batch: ds_base.ArrayBatch, + batch: enn_base.Data, key: chex.PRNGKey, - ) -> ds_base.ArrayBatch: + ) -> enn_base.Data: """Samples a batch from a pool of data for learning.""" @property @@ -57,13 +57,13 @@ def acquisition_size(self, size: int) -> None: PriorityOutput = tp.Tuple[chex.Array, tp.Dict[str, chex.Array]] -class PriorityFn(typing_extensions.Protocol): +class PriorityFn(typing_extensions.Protocol[enn_base.Data]): def __call__( self, params: hk.Params, state: hk.State, - batch: ds_base.ArrayBatch, + batch: enn_base.Data, key: chex.PRNGKey, ) -> PriorityOutput: """Assigns a priority score to a batch.""" diff --git a/enn/active_learning/prioritized.py b/enn/active_learning/prioritized.py index 824895b..4a3a3f6 100644 --- a/enn/active_learning/prioritized.py +++ b/enn/active_learning/prioritized.py @@ -27,7 +27,7 @@ import jax.numpy as jnp -class PrioritizedBatcher(base.ActiveLearner): +class PrioritizedBatcher(base.ActiveLearner[datasets.ArrayBatch]): """Prioritizes bathces based on a priority fn.""" def __init__(