Skip to content

Matrix-valued functions

klax.nn.Matrix ¤

An unconstrained matrix-valued function based on an MLP.

The MLP maps to a vector of elements which is transformed into a matrix.

__init__(in_size, shape, width_sizes, weight_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, activation=<PjitFunction of <function softplus at 0x7fb723085ee0>>, final_activation=<function Matrix.<lambda>>, use_bias=True, use_final_bias=True, dtype=None, *, key) ¤

Initialize the Matrix.

PARAMETER DESCRIPTION
in_size

The input size. The input to the module should be a vector of shape (in_size,)

TYPE: int | Literal['scalar']

shape

The matrix shape. The output from the module will be an array with the specified shape. For square matrices a single integer N can be used as a shorthand for (N, N).

TYPE: int | AtLeast2DTuple[int]

width_sizes

The sizes of each hidden layer of the underlying MLP in a list.

TYPE: Sequence[int]

weight_init

The weight initializer of type jax.nn.initializers.Initializer. (Defaults to he_normal())

TYPE: jax.nn.initializers.Initializer DEFAULT: <function variance_scaling.<locals>.init>

bias_init

The bias initializer of type jax.nn.initializers.Initializer. (Defaults to zeros)

TYPE: jax.nn.initializers.Initializer DEFAULT: <function zeros>

activation

The activation function after each hidden layer. (Defaults to ReLU).

TYPE: Callable DEFAULT: <PjitFunction of <function softplus at 0x7fb723085ee0>>

final_activation

The activation function after the output layer. (Defaults to the identity.)

TYPE: Callable DEFAULT: <function Matrix.<lambda>>

use_bias

Whether to add on a bias to internal layers. (Defaults to True.)

TYPE: bool DEFAULT: True

use_final_bias

Whether to add on a bias to the final layer. (Defaults to True.)

TYPE: bool DEFAULT: True

dtype

The dtype to use for all the weights and biases in this MLP. (Defaults to either jax.numpy.float32 or jax.numpy.float64 depending on whether JAX is in 64-bit mode.)

TYPE: Any | None DEFAULT: None

key

A jax.random.PRNGKey used to provide randomness for parameter initialisation. (Keyword only argument.)

TYPE: PRNGKeyArray

Note

Note that in_size also supports the string "scalar" as a special value. In this case the input to the module should be of shape ().

__call__(x) ¤

Forward pass through Matrix.

PARAMETER DESCRIPTION
x

The input. Should be a JAX array of shape (in_size,). (Or shape () if in_size="scalar".)

TYPE: Array

RETURNS DESCRIPTION
Array

A JAX array of shape shape.


klax.nn.ConstantMatrix ¤

A constant, unconstrained matrix.

It is a wrapper around a constant array that implements the matrix-valued function interface.

__init__(shape, init=<function variance_scaling.<locals>.init>, dtype=None, *, key) ¤

Initialize the ConstantMatrix.

PARAMETER DESCRIPTION
shape

The matrix shape. The output from the module will be a Array with sthe specified shape. For square matrices a single integer N can be used as a shorthand for (N, N).

TYPE: int | AtLeast2DTuple[int]

init

The array initializer of type jax.nn.initializers.Initializer. (Defaults to variance_scaling(scale=1, mode="fan_avg", distribution="normal").)

TYPE: jax.nn.initializers.Initializer DEFAULT: <function variance_scaling.<locals>.init>

dtype

The dtype to use for all the weights and biases in this MLP. (Defaults to either jax.numpy.float32 or jax.numpy.float64 depending on whether JAX is in 64-bit mode.)

TYPE: Any | None DEFAULT: None

key

A jax.random.PRNGKey used to provide randomness for parameter initialisation. (Keyword only argument.)

TYPE: PRNGKeyArray

__call__(x) ¤

Forward pass through ConstantMatrix.

PARAMETER DESCRIPTION
x

Ignored; provided for compatibility with the rest of the Matrix-valued function API.

TYPE: Array

RETURNS DESCRIPTION
Array

A JAX array of shape shape.


klax.nn.SkewSymmetricMatrix ¤

A kkew-symmetric matrix-valued function based on an MLP.

The MLP maps the input to a vector of elements that are transformed into a skew-symmetric matrix.

