From 34ec0a0d5e8760034c614e1832ad0fb60a1c62ef Mon Sep 17 00:00:00 2001 From: Oscar Dowson Date: Tue, 13 Aug 2024 09:50:22 +1200 Subject: [PATCH] Fix various bugs with Real-valued Hermitian matrices (#3805) --- src/operators.jl | 62 +++++++++++++++++++++++++++++++++++++---- src/sd.jl | 16 +++++++++++ test/test_constraint.jl | 19 +++++++++++++ test/test_operator.jl | 38 +++++++++++++++++++++++-- 4 files changed, 128 insertions(+), 7 deletions(-) diff --git a/src/operators.jl b/src/operators.jl index 25cd8f14c79..11671393eb8 100644 --- a/src/operators.jl +++ b/src/operators.jl @@ -333,11 +333,9 @@ end function Base.:+(a::GenericAffExpr, q::GenericQuadExpr) return GenericQuadExpr(a + q.aff, copy(q.terms)) end -function Base.:-(a::GenericAffExpr, q::GenericQuadExpr) - result = -q - # This makes an unnecessary copy of aff, but it's important for a to appear - # first. - result.aff = a + result.aff +function Base.:-(a::GenericAffExpr{S}, q::GenericQuadExpr{T}) where {S,T} + result = -_copy_convert_coef(_MA.promote_operation(-, S, T), q) + add_to_expression!(result.aff, a) return result end @@ -571,3 +569,57 @@ function Base.complex( ) return r + im * i end + +# These methods exist in LinearAlgebra for subtypes of Real. Without them, we +# return a `Matrix` which looses the Hermitian information. +function Base.:+( + A::LinearAlgebra.Symmetric{V,Matrix{V}}, + B::LinearAlgebra.Hermitian, +) where { + V<:Union{ + GenericVariableRef{<:Real}, + GenericAffExpr{<:Real}, + GenericQuadExpr{<:Real}, + }, +} + return LinearAlgebra.Hermitian(A) + B +end + +function Base.:+( + A::LinearAlgebra.Hermitian, + B::LinearAlgebra.Symmetric{V,Matrix{V}}, +) where { + V<:Union{ + GenericVariableRef{<:Real}, + GenericAffExpr{<:Real}, + GenericQuadExpr{<:Real}, + }, +} + return A + LinearAlgebra.Hermitian(B) +end + +function Base.:-( + A::LinearAlgebra.Symmetric{V,Matrix{V}}, + B::LinearAlgebra.Hermitian, +) where { + V<:Union{ + GenericVariableRef{<:Real}, + GenericAffExpr{<:Real}, + GenericQuadExpr{<:Real}, + }, +} + return LinearAlgebra.Hermitian(A) - B +end + +function Base.:-( + A::LinearAlgebra.Hermitian, + B::LinearAlgebra.Symmetric{V,Matrix{V}}, +) where { + V<:Union{ + GenericVariableRef{<:Real}, + GenericAffExpr{<:Real}, + GenericQuadExpr{<:Real}, + }, +} + return A - LinearAlgebra.Hermitian(B) +end diff --git a/src/sd.jl b/src/sd.jl index a00b85735ac..f07b1a09d65 100644 --- a/src/sd.jl +++ b/src/sd.jl @@ -703,6 +703,22 @@ function build_constraint( return VectorConstraint(x, MOI.Zeros(length(x)), shape) end +# If we have a real-valued Hermitian matrix, then it is actually Symmetric, and +# not Complex-valued Hermitian. +function build_constraint( + error_fn::Function, + H::LinearAlgebra.Hermitian{V}, + set::Zeros, +) where { + V<:Union{ + GenericVariableRef{<:Real}, + GenericAffExpr{<:Real}, + GenericQuadExpr{<:Real}, + }, +} + return build_constraint(error_fn, LinearAlgebra.Symmetric(H), set) +end + reshape_set(s::MOI.Zeros, ::HermitianMatrixShape) = Zeros() function build_constraint(error_fn::Function, ::AbstractMatrix, ::Nonnegatives) diff --git a/test/test_constraint.jl b/test/test_constraint.jl index 282d299de9c..fa6dc9b04a7 100644 --- a/test/test_constraint.jl +++ b/test/test_constraint.jl @@ -2088,4 +2088,23 @@ function test_abstract_vector_orthants() return end +function test_real_hermitian_in_zeros() + model = Model() + @variable(model, x[1:2, 1:2], Symmetric) + c = @constraint(model, LinearAlgebra.Hermitian(x) in Zeros()) + obj = constraint_object(c) + @test obj.func == [x[1, 1], x[1, 2], x[2, 2]] + @test obj.shape == SymmetricMatrixShape(2; needs_adjoint_dual = true) + H = LinearAlgebra.Hermitian([1 2; 2 3]) + c = @constraint(model, x == H) + obj = constraint_object(c) + @test obj.func == [x[1, 1] - 1, x[1, 2] - 2, x[2, 2] - 3] + @test obj.shape == SymmetricMatrixShape(2; needs_adjoint_dual = true) + c = @constraint(model, LinearAlgebra.Hermitian(x .^ 2) in Zeros()) + obj = constraint_object(c) + @test obj.func == [x[1, 1]^2, x[1, 2]^2, x[2, 2]^2] + @test obj.shape == SymmetricMatrixShape(2; needs_adjoint_dual = true) + return +end + end # module diff --git a/test/test_operator.jl b/test/test_operator.jl index d8d27270b41..b29d2b357e0 100644 --- a/test/test_operator.jl +++ b/test/test_operator.jl @@ -469,7 +469,7 @@ function test_extension_basic_operators_affexpr( @test_expression_with_string aff - aff "0 x" # 4-4 AffExpr--QuadExpr @test_expression_with_string aff2 + q "2.5 y*z + 1.2 y + 7.1 x + 3.7" - @test_expression_with_string aff2 - q "-2.5 y*z + 1.2 y - 7.1 x - 1.3" + @test_expression_with_string aff2 - q "-2.5 y*z - 7.1 x + 1.2 y - 1.3" @test_expression_with_string aff2 * q "(1.2 y + 1.2) * (2.5 y*z + 7.1 x + 2.5)" @test_expression_with_string aff2 / q "(1.2 y + 1.2) / (2.5 y*z + 7.1 x + 2.5)" @test transpose(aff) === aff @@ -499,7 +499,7 @@ function test_extension_basic_operators_quadexpr( @test_expression_with_string q * 2 "5 y*z + 14.2 x + 5" @test_expression_with_string q / 2 "1.25 y*z + 3.55 x + 1.25" @test q == q - @test_expression_with_string aff2 - q "-2.5 y*z + 1.2 y - 7.1 x - 1.3" + @test_expression_with_string aff2 - q "-2.5 y*z - 7.1 x + 1.2 y - 1.3" # 4-2 QuadExpr--Variable @test_expression_with_string q + w "2.5 y*z + 7.1 x + w + 2.5" @test_expression_with_string q - w "2.5 y*z + 7.1 x - w + 2.5" @@ -689,4 +689,38 @@ function test_base_complex() return end +function test_aff_minus_quad() + model = Model() + @variable(model, x) + a, b = 1.0 * x, (2 + 3im) * x^2 + @test a - b == -(b - a) + @test b - a == -(a - b) + a, b = (1.0 + 2im) * x, 3 * x^2 + 4 * x + @test a - b == -(b - a) + @test b - a == -(a - b) + return +end + +function test_hermitian_and_symmetric() + model = Model() + @variable(model, A[1:2, 1:2], Symmetric) + @variable(model, B[1:2, 1:2], Hermitian) + for (x, y) in ( + (A, B), + (B, A), + (1.0 * A, B), + (B, 1.0 * A), + (1.0 * A, 1.0 * B), + (1.0 * B, 1.0 * A), + (1.0 * LinearAlgebra.Symmetric(A .* A), 1.0 * B), + (1.0 * B, 1.0 * LinearAlgebra.Symmetric(A .* A)), + ) + @test x + y isa LinearAlgebra.Hermitian + @test x + y == x .+ y + @test x - y isa LinearAlgebra.Hermitian + @test x - y == x .- y + end + return +end + end