Skip to content

Commit

Permalink
fixed demo and documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
MLRichter committed Jan 24, 2021
1 parent 5ec350a commit f936226
Show file tree
Hide file tree
Showing 4 changed files with 140 additions and 153 deletions.
275 changes: 127 additions & 148 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,180 +2,159 @@

[![PyPI version](https://badge.fury.io/py/delve.svg)](https://badge.fury.io/py/delve) [![Build Status](https://travis-ci.org/delve-team/delve.svg?branch=master)](https://travis-ci.org/delve-team/delve) [![License: MIT](https://img.shields.io/badge/License-MIT-blue.svg)](https://opensource.org/licenses/MIT)

Delve is a Python package for visualizing deep learning model training.
Delve is a Python package for analyzing the inference dynamics of your model.

![playground](https://github.com/justinshenk/playground/blob/master/saturation_demo.gif)

Use Delve if you need a lightweight PyTorch or Keras extension that:
- Plots live statistics of network layer inputs to TensorBoard or terminal
- Performs spectral analysis to identify layer saturation for network pruning
- Is easily extendible and configurable
Use Delve if you need a lightweight PyTorch extension that:
- Gives you insight into the inference dynamics of your architecture
- Allows you to optimize and adjust neural networks models to your dataset
without much trial and error
- Allows you to analyze the eigenspaces your data at different stages of inference
- Provides you basic tooling for experiment logging

------------------

## Motivation

Designing a deep neural network involves optimizing over a wide range of parameters and hyperparameters. Delve allows you to visualize your layer saturation during training so you can grow and shrink layers as needed.
Designing a deep neural network is a trial and error heavy process that mostly revolves around comparing
performance metrics of different runs.
One of the key issues with this development process is that the results of metrics not realy propagte back easily
to concrete design improvements.
Delve provides you with spectral analysis tools that allow you to investigate the inference
dynamic evolving in the model while training.
This allows you to spot underutilized and unused layers. Missmatches between
object size and neural architecture among other inefficiencies.
These observations can be propagated back directly to design changes in the architecture even before
the model has fully converged, allowing for a quicker and mor guided design process.

## Demo

![live layer saturation demo](images/layer-saturation-convnet.gif)

![example_fc.gif](images/example_fc.gif)

## Getting Started
## Installation

```bash
pip install delve
```

### Layer Saturation

#### PyTorch

`delve.CheckLayerSat` can be configured as follows:

```
savefile (str) : destination for summaries, depending on the saving strategy, this may be a directory or file
save_to (str) : specifies the saving strategy
supported writers are:
console : print the stats to console everytime save() is called
tensorboard : logs everything in tensorboard format, in this case the savefile must be a directory
csv : creates a csv file with each column corresponding to a logged variable. Everytime save() is called a new
line in the file is created
layerwise_sat (bool) : toggles if layerwise sautration should be saved by the writer
average_sat (bool) : toggles if average saturation should be saved by the writer
ignore_layer_names (list) : a list of layer names, as specified in the modules. The layers specified will be excluded in the
computation. Usefull for excluding layers which are force into a speciic saturation like softmaxes or other output layers.
include_conv (bool) : toggle if convolutional layers should be included
conv_method (str) : the method used to pool the latent space of convolutional layers. Default is 'median", valid inputs
'median', 'mean' and 'max'
sat_threshold (float) : the saturation theshold for computing the dimensionality of the latent representations. Default is
.99. This value may be any floating point in 0 and 1.
modules (torch modules or list of modules) : layer-containing object (may contain submodules)
log_interval (int) : steps between writing summaries
stats (list of str): list of stats to collect
supported stats are:
lsat : layer saturation
conv_method : Method for calculating saturation. Use `cumvar99``, or `all`.
See https://github.com/justinshenk/playground for a comparison of how they work.
include_conv : bool, setting to False includes only linear layers
verbose (bool) : print saturation for every layer during training
```
Pass either a PyTorch model or `torch.nn.Linear` layers to `CheckLayerSat`:

```python
from delve import CheckLayerSat

model = TwoLayerNet() # PyTorch network
stats = CheckLayerSat('runs', model) #logging directory and input

... # setup data loader

for i, data in enumerate(train_loader):
stats.saturation() # output saturation
```

Only fully-connected and convolutional layers are currently supported.

To log the saturation to console, call `stats.saturation()`. For example:

```bash
Regression - SixLayerNet - Hidden layer size 10 │
loss=0.231825: 68%|████████████████████▎ | 1350/2000 [00:04<00:02, 289.30it/s]│
linear1: 90%|█████████████████████████████████▎ | 90.0/100 [00:00<00:00, 453.47it/s]│
linear2: 18%|██████▊ | 18.0/100 [00:00<00:00, 90.68it/s]│
linear3: 32%|███████████▊ | 32.0/100 [00:00<00:00, 161.22it/s]│
linear4: 32%|███████████▊ | 32.0/100 [00:00<00:00, 161.24it/s]│
linear5: 28%|██████████▎ | 28.0/100 [00:00<00:00, 141.11it/s]│
linear6: 90%|██████████████████████████████████▏ | 90.0/100 [00:01<00:00, 56.04it/s]
```

#### Keras

Two classes are provided in `delve.kerascallback`: `CustomTensorBoard`,`SaturationLogger`.

`CustomTensorBoard` takes two parameters:

| Argument | Description |
| --- | --- |
| `log_dir` | location for writing summaries |
| `user_defined_freq` | frequency for writing summaries |
| `kwargs` | passed to `tf.keras.callbacks.TensorBoard` |

`SaturationLogger` contains two parameters:

| Argument | Description |
| --- | --- |
| `model` | Keras model |
| `input_data` | data for passing through the model |
| `print_freq` | frequency for printing |

Example usage:

``` python
from delve.kerascallback import CustomTensorBoard, SaturationLogger

...

# Tensorboard logging
tbCallBack = CustomTensorBoard(log_dir='./runs', user_defined_freq=1)

# Console logging
saturation_logger = SaturationLogger(model, input_data=input_x_train[:2], print_freq=1)

...

# Add callback to Keras `fit` method
model.fit(x_train, y_train,
epochs=100,
batch_size=128,
callbacks=[saturation_logger]) # can also pass tbCallBack
```

Output:

```bash
Epoch 29/100
128/1000 [==>...........................] - ETA: 0s - loss: 2.2783 - acc: 0.1406
dense_1 : %0.83 | dense_2 : %0.79 | dense_3 : %0.67 |
```

#### Optimize neural network topology

Ever wonder how big your fully-connected layers should be? Delve helps you visualize the effect of modifying the layer size on your layer saturation.

For example, see how modifying the hidden layer size of this network affects the second layer saturation but not the first. Multiple runs show that the fully-connected "linear2" layer (light blue is 256-wide and orange is 8-wide) saturation is sensitive to layer size:
### Using Layer Saturation to improve model performance
The saturation metric is the core feature of delve. By default saturation is a value between 0 and 1.0 computed
for any convolutional, lstm or dense layer in the network.
The saturation describes the percentage of eigendirections required for explaining 99% of the variance.
Simply speaking, it tells you how much your data is "filling up" the individual layers inside
your model.

![saturation](images/layer1-saturation.png)
In the image below you can see how saturation portraits inefficiencies in your neural network.
The depicted model is ResNet18 trained on 32 pixel images, which is way to small for
a model with a receptive field exceeding 400 pixels in the final layers.

![saturation](images/layer2-saturation.png)
![demo1](images/resnet.png)

### Log spectral analysis
To visualize what this poorly chosen input resolution does to the inference, we trained logistic regressions on the output of
every layer to solve the same task as the model.
You can clearly see that only the first half of the model (at best) is improving
the intermedia solutions of our logistic regression "probes".
The layers following this are contributing nothing to the quality of the prediction!
You also see that saturation is extremly low for this layers!

Writes the top 5 eigenvalues of each layer to TensorBoard summaries:
We call this a *tail* and it can be removed by either increasing the input resolution or
(which is more economical) reducing the receptive field size to match the object size of your
dataset.

```python
# PyTorch-only
stats = CheckLayerSat('runs', layers, 'spectrum')
```
![demo2](images/resnetBetter.png)

Other options
![spectrum](images/spectrum.png)
We can do this by removing the first two downsampling layers, which quarters the growth
of the receptive field of your network, which reduced not only the number of
parameters but also makes more use of the available parameters, by making more layers
contribute effectivly!

### Intrinsic dimensionality
__For more details check our publication on this topics__
- [Spectral Analysis of Latent Representations](https://arxiv.org/abs/1907.08589)
- [Feature Space Saturation during Training](https://arxiv.org/abs/2006.08679)
- Size Matters (soon)

View the intrinsic dimensionality of models in realtime:

## Demo

![intrinsic_dimensionality-layer2](images/layer2-intrinsic.png)
````python

This comparison suggests that the 8-unit layer (light blue) is too saturated and that a larger layer is needed.
import torch
from delve import CheckLayerSat
from torch.cuda import is_available
from torch.nn import CrossEntropyLoss
from torchvision.datasets import CIFAR10
from torchvision.transforms import ToTensor, Compose
from torch.utils.data.dataloader import DataLoader
from torch.optim import Adam
from torchvision.models.vgg import vgg16

# setup compute device
from tqdm import tqdm

if __name__ == "__main__":

device = "cuda:0" if is_available() else "cpu"

# Get some data
train_data = CIFAR10(root="./tmp", train=True,
download=True, transform=Compose([ToTensor()]))
test_data = CIFAR10(root="./tmp", train=False, download=True, transform=Compose([ToTensor()]))

train_loader = DataLoader(train_data, batch_size=1024,
shuffle=True, num_workers=6,
pin_memory=True)
test_loader = DataLoader(test_data, batch_size=1024,
shuffle=False, num_workers=6,
pin_memory=True)

# instantiate model
model = vgg16(num_classes=10).to(device)

# instantiate optimizer and loss
optimizer = Adam(params=model.parameters())
criterion = CrossEntropyLoss().to(device)

# initialize delve
tracker = CheckLayerSat("my_experiment", save_to="plotcsv", modules=model, device=device)

# begin training
for epoch in range(10):
model.train()
for (images, labels) in tqdm(train_loader):
images, labels = images.to(device), labels.to(device)
prediction = model(images)
optimizer.zero_grad(set_to_none=True)
with torch.cuda.amp.autocast():
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)

loss = criterion(outputs, labels)
loss.backward()
optimizer.step()

total = 0
test_loss = 0
correct = 0
model.eval()
for (images, labels) in tqdm(test_loader):
images, labels = images.to(device), labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
_, predicted = torch.max(outputs.data, 1)

total += labels.size(0)
correct += torch.sum((predicted == labels)).item()
test_loss += loss.item()

# add some additional metrics we want to keep track of
tracker.add_scalar("accuracy", correct / total)
tracker.add_scalar("loss", test_loss / total)

# add saturation to the mix
tracker.add_saturations()

# close the tracker to finish training
tracker.close()

````

### Why this name, Delve?

Expand Down
2 changes: 1 addition & 1 deletion delve/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.1.42'
__version__ = '0.1.43'
7 changes: 5 additions & 2 deletions delve/writers.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,5 +461,8 @@ def plot_stat_level_from_results(savepath, epoch, stat, primary_metric=None, fon


def plot_scatter_from_results(savepath, epoch, stat, df):
ax = plot_stat(df=df, savepath=savepath, epoch=epoch, stat=stat, line=False, save=True, samples=True, ylim=None)
return ax
if len(df) > 0:
ax = plot_stat(df=df, savepath=savepath, epoch=epoch, stat=stat, line=False, save=True, samples=True, ylim=None)
return ax
else:
return None
9 changes: 7 additions & 2 deletions example.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
#!/usr/bin/env python
from os import mkdir
from os.path import exists

import torch
from tqdm import trange

Expand All @@ -17,6 +20,8 @@ def forward(self, x):
return y_pred


if not exists("regression/"):
mkdir("regression/")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.manual_seed(1)

Expand All @@ -34,7 +39,7 @@ def forward(self, x):
x, y, model = x.to(device), y.to(device), model.to(device)

layers = [model.linear1, model.linear2]
stats = CheckLayerSat('regression/h{}'.format(h), layers, device=device)
stats = CheckLayerSat('regression/h{}'.format(h), save_to="csv", modules=layers, device=device)

loss_fn = torch.nn.MSELoss(size_average=False)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
Expand All @@ -49,7 +54,7 @@ def forward(self, x):
loss.backward()
optimizer.step()

stats.saturation()
stats.add_saturations()
steps_iter.write('\n')
stats.close()
steps_iter.close()

0 comments on commit f936226

Please sign in to comment.