Linear Layers
klax.nn.Linear
¤
Performs a linear transformation.
This class is modified from
equinox.nn.Linear
to allow for custom initialization.
__init__(in_features, out_features, weight_init, bias_init=<function zeros>, use_bias=True, weight_wrap=None, bias_wrap=None, dtype=None, *, key)
¤
Initialize the linear layer.
PARAMETER | DESCRIPTION |
---|---|
in_features
|
The input size. The input to the layer should be a
vector of shape
TYPE:
|
out_features
|
The output size. The output from the layer will be a
vector of shape
TYPE:
|
weight_init
|
The weight initializer of type
TYPE:
|
bias_init
|
The bias initializer of type
TYPE:
|
use_bias
|
Whether to add on a bias as well.
TYPE:
|
weight_wrap
|
An optional wrapper that can be passed to enforce weight constraints.
TYPE:
|
bias_wrap
|
An optional wrapper that can be passed to enforce bias constraints.
TYPE:
|
dtype
|
The dtype to use for the weight and the bias in this layer.
Defaults to either
TYPE:
|
key
|
A
TYPE:
|
Note
Note that in_features
also supports the string "scalar"
as a
special value. In this case the input to the layer should be of
shape ()
.
Likewise out_features
can also be a string "scalar"
, in which
case the output from the layer will have shape ()
.
Further note that, some jax.nn.initializers.Initializer
s do not
work if one of in_features
or out_features
is zero.
Likewise, some jax.nn.initializers.Initialzers
s do not work when
dtype
is jax.numpy.complex64
.
__call__(x, *, key=None)
¤
Forward pass of the linear transformation.
PARAMETER | DESCRIPTION |
---|---|
x
|
The input. Should be a JAX array of shape
TYPE:
|
key
|
Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)
TYPE:
|
Note
If you want to use higher order tensors as inputs (for example
featuring batch dimensions) then use jax.vmap
. For example, for
an input x
of shape (batch, in_features)
, using
>>> import jax
>>> from jax.nn.initializers import he_normal
>>> import jax.random as jrandom
>>> import klax
>>>
>>> key = jrandom.PRNGKey(0)
>>> keys = jrandom.split(key)
>>> x = jrandom.uniform(keys[0], (10,))
>>> linear = klax.nn.Linear(
>>> "scalar",
>>> "scalar",
>>> he_normal(),
>>> key=keys[1]
>>> )
>>> jax.vmap(linear)(x).shape
(10,)
will produce the appropriate output of shape
(batch, out_features)
.
RETURNS | DESCRIPTION |
---|---|
Array
|
A JAX array of shape |
Array
|
|
klax.nn.InputSplitLinear
¤
Performs a linear transformation for multiple inputs.
The transformation is of the form:
y = x_1 @ W_1 + x_2 @ W_2 + ... + x_n @ W_n + b
for
x_1, ..., x_n
.
This layer is useful for formulating transformations with multiple inputs where different inputs require different weight constraints or initialization for the corresponding weight matrices.
__init__(in_features, out_features, weight_inits, bias_init=<function zeros>, use_bias=True, weight_wraps=None, bias_wrap=None, dtype=None, *, key)
¤
Initialize the input split linear layer.
PARAMETER | DESCRIPTION |
---|---|
in_features
|
The input sizes of each input. The n-th input to the
layer should be a vector of shape
TYPE:
|
out_features
|
The output size. The output from the layer will be a
vector of shape
TYPE:
|
weight_inits
|
Weight initializer or sequence of weight initializers
of type
TYPE:
|
bias_init
|
The bias initializer of type
TYPE:
|
use_bias
|
Whether to add on a bias as well.
TYPE:
|
weight_wraps
|
One or a list/tuple of wrappers that can be passed to enforce weight constraints. By specifying a sequence it is possible to apply a different wrapper to each weight matrix. The sequence must have the same length as in_features.
TYPE:
|
bias_wrap
|
An optional wrapper that can be passed to enforce bias constraints.
TYPE:
|
dtype
|
The dtype to use for the weight and the bias in this layer.
Defaults to either
TYPE:
|
key
|
A
TYPE:
|
Note
Note that in_features
also supports the string "scalar"
as a
special value. In this case the respective input to the layer
should
be of shape ()
.
Likewise out_features
can also be a string "scalar"
, in which
case the output from the layer will have shape ()
.
Further note that, some jax.nn.initializers.Initializer
s do not
work if one of in_features
or out_features
is zero.
Likewise, some jax.nn.initializers.Initialzer
s do not work when
dtype
is jax.numpy.complex64
.
__call__(*xs, key=None)
¤
Forward pass of the linear transformation.
PARAMETER | DESCRIPTION |
---|---|
xs
|
The inputs. Should be n JAX arrays x_i of shape
TYPE:
|
key
|
Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)
TYPE:
|
RETURNS | DESCRIPTION |
---|---|
Array
|
A JAX array of shape |
Array
|
|