Solve a discretised model using a JAX compiled solver.
termination events or are not converted to jax format
RuntimeError – if model has any termination events
RuntimeError – if model.convert_to_format != ‘jax’
method (str) – ‘RK45’ (default) uses jax.experimental.odeint ‘BDF’ uses custom jax_bdf_integrate (see jax_bdf_integrate.py for details)
root_method (str, optional) – Method to use to calculate consistent initial conditions. By default this uses the newton chord method internal to the jax bdf solver, otherwise choose from the set of default options defined in docs for pybamm.BaseSolver
rtol (float, optional) – The relative tolerance for the solver (default is 1e-6).
atol (float, optional) – The absolute tolerance for the solver (default is 1e-6).
extrap_tol (float, optional) – The tolerance to assert whether extrapolation occurs or not (default is 0).
extra_options (dict, optional) – Any options to pass to the solver. Please consult JAX documentation for details.
Return a compiled JAX function that solves an ode model with input arguments.
model (pybamm.BaseModel
) – The model whose solution to calculate.
t_eval (numpy.array
, size (k,)) – The times at which to compute the solution
A function with signature f(inputs), where inputs are a dict containing any input parameters to pass to the model when solving
function
Return a compiled JAX function that solves an ode model with input arguments.
model (pybamm.BaseModel
) – The model whose solution to calculate.
t_eval (numpy.array
, size (k,)) – The times at which to compute the solution
A function with signature f(inputs), where inputs are a dict containing any input parameters to pass to the model when solving
function
Backward Difference formula (BDF) implicit multistep integrator. The basic algorithm is derived in 2. This particular implementation follows that implemented in the Matlab routine ode15s described in 1 and the SciPy implementation 3, which features the NDF formulas for improved stability, with associated differences in the error constants, and calculates the jacobian at J(t_{n+1}, y^0_{n+1}). This implementation was based on that implemented in the scipy library 3, which also mainly follows 1 but uses the more standard jacobian update.
func (callable) – function to evaluate the time derivative of the solution y at time t as func(y, t, *args), producing the same shape/structure as y0.
y0 (ndarray) – initial state vector
t_eval (ndarray) – time points to evaluate the solution, has shape (m,)
args ((optional)) – tuple of additional arguments for fun, which must be arrays scalars, or (nested) standard Python containers (tuples, lists, dicts, namedtuples, i.e. pytrees) of those types.
rtol ((optional) float) – relative tolerance for the solver
atol ((optional) float) – absolute tolerance for the solver
mass ((optional) ndarray) – diagonal of the mass matrix with shape (n,)
y – calculated state vector at each of the m time points
ndarray with shape (n, m)
References
L. F. Shampine, M. W. Reichelt, “THE MATLAB ODE SUITE”, SIAM J. SCI. COMPUTE., Vol. 18, No. 1, pp. 1-22, January 1997.
G. D. Byrne, A. C. Hindmarsh, “A Polyalgorithm for the Numerical Solution of Ordinary Differential Equations”, ACM Transactions on Mathematical Software, Vol. 1, No. 1, pp. 71-96, March 1975.
Virtanen, P., Gommers, R., Oliphant, T. E., Haberland, M., Reddy, T., Cournapeau, D., … & van der Walt, S. J. (2020). SciPy 1.0: fundamental algorithms for scientific computing in Python. Nature methods, 17(3), 261-272.