Skip to content

Loss functions

Loss functions define the objective minimized during training with klax.fit. In klax, a loss receives (model, batch, run_state) and returns a scalar value. Built-in options like klax.mse and klax.mae cover common regression tasks, while custom losses can be created either with the loss decorator or by implementing a custom Loss class.


Ready to use loss functions¤

klax.mse(model, data, run_state) ¤

Mean squared error for a tuple of data (x, y).

The inputs x and the outputs y are expected to have the same batch axis and equal length along that axis.

klax.mae(model, data, run_state) ¤

Mean absolute error for a tuple of data (x, y).

The inputs x and the outputs y are expected to have the same batch axis and equal length along that axis.


Defining custom loss functions¤

To define a custom loss function use the loss decorator. Behind the scenes this converts your custom function into a Loss object. In advanced cases, where you need finer control on how the loss or its gradient is calculated, you can also directly implement a custom Loss object.

klax.loss(func) ¤

Convert a function into a klax.Loss object.

Example

To create a mean squared error loss using this decorator, you can do:

@klax.loss
def mse(model, data, run_state):
    x, y = data
    y_pred = jax.vmap(model)(x)
    return jnp.mean(jnp.square(y_pred - y))

PARAMETER DESCRIPTION
func

Function that computes the loss. It must have the signature (model: PyTree, batch: PyTree, run_state: PyTree) -> Scalar.

TYPE: Callable[[PyTree, PyTree, PyTree], Shaped[Array, '']]

RETURNS DESCRIPTION
Loss

An instance of a subclass of klax.Loss that wraps the given function.

TYPE: Loss

klax.Loss ¤

An abstract callable loss object.

Inherit from this class to define a custom loss that can be passed to fit. An instance of the loss class has two methods - value and value_and_grad - which determine how the loss value and it's gradients are calculated. To define a custom loss, implement a custom value method. When calling the loss instance, the model will first be unwrapped, and then passed to the value method. The value_and_grad function per default computes the gradient based on the value function. You should only overwrite it to specify a custom gradient computation.

Example

A simple custom loss that computes the mean squared error between the predicted values y_pred and true values y for inputs x may be implemented as follows:

>>> class MSE(klax.Loss):
...     def value(self, model, data, run_state):
...         x, y = data
...         y_pred = jax.vmap(model)(x)
...         return jnp.mean(jnp.square(y_pred - y))

__call__(model, batch, run_state) ¤

Compute the loss value used during training.

This method unwraps the model before computing the loss by calling the value method.

PARAMETER DESCRIPTION
model

The model parameters or structure to evaluate the loss.

TYPE: PyTree

batch

The input data or structure used for loss computation.

TYPE: PyTree[Any, T]

run_state

Auxiliary, user-defined runtime state.

TYPE: PyTree[Any]

RETURNS DESCRIPTION
Scalar

The computed loss value.

TYPE: Shaped[Array, '']

value(model, batch, run_state) ¤

Abstract method to compute the loss for a given model and data.

PARAMETER DESCRIPTION
model

The model parameters or structure to evaluate the loss.

TYPE: PyTree

batch

The input data or structure used for loss computation.

TYPE: PyTree[Any, T]

run_state

Auxiliary, user-defined runtime state.

TYPE: PyTree[Any]

RETURNS DESCRIPTION
Scalar

The computed loss value.

TYPE: Shaped[Array, '']

value_and_grad(model, batch, run_state) ¤

Compute the loss value and its gradient.

This method computes the loss value and its gradient with respect to the model parameters by applying eqx.filter_value_and_grad to the value method.

PARAMETER DESCRIPTION
model

The model parameters or structure to evaluate the loss.

TYPE: PyTree[Any, M]

batch

The input data or structure used for loss computation.

TYPE: PyTree[Any, T]

run_state

Auxiliary, user-defined runtime state.

TYPE: PyTree[Any]

RETURNS DESCRIPTION
tuple[Shaped[Array, ''], PyTree[Any, M]]

Tuple of loss value and gradient with respect to the model.