Skip to content

Releases: patrick-kidger/equinox

Equinox v0.12.1

27 Mar 22:12
Compare
Choose a tag to compare

Hotfix to work around JAX bug jax-ml/jax#27545 (#988)

Equinox v0.12.0 release notes here: https://github.com/patrick-kidger/equinox/releases/tag/v0.12.0

Full Changelog: v0.12.0...v0.12.1

Equinox v0.12.0

27 Mar 15:03
Compare
Choose a tag to compare
  • BREAKING CHANGE: eqx.field(converter=...) now runs after __post_init__ instead of before. This dramatically simplifies some internals and improves compatibility with other libraries. (#969, #975)
  • Fixed warning when using eqx.filter_closure_convert with JAX 0.5.3. (#979, #981)
  • Reduced overhead of eqx.filter_jit. (Thanks @ZagButNoZig! #973, #980, #983)
  • Shiny new documentation!

New Contributors

Full Changelog: v0.11.12...v0.12.0

Equinox v0.11.12

25 Feb 18:57
Compare
Choose a tag to compare

This is primarily a compatibility release.

New Contributors

Full Changelog: v0.11.11...v0.11.12

Equinox v0.11.11

24 Dec 12:08
Compare
Choose a tag to compare

JAX 0.4.38 moved a number of APIs with a deprecation warning, e.g. jax.core.Jaxpr -> jax.extend.core.Jaxpr. With this release we've updated and are back to being warning-free under this JAX release! (Thanks @FFroehlich @DrJessop, #913, #915, #917)

New Contributors

Full Changelog: v0.11.10...v0.11.11

Equinox v0.11.10

08 Dec 02:44
Compare
Choose a tag to compare

This is a JAX 0.4.36 compatibility release.

With this release, JAX changed how custom primitive rules are called (they are always called, instead of only when the data requires them to be). That requires some updates in Equinox to avoid crashes in the downstream ecosystem. (patrick-kidger/diffrax#532, jax-ml/jax#25289 + links therein.)

Full Changelog: v0.11.9...v0.11.10

Equinox v0.11.9

24 Nov 15:01
Compare
Choose a tag to compare

This is a (important) bugfix release.

  • Fix filter_vmap with out_axes!=0,1 producing outputs with the wrong axis order. (Thanks @remifan! #900, #901)

Full Changelog: v0.11.8...v0.11.9

Equinox v0.11.8

18 Oct 17:19
689c35a
Compare
Choose a tag to compare

The main thing for this release is JAX 0.4.34 compatibility -- JAX introduced breaking changes in this release that we are now compatible with. (#871)

Bugfixes

  • Accessing the concrete implementation of an abstract class attribute within __init_subclass__ should no longer crash. (Plus probably better-behaved __init_subclass__ overall.)

Miscellaneous

  • JAX 0.4.33 introduced a change that broke eqx.error_if's nice displaying of error message. With this release then we are back to having nice error messages again!
  • eqx.nn.StateIndex can now be passed through jax.jit (and not just eqx.filter_jit). (Thanks @NeilGirdhar! #843)
  • Normalization layers now upcast to at least 32-bit precision. (Thanks @AakashKumarNain! #876)
  • Poetry has a bug in its interpretation of ~= version constraints. We now work around that for better compatibility with certain kinds of Poetry installations. (Thanks @norpadon! #878)

Documentation

New Contributors

Full Changelog: v0.11.7...v0.11.8

Equinox v0.11.7

18 Sep 17:10
31b554f
Compare
Choose a tag to compare

Quick release. JAX 0.4.32 / 0.4.33 just introduced a breaking change; this release ensures Equinox is compatible with this. (#856)

Full Changelog: v0.11.6...v0.11.7

Equinox v0.11.6

14 Sep 09:34
Compare
Choose a tag to compare

This is primarily a bug fix release.

  • Runtime error messages (those from eqx.error_if, in particular when wrapped with eqx.filter_jit) should now be compatible with PyCharm's debugger, and with certain multithreaded contexts. (Thanks @adam-hartshorne, @dlwh! #828, #844, #849)

  • Marking a jax.Array or np.ndarray as an eqx.field(static=True) will now raise a warning. This was technically okay as long as you use it in certain very narrow contexts (e.g. to smuggle it into a JIT'd region without being traced), but in practice it was nearly always just a common new-user footgun. (Thanks @lockwo! #800)

  • Using eqx.tree_at for replacing empty tuples is improved. (Thanks @danielward27! #818, #819)

  • eqx.nn.RotaryEmbedding no longer promote input dtypes to at least float32. (Thanks @knyazer! #836)

  • Mypy now understands that eqx.Modules are dataclasses. (Pyright always did, but mypy needed a slightly different approach to appreciate this fact.) (Thanks @NeilGirdhar! #822)

  • Multiple eqx.Modules participating in co-operative multiple inheritance (at least 5 inheriting from each other seem to be necessary?), with some of them overriding the __post_init__s of others, should now follow their expected resolution order. (Thanks @NeilGirdhar! #832, #834)

  • We now have a .editorconfig file, (thanks @NeilGirdhar! #821)

  • Doc improvements. (Thanks @garymm, @ColCarroll! #804, #805)

New Contributors

Full Changelog: v0.11.5...v0.11.6

Equinox v0.11.5

18 Aug 19:11
Compare
Choose a tag to compare

JAX compatibility

Recent versions of JAX (0.4.28+) have made some changes to:

  • Hashing of tracers;
  • Tree-map'ing over Nones;
  • Callbacks;
  • Pretty-printing.

With this update, we should now be compatible with both old and new versions of JAX: this fixes both some new crashes, and some new warnings. (#719, #724, #753, #758, thanks @jakevdp, @hawkinsp!)

Better errors

  • The error messages from eqx.error_if are now substantially more informative: they include traceback information including the stack, and mention the availability of the EQX_ON_ERROR variable. We also do a much better job hiding the large unhelpful printouts that XLA gives by default. (#785, #803)

  • The default value of EQX_ON_ERROR_BREAKPOINT_FRAMES is now 1. (#777) The impact of this is that using eqx.error_if alongside EQX_ON_ERROR=breakpoint will now:

    • reliably always open a debugger, rather than sometimes crashing at trace-time due to upstream JAX bug #16732.
    • however, by default the debugger will no longer include any additional stack frames above it (accessed via u).
    • much of the above is now explained in a printed-out informative message prior to the debugger opening.

Bugfixes

  • eqx.filter_{jacfwd, jacrev} now only apply filtering to their inputs but not their outputs. Previously this was problematic as there was no way to represent static-input-by-static-output in the returned Jacobian, so pieces were silently dropped. (#734, thanks @lockwo!)

  • eqx.tree_at can now be used to replace empty tuples. (#715, #717, #722, thanks @lockwo!)

  • eqx.filter_custom_jvp no longer raises a trace-time crash in some scenarios in which its **kwargs were erroneously counted as having tangents. (#745 (comment), #749)

  • No longer getting a trace-time crash when doing a particular combination of vmap + autodiff + checkpointed while loops. This occurred when using optimistix.BFGS around diffrax.diffeqsolve. (#777)

  • Fixed a trace-time crash when:

    • using a checkpointed while loop...
    • ...with a body function that has a closed-over tracer...
    • ...and that closed-over tracer is differentiated...
    • ...and there are no other closed-over tracers that are differentiated...
    • ...and the dependency on that tracer is only linear.
    • (patrick-kidger/diffrax#387 (comment), #752, thanks @dkweiss31!)
  • Fixed a trace-time crash when composing the grad of vmap of lineax.linear_solve. (patrick-kidger/lineax#101, #795, thanks @rhacking!)

  • eqx.nn.RMSNorm now uses at least 32-bit precision for numerical stability (#723, thanks @AakashKumarNain!)

New features

Other changes

New Contributors

Full Changelog: v0.11.4...v0.11.5