tensormesh.mesh

Mesh

class Mesh(mesh: Mesh, reorder: bool = False)[源代码]

基类:Module

FEM mesh — interpolation-node coordinates, per-element-type connectivity, and point/cell/field data attached to either. Mixed-element meshes are supported via cells being a BufferDict keyed by element type string (e.g. "triangle", "quad", "tetra").

A "point" throughout the API means an interpolation node / degree of freedom — for order=1 this is the corner vertex of an element, for order>=2 it also includes mid-edge, mid-face, and interior nodes.

参数:
  • mesh (meshio.Mesh) -- A meshio mesh object to wrap.

  • reorder (bool, default=False) -- Whether to convert connectivity from Gmsh/VTK ordering to TensorMesh internal ordering (delegates to tensormesh.Element.reorder()).

points

2D tensor of shape \([|\mathcal V|, D]\), where \(|\mathcal V|\) is the number of interpolation nodes and \(D\) is the spatial dimension. Includes high-order nodes (mid-edge / mid-face / interior) when the mesh uses order >= 2 elements.

Type:

Tensor

cells

Each key is an element_type string (see tensormesh.element); the value is a 2D tensor of shape \([|\mathcal C|, B]\), where \(|\mathcal C|\) is the number of elements and \(B\) is the number of basis functions.

Type:

BufferDict[str, Tensor]

point_data

Per-point fields, keyed by name.

Type:

BufferDict[str, Tensor], optional

cell_data

Per-element fields. The outer key is an element_type; the inner key is the field name.

Type:

ModuleDict[str, BufferDict[str, Tensor]], optional

field_data

Global named fields.

Type:

BufferDict[str, Tensor], optional

cell_sets

Named subsets of cells, kept in meshio's native format.

Type:

dict, optional

dim2eletyp

Each key is a spatial dimension, and the value is a list of element types of that dimension present in the mesh.

Type:

Dict[int, List[str]]

default_eletyp

The default element type — a single string for homogeneous meshes, a list of strings for mixed-element meshes. Exposed publicly via the default_element_type property.

Type:

str or List[str]

__init__(mesh: Mesh, reorder: bool = False)[源代码]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

register_point_data(key: str, value: Tensor)[源代码]

Register a per-point field on point_data.

point_data is a tensormesh.nn.BufferDict, so prefer this method over __setitem__ to make sure the tensor is tracked as a buffer of the underlying torch.nn.Module.

参数:
  • key (str) -- the key of the value

  • value (Tensor) -- tensor of shape \([|\mathcal V|, ...]\), where \(|\mathcal V|\) is the number of interpolation nodes (mesh.n_points)

返回:

self will be returned

返回类型:

Mesh

register_element_data(key: str, value: Dict[str, Tensor] | Tensor)[源代码]

Register a per-element field on cell_data.

For homogeneous meshes value may be a single tensor; for mixed-element meshes pass a dict keyed by element type with one tensor per type.

to_meshio(reorder: bool = False) Mesh[源代码]

Export this mesh as an in-memory meshio.Mesh.

参数:

reorder (bool, default=False) -- If True, convert connectivity from the internal ordering back to Gmsh/VTK ordering before returning (delegates to tensormesh.Element.reorder()).

返回:

The meshio mesh object.

返回类型:

meshio.Mesh

save(file_name: str, file_format: str | None = None)[源代码]

Write this mesh to disk via meshio.write.

Boolean point/cell/field arrays are cast to float before writing (meshio does not support bool). For .vtk / .vtu outputs 2-D meshes are padded to 3-D and connectivity is reordered to the Gmsh/VTK convention.

参数:
  • file_name (str) -- the name of the file

  • file_format (str) -- the format of the file, e.g., 'msh', 'vtk', 'obj' default is the file extension

返回:

self will be returned

返回类型:

Mesh

to_file(file_name: str, file_format: str | None = None)

Write this mesh to disk via meshio.write.

