Source code for pybamm.expression_tree.operations.unpack_symbols
#
# Helper function to unpack a symbol
#
from __future__ import annotations
from typing import TYPE_CHECKING
from collections.abc import Sequence
if TYPE_CHECKING: # pragma: no cover
import pybamm
[docs]
class SymbolUnpacker:
"""
Helper class to unpack a (set of) symbol(s) to find all instances of a class.
Uses caching to speed up the process.
Parameters
----------
classes_to_find : list of pybamm classes
Classes to identify in the equations
unpacked_symbols: set, optional
cached unpacked equations
"""
def __init__(
self,
classes_to_find: Sequence[pybamm.Symbol] | pybamm.Symbol,
unpacked_symbols: dict | None = None,
):
self.classes_to_find = classes_to_find
self._unpacked_symbols: dict = unpacked_symbols or {}
[docs]
def unpack_list_of_symbols(
self, list_of_symbols: Sequence[pybamm.Symbol]
) -> set[pybamm.Symbol]:
"""
Unpack a list of symbols. See :meth:`SymbolUnpacker.unpack()`
Parameters
----------
list_of_symbols : list of :class:`pybamm.Symbol`
List of symbols to unpack
Returns
-------
list of :class:`pybamm.Symbol`
Set of unpacked symbols with class in `self.classes_to_find`
"""
all_instances = set()
for symbol in list_of_symbols:
new_instances = self.unpack_symbol(symbol)
all_instances.update(new_instances)
return all_instances
[docs]
def unpack_symbol(
self, symbol: Sequence[pybamm.Symbol] | pybamm.Symbol
) -> list[pybamm.Symbol]:
"""
This function recurses down the tree, unpacking the symbols and saving the ones
that have a class in `self.classes_to_find`.
Parameters
----------
symbol : list of :class:`pybamm.Symbol`
The symbols to unpack
Returns
-------
list of :class:`pybamm.Symbol`
List of unpacked symbols with class in `self.classes_to_find`
"""
try:
return self._unpacked_symbols[symbol]
except KeyError:
unpacked = self._unpack(symbol)
self._unpacked_symbols[symbol] = unpacked
return unpacked
def _unpack(self, symbol):
"""See :meth:`SymbolUnpacker.unpack()`."""
# found a symbol of the right class -> return it
if isinstance(symbol, self.classes_to_find):
return set([symbol])
children = symbol.children
if len(children) == 0:
# not the right class and no children so the class to find doesn't appear
return set()
else:
# iterate over all children
found_vars = set()
for child in children:
# call back unpack_symbol to cache values
child_vars = self.unpack_symbol(child)
found_vars.update(child_vars)
return found_vars