Skip to content

Model tools

Extracting model information¤

klax.count_parameters(model) ¤

Count the number of trainable parameters in a model.

Under the hood this just counts the number of inexact JAX/NumPy array elements in the pytree that are not wrapped by klax.NonTrainable.

Warning

If you use jax.lax.stop_gradient or any other method besides klax.NonTrainable to make arrays not receive gradient updates, then this function will overestimate the number of trainable parameters! Consider using klax.NonTrainable or counting the trainable parameters manually.

PARAMETER DESCRIPTION
model

Arbitrary pytree.

TYPE: PyTree

RETURNS DESCRIPTION
int

Integer count of inexact inexact JAX/NumPy array elements.