Model tools
Extracting model information¤
klax.count_parameters(model)
¤
Count the number of trainable parameters in a model.
Under the hood this just counts the number of inexact
JAX/NumPy array elements in the pytree that are not
wrapped by klax.NonTrainable.
Warning
If you use jax.lax.stop_gradient or any other method
besides klax.NonTrainable to make arrays not receive
gradient updates, then this function will overestimate
the number of trainable parameters!
Consider using klax.NonTrainable or counting the trainable
parameters manually.
| PARAMETER | DESCRIPTION |
|---|---|
model
|
Arbitrary pytree.
TYPE:
|
| RETURNS | DESCRIPTION |
|---|---|
int
|
Integer count of inexact inexact JAX/NumPy array elements. |