SRIP IITGN 2022 Screening Round Solutions
JAX is an automatic differentiation toolbox which aims to bring differentiable programming in Numpy-style onto GPUs and TPUs. With its updated version of Autograd, JAX can automatically differentiate native Python and NumPy code. JAX uses XLA to compile and run the Numpy code on accelerators, like GPUs and TPUs. Python as an interpreted programming language in slow. In order to train networks at scale we need fast compilation and parallel computing. Some precompiled CUDA kernels already provides a set of primitive instructions which can be executed massively on a NVIDIA GPU. But ideally we want to launch as few kernels as possible to reduce communication times and memory load. XLA ptimizes memory bandwith by “fusing” operations and reduces the amount of returned intermediate computations.
JAX is much more than just a GPU-backed NumPy. It also comes with a few program transformations that are useful when writing numerical code. The three main ones are:
- jit(), for speeding up your code
- grad(), for taking derivatives
- vmap(), for automatic vectorization or batching.
In the solution though I tried to implement the code efficiently but still in some places the implementation can be made more effective. Here is a short list:
- Use of pytree for model update
- Use of more evaluation metrics for MNIST classification
- Accuracy of the neural network can be increased by the use of some common regularisation techniques like batch-normalisation etc.
- Variation of covariance can be made more interactive.