Skip to content

Commit

Permalink
Add test cases for ONNX converter of NStepGRU link
Browse files Browse the repository at this point in the history
  • Loading branch information
msakai committed Dec 18, 2019
1 parent ed659e1 commit bd6d32c
Showing 1 changed file with 31 additions and 0 deletions.
31 changes: 31 additions & 0 deletions tests/onnx_chainer_tests/functions_tests/test_rnn.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import chainer
import chainer.functions as F
import chainer.links as L
from chainer import testing

from onnx_chainer.testing import input_generator
Expand Down Expand Up @@ -44,3 +45,33 @@ def __call__(self, hx, ws1, ws2, ws3, bs, xs):
xs = input_generator.increasing(seq_length, batch_size, input_size)

self.expect(model, (hx, ws1, ws2, ws3, bs, xs))


@testing.parameterize(
{'n_layers': 1, 'name': 'TestNStepGRU_1_layer'},
{'n_layers': 2, 'name': 'TestNStepGRU_2_layer'},
)
class TestNStepGRULink(ONNXModelTest):
def test_output(self):
n_layers = self.n_layers
dropout_ratio = 0.0
batch_size = 3
input_size = 4
hidden_size = 5
seq_length = 6

class Model(chainer.Chain):
def __init__(self):
super().__init__()
with self.init_scope():
self.gru = L.NStepGRU(
n_layers, input_size, hidden_size, dropout_ratio)

def __call__(self, *xs):
hy, ys = self.gru(None, xs)
return [hy] + ys

model = Model()
xs = [input_generator.increasing(seq_length, input_size)
for i in range(batch_size)]
self.expect(model, xs, skip_opset_version=[7, 8])

0 comments on commit bd6d32c

Please sign in to comment.