Releases: patrick-kidger/equinox
Equinox v0.12.1
Hotfix to work around JAX bug jax-ml/jax#27545 (#988)
Equinox v0.12.0 release notes here: https://github.com/patrick-kidger/equinox/releases/tag/v0.12.0
Full Changelog: v0.12.0...v0.12.1
Equinox v0.12.0
- BREAKING CHANGE:
eqx.field(converter=...)
now runs after__post_init__
instead of before. This dramatically simplifies some internals and improves compatibility with other libraries. (#969, #975) - Fixed warning when using
eqx.filter_closure_convert
with JAX 0.5.3. (#979, #981) - Reduced overhead of
eqx.filter_jit
. (Thanks @ZagButNoZig! #973, #980, #983) - Shiny new documentation!
New Contributors
- @ZagButNoZig made their first contribution in #980
Full Changelog: v0.11.12...v0.12.0
Equinox v0.11.12
This is primarily a compatibility release.
- Fixes for compatibility with JAX 0.5.1 (#959, #960).
- Fixes for compatibility with pyright 1.1.394 (#956, #960).
eqx.nn.Linear(0, ...)
no longer crashes runs (propagating zero-size tensors). (Thanks @aseyboldt! #950)- Pretty-printing (
eqx.tree_pformat
andeqx.tree_pprint
) now uses the new Wadler-Lindig pretty-printing library. (#924) - Many doc improvements (Thanks @matthewfeickert @TugdualKerjan @struan-robertson @danielward27! #930, #937, #941, #950)
New Contributors
- @matthewfeickert made their first contribution in #930
- @struan-robertson made their first contribution in #941
- @aseyboldt made their first contribution in #950
- @TugdualKerjan made their first contribution in #947
Full Changelog: v0.11.11...v0.11.12
Equinox v0.11.11
JAX 0.4.38 moved a number of APIs with a deprecation warning, e.g. jax.core.Jaxpr -> jax.extend.core.Jaxpr
. With this release we've updated and are back to being warning-free under this JAX release! (Thanks @FFroehlich @DrJessop, #913, #915, #917)
New Contributors
Full Changelog: v0.11.10...v0.11.11
Equinox v0.11.10
This is a JAX 0.4.36 compatibility release.
With this release, JAX changed how custom primitive rules are called (they are always called, instead of only when the data requires them to be). That requires some updates in Equinox to avoid crashes in the downstream ecosystem. (patrick-kidger/diffrax#532, jax-ml/jax#25289 + links therein.)
Full Changelog: v0.11.9...v0.11.10
Equinox v0.11.9
This is a (important) bugfix release.
- Fix filter_vmap with out_axes!=0,1 producing outputs with the wrong axis order. (Thanks @remifan! #900, #901)
Full Changelog: v0.11.8...v0.11.9
Equinox v0.11.8
The main thing for this release is JAX 0.4.34 compatibility -- JAX introduced breaking changes in this release that we are now compatible with. (#871)
Bugfixes
- Accessing the concrete implementation of an abstract class attribute within
__init_subclass__
should no longer crash. (Plus probably better-behaved__init_subclass__
overall.)
Miscellaneous
- JAX 0.4.33 introduced a change that broke
eqx.error_if
's nice displaying of error message. With this release then we are back to having nice error messages again! eqx.nn.StateIndex
can now be passed throughjax.jit
(and not justeqx.filter_jit
). (Thanks @NeilGirdhar! #843)- Normalization layers now upcast to at least 32-bit precision. (Thanks @AakashKumarNain! #876)
- Poetry has a bug in its interpretation of
~=
version constraints. We now work around that for better compatibility with certain kinds of Poetry installations. (Thanks @norpadon! #878)
Documentation
- Updated CNN example to work with recent JAX versions. (Thanks @pasq-cat! #880, #881)
- Update
eqx.tree_at
documentation for clarity. (Thanks @jeertmans! #872, #874, #877)
New Contributors
Full Changelog: v0.11.7...v0.11.8
Equinox v0.11.7
Quick release. JAX 0.4.32 / 0.4.33 just introduced a breaking change; this release ensures Equinox is compatible with this. (#856)
Full Changelog: v0.11.6...v0.11.7
Equinox v0.11.6
This is primarily a bug fix release.
-
Runtime error messages (those from
eqx.error_if
, in particular when wrapped witheqx.filter_jit
) should now be compatible with PyCharm's debugger, and with certain multithreaded contexts. (Thanks @adam-hartshorne, @dlwh! #828, #844, #849) -
Marking a
jax.Array
ornp.ndarray
as aneqx.field(static=True)
will now raise a warning. This was technically okay as long as you use it in certain very narrow contexts (e.g. to smuggle it into a JIT'd region without being traced), but in practice it was nearly always just a common new-user footgun. (Thanks @lockwo! #800) -
Using
eqx.tree_at
for replacing empty tuples is improved. (Thanks @danielward27! #818, #819) -
eqx.nn.RotaryEmbedding
no longer promote input dtypes to at least float32. (Thanks @knyazer! #836) -
Mypy now understands that
eqx.Module
s are dataclasses. (Pyright always did, but mypy needed a slightly different approach to appreciate this fact.) (Thanks @NeilGirdhar! #822) -
Multiple
eqx.Module
s participating in co-operative multiple inheritance (at least 5 inheriting from each other seem to be necessary?), with some of them overriding the__post_init__
s of others, should now follow their expected resolution order. (Thanks @NeilGirdhar! #832, #834) -
We now have a
.editorconfig
file, (thanks @NeilGirdhar! #821) -
Doc improvements. (Thanks @garymm, @ColCarroll! #804, #805)
New Contributors
- @garymm made their first contribution in #804
- @ColCarroll made their first contribution in #805
- @NeilGirdhar made their first contribution in #823
Full Changelog: v0.11.5...v0.11.6
Equinox v0.11.5
JAX compatibility
Recent versions of JAX (0.4.28+) have made some changes to:
- Hashing of tracers;
- Tree-map'ing over Nones;
- Callbacks;
- Pretty-printing.
With this update, we should now be compatible with both old and new versions of JAX: this fixes both some new crashes, and some new warnings. (#719, #724, #753, #758, thanks @jakevdp, @hawkinsp!)
Better errors
-
The error messages from
eqx.error_if
are now substantially more informative: they include traceback information including the stack, and mention the availability of theEQX_ON_ERROR
variable. We also do a much better job hiding the large unhelpful printouts that XLA gives by default. (#785, #803) -
The default value of
EQX_ON_ERROR_BREAKPOINT_FRAMES
is now1
. (#777) The impact of this is that usingeqx.error_if
alongsideEQX_ON_ERROR=breakpoint
will now:- reliably always open a debugger, rather than sometimes crashing at trace-time due to upstream JAX bug #16732.
- however, by default the debugger will no longer include any additional stack frames above it (accessed via
u
). - much of the above is now explained in a printed-out informative message prior to the debugger opening.
Bugfixes
-
eqx.filter_{jacfwd, jacrev}
now only apply filtering to their inputs but not their outputs. Previously this was problematic as there was no way to represent static-input-by-static-output in the returned Jacobian, so pieces were silently dropped. (#734, thanks @lockwo!) -
eqx.tree_at
can now be used to replace empty tuples. (#715, #717, #722, thanks @lockwo!) -
eqx.filter_custom_jvp
no longer raises a trace-time crash in some scenarios in which its**kwargs
were erroneously counted as having tangents. (#745 (comment), #749) -
No longer getting a trace-time crash when doing a particular combination of vmap + autodiff + checkpointed while loops. This occurred when using
optimistix.BFGS
arounddiffrax.diffeqsolve
. (#777) -
Fixed a trace-time crash when:
- using a checkpointed while loop...
- ...with a body function that has a closed-over tracer...
- ...and that closed-over tracer is differentiated...
- ...and there are no other closed-over tracers that are differentiated...
- ...and the dependency on that tracer is only linear.
- (patrick-kidger/diffrax#387 (comment), #752, thanks @dkweiss31!)
-
Fixed a trace-time crash when composing the grad of vmap of
lineax.linear_solve
. (patrick-kidger/lineax#101, #795, thanks @rhacking!) -
eqx.nn.RMSNorm
now uses at least 32-bit precision for numerical stability (#723, thanks @AakashKumarNain!)
New features
-
eqx.nn.{Linear,Conv,GRUCell,LSTMCell}
now support complex dtypes (#765, thanks @ChenAo-Phys!) -
Added
eqx.nn.RotaryEmbedding(..., theta=...)
. (#735, thanks @Artur-Galstyan!)
Other changes
-
Several doc fixes. (#708, #731, #733, #747, #750, #757 + several other PRs, thanks @Artur-Galstyan, @matteoguarrera, @lockwo, @nasyxx!)
-
Several internal test fixes as downstream libraries have changed slightly. (#740, #742 + several other PRs, big thanks to @GaetanLepage for reporting many of these!)
-
There is now a Mistral 7B implementation using JAX+Equinox available over in AakashKumarNain/mistral_jax!
New Contributors
- @nasyxx made their first contribution in #708
- @jakevdp made their first contribution in #724
- @matteoguarrera made their first contribution in #739
Full Changelog: v0.11.4...v0.11.5