Tip

An interactive online version of this notebook is available, which can be accessed via Open this notebook in Google Colab


Alternatively, you may download this notebook and run it offline.

Attention

You are viewing this notebook on the latest version of the documentation, where these notebooks may not be compatible with the stable release of PyBaMM since they can contain features that are not yet released. We recommend viewing these notebooks from the stable version of the documentation. To install the latest version of PyBaMM that is compatible with the latest notebooks, build PyBaMM from source.

IDAKLU-JAX interface#

The IDAKLU-JAX interface requires that PyBaMM is installed with the optional JAX solver enabled (pip install pybamm[jax]) and requires at least Python 3.9.

PyBaMM provides two mechanisms to interface battery models with JAX. The first (JaxSolver) implements PyBaMM models directly in native JAX, and as such provides the greatest flexibility. However, these models can be very slow to compile, especially during their initial run, and can require large amounts of memory.

The second (the IDAKLU-Jax interface) instead provides a JAX-compliant interface to the IDAKLU solver. IDAKLU is a fast (compiled) solver based on SUNDIALS. By exposing the IDAKLU solver to JAX, we provide a fast solver capable of interfacing with third-party JAX-compatible software libraries, such as numpyro.

Despite the apparent advantages, there are some limitations to this approach. The most notable is that model derivatives are limited to first-order (i.e. sensitivities), since the IDAKLU solver is not capable of auto-differentiation.

Setup a basic DFN model#

To demonstrate use of the IDAKLU-Jax interface, we first set-up a basic model, choosing the DFN model in this case. We will provide two inputs to the model and will specify a list of variables of interest (output_variables). Specifying output_variables is strongly recommended to reduce computational load, while inputs are only required when derivatives are to be considered.

[1]:
%pip install "pybamm[jax]" -q    # install PyBaMM with JAX support if it is not installed
import pybamm
import time
import numpy as np
import jax
import jax.numpy as jnp
Note: you may need to restart the kernel to use updated packages.
[2]:
# We will want to differentiate our model, so let's define two input parameters
inputs = {
    "Current function [A]": 0.222,
    "Separator porosity": 0.3,
}

# Set-up the model
model = pybamm.lithium_ion.DFN()
geometry = model.default_geometry
param = model.default_parameter_values
param.update({key: "[input]" for key in inputs.keys()})
param.process_geometry(geometry)
param.process_model(model)
var = pybamm.standard_spatial_vars
var_pts = {var.x_n: 20, var.x_s: 20, var.x_p: 20, var.r_n: 10, var.r_p: 10}
mesh = pybamm.Mesh(geometry, model.default_submesh_types, var_pts)
disc = pybamm.Discretisation(mesh, model.default_spatial_methods)
disc.process_model(model)

# Use a short time-vector for this example, and declare which variables to track
t_eval = np.linspace(0, 360, 10)
output_variables = [
    "Voltage [V]",
    "Current [A]",
    "Time [min]",
]

# Create the IDAKLU Solver object
idaklu_solver = pybamm.IDAKLUSolver(
    rtol=1e-6,
    atol=1e-6,
    output_variables=output_variables,
)

Next, we jaxify the IDAKLU solver in the same way that we would run the IDAKLU solve. The only difference is that the jaxify() function returns an IDAKLUJax object, instead of a Solution object. We will keep track of this object, and can request a JAX-expression from it using the get_jaxpr() method, as below.

[3]:
# This is how we would normally perform a solve using IDAKLU
sim = idaklu_solver.solve(
    model,
    t_eval,
    inputs=inputs,
    calculate_sensitivities=True,
)

# Instead, we Jaxify the IDAKLU solver using similar arguments...
jax_solver = idaklu_solver.jaxify(
    model,
    t_eval,
)

# ... and then obtain a JAX expression for the solve
f = jax_solver.get_jaxpr()
print(f"JAX expression: {f}")
JAX expression: <function IDAKLUJax._jaxify.<locals>.f at 0x132d50f40>

The JAX expression (that we named f in our example), is a function that can be used and evaluated like any other native JAX expression. This means that it can be included in broader JAX expressions, and can even be JIT compiled. The only limitations are that: 1) derivatives cannot be taken beyond first-order, which is the limit of our IDAKLU solver implementation, and 2) you are required to specify output_variables either at the IDAKLUSolver stage, or at the jaxify stage (you can create many jaxified expressions from a single solver object).

Here is the most basic usage example:

