Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature : homogenize imports of optional deps #440

Merged
merged 2 commits into from
Oct 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 6 additions & 9 deletions pylops/basicoperators/spread.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,13 @@
import numpy as np

from pylops import LinearOperator
from pylops.utils import deps
from pylops.utils.decorators import reshaped
from pylops.utils.typing import DTypeLike, InputDimsLike, NDArray

try:
jit_message = deps.numba_import("the spread module")

if jit_message is None:
from numba import jit

from ._spread_numba import (
Expand All @@ -18,12 +21,6 @@
_rmatvec_numba_onthefly,
_rmatvec_numba_table,
)
except ModuleNotFoundError:
jit = None
jit_message = "Numba not available, reverting to numpy."
except Exception as e:
jit = None
jit_message = "Failed to import numba (error:%s), use numpy." % e

logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.WARNING)

Expand Down Expand Up @@ -183,10 +180,10 @@ def __init__(

if engine not in ["numpy", "numba"]:
raise KeyError("engine must be numpy or numba")
if engine == "numba" and jit is not None:
if engine == "numba" and jit_message is None:
self.engine = "numba"
else:
if engine == "numba" and jit is None:
if engine == "numba" and jit is not None:
logging.warning(jit_message)
self.engine = "numpy"

Expand Down
13 changes: 5 additions & 8 deletions pylops/optimization/cls_sparsity.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,15 @@
normal_equations_inversion,
regularized_inversion,
)
from pylops.utils import deps
from pylops.utils.backend import get_array_module, get_module_name
from pylops.utils.decorators import disable_ndarray_multiplication
from pylops.utils.typing import InputDimsLike, NDArray, SamplingLike

try:
spgl1_message = deps.spgl1_import("the spgl1 solver")

if spgl1_message is None:
from spgl1 import spgl1 as ext_spgl1
except ModuleNotFoundError:
ext_spgl1 = None
spgl1_message = "Spgl1 not installed. " 'Run "pip install spgl1".'
except Exception as e:
ext_spgl1 = None
spgl1_message = f"Failed to import spgl1 (error:{e})."


