Skip to content

Commit

Permalink
fixed mistake
Browse files Browse the repository at this point in the history
  • Loading branch information
lostella committed Dec 2, 2018
1 parent 8ab41c3 commit e7479b3
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
4 changes: 2 additions & 2 deletions src/functions/cubeNormL2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,12 @@ end

function prox!(y::AbstractArray{T}, f::CubeNormL2{R}, x::AbstractArray{T}, gamma::R=one(R)) where {R, T <: RealOrComplex{R}}
norm_x = norm(x)
scale = 2 / (1 + sqrt(1 + 12 * f.lambda * norm_x))
scale = 2 / (1 + sqrt(1 + 12 * gamma * f.lambda * norm_x))
y .= scale .* x
return f.lambda * (scale * norm_x)^3
end

function prox_naive(f::CubeNormL2{R}, x::AbstractArray{T}, gamma=one(R)) where {R, T <: RealOrComplex{R}}
y = 2 / (1 + sqrt(1 + 12 * f.lambda * norm(x))) * x
y = 2 / (1 + sqrt(1 + 12 * gamma * f.lambda * norm(x))) * x
return y, f.lambda * norm(y)^3
end
2 changes: 1 addition & 1 deletion test/test_cubeNormL2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ for R in [Float16, Float32, Float64]
gamma = R(0.5)+rand(R)
y, f_y = prox_test(f, x, gamma)
grad_f_y, f_y = gradient(f, y)
@test grad_f_y x - y
@test grad_f_y (x - y)/gamma
end
end
end
Expand Down

0 comments on commit e7479b3

Please sign in to comment.