Skip to content

Latest commit

 

History

History
25 lines (17 loc) · 633 Bytes

README.md

File metadata and controls

25 lines (17 loc) · 633 Bytes

An Implementation of Batch Normalization LSTM in Pytorch

Tim Cooijmans etl. Recurrent Batch Normalization(arxiv1603.09025)

Frok from sysuNie

Modified to be compatible with Pytorch 1.0.0

To use:

import torch
import torch.nn as nn
from batch_normalization_LSTM import BNLSTMCell, LSTM


model = LSTM(cell_class=BNLSTMCell, input_size=28, hidden_size=512, batch_first=True, max_length=152)

if __name__ == "__main__":
    size = 28
    dummy = torch.rand(300, 2, size)
    out = model(dummy)
    print(model)
    print(out[0])