Unwrappables and Constraints
Basic classes and functions¤
klax.Unwrappable
¤
An abstract class representing an unwrappable object.
Unwrappables replace PyTree nodes to apply custom behavior upon unwrapping.
This class is a renamed copy of paramax.AbstractUnwrappable
.
Note
Models containing Unwrappables need to be unwrapped or finalized before they are callable.
__init__()
¤
Initialize self. See help(type(self)) for accurate signature.
klax.unwrap(tree)
¤
Map across a PyTree and unwrap all klax.Unwrappable
objects.
This leaves all other nodes unchanged. If nested, the innermost
klax.Unwrappable
is unwrapped first.
Example
Enforcing positivity.
>>> import klax
>>> import jax.numpy as jnp
>>> params = klax.Parameterize(jnp.exp, jnp.zeros(3))
>>> klax.unwrap(("abc", 1, params))
('abc', 1, Array([1., 1., 1.], dtype=float32))
klax.contains_unwrappables(tree)
¤
Check if a PyTree contains instances of klax.Unwrappable
.
klax.Constraint(klax.Unwrappable)
¤
An abstract constraint around a jax.Array
.
A klax.Constraint
is an extended version of klax.Unwrappable
,
that marks an array in a PyTree as constrained. It implements the known
unwrap
method from klax.Unwrappable
and adds the apply
method for the implementation of constraints that are
non-differentiable or could lead to vanishing gradient during optimization.
We intend the following usage of the unwrap
and apply
methods:
unwrap
: Identical functionality to an klax.Unwrappable
.
Use this for the implementation of constraints that are differentiable
and shall be applied withing the training loop. E.g., our
implementation of klax.fit
will unwrap the model as part of the
loss function. Thus, the implementation of unwrap
contributes to the
gradients during training. An example would be a positivity constraint,
that passes the array through jax.nn.softplus
upon unwrapping.
apply
: New functionality added with klax.Constraint
.
Use this to implement non-differentiable or zero-gradient constraints
that shall be applied after
the parameter update and modify the
wrapped array without unwrapping. Consequently, its suitable for the
implementation of non-differentiable
constraints, such as clamping
a parameter to a range of admissible values. Apply functions should
return a modified copy of Self
.
Note
Models containing Constraints need to be finalized before they are callable.
Warning
Constraints objects should not be nested, as this
can lead to unexpected behavior or errors. To combine the effects of
two constraints, implement a custom constraint and define the combined
effects via unwrap
and apply
.
__init__()
¤
Initialize self. See help(type(self)) for accurate signature.
klax.apply(tree)
¤
Map across a PyTree and apply all Constraints.
This leaves all other nodes unchanged.
Example
Enforcing non-negativity.
>>> import klax
>>> import jax.numpy as jnp
>>> params = klax.NonNegative(-1 * jnp.ones(3))
>>> klax.apply(("abc", 1, params))
('abc', 1, NonNegative(parameter=Array([0., 0., 0.], dtype=float32)))
klax.contains_constraints(tree)
¤
Check if a PyTree contains instances of klax.Constraint
.
klax.finalize(tree)
¤
Make a model containing Constraints callable.
This function combined that functionalities of klax.apply
and
klax.unwrap
Warning
For models/PyTrees containing Constraints, only
finalize
the model after the parameter update or after
training with klax.fit
. This is because klax.finalize
returns an unwrapped PyTree where all constraints and wrappers have
been applied. However, this also means that the returned PyTree is no
longer constrained.
If you want to call a model that you want to fit afterwards, we recommend using a different name for the finalized model. For example::
>>> finalized_model = klax.finalize(model)
>>> y = finalzed_model(x) # Call finalized model
>>> model, history = fit(model, ...) # Continue training with constrained model
Unwrappables¤
klax.Parameterize(klax.Unwrappable)
¤
Unwrap an object by calling fn
with args
and `kwargs
.
All of fn
, *args
and **kwargs
may contain trainable parameters.
Note
Unwrapping typically occurs after model initialization. Therefore, if
the klax.Parameterize
object may be created in a vectorized
context, we recommend ensuring that fn
still unwraps correctly,
e.g. by supporting broadcasting.
Example
>>> from klax import Parameterize, unwrap
>>> import jax.numpy as jnp
>>> positive = Parameterize(jnp.exp, jnp.zeros(3))
>>> unwrap(positive) # Applies exp on unwrapping
Array([1., 1., 1.], dtype=float32)
PARAMETER | DESCRIPTION |
---|---|
fn
|
Callable to call with args, and kwargs.
TYPE:
|
*args
|
Positional arguments to pass to fn.
TYPE:
|
**kwargs
|
Keyword arguments to pass to fn.
TYPE:
|
__init__(fn, *args, **kwargs)
¤
klax.NonTrainable(klax.Unwrappable)
¤
Applies stop gradient to all ArrayLike leaves before unwrapping.
See also klax.non_trainable
, which is probably a generally preferable
way to achieve similar behaviour, which wraps the ArrayLike leaves
directly, rather than the tree. Useful to mark PyTrees (Arrays, Modules,
etc.) as frozen/non-trainable. Note that the underlying parameters may
still be impacted by regularization, so it is generally advised to use this
as a suggestively named class for filtering parameters.
__init__(tree)
¤
Initialize self. See help(type(self)) for accurate signature.
klax.non_trainable(tree)
¤
Freeze parameters by wrapping inexact arrays.
This function wraps a klax.NonTrainable
wrapper around every inexact
array or klax.Constraint
in the PyTree.
Note
Regularization is likely to apply before unwrapping. To avoid regularization impacting non-trainable parameters, they should be filtered out, for example using:
>>> eqx.partition(
... ...,
... is_leaf=lambda leaf: isinstance(leaf, (NonTrainable, Constraint)),
... )
Wrapping the arrays in a model rather than the entire tree is often preferable, allowing easier access to attributes compared to wrapping the entire tree.
PARAMETER | DESCRIPTION |
---|---|
tree
|
The PyTree.
TYPE:
|
klax.Symmetric(klax.Unwrappable)
¤
Ensures symmetry of a square matrix upon unwrapping.
Warning
Wrapping Symmetric
around parameters that are
already wrapped may lead to unexpected behavior and is
generally discouraged.
__init__(parameter)
¤
Initialize a Symmetric
wrapper.
PARAMETER | DESCRIPTION |
---|---|
parameter
|
To be wrapped matrix array of shape (..., N, N).
TYPE:
|
klax.SkewSymmetric(klax.Unwrappable)
¤
Ensures skew-symmetry of a square matrix upon unwrapping.
Warning
Wrapping SkewSymmetric
around parameters that are
already wrapped may lead to unexpected behavior and is
generally discouraged.
__init__(parameter)
¤
Initialize a SkewSymmetric
wrapper.
PARAMETER | DESCRIPTION |
---|---|
parameter
|
Wrapped matrix as array of shape (..., N, N).
TYPE:
|
Constraints¤
klax.NonNegative(klax.Constraint)
¤
Applies a non-negative constraint.
PARAMETER | DESCRIPTION |
---|---|
parameter
|
The
TYPE:
|
__init__(parameter)
¤
Initialize self. See help(type(self)) for accurate signature.