[4]:
# Print all output variables, evaluated over a given time vector
data = f(t_eval, inputs)
print(data)
[[3.81930814e+000 2.22000000e-001 4.95024341e-316]
 [3.81346107e+000 2.22000000e-001 6.66666667e-001]
 [3.81080090e+000 2.22000000e-001 1.33333333e+000]
 [3.80885531e+000 2.22000000e-001 2.00000000e+000]
 [3.80714541e+000 2.22000000e-001 2.66666667e+000]
 [3.80552362e+000 2.22000000e-001 3.33333333e+000]
 [3.80393909e+000 2.22000000e-001 4.00000000e+000]
 [3.80237338e+000 2.22000000e-001 4.66666667e+000]
 [3.80081962e+000 2.22000000e-001 5.33333333e+000]
 [3.79927489e+000 2.22000000e-001 6.00000000e+000]]

Here we see a matrix of (Nx3), where N is the number of time-samples in t_eval, and the three column-vectors correspond to our three output_variables. We can evaluate the expression at any point within our time-span (e.g. f(0.0, inputs)), or at multiple points (such as the full range of t_eval, as in our example). To help isolate output variables, the IDAKLU-Jax interface provides several helper functions. Below we demonstrate isolating a single variable using the get_var helper. You can also isolate multiple variables, provided as a list, by using the get_vars helper function.

[5]:
# Isolate a single variables
data = jax_solver.get_var("Voltage [V]")(t_eval, inputs)
print(f"Isolating a single variable returns an array of shape {data.shape}")
print(data)

# Isolate two variables from the solver
data = jax_solver.get_vars(
    [
        "Voltage [V]",
        "Current [A]",
    ],
)(t_eval, inputs)
print(f"\nIsolating two variables returns an array of shape {data.shape}")
print(data)
Isolating a single variable returns an array of shape (10,)
[3.81930814 3.81346107 3.8108009  3.80885531 3.80714541 3.80552362
 3.80393909 3.80237338 3.80081962 3.79927489]

Isolating two variables returns an array of shape (10, 2)
[[3.81930814 0.222     ]
 [3.81346107 0.222     ]
 [3.8108009  0.222     ]
 [3.80885531 0.222     ]
 [3.80714541 0.222     ]
 [3.80552362 0.222     ]
 [3.80393909 0.222     ]
 [3.80237338 0.222     ]
 [3.80081962 0.222     ]
 [3.79927489 0.222     ]]

As with any JAX expression, we can create new expressions by encapsulating them in outer functions (as further demonstrated below). The method jax_solver.get_var() does this for you by encapsulating f with a function that isolates a given variable of interest. We then evaluate that new expression by passing our usual arguments (t_eval, inputs).

To compute the Jacobian matrix (the matrix of derivates of output variables with respect to each input parameter), make use of the Jacobian forward derivation jax.jacfwd and Jacobian reverse derivation jax.jacrev functions.

When calling these functions we note that argnums=1 signifies that we are taking the Jacobian with respect to the second argument (indexing from 0: inputs). Since inputs is a dictionary of input parameters, the result will also be a dictionary of derivatives with respect to each dictionary key / input parameter. These two methods (jacfwd and jacrev) will produce the same output, it is simply their derivation that differs. In general, the forward method tends to be slightly faster to run than the reverse method for our IDAKLU implementation.

[6]:
# Calculate the Jacobian matrix (via forward autodiff)
t_start = time.time()
out = jax.jacfwd(f, argnums=1)(t_eval, inputs)
print(f"Jacobian forward method ran in {time.time()-t_start:0.3} secs")
print(out)

# Calculate Jacobian matrix (via backward autodiff)
t_start = time.time()
out = jax.jacrev(f, argnums=1)(t_eval, inputs)
print(f"\nJacobian reverse method ran in {time.time()-t_start:0.3} secs")
print(out)
Jacobian forward method ran in 0.125 secs
{'Current function [A]': Array([[-0.13643792,  1.        ,  0.        ],
       [-0.16400861,  1.        ,  0.        ],
       [-0.17630142,  1.        ,  0.        ],
       [-0.18509421,  1.        ,  0.        ],
       [-0.19273301,  1.        ,  0.        ],
       [-0.19993145,  1.        ,  0.        ],
       [-0.20692727,  1.        ,  0.        ],
       [-0.21380043,  1.        ,  0.        ],
       [-0.22057579,  1.        ,  0.        ],
       [-0.2272616 ,  1.        ,  0.        ]], dtype=float64), 'Separator porosity': Array([[0.00579553, 0.        , 0.        ],
       [0.00797   , 0.        , 0.        ],
       [0.0095281 , 0.        , 0.        ],
       [0.01024868, 0.        , 0.        ],
       [0.01053737, 0.        , 0.        ],
       [0.0106461 , 0.        , 0.        ],
       [0.01068649, 0.        , 0.        ],
       [0.01070164, 0.        , 0.        ],
       [0.01070816, 0.        , 0.        ],
       [0.01071172, 0.        , 0.        ]], dtype=float64)}

