Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix broadcasts which are type unstable with Dual numbers #1441

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Initial edits for real inputs
DomCRose committed Jul 13, 2023
commit 722277542dd9ea5027e17f58863f5bcfe16aa975
40 changes: 19 additions & 21 deletions src/lib/broadcast.jl
Original file line number Diff line number Diff line change
@@ -281,40 +281,38 @@ end
@inline function broadcast_forward(f, args::Vararg{Any,N}) where N
out = dual_function(f).(args...)
T = eltype(out)
T <: Union{Dual, Complex{<:Dual}} || return (out, _ -> nothing)
if any(eltype(a) <: Complex for a in args)
_broadcast_forward_complex(T, out, args...)
if !isconcretetype(T) || T <: Union{Dual, Complex{<:Dual}}
if any(eltype(a) <: Complex for a in args)
return _broadcast_forward_complex(out, args...)
else
return _broadcast_forward(out, args...)
end
else
_broadcast_forward(T, out, args...)
return (out, _ -> nothing)
end
end

# Real input and real output pullback
@inline function _broadcast_forward(::Type{<:Dual}, out, args::Vararg{Any, N}) where {N}
# Real input
@inline _extract_value(x) = value(x)
@inline _extract_value(x::Complex) = Complex(value(real(x)), value(imag(x)))
@inline _broadcast_scalar_pullback(ȳ, out, i) = ȳ * partials(out, i)
@inline function _broadcast_scalar_pullback::Complex, out, i)
return Complex(real(ȳ) * partials(real(out), i), imag(ȳ) * partials(imag(out), i))
end
@inline function _broadcast_forward(out, args::Vararg{Any, N}) where {N}
valN = Val(N)
y = broadcast(x -> value(x), out)
y = broadcast(x -> _extract_value(x), out)
function bc_fwd_back(ȳ)
dargs = ntuple(valN) do i
unbroadcast(args[i], broadcast((y1, o1) -> y1 * partials(o1,i), ȳ, out))
unbroadcast(args[i],
broadcast((y1, o1) -> _broadcast_scalar_pullback(y1, o1, i), ȳ, out)
)
end
(nothing, nothing, dargs...) # nothings for broadcasted & f
end
return y, bc_fwd_back
end

# This handles the complex output and real input pullback
@inline function _broadcast_forward(::Type{<:Complex}, out, args::Vararg{Any, N}) where {N}
valN = Val(N)
y = broadcast(x -> Complex(value(real(x)), value(imag(x))), out)
function bc_fwd_back(ȳ)
dargs = ntuple(valN) do i
unbroadcast(args[i], broadcast((y1, o1) -> (real(y1)*partials(real(o1),i) + imag(y1)*partials(imag(o1), i)), ȳ, out))
end
(nothing, nothing, dargs...) # nothings for broadcasted & f
end
return y, bc_fwd_back
end

# This handles complex input and real output. We use the gradient definition from ChainRules here
# since it agrees with what Zygote did for real(x).
@inline function _broadcast_forward_complex(::Type{<:Dual}, out, args::Vararg{Any, N}) where {N}