tensormesh.mesh

Mesh

class Mesh(mesh: Mesh, reorder: bool = False)[source]

Bases: Module

FEM mesh — interpolation-node coordinates, per-element-type connectivity, and point/cell/field data attached to either. Mixed-element meshes are supported via cells being a BufferDict keyed by element type string (e.g. "triangle", "quad", "tetra").

A “point” throughout the API means an interpolation node / degree of freedom — for order=1 this is the corner vertex of an element, for order>=2 it also includes mid-edge, mid-face, and interior nodes.

Parameters:
  • mesh (meshio.Mesh) – A meshio mesh object to wrap.

  • reorder (bool, default=False) – Whether to convert connectivity from Gmsh/VTK ordering to TensorMesh internal ordering (delegates to tensormesh.Element.reorder()).

points

2D tensor of shape \([|\mathcal V|, D]\), where \(|\mathcal V|\) is the number of interpolation nodes and \(D\) is the spatial dimension. Includes high-order nodes (mid-edge / mid-face / interior) when the mesh uses order >= 2 elements.

Type:

Tensor

cells

Each key is an element_type string (see tensormesh.element); the value is a 2D tensor of shape \([|\mathcal C|, B]\), where \(|\mathcal C|\) is the number of elements and \(B\) is the number of basis functions.

Type:

BufferDict[str, Tensor]

point_data

Per-point fields, keyed by name.

Type:

BufferDict[str, Tensor], optional

cell_data

Per-element fields. The outer key is an element_type; the inner key is the field name.

Type:

ModuleDict[str, BufferDict[str, Tensor]], optional

field_data

Global named fields.

Type:

BufferDict[str, Tensor], optional

cell_sets

Named subsets of cells, kept in meshio’s native format.

Type:

dict, optional

dim2eletyp

Each key is a spatial dimension, and the value is a list of element types of that dimension present in the mesh.

Type:

Dict[int, List[str]]

default_eletyp

The default element type — a single string for homogeneous meshes, a list of strings for mixed-element meshes. Exposed publicly via the default_element_type property.

Type:

str or List[str]

__init__(mesh: Mesh, reorder: bool = False)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

register_point_data(key: str, value: Tensor)[source]

Register a per-point field on point_data.

point_data is a tensormesh.nn.BufferDict, so prefer this method over __setitem__ to make sure the tensor is tracked as a buffer of the underlying torch.nn.Module.

Parameters:
  • key (str) – the key of the value

  • value (Tensor) – tensor of shape \([|\mathcal V|, ...]\), where \(|\mathcal V|\) is the number of interpolation nodes (mesh.n_points)

Returns:

self will be returned

Return type:

Mesh

register_element_data(key: str, value: Dict[str, Tensor] | Tensor)[source]

Register a per-element field on cell_data.

For homogeneous meshes value may be a single tensor; for mixed-element meshes pass a dict keyed by element type with one tensor per type.

to_meshio(reorder: bool = False) Mesh[source]

Export this mesh as an in-memory meshio.Mesh.

Parameters:

reorder (bool, default=False) – If True, convert connectivity from the internal ordering back to Gmsh/VTK ordering before returning (delegates to tensormesh.Element.reorder()).

Returns:

The meshio mesh object.

Return type:

meshio.Mesh

save(file_name: str, file_format: str | None = None)[source]

Write this mesh to disk via meshio.write.

Boolean point/cell/field arrays are cast to float before writing (meshio does not support bool). For .vtk / .vtu outputs 2-D meshes are padded to 3-D and connectivity is reordered to the Gmsh/VTK convention.

Parameters:
  • file_name (str) – the name of the file

  • file_format (str) – the format of the file, e.g., ‘msh’, ‘vtk’, ‘obj’ default is the file extension

Returns:

self will be returned

Return type:

Mesh

to_file(file_name: str, file_format: str | None = None)

Write this mesh to disk via meshio.write.

