Skip to content

Commit

Permalink
fall-back to normal matrix exponentiation
Browse files Browse the repository at this point in the history
  • Loading branch information
C.A.P. Linssen committed Feb 13, 2025
1 parent b692284 commit d5ba448
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 9 deletions.
30 changes: 21 additions & 9 deletions odetoolbox/system_of_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,20 +35,31 @@
from .sympy_helpers import _custom_simplify_expr, _is_zero


class GetBlockDiagonalException(Exception):
"""
Thrown in case an error occurs while block diagonalising a matrix.
"""
pass


def get_block_diagonal_blocks(A):
assert A.shape[0] == A.shape[1], "matrix A should be square"

A_mirrored = (A + A.T) != 0 # make the matrix symmetric so we only have to check one triangle

graph_components = scipy.sparse.csgraph.connected_components(A_mirrored)[1]

assert all(np.diff(graph_components) >= 0), "Matrix is not ordered"
if not all(np.diff(graph_components) >= 0):
# matrix is not ordered
raise GetBlockDiagonalException()

blocks = []
for i in np.unique(graph_components):
idx = np.where(graph_components == i)[0]
assert all(np.diff(idx) > 0)
assert len(idx) == 1 or (len(np.unique(np.diff(idx))) == 1 and np.unique(np.diff(idx))[0] == 1)

if not all(np.diff(idx) > 0) or not (len(idx) == 1 or (len(np.unique(np.diff(idx))) == 1 and np.unique(np.diff(idx))[0] == 1)):
raise GetBlockDiagonalException()

idx_min = np.amin(idx)
idx_max = np.amax(idx)
block = A[idx_min:idx_max + 1, idx_min:idx_max + 1]
Expand Down Expand Up @@ -200,13 +211,14 @@ def _generate_propagator_matrix(self, A):
XXX: the default custom simplification expression does not work well with sympy 1.4 here. Consider replacing sympy.simplify() with _custom_simplify_expr() if sympy 1.4 support is dropped.
"""

# naive: calculate propagators in one step
# P_naive = sympy.simplify(sympy.exp(A * sympy.Symbol(Config().output_timestep_symbol)))

# optimized: be explicit about block diagonal elements; much faster!
blocks = get_block_diagonal_blocks(np.array(A))
propagators = [sympy.simplify(sympy.exp(sympy.Matrix(block) * sympy.Symbol(Config().output_timestep_symbol))) for block in blocks]
P = sympy.Matrix(scipy.linalg.block_diag(*propagators))
try:
blocks = get_block_diagonal_blocks(np.array(A))
propagators = [sympy.simplify(sympy.exp(sympy.Matrix(block) * sympy.Symbol(Config().output_timestep_symbol))) for block in blocks]
P = sympy.Matrix(scipy.linalg.block_diag(*propagators))
except:
# naive: calculate propagators in one step
P = sympy.simplify(sympy.exp(A * sympy.Symbol(Config().output_timestep_symbol)))

# check the result
if sympy.I in sympy.preorder_traversal(P):
Expand Down
38 changes: 38 additions & 0 deletions tests/test_propagator_solver_homogeneous.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
{
"dynamics": [
{
"expression": "V_m' = (-(V_m - E_L)) / tau_m + (I_kernel_exc__X__exc_spikes * 1.0 - I_kernel_inh__X__inh_spikes * 1.0 + I_e + I_stim) / C_m",
"initial_values": {
"V_m": "E_L"
}
},
{
"expression": "refr_t' = -1",
"initial_values": {
"refr_t": "0"
}
},
{
"expression": "I_kernel_exc__X__exc_spikes = (e / tau_syn_exc) * t * exp(-t / tau_syn_exc)",
"initial_values": {}
},
{
"expression": "I_kernel_inh__X__inh_spikes = (e / tau_syn_inh) * t * exp(-t / tau_syn_inh)",
"initial_values": {}
}
],
"options": {
"output_timestep_symbol": "__h"
},
"parameters": {
"C_m": "250",
"E_L": "(-70)",
"I_e": "0",
"V_reset": "(-70)",
"V_th": "(-55)",
"refr_T": "2",
"tau_m": "10",
"tau_syn_exc": "2",
"tau_syn_inh": "2"
}
}
57 changes: 57 additions & 0 deletions tests/test_propagator_solver_homogeneous.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
#
# test_analytic_solver_integration.py
#
# This file is part of the NEST ODE toolbox.
#
# Copyright (C) 2017 The NEST Initiative
#
# The NEST ODE toolbox is free software: you can redistribute it
# and/or modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation, either version 2 of
# the License, or (at your option) any later version.
#
# The NEST ODE toolbox is distributed in the hope that it will be
# useful, but WITHOUT ANY WARRANTY; without even the implied warranty
# of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with NEST. If not, see <http://www.gnu.org/licenses/>.
#

import math
import numpy as np
import os
import sympy
import sympy.parsing.sympy_parser
import scipy
import scipy.special
import scipy.linalg
import scipy.integrate


try:
import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as plt
INTEGRATION_TEST_DEBUG_PLOTS = True
except ImportError:
INTEGRATION_TEST_DEBUG_PLOTS = False


from .context import odetoolbox
from odetoolbox.analytic_integrator import AnalyticIntegrator
from tests.test_utils import _open_json


class TestPropagatorSolverHomogeneous:
r"""Test ODE-toolbox ability to come up with a propagator solver for a matrix that is not block-diagonalisable, because it contains an autonomous ODE."""

def test_propagator_solver_homogeneous(self):
indict = _open_json("test_propagator_solver_homogeneous.json")
solver_dict = odetoolbox.analysis(indict, disable_stiffness_check=True, log_level="DEBUG")
assert len(solver_dict) == 1
solver_dict = solver_dict[0]
assert solver_dict["solver"] == "analytical"
assert float(solver_dict["propagators"]["__P__refr_t__refr_t"]) == 1.
assert solver_dict["propagators"]["__P__V_m__V_m"] == "1.0*exp(-__h/tau_m)"

0 comments on commit d5ba448

Please sign in to comment.