__init__(in_size, shape, width_sizes, weight_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, activation=<PjitFunction of <function softplus at 0x7fb723085ee0>>, final_activation=<function SkewSymmetricMatrix.<lambda>>, use_bias=True, use_final_bias=True, dtype=None, *, key) ¤

Initialize the SkewSymmetricMatrix.

PARAMETER DESCRIPTION
in_size

The input size. The input to the module should be a vector of shape (in_size,)

TYPE: int | Literal['scalar']

shape

The matrix shape. The output from the module will be a Array with sthe specified shape. For square matrices a single integer N can be used as a shorthand for (N, N).

TYPE: int | AtLeast2DTuple[int]

width_sizes

The sizes of each hidden layer of the underlying MLP in a list.

TYPE: Sequence[int]

weight_init

The weight initializer of type jax.nn.initializers.Initializer. (Defaults to he_normal())

TYPE: jax.nn.initializers.Initializer DEFAULT: <function variance_scaling.<locals>.init>

bias_init

The bias initializer of type jax.nn.initializers.Initializer. (Defaults to zeros)

TYPE: jax.nn.initializers.Initializer DEFAULT: <function zeros>

activation

The activation function after each hidden layer. (Defaults to softplus).

TYPE: Callable DEFAULT: <PjitFunction of <function softplus at 0x7fb723085ee0>>

final_activation

The activation function after the output layer. (Defaults to the identity.)

TYPE: Callable DEFAULT: <function SkewSymmetricMatrix.<lambda>>

use_bias

Whether to add on a bias to internal layers. (Defaults to True.)

TYPE: bool DEFAULT: True

use_final_bias

Whether to add on a bias to the final layer. (Defaults to True.)

TYPE: bool DEFAULT: True

dtype

The dtype to use for all the weights and biases in this MLP. (Defaults to either jax.numpy.float32 or jax.numpy.float64 depending on whether JAX is in 64-bit mode.)

TYPE: Any | None DEFAULT: None

key

A jax.random.PRNGKey used to provide randomness for parameter initialisation. (Keyword only argument.)

TYPE: PRNGKeyArray

Note

Note that in_size also supports the string "scalar" as a special value. In this case the input to the module should be of shape ().

__call__(x) ¤

Forward pass through SkewSymmetricMatrix.

PARAMETER DESCRIPTION
x

The input. Should be a JAX array of shape (in_size,). (Or shape () if in_size="scalar".)

TYPE: Array

RETURNS DESCRIPTION
Array

A JAX array of shape shape.


klax.nn.ConstantSkewSymmetricMatrix ¤

A constant skew-symmetric matrix.

It is a wrapper around a constant skew-symmetry-constraind array that implements the matrix-valued function interface.

__init__(shape, init=<function variance_scaling.<locals>.init>, dtype=None, *, key) ¤

Initialize the ConstantSkewSymmetricMatrix.

PARAMETER DESCRIPTION
shape

The matrix shape. The output from the module will be a Array with sthe specified shape. For square matrices a single integer N can be used as a shorthand for (N, N).

TYPE: int | AtLeast2DTuple[int]

init

The array initializer of type jax.nn.initializers.Initializer. (Defaults to variance_scaling(scale=1, mode="fan_avg", distribution="normal").)

TYPE: jax.nn.initializers.Initializer DEFAULT: <function variance_scaling.<locals>.init>

dtype

The dtype to use for all the weights and biases in this MLP. (Defaults to either jax.numpy.float32 or jax.numpy.float64 depending on whether JAX is in 64-bit mode.)

TYPE: Any | None DEFAULT: None

key

A jax.random.PRNGKey used to provide randomness for parameter initialisation. (Keyword only argument.)

TYPE: PRNGKeyArray

__call__(x) ¤

Forward pass through ConstantSkewSymmetricMatrix.

PARAMETER DESCRIPTION
x

Ignored; provided for compatibility with the rest of the Matrix-valued function API.

TYPE: Array

RETURNS DESCRIPTION
Array

A JAX array of shape shape.


klax.nn.SPDMatrix ¤

A symmetric positive definite matrix-valued function based on an MLP.

The output vector v of the MLP is mapped to a matrix B. The module's output is then computed via A=B@B*.

__init__(in_size, shape, width_sizes, epsilon=1e-06, weight_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, activation=<PjitFunction of <function softplus at 0x7fb723085ee0>>, final_activation=<function SPDMatrix.<lambda>>, use_bias=True, use_final_bias=True, dtype=None, *, key) ¤

