Skip to content

Logging

Klax logging is built around metrics that are evaluated during training and stored in a History. A metric is any callable with a name that receives the current TrainingContext. The default MetricLogger callback evaluates metrics every log_every steps, records them in history, and optionally prints progress (or a progress bar). This is the default mechanism used by klax.fit when make_logger=True.

klax.Metric ¤

A Metric computes values that should be recorded in the training history.

Metrics callables, that take the current TrainingContext and return some value to be added to the training history by the MetricLogger. Additionally Metrics have a name and verbose property, that determines how they are logged.

klax.metric(name, verbose=False) ¤

Turn a function into a Metric using a decorator factory.

Intended usage:

@metric(name="my_metric", verbose=False)
def compute_my_metric(context): ...

PARAMETER DESCRIPTION
name

Name of the metric.

TYPE: str

verbose

Wether to print the Metrics values to the console during training. Defaults to False.

TYPE: bool DEFAULT: False

klax.BatchMetric ¤

Loss-like function Metric.

A BatchMetric uses it's own batch generator and data, to turn a loss-like function evaluation into a Metric.

__init__(name, func, data, batcher, batch_size, batch_axes=0, verbose=False, jit_compile=True, *, key) ¤

Initialize the BatchMetric.

PARAMETER DESCRIPTION
name

Name of the metric.

TYPE: str

func

The evaluation function to compute. It should take the model, a batch of data, and the auxiliary runtime state as input.

TYPE: Callable[[PyTree, PyTree[Any, T], PyTree], Any]

data

The dataset to generate batches from.

TYPE: PyTree[Any, T]

batcher

Batch generator factory.

TYPE: Batcher

batch_size

The size of each batch.

TYPE: int

batch_axes

The axes corresponding to the batch dimension in the data.

TYPE: PyTree[int | None, 'T ...'] DEFAULT: 0

verbose

Verbosity level for the metric. Defaults to False.

TYPE: bool DEFAULT: False

jit_compile

If true, eqx.filter_jit is used to jit compile func.

TYPE: bool DEFAULT: True

key

PRNG key for random number generation.

TYPE: PRNGKeyArray

__call__(context) ¤

Compute the metric.

PARAMETER DESCRIPTION
context

TrainingContext to evaluate in.

TYPE: TrainingContext

RETURNS DESCRIPTION
Any

The metric value on the sampled batch.


klax.History ¤

Dict-like object for storing a training history with metadata and utility methods.

The training history stores (metric) values along with the training steps they correspond to, as well as total training time, total steps, and the final optimizer state. Furthermore, it provides methods for saving/loading the history to/from disk, plotting metrics, and extending the history.

append(step, key, value) ¤

Add a new value to the history.

PARAMETER DESCRIPTION
step

Training step the value belongs to.

TYPE: int

key

Metric name.

TYPE: str

value

Metric value.

TYPE: Any

extend(other) ¤

Extend this history with the contents of another history.

PARAMETER DESCRIPTION
other

Another History instance to extend with.

TYPE: History

keys() ¤

Get the list of metric names stored in the history.

RETURNS DESCRIPTION
list[str]

A list of metric names.

plot(*keys, ax=None, **kwargs) ¤

Plot stored metrics using matplotlib.

Note

This method requires matplotlib.

PARAMETER DESCRIPTION
keys

Metric names to plot. If empty, all metrics are plotted.

TYPE: str

ax

Matplotlib axes to plot into. If None then a new axis is created. (Defaults to None.)

TYPE: Any DEFAULT: None

kwargs

Dictionary of keyword arguments passed to matplotlib's plot.

TYPE: Any

RAISES DESCRIPTION
ImportError

If matplotlib is not installed.

save(path) ¤

Persist the history to disk using pickle.

PARAMETER DESCRIPTION
path

Destination filepath where the history will be stored.

TYPE: str | Path

load(path) classmethod ¤

Restore a history saved with :meth:save.

PARAMETER DESCRIPTION
path

Filepath to load the serialized history from.

TYPE: str | Path

RETURNS DESCRIPTION
History

A populated History instance.


klax.MetricLogger ¤

Callback for logging metrics in a History during training.

__init__(log_every=100, metrics=None, verbose=2, history=None) ¤

Initialize the MetricLogger.

PARAMETER DESCRIPTION
log_every

Frequency of logging metrics (in steps).

TYPE: int DEFAULT: 100

metrics

Sequence of metrics to evaluate. If multiple metrics share the same name, the later metrics will overwrite prior metrics.

TYPE: Sequence[Metric] | None DEFAULT: None

verbose

Verbosity level for logging metrics to the console. If 0, no metrics will be printed. If 1, the metrics are printed. If 2, a progress bar will be shown.

TYPE: Literal[0, 1, 2] DEFAULT: 2

history

An existing history object to log metrics to. If None, a new history object will be created.

TYPE: History | None DEFAULT: None

add_metric(metric) ¤

Add a metric to be logged during training.

Warning

Existing metrics sharing the same name will be overwritten.

PARAMETER DESCRIPTION
metric

The metric to be added.

TYPE: Metric