Matrix-valued functions
klax.nn.Matrix
¤
An unconstrained matrix-valued function based on an MLP.
The MLP maps to a vector of elements which is transformed into a matrix.
__init__(in_size, shape, width_sizes, weight_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, activation=<PjitFunction of <function softplus at 0x7fb723085ee0>>, final_activation=<function Matrix.<lambda>>, use_bias=True, use_final_bias=True, dtype=None, *, key)
¤
Initialize the Matrix
.
PARAMETER | DESCRIPTION |
---|---|
in_size
|
The input size. The input to the module should be a vector
of shape
TYPE:
|
shape
|
The matrix shape. The output from the module will be an
array with the specified
TYPE:
|
width_sizes
|
The sizes of each hidden layer of the underlying MLP in a list.
TYPE:
|
weight_init
|
The weight initializer of type
TYPE:
|
bias_init
|
The bias initializer of type
TYPE:
|
activation
|
The activation function after each hidden layer. (Defaults to ReLU).
TYPE:
|
final_activation
|
The activation function after the output layer. (Defaults to the identity.)
TYPE:
|
use_bias
|
Whether to add on a bias to internal layers.
(Defaults to
TYPE:
|
use_final_bias
|
Whether to add on a bias to the final layer.
(Defaults to
TYPE:
|
dtype
|
The dtype to use for all the weights and biases in this MLP.
(Defaults to either
TYPE:
|
key
|
A
TYPE:
|
Note
Note that in_size
also supports the string "scalar"
as a
special value. In this case the input to the module should be of
shape ()
.
__call__(x)
¤
Forward pass through Matrix
.
PARAMETER | DESCRIPTION |
---|---|
x
|
The input. Should be a JAX array of shape
TYPE:
|
RETURNS | DESCRIPTION |
---|---|
Array
|
A JAX array of shape |
klax.nn.ConstantMatrix
¤
A constant, unconstrained matrix.
It is a wrapper around a constant array that implements the matrix-valued function interface.
__init__(shape, init=<function variance_scaling.<locals>.init>, dtype=None, *, key)
¤
Initialize the ConstantMatrix
.
PARAMETER | DESCRIPTION |
---|---|
shape
|
The matrix shape. The output from the module will be a Array
with sthe specified
TYPE:
|
init
|
The array initializer of type
TYPE:
|
dtype
|
The dtype to use for all the weights and biases in this MLP.
(Defaults to either
TYPE:
|
key
|
A
TYPE:
|
__call__(x)
¤
Forward pass through ConstantMatrix
.
PARAMETER | DESCRIPTION |
---|---|
x
|
Ignored; provided for compatibility with the rest of the Matrix-valued function API.
TYPE:
|
RETURNS | DESCRIPTION |
---|---|
Array
|
A JAX array of shape |
klax.nn.SkewSymmetricMatrix
¤
A kkew-symmetric matrix-valued function based on an MLP.
The MLP maps the input to a vector of elements that are transformed into a skew-symmetric matrix.
__init__(in_size, shape, width_sizes, weight_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, activation=<PjitFunction of <function softplus at 0x7fb723085ee0>>, final_activation=<function SkewSymmetricMatrix.<lambda>>, use_bias=True, use_final_bias=True, dtype=None, *, key)
¤
Initialize the SkewSymmetricMatrix
.
PARAMETER | DESCRIPTION |
---|---|
in_size
|
The input size. The input to the module should be a vector
of shape
TYPE:
|
shape
|
The matrix shape. The output from the module will be a Array
with sthe specified
TYPE:
|
width_sizes
|
The sizes of each hidden layer of the underlying MLP in a list.
TYPE:
|
weight_init
|
The weight initializer of type
TYPE:
|
bias_init
|
The bias initializer of type
TYPE:
|
activation
|
The activation function after each hidden layer.
(Defaults to
TYPE:
|
final_activation
|
The activation function after the output layer. (Defaults to the identity.)
TYPE:
|
use_bias
|
Whether to add on a bias to internal layers.
(Defaults to
TYPE:
|
use_final_bias
|
Whether to add on a bias to the final layer.
(Defaults to
TYPE:
|
dtype
|
The dtype to use for all the weights and biases in this MLP.
(Defaults to either
TYPE:
|
key
|
A
TYPE:
|
Note
Note that in_size
also supports the string "scalar"
as a
special value. In this case the input to the module should be of
shape ()
.
__call__(x)
¤
Forward pass through SkewSymmetricMatrix
.
PARAMETER | DESCRIPTION |
---|---|
x
|
The input. Should be a JAX array of shape
TYPE:
|
RETURNS | DESCRIPTION |
---|---|
Array
|
A JAX array of shape |
klax.nn.ConstantSkewSymmetricMatrix
¤
A constant skew-symmetric matrix.
It is a wrapper around a constant skew-symmetry-constraind array that implements the matrix-valued function interface.
__init__(shape, init=<function variance_scaling.<locals>.init>, dtype=None, *, key)
¤
Initialize the ConstantSkewSymmetricMatrix
.
PARAMETER | DESCRIPTION |
---|---|
shape
|
The matrix shape. The output from the module will be a Array
with sthe specified
TYPE:
|
init
|
The array initializer of type
TYPE:
|
dtype
|
The dtype to use for all the weights and biases in this MLP.
(Defaults to either
TYPE:
|
key
|
A
TYPE:
|
__call__(x)
¤
Forward pass through ConstantSkewSymmetricMatrix
.
PARAMETER | DESCRIPTION |
---|---|
x
|
Ignored; provided for compatibility with the rest of the Matrix-valued function API.
TYPE:
|
RETURNS | DESCRIPTION |
---|---|
Array
|
A JAX array of shape |
klax.nn.SPDMatrix
¤
A symmetric positive definite matrix-valued function based on an MLP.
The output vector v
of the MLP is mapped to a matrix B
. The module's
output is then computed via A=B@B*
.
__init__(in_size, shape, width_sizes, epsilon=1e-06, weight_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, activation=<PjitFunction of <function softplus at 0x7fb723085ee0>>, final_activation=<function SPDMatrix.<lambda>>, use_bias=True, use_final_bias=True, dtype=None, *, key)
¤
Initialize the SPDMatrix
.
PARAMETER | DESCRIPTION |
---|---|
in_size
|
The input size. The input to the module should be a vector
of shape
TYPE:
|
shape
|
The matrix shape. The output from the module will be a Array
with sthe specified
TYPE:
|
width_sizes
|
The sizes of each hidden layer of the underlying MLP in a list.
TYPE:
|
epsilon
|
Small value that is added to the diagonal of the output
matrix to ensure positive definiteness. If only positive
semi-definiteness is required set
TYPE:
|
weight_init
|
The weight initializer of type
TYPE:
|
bias_init
|
The bias initializer of type
TYPE:
|
activation
|
The activation function after each hidden layer.
(Defaults to
TYPE:
|
final_activation
|
The activation function after the output layer. (Defaults to the identity.)
TYPE:
|
use_bias
|
Whether to add on a bias to internal layers.
(Defaults to
TYPE:
|
use_final_bias
|
Whether to add on a bias to the final layer.
(Defaults to
TYPE:
|
dtype
|
The dtype to use for all the weights and biases in this MLP.
(Defaults to either
TYPE:
|
key
|
A
TYPE:
|
Note
Note that in_size
also supports the string "scalar"
as a
special value. In this case the input to the module should be of
shape ()
.
__call__(x)
¤
Forward pass through SPDMatrix
.
PARAMETER | DESCRIPTION |
---|---|
x
|
The input. Should be a JAX array of shape
TYPE:
|
RETURNS | DESCRIPTION |
---|---|
Array
|
A JAX array of shape |
klax.nn.ConstantSPDMatrix
¤
A constant symmetric positive definite matrix-valued function.
It is a wrapper around a constant symmetric postive semi-definite matrix with the matrix-valued function interface.
__init__(shape, epsilon=1e-06, init=<function variance_scaling.<locals>.init>, dtype=None, *, key)
¤
Initialize the ConstantSPDMatrix
.
PARAMETER | DESCRIPTION |
---|---|
shape
|
The matrix shape. The output from the module will be a
Array with sthe specified
TYPE:
|
epsilon
|
Small value that is added to the diagonal of the output
matrix to ensure positive definiteness. If only positive
semi-definiteness is required set
TYPE:
|
init
|
The initializer of type
TYPE:
|
dtype
|
The dtype to use for all the weights and biases in this MLP.
(Defaults to either
TYPE:
|
key
|
A
TYPE:
|
__call__(x)
¤
Forward pass through ConstantSPDMatrix
.
PARAMETER | DESCRIPTION |
---|---|
x
|
Ignored; provided for compatibility with the rest of the matrix-valued function API.
TYPE:
|
RETURNS | DESCRIPTION |
---|---|
Array
|
A JAX array of shape |