Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ComposedFunctional class #577

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ Version 0.0.7 (unreleased)
----------------------------

• New module ``scico.trace`` for tracing function/method calls.
• New generic functional ``functional.ComposedFunctional`` representing
a functional composed with an orthogonal linear operator.
• Support ``jaxlib`` and ``jax`` versions 0.4.13 to 0.5.0.
• Support ``flax`` versions 0.8.0 to 0.10.2.

Expand Down
2 changes: 1 addition & 1 deletion examples/scripts/ct_abel_tv_admm_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

This script is hard-coded to run on CPU only to avoid the large number of
warnings that are emitted when GPU resources are requested but not
available, and due to the difficulty of supressing these warnings in a
available, and due to the difficulty of suppressing these warnings in a
way that does not force use of the CPU only. To enable GPU usage, comment
out the `os.environ` statements near the beginning of the script, and
change the value of the "gpu" entry in the `resources` dict from 0 to 1.
Expand Down
2 changes: 1 addition & 1 deletion examples/scripts/deconv_tv_admm_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

This script is hard-coded to run on CPU only to avoid the large number of
warnings that are emitted when GPU resources are requested but not available,
and due to the difficulty of supressing these warnings in a way that does
and due to the difficulty of suppressing these warnings in a way that does
not force use of the CPU only. To enable GPU usage, comment out the
`os.environ` statements near the beginning of the script, and change the
value of the "gpu" entry in the `resources` dict from 0 to 1. Note that
Expand Down
11 changes: 9 additions & 2 deletions scico/functional/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2021-2024 by SCICO Developers
# Copyright (C) 2021-2025 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
Expand All @@ -10,7 +10,13 @@
import sys

# isort: off
from ._functional import Functional, ScaledFunctional, SeparableFunctional, ZeroFunctional
from ._functional import (
Functional,
ComposedFunctional,
ScaledFunctional,
SeparableFunctional,
ZeroFunctional,
)
from ._norm import (
HuberNorm,
L0Norm,
Expand All @@ -33,6 +39,7 @@
"IsotropicTVNorm",
"TVNorm",
"Functional",
"ComposedFunctional",
"ScaledFunctional",
"SeparableFunctional",
"ZeroFunctional",
Expand Down
81 changes: 80 additions & 1 deletion scico/functional/_functional.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2020-2023 by SCICO Developers
# Copyright (C) 2020-2025 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
Expand All @@ -16,6 +16,7 @@

import scico
from scico import numpy as snp
from scico.linop import LinearOperator
from scico.numpy import Array, BlockArray


Expand Down Expand Up @@ -201,6 +202,13 @@ def __init__(self, functional_list: List[Functional]):

super().__init__()

def __repr__(self):
return (
Functional.__repr__(self)
+ "\nComponents:\n"
+ "\n".join([" " + repr(f) for f in self.functional_list])
)

def __call__(self, x: BlockArray) -> float:
if len(x.shape) == len(self.functional_list):
return snp.sum(snp.array([fi(xi) for fi, xi in zip(self.functional_list, x)]))
Expand Down Expand Up @@ -240,6 +248,77 @@ def prox(self, v: BlockArray, lam: float = 1.0, **kwargs) -> BlockArray:
)


class ComposedFunctional(Functional):
r"""A functional constructed by composition.

A functional constructed by composition of a functional with an
orthogonal linear operator, i.e.

.. math::
f(\mb{x}) = g(A \mb{x})

where :math:`f` is the composed functional, :math:`g` is the
functional from which it is composed, and :math:`A` is an orthogonal
linear operator. Note that the resulting :class:`Functional` can only
be applied (either via evaluation or :method:`prox` calls) to inputs
of shape and dtype corresponding to the input specification of the
linear operator.
"""

def __init__(self, functional: Functional, linop: LinearOperator):
r"""
Args:
functional: The functional :math:`g` to be composed.
linop: The linear operator :math:`A` to be composed. Note
that it is the user's responsibility to confirm that
the linear operator is orthogonal. If it is not, the
result of :meth:`prox` will be incorrect.
"""
self.functional = functional
self.linop = linop

self.has_eval = functional.has_eval
self.has_prox = functional.has_prox

super().__init__()

def __repr__(self):
return (
Functional.__repr__(self)
+ "\nComposition of:\n"
+ self.functional.__repr__()
+ "\n"
+ self.linop.__repr__()
)

def __call__(self, x: BlockArray) -> float:
return self.functional(self.linop(x))

