File tree 2 files changed +7
-7
lines changed
2 files changed +7
-7
lines changed Original file line number Diff line number Diff line change 19
19
import typing as tp
20
20
21
21
import chex
22
+ from enn import base as enn_base
22
23
from enn import networks
23
- from enn .datasets import base as ds_base
24
24
import haiku as hk
25
25
import typing_extensions
26
26
27
27
28
- class ActiveLearner (abc .ABC ):
28
+ class ActiveLearner (abc .ABC , tp . Generic [ enn_base . Data ] ):
29
29
"""Samples a batch from a pool of data for learning.
30
30
31
31
An active learner selects an "acquisition batch" with acquisition_size
@@ -38,9 +38,9 @@ def sample_batch(
38
38
self ,
39
39
params : hk .Params ,
40
40
state : hk .State ,
41
- batch : ds_base . ArrayBatch ,
41
+ batch : enn_base . Data ,
42
42
key : chex .PRNGKey ,
43
- ) -> ds_base . ArrayBatch :
43
+ ) -> enn_base . Data :
44
44
"""Samples a batch from a pool of data for learning."""
45
45
46
46
@property
@@ -57,13 +57,13 @@ def acquisition_size(self, size: int) -> None:
57
57
PriorityOutput = tp .Tuple [chex .Array , tp .Dict [str , chex .Array ]]
58
58
59
59
60
- class PriorityFn (typing_extensions .Protocol ):
60
+ class PriorityFn (typing_extensions .Protocol [ enn_base . Data ] ):
61
61
62
62
def __call__ (
63
63
self ,
64
64
params : hk .Params ,
65
65
state : hk .State ,
66
- batch : ds_base . ArrayBatch ,
66
+ batch : enn_base . Data ,
67
67
key : chex .PRNGKey ,
68
68
) -> PriorityOutput :
69
69
"""Assigns a priority score to a batch."""
Original file line number Diff line number Diff line change 27
27
import jax .numpy as jnp
28
28
29
29
30
- class PrioritizedBatcher (base .ActiveLearner ):
30
+ class PrioritizedBatcher (base .ActiveLearner [ datasets . ArrayBatch ] ):
31
31
"""Prioritizes bathces based on a priority fn."""
32
32
33
33
def __init__ (
You can’t perform that action at this time.
0 commit comments