Skip to content

Callbacks

klax.CallbackArgs ¤

A callback argument designed to work in conjunction with klax.fit.

This class should not be instantiated directly. An instance of this class is passed to every callback object in the fit function. When writing a custom callback, use the properties of this class to access the current model, optimizer state, training data, and validation data during training.

This class implements cached and lazy-evaluated values via property methods. This means that properties like loss are only calculated if they are used and are stored such that they are not calculated multiple times.

loss property ¤

Lazy-evaluated and cached training loss.

model property ¤

Lazy-evaluated and cached model.

opt_state property ¤

Lazy-evaluated and cached optimizer state.

val_loss property ¤

Lazy-evaluated and cached validation loss.

__init__(get_loss, treedef_model, treedef_opt_state, data, val_data=None) ¤

Initialize the callback arguments object.

get_loss: Function that takes a model and a batch of data and returns the loss. treedef_model: Tree structure of the model. treedef_opt_state: Tree structure of the :py:mod:optax optimizer. data: PyTree of the training data. val_data: PyTree of the validation data. If None, no validation loss is calculated and the property :py:attr:val_loss will return None.

update(flat_model, flat_opt_state, step) ¤

Update the object with the current model and optimizer state.

This method is called repeatedly in klax.fit.

PARAMETER DESCRIPTION
flat_model

Flattened PyTree of the model.

TYPE: PyTree

flat_opt_state

Flattened PyTree of the optax optimizer.

TYPE: PyTree

step

Current step-count of the training.

TYPE: int


klax.Callback ¤

An abstract callback.

Inherit from this class to create a custom callback.

__call__(cbargs) ¤

Call after each step during training.

on_training_start(cbargs) ¤

Call when training starts.

on_training_end(cbargs) ¤

Call when training ends.


klax.HistoryCallback(klax.Callback) ¤

Default callback for logging a training process.

Records loss histories, training time, and the last optimizer state.

__init__(log_every=100, verbose=True) ¤

Initialize the HistoryCallback.

log_every: Amount of steps after which the training and validation losses are logged. (Defaults to 100.) verbose: If true prints the training progress and losses. (Defaults to True.)

__call__(cbargs) ¤

Record the losses and step count.

Called at each step during training.

load(filename) staticmethod ¤

Load a HistoryCallback instance from a file.

PARAMETER DESCRIPTION
filename

The file path from which the instance should be loaded.

TYPE: str | pathlib.Path

RETURNS DESCRIPTION
klax.HistoryCallback

The loaded HistoryCallback instance.

RAISES DESCRIPTION
ValueError

If the file is not a valid pickle file or does not contain a HistoryCallback instance.

save(filename, overwrite=False, create_dir=True) ¤

Save the HistoryCallback instance to a file using pickle.

PARAMETER DESCRIPTION
filename

The file path where the instance should be saved.

TYPE: str | pathlib.Path

overwrite

If True, overwrite the file if it already exists. If False, raise a FileExistsError if the file exists. (Defaults to False.)

TYPE: bool DEFAULT: False

create_dir

If True, create the parent directory if it does not exist. (Defaults to True.)

TYPE: bool DEFAULT: True

RAISES DESCRIPTION
FileExistsError

If the file already exists and overwrite is False.

ValueError

If the provided path is not a valid file path.

plot(*, ax=None, loss_options={}, val_loss_options={}) ¤

Plot the recorded training and validation losses.

Note

This method requires matplotlib.

PARAMETER DESCRIPTION
ax

Matplotlib axes to plot into. If None then a new axis is created. (Defaults to None.)

TYPE: Any DEFAULT: None

loss_options

Dictionary of keyword arguments passed to matplotlibs plot for the training loss. (Defaults to {}.)

TYPE: dict DEFAULT: {}

val_loss_options

Dictionary of keyword arguments passed to matplotlibs plot for the validation loss. (Defaults to {}.)

TYPE: dict DEFAULT: {}

RAISES DESCRIPTION
ImportError

description

on_training_start(cbargs) ¤

Initialize the training start time.

Called at beginning of training.

on_training_end(cbargs) ¤

Record the training end time and the last optimizer state.

Called at end of training.