Skip to content

klax.nn.MLP ¤

Standard Multi-Layer Perceptron; also known as a feed-forward network.

This class is modified form equinox.nn.MLP to allow for custom initialization and different node numbers in the hidden layers.

__init__(in_size, out_size, width_sizes, weight_init=init, bias_init=zeros, activation=softplus, final_activation=<lambda>, use_bias=True, use_final_bias=True, weight_wrap=None, bias_wrap=None, dtype=None, *, key) ¤

Initialize MLP.

PARAMETER DESCRIPTION
in_size

The input size. The input to the module should be a vector of shape (in_features,).

TYPE: Union[int, Literal['scalar']]

out_size

The output size. The output from the module will be a vector of shape (out_features,).

TYPE: Union[int, Literal['scalar']]

width_sizes

The sizes of each hidden layer in a list.

TYPE: Sequence[int]

weight_init

The weight initializer of type Initializer.

TYPE: Initializer DEFAULT: init

bias_init

The bias initializer of type Initializer.

TYPE: Initializer DEFAULT: zeros

activation

The activation function after each hidden layer.

TYPE: Callable DEFAULT: softplus

final_activation

The activation function after the output layer.

TYPE: Callable DEFAULT: <lambda>

use_bias

Whether to add on a bias to internal layers.

TYPE: bool DEFAULT: True

use_final_bias

Whether to add on a bias to the final layer.

TYPE: bool DEFAULT: True

weight_wrap

An optional wrapper that is passed to all weights.

TYPE: type[Constraint] | type[Unwrappable[Array]] | None DEFAULT: None

bias_wrap

An optional wrapper that is passed to all biases.

TYPE: type[Constraint] | type[Unwrappable[Array]] | None DEFAULT: None

dtype

The dtype to use for all the weights and biases in this MLP. Defaults to either jax.numpy.float32 or jax.numpy.float64 depending on whether JAX is in 64-bit mode.

TYPE: type | None DEFAULT: None

key

A jax.random.PRNGKey used to provide randomness for parameter initialisation. (Keyword only argument.)

TYPE: PRNGKeyArray

Note

Note that in_size also supports the string "scalar" as a special value. In this case the input to the module should be of shape ().

Likewise out_size can also be a string "scalar", in which case the output from the module will have shape ().

__call__(x, *, key=None) ¤

Forward pass through MLP.

PARAMETER DESCRIPTION
x

A JAX array with shape (in_size,). (Or shape () if in_size="scalar".)

TYPE: Array

key

Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)

TYPE: PRNGKeyArray | None DEFAULT: None

RETURNS DESCRIPTION
Array

A JAX array with shape (out_size,). (Or shape () if

Array

out_size="scalar".)


Input Convex Neural Networks¤

Input Convex Neural Networks (ICNNs) are a family of neural network architectures introduced by Amos et al. (2017) that are designed to produce outputs that are provably convex with respect to (some of) their inputs. Klax provides two variants: the Fully Input Convex Neural Network (FICNN) and the Partially Input Convex Neural Network (PICNN).

Fully Input Convex Neural Network (FICNN)¤

An FICNN represents a function \(f : x \mapsto y\) where every element of the output \(y\) is a convex function of the input \(x\).

Layer structure¤

FICNN Layer FICNN Layer

Each FICNNLayer computes:

\[y_{i+1} = \sigma\!\left(U_i y_i + W_i x + b_i\right)\]

where:

  • \(y_i\) is the hidden state from the previous layer (initialised as \(y^0 = x\)),
  • \(x\) is the original network input, passed directly to every layer via the passthrough connection,
  • \(U_i\), \(W_i\), \(_i\) are the weight matrices and bias of layer \(i\),
  • \(\sigma\) is an activation function that must be convex and non-decreasing (default: softplus).

Convexity guarantee¤

Convexity is maintained by enforcing two constraints:

  1. Non-negative weights on the convex path. The weight matrix \(U_i\) acting on the previous hidden state \(y_i\) is constrained to be element-wise non-negative (via NonNegative) for all layers except the first. This ensures that the composition of affine maps and a convex activation remains convex.
  2. Convex, non-decreasing activation. The activation \(\sigma\) applied to each layer's pre-activation must itself be convex and non-decreasing.

The passthrough weight \(W_i\) has no sign constraint because \(x\) enters each layer as a fixed affine term, which does not break convexity.

Optional: non-decreasing output¤

It is possible to additionally constrain \(U_0\) (the first-layer weight) and all passthrough weights \(W_i\) to be non-negative. This ensures the output is not only convex but also element-wise non-decreasing in \(x\). This is needed, for example, when the FICNN is composed with another convex function \(x(z)\), and convexity in \(z\) must be preserved through the chain rule.

