Skip to content

Multi-layer perceptrons

klax.nn.MLP ¤

Standard Multi-Layer Perceptron; also known as a feed-forward network.

This class is modified form equinox.nn.MLP to allow for custom initialization and different node numbers in the hidden layers. Hence, it may also be used for ecoder/decoder tasks.

__init__(in_size, out_size, width_sizes, weight_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, activation=<PjitFunction of <function softplus at 0x7fb723085ee0>>, final_activation=<function MLP.<lambda>>, use_bias=True, use_final_bias=True, weight_wrap=None, bias_wrap=None, dtype=None, *, key) ¤

Initialize MLP.

PARAMETER DESCRIPTION
in_size

The input size. The input to the module should be a vector of shape (in_features,).

TYPE: int | Literal['scalar']

out_size

The output size. The output from the module will be a vector of shape (out_features,).

TYPE: int | Literal['scalar']

width_sizes

The sizes of each hidden layer in a list.

TYPE: Sequence[int]

weight_init

The weight initializer of type jax.nn.initializers.Initializer. (Defaults to he_normal().)

TYPE: jax.nn.initializers.Initializer DEFAULT: <function variance_scaling.<locals>.init>

bias_init

The bias initializer of type jax.nn.initializers.Initializer. (Defaults to zeros.)

TYPE: jax.nn.initializers.Initializer DEFAULT: <function zeros>

activation

The activation function after each hidden layer. (Defaults to jax.nn.softplus).

TYPE: Callable DEFAULT: <PjitFunction of <function softplus at 0x7fb723085ee0>>

final_activation

The activation function after the output layer. (Defaults to the identity.)

TYPE: Callable DEFAULT: <function MLP.<lambda>>

use_bias

Whether to add on a bias to internal layers. (Defaults to True.)

TYPE: bool DEFAULT: True

use_final_bias

Whether to add on a bias to the final layer. (Defaults to True.)

TYPE: bool DEFAULT: True

weight_wrap

An optional wrapper that is passed to all weights.

TYPE: type[klax.Constraint] | type[klax.Unwrappable[Array]] | None DEFAULT: None

bias_wrap

An optional wrapper that is passed to all biases.

TYPE: type[klax.Constraint] | type[klax.Unwrappable[Array]] | None DEFAULT: None

dtype

The dtype to use for all the weights and biases in this MLP. Defaults to either jax.numpy.float32 or jax.numpy.float64 depending on whether JAX is in 64-bit mode.

TYPE: type | None DEFAULT: None

key

A jax.random.PRNGKey used to provide randomness for parameter initialisation. (Keyword only argument.)

TYPE: PRNGKeyArray

Note

Note that in_size also supports the string "scalar" as a special value. In this case the input to the module should be of shape ().

Likewise out_size can also be a string "scalar", in which case the output from the module will have shape ().

__call__(x, *, key=None) ¤

Forward pass through MLP.

PARAMETER DESCRIPTION
x

A JAX array with shape (in_size,). (Or shape () if in_size="scalar".)

TYPE: Array

key

Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)

TYPE: PRNGKeyArray | None DEFAULT: None

RETURNS DESCRIPTION
Array

A JAX array with shape (out_size,). (Or shape () if

Array

out_size="scalar".)


klax.nn.FICNN ¤

A fully input convex neural network (FICNN).

Each element of the output is a convex function of the input.

See: https://arxiv.org/abs/1609.07152

__init__(in_size, out_size, width_sizes, use_passthrough=True, non_decreasing=False, weight_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, activation=<PjitFunction of <function softplus at 0x7fb723085ee0>>, final_activation=<function FICNN.<lambda>>, use_bias=True, use_final_bias=True, dtype=None, *, key) ¤

Initialize FICNN.

Warning

Modifying final_activation to a non-convex function will break the convexity of the FICNN. Use this parameter with care.

PARAMETER DESCRIPTION
in_size

The input size. The input to the module should be a vector of shape (in_features,).

TYPE: int | Literal['scalar']

out_size

The output size. The output from the module will be a vector of shape (out_features,).

TYPE: int | Literal['scalar']

width_sizes

The sizes of each hidden layer in a list.

TYPE: Sequence[int]

use_passthrough

Whether to use passthrough layers. If true, the input is passed through to each hidden layer. Defaults to True.

TYPE: bool DEFAULT: True

non_decreasing

If true, the output is element-wise non-decreasing in each input. This is useful if the input x is a convex function of some other quantity z. If the FICNN f(x(z)) is non-decreasing then f preserves the convexity with respect to z. Defaults to False.

TYPE: bool DEFAULT: False

weight_init

The weight initializer of type jax.nn.initializers.Initializer. Defaults to he_normal().

TYPE: jax.nn.initializers.Initializer DEFAULT: <function variance_scaling.<locals>.init>

bias_init

The bias initializer of type jax.nn.initializers.Initializer. Defaults to zeros.

TYPE: jax.nn.initializers.Initializer DEFAULT: <function zeros>

activation

The activation function of each hidden layer. To ensure convexity this function must be convex and non-decreasing. Defaults to jax.nn.softplus.

TYPE: Callable DEFAULT: <PjitFunction of <function softplus at 0x7fb723085ee0>>

final_activation

The activation function after the output layer. To ensure convexity this function must be convex and non-decreasing. (Defaults to the identity.)

TYPE: Callable DEFAULT: <function FICNN.<lambda>>

use_bias

Whether to add on a bias in the hidden layers. (Defaults to True.)

TYPE: bool DEFAULT: True

use_final_bias

Whether to add on a bias to the final layer. Defaults to True.

TYPE: bool DEFAULT: True

dtype

The dtype to use for all the weights and biases in this MLP. Defaults to either jax.numpy.float32 or jax.numpy.float64 depending on whether JAX is in 64-bit mode.

TYPE: type | None DEFAULT: None

key

A jax.random.PRNGKey used to provide randomness for parameter initialisation. (Keyword only argument.)

TYPE: PRNGKeyArray

__call__(x, *, key=None) ¤

Forward pass through FICNN.

PARAMETER DESCRIPTION
x

A JAX array with shape (in_size,). (Or shape () if in_size="scalar".)

TYPE: Array

key

Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)

TYPE: PRNGKeyArray | None DEFAULT: None

RETURNS DESCRIPTION
Array

A JAX array with shape (out_size,). (Or shape () if

Array

out_size="scalar".)