Skip to content

Unwrappables and Constraints

Basic classes and functions¤

klax.Unwrappable ¤

An abstract class representing an unwrappable object.

Unwrappables replace PyTree nodes to apply custom behavior upon unwrapping. This class is a renamed copy of paramax.AbstractUnwrappable.

Note

Models containing Unwrappables need to be unwrapped or finalized before they are callable.

__init__() ¤

Initialize self. See help(type(self)) for accurate signature.

klax.unwrap(tree) ¤

Map across a PyTree and unwrap all klax.Unwrappable objects.

This leaves all other nodes unchanged. If nested, the innermost klax.Unwrappable is unwrapped first.

Example

Enforcing positivity.

>>> import klax
>>> import jax.numpy as jnp
>>> params = klax.Parameterize(jnp.exp, jnp.zeros(3))
>>> klax.unwrap(("abc", 1, params))
('abc', 1, Array([1., 1., 1.], dtype=float32))

klax.contains_unwrappables(tree) ¤

Check if a PyTree contains instances of klax.Unwrappable.

klax.Constraint(klax.Unwrappable) ¤

An abstract constraint around a jax.Array.

A klax.Constraint is an extended version of klax.Unwrappable, that marks an array in a PyTree as constrained. It implements the known unwrap method from klax.Unwrappable and adds the apply method for the implementation of constraints that are non-differentiable or could lead to vanishing gradient during optimization.

We intend the following usage of the unwrap and apply methods:

unwrap: Identical functionality to an klax.Unwrappable. Use this for the implementation of constraints that are differentiable and shall be applied withing the training loop. E.g., our implementation of klax.fit will unwrap the model as part of the loss function. Thus, the implementation of unwrap contributes to the gradients during training. An example would be a positivity constraint, that passes the array through jax.nn.softplus upon unwrapping.

apply: New functionality added with klax.Constraint. Use this to implement non-differentiable or zero-gradient constraints that shall be applied after the parameter update and modify the wrapped array without unwrapping. Consequently, its suitable for the implementation of non-differentiable constraints, such as clamping a parameter to a range of admissible values. Apply functions should return a modified copy of Self.

Note

Models containing Constraints need to be finalized before they are callable.

Warning

Constraints objects should not be nested, as this can lead to unexpected behavior or errors. To combine the effects of two constraints, implement a custom constraint and define the combined effects via unwrap and apply.

__init__() ¤

Initialize self. See help(type(self)) for accurate signature.

klax.apply(tree) ¤

Map across a PyTree and apply all Constraints.

This leaves all other nodes unchanged.

Example

Enforcing non-negativity.

>>> import klax
>>> import jax.numpy as jnp
>>> params = klax.NonNegative(-1 * jnp.ones(3))
>>> klax.apply(("abc", 1, params))
('abc', 1, NonNegative(parameter=Array([0., 0., 0.], dtype=float32)))

klax.contains_constraints(tree) ¤

Check if a PyTree contains instances of klax.Constraint.

klax.finalize(tree) ¤

Make a model containing Constraints callable.

This function combined that functionalities of klax.apply and klax.unwrap

Warning

For models/PyTrees containing Constraints, only finalize the model after the parameter update or after training with klax.fit. This is because klax.finalize returns an unwrapped PyTree where all constraints and wrappers have been applied. However, this also means that the returned PyTree is no longer constrained.

If you want to call a model that you want to fit afterwards, we recommend using a different name for the finalized model. For example::

>>> finalized_model = klax.finalize(model)
>>> y = finalzed_model(x)            # Call finalized model
>>> model, history = fit(model, ...) # Continue training with constrained model

Unwrappables¤

klax.Parameterize(klax.Unwrappable) ¤

Unwrap an object by calling fn with args and `kwargs.

All of fn, *args and **kwargs may contain trainable parameters.

Note

Unwrapping typically occurs after model initialization. Therefore, if the klax.Parameterize object may be created in a vectorized context, we recommend ensuring that fn still unwraps correctly, e.g. by supporting broadcasting.

Example
>>> from klax import Parameterize, unwrap
>>> import jax.numpy as jnp
>>> positive = Parameterize(jnp.exp, jnp.zeros(3))
>>> unwrap(positive)  # Applies exp on unwrapping
Array([1., 1., 1.], dtype=float32)
PARAMETER DESCRIPTION
fn

Callable to call with args, and kwargs.

TYPE: Callable[..., ~T]

*args

Positional arguments to pass to fn.

TYPE: Any

**kwargs

Keyword arguments to pass to fn.

TYPE: Any

__init__(fn, *args, **kwargs) ¤

klax.NonTrainable(klax.Unwrappable) ¤

Applies stop gradient to all ArrayLike leaves before unwrapping.

See also klax.non_trainable, which is probably a generally preferable way to achieve similar behaviour, which wraps the ArrayLike leaves directly, rather than the tree. Useful to mark PyTrees (Arrays, Modules, etc.) as frozen/non-trainable. Note that the underlying parameters may still be impacted by regularization, so it is generally advised to use this as a suggestively named class for filtering parameters.

__init__(tree) ¤

Initialize self. See help(type(self)) for accurate signature.

klax.non_trainable(tree) ¤

Freeze parameters by wrapping inexact arrays.

This function wraps a klax.NonTrainable wrapper around every inexact array or klax.Constraint in the PyTree.

Note

Regularization is likely to apply before unwrapping. To avoid regularization impacting non-trainable parameters, they should be filtered out, for example using:

>>> eqx.partition(
...     ...,
...     is_leaf=lambda leaf: isinstance(leaf, (NonTrainable, Constraint)),
... )

Wrapping the arrays in a model rather than the entire tree is often preferable, allowing easier access to attributes compared to wrapping the entire tree.

PARAMETER DESCRIPTION
tree

The PyTree.

TYPE: PyTree

klax.Symmetric(klax.Unwrappable) ¤

Ensures symmetry of a square matrix upon unwrapping.

Warning

Wrapping Symmetric around parameters that are already wrapped may lead to unexpected behavior and is generally discouraged.

__init__(parameter) ¤

Initialize a Symmetric wrapper.

PARAMETER DESCRIPTION
parameter

To be wrapped matrix array of shape (..., N, N).

TYPE: Array

klax.SkewSymmetric(klax.Unwrappable) ¤

Ensures skew-symmetry of a square matrix upon unwrapping.

Warning

Wrapping SkewSymmetric around parameters that are already wrapped may lead to unexpected behavior and is generally discouraged.

__init__(parameter) ¤

Initialize a SkewSymmetric wrapper.

PARAMETER DESCRIPTION
parameter

Wrapped matrix as array of shape (..., N, N).

TYPE: Array


Constraints¤

klax.NonNegative(klax.Constraint) ¤

Applies a non-negative constraint.

PARAMETER DESCRIPTION
parameter

The jax.Array that is to be made non-negative upon unwrapping and applying.

TYPE: Array

__init__(parameter) ¤

Initialize self. See help(type(self)) for accurate signature.