Usage¤

klax.nn.FICNN ¤

Fully input convex neural network (FICNN) from Amos et al..

A FICNN is a function x -> y where each element of the output y is a convex function of the input x.

__init__(in_size, out_size, width_sizes, use_passthrough=True, non_decreasing=False, weight_init=init, bias_init=zeros, constrained_weight_init=init, constrained_bias_init=init, activation=softplus, final_activation=<lambda>, use_bias=True, use_final_bias=True, dtype=None, *, key) ¤

Initialize FICNN.

The FICNN's output y will be an element-wise convex function of the input x.

Warning

To ensure convexity, the activation functions activation and final_activation need to be convex and non-decreasing.

PARAMETER DESCRIPTION
in_size

Size of the input x. Can be "scalar", to indicate a scalar input. The input to the FICNN should be a vector of shape (x_size,) or a scalar with shape () if x_size="scalar"

TYPE: Union[int, Literal['scalar']]

out_size

Size of the output y. Can be "scalar", to indicate a scalar output. The output of the FICNN will be a vector of shape (out_size,) or a scalar with shape () if out_size="scalar".

TYPE: Union[int, Literal['scalar']]

width_sizes

The sizes of each hidden layer provided as a list.

TYPE: Sequence[int]

use_passthrough

Whether to use passthrough layers. If true, each FICNN's hidden layer and the output layer are passed the original input x as additional input. Defaults to True.

TYPE: bool DEFAULT: True

non_decreasing

If true, the weights in the first layer and in the passthrough connections are constrained using klax.NonNegative. Hence, the output is element-wise convex and non-decreasing in each input. This is useful in the following scenario: Consider that you want to model the function g(z) = FICNN(x(z)) as a chain of the functions x(z) and FICNN(x) such that g is convex w.r.t. z. If x(z) is convex, the FICNN must be convex and non-decreasing in x to ensure convexity of g(z). This option is, for example, used in material modeling applications, where the FICNN is a function of convex invariants, c.f., Dammaß et al. (2025).

TYPE: bool DEFAULT: False

weight_init

The weight initializer of type SupportedInitializer used for unconstrained weights. Defaults to he_normal().

TYPE: SupportedInitializer DEFAULT: init

bias_init

The bias initializer used for the biases of unconstrained layers. Defaults to zeros.

TYPE: SupportedInitializer DEFAULT: zeros

constrained_weight_init

The weight initializer used for constrained weights. If None, then weight_init is used for constrained weights as well. Defaults to klax.hoedt_normal.

TYPE: SupportedInitializer | None DEFAULT: init

constrained_bias_init

The bias initializer used for the biases of constrained layers. If None, then bias_init is used for the biases in constrained layers as well. Defaults to zeros.

TYPE: SupportedInitializer | None DEFAULT: init

activation

The activation function of each hidden layer. To ensure convexity this function must be convex and non-decreasing. Defaults to jax.nn.softplus.

TYPE: Callable DEFAULT: softplus

final_activation

The activation function in the last layer. To ensure convexity of the overall FICNN this function must be convex and non-decreasing. Defaults to the identity.

TYPE: Callable DEFAULT: <lambda>

use_bias

Whether to add on a bias in the hidden layers. Defaults to True.

TYPE: bool DEFAULT: True

use_final_bias

Whether to add on a bias to the final layer. Defaults to True.

TYPE: bool DEFAULT: True

dtype

The dtype to use for all the weights and biases in this MLP. Defaults to either jax.numpy.float32 or jax.numpy.float64 depending on whether JAX is in 64-bit mode.

TYPE: type | None DEFAULT: None

key

A jax.random.PRNGKey used to provide randomness for parameter initialization.

TYPE: PRNGKeyArray

__call__(x, *, key=None) ¤


Partially Input Convex Neural Network (PICNN)¤

A PICNN represents a function \(f : (x, p) \mapsto y\) where every element of the output \(y\) is a convex function of \(x\), but can have an arbitrary relationship to the conditioning input \(p\). Intuitively, a PICNN is a FICNN whose weights and biases are themselves functions of \(p\).

Layer structure¤

FICNN Layer FICNN Layer

Each PICNNLayer maintains two parallel hidden states:

  • Convex path \(y_i\): carries information about \(x\) and \(p\); convexity in \(x\) is enforced here.
  • Arbitrary path \(u_i\): an unconstrained MLP processing the input \(p\) with no convexity constraints applied.

The two paths interact via interconnection sublayers that modulate the weights of the convex path using the current state \(u_i\) of the arbitrary path:

