Skip to content

Commit

Permalink
tests and fixes for stackedrnn
Browse files Browse the repository at this point in the history
  • Loading branch information
MartinuzziFrancesco committed Dec 22, 2024
1 parent 0ee418b commit 36c585e
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 3 deletions.
6 changes: 3 additions & 3 deletions src/wrappers/stackedrnn.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# based on https://fluxml.ai/Flux.jl/stable/guide/models/recurrence/
struct StackedRNN{L,D,S}
layers::L
droput::D
dropout::D
states::S
end

Expand Down Expand Up @@ -39,7 +39,7 @@ function StackedRNN(rlayer, (input_size, hidden_size)::Pair, args...;
@warn("Dropout is not applied when num_layers is 1.")
end

for (idx,layer) in enumerate(num_layers)
for idx in 1:num_layers
in_size = idx == 1 ? input_size : hidden_size
push!(layers, rlayer(in_size => hidden_size, args...; kwargs...))
end
Expand All @@ -52,7 +52,7 @@ end

function (stackedrnn::StackedRNN)(inp::AbstractArray)
for (idx,(layer, state)) in enumerate(zip(stackedrnn.layers, stackedrnn.states))
inp = layer(inp, state0)
inp = layer(inp, state)
if !(idx == length(stackedrnn.layers))
inp = stackedrnn.dropout(inp)
end
Expand Down
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,8 @@ end

@safetestset "Layers" begin
include("test_layers.jl")
end

@safetestset "Wrappers" begin
include("test_wrappers.jl")
end
20 changes: 20 additions & 0 deletions test/test_wrappers.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
using RecurrentLayers
using Flux
using Test

layers = [RNN, GRU, GRUv3, LSTM, MGU, LiGRU, RAN, LightRU, NAS, MUT1, MUT2, MUT3,
SCRN, PeepholeLSTM, FastRNN, FastGRNN]

@testset "Sizes for StackedRNN with layer: $layer" for layer in layers
wrap = StackedRNN(layer, 2 => 4)

inp = rand(Float32, 2, 3, 1)
output = wrap(inp)
@test output isa Array{Float32, 3}
@test size(output) == (4, 3, 1)

inp = rand(Float32, 2, 3)
output = wrap(inp)
@test output isa Array{Float32, 2}
@test size(output) == (4, 3)
end

0 comments on commit 36c585e

Please sign in to comment.