Skip to content

Commit add94ff

Browse files
DeepMindmohammadasghari
DeepMind
authored andcommitted
Internal change.
PiperOrigin-RevId: 441861990 Change-Id: Ie82f0da222bd0c23b81084f53ba3d91558941739
1 parent 08968ff commit add94ff

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

55 files changed

+585
-50
lines changed

enn/experiments/neurips_2021/__init__.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# python3
21
# pylint: disable=g-bad-file-header
32
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
43
#

enn/experiments/neurips_2021/agent_factories.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# python3
21
# pylint: disable=g-bad-file-header
32
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
43
#

enn/experiments/neurips_2021/agents.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# python3
21
# pylint: disable=g-bad-file-header
32
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
43
#

enn/experiments/neurips_2021/base.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# python3
21
# pylint: disable=g-bad-file-header
32
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
43
#

enn/experiments/neurips_2021/enn_losses.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# python3
21
# pylint: disable=g-bad-file-header
32
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
43
#

enn/experiments/neurips_2021/load.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# python3
21
# pylint: disable=g-bad-file-header
32
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
43
#

enn/experiments/neurips_2021/plotting.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# python3
21
# pylint: disable=g-bad-file-header
32
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
43
#

enn/experiments/neurips_2021/run_testbed.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# python3
21
# pylint: disable=g-bad-file-header
32
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
43
#

enn/experiments/neurips_2021/run_thompson.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# python3
21
# pylint: disable=g-bad-file-header
32
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
43
#

enn/experiments/neurips_2021/testbed.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# python3
21
# pylint: disable=g-bad-file-header
32
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
43
#

enn/experiments/neurips_2021/thompson.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# python3
21
# pylint: disable=g-bad-file-header
32
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
43
#

enn/extra/__init__.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# python3
21
# pylint: disable=g-bad-file-header
32
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
43
#

enn/extra/kmeans.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# python3
21
# pylint: disable=g-bad-file-header
32
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
43
#

enn/extra/kmeans_test.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# python3
21
# pylint: disable=g-bad-file-header
32
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
43
#

enn/extra/vae.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# python3
21
# pylint: disable=g-bad-file-header
32
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
43
#

enn/extra/vae_test.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# python3
21
# pylint: disable=g-bad-file-header
32
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
43
#

enn/networks/__init__.py

+17-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# python3
21
# pylint: disable=g-bad-file-header
32
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
43
#
@@ -61,6 +60,9 @@
6160
from enn.networks.indexers import GaussianWithUnitIndexer
6261
from enn.networks.indexers import PrngIndexer
6362
from enn.networks.indexers import ScaledGaussianIndexer
63+
# LeNet (MNIST)
64+
from enn.networks.lenet import EnsembleLeNet5ENN
65+
from enn.networks.lenet import LeNet5
6466
# Priors
6567
from enn.networks.priors import convert_enn_to_prior_fn
6668
from enn.networks.priors import EnnWithAdditivePrior
@@ -69,3 +71,17 @@
6971
from enn.networks.priors import make_random_feat_gp
7072
from enn.networks.priors import NetworkWithAdditivePrior
7173
from enn.networks.priors import PriorFn
74+
# ResNet (Imagenet)
75+
from enn.networks.resnet import EnsembleResNetENN
76+
from enn.networks.resnet import resnet_model
77+
# ResNet Configs (Imagenet)
78+
from enn.networks.resnet_lib import ResBlockV2
79+
from enn.networks.resnet_lib import ResNet
80+
from enn.networks.resnet_lib import RESNET_101
81+
from enn.networks.resnet_lib import RESNET_152
82+
from enn.networks.resnet_lib import RESNET_200
83+
from enn.networks.resnet_lib import RESNET_50
84+
from enn.networks.resnet_lib import ResNetConfig
85+
# VGG (Cifar10)
86+
from enn.networks.vgg import EnsembleVGGENN
87+
from enn.networks.vgg import VGG

enn/networks/bbb.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# python3
21
# pylint: disable=g-bad-file-header
32
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
43
#

enn/networks/bbb_test.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# python3
21
# pylint: disable=g-bad-file-header
32
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
43
#

enn/networks/categorical_ensembles.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# python3
21
# pylint: disable=g-bad-file-header
32
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
43
#

enn/networks/categorical_ensembles_test.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# python3
21
# pylint: disable=g-bad-file-header
32
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
43
#

enn/networks/dropout.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# python3
21
# pylint: disable=g-bad-file-header
32
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
43
#

enn/networks/dropout_test.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# python3
21
# pylint: disable=g-bad-file-header
32
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
43
#

enn/networks/einsum_mlp.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# python3
21
# pylint: disable=g-bad-file-header
32
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
43
#

enn/networks/einsum_mlp_test.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# python3
21
# pylint: disable=g-bad-file-header
32
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
43
#

enn/networks/ensembles.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# python3
21
# pylint: disable=g-bad-file-header
32
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
43
#

enn/networks/ensembles_test.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# python3
21
# pylint: disable=g-bad-file-header
32
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
43
#

enn/networks/epinet.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# python3
21
# pylint: disable=g-bad-file-header
32
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
43
#

enn/networks/epinet_test.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# python3
21
# pylint: disable=g-bad-file-header
32
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
43
#
@@ -20,10 +19,10 @@
2019

2120
from absl.testing import absltest
2221
from absl.testing import parameterized
23-
from enn import networks
2422
from enn import supervised
2523
from enn import utils
2624
from enn.networks import epinet
25+
from enn.networks import indexers
2726
import haiku as hk
2827

2928

