Half-edge meshes

In the TriMesh class, we represent a mesh a list of triangles. However, many common operations are difficult with this data structure. For example, how do you get all the neighbors of a given vertex, or compute the area of a dual cell?

For simulation and geometry processing, we need a different representation of the adjacency information. Typically, this is achieved by a half-edge mesh (HE) data structure. We represent the HE data structure by 3 sets of integer index arrays:

  1. Vertices: 1 \((N_V,)\) matrix, whose entry for vertex \(i\) is an arbitrary HE incident on \(i\)
  2. Edges: 6 \((2N_E,)\) matrices, [origin, dest, nxt, prv, twin, face] for each half-edge (face is undefined for boundary half-edges).
  3. Faces, 1 \((N_F, 1)\) matrix, whose entry for face \(i\) is an arbitrary HE in \(i\). (Not to be confused with the \((N_F, 3)\) matrix of vertex IDs used previously).

Additionally, there are two float arrays for vertex and face positions, as previously. However, we split combinatorial and geometric information - a HeMesh class for the combinatorics, and a couple of regular arrays for the vertex positions, face positions, and vertex/half-edge/face attributes. The latter are packaged into a GeomMesh class. Together, the pair (GeomMesh, HeMesh) describes a mesh (like vertices/faces pair). A named tuple Mesh combines the two.

The first task is to create a helper function to plot mesh connectivity, and to create the half-edge connectivity matrices from the more conventional list-of-triangles format. The latter is somewhat involved.

We follow the notes in notebook 02 to ensure JAX compatibility.


source

label_plot


def label_plot(
    vertices:Float[Array, 'n_vertices 2'], faces:Int[Array, 'n_faces 3'], hemesh:Union=None, vertex_labels:bool=True,
    face_labels:bool=True, ax:Union=None, fontsize:Union=None
)->None:

For debugging purposes. Plot triangular mesh with face/vertex labels in black/blue. If hemesh is not None, the connectivity info from it is used to plot the half-edge labels.

mesh = TriMesh.read_obj("test_meshes/disk.obj")

plt.triplot(*mesh.vertices.T, mesh.faces)
label_plot(mesh.vertices, mesh.faces, fontsize=10)
plt.axis("equal")
Warning: readOBJ() ignored non-comment line 3:
  o flat_tri_ecmc
(np.float64(-1.10003475),
 np.float64(1.09628575),
 np.float64(-1.09934025),
 np.float64(1.09050125))


source

get_half_edge_arrays_vectorized


def get_half_edge_arrays_vectorized(
    n_vertices:int, faces:Int[Array, 'n_faces 3']
)->list:

Get half-edge data structure arrays from faces (vectorized).

Returns: incident, orig, dest, twin, nxt, prv, heface, face_incident

mesh_high_res = TriMesh.read_obj("test_meshes/torus_high_resolution.obj")
mesh = TriMesh.read_obj("test_meshes/disk.obj")
Warning: readOBJ() ignored non-comment line 3:
  o Torus
Warning: readOBJ() ignored non-comment line 3:
  o flat_tri_ecmc
results = get_half_edge_arrays(mesh.vertices.shape[0], mesh.faces)
UsageError: Line magic function `%%timeit` not found.
# test vectorized vs reference implementation for two meshes

mesh = TriMesh.read_obj("test_meshes/disk.obj")
ref = get_half_edge_arrays(mesh.vertices.shape[0], mesh.faces)
fast = get_half_edge_arrays_vectorized(mesh.vertices.shape[0], mesh.faces)

print("Equal?", all([jnp.array_equal(a, b) for a, b in zip(ref, fast)]))

mesh = TriMesh.read_obj("test_meshes/sphere.obj")
ref = get_half_edge_arrays(mesh.vertices.shape[0], mesh.faces)
fast = get_half_edge_arrays_vectorized(mesh.vertices.shape[0], mesh.faces)

print("Equal?", all([jnp.array_equal(a, b) for a, b in zip(ref, fast)]))
Warning: readOBJ() ignored non-comment line 3:
  o flat_tri_ecmc
