Concepts¶
TensorMesh is a finite-element library written from the ground up for
PyTorch. A Mesh is an torch.nn.Module,
weak forms are plain forward methods that receive basis tensors,
and every linear solve is a differentiable op. The same code that
solves a Poisson problem on a laptop CPU also runs on a GPU and
backpropagates through to a learnable parameter — without changing
the FEM logic.
This page is the mental model: how the modules fit together and the design principles behind those choices.
The FEM workflow¶
Solving a PDE in TensorMesh follows one canonical pipeline:
Mesh → Assembler → SparseMatrix → Condenser → Solve
Mesh discretizes the domain into points and cells.
Assembler turns a weak form (
a(u, v)orl(v)) into aSparseMatrixor load vector.Condenser applies Dirichlet boundary conditions by static condensation, producing a reduced system on the interior DOFs.
Solve dispatches the reduced system to a sparse-linear-algebra backend (via the torch-sla package).
The Quickstart walks through this pipeline end-to-end in about 30 lines of Python. Each subsequent chapter of this guide zooms in on one stage.
Module map¶
The library splits cleanly along the pipeline. The arrows show data flow, not import direction:
┌──────────┐ ┌────────────┐ ┌──────────────┐ ┌───────────┐ ┌─────────┐
│ Mesh │ → │ Assembler │ → │ SparseMatrix │ → │ Condenser │ → │ Solve │
│ (nn. │ │ (Element / │ │ (torch_sla. │ │ (Dirichlet│ │ (torch- │
│ Module) │ │ Node / │ │ SparseTensor│ │ static │ │ sla │
│ │ │ Facet) │ │ + spmm / @) │ │ cond.) │ │ .solve) │
└──────────┘ └────────────┘ └──────────────┘ └───────────┘ └─────────┘
↑ ↑ │
│ │ ▼
┌──────────┐ ┌────────────┐ ┌───────────┐
│ element │ │ functional │ │ Postproc │
│ (Triangle│ │ (voigt, │ │ visualize │
│ Hex,…) │ │ strain,…) │ │ ode step │
└──────────┘ └────────────┘ └───────────┘
What lives in each module:
tensormesh.mesh—Meshand its built-in generators (gen_rectangle,gen_circle,gen_cube, …); meshio I/O; adjacency, partitioning, and graph coloring helpers.tensormesh.element— reference shapes (Triangle,Quadrilateral,Tetrahedron,Hexahedron,Prism,Pyramid,Line), basis evaluation, quadrature rules, and the Gmsh/VTK ↔ TensorMesh ordering convention.tensormesh.assemble— the three weak-form base classesElementAssembler,NodeAssembler,FacetAssembler, plus built-ins for the most common forms (Laplace, mass, linear elasticity, Neo-Hookean, …).tensormesh.sparse—SparseMatrix(subclass oftorch_sla.SparseTensor), so linear systems are solved byK.solve(b)and nonlinear systems byK.nonlinear_solve(residual, u0, *params), dispatched throughtorch-sla. The in-treespsolve/nonlinear_solve()free functions are legacy entry points scheduled for removal.tensormesh.operator—Condenserfor Dirichlet BCs via static condensation.tensormesh.ode— explicit and implicit-linear time integrators (Euler, midpoint, Runge-Kutta) for transient problems.tensormesh.functional— Voigt elasticity helpers, strain / stress, and other tensor utilities used insideforwardmethods.tensormesh.dataset—MeshGenand pre-built multi-frequency equation classes for generating training datasets.tensormesh.material—IsotropicMaterialand library presets (Steel, Aluminum, Rubber, Glass).tensormesh.optimizer—OCOptimizer(Optimality Criteria) for compliance-based topology optimization.tensormesh.visualization— matplotlib (2D) and PyVista (3D) backends; lazily imported byplot().tensormesh.distributed— graph-partitioned distributed assembly across multiple ranks (advanced; see the example gallery).
The sparse-linear-algebra stack (SparseMatrix, .solve /
.nonlinear_solve, gradient-aware adjoint backward) is delegated to
a separate package,
torch-sla, and shared with
other projects in the same ecosystem.
Design principles¶
PyTorch-native. Mesh extends
torch.nn.Module. Its points and per-element connectivity
are buffers; per-node fields are buffers too. Assemblers are also
nn.Module s. There is no separate “FEM kernel” abstraction layer —
everything is a tensor that flows through familiar PyTorch machinery
(.to(device), .double(), state_dict, autograd, JIT
tracing).
Weak forms in pure Python. The only PDE-specific code a user
writes is a forward method that returns the integrand at the
quadrature points:
class LaplaceAssembler(ElementAssembler):
def forward(self, gradu, gradv):
return gradu @ gradv
The library handles reference-element evaluation, geometry,
quadrature weights, and the global assemble-into-sparse step. The
same pattern works for load vectors (NodeAssembler)
and boundary integrals (FacetAssembler).
Tensorized assembly. There is no Python-level loop over
elements. Inside __call__, basis functions and quadrature points
are evaluated once for the whole mesh; the user’s forward runs
on a tensor that already has element and quadrature dimensions
broadcast-ready; the global assemble is a sparse scatter. The
result: assembly is a single GPU kernel.
Differentiable by construction. SparseMatrix.solve is a
torch.autograd.Function with a custom backward (an adjoint
sparse solve). Gradients therefore flow end-to-end from a loss back
through the linear solve, the assembly, and any parameter that
touched either — be it a material coefficient, a Dirichlet value, or
a neural network’s prediction. See Differentiability.
Modular linear algebra. The solver layer is a separate package,
torch-sla. The same FEM code retargets between SciPy (CPU),
Eigen (CPU), native PyTorch (CPU/GPU), cuDSS (GPU), and CuPy (GPU)
by changing one keyword argument. PETSc and Hypre are on the
torch-sla roadmap; until they ship, fallback paths in
tensormesh.sparse provide best-effort support if those
libraries are already installed locally.
What’s next¶
Meshes — build, inspect, and load meshes; per-node and per-cell data; meshio round-tripping.
Elements and Quadrature — the element zoo and the basis / quadrature interface.
Forms — write your own weak form against the
ElementAssembler/NodeAssembler/FacetAssemblercontract.Quickstart — the same pipeline as a complete worked example.