Skip to content

Commit 1e3b332

Browse files
authored
Merge pull request #119 from cklb/Fix_missing_user_data
Provide user_data for rhs, jac and err calls when using IDA
2 parents 159dbb8 + 373511d commit 1e3b332

File tree

9 files changed

+177
-91
lines changed

9 files changed

+177
-91
lines changed

scikits/odes/dopri5.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,12 @@
4343

4444
from __future__ import print_function
4545

46+
import sys
4647
from collections import namedtuple
47-
from enum import IntEnum
48+
try:
49+
from enum import IntEnum
50+
except ImportError:
51+
from enum34 import IntEnum
4852
from warnings import warn
4953

5054
import numpy as np

scikits/odes/sundials/__init__.py

+8-12
Original file line numberDiff line numberDiff line change
@@ -51,16 +51,12 @@ def _get_num_args(func):
5151
"""
5252
Python 2/3 compatible method of getting number of args that `func` accepts
5353
"""
54-
if hasattr(inspect, "signature"):
55-
sig = inspect.signature(func)
56-
numargs = 0
57-
for param in sig.parameters.values():
58-
if param.kind in (
59-
inspect.Parameter.POSITIONAL_ONLY,
60-
inspect.Parameter.POSITIONAL_OR_KEYWORD,
61-
inspect.Parameter.VAR_POSITIONAL,
62-
):
63-
numargs += 1
64-
return numargs
54+
if hasattr(inspect, "getfullargspec"):
55+
argspec = inspect.getfullargspec(func)
6556
else:
66-
return len(inspect.getargspec(func).args)
57+
argspec = inspect.getargspec(func)
58+
arg_cnt = 0
59+
for arg in argspec.args:
60+
if arg not in ("self", "cls"):
61+
arg_cnt += 1
62+
return arg_cnt

scikits/odes/sundials/cvode.pyx

+4-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
# cython: embedsignature=True
22
from cpython.exc cimport PyErr_CheckSignals
33
from collections import namedtuple
4-
from enum import IntEnum
4+
try:
5+
from enum import IntEnum
6+
except ImportError:
7+
from enum34 import IntEnum
58
import inspect
69
from warnings import warn
710

scikits/odes/sundials/cvodes.pyx

+4-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
# cython: embedsignature=True
22
from cpython.exc cimport PyErr_CheckSignals
33
from collections import namedtuple
4-
from enum import IntEnum
4+
try:
5+
from enum import IntEnum
6+
except ImportError:
7+
from enum34 import IntEnum
58
import inspect
69
from warnings import warn
710

scikits/odes/sundials/ida.pxd

+3-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@ cdef class IDA_JacRhsFunction:
3232
np.ndarray[DTYPE_t, ndim=1] ydot,
3333
np.ndarray[DTYPE_t, ndim=1] residual,
3434
DTYPE_t cj,
35-
np.ndarray[DTYPE_t, ndim=2] J) except? -1
35+
np.ndarray[DTYPE_t, ndim=2] J,
36+
object userdata = *) except? -1
37+
3638

3739
cdef class IDA_WrapJacRhsFunction(IDA_JacRhsFunction):
3840
cdef object _jacfn

scikits/odes/sundials/ida.pyx

+59-59
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
# cython: embedsignature=True
22
from cpython.exc cimport PyErr_CheckSignals
33
from collections import namedtuple
4-
from enum import IntEnum
4+
try:
5+
from enum import IntEnum
6+
except ImportError:
7+
from enum34 import IntEnum
58
import inspect
69
from warnings import warn
710