Warning: readOBJ() ignored non-comment line 3:
  o Icosphere
Equal? True
Equal? True
1.11 ms ± 43 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
CPU times: user 211 ms, sys: 12.6 ms, total: 224 ms
Wall time: 221 ms

source

HeMesh


def HeMesh(
    incident:Int[Array, '*batch n_vertices'], orig:Int[Array, '*batch n_hes'], dest:Int[Array, '*batch n_hes'],
    twin:Int[Array, '*batch n_hes'], nxt:Int[Array, '*batch n_hes'], prv:Int[Array, '*batch n_hes'],
    heface:Int[Array, '*batch n_hes'], face_incident:Int[Array, '*batch n_faces'], inf_vertices:Union=()
)->None:

Half-edge mesh data structure for triangular meshes.

A half-edge mesh is described by a set of half-edges and several arrays that specify their connectivity (see markup explanation above). This class serves as a container for multiple arrays. For future compatibility with JAX, after initialization, do not modify these arrays in-place; always return a new HeMesh object. The mesh vertices may live in whatever dimension - this does not affect the connectivity bookkeeping.

Half-edge meshes are initialized from a list of triangles and a number of vertices, and can return the original triangles (e.g., to save as a .obj).

All information and methods are purely “combinatorial”. The HeMesh class does not contain the vertex or face positions. These are saved in the GeomHeMesh class that combines a HeMesh (combinatorics) with a couple of other arrays (geometry).

—Conventions—

For vertices, the incident half-edge points away from the vertex.

To describe the mesh boundary, there are two options: 1. Initialize from a triangulation with a boundary. Half-edges without a face (boundary) are assigned heface=-1. 2. Initialize from a triangulation without boundary, where certain vertices are “at infinity”. They should have coordinates [np.inf, np.inf]. Each infinity vertex corresponds to one boundary. For a single boundary, the vertex at infinity is, by convention, the final one.

Starting from a set of triangles, the half-edges are initialized as follows: The 1st N_edges half-edges are (origin_vertex, destination_vertex), in lexicographic order, with origin_vertex < destination_vertex. The 2nd N_edges are their twins, in the same order.

Attributes

incident : Int[jax.Array, “n_vertices”]

orig : Int[jax.Array, “n_hes”]

dest : Int[jax.Array, “n_hes”]

nxt : Int[jax.Array, “n_hes”]

prv : Int[jax.Array, “n_hes”]

twin : Int[jax.Array, “n_hes”]

heface : Int[jax.Array, “n_hes”]

face_incident : Int[jax.Array, “n_faces”]

inf_vertices : tuple[Int]

Property methods (use like attributes)

n_vertices : int

n_hes : int

n_faces : int

n_items : tuple[int, int, int]

faces : Int[jax.Array, “n_faces 3”]

has_inf_vertex : bool

is_inf_face : Bool[jax.Array, “n_faces”]

is_unique : Bool[jax.Array, “n_hes”]

is_inf_he : Bool[jax.Array, “n_hes”]

is_bdry_he : Bool[jax.Array, “n_hes”]

is_bdry_edge : Bool[jax.Array, “n_hes”]

is_bdry : Bool[jax.Array, “n_vertices”]

Static methods

