klax.nn.MLP
¤
Standard Multi-Layer Perceptron; also known as a feed-forward network.
This class is modified form equinox.nn.MLP
to allow for custom initialization and different node numbers in the hidden
layers.
__init__(in_size, out_size, width_sizes, weight_init=init, bias_init=zeros, activation=softplus, final_activation=<lambda>, use_bias=True, use_final_bias=True, weight_wrap=None, bias_wrap=None, dtype=None, *, key)
¤
Initialize MLP.
| PARAMETER | DESCRIPTION |
|---|---|
in_size
|
The input size. The input to the module should be a vector
of shape
TYPE:
|
out_size
|
The output size. The output from the module will be a
vector of shape
TYPE:
|
width_sizes
|
The sizes of each hidden layer in a list.
TYPE:
|
weight_init
|
The weight initializer of type
TYPE:
|
bias_init
|
The bias initializer of type
TYPE:
|
activation
|
The activation function after each hidden layer.
TYPE:
|
final_activation
|
The activation function after the output layer.
TYPE:
|
use_bias
|
Whether to add on a bias to internal layers.
TYPE:
|
use_final_bias
|
Whether to add on a bias to the final layer.
TYPE:
|
weight_wrap
|
An optional wrapper that is passed to all weights.
TYPE:
|
bias_wrap
|
An optional wrapper that is passed to all biases.
TYPE:
|
dtype
|
The dtype to use for all the weights and biases in this MLP.
Defaults to either
TYPE:
|
key
|
A
TYPE:
|
Note
Note that in_size also supports the string "scalar" as a
special value. In this case the input to the module should be of
shape ().
Likewise out_size can also be a string "scalar", in which case
the output from the module will have shape ().
__call__(x, *, key=None)
¤
Forward pass through MLP.
| PARAMETER | DESCRIPTION |
|---|---|
x
|
A JAX array with shape
TYPE:
|
key
|
Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)
TYPE:
|
| RETURNS | DESCRIPTION |
|---|---|
Array
|
A JAX array with shape |
Array
|
|
Input Convex Neural Networks¤
Input Convex Neural Networks (ICNNs) are a family of neural network architectures introduced by
Amos et al. (2017) that are designed to produce outputs that
are provably convex with respect to (some of) their inputs. Klax provides two variants:
the Fully Input Convex Neural Network (FICNN) and the
Partially Input Convex Neural Network (PICNN).
Fully Input Convex Neural Network (FICNN)¤
An FICNN represents a function \(f : x \mapsto y\) where every element of the output \(y\) is a
convex function of the input \(x\).
Layer structure¤

