from abc import abstractmethod
import inspect
from typing import Callable, Optional, Dict, List
import torch
import torch.nn as nn
from tensormesh.element.element_type import element_type2element
from .projector import ReduceProjector, SparseProjector
from ..element import element_type2dimension, Transformation
from ..nn import BufferList
from ..mesh import Mesh
from ..vmap import vmap
[docs]
class FacetAssembler(nn.Module):
r"""Assemble an integrand over boundary facets of a mesh.
:class:`FacetAssembler` mirrors :class:`NodeAssembler` but integrates over
:math:`\partial \Omega` instead of :math:`\Omega`. Override
:meth:`forward` to define a per-quadrature-point integrand; calling the
assembler returns a flattened tensor of shape :math:`[|\mathcal V|]` or
:math:`[|\mathcal V| \times H]` (vector-valued problems with :math:`H`
DOFs per node).
Typical uses include Neumann tractions, penalty contact, surface
tension, and Robin boundary conditions.
Examples
--------
Constant downward traction on the boundary:
.. code-block:: python
import torch
from tensormesh import Mesh, FacetAssembler
class TractionAssembler(FacetAssembler):
def forward(self, v):
t = torch.tensor([0.0, -1.0], dtype=v.dtype, device=v.device)
return t * v # contribution at one quadrature point
mesh = Mesh.gen_rectangle()
f = TractionAssembler.from_mesh(mesh)(mesh.points)
Attributes
----------
projector : torch.nn.ModuleDict
Maps each ``element_type`` to a
:class:`~tensormesh.assemble.projector.Projector` that scatters
per-facet basis contributions onto the node vector.
facet_mask : torch.nn.ModuleDict
Maps each ``element_type`` to a :class:`~tensormesh.nn.BufferList`
of boolean masks marking which facets of which elements lie on the
selected boundary (one mask per facet type for mixed-facet shapes,
otherwise a list of one).
transformation : torch.nn.ModuleDict
Maps each ``element_type`` to its cached :class:`~tensormesh.Transformation`,
providing ``facet_shape_val``, ``facet_shape_grad``, and ``FxW``.
n_points : int
Number of mesh points (length of the output vector for scalar problems).
dimension : int
Spatial dimension of the mesh, one of ``1``, ``2``, ``3``.
element_types : list[str]
Element-type strings present in the mesh.
"""
projector:nn.ModuleDict # Dict[str, Projector]
facet_mask:nn.ModuleDict # Dict[str, List[torch.Tensor]]
transformation:nn.ModuleDict # Dict[str, Transformation]
dimension:int
element_types:List[str]
n_points:int
__autodoc__ = [
'__call__',
'forward',
'__post_init__',
'from_assembler',
'from_mesh',
]
[docs]
def __init__(self,
facet_mask:nn.ModuleDict,
projector:nn.ModuleDict,
transformation:nn.ModuleDict,
*args,
**kwargs):
super().__init__()
element_types = list(projector.keys())
dimension = element_type2dimension[element_types[0]]
self.projector = projector
self.facet_mask = facet_mask
self.transformation = transformation
self.dimension = dimension
self.element_types = element_types
self.n_points = next(iter(transformation.values())).n_points # type:ignore
self.__post_init__(*args,**kwargs)
@property
def device(self) -> torch.device:
"""Device on which the assembler's buffers live."""
return next(iter(self.transformation.values())).device # type: ignore
@property
def dtype(self) -> torch.dtype:
"""Floating dtype of the assembler's buffers (``float32`` or ``float64``)."""
return next(iter(self.transformation.values())).dtype # type: ignore
[docs]
def type(self, dtype: torch.dtype):
if dtype == torch.float64:
self.double()
elif dtype == torch.float32:
self.float()
else:
raise Exception(f"the dtype {dtype} is not supported")
return self
def __call__(self, points:Optional[torch.Tensor] = None,
func:Optional[Callable] = None,
point_data:Optional[Dict[str,torch.Tensor]] = None,
)->torch.Tensor:
r"""Integrate the facet form and scatter into a global node vector.
Parameters
----------
points : torch.Tensor, optional
Nodal coordinates of shape :math:`[|\mathcal V|, D]`. If ``None``,
the points stored in the cached :class:`Transformation` are used.
func : Callable, optional
Facet integrand to use *in place of* :meth:`forward`.
point_data : dict[str, torch.Tensor], optional
Nodal fields, each of shape :math:`[|\mathcal V|, ...]`. Keys
can appear as ``forward`` parameters and as gradients
(``"grad"+key``).
Returns
-------
torch.Tensor
1D tensor of shape :math:`[|\mathcal V|]` (scalar problems) or
flattened ``[|\mathcal V| \times H]`` (vector problems).
"""
if point_data is None:
point_data = {}
if points is not None:
self = self.type(points.dtype).to(points.device)
for element_type in self.element_types:
trans:Transformation = self.transformation[element_type] # type:ignore
trans.update_points(points)
else:
points = next(iter(self.transformation.values())).points # type:ignore
point_data["x"] = points # type:ignore
for key, value in point_data.items():
assert value.shape[0] == self.n_points, f"the shape of {key} should be [n_point, ...], but got {value.shape}"
fn = self.forward if func is None else func
signature = inspect.signature(fn)
parallel_fn = vmap(vmap(fn))
integral = None
for element_type in self.element_types:
trans:Transformation = self.transformation[element_type] # type: ignore
proj = self.projector[element_type] # type: ignore
m:torch.Tensor = self.facet_mask[element_type] # type:ignore [n_element, n_facet]
ele_point_data = {k:v[trans.elements] for k,v in point_data.items()}
if trans.element.is_mix_facet: # for pyramid and prism
tri_m, quad_m = self.facet_mask[element_type] # type:ignore [n_element, n_facet]
# prepare arguments
tri_args = []
quad_args= []
for key in signature.parameters:
if key in ["u", "v"]:
tri_shape_val, quad_shape_val = trans.facet_shape_val
tri_shape_val = tri_shape_val.repeat(trans.n_elements, 1, 1, 1)[m] # [n_selected_tri_facet, n_quadrature_per_tri_facet, n_basis]
quad_shape_val= quad_shape_val.repeat(trans.n_elements, 1, 1, 1)[m] # [n_selected_quad_facet, n_quadrature_per_quad_facet, n_basis]
tri_args.append(tri_shape_val)
quad_args.append(quad_shape_val)
elif key in ["gradu", "gradv"]:
tri_shape_grad, quad_shape_grad = trans.facet_shape_grad
tri_args.append(tri_shape_grad[tri_m])
quad_args.append(quad_shape_grad[quad_m])
elif key in ele_point_data:
tri_shape_val, quad_shape_val = trans.facet_shape_val
tri_point_data = torch.einsum("eb...,fqb->efq...",ele_point_data[key], tri_shape_val)
quad_point_data= torch.einsum("eb...,fqb->efq...",ele_point_data[key], quad_shape_val)
tri_point_data = tri_point_data[tri_m] # [n_selected_tri_facet, n_quadrature_per_tri_facet, ...]
quad_point_data= quad_point_data[quad_m]# [n_selected_quad_facet, n_quadrature_per_quad_facet, ...]
tri_args.append(tri_point_data)
quad_args.append(quad_point_data)
elif key.startswith("grad") and key[4:] in ele_point_data: # "key"->"gradkey"
tri_shape_grad, quad_shape_grad = trans.facet_shape_grad
tri_grad_data = torch.einsum("eb...,efqbd->efq...d",ele_point_data[key[4:]], tri_shape_grad)
quad_grad_data= torch.einsum("eb...,efqbd->efq...d",ele_point_data[key[4:]], quad_shape_grad)
tri_grad_data = tri_grad_data[tri_m] # [n_selected_tri_facet, n_quadrature_per_tri_facet, ...., n_dim]
quad_grad_data= quad_grad_data[quad_m] # [n_selected_quad_facet, n_quadrature_per_quad_facet, ...., n_dim]
tri_args.append(tri_grad_data)
quad_args.append(quad_grad_data)
else:
raise NotImplementedError(f"key {key} is not implemented")
# parallel dispatch
tri_integral = parallel_fn(*tri_args)
quad_integral= parallel_fn(*quad_args)
# tri_integral [n_selected_tri_facet, n_quadrature_per_tri_facet, n_basis, ...]
# quad_integral [n_selected_quad_facet, n_quadrature_per_quad_facet, n_basis, ...]
tri_jxw, quad_jxw = trans.JxW
tri_jxw = tri_jxw[m]
quad_jxw= quad_jxw[m]
tri_integral = torch.einsum('fqb..., fq->fb...', tri_integral, tri_jxw) # [n_selected_tri_facet, n_basis, ...]
quad_integral= torch.einsum('fqb..., fq->fb...', quad_integral, quad_jxw) # [n_selected_quad_facet, n_basis, ...]
_integral = torch.cat([tri_integral, quad_integral], dim=0) # [n_selected_tri_facet+n_selected_quad_facet, n_basis, ...]
_integral = proj(_integral) # [n_points, ...]
integral = _integral if integral is None else integral + _integral
else: # same facet type
m:torch.Tensor = self.facet_mask[element_type].item() # type:ignore [n_element, n_facet]
# prepare arguments
args = []
for key in signature.parameters:
if key in ["u", "v"]:
args.append(trans.facet_shape_val.repeat(trans.n_elements, 1, 1, 1)[m]) # type:ignore [n_selected_facet, n_quadrature_per_facet, n_basis]
elif key in ["gradu", "gradv"]:
args.append(trans.facet_shape_grad[m]) # [n_selected_facet, n_qudrature_per_facet, n_basis, n_dim]
elif key in ele_point_data:
_ele_point_data = torch.einsum("eb...,fqb->efq...", ele_point_data[key], trans.facet_shape_val)
args.append(_ele_point_data[m]) # [n_selected_facet, n_quadrature_per_facet, ...]
elif key.startswith("grad") and key[4:] in ele_point_data: # "key" -> "gradkey"
_ele_grad_data = torch.einsum("eb...,efqbd->efq...d", ele_point_data[key[4:]], trans.facet_shape_grad)
args.append(_ele_grad_data[m]) # [n_selected_facet, n_quadrature_per_facet, ..., n_dim]
else:
raise NotImplementedError(f"key {key} is not implemented")
# parallel dispatch
_integral = parallel_fn(*args) # [n_selected_facet, n_quadrature_per_facet, n_basis, ...]
_integral = torch.einsum('fqb..., fq->fb...', _integral, trans.FxW[m]) # [n_selected_facet, n_basis, ...]
_integral = proj(_integral) # [n_points, ...]
integral = _integral if integral is None else integral + _integral
return integral.flatten() # type: ignore
def __post_init__(self, *args, **kwargs):
r"""Override this function to precompute some data after the initialization
"""
pass
def __str__(self):
n_element = {k:trans.n_elements for k, trans in self.transformation.items()}
n_basis = {k:trans.n_basis for k, trans in self.transformation.items()}
return (
f"{self.__class__.__name__}(\n"
f" element_types: {self.element_types}\n"
f" n_element: {n_element}\n"
f" n_point: {self.n_points}\n"
f" n_basis: {n_basis}\n"
f" n_dim: {self.dimension}\n"
f" n_quadrature: {' '.join(f'{k}:{v.n_quadrature}' for k, v in self.transformation.items())}\n"
f" forward: \n{inspect.getsource(self.forward)}"
f")"
)
def __repr__(self):
return str(self)
[docs]
@abstractmethod
def forward(self, *args):
r"""Define the facet integrand at a single quadrature point.
Subclasses must override this method. Vmap dispatches the
per-quadrature-point function over all selected facets, so write
it as if evaluating at *one* facet quadrature point. Unlike
:class:`ElementAssembler.forward`, the basis arguments here keep
the basis dimension (the inner ``vmap(vmap(...))`` covers facet +
quadrature only, not basis).
Parameters
----------
u, v : torch.Tensor, optional
Shape value on the facet — 1D tensor of shape ``[B]``.
gradu, gradv : torch.Tensor, optional
Shape gradient in physical coordinates — 2D tensor of shape ``[B, D]``.
x : torch.Tensor, optional
Physical coordinate at the quadrature point — 1D tensor of shape ``[D]``.
gradx : torch.Tensor, optional
Gradient of ``x`` w.r.t. reference coordinates — 2D tensor of shape ``[D, D]``.
**point_data : torch.Tensor
Any key passed to ``__call__`` via ``point_data``: if the nodal
tensor has shape :math:`[|\mathcal V|, ...]`, the value handed
to ``forward`` has the trailing ``[...]`` shape, and its
counterpart ``"grad"+key`` has shape ``[..., D]``.
Returns
-------
torch.Tensor
1D tensor of shape ``[B]`` (scalar problems) or 2D tensor of
shape ``[B, H]`` (vector problems with ``H`` degrees of freedom
per node).
"""
raise NotImplementedError("forward is not implemented")
[docs]
@classmethod
def from_assembler(cls, obj, *args, **kwargs):
r"""Build a :class:`FacetAssembler` sharing topology with ``obj``.
Much faster than :meth:`from_mesh` since the facet mask, projector,
and cached :class:`Transformation` are reused as-is.
Parameters
----------
obj : FacetAssembler
An existing facet assembler whose boundary topology should be reused.
*args, **kwargs
Additional arguments forwarded to ``__post_init__``.
Returns
-------
FacetAssembler
A new assembler sharing the same boundary topology.
"""
err_msg = f"the object {obj} should inheritate from NodeAssembler"
assert isinstance(obj, FacetAssembler), err_msg
return cls(
obj.facet_mask,
obj.projector,
obj.transformation,
*args,**kwargs
)
[docs]
@classmethod
def from_elements(cls, points:torch.Tensor,
elements:Dict[str,torch.Tensor],
boundary_mask:torch.Tensor,
quadrature_order:int = 2,
device:str|torch.device="cpu",
dtype:torch.dtype=torch.float32,
project:str = "reduce",
*args,**kwargs):
r"""Build a :class:`FacetAssembler` from raw connectivity tensors.
Slower than :meth:`from_assembler` because the boundary topology is
rebuilt from scratch.
Parameters
----------
points : torch.Tensor
2D tensor of shape :math:`[|\mathcal V|, D]` listing node coordinates.
elements : dict[str, torch.Tensor]
Connectivity keyed by element-type string, e.g.
``{"triangle": tensor([[0, 1, 2], [1, 2, 3]])}``.
boundary_mask : torch.Tensor
1D boolean tensor of shape :math:`[|\mathcal V|]` marking which
nodes lie on the boundary; a facet is selected iff *all* of its
corner nodes are flagged.
quadrature_order : int, optional
Positive integer; defaults to ``2``.
device : torch.device or str, optional
Device of the assembler; defaults to ``"cpu"``.
dtype : torch.dtype, optional
Floating dtype; defaults to :obj:`torch.float32`.
project : {'reduce', 'sparse'}, optional
Projection backend; defaults to ``"reduce"``.
Returns
-------
FacetAssembler
A new assembler that owns the given boundary topology.
"""
n_points = points.shape[0] # TODO: move transformation to the __call__
projector = {}
facet_mask = {}
trasnformations = {}
# compute the facet_mask -> facet_quadrature
for element_type, value in elements.items(): # type: ignore
element = element_type2element(element_type)
if element.is_mix_facet:
is_boundary_element = boundary_mask[value].any(-1)
boundary_elements = value[is_boundary_element] # [n_boundary_element, n_basis_per_cell]
trans = Transformation(
points,
boundary_elements,
element_type,
quadrature_order)
tri_boundary_facet_candidate, quad_boundary_facet_candidate = trans.facets
# tri_boundary_facet_candidate [n_boundary_element, n_tri_facet, n_basis_per_tri_facet]
# quad_boundary_facet_candidate [n_boundary_element, n_quad_facet, n_basis_per_quad_facet]
is_tri_boundary_facet = boundary_mask[tri_boundary_facet_candidate].all(-1) # [n_boundary_element, n_tri_facet]
is_quad_boundary_facet= boundary_mask[quad_boundary_facet_candidate].all(-1) # [n_boundary_element, n_quad_facet]
n_selected_tri_facet = int(is_tri_boundary_facet.sum().item())
n_selected_quad_facet = int(is_quad_boundary_facet.sum().item())
n_basis = trans.n_basis # n_basis_per_cell
# For each selected facet, fetch the *whole cell* connectivity. This is needed
# because the integrand is computed on cell-wise shape functions; entries that
# do not belong to the facet are exactly zero for Lagrange bases, so scattering
# them via these cell dofs only adds zeros to "off-facet" global dofs.
tri_elem_idx = is_tri_boundary_facet.nonzero(as_tuple=True)[0] # [n_selected_tri_facet]
quad_elem_idx = is_quad_boundary_facet.nonzero(as_tuple=True)[0] # [n_selected_quad_facet]
tri_cell_dofs = boundary_elements[tri_elem_idx] # [n_selected_tri_facet, n_basis_per_cell]
quad_cell_dofs = boundary_elements[quad_elem_idx] # [n_selected_quad_facet, n_basis_per_cell]
elements[element_type] = boundary_elements
trasnformations[element_type] = trans
if project == "reduce":
projector[element_type] = ReduceProjector(
indices = torch.cat([tri_cell_dofs.flatten(), quad_cell_dofs.flatten()]), # [(n_sel_tri + n_sel_quad) * n_basis_per_cell]
from_shape = (n_selected_tri_facet + n_selected_quad_facet, n_basis),
to_shape = (n_points,)
)
elif project == "sparse":
n_entries = (n_selected_tri_facet + n_selected_quad_facet) * n_basis
projector[element_type] = SparseProjector(
from_ = torch.arange(n_entries),
to_ = torch.cat([tri_cell_dofs.flatten(), quad_cell_dofs.flatten()]),
from_shape = (n_selected_tri_facet + n_selected_quad_facet, n_basis),
to_shape = (n_points,)
)
facet_mask[element_type] = BufferList([is_tri_boundary_facet, is_quad_boundary_facet])
else: # same facet type
is_boundary_element = boundary_mask[value].any(-1)
boundary_elements = value[is_boundary_element] # [n_boundary_element, n_basis_per_cell]
trans = Transformation(
points,
boundary_elements,
element_type,
quadrature_order)
boundary_facet_candidate = trans.facets # [n_boundary_element, n_facet, n_basis_per_facet]
is_boundary_facet = boundary_mask[boundary_facet_candidate].all(-1) # [n_boundary_element, n_facet]
n_selected_facet = int(is_boundary_facet.sum().item())
n_basis = trans.n_basis # n_basis_per_cell
# For each selected facet, fetch the *whole cell* connectivity so that the
# projector indices align with the cell-basis dimension of the integrand
# produced in __call__. Lagrange basis functions vanish on facets they do
# not belong to, so scattering the corresponding zero entries is harmless.
selected_elem_idx = is_boundary_facet.nonzero(as_tuple=True)[0] # [n_selected_facet]
cell_dofs_per_facet = boundary_elements[selected_elem_idx] # [n_selected_facet, n_basis_per_cell]
elements[element_type] = boundary_elements
facet_mask[element_type] = BufferList([is_boundary_facet])
trasnformations[element_type] = trans
if project == "reduce":
projector[element_type] = ReduceProjector(
indices = cell_dofs_per_facet.flatten(),
from_shape = (n_selected_facet, n_basis),
to_shape = (n_points,)
)
elif project == "sparse":
projector[element_type] = SparseProjector(
from_ = torch.arange(n_selected_facet * n_basis),
to_ = cell_dofs_per_facet.flatten(),
from_shape = (n_selected_facet, n_basis),
to_shape = (n_points,)
)
else:
raise ValueError(f"project should be either 'reduce' or 'sparse', but got {project}")
facet_mask = nn.ModuleDict(facet_mask)
projector = nn.ModuleDict(projector)
transformation = nn.ModuleDict(trasnformations)
assembler = cls(
facet_mask,
projector,
transformation,
*args,**kwargs)
assembler = assembler.type(dtype).to(device)
return assembler
[docs]
@classmethod
def from_mesh(cls, mesh:Mesh,
boundary_mask:Optional[str|torch.Tensor]=None,
quadrature_order:int=2,
project:str = "reduce",
*args,**kwargs):
r"""Build a :class:`FacetAssembler` from a :class:`~tensormesh.Mesh`.
Slower than :meth:`from_assembler` because the boundary topology is
rebuilt from connectivity.
Parameters
----------
mesh : tensormesh.Mesh
Source mesh; connectivity, points, and (default) boundary mask
are read from it.
boundary_mask : str, torch.Tensor, or None, optional
Boundary selector. ``None`` (default) uses ``mesh.boundary_mask``;
``str`` keys into ``mesh.point_data``; a tensor is used verbatim
and must be 1D boolean of length ``n_points``.
quadrature_order : int, optional
Positive integer; defaults to ``2``.
project : {'reduce', 'sparse'}, optional
Projection backend; defaults to ``"reduce"``.
Returns
-------
FacetAssembler
A new assembler that owns the boundary topology of the mesh.
"""
points:torch.Tensor = mesh.points # type:ignore
elements = mesh.elements()
n_points = points.shape[0]
if isinstance(elements, torch.Tensor):
elements = {mesh.default_element_type: elements} # type:ignore
if boundary_mask is None:
boundary_mask = mesh.boundary_mask
elif isinstance(boundary_mask, str):
boundary_mask = mesh.point_data[boundary_mask]
assert boundary_mask.dim() == 1 and boundary_mask.shape[0] == n_points
return cls.from_elements(points,
elements, # type:ignore
boundary_mask,
quadrature_order,
mesh.device,
mesh.dtype,
project,
*args,**kwargs)
FacetAssembler.type.__doc__ = nn.Module.type.__doc__