Skip to content

ODESolver

dynax.ODESolver

Integrate a function or submodule.

This class integrates \(\dot{y} = \text{func}(t, y, u, [\text{funcargs}])\) defined by a function or submodule func. The integration is performed using diffrax. This makes the ODESolver differentiable.

Example
>>> import jax.numpy as jnp
>>> from dynax import ODESolver
>>> def func(t, y, u):
...     return -y + u
>>> model = ODESolver(func)
>>> ts = jnp.linspace(0, 1, 100)
>>> y0 = jnp.array([0.5, 1.0, 2.0])
>>> us = jnp.sin(ts)  # Example input
>>> solution = model(ts, y0, us)
>>> print(solution.shape)
(100, 3)

__init__(func, augmentation=0, augmented_ic_learnable=False, solver=diffrax.Tsit5(), stepsize_controller=diffrax.PIDController(rtol=1e-06, atol=1e-06), max_steps=4096)

Initialize the ODESolver.

PARAMETER DESCRIPTION
func

Function or submodel to integrate. The function arguments are (t, y, u, [funcargs]) with a scalar time t, a N-dimensional state vector y, a m-dimensional input vector u and an optional patree funcargs. The function must return an N-dimensional vector representing the state derivative.

TYPE: Callable[..., Array]

augmentation

If augmentation is an array, it describes the vector of augmented states that are added to the inital state y0 before passing it to func. If augmentation is an integer, it describes how many extra dimensions are to be added to the state and initializes the augmented initial condition to all zeros. Defaults to 0.

TYPE: int | Array DEFAULT: 0

augmented_ic_learnable

If True the initial condition of the augmented state is updated during training. Defaults to False.

TYPE: bool DEFAULT: False

solver

Specifies the diffax solver to use for numerical integration.

TYPE: diffrax.AbstractSolver DEFAULT: diffrax.Tsit5()

stepsize_controller

The diffrax stepsize controller to use for integration.

TYPE: diffrax.AbstractStepSizeController DEFAULT: diffrax.PIDController(rtol=1e-06, atol=1e-06)

max_steps

The maximum number of steps to take before quitting the computation unconditionally. A value of None means no limit is imposed.

TYPE: int | None DEFAULT: 4096

Note

The arguments solver, stepsize_controller, and max_steps are passed directly to the diffrax.diffeqsolve function. For more information view the diffrax documentation.

RAISES DESCRIPTION
ValueError

If provided augmentation is neither an array or integer.

__call__(ts, y0, us=None, funcargs=None)

Solve ODE.

PARAMETER DESCRIPTION
ts

Array of timesteps at which the solution is evaluated, with shape (k,).

TYPE: Array

y0

Initial condition of the system state at time ts[0].

TYPE: Array

us

Two dimensional array of stacked input vectors at each timestep. Shape (k, m). Defaults to None.

TYPE: Array | None DEFAULT: None

funcargs

Optional additional arguments to pass to the function func.

TYPE: PyTree DEFAULT: None

RETURNS DESCRIPTION
Array

Solution of the ODE excluding the augmented states. Shape (k, n).

get_augmented_trajectory(ts, y0, us=None, funcargs=None)

Return the ODE solution including the augmented states.

PARAMETER DESCRIPTION
ts

Array of timesteps at which the solution is evaluated, with shape (k,).

TYPE: Array

y0

Initial condition of the system state at time ts[0].

TYPE: Array

us

Two dimensional array of stacked input vectors at each timestep. Shape (k, m). Defaults to None.

TYPE: Array | None DEFAULT: None

funcargs

Optional additional arguments to pass to the function func.

TYPE: PyTree DEFAULT: None

RETURNS DESCRIPTION
Array

Solution of the ODE including the augmented states. Shape (k, N).

get_solution(ts, y0, us=None, funcargs=None)

Return the diffrax.Solution object.

PARAMETER DESCRIPTION
ts

Array of timesteps at which the solution is evaluated, with shape (k,).

TYPE: Array

y0

Initial condition of the system state at time ts[0].

TYPE: Array

us

Two dimensional array of stacked input vectors at each timestep. Shape (k, m). Defaults to None.

TYPE: Array | None DEFAULT: None

funcargs

Optional additional arguments to pass to the function func.

TYPE: PyTree DEFAULT: None

RETURNS DESCRIPTION
diffrax.Solution

diffrax solution object.