diff --git a/Project.toml b/Project.toml index 1e44d5e..3250436 100644 --- a/Project.toml +++ b/Project.toml @@ -19,7 +19,7 @@ JacobiEllipticZygoteExt = "Zygote" [compat] DocStringExtensions = "0.9" -Enzyme = "0.11, 0.12" +Enzyme = "0.13" ForwardDiff = "0.10" StaticArrays = "1.6" Zygote = "0.6" diff --git a/ext/JacobiEllipticEnzymeExt.jl b/ext/JacobiEllipticEnzymeExt.jl index 8405523..bfdd491 100644 --- a/ext/JacobiEllipticEnzymeExt.jl +++ b/ext/JacobiEllipticEnzymeExt.jl @@ -34,7 +34,7 @@ function forward(func::Const{typeof(JacobiElliptic.CarlsonAlg.F)}, ::Type{<:Cons end function augmented_primal( - config::ConfigWidth{N}, + config::RevConfigWidth{N}, func::Const{typeof(JacobiElliptic.CarlsonAlg.F)}, ::Union{Type{<:Const}, Type{<:Active}}, ϕ, @@ -45,17 +45,17 @@ function augmented_primal( return EnzymeRules.AugmentedReturn(primal, nothing, nothing) end -function reverse(::ConfigWidth{1}, func::Const{typeof(JacobiElliptic.CarlsonAlg.F)}, dret::Active, tape, ϕ::Const, m::Active) +function reverse(::RevConfigWidth{1}, func::Const{typeof(JacobiElliptic.CarlsonAlg.F)}, dret::Active, tape, ϕ::Const, m::Active) dm = ∂F_∂m(ϕ.val, m.val) * dret.val return (nothing, dm) end -function reverse(::ConfigWidth{1}, ::Const{typeof(JacobiElliptic.CarlsonAlg.F)}, dret::Active, tape, ϕ::Active, m::Const) +function reverse(::RevConfigWidth{1}, ::Const{typeof(JacobiElliptic.CarlsonAlg.F)}, dret::Active, tape, ϕ::Active, m::Const) dϕ = ∂F_∂ϕ(ϕ.val, m.val) * dret.val return (dϕ, nothing) end -function reverse(::ConfigWidth{N}, ::Const{typeof(JacobiElliptic.CarlsonAlg.F)}, dret::Active, tape, ϕ::Active, m::Active) where N +function reverse(::RevConfigWidth{N}, ::Const{typeof(JacobiElliptic.CarlsonAlg.F)}, dret::Active, tape, ϕ::Active, m::Active) where N ϕval = ϕ.val mval = m.val dm = ∂F_∂m(ϕval, mval) * dret.val @@ -63,14 +63,14 @@ function reverse(::ConfigWidth{N}, ::Const{typeof(JacobiElliptic.CarlsonAlg.F)}, return (dϕ, dm) end -function reverse(::ConfigWidth, ::Const{typeof(JacobiElliptic.CarlsonAlg.F)}, ::Type{<:Const}, tape, ϕ::Union{Const, Duplicated}, m::Active) +function reverse(::RevConfigWidth, ::Const{typeof(JacobiElliptic.CarlsonAlg.F)}, ::Type{<:Const}, tape, ϕ::Union{Const, Duplicated}, m::Active) return (nothing, zero(m.val)) end -function reverse(::ConfigWidth, ::Const{typeof(JacobiElliptic.CarlsonAlg.F)}, ::Type{<:Const}, tape, ϕ::Active, m::Union{Const, Duplicated}) +function reverse(::RevConfigWidth, ::Const{typeof(JacobiElliptic.CarlsonAlg.F)}, ::Type{<:Const}, tape, ϕ::Active, m::Union{Const, Duplicated}) return (zero(ϕ.val), nothing) end -function reverse(::ConfigWidth, ::Const{typeof(JacobiElliptic.CarlsonAlg.F)}, ::Type{<:Const}, tape, ϕ::Active, m::Active) +function reverse(::RevConfigWidth, ::Const{typeof(JacobiElliptic.CarlsonAlg.F)}, ::Type{<:Const}, tape, ϕ::Active, m::Active) return (zero(ϕ.val), zero(m.val)) end @@ -103,7 +103,7 @@ function forward(func::Const{typeof(JacobiElliptic.CarlsonAlg.E)}, ::Type{<:Cons end function augmented_primal( - config::ConfigWidth{N}, + config::RevConfigWidth{N}, func::Const{typeof(JacobiElliptic.CarlsonAlg.E)}, ::Union{Type{<:Const}, Type{<:Active}}, ϕ, @@ -114,17 +114,17 @@ function augmented_primal( return EnzymeRules.AugmentedReturn(primal, nothing, nothing) end -function reverse(::ConfigWidth{1}, func::Const{typeof(JacobiElliptic.CarlsonAlg.E)}, dret::Active, tape, ϕ::Const, m::Active) +function reverse(::RevConfigWidth{1}, func::Const{typeof(JacobiElliptic.CarlsonAlg.E)}, dret::Active, tape, ϕ::Const, m::Active) dm = ∂E_∂m(ϕ.val, m.val) * dret.val return (nothing, dm) end -function reverse(::ConfigWidth{1}, ::Const{typeof(JacobiElliptic.CarlsonAlg.E)}, dret::Active, tape, ϕ::Active, m::Const) +function reverse(::RevConfigWidth{1}, ::Const{typeof(JacobiElliptic.CarlsonAlg.E)}, dret::Active, tape, ϕ::Active, m::Const) dϕ = ∂E_∂ϕ(ϕ.val, m.val) * dret.val return (dϕ, nothing) end -function reverse(::ConfigWidth{N}, ::Const{typeof(JacobiElliptic.CarlsonAlg.E)}, dret::Active, tape, ϕ::Active, m::Active) where N +function reverse(::RevConfigWidth{N}, ::Const{typeof(JacobiElliptic.CarlsonAlg.E)}, dret::Active, tape, ϕ::Active, m::Active) where N ϕval = ϕ.val mval = m.val dm = ∂E_∂m(ϕval, mval) * dret.val @@ -132,14 +132,14 @@ function reverse(::ConfigWidth{N}, ::Const{typeof(JacobiElliptic.CarlsonAlg.E)}, return (dϕ, dm) end -function reverse(::ConfigWidth, ::Const{typeof(JacobiElliptic.CarlsonAlg.E)}, ::Type{<:Const}, tape, ϕ::Union{Const, Duplicated}, m::Active) +function reverse(::RevConfigWidth, ::Const{typeof(JacobiElliptic.CarlsonAlg.E)}, ::Type{<:Const}, tape, ϕ::Union{Const, Duplicated}, m::Active) return (nothing, zero(m.val)) end -function reverse(::ConfigWidth, ::Const{typeof(JacobiElliptic.CarlsonAlg.E)}, ::Type{<:Const}, tape, ϕ::Active, m::Union{Const, Duplicated}) +function reverse(::RevConfigWidth, ::Const{typeof(JacobiElliptic.CarlsonAlg.E)}, ::Type{<:Const}, tape, ϕ::Active, m::Union{Const, Duplicated}) return (zero(ϕ.val), nothing) end -function reverse(::ConfigWidth, ::Const{typeof(JacobiElliptic.CarlsonAlg.E)}, ::Type{<:Const}, tape, ϕ::Active, m::Active) +function reverse(::RevConfigWidth, ::Const{typeof(JacobiElliptic.CarlsonAlg.E)}, ::Type{<:Const}, tape, ϕ::Active, m::Active) return (zero(ϕ.val), zero(m.val)) end