Boolean point/cell/field arrays are cast to float before writing (meshio does not support bool). For .vtk / .vtu outputs 2-D meshes are padded to 3-D and connectivity is reordered to the Gmsh/VTK convention.

Parameters:
  • file_name (str) – the name of the file

  • file_format (str) – the format of the file, e.g., ‘msh’, ‘vtk’, ‘obj’ default is the file extension

Returns:

self will be returned

Return type:

Mesh

node_adjacency(element_type: str | Iterable[str] | None = None) SparseMatrix[source]

get the node adjacency matrix, inside each element, the nodes are considered fully connected

Parameters:

element_type (str or Iterable[str] or None) – the type of the elements if None is the default_element_type default : None

Returns:

the adjacency matrix of points \([|\mathcal V|,|\mathcal V|]\), where \(|\mathcal V|\) is the number of interpolation nodes

Return type:

SparseMatrix

element_adjacency(element_type: str | None = None) SparseMatrix[source]

get the element adjacency matrix, the element are considered connected only if they share a boundary/facet

Parameters:

element_type (str or Iterable[str] or None) – the type of the elements, should be of same dimension if None is the default_element_type default : None

Returns:

the adjacency matrix of elements \([|\mathcal C|,|\mathcal C|]\), where \(|\mathcal C|\) is the number of elements

Return type:

SparseMatrix

partition(n_parts: int, method: str = 'spectral', element_type: str | None = None) Tensor[source]

Partition the mesh into n_parts

Parameters:
  • n_parts (int) – Number of partitions

  • method (str, optional) – Partition method: ‘spectral’ or ‘metis’. Default is ‘spectral’.

  • element_type (str or Iterable[str] or None) – the type of the elements to partition based on connectivity.

Returns:

IntTensor of shape [n_elements] containing partition ID

Return type:

Tensor

color(element_type: str | None = None) Tensor[source]

Color the mesh elements such that no adjacent elements share the same color.

Parameters:

element_type (str or Iterable[str] or None) – the type of the elements.

Returns:

IntTensor of shape [n_elements] containing color ID

Return type:

Tensor

elements(element_type: int | str | Iterable[str] | None = None) Tensor | Dict[str, Tensor][source]

Get the element connectivity for specified element types.

Examples

  1. Get elements of default type:

import tensormesh
mesh = tensormesh.Mesh.gen_rectangle()
elements = mesh.elements() # Returns tensor of shape [n_elements, n_basis]
  1. Get elements of specific type:

elements = mesh.elements("tri6") # Returns tensor for triangle elements
  1. Get elements of multiple types:

elements = mesh.elements(["tri6", "quad9"]) # Returns dict of tensors
  1. Get all element types:

elements = mesh.elements("all") # Returns dict of all element tensors
  1. Get elements of specific dimension:

# Get all 2D elements (triangles, quads)
elements = mesh.elements(2) # Returns dict of 2D element tensors

# Get all elements matching mesh dimension
elements = mesh.elements(-1) # Same as mesh.elements(mesh.dim)
Parameters:

element_type (str or Iterable[str] or int or None) –

the type of the elements:

  • if all, return dict of all elements

  • if int, return dict of elements of that dimension

  • if str, return elements of that type

  • if Iterable[str], return elements of those types

  • if None, use default_eletyp (default)

Returns:

  • if element_type is str, return the corresponding elements connections of shape \([|\mathcal C|, B]\), where \(|\mathcal C|\) is the number of elements and \(B\) is the number of basis functions

  • if element_type is int, return dict of elements of that dimension

  • if element_type is Iterable[str], return the mapping of corresponding elements connections of shape \([|\mathcal C|, B]\), where \(|\mathcal C|\) is the number of elements and \(B\) is the number of basis functions

  • if element_type is None, the element_type will be the default_element_type and do as above

  • if element_type is "all", return all elements as a dictionary

Return type:

