Source code for tensormesh.assemble.node_assembler

from abc import abstractmethod
import inspect
import math
from typing import Dict, Optional, Callable, Literal

import torch
import torch.nn as nn

from .projector import ReduceProjector, SparseProjector
from ..element import Transformation, element_type2dimension
from ..nn import BufferDict
from ..mesh import Mesh
from ..vmap import vmap


class InputBroadcast:
    """Per-argument vmap mapping for :class:`NodeAssembler.forward`.

    Each attribute is the ``in_dims`` index for the corresponding vmap layer
    (element / quadrature / v), or ``None`` to broadcast over that layer.
    """
    element: Optional[int]
    quadrature: Optional[int]
    v: Optional[int]

    def __init__(self, element: bool, quadrature: bool, v: bool):
        self.element    = 0 if element else None
        self.quadrature = 0 if quadrature else None
        self.v          = 0 if v else None



[docs] class NodeAssembler(nn.Module): r"""Assemble an element-wise linear form into a global node vector. :class:`NodeAssembler` is the linear-form counterpart of :class:`ElementAssembler`. Override :meth:`forward` to define the integrand :math:`l(v) = \int_\Omega f(v)\, \mathrm{d}\Omega`; calling the assembler returns a 1-D :class:`torch.Tensor` of shape :math:`[|\mathcal V|]`, or :math:`[|\mathcal V| \times H]` for vector-valued problems with :math:`H` degrees of freedom per node. Subclasses are usually built from a mesh: * :meth:`from_mesh` — build from a :class:`~tensormesh.Mesh`. * :meth:`from_elements` — build from raw connectivity tensors. * :meth:`from_assembler` — share topology with another assembler. Examples -------- Load vector :math:`f_i = \int_\Omega v_i\, \mathrm{d}\Omega`: .. code-block:: python import tensormesh class OneAssembler(tensormesh.NodeAssembler): def forward(self, v): return v mesh = tensormesh.Mesh.gen_rectangle() f = OneAssembler.from_mesh(mesh)(mesh.points) Traction-style load :math:`f_i = \int_\Omega \mathbf t \cdot v_i\, \mathrm{d}\Omega`: .. code-block:: python import tensormesh import tensormesh.functional as F class TractionAssembler(tensormesh.NodeAssembler): def forward(self, v, t): return F.dot(t, v) mesh = tensormesh.Mesh.gen_circle() t = torch.ones(mesh.n_points, 2) # unit traction in x, y assembler = TractionAssembler.from_mesh(mesh) f = assembler(mesh.points, point_data={"t": t}) Attributes ---------- projector : torch.nn.ModuleDict Maps each ``element_type`` to a :class:`~tensormesh.assemble.projector.Projector` that scatters per-element basis contributions of shape :math:`[|\mathcal C_e|, B_e]` onto the node vector of shape :math:`[|\mathcal V|]`. transformation : torch.nn.ModuleDict Maps each ``element_type`` to a :class:`~tensormesh.Transformation` caching shape values, shape gradients, and ``JxW`` at quadrature points. elements : tensormesh.nn.BufferDict Maps each ``element_type`` to its connectivity tensor of shape :math:`[|\mathcal C|, B]`. n_points : int Number of mesh points (length of the output vector for scalar problems). dimension : int Spatial dimension of the mesh, one of ``1``, ``2``, ``3``. element_types : list[str] Element-type strings present in the mesh. """ projector: nn.ModuleDict # Dict[str, Projector] transformation: nn.ModuleDict # Dict[str, Transformation] elements: BufferDict # Dict[str, torch.Tensor] dimension: int element_types: list[str] n_points: int __autodoc__ = [ '__call__', 'forward', '__post_init__', 'from_assembler', 'from_mesh', 'compile', 'reset_compile', ]
[docs] def __init__(self, projector:nn.ModuleDict, transformation:nn.ModuleDict, elements:BufferDict, *args, **kwargs): super().__init__() element_types = list(projector.keys()) dimension = element_type2dimension[element_types[0]] self.projector = projector self.transformation = transformation self.elements = elements self.dimension = dimension self.element_types = element_types self.n_points = next(iter(elements.values())).shape[0] # Compile options self._compile: bool = False self._compile_options: Dict = {} self._compiled_call_fn: Optional[Callable] = None self.__post_init__(*args, **kwargs)
def _integrate(self, batch_integral, jxw, n_element, n_basis, use_element_parallel): if not use_element_parallel: error_msg = f"the shape returned by forward function is {batch_integral.shape} which is not supported, should either be [batch_size,{n_basis}] or [batch_size,{n_basis}, dof_per_point]" assert batch_integral.dim() == 2 or batch_integral.dim() == 3, error_msg assert batch_integral.shape[1] == n_basis, error_msg batch_integral = torch.einsum("qi...,eq->ei...", batch_integral, jxw) # [n_element, n_basis, ...] else: error_msg = f"the shape returned by forward function is {batch_integral.shape} which is not supported, should either be [{n_element},batch_size,{n_basis}] or [{n_element},batch_size,{n_basis}, dof_per_point]" assert batch_integral.dim() == 3 or batch_integral.dim() == 4, error_msg assert batch_integral.shape[0] == n_element, error_msg assert batch_integral.shape[2] == n_basis, error_msg batch_integral = torch.einsum("eqb...,eq->eb...", batch_integral, jxw) # [n_element, n_basis, ...] return batch_integral @property def device(self) -> torch.device: """Device on which the assembler's buffers live.""" return next(iter(self.transformation.values())).device @property def dtype(self) -> torch.dtype: """Floating dtype of the assembler's buffers (``float32`` or ``float64``).""" return next(iter(self.transformation.values())).dtype
[docs] def type(self, dtype:torch.dtype): super().__doc__ if dtype == torch.float64: self.double() elif dtype == torch.float32: self.float() else: raise Exception(f"the dtype {dtype} is not supported") return self
def _build_compiled_call(self, point_data_keys: list, scalar_data_keys: list, element_type: str) -> Callable: """Build a compiled function for the entire call path. This function directly uses broadcast operations instead of vmap, which is more efficient when compiled with torch.compile. Performance optimizations: - Uses broadcast + sum instead of einsum for better GPU performance - Avoids einsum overhead for simple tensor contractions - Uses matmul for 2D contractions where applicable """ trans = self.transformation[element_type] proj = self.projector[element_type] elements = self.elements[element_type] fn = self.forward signature = inspect.signature(fn) param_keys = list(signature.parameters.keys()) # Pre-compute static data shape_val = trans.batch_shape_val(0, trans.n_quadrature) shape_grad, jxw = trans.batch_shape_grad_jxw( quadrature_start=0, quadrature_batch=trans.n_quadrature ) # Pre-transpose shape_val for matmul: [n_quad, n_basis] -> [n_basis, n_quad] shape_val_T = shape_val.T def compiled_call(point_data_tensors: list, scalar_data_tensors: list) -> torch.Tensor: """Optimized call path using direct broadcast (no vmap).""" # Build ele_point_data ele_point_data = {k: v[elements] for k, v in zip(point_data_keys, point_data_tensors)} scalar_data_dict = {k: v for k, v in zip(scalar_data_keys, scalar_data_tensors)} # Build args with proper shapes for broadcast args = [] for key in param_keys: if key == "v": args.append(shape_val) # [n_quad, n_basis] elif key == "gradv": args.append(shape_grad) # [n_element, n_quad, n_basis, n_dim] elif key in ele_point_data: # Interpolate to quadrature points: [n_element, n_basis] @ [n_basis, n_quad] -> [n_element, n_quad] # Use matmul instead of einsum for better performance args.append(torch.matmul(ele_point_data[key], shape_val_T)) elif key.startswith("grad") and key[4:] in ele_point_data: # Gradient at quadrature points: [n_element, n_basis] -> [n_element, n_quad, n_dim] # Use broadcast + sum instead of einsum: 3.7x faster # einsum("eb,eqbd->eqd") == (x[:, None, :, None] * shape_grad).sum(dim=2) args.append((ele_point_data[key[4:]][:, None, :, None] * shape_grad).sum(dim=2)) elif key in scalar_data_dict: args.append(scalar_data_dict[key]) # Call forward directly - it should handle broadcast automatically # The forward function is written for scalar inputs but works with broadcast # because PyTorch broadcasting rules apply batch_integral = fn(*args) # [n_element, n_quad, n_basis] # Integrate over quadrature points using broadcast + sum instead of einsum # einsum("eqb,eq->eb") == (result * jxw[:, :, None]).sum(dim=1) # This is ~3x faster than einsum batch_integral = (batch_integral * jxw[:, :, None]).sum(dim=1) # [n_element, n_basis] # Project to nodes return proj(batch_integral).flatten() return compiled_call def __call__(self, points:Optional[torch.Tensor] = None, func:Optional[Callable] = None, point_data:Optional[Dict[str, torch.Tensor]]=None, scalar_data:Optional[Dict[str, torch.Tensor]]=None, batch_size:int=1)->torch.Tensor: r"""Assemble the linear form into a global node vector. Parameters ---------- points : torch.Tensor, optional Nodal coordinates of shape :math:`[|\mathcal V|, D]`. If ``None``, the points stored in the cached :class:`Transformation` are used. func : Callable, optional Linear integrand to use *in place of* :meth:`forward`. point_data : dict[str, torch.Tensor], optional Nodal fields, each of shape :math:`[|\mathcal V|, ...]`. Keys can appear as ``forward`` parameters (e.g. ``"f"``) and as gradients (``"gradf"``). scalar_data : dict[str, scalar or torch.Tensor], optional Global scalars passed verbatim to ``forward``. batch_size : int, optional Batch size for quadrature points. Defaults to ``1`` (process one quadrature point at a time); pass ``-1`` to process all quadrature points at once. Returns ------- torch.Tensor 1D tensor of shape :math:`[|\mathcal V|]` (scalar problems) or flattened ``[|\mathcal V| \times H]`` (vector problems with ``H`` degrees of freedom per node). """ if point_data is None: point_data = {} if scalar_data is None: scalar_data = {} if points is None: points = next(iter(self.transformation.values())).points # type:ignore [n_point, n_dim] else: for element_type in self.element_types: assert points.shape[1] == self.transformation[element_type].dim, f"the dimension of the points should be {self.transformation[element_type].dimension}, but got {points.shape[1]}" trans:Transformation = self.transformation[element_type] # type:ignore trans.update_points(points) # type:ignore point_data["x"] = points # type:ignore self = self.type(points.dtype).to(points.device) # type:ignore for key, value in point_data.items(): assert value.shape[0] == points.shape[0], f"the shape of {key} should be [n_point, ...], but got {value.shape}" # Use fast path if enabled (bypasses vmap, uses direct broadcast) if self._compile and len(self.element_types) == 1 and func is None: element_type = self.element_types[0] point_data_keys = sorted([k for k in point_data.keys() if k != "x"]) scalar_data_keys = sorted(scalar_data.keys()) cache_key = f"call_{element_type}_{tuple(point_data_keys)}_{tuple(scalar_data_keys)}" if self._compiled_call_fn is None or getattr(self, '_compiled_cache_key', None) != cache_key: # Build fast call function (uses broadcast, not vmap) raw_fn = self._build_compiled_call(point_data_keys, scalar_data_keys, element_type) # Optionally compile with torch.compile for additional optimization if self._compile_options.get("mode") != "disable": self._compiled_call_fn = torch.compile(raw_fn, **self._compile_options) else: self._compiled_call_fn = raw_fn self._compiled_cache_key = cache_key # Call fast function point_data_tensors = [point_data[k] for k in point_data_keys] scalar_data_tensors = [scalar_data[k] for k in scalar_data_keys] return self._compiled_call_fn(point_data_tensors, scalar_data_tensors) # Original vmap path fn = self.forward if func is None else func signature = inspect.signature(fn) broadcast_fns = [ (lambda x: x=="v" , InputBroadcast(False, True, True)), # [: , n_quadrature, n_v_basis] (lambda x: x=="gradv" , InputBroadcast(True, True, True)), # [n_element, n_quadrature, n_v_basis, n_dim] (lambda x: x in scalar_data.keys(), InputBroadcast(True, True, True)), (lambda x: x in point_data.keys(), InputBroadcast(True, True, False)), # [n_element, n_quadrature, :] (lambda x: x in {"grad" + key for key in point_data.keys()}, InputBroadcast(True, True, False)), # [n_element, n_quadrature, :, n_dim] ] element_dims = [] quadrature_dims = [] v_dims = [] for key in signature.parameters: is_match = False for condition, broadcast in broadcast_fns: if condition(key): element_dims.append(broadcast.element) quadrature_dims.append(broadcast.quadrature) v_dims.append(broadcast.v) is_match = True break if not is_match: raise ValueError(f"{key} is not supported, please use `v`, `gradv` or more keys provided by point_data, element_data or scalar_data") element_dims = tuple(element_dims) quadrature_dims = tuple(quadrature_dims) v_dims = tuple(v_dims) # Determine use_element_parallel based on element_dims use_element_parallel = not all([x is None for x in element_dims]) if all([x is None for x in element_dims]): parallel_fn = vmap(vmap(fn, in_dims=v_dims), in_dims=quadrature_dims) else: parallel_fn = vmap(vmap(vmap(fn, in_dims=v_dims), in_dims=quadrature_dims), in_dims=element_dims) integral:Optional[torch.Tensor] = None # [n_points, ...] for element_type in self.element_types: trans:Transformation = self.transformation[element_type] # type:ignore proj:ReduceProjector = self.projector[element_type] # type:ignore element_integral:Optional[torch.Tensor] = None # [n_element, n_basis, ...] if batch_size in (-1, None): n_batch = 1 n_batch_size = trans.n_quadrature else: n_batch_size = batch_size n_batch = math.ceil(trans.n_quadrature / batch_size) ele_point_data = {k:v[self.elements[element_type]] for k,v in point_data.items()} for i in range(n_batch): shape_val = trans.batch_shape_val(i*n_batch_size, n_batch_size) shape_grad, jxw = trans.batch_shape_grad_jxw( quadrature_start = i*n_batch_size, quadrature_batch = n_batch_size) args = [] for key in signature.parameters: if key in ["v"]: args.append(shape_val) elif key in ["gradv"]: args.append(shape_grad) elif key in ele_point_data: args.append(torch.einsum("eb...,qb->eq...",ele_point_data[key], shape_val)) # point data : [element_batch, quadrature_batch, ...] elif key.startswith("grad") and key[4:] in ele_point_data: # grad point data args.append(torch.einsum("eb...,eqbd->eq...d",ele_point_data[key[4:]], shape_grad)) # grad point data : [element_batch, quadrature_batch, ..., dim] elif key in scalar_data: # type:ignore args.append(scalar_data[key]) else: raise NotImplementedError(f"key {key} is not implemented") batch_integral = parallel_fn(*args) # [n_element, batch_size, n_basis, ...] or [batch_size, n_basis, ...] batch_integral = self._integrate(batch_integral, jxw, trans.n_elements, trans.n_basis, use_element_parallel) element_integral = batch_integral if element_integral is None else element_integral + batch_integral assert element_integral is not None integral = proj(element_integral) if integral is None else integral + proj(element_integral) assert integral is not None return integral.flatten() # [n_points * n_dim] def __post_init__(self, *args, **kwargs): r"""Override this function to precompute some data after the initialization """ pass
[docs] def compile(self, mode: Literal["default", "reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs", "disable"] = "disable", fullgraph: bool = False, dynamic: Optional[bool] = None, backend: str = "inductor", **kwargs) -> "NodeAssembler": r"""Enable fast mode for the assembler to speed up computation. When compile mode is enabled, the ``__call__`` method bypasses vmap and uses direct broadcast operations, achieving up to 5-30x speedup. By default (``mode="disable"``), only the vmap bypass is enabled without ``torch.compile``. This provides the best performance for most cases. Set ``mode="default"`` or other modes to additionally enable ``torch.compile``. Examples -------- .. code-block:: python # Enable fast mode (recommended, no torch.compile overhead) assembler = MassAssembler.from_mesh(mesh).compile() # Enable with torch.compile for potential additional optimization assembler = MassAssembler.from_mesh(mesh).compile(mode="default") # Use normally - automatically uses fast path result = assembler(point_data={'phi': phi, 'f': f}) # Disable for debugging (can set breakpoints in forward) assembler.reset_compile() Parameters ---------- mode : str, optional Compilation mode, one of: - ``"disable"``: Only bypass vmap, no torch.compile (fastest startup, recommended) - ``"default"``: Also use torch.compile with default settings - ``"reduce-overhead"``: torch.compile with reduced Python overhead - ``"max-autotune"``: torch.compile with maximum optimization - ``"max-autotune-no-cudagraphs"``: Like max-autotune but without CUDA graphs Default is ``"disable"`` fullgraph : bool, optional Whether to compile the entire graph. Default is ``False`` dynamic : bool or None, optional Whether to use dynamic shapes. Default is ``None`` (auto-detect) backend : str, optional Compilation backend. Default is ``"inductor"`` **kwargs : dict Additional keyword arguments passed to ``torch.compile`` Returns ------- NodeAssembler Returns self for method chaining See Also -------- reset_compile : Disable fast mode and use vmap path is_compiled : Check if fast mode is enabled """ self._compile = True self._compile_options = { "mode": mode, "fullgraph": fullgraph, "dynamic": dynamic, "backend": backend, **kwargs } self._compiled_call_fn = None return self
[docs] def flat_mode(self) -> "NodeAssembler": r"""Enable the fast broadcast-based implementation without torch.compile. This allows to bypass vmap and use optimized broadcast operations. It is equivalent to calling compile(mode="disable"). Returns ------- NodeAssembler Returns self for method chaining """ return self.compile(mode="disable")
[docs] def reset_compile(self) -> "NodeAssembler": r"""Disable torch.compile and clear the compiled function cache. This is useful for debugging or when you want to switch back to the non-compiled version. Examples -------- .. code-block:: python # Disable compile for debugging assembler.reset_compile() # Now you can set breakpoints in forward() Returns ------- NodeAssembler Returns self for method chaining """ self._compile = False self._compile_options = {} self._compiled_call_fn = None return self
@property def is_compiled(self) -> bool: r"""Check if the assembler is in compile mode. Returns ------- bool True if compile mode is enabled, False otherwise """ return self._compile def __str__(self): return ( f"{self.__class__.__name__}(\n" f" element_types: {self.element_types}\n" f" n_element: {' '.join(f'{k}:{v.shape[0]}' for k, v in self.elements.items())}\n" f" n_point: {self.n_points}\n" f" n_basis: {' '.join(f'{k}:{v.shape[1]}' for k, v in self.elements.items())}\n" f" n_dim: {self.dimension}\n" f" n_quadrature: {' '.join(f'{k}:{v.n_quadrature}' for k, v in self.transformation.items())}\n" f" forward: \n{inspect.getsource(self.forward)}" f")" ) def __repr__(self): return str(self)
[docs] @abstractmethod def forward(self, *args): r"""Define the integrand of the linear form at a single quadrature point. Subclasses must override this method. The library uses :func:`torch.vmap` to lift the per-quadrature-point function over all quadrature points and all elements, so write it as if you were evaluating at *one* point. Parameters ---------- v : torch.Tensor, optional Shape value at the quadrature point — 0D tensor of shape ``[]``. gradv : torch.Tensor, optional Shape gradient in physical coordinates — 1D tensor of shape ``[D]``. x : torch.Tensor, optional Physical coordinate — 1D tensor of shape ``[D]``. gradx : torch.Tensor, optional Gradient of ``x`` w.r.t. reference coordinates — 2D tensor of shape ``[D, D]``. **point_data : torch.Tensor Any key passed to ``__call__`` via ``point_data``: if the nodal tensor has shape :math:`[|\mathcal V|, ...]`, the value handed to ``forward`` has the trailing ``[...]`` shape, and its counterpart ``"grad"+key`` has shape ``[..., D]``. Returns ------- torch.Tensor Either a 1D tensor of shape ``[B]`` (scalar problems) or a 2D tensor of shape ``[B, H]`` (vector problems with ``H`` degrees of freedom per node). """ raise NotImplementedError("forward is not implemented")
[docs] @classmethod def from_assembler(cls, obj, *args,**kwargs): r"""Build a :class:`NodeAssembler` that shares topology with ``obj``. Much faster than :meth:`from_mesh` since the projector and cached :class:`Transformation` are reused as-is. Parameters ---------- obj : NodeAssembler An existing node assembler whose mesh topology should be reused. *args, **kwargs Additional arguments forwarded to ``__post_init__``. Returns ------- NodeAssembler A new assembler sharing the same mesh. """ err_msg = f"the object {obj} should inheritate from NodeAssembler" assert isinstance(obj, NodeAssembler), err_msg return cls( obj.projector, obj.transformation, obj.elements, *args, **kwargs )
[docs] @classmethod def from_elements(cls, points:torch.Tensor, elements:Dict[str, torch.Tensor], quadrature_order:int = 2, device:torch.device|str="cpu", dtype:torch.dtype=torch.float32, project:str = "reduce", *args,**kwargs): r"""Build a :class:`NodeAssembler` from raw connectivity tensors. Slower than :meth:`from_assembler` because the projection backend is built from scratch. Parameters ---------- points : torch.Tensor 2D tensor of shape :math:`[|\mathcal V|, D]` listing node coordinates. elements : dict[str, torch.Tensor] Connectivity keyed by element-type string, e.g. ``{"triangle": tensor([[0, 1, 2], [1, 2, 3]])}``. quadrature_order : int, optional Positive integer; defaults to ``2``. device : torch.device or str, optional Device of the assembler; defaults to ``"cpu"``. dtype : torch.dtype, optional Floating dtype; defaults to :obj:`torch.float32`. project : {'reduce', 'sparse'}, optional Projection backend; ``"reduce"`` (default) uses :meth:`torch.Tensor.index_add_` and is memory-efficient, ``"sparse"`` uses a sparse mat-vec product and is faster but uses more memory. Returns ------- NodeAssembler A new assembler that owns the given topology. """ projector = {} tranformation = {} n_points = points.shape[0] for element_type, value in elements.items(): n_element, n_basis = value.shape # quadrature_weights[element_type], quadrature_points[element_type] =\ # get_quadrature(element_type, quadrature_order) # [n_quadrature], [n_quadrature, n_dim] # shape_val[element_type] = get_shape_val(element_type, quadrature_points[element_type]) # [n_quadrature, n_basis] if project == "reduce": projector[element_type] = ReduceProjector( indices = value.flatten(), from_shape = (n_element, n_basis), to_shape = (n_points,) ) elif project == "sparse": projector[element_type] = SparseProjector( from_ = torch.arange(n_element * n_basis, device=value.device).reshape(n_element, n_basis).flatten(), to_ = value.flatten(), from_shape = (n_element, n_basis), to_shape = (n_points,) ) else: raise ValueError(f"project should be either 'reduce' or 'sparse', but got {project}") tranformation[element_type] = Transformation( points, value, element_type, quadrature_order ) projector = nn.ModuleDict(projector) transformation = nn.ModuleDict(tranformation) elements = BufferDict(elements) # type:ignore assembler = cls( projector, transformation, elements,*args, **kwargs) # type:ignore assembler = assembler.type(dtype).to(device) return assembler
[docs] @classmethod def from_mesh(cls, mesh:Mesh, quadrature_order:int = 2, project:str = "reduce", *args, **kwargs): r"""Build a :class:`NodeAssembler` from a :class:`~tensormesh.Mesh`. Slower than :meth:`from_assembler` because the projection backend :math:`\mathcal P_{\mathcal V}` is precomputed from connectivity. Parameters ---------- mesh : tensormesh.Mesh Source mesh; both connectivity and points are taken from it. quadrature_order : int, optional Positive integer; defaults to ``2``. project : {'reduce', 'sparse'}, optional Projection backend; defaults to ``"reduce"``. *args, **kwargs Additional arguments forwarded to ``__post_init__``. Returns ------- NodeAssembler A new assembler that owns the mesh topology. """ elements = mesh.elements() assert isinstance(mesh.points, torch.Tensor) points = mesh.points if isinstance(elements, torch.Tensor): elements = {mesh.default_element_type: elements} return cls.from_elements(points, elements, quadrature_order, mesh.device, mesh.dtype, project, *args, **kwargs)
NodeAssembler.type.__doc__ = nn.Module.type.__doc__