[Deepmind Publication] [arXiv Paper]
This repository contains Deepmind's Gato architecture imitation in TensorFlow.
Since Deepmind only mentions parts of the architecture in its paper, We still don't know much about the model.
However, I believe the paper is enough to imitate the architecture, I'm trying to do that with the open source community's help.
Currently, the repository supports the following operations:
- Transformer (via
GatoTransformer
) - Patch Position Encodings (via
PatchPositionEncoding
) - Embedding Function (via
ResidualEmbedding
) - Local Observation Position Encodings (via
LocalPositionEncoding
) - Tokenizing Continuous Values (via
ContinuousValueTokenizer
) - Shared Embedding (via
DiscreteEmbedding
)
Action tokens are still a mystery in the paper, I need your help.
However, the repository lacks the following miscellaneous.
- Datasets (most important)
- Pre-trained tokenizers
- Training strategy
But, you can still explore the basic architecture of the Gato based on the paper.
Appendix C.1. Transformer Hyperparameters
In the paper, Deepmind tested Gato with 3 architecture variants, 1.18B
, 364M
, and 79M
.
I have named them as large()
, baseline()
and small()
respectively in GatoConfig
.
Hyperparameters | Large(1.18B) | Baseline(364M) | Small(79M) |
---|---|---|---|
Transformer blocks | 24 | 12 | 8 |
Attention heads | 16 | 12 | 24 |
Layer width | 2048 | 1536 | 768 |
Feedforward hidden size | 8192 | 6144 | 3072 |
Key/value size | 128 | 128 | 32 |
Appendix C.2. Embedding Function
There are no mentions that how many residual networks must be stacked for token embeddings.
Therefore, I remain configurable in GatoConfig
.
Whatever how many residual layers are existing, full-preactivation is a key.
The blocks are consisted of:
- Version 2 ResNet architecture (based on ResNet50V2)
- GroupNorm (instead of LayerNorm)
- GeLU (instead of ReLU)
Since the GroupNorm is not supported in TensorFlow, you need to install tensorflow-addons
.
Appendix C.3. Position Encodings
Like Vision Transformer (ViT) by Google, Gato takes the input images as raster-ordered 16x16 patches.
Unlike the Vision Transformer model, however, Gato divides its patch encoding strategy into 2 phases, training and evaluation.
For high-performance computation in TensorFlow, I have used the following expressions.
from
and to
respectively.
In the definition of Appendix B., text tokens, image patch tokens, and discrete & continuous values are observation tokens
When Gato receives those values, they must be encoded with their own (local) time steps.
pip install tensorflow tensorflow-addons
This repository is still a work in progress.
Currently, no downloads and no executables are provided.
I welcome many contributors who can help.
Licensed under the MIT license.