Skip to content

Calibration and Data Handling

klax.batch_data(data, batch_size, batch_axes=0, convert_to_numpy=True, *, key) ¤

Create a Generator that draws subsets of data without replacement.

The data can be any PyTree with ArrayLike leaves. If batch_axes is passed, batch axes (including None for no batching) can be specified for every leaf individually. A generator is returned that indefinitely yields batches of data with size batch_size. Examples are drawn without replacement until the remaining dataset is smaller than batch_size, at which point the dataset will be reshuffled and the process starts over.

Example

This is an example for a nested PyTree, where the elements x and y have batch dimension along the first axis.

>>> import klax
>>> import jax
>>> import jax.numpy as jnp
>>>
>>> x = jnp.array([1., 2.])
>>> y = jnp.array([[1.], [2.]])
>>> data = (x, {"a": 1.0, "b": y})
>>> batch_axes = (0, {"a": None, "b": 0})
>>> batch = klax.batch_data(
...     data,
...     32,
...     batch_axes,
...     key=jax.random.key(0)
... )
PARAMETER DESCRIPTION
data

The data that shall be batched. It can be any PyTree with ArrayLike leaves.

TYPE: PyTree[Any, T]

batch_size

The number of examples in a batch.

TYPE: int

batch_axes

PyTree of the batch axis indices. None is used to indicate that the corresponding leaf or subtree in data does not have a batch axis. batch_axes must have the same structure as data or have data as a prefix. (Defaults to 0, meaning all leaves in data are batched along their first dimension.)

TYPE: PyTree[None | int] DEFAULT: 0

convert_to_numpy

If True, batched data leafs will be converted to Numpy arrays before batching. This is useful for performance reasons, as Numpy's slicing is much faster than JAX's.

TYPE: bool DEFAULT: True

key

A jax.random.PRNGKey used to provide randomness for batch generation. (Keyword only argument.)

TYPE: PRNGKeyArray

RETURNS DESCRIPTION
None

A Generator that yields a random batch of data.

YIELDS DESCRIPTION
PyTree[Any, T]

A PyTree[ArrayLike] with the same structure as data. Where all batched leaves have batch_size.

Note

Note that if the size of the dataset is smaller than batch_size, the used batch_size will be reduced.


klax.split_data(data, proportions, batch_axes=0, *, key) ¤

Split a PyTree of data into multiply randomly drawn subsets.

This function is useful for splitting into training and test datasets. The axis of the split if controlled by the batch_axes argument, which specifies the batch axis for each leaf in data.

Example

