Skip to content

Commit

Permalink
Add environ module
Browse files Browse the repository at this point in the history
  • Loading branch information
Routhleck committed Jan 20, 2025
1 parent ce19440 commit 1600162
Show file tree
Hide file tree
Showing 4 changed files with 264 additions and 87 deletions.
2 changes: 2 additions & 0 deletions brainunit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from . import _matplotlib_compat
from . import autograd
from . import constants
from . import environ
from . import fft
from . import lax
from . import linalg
Expand All @@ -38,6 +39,7 @@
'math',
'linalg',
'autograd',
'environ',
'fft',
'constants',
'sparse'
Expand Down
96 changes: 9 additions & 87 deletions brainunit/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@
from jax.interpreters.partial_eval import DynamicJaxprTracer
from jax.tree_util import register_pytree_node_class

from .environ import (get_compute_mode,
SI_MODE,
NON_SI_MODE)
from ._misc import set_module_as
from ._sparse_base import SparseMatrix

Expand Down Expand Up @@ -65,7 +68,6 @@

# advanced functions
'get_or_create_dimension',
'convert_in_si',
]

StaticScalar = Union[
Expand Down Expand Up @@ -1233,6 +1235,7 @@ def _assert_same_base(u1, u2):
f"But we got {u1.base} != {u1.base}.")


# TODO: Cannot find compound standard unit
def _find_standard_unit(dim: Dimension, base, scale, factor) -> Tuple[Optional[str], bool, bool]:
"""
Find a standard unit for the given dimension, base, scale, and factor.
Expand Down Expand Up @@ -2199,6 +2202,11 @@ def __init__(
# dimension
self._unit = unit

if get_compute_mode() == SI_MODE:
self._mantissa = self._mantissa * self._unit.factor
self._unit = self._unit.factorless()


@property
def at(self):
"""
Expand Down Expand Up @@ -4944,89 +4952,3 @@ def _assign_unit(f, val, unit):

def _is_quantity(x):
return isinstance(x, Quantity)


def _convert_in_si(x):
"""
Convert a Quantity to a Quantity in SI units.
"""
if isinstance(x, Quantity) or isinstance(x, Unit):
return x.factorless()
return x


def convert_in_si():
"""
Convert all the local variables in SI units.
This function traverses the local variables in the calling scope and converts all `Quantity`
instances (including those nested in lists, tuples, or dictionaries) to their SI unit equivalents.
The conversion is performed by calling the `factorless()` method on each `Quantity` instance,
which convert the unit and returns the quantities in SI units.
Notes:
- This function modifies the local variables in the calling scope.
- Only `Quantity` instances are affected; other types of variables remain unchanged.
- If a `Quantity` instance is nested within a list, tuple, or dictionary, it will be
recursively converted to its SI unit equivalent.
Examples:
>>> import brainunit as u
>>> time1 = 1 * u.second
>>> time2 = 1 * u.minute
>>> time3 = time1 + time2
>>> time4 = time2 + time1
>>> time3
61. * second
>>> time4
1.0166667 * minute
>>> u.convert_in_si() # Convert all local variables to SI units
>>> time3 = time1 + time2
>>> time4 = time2 + time1
>>> time3
61. * second
>>> time4
6.1 * dasecond
>>> length1 = 1 * u.inch
>>> result1 = time1 * length1
>>> result2 = u.math.multiply(time1, length1)
>>> result1
1. * second * inch
>>> result2
1. * second * inch
>>> u.convert_in_si() # Convert all local variables to SI units
>>> result1 = time1 * length1
>>> result2 = u.math.multiply(time1, length1)
>>> result1
0.0254 * second * meter
>>> result2
0.0254 * second * meter
>>> dict1 = {
... 'time1': 1 * u.second,
... 'time2': 1 * u.minute,
... 'length1': 1 * u.inch,
...}
>>> u.convert_in_si() # Convert all local variables to SI units
>>> dict1
{'length1': 0.0254 * meter, 'time1': 1 * second, 'time2': 6. * dasecond}
Raises:
None: This function does not raise any exceptions explicitly, but may propagate
exceptions from `factorless()` or `jax.tree.map()` if they fail.
See Also:
- `Quantity.factorless()`: Method used to convert `Quantity` instances to SI units.
"""
frame = inspect.currentframe()
try:
caller_frame = frame.f_back
caller_globals = caller_frame.f_globals

for key, val in list(caller_globals.items()):
caller_globals[key] = jax.tree.map(_convert_in_si, val, is_leaf=lambda x: _is_quantity(x))
finally:
del frame
216 changes: 216 additions & 0 deletions brainunit/environ.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

# -*- coding: utf-8 -*-

from __future__ import annotations

import contextlib
import dataclasses
import functools
import inspect
import os
import re
import threading
from collections import defaultdict
from typing import Any, Callable, Dict, Hashable

