#
# Function classes and methods
#
from __future__ import annotations
import numpy as np
from scipy import special
import sympy
from typing import Callable
from collections.abc import Sequence
from typing_extensions import TypeVar
import pybamm
from pybamm.util import import_optional_dependency
[docs]
class Function(pybamm.Symbol):
"""
A node in the expression tree representing an arbitrary function.
Parameters
----------
function : method
A function can have 0 or many inputs. If no inputs are given, self.evaluate()
simply returns func(). Otherwise, self.evaluate(t, y, u) returns
func(child0.evaluate(t, y, u), child1.evaluate(t, y, u), etc).
children : :class:`pybamm.Symbol`
The children nodes to apply the function to
derivative : str, optional
Which derivative to use when differentiating ("autograd" or "derivative").
Default is "autograd".
differentiated_function : method, optional
The function which was differentiated to obtain this one. Default is None.
"""
def __init__(
self,
function: Callable,
*children: pybamm.Symbol,
name: str | None = None,
derivative: str | None = "autograd",
differentiated_function: Callable | None = None,
):
# Turn numbers into scalars
children = list(children)
for idx, child in enumerate(children):
if isinstance(child, (float, int, np.number)):
children[idx] = pybamm.Scalar(child)
if name is not None:
self.name = name
else:
try:
name = f"function ({function.__name__})"
except AttributeError:
name = f"function ({function.__class__})"
domains = self.get_children_domains(children)
self.function = function
self.derivative = derivative
self.differentiated_function = differentiated_function
super().__init__(name, children=children, domains=domains)
def __str__(self):
"""See :meth:`pybamm.Symbol.__str__()`."""
out = f"{self.name[10:-1]}("
for child in self.children:
out += f"{child!s}, "
out = out[:-2] + ")"
return out
[docs]
def diff(self, variable: pybamm.Symbol):
"""See :meth:`pybamm.Symbol.diff()`."""
if variable == self:
return pybamm.Scalar(1)
else:
children = self.orphans
partial_derivatives: list[None | pybamm.Symbol] = [None] * len(children)
for i, child in enumerate(self.children):
# if variable appears in the function, differentiate
# function, and apply chain rule
if variable in child.pre_order():
partial_derivatives[i] = self._function_diff(
children, i
) * child.diff(variable)
# remove None entries
partial_derivatives = [x for x in partial_derivatives if x is not None]
derivative = sum(partial_derivatives)
if derivative == 0:
return pybamm.Scalar(0)
return derivative
def _function_diff(self, children: Sequence[pybamm.Symbol], idx: float):
"""
Derivative with respect to child number 'idx'.
See :meth:`pybamm.Symbol._diff()`.
"""
autograd = import_optional_dependency("autograd")
# Store differentiated function, needed in case we want to convert to CasADi
if self.derivative == "autograd":
return Function(
autograd.elementwise_grad(self.function, idx),
*children,
differentiated_function=self.function,
)
elif self.derivative == "derivative":
if len(children) > 1:
raise ValueError(
"""
differentiation using '.derivative()' not implemented for functions
with more than one child
"""
)
else:
# keep using "derivative" as derivative
return pybamm.Function(
self.function.derivative(), # type: ignore[attr-defined]
*children,
derivative="derivative",
differentiated_function=self.function,
)
def _function_jac(self, children_jacs):
"""Calculate the Jacobian of a function."""
if all(child.evaluates_to_constant_number() for child in self.children):
jacobian = pybamm.Scalar(0)
else:
# if at least one child contains variable dependence, then
# calculate the required partial Jacobians and add them
jacobian = None
children = self.orphans
for i, child in enumerate(children):
if not child.evaluates_to_constant_number():
jac_fun = self._function_diff(children, i) * children_jacs[i]
jac_fun.clear_domains()
if jacobian is None:
jacobian = jac_fun
else:
jacobian += jac_fun
return jacobian
[docs]
def evaluate(
self,
t: float | None = None,
y: np.ndarray | None = None,
y_dot: np.ndarray | None = None,
inputs: dict | str | None = None,
):
"""See :meth:`pybamm.Symbol.evaluate()`."""
evaluated_children = [
child.evaluate(t, y, y_dot, inputs) for child in self.children
]
return self._function_evaluate(evaluated_children)
def _evaluates_on_edges(self, dimension: str) -> bool:
"""See :meth:`pybamm.Symbol._evaluates_on_edges()`."""
return any(child.evaluates_on_edges(dimension) for child in self.children)
[docs]
def is_constant(self):
"""See :meth:`pybamm.Symbol.is_constant()`."""
return all(child.is_constant() for child in self.children)
def _evaluate_for_shape(self):
"""
Default behaviour: has same shape as all child
See :meth:`pybamm.Symbol.evaluate_for_shape()`
"""
evaluated_children = [child.evaluate_for_shape() for child in self.children]
return self._function_evaluate(evaluated_children)
def _function_evaluate(self, evaluated_children):
return self.function(*evaluated_children)
[docs]
def create_copy(self):
"""See :meth:`pybamm.Symbol.new_copy()`."""
children_copy = [child.new_copy() for child in self.children]
return self._function_new_copy(children_copy)
def _function_new_copy(self, children: list) -> Function:
"""
Returns a new copy of the function.
Inputs
------
children : : list
A list of the children of the function
Returns
-------
: :pybamm.Function
A new copy of the function
"""
return pybamm.simplify_if_constant(
pybamm.Function(
self.function,
*children,
name=self.name,
derivative=self.derivative,
differentiated_function=self.differentiated_function,
)
)
def _sympy_operator(self, child):
"""Apply appropriate SymPy operators."""
return child
[docs]
def to_equation(self):
"""Convert the node and its subtree into a SymPy equation."""
if self.print_name is not None:
return sympy.Symbol(self.print_name)
else:
eq_list = []
for child in self.children:
eq = child.to_equation()
eq_list.append(eq)
return self._sympy_operator(*eq_list)
[docs]
def to_json(self):
raise NotImplementedError(
"pybamm.Function: Serialisation is only implemented for discretised models."
)
@classmethod
def _from_json(cls, snippet):
raise NotImplementedError(
"pybamm.Function: Please use a discretised model when reading in from JSON."
)
[docs]
class SpecificFunction(Function):
"""
Parent class for the specific functions, which implement their own `diff`
operators directly.
Parameters
----------
function : method
Function to be applied to child
child : :class:`pybamm.Symbol`
The child to apply the function to
"""
def __init__(self, function: Callable, child: pybamm.Symbol):
super().__init__(function, child)
@classmethod
def _from_json(cls, snippet: dict):
"""
Reconstructs a SpecificFunction instance during deserialisation of a JSON file.
Parameters
----------
function : method
Function to be applied to child
snippet: dict
Contains the child to apply the function to
"""
instance = cls.__new__(cls)
super(SpecificFunction, instance).__init__(
snippet["function"], snippet["children"][0]
)
return instance
def _function_new_copy(self, children):
"""See :meth:`pybamm.Function._function_new_copy()`"""
return pybamm.simplify_if_constant(self.__class__(*children))
def _sympy_operator(self, child):
"""Apply appropriate SymPy operators."""
class_name = self.__class__.__name__.lower()
sympy_function = getattr(sympy, class_name)
return sympy_function(child)
[docs]
def to_json(self):
"""
Method to serialise a SpecificFunction object into JSON.
"""
json_dict = {
"name": self.name,
"id": self.id,
"function": self.function.__name__,
}
return json_dict
SF = TypeVar("SF", bound=SpecificFunction)
def simplified_function(func_class: type[SF], child: pybamm.Symbol):
"""
Simplifications implemented before applying the function.
Currently only implemented for one-child functions.
"""
if isinstance(child, pybamm.Broadcast):
# Move the function inside the broadcast
# Apply recursively
func_child_not_broad = pybamm.simplify_if_constant(
simplified_function(func_class, child.orphans[0])
)
return child._unary_new_copy(func_child_not_broad)
else:
return pybamm.simplify_if_constant(func_class(child)) # type: ignore[call-arg, arg-type]
[docs]
class Arcsinh(SpecificFunction):
"""Arcsinh function."""
def __init__(self, child):
super().__init__(np.arcsinh, child)
@classmethod
def _from_json(cls, snippet: dict):
"""See :meth:`pybamm.SpecificFunction._from_json()`."""
snippet["function"] = np.arcsinh
instance = super()._from_json(snippet)
return instance
def _function_diff(self, children, idx):
"""See :meth:`pybamm.Symbol._function_diff()`."""
return 1 / sqrt(children[0] ** 2 + 1)
def _sympy_operator(self, child):
"""Override :meth:`pybamm.Function._sympy_operator`"""
return sympy.asinh(child)
[docs]
def arcsinh(child: pybamm.Symbol):
"""Returns arcsinh function of child."""
return simplified_function(Arcsinh, child)
[docs]
class Arctan(SpecificFunction):
"""Arctan function."""
def __init__(self, child):
super().__init__(np.arctan, child)
@classmethod
def _from_json(cls, snippet: dict):
"""See :meth:`pybamm.SpecificFunction._from_json()`."""
snippet["function"] = np.arctan
instance = super()._from_json(snippet)
return instance
def _function_diff(self, children, idx):
"""See :meth:`pybamm.Function._function_diff()`."""
return 1 / (children[0] ** 2 + 1)
def _sympy_operator(self, child):
"""Override :meth:`pybamm.Function._sympy_operator`"""
return sympy.atan(child)
[docs]
def arctan(child: pybamm.Symbol):
"""Returns hyperbolic tan function of child."""
return simplified_function(Arctan, child)
[docs]
class Cos(SpecificFunction):
"""Cosine function."""
def __init__(self, child):
super().__init__(np.cos, child)
@classmethod
def _from_json(cls, snippet: dict):
"""See :meth:`pybamm.SpecificFunction._from_json()`."""
snippet["function"] = np.cos
instance = super()._from_json(snippet)
return instance
def _function_diff(self, children, idx):
"""See :meth:`pybamm.Symbol._function_diff()`."""
return -sin(children[0])
[docs]
def cos(child: pybamm.Symbol):
"""Returns cosine function of child."""
return simplified_function(Cos, child)
[docs]
class Cosh(SpecificFunction):
"""Hyberbolic cosine function."""
def __init__(self, child):
super().__init__(np.cosh, child)
@classmethod
def _from_json(cls, snippet: dict):
"""See :meth:`pybamm.SpecificFunction._from_json()`."""
snippet["function"] = np.cosh
instance = super()._from_json(snippet)
return instance
def _function_diff(self, children, idx):
"""See :meth:`pybamm.Function._function_diff()`."""
return sinh(children[0])
[docs]
def cosh(child: pybamm.Symbol):
"""Returns hyperbolic cosine function of child."""
return simplified_function(Cosh, child)
[docs]
class Erf(SpecificFunction):
"""Error function."""
def __init__(self, child):
super().__init__(special.erf, child)
@classmethod
def _from_json(cls, snippet: dict):
"""See :meth:`pybamm.SpecificFunction._from_json()`."""
snippet["function"] = special.erf
instance = super()._from_json(snippet)
return instance
def _function_diff(self, children, idx):
"""See :meth:`pybamm.Function._function_diff()`."""
return 2 / np.sqrt(np.pi) * exp(-(children[0] ** 2))
[docs]
def erf(child: pybamm.Symbol):
"""Returns error function of child."""
return simplified_function(Erf, child)
[docs]
def erfc(child: pybamm.Symbol):
"""Returns complementary error function of child."""
return 1 - simplified_function(Erf, child)
[docs]
class Exp(SpecificFunction):
"""Exponential function."""
def __init__(self, child):
super().__init__(np.exp, child)
@classmethod
def _from_json(cls, snippet: dict):
"""See :meth:`pybamm.SpecificFunction._from_json()`."""
snippet["function"] = np.exp
instance = super()._from_json(snippet)
return instance
def _function_diff(self, children, idx):
"""See :meth:`pybamm.Function._function_diff()`."""
return exp(children[0])
[docs]
def exp(child: pybamm.Symbol):
"""Returns exponential function of child."""
return simplified_function(Exp, child)
[docs]
class Log(SpecificFunction):
"""Logarithmic function."""
def __init__(self, child):
super().__init__(np.log, child)
@classmethod
def _from_json(cls, snippet: dict):
"""See :meth:`pybamm.SpecificFunction._from_json()`."""
snippet["function"] = np.log
instance = super()._from_json(snippet)
return instance
def _function_evaluate(self, evaluated_children):
# don't raise RuntimeWarning for NaNs
with np.errstate(invalid="ignore"):
return np.log(*evaluated_children)
def _function_diff(self, children, idx):
"""See :meth:`pybamm.Function._function_diff()`."""
return 1 / children[0]
[docs]
def log(child, base="e"):
"""Returns logarithmic function of child (any base, default 'e')."""
log_child = simplified_function(Log, child)
if base == "e":
return log_child
else:
return log_child / np.log(base)
[docs]
def log10(child: pybamm.Symbol):
"""Returns logarithmic function of child, with base 10."""
return log(child, base=10)
[docs]
class Max(SpecificFunction):
"""Max function."""
def __init__(self, child):
super().__init__(np.max, child)
@classmethod
def _from_json(cls, snippet: dict):
"""See :meth:`pybamm.SpecificFunction._from_json()`."""
snippet["function"] = np.max
instance = super()._from_json(snippet)
return instance
def _evaluate_for_shape(self):
"""See :meth:`pybamm.Symbol.evaluate_for_shape_using_domain()`"""
# Max will always return a scalar
return np.nan * np.ones((1, 1))
[docs]
def max(child: pybamm.Symbol):
"""
Returns max function of child. Not to be confused with :meth:`pybamm.maximum`, which
returns the larger of two objects.
"""
return pybamm.simplify_if_constant(Max(child))
[docs]
class Min(SpecificFunction):
"""Min function."""
def __init__(self, child):
super().__init__(np.min, child)
@classmethod
def _from_json(cls, snippet: dict):
"""See :meth:`pybamm.SpecificFunction._from_json()`."""
snippet["function"] = np.min
instance = super()._from_json(snippet)
return instance
def _evaluate_for_shape(self):
"""See :meth:`pybamm.Symbol.evaluate_for_shape_using_domain()`"""
# Min will always return a scalar
return np.nan * np.ones((1, 1))
[docs]
def min(child: pybamm.Symbol):
"""
Returns min function of child. Not to be confused with :meth:`pybamm.minimum`, which
returns the smaller of two objects.
"""
return pybamm.simplify_if_constant(Min(child))
[docs]
def sech(child: pybamm.Symbol):
"""Returns hyperbolic sec function of child."""
return 1 / simplified_function(Cosh, child)
[docs]
class Sin(SpecificFunction):
"""Sine function."""
def __init__(self, child):
super().__init__(np.sin, child)
@classmethod
def _from_json(cls, snippet: dict):
"""See :meth:`pybamm.SpecificFunction._from_json()`."""
snippet["function"] = np.sin
instance = super()._from_json(snippet)
return instance
def _function_diff(self, children, idx):
"""See :meth:`pybamm.Function._function_diff()`."""
return cos(children[0])
[docs]
def sin(child: pybamm.Symbol):
"""Returns sine function of child."""
return simplified_function(Sin, child)
[docs]
class Sinh(SpecificFunction):
"""Hyperbolic sine function."""
def __init__(self, child):
super().__init__(np.sinh, child)
@classmethod
def _from_json(cls, snippet: dict):
"""See :meth:`pybamm.SpecificFunction._from_json()`."""
snippet["function"] = np.sinh
instance = super()._from_json(snippet)
return instance
def _function_diff(self, children, idx):
"""See :meth:`pybamm.Function._function_diff()`."""
return cosh(children[0])
[docs]
def sinh(child: pybamm.Symbol):
"""Returns hyperbolic sine function of child."""
return simplified_function(Sinh, child)
[docs]
class Sqrt(SpecificFunction):
"""Square root function."""
def __init__(self, child):
super().__init__(np.sqrt, child)
@classmethod
def _from_json(cls, snippet: dict):
"""See :meth:`pybamm.SpecificFunction._from_json()`."""
snippet["function"] = np.sqrt
instance = super()._from_json(snippet)
return instance
def _function_evaluate(self, evaluated_children):
# don't raise RuntimeWarning for NaNs
with np.errstate(invalid="ignore"):
return np.sqrt(*evaluated_children)
def _function_diff(self, children, idx):
"""See :meth:`pybamm.Function._function_diff()`."""
return 1 / (2 * sqrt(children[0]))
[docs]
def sqrt(child: pybamm.Symbol):
"""Returns square root function of child."""
return simplified_function(Sqrt, child)
[docs]
class Tanh(SpecificFunction):
"""Hyperbolic tan function."""
def __init__(self, child):
super().__init__(np.tanh, child)
@classmethod
def _from_json(cls, snippet: dict):
"""See :meth:`pybamm.SpecificFunction._from_json()`."""
snippet["function"] = np.tanh
instance = super()._from_json(snippet)
return instance
def _function_diff(self, children, idx):
"""See :meth:`pybamm.Function._function_diff()`."""
return sech(children[0]) ** 2
[docs]
def tanh(child: pybamm.Symbol):
"""Returns hyperbolic tan function of child."""
return simplified_function(Tanh, child)
[docs]
def normal_pdf(
x: pybamm.Symbol, mu: pybamm.Symbol | float, sigma: pybamm.Symbol | float
):
"""
Returns the normal probability density function at x.
Parameters
----------
x : pybamm.Symbol
The value at which to evaluate the normal distribution
mu : pybamm.Symbol or float
The mean of the normal distribution
sigma : pybamm.Symbol or float
The standard deviation of the normal distribution
Returns
-------
pybamm.Symbol
The value of the normal distribution at x
"""
return 1 / (np.sqrt(2 * np.pi) * sigma) * np.exp(-0.5 * ((x - mu) / sigma) ** 2)
[docs]
def normal_cdf(
x: pybamm.Symbol, mu: pybamm.Symbol | float, sigma: pybamm.Symbol | float
):
"""
Returns the normal cumulative distribution function at x.
Parameters
----------
x : pybamm.Symbol
The value at which to evaluate the normal distribution
mu : pybamm.Symbol or float
The mean of the normal distribution
sigma : pybamm.Symbol or float
The standard deviation of the normal distribution
Returns
-------
pybamm.Symbol
The value of the normal distribution at x
"""
return 0.5 * (1 + special.erf((x - mu) / (sigma * np.sqrt(2))))