Each FICNNLayer computes:
where:
- \(y_i\) is the hidden state from the previous layer (initialised as \(y^0 = x\)),
- \(x\) is the original network input, passed directly to every layer via the passthrough connection,
- \(U_i\), \(W_i\), \(_i\) are the weight matrices and bias of layer \(i\),
- \(\sigma\) is an activation function that must be convex and non-decreasing (default:
softplus).
Convexity guarantee¤
Convexity is maintained by enforcing two constraints:
- Non-negative weights on the convex path. The weight matrix \(U_i\) acting on the
previous hidden state \(y_i\) is constrained to be element-wise non-negative (via
NonNegative) for all layers except the first. This ensures that the composition of affine maps and a convex activation remains convex. - Convex, non-decreasing activation. The activation \(\sigma\) applied to each layer's pre-activation must itself be convex and non-decreasing.
The passthrough weight \(W_i\) has no sign constraint because \(x\) enters each layer as a fixed affine term, which does not break convexity.
Optional: non-decreasing output¤
It is possible to additionally constrain \(U_0\) (the first-layer weight) and all passthrough weights \(W_i\) to be non-negative. This ensures the output is not only convex but also element-wise non-decreasing in \(x\). This is needed, for example, when the FICNN is composed with another convex function \(x(z)\), and convexity in \(z\) must be preserved through the chain rule.
Usage¤
klax.nn.FICNN
¤
Fully input convex neural network (FICNN) from Amos et al..
A FICNN is a function x -> y where each element
of the output y is a convex function of the input x.
__init__(in_size, out_size, width_sizes, use_passthrough=True, non_decreasing=False, weight_init=init, bias_init=zeros, constrained_weight_init=init, constrained_bias_init=init, activation=softplus, final_activation=<lambda>, use_bias=True, use_final_bias=True, dtype=None, *, key)
¤
Initialize FICNN.
The FICNN's output y will be an element-wise convex function of the
input x.
Warning
To ensure convexity, the activation functions activation and
final_activation need to be convex and non-decreasing.
| PARAMETER | DESCRIPTION |
|---|---|
in_size
|
Size of the input
TYPE:
|
out_size
|
Size of the output
TYPE:
|
width_sizes
|
The sizes of each hidden layer provided as a list.
TYPE:
|
use_passthrough
|
Whether to use passthrough layers. If true, each
FICNN's hidden layer and the output layer are passed the original
input
TYPE:
|
non_decreasing
|
If true, the weights in the first layer and in the passthrough
connections are constrained using
TYPE:
|
weight_init
|
The weight initializer of type
TYPE:
|
bias_init
|
The bias initializer used for the biases of unconstrained layers. Defaults to zeros.
TYPE:
|
constrained_weight_init
|
The weight initializer used for constrained weights.
If None, then
TYPE:
|
constrained_bias_init
|
The bias initializer used for the biases of
constrained layers. If None, then
TYPE:
|
activation
|
The activation function of each hidden layer. To ensure
convexity this function must be convex and non-decreasing.
Defaults to
TYPE:
|
final_activation
|
The activation function in the last layer. To ensure convexity of the overall FICNN this function must be convex and non-decreasing. Defaults to the identity.
TYPE:
|
use_bias
|
Whether to add on a bias in the hidden layers. Defaults to True.
TYPE:
|
use_final_bias
|
Whether to add on a bias to the final layer. Defaults to True.
TYPE:
|
dtype
|
The dtype to use for all the weights and biases in this MLP.
Defaults to either
TYPE:
|
key
|
A
TYPE:
|
__call__(x, *, key=None)
¤
Partially Input Convex Neural Network (PICNN)¤
A PICNN represents a function \(f : (x, p) \mapsto y\) where every element of the output \(y\) is
a convex function of \(x\), but can have an arbitrary relationship to the conditioning input \(p\).
Intuitively, a PICNN is a FICNN whose weights and biases are themselves functions of \(p\).
Layer structure¤

