triangulax

JAX-compatible triangular meshes and triangular-mesh-based simulations

Overview

This python package provides data-structures for triangular meshes and geometry processing tools based on JAX, fully compatible with just-in-time compilation and automatic differentiation.

Use cases

Triangular meshes are ubiquitous in computer graphics and in scientific computing, for example in the finite-element method.

A major motivation for triangulax are simulations in soft-matter and biophysics, like the mechanics of membranes and two-dimensional tissue sheets (e.g. active tension networks or the self-propelled Voronoi model).

triangulax complements libraries like the libigl python bindings by allowing implementation of custom simulations and geometry operations.

Automatic differentiation

The main feature of triangulax is compatibility with (forward- and reverse-mode) automatic differentiation. This enables computation of gradients of any mesh-based function. Most tools are also compatible with JAX’s JIT-compilation, delivering high performance in high-level Python (rather than C++).

For example, consider:

  1. Flattening or deforming 3D models (computer graphics)
  2. Mechanics of thin plates or membranes (mechanics)
  3. Cell resolved tissue simulations

These tasks revolve around a mesh-based “energy” (like the Dirichlet variational functional, the Helfrich elastic energy, or the Dirichlet functional, or the area-perimeter energy, respectively). JAX automatically computes their gradients, making it easy to optimize energies or to simulate forces.

Gradient-based “meta-optimization” and inverse problems

Since triangulax is fully JAX-compatible, it allows differentiating a simulation w.r.t. its parameters. This means one can apply gradient-based optimization to inverse problems. For example, in the tissue mechanics context, one can ask: what do individual cells need to do so that the tissue as a whole takes on a certain shape?

Developer guide and installation instructions

This package is developed based on Jupyter notebooks, which are converted into python modules using nbdev. Take a look at .github/workflows/copilot-instructions.md for details.

Install triangulax in Development mode

  1. Clone the GitHub repository
$ git clone https://github.com/nikolas-claussen/triangulax.git
  1. Create a conda environment with all Python dependencies
$ conda env create -n triangulax -f triangulax.yml
$ conda activate triangulax
  1. Install the triangulax package
# make sure triangulax package is installed in development mode
$ pip install -e .
  1. If necessary, edit the package notebooks and export
# make changes under nbs/ directory
# ...

# compile to have changes apply to triangulax
$ nbdev_prepare

Documentation

Documentation can be found hosted on this GitHub repository’s pages. Jupyter notebooks tutorials can be found in the nbs/tutorials/ folder.

Usage

triangulax comprises the following modules:

  • triangular: input/output for triangular meshes
  • trigonometry: trigonometry
  • mesh: a half-edge data structure for triangular meshes compatible with JAX.
  • topology: topological modifications (flip, collapse, and split)
  • adjacency, geometry, linops: geometry processing tools

Minimal example

import igl
import jax
import jax.numpy as jnp
from triangulax import mesh, geometry

# load example mesh and convert to half-edge mesh

vertices, _, _, faces, _, _ = igl.readOBJ("test_meshes/disk.obj")
hemesh = mesh.HeMesh.from_triangles(vertices.shape[0], faces)

# with the half-edge mesh, you can carry out various operations, for example
# compute the coordination number by summing incoming half-edges per vertex

coord_number = jnp.zeros(hemesh.n_vertices)
coord_number = coord_number.at[hemesh.dest].add(jnp.ones(hemesh.n_hes))
print("Mean coordination number:", coord_number.mean())

# Let's define a simple geometric function and compute its gradient with JAX

def mean_voronoi_area(vertices, hemesh: mesh.HeMesh) -> float:
    """Compute the mean Voronoi area per vertex."""
    voronoi_areas = geometry.get_voronoi_areas(vertices, hemesh)
    return jnp.mean(voronoi_areas)

value, gradient = jax.value_and_grad(mean_voronoi_area)(vertices, hemesh)
print("Mean gradient norm:", jnp.linalg.norm(gradient, axis=1).mean())
Warning: readOBJ() ignored non-comment line 3:
  o flat_tri_ecmc
Mean coordination number: 5.40458
Mean gradient norm: 0.00036383414