@@ -65,7 +64,7 @@ def enn_ctor():
6564
index_dim=index_dim,
6665
)
6766
enn = utils.epistemic_network_from_module(
68-
enn_ctor, networks.GaussianIndexer(index_dim))
67+
enn_ctor, indexers.GaussianIndexer(index_dim))
6968

7069
experiment = test_experiment.experiment_ctor(enn)
7170
experiment.train(10)

enn/networks/gaussian_enn.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# python3
21
# pylint: disable=g-bad-file-header
32
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
43
#

enn/networks/gaussian_enn_test.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# python3
21
# pylint: disable=g-bad-file-header
32
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
43
#

enn/networks/hypermodels.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# python3
21
# pylint: disable=g-bad-file-header
32
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
43
#

enn/networks/hypermodels_test.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# python3
21
# pylint: disable=g-bad-file-header
32
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
43
#

enn/networks/index_mlp.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# python3
21
# pylint: disable=g-bad-file-header
32
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
43
#

enn/networks/index_mlp_test.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# python3
21
# pylint: disable=g-bad-file-header
32
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
43
#

enn/networks/indexers.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# python3
21
# pylint: disable=g-bad-file-header
32
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
43
#

enn/networks/indexers_test.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# python3
21
# pylint: disable=g-bad-file-header
32
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
43
#

enn/networks/lenet.py

+78
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# pylint: disable=g-bad-file-header
2+
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
# ============================================================================
16+
"""Network definitions for LeNet5."""
17+
from typing import Sequence
18+
19+
from absl import logging
20+
import chex
21+
from enn import base
22+
from enn.networks import ensembles
23+
import haiku as hk
24+
import jax
25+
26+
_LeNet5_CHANNELS = (6, 16, 120)
27+
28+
29+
class LeNet5(hk.Module):
30+
"""VGG Network with batchnorm and without maxpool."""
31+
32+
def __init__(self,
33+
num_output_classes: int,
34+
lenet_output_channels: Sequence[int] = _LeNet5_CHANNELS,):
35+
super().__init__()
36+
logging.info('Initializing a LeNet5.')
37+
self._output_channels = lenet_output_channels
38+
num_channels = len(self._output_channels)
39+
40+
self._conv_modules = [
41+
hk.Conv2D( # pylint: disable=g-complex-comprehension
42+
output_channels=self._output_channels[i],
43+
kernel_shape=5,
44+
padding='SAME',
45+
name=f'conv_2d_{i}') for i in range(num_channels)
46+
]
47+
self._mp_modules = [
48+
hk.MaxPool( # pylint: disable=g-complex-comprehension
49+
window_shape=2, strides=2, padding='SAME',
50+
name=f'max_pool_{i}') for i in range(num_channels)
51+
]
52+
self._flatten_module = hk.Flatten()
53+
self._linear_module = hk.Linear(84, name='linear')
54+
self._logits_module = hk.Linear(num_output_classes, name='logits')
55+
56+
def __call__(self, inputs: chex.Array) -> chex.Array:
57+
net = inputs
58+
for conv_layer, mp_layer in zip(self._conv_modules, self._mp_modules):
59+
net = conv_layer(net)
60+
net = jax.nn.relu(net)
61+
net = mp_layer(net)
62+
net = self._flatten_module(net)
63+
net = self._linear_module(net)
64+
net = jax.nn.relu(net)
65+
return self._logits_module(net)
66+
67+
68+
class EnsembleLeNet5ENN(base.EpistemicNetworkWithState):
69+
"""Ensemble of LeNet5 Networks created using einsum ensemble."""
70+
71+
def __init__(self,
72+
num_output_classes: int,
73+
num_ensemble: int = 1,):
74+
def net_fn(x: chex.Array) -> chex.Array:
75+
return LeNet5(num_output_classes)(x)
76+
transformed = hk.without_apply_rng(hk.transform_with_state(net_fn))
77+
enn = ensembles.EnsembleWithState(transformed, num_ensemble)
78+
super().__init__(enn.apply, enn.init, enn.indexer)

enn/networks/lenet_test.py

+52
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# pylint: disable=g-bad-file-header
2+
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
# ============================================================================
16+
17+
"""Tests for ENN Networks."""
18+
from absl.testing import absltest
19+
from absl.testing import parameterized
20+
from enn.networks import lenet
21+
import haiku as hk
22+
import jax
23+
24+
25+
class NetworkTest(parameterized.TestCase):
26+
27+
@parameterized.product(
28+
num_classes=[2, 10],
29+
batch_size=[1, 10],
30+
image_size=[2, 10],
31+
)
32+
def test_forward_pass(
33+
self,
34+
num_classes: int,
35+
batch_size: int,
36+
image_size: int,
37+
):
38+
"""Tests forward pass and output shape."""
39+
enn = lenet.EnsembleLeNet5ENN(
40+
num_output_classes=num_classes,
41+
)
42+
rng = hk.PRNGSequence(0)
43+
image_shape = [image_size, image_size, 3]
44+
x = jax.random.normal(next(rng), shape=[batch_size,] + image_shape)
45+
index = enn.indexer(next(rng))
46+
params, state = enn.init(next(rng), x, index)
47+
out, unused_new_state = enn.apply(params, state, x, index)
48+
self.assertEqual(out.shape, (batch_size, num_classes))
49+
50+
51+
if __name__ == '__main__':
52+
absltest.main()

enn/networks/priors.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# python3
21
# pylint: disable=g-bad-file-header
32
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
43
#

enn/networks/priors_test.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# python3
21
# pylint: disable=g-bad-file-header
32
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
43
#

0 commit comments

Comments
 (0)