Source code for pybamm.expression_tree.operations.jacobian

#
# Calculate the Jacobian of a symbol
#
from __future__ import annotations
import pybamm


[docs] class Jacobian: """ Helper class to calculate the Jacobian of an expression. Parameters ---------- known_jacs: dict {variable ids -> :class:`pybamm.Symbol`} cached jacobians clear_domain: bool whether or not the Jacobian clears the domain (default True) """ def __init__( self, known_jacs: dict[pybamm.Symbol, pybamm.Symbol] | None = None, clear_domain: bool = True, ): self._known_jacs = known_jacs or {} self._clear_domain = clear_domain
[docs] def jac(self, symbol: pybamm.Symbol, variable: pybamm.Symbol) -> pybamm.Symbol: """ This function recurses down the tree, computing the Jacobian using the Jacobians defined in classes derived from pybamm.Symbol. E.g. the Jacobian of a 'pybamm.Multiplication' is computed via the product rule. If the Jacobian of a symbol has already been calculated, the stored value is returned. Note: The Jacobian is the derivative of a symbol with respect to a (slice of) a State Vector. Parameters ---------- symbol : :class:`pybamm.Symbol` The symbol to calculate the Jacobian of variable : :class:`pybamm.Symbol` The variable with respect to which to differentiate Returns ------- :class:`pybamm.Symbol` Symbol representing the Jacobian """ try: return self._known_jacs[symbol] except KeyError: jac = self._jac(symbol, variable) self._known_jacs[symbol] = jac return jac
def _jac(self, symbol: pybamm.Symbol, variable: pybamm.Symbol): """See :meth:`Jacobian.jac()`.""" if isinstance(symbol, pybamm.BinaryOperator): left, right = symbol.children # process children left_jac = self.jac(left, variable) right_jac = self.jac(right, variable) # _binary_jac defined in derived classes for specific rules jac = symbol._binary_jac(left_jac, right_jac) elif isinstance(symbol, pybamm.UnaryOperator): child_jac = self.jac(symbol.child, variable) # type: ignore[has-type] # _unary_jac defined in derived classes for specific rules jac = symbol._unary_jac(child_jac) elif isinstance(symbol, pybamm.Function): children_jacs: list[None | pybamm.Symbol] = [None] * len(symbol.children) for i, child in enumerate(symbol.children): children_jacs[i] = self.jac(child, variable) # _function_jac defined in function class jac = symbol._function_jac(children_jacs) elif isinstance(symbol, pybamm.Concatenation): children_jacs = [self.jac(child, variable) for child in symbol.children] if len(children_jacs) == 1: jac = children_jacs[0] else: jac = symbol._concatenation_jac(children_jacs) else: try: jac = symbol._jac(variable) except NotImplementedError as error: raise NotImplementedError( f"Cannot calculate Jacobian of symbol of type '{type(symbol)}'" ) from error # Jacobian by default removes the domain(s) if self._clear_domain: jac.clear_domains() return jac