Skip to content

Commit

Permalink
solve conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Feb 2, 2025
2 parents 3217bf8 + f9f03d4 commit 145e670
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 10 deletions.
67 changes: 65 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@


# Physical units and unit-aware mathematical system in JAX
# ``BrainUnit``: physical units and unit-aware mathematical system for brain dynamics and AI4Science

<p align="center">
<img alt="Header image of brainunit." src="https://github.com/chaobrain/brainunit/blob/main/docs/_static/brainunit.png" width=50%>
Expand All @@ -20,7 +20,70 @@
</p>


[``brainunit``](https://github.com/chaor/brainunit) provides physical units and unit-aware mathematical system in JAX for brain dynamics and AI4Science
## Motivation


[``brainunit``](https://github.com/chaobrain/brainunit) provides physical units and unit-aware mathematical system in JAX for brain dynamics and AI4Science.

It is initially designed to enable unit-aware computations in brain dynamics modeling (see our [ecosystem](https://ecosystem-for-brain-dynamics.readthedocs.io/)).

However, its features and capacities can be applied to general domains in scientific computing and AI for science.
We also provide ample examples and tutorials to help users integrate ``brainunit`` into their projects
(see [Unit-aware computation ecosystem](#unit-aware-computation-ecosystem) in the below).


## Features


The uniqueness of ``Brainunit`` lies in that it brings physical units handling and AI-driven computation together in a seamless way:

- It provides over 2,000 commonly used physical units and constants.
- It implements over 500 unit-aware mathematical functions.
- Its physical units and unit-aware functions are fully compatible with JAX, including autograd, JIT, vecterization, parallelization, and others.


```mermaid
graph TD
A[BrainUnit] --> B[Physical Units]
A --> C[Mathematical Functions]
A --> D[JAX Integration]
B --> B1[2000+ Units]
B --> B2[Physical Constants]
C --> C1[500+ Unit-aware Functions]
D --> D1[Autograd]
D --> D2[JIT Compilation]
D --> D3[Vectorization]
D --> D4[Parallelization]
```

A quick example:

```python

import brainunit as u

# Define a physical quantity
x = 3.0 * u.meter
x
# [out] 3. * meter

# autograd
f = lambda x: x ** 3
u.autograd.grad(f)(x)
# [out] 27. * meter2


# JIT
import jax
jax.jit(f)(x)
# [out] 27. * klitre

# vmap
jax.vmap(f)(u.math.arange(0. * u.mV, 10. * u.mV, 1. * u.mV))
# [out] ArrayImpl([ 0., 1., 8., 27., 64., 125., 216., 343., 512., 729.],
# dtype=float32) * mvolt3
```



## Installation
Expand Down
11 changes: 3 additions & 8 deletions brainunit/_base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,20 +776,15 @@ def test_numpy_functions_indices(self):
units = [volt, second, siemens, mV, kHz]

# numpy functions
keep_dim_funcs = [
np.argmin,
np.argmax,
# np.argsort, # TODO: after upgrading jax 0.5.0, argsort will raise an error
np.nonzero
]
indice_funcs = [u.math.argmin, u.math.argmax, u.math.argsort, u.math.nonzero]

for value, unit in itertools.product(values, units):
q_ar = value * unit
for func in keep_dim_funcs:
for func in indice_funcs:
test_ar = func(q_ar)
# Compare it to the result on the same value without units
comparison_ar = func(value)
test_ar = np.asarray(test_ar)
test_ar = u.math.asarray(test_ar)
comparison_ar = np.asarray(comparison_ar)
assert_equal(
test_ar,
Expand Down

0 comments on commit 145e670

Please sign in to comment.