Unwrappables and Constraints
Basic classes and functions¤
klax.Unwrappable
¤
An abstract class representing an unwrappable object.
Unwrappables replace PyTree nodes to apply custom behavior upon unwrapping.
This class is a renamed copy of paramax.AbstractUnwrappable.
Note
Models containing Unwrappables need to be unwrapped or finalized before they are callable.
__init__()
¤
Initialize self. See help(type(self)) for accurate signature.
klax.unwrap(tree)
¤
Map across a PyTree and unwrap all klax.Unwrappable objects.
This leaves all other nodes unchanged. If nested, the innermost
klax.Unwrappable is unwrapped first.
Example
Enforcing positivity.
>>> import klax
>>> import jax.numpy as jnp
>>> params = klax.Parameterize(jnp.exp, jnp.zeros(3))
>>> klax.unwrap(("abc", 1, params))
('abc', 1, Array([1., 1., 1.], dtype=float32))
klax.contains_unwrappables(tree)
¤
Check if a PyTree contains instances of klax.Unwrappable.
klax.Constraint(klax.Unwrappable)
¤
An abstract constraint around a jax.Array.
A klax.Constraint is an extended version of klax.Unwrappable,
that marks an array in a PyTree as constrained. It implements the known
unwrap method from klax.Unwrappable and adds the apply
method for the implementation of constraints that are
non-differentiable or could lead to vanishing gradient during optimization.
We intend the following usage of the unwrap and apply methods:
unwrap: Identical functionality to an klax.Unwrappable.
Use this for the implementation of constraints that are differentiable
and shall be applied withing the training loop. E.g., our
implementation of klax.fit will unwrap the model as part of the
loss function. Thus, the implementation of unwrap contributes to the
gradients during training. An example would be a positivity constraint,
that passes the array through jax.nn.softplus upon unwrapping.
apply: New functionality added with klax.Constraint.
Use this to implement non-differentiable or zero-gradient constraints
that shall be applied after the parameter update and modify the
wrapped array without unwrapping. Consequently, its suitable for the
implementation of non-differentiable constraints, such as clamping
a parameter to a range of admissible values. Apply functions should
return a modified copy of Self.
Note
Models containing Constraints need to be finalized before they are callable.
Warning
Constraints objects should not be nested, as this
can lead to unexpected behavior or errors. To combine the effects of
two constraints, implement a custom constraint and define the combined
effects via unwrap and apply.
__init__()
¤
Initialize self. See help(type(self)) for accurate signature.
klax.apply(tree)
¤
Map across a PyTree and apply all Constraints.
This leaves all other nodes unchanged.
Example
Enforcing non-negativity.
>>> import klax
>>> import jax.numpy as jnp
>>> params = klax.NonNegative(-1 * jnp.ones(3))
>>> klax.apply(("abc", 1, params))
('abc', 1, NonNegative(parameter=Array([0., 0., 0.], dtype=float32)))
klax.contains_constraints(tree)
¤
Check if a PyTree contains instances of klax.Constraint.
klax.finalize(tree)
¤
Make a model containing Constraints callable.
This function combined that functionalities of klax.apply and
klax.unwrap
Warning
For models/PyTrees containing Constraints, only
finalize the model after the parameter update or after
training with klax.fit. This is because klax.finalize
returns an unwrapped PyTree where all constraints and wrappers have
been applied. However, this also means that the returned PyTree is no
longer constrained.
If you want to call a model that you want to fit afterwards, we recommend using a different name for the finalized model. For example::
>>> finalized_model = klax.finalize(model)
>>> y = finalzed_model(x) # Call finalized model
>>> model, history = fit(model, ...) # Continue training with constrained model
Unwrappables¤
klax.Parameterize(klax.Unwrappable)
¤
Unwrap an object by calling fn with args and `kwargs.
All of fn, *args and **kwargs may contain trainable parameters.
Note
Unwrapping typically occurs after model initialization. Therefore, if
the klax.Parameterize object may be created in a vectorized
context, we recommend ensuring that fn still unwraps correctly,
e.g. by supporting broadcasting.
Example
>>> from klax import Parameterize, unwrap
>>> import jax.numpy as jnp
>>> positive = Parameterize(jnp.exp, jnp.zeros(3))
>>> unwrap(positive) # Applies exp on unwrapping
Array([1., 1., 1.], dtype=float32)
| PARAMETER | DESCRIPTION |
|---|---|
fn
|
Callable to call with args, and kwargs.
TYPE:
|
*args
|
Positional arguments to pass to fn.
TYPE:
|
**kwargs
|
Keyword arguments to pass to fn.
TYPE:
|
__init__(fn, *args, **kwargs)
¤
klax.NonTrainable(klax.Unwrappable)
¤
Applies stop gradient to all ArrayLike leaves before unwrapping.
See also klax.non_trainable, which is probably a generally preferable
way to achieve similar behaviour, which wraps the ArrayLike leaves
directly, rather than the tree. Useful to mark PyTrees (Arrays, Modules,
etc.) as frozen/non-trainable. Note that the underlying parameters may
still be impacted by regularization, so it is generally advised to use this
as a suggestively named class for filtering parameters.
__init__(tree)
¤
Initialize self. See help(type(self)) for accurate signature.
klax.non_trainable(tree)
¤
Freeze parameters by wrapping inexact arrays.
This function wraps a klax.NonTrainable wrapper around every inexact
array or klax.Constraint in the PyTree.
Note
Regularization is likely to apply before unwrapping. To avoid regularization impacting non-trainable parameters, they should be filtered out, for example using:
>>> eqx.partition(
... ...,
... is_leaf=lambda leaf: isinstance(leaf, (NonTrainable, Constraint)),
... )
Wrapping the arrays in a model rather than the entire tree is often preferable, allowing easier access to attributes compared to wrapping the entire tree.
| PARAMETER | DESCRIPTION |
|---|---|
tree
|
The PyTree.
TYPE:
|
klax.Symmetric(klax.Unwrappable)
¤
Ensures symmetry of a square matrix upon unwrapping.
Warning
Wrapping Symmetric around parameters that are
already wrapped may lead to unexpected behavior and is
generally discouraged.
__init__(parameter)
¤
Initialize a Symmetric wrapper.
| PARAMETER | DESCRIPTION |
|---|---|
parameter
|
To be wrapped matrix array of shape (..., N, N).
TYPE:
|
klax.SkewSymmetric(klax.Unwrappable)
¤
Ensures skew-symmetry of a square matrix upon unwrapping.
Warning
Wrapping SkewSymmetric around parameters that are
already wrapped may lead to unexpected behavior and is
generally discouraged.
__init__(parameter)
¤
Initialize a SkewSymmetric wrapper.
| PARAMETER | DESCRIPTION |
|---|---|
parameter
|
Wrapped matrix as array of shape (..., N, N).
TYPE:
|
Constraints¤
klax.NonNegative(klax.Constraint)
¤
Applies a non-negative constraint.
| PARAMETER | DESCRIPTION |
|---|---|
parameter
|
The
TYPE:
|
__init__(parameter)
¤
Initialize self. See help(type(self)) for accurate signature.