From 2364b21c3075601440c4dfd11a90eb19d193c3e3 Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Wed, 25 Dec 2024 11:39:26 +0800 Subject: [PATCH] Remove `automatically_register_units` (#85) * update unit repr * remove `automatically_register_units` * update ignore --- .gitignore | 2 ++ brainunit/__init__.py | 2 +- brainunit/_base.py | 20 ++++++++++++++++---- brainunit/_unit_common.py | 2 -- brainunit/sparse/_csr.py | 2 +- brainunit/sparse/_csr_test.py | 23 +++++++++++++++++++++++ dev/units_template.py | 1 - 7 files changed, 43 insertions(+), 9 deletions(-) diff --git a/.gitignore b/.gitignore index 99a9d53..11a6355 100644 --- a/.gitignore +++ b/.gitignore @@ -225,3 +225,5 @@ cython_debug/ /docs/apis/brainunit.linalg.rst /docs/apis/brainunit.math.rst /docs/apis/changelog.md +/dist-hist/ +/dist-hist/ diff --git a/brainunit/__init__.py b/brainunit/__init__.py index 5cea734..f1520b5 100644 --- a/brainunit/__init__.py +++ b/brainunit/__init__.py @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================== -__version__ = "0.0.3" +__version__ = "0.0.4" from . import _matplotlib_compat from . import autograd diff --git a/brainunit/_base.py b/brainunit/_base.py index ea4ba90..e5f33d5 100644 --- a/brainunit/_base.py +++ b/brainunit/_base.py @@ -1760,9 +1760,15 @@ def __repr__(self) -> str: return f'Unit({self.base}^{self.scale})' else: if self.factor == 1.: - return f'{self.base}^{self.scale} * {self.name}' + if self.scale == 0: + return f'{self.name}' + else: + return f'{self.base}^{self.scale} * {self.name}' else: - return f'{self.factor} * {self.base}^{self.scale} * {self.name}' + if self.scale == 0: + return f'{self.factor} * {self.name}' + else: + return f'{self.factor} * {self.base}^{self.scale} * {self.name}' def __str__(self) -> str: if self.is_fullname: @@ -1771,9 +1777,15 @@ def __str__(self) -> str: return f'Unit({self.base}^{self.scale})' else: if self.factor == 1.: - return f'{self.base}^{self.scale} * {self.dispname}' + if self.scale == 0: + return f'{self.dispname}' + else: + return f'{self.base}^{self.scale} * {self.dispname}' else: - return f'{self.factor} * {self.base}^{self.scale} * {self.dispname}' + if self.scale == 0: + return f'{self.factor} * {self.dispname}' + else: + return f'{self.factor} * {self.base}^{self.scale} * {self.dispname}' def __mul__(self, other) -> 'Unit' | Quantity: # self * other diff --git a/brainunit/_unit_common.py b/brainunit/_unit_common.py index 61c8e93..ae12394 100644 --- a/brainunit/_unit_common.py +++ b/brainunit/_unit_common.py @@ -2097,8 +2097,6 @@ ] -Unit.automatically_register_units = False - #### FUNDAMENTAL UNITS metre = Unit.create(get_or_create_dimension(m=1), "metre", "m") meter = Unit.create(get_or_create_dimension(m=1), "meter", "m") diff --git a/brainunit/sparse/_csr.py b/brainunit/sparse/_csr.py index 94316e1..2751410 100644 --- a/brainunit/sparse/_csr.py +++ b/brainunit/sparse/_csr.py @@ -101,7 +101,7 @@ def with_data(self, data: jax.Array | Quantity) -> CSR: assert data.shape == self.data.shape assert data.dtype == self.data.dtype assert get_unit(data) == get_unit(self.data) - return CSR((data, self.indices, self.indptr), shape=self.shape) + return self.__class__((data, self.indices, self.indptr), shape=self.shape) def todense(self): return csr_todense(self) diff --git a/brainunit/sparse/_csr_test.py b/brainunit/sparse/_csr_test.py index 8fe144a..67f56b9 100644 --- a/brainunit/sparse/_csr_test.py +++ b/brainunit/sparse/_csr_test.py @@ -52,6 +52,29 @@ def test_matvec(self): ) ) + def test_matvec_non_unit(self): + data = bst.random.rand(10, 20) + data = data * (data < 0.3) + + csr = u.sparse.CSR.fromdense(data) + + x = bst.random.random((10,)) + + self.assertTrue( + u.math.allclose( + x @ data, + x @ csr + ) + ) + + x = bst.random.random((20,)) + self.assertTrue( + u.math.allclose( + data @ x, + csr @ x + ) + ) + def test_matmul(self): for ux, uy in [ (u.ms, u.mV), diff --git a/dev/units_template.py b/dev/units_template.py index 23148e1..3643646 100644 --- a/dev/units_template.py +++ b/dev/units_template.py @@ -25,7 +25,6 @@ {all} -Unit.automatically_register_units = False #### FUNDAMENTAL UNITS metre = Unit.create(get_or_create_dimension(m=1), "metre", "m")