Tutorial: coding in JAX

triangulax aims to create a triangulation data structure compatible with the JAX library for automatic differentiation and numerical computing (see JAX- the sharp bits). What does this mean in practice?

  1. Use jnp (=jax.numpy) instead of numpy.
  2. Use a functional programming paradigm (pure functions, no side effects). Avoid dynamically changing array shapes. For example, instead of in-place array modifications, use JAX’s x = x.at[idx].set(y).
  3. Use JAX idioms for control flow
  4. Register any new classes, so JAX knows how to handle them during gradient-computation and just-in-time compilation. See here and here.

To provide type signatures for all functions (what are the inputs? What do the array dimensions mean?), we use jaxtyping. Later on, we will also use the equinox library, which adds a few useful tools to JAX.

PyTrees

JAX supports not only arrays as inputs/outputs/intermediate variables, but also pytrees. Pytrees are nested structures (dictionaries, lists-of-lists, etc.) whose leaves are “elementary” objects like arrays. Fortunately, our triangulation data structure classes are already a lot like a pytree - it is a collection of arrays. For JAX to understand this, we need to register our classes as a custom pytree node via jax.tree_util.register_dataclass.

Sidenote: Neural networks, in libraries like Flax or Equinox, are basically very similar. They are dataclass-like classes which hold all the arrays associated with a NN (the different weights, and maybe some parameters) with class methods like __call__ specifying the forward pass through the NN. Equinox automatically registers your NN as a pytree by inheriting from the equinox.Module class.

Control flow

JAX can use just-in-time (JIT) compilation to greatly accelerate your code. However, JIT compilation only works if your code fullfils certain requirements. JAX distinguishes two types of variables: dynamic and static. Control flow cannot depend on the value of dynamic variables, only on their shape.

Upshots: 1. Replace if with jax.lax.cond / jnp.where (full autodiff compatible), and while with jax.lax.while_loop (forward autodiff only). 2. Mark variables which are not going to change during simulation as static.

Static array shapes

JAX works best if the shapes of arrays do not change during the computation. For this reason, we (first) focus on triangulations where the number of vertices does not change. Topological modifications (like edge flips) are nevertheless possible, as long as they do not change the number of mesh elements (vertices, edges, and faces).

Batching

In simulations, we may want to “batch” over several initial conditions/random seeds/etc. (analogous to batching over training data in normal ML). In JAX, one can efficiently and concisely vectorize operations over such “batch axes” with jax.vmap.

To batch over our custom data structures, we need to pull a small trick - convert a list of instances into a single mesh with a batch axis for the various arrays. Luckily, this can be done using JAX’s pytree tools.

Simulation loops with jax.lax.scan

In simulations, we generally start with an initial state (call it init), do a series of time steps (via a function make_step(state)), and record some “measurement” at each time step (via a measure(state) function). As a result, we get a time series of measurements, and the final simulation state. In normal python, you would do that with a for loop. When working with JAX, we need to replace control-flow operations like for with their JAX pendant. For for loops, this is jax.lax.scan(f, init, xs), which is equivalent to the python code

def scan(f, init, xs):
  carry = init
  ys = []
  for x in xs:
    carry, y = f(carry, x)
    ys.append(y)
  return carry, np.stack(ys)

In our pattern, xs is the vector of time-points timepoints, and the “scanning-function” f is generally comprised of two parts, a time-step and a measurement/logging step (above, we logged energy and T1 count):

def f(carry, t):
    new_state = make_step(carry, t)
    measurements = measure(new_state)
    return new_state, measurements

The carry variable contains all information about the state of the simulation. Typically, carry is also composed of multiple pieces (the, the physical state physical_state, as well as ancilliary variables like the ODE solver state solver_state). To keep things organized, it can make sense to define dataclasses for the simulation state and the measurements, like this (schematic) example:

@jax.tree_util.register_dataclass
@dataclass
class SimState:
    physical_state: jax.Array
    solver_state: dict  # or another PyTree
    current_time: jax.Array

@jax.tree_util.register_dataclass
@dataclass
class Log:
    energy: float

def scan_function(carry: SimState, next_time: jax.Array) -> tuple[SimState, Log]:
    physical_state, solver_state = make_step(carry.physical_state, carry.solver_state,
                                             carry.current_time, next_time)
    log = Log(energy=compute_energy(physical_state))
    return SimState(physical_state, solver_state), log

timepoints = jnp.arange(t0, t1, dt)
init = ... # define initial condition
final_state, measurements = jax.lax.scan(scan_function, init, timepoints)