Finite-element gradient and cotan-Laplacian

Building on the mesh geometry and adjancency-based operators, we can now define two important linear operators that depend both on mesh connectivity and on mesh geometry. They are the (discrete, triangulation-based) equivalent of the gradient and Laplace-Beltrami operator. The latter is known as the cotan Laplacian.

We now implement gradient (per-vertex scalar field -> per-face vector field) and the cotan-Laplacian (vertex -> vertex) using gather/scatter ops. In both cases, we start with a scalar field \(u_i\) defined per vertex \(i\) of the triangulation. The finite-element gradient is defined for each face \(ijk\), like so: \[ (\nabla u)_{ijk} = \sum_{l\in \{i,j,k\}} u_l \nabla\phi_l \] where \(\phi_i\) is a linear finite element test function (linear Lagrange element) and has gradient \[ \nabla\phi_i = \frac{1}{2a_{ijk}} (\mathbf{v}_k-\mathbf{v}_j)^\perp \] plus cyclic permutations. Here, \(a_{ijk}\) is the triangle area, \(\mathbf{v}_i\) are the vertex positions, and \(()^\perp\) denotes rotation by 90 degrees (in 3D, you rotate about the triangle normal).

The cot-Laplacian computes the following per-vertex field: \[ (\Delta u)_i = \frac{1}{2} \sum_{j} (\cot\alpha_j +\cot\beta_j) (u_j-u_i) \] The sum is over adjacent vertices, and \(\alpha_j, \beta_j\) are the two triangle angles “opposite” to the edge \(ij\).

To check for correctness, we can compare with this libgigl tutorial, using the test mesh and some random test fields.

from triangulax.triangular import TriMesh
# load test data

mesh = TriMesh.read_obj("../test_meshes/disk.obj")
hemesh = msh.HeMesh.from_triangles(mesh.vertices.shape[0], mesh.faces)
geommesh = msh.GeomMesh(*hemesh.n_items, mesh.vertices, mesh.face_positions)

mesh_3d = TriMesh.read_obj("../test_meshes/disk.obj", dim=3)
geommesh_3d = msh.GeomMesh(*hemesh.n_items, mesh_3d.vertices, mesh_3d.face_positions)
Warning: readOBJ() ignored non-comment line 3:
  o flat_tri_ecmc
Warning: readOBJ() ignored non-comment line 3:
  o flat_tri_ecmc

Sparse matrix utility functions


source

diag_jsparse


def diag_jsparse(
    v:Float[Array, 'N'], k:int=0
)->BCOO:

Construct a diagonal jax.sparse array. Plugin replacement for np.diag


source

bcoo_to_scipy


def bcoo_to_scipy(
    A:BCOO, # Input JAX sparse matrix
)->csr_matrix: # Equivalent SciPy sparse matrix

Convert a JAX BCOO sparse matrix to a SciPy CSR sparse matrix.


source

scipy_to_bcoo


def scipy_to_bcoo(
    A, # Input sparse matrix (CSR or CSC recommended)
)->BCOO: # Equivalent JAX sparse matrix

Convert a SciPy sparse matrix (CSC or CSR) to a JAX BCOO sparse matrix without converting to dense.

Cotan-Laplacian


source

compute_cotan_laplace


def compute_cotan_laplace(
    vertices:Float[Array, 'n_vertices dim'], hemesh:HeMesh, vertex_field:Float[Array, 'n_vertices ...']
)->Float[Array, 'n_vertices ...']:

Compute cotangent laplacian of a per-vertex field (natural boundary conditions).


source

cotan_laplace_sparse


def cotan_laplace_sparse(
    vertices:Float[Array, 'n_vertices dim'], hemesh:HeMesh
)->BCOO:

Assemble cotangent Laplacian as a sparse matrix (BCOO).

# Test against libigl cotmatrix (natural boundary conditions)
key = jax.random.PRNGKey(0)
u = jax.random.normal(key, (hemesh.n_vertices,))
u_vec = jax.random.normal(key, (hemesh.n_vertices, 3))

L = igl.cotmatrix(np.asarray(geommesh.vertices), np.asarray(hemesh.faces))

