tensormesh.mesh¶
Mesh¶
- class Mesh(mesh: Mesh, reorder: bool = False)[source]¶
Bases:
ModuleFEM mesh — interpolation-node coordinates, per-element-type connectivity, and point/cell/field data attached to either. Mixed-element meshes are supported via
cellsbeing aBufferDictkeyed by element type string (e.g."triangle","quad","tetra").A “point” throughout the API means an interpolation node / degree of freedom — for
order=1this is the corner vertex of an element, fororder>=2it also includes mid-edge, mid-face, and interior nodes.- Parameters:
mesh (meshio.Mesh) – A meshio mesh object to wrap.
reorder (bool, default=False) – Whether to convert connectivity from Gmsh/VTK ordering to TensorMesh internal ordering (delegates to
tensormesh.Element.reorder()).
- points¶
2D tensor of shape \([|\mathcal V|, D]\), where \(|\mathcal V|\) is the number of interpolation nodes and \(D\) is the spatial dimension. Includes high-order nodes (mid-edge / mid-face / interior) when the mesh uses
order >= 2elements.- Type:
- cells¶
Each key is an
element_typestring (seetensormesh.element); the value is a 2D tensor of shape \([|\mathcal C|, B]\), where \(|\mathcal C|\) is the number of elements and \(B\) is the number of basis functions.- Type:
- point_data¶
Per-point fields, keyed by name.
- Type:
BufferDict[str, Tensor], optional
- cell_data¶
Per-element fields. The outer key is an
element_type; the inner key is the field name.- Type:
ModuleDict[str, BufferDict[str, Tensor]], optional
- field_data¶
Global named fields.
- Type:
BufferDict[str, Tensor], optional
- dim2eletyp¶
Each key is a spatial dimension, and the value is a list of element types of that dimension present in the mesh.
- default_eletyp¶
The default element type — a single string for homogeneous meshes, a list of strings for mixed-element meshes. Exposed publicly via the
default_element_typeproperty.
- __init__(mesh: Mesh, reorder: bool = False)[source]¶
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- register_point_data(key: str, value: Tensor)[source]¶
Register a per-point field on
point_data.point_datais atensormesh.nn.BufferDict, so prefer this method over__setitem__to make sure the tensor is tracked as a buffer of the underlyingtorch.nn.Module.
- register_element_data(key: str, value: Dict[str, Tensor] | Tensor)[source]¶
Register a per-element field on
cell_data.For homogeneous meshes
valuemay be a single tensor; for mixed-element meshes pass a dict keyed by element type with one tensor per type.
- to_meshio(reorder: bool = False) Mesh[source]¶
Export this mesh as an in-memory
meshio.Mesh.- Parameters:
reorder (bool, default=False) – If
True, convert connectivity from the internal ordering back to Gmsh/VTK ordering before returning (delegates totensormesh.Element.reorder()).- Returns:
The meshio mesh object.
- Return type:
meshio.Mesh
- save(file_name: str, file_format: str | None = None)[source]¶
Write this mesh to disk via
meshio.write.Boolean point/cell/field arrays are cast to
floatbefore writing (meshio does not supportbool). For.vtk/.vtuoutputs 2-D meshes are padded to 3-D and connectivity is reordered to the Gmsh/VTK convention.
- to_file(file_name: str, file_format: str | None = None)¶
Write this mesh to disk via
meshio.write.Boolean point/cell/field arrays are cast to
floatbefore writing (meshio does not supportbool). For.vtk/.vtuoutputs 2-D meshes are padded to 3-D and connectivity is reordered to the Gmsh/VTK convention.
- node_adjacency(element_type: str | Iterable[str] | None = None) SparseMatrix[source]¶
get the node adjacency matrix, inside each element, the nodes are considered fully connected
- element_adjacency(element_type: str | None = None) SparseMatrix[source]¶
get the element adjacency matrix, the element are considered connected only if they share a boundary/facet
- partition(n_parts: int, method: str = 'spectral', element_type: str | None = None) Tensor[source]¶
Partition the mesh into n_parts
- Parameters:
- Returns:
IntTensor of shape [n_elements] containing partition ID
- Return type:
- color(element_type: str | None = None) Tensor[source]¶
Color the mesh elements such that no adjacent elements share the same color.
- elements(element_type: int | str | Iterable[str] | None = None) Tensor | Dict[str, Tensor][source]¶
Get the element connectivity for specified element types.
Examples
Get elements of default type:
import tensormesh mesh = tensormesh.Mesh.gen_rectangle() elements = mesh.elements() # Returns tensor of shape [n_elements, n_basis]
Get elements of specific type:
elements = mesh.elements("tri6") # Returns tensor for triangle elements
Get elements of multiple types:
elements = mesh.elements(["tri6", "quad9"]) # Returns dict of tensors
Get all element types:
elements = mesh.elements("all") # Returns dict of all element tensors
Get elements of specific dimension:
# Get all 2D elements (triangles, quads) elements = mesh.elements(2) # Returns dict of 2D element tensors # Get all elements matching mesh dimension elements = mesh.elements(-1) # Same as mesh.elements(mesh.dim)
- Parameters:
element_type (str or Iterable[str] or int or None) –
the type of the elements:
if
all, return dict of all elementsif
int, return dict of elements of that dimensionif
str, return elements of that typeif
Iterable[str], return elements of those typesif
None, usedefault_eletyp(default)
- Returns:
if
element_typeisstr, return the corresponding elements connections of shape \([|\mathcal C|, B]\), where \(|\mathcal C|\) is the number of elements and \(B\) is the number of basis functionsif
element_typeisint, return dict of elements of that dimensionif
element_typeisIterable[str], return the mapping of corresponding elements connections of shape \([|\mathcal C|, B]\), where \(|\mathcal C|\) is the number of elements and \(B\) is the number of basis functionsif
element_typeisNone, theelement_typewill be thedefault_element_typeand do as aboveif
element_typeis"all", return all elements as a dictionary
- Return type:
- clone() Mesh[source]¶
Return a deep copy of the mesh that preserves the autograd graph.
Calling
torch.Tensor.cloneon the underlying buffers detaches them from the computation graph, so gradients flowing throughpoints/cell_datawould vanish. This method round-trips throughmeshioto reconstruct the mesh while keeping the connectivity and metadata intact.- Returns:
The cloned mesh.
- Return type:
- plot(values: Dict[str, Tensor] | Dict[str, Iterable[Tensor]] | None = None, save_path: str | None = None, dt: float | None = None, show_mesh: bool = False, fix_clim: bool = False, show: bool = False, **kwargs)[source]¶
Plot the mesh, optionally overlaying scalar fields or animations.
With no
valuesonly the mesh wireframe is drawn. PassingDict[str, torch.Tensor]produces a static multi-panel figure; passingDict[str, List[torch.Tensor]]produces an mp4/gif animation (one frame per list element).- Parameters:
values (None or Dict[str, Tensor] or Dict[str, List[Tensor]]) – the values to plot, if None, only plot the mesh if
Dict[str, torch.Tensor], a static subplots will be plotted, the key is the name of the subplot, the value is of shape \([|\mathcal V|]\), where \(|\mathcal V|\) is the number of interpolation nodes ifDict[str, List[torch.Tensor]], a mp4/gif will be plotted, the key is the name of the subplot, each item in the list is of shape \([|\mathcal V|]\), where \(|\mathcal V|\) is the number of interpolation nodes default: Nonesave_path (str or None) – the path to save the plot, if None, it will not be saved if the
valuesis passed in asDict[str, List[torch.Tensor]], thesave_pathmust endswith ‘.mp4’ or ‘.gif’ default: Nonedt (float or None) – the time interval between each frame, only used when
valuesis passed in asDict[str, List[torch.Tensor]]default: Noneshow_mesh (bool) – whether to overlay the mesh wireframe (and, at
order >= 2, the interpolation nodes) on top of the colour-filled field. Only takes effect whenvaluesis given. default: Falsefix_clim (bool) – whether to fix the color limits across all frames, only used when
valuesis passed in asDict[str, List[torch.Tensor]]. If True, the color limits are determined by the global min and max across all frames, ensuring a consistent colorbar throughout the animation. default: Falseshow (bool) – whether to display the plot interactively (e.g., via
matplotlib.pyplot.show()) default: False**kwargs – additional keyword arguments passed to the underlying visualization functions
- property n_points: int¶
Number of interpolation nodes \(|\mathcal V|\) in the mesh.
Equals
mesh.points.shape[0]. Fororder >= 2this counts high-order nodes (mid-edge, mid-face, interior) as well, not only corner vertices.- Returns:
the number of interpolation nodes \(|\mathcal V|\)
- Return type:
- property n_elements: int¶
Number of elements \(|\mathcal C|\) of the
default_element_type.For mixed-element meshes this sums element counts across every type in
default_element_type.- Returns:
the number of elements \(|\mathcal C|\)
- Return type:
- property boundary_mask: Tensor¶
Boolean mask flagging boundary points.
Looked up from
point_dataunder the key"is_boundary"(preferred) or"boundary_mask". Mesh generators intensormesh.datasetpopulate this automatically.- Returns:
1D bool tensor of shape \([|\mathcal V|]\), where \(|\mathcal V|\) is the number of interpolation nodes; requires
"is_boundary"or"boundary_mask"to live inpoint_data- Return type:
- property dtype: dtype¶
Floating-point dtype of
points(and, by convention, of every buffer in the mesh).- Returns:
the data type of the points, e.g., torch.float32, torch.float64
- Return type:
- property device: device¶
Device on which the mesh tensors live.
- Returns:
the device of the points, e.g., torch.device(“cpu”), torch.device(“cuda:0”)
- Return type:
- classmethod from_meshio(mesh: Mesh, reorder: bool = False)[source]¶
Build a
Meshfrom an in-memorymeshio.Mesh.- Parameters:
mesh (meshio.Mesh) – a meshio mesh object
reorder (bool) – whether to convert connectivity from Gmsh/VTK ordering to TensorMesh internal ordering (delegates to
tensormesh.Element.reorder()).
- Returns:
the mesh object
- Return type:
- classmethod read(file_name: str, file_format: str | None = None, reorder: bool = False)[source]¶
Read a mesh from disk via
meshio.read.- Parameters:
file_name (str) – the name of the file
file_format (str) – the format of the file, e.g., ‘msh’, ‘vtk’, ‘obj’ default is the file extension
reorder (bool) – whether to convert connectivity from Gmsh/VTK ordering to TensorMesh internal ordering (delegates to
tensormesh.Element.reorder()).
- Returns:
the mesh object
- Return type:
- classmethod from_file(file_name: str, file_format: str | None = None, reorder: bool = False)¶
Read a mesh from disk via
meshio.read.- Parameters:
file_name (str) – the name of the file
file_format (str) – the format of the file, e.g., ‘msh’, ‘vtk’, ‘obj’ default is the file extension
reorder (bool) – whether to convert connectivity from Gmsh/VTK ordering to TensorMesh internal ordering (delegates to
tensormesh.Element.reorder()).
- Returns:
the mesh object
- Return type:
- static gen_rectangle(chara_length: float = 0.1, order: int = 1, element_type: str = 'tri', left: float = 0.0, right: float = 1.0, bottom: float = 0.0, top: float = 1.0, visualize: bool = False, cache_path: str | None = None) Mesh[source]¶
Generate a 2-D mesh of an axis-aligned rectangle.
Delegates to
gen_rectangle(), which calls Gmsh under the hood and caches the result ifcache_pathis given.- Parameters:
chara_length (float, optional) – the characteristic length of the mesh, default:
0.1order (int, optional) – the order of the basis function, default:
1element_type (str, optional) – the type of the element, default:
"tri"left (float, optional) – the left boundary of the rectangle, default:
0.0right (float, optional) – the right boundary of the rectangle, default:
1.0bottom (float, optional) – the bottom boundary of the rectangle, default:
0.0top (float, optional) – the top boundary of the rectangle, default:
1.0visualize (bool, optional) – whether to visualize the mesh, default:
Falsecache_path (str, optional) – the path to save the mesh, if
None, it will be decided bygen_rectangle(), default:None
- Returns:
the mesh object
- Return type:
- static gen_hollow_rectangle(chara_length: float = 0.1, order: int = 1, element_type: str = 'quad', outer_left: float = 0.0, outer_right: float = 1.0, outer_bottom: float = 0.0, outer_top: float = 1.0, inner_left: float = 0.25, inner_right: float = 0.75, inner_bottom: float = 0.25, inner_top: float = 0.75, visualize: bool = False, cache_path: str | None = None) Mesh[source]¶
Generate a 2-D mesh of a rectangle with a rectangular hole cut out.
Delegates to
gen_hollow_rectangle().- Parameters:
chara_length (float, optional) – the characteristic length of the mesh, default:
0.1order (int, optional) – the order of the basis function, default:
1element_type (str, optional) – the type of the element, default:
"quad"outer_left (float, optional) – the left boundary of the outer rectangle, default:
0.0outer_right (float, optional) – the right boundary of the outer rectangle, default:
1.0outer_bottom (float, optional) – the bottom boundary of the outer rectangle, default:
0.0outer_top (float, optional) – the top boundary of the outer rectangle, default:
1.0inner_left (float, optional) – the left boundary of the inner rectangle, default:
0.25inner_right (float, optional) – the right boundary of the inner rectangle, default:
0.75inner_bottom (float, optional) – the bottom boundary of the inner rectangle, default:
0.25inner_top (float, optional) – the top boundary of the inner rectangle, default:
0.75visualize (bool, optional) – whether to visualize the mesh, default:
Falsecache_path (str, optional) – the path to save the mesh, if
None, it will be decided bygen_hollow_rectangle(), default:None
- Returns:
the mesh object
- Return type:
- static gen_circle(chara_length: float = 0.1, order: int = 1, element_type: str = 'tri', cx: float = 0.0, cy: float = 0.0, r: float = 1.0, visualize: bool = False, cache_path: str | None = None) Mesh[source]¶
Generate a 2-D mesh of a disk (filled circle).
Delegates to
gen_circle().- Parameters:
chara_length (float, optional) – the characteristic length of the mesh, default:
0.1order (int, optional) – the order of the basis function, default:
1element_type (str, optional) – the type of the element, default:
"tri"cx (float, optional) – the x coordinate of the center of the circle, default:
0.0cy (float, optional) – the y coordinate of the center of the circle, default:
0.0r (float, optional) – the radius of the circle, default:
1.0visualize (bool, optional) – whether to visualize the mesh, default:
Falsecache_path (str, optional) – the path to save the mesh, if
None, it will be decided bygen_circle(), default:None
- Returns:
the mesh object
- Return type:
- static gen_hollow_circle(chara_length: float = 0.1, order: int = 1, element_type: str = 'quad', cx: float = 0.0, cy: float = 0.0, r_inner: float = 1.0, r_outer: float = 2.0, visualize: bool = False, cache_path: str | None = None) Mesh[source]¶
Generate a 2-D mesh of an annulus (disk with a circular hole).
Delegates to
gen_hollow_circle().- Parameters:
chara_length (float, optional) – the characteristic length of the mesh, default:
0.1order (int, optional) – the order of the basis function, default:
1element_type (str, optional) – the type of the element, default:
"quad"cx (float, optional) – the x coordinate of the center of the circle, default:
0.0cy (float, optional) – the y coordinate of the center of the circle, default:
0.0r_inner (float, optional) – the inner radius of the circle, default:
1.0r_outer (float, optional) – the outer radius of the circle, default:
2.0visualize (bool, optional) – whether to visualize the mesh, default:
Falsecache_path (str, optional) – the path to save the mesh, if
None, it will be decided bygen_hollow_circle(), default:None
- Returns:
the mesh object
- Return type:
- static gen_L(chara_length: float = 0.1, order: int = 1, element_type: str = 'quad', left: float = 0.0, right: float = 1.0, bottom: float = 0.0, top: float = 1.0, top_inner: float = 0.5, right_inner: float = 0.5, visualize: bool = False, cache_path: str | None = None) Mesh[source]¶
Generate a 2-D mesh of an L-shaped domain (re-entrant corner benchmark).
Delegates to
gen_L().- Parameters:
chara_length (float, optional) – the characteristic length of the mesh, default:
0.1order (int, optional) – the order of the basis function, default:
1element_type (str, optional) – the type of the element, default:
"quad"left (float, optional) – the left boundary of the rectangle, default:
0.0right (float, optional) – the right boundary of the rectangle, default:
1.0bottom (float, optional) – the bottom boundary of the rectangle, default:
0.0top (float, optional) – the top boundary of the rectangle, default:
1.0top_inner (float, optional) – the top inner boundary of the rectangle, default:
0.5right_inner (float, optional) – the right inner boundary of the rectangle, default:
0.5visualize (bool, optional) – whether to visualize the mesh, default:
Falsecache_path (str, optional) – the path to save the mesh, if
None, it will be decided bygen_L(), default:None
- Returns:
the mesh object
- Return type:
- static gen_cube(chara_length: float = 0.1, order: int = 1, left: float = 0.0, right: float = 1.0, bottom: float = 0.0, top: float = 1.0, front: float = 0.0, back: float = 1.0, visualize: bool = False, cache_path: str | None = None) Mesh[source]¶
Generate a 3-D mesh of an axis-aligned cuboid.
Delegates to
gen_cube().- Parameters:
chara_length (float, optional) – the characteristic length of the mesh, default:
0.1order (int, optional) – the order of the basis function, default:
1left (float, optional) – the left boundary of the cube, default:
0.0right (float, optional) – the right boundary of the cube, default:
1.0bottom (float, optional) – the bottom boundary of the cube, default:
0.0top (float, optional) – the top boundary of the cube, default:
1.0front (float, optional) – the front boundary of the cube, default:
0.0back (float, optional) – the back boundary of the cube, default:
1.0visualize (bool, optional) – whether to visualize the mesh, default:
Falsecache_path (str, optional) – the path to save the mesh, if
None, it will be decided bygen_cube(), default:None
- Returns:
the mesh object
- Return type:
- static gen_hollow_cube(chara_length: float = 0.1, order: int = 1, outer_left: float = 0.0, outer_right: float = 1.0, outer_bottom: float = 0.0, outer_top: float = 1.0, outer_front: float = 0.0, outer_back: float = 1.0, inner_left: float = 0.25, inner_right: float = 0.75, inner_bottom: float = 0.25, inner_top: float = 0.75, inner_front: float = 0.25, inner_back: float = 0.75, visualize: bool = False, cache_path: str | None = None) Mesh[source]¶
Generate a 3-D mesh of a cuboid with a cuboidal hole.
Delegates to
gen_hollow_cube().- Parameters:
chara_length (float, optional) – the characteristic length of the mesh, default:
0.1order (int, optional) – the order of the basis function, default:
1outer_left (float, optional) – the left boundary of the outer cube, default:
0.0outer_right (float, optional) – the right boundary of the outer cube, default:
1.0outer_bottom (float, optional) – the bottom boundary of the outer cube, default:
0.0outer_top (float, optional) – the top boundary of the outer cube, default:
1.0outer_front (float, optional) – the front boundary of the outer cube, default:
0.0outer_back (float, optional) – the back boundary of the outer cube, default:
1.0inner_left (float, optional) – the left boundary of the inner cube, default:
0.25inner_right (float, optional) – the right boundary of the inner cube, default:
0.75inner_bottom (float, optional) – the bottom boundary of the inner cube, default:
0.25inner_top (float, optional) – the top boundary of the inner cube, default:
0.75inner_front (float, optional) – the front boundary of the inner cube, default:
0.25inner_back (float, optional) – the back boundary of the inner cube, default:
0.75visualize (bool, optional) – whether to visualize the mesh, default:
Falsecache_path (str, optional) – the path to save the mesh, if
None, it will be decided bygen_hollow_cube(), default:None
- Returns:
the mesh object
- Return type:
- static gen_sphere(chara_length: float = 0.1, order: int = 1, cx: float = 0.0, cy: float = 0.0, cz: float = 0.0, r: float = 1.0, visualize: bool = False, cache_path: str | None = None) Mesh[source]¶
Generate a 3-D mesh of a solid ball (filled sphere).
Delegates to
gen_sphere().- Parameters:
chara_length (float, optional) – the characteristic length of the mesh, default:
0.1order (int, optional) – the order of the basis function, default:
1cx (float, optional) – the x coordinate of the center of the sphere, default:
0.0cy (float, optional) – the y coordinate of the center of the sphere, default:
0.0cz (float, optional) – the z coordinate of the center of the sphere, default:
0.0r (float, optional) – the radius of the sphere, default:
1.0visualize (bool, optional) – whether to visualize the mesh, default:
Falsecache_path (str, optional) – the path to save the mesh, if
None, it will be decided bygen_sphere(), default:None
- Returns:
the mesh object
- Return type:
- static gen_hollow_sphere(chara_length: float = 0.1, order: int = 1, cx: float = 0.0, cy: float = 0.0, cz: float = 0.0, r_inner: float = 1.0, r_outer: float = 2.0, visualize: bool = False, cache_path: str | None = None) Mesh[source]¶
Generate a 3-D mesh of a spherical shell (ball with a concentric spherical cavity).
Delegates to
gen_hollow_sphere().- Parameters:
chara_length (float, optional) – the characteristic length of the mesh, default:
0.1order (int, optional) – the order of the basis function, default:
1cx (float, optional) – the x coordinate of the center of the sphere, default:
0.0cy (float, optional) – the y coordinate of the center of the sphere, default:
0.0cz (float, optional) – the z coordinate of the center of the sphere, default:
0.0r_inner (float, optional) – the inner radius of the sphere, default:
1.0r_outer (float, optional) – the outer radius of the sphere, default:
2.0visualize (bool, optional) – whether to visualize the mesh, default:
Falsecache_path (str, optional) – the path to save the mesh, if
None, it will be decided bygen_hollow_sphere(), default:None
- Returns:
the mesh object
- Return type:
Graph algorithms¶
Helpers used internally by DistributedMesh
for race-free parallel assembly and domain decomposition.
- graph_coloring(adjacency: SparseMatrix, max_iter: int = 100) Tensor[source]¶
Parallel graph coloring algorithm (Iterative Conflict Resolution). Runs efficiently on GPU.
- Parameters:
adjacency (SparseMatrix) – The adjacency matrix of the graph. shape: [n_nodes, n_nodes]
max_iter (int) – Maximum number of conflict resolution iterations.
- Returns:
IntTensor of shape [n_nodes] containing the color ID for each node.
- Return type:
- graph_partition(adjacency: SparseMatrix, n_parts: int, method: str = 'spectral') Tensor[source]¶
Partition a graph into balanced subdomains using GPU-accelerated algorithms.
This function divides graph nodes into
n_partsgroups such that:Each partition has approximately equal number of nodes (load balance)
The number of edges crossing partitions is minimized (minimal interface)
- Parameters:
adjacency (SparseMatrix) – The adjacency matrix of the graph. Shape:
[n_nodes, n_nodes]. Should be symmetric for undirected graphs.n_parts (int) – Number of partitions to create. For spectral method, works best when
n_partsis a power of 2.method (str, optional) –
Partitioning algorithm to use:
'spectral': Recursive Spectral Bisection using Fiedler vector. Computed viatorch.lobpcgon GPU. Default method.'metis': Uses pymetis library (requires pymetis installation). Falls back to spectral if pymetis is not available.
- Returns:
Integer tensor of shape
[n_nodes]containing partition labels in range[0, n_parts-1].- Return type:
Notes
The spectral method computes the Fiedler vector (second smallest eigenvector of the graph Laplacian \(L = D - A\)) and recursively bisects based on median values. This preserves locality in the partitioning.
For small subgraphs (< 2048 nodes), dense eigensolvers are used for stability. For larger graphs, LOBPCG is used for efficiency.
Examples
>>> from tensormesh.mesh import graph_partition >>> # Partition element adjacency into 4 parts >>> adj = mesh.element_adjacency() >>> labels = graph_partition(adj, n_parts=4, method='spectral') >>> print(f"Partition sizes: {[(labels == i).sum().item() for i in range(4)]}")