Multi-layer perceptrons
klax.nn.MLP
¤
Standard Multi-Layer Perceptron; also known as a feed-forward network.
This class is modified form equinox.nn.MLP
to allow for custom initialization and different node numbers in the hidden
layers. Hence, it may also be used for ecoder/decoder tasks.
__init__(in_size, out_size, width_sizes, weight_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, activation=<PjitFunction of <function softplus at 0x7fb723085ee0>>, final_activation=<function MLP.<lambda>>, use_bias=True, use_final_bias=True, weight_wrap=None, bias_wrap=None, dtype=None, *, key)
¤
Initialize MLP.
PARAMETER | DESCRIPTION |
---|---|
in_size
|
The input size. The input to the module should be a vector
of shape
TYPE:
|
out_size
|
The output size. The output from the module will be a
vector of shape
TYPE:
|
width_sizes
|
The sizes of each hidden layer in a list.
TYPE:
|
weight_init
|
The weight initializer of type
TYPE:
|
bias_init
|
The bias initializer of type
TYPE:
|
activation
|
The activation function after each hidden layer.
(Defaults to
TYPE:
|
final_activation
|
The activation function after the output layer. (Defaults to the identity.)
TYPE:
|
use_bias
|
Whether to add on a bias to internal layers.
(Defaults to
TYPE:
|
use_final_bias
|
Whether to add on a bias to the final layer.
(Defaults to
TYPE:
|
weight_wrap
|
An optional wrapper that is passed to all weights.
TYPE:
|
bias_wrap
|
An optional wrapper that is passed to all biases.
TYPE:
|
dtype
|
The dtype to use for all the weights and biases in this MLP.
Defaults to either
TYPE:
|
key
|
A
TYPE:
|
Note
Note that in_size
also supports the string "scalar"
as a
special value. In this case the input to the module should be of
shape ()
.
Likewise out_size
can also be a string "scalar"
, in which case
the output from the module will have shape ()
.
__call__(x, *, key=None)
¤
Forward pass through MLP.
PARAMETER | DESCRIPTION |
---|---|
x
|
A JAX array with shape
TYPE:
|
key
|
Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)
TYPE:
|
RETURNS | DESCRIPTION |
---|---|
Array
|
A JAX array with shape |
Array
|
|
klax.nn.FICNN
¤
A fully input convex neural network (FICNN).
Each element of the output is a convex function of the input.
See: https://arxiv.org/abs/1609.07152
__init__(in_size, out_size, width_sizes, use_passthrough=True, non_decreasing=False, weight_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, activation=<PjitFunction of <function softplus at 0x7fb723085ee0>>, final_activation=<function FICNN.<lambda>>, use_bias=True, use_final_bias=True, dtype=None, *, key)
¤
Initialize FICNN.
Warning
Modifying final_activation
to a non-convex function will break
the convexity of the FICNN. Use this parameter with care.
PARAMETER | DESCRIPTION |
---|---|
in_size
|
The input size. The input to the module should be a vector
of shape
TYPE:
|
out_size
|
The output size. The output from the module will be a
vector of shape
TYPE:
|
width_sizes
|
The sizes of each hidden layer in a list.
TYPE:
|
use_passthrough
|
Whether to use passthrough layers. If true, the input is passed through to each hidden layer. Defaults to True.
TYPE:
|
non_decreasing
|
If true, the output is element-wise non-decreasing
in each input. This is useful if the input
TYPE:
|
weight_init
|
The weight initializer of type
TYPE:
|
bias_init
|
The bias initializer of type
TYPE:
|
activation
|
The activation function of each hidden layer. To ensure
convexity this function must be convex and non-decreasing.
Defaults to
TYPE:
|
final_activation
|
The activation function after the output layer. To ensure convexity this function must be convex and non-decreasing. (Defaults to the identity.)
TYPE:
|
use_bias
|
Whether to add on a bias in the hidden layers. (Defaults to True.)
TYPE:
|
use_final_bias
|
Whether to add on a bias to the final layer. Defaults to True.
TYPE:
|
dtype
|
The dtype to use for all the weights and biases in this MLP.
Defaults to either
TYPE:
|
key
|
A
TYPE:
|
__call__(x, *, key=None)
¤
Forward pass through FICNN
.
PARAMETER | DESCRIPTION |
---|---|
x
|
A JAX array with shape
TYPE:
|
key
|
Ignored; provided for compatibility with the rest of the Equinox API. (Keyword only argument.)
TYPE:
|
RETURNS | DESCRIPTION |
---|---|
Array
|
A JAX array with shape |
Array
|
|