Loss functions
Loss functions define the objective minimized during training with
klax.fit. In klax, a loss receives (model, batch, run_state) and
returns a scalar value. Built-in options like klax.mse and klax.mae
cover common regression tasks, while custom losses can be created either with
the loss decorator or by implementing a custom
Loss class.
Ready to use loss functions¤
klax.mse(model, data, run_state)
¤
Mean squared error for a tuple of data (x, y).
The inputs x and the outputs y are expected to have the same batch axis
and equal length along that axis.
klax.mae(model, data, run_state)
¤
Mean absolute error for a tuple of data (x, y).
The inputs x and the outputs y are expected to have the same batch axis
and equal length along that axis.
Defining custom loss functions¤
To define a custom loss function use the loss decorator. Behind the scenes this converts your custom function into a Loss object. In advanced cases, where you need finer control on how the loss or its gradient is calculated, you can also directly implement a custom Loss object.
klax.loss(func)
¤
Convert a function into a klax.Loss object.
Example
To create a mean squared error loss using this decorator, you can do:
@klax.loss
def mse(model, data, run_state):
x, y = data
y_pred = jax.vmap(model)(x)
return jnp.mean(jnp.square(y_pred - y))
| PARAMETER | DESCRIPTION |
|---|---|
func
|
Function that computes the loss. It must have the signature
TYPE:
|
| RETURNS | DESCRIPTION |
|---|---|
Loss
|
An instance of a subclass of
TYPE:
|
klax.Loss
¤
An abstract callable loss object.
Inherit from this class to define a custom loss that can be passed to
fit.
An instance of the loss class has two methods - value and value_and_grad -
which determine how the loss value and it's gradients are calculated.
To define a custom loss, implement a custom value method. When calling
the loss instance, the model will first be unwrapped, and
then passed to the value method.
The value_and_grad function per default computes the gradient based on
the value function. You should only overwrite it to specify a custom
gradient computation.
Example
A simple custom loss that computes the mean squared error between
the predicted values y_pred and true values y for inputs x may
be implemented as follows:
>>> class MSE(klax.Loss):
... def value(self, model, data, run_state):
... x, y = data
... y_pred = jax.vmap(model)(x)
... return jnp.mean(jnp.square(y_pred - y))
__call__(model, batch, run_state)
¤
Compute the loss value used during training.
This method unwraps the model before computing the loss by calling
the value method.
| PARAMETER | DESCRIPTION |
|---|---|
model
|
The model parameters or structure to evaluate the loss.
TYPE:
|
batch
|
The input data or structure used for loss computation.
TYPE:
|
run_state
|
Auxiliary, user-defined runtime state.
TYPE:
|
| RETURNS | DESCRIPTION |
|---|---|
Scalar
|
The computed loss value.
TYPE:
|
value(model, batch, run_state)
¤
Abstract method to compute the loss for a given model and data.
| PARAMETER | DESCRIPTION |
|---|---|
model
|
The model parameters or structure to evaluate the loss.
TYPE:
|
batch
|
The input data or structure used for loss computation.
TYPE:
|
run_state
|
Auxiliary, user-defined runtime state.
TYPE:
|
| RETURNS | DESCRIPTION |
|---|---|
Scalar
|
The computed loss value.
TYPE:
|
value_and_grad(model, batch, run_state)
¤
Compute the loss value and its gradient.
This method computes the loss value and its gradient with respect to
the model parameters by applying eqx.filter_value_and_grad to the value
method.
| PARAMETER | DESCRIPTION |
|---|---|
model
|
The model parameters or structure to evaluate the loss.
TYPE:
|
batch
|
The input data or structure used for loss computation.
TYPE:
|
run_state
|
Auxiliary, user-defined runtime state.
TYPE:
|
| RETURNS | DESCRIPTION |
|---|---|
tuple[Shaped[Array, ''], PyTree[Any, M]]
|
Tuple of loss value and gradient with respect to the model. |