#
# Write a symbol to python
#
from __future__ import annotations
import numbers
from collections import OrderedDict
from numpy.typing import ArrayLike
import numpy as np
import scipy.sparse
import pybamm
if pybamm.have_jax():
import jax
platform = jax.lib.xla_bridge.get_backend().platform.casefold()
if platform != "metal":
jax.config.update("jax_enable_x64", True)
class JaxCooMatrix:
"""
A sparse matrix in COO format, with internal arrays using jax device arrays
This matrix only has two operations supported, a multiply with a scalar, and a
dot product with a dense vector. It can also be converted to a dense 2D jax
device array
Parameters
----------
row: arraylike
1D array holding row indices of non-zero entries
col: arraylike
1D array holding col indices of non-zero entries
data: arraylike
1D array holding non-zero entries
shape: 2-element tuple (x, y)
where x is the number of rows, and y the number of columns of the matrix
"""
def __init__(
self, row: ArrayLike, col: ArrayLike, data: ArrayLike, shape: tuple[int, int]
):
if not pybamm.have_jax(): # pragma: no cover
raise ModuleNotFoundError(
"Jax or jaxlib is not installed, please see https://docs.pybamm.org/en/latest/source/user_guide/installation/gnu-linux-mac.html#optional-jaxsolver"
)
self.row = jax.numpy.array(row)
self.col = jax.numpy.array(col)
self.data = jax.numpy.array(data)
self.shape = shape
self.nnz = len(self.data)
def toarray(self):
"""convert sparse matrix to a dense 2D array"""
result = jax.numpy.zeros(self.shape, dtype=self.data.dtype)
return result.at[self.row, self.col].add(self.data)
def dot_product(self, b):
"""
dot product of matrix with a dense column vector b
Parameters
----------
b: jax device array
must have shape (n, 1)
"""
# assume b is a column vector
result = jax.numpy.zeros((self.shape[0], 1), dtype=b.dtype)
return result.at[self.row].add(self.data.reshape(-1, 1) * b[self.col])
def scalar_multiply(self, b: float):
"""
multiply of matrix with a scalar b
Parameters
----------
b: Number or 1 element jax device array
scalar value to multiply
"""
# assume b is a scalar or ndarray with 1 element
return JaxCooMatrix(self.row, self.col, (self.data * b).reshape(-1), self.shape)
def multiply(self, b):
"""
general matrix multiply not supported
"""
raise NotImplementedError
def __matmul__(self, b):
"""see self.dot_product"""
return self.dot_product(b)
def create_jax_coo_matrix(value: scipy.sparse):
"""
Creates a JaxCooMatrix from a scipy.sparse matrix
Parameters
----------
value: scipy.sparse matrix
the sparse matrix to be converted
"""
scipy_coo = value.tocoo()
row = jax.numpy.asarray(scipy_coo.row)
col = jax.numpy.asarray(scipy_coo.col)
data = jax.numpy.asarray(scipy_coo.data)
return JaxCooMatrix(row, col, data, value.shape)
def id_to_python_variable(symbol_id, constant=False):
"""
This function defines the format for the python variable names used in find_symbols
and to_python. Variable names are based on a nodes' id to make them unique
"""
if constant:
var_format = "const_{:05d}"
else:
var_format = "var_{:05d}"
# Need to replace "-" character to make them valid python variable names
return var_format.format(symbol_id).replace("-", "m")
def is_scalar(arg):
is_number = isinstance(arg, numbers.Number)
if is_number:
return True
else:
return np.all(np.array(arg.shape) == 1)
def find_symbols(
symbol: pybamm.Symbol,
constant_symbols: OrderedDict,
variable_symbols: OrderedDict,
output_jax=False,
):
"""
This function converts an expression tree to a dictionary of node id's and strings
specifying valid python code to calculate that nodes value, given y and t.
The function distinguishes between nodes that represent constant nodes in the tree
(e.g. a pybamm.Matrix), and those that are variable (e.g. subtrees that contain
pybamm.StateVector). The former are put in `constant_symbols`, the latter in
`variable_symbols`
Note that it is important that the arguments `constant_symbols` and
`variable_symbols` be an *ordered* dict, since the final ordering of the code lines
are important for the calculations. A dict is specified rather than a list so that
identical subtrees (which give identical id's) are not recalculated in the code
Parameters
----------
symbol : :class:`pybamm.Symbol`
The symbol or expression tree to convert
constant_symbol: collections.OrderedDict
The output dictionary of constant symbol ids to lines of code
variable_symbol: collections.OrderedDict
The output dictionary of variable (with y or t) symbol ids to lines of code
output_jax: bool
If True, only numpy and jax operations will be used in the generated code,
raises NotImplNotImplementedError if any SparseStack or Mat-Mat multiply
operations are used
"""
# constant symbols that are not numbers are stored in a list of constants, which are
# passed into the generated function constant symbols that are numbers are written
# directly into the code
if symbol.is_constant():
value = symbol.evaluate()
if not isinstance(value, numbers.Number):
if output_jax and scipy.sparse.issparse(value):
# convert any remaining sparse matrices to our custom coo matrix
constant_symbols[symbol.id] = create_jax_coo_matrix(value)
else:
constant_symbols[symbol.id] = value
return
# process children recursively
for child in symbol.children:
find_symbols(child, constant_symbols, variable_symbols, output_jax)
# calculate the variable names that will hold the result of calculating the
# children variables
children_vars = []
for child in symbol.children:
if child.is_constant():
child_eval = child.evaluate()
if isinstance(child_eval, numbers.Number):
children_vars.append(str(child_eval))
else:
children_vars.append(id_to_python_variable(child.id, True))
else:
children_vars.append(id_to_python_variable(child.id, False))
if isinstance(symbol, pybamm.BinaryOperator):
# Multiplication and Division need special handling for scipy sparse matrices
# TODO: we can pass through a dummy y and t to get the type and then hardcode
# the right line, avoiding these checks
if isinstance(symbol, pybamm.Multiplication):
dummy_eval_left = symbol.children[0].evaluate_for_shape()
dummy_eval_right = symbol.children[1].evaluate_for_shape()
if scipy.sparse.issparse(dummy_eval_left):
if output_jax and is_scalar(dummy_eval_right):
symbol_str = (
f"{children_vars[0]}.scalar_multiply({children_vars[1]})"
)
else:
symbol_str = f"{children_vars[0]}.multiply({children_vars[1]})"
elif scipy.sparse.issparse(dummy_eval_right):
symbol_str = f"{children_vars[1]}.multiply({children_vars[0]})"
else:
symbol_str = f"{children_vars[0]} * {children_vars[1]}"
elif isinstance(symbol, pybamm.Division):
dummy_eval_left = symbol.children[0].evaluate_for_shape()
dummy_eval_right = symbol.children[1].evaluate_for_shape()
if scipy.sparse.issparse(dummy_eval_left):
if output_jax and is_scalar(dummy_eval_right):
symbol_str = (
f"{children_vars[0]}.scalar_multiply(1/{children_vars[1]})"
)
else:
symbol_str = f"{children_vars[0]}.multiply(1/{children_vars[1]})"
else:
symbol_str = f"{children_vars[0]} / {children_vars[1]}"
elif isinstance(symbol, pybamm.Inner):
dummy_eval_left = symbol.children[0].evaluate_for_shape()
dummy_eval_right = symbol.children[1].evaluate_for_shape()
if scipy.sparse.issparse(dummy_eval_left):
if output_jax and is_scalar(dummy_eval_right):
symbol_str = (
f"{children_vars[0]}.scalar_multiply({children_vars[1]})"
)
else:
symbol_str = f"{children_vars[0]}.multiply({children_vars[1]})"
elif scipy.sparse.issparse(dummy_eval_right):
if output_jax and is_scalar(dummy_eval_left):
symbol_str = (
f"{children_vars[1]}.scalar_multiply({children_vars[0]})"
)
else:
symbol_str = f"{children_vars[1]}.multiply({children_vars[0]})"
else:
symbol_str = f"{children_vars[0]} * {children_vars[1]}"
elif isinstance(symbol, pybamm.Minimum):
symbol_str = f"np.minimum({children_vars[0]},{children_vars[1]})"
elif isinstance(symbol, pybamm.Maximum):
symbol_str = f"np.maximum({children_vars[0]},{children_vars[1]})"
elif isinstance(symbol, pybamm.MatrixMultiplication):
dummy_eval_left = symbol.children[0].evaluate_for_shape()
dummy_eval_right = symbol.children[1].evaluate_for_shape()
if output_jax and (
scipy.sparse.issparse(dummy_eval_left)
and scipy.sparse.issparse(dummy_eval_right)
):
raise NotImplementedError(
"sparse mat-mat multiplication not supported "
"for output_jax == True"
)
else:
symbol_str = (
children_vars[0] + " " + symbol.name + " " + children_vars[1]
)
else:
symbol_str = children_vars[0] + " " + symbol.name + " " + children_vars[1]
elif isinstance(symbol, pybamm.UnaryOperator):
# Index has a different syntax than other univariate operations
if isinstance(symbol, pybamm.Index):
symbol_str = f"{children_vars[0]}[{symbol.slice.start}:{symbol.slice.stop}]"
else:
symbol_str = symbol.name + children_vars[0]
elif isinstance(symbol, pybamm.Function):
children_str = ""
for child_var in children_vars:
if children_str == "":
children_str = child_var
else:
children_str += ", " + child_var
if isinstance(symbol.function, np.ufunc):
# write any numpy functions directly
symbol_str = f"np.{symbol.function.__name__}({children_str})"
else:
# unknown function, store it as a constant and call this in the
# generated code
constant_symbols[symbol.id] = symbol.function
funct_var = id_to_python_variable(symbol.id, True)
symbol_str = f"{funct_var}({children_str})"
elif isinstance(symbol, pybamm.Concatenation):
# no need to concatenate if there is only a single child
if isinstance(symbol, pybamm.NumpyConcatenation):
if len(children_vars) == 1:
symbol_str = children_vars[0]
else:
symbol_str = "np.concatenate(({}))".format(",".join(children_vars))
elif isinstance(symbol, pybamm.SparseStack):
if len(children_vars) == 1:
symbol_str = children_vars[0]
else:
if output_jax:
raise NotImplementedError
else:
symbol_str = "scipy.sparse.vstack(({}))".format(
",".join(children_vars)
)
# DomainConcatenation specifies a particular ordering for the concatenation,
# which we must follow
elif isinstance(symbol, pybamm.DomainConcatenation):
slice_starts = []
all_child_vectors = []
for i in range(symbol.secondary_dimensions_npts):
child_vectors = []
for child_var, slices in zip(children_vars, symbol._children_slices):
for child_dom, child_slice in slices.items():
slice_starts.append(symbol._slices[child_dom][i].start)
child_vectors.append(
f"{child_var}[{child_slice[i].start}:{child_slice[i].stop}]"
)
all_child_vectors.extend(
[v for _, v in sorted(zip(slice_starts, child_vectors))]
)
if len(children_vars) > 1 or symbol.secondary_dimensions_npts > 1:
symbol_str = "np.concatenate(({}))".format(",".join(all_child_vectors))
else:
symbol_str = "{}".format(",".join(children_vars))
else:
raise NotImplementedError
# Note: we assume that y is being passed as a column vector
elif isinstance(symbol, pybamm.StateVector):
indices = np.argwhere(symbol.evaluation_array).reshape(-1).astype(np.int32)
consecutive = np.all(indices[1:] - indices[:-1] == 1)
if len(indices) == 1 or consecutive:
symbol_str = f"y[{indices[0]}:{indices[-1] + 1}]"
else:
indices_array = pybamm.Array(indices)
constant_symbols[indices_array.id] = indices
index_name = id_to_python_variable(indices_array.id, True)
symbol_str = f"y[{index_name}]"
elif isinstance(symbol, pybamm.Time):
symbol_str = "t"
elif isinstance(symbol, pybamm.InputParameter):
symbol_str = f'inputs["{symbol.name}"]'
else:
raise NotImplementedError(
f"Conversion to python not implemented for a symbol of type '{type(symbol)}'"
)
variable_symbols[symbol.id] = symbol_str
def to_python(
symbol: pybamm.Symbol, debug=False, output_jax=False
) -> tuple[OrderedDict, str]:
"""
This function converts an expression tree into a dict of constant input values, and
valid python code that acts like the tree's :func:`pybamm.Symbol.evaluate` function
Parameters
----------
symbol : :class:`pybamm.Symbol`
The symbol to convert to python code
debug : bool
If set to True, the function also emits debug code
Returns
-------
collections.OrderedDict:
dict mapping node id to a constant value. Represents all the constant nodes in
the expression tree
str:
valid python code that will evaluate all the variable nodes in the tree.
output_jax: bool
If True, only numpy and jax operations will be used in the generated code.
Raises NotImplNotImplementedError if any SparseStack or Mat-Mat multiply
operations are used
"""
constant_values: OrderedDict = OrderedDict()
variable_symbols: OrderedDict = OrderedDict()
find_symbols(symbol, constant_values, variable_symbols, output_jax)
line_format = "{} = {}"
if debug: # pragma: no cover
variable_lines = [
f"print('{line_format.format(id_to_python_variable(symbol_id, False), symbol_line)}'); "
+ line_format.format(id_to_python_variable(symbol_id, False), symbol_line)
+ "; print(type({0}),np.shape({0}))".format(
id_to_python_variable(symbol_id, False)
)
for symbol_id, symbol_line in variable_symbols.items()
]
else:
variable_lines = [
line_format.format(id_to_python_variable(symbol_id, False), symbol_line)
for symbol_id, symbol_line in variable_symbols.items()
]
return constant_values, "\n".join(variable_lines)
[docs]
class EvaluatorPython:
"""
Converts a pybamm expression tree into pure python code that will calculate the
result of calling `evaluate(t, y)` on the given expression tree.
Parameters
----------
symbol : :class:`pybamm.Symbol`
The symbol to convert to python code
"""
def __init__(self, symbol: pybamm.Symbol):
constants, python_str = pybamm.to_python(symbol, debug=False)
# extract constants in generated function
for i, symbol_id in enumerate(constants.keys()):
const_name = id_to_python_variable(symbol_id, True)
python_str = f"{const_name} = constants[{i}]\n" + python_str
# constants passed in as an ordered dict, convert to list
self._constants = list(constants.values())
# indent code
python_str = " " + python_str
python_str = python_str.replace("\n", "\n ")
# add function def to first line
python_str = (
"def evaluate(constants, t=None, y=None, " "inputs=None):\n" + python_str
)
# calculate the final variable that will output the result of calling `evaluate`
# on `symbol`
result_var = id_to_python_variable(symbol.id, symbol.is_constant())
if symbol.is_constant():
result_value = symbol.evaluate()
# add return line
if symbol.is_constant() and isinstance(result_value, numbers.Number):
python_str = python_str + "\n return " + str(result_value)
else:
python_str = python_str + "\n return " + result_var
# store a copy of examine_jaxpr
python_str = python_str + "\nself._evaluate = evaluate"
self._python_str = python_str
self._result_var = result_var
self._symbol = symbol
# compile and run the generated python code,
compiled_function = compile(python_str, result_var, "exec")
exec(compiled_function)
def __call__(self, t=None, y=None, inputs=None):
"""
evaluate function
"""
# generated code assumes y is a column vector
if y is not None and y.ndim == 1:
y = y.reshape(-1, 1)
result = self._evaluate(self._constants, t, y, inputs)
return result
def __getstate__(self):
# Control the state of instances of EvaluatorPython
# before pickling. Method "_evaluate" cannot be pickled.
# See https://github.com/pybamm-team/PyBaMM/issues/1283
state = self.__dict__.copy()
del state["_evaluate"]
return state
def __setstate__(self, state):
# Restore pickled attributes and
# compile code from "python_str"
# Execution of bytecode (re)adds attribute
# "_method"
self.__dict__.update(state)
compiled_function = compile(self._python_str, self._result_var, "exec")
exec(compiled_function)
class EvaluatorJax:
"""
Converts a pybamm expression tree into pure python code that will calculate the
result of calling `evaluate(t, y)` on the given expression tree. The resultant code
is compiled with JAX
Limitations: JAX currently does not work on expressions involving sparse matrices,
so any sparse matrices and operations involved sparse matrices are converted to
their dense equivilents before compilation
Parameters
----------
symbol : :class:`pybamm.Symbol`
The symbol to convert to python code
"""
def __init__(self, symbol: pybamm.Symbol):
if not pybamm.have_jax(): # pragma: no cover
raise ModuleNotFoundError(
"Jax or jaxlib is not installed, please see https://docs.pybamm.org/en/latest/source/user_guide/installation/gnu-linux-mac.html#optional-jaxsolver"
)
constants, python_str = pybamm.to_python(symbol, debug=False, output_jax=True)
# replace numpy function calls to jax numpy calls
python_str = python_str.replace("np.", "jax.numpy.")
# convert all numpy constants to device vectors
for symbol_id in constants:
if isinstance(constants[symbol_id], np.ndarray):
constants[symbol_id] = jax.device_put(constants[symbol_id])
# get a list of constant arguments to input to the function
self._arg_list = [
id_to_python_variable(symbol_id, True) for symbol_id in constants.keys()
]
# get a list of hashable arguments to make static
# a jax device array is not hashable
static_argnums = (
i
for i, c in enumerate(constants.values())
if not (isinstance(c, jax.Array))
)
# store constants
self._constants = tuple(constants.values())
# indent code
python_str = " " + python_str
python_str = python_str.replace("\n", "\n ")
# add function def to first line
args = "t=None, y=None, inputs=None"
if self._arg_list:
args = ",".join(self._arg_list) + ", " + args
python_str = f"def evaluate_jax({args}):\n" + python_str
# calculate the final variable that will output the result of calling `evaluate`
# on `symbol`
result_var = id_to_python_variable(symbol.id, symbol.is_constant())
if symbol.is_constant():
result_value = symbol.evaluate()
# add return line
if symbol.is_constant() and isinstance(result_value, numbers.Number):
python_str = python_str + "\n return " + str(result_value)
else:
python_str = python_str + "\n return " + result_var
# store a copy of examine_jaxpr
python_str = python_str + "\nself._evaluate_jax = evaluate_jax"
# store the final generated code
self._python_str = python_str
# compile and run the generated python code,
compiled_function = compile(python_str, result_var, "exec")
exec(compiled_function)
self._static_argnums = tuple(static_argnums)
self._jit_evaluate = jax.jit(
self._evaluate_jax, # type:ignore[attr-defined]
static_argnums=self._static_argnums,
)
def get_jacobian(self):
n = len(self._arg_list)
# forward mode autodiff wrt y, which is argument 1 after arg_list
jacobian_evaluate = jax.jacfwd(self._evaluate_jax, argnums=1 + n)
self._jac_evaluate = jax.jit(
jacobian_evaluate, static_argnums=self._static_argnums
)
return EvaluatorJaxJacobian(self._jac_evaluate, self._constants)
def get_jacobian_action(self):
return self.jvp
def get_sensitivities(self):
n = len(self._arg_list)
# forward mode autodiff wrt inputs, which is argument 2 after arg_list
jacobian_evaluate = jax.jacfwd(self._evaluate_jax, argnums=2 + n)
self._sens_evaluate = jax.jit(
jacobian_evaluate, static_argnums=self._static_argnums
)
return EvaluatorJaxSensitivities(self._sens_evaluate, self._constants)
def debug(self, t=None, y=None, inputs=None):
# generated code assumes y is a column vector
if y is not None and y.ndim == 1:
y = y.reshape(-1, 1)
# execute code
jaxpr = jax.make_jaxpr(self._evaluate_jax)(*self._constants, t, y, inputs).jaxpr
print("invars:", jaxpr.invars)
print("outvars:", jaxpr.outvars)
print("constvars:", jaxpr.constvars)
for eqn in jaxpr.eqns:
print("equation:", eqn.invars, eqn.primitive, eqn.outvars, eqn.params)
print()
print("jaxpr:", jaxpr)
def __call__(self, t=None, y=None, inputs=None):
"""
evaluate function
"""
# generated code assumes y is a column vector
if y is not None and y.ndim == 1:
y = y.reshape(-1, 1)
result = self._jit_evaluate(*self._constants, t, y, inputs)
return result
def jvp(self, t=None, y=None, v=None, inputs=None):
"""
evaluate jacobian vector product of function
"""
# generated code assumes y is a column vector
if y is not None and y.ndim == 1:
y = y.reshape(-1, 1)
if v is not None and v.ndim == 1:
v = v.reshape(-1, 1)
def bind_t_and_inputs(the_y):
return self._jit_evaluate(*self._constants, t, the_y, inputs)
return jax.jvp(bind_t_and_inputs, (y,), (v,))[1]
class EvaluatorJaxJacobian:
def __init__(self, jac_evaluate, constants):
self._jac_evaluate = jac_evaluate
self._constants = constants
def __call__(self, t=None, y=None, inputs=None):
"""
evaluate function
"""
# generated code assumes y is a column vector
if y is not None and y.ndim == 1:
y = y.reshape(-1, 1)
# execute code
result = self._jac_evaluate(*self._constants, t, y, inputs)
result = result.reshape(result.shape[0], -1)
return result
class EvaluatorJaxSensitivities:
def __init__(self, jac_evaluate, constants):
self._jac_evaluate = jac_evaluate
self._constants = constants
def __call__(self, t=None, y=None, inputs=None):
"""
evaluate function
"""
# generated code assumes y is a column vector
if y is not None and y.ndim == 1:
y = y.reshape(-1, 1)
# execute code
result = self._jac_evaluate(*self._constants, t, y, inputs)
result = {
key: value.reshape(value.shape[0], -1) for key, value in result.items()
}
return result