tensormesh.mesh.coloring 源代码
import torch
import numpy as np
from .. import sparse
[文档]
def graph_coloring(adjacency: sparse.SparseMatrix, max_iter: int = 100) -> torch.Tensor:
"""
Parallel graph coloring algorithm (Iterative Conflict Resolution).
Runs efficiently on GPU.
Parameters
----------
adjacency : SparseMatrix
The adjacency matrix of the graph. shape: [n_nodes, n_nodes]
max_iter : int
Maximum number of conflict resolution iterations.
Returns
-------
torch.Tensor
IntTensor of shape [n_nodes] containing the color ID for each node.
"""
n_nodes = adjacency.shape[0]
device = adjacency.device
# Random weights for conflict resolution (fixed throughout)
# If two neighbors pick the same color, the one with lower weight yields.
node_weights = torch.rand(n_nodes, device=device)
# Initialize colors (0 for all)
colors = torch.zeros(n_nodes, dtype=torch.long, device=device)
# Get edge list for efficient neighbor lookup
# edges: [2, n_edges]
edges = adjacency.edges
u, v = edges[0], edges[1]
# Remove self-loops if any
mask = u != v
u, v = u[mask], v[mask]
# Main loop
for i in range(max_iter):
# 1. Detect conflicts
# Check edges where endpoints have same color
color_u = colors[u]
color_v = colors[v]
conflict_mask = (color_u == color_v)
if not conflict_mask.any():
break
# Get conflicting edges
conf_u = u[conflict_mask]
conf_v = v[conflict_mask]
# 2. Resolve conflicts
# In a conflict pair (u, v), the one with lower weight must change
# We identify nodes that need to change
weight_u = node_weights[conf_u]
weight_v = node_weights[conf_v]
# Nodes to update: u where weight[u] < weight[v]
update_mask_u = weight_u < weight_v
nodes_to_update = conf_u[update_mask_u]
# Also handle the other side (v < u) implicitly?
# Since edges are symmetric in sparse matrix usually, (u,v) and (v,u) exist.
# But if adjacency stores symmetric edges explicitly, we process both.
# If adjacency is strictly upper/lower triangular, we need logic.
# Assuming adjacency is symmetric (contains both u->v and v->u).
if nodes_to_update.numel() == 0:
# Should not happen if conflicts exist and weights are random floats (low collision prob)
break
unique_nodes_to_update = torch.unique(nodes_to_update)
# 3. Re-color conflicting nodes from a small random palette.
# The palette [0, limit) widens slowly with i to guarantee convergence
# while keeping the total color count low (~5-8 for planar graphs).
# We only need *a* valid coloring for FEM assembly; minimizing color
# count is secondary (though fewer colors → better GPU occupancy).
limit = 6 + (i // 5)
new_colors = torch.randint(0, limit, (unique_nodes_to_update.shape[0],), device=device)
colors[unique_nodes_to_update] = new_colors
return colors