Skip to content

Commit ca09e65

Browse files
mohammadasghariDeepMind
authored and
DeepMind
committed
Changed ActiveLearner to be generic in Data
PiperOrigin-RevId: 533570248
1 parent dbc1ea7 commit ca09e65

File tree

2 files changed

+7
-7
lines changed

2 files changed

+7
-7
lines changed

enn/active_learning/base.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,13 @@
1919
import typing as tp
2020

2121
import chex
22+
from enn import base as enn_base
2223
from enn import networks
23-
from enn.datasets import base as ds_base
2424
import haiku as hk
2525
import typing_extensions
2626

2727

28-
class ActiveLearner(abc.ABC):
28+
class ActiveLearner(abc.ABC, tp.Generic[enn_base.Data]):
2929
"""Samples a batch from a pool of data for learning.
3030
3131
An active learner selects an "acquisition batch" with acquisition_size
@@ -38,9 +38,9 @@ def sample_batch(
3838
self,
3939
params: hk.Params,
4040
state: hk.State,
41-
batch: ds_base.ArrayBatch,
41+
batch: enn_base.Data,
4242
key: chex.PRNGKey,
43-
) -> ds_base.ArrayBatch:
43+
) -> enn_base.Data:
4444
"""Samples a batch from a pool of data for learning."""
4545

4646
@property
@@ -57,13 +57,13 @@ def acquisition_size(self, size: int) -> None:
5757
PriorityOutput = tp.Tuple[chex.Array, tp.Dict[str, chex.Array]]
5858

5959

60-
class PriorityFn(typing_extensions.Protocol):
60+
class PriorityFn(typing_extensions.Protocol[enn_base.Data]):
6161

6262
def __call__(
6363
self,
6464
params: hk.Params,
6565
state: hk.State,
66-
batch: ds_base.ArrayBatch,
66+
batch: enn_base.Data,
6767
key: chex.PRNGKey,
6868
) -> PriorityOutput:
6969
"""Assigns a priority score to a batch."""

enn/active_learning/prioritized.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
import jax.numpy as jnp
2828

2929

30-
class PrioritizedBatcher(base.ActiveLearner):
30+
class PrioritizedBatcher(base.ActiveLearner[datasets.ArrayBatch]):
3131
"""Prioritizes bathces based on a priority fn."""
3232

3333
def __init__(

0 commit comments

Comments
 (0)