Boolean point/cell/field arrays are cast to float before writing (meshio does not support bool). For .vtk / .vtu outputs 2-D meshes are padded to 3-D and connectivity is reordered to the Gmsh/VTK convention.

参数:
  • file_name (str) -- the name of the file

  • file_format (str) -- the format of the file, e.g., 'msh', 'vtk', 'obj' default is the file extension

返回:

self will be returned

返回类型:

Mesh

node_adjacency(element_type: str | Iterable[str] | None = None) SparseMatrix[源代码]

get the node adjacency matrix, inside each element, the nodes are considered fully connected

参数:

element_type (str or Iterable[str] or None) -- the type of the elements if None is the default_element_type default : None

返回:

the adjacency matrix of points \([|\mathcal V|,|\mathcal V|]\), where \(|\mathcal V|\) is the number of interpolation nodes

返回类型:

SparseMatrix

element_adjacency(element_type: str | None = None) SparseMatrix[源代码]

get the element adjacency matrix, the element are considered connected only if they share a boundary/facet

参数:

element_type (str or Iterable[str] or None) -- the type of the elements, should be of same dimension if None is the default_element_type default : None

返回:

the adjacency matrix of elements \([|\mathcal C|,|\mathcal C|]\), where \(|\mathcal C|\) is the number of elements

返回类型:

SparseMatrix

partition(n_parts: int, method: str = 'spectral', element_type: str | None = None) Tensor[源代码]

Partition the mesh into n_parts

参数:
  • n_parts (int) -- Number of partitions

  • method (str, optional) -- Partition method: 'spectral' or 'metis'. Default is 'spectral'.

  • element_type (str or Iterable[str] or None) -- the type of the elements to partition based on connectivity.

返回:

IntTensor of shape [n_elements] containing partition ID

返回类型:

Tensor

color(element_type: str | None = None) Tensor[源代码]

Color the mesh elements such that no adjacent elements share the same color.

参数:

element_type (str or Iterable[str] or None) -- the type of the elements.

返回:

IntTensor of shape [n_elements] containing color ID

返回类型:

Tensor

elements(element_type: int | str | Iterable[str] | None = None) Tensor | Dict[str, Tensor][源代码]

Get the element connectivity for specified element types.

示例

  1. Get elements of default type:

import tensormesh
mesh = tensormesh.Mesh.gen_rectangle()
elements = mesh.elements() # Returns tensor of shape [n_elements, n_basis]
  1. Get elements of specific type:

elements = mesh.elements("tri6") # Returns tensor for triangle elements
  1. Get elements of multiple types:

elements = mesh.elements(["tri6", "quad9"]) # Returns dict of tensors
  1. Get all element types:

elements = mesh.elements("all") # Returns dict of all element tensors
  1. Get elements of specific dimension:

# Get all 2D elements (triangles, quads)
elements = mesh.elements(2) # Returns dict of 2D element tensors

# Get all elements matching mesh dimension
elements = mesh.elements(-1) # Same as mesh.elements(mesh.dim)
参数:

element_type (str or Iterable[str] or int or None) --

the type of the elements:

  • if all, return dict of all elements

  • if int, return dict of elements of that dimension

  • if str, return elements of that type

  • if Iterable[str], return elements of those types

  • if None, use default_eletyp (default)

返回:

  • if element_type is str, return the corresponding elements connections of shape \([|\mathcal C|, B]\), where \(|\mathcal C|\) is the number of elements and \(B\) is the number of basis functions

  • if element_type is int, return dict of elements of that dimension

  • if element_type is Iterable[str], return the mapping of corresponding elements connections of shape \([|\mathcal C|, B]\), where \(|\mathcal C|\) is the number of elements and \(B\) is the number of basis functions

  • if element_type is None, the element_type will be the default_element_type and do as above

  • if element_type is "all", return all elements as a dictionary

返回类型:

Tensor or Dict[str, Tensor]

clone() Mesh[源代码]