lap_jax = compute_cotan_laplace(geommesh.vertices, hemesh, u)
lap_igl = L @ np.asarray(u)

rel_err = np.linalg.norm(np.asarray(lap_jax) - lap_igl) / np.linalg.norm(lap_igl)
print("scalar field rel. error:", rel_err)

lap_jax_vec = compute_cotan_laplace(geommesh.vertices, hemesh, u_vec)
lap_igl_vec = L @ np.asarray(u_vec)

rel_err_vec = np.linalg.norm(np.asarray(lap_jax_vec) - lap_igl_vec) / np.linalg.norm(lap_igl_vec)
print("vector field rel. error:", rel_err_vec)
scalar field rel. error: 1.7692000627878292e-16
vector field rel. error: 2.011929541056845e-16
# test sparse cotan Laplacian vs apply function
key = jax.random.PRNGKey(0)
u_test = jax.random.normal(key, (hemesh.n_vertices,))

L_sparse = cotan_laplace_sparse(geommesh.vertices, hemesh)
lap_sparse = L_sparse @ u_test
lap_apply = compute_cotan_laplace(geommesh.vertices, hemesh, u_test)

rel_err_sparse = jnp.linalg.norm(lap_sparse - lap_apply) / jnp.linalg.norm(lap_apply)
print("cotan sparse vs apply rel. error:", rel_err_sparse)
cotan sparse vs apply rel. error: 1.302894564211555e-16
bcoo_to_scipy(L_sparse), (scipy_to_bcoo(bcoo_to_scipy(L_sparse)).todense() == L_sparse.todense()).all()
(<Compressed Sparse Row sparse matrix of dtype 'float64'
    with 839 stored elements and shape (131, 131)>,
 Array(True, dtype=bool))

Mass matrix (lumped)

The finite-element mass matrix \(M\) appears whenever we discretize a time-dependent PDE. For a lumped (diagonal) mass matrix, \(M_{ii} = A_i\) where \(A_i\) is the Voronoi area associated with vertex \(i\).


source

mass_matrix_inv_sparse


def mass_matrix_inv_sparse(
    vertices:Float[Array, 'n_vertices dim'], hemesh:HeMesh
)->BCOO: # Diagonal sparse inverse mass matrix.

Assemble inverse lumped mass matrix as a sparse matrix (BCOO).


source

mass_matrix_sparse


def mass_matrix_sparse(
    vertices:Float[Array, 'n_vertices dim'], # Vertex positions.
    hemesh:HeMesh, # Half-edge mesh connectivity.
)->BCOO: # Diagonal sparse mass matrix.

Assemble lumped (diagonal) mass matrix as a sparse matrix (BCOO).

The lumped mass matrix is diagonal with entries equal to the Voronoi dual area of each vertex: \(M_{ii} = A_i\).

# Test mass matrix against igl.massmatrix (Voronoi type)
M_jax = mass_matrix_sparse(geommesh.vertices, hemesh)
M_igl = igl.massmatrix(np.asarray(geommesh.vertices), np.asarray(hemesh.faces), igl.MASSMATRIX_TYPE_VORONOI)

rel_err_mass = np.linalg.norm(M_jax.todense() - M_igl.todense()) / np.linalg.norm(M_igl.todense())
print("mass matrix rel. error:", rel_err_mass)

# Test inverse
M_inv_jax = mass_matrix_inv_sparse(geommesh.vertices, hemesh)
identity_check = (M_jax @ M_inv_jax.todense()).todense()
print("M @ M_inv ~ I error:", np.linalg.norm(identity_check - np.eye(hemesh.n_vertices)))

Finite-element gradient

Not to be confused with the discrete-exterior-calculus operators, which only depend on mesh connectivity, not geometry.


source

compute_gradient_3d


def compute_gradient_3d(
    vertices:Float[Array, 'n_vertices 3'], hemesh:HeMesh, vertex_field:Float[Array, 'n_vertices ...']
)->Float[Array, 'n_faces 3 ...']:

Compute the linear finite-element gradient (constant per face).


source

compute_gradient_2d


def compute_gradient_2d(
    vertices:Float[Array, 'n_vertices 2'], hemesh:HeMesh, vertex_field:Float[Array, 'n_vertices ...']
)->Float[Array, 'n_faces 2 ...']:

