JAX-compatible triangular meshes and triangular-mesh-based simulations
Overview
This python package provides data-structures for triangular meshes and geometry processing tools based on JAX, fully compatible with just-in-time compilation and automatic differentiation.
Use cases
Triangular meshes are ubiquitous in computer graphics and in scientific computing, for example in the finite-element method.
A major motivation for triangulax are simulations in soft-matter and biophysics, like the mechanics of membranes and two-dimensional tissue sheets (e.g. active tension networks or the self-propelled Voronoi model).
triangulax complements libraries like the libigl python bindings by allowing implementation of custom simulations and geometry operations.
Automatic differentiation
The main feature of triangulax is compatibility with (forward- and reverse-mode) automatic differentiation. This enables computation of gradients of any mesh-based function. Most tools are also compatible with JAX’s JIT-compilation, delivering high performance in high-level Python (rather than C++).
For example, consider:
Flattening or deforming 3D models (computer graphics)
Gradient-based “meta-optimization” and inverse problems
Since triangulax is fully JAX-compatible, it allows differentiating a simulation w.r.t. its parameters. This means one can apply gradient-based optimization to inverse problems. For example, in the tissue mechanics context, one can ask: what do individual cells need to do so that the tissue as a whole takes on a certain shape?
Developer guide and installation instructions
This package is developed based on Jupyter notebooks, which are converted into python modules using nbdev. Take a look at .github/workflows/copilot-instructions.md for details.
import iglimport jaximport jax.numpy as jnpfrom triangulax import mesh, geometry# load example mesh and convert to half-edge meshvertices, _, _, faces, _, _ = igl.readOBJ("test_meshes/disk.obj")hemesh = mesh.HeMesh.from_triangles(vertices.shape[0], faces)# with the half-edge mesh, you can carry out various operations, for example# compute the coordination number by summing incoming half-edges per vertexcoord_number = jnp.zeros(hemesh.n_vertices)coord_number = coord_number.at[hemesh.dest].add(jnp.ones(hemesh.n_hes))print("Mean coordination number:", coord_number.mean())# Let's define a simple geometric function and compute its gradient with JAXdef mean_voronoi_area(vertices, hemesh: mesh.HeMesh) ->float:"""Compute the mean Voronoi area per vertex.""" voronoi_areas = geometry.get_voronoi_areas(vertices, hemesh)return jnp.mean(voronoi_areas)value, gradient = jax.value_and_grad(mean_voronoi_area)(vertices, hemesh)print("Mean gradient norm:", jnp.linalg.norm(gradient, axis=1).mean())
Warning: readOBJ() ignored non-comment line 3:
o flat_tri_ecmc
Mean coordination number: 5.40458
Mean gradient norm: 0.00036383414