Tensor or Dict[str, Tensor]

clone() Mesh[source]

Return a deep copy of the mesh that preserves the autograd graph.

Calling torch.Tensor.clone on the underlying buffers detaches them from the computation graph, so gradients flowing through points / cell_data would vanish. This method round-trips through meshio to reconstruct the mesh while keeping the connectivity and metadata intact.

Returns:

The cloned mesh.

Return type:

Mesh

plot(values: Dict[str, Tensor] | Dict[str, Iterable[Tensor]] | None = None, save_path: str | None = None, dt: float | None = None, show_mesh: bool = False, fix_clim: bool = False, show: bool = False, **kwargs)[source]

Plot the mesh, optionally overlaying scalar fields or animations.

With no values only the mesh wireframe is drawn. Passing Dict[str, torch.Tensor] produces a static multi-panel figure; passing Dict[str, List[torch.Tensor]] produces an mp4/gif animation (one frame per list element).

Parameters:
  • values (None or Dict[str, Tensor] or Dict[str, List[Tensor]]) – the values to plot, if None, only plot the mesh if Dict[str, torch.Tensor], a static subplots will be plotted, the key is the name of the subplot, the value is of shape \([|\mathcal V|]\), where \(|\mathcal V|\) is the number of interpolation nodes if Dict[str, List[torch.Tensor]], a mp4/gif will be plotted, the key is the name of the subplot, each item in the list is of shape \([|\mathcal V|]\), where \(|\mathcal V|\) is the number of interpolation nodes default: None

  • save_path (str or None) – the path to save the plot, if None, it will not be saved if the values is passed in as Dict[str, List[torch.Tensor]], the save_path must endswith ‘.mp4’ or ‘.gif’ default: None

  • dt (float or None) – the time interval between each frame, only used when values is passed in as Dict[str, List[torch.Tensor]] default: None

  • show_mesh (bool) – whether to overlay the mesh wireframe (and, at order >= 2, the interpolation nodes) on top of the colour-filled field. Only takes effect when values is given. default: False

  • fix_clim (bool) – whether to fix the color limits across all frames, only used when values is passed in as Dict[str, List[torch.Tensor]]. If True, the color limits are determined by the global min and max across all frames, ensuring a consistent colorbar throughout the animation. default: False

  • show (bool) – whether to display the plot interactively (e.g., via matplotlib.pyplot.show()) default: False

  • **kwargs – additional keyword arguments passed to the underlying visualization functions

property n_points: int

Number of interpolation nodes \(|\mathcal V|\) in the mesh.

Equals mesh.points.shape[0]. For order >= 2 this counts high-order nodes (mid-edge, mid-face, interior) as well, not only corner vertices.

Returns:

the number of interpolation nodes \(|\mathcal V|\)

Return type:

int

property n_elements: int

Number of elements \(|\mathcal C|\) of the default_element_type.

For mixed-element meshes this sums element counts across every type in default_element_type.

Returns:

the number of elements \(|\mathcal C|\)

Return type:

int

property boundary_mask: Tensor

Boolean mask flagging boundary points.

Looked up from point_data under the key "is_boundary" (preferred) or "boundary_mask". Mesh generators in tensormesh.dataset populate this automatically.

Returns:

1D bool tensor of shape \([|\mathcal V|]\), where \(|\mathcal V|\) is the number of interpolation nodes; requires "is_boundary" or "boundary_mask" to live in point_data

Return type:

Tensor

property dtype: dtype

Floating-point dtype of points (and, by convention, of every buffer in the mesh).

Returns:

the data type of the points, e.g., torch.float32, torch.float64

Return type:

dtype

property device: device

Device on which the mesh tensors live.

Returns:

the device of the points, e.g., torch.device(“cpu”), torch.device(“cuda:0”)

Return type:

device

classmethod from_meshio(mesh: Mesh, reorder: bool = False)[source]