def prox(
self, v: Union[Array, BlockArray], lam: float = 1.0, **kwargs
) -> Union[Array, BlockArray]:
r"""Evaluate proximal operator of a composed functional.

Evaluate proximal operator :math:`f(\mb{x}) = g(A \mb{x})`, where
:math:`A` is an orthogonal linear operator, via a special case of
Theorem 6.15 of :cite:`beck-2017-first`

.. math::
\prox_{\lambda f}(\mb{v}) = A^T \prox_{\lambda g}(A \mb{v}) \;.

Examples of orthogonal linear operator in SCICO include
:class:`.linop.Reshape` and :class:`.linop.Transpose`.

Args:
v: Input array :math:`\mb{v}`.
lam: Proximal parameter :math:`\lambda`.
kwargs: Additional arguments that may be used by derived
classes.

"""
return self.linop.T(self.functional.prox(self.linop(v), lam=lam, **kwargs))


class FunctionalSum(Functional):
r"""A sum of two functionals."""

Expand Down
29 changes: 29 additions & 0 deletions scico/test/functional/test_composed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import numpy as np

from jax import config

from prox import prox_test

from scico import functional, linop
from scico.random import randn

# enable 64-bit mode for output dtype checks
config.update("jax_enable_x64", True)


class TestComposed:
def setup_method(self):
key = None
self.shape = (2, 3, 4)
self.dtype = np.float32
self.x, key = randn(self.shape, key=key, dtype=self.dtype)
self.composed = functional.ComposedFunctional(
functional.L2Norm(),
linop.Reshape(self.x.shape, (2, -1), input_dtype=self.dtype),
)

def test_eval(self):
np.testing.assert_allclose(self.composed(self.x), self.composed.functional(self.x))

def test_prox(self):
prox_test(self.x, self.composed.__call__, self.composed.prox, 1.0)
2 changes: 1 addition & 1 deletion scico/test/functional/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def test_separable_grad(test_separable_obj):
np.testing.assert_warns(test_separable_obj.f.grad(test_separable_obj.v1))
np.testing.assert_warns(test_separable_obj.fg.grad(test_separable_obj.vb))

# Tests the separable grad with warnings being supressed
# Test the separable grad with warnings being suppressed
fv1 = test_separable_obj.f.grad(test_separable_obj.v1)
gv2 = test_separable_obj.g.grad(test_separable_obj.v2)
fgv = test_separable_obj.fg.grad(test_separable_obj.vb)
Expand Down
1 change: 1 addition & 0 deletions scico/test/functional/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class TestCheckAttrs:
functional.Functional,
functional.ScaledFunctional,
functional.SeparableFunctional,
functional.ComposedFunctional,
functional.ProximalAverage,
]
to_check = []
Expand Down
2 changes: 1 addition & 1 deletion scico/test/functional/test_separable.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def test_separable_grad(test_separable_obj):
np.testing.assert_warns(test_separable_obj.f.grad(test_separable_obj.v1))
np.testing.assert_warns(test_separable_obj.fg.grad(test_separable_obj.vb))

# Tests the separable grad with warnings being supressed
# Tests the separable grad with warnings being suppressed
fv1 = test_separable_obj.f.grad(test_separable_obj.v1)
gv2 = test_separable_obj.g.grad(test_separable_obj.v2)
fgv = test_separable_obj.fg.grad(test_separable_obj.vb)
Expand Down
5 changes: 3 additions & 2 deletions scico/test/test_denoiser.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest

from scico.denoiser import DnCNN, bm3d, bm4d, have_bm3d, have_bm4d
from scico.metric import rel_res
from scico.random import randn
from scico.test.osver import osx_ver_geq_than

Expand Down Expand Up @@ -126,14 +127,14 @@ def setup_method(self):
def test_single_channel(self):
no_jit = self.dncnn(self.x_sngchn)
jitted = jax.jit(self.dncnn)(self.x_sngchn)
np.testing.assert_allclose(no_jit, jitted, rtol=1e-3)
assert rel_res(no_jit, jitted) < 1e-6
assert no_jit.dtype == np.float32
assert jitted.dtype == np.float32

def test_multi_channel(self):
no_jit = self.dncnn(self.x_mltchn)
jitted = jax.jit(self.dncnn)(self.x_mltchn)
np.testing.assert_allclose(no_jit, jitted, rtol=1e-3)
assert rel_res(no_jit, jitted) < 1e-6
assert no_jit.dtype == np.float32
assert jitted.dtype == np.float32

Expand Down
Loading