import torch
import numpy as np
import meshio
from typing import Optional, Dict, List, Any, Tuple
from .. import sparse
# ─── Fast coordinate-based partitioning ─────────────────────────────
def _rcb_partition_elements(centroids: torch.Tensor, n_parts: int) -> torch.Tensor:
"""Recursive Coordinate Bisection on element centroids.
Recursively splits elements along the longest bounding-box axis at
the median. O(n log n) — orders of magnitude faster than spectral.
Parameters
----------
centroids : torch.Tensor
Element centroids, shape ``[n_elements, dim]``.
n_parts : int
Number of partitions (best if power of 2).
Returns
-------
torch.Tensor
Integer labels ``[n_elements]`` in ``[0, n_parts-1]``.
"""
n = centroids.shape[0]
labels = torch.zeros(n, dtype=torch.long)
# Stack: (indices_into_centroids, partition_offset)
stack: List[Tuple[torch.Tensor, int, int]] = [
(torch.arange(n), 0, n_parts)
]
while stack:
indices, offset, parts_remaining = stack.pop()
if parts_remaining <= 1 or len(indices) == 0:
labels[indices] = offset
continue
# Pick longest axis of bounding box
pts = centroids[indices]
extent = pts.max(dim=0).values - pts.min(dim=0).values
axis = extent.argmax().item()
# Median split along that axis
vals = pts[:, axis]
median = vals.median()
left_mask = vals <= median
# Ensure both sides are non-empty
if left_mask.all():
left_mask[vals.argmax()] = False
elif not left_mask.any():
left_mask[vals.argmin()] = True
left_idx = indices[left_mask]
right_idx = indices[~left_mask]
left_parts = parts_remaining // 2
right_parts = parts_remaining - left_parts
stack.append((left_idx, offset, left_parts))
stack.append((right_idx, offset + left_parts, right_parts))
return labels
def _compute_element_centroids(mesh) -> Tuple[torch.Tensor, List[str], List[int]]:
"""Compute centroids for all top-dimension elements.
Returns (centroids, target_types, cell_counts).
"""
dim_key = getattr(mesh, 'max_dim', mesh.dim)
target_types = mesh.dim2eletyp[dim_key]
if isinstance(target_types, str):
target_types = [target_types]
cell_counts = [mesh.cells[k].shape[0] for k in target_types]
points = mesh.points # [n_points, dim]
centroid_list = []
for k in target_types:
cells = mesh.cells[k] # [n_elem, n_basis]
elem_pts = points[cells] # [n_elem, n_basis, dim]
centroid_list.append(elem_pts.mean(dim=1)) # [n_elem, dim]
centroids = torch.cat(centroid_list, dim=0) # [total_elements, dim]
return centroids, target_types, cell_counts
def _spectral_bisection_gpu(adjacency: sparse.SparseMatrix, indices: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Partition a subgraph using Fiedler vector computed via LOBPCG on GPU.
Returns
-------
Tuple[torch.Tensor, torch.Tensor]
``(part0_global, part1_global)`` — two 1D index tensors that together
cover the input ``indices``. Either side may be empty when the
bisection degenerates; the caller in :func:`graph_partition` then
falls back to a manual median split.
"""
n = indices.shape[0]
device = adjacency.device
empty = torch.empty(0, dtype=indices.dtype, device=device)
if n == 0:
return empty, empty
if n == 1:
return indices, empty
# 1. Extract subgraph Laplacian
# Filter edges: both u and v must be in indices
# To do this efficiently:
# 1. Create a boolean mask of nodes in subgraph
mask = torch.zeros(adjacency.shape[0], dtype=torch.bool, device=device)
mask[indices] = True
row, col = adjacency.row, adjacency.col
edge_mask = mask[row] & mask[col]
sub_row = row[edge_mask]
sub_col = col[edge_mask]
sub_data = adjacency.edata[edge_mask]
# Relabel indices to [0, n-1]
indices_sorted, _ = torch.sort(indices)
sub_row_mapped = torch.searchsorted(indices_sorted, sub_row)
sub_col_mapped = torch.searchsorted(indices_sorted, sub_col)
# Build Laplacian L = D - A
# Convert to torch.sparse_coo_tensor for efficient matmul
sub_indices = torch.stack([sub_row_mapped, sub_col_mapped])
A_sparse = torch.sparse_coo_tensor(sub_indices, sub_data, (n, n)).coalesce()
# Degree
degrees = torch.sparse.sum(A_sparse, dim=1).to_dense()
# Solve eigenvalue problem
try:
# Use lobpcg
# Simpler fallback: Use generic method or dense if small.
if n < 2048:
L_dense = torch.diag(degrees) - A_sparse.to_dense()
vals, vecs = torch.linalg.eigh(L_dense)
fiedler = vecs[:, 1]
else:
# For sparse LOBPCG
idx_d = torch.arange(n, device=device)
indices_d = torch.stack([idx_d, idx_d])
L_sparse = torch.sparse_coo_tensor(
torch.cat([indices_d, sub_indices], dim=1),
torch.cat([degrees, -sub_data]),
(n, n)
).coalesce()
vals, vecs = torch.lobpcg(L_sparse, k=2, largest=False)
fiedler = vecs[:, 1]
except Exception as e:
# Fallback to random bisection if the eigensolver fails
labels = (torch.rand(n, device=device) > 0.5).long()
part0_local = torch.nonzero(labels == 0).squeeze(1)
part1_local = torch.nonzero(labels == 1).squeeze(1)
return indices_sorted[part0_local], indices_sorted[part1_local]
# Median cut
median = torch.median(fiedler)
labels = (fiedler > median).long()
part0_local = torch.nonzero(labels == 0).squeeze(1)
part1_local = torch.nonzero(labels == 1).squeeze(1)
part0_global = indices_sorted[part0_local]
part1_global = indices_sorted[part1_local]
return part0_global, part1_global
[文档]
def graph_partition(adjacency: sparse.SparseMatrix, n_parts: int, method: str = 'spectral') -> torch.Tensor:
"""
Partition a graph into balanced subdomains using GPU-accelerated algorithms.
This function divides graph nodes into ``n_parts`` groups such that:
1. Each partition has approximately equal number of nodes (load balance)
2. The number of edges crossing partitions is minimized (minimal interface)
Parameters
----------
adjacency : SparseMatrix
The adjacency matrix of the graph. Shape: ``[n_nodes, n_nodes]``.
Should be symmetric for undirected graphs.
n_parts : int
Number of partitions to create. For spectral method, works best
when ``n_parts`` is a power of 2.
method : str, optional
Partitioning algorithm to use:
- ``'spectral'``: Recursive Spectral Bisection using Fiedler vector.
Computed via ``torch.lobpcg`` on GPU. Default method.
- ``'metis'``: Uses pymetis library (requires pymetis installation).
Falls back to spectral if pymetis is not available.
Returns
-------
torch.Tensor
Integer tensor of shape ``[n_nodes]`` containing partition labels
in range ``[0, n_parts-1]``.
Notes
-----
The spectral method computes the Fiedler vector (second smallest eigenvector
of the graph Laplacian :math:`L = D - A`) and recursively bisects based on
median values. This preserves locality in the partitioning.
For small subgraphs (< 2048 nodes), dense eigensolvers are used for stability.
For larger graphs, LOBPCG is used for efficiency.
Examples
--------
>>> from tensormesh.mesh import graph_partition
>>> # Partition element adjacency into 4 parts
>>> adj = mesh.element_adjacency()
>>> labels = graph_partition(adj, n_parts=4, method='spectral')
>>> print(f"Partition sizes: {[(labels == i).sum().item() for i in range(4)]}")
"""
n_nodes = adjacency.shape[0]
device = adjacency.device
if method == 'metis':
try:
import pymetis
adj_scipy = adjacency.to_scipy_coo().tocsr()
n_cuts, membership = pymetis.part_graph(n_parts, adjacency=adj_scipy)
return torch.tensor(membership, device=device)
except ImportError:
print("pymetis not found, falling back to spectral.")
method = 'spectral'
if method == 'spectral':
labels = torch.zeros(n_nodes, dtype=torch.long, device=device)
# Queue of node indices for each part
parts = [torch.arange(n_nodes, device=device)]
while len(parts) < n_parts:
# Pick largest part to split
lengths = [len(p) for p in parts]
split_idx = np.argmax(lengths)
indices_to_split = parts.pop(split_idx)
# Bisect
part0, part1 = _spectral_bisection_gpu(adjacency, indices_to_split)
if len(part0) > 0: parts.append(part0)
if len(part1) > 0: parts.append(part1)
if len(part0) == 0 or len(part1) == 0:
n = len(indices_to_split)
if n > 1:
mid = n // 2
parts.append(indices_to_split[:mid])
parts.append(indices_to_split[mid:])
# Assign labels
for i, indices in enumerate(parts):
labels[indices] = i
return labels
raise ValueError(f"Unknown method {method}")
def partition_mesh(mesh, n_parts: int, method: str = 'coordinate', ghost_nodes: bool = True) -> List[Any]:
"""
Partition a mesh into independent submeshes for parallel computation.
This function performs element-based domain decomposition, creating
``n_parts`` submeshes that can be processed independently (with ghost
node communication for boundary data exchange).
Parameters
----------
mesh : Mesh
The mesh to partition. Must have elements of at least one type.
n_parts : int
Number of partitions to create.
method : str, optional
Partitioning algorithm:
- ``'coordinate'``: Recursive Coordinate Bisection on element
centroids. Very fast — O(n log n). Default.
- ``'spectral'``: Recursive Spectral Bisection using Fiedler vector.
Better partition quality but much slower.
- ``'metis'``: Uses pymetis (requires installation).
ghost_nodes : bool, optional
Whether to include ghost nodes (shared boundary nodes) in submeshes.
Currently only ``True`` is supported for element-based partitioning.
Default: ``True``.
Returns
-------
List[Mesh]
A list of ``n_parts`` submeshes. Each submesh is a complete ``Mesh``
object containing:
- Local nodes and elements with renumbered indices
- ``point_data['orig_nid']``: Tensor mapping local node indices to
original global node indices (for data exchange between partitions)
Returns ``None`` for empty partitions.
Notes
-----
Ghost nodes are nodes shared between partitions (on the interface).
They are duplicated in each partition that uses them, enabling independent
local computation. The ``orig_nid`` mapping allows reconstruction of
global solutions and inter-partition communication.
Examples
--------
>>> from tensormesh.mesh.partition import partition_mesh
>>> submeshes = partition_mesh(mesh, n_parts=4, method='coordinate')
>>> for i, sub in enumerate(submeshes):
... if sub is not None:
... print(f"Part {i}: {sub.n_nodes} nodes, {sub.n_elements} elements")
... # Access original node IDs for global assembly
... global_ids = sub.point_data['orig_nid']
"""
if not ghost_nodes:
raise NotImplementedError("ghost_nodes=False is not yet supported. Element-based partitioning always implies ghost nodes.")
# 1. Partition Elements
if method == 'coordinate':
# Fast path: RCB on element centroids — no adjacency graph needed
centroids, target_types, cell_counts = _compute_element_centroids(mesh)
element_labels = _rcb_partition_elements(centroids, n_parts)
else:
# Slow path: spectral / metis (requires element adjacency)
element_labels = mesh.partition(n_parts, method)
dim_key = getattr(mesh, 'max_dim', mesh.dim)
target_types = mesh.dim2eletyp[dim_key]
if isinstance(target_types, str):
target_types = [target_types]
cell_counts = [mesh.cells[k].shape[0] for k in target_types]
# 2. Iterate and split
submeshes = []
labels_per_type = {}
curr = 0
for k, count in zip(target_types, cell_counts):
labels_per_type[k] = element_labels[curr:curr+count]
curr += count
for p in range(n_parts):
sub_cells_dict = {}
sub_nodes_indices = []
# Filter elements for this partition
for k in target_types:
labels = labels_per_type[k]
mask = (labels == p)
if not mask.any():
continue
cells = mesh.cells[k][mask] # [n_sub, n_nodes_per_elem]
sub_cells_dict[k] = cells
sub_nodes_indices.append(cells.reshape(-1))
if not sub_cells_dict:
# Empty partition
submeshes.append(None)
continue
# Unique nodes required by this partition
used_nodes = torch.cat(sub_nodes_indices).unique()
used_nodes = torch.sort(used_nodes).values
# Extract points
sub_points = mesh.points[used_nodes]
# Remap cells
sub_cells_meshio = []
for k, cells in sub_cells_dict.items():
new_cells = torch.searchsorted(used_nodes, cells)
sub_cells_meshio.append((k, new_cells.cpu().numpy()))
# Create meshio mesh
m_io = meshio.Mesh(
points=sub_points.cpu().numpy(),
cells=sub_cells_meshio
)
# Create TensorMesh
submesh = mesh.__class__(m_io)
submesh.to(mesh.device)
# Add metadata
submesh.point_data['orig_nid'] = used_nodes
submeshes.append(submesh)
return submeshes