Loss functions
klax.Loss
¤
An abstract callable loss object.
It can be used to build custom losses that can be passed to klax.fit
.
Example
A simple custom loss that computes the mean squared error between
the predicted values y_pred
and true values y
for in inputs x
may
be implemented as follows:
>>> def mse(model, data, batch_axis=0):
... x, y = data
... if isinstance(batch_axis, tuple):
... in_axes = batch_axis[0]
... else:
... in_axes = batch_axis
... y_pred = jax.vmap(model, in_axes=(in_axes,))(x)
... return jnp.mean(jnp.square(y_pred - y))
Note that, since we a aim to provide a maximum of flexibility the users
have to take care of applying jax.vmap
to the model themselves.
__call__(model, data, batch_axis)
¤
Abstract method to compute the loss for a given model and data.
PARAMETER | DESCRIPTION |
---|---|
model
|
The model parameters or structure to evaluate the loss.
TYPE:
|
data
|
The input data or structure used for loss computation.
TYPE:
|
batch_axis
|
Specifies the axis or axes corresponding to the batch dimension in the data. Can be an integer, None, or a sequence of values.
TYPE:
|
RETURNS | DESCRIPTION |
---|---|
Scalar
|
The computed loss value.
TYPE:
|
klax.MSE(klax.Loss)
¤
Mean squared error for a tuple of data (x, y)
.
The inputs x
and the outputs y
are expected to have the same batch axis
and equal length along that axis.
__call__(model, data, batch_axis=0)
¤
klax.mse(model, data, batch_axis=0)
¤
Mean squared error for a tuple of data (x, y)
.
The inputs x
and the outputs y
are expected to have the same batch axis
and equal length along that axis.
klax.MAE(klax.Loss)
¤
Mean absolute error for a tuple of data (x, y)
.
The inputs x
and the outputs y
are expected to have the same batch axis
and equal length along that axis.
__call__(model, data, batch_axis=0)
¤
klax.mae(model, data, batch_axis=0)
¤
Mean absolute error for a tuple of data (x, y)
.
The inputs x
and the outputs y
are expected to have the same batch axis
and equal length along that axis.