Source code for tensormesh.distributed.assemble

"""
Distributed assembly for multi-GPU FEM computation.

Provides functions to assemble element matrices and node vectors in parallel
across multiple GPUs, then combine results for distributed solving via torch-sla.

Strategy:
  Phase 1 (sequential): create assemblers + pre-warm all lazy properties
           (CUDA lazy init and Transformation caching are NOT thread-safe)
  Phase 2 (parallel threads): run asm() on each device concurrently
           (pure computation, all lazy state already cached)
  Phase 3 (sequential): collect results and merge COO
"""

import threading
import torch
from typing import Type, Optional, Dict, List, Tuple

from ..assemble import ElementAssembler, NodeAssembler
from ..sparse import SparseMatrix

from .mesh import DistributedMesh

try:
    from torch_sla import DSparseTensor
    HAS_DSPARSE = True
except ImportError:
    HAS_DSPARSE = False


# ─── Warmup helpers ─────────────────────────────────────────────────

def _warmup_assembler(asm: ElementAssembler):
    """Force-evaluate all lazy cached properties in Transformation.

    Transformation stores shape_val, shape_grad, JxW etc. as lazy
    ``@property`` that compute-on-first-access and cache in _buffers.
    If these are first triggered inside a thread, PyTorch's internal
    lazy init (CUDA, vmap, etc.) can race.  Calling them here in the
    main thread makes the subsequent threaded asm() purely arithmetic.
    """
    for element_type in asm.element_types:
        trans = asm.transformation[element_type]
        # Access lazy properties to force caching
        _ = trans.shape_val
        _ = trans.shape_grad
        _ = trans.JxW


# ─── Core implementation ────────────────────────────────────────────

