ODESolver
dynax.ODESolver
Integrate a function or submodule.
This class integrates \(\dot{y} = \text{func}(t, y, u, [\text{funcargs}])\)
defined by a function or submodule func
. The integration is performed
using diffrax.
This makes the ODESolver
differentiable.
Example
>>> import jax.numpy as jnp
>>> from dynax import ODESolver
>>> def func(t, y, u):
... return -y + u
>>> model = ODESolver(func)
>>> ts = jnp.linspace(0, 1, 100)
>>> y0 = jnp.array([0.5, 1.0, 2.0])
>>> us = jnp.sin(ts) # Example input
>>> solution = model(ts, y0, us)
>>> print(solution.shape)
(100, 3)
__init__(func, augmentation=0, augmented_ic_learnable=False, solver=diffrax.Tsit5(), stepsize_controller=diffrax.PIDController(rtol=1e-06, atol=1e-06), max_steps=4096)
Initialize the ODESolver.
PARAMETER | DESCRIPTION |
---|---|
func
|
Function or submodel to integrate. The function arguments are
TYPE:
|
augmentation
|
If
TYPE:
|
augmented_ic_learnable
|
If
TYPE:
|
solver
|
Specifies the diffax solver to use for numerical integration.
TYPE:
|
stepsize_controller
|
The diffrax stepsize controller to use for integration.
TYPE:
|
max_steps
|
The maximum number of steps to take before quitting the computation
unconditionally. A value of
TYPE:
|
Note
The arguments solver
, stepsize_controller
, and max_steps
are passed directly
to the diffrax.diffeqsolve
function. For more information view the
diffrax documentation.
RAISES | DESCRIPTION |
---|---|
ValueError
|
If provided augmentation is neither an array or integer. |
__call__(ts, y0, us=None, funcargs=None)
Solve ODE.
PARAMETER | DESCRIPTION |
---|---|
ts
|
Array of timesteps at which the solution is evaluated, with shape
TYPE:
|
y0
|
Initial condition of the system state at time
TYPE:
|
us
|
Two dimensional array of stacked input vectors at each timestep.
Shape
TYPE:
|
funcargs
|
Optional additional arguments to pass to the function
TYPE:
|
RETURNS | DESCRIPTION |
---|---|
Array
|
Solution of the ODE excluding the augmented states. Shape |
get_augmented_trajectory(ts, y0, us=None, funcargs=None)
Return the ODE solution including the augmented states.
PARAMETER | DESCRIPTION |
---|---|
ts
|
Array of timesteps at which the solution is evaluated, with shape
TYPE:
|
y0
|
Initial condition of the system state at time
TYPE:
|
us
|
Two dimensional array of stacked input vectors at each timestep.
Shape
TYPE:
|
funcargs
|
Optional additional arguments to pass to the function
TYPE:
|
RETURNS | DESCRIPTION |
---|---|
Array
|
Solution of the ODE including the augmented states. Shape |
get_solution(ts, y0, us=None, funcargs=None)
Return the diffrax.Solution
object.
PARAMETER | DESCRIPTION |
---|---|
ts
|
Array of timesteps at which the solution is evaluated, with shape
TYPE:
|
y0
|
Initial condition of the system state at time
TYPE:
|
us
|
Two dimensional array of stacked input vectors at each timestep.
Shape
TYPE:
|
funcargs
|
Optional additional arguments to pass to the function
TYPE:
|
RETURNS | DESCRIPTION |
---|---|
diffrax.Solution
|
diffrax solution object. |