Skip to content

Linear Layers

klax.nn.Linear ¤

Performs a linear transformation.

This class is modified from equinox.nn.Linear to allow for custom initialization.

__init__(in_features, out_features, weight_init, bias_init=<function zeros>, use_bias=True, weight_wrap=None, bias_wrap=None, dtype=None, *, key) ¤

Initialize the linear layer.

PARAMETER DESCRIPTION
in_features

The input size. The input to the layer should be a vector of shape (in_features,)

TYPE: int | Literal['scalar']

out_features

The output size. The output from the layer will be a vector of shape (out_features,).

TYPE: int | Literal['scalar']

weight_init

The weight initializer of type jax.nn.initializers.Initializer.

TYPE: jax.nn.initializers.Initializer

bias_init

The bias initializer of type jax.nn.initializers.Initializer.

TYPE: jax.nn.initializers.Initializer DEFAULT: <function zeros>

use_bias

Whether to add on a bias as well.

TYPE: bool DEFAULT: True

weight_wrap

An optional wrapper that can be passed to enforce weight constraints.

TYPE: type[klax.Constraint] | type[klax.Unwrappable[Array]] | None DEFAULT: None

bias_wrap

An optional wrapper that can be passed to enforce bias constraints.

TYPE: type[klax.Constraint] | type[klax.Unwrappable[Array]] | None DEFAULT: None

dtype

The dtype to use for the weight and the bias in this layer. Defaults to either jax.numpy.float32 or jax.numpy.float64 depending on whether JAX is in 64-bit mode.

TYPE: type | None DEFAULT: None

key

A jax.random.PRNGKey used to provide randomness for parameter initialisation. (Keyword only argument.)

TYPE: PRNGKeyArray

Note

Note that in_features also supports the string "scalar" as a special value. In this case the input to the layer should be of shape ().

Likewise out_features can also be a string "scalar", in which case the output from the layer will have shape ().

Further note that, some jax.nn.initializers.Initializers do not work if one of in_features or out_features is zero.

Likewise, some jax.nn.initializers.Initialzerss do not work when dtype is jax.numpy.complex64.

__call__(x, *, key=None) ¤

Forward pass of the linear transformation.

PARAMETER DESCRIPTION
x

The input. Should be a JAX array of shape (in_features,). (Or shape () if in_features="scalar".)

TYPE: Array

key

Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)

TYPE: PRNGKeyArray | None DEFAULT: None

Note

If you want to use higher order tensors as inputs (for example featuring batch dimensions) then use jax.vmap. For example, for an input x of shape (batch, in_features), using

>>> import jax
>>> from jax.nn.initializers import he_normal
>>> import jax.random as jrandom
>>> import klax
>>>
>>> key = jrandom.PRNGKey(0)
>>> keys = jrandom.split(key)
>>> x = jrandom.uniform(keys[0], (10,))
>>> linear = klax.nn.Linear(
>>>     "scalar",
>>>     "scalar",
>>>     he_normal(),
>>>     key=keys[1]
>>> )
>>> jax.vmap(linear)(x).shape
(10,)

will produce the appropriate output of shape (batch, out_features).

RETURNS DESCRIPTION
Array

A JAX array of shape (out_features,). (Or shape () if

Array

out_features="scalar".)


klax.nn.InputSplitLinear ¤

Performs a linear transformation for multiple inputs.

The transformation is of the form: y = x_1 @ W_1 + x_2 @ W_2 + ... + x_n @ W_n + b for x_1, ..., x_n.

This layer is useful for formulating transformations with multiple inputs where different inputs require different weight constraints or initialization for the corresponding weight matrices.

__init__(in_features, out_features, weight_inits, bias_init=<function zeros>, use_bias=True, weight_wraps=None, bias_wrap=None, dtype=None, *, key) ¤

Initialize the input split linear layer.

PARAMETER DESCRIPTION
in_features

The input sizes of each input. The n-th input to the layer should be a vector of shape (in_features[n],)

TYPE: Sequence[int | Literal['scalar']]

out_features

The output size. The output from the layer will be a vector of shape (out_features,).

TYPE: int | Literal['scalar']

weight_inits

Weight initializer or sequence of weight initializers of type jax.nn.initializers.Initializer. By specifying a sequence it is possible to apply a different initializer to each weight matrix. The sequence must have the same length as in_features.

TYPE: Sequence[jax.nn.initializers.Initializer] | jax.nn.initializers.Initializer

bias_init

The bias initializer of type jax.nn.initializers.Initializer.

TYPE: jax.nn.initializers.Initializer DEFAULT: <function zeros>

use_bias

Whether to add on a bias as well.

TYPE: bool DEFAULT: True

weight_wraps

One or a list/tuple of wrappers that can be passed to enforce weight constraints. By specifying a sequence it is possible to apply a different wrapper to each weight matrix. The sequence must have the same length as in_features.

TYPE: Sequence[type[klax.Constraint] | type[klax.Unwrappable[Array]] | None] | type[klax.Constraint] | type[klax.Unwrappable[Array]] | None DEFAULT: None

bias_wrap

An optional wrapper that can be passed to enforce bias constraints.

TYPE: type[klax.Constraint] | type[klax.Unwrappable[Array]] | None DEFAULT: None

dtype

The dtype to use for the weight and the bias in this layer. Defaults to either jax.numpy.float32 or jax.numpy.float64 depending on whether JAX is in 64-bit mode.

TYPE: type | None DEFAULT: None

key

A jax.random.PRNGKey used to provide randomness for parameter initialisation. (Keyword only argument.)

TYPE: PRNGKeyArray

Note

Note that in_features also supports the string "scalar" as a special value. In this case the respective input to the layer should be of shape ().

Likewise out_features can also be a string "scalar", in which case the output from the layer will have shape ().

Further note that, some jax.nn.initializers.Initializers do not work if one of in_features or out_features is zero.

Likewise, some jax.nn.initializers.Initialzers do not work when dtype is jax.numpy.complex64.

__call__(*xs, key=None) ¤

Forward pass of the linear transformation.

PARAMETER DESCRIPTION
xs

The inputs. Should be n JAX arrays x_i of shape (in_features[i],). (Or shape () if in_features[i]="scalar".)

TYPE: Array

key

Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)

TYPE: PRNGKeyArray | None DEFAULT: None

RETURNS DESCRIPTION
Array

A JAX array of shape (out_features,). (Or shape () if

Array

out_features="scalar".)