Skip to content

Commit d214600

Browse files
authored
unthunk in some rules (#2058)
* unthunk for ∇batchnorm * unthunk some rrules * unthunk in multigate rrule
1 parent 31e4dd0 commit d214600

File tree

3 files changed

+5
-4
lines changed

3 files changed

+5
-4
lines changed

src/cuda/cudnn.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ end
1414
function ChainRulesCore.rrule(::typeof(batchnorm), g, b, x, running_mean, running_var, momentum; kw...)
1515
y = batchnorm(g, b, x, running_mean, running_var, momentum; kw...)
1616
function batchnorm_pullback(Δ)
17-
grad = ∇batchnorm(g, b, x, Δ, running_mean, running_var, momentum; kw...)
17+
grad = ∇batchnorm(g, b, x, unthunk(Δ), running_mean, running_var, momentum; kw...)
1818
(NoTangent(), grad..., NoTangent(), NoTangent(), NoTangent())
1919
end
2020
y, batchnorm_pullback

src/functor.jl

+3-2
Original file line numberDiff line numberDiff line change
@@ -119,11 +119,11 @@ adapt_storage(to::FluxCPUAdaptor, x::CUDA.RNG) = Random.default_rng()
119119
adapt_storage(to::FluxCPUAdaptor, x::AbstractRNG) = x
120120

121121
function ChainRulesCore.rrule(::Type{Array}, x::CUDA.CuArray)
122-
Array(x), d -> (NoTangent(), CUDA.cu(d),)
122+
Array(x), dx -> (NoTangent(), CUDA.cu(unthunk(dx)),)
123123
end
124124

125125
function ChainRulesCore.rrule(::typeof(Adapt.adapt_storage), to::FluxCPUAdaptor, x::CUDA.AbstractGPUArray)
126-
adapt_storage(to, x), d -> (NoTangent(), NoTangent(), adapt_storage(FluxCUDAAdaptor(), d),)
126+
adapt_storage(to, x), dx -> (NoTangent(), NoTangent(), adapt_storage(FluxCUDAAdaptor(), unthunk(dx)),)
127127
end
128128

129129
# CPU/GPU movement conveniences
@@ -227,3 +227,4 @@ f64(m) = paramtype(Float64, m)
227227
# Functors for certain Julia data structures
228228
@functor Cholesky
229229
trainable(c::Cholesky) = ()
230+

src/layers/recurrent.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ multigate(x::AbstractArray, h, ::Val{N}) where N = ntuple(n -> gate(x,h,n), N)
99
function ChainRulesCore.rrule(::typeof(multigate), x::AbstractArray, h, c)
1010
function multigate_pullback(dy)
1111
dx = map!(zero, similar(x, float(eltype(x)), axes(x)), x)
12-
foreach(multigate(dx, h, c), dy) do dxᵢ, dyᵢ
12+
foreach(multigate(dx, h, c), unthunk(dy)) do dxᵢ, dyᵢ
1313
dyᵢ isa AbstractZero && return
1414
@. dxᵢ += dyᵢ
1515
end

0 commit comments

Comments
 (0)