Calibration and Data Handling
klax.batch_data(data, batch_size, batch_axes=0, convert_to_numpy=True, *, key)
¤
Create a Generator that draws subsets of data without replacement.
The data can be any PyTree with ArrayLike leaves. If batch_axes is
passed, batch axes (including None for no batching) can be specified for
every leaf individually.
A generator is returned that indefinitely yields batches of data with size
batch_size. Examples are drawn without replacement until the remaining
dataset is smaller than batch_size, at which point the dataset will be
reshuffled and the process starts over.
Example
This is an example for a nested PyTree, where the elements x and y
have batch dimension along the first axis.
>>> import klax
>>> import jax
>>> import jax.numpy as jnp
>>>
>>> x = jnp.array([1., 2.])
>>> y = jnp.array([[1.], [2.]])
>>> data = (x, {"a": 1.0, "b": y})
>>> batch_axes = (0, {"a": None, "b": 0})
>>> batch = klax.batch_data(
... data,
... 32,
... batch_axes,
... key=jax.random.key(0)
... )
| PARAMETER | DESCRIPTION |
|---|---|
data
|
The data that shall be batched. It can be any
TYPE:
|
batch_size
|
The number of examples in a batch.
TYPE:
|
batch_axes
|
PyTree of the batch axis indices.
TYPE:
|
convert_to_numpy
|
If
TYPE:
|
key
|
A
TYPE:
|
| RETURNS | DESCRIPTION |
|---|---|
None
|
A |
| YIELDS | DESCRIPTION |
|---|---|
PyTree[Any, T]
|
A |
Note
Note that if the size of the dataset is smaller than batch_size, the
used batch_size will be reduced.
klax.split_data(data, proportions, batch_axes=0, *, key)
¤
Split a PyTree of data into multiply randomly drawn subsets.
This function is useful for splitting into training and test datasets. The
axis of the split if controlled by the batch_axes argument, which
specifies the batch axis for each leaf in data.
Example
This is an example for a nested PyTree` of data.
>>> import klax
>>> import jax
>>>
>>> x = jax.numpy.array([1., 2., 3.])
>>> data = (x, {"a": 1.0, "b": x})
>>> s1, s2 = klax.split_data(
... data,
... (2, 1),
... key=jax.random.key(0)
... )
>>> s1
(Array([1., 2.], dtype=float32), {'a': 1.0, 'b': Array([1., 2.], dtype=float32)})
>>> s2
(Array([3.], dtype=float32), {'a': 1.0, 'b': Array([3.], dtype=float32)})
| PARAMETER | DESCRIPTION |
|---|---|
data
|
Data that shall be split. It can be any
TYPE:
|
proportions
|
Proportions of the split that will be applied to the data,
e.g.,
TYPE:
|
batch_axes
|
PyTree of the batch axis indices.
TYPE:
|
key
|
A
TYPE:
|
| RETURNS | DESCRIPTION |
|---|---|
tuple[PyTree[Any], ...]
|
Tuple of |
klax.fit(model, data, *, batch_size=32, batch_axes=0, run_state=None, validation_data=None, steps=1000, loss=mse, optimizer=GradientTransformationExtraArgs(init=<function chain.<locals>.init_fn at 0x7f8ba9d4fec0>, update=<function chain.<locals>.update_fn at 0x7f8ba9d4ff60>), init_opt_state=None, batcher=batch_data, make_logger=True, metrics=None, log_every=100, verbose=2, callbacks=None, key)
¤
Train a model using an optimizer from optax.
This is a convenient wrapper around run_training_loop
that sets up optimizer, training state and callbacks.
| PARAMETER | DESCRIPTION |
|---|---|
model
|
The model instance, which should be trained. It must be a
subclass of
TYPE:
|
data
|
The training data can be any
TYPE:
|
batch_size
|
The number of examples in a batch.
TYPE:
|
batch_axes
|
A Example: For a dataset of 100 examples
TYPE:
|
run_state
|
Auxiliary runtime state, that is passed to the loss function.
Can be updated via callbacks.
Defaults to
TYPE:
|
validation_data
|
Arbitrary
TYPE:
|
steps
|
Number of gradient updates to apply. Defaults to 1000.
TYPE:
|
loss
|
The loss function with call signature
TYPE:
|
optimizer
|
The optimizer. Any optax gradient transform to calculate the updates for the model. Defaults to optax.adam(1e-3).
TYPE:
|
init_opt_state
|
The initial state of the optimizer. If
TYPE:
|
batcher
|
The data loader that splits inputs and targets into batches.
Defaults to
TYPE:
|
make_logger
|
Wether to create a
TYPE:
|
metrics
|
Sequence of metrics to be evaluated at regular
intervals during the training. You can overwrite the default "loss"
and "validation_loss" metrics, by adding custom metrics with the same
name.
Defaults to
TYPE:
|
log_every
|
Interval for both metric evaluation and progress logging.
A value
TYPE:
|
verbose
|
Integer controlling the verbosity during training.
- 0: Nothing is printed.
- 1: A message is printed every
TYPE:
|
callbacks
|
List of Callbacks. They can be used to
implement early stopping, custom logging and more. The argument
to the callback function is aCallbackArgs object.
Defaults to
TYPE:
|
key
|
A
TYPE:
|
| RETURNS | DESCRIPTION |
|---|---|
tuple[T, History | None]
|
A tuple of the trained model and the training history. |
Note
The returned history will be None if make_logger=False.
klax.run_training_loop(context, callbacks)
¤
klax.make_step(state_leaves, state_treedef, batch, loss, optimizer)
¤
klax.TrainingState
¤
Dataclass of things that are expected to change during training.
This consists of:
model: Theeqx.Module(or more generally any PyTree) representing the trainable model.opt_state: The state of theoptaxoptimizer.run_state: The user-defined state of the training run, which is passed to the loss function and may be modified via callbacks.step: The current optimization step count of the training.
klax.TrainingContext
¤
Collection of all training relevant objects.
This includes:
state: TheTrainingStateoptimizer: The optax optimizerloss: The Loss functionbatch_generator: The generator object responsible for creating data batchessteps: The total number of scheduled optimization steps for the training run