__all__ = [
# functions for environment settings
'set', 'context', 'get', 'all',
# functions for getting default behaviors
'get_compute_mode',
# constants
'SI_MODE', 'NON_SI_MODE'
]

SI_MODE: str = 'si'
NON_SI_MODE: str = 'non_si'


@dataclasses.dataclass
class DefaultContext(threading.local):
# default environment settings
settings: Dict[Hashable, Any] = dataclasses.field(default_factory=dict)
# current environment settings
contexts: defaultdict[Hashable, Any] = dataclasses.field(default_factory=lambda: defaultdict(list))
# environment functions
functions: Dict[Hashable, Any] = dataclasses.field(default_factory=dict)

DEFAULT = DefaultContext()
_NOT_PROVIDE = object()


@contextlib.contextmanager
def context(**kwargs):
r"""
Context-manager that sets a computing environment for brainunit.
For instance::
>>> import brainunit as u
>>> global_1 = 2 * u.kmh
>>> global_2 = 0
>>> def create_a(a):
... return a.mantissa * 2 * u.minute
>>> with u.environ.context(compute_mode='si'):
... a = create_a([1, 2, 3] * u.minute) # If input is [1, 2, 3] * u.second, the result would differ
... b = [4, 5, 6] * u.inch
... global_2 = (b / a) / global_1
"""
if 'compute_mode' in kwargs:
if kwargs['compute_mode'] == SI_MODE:
_convert_to_si_quantity(**kwargs)
else:
pass

try:
for k, v in kwargs.items():

# update the current environment
DEFAULT.contexts[k].append(v)

# restore the environment functions
if k in DEFAULT.functions:
DEFAULT.functions[k](v)

# yield the current all environment information
yield all()
finally:

for k, v in kwargs.items():

# restore the current environment
DEFAULT.contexts[k].pop()

# restore the environment functions
if k in DEFAULT.functions:
DEFAULT.functions[k](get(k))


def get(key: str, default: Any = _NOT_PROVIDE, desc: str = None):
"""
Get one of the default computation environment.
Returns
-------
item: Any
The default computation environment.
"""
if key in DEFAULT.contexts:
if len(DEFAULT.contexts[key]) > 0:
return DEFAULT.contexts[key][-1]
if key in DEFAULT.settings:
return DEFAULT.settings[key]

if default is _NOT_PROVIDE:
if desc is not None:
raise KeyError(
f"'{key}' is not found in the context. \n"
f"You can set it by `brainstate.share.context({key}=value)` "
f"locally or `brainstate.share.set({key}=value)` globally. \n"
f"Description: {desc}"
)
else:
raise KeyError(
f"'{key}' is not found in the context. \n"
f"You can set it by `brainstate.share.context({key}=value)` "
f"locally or `brainstate.share.set({key}=value)` globally."
)
return default


def all() -> dict:
"""
Get all the current default computation environment.
Returns
-------
r: dict
The current default computation environment.
"""
r = dict()
for k, v in DEFAULT.contexts.items():
if v:
r[k] = v[-1]
for k, v in DEFAULT.settings.items():
if k not in r:
r[k] = v
return r


def get_compute_mode() -> str:
"""
Get the current compute mode.
Returns
-------
mode: str
The current compute mode.
"""
return get('compute_mode')

def set(
compute_mode: str = None,
**kwargs
):
"""
Set the global default computation environment.
Args:
compute_mode: str, optional
The default compute mode. Default is computing in 'si'.
"""
if compute_mode is not None:
assert compute_mode in ['si', 'non_si'], f"compute_mode must be 'si' or 'non_si'. Got: {compute_mode}"
kwargs['compute_mode'] = compute_mode

# set default environment
DEFAULT.settings.update(kwargs)

# update the environment functions
for k, v in kwargs.items():
if k in DEFAULT.functions:
DEFAULT.functions[k](v)

def _convert_to_si_quantity(**kwargs):
"""
Convert all the local variables in SI units.
Traverses the local variables in the calling scope and converts all `Quantity`
instances (including those nested in lists, tuples, or dictionaries) to their SI unit equivalents.
The conversion is performed by calling the `factorless()` method on each `Quantity` instance,
which convert the unit and returns the quantities in SI units.
"""
set(compute_mode=kwargs['compute_mode'])
from ._base import Quantity, Unit
frame = inspect.currentframe().f_back.f_back.f_back
original = {k: v for k, v in frame.f_locals.items()
if isinstance(v, (Quantity, Unit))}

try:
# Convert to SI
for k, v in original.items():
frame.f_locals[k] = v.factorless()
yield
finally:
# Restore original values
for k, v in original.items():
frame.f_locals[k] = v

set(compute_mode=NON_SI_MODE)
Loading

0 comments on commit 1600162

Please sign in to comment.