Skip to content

Commit

Permalink
Use 1D five point sencil for derivative approx (#442)
Browse files Browse the repository at this point in the history
  • Loading branch information
rparini authored Dec 24, 2024
1 parent 68aaf92 commit 2737be6
Show file tree
Hide file tree
Showing 6 changed files with 11 additions and 11 deletions.
1 change: 1 addition & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

### Unreleased

- Changes derivative approximation to use 1D five-point sencil
- Change project to use pyproject.toml
- Set __version__ using setuptools-scm

Expand Down
6 changes: 3 additions & 3 deletions cxroots/derivative.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
from .types import AnalyticFunc, ComplexScalarOrArray, ScalarOrArray


def central_diff(
def approx_deriv(
f: AnalyticFunc,
) -> AnalyticFunc:
h = 1e-6
h = 1e-5

@overload
def df(
Expand All @@ -20,6 +20,6 @@ def df(
def df(z: complex | float) -> complex: ...

def df(z: ScalarOrArray) -> ComplexScalarOrArray:
return (f(z + h) - f(z - h)) / (2 * h)
return (-f(z + 2 * h) + 8 * f(z + h) - 8 * f(z - h) + f(z - 2 * h)) / (12 * h)

return df
4 changes: 2 additions & 2 deletions cxroots/root_counting.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import numpy.typing as npt

from .contour_interface import ContourABC
from .derivative import central_diff
from .derivative import approx_deriv
from .types import AnalyticFunc, ComplexScalarOrArray, IntegrationMethod, ScalarOrArray

RombCallback = Callable[[complex, float | None, int], bool | None]
Expand Down Expand Up @@ -160,7 +160,7 @@ def _quad_prod(
rel_tol: float = 1.49e-08,
) -> complex:
if df is None:
df = central_diff(f)
df = approx_deriv(f)

def one(z: ScalarOrArray) -> int:
return 1
Expand Down
1 change: 0 additions & 1 deletion cxroots/tests/test_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
[3, 3.0001, 3.0002, 8, 8.0002, 8 + 0.0001j],
[1, 1, 1, 1, 1, 1],
id="cluster_10^-4",
marks=pytest.mark.slow,
),
pytest.param(
[3, 3.00001, 3.00002, 8, 8.00002, 8 + 0.00001j],
Expand Down
6 changes: 3 additions & 3 deletions cxroots/tests/test_deriv.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
import pytest
from numpy import cos, sin

from cxroots.derivative import central_diff
from cxroots.derivative import approx_deriv


def test_central_diff():
def test_approx_deriv():
def f(z):
return z**10 - 2 * z**5 + sin(z) * cos(z / 2)

Expand All @@ -14,6 +14,6 @@ def df(z):

z = np.array([-1.234, 0.3 + 1j, 0.1j, -0.9 - 0.5j])

approx_df = central_diff(f)
approx_df = approx_deriv(f)

assert approx_df(z) == pytest.approx(df(z), abs=1e-8)
4 changes: 2 additions & 2 deletions docs/_modules/cxroots/root_counting.html
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ <h1>Source code for cxroots.root_counting</h1><div class="highlight"><pre>
<span class="kn">import</span> <span class="nn">numpy.typing</span> <span class="k">as</span> <span class="nn">npt</span>

<span class="kn">from</span> <span class="nn">.contour_interface</span> <span class="kn">import</span> <span class="n">ContourABC</span>
<span class="kn">from</span> <span class="nn">.derivative</span> <span class="kn">import</span> <span class="n">central_diff</span>
<span class="kn">from</span> <span class="nn">.derivative</span> <span class="kn">import</span> <span class="n">approx_deriv</span>
<span class="kn">from</span> <span class="nn">.types</span> <span class="kn">import</span> <span class="n">AnalyticFunc</span><span class="p">,</span> <span class="n">ComplexScalarOrArray</span><span class="p">,</span> <span class="n">IntegrationMethod</span><span class="p">,</span> <span class="n">ScalarOrArray</span>

<span class="n">RombCallback</span> <span class="o">=</span> <span class="n">Callable</span><span class="p">[[</span><span class="nb">complex</span><span class="p">,</span> <span class="nb">float</span> <span class="o">|</span> <span class="kc">None</span><span class="p">,</span> <span class="nb">int</span><span class="p">],</span> <span class="nb">bool</span> <span class="o">|</span> <span class="kc">None</span><span class="p">]</span>
Expand Down Expand Up @@ -263,7 +263,7 @@ <h1>Source code for cxroots.root_counting</h1><div class="highlight"><pre>
<span class="n">rel_tol</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.49e-08</span><span class="p">,</span>
<span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">complex</span><span class="p">:</span>
<span class="k">if</span> <span class="n">df</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">df</span> <span class="o">=</span> <span class="n">central_diff</span><span class="p">(</span><span class="n">f</span><span class="p">)</span>
<span class="n">df</span> <span class="o">=</span> <span class="n">approx_deriv</span><span class="p">(</span><span class="n">f</span><span class="p">)</span>

<span class="k">def</span> <span class="nf">one</span><span class="p">(</span><span class="n">z</span><span class="p">:</span> <span class="n">ScalarOrArray</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">int</span><span class="p">:</span>
<span class="k">return</span> <span class="mi">1</span>
Expand Down

0 comments on commit 2737be6

Please sign in to comment.