Return a deep copy of the mesh that preserves the autograd graph.

Calling torch.Tensor.clone on the underlying buffers detaches them from the computation graph, so gradients flowing through points / cell_data would vanish. This method round-trips through meshio to reconstruct the mesh while keeping the connectivity and metadata intact.

返回:

The cloned mesh.

返回类型:

Mesh

plot(values: Dict[str, Tensor] | Dict[str, Iterable[Tensor]] | None = None, save_path: str | None = None, dt: float | None = None, show_mesh: bool = False, fix_clim: bool = False, show: bool = False, **kwargs)[源代码]

Plot the mesh, optionally overlaying scalar fields or animations.

With no values only the mesh wireframe is drawn. Passing Dict[str, torch.Tensor] produces a static multi-panel figure; passing Dict[str, List[torch.Tensor]] produces an mp4/gif animation (one frame per list element).

参数:
  • values (None or Dict[str, Tensor] or Dict[str, List[Tensor]]) -- the values to plot, if None, only plot the mesh if Dict[str, torch.Tensor], a static subplots will be plotted, the key is the name of the subplot, the value is of shape \([|\mathcal V|]\), where \(|\mathcal V|\) is the number of interpolation nodes if Dict[str, List[torch.Tensor]], a mp4/gif will be plotted, the key is the name of the subplot, each item in the list is of shape \([|\mathcal V|]\), where \(|\mathcal V|\) is the number of interpolation nodes default: None

  • save_path (str or None) -- the path to save the plot, if None, it will not be saved if the values is passed in as Dict[str, List[torch.Tensor]], the save_path must endswith '.mp4' or '.gif' default: None

  • dt (float or None) -- the time interval between each frame, only used when values is passed in as Dict[str, List[torch.Tensor]] default: None

  • show_mesh (bool) -- whether to overlay the mesh wireframe (and, at order >= 2, the interpolation nodes) on top of the colour-filled field. Only takes effect when values is given. default: False

  • fix_clim (bool) -- whether to fix the color limits across all frames, only used when values is passed in as Dict[str, List[torch.Tensor]]. If True, the color limits are determined by the global min and max across all frames, ensuring a consistent colorbar throughout the animation. default: False

  • show (bool) -- whether to display the plot interactively (e.g., via matplotlib.pyplot.show()) default: False

  • **kwargs -- additional keyword arguments passed to the underlying visualization functions

property n_points: int

Number of interpolation nodes \(|\mathcal V|\) in the mesh.

Equals mesh.points.shape[0]. For order >= 2 this counts high-order nodes (mid-edge, mid-face, interior) as well, not only corner vertices.

返回:

the number of interpolation nodes \(|\mathcal V|\)

返回类型:

int

property n_elements: int

Number of elements \(|\mathcal C|\) of the default_element_type.

For mixed-element meshes this sums element counts across every type in default_element_type.

返回:

the number of elements \(|\mathcal C|\)

返回类型:

int

property boundary_mask: Tensor

Boolean mask flagging boundary points.

Looked up from point_data under the key "is_boundary" (preferred) or "boundary_mask". Mesh generators in tensormesh.dataset populate this automatically.

返回:

1D bool tensor of shape \([|\mathcal V|]\), where \(|\mathcal V|\) is the number of interpolation nodes; requires "is_boundary" or "boundary_mask" to live in point_data

返回类型:

Tensor

property dtype: dtype

Floating-point dtype of points (and, by convention, of every buffer in the mesh).

返回:

the data type of the points, e.g., torch.float32, torch.float64

返回类型:

dtype

property device: device

Device on which the mesh tensors live.

返回:

the device of the points, e.g., torch.device("cpu"), torch.device("cuda:0")

返回类型:

device

classmethod from_meshio(mesh: Mesh, reorder: bool = False)[源代码]

Build a Mesh from an in-memory meshio.Mesh.

