Skip to content

Commit

Permalink
update README to reflect its general applicability (#97)
Browse files Browse the repository at this point in the history
* update README

* fix

* fix tests
  • Loading branch information
chaoming0625 authored Feb 2, 2025
1 parent 862210b commit 0d01b3a
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 6 deletions.
53 changes: 51 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@


# Physical units and unit-aware mathematical system in JAX
# ``BrainUnit``: physical units and unit-aware mathematical system for brain dynamics and AI4Science

<p align="center">
<img alt="Header image of brainunit." src="https://github.com/chaobrain/brainunit/blob/main/docs/_static/brainunit.png" width=50%>
Expand All @@ -20,7 +20,56 @@
</p>


[``brainunit``](https://github.com/chaor/brainunit) provides physical units and unit-aware mathematical system in JAX for brain dynamics and AI4Science
## Motivation


[``brainunit``](https://github.com/chaobrain/brainunit) provides physical units and unit-aware mathematical system in JAX for brain dynamics and AI4Science.

It is initially designed to enable unit-aware computations in brain dynamics modeling (see our [ecosystem](https://ecosystem-for-brain-dynamics.readthedocs.io/)).

However, its features and capacities can be applied to general domains for scientific computing and AI for science.
We also provide ample examples and tutorials to help users integrate ``brainunit`` into their projects
(see [Unit-aware computation ecosystem](#unit-aware-computation-ecosystem) in the below).


## Features


The uniqueness of ``Brainunit`` lies in that it brings physical units handling and AI-driven computation together in a seamless way:

- It provides over 2,000 commonly used physical units and constants.
- It implements over 500 unit-aware mathematical functions.
- Its physical units and unit-aware functions are fully compatible with JAX, including autograd, JIT, vecterization, parallelization, and others.


A quick example:

```python

import brainunit as u

# Define a physical quantity
x = 3.0 * u.meter
x
# [out] 3. * meter

# autograd
f = lambda x: x ** 3
u.autograd.grad(f)(x)
# [out] 27. * meter2


# JIT
import jax
jax.jit(f)(x)
# [out] 27. * klitre

# vmap
jax.vmap(f)(u.math.arange(0. * u.mV, 10. * u.mV, 1. * u.mV))
# [out] ArrayImpl([ 0., 1., 8., 27., 64., 125., 216., 343., 512., 729.],
# dtype=float32) * mvolt3
```



## Installation
Expand Down
9 changes: 5 additions & 4 deletions brainunit/_base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import jax.numpy as jnp
import numpy as np
import pytest
from numpy.ma.core import indices
from numpy.testing import assert_equal

import brainunit as u
Expand Down Expand Up @@ -768,23 +769,23 @@ def test_deepcopy(self):
assert d_copy["x"] == 2 * second
assert d["x"] == 1 * second

def test_numpy_functions_indices(self):
def test_indices_functions(self):
"""
Check numpy functions that return indices.
"""
values = [np.array([-4, 3, -2, 1, 0]), np.ones((3, 3)), np.array([17])]
units = [volt, second, siemens, mV, kHz]

# numpy functions
keep_dim_funcs = [np.argmin, np.argmax, np.argsort, np.nonzero]
indice_funcs = [u.math.argmin, u.math.argmax, u.math.argsort, u.math.nonzero]

for value, unit in itertools.product(values, units):
q_ar = value * unit
for func in keep_dim_funcs:
for func in indice_funcs:
test_ar = func(q_ar)
# Compare it to the result on the same value without units
comparison_ar = func(value)
test_ar = np.asarray(test_ar)
test_ar = u.math.asarray(test_ar)
comparison_ar = np.asarray(comparison_ar)
assert_equal(
test_ar,
Expand Down

0 comments on commit 0d01b3a

Please sign in to comment.