JAX Solver#

class pybamm.JaxSolver(method='RK45', root_method=None, rtol=1e-06, atol=1e-06, extrap_tol=None, extra_options=None)[source]#

Solve a discretised model using a JAX compiled solver.

Note: this solver will not work with models that have

termination events or are not converted to jax format

Raises:
Parameters:
  • method (str) – ‘RK45’ (default) uses jax.experimental.odeint ‘BDF’ uses custom jax_bdf_integrate (see jax_bdf_integrate.py for details)

  • root_method (str, optional) – Method to use to calculate consistent initial conditions. By default this uses the newton chord method internal to the jax bdf solver, otherwise choose from the set of default options defined in docs for pybamm.BaseSolver

  • rtol (float, optional) – The relative tolerance for the solver (default is 1e-6).

  • atol (float, optional) – The absolute tolerance for the solver (default is 1e-6).

  • extrap_tol (float, optional) – The tolerance to assert whether extrapolation occurs or not (default is 0).

  • extra_options (dict, optional) – Any options to pass to the solver. Please consult JAX documentation for details.

Extends: pybamm.solvers.base_solver.BaseSolver

create_solve(model, t_eval)[source]#

Return a compiled JAX function that solves an ode model with input arguments.

Parameters:
  • model (pybamm.BaseModel) – The model whose solution to calculate.

  • t_eval (numpy.array, size (k,)) – The times at which to compute the solution

Returns:

A function with signature f(inputs), where inputs are a dict containing any input parameters to pass to the model when solving

Return type:

function

get_solve(model, t_eval)[source]#

Return a compiled JAX function that solves an ode model with input arguments.

Parameters:
  • model (pybamm.BaseModel) – The model whose solution to calculate.

  • t_eval (numpy.array, size (k,)) – The times at which to compute the solution

Returns:

A function with signature f(inputs), where inputs are a dict containing any input parameters to pass to the model when solving

Return type:

function

pybamm.jax_bdf_integrate(func, y0, t_eval, *args, rtol=1e-06, atol=1e-06, mass=None)[source]#

Backward Difference formula (BDF) implicit multistep integrator. The basic algorithm is derived in Byrne and Hindmarsh[1]. This particular implementation follows that implemented in the Matlab routine ode15s described in Shampine and Reichelt[2] and the SciPy implementation Virtanen et al.[3] which features the NDF formulas for improved stability, with associated differences in the error constants, and calculates the jacobian at J(t_{n+1}, y^0_{n+1}). This implementation was based on that implemented in the SciPy library Virtanen et al.[3], which also mainly follows Shampine and Reichelt[2] but uses the more standard jacobian update.

Parameters:
  • func (callable) – function to evaluate the time derivative of the solution y at time t as func(y, t, *args), producing the same shape/structure as y0.

  • y0 (ndarray) – initial state vector

  • t_eval (ndarray) – time points to evaluate the solution, has shape (m,)

  • args ((optional)) – tuple of additional arguments for fun, which must be arrays scalars, or (nested) standard Python containers (tuples, lists, dicts, namedtuples, i.e. pytrees) of those types.

  • rtol ((optional) float) – relative tolerance for the solver

  • atol ((optional) float) – absolute tolerance for the solver

  • mass ((optional) ndarray) – diagonal of the mass matrix with shape (n,)

Returns:

y – calculated state vector at each of the m time points

Return type:

ndarray with shape (n, m)

References