参数:
  • mesh (meshio.Mesh) -- a meshio mesh object

  • reorder (bool) -- whether to convert connectivity from Gmsh/VTK ordering to TensorMesh internal ordering (delegates to tensormesh.Element.reorder()).

返回:

the mesh object

返回类型:

Mesh

classmethod read(file_name: str, file_format: str | None = None, reorder: bool = False)[源代码]

Read a mesh from disk via meshio.read.

参数:
  • file_name (str) -- the name of the file

  • file_format (str) -- the format of the file, e.g., 'msh', 'vtk', 'obj' default is the file extension

  • reorder (bool) -- whether to convert connectivity from Gmsh/VTK ordering to TensorMesh internal ordering (delegates to tensormesh.Element.reorder()).

返回:

the mesh object

返回类型:

Mesh

classmethod from_file(file_name: str, file_format: str | None = None, reorder: bool = False)

Read a mesh from disk via meshio.read.

参数:
  • file_name (str) -- the name of the file

  • file_format (str) -- the format of the file, e.g., 'msh', 'vtk', 'obj' default is the file extension

  • reorder (bool) -- whether to convert connectivity from Gmsh/VTK ordering to TensorMesh internal ordering (delegates to tensormesh.Element.reorder()).

返回:

the mesh object

返回类型:

Mesh

static gen_rectangle(chara_length: float = 0.1, order: int = 1, element_type: str = 'tri', left: float = 0.0, right: float = 1.0, bottom: float = 0.0, top: float = 1.0, visualize: bool = False, cache_path: str | None = None) Mesh[源代码]

Generate a 2-D mesh of an axis-aligned rectangle.

Delegates to gen_rectangle(), which calls Gmsh under the hood and caches the result if cache_path is given.

参数:
  • chara_length (float, optional) -- the characteristic length of the mesh, default: 0.1

  • order (int, optional) -- the order of the basis function, default: 1

  • element_type (str, optional) -- the type of the element, default: "tri"

  • left (float, optional) -- the left boundary of the rectangle, default: 0.0

  • right (float, optional) -- the right boundary of the rectangle, default: 1.0

  • bottom (float, optional) -- the bottom boundary of the rectangle, default: 0.0

  • top (float, optional) -- the top boundary of the rectangle, default: 1.0

  • visualize (bool, optional) -- whether to visualize the mesh, default: False

  • cache_path (str, optional) -- the path to save the mesh, if None, it will be decided by gen_rectangle(), default: None

返回:

the mesh object

返回类型:

Mesh

static gen_hollow_rectangle(chara_length: float = 0.1, order: int = 1, element_type: str = 'quad', outer_left: float = 0.0, outer_right: float = 1.0, outer_bottom: float = 0.0, outer_top: float = 1.0, inner_left: float = 0.25, inner_right: float = 0.75, inner_bottom: float = 0.25, inner_top: float = 0.75, visualize: bool = False, cache_path: str | None = None) Mesh[源代码]

Generate a 2-D mesh of a rectangle with a rectangular hole cut out.

Delegates to gen_hollow_rectangle().

参数:
  • chara_length (float, optional) -- the characteristic length of the mesh, default: 0.1

  • order (int, optional) -- the order of the basis function, default: 1

  • element_type (str, optional) -- the type of the element, default: "quad"

  • outer_left (float, optional) -- the left boundary of the outer rectangle, default: 0.0

  • outer_right (float, optional) -- the right boundary of the outer rectangle, default: 1.0

  • outer_bottom (float, optional) -- the bottom boundary of the outer rectangle, default: 0.0

  • outer_top (float, optional) -- the top boundary of the outer rectangle, default: 1.0

  • inner_left (float, optional) -- the left boundary of the inner rectangle, default: 0.25

  • inner_right (float, optional) -- the right boundary of the inner rectangle, default: 0.75

  • inner_bottom (float, optional) -- the bottom boundary of the inner rectangle, default: 0.25

  • inner_top (float, optional) -- the top boundary of the inner rectangle, default: 0.75

  • visualize (bool, optional) -- whether to visualize the mesh, default: False

  • cache_path (str, optional) -- the path to save the mesh, if None, it will be decided by gen_hollow_rectangle(), default: None

