From e7479b3572aa9acecc050532e52fcece9a2b3622 Mon Sep 17 00:00:00 2001 From: Lorenzo Stella Date: Sat, 1 Dec 2018 00:44:01 +0100 Subject: [PATCH] fixed mistake --- src/functions/cubeNormL2.jl | 4 ++-- test/test_cubeNormL2.jl | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/functions/cubeNormL2.jl b/src/functions/cubeNormL2.jl index cac7bc06..b2111bfd 100644 --- a/src/functions/cubeNormL2.jl +++ b/src/functions/cubeNormL2.jl @@ -40,12 +40,12 @@ end function prox!(y::AbstractArray{T}, f::CubeNormL2{R}, x::AbstractArray{T}, gamma::R=one(R)) where {R, T <: RealOrComplex{R}} norm_x = norm(x) - scale = 2 / (1 + sqrt(1 + 12 * f.lambda * norm_x)) + scale = 2 / (1 + sqrt(1 + 12 * gamma * f.lambda * norm_x)) y .= scale .* x return f.lambda * (scale * norm_x)^3 end function prox_naive(f::CubeNormL2{R}, x::AbstractArray{T}, gamma=one(R)) where {R, T <: RealOrComplex{R}} - y = 2 / (1 + sqrt(1 + 12 * f.lambda * norm(x))) * x + y = 2 / (1 + sqrt(1 + 12 * gamma * f.lambda * norm(x))) * x return y, f.lambda * norm(y)^3 end diff --git a/test/test_cubeNormL2.jl b/test/test_cubeNormL2.jl index 709eab42..a2a48f47 100644 --- a/test/test_cubeNormL2.jl +++ b/test/test_cubeNormL2.jl @@ -17,7 +17,7 @@ for R in [Float16, Float32, Float64] gamma = R(0.5)+rand(R) y, f_y = prox_test(f, x, gamma) grad_f_y, f_y = gradient(f, y) - @test grad_f_y ≈ x - y + @test grad_f_y ≈ (x - y)/gamma end end end