Source code for pybamm.expression_tree.operations.unpack_symbols

# Helper function to unpack a symbol

[docs]class SymbolUnpacker(object): """ Helper class to unpack a (set of) symbol(s) to find all instances of a class. Uses caching to speed up the process. Parameters ---------- classes_to_find : list of pybamm classes Classes to identify in the equations unpacked_symbols: set, optional cached unpacked equations """ def __init__(self, classes_to_find, unpacked_symbols=None): self.classes_to_find = classes_to_find self._unpacked_symbols = unpacked_symbols or {}
[docs] def unpack_list_of_symbols(self, list_of_symbols): """ Unpack a list of symbols. See :meth:`SymbolUnpacker.unpack()` Parameters ---------- list_of_symbols : list of :class:`pybamm.Symbol` List of symbols to unpack Returns ------- list of :class:`pybamm.Symbol` Set of unpacked symbols with class in `self.classes_to_find` """ all_instances = set() for symbol in list_of_symbols: new_instances = self.unpack_symbol(symbol) all_instances.update(new_instances) return all_instances
[docs] def unpack_symbol(self, symbol): """ This function recurses down the tree, unpacking the symbols and saving the ones that have a class in `self.classes_to_find`. Parameters ---------- symbol : list of :class:`pybamm.Symbol` The symbols to unpack Returns ------- list of :class:`pybamm.Symbol` List of unpacked symbols with class in `self.classes_to_find` """ try: return self._unpacked_symbols[symbol] except KeyError: unpacked = self._unpack(symbol) self._unpacked_symbols[symbol] = unpacked return unpacked
def _unpack(self, symbol): """See :meth:`SymbolUnpacker.unpack()`.""" # found a symbol of the right class -> return it if isinstance(symbol, self.classes_to_find): return set([symbol]) children = symbol.children if len(children) == 0: # not the right class and no children so the class to find doesn't appear return set() else: # iterate over all children found_vars = set() for child in children: # call back unpack_symbol to cache values child_vars = self.unpack_symbol(child) found_vars.update(child_vars) return found_vars