@@ -127,7 +130,8 @@ cdef class IDA_RhsFunction:
127130
recoverable failure, negative for unrecoverable failure (as per IDA
128131
documentation).
129132
"""
130-
cpdef int evaluate(self, DTYPE_t t,
133+
cpdef int evaluate(self,
134+
DTYPE_t t,
131135
np.ndarray[DTYPE_t, ndim=1] y,
132136
np.ndarray[DTYPE_t, ndim=1] ydot,
133137
np.ndarray[DTYPE_t, ndim=1] result,
@@ -139,16 +143,14 @@ cdef class IDA_WrapRhsFunction(IDA_RhsFunction):
139143
"""
140144
set some residual equations as a ResFunction executable class
141145
"""
142-
self.with_userdata = 0
143-
nrarg = _get_num_args(resfn)
144-
if nrarg > 5:
145-
# hopefully a class method
146-
self.with_userdata = 1
147-
elif nrarg == 5 and inspect.isfunction(resfn):
146+
if _get_num_args(resfn) == 5:
148147
self.with_userdata = 1
148+
else:
149+
self.with_userdata = 0
149150
self._resfn = resfn
150151

151-
cpdef int evaluate(self, DTYPE_t t,
152+
cpdef int evaluate(self,
153+
DTYPE_t t,
152154
np.ndarray[DTYPE_t, ndim=1] y,
153155
np.ndarray[DTYPE_t, ndim=1] ydot,
154156
np.ndarray[DTYPE_t, ndim=1] result,
@@ -197,7 +199,8 @@ cdef class IDA_RootFunction:
197199
Note that evaluate must return a integer, 0 for success, non-zero for error
198200
(as per IDA documentation).
199201
"""
200-
cpdef int evaluate(self, DTYPE_t t,
202+
cpdef int evaluate(self,
203+
DTYPE_t t,
201204
np.ndarray[DTYPE_t, ndim=1] y,
202205
np.ndarray[DTYPE_t, ndim=1] ydot,
203206
np.ndarray[DTYPE_t, ndim=1] g,
@@ -209,16 +212,14 @@ cdef class IDA_WrapRootFunction(IDA_RootFunction):
209212
"""
210213
set root-ing condition(equations) as a RootFunction executable class
211214
"""
212-
self.with_userdata = 0
213-
nrarg = _get_num_args(rootfn)
214-
if nrarg > 5:
215-
#hopefully a class method, self gives 5 arg!
216-
self.with_userdata = 1
217-
elif nrarg == 5 and inspect.isfunction(rootfn):
215+
if _get_num_args(rootfn) == 5:
218216
self.with_userdata = 1
217+
else:
218+
self.with_userdata = 0
219219
self._rootfn = rootfn
220220

221-
cpdef int evaluate(self, DTYPE_t t,
221+
cpdef int evaluate(self,
222+
DTYPE_t t,
222223
np.ndarray[DTYPE_t, ndim=1] y,
223224
np.ndarray[DTYPE_t, ndim=1] ydot,
224225
np.ndarray[DTYPE_t, ndim=1] g,
@@ -268,12 +269,14 @@ cdef class IDA_JacRhsFunction:
268269
recoverable failure, negative for unrecoverable failure (as per IDA
269270
documentation).
270271
"""
271-
cpdef int evaluate(self, DTYPE_t t,
272+
cpdef int evaluate(self,
273+
DTYPE_t t,
272274
np.ndarray[DTYPE_t, ndim=1] y,
273275
np.ndarray[DTYPE_t, ndim=1] ydot,
274276
np.ndarray[DTYPE_t, ndim=1] residual,
275277
DTYPE_t cj,
276-
np.ndarray J) except? -1:
278+
np.ndarray J,
279+
object userdata = None) except? -1:
277280
"""
278281
Returns the Jacobi matrix of the residual function, as
279282
d(res)/d y + cj d(res)/d ydot
@@ -291,24 +294,29 @@ cdef class IDA_WrapJacRhsFunction(IDA_JacRhsFunction):
291294
"""
292295
Set some jacobian equations as a JacResFunction executable class.
293296
"""
297+
if _get_num_args(jacfn) == 7:
298+
self.with_userdata = 1
299+
else:
300+
self.with_userdata = 0
294301
self._jacfn = jacfn
295302

296-
cpdef int evaluate(self, DTYPE_t t,
303+
cpdef int evaluate(self,
304+
DTYPE_t t,
297305
np.ndarray[DTYPE_t, ndim=1] y,
298306
np.ndarray[DTYPE_t, ndim=1] ydot,
299307
np.ndarray[DTYPE_t, ndim=1] residual,
300308
DTYPE_t cj,
301-
np.ndarray J) except? -1:
309+
np.ndarray J,
310+
object userdata = None) except? -1:
302311
"""
303312
Returns the Jacobi matrix (for dense the full matrix, for band only
304313
bands. Result has to be stored in the variable J, which is preallocated
305314
to the corresponding size.
306315
"""
307-
## if self.with_userdata == 1:
308-
## self._jacfn(t, y, ydot, cj, J, userdata)
309-
## else:
310-
## self._jacfn(t, y, ydot, cj, J)
311-
user_flag = self._jacfn(t, y, ydot, residual, cj, J)
316+
if self.with_userdata == 1:
317+
user_flag = self._jacfn(t, y, ydot, residual, cj, J, userdata)
318+
else:
319+
user_flag = self._jacfn(t, y, ydot, residual, cj, J)
312320
if user_flag is None:
313321
user_flag = 0
314322
return user_flag
@@ -336,7 +344,7 @@ cdef int _jacdense(realtype tt, realtype cj,
336344
nv_s2ndarray(yy, yy_tmp)
337345
nv_s2ndarray(yp, yp_tmp)
338346
nv_s2ndarray(rr, residual_tmp)
339-
user_flag = aux_data.jac.evaluate(tt, yy_tmp, yp_tmp, residual_tmp, cj, jac_tmp)
347+
user_flag = aux_data.jac.evaluate(tt, yy_tmp, yp_tmp, residual_tmp, cj, jac_tmp, aux_data.user_data)
340348

341349
if parallel_implementation:
342350
raise NotImplemented
@@ -355,7 +363,8 @@ cdef class IDA_PrecSetupFunction:
355363
recoverable failure, negative for unrecoverable failure (as per CVODE
356364
documentation).
357365
"""
358-
cpdef int evaluate(self, DTYPE_t t,
366+
cpdef int evaluate(self,
367+
DTYPE_t t,
359368
np.ndarray[DTYPE_t, ndim=1] y,
360369
np.ndarray[DTYPE_t, ndim=1] yp,
361370
np.ndarray[DTYPE_t, ndim=1] rr,
@@ -377,16 +386,14 @@ cdef class IDA_WrapPrecSetupFunction(IDA_PrecSetupFunction):
377386
set a precondititioning setup method as a IDA_PrecSetupFunction
378387
executable class
379388
"""
380-
self.with_userdata = 0
381-
nrarg = _get_num_args(prec_setupfn)
382-
if nrarg > 5:
383-
#hopefully a class method, self gives 6 arg!
384-
self.with_userdata = 1
385-
elif nrarg == 5 and inspect.isfunction(prec_setupfn):
389+
if _get_num_args(prec_setupfn) == 6:
386390
self.with_userdata = 1
391+
else:
392+
self.with_userdata = 0
387393
self._prec_setupfn = prec_setupfn
388394

389-
cpdef int evaluate(self, DTYPE_t t,
395+
cpdef int evaluate(self,
396+
DTYPE_t t,
390397
np.ndarray[DTYPE_t, ndim=1] y,
391398
np.ndarray[DTYPE_t, ndim=1] yp,
392399
np.ndarray[DTYPE_t, ndim=1] rr,
@@ -433,7 +440,8 @@ cdef class IDA_PrecSolveFunction:
433440
recoverable failure, negative for unrecoverable failure (as per CVODE
434441
documentation).
435442
"""
436-
cpdef int evaluate(self, DTYPE_t t,
443+
cpdef int evaluate(self,
444+
DTYPE_t t,
437445
np.ndarray[DTYPE_t, ndim=1] y,
438446
np.ndarray[DTYPE_t, ndim=1] yp,
439447
np.ndarray[DTYPE_t, ndim=1] r,
@@ -460,16 +468,14 @@ cdef class IDA_WrapPrecSolveFunction(IDA_PrecSolveFunction):
460468
set a precondititioning solve method as a IDA_PrecSolveFunction
461469
executable class
462470
"""
463-
self.with_userdata = 0
464-
nrarg = _get_num_args(prec_solvefn)
465-
if nrarg > 9:
466-
#hopefully a class method, self gives 10 arg!
467-
self.with_userdata = 1
468-
elif nrarg == 9 and inspect.isfunction(prec_solvefn):
471+
if _get_num_args(prec_solvefn) == 9:
469472
self.with_userdata = 1
473+
else:
474+
self.with_userdata = 0
470475
self._prec_solvefn = prec_solvefn
471476

472-
cpdef int evaluate(self, DTYPE_t t,
477+
cpdef int evaluate(self,
478+
DTYPE_t t,
473479
np.ndarray[DTYPE_t, ndim=1] y,
474480
np.ndarray[DTYPE_t, ndim=1] yp,
475481
np.ndarray[DTYPE_t, ndim=1] r,
@@ -567,13 +573,10 @@ cdef class IDA_WrapJacTimesVecFunction(IDA_JacTimesVecFunction):
567573
set a jacobian-times-vector method as a IDA_JacTimesVecFunction
568574
executable class
569575
"""
570-
self.with_userdata = 0
571-
nrarg = _get_num_args(jac_times_vecfn)
572-
if nrarg > 8:
573-
#hopefully a class method, self gives 9 arg!
574-
self.with_userdata = 1
575-
elif nrarg == 8 and inspect.isfunction(jac_times_vecfn):
576+
if _get_num_args(jac_times_vecfn) == 8:
576577
self.with_userdata = 1
578+
else:
579+
self.with_userdata = 0
577580
self._jac_times_vecfn = jac_times_vecfn
578581

579582
cpdef int evaluate(self,
@@ -655,13 +658,10 @@ cdef class IDA_WrapJacTimesSetupFunction(IDA_JacTimesSetupFunction):
655658
set a jacobian-times-vector method setup as a IDA_JacTimesSetupFunction
656659
executable class
657660
"""
658-
self.with_userdata = 0
659-
nrarg = _get_num_args(jac_times_setupfn)
660-
if nrarg > 6:
661-
#hopefully a class method, self gives 7 arg!
662-
self.with_userdata = 1
663-
elif nrarg == 6 and inspect.isfunction(jac_times_setupfn):
661+
if _get_num_args(jac_times_setupfn) == 6:
664662
self.with_userdata = 1
663+
else:
664+
self.with_userdata = 0
665665
self._jac_times_setupfn = jac_times_setupfn
666666

667667
cpdef int evaluate(self,
@@ -734,10 +734,10 @@ cdef class IDA_WrapErrHandler(IDA_ErrHandler):
734734
"""
735735
set some (c/p)ython function as the error handler
736736
"""
737-
nrarg = _get_num_args(err_handler)
738-
self.with_userdata = (nrarg > 5) or (
739-
nrarg == 5 and inspect.isfunction(err_handler)
740-
)
737+
if _get_num_args(err_handler) == 5:
738+
self.with_userdata = 1
739+
else:
740+
self.with_userdata = 0
741741
self._err_handler = err_handler
742742

743743
cpdef evaluate(self,

scikits/odes/sundials/idas.pyx

+4-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
# cython: embedsignature=True
22
from cpython.exc cimport PyErr_CheckSignals
33
from collections import namedtuple
4-
from enum import IntEnum
4+
try:
5+
from enum import IntEnum
6+
except ImportError:
7+
from enum34 import IntEnum
58
import inspect
69
from warnings import warn
710

0 commit comments

Comments
 (0)