from __future__ import annotations
import pybamm
from datetime import datetime
import json
import importlib
import numpy as np
import re
[docs]
class Serialise:
"""
Converts a discretised model to and from a JSON file.
"""
def __init__(self):
pass
class _SymbolEncoder(json.JSONEncoder):
"""Converts PyBaMM symbols into a JSON-serialisable format"""
def default(self, node: dict):
node_dict = {"py/object": str(type(node))[8:-2], "py/id": id(node)}
if isinstance(node, pybamm.Symbol):
node_dict.update(node.to_json()) # this doesn't include children
node_dict["children"] = []
for c in node.children:
node_dict["children"].append(self.default(c))
if hasattr(node, "initial_condition"): # for ExplicitTimeIntegral
node_dict["initial_condition"] = self.default(
node.initial_condition
)
return node_dict
if isinstance(node, pybamm.Event):
node_dict.update(node.to_json())
node_dict["expression"] = self.default(node._expression)
return node_dict
node_dict["json"] = json.JSONEncoder.default(self, node) # pragma: no cover
return node_dict # pragma: no cover
class _MeshEncoder(json.JSONEncoder):
"""Converts PyBaMM meshes into a JSON-serialisable format"""
def default(self, node: pybamm.Mesh):
node_dict = {"py/object": str(type(node))[8:-2], "py/id": id(node)}
if isinstance(node, pybamm.Mesh):
node_dict.update(node.to_json())
submeshes = {}
for k, v in node.items():
if len(k) == 1 and "ghost cell" not in k[0]:
submeshes[k[0]] = self.default(v)
node_dict["sub_meshes"] = submeshes
return node_dict
if isinstance(node, pybamm.SubMesh):
node_dict.update(node.to_json())
return node_dict
node_dict["json"] = json.JSONEncoder.default(self, node) # pragma: no cover
return node_dict # pragma: no cover
class _Empty:
"""A dummy class to aid deserialisation"""
pass
class _EmptyDict(dict):
"""A dummy dictionary class to aid deserialisation"""
pass
[docs]
def save_model(
self,
model: pybamm.BaseModel,
mesh: pybamm.Mesh | None = None,
variables: pybamm.FuzzyDict | None = None,
filename: str | None = None,
):
"""Saves a discretised model to a JSON file.
As the model is discretised and ready to solve, only the right hand side,
algebraic and initial condition variables are saved.
Parameters
----------
model : :class:`pybamm.BaseModel`
The discretised model to be saved
mesh : :class:`pybamm.Mesh` (optional)
The mesh the model has been discretised over. Not neccesary to solve
the model when read in, but required to use pybamm's plotting tools.
variables: :class:`pybamm.FuzzyDict` (optional)
The discretised model varaibles. Not necessary to solve a model, but
required to use pybamm's plotting tools.
filename: str (optional)
The desired name of the JSON file. If no name is provided, one will be
created based on the model name, and the current datetime.
"""
if model.is_discretised is False:
raise NotImplementedError(
"PyBaMM can only serialise a discretised, ready-to-solve model."
)
model_json = {
"py/object": str(type(model))[8:-2],
"py/id": id(model),
"pybamm_version": pybamm.__version__,
"name": model.name,
"options": model.options,
"bounds": [bound.tolist() for bound in model.bounds], # type: ignore[attr-defined]
"concatenated_rhs": self._SymbolEncoder().default(model._concatenated_rhs),
"concatenated_algebraic": self._SymbolEncoder().default(
model._concatenated_algebraic
),
"concatenated_initial_conditions": self._SymbolEncoder().default(
model._concatenated_initial_conditions
),
"events": [self._SymbolEncoder().default(event) for event in model.events],
"mass_matrix": self._SymbolEncoder().default(model.mass_matrix),
"mass_matrix_inv": self._SymbolEncoder().default(model.mass_matrix_inv),
}
if mesh:
model_json["mesh"] = self._MeshEncoder().default(mesh)
if variables:
if model._geometry:
model_json["geometry"] = self._deconstruct_pybamm_dicts(model._geometry)
model_json["variables"] = {
k: self._SymbolEncoder().default(v) for k, v in dict(variables).items()
}
if filename is None:
filename = model.name + "_" + datetime.now().strftime("%Y_%m_%d-%p%I_%M")
with open(filename + ".json", "w") as f:
json.dump(model_json, f)
[docs]
def load_model(
self, filename: str, battery_model: pybamm.BaseModel | None = None
) -> pybamm.BaseModel:
"""
Loads a discretised, ready to solve model into PyBaMM.
A new pybamm battery model instance will be created, which can be solved
and the results plotted as usual.
Currently only available for pybamm models which have previously been written
out using the `save_model()` option.
Warning: This only loads in discretised models. If you wish to make edits to the
model or initial conditions, a new model will need to be constructed seperately.
Parameters
----------
filename: str
Path to the JSON file containing the serialised model file
battery_model: :class:`pybamm.BaseModel` (optional)
PyBaMM model to be created (e.g. pybamm.lithium_ion.SPM), which will
override any model names within the file. If None, the function will look
for the saved object path, present if the original model came from PyBaMM.
Returns
-------
:class:`pybamm.BaseModel`
A PyBaMM model object, of type specified either in the JSON or in
`battery_model`.
"""
with open(filename) as f:
model_data = json.load(f)
recon_model_dict = {
"name": model_data["name"],
"options": self._convert_options(model_data["options"]),
"bounds": tuple(np.array(bound) for bound in model_data["bounds"]),
"concatenated_rhs": self._reconstruct_expression_tree(
model_data["concatenated_rhs"]
),
"concatenated_algebraic": self._reconstruct_expression_tree(
model_data["concatenated_algebraic"]
),
"concatenated_initial_conditions": self._reconstruct_expression_tree(
model_data["concatenated_initial_conditions"]
),
"events": [
self._reconstruct_expression_tree(event)
for event in model_data["events"]
],
"mass_matrix": self._reconstruct_expression_tree(model_data["mass_matrix"]),
"mass_matrix_inv": self._reconstruct_expression_tree(
model_data["mass_matrix_inv"]
),
}
recon_model_dict["geometry"] = (
self._reconstruct_pybamm_dict(model_data["geometry"])
if "geometry" in model_data.keys()
else None
)
recon_model_dict["mesh"] = (
self._reconstruct_mesh(model_data["mesh"])
if "mesh" in model_data.keys()
else None
)
recon_model_dict["variables"] = (
{
k: self._reconstruct_expression_tree(v)
for k, v in model_data["variables"].items()
}
if "variables" in model_data.keys()
else None
)
if battery_model:
return battery_model.deserialise(recon_model_dict)
if "py/object" in model_data.keys():
model_framework = self._get_pybamm_class(model_data)
return model_framework.deserialise(recon_model_dict)
raise TypeError(
"""
The PyBaMM battery model to use has not been provided.
"""
)
# Helper functions
def _get_pybamm_class(self, snippet: dict):
"""Find a pybamm class to initialise from object path"""
parts = snippet["py/object"].split(".")
module = importlib.import_module(".".join(parts[:-1]))
class_ = getattr(module, parts[-1])
try:
empty_class = self._Empty()
empty_class.__class__ = class_
return empty_class
except TypeError:
# Mesh objects have a different layouts
empty_dict_class = self._EmptyDict()
empty_dict_class.__class__ = class_
return empty_dict_class
def _deconstruct_pybamm_dicts(self, dct: dict):
"""
Converts dictionaries which contain pybamm classes as keys
into a json serialisable format.
Dictionary keys present as pybamm objects are given a seperate key
as "symbol_<symbol name>" to store the dictionary required to reconstruct
a symbol, and their seperate key is used in the original dictionary. E.G:
{'rod':
{SpatialVariable(name='spat_var'): {"min":0.0, "max":2.0} }
}
converts to
{'rod':
{'symbol_spat_var': {"min":0.0, "max":2.0} },
'spat_var':
{"py/object":pybamm....}
}
Dictionaries which don't contain pybamm symbols are returned unchanged.
"""
def nested_convert(obj):
if isinstance(obj, dict):
new_dict = {}
for k, v in obj.items():
if isinstance(k, pybamm.Symbol):
new_k = self._SymbolEncoder().default(k)
new_dict["symbol_" + new_k["name"]] = new_k
k = new_k["name"]
new_dict[k] = nested_convert(v)
return new_dict
return obj
try:
_ = json.dumps(dct)
return dict(dct)
except TypeError: # dct must contain pybamm objects
return nested_convert(dct)
def _reconstruct_symbol(self, dct: dict):
"""Reconstruct an individual pybamm Symbol"""
symbol_class = self._get_pybamm_class(dct)
symbol = symbol_class._from_json(dct)
return symbol
def _reconstruct_expression_tree(self, node: dict):
"""
Loop through an expression tree creating pybamm Symbol classes
Conducts post-order tree traversal to turn each tree node into a
`pybamm.Symbol` class, starting from leaf nodes without children and
working upwards.
Parameters
----------
node: dict
A node in an expression tree.
"""
if "children" in node:
for i, c in enumerate(node["children"]):
child_obj = self._reconstruct_expression_tree(c)
node["children"][i] = child_obj
elif "expression" in node:
expression_obj = self._reconstruct_expression_tree(node["expression"])
node["expression"] = expression_obj
obj = self._reconstruct_symbol(node)
return obj
def _reconstruct_mesh(self, node: dict):
"""Reconstructs a Mesh object"""
if "sub_meshes" in node:
for k, v in node["sub_meshes"].items():
sub_mesh = self._reconstruct_symbol(v)
node["sub_meshes"][k] = sub_mesh
new_mesh = self._reconstruct_symbol(node)
return new_mesh
def _reconstruct_pybamm_dict(self, obj: dict):
"""
pybamm.Geometry can contain PyBaMM symbols as dictionary keys.
Converts
{"rod":
{"symbol_spat_var":
{"min":0.0, "max":2.0} },
"spat_var":
{"py/object":"pybamm...."}
}
from an exported JSON file to
{"rod":
{SpatialVariable(name="spat_var"): {"min":0.0, "max":2.0} }
}
"""
def recurse(obj):
if isinstance(obj, dict):
new_dict = {}
for k, v in obj.items():
if "symbol_" in k:
new_dict[k] = self._reconstruct_symbol(v)
elif isinstance(v, dict):
new_dict[k] = recurse(v)
else:
new_dict[k] = v
pattern = re.compile("symbol_")
symbol_keys = {k: v for k, v in new_dict.items() if pattern.match(k)}
# rearrange the dictionary to make pybamm objects the dictionary keys
if symbol_keys:
for k, v in symbol_keys.items():
new_dict[v] = new_dict[k.lstrip("symbol_")]
del new_dict[k]
del new_dict[k.lstrip("symbol_")]
return new_dict
return obj
return recurse(obj)
def _convert_options(self, d):
"""
Converts a dictionary with nested lists to nested tuples,
used to convert model options back into correct format
"""
if isinstance(d, dict):
return {k: self._convert_options(v) for k, v in d.items()}
elif isinstance(d, list):
return tuple(self._convert_options(item) for item in d)
else:
return d