BrainUnit provides physical units and unit-aware mathematical system in JAX for brain dynamics modeling. It introduces rigoirous physical units into high-performance AI-driven abstract numerical computing.
BrainUnit is initially designed to enable unit-aware computations in brain dynamics modeling (see our BDP ecosystem). However, its features and capacities can be applied to general domains in scientific computing and AI4Science. Starting in 2025/02, BrainUnit has been fully integrated into SAIUnit (the Unit system for Scientific AI).
Functionalities are the same for both brainunit
and saiunit
, and their functions and data structures are interoperable, sharing the same set of APIs, and eliminating any potential conflicts. This meas that
import brainunit as u
equals to
import saiunit as u
For users primarily engaged in general scientific computing, saiunit
is likely the preferred choice. However, for those focused on brain modeling, we recommend brainunit
, as it is more closely aligned with our specialized brain dynamics programming ecosystem.
The official documentation of BrainUnit is hosted on Read the Docs: https://brainunit.readthedocs.io
brainunit
can be seamlessly integrated into every aspect of our brain dynamics programming ecosystem, such as, the checkpointing of braintools, the event-driven operators in brainevent, the state-based JIT compilation in brainstate, online learning rules in brainscale, or event more.
A quick example for this kind of integration:
import braintools
import brainevent.nn
import brainstate
import brainunit as u
class EINet(brainstate.nn.Module):
def __init__(self):
super().__init__()
self.n_exc = 3200
self.n_inh = 800
self.num = self.n_exc + self.n_inh
self.N = brainstate.nn.LIFRef(
self.num, V_rest=-60. * u.mV, V_th=-50. * u.mV, V_reset=-60. * u.mV,
tau=20. * u.ms, tau_ref=5. * u.ms,
V_initializer=brainstate.init.Normal(-55., 2., unit=u.mV)
)
self.E = brainstate.nn.AlignPostProj(
comm=brainevent.nn.FixedProb(self.n_exc, self.num, 0.02, 0.6 * u.mS),
syn=brainstate.nn.Expon.desc(self.num, tau=5. * u.ms),
out=brainstate.nn.COBA.desc(E=0. * u.mV),
post=self.N
)
self.I = brainstate.nn.AlignPostProj(
comm=brainevent.nn.FixedProb(self.n_inh, self.num, 0.02, 6.7 * u.mS),
syn=brainstate.nn.Expon.desc(self.num, tau=10. * u.ms),
out=brainstate.nn.COBA.desc(E=-80. * u.mV),
post=self.N
)
def update(self, t, inp):
with brainstate.environ.context(t=t):
spk = self.N.get_spike() != 0.
self.E(spk[:self.n_exc])
self.I(spk[self.n_exc:])
self.N(inp)
return self.N.get_spike()
def save_checkpoint(self):
braintools.file.msgpack_save('states.msgpack', self.states())
You can install brainunit
via pip:
pip install brainunit --upgrade
We are building the brain dynamics programming (BDP) ecosystem. brainunit has been deeply integrated into our BDP ecosystem.