Skip to content

Commit

Permalink
Add ReconcileAndFilterFlows processor.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 665974395
  • Loading branch information
timblakely authored and copybara-github committed Sep 11, 2024
1 parent 43c6de2 commit df6a1ae
Show file tree
Hide file tree
Showing 4 changed files with 303 additions and 9 deletions.
193 changes: 193 additions & 0 deletions connectomics/common/tuples.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
# Copyright 2024 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Type-safe tuple and NamedTuple utilities."""

import dataclasses
from typing import Any, Callable, Generic, NamedTuple, Type, TypeVar
import dataclasses_json

T = TypeVar('T', int, float)
C = TypeVar('C')


class XYZ(Generic[T], NamedTuple):
"""XYZ is a named tuple for a 3-dimensional vector.
Allows static type checker to differentiate between XYZ and ZYX, and allows
switching between the two via named properties.
"""

x: T
y: T
z: T

def __eq__(self, other):
if not (isinstance(other, XYZ) or isinstance(other, ZYX)):
return False
return self.x == other.x and self.y == other.y and self.z == other.z

@property
def xyz(self) -> 'XYZ[T]':
return self

# Allow swizzling into ZYX format.
@property
def zyx(self) -> 'ZYX[T]':
return ZYX(*self[::-1])


class ZYX(Generic[T], NamedTuple):
"""ZYX is a named tuple for a 3-dimensional vector.
Allows static type checker to differentiate between XYZ and ZYX, and allows
switching between the two via named properties.
"""

z: T
y: T
x: T

# Allow swizzling into XYZ format.
@property
def xyz(self) -> 'XYZ[T]':
return XYZ(*self[::-1])

@property
def zyx(self) -> 'ZYX[T]':
return self

def __eq__(self, other):
if not (isinstance(other, XYZ) or isinstance(other, ZYX)):
return False
return self.x == other.x and self.y == other.y and self.z == other.z


class XYZC(Generic[T], NamedTuple):
"""XYZC is a named tuple for a 4-dimensional vector."""

x: T
y: T
z: T
c: T

def __eq__(self, other):
if not (isinstance(other, XYZC) or isinstance(other, CZYX)):
return False
return (
self.x == other.x
and self.y == other.y
and self.z == other.z
and self.c == other.c
)

@property
def xyz(self) -> 'XYZ[T]':
return XYZ(self.x, self.y, self.z)

@property
def zyx(self) -> 'ZYX[T]':
return ZYX(self.z, self.y, self.x)

@property
def xyzc(self) -> 'XYZC[T]':
return self

@property
def czyx(self) -> 'CZYX[T]':
return CZYX(*self[::-1])


class CZYX(Generic[T], NamedTuple):
"""CZYX is a named tuple for a 4-dimensional vector."""

c: T
z: T
y: T
x: T

# Allow swizzling into XYZ format.
@property
def xyz(self) -> 'XYZ[T]':
return XYZ(self.x, self.y, self.z)

@property
def zyx(self) -> 'ZYX[T]':
return ZYX(self.z, self.y, self.x)

@property
def xyzc(self) -> 'XYZC[T]':
return XYZC(*self[::-1])

@property
def czyx(self) -> 'CZYX[T]':
return self

def __eq__(self, other):
if not (isinstance(other, XYZC) or isinstance(other, CZYX)):
return False
return (
self.x == other.x
and self.y == other.y
and self.z == other.z
and self.c == other.c
)


def named_tuple_field(
cls: C,
encoder: Callable[..., Any] | None = None,
decoder: Callable[..., C] | Type[C] | None = tuple,
):
"""Add metadata to allow NamedTuple decoding in dataclasses.
Example usage:
@dataclass
class Foo:
location: XYZ[float] = named_tuple_field(XYZ)
dest_voxel: CZYX[float] = named_tuple_field(CZYX)
Args:
cls: The NamedTuple class to use.
encoder: The encoder to use for the NamedTuple.
decoder: The decoder to use for the NamedTuple.
Returns:
A dataclass field that will decode to the given NamedTuple.
"""
return dataclasses.field(
metadata={
'named_tuple_type': cls,
**dataclasses_json.config(encoder=encoder, decoder=decoder),
}
)


@dataclasses.dataclass(frozen=True)
class DataclassWithNamedTuples:
"""Parent class that allows dataclasses to have NamedTuple members.
Subclass to allow dataclasses to accept generic constructor arguments,
ensuring runtime NamedTuples.
"""

