diff --git a/Project.toml b/Project.toml index 4874707..fa5a441 100644 --- a/Project.toml +++ b/Project.toml @@ -1,13 +1,13 @@ name = "RecurrentLayers" uuid = "78449bcf-6750-4b78-9e82-63d4a1ccdf8c" authors = ["Francesco Martinuzzi"] -version = "0.1.5" +version = "0.2.0" [deps] Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" [compat] -Flux = "0.14, 0.15" +Flux = "0.16" julia = "1.10" [extras] diff --git a/docs/pages.jl b/docs/pages.jl index 1261f25..de181af 100644 --- a/docs/pages.jl +++ b/docs/pages.jl @@ -2,7 +2,7 @@ pages=[ "Home" => "index.md", "API Documentation" => [ "Cells" => "api/cells.md", - "Cell Wrappers" => "api/wrappers.md", + "Layers" => "api/layers.md", ], "Roadmap" => "roadmap.md" ] \ No newline at end of file diff --git a/docs/src/api/wrappers.md b/docs/src/api/layers.md similarity index 100% rename from docs/src/api/wrappers.md rename to docs/src/api/layers.md diff --git a/src/RecurrentLayers.jl b/src/RecurrentLayers.jl index d0eb969..4172e48 100644 --- a/src/RecurrentLayers.jl +++ b/src/RecurrentLayers.jl @@ -3,6 +3,7 @@ module RecurrentLayers using Flux import Flux: _size_check, _match_eltype, chunk, create_bias, zeros_like import Flux: glorot_uniform +#TODO add interlinks to initialstates in docstrings https://juliadocs.org/DocumenterInterLinks.jl/stable/ import Flux: initialstates, scan abstract type AbstractRecurrentCell end diff --git a/src/fastrnn_cell.jl b/src/fastrnn_cell.jl index d414191..67e3ca2 100644 --- a/src/fastrnn_cell.jl +++ b/src/fastrnn_cell.jl @@ -37,7 +37,19 @@ h_t &= \alpha \tilde{h}_t + \beta h_{t-1} # Forward - fastrnncell(inp, [state]) + fastrnncell(inp, state) + fastrnncell(inp) + +## Arguments +- `inp`: The input to the fastrnncell. It should be a vector of size `input_size` + or a matrix of size `input_size x batch_size`. +- `state`: The hidden state of the FastRNN. It should be a vector of size + `hidden_size` or a matrix of size `hidden_size x batch_size`. + If not provided, it is assumed to be a vector of zeros. + +## Returns +- A tuple `(output, state)`, where both elements are given by the updated state `new_state`, + a tensor of size `hidden_size` or `hidden_size x batch_size`. """ function FastRNNCell((input_size, hidden_size)::Pair, activation=tanh_fast; init_kernel = glorot_uniform, @@ -65,7 +77,7 @@ function (fastrnn::FastRNNCell)(inp::AbstractVecOrMat, state) candidate_state = fastrnn.activation.(Wi * inp .+ Wh * state .+ b) new_state = alpha .* candidate_state .+ beta .* state - return new_state + return new_state, new_state end Base.show(io::IO, fastrnn::FastRNNCell) = @@ -102,7 +114,18 @@ h_t &= \alpha \tilde{h}_t + \beta h_{t-1} # Forward - fastrnn(inp, [state]) + fastrnn(inp, state) + fastrnn(inp) + +## Arguments +- `inp`: The input to the fastrnn. It should be a vector of size `input_size x len` + or a matrix of size `input_size x len x batch_size`. +- `state`: The hidden state of the FastRNN. If given, it is a vector of size + `hidden_size` or a matrix of size `hidden_size x batch_size`. + If not provided, it is assumed to be a vector of zeros. + +## Returns +- New hidden states `new_states` as an array of size `hidden_size x len x batch_size`. """ function FastRNN((input_size, hidden_size)::Pair, activation = tanh_fast; kwargs...) @@ -155,7 +178,20 @@ h_t &= \big((\zeta (1 - z_t) + \nu) \odot \tilde{h}_t\big) + z_t \odot h_{t-1} # Forward - fastgrnncell(inp, [state]) + fastgrnncell(inp, state) + fastgrnncell(inp) + +## Arguments + +- `inp`: The input to the fastgrnncell. It should be a vector of size `input_size` + or a matrix of size `input_size x batch_size`. +- `state`: The hidden state of the FastGRNN. It should be a vector of size + `hidden_size` or a matrix of size `hidden_size x batch_size`. + If not provided, it is assumed to be a vector of zeros. + +## Returns +- A tuple `(output, state)`, where both elements are given by the updated state `new_state`, + a tensor of size `hidden_size` or `hidden_size x batch_size`. """ function FastGRNNCell((input_size, hidden_size)::Pair, activation=tanh_fast; init_kernel = glorot_uniform, @@ -187,7 +223,7 @@ function (fastgrnn::FastGRNNCell)(inp::AbstractVecOrMat, state) candidate_state = tanh_fast.(partial_gate .+ bh) new_state = (zeta .* (ones(size(gate)) .- gate) .+ nu) .* candidate_state .+ gate .* state - return new_state + return new_state, new_state end Base.show(io::IO, fastgrnn::FastGRNNCell) = @@ -225,7 +261,19 @@ h_t &= \big((\zeta (1 - z_t) + \nu) \odot \tilde{h}_t\big) + z_t \odot h_{t-1} # Forward - fastgrnn(inp, [state]) + fastgrnn(inp, state) + fastgrnn(inp) + +## Arguments + +- `inp`: The input to the fastgrnn. It should be a vector of size `input_size` + or a matrix of size `input_size x batch_size`. +- `state`: The hidden state of the FastGRNN. It should be a vector of size + `hidden_size` or a matrix of size `hidden_size x batch_size`. + If not provided, it is assumed to be a vector of zeros. + +## Returns +- New hidden states `new_states` as an array of size `hidden_size x len x batch_size`. """ function FastGRNN((input_size, hidden_size)::Pair, activation = tanh_fast; kwargs...) diff --git a/src/indrnn_cell.jl b/src/indrnn_cell.jl index c9a7228..94e5b06 100644 --- a/src/indrnn_cell.jl +++ b/src/indrnn_cell.jl @@ -33,8 +33,19 @@ See [`IndRNN`](@ref) for a layer that processes entire sequences. # Forward - rnncell(inp, [state]) - + indrnncell(inp, state) + indrnncell(inp) + +## Arguments +- `inp`: The input to the indrnncell. It should be a vector of size `input_size` + or a matrix of size `input_size x batch_size`. +- `state`: The hidden state of the IndRNNCell. It should be a vector of size + `hidden_size` or a matrix of size `hidden_size x batch_size`. + If not provided, it is assumed to be a vector of zeros. + +## Returns +- A tuple `(output, state)`, where both elements are given by the updated state `new_state`, + a tensor of size `hidden_size` or `hidden_size x batch_size`. """ function IndRNNCell((input_size, hidden_size)::Pair, σ=relu; init_kernel = glorot_uniform, @@ -50,7 +61,7 @@ function (indrnn::IndRNNCell)(inp::AbstractVecOrMat, state::AbstractVecOrMat) _size_check(indrnn, inp, 1 => size(indrnn.Wi, 2)) σ = NNlib.fast_act(indrnn.σ, inp) state = σ.(indrnn.Wi*inp .+ indrnn.Wh .* state .+ indrnn.b) - return state + return state, state end function Base.show(io::IO, indrnn::IndRNNCell) @@ -84,6 +95,20 @@ See [`IndRNNCell`](@ref) for a layer that processes a single sequence. ```math \mathbf{h}_{t} = \sigma(\mathbf{W} \mathbf{x}_t + \mathbf{u} \odot \mathbf{h}_{t-1} + \mathbf{b}) ``` +# Forward + + indrnn(inp, state) + indrnn(inp) + +## Arguments +- `inp`: The input to the indrnn. It should be a vector of size `input_size x len` + or a matrix of size `input_size x len x batch_size`. +- `state`: The hidden state of the IndRNN. If given, it is a vector of size + `hidden_size` or a matrix of size `hidden_size x batch_size`. + If not provided, it is assumed to be a vector of zeros. + +## Returns +- New hidden states `new_states` as an array of size `hidden_size x len x batch_size`. """ function IndRNN((input_size, hidden_size)::Pair, σ = tanh; kwargs...) cell = IndRNNCell(input_size, hidden_size, σ; kwargs...) diff --git a/src/lightru_cell.jl b/src/lightru_cell.jl index 87362d4..37de018 100644 --- a/src/lightru_cell.jl +++ b/src/lightru_cell.jl @@ -34,7 +34,19 @@ h_t &= (1 - f_t) \odot h_{t-1} + f_t \odot \tilde{h}_t. # Forward - rnncell(inp, [state]) + lightrucell(inp, state) + lightrucell(inp) + +## Arguments +- `inp`: The input to the lightrucell. It should be a vector of size `input_size` + or a matrix of size `input_size x batch_size`. +- `state`: The hidden state of the LightRUCell. It should be a vector of size + `hidden_size` or a matrix of size `hidden_size x batch_size`. + If not provided, it is assumed to be a vector of zeros. + +## Returns +- A tuple `(output, state)`, where both elements are given by the updated state `new_state`, + a tensor of size `hidden_size` or `hidden_size x batch_size`. """ function LightRUCell((input_size, hidden_size)::Pair; init_kernel = glorot_uniform, @@ -58,7 +70,7 @@ function (lightru::LightRUCell)(inp::AbstractVecOrMat, state) candidate_state = @. tanh_fast(gxs[1]) forget_gate = sigmoid_fast(gxs[2] .+ Wh * state .+ b) new_state = @. (1 - forget_gate) * state + forget_gate * candidate_state - return new_state + return new_state, new_state end Base.show(io::IO, lightru::LightRUCell) = @@ -93,6 +105,21 @@ f_t &= \delta(W_f x_t + U_f h_{t-1} + b_f), \\ h_t &= (1 - f_t) \odot h_{t-1} + f_t \odot \tilde{h}_t. \end{aligned} ``` + +# Forward + + lightru(inp, state) + lightru(inp) + +## Arguments +- `inp`: The input to the lightru. It should be a vector of size `input_size x len` + or a matrix of size `input_size x len x batch_size`. +- `state`: The hidden state of the LightRU. If given, it is a vector of size + `hidden_size` or a matrix of size `hidden_size x batch_size`. + If not provided, it is assumed to be a vector of zeros. + +## Returns +- New hidden states `new_states` as an array of size `hidden_size x len x batch_size`. """ function LightRU((input_size, hidden_size)::Pair; kwargs...) cell = LightRUCell(input_size => hidden_size; kwargs...) diff --git a/src/ligru_cell.jl b/src/ligru_cell.jl index d622c14..259a901 100644 --- a/src/ligru_cell.jl +++ b/src/ligru_cell.jl @@ -36,7 +36,19 @@ h_t &= z_t \odot h_{t-1} + (1 - z_t) \odot \tilde{h}_t # Forward - rnncell(inp, [state]) + ligrucell(inp, state) + ligrucell(inp) + +## Arguments +- `inp`: The input to the ligrucell. It should be a vector of size `input_size` + or a matrix of size `input_size x batch_size`. +- `state`: The hidden state of the LiGRUCell. It should be a vector of size + `hidden_size` or a matrix of size `hidden_size x batch_size`. + If not provided, it is assumed to be a vector of zeros. + +## Returns +- A tuple `(output, state)`, where both elements are given by the updated state `new_state`, + a tensor of size `hidden_size` or `hidden_size x batch_size`. """ function LiGRUCell((input_size, hidden_size)::Pair; init_kernel = glorot_uniform, @@ -60,7 +72,7 @@ function (ligru::LiGRUCell)(inp::AbstractVecOrMat, state) forget_gate = @. sigmoid_fast(gxs[1] + ghs[1]) candidate_hidden = @. tanh_fast(gxs[2] + ghs[2]) new_state = forget_gate .* state .+ (1 .- forget_gate) .* candidate_hidden - return new_state + return new_state, new_state end @@ -93,6 +105,21 @@ z_t &= \sigma(W_z x_t + U_z h_{t-1}), \\ h_t &= z_t \odot h_{t-1} + (1 - z_t) \odot \tilde{h}_t \end{aligned} ``` + +# Forward + + ligru(inp, state) + ligru(inp) + +## Arguments +- `inp`: The input to the ligru. It should be a vector of size `input_size x len` + or a matrix of size `input_size x len x batch_size`. +- `state`: The hidden state of the LiGRU. If given, it is a vector of size + `hidden_size` or a matrix of size `hidden_size x batch_size`. + If not provided, it is assumed to be a vector of zeros. + +## Returns +- New hidden states `new_states` as an array of size `hidden_size x len x batch_size`. """ function LiGRU((input_size, hidden_size)::Pair; kwargs...) cell = LiGRUCell(input_size => hidden_size; kwargs...) diff --git a/src/mgu_cell.jl b/src/mgu_cell.jl index b6ddd00..1a55e87 100644 --- a/src/mgu_cell.jl +++ b/src/mgu_cell.jl @@ -34,7 +34,19 @@ h_t &= (1 - f_t) \odot h_{t-1} + f_t \odot \tilde{h}_t # Forward - rnncell(inp, [state]) + mgucell(inp, state) + mgucell(inp) + +## Arguments +- `inp`: The input to the mgucell. It should be a vector of size `input_size` + or a matrix of size `input_size x batch_size`. +- `state`: The hidden state of the MGUCell. It should be a vector of size + `hidden_size` or a matrix of size `hidden_size x batch_size`. + If not provided, it is assumed to be a vector of zeros. + +## Returns +- A tuple `(output, state)`, where both elements are given by the updated state `new_state`, + a tensor of size `hidden_size` or `hidden_size x batch_size`. """ function MGUCell((input_size, hidden_size)::Pair; init_kernel = glorot_uniform, @@ -58,7 +70,7 @@ function (mgu::MGUCell)(inp::AbstractVecOrMat, state) forget_gate = sigmoid_fast.(gxs[1] .+ ghs[1]*state) candidate_state = tanh_fast.(gxs[2] .+ ghs[2]*(forget_gate.*state)) new_state = forget_gate .* state .+ (1 .- forget_gate) .* candidate_state - return new_state + return new_state, new_state end Base.show(io::IO, mgu::MGUCell) = @@ -92,6 +104,21 @@ f_t &= \sigma(W_f x_t + U_f h_{t-1} + b_f), \\ h_t &= (1 - f_t) \odot h_{t-1} + f_t \odot \tilde{h}_t \end{aligned} ``` + +# Forward + + mgu(inp, state) + mgu(inp) + +## Arguments +- `inp`: The input to the mgu. It should be a vector of size `input_size x len` + or a matrix of size `input_size x len x batch_size`. +- `state`: The hidden state of the MGU. If given, it is a vector of size + `hidden_size` or a matrix of size `hidden_size x batch_size`. + If not provided, it is assumed to be a vector of zeros. + +## Returns +- New hidden states `new_states` as an array of size `hidden_size x len x batch_size`. """ function MGU((input_size, hidden_size)::Pair; kwargs...) cell = MGUCell(input_size => hidden_size; kwargs...) diff --git a/src/mut_cell.jl b/src/mut_cell.jl index dec5a8b..2820474 100644 --- a/src/mut_cell.jl +++ b/src/mut_cell.jl @@ -35,7 +35,19 @@ h_{t+1} &= \tanh(U_h (r \odot h_t) + \tanh(W_h x_t) + b_h) \odot z \\ # Forward - rnncell(inp, [state]) + mutcell(inp, state) + mutcell(inp) + +## Arguments +- `inp`: The input to the mutcell. It should be a vector of size `input_size` + or a matrix of size `input_size x batch_size`. +- `state`: The hidden state of the MUTCell. It should be a vector of size + `hidden_size` or a matrix of size `hidden_size x batch_size`. + If not provided, it is assumed to be a vector of zeros. + +## Returns +- A tuple `(output, state)`, where both elements are given by the updated state `new_state`, +a tensor of size `hidden_size` or `hidden_size x batch_size`. """ function MUT1Cell((input_size, hidden_size)::Pair; init_kernel = glorot_uniform, @@ -62,7 +74,7 @@ function (mut::MUT1Cell)(inp::AbstractVecOrMat, state) ghs[2] * (reset_gate .* state) .+ tanh_fast(gxs[3]) ) #in the paper is tanh(x_t) but dimensionally it cannot work new_state = candidate_state .* forget_gate .+ state .* (1 .- forget_gate) - return new_state + return new_state, new_state end Base.show(io::IO, mut::MUT1Cell) = @@ -96,6 +108,21 @@ h_{t+1} &= \tanh(U_h (r \odot h_t) + \tanh(W_h x_t) + b_h) \odot z \\ &\quad + h_t \odot (1 - z). \end{aligned} ``` + +# Forward + + mut(inp, state) + mut(inp) + +## Arguments +- `inp`: The input to the mut. It should be a vector of size `input_size x len` + or a matrix of size `input_size x len x batch_size`. +- `state`: The hidden state of the MUT. If given, it is a vector of size + `hidden_size` or a matrix of size `hidden_size x batch_size`. + If not provided, it is assumed to be a vector of zeros. + +## Returns +- New hidden states `new_states` as an array of size `hidden_size x len x batch_size`. """ function MUT1((input_size, hidden_size)::Pair; kwargs...) cell = MUT1Cell(input_size => hidden_size; kwargs...) @@ -144,7 +171,19 @@ h_{t+1} &= \tanh(U_h (r \odot h_t) + W_h x_t + b_h) \odot z \\ # Forward - rnncell(inp, [state]) + mutcell(inp, state) + mutcell(inp) + +## Arguments +- `inp`: The input to the mutcell. It should be a vector of size `input_size` + or a matrix of size `input_size x batch_size`. +- `state`: The hidden state of the MUTCell. It should be a vector of size + `hidden_size` or a matrix of size `hidden_size x batch_size`. + If not provided, it is assumed to be a vector of zeros. + +## Returns +- A tuple `(output, state)`, where both elements are given by the updated state `new_state`, +a tensor of size `hidden_size` or `hidden_size x batch_size`. """ function MUT2Cell((input_size, hidden_size)::Pair; init_kernel = glorot_uniform, @@ -170,7 +209,7 @@ function (mut::MUT2Cell)(inp::AbstractVecOrMat, state) reset_gate = sigmoid_fast.(gxs[2] .+ ghs[2]*state) candidate_state = tanh_fast.(ghs[3] * (reset_gate .* state) .+ gxs[3]) new_state = candidate_state .* forget_gate .+ state .* (1 .- forget_gate) - return new_state + return new_state, new_state end Base.show(io::IO, mut::MUT2Cell) = @@ -205,6 +244,21 @@ h_{t+1} &= \tanh(U_h (r \odot h_t) + W_h x_t + b_h) \odot z \\ &\quad + h_t \odot (1 - z). \end{aligned} ``` + +# Forward + + mut(inp, state) + mut(inp) + +## Arguments +- `inp`: The input to the mut. It should be a vector of size `input_size x len` + or a matrix of size `input_size x len x batch_size`. +- `state`: The hidden state of the MUT. If given, it is a vector of size + `hidden_size` or a matrix of size `hidden_size x batch_size`. + If not provided, it is assumed to be a vector of zeros. + +## Returns +- New hidden states `new_states` as an array of size `hidden_size x len x batch_size`. """ function MUT2((input_size, hidden_size)::Pair; kwargs...) cell = MUT2Cell(input_size => hidden_size; kwargs...) @@ -253,7 +307,19 @@ h_{t+1} &= \tanh(U_h (r \odot h_t) + W_h x_t + b_h) \odot z \\ # Forward - rnncell(inp, [state]) + mutcell(inp, state) + mutcell(inp) + +## Arguments +- `inp`: The input to the mutcell. It should be a vector of size `input_size` + or a matrix of size `input_size x batch_size`. +- `state`: The hidden state of the MUTCell. It should be a vector of size + `hidden_size` or a matrix of size `hidden_size x batch_size`. + If not provided, it is assumed to be a vector of zeros. + +## Returns +- A tuple `(output, state)`, where both elements are given by the updated state `new_state`, +a tensor of size `hidden_size` or `hidden_size x batch_size`. """ function MUT3Cell((input_size, hidden_size)::Pair; init_kernel = glorot_uniform, @@ -278,7 +344,7 @@ function (mut::MUT3Cell)(inp::AbstractVecOrMat, state) reset_gate = sigmoid_fast.(gxs[2] .+ ghs[2]*state) candidate_state = tanh_fast.(ghs[3] * (reset_gate .* state) .+ gxs[3]) new_state = candidate_state .* forget_gate .+ state .* (1 .- forget_gate) - return new_state + return new_state, new_state end Base.show(io::IO, mut::MUT3Cell) = @@ -312,6 +378,21 @@ h_{t+1} &= \tanh(U_h (r \odot h_t) + W_h x_t + b_h) \odot z \\ &\quad + h_t \odot (1 - z). \end{aligned} ``` + +# Forward + + mut(inp, state) + mut(inp) + +## Arguments +- `inp`: The input to the mut. It should be a vector of size `input_size x len` + or a matrix of size `input_size x len x batch_size`. +- `state`: The hidden state of the MUT. If given, it is a vector of size + `hidden_size` or a matrix of size `hidden_size x batch_size`. + If not provided, it is assumed to be a vector of zeros. + +## Returns +- New hidden states `new_states` as an array of size `hidden_size x len x batch_size`. """ function MUT3((input_size, hidden_size)::Pair; kwargs...) cell = MUT3Cell(input_size => hidden_size; kwargs...) diff --git a/src/nas_cell.jl b/src/nas_cell.jl index c8ae7dc..e27f92d 100644 --- a/src/nas_cell.jl +++ b/src/nas_cell.jl @@ -78,7 +78,21 @@ h_{\text{new}} &= \tanh(c_{\text{new}} \cdot l_5) # Forward - rnncell(inp, [state]) + nascell(inp, (state, cstate)) + nascell(inp) + +## Arguments + +- `inp`: The input to the fastrnncell. It should be a vector of size `input_size` + or a matrix of size `input_size x batch_size`. +- `(state, cstate)`: A tuple containing the hidden and cell states of the NASCell. + They should be vectors of size `hidden_size` or matrices of size `hidden_size x batch_size`. + If not provided, they are assumed to be vectors of zeros. + +## Returns +- A tuple `(output, state)`, where `output = new_state` is the new hidden state and + `state = (new_state, new_cstate)` is the new hidden and cell state. + They are tensors of size `hidden_size` or `hidden_size x batch_size`. """ function NASCell((input_size, hidden_size)::Pair; init_kernel = glorot_uniform, @@ -123,7 +137,7 @@ function (nas::NASCell)(inp::AbstractVecOrMat, (state, c_state)) new_state = tanh_fast(new_cstate .* l3_2) - return new_state, new_cstate + return new_state, (new_state, new_cstate) end Base.show(io::IO, nas::NASCell) = @@ -178,6 +192,21 @@ l_5 &= \tanh(l_3 + l_4) \\ h_{\text{new}} &= \tanh(c_{\text{new}} \cdot l_5) \end{aligned} ``` + +# Forward + + nas(inp, (state, cstate)) + nas(inp) + +## Arguments +- `inp`: The input to the nas. It should be a vector of size `input_size x len` + or a matrix of size `input_size x len x batch_size`. +- `(state, cstate)`: A tuple containing the hidden and cell states of the NAS. + They should be vectors of size `hidden_size` or matrices of size `hidden_size x batch_size`. + If not provided, they are assumed to be vectors of zeros + +## Returns +- New hidden states `new_states` as an array of size `hidden_size x len x batch_size`. """ function NAS((input_size, hidden_size)::Pair; kwargs...) cell = NASCell(input_size => hidden_size; kwargs...) diff --git a/src/peepholelstm_cell.jl b/src/peepholelstm_cell.jl index 34a2ba4..001180b 100644 --- a/src/peepholelstm_cell.jl +++ b/src/peepholelstm_cell.jl @@ -37,17 +37,21 @@ h_t &= o_t \odot \sigma_h(c_t). # Forward - lstmcell(x, [h, c]) + peepholelstmcell(inp, (state, cstate)) + peepholelstmcell(inp) -The forward pass takes the following arguments: +## Arguments -- `x`: Input to the cell, which can be a vector of size `in` or a matrix of size `in x batch_size`. -- `h`: The hidden state vector of the cell, sized `out`, or a matrix of size `out x batch_size`. -- `c`: The candidate state, sized `out`, or a matrix of size `out x batch_size`. -If not provided, both `h` and `c` default to vectors of zeros. - -# Examples +- `inp`: The input to the peepholelstmcell. It should be a vector of size `input_size` + or a matrix of size `input_size x batch_size`. +- `(state, cstate)`: A tuple containing the hidden and cell states of the PeepholeLSTMCell. + They should be vectors of size `hidden_size` or matrices of size `hidden_size x batch_size`. + If not provided, they are assumed to be vectors of zeros. +## Returns +- A tuple `(output, state)`, where `output = new_state` is the new hidden state and + `state = (new_state, new_cstate)` is the new hidden and cell state. + They are tensors of size `hidden_size` or `hidden_size x batch_size`. """ function PeepholeLSTMCell( (input_size, hidden_size)::Pair; @@ -70,7 +74,7 @@ function (lstm::PeepholeLSTMCell)(inp::AbstractVecOrMat, input, forget, cell, output = chunk(g, 4; dims = 1) new_cstate = @. sigmoid_fast(forget) * c_state + sigmoid_fast(input) * tanh_fast(cell) new_state = @. sigmoid_fast(output) * tanh_fast(new_cstate) - return new_state, new_cstate + return new_state, (new_state, new_cstate) end Base.show(io::IO, lstm::PeepholeLSTMCell) = @@ -108,6 +112,20 @@ c_t &= f_t \odot c_{t-1} + i_t \odot \sigma_c(W_c x_t + b_c), \\ h_t &= o_t \odot \sigma_h(c_t). \end{align} ``` +# Forward + + peepholelstm(inp, (state, cstate)) + peepholelstm(inp) + +## Arguments +- `inp`: The input to the peepholelstm. It should be a vector of size `input_size x len` + or a matrix of size `input_size x len x batch_size`. +- `(state, cstate)`: A tuple containing the hidden and cell states of the PeepholeLSTM. + They should be vectors of size `hidden_size` or matrices of size `hidden_size x batch_size`. + If not provided, they are assumed to be vectors of zeros + +## Returns +- New hidden states `new_states` as an array of size `hidden_size x len x batch_size`. """ function PeepholeLSTM((input_size, hidden_size)::Pair; kwargs...) cell = PeepholeLSTM(input_size => hidden_size; kwargs...) diff --git a/src/ran_cell.jl b/src/ran_cell.jl index f54dc1c..c94bb7f 100644 --- a/src/ran_cell.jl +++ b/src/ran_cell.jl @@ -17,8 +17,6 @@ The `RANCell`, introduced in [this paper](https://arxiv.org/pdf/1705.07393), is a recurrent cell layer which provides additional memory through the use of gates. -and returns both h_t anf c_t. - See [`RAN`](@ref) for a layer that processes entire sequences. # Arguments @@ -41,29 +39,20 @@ h_t &= g(c_t) # Forward - rancell(x, [h, c]) - -The forward pass takes the following arguments: + rancell(inp, (state, cstate)) + rancell(inp) -- `x`: Input to the cell, which can be a vector of size `in` or a matrix of size `in x batch_size`. -- `h`: The hidden state vector of the cell, sized `out`, or a matrix of size `out x batch_size`. -- `c`: The candidate state, sized `out`, or a matrix of size `out x batch_size`. -If not provided, both `h` and `c` default to vectors of zeros. +## Arguments +- `inp`: The input to the rancell. It should be a vector of size `input_size` + or a matrix of size `input_size x batch_size`. +- `(state, cstate)`: A tuple containing the hidden and cell states of the RANCell. + They should be vectors of size `hidden_size` or matrices of size `hidden_size x batch_size`. + If not provided, they are assumed to be vectors of zeros. -# Examples - -```julia -rancell = RANCell(3 => 5) -inp = rand(Float32, 3) -#initializing the hidden states, if we want to provide them -state = rand(Float32, 5) -c_state = rand(Float32, 5) - -#result with default initialization of internal states -result = rancell(inp) -#result with internal states provided -result_state = rancell(inp, (state, c_state)) -``` +## Returns +- A tuple `(output, state)`, where `output = new_state` is the new hidden state and + `state = (new_state, new_cstate)` is the new hidden and cell state. + They are tensors of size `hidden_size` or `hidden_size x batch_size`. """ function RANCell((input_size, hidden_size)::Pair; init_kernel = glorot_uniform, @@ -88,7 +77,7 @@ function (ran::RANCell)(inp::AbstractVecOrMat, (state, c_state)) forget_gate = @. sigmoid_fast(gxs[3] + ghs[2]) candidate_state = @. input_gate * gxs[1] + forget_gate * c_state new_state = tanh_fast(candidate_state) - return new_state, candidate_state + return new_state, (new_state, candidate_state) end Base.show(io::IO, ran::RANCell) = @@ -129,6 +118,21 @@ c_t &= i_t \odot \tilde{c}_t + f_t \odot c_{t-1}, \\ h_t &= g(c_t) \end{aligned} ``` + +# Forward + + ran(inp, (state, cstate)) + ran(inp) + +## Arguments +- `inp`: The input to the ran. It should be a vector of size `input_size x len` + or a matrix of size `input_size x len x batch_size`. +- `(state, cstate)`: A tuple containing the hidden and cell states of the RAN. + They should be vectors of size `hidden_size` or matrices of size `hidden_size x batch_size`. + If not provided, they are assumed to be vectors of zeros + +## Returns +- New hidden states `new_states` as an array of size `hidden_size x len x batch_size`. """ function RAN((input_size, hidden_size)::Pair; kwargs...) cell = RANCell(input_size => hidden_size; kwargs...) diff --git a/src/rhn_cell.jl b/src/rhn_cell.jl index db61ab5..3df2bd9 100644 --- a/src/rhn_cell.jl +++ b/src/rhn_cell.jl @@ -130,7 +130,7 @@ function (rhn::RHNCell)(inp, state=nothing) end end - return current_state + return current_state, current_state end # TODO fix implementation here diff --git a/src/scrn_cell.jl b/src/scrn_cell.jl index bc099e3..117ba1c 100644 --- a/src/scrn_cell.jl +++ b/src/scrn_cell.jl @@ -39,7 +39,21 @@ y_t &= f(U_y h_t + W_y s_t) # Forward - rnncell(inp, [state, c_state]) + scrncell(inp, (state, cstate)) + scrncell(inp) + +## Arguments + +- `inp`: The input to the scrncell. It should be a vector of size `input_size` + or a matrix of size `input_size x batch_size`. +- `(state, cstate)`: A tuple containing the hidden and cell states of the SCRNCell. + They should be vectors of size `hidden_size` or matrices of size `hidden_size x batch_size`. + If not provided, they are assumed to be vectors of zeros. + +## Returns +- A tuple `(output, state)`, where `output = new_state` is the new hidden state and + `state = (new_state, new_cstate)` is the new hidden and cell state. + They are tensors of size `hidden_size` or `hidden_size x batch_size`. """ function SCRNCell((input_size, hidden_size)::Pair; init_kernel = glorot_uniform, @@ -54,12 +68,6 @@ function SCRNCell((input_size, hidden_size)::Pair; return SCRNCell(Wi, Wh, Wc, b, alpha) end -function (scrn::SCRNCell)(inp::AbstractVecOrMat) - state = zeros_like(inp, size(scrn.Wh, 2)) - c_state = zeros_like(state) - return scrn(inp, (state, c_state)) -end - function (scrn::SCRNCell)(inp::AbstractVecOrMat, (state, c_state)) _size_check(scrn, inp, 1 => size(scrn.Wi,2)) Wi, Wh, Wc, b = scrn.Wi, scrn.Wh, scrn.Wc, scrn.bias @@ -73,7 +81,7 @@ function (scrn::SCRNCell)(inp::AbstractVecOrMat, (state, c_state)) context_layer = (1 .- scrn.alpha) .* gxs[1] .+ scrn.alpha .* c_state hidden_layer = sigmoid_fast(gxs[2] .+ ghs[1] * state .+ gcs[1]) new_state = tanh_fast(ghs[2] * hidden_layer .+ gcs[2]) - return new_state, context_layer + return new_state, (new_state, context_layer) end Base.show(io::IO, scrn::SCRNCell) = @@ -112,6 +120,21 @@ h_t &= \sigma(W_h s_t + U_h h_{t-1} + b_h), \\ y_t &= f(U_y h_t + W_y s_t) \end{aligned} ``` + +# Forward + + scrn(inp, (state, cstate)) + scrn(inp) + +## Arguments +- `inp`: The input to the scrn. It should be a vector of size `input_size x len` + or a matrix of size `input_size x len x batch_size`. +- `(state, cstate)`: A tuple containing the hidden and cell states of the SCRN. + They should be vectors of size `hidden_size` or matrices of size `hidden_size x batch_size`. + If not provided, they are assumed to be vectors of zeros + +## Returns +- New hidden states `new_states` as an array of size `hidden_size x len x batch_size`. """ function SCRN((input_size, hidden_size)::Pair; kwargs...) cell = SCRNCell(input_size => hidden_size; kwargs...)