def _hardthreshold(x: NDArray, thresh: float) -> NDArray:
Expand Down Expand Up @@ -1763,7 +1760,7 @@ def setup(
Display setup log

"""
if ext_spgl1 is None:
if spgl1_message is not None:
raise ModuleNotFoundError(spgl1_message)

self.y = y
Expand Down
16 changes: 4 additions & 12 deletions pylops/signalprocessing/chirpradon3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,18 @@
import numpy as np

from pylops import LinearOperator
from pylops.utils import deps
from pylops.utils.decorators import reshaped
from pylops.utils.typing import DTypeLike, NDArray

from ._chirpradon3d import _chirp_radon_3d

try:
pyfftw_message = deps.pyfftw_import("the chirpradon3d module")

if pyfftw_message is None:
import pyfftw

from ._chirpradon3d import _chirp_radon_3d_fftw
except ModuleNotFoundError:
pyfftw = None
pyfftw_message = (
"Pyfftw not installed, use numpy or run "
'"pip install pyFFTW" or '
'"conda install -c conda-forge pyfftw".'
)
except Exception as e:
pyfftw = None
pyfftw_message = f"Failed to import pyfftw (error:{e}), use numpy."


logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.WARNING)

Expand Down
17 changes: 5 additions & 12 deletions pylops/signalprocessing/dwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,14 @@

from pylops import LinearOperator
from pylops.basicoperators import Pad
from pylops.utils import deps
from pylops.utils._internal import _value_or_sized_to_tuple
from pylops.utils.typing import DTypeLike, InputDimsLike, NDArray

try:
pywt_message = deps.pywt_import("the dwt module")

if pywt_message is None:
import pywt
except ModuleNotFoundError:
pywt = None
pywt_message = (
"Pywt package not installed. "
'Run "pip install PyWavelets" or '
'conda install pywavelets".'
)
except Exception as e:
pywt = None
pywt_message = f"Failed to import pywt (error:{e})."

logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.WARNING)

Expand Down Expand Up @@ -113,7 +106,7 @@ def __init__(
dtype: DTypeLike = "float64",
name: str = "D",
) -> None:
if pywt is None:
if pywt_message is not None:
raise ModuleNotFoundError(pywt_message)
_checkwavelet(wavelet)

Expand Down
17 changes: 5 additions & 12 deletions pylops/signalprocessing/dwt2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,15 @@

from pylops import LinearOperator
from pylops.basicoperators import Pad
from pylops.utils import deps
from pylops.utils.typing import DTypeLike, InputDimsLike, NDArray

from .dwt import _adjointwavelet, _checkwavelet

try:
pywt_message = deps.pywt_import("the dwt2d module")

if pywt_message is None:
import pywt
except ModuleNotFoundError:
pywt = None
pywt_message = (
"Pywt package not installed. "
'Run "pip install PyWavelets" or '
'conda install pywavelets".'
)
except Exception as e:
pywt = None
pywt_message = f"Failed to import pywt (error:{e})."

logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.WARNING)

Expand Down Expand Up @@ -90,7 +83,7 @@ def __init__(
dtype: DTypeLike = "float64",
name: str = "D",
) -> None:
if pywt is None:
if pywt_message is not None:
raise ModuleNotFoundError(pywt_message)
_checkwavelet(wavelet)

Expand Down
21 changes: 7 additions & 14 deletions pylops/signalprocessing/fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,14 @@

from pylops import LinearOperator
from pylops.signalprocessing._baseffts import _BaseFFT, _FFTNorms
from pylops.utils import deps
from pylops.utils.decorators import reshaped
from pylops.utils.typing import DTypeLike, InputDimsLike, NDArray

try:
pyfftw_message = deps.pyfftw_import("the fft module")

if pyfftw_message is None:
import pyfftw
except ModuleNotFoundError:
pyfftw = None
pyfftw_message = (
"Pyfftw not installed, use numpy or run "
'"pip install pyFFTW" or '
'"conda install -c conda-forge pyfftw".'
)
except Exception as e:
pyfftw = None
pyfftw_message = f"Failed to import pyfftw (error:{e}), use numpy."

logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.WARNING)

Expand Down Expand Up @@ -544,7 +537,7 @@ def FFT(
signals.

"""
if engine == "fftw" and pyfftw is not None:
if engine == "fftw" and pyfftw_message is None:
f = _FFT_fftw(
dims,
axis=axis,
Expand All @@ -557,8 +550,8 @@ def FFT(
dtype=dtype,
**kwargs_fftw,
)
elif engine == "numpy" or (engine == "fftw" and pyfftw is None):
if engine == "fftw" and pyfftw is None:
elif engine == "numpy" or (engine == "fftw" and pyfftw_message is not None):
if engine == "fftw" and pyfftw_message is not None:
logging.warning(pyfftw_message)
f = _FFT_numpy(
dims,
Expand Down
9 changes: 5 additions & 4 deletions pylops/signalprocessing/radon2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@
import numpy as np

from pylops.basicoperators import Spread
from pylops.utils import deps
from pylops.utils.typing import DTypeLike, NDArray

try:
jit_message = deps.numba_import("the radon2d module")

if jit_message is None:
from numba import jit

from ._radon2d_numba import (
Expand All @@ -18,8 +21,6 @@
_linear_numba,
_parabolic_numba,
)
except ModuleNotFoundError:
jit = None

logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.WARNING)

Expand Down Expand Up @@ -246,7 +247,7 @@ def Radon2D(
# engine
if engine not in ["numpy", "numba"]:
raise KeyError("engine must be numpy or numba")
if engine == "numba" and jit is None:
if engine == "numba" and jit_message is not None:
engine = "numpy"
# axes
nt, nh, npx = taxis.size, haxis.size, pxaxis.size
Expand Down
9 changes: 5 additions & 4 deletions pylops/signalprocessing/radon3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@
import numpy as np

from pylops.basicoperators import Spread
from pylops.utils import deps
from pylops.utils.typing import DTypeLike, NDArray

try:
jit_message = deps.numba_import("the radon3d module")

if jit_message is None:
from numba import jit

from ._radon3d_numba import (
Expand All @@ -18,8 +21,6 @@
_linear_numba,
_parabolic_numba,
)
except ModuleNotFoundError:
jit = None

logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.WARNING)

Expand Down Expand Up @@ -270,7 +271,7 @@ def Radon3D(
# engine
if engine not in ["numpy", "numba"]:
raise KeyError("engine must be numpy or numba")
if engine == "numba" and jit is None:
if engine == "numba" and jit_message is not None:
engine = "numpy"

# axes
Expand Down
Loading