from_triangles : tuple[int, Int[jax.Array, “n_faces 3”], Int[jax.Array, “n_boundaries”] -> HeMesh

Class methods

iterate_around_vertex : int -> Int[jax.Array, “n_neighbors”]

save : str -> None:

Static methods

load : str -> HeMesh


source

test_mesh_validity


def test_mesh_validity(
    h:HeMesh
):

Test if a mesh is valid. Returns True if valid, fails otherwise.

mesh = TriMesh.read_obj("test_meshes/disk.obj")
hemesh = HeMesh.from_triangles(mesh.vertices.shape[0], mesh.faces)
Warning: readOBJ() ignored non-comment line 3:
  o flat_tri_ecmc
test_mesh_validity(hemesh)
True
# hemeshes can be compared for equality and are registered as pytrees

leafs, ts = jax.tree_util.tree_flatten(hemesh)

hemesh, hemesh == hemesh
(HeMesh(N_V=131, N_HE=708, N_F=224), True)
# test iteration around vertex
hemesh.dest[hemesh.iterate_around_vertex(69)], hemesh.orig[hemesh.iterate_around_vertex(56)]
(Array([80, 68, 56, 46], dtype=int64),
 Array([56, 56, 56, 56, 56, 56, 56], dtype=int64))
# boundary in cc-wise order
(hemesh.orig[187], hemesh.dest[187]), hemesh.heface[187], hemesh.is_bdry_he[187],
((Array(58, dtype=int64), Array(70, dtype=int64)),
 Array(-1, dtype=int64),
 Array(True, dtype=bool))
hemesh.is_bdry_he[187], hemesh.is_bdry_he[541], hemesh.heface[541]
(Array(True, dtype=bool), Array(False, dtype=bool), Array(145, dtype=int64))
# to model mesh boundaries, we can add an "infinity" vertex. Not done here, see below
hemesh.has_inf_vertex, hemesh.inf_vertices
(False, ())
fig = plt.figure(figsize=(14,14))

plt.triplot(*mesh.vertices.T, hemesh.faces)
label_plot(mesh.vertices, hemesh.faces, fontsize=10, hemesh=hemesh, face_labels=False)
plt.axis("equal")
(np.float64(-1.10003475),
 np.float64(1.09628575),
 np.float64(-1.09934025),
 np.float64(1.09050125))

# here is how you would do mesh traversal with jax.lax. The issues is that the output size needs to be fixed
# ahead of time, so 

self = hemesh
max_valence = 10
v = 10

initial = jnp.hstack([jnp.array([self.incident[v]]), -1*jnp.ones(max_valence-1, dtype=int)])
jax.lax.fori_loop(1, max_valence, lambda i, x: x.at[i].set(self.twin[x[i-1]]), initial)
Array([ 47, 401,  47, 401,  47, 401,  47, 401,  47, 401], dtype=int64)

Computing with half-edge meshes

By using the arrays of a half-edge meshes to index vertex- or face-positions (in increasingly complex ways), we can compute all sorts of quantities of interests associated with a mesh, for example the edge lengths.

edges = mesh.vertices[hemesh.orig]-mesh.vertices[hemesh.dest]
lengths = jnp.linalg.norm(edges, axis=-1)

Boundary and the vertex at infinity

So far, our mesh representations TriMesh and HeMesh work for triangular meshes with and without boundary. In the HeMesh class, boundary half-edges are assigned to a fictitious -1 face. This convention has a downside. It is not possible to modify the boundary loop of the mesh by edge flips - doing so would result in an invalid state. In a simulation, this artificially limits the mesh’s ability to deform. Instead, we can add “vertices at infinity” and connect all edges in a given boundary to \(\infty\). This turns the mesh into a topological sphere. Now, one can flip boundary edges without the overall number of half-edges changing (so the array shape stays the same). Multiple boundaries are also supported. Each boundary corresponds to a distinct \(\infty\)-vertex (for example, 2 for a cylinder).

The coordinates of the fictitious vertices are set to [np.inf, np.inf] by convention. The boundary is found by iterating around \(\infty\). By convention, \(\infty\)-vertices, if they exist, are the final vertices of the mesh (don’t rely on this - implementation detail).

We generally assume that the mesh has only a single connected component.

The HeMesh class can deal with both the -1-face and the \(\infty\)-vertices conventions. The latter are listed in the inf_vertices attribute of a HeMesh.


source

connect_boundary_to_infinity


def connect_boundary_to_infinity(
    vertices:Float[Array, 'n_vertices 2'], # Vertex positions.
    faces:Int[Array, 'n_faces 3'], # Faces (triangles) as list of vertex indices.
)->tuple: # Vertex positions with infinity vertices appended.
One infinity vertex per boundary loop.

Connect boundary loop(s) to infinity.

New vertices are appeneded to the end of vertex array and have coordinates [np.inf, np.inf].

mesh = TriMesh.read_obj("test_meshes/disk.obj")
hemesh = HeMesh.from_triangles(mesh.vertices.shape[0], mesh.faces)
Warning: readOBJ() ignored non-comment line 3:
  o flat_tri_ecmc
new_vertices, new_faces, infinity_vertices = connect_boundary_to_infinity(mesh.vertices, mesh.faces)
mesh_infty = TriMesh(vertices=new_vertices, faces=new_faces)
hemesh_infty = HeMesh.from_triangles(mesh_infty.vertices.shape[0], mesh_infty.faces,
                                     inf_vertices=infinity_vertices)
# to get back the original faces/vertices, do this:

_ = hemesh_infty.faces[~hemesh_infty.is_inf_face]

# if you want to re-index the triangles so they only refer to non-infinity vertices:
_ = igl.remove_unreferenced(new_vertices, np.asarray(hemesh_infty.faces[~hemesh_infty.is_inf_face]) )
test_mesh_validity(hemesh_infty), (hemesh.is_bdry == (hemesh_infty.is_bdry[:-1] >0)).all()
(True, Array(True, dtype=bool))

Mesh geometry and per-mesh variables

Mesh geometry (vertex and face positions) and per-mesh-item (per-face, per-half-edge, per-vertex) variables are combined into a second data class, the GeomMesh.


source

GeomMesh


def GeomMesh(
    n_vertices:int, n_hes:int, n_faces:int, vertices:Float[Array, '*batch n_vertices dim'],
    face_positions:Float[Array, '*batch n_faces 2']=<factory>, vertex_attribs:dict=<factory>,
    he_attribs:dict=<factory>, face_attribs:dict=<factory>
)->None:

Data class for holding mesh geometry and mesh-associated variables. To be combined with a HeMesh to specify the connectivity.

One array (for vertex positions) must always be present. A second, but optional, standard entry is a set of positions for each face. The mesh coordinates can live in 2d or 3d.

Optionally, vertices, half-edges, and faces can have attributes (stored as dictionaries). The keys of the dictionary should be taken from a suitable ‘enum’. The values are ndarrays, whose 0th axis is (vertices/edges/faces). These attribute dicts are initialized empty and can be set afterwards.

See documentation on HeMesh

Attributes

vertices : Float[jax.Array, “n_vertices 2”]

face_positions : Float[jax.Array, “n_faces 2”]

vertex_attribs : dict[IntEnum, Float[jax.Array, “n_vertices *”]]

he_attribs : dict[IntEnum, Float[jax.Array, “n_hes *”]]

face_attribs : dict[IntEnum, Float[jax.Array, “n_faces *”]]

Property methods (use like attributes)

n_items : tuple[int, int, int]

dim : int

Class methods

validate_dimensions : bool

Static methods

load : str -> GeomHeMesh


source

Mesh


def Mesh(
    args:VAR_POSITIONAL, kwargs:VAR_KEYWORD
):

Combine geometric and connectivity info into a single object.

mesh = TriMesh.read_obj("test_meshes/disk.obj")
hemesh = HeMesh.from_triangles(mesh.vertices.shape[0], mesh.faces)
geommesh = GeomMesh(*hemesh.n_items, mesh.vertices, mesh.face_positions)
combined_mesh = Mesh(geommesh, hemesh)
combined_mesh
Warning: readOBJ() ignored non-comment line 3:
  o flat_tri_ecmc
Mesh(geommesh=GeomMesh(D=2,N_V=131, N_HE=708, N_F=224), hemesh=HeMesh(N_V=131, N_HE=708, N_F=224))
leafs, ts = jax.tree_util.tree_flatten(geommesh) # also a pytree
ts
PyTreeDef(CustomNode(GeomMesh[(131, 708, 224)], [*, *, {}, {}, {}]))
geommesh, geommesh.n_vertices, geommesh.vertices.shape, geommesh.check_compatibility(hemesh), geommesh == geommesh
(GeomMesh(D=2,N_V=131, N_HE=708, N_F=224), 131, (131, 2), True, True)

source

cellplot


def cellplot(
    hemesh:HeMesh, face_positions:Float[Array, 'n_faces 2'], cell_colors:Union=None, mpl_polygon_kwargs:Union=None
)->PatchCollection:

Plot a cell tesselation.

cell_colors can be either a single color (for all cells) or a vector of rgba values. Only interior cells are plotted.

plt.triplot(*geommesh.vertices.T, hemesh.faces)
polygons = cellplot(hemesh, geommesh.face_positions,
                    cell_colors=np.array([0,0,1,0.5]), mpl_polygon_kwargs={"lw": 1, "ec": "k"})
ax = plt.gca()
ax.add_collection(polygons)

plt.axis("equal")
(np.float64(-1.10003475),
 np.float64(1.09628575),
 np.float64(-1.09934025),
 np.float64(1.09050125))

Vertex, half-edge, and face properties

In simulations, we will often want to attach extra information to a mesh’s vertices/edges/faces. In the GeomMesh class, these are saved in three dictionaries, vertex_attribs, he_attribs, face_attribs. Each key/value pair represents one property (for example, the cell target area). All values are arrays, and the first axis corresponds to the number of vertices/half-edges/faces, respectively. To keep track of the possible attributes, we use IntEnum’s as keys (this also ensures keys are hashable, as required by JAX).

mesh = TriMesh.read_obj("test_meshes/disk.obj")
hemesh = HeMesh.from_triangles(mesh.vertices.shape[0], mesh.faces)
geommmesh = GeomMesh(*hemesh.n_items, mesh.vertices, mesh.face_positions)
Warning: readOBJ() ignored non-comment line 3:
  o flat_tri_ecmc
# this is how you set up an enum. It is important to use IntEnum, so we can _order_ the enums.
# The precise Enum you will use depends on your application.

class VertexAttribs(IntEnum):
    TARGET_AREA = 1
    TARGET_PERIMETER = 2

class HeAttribs(IntEnum):
    EDGE_TENSION = 1

class FaceAttribs(IntEnum):
    FACE_AREA = 1
# you can iterate over enums, and they are hashable. The latter is essential for JAX!
print([a for a in VertexAttribs])
# there are multiple ways to access enum entries:
hash(VertexAttribs.TARGET_PERIMETER), HeAttribs.EDGE_TENSION, HeAttribs['EDGE_TENSION'], HeAttribs.EDGE_TENSION.name
[<VertexAttribs.TARGET_AREA: 1>, <VertexAttribs.TARGET_PERIMETER: 2>]
(2, <HeAttribs.EDGE_TENSION: 1>, <HeAttribs.EDGE_TENSION: 1>, 'EDGE_TENSION')
# at initialization, a HeMesh's attribute dictionaries are empty
geommmesh.vertex_attribs
{}
# set some attributes

key1 = jax.random.key(0)
_, key2 = jax.random.split(key1)
_, key3 = jax.random.split(key2)

geommmesh = dataclasses.replace(geommmesh, vertex_attribs={VertexAttribs.TARGET_AREA: jax.random.normal(key=key1,shape=geommmesh.n_vertices),
                                                           VertexAttribs.TARGET_PERIMETER: jax.random.normal(key=key2, shape=geommmesh.n_vertices)})
geommmesh = dataclasses.replace(geommmesh, he_attribs={HeAttribs.EDGE_TENSION: jax.random.normal(key=key3, shape=geommmesh.n_hes)})
geommmesh.he_attribs.keys()
dict_keys([<HeAttribs.EDGE_TENSION: 1>])

Batching

In our simulations, we may want to “batch” over several initial conditions/random seeds/etc. (analogous to batching over training data in normal ML). In JAX, we can efficiently and concisely vectorize operations over such “batch axes” with jax.vmap.

To batch over our custom data structures, we need to pull a small trick - convert a list of HeMesh/GeomMeshe instances into a single mesh with a batch axis for the various arrays. Luckily, this can be done using JAX’s pytree tools. The resulting meshes have an extra “batch” axis in all their array.


source

tree_unstack


def tree_unstack(
    xb:PyTree, axis:int=0
)->list:

Unstack a batched pytree along axis into a list of pytrees.


source

tree_stack


def tree_stack(
    xs:list, axis:int=0
)->PyTree:

Stack a sequence of identical-structure pytrees along a new axis.

## Let us create a bunch of meshes with different initial positions and see if we can batch over them using vmap

key = jax.random.key(0)
sigma = 0.02

batch_geom = []
batch_he = []
for i in range(3):
    key, subkey = jax.random.split(key)
    random_noise = jax.random.normal(subkey, shape=geommmesh.vertices.shape)
    batch_geom.append(dataclasses.replace(geommmesh, vertices=geommmesh.vertices+sigma*random_noise))
    batch_he.append(copy.copy(hemesh))
# define a test function to appy over the batch

def test_function(geommesh: GeomMesh, hemesh: HeMesh) -> Float[jax.Array, " n_vertices"]:
    """Dummy test function."""
    return jnp.ones(geommesh.n_vertices)
# naive batching does not work. JAX needs a "struct-of-arrays", but a list of HeMeshes is an "array-of-structs"
# see https://stackoverflow.com/questions/79123001/storing-and-jax-vmap-over-pytrees

try:
    jax.vmap(test_function)(batch_geom, batch_he)
except ValueError as e:
    print("Expected error:", e)
Expected error: vmap got inconsistent sizes for array axes to be mapped:
  * most axes (21 of them) had size 708, e.g. axis 0 of argument geommesh[0].he_attribs[<HeAttribs.EDGE_TENSION: 1>] of type float64[708];
  * some axes (12 of them) had size 131, e.g. axis 0 of argument geommesh[0].vertices of type float64[131,2];
  * some axes (6 of them) had size 224, e.g. axis 0 of argument geommesh[0].face_positions of type float64[224,2]
# instead, we use a jax.tree.map to "push" the list axis into the underlying arrays.
# the resulting meshes have an extra batch dimension in all of their arrays.

batch_he_array = tree_stack(batch_he)
batch_geom_array = tree_stack(batch_geom)
batch_he_array, batch_geom_array, batch_geom_array.vertices.shape
(HeMesh(N_V=131, N_HE=708, N_F=224),
 GeomMesh(D=2,N_V=131, N_HE=708, N_F=224),
 (3, 131, 2))
# now it works! The result is a single object with batch axis

batch_out =jax.vmap(test_function)(batch_geom_array, batch_he_array)
batch_out.shape
(3, 131)
# we can unpack things again into a list of meshes

isinstance(tree_unstack(batch_out), list)
True

Saving to disk

We save and load TriMesh meshes as standard .obj files (with the hack of using vn lines for the face positions). The HeMesh class is basically a collection of arrays, which we can save to disk using numpy.

from tempfile import TemporaryFile
mesh = TriMesh.read_obj("test_meshes/disk.obj")
hemesh = HeMesh.from_triangles(mesh.vertices.shape[0], mesh.faces)
Warning: readOBJ() ignored non-comment line 3:
  o flat_tri_ecmc
outfile = TemporaryFile()

hemesh.save(outfile)
_ = outfile.seek(0) # simulates closing & reopening file
npzfile = np.load(outfile)
npzfile.files
['incident',
 'orig',
 'dest',
 'twin',
 'nxt',
 'prv',
 'heface',
 'face_incident',
 'inf_vertices']
outfile = TemporaryFile()

hemesh.save(outfile)
_ = outfile.seek(0) # simulates closing & reopening file
reloaded = HeMesh.load(outfile)

np.allclose(reloaded.faces, hemesh.faces)
True