返回:

the mesh object

返回类型:

Mesh

static gen_circle(chara_length: float = 0.1, order: int = 1, element_type: str = 'tri', cx: float = 0.0, cy: float = 0.0, r: float = 1.0, visualize: bool = False, cache_path: str | None = None) Mesh[源代码]

Generate a 2-D mesh of a disk (filled circle).

Delegates to gen_circle().

参数:
  • chara_length (float, optional) -- the characteristic length of the mesh, default: 0.1

  • order (int, optional) -- the order of the basis function, default: 1

  • element_type (str, optional) -- the type of the element, default: "tri"

  • cx (float, optional) -- the x coordinate of the center of the circle, default: 0.0

  • cy (float, optional) -- the y coordinate of the center of the circle, default: 0.0

  • r (float, optional) -- the radius of the circle, default: 1.0

  • visualize (bool, optional) -- whether to visualize the mesh, default: False

  • cache_path (str, optional) -- the path to save the mesh, if None, it will be decided by gen_circle(), default: None

返回:

the mesh object

返回类型:

Mesh

static gen_hollow_circle(chara_length: float = 0.1, order: int = 1, element_type: str = 'quad', cx: float = 0.0, cy: float = 0.0, r_inner: float = 1.0, r_outer: float = 2.0, visualize: bool = False, cache_path: str | None = None) Mesh[源代码]

Generate a 2-D mesh of an annulus (disk with a circular hole).

Delegates to gen_hollow_circle().

参数:
  • chara_length (float, optional) -- the characteristic length of the mesh, default: 0.1

  • order (int, optional) -- the order of the basis function, default: 1

  • element_type (str, optional) -- the type of the element, default: "quad"

  • cx (float, optional) -- the x coordinate of the center of the circle, default: 0.0

  • cy (float, optional) -- the y coordinate of the center of the circle, default: 0.0

  • r_inner (float, optional) -- the inner radius of the circle, default: 1.0

  • r_outer (float, optional) -- the outer radius of the circle, default: 2.0

  • visualize (bool, optional) -- whether to visualize the mesh, default: False

  • cache_path (str, optional) -- the path to save the mesh, if None, it will be decided by gen_hollow_circle(), default: None

返回:

the mesh object

返回类型:

Mesh

static gen_L(chara_length: float = 0.1, order: int = 1, element_type: str = 'quad', left: float = 0.0, right: float = 1.0, bottom: float = 0.0, top: float = 1.0, top_inner: float = 0.5, right_inner: float = 0.5, visualize: bool = False, cache_path: str | None = None) Mesh[源代码]

Generate a 2-D mesh of an L-shaped domain (re-entrant corner benchmark).

Delegates to gen_L().

参数:
  • chara_length (float, optional) -- the characteristic length of the mesh, default: 0.1

  • order (int, optional) -- the order of the basis function, default: 1

  • element_type (str, optional) -- the type of the element, default: "quad"

  • left (float, optional) -- the left boundary of the rectangle, default: 0.0

  • right (float, optional) -- the right boundary of the rectangle, default: 1.0

  • bottom (float, optional) -- the bottom boundary of the rectangle, default: 0.0

  • top (float, optional) -- the top boundary of the rectangle, default: 1.0

  • top_inner (float, optional) -- the top inner boundary of the rectangle, default: 0.5

  • right_inner (float, optional) -- the right inner boundary of the rectangle, default: 0.5

  • visualize (bool, optional) -- whether to visualize the mesh, default: False

  • cache_path (str, optional) -- the path to save the mesh, if None, it will be decided by gen_L(), default: None

返回:

the mesh object

返回类型:

Mesh

