IDAKLU-JAX Interface#

class pybamm.IDAKLUJax(solver, model, t_eval, output_variables=None, calculate_sensitivities=True)[source]#

JAX wrapper for IDAKLU solver

Objects of this class should be created via an IDAKLUSolver object.

Log information is available for this module via the named ‘pybamm.solvers.idaklu_jax’ logger.

Parameters:

solver (pybamm.IDAKLUSolver) – The IDAKLU solver object to be wrapped

get_jaxpr()[source]#

Returns a JAX expression representing the IDAKLU-wrapped solver object

Returns:

A JAX expression with the following call signature:

f(t, inputs=None)

where:
tfloat | np.ndarray

Time sample or vector of time samples

inputsdict, optional

dictionary of input values, e.g. {‘Current function [A]’: 0.222, ‘Separator porosity’: 0.3}

Return type:

Callable

get_var(*args)[source]#

Helper function to extract a single variable

Isolates a single variable from the model output. Can be called on a JAX expression (which returns a JAX expression), or on a numeric (np.ndarray) object (which returns a slice of the output).

Example call using default JAX expression, returns a JAX expression:

f = idaklu_jax.get_var("Voltage [V]")
data = f(t, inputs=None)

Example call using a custom function, returns a JAX expression:

f = idaklu_jax.get_var(jax.jit(f), "Voltage [V]")
data = f(t, inputs=None)

Example call to slice a matrix, returns an np.array:

data = idaklu_jax.get_var(
    jax.fwd(f, argnums=1)(t_eval, inputs)['Current function [A]'],
    'Voltage [V]'
)
Parameters:
  • f (Callable | np.ndarray, optional) – Expression or array from which to extract the target variable

  • varname (str) – The name of the variable to extract

Returns:

  • Callable – If called with a JAX expression, returns a JAX expression with the following call signature:

    f(t, inputs=None)

    where:
    tfloat | np.ndarray

    Time sample or vector of time samples

    inputsdict, optional

    dictionary of input values, e.g. {‘Current function [A]’: 0.222, ‘Separator porosity’: 0.3}

  • np.ndarray – If called with a numeric (np.ndarray) object, returns a slice of the output corresponding to the target variable.

get_vars(*args)[source]#

Helper function to extract a list of variables

Isolates a list of variables from the model output. Can be called on a JAX expression (which returns a JAX expression), or on a numeric (np.ndarray) object (which returns a slice of the output).

Example call using default JAX expression, returns a JAX expression:

f = idaklu_jax.get_vars(["Voltage [V]", "Current [A]"])
data = f(t, inputs=None)

Example call using a custom function, returns a JAX expression:

f = idaklu_jax.get_vars(jax.jit(f), ["Voltage [V]", "Current [A]"])
data = f(t, inputs=None)

Example call to slice a matrix, returns an np.array:

data = idaklu_jax.get_vars(
    jax.fwd(f, argnums=1)(t_eval, inputs)['Current function [A]'],
    ["Voltage [V]", "Current [A]"]
)
Parameters:
  • f (Callable | np.ndarray, optional) – Expression or array from which to extract the target variables

  • varname (list of str) – The names of the variables to extract

Returns:

  • Callable – If called with a JAX expression, returns a JAX expression with the following call signature:

    f(t, inputs=None)

    where:
    tfloat | np.ndarray

    Time sample or vector of time samples

    inputsdict, optional

    dictionary of input values, e.g. {‘Current function [A]’: 0.222, ‘Separator porosity’: 0.3}

  • np.ndarray – If called with a numeric (np.ndarray) object, returns a slice of the output corresponding to the target variables.

jax_grad(t: ndarray = None, inputs: dict | None = None, output_variables: list[str] | None = None)[source]#

Helper function to compute the gradient of a jaxified expression

Returns a numeric (np.ndarray) object (not a JAX expression). Parameters are inferred from the base object, but can be overridden.

Parameters:
  • t (float | np.ndarray) – Time sample or vector of time samples

  • inputs (dict) – dictionary of input values

  • output_variables (list of str, optional) – The variables to be returned. If None, the variables in the model are used.

jax_value(t: ndarray = None, inputs: dict | None = None, output_variables: list[str] | None = None)[source]#

Helper function to compute the gradient of a jaxified expression

Returns a numeric (np.ndarray) object (not a JAX expression). Parameters are inferred from the base object, but can be overridden.

Parameters:
  • t (float | np.ndarray) – Time sample or vector of time samples

  • inputs (dict) – dictionary of input values

  • output_variables (list of str, optional) – The variables to be returned. If None, the variables in the model are used.

jaxify(model, t_eval, *, output_variables=None, calculate_sensitivities=True)[source]#

JAXify the model and solver

Creates a JAX expression representing the IDAKLU-wrapped solver object.

Parameters:
  • model (pybamm.BaseModel) – The model to be solved

  • t_eval (numeric type, optional) – The times at which to compute the solution. If None, the times in the model are used.

  • output_variables (list of str, optional) – The variables to be returned. If None, the variables in the model are used.

  • calculate_sensitivities (bool, optional) – Whether to calculate sensitivities. Default is True.