\[ \begin{align} y^{i+1} &= \sigma^y\Bigl(U_i(u_i) y_i \;+W_i(u_i) x \;+b_i(u_i)\Bigr)\\ u^{i+1} &= \sigma^u\!\left(W_i^u\, u_i + b_i^u\right) \end{align} \]

where:

  • the weight \(U_i(u_i) = \tilde U_i \cdot \texttt{diag}(\sigma^{yu}(W^{yu}u_i+b^{yu}))\) is computed from a constant weight matrix \(\tilde U_i\), that is column-wise modulated by \(u_i\) via a single layer MLP,
  • the weight \(W(u_i) = \tilde W_i \cdot \texttt{diag}(\sigma^{xu}(W^{xu}u_i+b^{xu}))\) is similarly computed from a constant weight matrix \(\tilde W_i\), that is column-wise modulated by \(u_i\) via a single layer MLP (only when use_passthrough=True),
  • the bias \(b(u_i) = W^{bu}u_i+b^{bu}\) is provided through an additional affine projection of \(u_i\).
  • \(\sigma_{yu}\) must be non-negative (default: softplus) to preserve convexity,
  • \(\sigma_{xu}\) has no sign constraint by default (identity), but must be non-negative when non_decreasing=True.

Note that information from \(x\) never enters the arbitrary path \(u\), preserving the convexity guarantee in \(x\).

Convexity guarantee¤

The convexity argument mirrors the FICNN:

  1. Non-negative weights on the convex path. The weight matrix \(U_i(u_i)\) acting on the previous hidden state \(y_i\) is constrained to be element-wise non-negative for all layers except the first. For this, \(\tilde U_i\) is constrained to be non-negative (via NonNegative). Additionally, \(\sigma^{yu}\) must be non-negative (e.g. softplus), ensuring the constraint is maintained regardless of the value of \(u_i\). This ensures that the composition of affine maps and a convex activation remains convex.
  2. Convex, non-decreasing activation. The activation \(\sigma_y\) must be convex and non-decreasing.

Optional: non-decreasing output¤

Analogously to the FICNN case, it's possible to ensure the output \(y\) is non-decreasing in \(x\) by extending the non-negativity constraints to the first-layer weight and to all passthrough weights (setting non_decreasing=True).

Usage¤

klax.nn.PICNN ¤

Partially input convex neural network (PICNN) from Amos et al..

A PICNN is a function (x, p) -> y where each element of the output y is a convex function of the input x, but can have an arbitrary relationship to the input p. You can think of the PICNN as an FICNN mapping x to y, but whose weights and biases are functions of the input p.

__init__(x_size, p_size, out_size, width_sizes, *, use_passthrough=True, non_decreasing=False, weight_init=init, bias_init=zeros, constrained_weight_init=init, constrained_bias_init=init, interconnect_weight_init=zeros, interconnect_bias_init=ones, activation_y=softplus, activation_u=softplus, activation_yu=softplus, activation_xu=<lambda>, final_activation_y=<lambda>, use_bias=True, use_final_bias=True, dtype=None, key) ¤

Initialize PICNN.

The PICNN's output y will be an element-wise convex function of the input x, but can have an arbitrary functional relationship to the input p.

Warning

To ensure convexity, the activation functions need to have certain properties, depending on the non_decreasing option:

Activation non_decreasing=False non_decreasing=True
activation_y Convex and non-decreasing Convex and non-decreasing
final_activation_y Convex and non-decreasing Convex and non-decreasing
activation_u Arbitrary Arbitrary
activation_yu Non-negative Non-negative
activation_xu Arbitrary Non-negative
PARAMETER DESCRIPTION
x_size

Size of the input x. Can be "scalar", to indicate a scalar input. The input to the PICNN should be a vector of shape (x_size,) or a scalar with shape () if x_size="scalar".

TYPE: Union[int, Literal['scalar']]

p_size

Size of the input p. Can be "scalar", to indicate a scalar input. The input to the PICNN should be a vector of shape (p_size,) or a scalar with shape () if p_size="scalar".

TYPE: Union[int, Literal['scalar']]

out_size

Size of the output y. Can be "scalar", to indicate a scalar output. The output of the PICNN will be a vector of shape (out_size,) or a scalar with shape () if out_size="scalar".

TYPE: Union[int, Literal['scalar']]

width_sizes

List of the sizes for each hidden layer. Each element of the list can be either an integer or a tuple of two integers (y_size, u_size). In the latter case the layers output sizes can be defined individually for both the convex output y(x, p) and the arbitrary output p(p). If only a single integer is provided, it is used for both y_size and u_size.

TYPE: Sequence[int | tuple[int, int]]