This is an example for a nested PyTree` of data.

>>> import klax
>>> import jax
>>>
>>> x = jax.numpy.array([1., 2., 3.])
>>> data = (x, {"a": 1.0, "b": x})
>>> s1, s2 = klax.split_data(
...     data,
...     (2, 1),
...     key=jax.random.key(0)
... )
>>> s1
(Array([1., 2.], dtype=float32), {'a': 1.0, 'b': Array([1., 2.], dtype=float32)})
>>> s2
(Array([3.], dtype=float32), {'a': 1.0, 'b': Array([3.], dtype=float32)})
PARAMETER DESCRIPTION
data

Data that shall be split. It can be any PyTree.

TYPE: PyTree[Any]

proportions

Proportions of the split that will be applied to the data, e.g., (80, 20) for a 80% to 20% split. The proportions must be non-negative.

TYPE: Sequence[int | float]

batch_axes

PyTree of the batch axis indices. None is used to indicate that the corresponding leaf or subtree in data does not have a batch axis. batch_axes must have the same structure as data or have data as a prefix. (Defaults to 0)

TYPE: PyTree[None | int] DEFAULT: 0

key

A jax.random.PRNGKey used to provide randomness to the split. (Keyword only argument.)

TYPE: PRNGKeyArray

RETURNS DESCRIPTION
tuple[PyTree[Any], ...]

Tuple of PyTrees.

klax.fit(model, data, *, batch_size=32, batch_axes=0, run_state=None, validation_data=None, steps=1000, loss=mse, optimizer=GradientTransformationExtraArgs(init=<function chain.<locals>.init_fn at 0x7f8ba9d4fec0>, update=<function chain.<locals>.update_fn at 0x7f8ba9d4ff60>), init_opt_state=None, batcher=batch_data, make_logger=True, metrics=None, log_every=100, verbose=2, callbacks=None, key) ¤

Train a model using an optimizer from optax.

This is a convenient wrapper around run_training_loop that sets up optimizer, training state and callbacks.

PARAMETER DESCRIPTION
model

The model instance, which should be trained. It must be a subclass of equinox.Module. The model may contain klax.Unwrappable wrappers.

TYPE: T

data

The training data can be any PyTree with at least some ArrayLike leaves. Most likely you'll want data to be a tuple (x, y) with model inputs x and model outputs y.

TYPE: PyTree[Any]

batch_size

The number of examples in a batch.

TYPE: int DEFAULT: 32

batch_axes

A PyTree denoting, which axis is the batch axis for arrays in data. batch_axes must be a prefix of data. By specifying batch_axes as a PyTree it is possible to specify different batch axes for different leaves of data. (Defaults to 0, meaning the first axes of arrays in data are batch dimensions.)

Example: For a dataset of 100 examples data = (x, (y1, y2), "some_string") where x has shape (100, 32), y1 has shape (100,) and y2 has shape (10, 100), the appropriate batch_axes would be batch_axes = (0, (0, 1, None)) indicating that the batch axis for x is the first axis (0), for y1 also the first axis (0), for y2 the second axis (1) and for the string there is no batch axis (None). Defaults to 0.

TYPE: PyTree[None | int] DEFAULT: 0

run_state

Auxiliary runtime state, that is passed to the loss function. Can be updated via callbacks. Defaults to None.

TYPE: PyTree[Any] DEFAULT: None

validation_data

Arbitrary PyTree used for validation during training. Must have the same tree structure as data. (Defaults to None.) Internally, the validation data is used to create a BatchMetric for logging. Each time the metric is evaluated, the loss is computed on a batch from the validation dataset with batch size 4*batch_size. Defaults to None

TYPE: PyTree[Any] DEFAULT: None

steps

Number of gradient updates to apply. Defaults to 1000.

TYPE: int DEFAULT: 1000

loss

The loss function with call signature (model: PyTree, data: PyTree, run_state: PyTree) -> float. Defaults to mse.

TYPE: Loss DEFAULT: mse

optimizer

The optimizer. Any optax gradient transform to calculate the updates for the model. Defaults to optax.adam(1e-3).

TYPE: GradientTransformation | GradientTransformationExtraArgs DEFAULT: GradientTransformationExtraArgs(init=<function chain.<locals>.init_fn at 0x7f8ba9d4fec0>, update=<function chain.<locals>.update_fn at 0x7f8ba9d4ff60>)

init_opt_state

The initial state of the optimizer. If None, the optimizer is initialized from scratch. By providing a value for init_opt_state, the user can resume training from a previous state (e.g., obtained from the HistoryCallback.last_opt_state). Defaults to None.

TYPE: PyTree[Any] DEFAULT: None

batcher

The data loader that splits inputs and targets into batches. Defaults to batch_data.

TYPE: Batcher DEFAULT: batch_data

make_logger

Wether to create a MetricLogger. If False the arguments metrics, log_every and verbose don't have any effect and fit will return None instead of a History. This is useful for implementing custom logging.

TYPE: bool DEFAULT: True

metrics

Sequence of metrics to be evaluated at regular intervals during the training. You can overwrite the default "loss" and "validation_loss" metrics, by adding custom metrics with the same name. Defaults to None.

TYPE: Sequence[Metric] | None DEFAULT: None

log_every

Interval for both metric evaluation and progress logging. A value log_every=n means that every n steps during the training the metrics are evaluated and (if verbose>0) the training progress and selected metric values are printed.

TYPE: int DEFAULT: 100

verbose

Integer controlling the verbosity during training. - 0: Nothing is printed. - 1: A message is printed every log_every steps. - 2: A progressbar is used and updated every log_every steps. Defaults to 2.

TYPE: Literal[0, 1, 2] DEFAULT: 2

callbacks

List of Callbacks. They can be used to implement early stopping, custom logging and more. The argument to the callback function is aCallbackArgs object. Defaults to None.

TYPE: Sequence[Callback] | None DEFAULT: None

key

A jax.random.PRNGKey used to provide randomness for batch generation.

TYPE: PRNGKeyArray

RETURNS DESCRIPTION
tuple[T, History | None]

A tuple of the trained model and the training history.

Note

The returned history will be None if make_logger=False.


klax.run_training_loop(context, callbacks) ¤

klax.make_step(state_leaves, state_treedef, batch, loss, optimizer) ¤


klax.TrainingState ¤

Dataclass of things that are expected to change during training.

This consists of:

  • model: The eqx.Module (or more generally any PyTree) representing the trainable model.
  • opt_state: The state of the optax optimizer.
  • run_state: The user-defined state of the training run, which is passed to the loss function and may be modified via callbacks.
  • step: The current optimization step count of the training.

klax.TrainingContext ¤

Collection of all training relevant objects.

This includes:

  • state: The TrainingState
  • optimizer: The optax optimizer
  • loss: The Loss function
  • batch_generator: The generator object responsible for creating data batches
  • steps: The total number of scheduled optimization steps for the training run