from triangulax.triangular import TriMeshFinite-element gradient and cotan-Laplacian
Building on the mesh geometry and adjancency-based operators, we can now define two important linear operators that depend both on mesh connectivity and on mesh geometry. They are the (discrete, triangulation-based) equivalent of the gradient and Laplace-Beltrami operator. The latter is known as the cotan Laplacian.
We now implement gradient (per-vertex scalar field -> per-face vector field) and the cotan-Laplacian (vertex -> vertex) using gather/scatter ops. In both cases, we start with a scalar field \(u_i\) defined per vertex \(i\) of the triangulation. The finite-element gradient is defined for each face \(ijk\), like so: \[ (\nabla u)_{ijk} = \sum_{l\in \{i,j,k\}} u_l \nabla\phi_l \] where \(\phi_i\) is a linear finite element test function (linear Lagrange element) and has gradient \[ \nabla\phi_i = \frac{1}{2a_{ijk}} (\mathbf{v}_k-\mathbf{v}_j)^\perp \] plus cyclic permutations. Here, \(a_{ijk}\) is the triangle area, \(\mathbf{v}_i\) are the vertex positions, and \(()^\perp\) denotes rotation by 90 degrees (in 3D, you rotate about the triangle normal).
The cot-Laplacian computes the following per-vertex field: \[ (\Delta u)_i = \frac{1}{2} \sum_{j} (\cot\alpha_j +\cot\beta_j) (u_j-u_i) \] The sum is over adjacent vertices, and \(\alpha_j, \beta_j\) are the two triangle angles “opposite” to the edge \(ij\).
To check for correctness, we can compare with this libgigl tutorial, using the test mesh and some random test fields.
# load test data
mesh = TriMesh.read_obj("../test_meshes/disk.obj")
hemesh = msh.HeMesh.from_triangles(mesh.vertices.shape[0], mesh.faces)
geommesh = msh.GeomMesh(*hemesh.n_items, mesh.vertices, mesh.face_positions)
mesh_3d = TriMesh.read_obj("../test_meshes/disk.obj", dim=3)
geommesh_3d = msh.GeomMesh(*hemesh.n_items, mesh_3d.vertices, mesh_3d.face_positions)Warning: readOBJ() ignored non-comment line 3:
o flat_tri_ecmc
Warning: readOBJ() ignored non-comment line 3:
o flat_tri_ecmc
Sparse matrix utility functions
diag_jsparse
def diag_jsparse(
v:Float[Array, 'N'], k:int=0
)->BCOO:
Construct a diagonal jax.sparse array. Plugin replacement for np.diag
bcoo_to_scipy
def bcoo_to_scipy(
A:BCOO, # Input JAX sparse matrix
)->csr_matrix: # Equivalent SciPy sparse matrix
Convert a JAX BCOO sparse matrix to a SciPy CSR sparse matrix.
scipy_to_bcoo
def scipy_to_bcoo(
A, # Input sparse matrix (CSR or CSC recommended)
)->BCOO: # Equivalent JAX sparse matrix
Convert a SciPy sparse matrix (CSC or CSR) to a JAX BCOO sparse matrix without converting to dense.
Cotan-Laplacian
compute_cotan_laplace
def compute_cotan_laplace(
vertices:Float[Array, 'n_vertices dim'], hemesh:HeMesh, vertex_field:Float[Array, 'n_vertices ...']
)->Float[Array, 'n_vertices ...']:
Compute cotangent laplacian of a per-vertex field (natural boundary conditions).
cotan_laplace_sparse
def cotan_laplace_sparse(
vertices:Float[Array, 'n_vertices dim'], hemesh:HeMesh
)->BCOO:
Assemble cotangent Laplacian as a sparse matrix (BCOO).
# Test against libigl cotmatrix (natural boundary conditions)
key = jax.random.PRNGKey(0)
u = jax.random.normal(key, (hemesh.n_vertices,))
u_vec = jax.random.normal(key, (hemesh.n_vertices, 3))
L = igl.cotmatrix(np.asarray(geommesh.vertices), np.asarray(hemesh.faces))
lap_jax = compute_cotan_laplace(geommesh.vertices, hemesh, u)
lap_igl = L @ np.asarray(u)
rel_err = np.linalg.norm(np.asarray(lap_jax) - lap_igl) / np.linalg.norm(lap_igl)
print("scalar field rel. error:", rel_err)
lap_jax_vec = compute_cotan_laplace(geommesh.vertices, hemesh, u_vec)
lap_igl_vec = L @ np.asarray(u_vec)
rel_err_vec = np.linalg.norm(np.asarray(lap_jax_vec) - lap_igl_vec) / np.linalg.norm(lap_igl_vec)
print("vector field rel. error:", rel_err_vec)scalar field rel. error: 1.7692000627878292e-16
vector field rel. error: 2.011929541056845e-16
# test sparse cotan Laplacian vs apply function
key = jax.random.PRNGKey(0)
u_test = jax.random.normal(key, (hemesh.n_vertices,))
L_sparse = cotan_laplace_sparse(geommesh.vertices, hemesh)
lap_sparse = L_sparse @ u_test
lap_apply = compute_cotan_laplace(geommesh.vertices, hemesh, u_test)
rel_err_sparse = jnp.linalg.norm(lap_sparse - lap_apply) / jnp.linalg.norm(lap_apply)
print("cotan sparse vs apply rel. error:", rel_err_sparse)cotan sparse vs apply rel. error: 1.302894564211555e-16
bcoo_to_scipy(L_sparse), (scipy_to_bcoo(bcoo_to_scipy(L_sparse)).todense() == L_sparse.todense()).all()(<Compressed Sparse Row sparse matrix of dtype 'float64'
with 839 stored elements and shape (131, 131)>,
Array(True, dtype=bool))
Mass matrix (lumped)
The finite-element mass matrix \(M\) appears whenever we discretize a time-dependent PDE. For a lumped (diagonal) mass matrix, \(M_{ii} = A_i\) where \(A_i\) is the Voronoi area associated with vertex \(i\).
mass_matrix_inv_sparse
def mass_matrix_inv_sparse(
vertices:Float[Array, 'n_vertices dim'], hemesh:HeMesh
)->BCOO: # Diagonal sparse inverse mass matrix.
Assemble inverse lumped mass matrix as a sparse matrix (BCOO).
mass_matrix_sparse
def mass_matrix_sparse(
vertices:Float[Array, 'n_vertices dim'], # Vertex positions.
hemesh:HeMesh, # Half-edge mesh connectivity.
)->BCOO: # Diagonal sparse mass matrix.
Assemble lumped (diagonal) mass matrix as a sparse matrix (BCOO).
The lumped mass matrix is diagonal with entries equal to the Voronoi dual area of each vertex: \(M_{ii} = A_i\).
# Test mass matrix against igl.massmatrix (Voronoi type)
M_jax = mass_matrix_sparse(geommesh.vertices, hemesh)
M_igl = igl.massmatrix(np.asarray(geommesh.vertices), np.asarray(hemesh.faces), igl.MASSMATRIX_TYPE_VORONOI)
rel_err_mass = np.linalg.norm(M_jax.todense() - M_igl.todense()) / np.linalg.norm(M_igl.todense())
print("mass matrix rel. error:", rel_err_mass)
# Test inverse
M_inv_jax = mass_matrix_inv_sparse(geommesh.vertices, hemesh)
identity_check = (M_jax @ M_inv_jax.todense()).todense()
print("M @ M_inv ~ I error:", np.linalg.norm(identity_check - np.eye(hemesh.n_vertices)))Finite-element gradient
Not to be confused with the discrete-exterior-calculus operators, which only depend on mesh connectivity, not geometry.
compute_gradient_3d
def compute_gradient_3d(
vertices:Float[Array, 'n_vertices 3'], hemesh:HeMesh, vertex_field:Float[Array, 'n_vertices ...']
)->Float[Array, 'n_faces 3 ...']:
Compute the linear finite-element gradient (constant per face).
compute_gradient_2d
def compute_gradient_2d(
vertices:Float[Array, 'n_vertices 2'], hemesh:HeMesh, vertex_field:Float[Array, 'n_vertices ...']
)->Float[Array, 'n_faces 2 ...']:
Compute the linear finite-element gradient (constant per face).
gradient_sparse_3d
def gradient_sparse_3d(
vertices:Float[Array, 'n_vertices 3'], hemesh:HeMesh
)->BCOO:
Assemble FE gradient in 3D as a sparse matrix (BCOO).
Returns a matrix G with shape (3n_faces, n_vertices) such that for a scalar per-vertex field u (n_vertices,), the per-face gradients are obtained via: g_flat = G @ u # (3n_faces,) g = g_flat.reshape((3, n_faces)).T # (n_faces, 3)
This row layout matches libigl’s grad operator convention (component blocks).
gradient_sparse_2d
def gradient_sparse_2d(
vertices:Float[Array, 'n_vertices 2'], hemesh:HeMesh
)->BCOO:
Assemble FE gradient in 2D as a sparse matrix (BCOO).
Returns a matrix G with shape (2n_faces, n_vertices) such that for a scalar per-vertex field u (n_vertices,), the per-face gradients are obtained via: g_flat = G @ u # (2n_faces,) g = g_flat.reshape((2, n_faces)).T # (n_faces, 2)
This row layout matches libigl’s grad operator convention (component blocks).
reshape_face_gradient
def reshape_face_gradient(
grad_flat:Float[Array, 'dim_n_faces ...'], # Output of `G @ u`, with shape `(dim*n_faces, ...)`.
n_faces:int, # Number of mesh faces.
dim:int, # Spatial dimension (2 or 3).
)->Float[Array, 'n_faces dim ...']: # Reshaped gradient with shape `(n_faces, dim, ...)`, matching the output convention
of `compute_gradient_2d/3d`.
Reshape a flattened FE gradient into per-face vectors.
This is meant to be used with gradient_sparse_2d/3d (and any similar operator that stacks components in blocks), where applying the sparse matrix yields an array of shape (dim*n_faces, ...) (for scalar/vector/tensor per-vertex fields).
# here's how to compute the gradient in libigl
grad_matrix = igl.grad(np.asarray(geommesh.vertices), np.asarray(hemesh.faces))
# calculate the gradient of field by matrix multiplication
grad_igl = grad_matrix @ np.asarray(u)
# order='F' copied from igl tutorial
grad_igl = grad_igl.reshape((hemesh.n_faces, geommesh.dim), order='F')# test jax and libigl implementations
grad_jax = compute_gradient_2d(geommesh.vertices, hemesh, u)
rel_err_grad = np.linalg.norm(np.asarray(grad_jax) - grad_igl) / np.linalg.norm(grad_igl)
print("gradient rel. error:", rel_err_grad)gradient rel. error: 1.413315746703021e-16
# same test, in 3d
grad_matrix_3d = igl.grad(np.asarray(geommesh_3d.vertices), np.asarray(hemesh.faces))
grad_igl_3d = grad_matrix_3d @ np.asarray(u)
grad_igl_3d = grad_igl_3d.reshape((hemesh.n_faces, geommesh_3d.dim), order='F')
grad_jax_3d = compute_gradient_3d(geommesh_3d.vertices, hemesh, u)
rel_err_grad_3d = np.linalg.norm(np.asarray(grad_jax_3d) - grad_igl_3d) / np.linalg.norm(grad_igl_3d)
print("gradient rel. error:", rel_err_grad_3d)gradient rel. error: 1.5657863888820882e-16
# Test sparse gradient operators vs apply functions
key = jax.random.PRNGKey(123)
u_test = jax.random.normal(key, (hemesh.n_vertices,))
G2 = gradient_sparse_2d(geommesh.vertices, hemesh)
g2 = reshape_face_gradient(G2 @ u_test, hemesh.n_faces, dim=2)
g2_apply = compute_gradient_2d(geommesh.vertices, hemesh, u_test)
rel_err_g2 = jnp.linalg.norm(g2 - g2_apply) / jnp.linalg.norm(g2_apply)
print("2D grad sparse vs apply rel. error:", rel_err_g2)
G3 = gradient_sparse_3d(geommesh_3d.vertices, hemesh)
g3 = reshape_face_gradient(G3 @ u_test, hemesh.n_faces, dim=3)
g3_apply = compute_gradient_3d(geommesh_3d.vertices, hemesh, u_test)
rel_err_g3 = jnp.linalg.norm(g3 - g3_apply) / jnp.linalg.norm(g3_apply)
print("3D grad sparse vs apply rel. error:", rel_err_g3)
# quick sanity check for vector/tensor fields: u has extra axes
u_vec = jax.random.normal(key, (hemesh.n_vertices, 3))
g2_vec = reshape_face_gradient(G2 @ u_vec, hemesh.n_faces, dim=2)
g2_vec_apply = compute_gradient_2d(geommesh.vertices, hemesh, u_vec)
rel_err_g2_vec = jnp.linalg.norm(g2_vec - g2_vec_apply) / jnp.linalg.norm(g2_vec_apply)
print("2D grad (vector field) sparse vs apply rel. error:", rel_err_g2_vec)2D grad sparse vs apply rel. error: 8.71017994729607e-17
3D grad sparse vs apply rel. error: 9.602668379845331e-17
2D grad (vector field) sparse vs apply rel. error: 8.285943150518157e-17
Wrapping as linear operators
It’s often useful to think of functions like compute_cotan_laplace() as a linear operator on fields on meshes. For example, imagine you want to solve the Laplace equation on a mesh with fixed vertex positions and connectivity. You will want to use a linear solver. Luckily, most such solvers only need to be able to compute the action of a linear operator on an input vector, and don’t need an explicit matrix representation.
In the JAX ecosystem, the lineax library defines linear solvers. We can wrap compute_cotan_laplace() as a linear operator, which allows us to pass it into iterative linear algebra algorithms.
# "bake in" the connectivity and vertex positions
laplace_op = functools.partial(compute_cotan_laplace, geommesh.vertices, hemesh)
_ = laplace_op(u) # you can apply this to vertex-fields
# define the linear operator
laplace_op_lx = lineax.FunctionLinearOperator(laplace_op, input_structure=jax.eval_shape(laplace_op, u))
# now you can use the linear operator to compute matrix representations, solve linear systems, etc.
mat = laplace_op_lx.as_matrix()
mat.shape(131, 131)
linear_op_to_sparse
def linear_op_to_sparse(
op:callable, in_shape:tuple, out_shape:tuple, dtype:Union=None, chunk_size:int=256, tol:float=0.0
)->BCOO:
Build a sparse matrix for a linear map using batched one-hot probes.
Note: this function is general, but not necessarily very efficient for large matrix sizes.
# compare sparse construction to lineax dense matrix (small meshes only)
if hemesh.n_vertices <= 2000:
laplace_op_local = functools.partial(compute_cotan_laplace, geommesh.vertices, hemesh)
laplace_op_lx_local = lineax.FunctionLinearOperator(laplace_op_local,
input_structure=jax.eval_shape(laplace_op_local, u))
sp_mat = linear_op_to_sparse(laplace_op_local, (hemesh.n_vertices,), (hemesh.n_vertices,))
mat_dense = laplace_op_lx_local.as_matrix()
rel_err_sparse = jnp.linalg.norm(sp_mat.todense() - mat_dense) / jnp.linalg.norm(mat_dense)
print("sparse vs lineax rel. error:", rel_err_sparse)
else:
print("Skipping dense comparison for large mesh.")sparse vs lineax rel. error: 0.0
## now let's try with a large mesh
mesh = TriMesh.read_obj("test_meshes/torus_high_resolution.obj")
hemesh = msh.HeMesh.from_triangles(mesh.vertices.shape[0], mesh.faces)
geommesh = msh.GeomMesh(*hemesh.n_items, mesh.vertices, mesh.face_positions)
laplace_op = jax.jit(functools.partial(compute_cotan_laplace, geommesh.vertices, hemesh))Warning: readOBJ() ignored non-comment line 3:
o Torus
sparse_laplace_op = linear_op_to_sparse(laplace_op, (hemesh.n_vertices,), (hemesh.n_vertices,))