Build a Mesh from an in-memory meshio.Mesh.

Parameters:
  • mesh (meshio.Mesh) – a meshio mesh object

  • reorder (bool) – whether to convert connectivity from Gmsh/VTK ordering to TensorMesh internal ordering (delegates to tensormesh.Element.reorder()).

Returns:

the mesh object

Return type:

Mesh

classmethod read(file_name: str, file_format: str | None = None, reorder: bool = False)[source]

Read a mesh from disk via meshio.read.

Parameters:
  • file_name (str) – the name of the file

  • file_format (str) – the format of the file, e.g., ‘msh’, ‘vtk’, ‘obj’ default is the file extension

  • reorder (bool) – whether to convert connectivity from Gmsh/VTK ordering to TensorMesh internal ordering (delegates to tensormesh.Element.reorder()).

Returns:

the mesh object

Return type:

Mesh

classmethod from_file(file_name: str, file_format: str | None = None, reorder: bool = False)

Read a mesh from disk via meshio.read.

Parameters:
  • file_name (str) – the name of the file

  • file_format (str) – the format of the file, e.g., ‘msh’, ‘vtk’, ‘obj’ default is the file extension

  • reorder (bool) – whether to convert connectivity from Gmsh/VTK ordering to TensorMesh internal ordering (delegates to tensormesh.Element.reorder()).

Returns:

the mesh object

Return type:

Mesh

static gen_rectangle(chara_length: float = 0.1, order: int = 1, element_type: str = 'tri', left: float = 0.0, right: float = 1.0, bottom: float = 0.0, top: float = 1.0, visualize: bool = False, cache_path: str | None = None) Mesh[source]

Generate a 2-D mesh of an axis-aligned rectangle.

Delegates to gen_rectangle(), which calls Gmsh under the hood and caches the result if cache_path is given.

Parameters:
  • chara_length (float, optional) – the characteristic length of the mesh, default: 0.1

  • order (int, optional) – the order of the basis function, default: 1

  • element_type (str, optional) – the type of the element, default: "tri"

  • left (float, optional) – the left boundary of the rectangle, default: 0.0

  • right (float, optional) – the right boundary of the rectangle, default: 1.0

  • bottom (float, optional) – the bottom boundary of the rectangle, default: 0.0

  • top (float, optional) – the top boundary of the rectangle, default: 1.0

  • visualize (bool, optional) – whether to visualize the mesh, default: False

  • cache_path (str, optional) – the path to save the mesh, if None, it will be decided by gen_rectangle(), default: None

Returns:

the mesh object

Return type:

Mesh

static gen_hollow_rectangle(chara_length: float = 0.1, order: int = 1, element_type: str = 'quad', outer_left: float = 0.0, outer_right: float = 1.0, outer_bottom: float = 0.0, outer_top: float = 1.0, inner_left: float = 0.25, inner_right: float = 0.75, inner_bottom: float = 0.25, inner_top: float = 0.75, visualize: bool = False, cache_path: str | None = None) Mesh[source]

Generate a 2-D mesh of a rectangle with a rectangular hole cut out.

Delegates to gen_hollow_rectangle().

Parameters:
  • chara_length (float, optional) – the characteristic length of the mesh, default: 0.1

  • order (int, optional) – the order of the basis function, default: 1

  • element_type (str, optional) – the type of the element, default: "quad"

  • outer_left (float, optional) – the left boundary of the outer rectangle, default: 0.0

  • outer_right (float, optional) – the right boundary of the outer rectangle, default: 1.0

  • outer_bottom (float, optional) – the bottom boundary of the outer rectangle, default: 0.0

  • outer_top (float, optional) – the top boundary of the outer rectangle, default: 1.0

  • inner_left (float, optional) – the left boundary of the inner rectangle, default: 0.25

  • inner_right (float, optional) – the right boundary of the inner rectangle, default: 0.75

  • inner_bottom (float, optional) – the bottom boundary of the inner rectangle, default: 0.25

  • inner_top (float, optional) – the top boundary of the inner rectangle, default: 0.75

  • visualize (bool, optional) – whether to visualize the mesh, default: False

  • cache_path (str, optional) – the path to save the mesh, if None, it will be decided by gen_hollow_rectangle(), default: None

