tensormesh.nn

BufferUtils

class BufferDict(data: Dict[str, Tensor] | None = None)[source]

Bases: Module

Module-aware dict of tensors stored as buffers (non-trainable).

Use it whenever you need a dict of plain tensors attached to a Module — for example integer connectivity [n_element, n_basis], vector point data [n_point, D], or precomputed quadrature tables — keyed by element type or field name. Tensors registered through BufferDict follow the parent module under .to(device) / .float() / .cuda(), appear in state_dict(), and do not require gradients.

Two behaviours go beyond a plain register_buffer():

  1. Keys that aren’t valid Python identifiers (anything not matching ^[a-zA-Z_][a-zA-Z0-9_]*$, e.g. "123x" or names with dashes) are stored in an internal OrderedDict (_data) instead of being registered as buffers — Python’s register_buffer rejects such names. Their tensors are still moved by _apply(), so .to(device) and friends still work; they just don’t appear in state_dict().

  2. Buffer ↔ parameter promotion: as_parameter() turns a stored buffer into a trainable Parameter in place, and as_buffer() reverses it. This lets the same container serve both pure-FEM workflows (everything as buffers) and ML workflows where some fields need gradients (e.g. learnable material parameters).

Parameters:

data (Dict[str, Tensor], optional) – Initial key→tensor mapping. Keys matching ^[a-zA-Z_][a-zA-Z0-9_]*$ are registered as buffers via register_buffer(); the rest are kept in the fallback _data dict. Default: empty.

Examples

>>> import torch
>>> from tensormesh.nn import BufferDict
>>> cells = BufferDict({
...     "triangle": torch.zeros(10, 3, dtype=torch.long),
...     "quad":     torch.zeros(5,  4, dtype=torch.long),
... })
>>> cells.to("cuda")              # both tensors move to GPU
>>> list(cells.keys())
['triangle', 'quad']
>>> cells["triangle"].device.type
'cuda'
__init__(data: Dict[str, Tensor] | None = None)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

as_parameter(key: str)[source]

Promote the buffer at key to a trainable torch.nn.Parameter in place.

After this call, self[key] is a Parameter (gradient-tracking, will appear in parameters()); the same key must currently live in _buffers or this will raise KeyError. Reverse with as_buffer().

as_buffer(key: str)[source]

Demote the parameter at key back to a (non-trainable) buffer in place.

Inverse of as_parameter(). The same key must currently live in _parameters. The underlying storage is shared (via detach()); the result no longer requires grad.

keys() Iterable[str][source]

Iterate over keys across all three backing stores (buffers, parameters, fallback).

items() Iterable[Tuple[str, Tensor]][source]

Iterate over (key, tensor) pairs across all three backing stores.

values() Iterable[Tensor][source]

Iterate over tensors across all three backing stores.

is_floating_point() bool[source]

Return True if any stored tensor has a floating-point dtype.

is_complex() bool[source]

Return True if any stored tensor has a complex dtype.

property dtype

torch.dtype of the first registered buffer (representative).

property device

torch.device of the first registered buffer (representative).

to_dict() Mapping[str, Tensor | Module][source]

Return a plain dict view of the contents (no module wiring).

clone() BufferDict[source]

Return a deep copy: every stored tensor is cloned, then wrapped in a fresh BufferDict.

class BufferList(data: Iterable[Tensor] | None = None)[source]

Bases: Module

Module-aware list of tensors stored as buffers (non-trainable).

The list analogue of BufferDict. Same motivation: PyTorch ships torch.nn.ParameterList and torch.nn.ModuleList but no list of buffers. BufferList provides one — tensors are stored under stringified indices via register_buffer(), so they follow .to(device) and appear in state_dict().

Used inside FacetAssembler to hold the per-element-type boundary-facet masks — for mixed-facet elements like prisms and pyramids, each element type contributes more than one mask tensor (e.g. a triangle-facet mask and a quad-facet mask), so a list of buffers per key is the natural shape.

Beyond standard list indexing (int, slice), __getitem__() also accepts a 1D torch.Tensor of indices and returns a fresh BufferList — convenient for gather-style operations.

Like BufferDict, individual entries can be promoted to trainable parameters via as_parameter() (and demoted back via as_buffer()).

Parameters:

data (Iterable[Tensor], optional) – Initial tensors. Default: empty.

Examples

>>> import torch
>>> from tensormesh.nn import BufferList
>>> bl = BufferList([torch.zeros(3), torch.zeros(4)])
>>> bl.append(torch.zeros(5))
>>> len(bl)
3
>>> bl.to("cuda")           # all entries move to GPU
>>> bl[0].device.type
'cuda'
__init__(data: Iterable[Tensor] | None = None)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

as_parameter(key: int)[source]

Promote the buffer at index key to a trainable torch.nn.Parameter in place.

as_buffer(key: int)[source]

Demote the parameter at index key back to a (non-trainable) buffer in place.

Storage is shared via detach(); the result no longer requires grad.

append(value: Tensor)[source]

Append a tensor at the end of the list and register it as a buffer.

insert(index: int, value: Tensor)[source]

Insert value at index, shifting subsequent entries (and their backing keys) right by one.

pop(index: int = -1) Tensor[source]

Remove and return the tensor at index (default: last), shifting subsequent entries left by one.

item() Tensor[source]

Return the sole tensor when the list has length 1; assert otherwise.

is_floating_point() bool[source]

Return True if any stored tensor has a floating-point dtype.

is_complex() bool[source]

Return True if any stored tensor has a complex dtype.

property dtype

torch.dtype of the first entry (representative).

property device

torch.device of the first entry (representative).

to_list() List[Tensor][source]

Return a plain Python list of the contained tensors (no module wiring).

clone() BufferList[source]

Return a deep copy: every stored tensor is cloned, then wrapped in a fresh BufferList.