from abc import abstractmethod
import inspect
import math
from typing import Optional, Callable, List, Mapping, Union
import numpy as np
import scipy.sparse
import torch
import torch.nn as nn
from .projector import ReduceProjector, SparseProjector, Projector
from ..nn import BufferDict
from ..element import element_type2dimension, Transformation
from ..sparse import SparseMatrix
from ..mesh import Mesh
from ..vmap import vmap
class InputBroadcast:
element:Optional[int]
quadrature:Optional[int]
u:Optional[int]
v:Optional[int]
def __init__(self, element:bool,
quadrature:bool,
u:bool,
v:bool):
self.element = 0 if element else None
self.quadrature = 0 if quadrature else None
self.u = 0 if u else None
self.v = 0 if v else None
[文档]
class ElementAssembler(nn.Module):
r"""Assemble an element-wise bilinear form into a global sparse matrix.
:class:`ElementAssembler` is an :class:`torch:torch.nn.Module` whose
:meth:`forward` defines the integrand of a bilinear form
:math:`a(u, v) = \int_\Omega f(u, v)\, \mathrm{d}\Omega`. Calling the
assembler integrates :math:`f` over every quadrature point of every
element and scatters the result into a :class:`~tensormesh.sparse.SparseMatrix`
of shape :math:`[|\mathcal V|, |\mathcal V|]` (scalar problems) or
:math:`[|\mathcal V| \times H, |\mathcal V| \times H]` (vector problems
with :math:`H` degrees of freedom per node).
Subclasses are usually built from a mesh:
* :meth:`from_mesh` — build from a :class:`~tensormesh.Mesh` (slower; precomputes
the projection tensor :math:`\mathcal P_{\mathcal E}`).
* :meth:`from_assembler` — share the topology of another assembler (fast).
The schematic of one ``__call__`` is
.. math::
K \overset{\text{bsr matrix}}{\leftarrow}\hat K_\text{global},
\qquad
\hat K_{\text{global}}^{nkl} = \mathcal P_{\mathcal E}^{nhij}\, \hat K_{\text{local}}^{hklij},
\qquad
\hat K_{ij} = \int_\Omega f(u, v)\, \mathrm{d}\Omega,
where
* :math:`\hat K_{\text{global}}` are the non-zero entries of the global
Galerkin matrix, :math:`K_{\text{global}} \in \mathbb R^{|\mathcal E|\times d \times d}`;
* :math:`\hat K_{\text{local}}` is the per-element Galerkin block,
:math:`K_{\text{local}} \in \mathbb R^{|\mathcal C|\times h \times h \times d \times d}`;
* :math:`\mathcal P_{\mathcal E} \in \mathbb R_{\text{sparse}}^{|\mathcal E|\times|\mathcal C|\times h \times h}`
is the projection (assemble) tensor from local to global;
* :math:`\mathcal C` indexes elements/cells, :math:`\mathcal E` indexes
edges (the unique ``(i, j)`` pairs of basis indices), :math:`\mathcal V`
indexes nodes/vertices, and :math:`h` is the number of basis functions
per element.
Examples
--------
Mass matrix :math:`M_{ij} = \int_\Omega u_i v_j\, \mathrm{d}\Omega`:
.. code-block:: python
import tensormesh
class MassAssembler(tensormesh.ElementAssembler):
def forward(self, u, v):
return u * v
mesh = tensormesh.Mesh.gen_rectangle()
assembler = MassAssembler.from_mesh(mesh)
M = assembler(mesh.points)
Laplace stiffness :math:`K_{ij} = \int_\Omega \nabla u_i \cdot \nabla v_j\, \mathrm{d}\Omega`:
.. code-block:: python
import tensormesh
import tensormesh.functional as F
class LaplaceAssembler(tensormesh.ElementAssembler):
def forward(self, gradu, gradv):
return F.dot(gradu, gradv)
mesh = tensormesh.Mesh.gen_circle()
assembler = LaplaceAssembler.from_mesh(mesh)
K = assembler(mesh.points)
Attributes
----------
projector : torch.nn.ModuleDict
Maps each ``element_type`` to a
:class:`~tensormesh.assemble.projector.Projector` that scatters
per-element contributions of shape :math:`[|\mathcal C_e|, B_e, B_e]`
into the global edge vector of shape :math:`[|\mathcal E|]`.
transformation : torch.nn.ModuleDict
Maps each ``element_type`` to a :class:`~tensormesh.Transformation`
caching per-element Jacobians, shape values, shape gradients, and
``JxW`` at quadrature points.
elements : tensormesh.nn.BufferDict
Maps each ``element_type`` to its connectivity tensor of shape
:math:`[|\mathcal C|, B]`, e.g. ``{"triangle6": tensor([[0, 1, 2], ...])}``.
edges : torch.Tensor
Long tensor of shape :math:`[2, |\mathcal E|]` listing the unique
``(row, col)`` index pairs of the assembled sparse matrix; produced
by deduplicating all per-element basis-pair indices.
n_points : int
Number of mesh points (size of one matrix dimension).
dimension : int
Spatial dimension of the mesh, one of ``1``, ``2``, ``3``.
element_types : list[str]
Element-type strings present in the mesh, e.g. ``["triangle6", "quad9"]``.
"""
projector:nn.ModuleDict # Dict[str, Projector]
""":no-index:"""
transformation:nn.ModuleDict # Dict[str, Transformation]
""":no-index:"""
elements:BufferDict # Dict[str, torch.Tensor]
""":no-index:"""
edges:torch.Tensor
""":no-index:"""
dimension:int
""":no-index:"""
element_types:List[str]
""":no-index:"""
n_points:int
""":no-index:"""
__autodoc__ = [
'__call__',
'forward',
'__post_init__',
'from_assembler',
'from_mesh',
]
[文档]
def __init__(self,
projector:nn.ModuleDict,
transformation:nn.ModuleDict,
elements:BufferDict,
edges:torch.Tensor,
*args,
**kwargs):
super().__init__()
element_types = list(projector.keys())
dimension = element_type2dimension[element_types[0]]
self.projector = projector
self.transformation = transformation
self.elements = elements
self.register_buffer("edges", edges)
self.dimension = dimension
self.element_types = list(elements.keys())
self.n_points = next(iter(self.transformation.values())).n_points # type: ignore
self.__post_init__(*args,**kwargs)
@property
def device(self) -> torch.device:
r"""Device on which the assembler's buffers live."""
return next(iter(self.transformation.values())).device # type: ignore
@property
def dtype(self) -> torch.dtype:
r"""Floating dtype of the assembler's buffers (``float32`` or ``float64``)."""
return next(iter(self.transformation.values())).dtype # type: ignore
[文档]
def type(self, dtype:torch.dtype):
if dtype == torch.float64:
self.double()
elif dtype == torch.float32:
self.float()
else:
raise Exception(f"the dtype {dtype} is not supported")
return self
def _integrate(self, batch_integral, jxw, n_element, n_basis, use_element_parallel):
if use_element_parallel:
error_msg = f"the shape returned by forward function is {[*batch_integral.shape]} which is not supported, should either be [{n_element}, batch_size, {n_basis},{n_basis}] or [{n_element}, batch_size,{n_basis},{n_basis}, dof_per_point, dof_per_point]"
assert batch_integral.dim() == 4 or batch_integral.dim() == 6, error_msg
assert batch_integral.shape[0] == n_element, error_msg
assert batch_integral.shape[2] == n_basis, error_msg
assert batch_integral.shape[3] == n_basis, error_msg
if batch_integral.dim() == 6:
assert batch_integral.shape[-1] == batch_integral.shape[-2], error_msg
batch_integral = torch.einsum("eqij...,eq->eij...", batch_integral, jxw)
else:
error_msg = f"the shape returned by forward function is {[*batch_integral.shape]} which is not supported, should either be [batch_size, {n_basis},{n_basis}] or [batch_size,{n_basis},{n_basis}, dof_per_point, dof_per_point]"
assert batch_integral.dim() == 3 or batch_integral.dim() == 5, error_msg
assert batch_integral.shape[1] == n_basis, error_msg
assert batch_integral.shape[2] == n_basis, error_msg
if batch_integral.dim() == 5:
assert batch_integral.shape[-1] == batch_integral.shape[-2], error_msg
batch_integral = torch.einsum("qij...,eq->eij...", batch_integral, jxw)
return batch_integral
def _build_output(self, integral):
r"""Wrap the per-edge integral tensor in a :class:`SparseMatrix`.
Parameters
----------
integral : torch.Tensor
Per-edge integral of shape :math:`[|\mathcal E|, ...]`.
Returns
-------
SparseMatrix
COO sparse matrix; block-COO if ``integral`` is 3D (vector problems).
"""
if integral.dim() == 1:
return SparseMatrix(integral, self.edges[0], self.edges[1], shape=(self.n_points, self.n_points))
elif integral.dim() == 3:
return SparseMatrix.from_block_coo(integral, self.edges[0], self.edges[1], shape=(self.n_points, self.n_points))
else:
raise Exception(f"the shape of integral is supposed to be 1D or 3D, but got {integral.shape}")
def __call__(self, points:Optional[torch.Tensor] = None,
func:Optional[Callable] = None,
point_data:Optional[Mapping[str, torch.Tensor]] = None,
element_data:Optional[Union[Mapping[str, Mapping[str,torch.Tensor]], Mapping[str,torch.Tensor]]] = None,
scalar_data:Optional[Mapping[str, torch.Tensor]] = None,
batch_size:int = -1):
r"""Assemble the bilinear form into a global sparse matrix.
Parameters
----------
points : torch.Tensor, optional
Nodal coordinates of shape :math:`[|\mathcal V|, D]`. If ``None``,
the points stored in the cached :class:`Transformation` are used.
func : Callable, optional
Bilinear integrand to use *in place of* :meth:`forward`. Useful
when reusing the same topology with different forms.
point_data : Mapping[str, torch.Tensor], optional
Nodal fields, each of shape :math:`[|\mathcal V|, ...]`. Their
keys can appear as parameters of ``forward`` (e.g. ``"kappa"``)
and as gradients (``"gradkappa"``).
element_data : Mapping[str, ...], optional
Per-element data. Either ``{key: {element_type: tensor}}``
(mixed-element meshes) or ``{key: tensor}`` when only one
element type is present.
scalar_data : Mapping[str, scalar or torch.Tensor], optional
Global scalars passed verbatim to ``forward`` (no broadcasting).
batch_size : int, optional
Batch size for quadrature points. ``-1`` (default) processes all
quadrature points at once; positive values split them into
chunks for memory-bound problems.
Returns
-------
SparseMatrix
Sparse matrix of shape :math:`[|\mathcal V|, |\mathcal V|]`, or
block sparse of shape :math:`[|\mathcal V| \times H, |\mathcal V| \times H]`
with :math:`H` degrees of freedom per node (inferred from the
rank of the ``forward`` return).
"""
assert isinstance(point_data, dict) or point_data is None, f"point_data should be a dict, but got {type(point_data)}. Please pass in extra parameter using key-value pairs"
# make sure point data is Dict[str, torch.Tensor]
if point_data is None:
point_data = {}
# make sure element data is Dict[str,Dict[str, torch.Tensor]]
if element_data is None:
element_data = {element_type:{} for element_type in self.element_types} #
else:
if not isinstance(next(iter(element_data.values())), dict):
assert len(self.element_types) == 1
element_type = self.element_types[0]
element_data = {key:{element_type:value} for key, value in element_data.items()} # type:ignore
for key in element_data:
for element_type in self.element_types:
assert element_data[key][element_type].shape[0] == self.elements[element_type].shape[0], f"the shape of {key} should be [{self.elements[element_type].shape[0]}, ...], but got {element_data[key][element_type].shape[0]}"
else:
for key, value in element_data.items():
for element_type in self.element_types:
assert element_data[key][element_type].shape[0] == self.elements[element_type].shape[0]
# make sure scalar data is Dict[str, torch.Tensor]
if scalar_data is None:
scalar_data = {}
else:
scalar_data = {k:torch.tensor(v) for k,v in scalar_data.items()}
# make sure points is torch.Tensor
if points is None:
points = next(iter(self.transformation.values())).points # type: ignore
else:
for element_type in self.element_types:
assert points.shape[1] == self.transformation[element_type].dim, f"the dimension of points should be {self.transformation[element_type].dimension}, but got {points.shape[1]}"
self.transformation[element_type].update_points(points) # type: ignore
point_data["x"] = points # type:ignore [n_point, n_dim]
self = self.type(points.dtype).to(points.device) # type:ignore
for key, value in point_data.items():
assert value.shape[0] == points.shape[0], f"the shape of {key} should be [n_point, ...], but got {value.shape}"
# parse signature
fn = self.forward if func is None else func
signature = inspect.signature(fn)
broadcast_fns = [
(lambda x: x=="u" , InputBroadcast(False, True, True, False)), # [: , n_quadrature, n_u_basis, :]
(lambda x: x=="v" , InputBroadcast(False, True, False, True)), # [: , n_quadrature, :, n_v_basis]
(lambda x: x=="gradu" , InputBroadcast(True, True, True, False)), # [n_element, n_quadrature, n_u_basis, :, n_dim]
(lambda x: x=="gradv" , InputBroadcast(True, True, False, True)), # [n_element, n_quadrature, :, n_v_basis, n_dim]
(lambda x: x in element_data.keys(),
InputBroadcast(True, False, False, False)), # [n_element, :, :, :]
(lambda x: x in scalar_data.keys(),
InputBroadcast(True, True, True, True )),
(lambda x: x in point_data.keys(),
InputBroadcast(True, True, False, False)), # [n_element, n_quadrature, :, :]
(lambda x: x in {"grad" + key for key in point_data.keys()},
InputBroadcast(True, True, False, False)), # [n_element, n_quadrature, :, :, n_dim]
]
element_dims = []
quadrature_dims = []
u_dims = []
v_dims = []
for key in signature.parameters:
is_match = False
for condition, broadcast in broadcast_fns:
if condition(key):
element_dims.append(broadcast.element)
quadrature_dims.append(broadcast.quadrature)
u_dims.append(broadcast.u)
v_dims.append(broadcast.v)
is_match = True
break
if not is_match:
raise ValueError(f"{key} is not supported, please use `u`, `v`, `gradu`, `gradv` or more keys provided by point_data, element_data or scalar_data")
element_dims = tuple(element_dims)
quadrature_dims = tuple(quadrature_dims)
u_dims = tuple(u_dims)
v_dims = tuple(v_dims)
if all([x is None for x in element_dims]):
# if all is shape_val
parallel_fn = vmap(
vmap(
vmap(
fn,
in_dims = v_dims
),
in_dims = u_dims
),
in_dims = quadrature_dims
)
use_element_parallel = False
else:
parallel_fn = vmap(
vmap(
vmap(
vmap(
fn,
in_dims = v_dims
),
in_dims = u_dims
),
in_dims = quadrature_dims
),
in_dims=element_dims
)
use_element_parallel = True
integral:Optional[torch.Tensor] = None
for element_type in self.element_types:
element_integral = None
trans:Transformation = self.transformation[element_type] # type:ignore
proj:Projector = self.projector[element_type] # type:ignore
n_quadrature = trans.n_quadrature
if batch_size in (-1, None):
n_batch = 1
n_batch_size = n_quadrature
else:
n_batch_size = batch_size
n_batch = math.ceil(n_quadrature / batch_size)
elements:torch.Tensor = self.elements[element_type]
ele_point_data = {k:v[elements] for k,v in point_data.items()}
for i in range(n_batch):
# prepare arguments
shape_val = trans.batch_shape_val(i*n_batch_size, n_batch_size)
shape_grad, jxw = trans.batch_shape_grad_jxw(
quadrature_start = i*n_batch_size,
quadrature_batch = n_batch_size)
# shape_val : [quadrature_batch, n_basis]
# shape_grad: [element_batch, quadrature_batch, n_basis, dim]
# jxw : [element_batch, quadrature_batch]
args = []
for key in signature.parameters:
if key in ["u", "v"]:
args.append(shape_val)
elif key in ["gradu", "gradv"]:
args.append(shape_grad)
elif key in ele_point_data:
args.append(torch.einsum("eb...,qb->eq...",ele_point_data[key], shape_val))
# point data : [element_batch, quadrature_batch, ...]
elif key.startswith("grad") and key[4:] in ele_point_data: # grad point data
args.append(torch.einsum("eb...,eqbd->eq...d",ele_point_data[key[4:]], shape_grad))
# grad point data : [element_batch, quadrature_batch, ..., dim]
elif key in element_data: # type:ignore
args.append(element_data[key][element_type]) # type:ignore
elif key in scalar_data: # type:ignore
args.append(scalar_data[key])
else:
raise NotImplementedError(f"key {key} is not implemented")
# parallel dispatch
batch_integral = parallel_fn(*args) # [n_element, batch_size, n_basis, n_basis, ...] or [n_batch, batch_size, n_basis, ...]
batch_integral = self._integrate(batch_integral, jxw, trans.n_elements, trans.n_basis, use_element_parallel)
element_integral = batch_integral if element_integral is None else element_integral + batch_integral
integral = proj(element_integral) if integral is None else integral + proj(element_integral) # type:ignore [n_edge, ...]
return self._build_output(integral)
[文档]
def energy(self, points:Optional[torch.Tensor] = None,
func:Optional[Callable] = None,
point_data:Optional[Mapping[str, torch.Tensor]] = None,
element_data:Optional[Union[Mapping[str, Mapping[str,torch.Tensor]], Mapping[str,torch.Tensor]]] = None,
scalar_data:Optional[Mapping[str, torch.Tensor]] = None,
batch_size:int = -1):
r"""
Calculates the total potential energy of the system by integrating the element energy density over the domain.
This method is designed for energy-based variational problems (e.g., hyperelasticity, phase field, optimization).
It computes the global integral:
.. math::
E = \int_\Omega \psi(\mathbf{u}, \nabla \mathbf{u}, \dots) d\Omega
where :math:`\psi` is the energy density function (defined in :meth:`element_energy` or passed as `func`).
The method handles efficient parallel integration using quadrature rules and PyTorch's `vmap`.
**How it works (Simplified):**
1. You provide nodal data (like displacement `u`) in `point_data`.
2. You define an `element_energy` function that takes arguments like `u` (value) or `graddisplacement` (gradient) and returns the scalar energy density for a **single quadrature point**.
3. This method automatically matches your arguments, broadcasts them over all elements and quadrature points, computes the gradients if requested, and performs the numerical integration.
Parameters
----------
points : torch.Tensor, optional
The current positions of the nodes (shape: `[N_nodes, Dim]`).
If not provided, uses the mesh's initial points.
func : Callable, optional
A function that computes energy density at a single quadrature point.
Signature: `func(arg1, arg2, ...) -> torch.Tensor (scalar)`.
If `None`, uses `self.element_energy`.
**Supported Arguments for `func`:**
* **`u`**, **`v`**, ... : Values interpolated from `point_data` (e.g., passing `point_data={'u': ...}` allows requesting `u`).
* **`gradu`**, **`gradv`**, ...: Gradients interpolated from `point_data` (e.g., passing `point_data={'u': ...}` allows requesting `gradu`).
* **`key`** in `element_data`: Element-wise constant or quadrature-varying data.
* **`key`** in `scalar_data`: Global scalar constants.
point_data : Dict[str, torch.Tensor], optional
Nodal fields (e.g., displacement, phase field).
Shape: `[N_nodes, ...]` or `[N_nodes, Dim]`.
element_data : Dict[str, Tensor] or Dict[str, Dict[str, Tensor]], optional
Data defined on elements.
Can be:
* **Constant per element**: Shape `[N_elements, ...]`.
* **Varying per quadrature point** (e.g., history variables): Shape `[N_elements, N_quad, ...]`.
scalar_data : Dict[str, Tensor], optional
Global constants (e.g., time step `dt`).
batch_size : int, optional
Batch size for processing quadrature points to save memory. Default is -1 (process all at once).
Returns
-------
torch.Tensor
The total scalar energy (shape: `[]`).
Since this operation is differentiable, you can call `.backward()` on it to compute forces (gradients w.r.t `points` or `point_data`).
Examples
--------
**1. Hyperelasticity (Neo-Hookean Energy)**
.. code-block:: python
class NeoHookean(ElementAssembler):
def element_energy(self, graddisplacement):
# graddisplacement is [Dim, Dim] tensor at one quad point
F = torch.eye(3) + graddisplacement
J = torch.det(F)
# ... compute psi ...
return psi
model = NeoHookean.from_mesh(mesh)
# Compute energy
# Auto-differentiable w.r.t u
u = torch.zeros_like(mesh.points, requires_grad=True)
E = model.energy(point_data={"displacement": u})
# Compute Forces (Internal Force Vector)
E.backward()
F_int = u.grad
"""
assert isinstance(point_data, dict) or point_data is None, f"point_data should be a dict"
if point_data is None: point_data = {}
if element_data is None: element_data = {element_type:{} for element_type in self.element_types}
else:
if not isinstance(next(iter(element_data.values())), dict):
assert len(self.element_types) == 1
element_type = self.element_types[0]
element_data = {key:{element_type:value} for key, value in element_data.items()}
for key in element_data:
for element_type in self.element_types:
assert element_data[key][element_type].shape[0] == self.elements[element_type].shape[0]
else:
for key, value in element_data.items():
for element_type in self.element_types:
assert element_data[key][element_type].shape[0] == self.elements[element_type].shape[0]
if scalar_data is None: scalar_data = {}
else: scalar_data = {k:torch.tensor(v) for k,v in scalar_data.items()}
if points is None:
points = next(iter(self.transformation.values())).points
else:
for element_type in self.element_types:
self.transformation[element_type].update_points(points)
point_data["x"] = points
self = self.type(points.dtype).to(points.device)
fn = self.element_energy if func is None else func
signature = inspect.signature(fn)
# Broadcasting rules are inferred from argument names.
# For element_data we support both:
# - per-element constant: [n_elem, ...]
# - per-(element,quadrature): [n_elem, n_quad, ...] (e.g. history variables)
broadcast_fns = [
(lambda x: x in element_data.keys(), InputBroadcast(True, False, False, False)),
(lambda x: x in scalar_data.keys(), InputBroadcast(True, True, True, True)),
(lambda x: x in point_data.keys(), InputBroadcast(True, True, False, False)),
(lambda x: x in {"grad" + key for key in point_data.keys()}, InputBroadcast(True, True, False, False)),
]
element_dims = []
quadrature_dims = []
for key in signature.parameters:
# Special-case element_data: auto-detect whether it varies per quadrature
if key in element_data:
is_quad = False
for etype in self.element_types:
n_quad = self.transformation[etype].n_quadrature
data = element_data[key][etype]
if data.dim() >= 2 and data.shape[1] == n_quad:
is_quad = True
else:
is_quad = False
break
element_dims.append(0)
quadrature_dims.append(0 if is_quad else None)
continue
is_match = False
for condition, broadcast in broadcast_fns[1:]:
if condition(key):
element_dims.append(broadcast.element)
quadrature_dims.append(broadcast.quadrature)
is_match = True
break
if not is_match:
raise ValueError(f"{key} is not supported for energy calculation.")
element_dims = tuple(element_dims)
quadrature_dims = tuple(quadrature_dims)
parallel_fn = vmap(
vmap(fn, in_dims=quadrature_dims),
in_dims=element_dims
)
total_energy = 0.0
for element_type in self.element_types:
trans = self.transformation[element_type]
n_quadrature = trans.n_quadrature
if batch_size in (-1, None):
n_batch = 1
n_batch_size = n_quadrature
else:
n_batch_size = batch_size
n_batch = math.ceil(n_quadrature / batch_size)
elements = self.elements[element_type]
ele_point_data = {k:v[elements] for k,v in point_data.items()}
for i in range(n_batch):
q_start = i * n_batch_size
q_end = q_start + n_batch_size
# IMPORTANT: energy() needs element-wise shape gradients (include element dim).
# batch_shape_grad_jxw returns gradients without the element dimension for some elements,
# which breaks einsum broadcasting. Use full tensors and slice quadrature instead.
shape_val = trans.shape_val[q_start:q_end, :] # [q, b]
shape_grad = trans.shape_grad[:, q_start:q_end, :, :] # [e, q, b, d]
jxw = trans.JxW[:, q_start:q_end] # [e, q]
args = []
for key in signature.parameters:
if key in ele_point_data:
args.append(torch.einsum("eb...,qb->eq...", ele_point_data[key], shape_val))
elif key.startswith("grad") and key[4:] in ele_point_data:
args.append(torch.einsum("eb...,eqbd->eq...d", ele_point_data[key[4:]], shape_grad))
elif key in element_data:
args.append(element_data[key][element_type])
elif key in scalar_data:
args.append(scalar_data[key])
else:
raise NotImplementedError(f"key {key} not implemented")
batch_energy_density = parallel_fn(*args) # [n_elem, n_quad]
batch_energy = (batch_energy_density * jxw).sum()
total_energy += batch_energy
return total_energy
[文档]
def element_energy(self, **kwargs):
r"""
Override this method to define the energy density at a single quadrature point.
This method is called automatically by :meth:`energy` using `vmap` to parallelize over all quadrature points.
Parameters
----------
**kwargs : torch.Tensor
Arguments matching the variable names requested.
Common arguments:
* **u**, **v**: Value of field at quadrature point.
* **gradu**, **gradv**: Gradient of field at quadrature point.
* **element_data_key**: Value of element data (constant or interpolated).
Returns
-------
torch.Tensor
Scalar energy density (:math:`\psi`) for this quadrature point.
"""
raise NotImplementedError("element_energy is not implemented")
[文档]
@abstractmethod
def forward(self, **kwargs):
r"""Define the integrand of the bilinear form at a single quadrature point.
Subclasses must override this method. The library uses
:func:`torch.vmap` to lift the per-quadrature-point function over
all quadrature points and all elements, so write it as if you were
evaluating at *one* point. Parameters are dispatched by name;
return values are integrated against ``JxW`` and scattered by the
caller (see ``ElementAssembler.__call__``).
Parameters
----------
u, v : torch.Tensor, optional
Shape value at the quadrature point — 0D tensor of shape ``[]``.
gradu, gradv : torch.Tensor, optional
Shape gradient in physical coordinates — 1D tensor of shape ``[D]``.
x : torch.Tensor, optional
Physical coordinate at the quadrature point — 1D tensor of shape ``[D]``.
gradx : torch.Tensor, optional
Gradient of ``x`` w.r.t. reference coordinates — 2D tensor of shape ``[D, D]``.
**point_data : torch.Tensor
Any key passed to ``__call__`` via ``point_data``: if the nodal
tensor has shape :math:`[|\mathcal V|, ...]`, the value handed
to ``forward`` has the trailing ``[...]`` shape, and its
counterpart ``"grad"+key`` has shape ``[..., D]``.
Returns
-------
torch.Tensor
Either a 2D tensor of shape ``[B, B]`` (scalar problems) or a 4D
tensor of shape ``[B, B, H, H]`` (vector problems with ``H``
degrees of freedom per node).
"""
raise NotImplementedError("forward is not implemented")
def __post_init__(self):
r"""Override this function to precompute some data after the initialization
"""
pass
def __str__(self):
return (
f"{self.__class__.__name__}(\n"
f" element_types: {self.element_types}\n"
f" n_element: {' '.join(f'{k}:{v.shape[0]}' for k, v in self.elements.items())}\n"
f" n_point: {self.n_points}\n"
f" n_basis: {' '.join(f'{k}:{v.shape[1]}' for k, v in self.elements.items())}\n"
f" n_dim: {self.dimension}\n"
f" n_quadrature: {' '.join(f'{k}:{v.n_quadrature}' for k, v in self.transformation.items())}\n"
f" forward: \n{inspect.getsource(self.forward)}"
f")"
)
def __repr__(self):
return str(self)
[文档]
@classmethod
def from_assembler(cls, obj, *args, **kwargs):
r"""Build an :class:`ElementAssembler` that shares topology with ``obj``.
Much faster than :meth:`from_mesh` since the projector and the
cached :class:`Transformation` are reused as-is.
Parameters
----------
obj : ElementAssembler
An existing assembler whose mesh topology should be reused.
*args, **kwargs
Additional arguments forwarded to ``__post_init__``.
Returns
-------
ElementAssembler
A new assembler sharing the same mesh.
"""
assert isinstance(obj, ElementAssembler), f"obj must be an instance of ElementAssembler, but got {type(obj)}"
return cls(
obj.projector,
obj.transformation,
obj.elements,
obj.edges,
*args,**kwargs)
[文档]
@classmethod
def from_mesh(cls, mesh:Mesh,
quadrature_order:int = 2,
project:str = 'reduce',
*args,
**kwargs):
r"""Build an :class:`ElementAssembler` from a :class:`~tensormesh.Mesh`.
Slower than :meth:`from_assembler` because the projection tensor
:math:`\mathcal P_{\mathcal E}` is precomputed from the connectivity.
Parameters
----------
mesh : tensormesh.Mesh
Source mesh; both connectivity and points are taken from it.
quadrature_order : int, optional
Positive integer; defaults to ``2``.
project : {'reduce', 'sparse'}, optional
Backend used for the element-to-edge scatter; defaults to ``'reduce'``.
*args, **kwargs
Additional arguments forwarded to ``__post_init__``.
Returns
-------
ElementAssembler
A new assembler that owns the mesh topology.
"""
assert project in ['reduce', 'sparse']
points:torch.Tensor = mesh.points # type: ignore
elements = mesh.elements() # type:ignore
n_points:int = points.shape[0]
if isinstance(elements, torch.Tensor):
elements = {mesh.default_element_type: elements}
projector = {}
transformations = {}
######################
# compute the edges
######################
elem_u, elem_v = [], []
for element_type, value in elements.items():
n_element, n_basis = value.shape
for i in range(n_basis):
for j in range(n_basis):
elem_u.append(value[:, i])
elem_v.append(value[:, j])
elem_u, elem_v = torch.stack(elem_u, -1).flatten(), torch.stack(elem_v, -1).flatten() # [num_elements * num_basis * num_basis]
elem_u, elem_v = elem_u.cpu().numpy().copy(), elem_v.cpu().numpy().copy()
tmp = scipy.sparse.coo_matrix(( # used to remove duplicated edges
np.ones_like(elem_u), # data
(elem_u, elem_v), # (row, col)
), shape = (n_points, n_points)).tocsr().tocoo()
edge_u, edge_v = tmp.row, tmp.col
num_edges = len(edge_u)
eids_csr = scipy.sparse.coo_matrix((
np.arange(num_edges), (edge_u, edge_v)
), shape=(n_points, n_points)).tocsr()
elem_eids = np.array(eids_csr[elem_u, elem_v].copy()).ravel()
ptr = 0
for element_type, value in elements.items():
n_element, n_basis = value.shape
elem_eids = np.array(eids_csr[elem_u[ptr:ptr+n_element*n_basis*n_basis], elem_v[ptr:ptr+n_element*n_basis*n_basis]].copy()).ravel()
ptr += n_element * n_basis * n_basis
if project == "reduce":
projector[element_type] = ReduceProjector( # [n_element, n_basis, n_basis] -> [num_edges]
indices = torch.from_numpy(elem_eids),
from_shape = (n_element, n_basis, n_basis),
to_shape = (num_edges,),
)
elif project == "sparse":
projector[element_type] = SparseProjector( # [n_element, n_basis, n_basis] -> [num_edges]
from_ = torch.arange(len(elem_eids), device=points.device),
to_ = torch.from_numpy(elem_eids),
from_shape = (n_element, n_basis, n_basis),
to_shape = (num_edges,),
)
else:
raise ValueError(f"project must be 'reduce' or 'sparse', but got {project}")
transformations[element_type] = Transformation(
points = points,
elements = value,
element_type = element_type,
quadrature_order= quadrature_order
)
edges = torch.from_numpy(np.stack([edge_u, edge_v], 0))
projector = nn.ModuleDict(projector)
transformations = nn.ModuleDict(transformations)
elements = BufferDict({k:v.long() for k,v in elements.items()})
assembler = cls(
projector,
transformations,
elements,
edges,
*args,**kwargs)
assembler = assembler.type(mesh.dtype).to(mesh.device)
return assembler
ElementAssembler.type.__doc__ = nn.Module.type.__doc__