Skip to content

Commit

Permalink
Fix NamedTuple equality for tuple comparison.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 673512157
  • Loading branch information
timblakely authored and copybara-github committed Sep 11, 2024
1 parent 8a6dae5 commit 3379255
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 22 deletions.
48 changes: 26 additions & 22 deletions connectomics/common/tuples.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,10 @@ class XYZ(Generic[T], NamedTuple):
z: T

def __eq__(self, other):
if not (isinstance(other, XYZ) or isinstance(other, ZYX)):
return False
return self.x == other.x and self.y == other.y and self.z == other.z
if isinstance(other, XYZ) or isinstance(other, ZYX):
return self.x == other.x and self.y == other.y and self.z == other.z
# Defer to tuple equality.
return self[:] == other

@property
def xyz(self) -> 'XYZ[T]':
Expand Down Expand Up @@ -69,9 +70,10 @@ def zyx(self) -> 'ZYX[T]':
return self

def __eq__(self, other):
if not (isinstance(other, XYZ) or isinstance(other, ZYX)):
return False
return self.x == other.x and self.y == other.y and self.z == other.z
if isinstance(other, XYZ) or isinstance(other, ZYX):
return self.x == other.x and self.y == other.y and self.z == other.z
# Defer to tuple equality.
return self[:] == other


class XYZC(Generic[T], NamedTuple):
Expand All @@ -83,14 +85,15 @@ class XYZC(Generic[T], NamedTuple):
c: T

def __eq__(self, other):
if not (isinstance(other, XYZC) or isinstance(other, CZYX)):
return False
return (
self.x == other.x
and self.y == other.y
and self.z == other.z
and self.c == other.c
)
if isinstance(other, XYZC) or isinstance(other, CZYX):
return (
self.x == other.x
and self.y == other.y
and self.z == other.z
and self.c == other.c
)
# Defer to tuple equality.
return self[:] == other

@property
def xyz(self) -> 'XYZ[T]':
Expand Down Expand Up @@ -135,14 +138,15 @@ def czyx(self) -> 'CZYX[T]':
return self

def __eq__(self, other):
if not (isinstance(other, XYZC) or isinstance(other, CZYX)):
return False
return (
self.x == other.x
and self.y == other.y
and self.z == other.z
and self.c == other.c
)
if isinstance(other, XYZC) or isinstance(other, CZYX):
return (
self.x == other.x
and self.y == other.y
and self.z == other.z
and self.c == other.c
)
# Defer to tuple equality.
return self[:] == other


def named_tuple_field(
Expand Down
6 changes: 6 additions & 0 deletions connectomics/common/tuples_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ def test_xyz_zyx(self):
self.assertEqual(zyx[1], y)
self.assertEqual(zyx[2], x)

self.assertEqual(xyz, (x, y, z))
self.assertEqual(zyx, (z, y, x))

def test_xyzc_czyx(self):
x, y, z, c = [1, 2, 3, 4]
xyz = tuples.XYZ(x, y, z)
Expand Down Expand Up @@ -86,6 +89,9 @@ def test_xyzc_czyx(self):
self.assertEqual(czyx[2], y)
self.assertEqual(czyx[3], x)

self.assertEqual(xyzc, (x, y, z, c))
self.assertEqual(czyx, (c, z, y, x))


if __name__ == '__main__':
absltest.main()

0 comments on commit 3379255

Please sign in to comment.