def _element_assemble_all(
    assembler_cls: Type[ElementAssembler],
    dmesh: DistributedMesh,
    quadrature_order: int,
    project: str,
    assembler_kwargs: dict,
    call_kwargs: dict,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """Prepare assemblers sequentially, run assembly in parallel threads.

    Returns merged global COO (values, row, col) on CPU.
    """
    n = dmesh.num_partitions
    assemblers = []  # (asm, orig_nid_cpu, submesh) or None

    # ── Phase 1: sequential setup + warmup (main thread) ──
    for i in range(n):
        submesh = dmesh.submeshes[i]
        if submesh is None:
            assemblers.append(None)
            continue

        device = dmesh.devices[i]
        orig_nid = submesh.point_data['orig_nid'].clone()
        submesh.to(device)

        asm = assembler_cls.from_mesh(
            submesh, quadrature_order=quadrature_order,
            project=project, **assembler_kwargs,
        )
        _warmup_assembler(asm)
        assemblers.append((asm, orig_nid, submesh))

    # ── Phase 2: parallel assembly (threads) ──
    results: List[Optional[Tuple]] = [None] * n
    errors: List[Tuple[int, Exception]] = []

    def _worker(i):
        try:
            asm, orig_nid, submesh = assemblers[i]
            device = asm.device
            if device.type == 'cuda':
                torch.cuda.set_device(device)

            K_local: SparseMatrix = asm(**call_kwargs)

            global_row = orig_nid[K_local.row.cpu()]
            global_col = orig_nid[K_local.col.cpu()]
            values = K_local.edata.cpu()
            results[i] = (values, global_row, global_col)
        except Exception as e:
            errors.append((i, e))

    threads = []
    for i in range(n):
        if assemblers[i] is None:
            results[i] = (
                torch.tensor([], dtype=torch.float64),
                torch.tensor([], dtype=torch.long),
                torch.tensor([], dtype=torch.long),
            )
            continue
        t = threading.Thread(target=_worker, args=(i,))
        t.start()
        threads.append(t)

    for t in threads:
        t.join()

    # ── Phase 3: collect & merge ──
    # Move submeshes back to CPU
    for item in assemblers:
        if item is not None:
            item[2].cpu()

    if errors:
        msgs = [f"Partition {pid}: {e}" for pid, e in errors]
        raise RuntimeError(
            f"Distributed assembly failed on {len(errors)} partition(s):\n"
            + "\n".join(msgs)
        )

    all_values, all_row, all_col = [], [], []
    for r in results:
        assert r is not None
        vals, row, col = r
        if vals.numel() > 0:
            all_values.append(vals)
            all_row.append(row)
            all_col.append(col)

    if not all_values:
        raise RuntimeError("All partitions produced empty assemblies")

    return torch.cat(all_values), torch.cat(all_row), torch.cat(all_col)


def _node_assemble_all(
    assembler_cls: Type[NodeAssembler],
    dmesh: DistributedMesh,
    quadrature_order: int,
    project: str,
    assembler_kwargs: dict,
    call_kwargs: dict,
) -> torch.Tensor:
    """Prepare node assemblers sequentially, run in parallel threads.

    Returns global vector on CPU.
    """
    n = dmesh.num_partitions
    assemblers = []      # (asm, orig_nid, submesh, local_call_kwargs) or None

    # ── Phase 1: sequential setup + warmup ──
    for i in range(n):
        submesh = dmesh.submeshes[i]
        if submesh is None:
            assemblers.append(None)
            continue

        device = dmesh.devices[i]
        orig_nid = submesh.point_data['orig_nid'].clone()
        submesh.to(device)

        asm = assembler_cls.from_mesh(
            submesh, quadrature_order=quadrature_order,
            project=project, **assembler_kwargs,
        )
        # Warmup lazy properties
        for et in asm.element_types:
            trans = asm.transformation[et]
            _ = trans.shape_val
            _ = trans.shape_grad
            _ = trans.JxW

        # Remap global point_data → local
        lkw = dict(call_kwargs)
        if 'point_data' in lkw and lkw['point_data'] is not None:
            local_pd = {}
            for k, v in lkw['point_data'].items():
                local_pd[k] = v[orig_nid].to(device)
            lkw['point_data'] = local_pd

        assemblers.append((asm, orig_nid, submesh, lkw))

    # ── Phase 2: parallel assembly ──
    node_results: List[Optional[torch.Tensor]] = [None] * n
    errors: List[Tuple[int, Exception]] = []

    def _worker(i):
        try:
            asm, _, submesh, lkw = assemblers[i]
            device = asm.device
            if device.type == 'cuda':
                torch.cuda.set_device(device)
            node_results[i] = asm(**lkw).cpu()
        except Exception as e:
            errors.append((i, e))

    threads = []
    for i in range(n):
        if assemblers[i] is None:
            continue
        t = threading.Thread(target=_worker, args=(i,))
        t.start()
        threads.append(t)

    for t in threads:
        t.join()

    # Move submeshes back
    for item in assemblers:
        if item is not None:
            item[2].cpu()

    if errors:
        msgs = [f"Partition {pid}: {e}" for pid, e in errors]
        raise RuntimeError(
            f"Distributed node assembly failed on {len(errors)} partition(s):\n"
            + "\n".join(msgs)
        )

    # ── Phase 3: gather into global vector ──
    first_i = next(i for i in range(n) if node_results[i] is not None)
    first_result = node_results[first_i]
    local_n_points = assemblers[first_i][2].n_points
    dof_per_point = first_result.shape[0] // local_n_points

    N = dmesh.n_global_points
    f_global = torch.zeros(N * dof_per_point, dtype=first_result.dtype)

    for i in range(n):
        if assemblers[i] is None or node_results[i] is None:
            continue
        orig_nid = assemblers[i][1]
        f_local = node_results[i]

        if dof_per_point == 1:
            f_global.scatter_add_(0, orig_nid, f_local)
        else:
            f_local_2d = f_local.view(-1, dof_per_point)
            for h in range(dof_per_point):
                idx = orig_nid * dof_per_point + h
                f_global.scatter_add_(0, idx, f_local_2d[:, h])

    return f_global


# ─── Public API ─────────────────────────────────────────────────────

[docs] def distributed_element_assemble( assembler_cls: Type[ElementAssembler], dmesh: DistributedMesh, quadrature_order: int = 2, project: str = 'reduce', call_kwargs: Optional[dict] = None, **assembler_kwargs, ) -> "DSparseTensor": """Assemble element matrix in parallel across multiple devices. Assemblers are created sequentially (for CUDA thread-safety), then assembly computation runs in parallel threads on separate GPUs. Parameters ---------- assembler_cls : Type[ElementAssembler] The assembler class (e.g., ``LaplaceElementAssembler``). dmesh : DistributedMesh Partitioned mesh with device assignments. quadrature_order : int, optional Quadrature order for integration. Default: 2. project : str, optional Projection method: ``'reduce'`` or ``'sparse'``. Default: ``'reduce'``. call_kwargs : dict, optional Extra keyword arguments passed to ``assembler.__call__()`` (e.g., ``point_data``, ``scalar_data``). **assembler_kwargs Extra keyword arguments passed to ``assembler_cls.from_mesh()``. Returns ------- DSparseTensor Distributed sparse matrix ready for distributed solve. """ if not HAS_DSPARSE: raise ImportError( "torch-sla with DSparseTensor support is required.\n" "Install with: pip install torch-sla>=0.2.0" ) if call_kwargs is None: call_kwargs = {} global_values, global_row, global_col = _element_assemble_all( assembler_cls, dmesh, quadrature_order, project, assembler_kwargs, call_kwargs, ) N = dmesh.n_global_points if global_values.dim() > 1: global_values, global_row, global_col, N = _expand_block_coo( global_values, global_row, global_col, N ) return DSparseTensor( global_values, global_row, global_col, shape=(N, N), num_partitions=dmesh.num_partitions, coords=dmesh.global_mesh.points.cpu(), partition_method='rcb', )
[docs] def distributed_element_assemble_to_sparse( assembler_cls: Type[ElementAssembler], dmesh: DistributedMesh, quadrature_order: int = 2, project: str = 'reduce', call_kwargs: Optional[dict] = None, **assembler_kwargs, ) -> SparseMatrix: """Assemble element matrix in parallel, returning a global SparseMatrix. Same as :func:`distributed_element_assemble` but returns a standard :class:`~tensormesh.sparse.SparseMatrix` instead of torch-sla's ``DSparseTensor``. """ if call_kwargs is None: call_kwargs = {} global_values, global_row, global_col = _element_assemble_all( assembler_cls, dmesh, quadrature_order, project, assembler_kwargs, call_kwargs, ) N = dmesh.n_global_points if global_values.dim() > 1: return SparseMatrix.from_block_coo( global_values, global_row, global_col, shape=(N, N) ) return SparseMatrix(global_values, global_row, global_col, shape=(N, N))
[docs] def distributed_node_assemble( assembler_cls: Type[NodeAssembler], dmesh: DistributedMesh, quadrature_order: int = 2, project: str = 'reduce', point_data: Optional[Dict[str, torch.Tensor]] = None, call_kwargs: Optional[dict] = None, **assembler_kwargs, ) -> torch.Tensor: """Assemble node vector (RHS) in parallel across multiple devices.""" if call_kwargs is None: call_kwargs = {} if point_data is not None: call_kwargs['point_data'] = point_data return _node_assemble_all( assembler_cls, dmesh, quadrature_order, project, assembler_kwargs, call_kwargs, )
def _expand_block_coo( values: torch.Tensor, row: torch.Tensor, col: torch.Tensor, n_points: int, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]: """Expand block COO [nnz, dof, dof] to scalar COO.""" nnz = values.shape[0] dof = values.shape[1] N = n_points * dof block_offsets_i = torch.arange(dof, device=values.device) block_offsets_j = torch.arange(dof, device=values.device) row_expanded = (row[:, None, None] * dof + block_offsets_i[None, :, None]).expand(nnz, dof, dof) col_expanded = (col[:, None, None] * dof + block_offsets_j[None, None, :]).expand(nnz, dof, dof) return values.reshape(-1), row_expanded.reshape(-1), col_expanded.reshape(-1), N