Skip to content

Commit

Permalink
Update to Enzyme v0.13
Browse files Browse the repository at this point in the history
  • Loading branch information
dominic-chang committed Sep 30, 2024
1 parent 44ac87c commit f523d54
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 15 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ JacobiEllipticZygoteExt = "Zygote"

[compat]
DocStringExtensions = "0.9"
Enzyme = "0.11, 0.12"
Enzyme = "0.13"
ForwardDiff = "0.10"
StaticArrays = "1.6"
Zygote = "0.6"
Expand Down
28 changes: 14 additions & 14 deletions ext/JacobiEllipticEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ function forward(func::Const{typeof(JacobiElliptic.CarlsonAlg.F)}, ::Type{<:Cons
end

function augmented_primal(
config::ConfigWidth{N},
config::RevConfigWidth{N},
func::Const{typeof(JacobiElliptic.CarlsonAlg.F)},
::Union{Type{<:Const}, Type{<:Active}},
ϕ,
Expand All @@ -45,32 +45,32 @@ function augmented_primal(
return EnzymeRules.AugmentedReturn(primal, nothing, nothing)
end

function reverse(::ConfigWidth{1}, func::Const{typeof(JacobiElliptic.CarlsonAlg.F)}, dret::Active, tape, ϕ::Const, m::Active)
function reverse(::RevConfigWidth{1}, func::Const{typeof(JacobiElliptic.CarlsonAlg.F)}, dret::Active, tape, ϕ::Const, m::Active)
dm = ∂F_∂m.val, m.val) * dret.val
return (nothing, dm)
end

function reverse(::ConfigWidth{1}, ::Const{typeof(JacobiElliptic.CarlsonAlg.F)}, dret::Active, tape, ϕ::Active, m::Const)
function reverse(::RevConfigWidth{1}, ::Const{typeof(JacobiElliptic.CarlsonAlg.F)}, dret::Active, tape, ϕ::Active, m::Const)
= ∂F_∂ϕ.val, m.val) * dret.val
return (dϕ, nothing)
end

function reverse(::ConfigWidth{N}, ::Const{typeof(JacobiElliptic.CarlsonAlg.F)}, dret::Active, tape, ϕ::Active, m::Active) where N
function reverse(::RevConfigWidth{N}, ::Const{typeof(JacobiElliptic.CarlsonAlg.F)}, dret::Active, tape, ϕ::Active, m::Active) where N
ϕval = ϕ.val
mval = m.val
dm = ∂F_∂m(ϕval, mval) * dret.val
= ∂F_∂ϕ(ϕval, mval) * dret.val
return (dϕ, dm)
end

function reverse(::ConfigWidth, ::Const{typeof(JacobiElliptic.CarlsonAlg.F)}, ::Type{<:Const}, tape, ϕ::Union{Const, Duplicated}, m::Active)
function reverse(::RevConfigWidth, ::Const{typeof(JacobiElliptic.CarlsonAlg.F)}, ::Type{<:Const}, tape, ϕ::Union{Const, Duplicated}, m::Active)
return (nothing, zero(m.val))
end

function reverse(::ConfigWidth, ::Const{typeof(JacobiElliptic.CarlsonAlg.F)}, ::Type{<:Const}, tape, ϕ::Active, m::Union{Const, Duplicated})
function reverse(::RevConfigWidth, ::Const{typeof(JacobiElliptic.CarlsonAlg.F)}, ::Type{<:Const}, tape, ϕ::Active, m::Union{Const, Duplicated})
return (zero.val), nothing)
end
function reverse(::ConfigWidth, ::Const{typeof(JacobiElliptic.CarlsonAlg.F)}, ::Type{<:Const}, tape, ϕ::Active, m::Active)
function reverse(::RevConfigWidth, ::Const{typeof(JacobiElliptic.CarlsonAlg.F)}, ::Type{<:Const}, tape, ϕ::Active, m::Active)
return (zero.val), zero(m.val))
end

Expand Down Expand Up @@ -103,7 +103,7 @@ function forward(func::Const{typeof(JacobiElliptic.CarlsonAlg.E)}, ::Type{<:Cons
end

function augmented_primal(
config::ConfigWidth{N},
config::RevConfigWidth{N},
func::Const{typeof(JacobiElliptic.CarlsonAlg.E)},
::Union{Type{<:Const}, Type{<:Active}},
ϕ,
Expand All @@ -114,32 +114,32 @@ function augmented_primal(
return EnzymeRules.AugmentedReturn(primal, nothing, nothing)
end

function reverse(::ConfigWidth{1}, func::Const{typeof(JacobiElliptic.CarlsonAlg.E)}, dret::Active, tape, ϕ::Const, m::Active)
function reverse(::RevConfigWidth{1}, func::Const{typeof(JacobiElliptic.CarlsonAlg.E)}, dret::Active, tape, ϕ::Const, m::Active)
dm = ∂E_∂m.val, m.val) * dret.val
return (nothing, dm)
end

function reverse(::ConfigWidth{1}, ::Const{typeof(JacobiElliptic.CarlsonAlg.E)}, dret::Active, tape, ϕ::Active, m::Const)
function reverse(::RevConfigWidth{1}, ::Const{typeof(JacobiElliptic.CarlsonAlg.E)}, dret::Active, tape, ϕ::Active, m::Const)
= ∂E_∂ϕ.val, m.val) * dret.val
return (dϕ, nothing)
end

function reverse(::ConfigWidth{N}, ::Const{typeof(JacobiElliptic.CarlsonAlg.E)}, dret::Active, tape, ϕ::Active, m::Active) where N
function reverse(::RevConfigWidth{N}, ::Const{typeof(JacobiElliptic.CarlsonAlg.E)}, dret::Active, tape, ϕ::Active, m::Active) where N
ϕval = ϕ.val
mval = m.val
dm = ∂E_∂m(ϕval, mval) * dret.val
= ∂E_∂ϕ(ϕval, mval) * dret.val
return (dϕ, dm)
end

function reverse(::ConfigWidth, ::Const{typeof(JacobiElliptic.CarlsonAlg.E)}, ::Type{<:Const}, tape, ϕ::Union{Const, Duplicated}, m::Active)
function reverse(::RevConfigWidth, ::Const{typeof(JacobiElliptic.CarlsonAlg.E)}, ::Type{<:Const}, tape, ϕ::Union{Const, Duplicated}, m::Active)
return (nothing, zero(m.val))
end

function reverse(::ConfigWidth, ::Const{typeof(JacobiElliptic.CarlsonAlg.E)}, ::Type{<:Const}, tape, ϕ::Active, m::Union{Const, Duplicated})
function reverse(::RevConfigWidth, ::Const{typeof(JacobiElliptic.CarlsonAlg.E)}, ::Type{<:Const}, tape, ϕ::Active, m::Union{Const, Duplicated})
return (zero.val), nothing)
end
function reverse(::ConfigWidth, ::Const{typeof(JacobiElliptic.CarlsonAlg.E)}, ::Type{<:Const}, tape, ϕ::Active, m::Active)
function reverse(::RevConfigWidth, ::Const{typeof(JacobiElliptic.CarlsonAlg.E)}, ::Type{<:Const}, tape, ϕ::Active, m::Active)
return (zero.val), zero(m.val))
end

Expand Down

0 comments on commit f523d54

Please sign in to comment.