Skip to content

Commit

Permalink
Avoid ValidationError in NStepGRU link converter test case by a hack
Browse files Browse the repository at this point in the history
At the moment, this hack in necessary for avoiding error like:
ValidationError: Nodes in a graph must be topologically sorted, however input 'v330' of node:
input: "Permutate_0_const_empty" input: "v330" input: "Permutate_0_const_range" output: "Permutate_0_tmp_0" name: "Permutate_0_tmp_0" op_type: "Scatter"
is not output of any previous nodes.
  • Loading branch information
msakai committed Dec 18, 2019
1 parent bd6d32c commit 7f77a76
Showing 1 changed file with 34 additions and 1 deletion.
35 changes: 34 additions & 1 deletion tests/onnx_chainer_tests/functions_tests/test_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
import chainer.functions as F
import chainer.links as L
from chainer import testing
import numpy as np

from onnx_chainer import onnx_helper
from onnx_chainer.testing import input_generator
from onnx_chainer_tests.helper import ONNXModelTest

Expand Down Expand Up @@ -47,6 +49,25 @@ def __call__(self, hx, ws1, ws2, ws3, bs, xs):
self.expect(model, (hx, ws1, ws2, ws3, bs, xs))


def convert_Permutate(params):
gb = onnx_helper.GraphBuilder()
# indices_name = params.context.get_name(func.indices)
indices_name = params.context.add_const(params.func.indices,
'indices') # XXX
if params.func.inv:
empty = params.context.add_const(
np.zeros(dtype=np.int64, shape=params.func.indices.shape), 'empty')
r = params.context.add_const(
np.arange(len(params.func.indices), dtype=np.int64),
'range')
op = 'ScatterElements' if params.opset_version == 11 else 'Scatter'
indices_name = gb.op(op, [empty, indices_name, r])
params.input_names.append(indices_name)
gb.op_output_named('Gather', params.input_names, params.output_names,
axis=params.func.axis)
return gb.nodes()


@testing.parameterize(
{'n_layers': 1, 'name': 'TestNStepGRU_1_layer'},
{'n_layers': 2, 'name': 'TestNStepGRU_2_layer'},
Expand Down Expand Up @@ -74,4 +95,16 @@ def __call__(self, *xs):
model = Model()
xs = [input_generator.increasing(seq_length, input_size)
for i in range(batch_size)]
self.expect(model, xs, skip_opset_version=[7, 8])

# XXX: Replace Permutate converter for avoiding error like:
# ValidationError: Nodes in a graph must be topologically sorted, \
# however input 'v330' of node:
# input: "Permutate_0_const_empty" input: "v330" \
# input: "Permutate_0_const_range" output: "Permutate_0_tmp_0" \
# name: "Permutate_0_tmp_0" op_type: "Scatter"
# is not output of any previous nodes.
addon_converters = {
'Permutate': convert_Permutate,
}
self.expect(model, xs, skip_opset_version=[7, 8],
external_converters=addon_converters)

0 comments on commit 7f77a76

Please sign in to comment.