Skip to content

Loss functions

klax.Loss ¤

An abstract callable loss object.

It can be used to build custom losses that can be passed to klax.fit.

Example

A simple custom loss that computes the mean squared error between the predicted values y_pred and true values y for in inputs x may be implemented as follows:

>>> def mse(model, data, batch_axis=0):
...    x, y = data
...    if isinstance(batch_axis, tuple):
...        in_axes = batch_axis[0]
...    else:
...        in_axes = batch_axis
...    y_pred = jax.vmap(model, in_axes=(in_axes,))(x)
...    return jnp.mean(jnp.square(y_pred - y))

Note that, since we a aim to provide a maximum of flexibility the users have to take care of applying jax.vmap to the model themselves.

__call__(model, data, batch_axis) ¤

Abstract method to compute the loss for a given model and data.

PARAMETER DESCRIPTION
model

The model parameters or structure to evaluate the loss.

TYPE: PyTree

data

The input data or structure used for loss computation.

TYPE: PyTree

batch_axis

Specifies the axis or axes corresponding to the batch dimension in the data. Can be an integer, None, or a sequence of values.

TYPE: int | None | Sequence[Any]

RETURNS DESCRIPTION
Scalar

The computed loss value.

TYPE: Shaped[Array, '']


klax.MSE(klax.Loss) ¤

Mean squared error for a tuple of data (x, y).

The inputs x and the outputs y are expected to have the same batch axis and equal length along that axis.

__call__(model, data, batch_axis=0) ¤

klax.mse(model, data, batch_axis=0) ¤

Mean squared error for a tuple of data (x, y).

The inputs x and the outputs y are expected to have the same batch axis and equal length along that axis.


klax.MAE(klax.Loss) ¤

Mean absolute error for a tuple of data (x, y).

The inputs x and the outputs y are expected to have the same batch axis and equal length along that axis.

__call__(model, data, batch_axis=0) ¤

klax.mae(model, data, batch_axis=0) ¤

Mean absolute error for a tuple of data (x, y).

The inputs x and the outputs y are expected to have the same batch axis and equal length along that axis.