Returns:

the mesh object

Return type:

Mesh

static gen_circle(chara_length: float = 0.1, order: int = 1, element_type: str = 'tri', cx: float = 0.0, cy: float = 0.0, r: float = 1.0, visualize: bool = False, cache_path: str | None = None) Mesh[source]

Generate a 2-D mesh of a disk (filled circle).

Delegates to gen_circle().

Parameters:
  • chara_length (float, optional) – the characteristic length of the mesh, default: 0.1

  • order (int, optional) – the order of the basis function, default: 1

  • element_type (str, optional) – the type of the element, default: "tri"

  • cx (float, optional) – the x coordinate of the center of the circle, default: 0.0

  • cy (float, optional) – the y coordinate of the center of the circle, default: 0.0

  • r (float, optional) – the radius of the circle, default: 1.0

  • visualize (bool, optional) – whether to visualize the mesh, default: False

  • cache_path (str, optional) – the path to save the mesh, if None, it will be decided by gen_circle(), default: None

Returns:

the mesh object

Return type:

Mesh

static gen_hollow_circle(chara_length: float = 0.1, order: int = 1, element_type: str = 'quad', cx: float = 0.0, cy: float = 0.0, r_inner: float = 1.0, r_outer: float = 2.0, visualize: bool = False, cache_path: str | None = None) Mesh[source]

Generate a 2-D mesh of an annulus (disk with a circular hole).

Delegates to gen_hollow_circle().

Parameters:
  • chara_length (float, optional) – the characteristic length of the mesh, default: 0.1

  • order (int, optional) – the order of the basis function, default: 1

  • element_type (str, optional) – the type of the element, default: "quad"

  • cx (float, optional) – the x coordinate of the center of the circle, default: 0.0

  • cy (float, optional) – the y coordinate of the center of the circle, default: 0.0

  • r_inner (float, optional) – the inner radius of the circle, default: 1.0

  • r_outer (float, optional) – the outer radius of the circle, default: 2.0

  • visualize (bool, optional) – whether to visualize the mesh, default: False

  • cache_path (str, optional) – the path to save the mesh, if None, it will be decided by gen_hollow_circle(), default: None

Returns:

the mesh object

Return type:

Mesh

static gen_L(chara_length: float = 0.1, order: int = 1, element_type: str = 'quad', left: float = 0.0, right: float = 1.0, bottom: float = 0.0, top: float = 1.0, top_inner: float = 0.5, right_inner: float = 0.5, visualize: bool = False, cache_path: str | None = None) Mesh[source]

Generate a 2-D mesh of an L-shaped domain (re-entrant corner benchmark).

Delegates to gen_L().

Parameters:
  • chara_length (float, optional) – the characteristic length of the mesh, default: 0.1

  • order (int, optional) – the order of the basis function, default: 1

  • element_type (str, optional) – the type of the element, default: "quad"

  • left (float, optional) – the left boundary of the rectangle, default: 0.0

  • right (float, optional) – the right boundary of the rectangle, default: 1.0

  • bottom (float, optional) – the bottom boundary of the rectangle, default: 0.0

  • top (float, optional) – the top boundary of the rectangle, default: 1.0

  • top_inner (float, optional) – the top inner boundary of the rectangle, default: 0.5

  • right_inner (float, optional) – the right inner boundary of the rectangle, default: 0.5

  • visualize (bool, optional) – whether to visualize the mesh, default: False

  • cache_path (str, optional) – the path to save the mesh, if None, it will be decided by gen_L(), default: None

Returns:

the mesh object