Each PICNNLayer maintains two parallel hidden states:
- Convex path \(y_i\): carries information about \(x\) and \(p\); convexity in \(x\) is enforced here.
- Arbitrary path \(u_i\): an unconstrained MLP processing the input \(p\) with no convexity constraints applied.
The two paths interact via interconnection sublayers that modulate the weights of the convex path using the current state \(u_i\) of the arbitrary path:
where:
- the weight \(U_i(u_i) = \tilde U_i \cdot \texttt{diag}(\sigma^{yu}(W^{yu}u_i+b^{yu}))\) is computed from a constant weight matrix \(\tilde U_i\), that is column-wise modulated by \(u_i\) via a single layer MLP,
- the weight \(W(u_i) = \tilde W_i \cdot \texttt{diag}(\sigma^{xu}(W^{xu}u_i+b^{xu}))\) is similarly computed from a constant weight matrix \(\tilde W_i\), that is column-wise modulated by \(u_i\) via a single layer MLP (only when
use_passthrough=True), - the bias \(b(u_i) = W^{bu}u_i+b^{bu}\) is provided through an additional affine projection of \(u_i\).
- \(\sigma_{yu}\) must be non-negative (default:
softplus) to preserve convexity, - \(\sigma_{xu}\) has no sign constraint by default (identity), but must be non-negative when
non_decreasing=True.
Note that information from \(x\) never enters the arbitrary path \(u\), preserving the convexity guarantee in \(x\).
Convexity guarantee¤
The convexity argument mirrors the FICNN:
- Non-negative weights on the convex path. The weight matrix \(U_i(u_i)\) acting on the
previous hidden state \(y_i\) is constrained to be element-wise non-negative for all layers except the first. For this, \(\tilde U_i\) is constrained to be non-negative (via
NonNegative). Additionally, \(\sigma^{yu}\) must be non-negative (e.g.softplus), ensuring the constraint is maintained regardless of the value of \(u_i\). This ensures that the composition of affine maps and a convex activation remains convex. - Convex, non-decreasing activation. The activation \(\sigma_y\) must be convex and non-decreasing.
Optional: non-decreasing output¤
Analogously to the FICNN case, it's possible to ensure the output \(y\) is non-decreasing in \(x\) by extending the non-negativity constraints to the first-layer weight and to all passthrough weights (setting non_decreasing=True).
Usage¤
klax.nn.PICNN
¤
Partially input convex neural network (PICNN) from Amos et al..
A PICNN is a function (x, p) -> y where each element
of the output y is a convex function of the input x,
but can have an arbitrary relationship to the input p.
You can think of the PICNN as an FICNN mapping x to y, but
whose weights and biases are functions of the input p.
__init__(x_size, p_size, out_size, width_sizes, *, use_passthrough=True, non_decreasing=False, weight_init=init, bias_init=zeros, constrained_weight_init=init, constrained_bias_init=init, interconnect_weight_init=zeros, interconnect_bias_init=ones, activation_y=softplus, activation_u=softplus, activation_yu=softplus, activation_xu=<lambda>, final_activation_y=<lambda>, use_bias=True, use_final_bias=True, dtype=None, key)
¤
Initialize PICNN.
The PICNN's output y will be an element-wise convex function of the
input x, but can have an arbitrary functional relationship to the input p.
Warning
To ensure convexity, the activation functions need to have certain
properties, depending on the non_decreasing option:
| Activation | non_decreasing=False |
non_decreasing=True |
|---|---|---|
activation_y |
Convex and non-decreasing | Convex and non-decreasing |
final_activation_y |
Convex and non-decreasing | Convex and non-decreasing |
activation_u |
Arbitrary | Arbitrary |
activation_yu |
Non-negative | Non-negative |
activation_xu |
Arbitrary | Non-negative |
| PARAMETER | DESCRIPTION |
|---|---|
x_size
|
Size of the input
TYPE:
|
p_size
|
Size of the input
TYPE:
|
out_size
|
Size of the output
TYPE:
|
width_sizes
|
List of the sizes for each hidden layer. Each element of the list
can be either an integer or a tuple of two integers
TYPE:
|
use_passthrough
|
use_passthrough: Whether to use passthrough layers. If true, the PICNN's input is passed again to each hidden layer (Except for the first hidden layer, since it receives the original input anyway). Defaults to True.
TYPE:
|
non_decreasing
|
If true, the Note: If
TYPE:
|
weight_init
|
The weight initializer of type
TYPE:
|
bias_init
|
The bias initializer of type
TYPE:
|
constrained_weight_init
|
The weight initializer of type
TYPE:
|
constrained_bias_init
|
The bias initializer of type
TYPE:
|
interconnect_weight_init
|
The weight initializer of type
TYPE:
|
interconnect_bias_init
|
The bias initializer of type
TYPE:
|
activation_y
|
Activation function applied to the convex output
TYPE:
|
activation_u
|
Activation function applied to the arbitrary output
TYPE:
|
activation_yu
|
Activation function applied to the output of the sublayer computing the weight modulation in the convex path. To ensure convexity, this function is required to be non-negative. Defaults to jax.nn.softplus.
TYPE:
|
activation_xu
|
Activation function applied to the output
of the sublayer computing the weight modulation in the
passthrough path. To ensure convexity when
TYPE:
|
final_activation_y
|
Activation function applied to the convex output
TYPE:
|
use_bias
|
Whether to add on a bias in the hidden layers. Defaults to True.
TYPE:
|
use_final_bias
|
Whether to add on a bias to the final layer. Defaults to True.
TYPE:
|
dtype
|
The dtype to use for all the weights and biases in this MLP.
Defaults to either
TYPE:
|
key
|
A
TYPE:
|
__call__(x, p, *, key=None)
¤
Forward pass through the PICNN.
| PARAMETER | DESCRIPTION |
|---|---|
x
|
A JAX array with shape
TYPE:
|
p
|
A JAX array with shape
TYPE:
|
key
|
Ignored; provided for compatibility with the rest of the Equinox API. Defaults to None.
TYPE:
|
| RETURNS | DESCRIPTION |
|---|---|
Float[Array, '... out_size']
|
A JAX array with shape |
Float[Array, '... out_size']
|
|