Skip to content

Klax¤

A lightweight machine learning package for computational mechanics built on JAX.


Warning

Klax is still in early development and will likely see significant API changes in the near future. Likewise, the documentation is still under heavy development.

Overview¤

Klax provides:

  • Specialized machine learning architectures: MLPs with customizable initialization, fully and partially input convex neural networks (ICNNs), matrix-valued neural networks, e.g., skew symmetric matrices, and more.
  • Parameter constraints: Differentiable and non-differentiable parameter constraints, e.g., non-negativity and symmetry constraints.
  • Highly customizable training and logging utlities: Methods for calibrating abitrary trainable PyTrees with custom loss functions, callbacks, and metrics logging.
  • Full JAX compatibility: Seamless integration with JAX's automatic differentiation and acceleration

Klax is build around the highly successfull JAX, Equinox, and Optax projects and designed to be minimally intrusive. All models inherit directly from equinox.Module without additional abstraction layers, ensuring full compatibility with the ecosystem.

The constraint system is derived from Paramax's paramax.AbstractUnwrappable, extending it to support non-differentiable/zero-gradient parameter constraints such as ReLU-based non-negativity constraints.

The training utilities (klax.fit, klax.Loss, klax.Callback) are designed to operate on arbitrarily shaped model and data PyTrees, fully utilizing the flexibility of JAX and Equinox. While they cover most common machine learning use cases, as well as our specialized requirements, they remain entirely optional. The meachine learning architectures implemented in Klax work seamlessly in any JAX-compatible training loop.

Currently Klax's training utilities are built around Optax, but different optimization libraries could be supported in the future if desired.

Installation¤

Klax can be installed via pip using

pip install klax

If you want to add the latest release to your Python uv project run

uv add klax

or directly install the main branch via

uv add "klax @ git+https://github.com/Drenderer/klax.git@main"

Getting Started¤

If you're new to the JAX ecosystem, we recommend looking at the JAX Quickstart guide, which provides a concise overview of JAX's core functionality. You may also take a look at the Equinox documentation on which all Klax models are based.

Finally, checkout our examples section.

Citation¤

Acknowledgement¤

Klax is built on top of several powerful frameworks:

JAX - For automatic differentiation and acceleration
Equinox - For neural network primitives
Optax - For optimization utilities
Paramax - For constraints (We decided to embed Paramax directly into Klax due to the need for non-differentiable constraints).