Initialize the SPDMatrix.

PARAMETER DESCRIPTION
in_size

The input size. The input to the module should be a vector of shape (in_size,)

TYPE: int | Literal['scalar']

shape

The matrix shape. The output from the module will be a Array with sthe specified shape. For square matrices a single integer N can be used as a shorthand for (N, N).

TYPE: int | AtLeast2DTuple[int]

width_sizes

The sizes of each hidden layer of the underlying MLP in a list.

TYPE: Sequence[int]

epsilon

Small value that is added to the diagonal of the output matrix to ensure positive definiteness. If only positive semi-definiteness is required set epsilon = 0. (Defaults to 1e-6)

TYPE: float DEFAULT: 1e-06

weight_init

The weight initializer of type jax.nn.initializers.Initializer. (Defaults to he_normal())

TYPE: jax.nn.initializers.Initializer DEFAULT: <function variance_scaling.<locals>.init>

bias_init

The bias initializer of type jax.nn.initializers.Initializer. (Defaults to zeros)

TYPE: jax.nn.initializers.Initializer DEFAULT: <function zeros>

activation

The activation function after each hidden layer. (Defaults to softplus)

TYPE: Callable DEFAULT: <PjitFunction of <function softplus at 0x7fb723085ee0>>

final_activation

The activation function after the output layer. (Defaults to the identity.)

TYPE: Callable DEFAULT: <function SPDMatrix.<lambda>>

use_bias

Whether to add on a bias to internal layers. (Defaults to True.)

TYPE: bool DEFAULT: True

use_final_bias

Whether to add on a bias to the final layer. (Defaults to True.)

TYPE: bool DEFAULT: True

dtype

The dtype to use for all the weights and biases in this MLP. (Defaults to either jax.numpy.float32 or jax.numpy.float64 depending on whether JAX is in 64-bit mode.)

TYPE: Any | None DEFAULT: None

key

A jax.random.PRNGKey used to provide randomness for parameter initialisation. (Keyword only argument.)

TYPE: PRNGKeyArray

Note

Note that in_size also supports the string "scalar" as a special value. In this case the input to the module should be of shape ().

__call__(x) ¤

Forward pass through SPDMatrix.

PARAMETER DESCRIPTION
x

The input. Should be a JAX array of shape (in_size,). (Or shape () if in_size="scalar".)

TYPE: Array

RETURNS DESCRIPTION
Array

A JAX array of shape shape.


klax.nn.ConstantSPDMatrix ¤

A constant symmetric positive definite matrix-valued function.

It is a wrapper around a constant symmetric postive semi-definite matrix with the matrix-valued function interface.

__init__(shape, epsilon=1e-06, init=<function variance_scaling.<locals>.init>, dtype=None, *, key) ¤

Initialize the ConstantSPDMatrix.

PARAMETER DESCRIPTION
shape

The matrix shape. The output from the module will be a Array with sthe specified shape. For square matrices a single integer N can be used as a shorthand for (N, N).

TYPE: int | AtLeast2DTuple[int]

epsilon

Small value that is added to the diagonal of the output matrix to ensure positive definiteness. If only positive semi-definiteness is required set epsilon = 0. (Defaults to 1e-6)

TYPE: float DEFAULT: 1e-06

init

The initializer of type jax.nn.initializers.Initializer for the constant matrix B that produces the module's output via A = B@B*. (Defaults to variance_scaling(scale=1, mode="fan_avg", distribution="normal").)

TYPE: jax.nn.initializers.Initializer DEFAULT: <function variance_scaling.<locals>.init>

dtype

The dtype to use for all the weights and biases in this MLP. (Defaults to either jax.numpy.float32 or jax.numpy.float64 depending on whether JAX is in 64-bit mode.)

TYPE: Any | None DEFAULT: None

key

A jax.random.PRNGKey used to provide randomness for parameter initialisation. (Keyword only argument.)

TYPE: PRNGKeyArray

__call__(x) ¤

Forward pass through ConstantSPDMatrix.

PARAMETER DESCRIPTION
x

Ignored; provided for compatibility with the rest of the matrix-valued function API.

TYPE: Array

RETURNS DESCRIPTION
Array

A JAX array of shape shape.