Return type:

Mesh

static gen_cube(chara_length: float = 0.1, order: int = 1, left: float = 0.0, right: float = 1.0, bottom: float = 0.0, top: float = 1.0, front: float = 0.0, back: float = 1.0, visualize: bool = False, cache_path: str | None = None) Mesh[source]

Generate a 3-D mesh of an axis-aligned cuboid.

Delegates to gen_cube().

Parameters:
  • chara_length (float, optional) – the characteristic length of the mesh, default: 0.1

  • order (int, optional) – the order of the basis function, default: 1

  • left (float, optional) – the left boundary of the cube, default: 0.0

  • right (float, optional) – the right boundary of the cube, default: 1.0

  • bottom (float, optional) – the bottom boundary of the cube, default: 0.0

  • top (float, optional) – the top boundary of the cube, default: 1.0

  • front (float, optional) – the front boundary of the cube, default: 0.0

  • back (float, optional) – the back boundary of the cube, default: 1.0

  • visualize (bool, optional) – whether to visualize the mesh, default: False

  • cache_path (str, optional) – the path to save the mesh, if None, it will be decided by gen_cube(), default: None

Returns:

the mesh object

Return type:

Mesh

static gen_hollow_cube(chara_length: float = 0.1, order: int = 1, outer_left: float = 0.0, outer_right: float = 1.0, outer_bottom: float = 0.0, outer_top: float = 1.0, outer_front: float = 0.0, outer_back: float = 1.0, inner_left: float = 0.25, inner_right: float = 0.75, inner_bottom: float = 0.25, inner_top: float = 0.75, inner_front: float = 0.25, inner_back: float = 0.75, visualize: bool = False, cache_path: str | None = None) Mesh[source]

Generate a 3-D mesh of a cuboid with a cuboidal hole.

Delegates to gen_hollow_cube().

Parameters:
  • chara_length (float, optional) – the characteristic length of the mesh, default: 0.1

  • order (int, optional) – the order of the basis function, default: 1

  • outer_left (float, optional) – the left boundary of the outer cube, default: 0.0

  • outer_right (float, optional) – the right boundary of the outer cube, default: 1.0

  • outer_bottom (float, optional) – the bottom boundary of the outer cube, default: 0.0

  • outer_top (float, optional) – the top boundary of the outer cube, default: 1.0

  • outer_front (float, optional) – the front boundary of the outer cube, default: 0.0

  • outer_back (float, optional) – the back boundary of the outer cube, default: 1.0

  • inner_left (float, optional) – the left boundary of the inner cube, default: 0.25

  • inner_right (float, optional) – the right boundary of the inner cube, default: 0.75

  • inner_bottom (float, optional) – the bottom boundary of the inner cube, default: 0.25

  • inner_top (float, optional) – the top boundary of the inner cube, default: 0.75

  • inner_front (float, optional) – the front boundary of the inner cube, default: 0.25

  • inner_back (float, optional) – the back boundary of the inner cube, default: 0.75

  • visualize (bool, optional) – whether to visualize the mesh, default: False

  • cache_path (str, optional) – the path to save the mesh, if None, it will be decided by gen_hollow_cube(), default: None

Returns:

the mesh object

Return type:

Mesh

static gen_sphere(chara_length: float = 0.1, order: int = 1, cx: float = 0.0, cy: float = 0.0, cz: float = 0.0, r: float = 1.0, visualize: bool = False, cache_path: str | None = None) Mesh[source]

Generate a 3-D mesh of a solid ball (filled sphere).

Delegates to gen_sphere().

Parameters:
  • chara_length (float, optional) – the characteristic length of the mesh, default: 0.1

  • order (int, optional) – the order of the basis function, default: 1

  • cx (float, optional) – the x coordinate of the center of the sphere, default: 0.0

  • cy (float, optional) – the y coordinate of the center of the sphere, default: 0.0

  • cz (float, optional) – the z coordinate of the center of the sphere, default: 0.0

  • r (float, optional) – the radius of the sphere, default: 1.0

  • visualize (bool, optional) – whether to visualize the mesh, default: False

  • cache_path (str, optional) – the path to save the mesh, if None, it will be decided by gen_sphere(), default: None

