Skip to content

Commit

Permalink
Refactor units testing to use pytest
Browse files Browse the repository at this point in the history
Refactor _Unit class to use dunders better
  • Loading branch information
oerc0122 committed Feb 26, 2025
1 parent 6b6b3b5 commit 8fd4226
Show file tree
Hide file tree
Showing 2 changed files with 176 additions and 237 deletions.
144 changes: 74 additions & 70 deletions MDANSE/Src/MDANSE/Framework/Units.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import copy
import math
import numbers

from collections import defaultdict
import json

from MDANSE.Core.Platform import PLATFORM
Expand Down Expand Up @@ -364,6 +364,76 @@ def __rsub__(self, other):

return other.__sub__(self)

def __ceil__(self):
"""Ceil of a _Unit value in canonical units.
>>> print(measure(10.2, 'm/s').ceiling())
10.0000 m / s
>>> print(measure(3.6, 'm/s').ounit('km/h').ceiling())
10.0 km / h
>>> print(measure(50.3, 'km/h').ceiling())
50.0 km / h
"""

r = copy.deepcopy(self)

if r._ounit is not None:
val = math.ceil(r.toval(r._ounit))
newu = _Unit("au", val)
newu *= _str_to_unit(r._ounit)
return newu.ounit(r._ounit)
else:
r._factor = math.ceil(r._factor)
return r

def __floor__(self):
"""Floor of a _Unit value in canonical units.
>>> print(measure(10.2, 'm/s').floor())
10.0000 m / s
>>> print(measure(3.6, 'm/s').ounit('km/h').floor())
10.0 km / h
>>> print(measure(50.3, 'km/h').floor())
50.0 km / h
"""

r = copy.deepcopy(self)

if r._ounit is not None:
val = math.floor(r.toval(r._ounit))
newu = _Unit("au", val)
newu *= _str_to_unit(r._ounit)
return newu.ounit(r._ounit)
else:
r._factor = math.floor(r._factor)
return r

def __round__(self, ndigits=None):
"""Round of a _Unit value in canonical units.
>>> print(measure(10.2, 'm/s').round())
10.0000 m / s
>>> print(measure(3.6, 'm/s').ounit('km/h').round())
11.0 km / h
>>> print(measure(50.3, 'km/h').round())
50.0 km / h
"""

r = copy.deepcopy(self)

if r._ounit is not None:
val = round(r.toval(r._ounit), ndigits)
newu = _Unit("au", val)
newu *= _str_to_unit(r._ounit)
return newu.ounit(r._ounit)
else:
r._factor = round(r._factor, ndigits)
return r

ceil = __ceil__
floor = __floor__
round = __round__

def __sub__(self, other):
"""Substract _Unit instances. To be substracted, the units has to be analog or equivalent.
Expand Down Expand Up @@ -464,28 +534,6 @@ def _mult_by(self, other):
self._ounit = None
self._out_factor = None

def ceil(self):
"""Ceil of a _Unit value in canonical units.
>>> print(measure(10.2, 'm/s').ceiling())
10.0000 m / s
>>> print(measure(3.6, 'm/s').ounit('km/h').ceiling())
10.0 km / h
>>> print(measure(50.3, 'km/h').ceiling())
50.0 km / h
"""

r = copy.deepcopy(self)

if r._ounit is not None:
val = math.ceil(r.toval(r._ounit))
newu = _Unit("au", val)
newu *= _str_to_unit(r._ounit)
return newu.ounit(r._ounit)
else:
r._factor = math.ceil(r._factor)
return r

@property
def dimension(self):
"""Getter for _dimension attribute. Returns a copy."""
Expand All @@ -510,28 +558,6 @@ def factor(self):

return self._factor

def floor(self):
"""Floor of a _Unit value in canonical units.
>>> print(measure(10.2, 'm/s').floor())
10.0000 m / s
>>> print(measure(3.6, 'm/s').ounit('km/h').floor())
10.0 km / h
>>> print(measure(50.3, 'km/h').floor())
50.0 km / h
"""

r = copy.deepcopy(self)

if r._ounit is not None:
val = math.floor(r.toval(r._ounit))
newu = _Unit("au", val)
newu *= _str_to_unit(r._ounit)
return newu.ounit(r._ounit)
else:
r._factor = math.floor(r._factor)
return r

@property
def format(self):
"""Getter for the output format."""
Expand Down Expand Up @@ -595,28 +621,6 @@ def ounit(self, ounit):
else:
raise UnitError("The units are not compatible")

def round(self):
"""Round of a _Unit value in canonical units.
>>> print(measure(10.2, 'm/s').round())
10.0000 m / s
>>> print(measure(3.6, 'm/s').ounit('km/h').round())
11.0 km / h
>>> print(measure(50.3, 'km/h').round())
50.0 km / h
"""

r = copy.deepcopy(self)

if r._ounit is not None:
val = round(r.toval(r._ounit))
newu = _Unit("au", val)
newu *= _str_to_unit(r._ounit)
return newu.ounit(r._ounit)
else:
r._factor = round(r._factor)
return r

def sqrt(self):
"""Square root of a _Unit.
Expand Down Expand Up @@ -731,12 +735,12 @@ def units(self, units):
UnitsManager._UNITS = units


_EQUIVALENCES = {}
_EQUIVALENCES = defaultdict(dict)


def add_equivalence(dim1, dim2, factor):
_EQUIVALENCES.setdefault(dim1, {}).__setitem__(dim2, factor)
_EQUIVALENCES.setdefault(dim2, {}).__setitem__(dim1, 1.0 / factor)
_EQUIVALENCES[dim1][dim2] = factor
_EQUIVALENCES[dim2][dim1] = 1.0 / factor


def measure(val, iunit="au", ounit="", equivalent=False):
Expand Down
Loading

0 comments on commit 8fd4226

Please sign in to comment.