diff --git a/MDANSE/Src/MDANSE/Framework/Units.py b/MDANSE/Src/MDANSE/Framework/Units.py
index c97e35588..3c09c4b06 100644
--- a/MDANSE/Src/MDANSE/Framework/Units.py
+++ b/MDANSE/Src/MDANSE/Framework/Units.py
@@ -14,10 +14,12 @@
# along with this program. If not, see .
#
import copy
+import json
import math
import numbers
-
-import json
+from collections import defaultdict
+from functools import singledispatchmethod
+from typing import Optional, Tuple
from MDANSE.Core.Platform import PLATFORM
from MDANSE.Core.Singleton import Singleton
@@ -70,159 +72,370 @@ class UnitError(Exception):
pass
-def get_trailing_digits(s):
- for i in range(len(s)):
- if s[i:].isdigit():
- return s[:i], int(s[i:])
+def get_trailing_digits(string: str) -> Tuple[str, int]:
+ """Get digits from the end of a string.
+
+ Always returns ``1`` if no digits.
+
+ Parameters
+ ----------
+ string : str
+ String to parse.
+
+ Returns
+ -------
+ str
+ String with digits stripped.
+ int
+ Digits from end of string as ``int``.
+
+ Examples
+ --------
+ >>> get_trailing_digits("str123")
+ ('str', 123)
+ >>> get_trailing_digits("nodigits")
+ ('nodigits', 1)
+ >>> get_trailing_digits("123preceding")
+ ('123preceding', 1)
+ """
+ for i in range(len(string)):
+ if string[i:].isdigit():
+ return string[:i], int(string[i:])
else:
- return s, 1
+ return string, 1
+
+
+class _Unit:
+ """Unit handler.
+
+ Handles all basic functions of units with correct dimensionality
+ and string printing.
+
+ Parameters
+ ----------
+ uname : str
+ Name of the unit.
+ factor : float
+ Factor relative to internal units.
+
+ Extra Parameters
+ ----------------
+ kg : int
+ Mass dimension.
+ m : int
+ Length dimension.
+ s : int
+ Time dimension.
+ K : int
+ Temperature dimension.
+ mol : int
+ Count dimension.
+ A : int
+ Current dimension.
+ cd : int
+ Luminous intensity dimension.
+ rad : int
+ Angular dimension.
+ sr : int
+ Solid angular dimension.
+ """
+ def __init__(
+ self,
+ uname: str,
+ factor: float,
+ kg: int = 0,
+ m: int = 0,
+ s: int = 0,
+ K: int = 0,
+ mol: int = 0,
+ A: int = 0,
+ cd: int = 0,
+ rad: int = 0,
+ sr: int = 0,
+ ):
+ self._factor = factor
+ self._dimension = [kg, m, s, K, mol, A, cd, rad, sr]
+ self._format = "g"
+ self._uname = uname
+ self._ounit = None
+ self._out_factor = None
+ self._equivalent = False
-def _parse_unit(iunit):
- max_prefix_length = 0
- for p in _PREFIXES:
- max_prefix_length = max(max_prefix_length, len(p))
+ def __add__(self, other):
+ """Add two _Unit instances.
- iunit = iunit.strip()
+ To be added, the units have to be analog or equivalent.
- iunit, upower = get_trailing_digits(iunit)
- if not iunit:
- raise UnitError("Invalid unit")
+ Parameters
+ ----------
+ other : _Unit
+ Unit to add.
- for i in range(len(iunit)):
- if UNITS_MANAGER.has_unit(iunit[i:]):
- prefix = iunit[:i]
- iunit = iunit[i:]
- break
- else:
- raise UnitError(f"The unit {iunit} is unknown")
+ Raises
+ ------
+ UnitError
+ Units are not equivalent or incompatible.
- if prefix:
- if prefix not in _PREFIXES:
- raise UnitError(f"The prefix {prefix} is unknown")
- prefix = _PREFIXES[prefix]
- else:
- prefix = 1.0
+ Examples
+ --------
+ >>> print(measure(10, 'm') + measure(20, 'km'))
+ 20010 m
+ """
+ u = copy.deepcopy(self)
- unit = UNITS_MANAGER.get_unit(iunit)
+ if u.is_analog(other):
+ u._factor += other._factor
+ elif self._equivalent:
+ equivalence_factor = u.get_equivalence_factor(other)
+ if equivalence_factor is None:
+ raise UnitError("The units are not equivalent")
- unit = _Unit(iunit, prefix * unit._factor, *unit._dimension)
+ u._factor += other._factor / equivalence_factor
+ else:
+ raise UnitError("Incompatible units.")
- unit **= upower
+ return u
- return unit
+ def __sub__(self, other):
+ """Subtract _Unit instances.
+ To be subtracted, the units have to be analog or equivalent.
-def _str_to_unit(s):
- if UNITS_MANAGER.has_unit(s):
- unit = UNITS_MANAGER.get_unit(s)
- return copy.deepcopy(unit)
+ >>> print(measure(20, 'km') + measure(10, 'm'))
+ 20.01 km
+ """
+ u = copy.deepcopy(self)
- else:
- unit = _Unit("au", 1.0)
+ if u.is_analog(other):
+ u._factor -= other._factor
+ elif u._equivalent:
+ equivalence_factor = u.get_equivalence_factor(other)
+ if equivalence_factor is None:
+ raise UnitError("The units are not equivalent")
- splitted_units = s.split("/")
+ u._factor -= other._factor / equivalence_factor
+ else:
+ raise UnitError("Incompatible units")
- if len(splitted_units) == 1:
- units = splitted_units[0].split(" ")
- for u in units:
- u = u.strip()
- unit *= _parse_unit(u)
- unit._uname = s
+ return u
- return unit
+ def __truediv__(self, other):
+ """Divide two _Unit instances.
- elif len(splitted_units) == 2:
- numerator = splitted_units[0].strip()
- if numerator != "1":
- numerator = numerator.split(" ")
- for u in numerator:
- u = u.strip()
- unit *= _parse_unit(u)
+ To be divided, the units have to be analog or equivalent.
- denominator = splitted_units[1].strip().split(" ")
- for u in denominator:
- u = u.strip()
- unit /= _parse_unit(u)
+ Parameters
+ ----------
+ other : _Unit
+ Unit to add.
- unit._uname = s
+ Raises
+ ------
+ UnitError
+ Units are not equivalent or incompatible.
- return unit
+ Examples
+ --------
+ >>> print(measure(100, 'V') / measure(10, 'kohm'))
+ 0.01 A1
+ >>> print(measure(100, 'V') / 10)
+ 10 V
+ """
+ u = copy.deepcopy(self)
+ if isinstance(other, numbers.Number):
+ u._factor /= other
+ elif isinstance(other, _Unit):
+ u._div_by(other)
+ else:
+ raise UnitError("Invalid operand")
+
+ return u
+ def __floordiv__(self, other):
+ """Divide two _Unit instances and truncate.
+
+ To be divided, the units have to be analog or equivalent.
+
+ Parameters
+ ----------
+ other : _Unit
+ Unit to add.
+
+ Raises
+ ------
+ UnitError
+ Units are not equivalent or incompatible.
+
+ Examples
+ --------
+ >>> print(measure(10, 'kohm') // measure(10, 'V'))
+ 1000 1 / A1
+ >>> print(measure(15, 'ohm') // 10)
+ 1 ohm
+ """
+ u = copy.deepcopy(self)
+ if isinstance(other, numbers.Number):
+ u._factor //= other
+ elif isinstance(other, _Unit):
+ u._div_by(other)
+ u._factor = math.floor(u._factor)
else:
- raise UnitError(f"Invalid unit: {s}")
+ raise UnitError("Invalid operand")
+ return u
-class _Unit(object):
- def __init__(
- self, uname, factor, kg=0, m=0, s=0, K=0, mol=0, A=0, cd=0, rad=0, sr=0
- ):
- self._factor = factor
+ def __mul__(self, other):
+ """Multiply _Unit instances or scaling factors.
- self._dimension = [kg, m, s, K, mol, A, cd, rad, sr]
+ Examples
+ --------
+ >>> print(measure(10, 'm/s') * measure(10, 's'))
+ 100 m1
+ >>> print(measure(10, 'm') * measure(10, 's'))
+ 100 m1 s1
+ >>> print(measure(10, 'm') * 10)
+ 100 m
+ """
- self._format = "g"
+ u = copy.deepcopy(self)
+ if isinstance(other, numbers.Number):
+ u._factor *= other
+ return u
+ elif isinstance(other, _Unit):
+ u._mult_by(other)
+ return u
+ else:
+ raise UnitError("Invalid operand")
- self._uname = uname
+ def __pow__(self, n: float):
+ """Raise a _Unit to a factor.
- self._ounit = None
+ Examples
+ --------
+ >>> print(measure(10.5, 'm/s')**2)
+ 110.25 m2 / s2
+ """
+ output_unit = copy.copy(self)
+ output_unit._ounit = None
+ output_unit._out_factor = None
+ output_unit._factor = pow(output_unit._factor, n)
+ for i in range(len(output_unit._dimension)):
+ output_unit._dimension[i] *= n
- self._out_factor = None
+ return output_unit
- self._equivalent = False
+ def __float__(self) -> float:
+ """Return the value of a _Unit coerced to float.
- def __add__(self, other):
- """Add _Unit instances. To be added, the units has to be analog or equivalent.
+ Examples
+ --------
+ >>> float(measure(10.5, 'm/s'))
+ 10.5
- >>> print(measure(10, 'm') + measure(20, 'km'))
- 20010 m
+ See Also
+ --------
+ __int__ : Truncate value.
"""
+ return float(self.toval())
- u = copy.deepcopy(self)
+ def __int__(self) -> int:
+ """Return the value of a _Unit coerced to integer.
- if u.is_analog(other):
- u._factor += other._factor
- return u
- elif self._equivalent:
- equivalence_factor = u.get_equivalence_factor(other)
- if equivalence_factor is not None:
- u._factor += other._factor / equivalence_factor
- return u
- else:
- raise UnitError("The units are not equivalent")
+ Notes
+ ------
+ This will happen to the value in the default output unit:
+
+ Examples
+ --------
+ >>> int(measure(10.5, 'm/s'))
+ 10
+ """
+ return int(self.toval())
+
+ def __ceil__(self):
+ """Ceil of a _Unit value in canonical units.
+
+ Examples
+ --------
+ >>> print(measure(10.2, 'm/s').ceiling())
+ 11 m/s
+ >>> print(measure(3.6, 'm/s').ounit('km/h').ceiling())
+ 13 km/h
+ >>> print(measure(50.3, 'km/h').ceiling())
+ 51 km/h
+ """
+
+ r = copy.deepcopy(self)
+
+ if r._ounit is not None:
+ val = math.ceil(r.toval(r._ounit))
+ newu = _Unit("au", val)
+ newu *= _str_to_unit(r._ounit)
+ return newu.ounit(r._ounit)
else:
- raise UnitError("Incompatible units")
+ r._factor = math.ceil(r._factor)
+ return r
- def __truediv__(self, other):
- """Divide _Unit instances.
+ def __floor__(self):
+ """Floor of a _Unit value in canonical units.
- >>> print(measure(100, 'V') / measure(10, 'kohm'))
- 0.0100 A
+ Examples
+ --------
+ >>> print(measure(10.2, 'm/s').floor())
+ 10 m/s
+ >>> print(measure(3.6, 'm/s').ounit('km/h').floor())
+ 12 km/h
+ >>> print(measure(50.3, 'km/h').floor())
+ 50 km/h
"""
- u = copy.deepcopy(self)
- if isinstance(other, numbers.Number):
- u._factor /= other
- return u
- elif isinstance(other, _Unit):
- u._div_by(other)
- return u
+ r = copy.deepcopy(self)
+
+ if r._ounit is not None:
+ val = math.floor(r.toval(r._ounit))
+ newu = _Unit("au", val)
+ newu *= _str_to_unit(r._ounit)
+ return newu.ounit(r._ounit)
else:
- raise UnitError("Invalid operand")
+ r._factor = math.floor(r._factor)
+ return r
- def __float__(self):
- """Return the value of a _Unit coerced to float. See __int__."""
+ def __round__(self, ndigits=None):
+ """Round of a _Unit value in canonical units.
- return float(self.toval())
+ Examples
+ --------
+ >>> print(measure(10.2, 'm/s').round())
+ 10 m/s
+ >>> print(measure(3.6, 'm/s').ounit('km/h').round())
+ 13 km/h
+ >>> print(measure(50.3, 'km/h').round())
+ 50 km/h
+ """
- def __floordiv__(self, other):
- u = copy.deepcopy(self)
- u._div_by(other)
- u._factor = math.floor(u._factor)
- return u
+ r = copy.deepcopy(self)
+
+ if r._ounit is not None:
+ val = round(r.toval(r._ounit), ndigits)
+ newu = _Unit("au", val)
+ newu *= _str_to_unit(r._ounit)
+ return newu.ounit(r._ounit)
+ else:
+ r._factor = round(r._factor, ndigits)
+ return r
+
+ ceiling = __ceil__
+ floor = __floor__
+ round = __round__
def __iadd__(self, other):
- """Add _Unit instances. See __add__."""
+ """Add _Unit instances.
+
+ See Also
+ --------
+ __add__
+ """
if self.is_analog(other):
self._factor += other._factor
@@ -238,7 +451,12 @@ def __iadd__(self, other):
raise UnitError("Incompatible units")
def __itruediv__(self, other):
- """Divide _Unit instances. See __div__."""
+ """Divide _Unit instances.
+
+ See Also
+ --------
+ __div__
+ """
if isinstance(other, numbers.Number):
self._factor /= other
@@ -250,12 +468,24 @@ def __itruediv__(self, other):
raise UnitError("Invalid operand")
def __ifloordiv__(self, other):
+ """Divide _Unit instances and truncate.
+
+ See Also
+ --------
+ __truediv__
+ """
self._div_by(other)
self._factor = math.floor(self._factor)
return self
def __imul__(self, other):
- """Multiply _Unit instances. See __mul__."""
+ """
+ Multiply _Unit instances.
+
+ See Also
+ --------
+ __mul__
+ """
if isinstance(other, numbers.Number):
self._factor *= other
@@ -266,17 +496,6 @@ def __imul__(self, other):
else:
raise UnitError("Invalid operand")
- def __int__(self):
- """Return the value of a _Unit coerced to integer.
-
- Note that this will happen to the value in the default output unit:
-
- >>> print(int(measure(10.5, 'm/s')))
- 10
- """
-
- return int(self.toval())
-
def __ipow__(self, n):
self._factor = pow(self._factor, n)
for i in range(len(self._dimension)):
@@ -288,7 +507,7 @@ def __ipow__(self, n):
return self
def __isub__(self, other):
- """Substract _Unit instances. See __sub__."""
+ """Subtract _Unit instances. See __sub__."""
if self.is_analog(other):
self._factor -= other._factor
@@ -303,36 +522,13 @@ def __isub__(self, other):
else:
raise UnitError("Incompatible units")
- def __mul__(self, other):
- """Multiply _Unit instances.
-
- >>> print(measure(10, 'm/s') * measure(10, 's'))
- 100.0000 m
- """
-
- u = copy.deepcopy(self)
- if isinstance(other, numbers.Number):
- u._factor *= other
- return u
- elif isinstance(other, _Unit):
- u._mult_by(other)
- return u
- else:
- raise UnitError("Invalid operand")
-
- def __pow__(self, n):
- output_unit = copy.copy(self)
- output_unit._ounit = None
- output_unit._out_factor = None
- output_unit._factor = pow(output_unit._factor, n)
- for i in range(len(output_unit._dimension)):
- output_unit._dimension[i] *= n
-
- return output_unit
-
def __radd__(self, other):
- """Add _Unit instances. See __add__."""
+ """Add _Unit instances.
+ See Also
+ --------
+ __add__
+ """
return self.__add__(other)
def __rdiv__(self, other):
@@ -360,32 +556,10 @@ def __rmul__(self, other):
raise UnitError("Invalid operand")
def __rsub__(self, other):
- """Substract _Unit instances. See __sub__."""
+ """Subtract _Unit instances. See __sub__."""
return other.__sub__(self)
- def __sub__(self, other):
- """Substract _Unit instances. To be substracted, the units has to be analog or equivalent.
-
- >>> print(measure(20, 'km') + measure(10, 'm'))
- 19990 m
- """
-
- u = copy.deepcopy(self)
-
- if u.is_analog(other):
- u._factor -= other._factor
- return u
- elif u._equivalent:
- equivalence_factor = u.get_equivalence_factor(other)
- if equivalence_factor is not None:
- u._factor -= other._factor / equivalence_factor
- return u
- else:
- raise UnitError("The units are not equivalent")
- else:
- raise UnitError("Incompatible units")
-
def __str__(self):
unit = copy.copy(self)
@@ -423,7 +597,26 @@ def __str__(self):
return s
- def _div_by(self, other):
+ def _div_by(self, other) -> None:
+ """Compute divided unit including new dimensionality.
+
+ Parameters
+ ----------
+ other : _Unit
+ Factor to divide by.
+
+ Raises
+ ------
+ UnitError
+ If other is not compatible.
+
+ Examples
+ --------
+ >>> a = measure(2., "ang")
+ >>> a._div_by(measure(4., "s"))
+ >>> print(a)
+ 5e-11 m1 / s1
+ """
if self.is_analog(other):
self._factor /= other._factor
self._dimension = [0, 0, 0, 0, 0, 0, 0, 0, 0]
@@ -442,20 +635,39 @@ def _div_by(self, other):
self._ounit = None
self._out_factor = None
- def _mult_by(self, other):
+ def _mult_by(self, other) -> None:
+ """Compute multiplied unit including new dimensionality.
+
+ Parameters
+ ----------
+ other : _Unit
+ Factor to multiply by.
+
+ Raises
+ ------
+ UnitError
+ If other is not compatible.
+
+ Examples
+ --------
+ >>> a = measure(1., "ang")
+ >>> a._mult_by(measure(3., "s"))
+ >>> print(a)
+ 3e-10 m1 s1
+ """
if self.is_analog(other):
self._factor *= other._factor
for i in range(len(self._dimension)):
self._dimension[i] = 2.0 * self._dimension[i]
elif self._equivalent:
equivalence_factor = self.get_equivalence_factor(other)
- if equivalence_factor is not None:
- self._factor *= other._factor / equivalence_factor
- for i in range(len(self._dimension)):
- self._dimension[i] = 2 * self._dimension[i]
- return
- else:
+ if equivalence_factor is None:
raise UnitError("The units are not equivalent")
+
+ self._factor *= other._factor / equivalence_factor
+ for i in range(len(self._dimension)):
+ self._dimension[i] = 2 * self._dimension[i]
+ return
else:
self._factor *= other._factor
for i in range(len(self._dimension)):
@@ -464,28 +676,6 @@ def _mult_by(self, other):
self._ounit = None
self._out_factor = None
- def ceil(self):
- """Ceil of a _Unit value in canonical units.
-
- >>> print(measure(10.2, 'm/s').ceiling())
- 10.0000 m / s
- >>> print(measure(3.6, 'm/s').ounit('km/h').ceiling())
- 10.0 km / h
- >>> print(measure(50.3, 'km/h').ceiling())
- 50.0 km / h
- """
-
- r = copy.deepcopy(self)
-
- if r._ounit is not None:
- val = math.ceil(r.toval(r._ounit))
- newu = _Unit("au", val)
- newu *= _str_to_unit(r._ounit)
- return newu.ounit(r._ounit)
- else:
- r._factor = math.ceil(r._factor)
- return r
-
@property
def dimension(self):
"""Getter for _dimension attribute. Returns a copy."""
@@ -493,15 +683,13 @@ def dimension(self):
return copy.copy(self._dimension)
@property
- def equivalent(self):
+ def equivalent(self) -> bool:
"""Getter for _equivalent attribute."""
return self._equivalent
@equivalent.setter
def equivalent(self, equivalent):
- """Setter for _equivalent attribute."""
-
self._equivalent = equivalent
@property
@@ -510,28 +698,6 @@ def factor(self):
return self._factor
- def floor(self):
- """Floor of a _Unit value in canonical units.
-
- >>> print(measure(10.2, 'm/s').floor())
- 10.0000 m / s
- >>> print(measure(3.6, 'm/s').ounit('km/h').floor())
- 10.0 km / h
- >>> print(measure(50.3, 'km/h').floor())
- 50.0 km / h
- """
-
- r = copy.deepcopy(self)
-
- if r._ounit is not None:
- val = math.floor(r.toval(r._ounit))
- newu = _Unit("au", val)
- newu *= _str_to_unit(r._ounit)
- return newu.ounit(r._ounit)
- else:
- r._factor = math.floor(r._factor)
- return r
-
@property
def format(self):
"""Getter for the output format."""
@@ -540,43 +706,106 @@ def format(self):
@format.setter
def format(self, fmt):
- """Setter for the output format."""
-
self._format = fmt
- def is_analog(self, other):
- """Returns True if the other unit is analog to this unit. Analog units are units whose dimension vector exactly matches."""
+ def is_analog(self, other) -> bool:
+ """Whether two units are analog.
+
+ Analog units are units whose dimension vector exactly matches.
+
+ Parameters
+ ----------
+ other : _Unit
+ Unit to test.
+
+ Returns
+ -------
+ bool
+ Whether two units are "analog".
+
+ Examples
+ --------
+ >>> a, b = measure(1., "km"), measure(1., "ang")
+ >>> a.is_analog(b)
+ True
+ >>> a, b = measure(1., "km"), measure(1., "ohm")
+ >>> a.is_analog(b)
+ False
+ """
return self._dimension == other._dimension
- def get_equivalence_factor(self, other):
- """Returns the equivalence factor if the other unit is equivalent to this unit. Equivalent units are units whose dimension are related through a constant
+ def get_equivalence_factor(self, other) -> Optional[float]:
+ """Returns the equivalence factor if other unit is equivalent.
+
+ Equivalent units are units whose dimension are related through a constant
(e.g. energy and mass, or frequency and temperature).
+
+ Parameters
+ ----------
+ other : _Unit
+ Potentially equivalent unit.
+
+ Returns
+ -------
+ Optional[float]
+ Equivalence factor to transform from one to the other
+ or ``None`` if not equivalent.
+
+ See Also
+ --------
+ _EQUIVALENCES : Dict of equivalent units.
+ add_equivalence : Add new equivalence to dict.
+
+ Examples
+ --------
+ >>> a = measure(1., "1/m")
+ >>> a.get_equivalence_factor(measure(1., "J/mol"))
+ 0.000119627
+ >>> print(a.get_equivalence_factor(measure(1., "ang")))
+ None
"""
_, upower = get_trailing_digits(self._uname)
- dimension = tuple([d / upower for d in self._dimension])
+ dimension = tuple(d / upower for d in self._dimension)
+
if dimension not in _EQUIVALENCES:
return None
powerized_equivalences = {}
- for k, v in list(_EQUIVALENCES[dimension].items()):
- pk = tuple([d * upower for d in k])
+ for k, v in _EQUIVALENCES[dimension].items():
+ pk = tuple(d * upower for d in k)
powerized_equivalences[pk] = pow(v, upower)
odimension = tuple(other._dimension)
if odimension in powerized_equivalences:
return powerized_equivalences[odimension]
- else:
- return None
- def ounit(self, ounit):
+ return None
+
+ def ounit(self, ounit: str):
"""Set the preferred unit for output.
+ Parameters
+ ----------
+ ounit : str
+ Preferred output unit.
+
+ Raises
+ ------
+ UnitError
+ Units are incompatible.
+
+ Notes
+ -----
+ Returns a modified reference to self not a new object.
+
+ Examples
+ --------
>>> a = measure(1, 'kg m2 / s2')
>>> print(a)
- 1.0000 kg m2 / s2
+ 1 kg m2 / s2
>>> print(a.ounit('J'))
- 1.0000 J
+ 1 J
"""
out_factor = _str_to_unit(ounit)
@@ -585,57 +814,56 @@ def ounit(self, ounit):
self._ounit = ounit
self._out_factor = out_factor
return self
+
elif self._equivalent:
- if self.get_equivalence_factor(out_factor) is not None:
- self._ounit = ounit
- self._out_factor = out_factor
- return self
- else:
+ if self.get_equivalence_factor(out_factor) is None:
raise UnitError("The units are not equivalents")
- else:
- raise UnitError("The units are not compatible")
-
- def round(self):
- """Round of a _Unit value in canonical units.
-
- >>> print(measure(10.2, 'm/s').round())
- 10.0000 m / s
- >>> print(measure(3.6, 'm/s').ounit('km/h').round())
- 11.0 km / h
- >>> print(measure(50.3, 'km/h').round())
- 50.0 km / h
- """
- r = copy.deepcopy(self)
+ self._ounit = ounit
+ self._out_factor = out_factor
+ return self
- if r._ounit is not None:
- val = round(r.toval(r._ounit))
- newu = _Unit("au", val)
- newu *= _str_to_unit(r._ounit)
- return newu.ounit(r._ounit)
else:
- r._factor = round(r._factor)
- return r
+ raise UnitError("The units are not compatible")
def sqrt(self):
"""Square root of a _Unit.
+ Returns
+ -------
+ _Unit
+ New unit which is sqrt of original.
+
+ Examples
+ --------
>>> print(measure(4, 'm2/s2').sqrt())
- 2.0000 m / s
+ 2 m1.0 / s-1.0
"""
return self**0.5
- def toval(self, ounit=""):
+ def toval(self, ounit: str = "") -> float:
"""Returns the numeric value of a unit.
The value is given in ounit or in the default output unit.
+ Parameters
+ ----------
+ ounit : str
+ Unit to convert to.
+
+ Returns
+ -------
+ float
+ Value in output unit.
+
+ Examples
+ --------
>>> v = measure(100, 'km/h')
>>> v.toval()
100.0
>>> v.toval(ounit='m/s')
- 27.777777777777779
+ 27.77777777777778
"""
newu = copy.deepcopy(self)
@@ -660,19 +888,119 @@ def toval(self, ounit=""):
return newu._factor
-class UnitsManagerEncoder(json.JSONEncoder):
- def default(self, obj):
- if isinstance(obj, UnitsManager):
- d = {}
- for k, v in obj.units:
- d[k] = {"factor": v.factor, "dimension": v.dimension}
- return d
- elif isinstance(obj, _Unit):
- return {"factor": obj.factor, "dimension": obj.dimension}
- return json.JSONEncoder.default(self, obj)
+def _parse_unit(iunit: str) -> _Unit:
+ """Parse single unit as a string into a Unit type.
+
+ Parameters
+ ----------
+ iunit : str
+ String to parse.
+
+ Returns
+ -------
+ _Unit
+ Expected unit.
+
+ Raises
+ ------
+ UnitError
+ If string does not contain valid unit.
+ """
+
+ max_prefix_length = 0
+ for p in _PREFIXES:
+ max_prefix_length = max(max_prefix_length, len(p))
+
+ iunit = iunit.strip()
+
+ iunit, upower = get_trailing_digits(iunit)
+ if not iunit:
+ raise UnitError("Invalid unit")
+
+ for i in range(len(iunit)):
+ if UNITS_MANAGER.has_unit(iunit[i:]):
+ prefix = iunit[:i]
+ iunit = iunit[i:]
+ break
+ else:
+ raise UnitError(f"The unit {iunit} is unknown")
+
+ if prefix:
+ if prefix not in _PREFIXES:
+ raise UnitError(f"The prefix {prefix} is unknown")
+ prefix = _PREFIXES[prefix]
+ else:
+ prefix = 1.0
+
+ unit = UNITS_MANAGER.get_unit(iunit)
+
+ unit = _Unit(iunit, prefix * unit._factor, *unit._dimension)
+
+ unit **= upower
+
+ return unit
+
+
+def _str_to_unit(s: str) -> _Unit:
+ """Parse general string into unit description.
+
+ Parameters
+ ----------
+ s : str
+ String to parse.
+
+ Returns
+ -------
+ _Unit
+ Parsed unit.
+
+ Raises
+ ------
+ UnitError
+ String is not a valid unit specification.
+ """
+ if UNITS_MANAGER.has_unit(s):
+ unit = UNITS_MANAGER.get_unit(s)
+ return copy.deepcopy(unit)
+
+ else:
+ unit = _Unit("au", 1.0)
+
+ splitted_units = s.split("/")
+
+ if len(splitted_units) == 1:
+ units = splitted_units[0].split(" ")
+ for u in units:
+ u = u.strip()
+ unit *= _parse_unit(u)
+ unit._uname = s
+
+ return unit
+
+ elif len(splitted_units) == 2:
+ numerator = splitted_units[0].strip()
+ if numerator != "1":
+ numerator = numerator.split(" ")
+ for u in numerator:
+ u = u.strip()
+ unit *= _parse_unit(u)
+
+ denominator = splitted_units[1].strip().split(" ")
+ for u in denominator:
+ u = u.strip()
+ unit /= _parse_unit(u)
+
+ unit._uname = s
+
+ return unit
+
+ else:
+ raise UnitError(f"Invalid unit: {s}")
class UnitsManager(metaclass=Singleton):
+ """Database dictionary for handling units."""
+
_UNITS = {}
_DEFAULT_DATABASE = (
@@ -691,39 +1019,49 @@ def add_unit(
uname, factor, kg, m, s, K, mol, A, cd, rad, sr
)
- def delete_unit(self, uname):
+ def delete_unit(self, uname) -> None:
if uname in UnitsManager._UNITS:
del UnitsManager._UNITS[uname]
- def get_unit(self, uname):
+ def get_unit(self, uname) -> Optional[_Unit]:
return UnitsManager._UNITS.get(uname, None)
- def has_unit(self, uname):
+ def has_unit(self, uname) -> bool:
return uname in UnitsManager._UNITS
- def load(self):
+ def load(self) -> None:
+ """Load units from databases.
+
+ Fill self with unit infomration.
+ """
UnitsManager._UNITS.clear()
d = {}
+
with open(UnitsManager._DEFAULT_DATABASE, "r") as fin:
d.update(json.load(fin))
+
try:
with open(UnitsManager._USER_DATABASE, "r") as fin:
d.update(json.load(fin))
- except Exception:
+
+ except FileNotFoundError:
self.save()
+
finally:
- for uname, udict in list(d.items()):
+ for uname, udict in d.items():
factor = udict.get("factor", 1.0)
dim = udict.get("dimension", [0, 0, 0, 0, 0, 0, 0, 0, 0])
UnitsManager._UNITS[uname] = _Unit(uname, factor, *dim)
def save(self):
+ """Write self to custom user database."""
with open(UnitsManager._USER_DATABASE, "w") as fout:
json.dump(UnitsManager._UNITS, fout, indent=4, cls=UnitsManagerEncoder)
@property
def units(self):
+ """Direct access to unit database."""
return UnitsManager._UNITS
@units.setter
@@ -731,15 +1069,59 @@ def units(self, units):
UnitsManager._UNITS = units
-_EQUIVALENCES = {}
+class UnitsManagerEncoder(json.JSONEncoder):
+ """Custom encoder for writing units."""
+
+ @singledispatchmethod
+ def default(self, obj):
+ return json.JSONEncoder.default(self, obj)
+ @default.register(UnitsManager)
+ def _(self, obj):
+ return {k: {"factor": v.factor, "dimension": v.dimension} for k, v in obj.units}
-def add_equivalence(dim1, dim2, factor):
- _EQUIVALENCES.setdefault(dim1, {}).__setitem__(dim2, factor)
- _EQUIVALENCES.setdefault(dim2, {}).__setitem__(dim1, 1.0 / factor)
+ @default.register(_Unit)
+ def _(self, obj):
+ return {"factor": obj.factor, "dimension": obj.dimension}
-def measure(val, iunit="au", ounit="", equivalent=False):
+#: Set of units considered directly or indirectly equivalent.
+_EQUIVALENCES = defaultdict(dict)
+
+
+def add_equivalence(dim1, dim2, factor):
+ _EQUIVALENCES[dim1][dim2] = factor
+ _EQUIVALENCES[dim2][dim1] = 1.0 / factor
+
+
+def measure(
+ val: float, iunit: str = "au", ounit: str = "", *, equivalent: bool = False
+) -> _Unit:
+ """Create a unit bearing object.
+
+ Parses i/ounits and returns the relevant data object.
+
+ Parameters
+ ----------
+ val : float
+ Value for unit.
+ iunit : str
+ Input unit.
+ ounit : str
+ Desired output unit.
+ equivalent : bool
+ Whether the unit is to be considered "equivalent".
+
+ Returns
+ -------
+ _Unit
+ Desired unit.
+
+ Examples
+ --------
+ >>> print(measure(1., 'ang'))
+ 1 ang
+ """
if iunit:
unit = _str_to_unit(iunit)
unit *= val
diff --git a/MDANSE/Tests/UnitTests/test_units.py b/MDANSE/Tests/UnitTests/test_units.py
index 28e07d624..342e7c90c 100644
--- a/MDANSE/Tests/UnitTests/test_units.py
+++ b/MDANSE/Tests/UnitTests/test_units.py
@@ -13,170 +13,104 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see .
#
-import unittest
-
-from MDANSE.Framework.Units import _PREFIXES, UnitError, measure
-
-
-class TestUnits(unittest.TestCase):
- """
- Unittest for the geometry-related functions
- """
-
- def test_basic_units(self):
- m = measure(1.0, "kg")
- self.assertAlmostEqual(m.toval(), 1.0, delta=1.0e-09)
-
- m = measure(1.0, "m")
- self.assertAlmostEqual(m.toval(), 1.0, delta=1.0e-09)
-
- m = measure(1.0, "s")
- self.assertAlmostEqual(m.toval(), 1.0, delta=1.0e-09)
-
- m = measure(1.0, "K")
- self.assertAlmostEqual(m.toval(), 1.0, delta=1.0e-09)
-
- m = measure(1.0, "mol")
- self.assertAlmostEqual(m.toval(), 1.0, delta=1.0e-09)
-
- m = measure(1.0, "A")
- self.assertAlmostEqual(m.toval(), 1.0, delta=1.0e-09)
-
- m = measure(1.0, "cd")
- self.assertAlmostEqual(m.toval(), 1.0, delta=1.0e-09)
-
- m = measure(1.0, "rad")
- self.assertAlmostEqual(m.toval(), 1.0, delta=1.0e-09)
-
- m = measure(1.0, "sr")
- self.assertAlmostEqual(m.toval(), 1.0, delta=1.0e-09)
-
- def test_prefix(self):
- m = measure(1.0, "s")
- self.assertAlmostEqual(m.toval("ys"), 1.0 / _PREFIXES["y"], delta=1.0e-09)
- self.assertAlmostEqual(m.toval("zs"), 1.0 / _PREFIXES["z"], delta=1.0e-09)
- self.assertAlmostEqual(m.toval("as"), 1.0 / _PREFIXES["a"], delta=1.0e-09)
- self.assertAlmostEqual(m.toval("fs"), 1.0 / _PREFIXES["f"], delta=1.0e-09)
- self.assertAlmostEqual(m.toval("ps"), 1.0 / _PREFIXES["p"], delta=1.0e-09)
- self.assertAlmostEqual(m.toval("ns"), 1.0 / _PREFIXES["n"], delta=1.0e-09)
- self.assertAlmostEqual(m.toval("us"), 1.0 / _PREFIXES["u"], delta=1.0e-09)
- self.assertAlmostEqual(m.toval("ms"), 1.0 / _PREFIXES["m"], delta=1.0e-09)
- self.assertAlmostEqual(m.toval("cs"), 1.0 / _PREFIXES["c"], delta=1.0e-09)
- self.assertAlmostEqual(m.toval("ds"), 1.0 / _PREFIXES["d"], delta=1.0e-09)
- self.assertAlmostEqual(m.toval("das"), 1.0 / _PREFIXES["da"], delta=1.0e-09)
- self.assertAlmostEqual(m.toval("hs"), 1.0 / _PREFIXES["h"], delta=1.0e-09)
- self.assertAlmostEqual(m.toval("ks"), 1.0 / _PREFIXES["k"], delta=1.0e-09)
- self.assertAlmostEqual(m.toval("Ms"), 1.0 / _PREFIXES["M"], delta=1.0e-09)
- self.assertAlmostEqual(m.toval("Gs"), 1.0 / _PREFIXES["G"], delta=1.0e-09)
- self.assertAlmostEqual(m.toval("Ts"), 1.0 / _PREFIXES["T"], delta=1.0e-09)
- self.assertAlmostEqual(m.toval("Ps"), 1.0 / _PREFIXES["P"], delta=1.0e-09)
- self.assertAlmostEqual(m.toval("Es"), 1.0 / _PREFIXES["E"], delta=1.0e-09)
- self.assertAlmostEqual(m.toval("Zs"), 1.0 / _PREFIXES["Z"], delta=1.0e-09)
- self.assertAlmostEqual(m.toval("Ys"), 1.0 / _PREFIXES["Y"], delta=1.0e-09)
-
- def test_composite_units(self):
- m = measure(1.0, "m/s")
- self.assertAlmostEqual(m.toval("km/h"), 3.6, delta=1.0e-09)
-
- def test_add_units(self):
- m1 = measure(1.0, "s")
- m2 = measure(1.0, "ms")
-
- m = m1 + m2
- self.assertAlmostEqual(m.toval("s"), 1.001, delta=1.0e-09)
-
- m += m2
- self.assertAlmostEqual(m.toval("s"), 1.002, delta=1.0e-09)
-
- def test_substract_units(self):
- m1 = measure(1.0, "s")
- m2 = measure(1.0, "ms")
-
- m = m1 - m2
- self.assertAlmostEqual(m.toval("s"), 0.999, delta=1.0e-09)
-
- m -= m2
- self.assertAlmostEqual(m.toval("s"), 0.998, delta=1.0e-09)
-
- def test_product_units(self):
- m1 = measure(1.0, "m")
- m2 = measure(5.0, "hm")
-
- m = m1 * m2
- self.assertAlmostEqual(m.toval("m2"), 500, delta=1.0e-09)
-
- m *= measure(10, "cm")
- self.assertAlmostEqual(m.toval("m3"), 50, delta=1.0e-09)
-
- m *= 20
- self.assertAlmostEqual(m.toval("m3"), 1000, delta=1.0e-09)
-
- def test_divide_units(self):
- m1 = measure(1.0, "m")
- m2 = measure(5.0, "hm")
-
- m = m1 / m2
- self.assertAlmostEqual(m.toval("au"), 0.002, delta=1.0e-09)
-
- m /= 0.0001
- self.assertAlmostEqual(m.toval("au"), 20.0, delta=1.0e-09)
-
- m /= m2
- self.assertRaises(UnitError, m.toval, "au")
- self.assertAlmostEqual(m.toval("1/m"), 4.0e-02, delta=1.0e-09)
-
- def test_floor_unit(self):
- self.assertAlmostEqual(
- measure(10.2, "m/s").floor().toval(), 10.0, delta=1.0e-09
- )
- self.assertAlmostEqual(
- measure(3.6, "m/s").ounit("km/h").floor().toval(), 12.0, delta=1.0e-09
- )
- self.assertAlmostEqual(
- measure(50.3, "km/h").floor().toval(), 50.0, delta=1.0e-09
- )
-
- def test_ceil_unit(self):
- self.assertAlmostEqual(measure(10.2, "m/s").ceil().toval(), 11.0, delta=1.0e-09)
- self.assertAlmostEqual(
- measure(3.6, "m/s").ounit("km/h").ceil().toval(), 13.0, delta=1.0e-09
- )
- self.assertAlmostEqual(
- measure(50.3, "km/h").ceil().toval(), 51.0, delta=1.0e-09
- )
-
- def test_round_unit(self):
- self.assertAlmostEqual(
- measure(10.2, "m/s").round().toval(), 10.0, delta=1.0e-09
- )
- self.assertAlmostEqual(
- measure(3.6, "m/s").ounit("km/h").round().toval(), 13.0, delta=1.0e-09
- )
- self.assertAlmostEqual(
- measure(50.3, "km/h").round().toval(), 50.0, delta=1.0e-09
- )
-
- def test_int_unit(self):
- self.assertEqual(int(measure(10.2, "km/h")), 10)
-
- def test_sqrt_unit(self):
- m = measure(4.0, "m2/s2")
-
- m = m.sqrt()
-
- self.assertAlmostEqual(m.toval(), 2.0, delta=1.0e-09)
- self.assertEqual(m.dimension, [0, 1, -1, 0, 0, 0, 0, 0, 0])
-
- def test_power_unit(self):
- m = measure(4.0, "m")
- m **= 3
- self.assertAlmostEqual(m.toval(), 64.0, delta=1.0e-09)
- self.assertEqual(m.dimension, [0, 3, 0, 0, 0, 0, 0, 0, 0])
-
- def test_equivalent_units(self):
- m = measure(1.0, "eV", equivalent=True)
- self.assertAlmostEqual(m.toval("THz"), 241.799, delta=1.0e-03)
- self.assertAlmostEqual(m.toval("K"), 11604.52, delta=1.0e-02)
-
- m = measure(1.0, "eV", equivalent=False)
- self.assertRaises(UnitError, m.toval, "THz")
+from contextlib import nullcontext
+from math import ceil, floor
+from operator import (add, iadd, imul, ipow, isub, itruediv, mul, pow, sub,
+ truediv)
+from random import random
+from typing import TypeVar, Union
+
+import pytest
+from MDANSE.Framework.Units import _PREFIXES, UnitError, _Unit, measure
+
+T = TypeVar("T")
+
+def _measure_or_val(m_o_v: T) -> Union[_Unit, T]:
+ """Convert a value to a measure or leave if not (val, unit)."""
+ if isinstance(m_o_v, (tuple, list)):
+ return measure(*m_o_v)
+ return m_o_v
+
+@pytest.mark.parametrize("unit", [
+ "kg", "m", "s", "K", "mol", "A", "cd", "rad", "sr",
+])
+def test_basic_units(unit):
+ m = measure(1.0, unit)
+ assert m.toval() == pytest.approx(1.0)
+
+@pytest.mark.parametrize("prefix", [
+ "y", "z", "a", "f", "p", "n", "u", "m", "c", "d",
+ "da", "h", "k", "M", "G", "T", "P", "E", "Z", "Y",
+])
+def test_prefixes(prefix):
+ val = random()
+ m = measure(val, "s")
+ assert m.toval(f"{prefix}s") == pytest.approx(val / _PREFIXES[prefix])
+
+@pytest.mark.parametrize("from_, equivalent, to, expected", [
+ ((1., "m/s"), False, "km/h", nullcontext(3.6)),
+ ((1., "eV"), False, "THz", pytest.raises(UnitError)),
+ ((1., "eV"), True, "THz", nullcontext(241.799)),
+ ((1., "eV"), True, "K", nullcontext(11604.52)),
+])
+def test_conversion(from_, equivalent, to, expected):
+ m = measure(*from_, equivalent=equivalent)
+ with expected as val:
+ assert m.toval(to) == pytest.approx(val)
+
+@pytest.mark.parametrize("in_units, op, out_unit, out_val", [
+ (((1., "s"), (1., "ms")), add, "s", 1.001),
+ (((1., "s"), (1., "ms")), sub, "s", 0.999),
+
+ (((1., "m"), (5., "hm")), mul, "m2", 500.),
+ (((500., "m2"), (10., "cm")), mul, "m3", 50.),
+ (((50., "m3"), 20.), mul, "m3", 1000.),
+
+ (((1., "m"), (5., "hm")), truediv, "au", 0.002),
+ (((0.002, "au"), 0.0001), truediv, "au", 20.),
+ (((20., "au"), (5., "hm")), truediv, "1/m", 4.e-2),
+
+ (((4., "m"), 3), pow, "m3", 64.),
+
+ # In-place
+ (((1., "s"), (1., "ms")), iadd, "s", 1.001),
+ (((1., "s"), (1., "ms")), isub, "s", 0.999),
+
+ (((1., "m"), (5., "hm")), imul, "m2", 500.),
+ (((500., "m2"), (10., "cm")), imul, "m3", 50.),
+ (((50., "m3"), 20.), imul, "m3", 1000.),
+
+ (((1., "m"), (5., "hm")), itruediv, "au", 0.002),
+ (((0.002, "au"), 0.0001), itruediv, "au", 20.),
+ (((20., "au"), (5., "hm")), itruediv, "1/m", 4.e-2),
+
+ (((4., "m"), 3), ipow, "m3", 64.),
+
+ # Other ops
+ (((10.2, "m/s"),), floor, None, 10.),
+ (((3.6, "m/s"),), floor, None, 3.),
+ (((50.3, "km/h"),), floor, None, 50.),
+
+ (((10.2, "m/s"),), ceil, None, 11.),
+ (((3.6, "m/s"),), ceil, None, 4.),
+ (((50.3, "km/h"),), ceil, None, 51.),
+
+ (((10.2, "m/s"),), round, None, 10.),
+ (((3.6, "m/s"),), round, None, 4.),
+ (((50.3, "km/h"),), round, None, 50.),
+
+ (((50.3, "km/h"),), round, None, 50.),
+
+])
+def test_operators(in_units, op, out_unit, out_val):
+ m = op(*map(_measure_or_val, in_units))
+ assert m.toval(out_unit) == pytest.approx(out_val)
+
+
+def test_sqrt():
+ m = measure(4.0, "m2/s2")
+
+ m = m.sqrt()
+
+ assert m.toval() == 2.0
+ assert m.dimension == [0, 1, -1, 0, 0, 0, 0, 0, 0]