Compute the linear finite-element gradient (constant per face).


source

gradient_sparse_3d


def gradient_sparse_3d(
    vertices:Float[Array, 'n_vertices 3'], hemesh:HeMesh
)->BCOO:

Assemble FE gradient in 3D as a sparse matrix (BCOO).

Returns a matrix G with shape (3n_faces, n_vertices) such that for a scalar per-vertex field u (n_vertices,), the per-face gradients are obtained via: g_flat = G @ u # (3n_faces,) g = g_flat.reshape((3, n_faces)).T # (n_faces, 3)

This row layout matches libigl’s grad operator convention (component blocks).


source

gradient_sparse_2d


def gradient_sparse_2d(
    vertices:Float[Array, 'n_vertices 2'], hemesh:HeMesh
)->BCOO:

Assemble FE gradient in 2D as a sparse matrix (BCOO).

Returns a matrix G with shape (2n_faces, n_vertices) such that for a scalar per-vertex field u (n_vertices,), the per-face gradients are obtained via: g_flat = G @ u # (2n_faces,) g = g_flat.reshape((2, n_faces)).T # (n_faces, 2)

This row layout matches libigl’s grad operator convention (component blocks).


source

reshape_face_gradient


def reshape_face_gradient(
    grad_flat:Float[Array, 'dim_n_faces ...'], # Output of `G @ u`, with shape `(dim*n_faces, ...)`.
    n_faces:int, # Number of mesh faces.
    dim:int, # Spatial dimension (2 or 3).
)->Float[Array, 'n_faces dim ...']: # Reshaped gradient with shape `(n_faces, dim, ...)`, matching the output convention
of `compute_gradient_2d/3d`.

Reshape a flattened FE gradient into per-face vectors.

This is meant to be used with gradient_sparse_2d/3d (and any similar operator that stacks components in blocks), where applying the sparse matrix yields an array of shape (dim*n_faces, ...) (for scalar/vector/tensor per-vertex fields).

# here's how to compute the gradient in libigl

grad_matrix = igl.grad(np.asarray(geommesh.vertices), np.asarray(hemesh.faces))
# calculate the gradient of field by matrix multiplication
grad_igl = grad_matrix @ np.asarray(u)
# order='F' copied from igl tutorial
grad_igl = grad_igl.reshape((hemesh.n_faces, geommesh.dim), order='F')
# test jax and libigl implementations

grad_jax = compute_gradient_2d(geommesh.vertices, hemesh, u)

rel_err_grad = np.linalg.norm(np.asarray(grad_jax) - grad_igl) / np.linalg.norm(grad_igl)
print("gradient rel. error:", rel_err_grad)
gradient rel. error: 1.413315746703021e-16
# same test, in 3d

grad_matrix_3d = igl.grad(np.asarray(geommesh_3d.vertices), np.asarray(hemesh.faces))
grad_igl_3d = grad_matrix_3d @ np.asarray(u)
grad_igl_3d = grad_igl_3d.reshape((hemesh.n_faces, geommesh_3d.dim), order='F')

grad_jax_3d = compute_gradient_3d(geommesh_3d.vertices, hemesh, u)

rel_err_grad_3d = np.linalg.norm(np.asarray(grad_jax_3d) - grad_igl_3d) / np.linalg.norm(grad_igl_3d)
print("gradient rel. error:", rel_err_grad_3d)
gradient rel. error: 1.5657863888820882e-16
# Test sparse gradient operators vs apply functions
key = jax.random.PRNGKey(123)
u_test = jax.random.normal(key, (hemesh.n_vertices,))

G2 = gradient_sparse_2d(geommesh.vertices, hemesh)
g2 = reshape_face_gradient(G2 @ u_test, hemesh.n_faces, dim=2)
g2_apply = compute_gradient_2d(geommesh.vertices, hemesh, u_test)
rel_err_g2 = jnp.linalg.norm(g2 - g2_apply) / jnp.linalg.norm(g2_apply)
print("2D grad sparse vs apply rel. error:", rel_err_g2)

