Integration Models
Integration models perform numerical integration of derivative models while supporting automatic differentiation through time.
dynax.ODESolver
Integrate a function or submodule.
This class integrates \(\dot{y} = \text{func}(t, y, u, [\text{funcargs}])\)
defined by a function or submodule func. The integration is performed
using diffrax.
This makes the ODESolver differentiable.
Example
>>> import jax.numpy as jnp
>>> from dynax import ODESolver
>>> def func(t, y, u):
... return -y + u
>>> model = ODESolver(func)
>>> ts = jnp.linspace(0, 1, 100)
>>> y0 = jnp.array([0.5, 1.0, 2.0])
>>> us = jnp.sin(ts) # Example input
>>> solution = model(ts, y0, us)
>>> print(solution.shape)
(100, 3)
__init__(func, augmentation=0, augmented_ic_learnable=False, solver=diffrax.Tsit5(), stepsize_controller=diffrax.PIDController(rtol=1e-06, atol=1e-06), max_steps=4096)
Initialize the ODESolver.
| PARAMETER | DESCRIPTION |
|---|---|
func
|
Function or submodel to integrate. The function arguments are
TYPE:
|
augmentation
|
If
TYPE:
|
augmented_ic_learnable
|
If
TYPE:
|
solver
|
Specifies the diffax solver to use for numerical integration.
TYPE:
|
stepsize_controller
|
The diffrax stepsize controller to use for integration.
TYPE:
|
max_steps
|
The maximum number of steps to take before quitting the computation
unconditionally. A value of
TYPE:
|
Note
The arguments solver, stepsize_controller, and max_steps are passed directly
to the diffrax.diffeqsolve function. For more information view the
diffrax documentation.
| RAISES | DESCRIPTION |
|---|---|
ValueError
|
If provided augmentation is neither an array or integer. |
__call__(ts, y0, us=None, funcargs=None)
Solve ODE.
| PARAMETER | DESCRIPTION |
|---|---|
ts
|
Array of timesteps at which the solution is evaluated, with shape
TYPE:
|
y0
|
Initial condition of the system state at time
TYPE:
|
us
|
Two dimensional array of stacked input vectors at each timestep.
Shape
TYPE:
|
funcargs
|
Optional additional arguments to pass to the function
TYPE:
|
| RETURNS | DESCRIPTION |
|---|---|
Array
|
Solution of the ODE excluding the augmented states. Shape |
get_augmented_trajectory(ts, y0, us=None, funcargs=None)
Return the ODE solution including the augmented states.
| PARAMETER | DESCRIPTION |
|---|---|
ts
|
Array of timesteps at which the solution is evaluated, with shape
TYPE:
|
y0
|
Initial condition of the system state at time
TYPE:
|
us
|
Two dimensional array of stacked input vectors at each timestep.
Shape
TYPE:
|
funcargs
|
Optional additional arguments to pass to the function
TYPE:
|
| RETURNS | DESCRIPTION |
|---|---|
Array
|
Solution of the ODE including the augmented states. Shape |
get_solution(ts, y0, us=None, funcargs=None)
Return the diffrax.Solution object.
| PARAMETER | DESCRIPTION |
|---|---|
ts
|
Array of timesteps at which the solution is evaluated, with shape
TYPE:
|
y0
|
Initial condition of the system state at time
TYPE:
|
us
|
Two dimensional array of stacked input vectors at each timestep.
Shape
TYPE:
|
funcargs
|
Optional additional arguments to pass to the function
TYPE:
|
| RETURNS | DESCRIPTION |
|---|---|
diffrax.Solution
|
diffrax solution object. |