Returns:

the mesh object

Return type:

Mesh

static gen_hollow_sphere(chara_length: float = 0.1, order: int = 1, cx: float = 0.0, cy: float = 0.0, cz: float = 0.0, r_inner: float = 1.0, r_outer: float = 2.0, visualize: bool = False, cache_path: str | None = None) Mesh[source]

Generate a 3-D mesh of a spherical shell (ball with a concentric spherical cavity).

Delegates to gen_hollow_sphere().

Parameters:
  • chara_length (float, optional) – the characteristic length of the mesh, default: 0.1

  • order (int, optional) – the order of the basis function, default: 1

  • cx (float, optional) – the x coordinate of the center of the sphere, default: 0.0

  • cy (float, optional) – the y coordinate of the center of the sphere, default: 0.0

  • cz (float, optional) – the z coordinate of the center of the sphere, default: 0.0

  • r_inner (float, optional) – the inner radius of the sphere, default: 1.0

  • r_outer (float, optional) – the outer radius of the sphere, default: 2.0

  • visualize (bool, optional) – whether to visualize the mesh, default: False

  • cache_path (str, optional) – the path to save the mesh, if None, it will be decided by gen_hollow_sphere(), default: None

Returns:

the mesh object

Return type:

Mesh

Graph algorithms

Helpers used internally by DistributedMesh for race-free parallel assembly and domain decomposition.

graph_coloring(adjacency: SparseMatrix, max_iter: int = 100) Tensor[source]

Parallel graph coloring algorithm (Iterative Conflict Resolution). Runs efficiently on GPU.

Parameters:
  • adjacency (SparseMatrix) – The adjacency matrix of the graph. shape: [n_nodes, n_nodes]

  • max_iter (int) – Maximum number of conflict resolution iterations.

Returns:

IntTensor of shape [n_nodes] containing the color ID for each node.

Return type:

Tensor

graph_partition(adjacency: SparseMatrix, n_parts: int, method: str = 'spectral') Tensor[source]

Partition a graph into balanced subdomains using GPU-accelerated algorithms.

This function divides graph nodes into n_parts groups such that:

  1. Each partition has approximately equal number of nodes (load balance)

  2. The number of edges crossing partitions is minimized (minimal interface)

Parameters:
  • adjacency (SparseMatrix) – The adjacency matrix of the graph. Shape: [n_nodes, n_nodes]. Should be symmetric for undirected graphs.

  • n_parts (int) – Number of partitions to create. For spectral method, works best when n_parts is a power of 2.

  • method (str, optional) –

    Partitioning algorithm to use:

    • 'spectral': Recursive Spectral Bisection using Fiedler vector. Computed via torch.lobpcg on GPU. Default method.

    • 'metis': Uses pymetis library (requires pymetis installation). Falls back to spectral if pymetis is not available.

Returns:

Integer tensor of shape [n_nodes] containing partition labels in range [0, n_parts-1].

Return type:

Tensor

Notes

The spectral method computes the Fiedler vector (second smallest eigenvector of the graph Laplacian \(L = D - A\)) and recursively bisects based on median values. This preserves locality in the partitioning.

For small subgraphs (< 2048 nodes), dense eigensolvers are used for stability. For larger graphs, LOBPCG is used for efficiency.

Examples

>>> from tensormesh.mesh import graph_partition
>>> # Partition element adjacency into 4 parts
>>> adj = mesh.element_adjacency()
>>> labels = graph_partition(adj, n_parts=4, method='spectral')
>>> print(f"Partition sizes: {[(labels == i).sum().item() for i in range(4)]}")