Source code for pybamm.expression_tree.operations.serialise

from __future__ import annotations

import pybamm
from datetime import datetime
import json
import importlib
import numpy as np
import re


[docs] class Serialise: """ Converts a discretised model to and from a JSON file. """ def __init__(self): pass class _SymbolEncoder(json.JSONEncoder): """Converts PyBaMM symbols into a JSON-serialisable format""" def default(self, node: dict): node_dict = {"py/object": str(type(node))[8:-2], "py/id": id(node)} if isinstance(node, pybamm.Symbol): node_dict.update(node.to_json()) # this doesn't include children node_dict["children"] = [] for c in node.children: node_dict["children"].append(self.default(c)) if hasattr(node, "initial_condition"): # for ExplicitTimeIntegral node_dict["initial_condition"] = self.default( node.initial_condition ) return node_dict if isinstance(node, pybamm.Event): node_dict.update(node.to_json()) node_dict["expression"] = self.default(node._expression) return node_dict node_dict["json"] = json.JSONEncoder.default(self, node) # pragma: no cover return node_dict # pragma: no cover class _MeshEncoder(json.JSONEncoder): """Converts PyBaMM meshes into a JSON-serialisable format""" def default(self, node: pybamm.Mesh): node_dict = {"py/object": str(type(node))[8:-2], "py/id": id(node)} if isinstance(node, pybamm.Mesh): node_dict.update(node.to_json()) submeshes = {} for k, v in node.items(): if len(k) == 1 and "ghost cell" not in k[0]: submeshes[k[0]] = self.default(v) node_dict["sub_meshes"] = submeshes return node_dict if isinstance(node, pybamm.SubMesh): node_dict.update(node.to_json()) return node_dict node_dict["json"] = json.JSONEncoder.default(self, node) # pragma: no cover return node_dict # pragma: no cover class _Empty: """A dummy class to aid deserialisation""" pass class _EmptyDict(dict): """A dummy dictionary class to aid deserialisation""" pass
[docs] def save_model( self, model: pybamm.BaseModel, mesh: pybamm.Mesh | None = None, variables: pybamm.FuzzyDict | None = None, filename: str | None = None, ): """Saves a discretised model to a JSON file. As the model is discretised and ready to solve, only the right hand side, algebraic and initial condition variables are saved. Parameters ---------- model : :class:`pybamm.BaseModel` The discretised model to be saved mesh : :class:`pybamm.Mesh` (optional) The mesh the model has been discretised over. Not neccesary to solve the model when read in, but required to use pybamm's plotting tools. variables: :class:`pybamm.FuzzyDict` (optional) The discretised model varaibles. Not necessary to solve a model, but required to use pybamm's plotting tools. filename: str (optional) The desired name of the JSON file. If no name is provided, one will be created based on the model name, and the current datetime. """ if model.is_discretised is False: raise NotImplementedError( "PyBaMM can only serialise a discretised, ready-to-solve model." ) model_json = { "py/object": str(type(model))[8:-2], "py/id": id(model), "pybamm_version": pybamm.__version__, "name": model.name, "options": model.options, "bounds": [bound.tolist() for bound in model.bounds], # type: ignore[attr-defined] "concatenated_rhs": self._SymbolEncoder().default(model._concatenated_rhs), "concatenated_algebraic": self._SymbolEncoder().default( model._concatenated_algebraic ), "concatenated_initial_conditions": self._SymbolEncoder().default( model._concatenated_initial_conditions ), "events": [self._SymbolEncoder().default(event) for event in model.events], "mass_matrix": self._SymbolEncoder().default(model.mass_matrix), "mass_matrix_inv": self._SymbolEncoder().default(model.mass_matrix_inv), } if mesh: model_json["mesh"] = self._MeshEncoder().default(mesh) if variables: if model._geometry: model_json["geometry"] = self._deconstruct_pybamm_dicts(model._geometry) model_json["variables"] = { k: self._SymbolEncoder().default(v) for k, v in dict(variables).items() } if filename is None: filename = model.name + "_" + datetime.now().strftime("%Y_%m_%d-%p%I_%M") with open(filename + ".json", "w") as f: json.dump(model_json, f)
[docs] def load_model( self, filename: str, battery_model: pybamm.BaseModel | None = None ) -> pybamm.BaseModel: """ Loads a discretised, ready to solve model into PyBaMM. A new pybamm battery model instance will be created, which can be solved and the results plotted as usual. Currently only available for pybamm models which have previously been written out using the `save_model()` option. Warning: This only loads in discretised models. If you wish to make edits to the model or initial conditions, a new model will need to be constructed seperately. Parameters ---------- filename: str Path to the JSON file containing the serialised model file battery_model: :class:`pybamm.BaseModel` (optional) PyBaMM model to be created (e.g. pybamm.lithium_ion.SPM), which will override any model names within the file. If None, the function will look for the saved object path, present if the original model came from PyBaMM. Returns ------- :class:`pybamm.BaseModel` A PyBaMM model object, of type specified either in the JSON or in `battery_model`. """ with open(filename) as f: model_data = json.load(f) recon_model_dict = { "name": model_data["name"], "options": self._convert_options(model_data["options"]), "bounds": tuple(np.array(bound) for bound in model_data["bounds"]), "concatenated_rhs": self._reconstruct_expression_tree( model_data["concatenated_rhs"] ), "concatenated_algebraic": self._reconstruct_expression_tree( model_data["concatenated_algebraic"] ), "concatenated_initial_conditions": self._reconstruct_expression_tree( model_data["concatenated_initial_conditions"] ), "events": [ self._reconstruct_expression_tree(event) for event in model_data["events"] ], "mass_matrix": self._reconstruct_expression_tree(model_data["mass_matrix"]), "mass_matrix_inv": self._reconstruct_expression_tree( model_data["mass_matrix_inv"] ), } recon_model_dict["geometry"] = ( self._reconstruct_pybamm_dict(model_data["geometry"]) if "geometry" in model_data.keys() else None ) recon_model_dict["mesh"] = ( self._reconstruct_mesh(model_data["mesh"]) if "mesh" in model_data.keys() else None ) recon_model_dict["variables"] = ( { k: self._reconstruct_expression_tree(v) for k, v in model_data["variables"].items() } if "variables" in model_data.keys() else None ) if battery_model: return battery_model.deserialise(recon_model_dict) if "py/object" in model_data.keys(): model_framework = self._get_pybamm_class(model_data) return model_framework.deserialise(recon_model_dict) raise TypeError( """ The PyBaMM battery model to use has not been provided. """ )
# Helper functions def _get_pybamm_class(self, snippet: dict): """Find a pybamm class to initialise from object path""" parts = snippet["py/object"].split(".") module = importlib.import_module(".".join(parts[:-1])) class_ = getattr(module, parts[-1]) try: empty_class = self._Empty() empty_class.__class__ = class_ return empty_class except TypeError: # Mesh objects have a different layouts empty_dict_class = self._EmptyDict() empty_dict_class.__class__ = class_ return empty_dict_class def _deconstruct_pybamm_dicts(self, dct: dict): """ Converts dictionaries which contain pybamm classes as keys into a json serialisable format. Dictionary keys present as pybamm objects are given a seperate key as "symbol_<symbol name>" to store the dictionary required to reconstruct a symbol, and their seperate key is used in the original dictionary. E.G: {'rod': {SpatialVariable(name='spat_var'): {"min":0.0, "max":2.0} } } converts to {'rod': {'symbol_spat_var': {"min":0.0, "max":2.0} }, 'spat_var': {"py/object":pybamm....} } Dictionaries which don't contain pybamm symbols are returned unchanged. """ def nested_convert(obj): if isinstance(obj, dict): new_dict = {} for k, v in obj.items(): if isinstance(k, pybamm.Symbol): new_k = self._SymbolEncoder().default(k) new_dict["symbol_" + new_k["name"]] = new_k k = new_k["name"] new_dict[k] = nested_convert(v) return new_dict return obj try: _ = json.dumps(dct) return dict(dct) except TypeError: # dct must contain pybamm objects return nested_convert(dct) def _reconstruct_symbol(self, dct: dict): """Reconstruct an individual pybamm Symbol""" symbol_class = self._get_pybamm_class(dct) symbol = symbol_class._from_json(dct) return symbol def _reconstruct_expression_tree(self, node: dict): """ Loop through an expression tree creating pybamm Symbol classes Conducts post-order tree traversal to turn each tree node into a `pybamm.Symbol` class, starting from leaf nodes without children and working upwards. Parameters ---------- node: dict A node in an expression tree. """ if "children" in node: for i, c in enumerate(node["children"]): child_obj = self._reconstruct_expression_tree(c) node["children"][i] = child_obj elif "expression" in node: expression_obj = self._reconstruct_expression_tree(node["expression"]) node["expression"] = expression_obj obj = self._reconstruct_symbol(node) return obj def _reconstruct_mesh(self, node: dict): """Reconstructs a Mesh object""" if "sub_meshes" in node: for k, v in node["sub_meshes"].items(): sub_mesh = self._reconstruct_symbol(v) node["sub_meshes"][k] = sub_mesh new_mesh = self._reconstruct_symbol(node) return new_mesh def _reconstruct_pybamm_dict(self, obj: dict): """ pybamm.Geometry can contain PyBaMM symbols as dictionary keys. Converts {"rod": {"symbol_spat_var": {"min":0.0, "max":2.0} }, "spat_var": {"py/object":"pybamm...."} } from an exported JSON file to {"rod": {SpatialVariable(name="spat_var"): {"min":0.0, "max":2.0} } } """ def recurse(obj): if isinstance(obj, dict): new_dict = {} for k, v in obj.items(): if "symbol_" in k: new_dict[k] = self._reconstruct_symbol(v) elif isinstance(v, dict): new_dict[k] = recurse(v) else: new_dict[k] = v pattern = re.compile("symbol_") symbol_keys = {k: v for k, v in new_dict.items() if pattern.match(k)} # rearrange the dictionary to make pybamm objects the dictionary keys if symbol_keys: for k, v in symbol_keys.items(): new_dict[v] = new_dict[k.lstrip("symbol_")] del new_dict[k] del new_dict[k.lstrip("symbol_")] return new_dict return obj return recurse(obj) def _convert_options(self, d): """ Converts a dictionary with nested lists to nested tuples, used to convert model options back into correct format """ if isinstance(d, dict): return {k: self._convert_options(v) for k, v in d.items()} elif isinstance(d, list): return tuple(self._convert_options(item) for item in d) else: return d