diff --git a/MDANSE/Src/MDANSE/Framework/Units.py b/MDANSE/Src/MDANSE/Framework/Units.py index c97e35588..3c09c4b06 100644 --- a/MDANSE/Src/MDANSE/Framework/Units.py +++ b/MDANSE/Src/MDANSE/Framework/Units.py @@ -14,10 +14,12 @@ # along with this program. If not, see . # import copy +import json import math import numbers - -import json +from collections import defaultdict +from functools import singledispatchmethod +from typing import Optional, Tuple from MDANSE.Core.Platform import PLATFORM from MDANSE.Core.Singleton import Singleton @@ -70,159 +72,370 @@ class UnitError(Exception): pass -def get_trailing_digits(s): - for i in range(len(s)): - if s[i:].isdigit(): - return s[:i], int(s[i:]) +def get_trailing_digits(string: str) -> Tuple[str, int]: + """Get digits from the end of a string. + + Always returns ``1`` if no digits. + + Parameters + ---------- + string : str + String to parse. + + Returns + ------- + str + String with digits stripped. + int + Digits from end of string as ``int``. + + Examples + -------- + >>> get_trailing_digits("str123") + ('str', 123) + >>> get_trailing_digits("nodigits") + ('nodigits', 1) + >>> get_trailing_digits("123preceding") + ('123preceding', 1) + """ + for i in range(len(string)): + if string[i:].isdigit(): + return string[:i], int(string[i:]) else: - return s, 1 + return string, 1 + + +class _Unit: + """Unit handler. + + Handles all basic functions of units with correct dimensionality + and string printing. + + Parameters + ---------- + uname : str + Name of the unit. + factor : float + Factor relative to internal units. + + Extra Parameters + ---------------- + kg : int + Mass dimension. + m : int + Length dimension. + s : int + Time dimension. + K : int + Temperature dimension. + mol : int + Count dimension. + A : int + Current dimension. + cd : int + Luminous intensity dimension. + rad : int + Angular dimension. + sr : int + Solid angular dimension. + """ + def __init__( + self, + uname: str, + factor: float, + kg: int = 0, + m: int = 0, + s: int = 0, + K: int = 0, + mol: int = 0, + A: int = 0, + cd: int = 0, + rad: int = 0, + sr: int = 0, + ): + self._factor = factor + self._dimension = [kg, m, s, K, mol, A, cd, rad, sr] + self._format = "g" + self._uname = uname + self._ounit = None + self._out_factor = None + self._equivalent = False -def _parse_unit(iunit): - max_prefix_length = 0 - for p in _PREFIXES: - max_prefix_length = max(max_prefix_length, len(p)) + def __add__(self, other): + """Add two _Unit instances. - iunit = iunit.strip() + To be added, the units have to be analog or equivalent. - iunit, upower = get_trailing_digits(iunit) - if not iunit: - raise UnitError("Invalid unit") + Parameters + ---------- + other : _Unit + Unit to add. - for i in range(len(iunit)): - if UNITS_MANAGER.has_unit(iunit[i:]): - prefix = iunit[:i] - iunit = iunit[i:] - break - else: - raise UnitError(f"The unit {iunit} is unknown") + Raises + ------ + UnitError + Units are not equivalent or incompatible. - if prefix: - if prefix not in _PREFIXES: - raise UnitError(f"The prefix {prefix} is unknown") - prefix = _PREFIXES[prefix] - else: - prefix = 1.0 + Examples + -------- + >>> print(measure(10, 'm') + measure(20, 'km')) + 20010 m + """ + u = copy.deepcopy(self) - unit = UNITS_MANAGER.get_unit(iunit) + if u.is_analog(other): + u._factor += other._factor + elif self._equivalent: + equivalence_factor = u.get_equivalence_factor(other) + if equivalence_factor is None: + raise UnitError("The units are not equivalent") - unit = _Unit(iunit, prefix * unit._factor, *unit._dimension) + u._factor += other._factor / equivalence_factor + else: + raise UnitError("Incompatible units.") - unit **= upower + return u - return unit + def __sub__(self, other): + """Subtract _Unit instances. + To be subtracted, the units have to be analog or equivalent. -def _str_to_unit(s): - if UNITS_MANAGER.has_unit(s): - unit = UNITS_MANAGER.get_unit(s) - return copy.deepcopy(unit) + >>> print(measure(20, 'km') + measure(10, 'm')) + 20.01 km + """ + u = copy.deepcopy(self) - else: - unit = _Unit("au", 1.0) + if u.is_analog(other): + u._factor -= other._factor + elif u._equivalent: + equivalence_factor = u.get_equivalence_factor(other) + if equivalence_factor is None: + raise UnitError("The units are not equivalent") - splitted_units = s.split("/") + u._factor -= other._factor / equivalence_factor + else: + raise UnitError("Incompatible units") - if len(splitted_units) == 1: - units = splitted_units[0].split(" ") - for u in units: - u = u.strip() - unit *= _parse_unit(u) - unit._uname = s + return u - return unit + def __truediv__(self, other): + """Divide two _Unit instances. - elif len(splitted_units) == 2: - numerator = splitted_units[0].strip() - if numerator != "1": - numerator = numerator.split(" ") - for u in numerator: - u = u.strip() - unit *= _parse_unit(u) + To be divided, the units have to be analog or equivalent. - denominator = splitted_units[1].strip().split(" ") - for u in denominator: - u = u.strip() - unit /= _parse_unit(u) + Parameters + ---------- + other : _Unit + Unit to add. - unit._uname = s + Raises + ------ + UnitError + Units are not equivalent or incompatible. - return unit + Examples + -------- + >>> print(measure(100, 'V') / measure(10, 'kohm')) + 0.01 A1 + >>> print(measure(100, 'V') / 10) + 10 V + """ + u = copy.deepcopy(self) + if isinstance(other, numbers.Number): + u._factor /= other + elif isinstance(other, _Unit): + u._div_by(other) + else: + raise UnitError("Invalid operand") + + return u + def __floordiv__(self, other): + """Divide two _Unit instances and truncate. + + To be divided, the units have to be analog or equivalent. + + Parameters + ---------- + other : _Unit + Unit to add. + + Raises + ------ + UnitError + Units are not equivalent or incompatible. + + Examples + -------- + >>> print(measure(10, 'kohm') // measure(10, 'V')) + 1000 1 / A1 + >>> print(measure(15, 'ohm') // 10) + 1 ohm + """ + u = copy.deepcopy(self) + if isinstance(other, numbers.Number): + u._factor //= other + elif isinstance(other, _Unit): + u._div_by(other) + u._factor = math.floor(u._factor) else: - raise UnitError(f"Invalid unit: {s}") + raise UnitError("Invalid operand") + return u -class _Unit(object): - def __init__( - self, uname, factor, kg=0, m=0, s=0, K=0, mol=0, A=0, cd=0, rad=0, sr=0 - ): - self._factor = factor + def __mul__(self, other): + """Multiply _Unit instances or scaling factors. - self._dimension = [kg, m, s, K, mol, A, cd, rad, sr] + Examples + -------- + >>> print(measure(10, 'm/s') * measure(10, 's')) + 100 m1 + >>> print(measure(10, 'm') * measure(10, 's')) + 100 m1 s1 + >>> print(measure(10, 'm') * 10) + 100 m + """ - self._format = "g" + u = copy.deepcopy(self) + if isinstance(other, numbers.Number): + u._factor *= other + return u + elif isinstance(other, _Unit): + u._mult_by(other) + return u + else: + raise UnitError("Invalid operand") - self._uname = uname + def __pow__(self, n: float): + """Raise a _Unit to a factor. - self._ounit = None + Examples + -------- + >>> print(measure(10.5, 'm/s')**2) + 110.25 m2 / s2 + """ + output_unit = copy.copy(self) + output_unit._ounit = None + output_unit._out_factor = None + output_unit._factor = pow(output_unit._factor, n) + for i in range(len(output_unit._dimension)): + output_unit._dimension[i] *= n - self._out_factor = None + return output_unit - self._equivalent = False + def __float__(self) -> float: + """Return the value of a _Unit coerced to float. - def __add__(self, other): - """Add _Unit instances. To be added, the units has to be analog or equivalent. + Examples + -------- + >>> float(measure(10.5, 'm/s')) + 10.5 - >>> print(measure(10, 'm') + measure(20, 'km')) - 20010 m + See Also + -------- + __int__ : Truncate value. """ + return float(self.toval()) - u = copy.deepcopy(self) + def __int__(self) -> int: + """Return the value of a _Unit coerced to integer. - if u.is_analog(other): - u._factor += other._factor - return u - elif self._equivalent: - equivalence_factor = u.get_equivalence_factor(other) - if equivalence_factor is not None: - u._factor += other._factor / equivalence_factor - return u - else: - raise UnitError("The units are not equivalent") + Notes + ------ + This will happen to the value in the default output unit: + + Examples + -------- + >>> int(measure(10.5, 'm/s')) + 10 + """ + return int(self.toval()) + + def __ceil__(self): + """Ceil of a _Unit value in canonical units. + + Examples + -------- + >>> print(measure(10.2, 'm/s').ceiling()) + 11 m/s + >>> print(measure(3.6, 'm/s').ounit('km/h').ceiling()) + 13 km/h + >>> print(measure(50.3, 'km/h').ceiling()) + 51 km/h + """ + + r = copy.deepcopy(self) + + if r._ounit is not None: + val = math.ceil(r.toval(r._ounit)) + newu = _Unit("au", val) + newu *= _str_to_unit(r._ounit) + return newu.ounit(r._ounit) else: - raise UnitError("Incompatible units") + r._factor = math.ceil(r._factor) + return r - def __truediv__(self, other): - """Divide _Unit instances. + def __floor__(self): + """Floor of a _Unit value in canonical units. - >>> print(measure(100, 'V') / measure(10, 'kohm')) - 0.0100 A + Examples + -------- + >>> print(measure(10.2, 'm/s').floor()) + 10 m/s + >>> print(measure(3.6, 'm/s').ounit('km/h').floor()) + 12 km/h + >>> print(measure(50.3, 'km/h').floor()) + 50 km/h """ - u = copy.deepcopy(self) - if isinstance(other, numbers.Number): - u._factor /= other - return u - elif isinstance(other, _Unit): - u._div_by(other) - return u + r = copy.deepcopy(self) + + if r._ounit is not None: + val = math.floor(r.toval(r._ounit)) + newu = _Unit("au", val) + newu *= _str_to_unit(r._ounit) + return newu.ounit(r._ounit) else: - raise UnitError("Invalid operand") + r._factor = math.floor(r._factor) + return r - def __float__(self): - """Return the value of a _Unit coerced to float. See __int__.""" + def __round__(self, ndigits=None): + """Round of a _Unit value in canonical units. - return float(self.toval()) + Examples + -------- + >>> print(measure(10.2, 'm/s').round()) + 10 m/s + >>> print(measure(3.6, 'm/s').ounit('km/h').round()) + 13 km/h + >>> print(measure(50.3, 'km/h').round()) + 50 km/h + """ - def __floordiv__(self, other): - u = copy.deepcopy(self) - u._div_by(other) - u._factor = math.floor(u._factor) - return u + r = copy.deepcopy(self) + + if r._ounit is not None: + val = round(r.toval(r._ounit), ndigits) + newu = _Unit("au", val) + newu *= _str_to_unit(r._ounit) + return newu.ounit(r._ounit) + else: + r._factor = round(r._factor, ndigits) + return r + + ceiling = __ceil__ + floor = __floor__ + round = __round__ def __iadd__(self, other): - """Add _Unit instances. See __add__.""" + """Add _Unit instances. + + See Also + -------- + __add__ + """ if self.is_analog(other): self._factor += other._factor @@ -238,7 +451,12 @@ def __iadd__(self, other): raise UnitError("Incompatible units") def __itruediv__(self, other): - """Divide _Unit instances. See __div__.""" + """Divide _Unit instances. + + See Also + -------- + __div__ + """ if isinstance(other, numbers.Number): self._factor /= other @@ -250,12 +468,24 @@ def __itruediv__(self, other): raise UnitError("Invalid operand") def __ifloordiv__(self, other): + """Divide _Unit instances and truncate. + + See Also + -------- + __truediv__ + """ self._div_by(other) self._factor = math.floor(self._factor) return self def __imul__(self, other): - """Multiply _Unit instances. See __mul__.""" + """ + Multiply _Unit instances. + + See Also + -------- + __mul__ + """ if isinstance(other, numbers.Number): self._factor *= other @@ -266,17 +496,6 @@ def __imul__(self, other): else: raise UnitError("Invalid operand") - def __int__(self): - """Return the value of a _Unit coerced to integer. - - Note that this will happen to the value in the default output unit: - - >>> print(int(measure(10.5, 'm/s'))) - 10 - """ - - return int(self.toval()) - def __ipow__(self, n): self._factor = pow(self._factor, n) for i in range(len(self._dimension)): @@ -288,7 +507,7 @@ def __ipow__(self, n): return self def __isub__(self, other): - """Substract _Unit instances. See __sub__.""" + """Subtract _Unit instances. See __sub__.""" if self.is_analog(other): self._factor -= other._factor @@ -303,36 +522,13 @@ def __isub__(self, other): else: raise UnitError("Incompatible units") - def __mul__(self, other): - """Multiply _Unit instances. - - >>> print(measure(10, 'm/s') * measure(10, 's')) - 100.0000 m - """ - - u = copy.deepcopy(self) - if isinstance(other, numbers.Number): - u._factor *= other - return u - elif isinstance(other, _Unit): - u._mult_by(other) - return u - else: - raise UnitError("Invalid operand") - - def __pow__(self, n): - output_unit = copy.copy(self) - output_unit._ounit = None - output_unit._out_factor = None - output_unit._factor = pow(output_unit._factor, n) - for i in range(len(output_unit._dimension)): - output_unit._dimension[i] *= n - - return output_unit - def __radd__(self, other): - """Add _Unit instances. See __add__.""" + """Add _Unit instances. + See Also + -------- + __add__ + """ return self.__add__(other) def __rdiv__(self, other): @@ -360,32 +556,10 @@ def __rmul__(self, other): raise UnitError("Invalid operand") def __rsub__(self, other): - """Substract _Unit instances. See __sub__.""" + """Subtract _Unit instances. See __sub__.""" return other.__sub__(self) - def __sub__(self, other): - """Substract _Unit instances. To be substracted, the units has to be analog or equivalent. - - >>> print(measure(20, 'km') + measure(10, 'm')) - 19990 m - """ - - u = copy.deepcopy(self) - - if u.is_analog(other): - u._factor -= other._factor - return u - elif u._equivalent: - equivalence_factor = u.get_equivalence_factor(other) - if equivalence_factor is not None: - u._factor -= other._factor / equivalence_factor - return u - else: - raise UnitError("The units are not equivalent") - else: - raise UnitError("Incompatible units") - def __str__(self): unit = copy.copy(self) @@ -423,7 +597,26 @@ def __str__(self): return s - def _div_by(self, other): + def _div_by(self, other) -> None: + """Compute divided unit including new dimensionality. + + Parameters + ---------- + other : _Unit + Factor to divide by. + + Raises + ------ + UnitError + If other is not compatible. + + Examples + -------- + >>> a = measure(2., "ang") + >>> a._div_by(measure(4., "s")) + >>> print(a) + 5e-11 m1 / s1 + """ if self.is_analog(other): self._factor /= other._factor self._dimension = [0, 0, 0, 0, 0, 0, 0, 0, 0] @@ -442,20 +635,39 @@ def _div_by(self, other): self._ounit = None self._out_factor = None - def _mult_by(self, other): + def _mult_by(self, other) -> None: + """Compute multiplied unit including new dimensionality. + + Parameters + ---------- + other : _Unit + Factor to multiply by. + + Raises + ------ + UnitError + If other is not compatible. + + Examples + -------- + >>> a = measure(1., "ang") + >>> a._mult_by(measure(3., "s")) + >>> print(a) + 3e-10 m1 s1 + """ if self.is_analog(other): self._factor *= other._factor for i in range(len(self._dimension)): self._dimension[i] = 2.0 * self._dimension[i] elif self._equivalent: equivalence_factor = self.get_equivalence_factor(other) - if equivalence_factor is not None: - self._factor *= other._factor / equivalence_factor - for i in range(len(self._dimension)): - self._dimension[i] = 2 * self._dimension[i] - return - else: + if equivalence_factor is None: raise UnitError("The units are not equivalent") + + self._factor *= other._factor / equivalence_factor + for i in range(len(self._dimension)): + self._dimension[i] = 2 * self._dimension[i] + return else: self._factor *= other._factor for i in range(len(self._dimension)): @@ -464,28 +676,6 @@ def _mult_by(self, other): self._ounit = None self._out_factor = None - def ceil(self): - """Ceil of a _Unit value in canonical units. - - >>> print(measure(10.2, 'm/s').ceiling()) - 10.0000 m / s - >>> print(measure(3.6, 'm/s').ounit('km/h').ceiling()) - 10.0 km / h - >>> print(measure(50.3, 'km/h').ceiling()) - 50.0 km / h - """ - - r = copy.deepcopy(self) - - if r._ounit is not None: - val = math.ceil(r.toval(r._ounit)) - newu = _Unit("au", val) - newu *= _str_to_unit(r._ounit) - return newu.ounit(r._ounit) - else: - r._factor = math.ceil(r._factor) - return r - @property def dimension(self): """Getter for _dimension attribute. Returns a copy.""" @@ -493,15 +683,13 @@ def dimension(self): return copy.copy(self._dimension) @property - def equivalent(self): + def equivalent(self) -> bool: """Getter for _equivalent attribute.""" return self._equivalent @equivalent.setter def equivalent(self, equivalent): - """Setter for _equivalent attribute.""" - self._equivalent = equivalent @property @@ -510,28 +698,6 @@ def factor(self): return self._factor - def floor(self): - """Floor of a _Unit value in canonical units. - - >>> print(measure(10.2, 'm/s').floor()) - 10.0000 m / s - >>> print(measure(3.6, 'm/s').ounit('km/h').floor()) - 10.0 km / h - >>> print(measure(50.3, 'km/h').floor()) - 50.0 km / h - """ - - r = copy.deepcopy(self) - - if r._ounit is not None: - val = math.floor(r.toval(r._ounit)) - newu = _Unit("au", val) - newu *= _str_to_unit(r._ounit) - return newu.ounit(r._ounit) - else: - r._factor = math.floor(r._factor) - return r - @property def format(self): """Getter for the output format.""" @@ -540,43 +706,106 @@ def format(self): @format.setter def format(self, fmt): - """Setter for the output format.""" - self._format = fmt - def is_analog(self, other): - """Returns True if the other unit is analog to this unit. Analog units are units whose dimension vector exactly matches.""" + def is_analog(self, other) -> bool: + """Whether two units are analog. + + Analog units are units whose dimension vector exactly matches. + + Parameters + ---------- + other : _Unit + Unit to test. + + Returns + ------- + bool + Whether two units are "analog". + + Examples + -------- + >>> a, b = measure(1., "km"), measure(1., "ang") + >>> a.is_analog(b) + True + >>> a, b = measure(1., "km"), measure(1., "ohm") + >>> a.is_analog(b) + False + """ return self._dimension == other._dimension - def get_equivalence_factor(self, other): - """Returns the equivalence factor if the other unit is equivalent to this unit. Equivalent units are units whose dimension are related through a constant + def get_equivalence_factor(self, other) -> Optional[float]: + """Returns the equivalence factor if other unit is equivalent. + + Equivalent units are units whose dimension are related through a constant (e.g. energy and mass, or frequency and temperature). + + Parameters + ---------- + other : _Unit + Potentially equivalent unit. + + Returns + ------- + Optional[float] + Equivalence factor to transform from one to the other + or ``None`` if not equivalent. + + See Also + -------- + _EQUIVALENCES : Dict of equivalent units. + add_equivalence : Add new equivalence to dict. + + Examples + -------- + >>> a = measure(1., "1/m") + >>> a.get_equivalence_factor(measure(1., "J/mol")) + 0.000119627 + >>> print(a.get_equivalence_factor(measure(1., "ang"))) + None """ _, upower = get_trailing_digits(self._uname) - dimension = tuple([d / upower for d in self._dimension]) + dimension = tuple(d / upower for d in self._dimension) + if dimension not in _EQUIVALENCES: return None powerized_equivalences = {} - for k, v in list(_EQUIVALENCES[dimension].items()): - pk = tuple([d * upower for d in k]) + for k, v in _EQUIVALENCES[dimension].items(): + pk = tuple(d * upower for d in k) powerized_equivalences[pk] = pow(v, upower) odimension = tuple(other._dimension) if odimension in powerized_equivalences: return powerized_equivalences[odimension] - else: - return None - def ounit(self, ounit): + return None + + def ounit(self, ounit: str): """Set the preferred unit for output. + Parameters + ---------- + ounit : str + Preferred output unit. + + Raises + ------ + UnitError + Units are incompatible. + + Notes + ----- + Returns a modified reference to self not a new object. + + Examples + -------- >>> a = measure(1, 'kg m2 / s2') >>> print(a) - 1.0000 kg m2 / s2 + 1 kg m2 / s2 >>> print(a.ounit('J')) - 1.0000 J + 1 J """ out_factor = _str_to_unit(ounit) @@ -585,57 +814,56 @@ def ounit(self, ounit): self._ounit = ounit self._out_factor = out_factor return self + elif self._equivalent: - if self.get_equivalence_factor(out_factor) is not None: - self._ounit = ounit - self._out_factor = out_factor - return self - else: + if self.get_equivalence_factor(out_factor) is None: raise UnitError("The units are not equivalents") - else: - raise UnitError("The units are not compatible") - - def round(self): - """Round of a _Unit value in canonical units. - - >>> print(measure(10.2, 'm/s').round()) - 10.0000 m / s - >>> print(measure(3.6, 'm/s').ounit('km/h').round()) - 11.0 km / h - >>> print(measure(50.3, 'km/h').round()) - 50.0 km / h - """ - r = copy.deepcopy(self) + self._ounit = ounit + self._out_factor = out_factor + return self - if r._ounit is not None: - val = round(r.toval(r._ounit)) - newu = _Unit("au", val) - newu *= _str_to_unit(r._ounit) - return newu.ounit(r._ounit) else: - r._factor = round(r._factor) - return r + raise UnitError("The units are not compatible") def sqrt(self): """Square root of a _Unit. + Returns + ------- + _Unit + New unit which is sqrt of original. + + Examples + -------- >>> print(measure(4, 'm2/s2').sqrt()) - 2.0000 m / s + 2 m1.0 / s-1.0 """ return self**0.5 - def toval(self, ounit=""): + def toval(self, ounit: str = "") -> float: """Returns the numeric value of a unit. The value is given in ounit or in the default output unit. + Parameters + ---------- + ounit : str + Unit to convert to. + + Returns + ------- + float + Value in output unit. + + Examples + -------- >>> v = measure(100, 'km/h') >>> v.toval() 100.0 >>> v.toval(ounit='m/s') - 27.777777777777779 + 27.77777777777778 """ newu = copy.deepcopy(self) @@ -660,19 +888,119 @@ def toval(self, ounit=""): return newu._factor -class UnitsManagerEncoder(json.JSONEncoder): - def default(self, obj): - if isinstance(obj, UnitsManager): - d = {} - for k, v in obj.units: - d[k] = {"factor": v.factor, "dimension": v.dimension} - return d - elif isinstance(obj, _Unit): - return {"factor": obj.factor, "dimension": obj.dimension} - return json.JSONEncoder.default(self, obj) +def _parse_unit(iunit: str) -> _Unit: + """Parse single unit as a string into a Unit type. + + Parameters + ---------- + iunit : str + String to parse. + + Returns + ------- + _Unit + Expected unit. + + Raises + ------ + UnitError + If string does not contain valid unit. + """ + + max_prefix_length = 0 + for p in _PREFIXES: + max_prefix_length = max(max_prefix_length, len(p)) + + iunit = iunit.strip() + + iunit, upower = get_trailing_digits(iunit) + if not iunit: + raise UnitError("Invalid unit") + + for i in range(len(iunit)): + if UNITS_MANAGER.has_unit(iunit[i:]): + prefix = iunit[:i] + iunit = iunit[i:] + break + else: + raise UnitError(f"The unit {iunit} is unknown") + + if prefix: + if prefix not in _PREFIXES: + raise UnitError(f"The prefix {prefix} is unknown") + prefix = _PREFIXES[prefix] + else: + prefix = 1.0 + + unit = UNITS_MANAGER.get_unit(iunit) + + unit = _Unit(iunit, prefix * unit._factor, *unit._dimension) + + unit **= upower + + return unit + + +def _str_to_unit(s: str) -> _Unit: + """Parse general string into unit description. + + Parameters + ---------- + s : str + String to parse. + + Returns + ------- + _Unit + Parsed unit. + + Raises + ------ + UnitError + String is not a valid unit specification. + """ + if UNITS_MANAGER.has_unit(s): + unit = UNITS_MANAGER.get_unit(s) + return copy.deepcopy(unit) + + else: + unit = _Unit("au", 1.0) + + splitted_units = s.split("/") + + if len(splitted_units) == 1: + units = splitted_units[0].split(" ") + for u in units: + u = u.strip() + unit *= _parse_unit(u) + unit._uname = s + + return unit + + elif len(splitted_units) == 2: + numerator = splitted_units[0].strip() + if numerator != "1": + numerator = numerator.split(" ") + for u in numerator: + u = u.strip() + unit *= _parse_unit(u) + + denominator = splitted_units[1].strip().split(" ") + for u in denominator: + u = u.strip() + unit /= _parse_unit(u) + + unit._uname = s + + return unit + + else: + raise UnitError(f"Invalid unit: {s}") class UnitsManager(metaclass=Singleton): + """Database dictionary for handling units.""" + _UNITS = {} _DEFAULT_DATABASE = ( @@ -691,39 +1019,49 @@ def add_unit( uname, factor, kg, m, s, K, mol, A, cd, rad, sr ) - def delete_unit(self, uname): + def delete_unit(self, uname) -> None: if uname in UnitsManager._UNITS: del UnitsManager._UNITS[uname] - def get_unit(self, uname): + def get_unit(self, uname) -> Optional[_Unit]: return UnitsManager._UNITS.get(uname, None) - def has_unit(self, uname): + def has_unit(self, uname) -> bool: return uname in UnitsManager._UNITS - def load(self): + def load(self) -> None: + """Load units from databases. + + Fill self with unit infomration. + """ UnitsManager._UNITS.clear() d = {} + with open(UnitsManager._DEFAULT_DATABASE, "r") as fin: d.update(json.load(fin)) + try: with open(UnitsManager._USER_DATABASE, "r") as fin: d.update(json.load(fin)) - except Exception: + + except FileNotFoundError: self.save() + finally: - for uname, udict in list(d.items()): + for uname, udict in d.items(): factor = udict.get("factor", 1.0) dim = udict.get("dimension", [0, 0, 0, 0, 0, 0, 0, 0, 0]) UnitsManager._UNITS[uname] = _Unit(uname, factor, *dim) def save(self): + """Write self to custom user database.""" with open(UnitsManager._USER_DATABASE, "w") as fout: json.dump(UnitsManager._UNITS, fout, indent=4, cls=UnitsManagerEncoder) @property def units(self): + """Direct access to unit database.""" return UnitsManager._UNITS @units.setter @@ -731,15 +1069,59 @@ def units(self, units): UnitsManager._UNITS = units -_EQUIVALENCES = {} +class UnitsManagerEncoder(json.JSONEncoder): + """Custom encoder for writing units.""" + + @singledispatchmethod + def default(self, obj): + return json.JSONEncoder.default(self, obj) + @default.register(UnitsManager) + def _(self, obj): + return {k: {"factor": v.factor, "dimension": v.dimension} for k, v in obj.units} -def add_equivalence(dim1, dim2, factor): - _EQUIVALENCES.setdefault(dim1, {}).__setitem__(dim2, factor) - _EQUIVALENCES.setdefault(dim2, {}).__setitem__(dim1, 1.0 / factor) + @default.register(_Unit) + def _(self, obj): + return {"factor": obj.factor, "dimension": obj.dimension} -def measure(val, iunit="au", ounit="", equivalent=False): +#: Set of units considered directly or indirectly equivalent. +_EQUIVALENCES = defaultdict(dict) + + +def add_equivalence(dim1, dim2, factor): + _EQUIVALENCES[dim1][dim2] = factor + _EQUIVALENCES[dim2][dim1] = 1.0 / factor + + +def measure( + val: float, iunit: str = "au", ounit: str = "", *, equivalent: bool = False +) -> _Unit: + """Create a unit bearing object. + + Parses i/ounits and returns the relevant data object. + + Parameters + ---------- + val : float + Value for unit. + iunit : str + Input unit. + ounit : str + Desired output unit. + equivalent : bool + Whether the unit is to be considered "equivalent". + + Returns + ------- + _Unit + Desired unit. + + Examples + -------- + >>> print(measure(1., 'ang')) + 1 ang + """ if iunit: unit = _str_to_unit(iunit) unit *= val diff --git a/MDANSE/Tests/UnitTests/test_units.py b/MDANSE/Tests/UnitTests/test_units.py index 28e07d624..342e7c90c 100644 --- a/MDANSE/Tests/UnitTests/test_units.py +++ b/MDANSE/Tests/UnitTests/test_units.py @@ -13,170 +13,104 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . # -import unittest - -from MDANSE.Framework.Units import _PREFIXES, UnitError, measure - - -class TestUnits(unittest.TestCase): - """ - Unittest for the geometry-related functions - """ - - def test_basic_units(self): - m = measure(1.0, "kg") - self.assertAlmostEqual(m.toval(), 1.0, delta=1.0e-09) - - m = measure(1.0, "m") - self.assertAlmostEqual(m.toval(), 1.0, delta=1.0e-09) - - m = measure(1.0, "s") - self.assertAlmostEqual(m.toval(), 1.0, delta=1.0e-09) - - m = measure(1.0, "K") - self.assertAlmostEqual(m.toval(), 1.0, delta=1.0e-09) - - m = measure(1.0, "mol") - self.assertAlmostEqual(m.toval(), 1.0, delta=1.0e-09) - - m = measure(1.0, "A") - self.assertAlmostEqual(m.toval(), 1.0, delta=1.0e-09) - - m = measure(1.0, "cd") - self.assertAlmostEqual(m.toval(), 1.0, delta=1.0e-09) - - m = measure(1.0, "rad") - self.assertAlmostEqual(m.toval(), 1.0, delta=1.0e-09) - - m = measure(1.0, "sr") - self.assertAlmostEqual(m.toval(), 1.0, delta=1.0e-09) - - def test_prefix(self): - m = measure(1.0, "s") - self.assertAlmostEqual(m.toval("ys"), 1.0 / _PREFIXES["y"], delta=1.0e-09) - self.assertAlmostEqual(m.toval("zs"), 1.0 / _PREFIXES["z"], delta=1.0e-09) - self.assertAlmostEqual(m.toval("as"), 1.0 / _PREFIXES["a"], delta=1.0e-09) - self.assertAlmostEqual(m.toval("fs"), 1.0 / _PREFIXES["f"], delta=1.0e-09) - self.assertAlmostEqual(m.toval("ps"), 1.0 / _PREFIXES["p"], delta=1.0e-09) - self.assertAlmostEqual(m.toval("ns"), 1.0 / _PREFIXES["n"], delta=1.0e-09) - self.assertAlmostEqual(m.toval("us"), 1.0 / _PREFIXES["u"], delta=1.0e-09) - self.assertAlmostEqual(m.toval("ms"), 1.0 / _PREFIXES["m"], delta=1.0e-09) - self.assertAlmostEqual(m.toval("cs"), 1.0 / _PREFIXES["c"], delta=1.0e-09) - self.assertAlmostEqual(m.toval("ds"), 1.0 / _PREFIXES["d"], delta=1.0e-09) - self.assertAlmostEqual(m.toval("das"), 1.0 / _PREFIXES["da"], delta=1.0e-09) - self.assertAlmostEqual(m.toval("hs"), 1.0 / _PREFIXES["h"], delta=1.0e-09) - self.assertAlmostEqual(m.toval("ks"), 1.0 / _PREFIXES["k"], delta=1.0e-09) - self.assertAlmostEqual(m.toval("Ms"), 1.0 / _PREFIXES["M"], delta=1.0e-09) - self.assertAlmostEqual(m.toval("Gs"), 1.0 / _PREFIXES["G"], delta=1.0e-09) - self.assertAlmostEqual(m.toval("Ts"), 1.0 / _PREFIXES["T"], delta=1.0e-09) - self.assertAlmostEqual(m.toval("Ps"), 1.0 / _PREFIXES["P"], delta=1.0e-09) - self.assertAlmostEqual(m.toval("Es"), 1.0 / _PREFIXES["E"], delta=1.0e-09) - self.assertAlmostEqual(m.toval("Zs"), 1.0 / _PREFIXES["Z"], delta=1.0e-09) - self.assertAlmostEqual(m.toval("Ys"), 1.0 / _PREFIXES["Y"], delta=1.0e-09) - - def test_composite_units(self): - m = measure(1.0, "m/s") - self.assertAlmostEqual(m.toval("km/h"), 3.6, delta=1.0e-09) - - def test_add_units(self): - m1 = measure(1.0, "s") - m2 = measure(1.0, "ms") - - m = m1 + m2 - self.assertAlmostEqual(m.toval("s"), 1.001, delta=1.0e-09) - - m += m2 - self.assertAlmostEqual(m.toval("s"), 1.002, delta=1.0e-09) - - def test_substract_units(self): - m1 = measure(1.0, "s") - m2 = measure(1.0, "ms") - - m = m1 - m2 - self.assertAlmostEqual(m.toval("s"), 0.999, delta=1.0e-09) - - m -= m2 - self.assertAlmostEqual(m.toval("s"), 0.998, delta=1.0e-09) - - def test_product_units(self): - m1 = measure(1.0, "m") - m2 = measure(5.0, "hm") - - m = m1 * m2 - self.assertAlmostEqual(m.toval("m2"), 500, delta=1.0e-09) - - m *= measure(10, "cm") - self.assertAlmostEqual(m.toval("m3"), 50, delta=1.0e-09) - - m *= 20 - self.assertAlmostEqual(m.toval("m3"), 1000, delta=1.0e-09) - - def test_divide_units(self): - m1 = measure(1.0, "m") - m2 = measure(5.0, "hm") - - m = m1 / m2 - self.assertAlmostEqual(m.toval("au"), 0.002, delta=1.0e-09) - - m /= 0.0001 - self.assertAlmostEqual(m.toval("au"), 20.0, delta=1.0e-09) - - m /= m2 - self.assertRaises(UnitError, m.toval, "au") - self.assertAlmostEqual(m.toval("1/m"), 4.0e-02, delta=1.0e-09) - - def test_floor_unit(self): - self.assertAlmostEqual( - measure(10.2, "m/s").floor().toval(), 10.0, delta=1.0e-09 - ) - self.assertAlmostEqual( - measure(3.6, "m/s").ounit("km/h").floor().toval(), 12.0, delta=1.0e-09 - ) - self.assertAlmostEqual( - measure(50.3, "km/h").floor().toval(), 50.0, delta=1.0e-09 - ) - - def test_ceil_unit(self): - self.assertAlmostEqual(measure(10.2, "m/s").ceil().toval(), 11.0, delta=1.0e-09) - self.assertAlmostEqual( - measure(3.6, "m/s").ounit("km/h").ceil().toval(), 13.0, delta=1.0e-09 - ) - self.assertAlmostEqual( - measure(50.3, "km/h").ceil().toval(), 51.0, delta=1.0e-09 - ) - - def test_round_unit(self): - self.assertAlmostEqual( - measure(10.2, "m/s").round().toval(), 10.0, delta=1.0e-09 - ) - self.assertAlmostEqual( - measure(3.6, "m/s").ounit("km/h").round().toval(), 13.0, delta=1.0e-09 - ) - self.assertAlmostEqual( - measure(50.3, "km/h").round().toval(), 50.0, delta=1.0e-09 - ) - - def test_int_unit(self): - self.assertEqual(int(measure(10.2, "km/h")), 10) - - def test_sqrt_unit(self): - m = measure(4.0, "m2/s2") - - m = m.sqrt() - - self.assertAlmostEqual(m.toval(), 2.0, delta=1.0e-09) - self.assertEqual(m.dimension, [0, 1, -1, 0, 0, 0, 0, 0, 0]) - - def test_power_unit(self): - m = measure(4.0, "m") - m **= 3 - self.assertAlmostEqual(m.toval(), 64.0, delta=1.0e-09) - self.assertEqual(m.dimension, [0, 3, 0, 0, 0, 0, 0, 0, 0]) - - def test_equivalent_units(self): - m = measure(1.0, "eV", equivalent=True) - self.assertAlmostEqual(m.toval("THz"), 241.799, delta=1.0e-03) - self.assertAlmostEqual(m.toval("K"), 11604.52, delta=1.0e-02) - - m = measure(1.0, "eV", equivalent=False) - self.assertRaises(UnitError, m.toval, "THz") +from contextlib import nullcontext +from math import ceil, floor +from operator import (add, iadd, imul, ipow, isub, itruediv, mul, pow, sub, + truediv) +from random import random +from typing import TypeVar, Union + +import pytest +from MDANSE.Framework.Units import _PREFIXES, UnitError, _Unit, measure + +T = TypeVar("T") + +def _measure_or_val(m_o_v: T) -> Union[_Unit, T]: + """Convert a value to a measure or leave if not (val, unit).""" + if isinstance(m_o_v, (tuple, list)): + return measure(*m_o_v) + return m_o_v + +@pytest.mark.parametrize("unit", [ + "kg", "m", "s", "K", "mol", "A", "cd", "rad", "sr", +]) +def test_basic_units(unit): + m = measure(1.0, unit) + assert m.toval() == pytest.approx(1.0) + +@pytest.mark.parametrize("prefix", [ + "y", "z", "a", "f", "p", "n", "u", "m", "c", "d", + "da", "h", "k", "M", "G", "T", "P", "E", "Z", "Y", +]) +def test_prefixes(prefix): + val = random() + m = measure(val, "s") + assert m.toval(f"{prefix}s") == pytest.approx(val / _PREFIXES[prefix]) + +@pytest.mark.parametrize("from_, equivalent, to, expected", [ + ((1., "m/s"), False, "km/h", nullcontext(3.6)), + ((1., "eV"), False, "THz", pytest.raises(UnitError)), + ((1., "eV"), True, "THz", nullcontext(241.799)), + ((1., "eV"), True, "K", nullcontext(11604.52)), +]) +def test_conversion(from_, equivalent, to, expected): + m = measure(*from_, equivalent=equivalent) + with expected as val: + assert m.toval(to) == pytest.approx(val) + +@pytest.mark.parametrize("in_units, op, out_unit, out_val", [ + (((1., "s"), (1., "ms")), add, "s", 1.001), + (((1., "s"), (1., "ms")), sub, "s", 0.999), + + (((1., "m"), (5., "hm")), mul, "m2", 500.), + (((500., "m2"), (10., "cm")), mul, "m3", 50.), + (((50., "m3"), 20.), mul, "m3", 1000.), + + (((1., "m"), (5., "hm")), truediv, "au", 0.002), + (((0.002, "au"), 0.0001), truediv, "au", 20.), + (((20., "au"), (5., "hm")), truediv, "1/m", 4.e-2), + + (((4., "m"), 3), pow, "m3", 64.), + + # In-place + (((1., "s"), (1., "ms")), iadd, "s", 1.001), + (((1., "s"), (1., "ms")), isub, "s", 0.999), + + (((1., "m"), (5., "hm")), imul, "m2", 500.), + (((500., "m2"), (10., "cm")), imul, "m3", 50.), + (((50., "m3"), 20.), imul, "m3", 1000.), + + (((1., "m"), (5., "hm")), itruediv, "au", 0.002), + (((0.002, "au"), 0.0001), itruediv, "au", 20.), + (((20., "au"), (5., "hm")), itruediv, "1/m", 4.e-2), + + (((4., "m"), 3), ipow, "m3", 64.), + + # Other ops + (((10.2, "m/s"),), floor, None, 10.), + (((3.6, "m/s"),), floor, None, 3.), + (((50.3, "km/h"),), floor, None, 50.), + + (((10.2, "m/s"),), ceil, None, 11.), + (((3.6, "m/s"),), ceil, None, 4.), + (((50.3, "km/h"),), ceil, None, 51.), + + (((10.2, "m/s"),), round, None, 10.), + (((3.6, "m/s"),), round, None, 4.), + (((50.3, "km/h"),), round, None, 50.), + + (((50.3, "km/h"),), round, None, 50.), + +]) +def test_operators(in_units, op, out_unit, out_val): + m = op(*map(_measure_or_val, in_units)) + assert m.toval(out_unit) == pytest.approx(out_val) + + +def test_sqrt(): + m = measure(4.0, "m2/s2") + + m = m.sqrt() + + assert m.toval() == 2.0 + assert m.dimension == [0, 1, -1, 0, 0, 0, 0, 0, 0]