From f6e18aa025ffb8fe4b7a910a7d9542cf6b87242a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Thu, 31 Aug 2023 13:22:09 +0200 Subject: [PATCH] Improve performance of sum of products with nonlinear expressions --- src/mutable_arithmetics.jl | 28 ++++------------------------ src/nlp_expr.jl | 14 +++++++++----- 2 files changed, 13 insertions(+), 29 deletions(-) diff --git a/src/mutable_arithmetics.jl b/src/mutable_arithmetics.jl index d7b77506520..e9676cc580c 100644 --- a/src/mutable_arithmetics.jl +++ b/src/mutable_arithmetics.jl @@ -286,12 +286,7 @@ end function _MA.add_mul(lhs::AbstractJuMPScalar, x::_Scalar, y::_Scalar) T = _MA.promote_operation(_MA.add_mul, typeof(lhs), typeof(x), typeof(y)) expr = _MA.operate(convert, T, lhs) - # We can't use `operate!!` here because in the IsNotMutable case (e.g., - # NonlinearExpr), it will fallback to this method and cause a StackOverflow. - if _MA.mutability(T) == _MA.IsNotMutable() - return expr + _MA.operate(*, x, y) - end - return _MA.operate!(_MA.add_mul, expr, x, y) + return _MA.operate!!(_MA.add_mul, expr, x, y) end function _MA.add_mul( @@ -308,23 +303,13 @@ function _MA.add_mul( typeof.(args)..., ) expr = _MA.operate(convert, T, lhs) - # We can't use `operate!!` here because in the IsNotMutable case (e.g., - # NonlinearExpr), it will fallback to this method and cause a StackOverflow. - if _MA.mutability(T) == _MA.IsNotMutable() - return expr + _MA.operate(*, x, y, args...) - end - return _MA.operate!(_MA.add_mul, expr, x, y, args...) + return _MA.operate!!(_MA.add_mul, expr, x, y, args...) end function _MA.sub_mul(lhs::AbstractJuMPScalar, x::_Scalar, y::_Scalar) T = _MA.promote_operation(_MA.sub_mul, typeof(lhs), typeof(x), typeof(y)) expr = _MA.operate(convert, T, lhs) - # We can't use `operate!!` here because in the IsNotMutable case (e.g., - # NonlinearExpr), it will fallback to this method and cause a StackOverflow. - if _MA.mutability(T) == _MA.IsNotMutable() - return expr - _MA.operate(*, x, y) - end - return _MA.operate!(_MA.sub_mul, expr, x, y) + return _MA.operate!!(_MA.sub_mul, expr, x, y) end function _MA.sub_mul( @@ -341,10 +326,5 @@ function _MA.sub_mul( typeof.(args)..., ) expr = _MA.operate(convert, T, lhs) - # We can't use `operate!!` here because in the IsNotMutable case (e.g., - # NonlinearExpr), it will fallback to this method and cause a StackOverflow. - if _MA.mutability(T) == _MA.IsNotMutable() - return expr - _MA.operate(*, x, y, args...) - end - return _MA.operate!(_MA.sub_mul, expr, x, y, args...) + return _MA.operate!!(_MA.sub_mul, expr, x, y, args...) end diff --git a/src/nlp_expr.jl b/src/nlp_expr.jl index 6fe42fcb5ee..116968db25a 100644 --- a/src/nlp_expr.jl +++ b/src/nlp_expr.jl @@ -368,16 +368,20 @@ for f in (:+, :-, :*, :^, :/, :atan) end function _MA.operate!!( - ::typeof(_MA.add_mul), + op::_MA.AddSubMul, x::GenericNonlinearExpr, - y::AbstractJuMPScalar, -) + args::Vararg{Any,N}, +) where {N} _throw_if_not_real(x) if x.head == :+ - push!(x.args, y) + arg = *(args...) + if _MA.add_sub_op(op) != + + arg = _MA.add_sub_op(op)(arg) + end + push!(x.args, arg) return x end - return +(x, y) + return _MA.add_sub_op(op)(x, *(args...)) end """