From 2590a38cf28117764500517f3f3a831626bb624c Mon Sep 17 00:00:00 2001 From: Jacob Wilkins Date: Mon, 3 Mar 2025 15:26:25 +0000 Subject: [PATCH] Rework Units --- MDANSE/Src/MDANSE/Framework/Units.py | 954 +++++++++++++++++++-------- 1 file changed, 666 insertions(+), 288 deletions(-) diff --git a/MDANSE/Src/MDANSE/Framework/Units.py b/MDANSE/Src/MDANSE/Framework/Units.py index de12f5eb6..0df1dd0f2 100644 --- a/MDANSE/Src/MDANSE/Framework/Units.py +++ b/MDANSE/Src/MDANSE/Framework/Units.py @@ -14,10 +14,12 @@ # along with this program. If not, see . # import copy +import json import math import numbers from collections import defaultdict -import json +from functools import singledispatchmethod +from typing import Optional, Tuple from MDANSE.Core.Platform import PLATFORM from MDANSE.Core.Singleton import Singleton @@ -70,244 +72,230 @@ class UnitError(Exception): pass -def get_trailing_digits(s): - for i in range(len(s)): - if s[i:].isdigit(): - return s[:i], int(s[i:]) - else: - return s, 1 - - -def _parse_unit(iunit): - max_prefix_length = 0 - for p in _PREFIXES: - max_prefix_length = max(max_prefix_length, len(p)) - - iunit = iunit.strip() - - iunit, upower = get_trailing_digits(iunit) - if not iunit: - raise UnitError("Invalid unit") - - for i in range(len(iunit)): - if UNITS_MANAGER.has_unit(iunit[i:]): - prefix = iunit[:i] - iunit = iunit[i:] - break - else: - raise UnitError(f"The unit {iunit} is unknown") - - if prefix: - if prefix not in _PREFIXES: - raise UnitError(f"The prefix {prefix} is unknown") - prefix = _PREFIXES[prefix] +def get_trailing_digits(string: str) -> Tuple[str, int]: + """Get digits from the end of a string. + + Always returns ``1`` if no digits. + + Parameters + ---------- + string : str + FIXME: Add docs. + + Returns + ------- + str + String with digits stripped. + int + Digits from end of string as ``int``. + + Examples + -------- + >>> get_trailing_digits("str123") + ('str', 123) + >>> get_trailing_digits("nodigits") + ('nodigits', 1) + >>> get_trailing_digits("123preceding") + ('123preceding', 1) + """ + for i in range(len(string)): + if string[i:].isdigit(): + return string[:i], int(string[i:]) else: - prefix = 1.0 - - unit = UNITS_MANAGER.get_unit(iunit) + return string, 1 + + +class _Unit: + """Unit handler. + + Handles all basic functions of units with correct dimensionality + and string printing. + + Parameters + ---------- + uname : str + Name of the unit. + factor : float + Factor relative to internal units. + + Extra Parameters + ---------------- + kg : int + Mass dimension. + m : int + Length dimension. + s : int + Time dimension. + K : int + Temperature dimension. + mol : int + Count dimension. + A : int + Current dimension. + cd : int + Luminous intensity dimension. + rad : int + Angular dimension. + sr : int + Solid angular dimension. + """ - unit = _Unit(iunit, prefix * unit._factor, *unit._dimension) - - unit **= upower - - return unit - - -def _str_to_unit(s): - if UNITS_MANAGER.has_unit(s): - unit = UNITS_MANAGER.get_unit(s) - return copy.deepcopy(unit) - - else: - unit = _Unit("au", 1.0) - - splitted_units = s.split("/") - - if len(splitted_units) == 1: - units = splitted_units[0].split(" ") - for u in units: - u = u.strip() - unit *= _parse_unit(u) - unit._uname = s - - return unit - - elif len(splitted_units) == 2: - numerator = splitted_units[0].strip() - if numerator != "1": - numerator = numerator.split(" ") - for u in numerator: - u = u.strip() - unit *= _parse_unit(u) - - denominator = splitted_units[1].strip().split(" ") - for u in denominator: - u = u.strip() - unit /= _parse_unit(u) - - unit._uname = s - - return unit - - else: - raise UnitError(f"Invalid unit: {s}") - - -class _Unit(object): def __init__( - self, uname, factor, kg=0, m=0, s=0, K=0, mol=0, A=0, cd=0, rad=0, sr=0 + self, + uname: str, + factor: float, + kg: int = 0, + m: int = 0, + s: int = 0, + K: int = 0, + mol: int = 0, + A: int = 0, + cd: int = 0, + rad: int = 0, + sr: int = 0, ): self._factor = factor - self._dimension = [kg, m, s, K, mol, A, cd, rad, sr] - self._format = "g" - self._uname = uname - self._ounit = None - self._out_factor = None - self._equivalent = False def __add__(self, other): - """Add _Unit instances. To be added, the units has to be analog or equivalent. + """Add two _Unit instances. + + To be added, the units have to be analog or equivalent. + + Parameters + ---------- + other : _Unit + Unit to add. + Raises + ------ + UnitError + Units are not equivalent or incompatible. + + Examples + -------- >>> print(measure(10, 'm') + measure(20, 'km')) 20010 m """ - u = copy.deepcopy(self) if u.is_analog(other): u._factor += other._factor - return u elif self._equivalent: equivalence_factor = u.get_equivalence_factor(other) - if equivalence_factor is not None: - u._factor += other._factor / equivalence_factor - return u - else: + if equivalence_factor is None: raise UnitError("The units are not equivalent") + + u._factor += other._factor / equivalence_factor else: - raise UnitError("Incompatible units") + raise UnitError("Incompatible units.") - def __truediv__(self, other): - """Divide _Unit instances. + return u - >>> print(measure(100, 'V') / measure(10, 'kohm')) - 0.0100 A - """ + def __sub__(self, other): + """Subtract _Unit instances. + To be subtracted, the units have to be analog or equivalent. + + >>> print(measure(20, 'km') + measure(10, 'm')) + 20.01 km + """ u = copy.deepcopy(self) - if isinstance(other, numbers.Number): - u._factor /= other - return u - elif isinstance(other, _Unit): - u._div_by(other) - return u - else: - raise UnitError("Invalid operand") - def __float__(self): - """Return the value of a _Unit coerced to float. See __int__.""" + if u.is_analog(other): + u._factor -= other._factor + elif u._equivalent: + equivalence_factor = u.get_equivalence_factor(other) + if equivalence_factor is None: + raise UnitError("The units are not equivalent") - return float(self.toval()) + u._factor -= other._factor / equivalence_factor + else: + raise UnitError("Incompatible units") - def __floordiv__(self, other): - u = copy.deepcopy(self) - u._div_by(other) - u._factor = math.floor(u._factor) return u - def __iadd__(self, other): - """Add _Unit instances. See __add__.""" + def __truediv__(self, other): + """Divide two _Unit instances. - if self.is_analog(other): - self._factor += other._factor - return self - elif self._equivalent: - equivalence_factor = self.get_equivalence_factor(other) - if equivalence_factor is not None: - self._factor += other._factor / equivalence_factor - return self - else: - raise UnitError("The units are not equivalent") - else: - raise UnitError("Incompatible units") + To be divided, the units have to be analog or equivalent. - def __itruediv__(self, other): - """Divide _Unit instances. See __div__.""" + Parameters + ---------- + other : _Unit + Unit to add. + Raises + ------ + UnitError + Units are not equivalent or incompatible. + + Examples + -------- + >>> print(measure(100, 'V') / measure(10, 'kohm')) + 0.01 A1 + >>> print(measure(100, 'V') / 10) + 10 V + """ + u = copy.deepcopy(self) if isinstance(other, numbers.Number): - self._factor /= other - return self + u._factor /= other elif isinstance(other, _Unit): - self._div_by(other) - return self + u._div_by(other) else: raise UnitError("Invalid operand") - def __ifloordiv__(self, other): - self._div_by(other) - self._factor = math.floor(self._factor) - return self - - def __imul__(self, other): - """Multiply _Unit instances. See __mul__.""" + return u + def __floordiv__(self, other): + """Divide two _Unit instances and truncate. + + To be divided, the units have to be analog or equivalent. + + Parameters + ---------- + other : _Unit + Unit to add. + + Raises + ------ + UnitError + Units are not equivalent or incompatible. + + Examples + -------- + >>> print(measure(10, 'kohm') // measure(10, 'V')) + 1000 1 / A1 + >>> print(measure(15, 'ohm') // 10) + 1 ohm + """ + u = copy.deepcopy(self) if isinstance(other, numbers.Number): - self._factor *= other - return self + u._factor //= other elif isinstance(other, _Unit): - self._mult_by(other) - return self + u._div_by(other) + u._factor = math.floor(u._factor) else: raise UnitError("Invalid operand") - def __int__(self): - """Return the value of a _Unit coerced to integer. - - Note that this will happen to the value in the default output unit: - - >>> print(int(measure(10.5, 'm/s'))) - 10 - """ - - return int(self.toval()) - - def __ipow__(self, n): - self._factor = pow(self._factor, n) - for i in range(len(self._dimension)): - self._dimension[i] *= n - - self._ounit = None - self._out_factor = None - - return self - - def __isub__(self, other): - """Substract _Unit instances. See __sub__.""" - - if self.is_analog(other): - self._factor -= other._factor - return self - elif self._equivalent: - equivalence_factor = self.get_equivalence_factor(other) - if equivalence_factor is not None: - self._factor -= other._factor / equivalence_factor - return self - else: - raise UnitError("The units are not equivalent") - else: - raise UnitError("Incompatible units") + return u def __mul__(self, other): - """Multiply _Unit instances. + """Multiply _Unit instances or scaling factors. + Examples + -------- >>> print(measure(10, 'm/s') * measure(10, 's')) - 100.0000 m + 100 m1 + >>> print(measure(10, 'm') * measure(10, 's')) + 100 m1 s1 + >>> print(measure(10, 'm') * 10) + 100 m """ u = copy.deepcopy(self) @@ -320,7 +308,14 @@ def __mul__(self, other): else: raise UnitError("Invalid operand") - def __pow__(self, n): + def __pow__(self, n: float): + """Raise a _Unit to a factor. + + Examples + -------- + >>> print(measure(10.5, 'm/s')**2) + 110.25 m2 / s2 + """ output_unit = copy.copy(self) output_unit._ounit = None output_unit._out_factor = None @@ -330,49 +325,45 @@ def __pow__(self, n): return output_unit - def __radd__(self, other): - """Add _Unit instances. See __add__.""" - - return self.__add__(other) + def __float__(self) -> float: + """Return the value of a _Unit coerced to float. - def __rdiv__(self, other): - u = copy.deepcopy(self) - if isinstance(other, numbers.Number): - u._factor /= other - return u - elif isinstance(other, _Unit): - u._div_by(other) - return u - else: - raise UnitError("Invalid operand") + Examples + -------- + >>> float(measure(10.5, 'm/s')) + 10.5 - def __rmul__(self, other): - """Multiply _Unit instances. See __mul__.""" + See Also + -------- + __int__ : Truncate value. + """ + return float(self.toval()) - u = copy.deepcopy(self) - if isinstance(other, numbers.Number): - u._factor *= other - return u - elif isinstance(other, _Unit): - u._mult_by(other) - return u - else: - raise UnitError("Invalid operand") + def __int__(self) -> int: + """Return the value of a _Unit coerced to integer. - def __rsub__(self, other): - """Substract _Unit instances. See __sub__.""" + Notes + ------ + This will happen to the value in the default output unit: - return other.__sub__(self) + Examples + -------- + >>> int(measure(10.5, 'm/s')) + 10 + """ + return int(self.toval()) def __ceil__(self): """Ceil of a _Unit value in canonical units. + Examples + -------- >>> print(measure(10.2, 'm/s').ceiling()) - 10.0000 m / s + 11 m/s >>> print(measure(3.6, 'm/s').ounit('km/h').ceiling()) - 10.0 km / h + 13 km/h >>> print(measure(50.3, 'km/h').ceiling()) - 50.0 km / h + 51 km/h """ r = copy.deepcopy(self) @@ -389,12 +380,14 @@ def __ceil__(self): def __floor__(self): """Floor of a _Unit value in canonical units. + Examples + -------- >>> print(measure(10.2, 'm/s').floor()) - 10.0000 m / s + 10 m/s >>> print(measure(3.6, 'm/s').ounit('km/h').floor()) - 10.0 km / h + 12 km/h >>> print(measure(50.3, 'km/h').floor()) - 50.0 km / h + 50 km/h """ r = copy.deepcopy(self) @@ -411,12 +404,14 @@ def __floor__(self): def __round__(self, ndigits=None): """Round of a _Unit value in canonical units. + Examples + -------- >>> print(measure(10.2, 'm/s').round()) - 10.0000 m / s + 10 m/s >>> print(measure(3.6, 'm/s').ounit('km/h').round()) - 11.0 km / h + 13 km/h >>> print(measure(50.3, 'km/h').round()) - 50.0 km / h + 50 km/h """ r = copy.deepcopy(self) @@ -430,32 +425,141 @@ def __round__(self, ndigits=None): r._factor = round(r._factor, ndigits) return r - ceil = __ceil__ + ceiling = __ceil__ floor = __floor__ round = __round__ - def __sub__(self, other): - """Substract _Unit instances. To be substracted, the units has to be analog or equivalent. + def __iadd__(self, other): + """Add _Unit instances. - >>> print(measure(20, 'km') + measure(10, 'm')) - 19990 m + See Also + -------- + __add__ """ - u = copy.deepcopy(self) + if self.is_analog(other): + self._factor += other._factor + return self + elif self._equivalent: + equivalence_factor = self.get_equivalence_factor(other) + if equivalence_factor is not None: + self._factor += other._factor / equivalence_factor + return self + else: + raise UnitError("The units are not equivalent") + else: + raise UnitError("Incompatible units") - if u.is_analog(other): - u._factor -= other._factor - return u - elif u._equivalent: - equivalence_factor = u.get_equivalence_factor(other) + def __itruediv__(self, other): + """Divide _Unit instances. + + See Also + -------- + __div__ + """ + + if isinstance(other, numbers.Number): + self._factor /= other + return self + elif isinstance(other, _Unit): + self._div_by(other) + return self + else: + raise UnitError("Invalid operand") + + def __ifloordiv__(self, other): + """Divide _Unit instances and truncate. + + See Also + -------- + __truediv__ + """ + self._div_by(other) + self._factor = math.floor(self._factor) + return self + + def __imul__(self, other): + """ + Multiply _Unit instances. + + See Also + -------- + __mul__ + """ + + if isinstance(other, numbers.Number): + self._factor *= other + return self + elif isinstance(other, _Unit): + self._mult_by(other) + return self + else: + raise UnitError("Invalid operand") + + def __ipow__(self, n): + self._factor = pow(self._factor, n) + for i in range(len(self._dimension)): + self._dimension[i] *= n + + self._ounit = None + self._out_factor = None + + return self + + def __isub__(self, other): + """Subtract _Unit instances. See __sub__.""" + + if self.is_analog(other): + self._factor -= other._factor + return self + elif self._equivalent: + equivalence_factor = self.get_equivalence_factor(other) if equivalence_factor is not None: - u._factor -= other._factor / equivalence_factor - return u + self._factor -= other._factor / equivalence_factor + return self else: raise UnitError("The units are not equivalent") else: raise UnitError("Incompatible units") + def __radd__(self, other): + """Add _Unit instances. + + See Also + -------- + __add__ + """ + return self.__add__(other) + + def __rdiv__(self, other): + u = copy.deepcopy(self) + if isinstance(other, numbers.Number): + u._factor /= other + return u + elif isinstance(other, _Unit): + u._div_by(other) + return u + else: + raise UnitError("Invalid operand") + + def __rmul__(self, other): + """Multiply _Unit instances. See __mul__.""" + + u = copy.deepcopy(self) + if isinstance(other, numbers.Number): + u._factor *= other + return u + elif isinstance(other, _Unit): + u._mult_by(other) + return u + else: + raise UnitError("Invalid operand") + + def __rsub__(self, other): + """Subtract _Unit instances. See __sub__.""" + + return other.__sub__(self) + def __str__(self): unit = copy.copy(self) @@ -493,7 +597,26 @@ def __str__(self): return s - def _div_by(self, other): + def _div_by(self, other) -> None: + """Compute divided unit including new dimensionality. + + Parameters + ---------- + other : _Unit + Factor to divide by. + + Raises + ------ + UnitError + If other is not compatible. + + Examples + -------- + >>> a = measure(2., "ang") + >>> a._div_by(measure(4., "s")) + >>> print(a) + 5e-11 m1 / s1 + """ if self.is_analog(other): self._factor /= other._factor self._dimension = [0, 0, 0, 0, 0, 0, 0, 0, 0] @@ -512,20 +635,39 @@ def _div_by(self, other): self._ounit = None self._out_factor = None - def _mult_by(self, other): + def _mult_by(self, other) -> None: + """Compute multiplied unit including new dimensionality. + + Parameters + ---------- + other : _Unit + Factor to multiply by. + + Raises + ------ + UnitError + If other is not compatible. + + Examples + -------- + >>> a = measure(1., "ang") + >>> a._mult_by(measure(3., "s")) + >>> print(a) + 3e-10 m1 s1 + """ if self.is_analog(other): self._factor *= other._factor for i in range(len(self._dimension)): self._dimension[i] = 2.0 * self._dimension[i] elif self._equivalent: equivalence_factor = self.get_equivalence_factor(other) - if equivalence_factor is not None: - self._factor *= other._factor / equivalence_factor - for i in range(len(self._dimension)): - self._dimension[i] = 2 * self._dimension[i] - return - else: + if equivalence_factor is None: raise UnitError("The units are not equivalent") + + self._factor *= other._factor / equivalence_factor + for i in range(len(self._dimension)): + self._dimension[i] = 2 * self._dimension[i] + return else: self._factor *= other._factor for i in range(len(self._dimension)): @@ -541,15 +683,13 @@ def dimension(self): return copy.copy(self._dimension) @property - def equivalent(self): + def equivalent(self) -> bool: """Getter for _equivalent attribute.""" return self._equivalent @equivalent.setter def equivalent(self, equivalent): - """Setter for _equivalent attribute.""" - self._equivalent = equivalent @property @@ -566,43 +706,106 @@ def format(self): @format.setter def format(self, fmt): - """Setter for the output format.""" - self._format = fmt - def is_analog(self, other): - """Returns True if the other unit is analog to this unit. Analog units are units whose dimension vector exactly matches.""" + def is_analog(self, other) -> bool: + """Whether two units are analog. + + Analog units are units whose dimension vector exactly matches. + + Parameters + ---------- + other : _Unit + Unit to test. + + Returns + ------- + bool + Whether two units are "analog". + + Examples + -------- + >>> a, b = measure(1., "km"), measure(1., "ang") + >>> a.is_analog(b) + True + >>> a, b = measure(1., "km"), measure(1., "ohm") + >>> a.is_analog(b) + False + """ return self._dimension == other._dimension - def get_equivalence_factor(self, other): - """Returns the equivalence factor if the other unit is equivalent to this unit. Equivalent units are units whose dimension are related through a constant + def get_equivalence_factor(self, other) -> Optional[float]: + """Returns the equivalence factor if other unit is equivalent. + + Equivalent units are units whose dimension are related through a constant (e.g. energy and mass, or frequency and temperature). + + Parameters + ---------- + other : _Unit + Potentially equivalent unit. + + Returns + ------- + Optional[float] + Equivalence factor to transform from one to the other + or ``None`` if not equivalent. + + See Also + -------- + _EQUIVALENCES : Dict of equivalent units. + add_equivalence : Add new equivalence to dict. + + Examples + -------- + >>> a = measure(1., "1/m") + >>> a.get_equivalence_factor(measure(1., "J/mol")) + 0.000119627 + >>> print(a.get_equivalence_factor(measure(1., "ang"))) + None """ _, upower = get_trailing_digits(self._uname) - dimension = tuple([d / upower for d in self._dimension]) + dimension = tuple(d / upower for d in self._dimension) + if dimension not in _EQUIVALENCES: return None powerized_equivalences = {} - for k, v in list(_EQUIVALENCES[dimension].items()): - pk = tuple([d * upower for d in k]) + for k, v in _EQUIVALENCES[dimension].items(): + pk = tuple(d * upower for d in k) powerized_equivalences[pk] = pow(v, upower) odimension = tuple(other._dimension) if odimension in powerized_equivalences: return powerized_equivalences[odimension] - else: - return None - def ounit(self, ounit): + return None + + def ounit(self, ounit: str): """Set the preferred unit for output. + Parameters + ---------- + ounit : str + Preferred output unit. + + Raises + ------ + UnitError + Units are incompatible. + + Notes + ----- + Returns a modified reference to self not a new object. + + Examples + -------- >>> a = measure(1, 'kg m2 / s2') >>> print(a) - 1.0000 kg m2 / s2 + 1 kg m2 / s2 >>> print(a.ounit('J')) - 1.0000 J + 1 J """ out_factor = _str_to_unit(ounit) @@ -611,35 +814,56 @@ def ounit(self, ounit): self._ounit = ounit self._out_factor = out_factor return self + elif self._equivalent: - if self.get_equivalence_factor(out_factor) is not None: - self._ounit = ounit - self._out_factor = out_factor - return self - else: + if self.get_equivalence_factor(out_factor) is None: raise UnitError("The units are not equivalents") + + self._ounit = ounit + self._out_factor = out_factor + return self + else: raise UnitError("The units are not compatible") def sqrt(self): """Square root of a _Unit. + Returns + ------- + _Unit + New unit which is sqrt of original. + + Examples + -------- >>> print(measure(4, 'm2/s2').sqrt()) - 2.0000 m / s + 2 m1.0 / s-1.0 """ return self**0.5 - def toval(self, ounit=""): + def toval(self, ounit: str = "") -> float: """Returns the numeric value of a unit. The value is given in ounit or in the default output unit. + Parameters + ---------- + ounit : str + Unit to convert to. + + Returns + ------- + float + Value in output unit. + + Examples + -------- >>> v = measure(100, 'km/h') >>> v.toval() 100.0 >>> v.toval(ounit='m/s') - 27.777777777777779 + 27.77777777777778 """ newu = copy.deepcopy(self) @@ -664,19 +888,119 @@ def toval(self, ounit=""): return newu._factor -class UnitsManagerEncoder(json.JSONEncoder): - def default(self, obj): - if isinstance(obj, UnitsManager): - d = {} - for k, v in obj.units: - d[k] = {"factor": v.factor, "dimension": v.dimension} - return d - elif isinstance(obj, _Unit): - return {"factor": obj.factor, "dimension": obj.dimension} - return json.JSONEncoder.default(self, obj) +def _parse_unit(iunit: str) -> _Unit: + """Parse single unit as a string into a Unit type. + + Parameters + ---------- + iunit : str + String to parse. + + Returns + ------- + _Unit + Expected unit. + + Raises + ------ + UnitError + If string does not contain valid unit. + """ + + max_prefix_length = 0 + for p in _PREFIXES: + max_prefix_length = max(max_prefix_length, len(p)) + + iunit = iunit.strip() + + iunit, upower = get_trailing_digits(iunit) + if not iunit: + raise UnitError("Invalid unit") + + for i in range(len(iunit)): + if UNITS_MANAGER.has_unit(iunit[i:]): + prefix = iunit[:i] + iunit = iunit[i:] + break + else: + raise UnitError(f"The unit {iunit} is unknown") + + if prefix: + if prefix not in _PREFIXES: + raise UnitError(f"The prefix {prefix} is unknown") + prefix = _PREFIXES[prefix] + else: + prefix = 1.0 + + unit = UNITS_MANAGER.get_unit(iunit) + + unit = _Unit(iunit, prefix * unit._factor, *unit._dimension) + + unit **= upower + + return unit + + +def _str_to_unit(s: str) -> _Unit: + """Parse general string into unit description. + + Parameters + ---------- + s : str + String to parse. + + Returns + ------- + _Unit + Parsed unit. + + Raises + ------ + UnitError + String is not a valid unit specification. + """ + if UNITS_MANAGER.has_unit(s): + unit = UNITS_MANAGER.get_unit(s) + return copy.deepcopy(unit) + + else: + unit = _Unit("au", 1.0) + + splitted_units = s.split("/") + + if len(splitted_units) == 1: + units = splitted_units[0].split(" ") + for u in units: + u = u.strip() + unit *= _parse_unit(u) + unit._uname = s + + return unit + + elif len(splitted_units) == 2: + numerator = splitted_units[0].strip() + if numerator != "1": + numerator = numerator.split(" ") + for u in numerator: + u = u.strip() + unit *= _parse_unit(u) + + denominator = splitted_units[1].strip().split(" ") + for u in denominator: + u = u.strip() + unit /= _parse_unit(u) + + unit._uname = s + + return unit + + else: + raise UnitError(f"Invalid unit: {s}") class UnitsManager(metaclass=Singleton): + """Database dictionary for handling units.""" + _UNITS = {} _DEFAULT_DATABASE = ( @@ -695,39 +1019,49 @@ def add_unit( uname, factor, kg, m, s, K, mol, A, cd, rad, sr ) - def delete_unit(self, uname): + def delete_unit(self, uname) -> None: if uname in UnitsManager._UNITS: del UnitsManager._UNITS[uname] - def get_unit(self, uname): + def get_unit(self, uname) -> Optional[_Unit]: return UnitsManager._UNITS.get(uname, None) - def has_unit(self, uname): + def has_unit(self, uname) -> bool: return uname in UnitsManager._UNITS - def load(self): + def load(self) -> None: + """Load units from databases. + + Fill self with unit infomration. + """ UnitsManager._UNITS.clear() d = {} + with open(UnitsManager._DEFAULT_DATABASE, "r") as fin: d.update(json.load(fin)) + try: with open(UnitsManager._USER_DATABASE, "r") as fin: d.update(json.load(fin)) - except Exception: + + except FileNotFoundError: self.save() + finally: - for uname, udict in list(d.items()): + for uname, udict in d.items(): factor = udict.get("factor", 1.0) dim = udict.get("dimension", [0, 0, 0, 0, 0, 0, 0, 0, 0]) UnitsManager._UNITS[uname] = _Unit(uname, factor, *dim) def save(self): + """Write self to custom user database.""" with open(UnitsManager._USER_DATABASE, "w") as fout: json.dump(UnitsManager._UNITS, fout, indent=4, cls=UnitsManagerEncoder) @property def units(self): + """Direct access to unit database.""" return UnitsManager._UNITS @units.setter @@ -735,6 +1069,23 @@ def units(self, units): UnitsManager._UNITS = units +class UnitsManagerEncoder(json.JSONEncoder): + """Custom encoder for writing units.""" + + @singledispatchmethod + def default(self, obj): + return json.JSONEncoder.default(self, obj) + + @default.register(UnitsManager) + def _(self, obj): + return {k: {"factor": v.factor, "dimension": v.dimension} for k, v in obj.units} + + @default.register(_Unit) + def _(self, obj): + return {"factor": obj.factor, "dimension": obj.dimension} + + +#: Set of units considered directly or indirectly equivalent. _EQUIVALENCES = defaultdict(dict) @@ -743,7 +1094,34 @@ def add_equivalence(dim1, dim2, factor): _EQUIVALENCES[dim2][dim1] = 1.0 / factor -def measure(val, iunit="au", ounit="", equivalent=False): +def measure( + val: float, iunit: str = "au", ounit: str = "", *, equivalent: bool = False +) -> _Unit: + """Create a unit bearing object. + + Parses i/ounits and returns the relevant data object. + + Parameters + ---------- + val : float + Value for unit. + iunit : str + Input unit. + ounit : str + Desired output unit. + equivalent : bool + Whether the unit is to be considered "equivalent". + + Returns + ------- + _Unit + Desired unit. + + Examples + -------- + >>> print(measure(1., 'ang')) + 1 ang + """ if iunit: unit = _str_to_unit(iunit) unit *= val