Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add call tracing module #569

Merged
merged 60 commits into from
Nov 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
307b862
Add decorator for call tracing
bwohlberg May 10, 2024
f2231d5
Add function for applying a decorator to all functions in a module
bwohlberg May 10, 2024
cec4a26
Add some comments
bwohlberg May 10, 2024
107138e
Add function to enable tracing for most scico functions
bwohlberg May 10, 2024
7554bfa
Improve trace call detail
bwohlberg May 13, 2024
0774abc
Merge branch 'main' into brendt/trace
bwohlberg May 13, 2024
005f821
Suppress typing errors
bwohlberg May 13, 2024
b376aac
Rectify oversight
bwohlberg May 13, 2024
7465042
No apparent reason for method to be dynamic here
bwohlberg May 13, 2024
1137a93
Some improvements
bwohlberg May 13, 2024
5213c93
Merge branch 'main' into brendt/trace
bwohlberg May 14, 2024
9cd8adb
Merge branch 'main' into brendt/trace
bwohlberg May 16, 2024
6c67a41
Merge branch 'main' into brendt/trace
bwohlberg May 16, 2024
922c64e
Add missing copyright statement
bwohlberg May 16, 2024
e17ce59
Fix incorrect jit of method
bwohlberg May 17, 2024
f622c83
Add missing type annotations
bwohlberg May 17, 2024
b1f3dfd
Merge branch 'main' into brendt/trace
bwohlberg May 22, 2024
236f981
Merge branch 'main' into brendt/trace
bwohlberg Jun 3, 2024
970424c
Merge branch 'main' into brendt/trace
bwohlberg Jun 22, 2024
c2365e6
Update submodule
bwohlberg Jul 25, 2024
6d43437
Merge branch 'main' into brendt/trace
bwohlberg Jul 25, 2024
f7eb380
Merge branch 'main' into brendt/trace
bwohlberg Jul 31, 2024
335c73e
Move trace functions to their own module
bwohlberg Aug 1, 2024
48a5556
Add some comments
bwohlberg Aug 1, 2024
6d3226a
Extend colour to all args
bwohlberg Aug 1, 2024
c1e3390
Fix handling of static and class methods
bwohlberg Aug 1, 2024
baf116f
Add trace usage example
bwohlberg Aug 1, 2024
ab5e3eb
Different colour for function name
bwohlberg Aug 1, 2024
519dd1e
Add display of return values
bwohlberg Aug 2, 2024
753bdc1
Improve placement of register_variable calls
bwohlberg Aug 2, 2024
8ecf2d2
Trivial edit
bwohlberg Aug 2, 2024
a12044a
Bug fix
bwohlberg Aug 2, 2024
9efd471
Add comments and improve some variable names
bwohlberg Aug 7, 2024
2d86156
Re-write of apply_decorator function in progress
bwohlberg Aug 13, 2024
70c3485
Improve verbose output
bwohlberg Aug 13, 2024
e247c58
Add docstrings
bwohlberg Aug 13, 2024
fee4460
Suppress mypy complaints
bwohlberg Aug 13, 2024
a5ac2e9
Output format improvement
bwohlberg Aug 13, 2024
1890f71
Clean up
bwohlberg Aug 13, 2024
cec8214
Output format improvement
bwohlberg Aug 13, 2024
0b0ae37
Add additional colour coding
bwohlberg Aug 13, 2024
aeee771
Add colorama dependency
bwohlberg Aug 13, 2024
9baf6a5
Update change log
bwohlberg Aug 13, 2024
3700e96
Exclude trace module from coverage
bwohlberg Aug 13, 2024
1981139
Add docs
bwohlberg Aug 13, 2024
fee1605
Minor edit
bwohlberg Aug 14, 2024
3dd05c6
Merge branch 'main' into brendt/trace
bwohlberg Aug 14, 2024
627ef34
Merge branch 'brendt/trace-alt-ver' into brendt/trace
bwohlberg Aug 14, 2024
500dc37
Add option for displaying jax array device and sharding information
bwohlberg Aug 15, 2024
c594fff
Merge branch 'main' into brendt/trace
bwohlberg Oct 2, 2024
310a1e4
Merge branch 'main' into brendt/trace
bwohlberg Oct 30, 2024
118afea
Merge branch 'main' into brendt/trace
bwohlberg Nov 4, 2024
6f8c99d
Typo fix
bwohlberg Nov 4, 2024
e249fea
Merge branch 'main' into brendt/trace
bwohlberg Nov 4, 2024
56fc918
Update change summary
bwohlberg Nov 5, 2024
dd88cb4
Improve docs
bwohlberg Nov 6, 2024
4feba96
Suppress mypy errors
bwohlberg Nov 6, 2024
931f212
Suppress mypy errors
bwohlberg Nov 6, 2024
fa567c5
Merge branch 'main' into brendt/trace
bwohlberg Nov 8, 2024
cddf9b4
Address PR review comments
bwohlberg Nov 8, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .coveragerc
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
[run]
source = scico
command_line = -m pytest
omit =
omit =
scico/test/*
scico/plot.py
scico/trace.py

[report]
# Regexes for lines to exclude from consideration
Expand Down
2 changes: 1 addition & 1 deletion CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ SCICO Release Notes
Version 0.0.7 (unreleased)
----------------------------

No changes yet.
New module ``scico.trace`` for tracing function/method calls.



Expand Down
1 change: 1 addition & 0 deletions examples/examples_requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
-r ../requirements.txt
colorama
colour_demosaicing
svmbir>=0.4.0
astra-toolbox
Expand Down
23 changes: 13 additions & 10 deletions examples/jnb.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,17 @@ def py_file_to_string(src):
if re.match("^import|^from .* import", line):
import_seen = True
lines.append(line)
# Backtrack through list of lines to find last import statement
n = 1
for line in lines[-2::-1]:
if re.match("^(import|from)", line):
break
else:
n += 1
# Insert notebook plotting config directly after last import statement
lines.insert(-n, "plot.config_notebook_plotting()\n")

if "plot" in "".join(lines):
# Backtrack through list of lines to find last import statement
n = 1
for line in lines[-2::-1]:
if re.match("^(import|from)", line):
break
else:
n += 1
# Insert notebook plotting config directly after last import statement
lines.insert(-n, "plot.config_notebook_plotting()\n")

# Process remainder of source file
for line in srcfile:
Expand All @@ -73,7 +75,8 @@ def py_file_to_string(src):
n += 1
else:
break
lines = lines[0:-n]
if n > 0:
lines = lines[0:-n]

return "".join(lines)

Expand Down
110 changes: 110 additions & 0 deletions examples/scripts/trace_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# This file is part of the SCICO package. Details of the copyright
# and user license can be found in the 'LICENSE.txt' file distributed
# with the package.

r"""
SCICO Call Tracing
bwohlberg marked this conversation as resolved.
Show resolved Hide resolved
==================

This example demonstrates the call tracing functionality provided by the
[trace](../_autosummary/scico.trace.rst) module. It is based on the
[non-negative BPDN example](sparsecode_nn_admm.rst).
"""


import numpy as np

import jax

import scico.numpy as snp
from scico import functional, linop, loss, metric
from scico.optimize.admm import ADMM, MatrixSubproblemSolver
from scico.trace import register_variable, trace_scico_calls
from scico.util import device_info

"""
Initialize tracing. JIT must be disabled for correct tracing.

The call tracing mechanism prints the name, arguments, and return values
of functions/methods as they are called. Module and class names are
printed in light red, function and method names in dark red, arguments
and return values in light blue, and the names of registered variables
in light yellow. When a method defined in a class is called for an object
of a derived class type, the class of that object is printed in light
magenta, in square brackets. Function names and return values are
distinguished by initial ">>" and "<<" characters respectively.
"""
jax.config.update("jax_disable_jit", True)
trace_scico_calls()


"""
Create random dictionary, reference random sparse representation, and
test signal consisting of the synthesis of the reference sparse
representation.
"""
m = 32 # signal size
n = 128 # dictionary size
s = 10 # sparsity level

np.random.seed(1)
D = np.random.randn(m, n).astype(np.float32)
D = D / np.linalg.norm(D, axis=0, keepdims=True) # normalize dictionary

xt = np.zeros(n, dtype=np.float32) # true signal
idx = np.random.randint(low=0, high=n, size=s) # support of xt
xt[idx] = np.random.rand(s)
y = D @ xt + 5e-2 * np.random.randn(m) # synthetic signal

xt = snp.array(xt) # convert to jax array
y = snp.array(y) # convert to jax array


"""
Register a variable so that it can be referenced by name in the call trace.
Any hashable object and numpy arrays may be registered, but JAX arrays
cannot.
"""
register_variable(D, "D")


"""
Set up the forward operator and ADMM solver object.
"""
lmbda = 1e-1
A = linop.MatrixOperator(D)
register_variable(A, "A")
f = loss.SquaredL2Loss(y=y, A=A)
g_list = [lmbda * functional.L1Norm(), functional.NonNegativeIndicator()]
C_list = [linop.Identity((n)), linop.Identity((n))]
rho_list = [1.0, 1.0]
maxiter = 1 # number of ADMM iterations (set to small value to simplify trace output)

register_variable(f, "f")
register_variable(g_list[0], "g_list[0]")
register_variable(g_list[1], "g_list[1]")
register_variable(C_list[0], "C_list[0]")
register_variable(C_list[1], "C_list[1]")

solver = ADMM(
f=f,
g_list=g_list,
C_list=C_list,
rho_list=rho_list,
x0=A.adj(y),
Michael-T-McCann marked this conversation as resolved.
Show resolved Hide resolved
maxiter=maxiter,
subproblem_solver=MatrixSubproblemSolver(),
itstat_options={"display": True, "period": 5},
)

register_variable(solver, "solver")


"""
Run the solver.
"""
print(f"Solving on {device_info()}\n")
x = solver.solve()
mse = metric.mse(xt, x)
18 changes: 13 additions & 5 deletions scico/optimize/_pgm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# see https://www.python.org/dev/peps/pep-0563/
from __future__ import annotations

from functools import partial
from typing import Optional, Union

import jax
Expand Down Expand Up @@ -101,15 +102,22 @@ def __init__(
self.L: float = L0 # reciprocal of step size (estimate of Lipschitz constant of ∇f)
self.fixed_point_residual = snp.inf

def x_step(v: Union[Array, BlockArray], L: float) -> Union[Array, BlockArray]:
return self.g.prox(v - 1.0 / L * self.f.grad(v), 1.0 / L)

self.x_step = jax.jit(x_step)

self.x: Union[Array, BlockArray] = x0 # current estimate of solution

super().__init__(**kwargs)

def x_step(self, v: Union[Array, BlockArray], L: float) -> Union[Array, BlockArray]:
"""Compute update for variable `x`."""
return PGM._x_step(self.f, self.g, v, L)

@staticmethod
@partial(jax.jit, static_argnums=(0, 1))
def _x_step(
f: Functional, g: Functional, v: Union[Array, BlockArray], L: float
) -> Union[Array, BlockArray]:
"""Jit-able static method for computing update for variable `x`."""
return g.prox(v - 1.0 / L * f.grad(v), 1.0 / L)

def _working_vars_finite(self) -> bool:
"""Determine where ``NaN`` of ``Inf`` encountered in solve.

Expand Down
1 change: 1 addition & 0 deletions scico/optimize/pgm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2020-2024 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
Expand Down
Loading
Loading