Source code for pybamm.solvers.idaklu_jax

import pybamm
import numpy as np
import logging
import warnings
import numbers

from typing import Union

from functools import lru_cache

import importlib.util
import importlib

logger = logging.getLogger("pybamm.solvers.idaklu_jax")

idaklu_spec = importlib.util.find_spec("pybamm.solvers.idaklu")
if idaklu_spec is not None:
        idaklu = importlib.util.module_from_spec(idaklu_spec)
        if idaklu_spec.loader:
    except ImportError:  # pragma: no cover
        idaklu_spec = None

if pybamm.have_jax():
    import jax
    from jax import lax
    from jax import numpy as jnp
    from jax.interpreters import ad
    from jax.interpreters import mlir
    from jax.interpreters import batching
    from jax.interpreters.mlir import custom_call
    from jax.lib import xla_client
    from jax.tree_util import tree_flatten

[docs] class IDAKLUJax: """JAX wrapper for IDAKLU solver Objects of this class should be created via an IDAKLUSolver object. Log information is available for this module via the named 'pybamm.solvers.idaklu_jax' logger. Parameters ---------- solver : :class:`pybamm.IDAKLUSolver` The IDAKLU solver object to be wrapped """ def __init__( self, solver, model, t_eval, output_variables=None, calculate_sensitivities=True, ): if not pybamm.have_jax(): raise ModuleNotFoundError( "Jax or jaxlib is not installed, please see" ) # pragma: no cover if not pybamm.have_idaklu(): raise ModuleNotFoundError( "IDAKLU is not installed, please see" ) # pragma: no cover self.jaxpr = ( None # JAX expression representing the IDAKLU-wrapped solver object ) self.idaklu_jax_obj = None # Low-level IDAKLU-JAX primitives object self.solver = solver # Originating IDAKLU Solver object self.jax_inputs = {k: [] for k, v in model.get_parameter_info().items()} # JAXify the solver ready for use self.jaxify( model, t_eval, output_variables=output_variables, calculate_sensitivities=calculate_sensitivities, )
[docs] def get_jaxpr(self): """Returns a JAX expression representing the IDAKLU-wrapped solver object Returns ------- Callable A JAX expression with the following call signature: f(t, inputs=None) where: t : float | np.ndarray Time sample or vector of time samples inputs : dict, optional dictionary of input values, e.g. {'Current function [A]': 0.222, 'Separator porosity': 0.3} """ if self.jaxpr is None: raise pybamm.SolverError("jaxify() must be called before get_jaxpr()") return self.jaxpr
[docs] def get_var( self, *args, ): """Helper function to extract a single variable Isolates a single variable from the model output. Can be called on a JAX expression (which returns a JAX expression), or on a numeric (np.ndarray) object (which returns a slice of the output). Example call using default JAX expression, returns a JAX expression:: f = idaklu_jax.get_var("Voltage [V]") data = f(t, inputs=None) Example call using a custom function, returns a JAX expression:: f = idaklu_jax.get_var(jax.jit(f), "Voltage [V]") data = f(t, inputs=None) Example call to slice a matrix, returns an np.array:: data = idaklu_jax.get_var( jax.fwd(f, argnums=1)(t_eval, inputs)['Current function [A]'], 'Voltage [V]' ) Parameters ---------- f : Callable | np.ndarray, optional Expression or array from which to extract the target variable varname : str The name of the variable to extract Returns ------- Callable If called with a JAX expression, returns a JAX expression with the following call signature: f(t, inputs=None) where: t : float | np.ndarray Time sample or vector of time samples inputs : dict, optional dictionary of input values, e.g. {'Current function [A]': 0.222, 'Separator porosity': 0.3} np.ndarray If called with a numeric (np.ndarray) object, returns a slice of the output corresponding to the target variable. """ # Identify the call signature if len(args) == 1: # Called on the default JAX expression f = self.jaxpr varname = args[0] elif len(args) == 2: # Called on a custom function f = args[0] varname = args[1] else: raise ValueError("Invalid call signature") # Utility function to slice the output def slice_out(out): index = self.jax_output_variables.index(varname) if out.ndim == 0: return out # pragma: no cover elif out.ndim == 1: return out[index] else: return out[:, index] # If the jaxified expression is not a function, return a slice if not callable(f): return slice_out(f) # Otherwise, return a function that slices the output def f_isolated(*args, **kwargs): return slice_out(self.jaxify_f(*args, **kwargs)) return f_isolated
[docs] def get_vars( self, *args, ): """Helper function to extract a list of variables Isolates a list of variables from the model output. Can be called on a JAX expression (which returns a JAX expression), or on a numeric (np.ndarray) object (which returns a slice of the output). Example call using default JAX expression, returns a JAX expression:: f = idaklu_jax.get_vars(["Voltage [V]", "Current [A]"]) data = f(t, inputs=None) Example call using a custom function, returns a JAX expression:: f = idaklu_jax.get_vars(jax.jit(f), ["Voltage [V]", "Current [A]"]) data = f(t, inputs=None) Example call to slice a matrix, returns an np.array:: data = idaklu_jax.get_vars( jax.fwd(f, argnums=1)(t_eval, inputs)['Current function [A]'], ["Voltage [V]", "Current [A]"] ) Parameters ---------- f : Callable | np.ndarray, optional Expression or array from which to extract the target variables varname : list of str The names of the variables to extract Returns ------- Callable If called with a JAX expression, returns a JAX expression with the following call signature: f(t, inputs=None) where: t : float | np.ndarray Time sample or vector of time samples inputs : dict, optional dictionary of input values, e.g. {'Current function [A]': 0.222, 'Separator porosity': 0.3} np.ndarray If called with a numeric (np.ndarray) object, returns a slice of the output corresponding to the target variables. """ # Identify the call signature if len(args) == 1: # Called on the default JAX expression f = self.jaxpr varnames = args[0] elif len(args) == 2: # Called on a custom function f = args[0] varnames = args[1] else: raise ValueError("Invalid call signature") # Utility function to slice the output def slice_out(out): index = np.array( [self.jax_output_variables.index(varname) for varname in varnames] ) if out.ndim == 0: return out # pragma: no cover elif out.ndim == 1: return out[index] else: return out[:, index] # If the jaxified expression is not a function, return a slice if not callable(f): return slice_out(f) # Otherwise, return a function that slices the output def f_isolated(*args, **kwargs): return slice_out(self.jaxify_f(*args, **kwargs)) return f_isolated
[docs] def jax_value( self, t: np.ndarray = None, inputs: Union[dict, None] = None, output_variables: Union[list[str], None] = None, ): """Helper function to compute the gradient of a jaxified expression Returns a numeric (np.ndarray) object (not a JAX expression). Parameters are inferred from the base object, but can be overridden. Parameters ---------- t : float | np.ndarray Time sample or vector of time samples inputs : dict dictionary of input values output_variables : list of str, optional The variables to be returned. If None, the variables in the model are used. """ if self.jaxpr is None: raise pybamm.SolverError("jaxify() must be called before get_jaxpr()") output_variables = ( output_variables if output_variables else self.jax_output_variables ) d = {} for outvar in output_variables: d[outvar] = jax.vmap( self.get_var(outvar), in_axes=(0, None), )(t, inputs) return d
[docs] def jax_grad( self, t: np.ndarray = None, inputs: Union[dict, None] = None, output_variables: Union[list[str], None] = None, ): """Helper function to compute the gradient of a jaxified expression Returns a numeric (np.ndarray) object (not a JAX expression). Parameters are inferred from the base object, but can be overridden. Parameters ---------- t : float | np.ndarray Time sample or vector of time samples inputs : dict dictionary of input values output_variables : list of str, optional The variables to be returned. If None, the variables in the model are used. """ if self.jaxpr is None: raise pybamm.SolverError("jaxify() must be called before get_jaxpr()") output_variables = ( output_variables if output_variables else self.jax_output_variables ) d = {} for outvar in output_variables: d[outvar] = jax.vmap( jax.grad( self.get_var(outvar), argnums=1, ), in_axes=(0, None), )(t, inputs) return d
class _hashabledict(dict): def __hash__(self): return hash(tuple(sorted(self.items()))) @lru_cache(maxsize=1) # noqa: B019 def _cached_solve(self, model, t_hashable, *args, **kwargs): """Cache the last solve for reuse""" return self.solve(model, t_hashable, *args, **kwargs) def _jaxify_solve(self, t, invar, *inputs_values): """Solve the model using the IDAKLU solver This method is called by the JAX primitive definition and caches the last solve fo reuse. """ # Reconstruct dictionary of inputs if self.jax_inputs is None: d = self._hashabledict() else: # Use hashable dictionaries for caching the solve d = self._hashabledict() for key, value in zip(self.jax_inputs.keys(), inputs_values): d[key] = value # Solver logger.debug("_jaxify_solve:") logger.debug(f" t_eval: {self.jax_t_eval}") logger.debug(f" t: {t}") logger.debug(f" invar: {invar}") logger.debug(f" inputs: {dict(d)}") logger.debug(f" calculate_sensitivities: {invar is not None}") sim = IDAKLUJax._cached_solve( self.solver, self.jax_model, tuple(self.jax_t_eval), inputs=self._hashabledict(d), calculate_sensitivities=self.jax_calculate_sensitivities, ) if invar is not None: if isinstance(invar, numbers.Number): invar = list(self.jax_inputs.keys())[invar] # Provide vector support for time if t.ndim == 0: t = np.array([t]) tk = list(map(lambda t: np.argmin(abs(self.jax_t_eval - t)), t)) out = jnp.array( [ jnp.array(sim[outvar].sensitivities[invar][tk]) for outvar in self.jax_output_variables ] ).squeeze() return out.T else: return jnp.array( [np.array(sim[outvar](t)) for outvar in self.jax_output_variables] ).T def _jax_solve_array_inputs(self, t, inputs_array): """Wrapper for _jax_solve used by IDAKLU callback This version assumes all parameters are provided as np.ndarray vectors """"jax_solve_array_inputs") logger.debug(f" t: {type(t)}, {t}") logger.debug(f" inputs_array: {type(inputs_array)}, {inputs_array}") inputs = tuple([k for k in inputs_array]) logger.debug(f" inputs: {type(inputs)}, {inputs}") return self._jax_solve(t, *inputs) def _jax_solve( self, t: Union[float, np.ndarray], *inputs, ) -> np.ndarray: """Solver implementation used by f-bind""""jax_solve") logger.debug(f" t: {type(t)}, {t}") logger.debug(f" inputs: {type(inputs)}, {inputs}") # Returns a jax array out = self._jaxify_solve(t, None, *inputs) # Convert to numpy array return np.array(out) def _jax_jvp_impl( self, *args: Union[np.ndarray], ): """JVP implementation used by f_jvp bind""" primals = args[: len(args) // 2] tangents = args[len(args) // 2 :] t = primals[0] inputs = primals[1:] inputs_t = tangents[1:] if t.ndim == 0: y_dot = jnp.zeros_like(t) else: # This permits direct vector indexing with time for jacfwd y_dot = jnp.zeros((len(t), len(self.jax_output_variables))) for index, value in enumerate(inputs_t): # Skipping zero values greatly improves performance if value > 0.0: invar = list(self.jax_inputs.keys())[index] js = self._jaxify_solve(t, invar, *inputs) if js.ndim == 0: js = jnp.array([js]) if js.ndim == 1 and t.ndim > 0: # This permits direct vector indexing with time js = js.reshape((t.shape[0], -1)) y_dot += value * js return np.array(y_dot) def _jax_jvp_impl_array_inputs( self, primal_t, primal_inputs, tangent_t, tangent_inputs, ): """Wrapper for JVP implementation used by IDAKLU callback This version assumes all parameters are provided as np.ndarray vectors """ primals = primal_t, *tuple([k for k in primal_inputs]) tangents = tangent_t, *tuple([k for k in tangent_inputs]) return self._jax_jvp_impl(*primals, *tangents) def _jax_vjp_impl( self, y_bar: np.ndarray, invar: Union[str, int], # index or name of input variable *primals: np.ndarray, ): """VJP implementation used by f_vjp bind""""py:f_vjp_p_impl") logger.debug(f" py:y_bar: {type(y_bar)}, {y_bar}") logger.debug(f" py:invar: {type(invar)}, {invar}") logger.debug(f" py:primals: {type(primals)}, {primals}") t = primals[0] inputs = primals[1:] if isinstance(invar, float): invar = round(invar) if isinstance(t, float): t = np.array(t) if t.ndim == 0 or (t.ndim == 1 and t.shape[0] == 1): # scalar time input logger.debug("scalar time") y_dot = jnp.zeros_like(t) js = self._jaxify_solve(t, invar, *inputs) if js.ndim == 0: js = jnp.array([js]) for index, value in enumerate(y_bar): if value > 0.0: y_dot += value * js[index] else: logger.debug("vector time") # vector time input js = self._jaxify_solve(t, invar, *inputs) if len(self.jax_output_variables) == 1 and len(t) > 1: js = np.array([js]).T y_dot = jnp.zeros(()) for ix, y_outvar in enumerate(y_bar.T): y_dot +=, js[:, ix]) logger.debug(f"_jax_vjp_impl [exit]: {type(y_dot)}, {y_dot}, {y_dot.shape}") y_dot = np.array(y_dot) return y_dot def _jax_vjp_impl_array_inputs( self, y_bar, y_bar_s0, y_bar_s1, invar, primal_t, primal_inputs, ): """Wrapper for VJP implementation used by IDAKLU callback This version assumes all parameters are provided as np.ndarray vectors """ # Reshape y_bar logger.debug(f"Reshaping y_bar to ({y_bar_s0}, {y_bar_s1})") y_bar = y_bar.reshape(y_bar_s0, y_bar_s1) logger.debug(f"y_bar is now: {y_bar}") primals = primal_t, *tuple([k for k in primal_inputs]) return self._jax_vjp_impl(y_bar, invar, *primals) def _register_callbacks(self): """Register the solve method with the IDAKLU solver""""_register_callbacks") self.idaklu_jax_obj.register_callbacks( self._jax_solve_array_inputs, self._jax_jvp_impl_array_inputs, self._jax_vjp_impl_array_inputs, ) def _unique_name(self): """Return a unique name for this solver object for naming the JAX primitives""" return f"{self.idaklu_jax_obj.get_index()}"
[docs] def jaxify( self, model, t_eval, *, output_variables=None, calculate_sensitivities=True, ): """JAXify the model and solver Creates a JAX expression representing the IDAKLU-wrapped solver object. Parameters ---------- model : :class:`pybamm.BaseModel` The model to be solved t_eval : numeric type, optional The times at which to compute the solution. If None, the times in the model are used. output_variables : list of str, optional The variables to be returned. If None, the variables in the model are used. calculate_sensitivities : bool, optional Whether to calculate sensitivities. Default is True. """ if self.jaxpr is not None: warnings.warn( "JAX expression has already been created. " "Overwriting with new expression.", UserWarning, stacklevel=2, ) self.jaxpr = self._jaxify( model, t_eval, output_variables=output_variables, calculate_sensitivities=calculate_sensitivities, ) return self.jaxpr
def _jaxify( self, model, t_eval, *, output_variables=None, calculate_sensitivities=True, ): """JAXify the model and solver""" self.jax_model = model self.jax_t_eval = t_eval self.jax_output_variables = ( output_variables if output_variables else self.solver.output_variables ) if not self.jax_output_variables: raise pybamm.SolverError("output_variables must be specified") self.jax_calculate_sensitivities = calculate_sensitivities self.idaklu_jax_obj = idaklu.create_idaklu_jax() # Create IDAKLU-JAX object self._register_callbacks() # Register python methods as callbacks in IDAKLU-JAX for _name, _value in idaklu.registrations().items(): xla_client.register_custom_call_target( f"{_name}_{self._unique_name()}", _value, platform="cpu" ) # --- JAX PRIMITIVE DEFINITION ------------------------------------------------ logger.debug(f"Creating new primitive: {self._unique_name()}") f_p = jax.core.Primitive(f"f_{self._unique_name()}") f_p.multiple_results = False # Returns a single multi-dimensional array def f(t, inputs=None): """Main function wrapper for the JAX primitive function Parameters ---------- t : float | np.ndarray Time sample or vector of time samples inputs : dict, optional dictionary of input values, e.g. {'Current function [A]': 0.222, 'Separator porosity': 0.3} """"f") flatargs, treedef = tree_flatten((t, inputs)) self.jax_inputs = inputs out = f_p.bind(*flatargs) return out self.jaxify_f = f @f_p.def_impl def f_impl(t, *inputs): """Concrete implementation of Primitive (used for non-jitted evaluation)""""f_impl") term_v = self._jaxify_solve(t, None, *inputs) logger.debug(f"f_impl [exit]: {type(term_v)}, {term_v}") return term_v @f_p.def_abstract_eval def f_abstract_eval(t, *inputs): """Abstract evaluation of Primitive""""f_abstract_eval") dtype = jax.dtypes.canonicalize_dtype(t.dtype) y_aval = jax.core.ShapedArray( (*t.shape, len(self.jax_output_variables)), dtype ) return y_aval def f_batch(args, batch_axes): """Batch rule for Primitive Takes batched inputs, returns batched outputs and batched axes""""f_batch: {type(args)}, {type(batch_axes)}") t = args[0] inputs = args[1:] if batch_axes[0] is not None and all([b is None for b in batch_axes[1:]]): # Temporal batching return jnp.stack(list(map(lambda tp: f_p.bind(tp, *inputs), t))), 0 else: raise NotImplementedError( f"jaxify: batching not implemented for batch_axes = {batch_axes}" ) batching.primitive_batchers[f_p] = f_batch def f_lowering_cpu(ctx, t, *inputs): """CPU lowering rule for Primitive This function calls the IDAKLU-JAX custom call target, which reroutes the call to the python callbacks, which call the standard IDAKLU solver. """"f_lowering_cpu") t_aval = ctx.avals_in[0] np_dtype = np.dtype(t_aval.dtype) if np_dtype == np.float64: op_name = f"cpu_idaklu_f64_{self._unique_name()}" op_dtype = else: raise NotImplementedError( f"Unsupported dtype {np_dtype}" ) # pragma: no cover dtype_t = dims_t = dtype_t.shape layout_t = tuple(range(len(dims_t) - 1, -1, -1)) size_t = input_aval = ctx.avals_in[1] dtype_input =, op_dtype) dims_input = dtype_input.shape layout_input = tuple(range(len(dims_input) - 1, -1, -1)) y_aval = ctx.avals_out[0] dtype_out =, op_dtype) dims_out = dtype_out.shape layout_out = tuple(range(len(dims_out) - 1, -1, -1)) results = custom_call( op_name, # Output types result_types=[dtype_out], # The inputs operands=[ mlir.ir_constant( self.idaklu_jax_obj.get_index() ), # solver index reference mlir.ir_constant(size_t), # 'size' argument mlir.ir_constant(len(self.jax_output_variables)), # 'vars' argument mlir.ir_constant(len(inputs)), # 'vars' argument t, *inputs, ], # Layout specification operand_layouts=[ (), # solver index reference (), # 'size' (), # 'vars' (), # number of inputs layout_t, # t *([layout_input] * len(inputs)), # inputs ], result_layouts=[layout_out], ) return results.results mlir.register_lowering( f_p, f_lowering_cpu, platform="cpu", ) # --- JAX PRIMITIVE JVP DEFINITION -------------------------------------------- def f_jvp(primals, tangents): """Main wrapper for the JVP function""""f_jvp") # Deal with Zero tangents def make_zero(prim, tan): return lax.zeros_like_array(prim) if type(tan) is ad.Zero else tan zero_mapped_tangents = tuple( map(lambda pt: make_zero(pt[0], pt[1]), zip(primals, tangents)) ) y = f_p.bind(*primals) y_dot = f_jvp_p.bind( *primals, *zero_mapped_tangents, ) logger.debug(f"f_jvp [exit]: {type(y)}, {y}, {type(y_dot)}, {y_dot}") return y, y_dot ad.primitive_jvps[f_p] = f_jvp f_jvp_p = jax.core.Primitive(f"f_jvp_{self._unique_name()}") @f_jvp_p.def_impl def f_jvp_eval(*args): """Concrete implementation of JVP primitive (for non-jitted evaluation)""""f_jvp_p_eval: {type(args)}") return self._jax_jvp_impl(*args) def f_jvp_batch(args, batch_axes): """Batch rule for JVP primitive""""f_jvp_batch") primals = args[: len(args) // 2] tangents = args[len(args) // 2 :] batch_primals = batch_axes[: len(batch_axes) // 2] batch_tangents = batch_axes[len(batch_axes) // 2 :] if ( batch_primals[0] is not None and all([b is None for b in batch_primals[1:]]) and all([b is None for b in batch_tangents]) ): # Temporal batching (primals) only t = primals[0] inputs = primals[1:] return ( jnp.stack( list(map(lambda tp: f_jvp_p.bind(tp, *inputs, *tangents), t)) ), 0, ) elif ( batch_tangents[0] is not None and all([b is None for b in batch_tangents[1:]]) and all([b is None for b in batch_primals]) ): # Batch over derivates wrt time raise NotImplementedError( "Taking the derivative with respect to time is not supported" ) elif ( batch_tangents[0] is None and any([b is not None for b in batch_tangents[1:]]) and all([b is None for b in batch_primals]) ): # Batch over (some combination of) inputs batch_axis_indices = [ i for i, b in enumerate(batch_tangents) if b is not None ] out = [] for i in range(len(batch_axis_indices)): tangents_item = list(tangents) for k in range(len(batch_axis_indices)): tangents_item[batch_axis_indices[k]] = tangents[ batch_axis_indices[k] ][i] out.append(f_jvp_p.bind(*primals, *tangents_item)) return jnp.stack(out), 0 else: raise NotImplementedError( "f_jvp_batch: batching not implemented for batch_axes = " f"{batch_axes}" ) # pragma: no cover batching.primitive_batchers[f_jvp_p] = f_jvp_batch @f_jvp_p.def_abstract_eval def f_jvp_abstract_eval(*args): """Abstract evaluation of JVP primitive""""f_jvp_abstract_eval") primals = args[: len(args) // 2] t = primals[0] out = jax.core.ShapedArray( (*t.shape, len(self.jax_output_variables)), t.dtype )"<- f_jvp_abstract_eval") return out def f_jvp_transpose(y_bar, *args): """Transpose rule for JVP primitive""" # Note: y_bar indexes the OUTPUT variable, e.g. [1, 0, 0] is the # first of three outputs. The function returns primals and tangents # corresponding to how each of the inputs derives that output, e.g. # (..., dout/din1, dout/din2)"f_jvp_transpose") primals = args[: len(args) // 2] tangents_out = [] for invar in self.jax_inputs.keys(): js = f_vjp(y_bar, invar, *primals) tangents_out.append(js) out = ( None, *([None] * len(tangents_out)), # primals None, *tangents_out, # tangents ) logger.debug("<- f_jvp_transpose") return out ad.primitive_transposes[f_jvp_p] = f_jvp_transpose def f_jvp_lowering_cpu(ctx, *args): """CPU lowering rule for JVP primitive""""f_jvp_lowering_cpu") primals = args[: len(args) // 2] t_primal = primals[0] inputs_primals = primals[1:] tangents = args[len(args) // 2 :] t_tangent = tangents[0] inputs_tangents = tangents[1:] t_aval = ctx.avals_in[0] np_dtype = np.dtype(t_aval.dtype) if np_dtype == np.float64: op_name = f"cpu_idaklu_jvp_f64_{self._unique_name()}" op_dtype = else: raise NotImplementedError( f"Unsupported dtype {np_dtype}" ) # pragma: no cover dtype_t = dims_t = dtype_t.shape layout_t_primal = tuple(range(len(dims_t) - 1, -1, -1)) layout_t_tangent = layout_t_primal size_t = input_aval = ctx.avals_in[1] dtype_input =, op_dtype) dims_input = dtype_input.shape layout_inputs_primals = tuple(range(len(dims_input) - 1, -1, -1)) layout_inputs_tangents = layout_inputs_primals y_aval = ctx.avals_out[0] dtype_out =, op_dtype) dims_out = dtype_out.shape layout_out = tuple(range(len(dims_out) - 1, -1, -1)) results = custom_call( op_name, # Output types result_types=[dtype_out], # The inputs operands=[ mlir.ir_constant( self.idaklu_jax_obj.get_index() ), # solver index reference mlir.ir_constant(size_t), # 'size' argument mlir.ir_constant(len(self.jax_output_variables)), # 'vars' argument mlir.ir_constant(len(inputs_primals)), # 'vars' argument t_primal, # 't' *inputs_primals, # inputs t_tangent, # 't' *inputs_tangents, # inputs ], # Layout specification operand_layouts=[ (), # solver index reference (), # 'size' (), # 'vars' (), # number of inputs layout_t_primal, # 't' *([layout_inputs_primals] * len(inputs_primals)), # inputs layout_t_tangent, # 't' *([layout_inputs_tangents] * len(inputs_tangents)), # inputs ], result_layouts=[layout_out], ) return results.results mlir.register_lowering( f_jvp_p, f_jvp_lowering_cpu, platform="cpu", ) # --- JAX PRIMITIVE VJP DEFINITION -------------------------------------------- f_vjp_p = jax.core.Primitive(f"f_vjp_{self._unique_name()}") def f_vjp(y_bar, invar, *primals): """Main wrapper for the VJP function""""f_vjp") logger.debug(f" y_bar: {y_bar}, {type(y_bar)}, {y_bar.shape}") if isinstance(invar, str): invar = list(self.jax_inputs.keys()).index(invar) return f_vjp_p.bind(y_bar, invar, *primals) @f_vjp_p.def_impl def f_vjp_impl(y_bar, invar, *primals): """Concrete implementation of VJP primitive (for non-jitted evaluation)""""f_vjp_impl") return self._jax_vjp_impl(y_bar, invar, *primals) @f_vjp_p.def_abstract_eval def f_vjp_abstract_eval(*args): """Abstract evaluation of VJP primitive""""f_vjp_abstract_eval") primals = args[: len(args) // 2] t = primals[0] out = jax.core.ShapedArray((), t.dtype) logger.debug("<- f_vjp_abstract_eval") return out def f_vjp_batch(args, batch_axes): """Batch rule for VJP primitive""""f_vjp_p_batch") y_bars, invar, t, *inputs = args if batch_axes[0] is not None and all([b is None for b in batch_axes[1:]]): # Batch over y_bar out = list(map(lambda yb: f_vjp(yb, invar, t, *inputs), y_bars)) return jnp.stack(out), 0 elif ( batch_axes[2] is not None and all([b is None for b in batch_axes[:2]]) and all([b is None for b in batch_axes[3:]]) ): # Batch over time out = list(map(lambda yt: f_vjp(y_bars, invar, yt, *inputs), t)) return jnp.stack(out), 0 else: raise Exception( "Batch mode not supported for batch_axes = ", batch_axes ) # pragma: no cover batching.primitive_batchers[f_vjp_p] = f_vjp_batch def f_vjp_lowering_cpu(ctx, y_bar, invar, *primals): """CPU lowering rule for VJP primitive""""f_vjp_lowering_cpu") t_primal = primals[0] inputs_primals = primals[1:] t_aval = ctx.avals_in[2] np_dtype = np.dtype(t_aval.dtype) if np_dtype == np.float64: op_name = f"cpu_idaklu_vjp_f64_{self._unique_name()}" op_dtype = else: raise NotImplementedError( f"Unsupported dtype {np_dtype}" ) # pragma: no cover y_bar_aval = ctx.avals_in[0] dtype_y_bar =, op_dtype) dims_y_bar = dtype_y_bar.shape logger.debug(f" y_bar shape: {dims_y_bar}") layout_y_bar = tuple(range(len(dims_y_bar) - 1, -1, -1)) invar_aval = ctx.avals_in[1] dtype_invar =, op_dtype) dims_invar = dtype_invar.shape layout_invar = tuple(range(len(dims_invar) - 1, -1, -1)) dtype_t = dims_t = dtype_t.shape layout_t_primal = tuple(range(len(dims_t) - 1, -1, -1)) size_t = input_aval = ctx.avals_in[3] dtype_input =, op_dtype) dims_input = dtype_input.shape layout_inputs_primals = tuple(range(len(dims_input) - 1, -1, -1)) y_aval = ctx.avals_out[0] dtype_out =, op_dtype) dims_out = dtype_out.shape layout_out = tuple(range(len(dims_out) - 1, -1, -1)) results = custom_call( op_name, # Output types result_types=[dtype_out], # The inputs operands=[ mlir.ir_constant( self.idaklu_jax_obj.get_index() ), # solver index reference mlir.ir_constant(size_t), # 'size' argument mlir.ir_constant(len(self.jax_inputs)), # number of inputs mlir.ir_constant(dims_y_bar[0]), # 'y_bar' shape[0] mlir.ir_constant( # 'y_bar' shape[1] dims_y_bar[1] if len(dims_y_bar) > 1 else -1 ), # 'y_bar' argument y_bar, # 'y_bar' invar, # 'invar' t_primal, # 't' *inputs_primals, # inputs ], # Layout specification operand_layouts=[ (), # solver index reference (), # 'size' (), # number of inputs (), # 'y_bar' shape[0] (), # 'y_bar' shape[1] layout_y_bar, # 'y_bar' layout_invar, # 'invar' layout_t_primal, # 't' *([layout_inputs_primals] * len(inputs_primals)), # inputs ], result_layouts=[layout_out], ) return results.results mlir.register_lowering( f_vjp_p, f_vjp_lowering_cpu, platform="cpu", ) return f