# Tutorial: coding in JAX


`triangulax` aims to create a triangulation data structure compatible
with the JAX library for automatic differentiation and numerical
computing (see [JAX- the sharp
bits](https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html)).
What does this mean in practice?

1.  Use `jnp` (=`jax.numpy`) instead of `numpy`.
2.  Use a *functional programming* paradigm (pure functions, no side
    effects). Avoid dynamically changing array shapes. For example,
    instead of in-place array modifications, use JAX’s
    `x = x.at[idx].set(y)`.
3.  Use JAX idioms for [control
    flow](https://docs.jax.dev/en/latest/control-flow.html)
4.  *Register* any new classes, so JAX knows how to handle them during
    gradient-computation and just-in-time compilation. See
    [here](https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html#using-jax-jit-with-class-methods)
    and [here](https://docs.jax.dev/en/latest/custom_pytrees.html).

To provide type signatures for all functions (what are the inputs? What
do the array dimensions mean?), we use
[jaxtyping](https://docs.kidger.site/jaxtyping). Later on, we will also
use the `equinox` library, which adds a few useful tools to JAX.

<!-- WARNING: THIS FILE WAS AUTOGENERATED! DO NOT EDIT! -->

### PyTrees

JAX supports not only arrays as inputs/outputs/intermediate variables,
but also [pytrees](https://docs.jax.dev/en/latest/pytrees.html). Pytrees
are nested structures (dictionaries, lists-of-lists, etc.) whose leaves
are “elementary” objects like arrays. Fortunately, our triangulation
data structure classes are already a lot like a pytree - it is a
collection of arrays. For JAX to understand this, we need to register
our classes as a [custom pytree
node](https://docs.jax.dev/en/latest/custom_pytrees.html#pytrees-custom-pytree-nodes)
via `jax.tree_util.register_dataclass`.

Sidenote: Neural networks, in libraries like Flax or Equinox, are
basically very similar. They are dataclass-like classes which hold all
the arrays associated with a NN (the different weights, and maybe some
parameters) with class methods like `__call__` specifying the forward
pass through the NN. Equinox automatically registers your NN as a pytree
by inheriting from the `equinox.Module` class.

### [Control flow](https://docs.jax.dev/en/latest/control-flow.html)

JAX can use just-in-time (JIT) compilation to greatly accelerate your
code. However, JIT compilation only works if your code fullfils certain
requirements. JAX distinguishes two types of variables: dynamic and
static. Control flow cannot depend on the *value* of dynamic variables,
only on their shape.

Upshots: 1. Replace `if` with `jax.lax.cond` / `jnp.where` (full
autodiff compatible), and `while` with `jax.lax.while_loop` (forward
autodiff only). 2. Mark variables which are not going to change during
simulation as static.

### Static array shapes

JAX works best if the *shapes* of arrays do not change during the
computation. For this reason, we (first) focus on triangulations where
the number of vertices does not change. Topological modifications (like
edge flips) are nevertheless possible, as long as they do not change the
number of mesh elements (vertices, edges, and faces).

### Batching

In simulations, we may want to “batch” over several initial
conditions/random seeds/etc. (analogous to batching over training data
in normal ML). In JAX, one can efficiently and concisely vectorize
operations over such “batch axes” with `jax.vmap`.

To batch over our custom data structures, we need to pull a small
trick - convert a list of instances into a single mesh with a batch axis
for the various arrays. Luckily, this can be [done using JAX’s pytree
tools](https://stackoverflow.com/questions/79123001/storing-and-jax-vmap-over-pytrees).

## Simulation loops with `jax.lax.scan`

In simulations, we generally start with an initial state (call it
`init`), do a series of time steps (via a function `make_step(state)`),
and record some “measurement” at each time step (via a `measure(state)`
function). As a result, we get a time series of measurements, and the
final simulation state. In normal python, you would do that with a `for`
loop. When working with JAX, we need to [replace control-flow operations
like `for` with their JAX
pendant](https://docs.jax.dev/en/latest/control-flow.html). For `for`
loops, this is `jax.lax.scan(f, init, xs)`, which is equivalent to the
python code

``` python
def scan(f, init, xs):
  carry = init
  ys = []
  for x in xs:
    carry, y = f(carry, x)
    ys.append(y)
  return carry, np.stack(ys)
```

In our pattern, `xs` is the vector of time-points `timepoints`, and the
“scanning-function” `f` is generally comprised of two parts, a time-step
and a measurement/logging step (above, we logged energy and T1 count):

``` python
def f(carry, t):
    new_state = make_step(carry, t)
    measurements = measure(new_state)
    return new_state, measurements
```

The `carry` variable contains all information about the state of the
simulation. Typically, `carry` is also composed of multiple pieces (the,
the physical state `physical_state`, as well as ancilliary variables
like the ODE solver state `solver_state`). To keep things organized, it
can make sense to define dataclasses for the simulation state and the
measurements, like this (schematic) example:

``` python
@jax.tree_util.register_dataclass
@dataclass
class SimState:
    physical_state: jax.Array
    solver_state: dict  # or another PyTree
    current_time: jax.Array

@jax.tree_util.register_dataclass
@dataclass
class Log:
    energy: float

def scan_function(carry: SimState, next_time: jax.Array) -> tuple[SimState, Log]:
    physical_state, solver_state = make_step(carry.physical_state, carry.solver_state,
                                             carry.current_time, next_time)
    log = Log(energy=compute_energy(physical_state))
    return SimState(physical_state, solver_state), log

timepoints = jnp.arange(t0, t1, dt)
init = ... # define initial condition
final_state, measurements = jax.lax.scan(scan_function, init, timepoints) 
```