def __post_init__(self):
for field in dataclasses.fields(self):
named_tuple_type = field.metadata.get('named_tuple_type', None)
if not named_tuple_type:
continue
# Use object.__setattr__, since setattr won't work on frozen dataclasses.
object.__setattr__(
self, field.name, named_tuple_type(*getattr(self, field.name))
)
91 changes: 91 additions & 0 deletions connectomics/common/tuples_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Copyright 2024 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Tests for tuples."""

from absl.testing import absltest
from connectomics.common import tuples


class NamedTupleTest(absltest.TestCase):

def test_xyz_zyx(self):
x, y, z = [1, 2, 3]
xyz = tuples.XYZ(x, y, z)
zyx = tuples.ZYX(z, y, x)

for tup in [xyz, zyx]:
self.assertEqual(tup.x, x)
self.assertEqual(tup.y, y)
self.assertEqual(tup.z, z)

self.assertEqual(xyz, zyx)

self.assertEqual(xyz.xyz, xyz)
self.assertEqual(xyz.zyx, zyx)
self.assertEqual(xyz.zyx.xyz, xyz)

self.assertEqual(zyx.zyx, zyx)
self.assertEqual(zyx.xyz, xyz)
self.assertEqual(zyx.xyz.zyx, zyx)

self.assertEqual(xyz[0], x)
self.assertEqual(xyz[1], y)
self.assertEqual(xyz[2], z)

self.assertEqual(zyx[0], z)
self.assertEqual(zyx[1], y)
self.assertEqual(zyx[2], x)

def test_xyzc_czyx(self):
x, y, z, c = [1, 2, 3, 4]
xyz = tuples.XYZ(x, y, z)
zyx = tuples.ZYX(z, y, x)
xyzc = tuples.XYZC(x, y, z, c)
czyx = tuples.CZYX(c, z, y, x)

for tup in [xyzc, czyx]:
self.assertEqual(tup.x, x)
self.assertEqual(tup.y, y)
self.assertEqual(tup.z, z)
self.assertEqual(tup.c, c)

self.assertEqual(xyzc, czyx)

self.assertEqual(xyzc.xyz, xyz)
self.assertEqual(xyzc.zyx, zyx)
self.assertEqual(xyzc.xyzc, xyzc)
self.assertEqual(xyzc.czyx, czyx)

self.assertEqual(czyx, xyzc)

self.assertEqual(czyx.xyz, xyz)
self.assertEqual(czyx.zyx, zyx)
self.assertEqual(czyx.czyx, czyx)
self.assertEqual(czyx.xyzc, xyzc)

self.assertEqual(xyzc[0], x)
self.assertEqual(xyzc[1], y)
self.assertEqual(xyzc[2], z)
self.assertEqual(xyzc[3], c)

self.assertEqual(czyx[0], c)
self.assertEqual(czyx[1], z)
self.assertEqual(czyx[2], y)
self.assertEqual(czyx[3], x)


if __name__ == '__main__':
absltest.main()
25 changes: 16 additions & 9 deletions connectomics/volume/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,19 @@

from connectomics.common import bounding_box
from connectomics.common import file
from connectomics.common import tuples
from connectomics.volume import decorators
import dataclasses_json
import numpy as np
import numpy.typing as npt


XYZ = tuples.XYZ


@dataclasses_json.dataclass_json
@dataclasses.dataclass(frozen=True)
class VolumeMetadata:
class VolumeMetadata(tuples.DataclassWithNamedTuples):
"""Metadata associated with a Volume.
Attributes:
Expand All @@ -38,9 +42,12 @@ class VolumeMetadata:
num_channels: Number of channels in the volume.
dtype: Datatype of the volume. Must be numpy compatible.
"""
volume_size: tuple[int, int, int]
pixel_size: tuple[float, float, float]
bounding_boxes: list[bounding_box.BoundingBox]

volume_size: XYZ[int] = tuples.named_tuple_field(XYZ)
pixel_size: XYZ[float] = tuples.named_tuple_field(XYZ)
bounding_boxes: list[bounding_box.BoundingBox] = dataclasses.field(
default_factory=list
)
num_channels: int = 1
dtype: npt.DTypeLike = dataclasses.field(
metadata=dataclasses_json.config(
Expand All @@ -50,17 +57,16 @@ class VolumeMetadata:
default=np.uint8,
)

def scale(
self, scale_factors: float | Sequence[float]
) -> 'VolumeMetadata':
def scale(self, scale_factors: float | Sequence[float]) -> 'VolumeMetadata':
"""Scales the volume metadata by the given scale factors.
`scale_factors` must be a single float that will be applied multiplicatively
to the volume size and pixel size, or a 3-element sequence of floats that
will be applied to XYZ dimensions respectively.
Args:
scale_factors: The scale factors to apply.
Returns:
A new VolumeMetadata with the scaled values.
"""
Expand Down Expand Up @@ -91,6 +97,7 @@ class Volume:
path: The path to the volume.
meta: The volume metadata.
"""

path: pathlib.Path
meta: VolumeMetadata

Expand Down
3 changes: 3 additions & 0 deletions connectomics/volume/metadata_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ def test_volume_metadata(self):
bounding_boxes=[BBOX([10, 10, 10], [100, 100, 100])],
)

self.assertIsInstance(meta.volume_size, metadata.XYZ)
self.assertIsInstance(meta.pixel_size, metadata.XYZ)

# No scale
scaled = meta.scale([1, 1, 1])
self.assertCountEqual(scaled.volume_size, [100, 100, 100])
Expand Down

0 comments on commit df6a1ae

Please sign in to comment.