Skip to content

Commit

Permalink
Merge branch 'main' into cristina/flax-updates
Browse files Browse the repository at this point in the history
  • Loading branch information
bwohlberg committed Nov 30, 2023
2 parents 4a36aef + a041524 commit 2c50a84
Show file tree
Hide file tree
Showing 9 changed files with 191 additions and 88 deletions.
2 changes: 2 additions & 0 deletions docs/source/team.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,5 @@ Contributors
- `Saurav Maheshkar <https://github.com/SauravMaheshkar>`_ (Improvements to pre-commit configuration)
- `Yanpeng Yuan <https://github.com/yanpeng7>`_ (ASTRA interface improvements)
- `Li-Ta (Ollie) Lo <https://github.com/ollielo>`_ (ASTRA interface improvements)
- `Renat Sibgatulin <https://github.com/Sibgatulin>`_ (Docs corrections)
- `Salman Naqvi <https://github.com/shnaqvi>`_ (Contributions to approximate TV norm prox and proximal average implementation)
51 changes: 25 additions & 26 deletions examples/scripts/ct_projector_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,47 +119,46 @@
"""
Display timing results.
On our server, the SCICO projection is more than twice as fast as ASTRA
when both are run on the GPU, and about 10% slower when both are run the
CPU. The SCICO back projection is slow the first time it is run, probably
due to JIT overhead. After the first run, it is an order of magnitude
faster than ASTRA when both are run on the GPU, and about three times
faster when both are run on the CPU.
On our server, when using the GPU, the SCICO projector (both forward
and backward) is faster than ASTRA. When using the CPU, it is slower
for forward projection and faster for back projection. The SCICO object
initialization and first back projection are slow due to JIT
overhead.
On our server, using the GPU:
```
init astra 1.36e-03 s
init scico 1.37e+01 s
init astra 4.81e-02 s
init scico 2.53e-01 s
first fwd astra 6.92e-02 s
first fwd scico 2.95e-02 s
first fwd astra 4.44e-02 s
first fwd scico 2.82e-02 s
first back astra 4.20e-02 s
first back scico 7.63e+00 s
first back astra 3.31e-02 s
first back scico 2.80e-01 s
avg fwd astra 4.62e-02 s
avg fwd scico 1.61e-02 s
avg fwd astra 4.76e-02 s
avg fwd scico 2.83e-02 s
avg back astra 3.71e-02 s
avg back scico 1.05e-03 s
avg back astra 3.96e-02 s
avg back scico 6.80e-04 s
```
Using the CPU:
```
init astra 1.06e-03 s
init scico 1.00e+01 s
init astra 1.72e-02 s
init scico 2.88e+00 s
first fwd astra 9.16e-01 s
first fwd scico 1.04e+00 s
first fwd astra 1.02e+00 s
first fwd scico 2.40e+00 s
first back astra 9.39e-01 s
first back scico 1.00e+01 s
first back astra 1.03e+00 s
first back scico 3.53e+00 s
avg fwd astra 9.11e-01 s
avg fwd scico 1.03e+00 s
avg fwd astra 1.03e+00 s
avg fwd scico 2.54e+00 s
avg back astra 9.34e-01 s
avg back scico 2.62e-01 s
avg back astra 1.01e+00 s
avg back scico 5.98e-01 s
```
"""

Expand Down
4 changes: 2 additions & 2 deletions examples/scripts/deconv_tv_padmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,9 @@
"""
Set up a proximal ADMM solver object.
"""
ρ = 1.0e-1 # ADMM penalty parameter
ρ = 5.0e-2 # ADMM penalty parameter
maxiter = 50 # number of ADMM iterations
mu, nu = ProximalADMM.estimate_parameters(D)
mu, nu = ProximalADMM.estimate_parameters(A)

