Skip to content

Commit

Permalink
Fix broken Enzyme reverse diff rule on Const evaluation
Browse files Browse the repository at this point in the history
  • Loading branch information
dominic-chang committed Oct 30, 2024
1 parent 2b217ce commit e7bd801
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 16 deletions.
50 changes: 40 additions & 10 deletions ext/JacobiEllipticEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,8 @@ d end
elseif EnzymeRules.needs_shadow(config)
if EnzymeRules.width(config) == 1
return (ϕ isa Const ? zero(ϕ.val) : ∂F_∂ϕ(ϕ.val, m.val)*ϕ.dval) +(m isa Const ? zero(m.val) : ∂F_∂m(ϕ.val, m.val)*m.dval)

else
return ntuple(i -> (ϕ isa Const ? zero(ϕ.val) : ∂F_∂ϕ(ϕ.val, m.val)*ϕ.dval[i]) + (m isa Const ? zero(m.val) : ∂F_∂m(ϕ.val, m.val)*m.dval[i]), Val(EnzymeRules.width(config)))

end
elseif EnzymeRules.needs_primal(config)
return func.val(ϕ.val, m.val)
Expand Down Expand Up @@ -79,17 +77,33 @@ function reverse(
dϕ = if ϕ isa Const
nothing
elseif EnzymeRules.width(config) == 1
∂F_∂ϕ(ϕ.val, m.val) * dret.val
if dret isa Type{<:Const}
zero(ϕ.val)
else
∂F_∂ϕ(ϕ.val, m.val) * dret.val
end
else
ntuple(i -> ∂F_∂ϕ(ϕ.val, m.val) * dret.val[i], Val(EnzymeRules.width(config)))
if dret isa Type{<:Const}
ntuple(i -> zero(ϕ.val), Val(EnzymeRules.width(config)))
else
ntuple(i -> ∂F_∂ϕ(ϕ.val, m.val) * dret.val[i], Val(EnzymeRules.width(config)))
end
end

dm = if m isa Const
nothing
elseif EnzymeRules.width(config) == 1
∂F_∂m(ϕ.val, m.val) * dret.val
if dret isa Type{<:Const}
zero(ϕ.val)
else
∂F_∂m(ϕ.val, m.val) * dret.val
end
else
ntuple(i -> ∂F_∂m(ϕ.val, m.val) * dret.val[i], Val(EnzymeRules.width(config)))
if dret isa Type{<:Const}
ntuple(i -> zero(ϕ.val), Val(EnzymeRules.width(config)))
else
ntuple(i -> ∂F_∂m(ϕ.val, m.val) * dret.val[i], Val(EnzymeRules.width(config)))
end
end
return (dϕ, dm)
end
Expand Down Expand Up @@ -167,17 +181,33 @@ function reverse(
dϕ = if ϕ isa Const
nothing
elseif EnzymeRules.width(config) == 1
∂E_∂ϕ(ϕ.val, m.val) * dret.val
if dret isa Type{<:Const}
zero(ϕ.val)
else
∂E_∂ϕ(ϕ.val, m.val) * dret.val
end
else
ntuple(i -> ∂E_∂ϕ(ϕ.val, m.val) * dret.val[i], Val(EnzymeRules.width(config)))
if dret isa Type{<:Const}
ntuple(i -> zero(ϕ.val), Val(EnzymeRules.width(config)))
else
ntuple(i -> ∂E_∂ϕ(ϕ.val, m.val) * dret.val[i], Val(EnzymeRules.width(config)))
end
end

dm = if m isa Const
nothing
elseif EnzymeRules.width(config) == 1
∂E_∂m(ϕ.val, m.val) * dret.val
if dret isa Type{<:Const}
zero(ϕ.val)
else
∂E_∂m(ϕ.val, m.val) * dret.val
end
else
ntuple(i -> ∂E_∂m(ϕ.val, m.val) * dret.val[i], Val(EnzymeRules.width(config)))
if dret isa Type{<:Const}
ntuple(i -> zero(ϕ.val), Val(EnzymeRules.width(config)))
else
ntuple(i -> ∂E_∂m(ϕ.val, m.val) * dret.val[i], Val(EnzymeRules.width(config)))
end
end
return (dϕ, dm)
end
Expand Down
12 changes: 6 additions & 6 deletions test/autodiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ using SpecialFunctions
_F = alg.F
@test Zygote.gradient(ϕ -> _F(ϕ, m), ϕ)[1] ≈ 1 / √(1 - m*sin(ϕ)^2) atol=1e-5
@test ForwardDiff.derivative(ϕ -> _F(ϕ, m), ϕ) ≈ 1 / √(1 - m*sin(ϕ)^2) atol=1e-5
@test Enzyme.autodiff(Reverse, ϕ -> _F(ϕ, m), Active, Active(ϕ))[1][1] ≈ 1 / √(1 - m*sin(ϕ)^2) atol=1e-5
@test Enzyme.autodiff(Forward, ϕ -> _F(ϕ, m), Duplicated, Duplicated(ϕ, 1.0))[1][1] ≈ 1 / √(1 - m*sin(ϕ)^2) atol=1e-5
@test Enzyme.autodiff(Reverse, _F, Active, Active(ϕ), Const(m))[1][1] ≈ 1 / √(1 - m*sin(ϕ)^2) atol=1e-5
@test Enzyme.autodiff(Forward, _F, Duplicated, Duplicated(ϕ, 1.0), Const(m))[1][1] ≈ 1 / √(1 - m*sin(ϕ)^2) atol=1e-5

# 4. ∂m(F(ϕ, m)) == E(ϕ, m) / (2 * m * (1 - m)) - F(ϕ, m) / 2m - sin(2ϕ) / (4 * (1-m) * √(1 - m * sin(ϕ)^2))
@test Zygote.gradient(m -> _F(ϕ, m), m)[1] ≈
Expand All @@ -47,7 +47,7 @@ using SpecialFunctions
alg.E(ϕ, m) / (2 * m * (1 - m)) -
alg.F(ϕ, m) / 2 / m -
sin(2*ϕ) / (4 * (1 - m) * √(1 - m * sin(ϕ)^2)) atol=1e-5
@test Enzyme.autodiff(Reverse, m -> _F(ϕ, m), Active, Active(m))[1][1] ≈
@test Enzyme.autodiff(Reverse, _F, Active, Const(ϕ), Active(m))[1][2] ≈
alg.E(ϕ, m) / (2 * m * (1 - m)) -
alg.F(ϕ, m) / 2 / m -
sin(2*ϕ) / (4 * (1 - m) * √(1 - m * sin(ϕ)^2)) atol=1e-5
Expand All @@ -56,13 +56,13 @@ using SpecialFunctions
# 5. ∂ϕ(E(ϕ, m)) == √(1 - m * sin(ϕ)^2)
@test Zygote.gradient(ϕ -> _E(ϕ, m), ϕ)[1] ≈ √(1 - m * sin(ϕ)^2) atol=1e-5
@test ForwardDiff.derivative(ϕ -> _E(ϕ, m), ϕ) ≈ √(1 - m * sin(ϕ)^2) atol=1e-5
@test Enzyme.autodiff(Reverse, ϕ -> _E(ϕ, m), Active, Active(ϕ))[1][1] ≈ √(1 - m * sin(ϕ)^2) atol=1e-5
@test Enzyme.autodiff(Forward, ϕ -> _E(ϕ, m), Duplicated, Duplicated(ϕ, 1.0))[1][1] ≈ √(1 - m*sin(ϕ)^2) atol=1e-5
@test Enzyme.autodiff(Reverse, _E, Active, Active(ϕ), Const(m))[1][1] ≈ √(1 - m * sin(ϕ)^2) atol=1e-5
@test Enzyme.autodiff(Forward, _E, Duplicated, Duplicated(ϕ, 1.0), Const(m))[1][1] ≈ √(1 - m*sin(ϕ)^2) atol=1e-5

# 6. ∂m(E(ϕ, m)) == (E(ϕ, m) - F(ϕ, m)) / 2m
@test Zygote.gradient(m -> _E(ϕ, m), m)[1] ≈ (alg.E(ϕ, m) - alg.F(ϕ, m)) / 2m
@test ForwardDiff.derivative(m -> _E(ϕ, m), m) ≈ (alg.E(ϕ, m) - alg.F(ϕ, m)) / 2m
@test Enzyme.autodiff(Reverse, m -> _E(ϕ, m), Active, Active(m))[1][1] ≈ (alg.E(ϕ, m) - alg.F(ϕ, m)) / 2m atol=1e-5
@test Enzyme.autodiff(Reverse, _E, Active, Const(ϕ), Active(m))[1][2] ≈ (alg.E(ϕ, m) - alg.F(ϕ, m)) / 2m atol=1e-5

end
end
Expand Down

2 comments on commit e7bd801

@dominic-chang
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register

Change:

  • Fix Enzyme Extension bugs with Const evaluation on reverse mode autodiff

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Error while trying to register: Version 0.3.1 already exists

Please sign in to comment.