G3 = gradient_sparse_3d(geommesh_3d.vertices, hemesh)
g3 = reshape_face_gradient(G3 @ u_test, hemesh.n_faces, dim=3)
g3_apply = compute_gradient_3d(geommesh_3d.vertices, hemesh, u_test)
rel_err_g3 = jnp.linalg.norm(g3 - g3_apply) / jnp.linalg.norm(g3_apply)
print("3D grad sparse vs apply rel. error:", rel_err_g3)

# quick sanity check for vector/tensor fields: u has extra axes
u_vec = jax.random.normal(key, (hemesh.n_vertices, 3))
g2_vec = reshape_face_gradient(G2 @ u_vec, hemesh.n_faces, dim=2)
g2_vec_apply = compute_gradient_2d(geommesh.vertices, hemesh, u_vec)
rel_err_g2_vec = jnp.linalg.norm(g2_vec - g2_vec_apply) / jnp.linalg.norm(g2_vec_apply)
print("2D grad (vector field) sparse vs apply rel. error:", rel_err_g2_vec)
2D grad sparse vs apply rel. error: 8.71017994729607e-17
3D grad sparse vs apply rel. error: 9.602668379845331e-17
2D grad (vector field) sparse vs apply rel. error: 8.285943150518157e-17

Wrapping as linear operators

It’s often useful to think of functions like compute_cotan_laplace() as a linear operator on fields on meshes. For example, imagine you want to solve the Laplace equation on a mesh with fixed vertex positions and connectivity. You will want to use a linear solver. Luckily, most such solvers only need to be able to compute the action of a linear operator on an input vector, and don’t need an explicit matrix representation.

In the JAX ecosystem, the lineax library defines linear solvers. We can wrap compute_cotan_laplace() as a linear operator, which allows us to pass it into iterative linear algebra algorithms.

# "bake in" the connectivity and vertex positions

laplace_op = functools.partial(compute_cotan_laplace, geommesh.vertices, hemesh)
_ = laplace_op(u) # you can apply this to vertex-fields

# define the linear operator
laplace_op_lx = lineax.FunctionLinearOperator(laplace_op, input_structure=jax.eval_shape(laplace_op, u))

# now you can use the linear operator to compute matrix representations, solve linear systems, etc.
mat = laplace_op_lx.as_matrix()
mat.shape
(131, 131)

source

linear_op_to_sparse


def linear_op_to_sparse(
    op:callable, in_shape:tuple, out_shape:tuple, dtype:Union=None, chunk_size:int=256, tol:float=0.0
)->BCOO:

Build a sparse matrix for a linear map using batched one-hot probes.

Note: this function is general, but not necessarily very efficient for large matrix sizes.

# compare sparse construction to lineax dense matrix (small meshes only)
if hemesh.n_vertices <= 2000:
    laplace_op_local = functools.partial(compute_cotan_laplace, geommesh.vertices, hemesh)
    laplace_op_lx_local = lineax.FunctionLinearOperator(laplace_op_local,
                                                        input_structure=jax.eval_shape(laplace_op_local, u))
    sp_mat = linear_op_to_sparse(laplace_op_local, (hemesh.n_vertices,), (hemesh.n_vertices,))
    mat_dense = laplace_op_lx_local.as_matrix()
    rel_err_sparse = jnp.linalg.norm(sp_mat.todense() - mat_dense) / jnp.linalg.norm(mat_dense)
    print("sparse vs lineax rel. error:", rel_err_sparse)
else:
    print("Skipping dense comparison for large mesh.")
sparse vs lineax rel. error: 0.0
## now let's try with a large mesh

mesh = TriMesh.read_obj("test_meshes/torus_high_resolution.obj")
hemesh = msh.HeMesh.from_triangles(mesh.vertices.shape[0], mesh.faces)
geommesh = msh.GeomMesh(*hemesh.n_items, mesh.vertices, mesh.face_positions)

laplace_op = jax.jit(functools.partial(compute_cotan_laplace, geommesh.vertices, hemesh))
Warning: readOBJ() ignored non-comment line 3:
  o Torus
sparse_laplace_op = linear_op_to_sparse(laplace_op, (hemesh.n_vertices,), (hemesh.n_vertices,))