static gen_cube(chara_length: float = 0.1, order: int = 1, left: float = 0.0, right: float = 1.0, bottom: float = 0.0, top: float = 1.0, front: float = 0.0, back: float = 1.0, visualize: bool = False, cache_path: str | None = None) Mesh[源代码]

Generate a 3-D mesh of an axis-aligned cuboid.

Delegates to gen_cube().

参数:
  • chara_length (float, optional) -- the characteristic length of the mesh, default: 0.1

  • order (int, optional) -- the order of the basis function, default: 1

  • left (float, optional) -- the left boundary of the cube, default: 0.0

  • right (float, optional) -- the right boundary of the cube, default: 1.0

  • bottom (float, optional) -- the bottom boundary of the cube, default: 0.0

  • top (float, optional) -- the top boundary of the cube, default: 1.0

  • front (float, optional) -- the front boundary of the cube, default: 0.0

  • back (float, optional) -- the back boundary of the cube, default: 1.0

  • visualize (bool, optional) -- whether to visualize the mesh, default: False

  • cache_path (str, optional) -- the path to save the mesh, if None, it will be decided by gen_cube(), default: None

返回:

the mesh object

返回类型:

Mesh

static gen_hollow_cube(chara_length: float = 0.1, order: int = 1, outer_left: float = 0.0, outer_right: float = 1.0, outer_bottom: float = 0.0, outer_top: float = 1.0, outer_front: float = 0.0, outer_back: float = 1.0, inner_left: float = 0.25, inner_right: float = 0.75, inner_bottom: float = 0.25, inner_top: float = 0.75, inner_front: float = 0.25, inner_back: float = 0.75, visualize: bool = False, cache_path: str | None = None) Mesh[源代码]

Generate a 3-D mesh of a cuboid with a cuboidal hole.

Delegates to gen_hollow_cube().

参数:
  • chara_length (float, optional) -- the characteristic length of the mesh, default: 0.1

  • order (int, optional) -- the order of the basis function, default: 1

  • outer_left (float, optional) -- the left boundary of the outer cube, default: 0.0

  • outer_right (float, optional) -- the right boundary of the outer cube, default: 1.0

  • outer_bottom (float, optional) -- the bottom boundary of the outer cube, default: 0.0

  • outer_top (float, optional) -- the top boundary of the outer cube, default: 1.0

  • outer_front (float, optional) -- the front boundary of the outer cube, default: 0.0

  • outer_back (float, optional) -- the back boundary of the outer cube, default: 1.0

  • inner_left (float, optional) -- the left boundary of the inner cube, default: 0.25

  • inner_right (float, optional) -- the right boundary of the inner cube, default: 0.75

  • inner_bottom (float, optional) -- the bottom boundary of the inner cube, default: 0.25

  • inner_top (float, optional) -- the top boundary of the inner cube, default: 0.75

  • inner_front (float, optional) -- the front boundary of the inner cube, default: 0.25

  • inner_back (float, optional) -- the back boundary of the inner cube, default: 0.75

  • visualize (bool, optional) -- whether to visualize the mesh, default: False

  • cache_path (str, optional) -- the path to save the mesh, if None, it will be decided by gen_hollow_cube(), default: None

返回:

the mesh object

返回类型:

Mesh

static gen_sphere(chara_length: float = 0.1, order: int = 1, cx: float = 0.0, cy: float = 0.0, cz: float = 0.0, r: float = 1.0, visualize: bool = False, cache_path: str | None = None) Mesh[源代码]

Generate a 3-D mesh of a solid ball (filled sphere).

Delegates to gen_sphere().

参数:
  • chara_length (float, optional) -- the characteristic length of the mesh, default: 0.1

  • order (int, optional) -- the order of the basis function, default: 1

  • cx (float, optional) -- the x coordinate of the center of the sphere, default: 0.0

  • cy (float, optional) -- the y coordinate of the center of the sphere, default: 0.0

  • cz (float, optional) -- the z coordinate of the center of the sphere, default: 0.0

  • r (float, optional) -- the radius of the sphere, default: 1.0

  • visualize (bool, optional) -- whether to visualize the mesh, default: False

  • cache_path (str, optional) -- the path to save the mesh, if None, it will be decided by gen_sphere(), default: None