solver = ProximalADMM(
f=f,
Expand Down
2 changes: 1 addition & 1 deletion examples/scripts/denoise_tv_admm.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@
"""
Denoise with anisotropic total variation for comparison.
"""
# Tune the weight to give the same data fidelty as the isotropic case.
# Tune the weight to give the same data fidelity as the isotropic case.
λ_aniso = 1.2e0
g_aniso = λ_aniso * functional.L1Norm()

Expand Down
2 changes: 1 addition & 1 deletion examples/scripts/denoise_tv_apgm.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def prox(self, v: Array, lam: float, **kwargs) -> Array:

"""
Use RobustLineSearchStepSize object and set up AcceleratedPGM solver
object. Weight was tuned to give the same data fidelty as the
object. Weight was tuned to give the same data fidelity as the
isotropic case. Run the solver.
"""

Expand Down
18 changes: 12 additions & 6 deletions scico/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def __init__(
ident: Optional[dict] = None,
display: bool = False,
period: int = 1,
shift_cycles: bool = True,
overwrite: bool = True,
colsep: int = 2,
):
Expand All @@ -48,14 +49,17 @@ def __init__(
fields: A dictionary associating field names with format
strings for displaying the corresponding values.
ident: A dictionary associating field names.
with corresponding valid identifiers for use within the
namedtuple used to record results. Defaults to ``None``.
with corresponding valid identifiers for use within the
namedtuple used to record results. Defaults to ``None``.
display: Flag indicating whether results should be printed
to stdout. Defaults to ``False``.
period: Only display one result in every cycle of length
`period`.
shift_cycles: If ``True``, apply an offset to the iteration
count so that display cycles end at 0, `period` - 1, etc.
Otherwise, cycles end at `period`, 2 * `period`, etc.
overwrite: If ``True``, display all results, but each one
overwrites the next, except for one result per cycle.
overwrites the next, except for one result per cycle.
colsep: Number of spaces seperating fields in displayed
tables. Defaults to 2.
Expand All @@ -69,6 +73,8 @@ def __init__(
raise TypeError("Parameter fields must be an instance of dict.")
# Subsampling rate of results that are to be displayed
self.period: int = period
# Offset to iteration count for determining start of period
self.period_offset = 1 if shift_cycles else 0
# Flag indicating whether to display and overwrite, or not display at all
self.overwrite: bool = overwrite
# Number of spaces seperating fields in displayed tables
Expand Down Expand Up @@ -159,13 +165,13 @@ def insert(self, values: Union[List, Tuple]):
print(self.disphdr)
self.disphdr = None
if self.overwrite:
if (len(self.iterations) - 1) % self.period == 0:
if (len(self.iterations) - self.period_offset) % self.period == 0:
end = "\n"
else:
end = "\r"
print((" " * self.colsep).join(self.fieldformat) % values, end=end)
else:
if (len(self.iterations) - 1) % self.period == 0:
if (len(self.iterations) - self.period_offset) % self.period == 0:
print((" " * self.colsep).join(self.fieldformat) % values)

def end(self):
Expand All @@ -180,7 +186,7 @@ def end(self):
self.display
and self.overwrite
and self.period > 1
and (len(self.iterations) - 1) % self.period
and (len(self.iterations) - self.period_offset) % self.period
):
print()

Expand Down
164 changes: 116 additions & 48 deletions scico/linop/xray/_xray.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,23 +40,43 @@ def __init__(self, projector):
currently the only option is
:class:`Parallel2dProjector`
"""
self.projector = projector
self._eval = projector.project

super().__init__(
input_shape=projector.im_shape,
output_shape=(len(projector.angles), *projector.det_shape),
output_shape=(len(projector.angles), projector.det_count),
)


class Parallel2dProjector:
"""Parallel ray, single axis, 2D X-ray projector."""
"""Parallel ray, single axis, 2D X-ray projector.
This implementation approximates the projection of each rectangular
pixel as a boxcar function (whereas the exact projection is a
trapezoid). Detector pixels are modeled as bins (rather than points)
and this approximation allows fast calculation of the contribution
of each pixel to each bin because the integral of the boxcar is
simple.
By requiring the side length of the pixels to be less than or equal
to the bin width (which is assumed to be 1.0), we ensure that each
pixel contributes to at most two bins, which accelerates the
accumulation of pixel values into bins (equivalently, makes the
linear operator sparse).
`x0`, `dx`, and `y0` should be expressed in units such that the
detector spacing `dy` is 1.0.
"""

def __init__(
self,
im_shape: Shape,
angles: ArrayLike,
x0: Optional[ArrayLike] = None,
dx: Optional[ArrayLike] = None,
y0: Optional[float] = None,
det_count: Optional[int] = None,
dither: bool = True,
):
r"""
Args:
Expand All @@ -66,67 +86,115 @@ def __init__(
angle of 0 corresponds to summing rows, an angle of pi/2
corresponds to summing columns, and an angle of pi/4
corresponds to summing along antidiagonals.
x0: (x, y) position of the corner of the pixel `im[0,0]`. By
default, `-im_shape / 2`.
dx: Image pixel side length in x- and y-direction. Should be
<= 1.0 in each dimension. By default, [1.0, 1.0].
y0: Location of the edge of the first detector bin. By
default, `-det_count / 2`
det_count: Number of elements in detector. If ``None``,
defaults to the size of the diagonal of `im_shape`.
dither: If ``True`` randomly shift pixel locations to
reduce projection artifacts caused by aliasing.
"""
self.im_shape = im_shape
self.angles = angles

im_shape = np.array(im_shape)
self.nx = np.array(im_shape)

x0 = -(im_shape - 1) / 2
if x0 is None:
x0 = -self.nx / 2
self.x0 = x0
if dx is None:
dx = np.ones(2)
self.dx = dx

if det_count is None:
det_count = int(np.ceil(np.linalg.norm(im_shape)))
self.det_shape = (det_count,)

y0 = -det_count / 2

@jax.vmap
def compute_inds(angle: float) -> ArrayLike:
"""Project pixel positions on to a detector at the given
angle, determine which detector element they contribute to.
"""
x = jnp.stack(
jnp.meshgrid(
*(
jnp.arange(shape_i) * step_i + start_i
for start_i, step_i, shape_i in zip(x0, [1, 1], im_shape)
),
indexing="ij",
),
axis=-1,
)
self.det_count = det_count
self.ny = det_count

# dither
if dither:
key = jax.random.PRNGKey(0)
x = x + jax.random.uniform(key, shape=x.shape, minval=-0.5, maxval=0.5)
if y0 is None:
y0 = -self.ny / 2
self.y0 = y0
self.dy = 1.0

# project
Px = x[..., 0] * jnp.cos(angle) + x[..., 1] * jnp.sin(angle)
if any(self.dx > self.dy):
raise ValueError(
f"This projector assumes dx <= dy, but dx was {self.dx} and dy was {self.dy}."
)

# quantize
inds = jnp.floor((Px - y0)).astype(int)
def project(self, im):
"""Compute X-ray projection."""
return _project(im, self.x0, self.dx, self.y0, self.ny, self.angles)


@partial(jax.jit, static_argnames=["ny"])
def _project(im, x0, dx, y0, ny, angles):
r"""
Args:
im: Input array, (M, N).
x0: (x, y) position of the corner of the pixel im[0,0].
dx: Pixel side length in x- and y-direction. Units are such
that the detector bins have length 1.0.
y0: Location of the edge of the first detector bin.
ny: Number of detector bins.
angles: (num_angles,) array of angles in radians. Pixels are
projected onto units vectors pointing in these directions.
"""
nx = im.shape
inds, weights = _calc_weights(x0, dx, nx, angles, y0)
# Handle out of bounds indices. In the .at call, inds >= y0 are
# ignored, while inds < 0 wrap around. So we set inds < 0 to y0.
inds = jnp.where(inds > 0, inds, ny)

# map negative inds to y_size, which is out of bounds and will be ignored
# otherwise they index from the end like x[-1]
inds = jnp.where(inds < 0, det_count, inds)
y = (
jnp.zeros((len(angles), ny))
.at[jnp.arange(len(angles)).reshape(-1, 1, 1), inds]
.add(im * weights)
)

return inds
y = y.at[jnp.arange(len(angles)).reshape(-1, 1, 1), inds + 1].add(im * (1 - weights))

inds = compute_inds(angles) # (len(angles), *im_shape)
return y

@partial(jax.vmap, in_axes=(None, 0))
def project_inds(im: ArrayLike, inds: ArrayLike) -> ArrayLike:
"""Compute the projection at a single angle."""
return jnp.zeros(det_count).at[inds].add(im)

@jax.jit
def project(im: ArrayLike) -> ArrayLike:
"""Compute the projection for all angles."""
return project_inds(im, inds)
@partial(jax.jit, static_argnames=["nx", "y0"])
@partial(jax.vmap, in_axes=(None, None, None, 0, None))
def _calc_weights(x0, dx, nx, angle, y0):
"""
self.project = project
Args:
x0: Location of the corner of the pixel im[0,0].
dx: Pixel side length in x- and y-direction. Units are such
that the detector bins have length 1.0.
nx: Input image shape.
angle: (num_angles,) array of angles in radians. Pixels are
projected onto units vectors pointing in these directions.
(This argument is `vmap`ed.)
y0: Location of the edge of the first detector bin.
"""
u = [jnp.cos(angle), jnp.sin(angle)]
Px0 = x0[0] * u[0] + x0[1] * u[1] - y0
Pdx = [dx[0] * u[0], dx[1] * u[1]]
Pxmin = jnp.min(jnp.array([Px0, Px0 + Pdx[0], Px0 + Pdx[1], Px0 + Pdx[0] + Pdx[1]]))

Px = (
Pxmin
+ Pdx[0] * jnp.arange(nx[0]).reshape(-1, 1)
+ Pdx[1] * jnp.arange(nx[1]).reshape(1, -1)
)

# detector bin inds
inds = jnp.floor(Px).astype(int)

# weights
Pdx = jnp.array(u) * jnp.array(dx)
diag1 = jnp.abs(Pdx[0] + Pdx[1])
diag2 = jnp.abs(Pdx[0] - Pdx[1])
w = jnp.max(jnp.array([diag1, diag2]))
f = jnp.min(jnp.array([diag1, diag2]))

width = (w + f) / 2
distance_to_next = 1 - (Px - inds) # always in (0, 1]
weights = jnp.minimum(distance_to_next, width) / width

return inds, weights
Loading

0 comments on commit 2c50a84

Please sign in to comment.