Skip to content

Calibration and Data Handling

klax.batch_data(data, batch_size=32, batch_axis=0, convert_to_numpy=True, *, key) ¤

Create a Generator that draws subsets of data without replacement.

The data can be any PyTree with ArrayLike leaves. If batch_mask is passed, batch axes (including None for no batching) can be specified for every leaf individualy. A generator is returned that indefinetly yields batches of data with size batch_size. Examples are drawn without replacement until the remaining dataset is smaller than batch_size, at which point the dataset will be reshuffeld and the process starts over.

Example

This is an example for a nested PyTree, where the elements x and y have batch dimension along the first axis.

>>> import klax
>>> import jax
>>> import jax.numpy as jnp
>>>
>>> x = jnp.array([1., 2.])
>>> y = jnp.array([[1.], [2.]])
>>> data = (x, {"a": 1.0, "b": y})
>>> batch_mask = (0, {"a": None, "b": 0})
>>> iter_data = klax.batch_data(
...     data,
...     32,
...     batch_mask,
...     key=jax.random.key(0)
... )
>>>
PARAMETER DESCRIPTION
data

The data that shall be batched. It can be any PyTree with ArrayLike leaves.

TYPE: PyTree[Any]

batch_size

The number of examples in a batch.

TYPE: int DEFAULT: 32

batch_axis

PyTree of the batch axis indices. None is used to indicate that the corresponding leaf or subtree in data does not have a batch axis. batch_axis must have the same structure as data or have data as a prefix. (Defaults to 0, meaning all leaves in data are batched along their first dimension.)

TYPE: PyTree[None | int] DEFAULT: 0

convert_to_numpy

If True, batched data leafs will be converted to Numpy arrays before batching. This is useful for performance reasons, as Numpy's slicing is much faster than JAX's.

TYPE: bool DEFAULT: True

key

A jax.random.PRNGKey used to provide randomness for batch generation. (Keyword only argument.)

TYPE: PRNGKeyArray

RETURNS DESCRIPTION
Generator[PyTree[Any], None, None]

A Generator that yields a random batch of data every time is is

Generator[PyTree[Any], None, None]

called.

YIELDS DESCRIPTION
Generator[PyTree[Any], None, None]

A PyTree[ArrayLike] with the same structure as data. Where all

Generator[PyTree[Any], None, None]

batched leaves have batch_size.

Note

Note that if the size of the dataset is smaller than batch_size, the obtained batches will have dataset size.

klax.split_data(data, proportions, batch_axis=0, *, key) ¤

Split a PyTree of data into multiply randomly drawn subsets.

This function is useful for splitting into training and test datasets. The axis of the split if controlled by the batch_axis argument, which specifies the batch axis for each leaf in data.

Example

This is an example for a nested PyTree` of data.

>>> import klax
>>> import jax
>>>
>>> x = jax.numpy.array([1., 2., 3.])
>>> data = (x, {"a": 1.0, "b": x})
>>> s1, s2 = klax.split_data(
...     data,
...     (2, 1),
...     key=jax.random.key(0)
... )
>>> s1
(Array([1., 2.], dtype=float32), {'a': 1.0, 'b': Array([1., 2.], dtype=float32)})
>>> s2
(Array([3.], dtype=float32), {'a': 1.0, 'b': Array([3.], dtype=float32)})
PARAMETER DESCRIPTION
data

Data that shall be split. It can be any PyTree.

TYPE: PyTree[Any]

proportions

Proportions of the split that will be applied to the data, e.g., (80, 20) for a 80% to 20% split. The proportions must be non-negative.

TYPE: Sequence[int | float]

batch_axis

PyTree of the batch axis indices. None is used to indicate that the corresponding leaf or subtree in data does not have a batch axis. batch_axis must have the same structure as data or have data as a prefix. (Defaults to 0)

TYPE: PyTree[None | int] DEFAULT: 0

key

A jax.random.PRNGKey used to provide randomness to the split. (Keyword only argument.)

TYPE: PRNGKeyArray

RETURNS DESCRIPTION
tuple[PyTree[Any], ...]

Tuple of PyTrees.

klax.fit(model, data, *, batch_size=32, batch_axis=0, validation_data=None, steps=1000, loss_fn=<klax._losses.MSE object at 0x7fb722df9670>, optimizer=GradientTransformationExtraArgs(init=<function chain.<locals>.init_fn>, update=<function chain.<locals>.update_fn>), init_opt_state=None, batcher=<function batch_data>, history=None, callbacks=None, key) ¤

Trains a model using an optimizer from optax.

PARAMETER DESCRIPTION
model

The model instance, which should be trained. It must be a subclass of equinox.Module. The model may contain klax.Unwrappable wrappers.

TYPE: T

data

The training data can be any PyTree with ArrayLike leaves. Most likely you'll want data to be a tuple (x, y) with model inputs x and model outputs y.

TYPE: PyTree[Any]

batch_size

The number of examples in a batch.

TYPE: int DEFAULT: 32

batch_axis

A PyTree denoting, which axis is the batch axis for arrays in data. batch_axis must be a prefix of data. By specifying batch_axis as a PyTree it is possible to specify different batch axes for different leaves of data. (Defaults to 0, meaning the first axes of arrays in data are batch dimensions.)

TYPE: PyTree[None | int] DEFAULT: 0

validation_data

Arbitrary PyTree used for validation during training. Must have the same tree structure as data. (Defaults to None.)

TYPE: PyTree[Any] DEFAULT: None

steps

Number of gradient updates to apply. (Defaults to 1000.)

TYPE: int DEFAULT: 1000

loss_fn

The loss function with call signature (model: PyTree, data: PyTree, batch_axis: int | None | Sequence[Any]) -> float. (Defaults to mse.)

TYPE: klax.Loss DEFAULT: <klax._losses.MSE object at 0x7fb722df9670>

optimizer

The optimizer. Any optax gradient transform to calculate the updates for the model. (Defaults to optax.adam(1e-3).)

TYPE: optax.GradientTransformation DEFAULT: GradientTransformationExtraArgs(init=<function chain.<locals>.init_fn>, update=<function chain.<locals>.update_fn>)

init_opt_state

The initial state of the optimizer. If None, the optimizer is initialized from scratch. By providing a value for init_opt_state, the user can resume training from a previous state (e.g., obtained from the HistoryCallback.last_opt_state). (Defaults to None.)

TYPE: PyTree[Any] DEFAULT: None

batcher

The data loader that splits inputs and targets into batches. (Defaults to batch_data.)

TYPE: klax._datahandler.BatchGenerator DEFAULT: <function batch_data>

history

A callback intended for tracking the training process. If no custom callback is passed the klax.HistoryCallback with a logging interval of 100 steps is used. To change the logging increment or verbosity of this default callback, pass a HistoryCallback object to this argument, e.g., history=HistoryCallback(log_every=10, verbose=False) for logging on every 10-th step without printing the loss.

TYPE: klax.HistoryCallback | H | None DEFAULT: None

callbacks

Callback functions that are evaluated after every training step. They can be used to implement early stopping, custom history logging and more. The argument to the callback function is a CallbackArgs object. (Defaults to None. Keyword only Argument)

TYPE: Iterable[klax.Callback] | None DEFAULT: None

key

A jax.random.PRNGKey used to provide randomness for batch generation. (Keyword only argument.)

TYPE: PRNGKeyArray

Note

This function assumes that the batch dimension is always oriented along the first axes of any jax.Array

RETURNS DESCRIPTION
tuple[T, klax.HistoryCallback | H]

A tuple of the trained model and the loss history.