返回:

the mesh object

返回类型:

Mesh

static gen_hollow_sphere(chara_length: float = 0.1, order: int = 1, cx: float = 0.0, cy: float = 0.0, cz: float = 0.0, r_inner: float = 1.0, r_outer: float = 2.0, visualize: bool = False, cache_path: str | None = None) Mesh[源代码]

Generate a 3-D mesh of a spherical shell (ball with a concentric spherical cavity).

Delegates to gen_hollow_sphere().

参数:
  • chara_length (float, optional) -- the characteristic length of the mesh, default: 0.1

  • order (int, optional) -- the order of the basis function, default: 1

  • cx (float, optional) -- the x coordinate of the center of the sphere, default: 0.0

  • cy (float, optional) -- the y coordinate of the center of the sphere, default: 0.0

  • cz (float, optional) -- the z coordinate of the center of the sphere, default: 0.0

  • r_inner (float, optional) -- the inner radius of the sphere, default: 1.0

  • r_outer (float, optional) -- the outer radius of the sphere, default: 2.0

  • visualize (bool, optional) -- whether to visualize the mesh, default: False

  • cache_path (str, optional) -- the path to save the mesh, if None, it will be decided by gen_hollow_sphere(), default: None

返回:

the mesh object

返回类型:

Mesh

Graph algorithms

Helpers used internally by DistributedMesh for race-free parallel assembly and domain decomposition.

graph_coloring(adjacency: SparseMatrix, max_iter: int = 100) Tensor[源代码]

Parallel graph coloring algorithm (Iterative Conflict Resolution). Runs efficiently on GPU.

参数:
  • adjacency (SparseMatrix) -- The adjacency matrix of the graph. shape: [n_nodes, n_nodes]

  • max_iter (int) -- Maximum number of conflict resolution iterations.

返回:

IntTensor of shape [n_nodes] containing the color ID for each node.

返回类型:

Tensor

graph_partition(adjacency: SparseMatrix, n_parts: int, method: str = 'spectral') Tensor[源代码]

Partition a graph into balanced subdomains using GPU-accelerated algorithms.

This function divides graph nodes into n_parts groups such that:

  1. Each partition has approximately equal number of nodes (load balance)

  2. The number of edges crossing partitions is minimized (minimal interface)

参数:
  • adjacency (SparseMatrix) -- The adjacency matrix of the graph. Shape: [n_nodes, n_nodes]. Should be symmetric for undirected graphs.

  • n_parts (int) -- Number of partitions to create. For spectral method, works best when n_parts is a power of 2.

  • method (str, optional) --

    Partitioning algorithm to use:

    • 'spectral': Recursive Spectral Bisection using Fiedler vector. Computed via torch.lobpcg on GPU. Default method.

    • 'metis': Uses pymetis library (requires pymetis installation). Falls back to spectral if pymetis is not available.

返回:

Integer tensor of shape [n_nodes] containing partition labels in range [0, n_parts-1].

返回类型:

Tensor

备注

The spectral method computes the Fiedler vector (second smallest eigenvector of the graph Laplacian \(L = D - A\)) and recursively bisects based on median values. This preserves locality in the partitioning.

For small subgraphs (< 2048 nodes), dense eigensolvers are used for stability. For larger graphs, LOBPCG is used for efficiency.

示例

>>> from tensormesh.mesh import graph_partition
>>> # Partition element adjacency into 4 parts
>>> adj = mesh.element_adjacency()
>>> labels = graph_partition(adj, n_parts=4, method='spectral')
>>> print(f"Partition sizes: {[(labels == i).sum().item() for i in range(4)]}")