Callbacks
klax.CallbackArgs
¤
A callback argument designed to work in conjunction with klax.fit
.
This class should not be instantiated directly. An instance of this class is passed to every callback object in the fit function. When writing a custom callback, use the properties of this class to access the current model, optimizer state, training data, and validation data during training.
This class implements cached and lazy-evaluated values via property
methods. This means that properties like loss
are only calculated if
they are used and are stored such that they are not calculated multiple
times.
loss
property
¤
Lazy-evaluated and cached training loss.
model
property
¤
Lazy-evaluated and cached model.
opt_state
property
¤
Lazy-evaluated and cached optimizer state.
val_loss
property
¤
Lazy-evaluated and cached validation loss.
__init__(get_loss, treedef_model, treedef_opt_state, data, val_data=None)
¤
Initialize the callback arguments object.
get_loss: Function that takes a model and a batch of data and
returns the loss.
treedef_model: Tree structure of the model.
treedef_opt_state: Tree structure of the :py:mod:optax
optimizer.
data: PyTree of the training data.
val_data: PyTree of the validation data. If None, no validation
loss is calculated and the property :py:attr:val_loss
will
return None.
update(flat_model, flat_opt_state, step)
¤
Update the object with the current model and optimizer state.
This method is called repeatedly in klax.fit
.
PARAMETER | DESCRIPTION |
---|---|
flat_model
|
Flattened PyTree of the model.
TYPE:
|
flat_opt_state
|
Flattened PyTree of the
TYPE:
|
step
|
Current step-count of the training.
TYPE:
|
klax.Callback
¤
klax.HistoryCallback(klax.Callback)
¤
Default callback for logging a training process.
Records loss histories, training time, and the last optimizer state.
__init__(log_every=100, verbose=True)
¤
Initialize the HistoryCallback
.
log_every: Amount of steps after which the training and validation losses are logged. (Defaults to 100.) verbose: If true prints the training progress and losses. (Defaults to True.)
__call__(cbargs)
¤
Record the losses and step count.
Called at each step during training.
load(filename)
staticmethod
¤
Load a HistoryCallback
instance from a file.
PARAMETER | DESCRIPTION |
---|---|
filename
|
The file path from which the instance should be loaded.
TYPE:
|
RETURNS | DESCRIPTION |
---|---|
klax.HistoryCallback
|
The loaded |
RAISES | DESCRIPTION |
---|---|
ValueError
|
If the file is not a valid pickle file or does not
contain a |
save(filename, overwrite=False, create_dir=True)
¤
Save the HistoryCallback instance to a file using pickle.
PARAMETER | DESCRIPTION |
---|---|
filename
|
The file path where the instance should be saved.
TYPE:
|
overwrite
|
If True, overwrite the file if it already exists. If False, raise a FileExistsError if the file exists. (Defaults to False.)
TYPE:
|
create_dir
|
If True, create the parent directory if it does not exist. (Defaults to True.)
TYPE:
|
RAISES | DESCRIPTION |
---|---|
FileExistsError
|
If the file already exists and overwrite is False. |
ValueError
|
If the provided path is not a valid file path. |
plot(*, ax=None, loss_options={}, val_loss_options={})
¤
Plot the recorded training and validation losses.
Note
This method requires matplotlib.
PARAMETER | DESCRIPTION |
---|---|
ax
|
Matplotlib axes to plot into. If
TYPE:
|
loss_options
|
Dictionary of keyword arguments passed to
matplotlibs
TYPE:
|
val_loss_options
|
Dictionary of keyword arguments passed to
matplotlibs
TYPE:
|
RAISES | DESCRIPTION |
---|---|
ImportError
|
description |
on_training_start(cbargs)
¤
Initialize the training start time.
Called at beginning of training.
on_training_end(cbargs)
¤
Record the training end time and the last optimizer state.
Called at end of training.