Jacobian reverse method ran in 0.196 secs
{'Current function [A]': Array([[-0.13643792,  1.        ,  0.        ],
       [-0.16400861,  1.        ,  0.        ],
       [-0.17630142,  1.        ,  0.        ],
       [-0.18509421,  1.        ,  0.        ],
       [-0.19273301,  1.        ,  0.        ],
       [-0.19993145,  1.        ,  0.        ],
       [-0.20692727,  1.        ,  0.        ],
       [-0.21380043,  1.        ,  0.        ],
       [-0.22057579,  1.        ,  0.        ],
       [-0.2272616 ,  1.        ,  0.        ]],      dtype=float64, weak_type=True), 'Separator porosity': Array([[0.00579553, 0.        , 0.        ],
       [0.00797   , 0.        , 0.        ],
       [0.0095281 , 0.        , 0.        ],
       [0.01024868, 0.        , 0.        ],
       [0.01053737, 0.        , 0.        ],
       [0.0106461 , 0.        , 0.        ],
       [0.01068649, 0.        , 0.        ],
       [0.01070164, 0.        , 0.        ],
       [0.01070816, 0.        , 0.        ],
       [0.01071172, 0.        , 0.        ]],      dtype=float64, weak_type=True)}

To extract the relevant data vector from the above expression, we can again make use of the get_var() helper function, which can also take numpy arrays as input, for example:

[7]:
# Isolate the derivate of Voltage with respect to the Current function:
out = jax.jacfwd(f, argnums=1)(t_eval, inputs)
data = jax_solver.get_var(out["Current function [A]"], "Voltage [V]")
print(data)
[-0.13643792 -0.16400861 -0.17630142 -0.18509421 -0.19273301 -0.19993145
 -0.20692727 -0.21380043 -0.22057579 -0.2272616 ]

The gradient (grad) function on the other hand requires the underlying function to return a scalar value. The function must therefore be called separately for each time sample, and can only be evaluted for one output variable at a time. We can obey these restrictions with our JAX expression f through use of the get_var and vmap functions (the latter of which provides vector-mapping over time).

[8]:
# Example evaluation using the `grad` function
t_start = time.time()
data = jax.vmap(
    jax.grad(
        jax_solver.get_var("Voltage [V]"),
        argnums=1,  # take derivative with respect to `inputs`
    ),
    in_axes=(0, None),  # map time over the 0th dimension and do not map inputs
)(t_eval, inputs)
print(f"Gradient method ran in {time.time()-t_start:0.3} secs")
print(data)
Gradient method ran in 0.105 secs
{'Current function [A]': Array([-0.13643792, -0.16400861, -0.17630142, -0.18509421, -0.19273301,
       -0.19993145, -0.20692727, -0.21380043, -0.22057579, -0.2272616 ],      dtype=float64), 'Separator porosity': Array([0.00579553, 0.00797   , 0.0095281 , 0.01024868, 0.01053737,
       0.0106461 , 0.01068649, 0.01070164, 0.01070816, 0.01071172],      dtype=float64)}

A use-case example#

As a use-case example, consider a fitting procedure where we want to compare simulation data against some experimental data. We achieve this by computing the sum-of-squared error (SEE) between the two. Many fitting procedures will converge more quickly (with fewer iterations) if both the value and gradient of the SSE function are provided. By making use of JAX-expressions we can derive these effortlessly.

Note: We do not need to map over time when calling value_and_grad in this example as the sse function returns a scalar (despite taking vector inputs).

[9]:
# Simulate some experimental data using our original parameter settings
data = sim["Voltage [V]"](t_eval)


# Sum-of-squared errors
def sse(t, inputs):
    modelled = jax_solver.get_var("Voltage [V]")(t_eval, inputs)
    return jnp.sum((modelled - data) ** 2)


# Provide some predicted model inputs (these could come from a fitting procedure)
inputs_pred = {
    "Current function [A]": 0.150,
    "Separator porosity": 0.333,
}

# Get the value and gradient of the SSE function
t_start = time.time()
value, gradient = jax.value_and_grad(sse, argnums=1)(t_eval, inputs_pred)
print(f"Value and gradient computed in {time.time()-t_start:0.3} secs")
print("SSE value: ", value)
print("SSE gradient (wrt each input): ", gradient)
Value and gradient computed in 0.095 secs
SSE value:  0.0020846163677995366
SSE gradient (wrt each input):  {'Current function [A]': array(-0.05775429), 'Separator porosity': array(0.00146983)}

All of the above expressions can be JIT compiled (onto CPU) by using the jax.jit directive. Practically, this provides a wrap-around back to the Python interface of the IDAKLU Solver, so is only provided to afford maximum downstream compatibility (where JIT may be called outside of the user’s control).