use_passthrough

use_passthrough: Whether to use passthrough layers. If true, the PICNN's input is passed again to each hidden layer (Except for the first hidden layer, since it receives the original input anyway). Defaults to True.

TYPE: bool DEFAULT: True

non_decreasing

If true, the PICNN output is element-wise non-decreasing in the input x. This is useful in the following scenario: Consider that you want to model the function g(z, p) = PICNN(x(z), p) as a chain of the functions x(z) and PICNN(x, z) such that g is convex w.r.t. z. If x(z) is convex, the PICNN must be non-decreasing in x to ensure convexity of g(z). This option is, for example, used in material modeling applications, where the PICNN is a function of convex invariants, c.f., Dammaß et al. (2025).

Note: If use_passthrough=True and non_decreasing=True then activation_xu has to be a non-negative function to guarantee convexity with respect to x. We recomment avoiding passthrough for non-decreasing PICNNs. Defaults to False.

TYPE: bool DEFAULT: False

weight_init

The weight initializer of type SupportedInitializer used for unconstrained weights. Defaults to he_normal().

TYPE: SupportedInitializer DEFAULT: init

bias_init

The bias initializer of type SupportedInitializer used for biases in sublayers without weight constraints. Defaults to zeros.

TYPE: SupportedInitializer DEFAULT: zeros

constrained_weight_init

The weight initializer of type SupportedInitializer used for constrained weights. Can be None, in which case weight_init is used. Defaults to hoedt_normal().

TYPE: SupportedInitializer | None DEFAULT: init

constrained_bias_init

The bias initializer of type SupportedInitializer used for biases in sublayers with weight constraints. Can be None, in which case bias_init is used. Defaults to hoedt_bias().

TYPE: SupportedInitializer | None DEFAULT: init

interconnect_weight_init

The weight initializer of type SupportedInitializer used for weights in interconnection sublayers. By initializing these weights to zero, the initial PICNN's output will not depend on the input p. Defaults to zeros.

TYPE: SupportedInitializer | None DEFAULT: zeros

interconnect_bias_init

The bias initializer of type SupportedInitializer used for biases in interconnection sublayers. Defaults to ones.

TYPE: SupportedInitializer | None DEFAULT: ones

activation_y

Activation function applied to the convex output y in all but the last layer. To ensure convexity, this function is required to be convex and non-decreasing. Defaults to jax.nn.softplus.

TYPE: Callable DEFAULT: softplus

activation_u

Activation function applied to the arbitrary output u. Defaults to jax.nn.softplus.

TYPE: Callable DEFAULT: softplus

activation_yu

Activation function applied to the output of the sublayer computing the weight modulation in the convex path. To ensure convexity, this function is required to be non-negative. Defaults to jax.nn.softplus.

TYPE: Callable DEFAULT: softplus

activation_xu

Activation function applied to the output of the sublayer computing the weight modulation in the passthrough path. To ensure convexity when non_decreasing=True, this function is required to be non-negative, otherwise there are no constraints. Defaults to lambdax: x.

TYPE: Callable DEFAULT: <lambda>

final_activation_y

Activation function applied to the convex output y of th last layer. To ensure convexity, this function is required to be convex and non-decreasing. Defaults to lambdax: x.

TYPE: Callable DEFAULT: <lambda>

use_bias

Whether to add on a bias in the hidden layers. Defaults to True.

TYPE: bool DEFAULT: True

use_final_bias

Whether to add on a bias to the final layer. Defaults to True.

TYPE: bool DEFAULT: True

dtype

The dtype to use for all the weights and biases in this MLP. Defaults to either jax.numpy.float32 or jax.numpy.float64 depending on whether JAX is in 64-bit mode.

TYPE: type | None DEFAULT: None

key

A jax.random.PRNGKey used to provide randomness for parameter initialization.

TYPE: PRNGKeyArray

__call__(x, p, *, key=None) ¤

Forward pass through the PICNN.

PARAMETER DESCRIPTION
x

A JAX array with shape (x_size,). (Or shape () if x_size="scalar".) The output will be element-wise convex in this input.

TYPE: Float[Array, '... x_size']

p

A JAX array with shape (p_size,). (Or shape () if x_size="scalar".) The output can have an arbitrary relationship to this input.

TYPE: Float[Array, '... p_size']

key

Ignored; provided for compatibility with the rest of the Equinox API. Defaults to None.

TYPE: PRNGKeyArray | None DEFAULT: None

RETURNS DESCRIPTION
Float[Array, '... out_size']

A JAX array with shape (out_size,). (Or shape () if

Float[Array, '... out_size']

out_size="scalar".)