Calibration and Data Handling
klax.batch_data(data, batch_size=32, batch_axis=0, convert_to_numpy=True, *, key)
¤
Create a Generator
that draws subsets of data without replacement.
The data can be any PyTree
with ArrayLike
leaves. If batch_mask
is
passed, batch axes (including None
for no batching) can be specified for
every leaf individualy.
A generator is returned that indefinetly yields batches of data with size
batch_size
. Examples are drawn without replacement until the remaining
dataset is smaller than batch_size
, at which point the dataset will be
reshuffeld and the process starts over.
Example
This is an example for a nested PyTree
, where the elements x and y
have batch dimension along the first axis.
>>> import klax
>>> import jax
>>> import jax.numpy as jnp
>>>
>>> x = jnp.array([1., 2.])
>>> y = jnp.array([[1.], [2.]])
>>> data = (x, {"a": 1.0, "b": y})
>>> batch_mask = (0, {"a": None, "b": 0})
>>> iter_data = klax.batch_data(
... data,
... 32,
... batch_mask,
... key=jax.random.key(0)
... )
>>>
PARAMETER | DESCRIPTION |
---|---|
data
|
The data that shall be batched. It can be any
TYPE:
|
batch_size
|
The number of examples in a batch.
TYPE:
|
batch_axis
|
PyTree of the batch axis indices.
TYPE:
|
convert_to_numpy
|
If
TYPE:
|
key
|
A
TYPE:
|
RETURNS | DESCRIPTION |
---|---|
Generator[PyTree[Any], None, None]
|
A |
Generator[PyTree[Any], None, None]
|
called. |
YIELDS | DESCRIPTION |
---|---|
Generator[PyTree[Any], None, None]
|
A |
Generator[PyTree[Any], None, None]
|
batched leaves have |
Note
Note that if the size of the dataset is smaller than batch_size
, the
obtained batches will have dataset size.
klax.split_data(data, proportions, batch_axis=0, *, key)
¤
Split a PyTree
of data into multiply randomly drawn subsets.
This function is useful for splitting into training and test datasets. The
axis of the split if controlled by the batch_axis
argument, which
specifies the batch axis for each leaf in data
.
Example
This is an example for a nested PyTree` of data.
>>> import klax
>>> import jax
>>>
>>> x = jax.numpy.array([1., 2., 3.])
>>> data = (x, {"a": 1.0, "b": x})
>>> s1, s2 = klax.split_data(
... data,
... (2, 1),
... key=jax.random.key(0)
... )
>>> s1
(Array([1., 2.], dtype=float32), {'a': 1.0, 'b': Array([1., 2.], dtype=float32)})
>>> s2
(Array([3.], dtype=float32), {'a': 1.0, 'b': Array([3.], dtype=float32)})
PARAMETER | DESCRIPTION |
---|---|
data
|
Data that shall be split. It can be any
TYPE:
|
proportions
|
Proportions of the split that will be applied to the data,
e.g.,
TYPE:
|
batch_axis
|
PyTree of the batch axis indices.
TYPE:
|
key
|
A
TYPE:
|
RETURNS | DESCRIPTION |
---|---|
tuple[PyTree[Any], ...]
|
Tuple of |
klax.fit(model, data, *, batch_size=32, batch_axis=0, validation_data=None, steps=1000, loss_fn=<klax._losses.MSE object at 0x7fb722df9670>, optimizer=GradientTransformationExtraArgs(init=<function chain.<locals>.init_fn>, update=<function chain.<locals>.update_fn>), init_opt_state=None, batcher=<function batch_data>, history=None, callbacks=None, key)
¤
Trains a model using an optimizer from optax.
PARAMETER | DESCRIPTION |
---|---|
model
|
The model instance, which should be trained. It must be a
subclass of
TYPE:
|
data
|
The training data can be any
TYPE:
|
batch_size
|
The number of examples in a batch.
TYPE:
|
batch_axis
|
A
TYPE:
|
validation_data
|
Arbitrary
TYPE:
|
steps
|
Number of gradient updates to apply. (Defaults to 1000.)
TYPE:
|
loss_fn
|
The loss function with call signature
TYPE:
|
optimizer
|
The optimizer. Any optax gradient transform to calculate the updates for the model. (Defaults to optax.adam(1e-3).)
TYPE:
|
init_opt_state
|
The initial state of the optimizer. If
TYPE:
|
batcher
|
The data loader that splits inputs and targets into batches.
(Defaults to
TYPE:
|
history
|
A callback intended for tracking the training process. If no
custom callback is passed the
TYPE:
|
callbacks
|
Callback functions that are evaluated after every training
step. They can be used to implement early stopping, custom history
logging and more. The argument to the callback function is a
CallbackArgs object. (Defaults to
TYPE:
|
key
|
A
TYPE:
|
Note
This function assumes that the batch dimension is always oriented along
the first axes of any jax.Array
RETURNS | DESCRIPTION |
---|---|
tuple[T, klax.HistoryCallback | H]
|
A tuple of the trained model and the loss history. |