Skip to content

Commit f462624

Browse files
authored
Return [-1,1] for the derivative of abs at 0 (#703)
1 parent d4a8536 commit f462624

File tree

3 files changed

+11
-8
lines changed

3 files changed

+11
-8
lines changed

ext/IntervalArithmeticDiffRulesExt.jl

+4-3
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
module IntervalArithmeticDiffRulesExt
22

3-
using IntervalArithmetic, DiffRules
3+
using IntervalArithmetic
4+
import DiffRules
45

5-
function DiffRules._abs_deriv(x::Interval)
6-
r = sign(bareinterval(x))
6+
function DiffRules._abs_deriv(x::Interval{T}) where {T<:IntervalArithmetic.NumTypes}
7+
r = ifelse(isthinzero(x), bareinterval(-one(T), one(T)), sign(bareinterval(x)))
78
d = decoration(x)
89
d = min(d, ifelse(in_interval(0, x), trv, d)) # if `x` contains 0, then `trv` decoration
910
return IntervalArithmetic._unsafe_interval(r, d, isguaranteed(x))

ext/IntervalArithmeticForwardDiffExt.jl

+5-3
Original file line numberDiff line numberDiff line change
@@ -101,14 +101,14 @@ end
101101
function (piecewise::Piecewise)(dual::Dual{T, <:Interval}) where {T}
102102
X = value(dual)
103103
input_domain = Domain(X)
104-
if !overlap_domain(input_domain, piecewise)
104+
if !overlap_domain(input_domain, piecewise)
105105
return Dual{T}(emptyinterval(X), emptyinterval(X) .* partials(dual))
106106
end
107107

108108
if !in_domain(input_domain, piecewise)
109109
dec = trv
110110
elseif any(x -> in_domain(x, input_domain), discontinuities(piecewise, 1))
111-
dec = def
111+
dec = def
112112
else
113113
dec = com
114114
end
@@ -135,5 +135,7 @@ function (piecewise::Piecewise)(dual::Dual{T, <:Interval}) where {T}
135135
return Dual{T}(primal, tuple(partial...))
136136
end
137137

138+
ForwardDiff.DiffRules._abs_deriv(x::Dual{T,<:Interval}) where {T} =
139+
Dual{T}(ForwardDiff.DiffRules._abs_deriv(value(x)), zero(partials(x)))
138140

139-
end
141+
end

test/interval_tests/forwarddiff.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ end
1010
@testset "abs" begin
1111
@test ForwardDiff.derivative(abs, interval(-2, -1)) === interval(-1, -1, com)
1212
@test ForwardDiff.derivative(abs, interval( 1, 2)) === interval( 1, 1, com)
13-
@test ForwardDiff.derivative(abs, interval( 0 )) === interval( 0 , trv)
13+
@test ForwardDiff.derivative(abs, interval( 0 )) === interval(-1, 1, trv)
1414
@test ForwardDiff.derivative(abs, interval(-1, 0)) === interval(-1, 0, trv)
1515
@test ForwardDiff.derivative(abs, interval( 0, 1)) === interval( 0, 1, trv)
1616
@test ForwardDiff.derivative(abs, interval(-2, 2)) === interval(-1, 1, trv)
@@ -21,7 +21,7 @@ end
2121
g(x) = abs(x)^2
2222
@test ForwardDiff.derivative(g, interval(-1, 1) ) === interval(convert(Interval{Float64}, -2), convert(Interval{Float64}, 2), trv)
2323
@test all(ForwardDiff.gradient( v -> g(v[1]), [interval(-1, 1)]) .=== [interval(convert(Interval{Float64}, -2), convert(Interval{Float64}, 2), trv)])
24-
@test_broken all(ForwardDiff.hessian( v -> g(v[1]), [interval( 0 )]) .=== [interval(convert(Interval{Float64}, -2), convert(Interval{Float64}, 2), trv)])
24+
@test all(ForwardDiff.hessian( v -> g(v[1]), [interval( 0 )]) .=== [interval(convert(Interval{Float64}, -2), convert(Interval{Float64}, 2), trv)])
2525
end
2626

2727
@testset "sin" begin

0 commit comments

Comments
 (0)