from triangulax.triangular import TriMeshAdjacency-like operators on half-edge meshes
Using the HeMesh data structure, we can efficiently “traverse” our mesh. Using such traversals, one can express many adjacency-based linear operators, for example:
- Sum over all half-edges “incoming” to a vertex (special case: count the incoming edges, i.e., compute the coordination number)
- Compute the finite-element gradient of a function defined on vertices
These operations can be done efficiently using a “gather/scatter” approach, see jax.numpy.ndarray.at. There is no need to explicitly instantiate a matrix for the operators.
All operators defined in this notebook depend only on the mesh topology, not the geometry (vertex/face positions)
# load test data
mesh = TriMesh.read_obj("test_meshes/disk.obj")
hemesh = msh.HeMesh.from_triangles(mesh.vertices.shape[0], mesh.faces)
geommesh = msh.GeomMesh(*hemesh.n_items, mesh.vertices, mesh.face_positions)
mesh_3d = TriMesh.read_obj("test_meshes/disk.obj", dim=3)
geommesh_3d = msh.GeomMesh(*hemesh.n_items, mesh_3d.vertices, mesh_3d.face_positions)Warning: readOBJ() ignored non-comment line 3:
o flat_tri_ecmc
Warning: readOBJ() ignored non-comment line 3:
o flat_tri_ecmc
Discrete derivative
On a triangular mesh, there are two natural “derivatives”: for a per-vertex field, the difference across half-edges, and for a per-half-edge field, the circulation around a face (this is the basis of discrete exterior calculus).
get_exterior_circulation
def get_exterior_circulation(
hemesh:HeMesh, he_field:Float[Array, 'n_hes ...']
)->Float[Array, 'n_faces ...']:
get_exterior_gradient
def get_exterior_gradient(
hemesh:HeMesh, v_field:Float[Array, 'n_vertices ...']
)->Float[Array, 'n_hes ...']:
# define a random scalar field on vertices and compute its gradient on halfedges
v_field = jax.random.uniform(jax.random.PRNGKey(0), (hemesh.n_vertices,))
he_gradient = get_exterior_gradient(hemesh, v_field)
f_circulation = get_exterior_circulation(hemesh, he_gradient)
hemesh, he_gradient.shape, f_circulation.shape, jnp.allclose(f_circulation, 0 )(HeMesh(N_V=131, N_HE=708, N_F=224), (708,), (224,), Array(True, dtype=bool))
Summing over adjacent mesh elements
A second important class of operation is summing over adjacent mesh elements. For example, to get the coordination number of a vertex, you want to sum the value \(1\) over all incoming half-edges. For computing things like cell areas, it’s also useful to sum over half-edges opposite to a vertex.
sum_he_to_vertex_opposite
def sum_he_to_vertex_opposite(
hemesh:HeMesh, he_field:Float[Array, 'n_hes ...']
)->Float[Array, 'n_vertices ...']:
Sum a half-edge field onto opposite vertices.
Attention: can include boundary half-edges!
hemesh: connectivity information he_field: (n_hes,) or (n_hes, d) array
sum_he_to_vertex_incoming
def sum_he_to_vertex_incoming(
hemesh:HeMesh, he_field:Float[Array, 'n_hes ...']
)->Float[Array, 'n_vertices ...']:
Sum a half-edge field onto destination vertices.
hemesh: connectivity information he_field: (n_hes,) or (n_hes, d) array
sum_face_to_he
def sum_face_to_he(
hemesh:HeMesh, f_field:Float[Array, 'n_faces ...']
)->Float[Array, 'n_hes ...']:
Sum face-field to half-edges. Sums over the face of the half-edge and its twin.
sum_he_to_face
def sum_he_to_face(
hemesh:HeMesh, he_field:Float[Array, 'n_hes ...']
)->Float[Array, 'n_faces ...']:
Sum over all half-edges of a face. Alias of get_exterior_circulation.
average_face_to_vertex
def average_face_to_vertex(
hemesh:HeMesh, f_field:Float[Array, 'n_faces ...']
)->Float[Array, 'n_vertices ...']:
Average face-field to vertices. Uniform weights.
sum_face_to_vertex
def sum_face_to_vertex(
hemesh:HeMesh, f_field:Float[Array, 'n_faces ...']
)->Float[Array, 'n_vertices ...']:
Sum face-field to vertices. Sums over the faces incident on the vertex.
average_vertex_to_face
def average_vertex_to_face(
hemesh:HeMesh, v_field:Float[Array, 'n_vertices ...']
)->Float[Array, 'n_faces ...']:
Average vertex-field to faces.
sum_vertex_to_face
def sum_vertex_to_face(
hemesh:HeMesh, v_field:Float[Array, 'n_vertices ...']
)->Float[Array, 'n_faces ...']:
Sum vertex-field to faces. Sums over the vertices of the face.
# tests vs libigl
key = jax.random.PRNGKey(123)
u_v = jax.random.normal(key, (hemesh.n_vertices,))
faces_avg_jax = average_vertex_to_face(hemesh, u_v)
faces_avg_igl = igl.average_onto_faces(np.asarray(hemesh.faces), np.asarray(u_v))
rel_err_faces = jnp.linalg.norm(faces_avg_jax - faces_avg_igl) / jnp.linalg.norm(faces_avg_igl)
print("vertex->face rel. error:", rel_err_faces)vertex->face rel. error: 0.0
u_f = jax.random.normal(key, (hemesh.n_faces,))
verts_avg_jax = average_face_to_vertex(hemesh, u_f)
verts_avg_igl = igl.average_onto_vertices(mesh.vertices, np.asarray(hemesh.faces), np.asarray(u_f))
rel_err_verts = jnp.linalg.norm(verts_avg_jax-verts_avg_igl) / jnp.linalg.norm(verts_avg_igl)
print("face->vertex rel. error:", rel_err_verts)face->vertex rel. error: 8.339340577730768e-17
# also works for vector fields
u_f = jax.random.normal(key, (hemesh.n_faces, 10))
verts_avg_jax = average_face_to_vertex(hemesh, u_f)
verts_avg_jax.shape(131, 10)
get_coordination_number
def get_coordination_number(
hemesh:HeMesh
)->Float[Array, 'n_vertices']:
get_coordination_number(hemesh).mean()Array(5.40458015, dtype=float64)
Uniform/graph Laplacian
get_uniform_laplacian
def get_uniform_laplacian(
hemesh:HeMesh
)->BCOO:
Returns the uniform Laplacian matrix as a sparse matrix. Non-normalized, positive definite.
compute_uniform_laplacian
def compute_uniform_laplacian(
hemesh:HeMesh, v_field:Float[Array, 'n_vertices ...']
)->Float[Array, 'n_vertices ...']:
Computes the uniform Laplacian of a vector field. Non-normalized, positive definite.
# test that the matrix and function versions are equivalent
laplace_mat = get_uniform_laplacian(hemesh)
jnp.allclose(laplace_mat @ v_field, compute_uniform_laplacian(hemesh, v_field)), jnp.dot(laplace_mat@v_field, v_field) > 0(Array(True, dtype=bool), Array(True, dtype=bool))