"""FEM-flavoured COO sparse matrix.
:class:`SparseMatrix` extends :class:`torch_sla.SparseTensor` with helpers
that come up in finite-element assembly: block-COO construction, layout
hashing for sparsity-pattern equality checks, scipy interoperability, and
block-stacking. Arithmetic and device-conversion methods are overridden
so that the result type is preserved (``A + B`` of two ``SparseMatrix``
returns ``SparseMatrix``, not the parent type).
"""
import numpy as np
import torch
import scipy.sparse
import hashlib
from typing import List, Optional, Tuple, Union
try:
from torch_sla import SparseTensor
except ImportError as e:
raise ImportError(
"torch-sla is required for TensorMesh sparse operations.\n"
"Install with: pip install torch-sla>=0.2.0"
) from e
[文档]
class SparseMatrix(SparseTensor):
"""COO sparse matrix with FEM-flavoured helpers.
Subclass of ``torch_sla.SparseTensor``. Inherits ``@`` (spmm),
``solve``, autograd-aware values, ``.to_dense()`` and friends from
the parent; adds layout hashing, block-COO assembly, scipy
interoperability and type-preserving arithmetic.
Parameters
----------
edata : torch.Tensor
1D float tensor of shape ``[nnz]``: the non-zero values.
row : torch.Tensor
1D integer tensor of shape ``[nnz]``: row indices.
col : torch.Tensor
1D integer tensor of shape ``[nnz]``: column indices.
shape : Tuple[int, int]
``(nrows, ncols)`` of the dense equivalent.
Attributes
----------
layout_hash : str
SHA-256 of ``(row, col)`` concatenated bytes. Two matrices with
the same sparsity pattern share this hash; useful for caching
any quantity that depends only on the topology (Condenser
permutations, AMG hierarchies, etc.).
Examples
--------
.. code-block:: python
import torch
from tensormesh.sparse import SparseMatrix
edata = torch.tensor([1.0, 2.0, 3.0])
row = torch.tensor([0, 1, 2])
col = torch.tensor([1, 2, 0])
A = SparseMatrix(edata, row, col, shape=(3, 3))
# Inherited from SparseTensor:
y = A @ torch.tensor([1.0, 2.0, 3.0])
x = A.double().solve(torch.ones(3, dtype=torch.float64))
# FEM-specific block-COO assembly: 10 element matrices of size 3x3.
block_data = torch.randn(10, 3, 3)
elem_row = torch.arange(10)
elem_col = torch.arange(10)
K = SparseMatrix.from_block_coo(block_data, elem_row, elem_col, (10, 10))
"""
[文档]
def __init__(self, edata: torch.Tensor, row: torch.Tensor,
col: torch.Tensor, shape: Tuple[int, int]):
# Initialize parent SparseTensor
super().__init__(edata, row.long(), col.long(), shape)
# Compute layout hash for FEM assembly comparison
row_cpu = row.detach().cpu()
col_cpu = col.detach().cpu()
self._layout_hash = hashlib.sha256(
row_cpu.numpy().tobytes() + col_cpu.numpy().tobytes()
).hexdigest()
# ==================== Backward Compatibility Properties ====================
@property
def edata(self) -> torch.Tensor:
"""Alias for ``values`` (legacy TensorMesh API)."""
return self.values
@property
def row(self) -> torch.Tensor:
"""Alias for ``row_indices``."""
return self.row_indices
@property
def col(self) -> torch.Tensor:
"""Alias for ``col_indices``."""
return self.col_indices
@property
def edges(self) -> torch.Tensor:
"""Stacked ``(row, col)`` indices of shape ``[2, nnz]``."""
return torch.stack([self.row_indices, self.col_indices], dim=0)
@property
def layout_hash(self) -> str:
"""SHA-256 of the sparsity pattern; see the class-level ``Attributes`` section."""
return self._layout_hash
@property
def layout_mask(self) -> torch.Tensor:
"""Dense ``[nrows, ncols]`` mask with 1.0 on the sparsity pattern."""
mask = torch.zeros(self.shape, device=self.device, dtype=self.dtype)
mask[self.row_indices, self.col_indices] = 1
return mask
@property
def grad(self) -> Optional['SparseMatrix']:
"""Gradient w.r.t. ``values``, wrapped as a :class:`SparseMatrix`.
Returns ``None`` when no gradient has accumulated on ``values``.
"""
if self.values.grad is None:
return None
return SparseMatrix(
self.values.grad, self.row_indices, self.col_indices, self.shape
)
# ==================== Type-Preserving Helpers ====================
def _wrap(self, result):
"""Wrap a SparseTensor result back into SparseMatrix."""
if isinstance(result, SparseTensor):
return SparseMatrix(result.values, result.row_indices, result.col_indices, result.shape)
return result
# ==================== Arithmetic (preserve SparseMatrix type) ====================
def __add__(self, other):
return self._wrap(super().__add__(other))
def __radd__(self, other):
return self._wrap(super().__radd__(other))
def __sub__(self, other):
return self._wrap(super().__sub__(other))
def __rsub__(self, other):
return self._wrap(super().__rsub__(other))
def __mul__(self, other):
return self._wrap(super().__mul__(other))
def __rmul__(self, other):
return self._wrap(super().__rmul__(other))
def __matmul__(self, other):
result = super().__matmul__(other)
if isinstance(result, SparseTensor):
return self._wrap(result)
return result
# ==================== Device/Dtype Methods (preserve type) ====================
[文档]
def to(self, *args, **kwargs) -> 'SparseMatrix':
return self._wrap(super().to(*args, **kwargs))
[文档]
def cuda(self, device=None) -> 'SparseMatrix':
return self._wrap(super().cuda(device))
[文档]
def cpu(self) -> 'SparseMatrix':
return self._wrap(super().cpu())
[文档]
def float(self) -> 'SparseMatrix':
return self._wrap(super().float())
[文档]
def double(self) -> 'SparseMatrix':
return self._wrap(super().double())
[文档]
def half(self) -> 'SparseMatrix':
return self._wrap(super().half())
[文档]
def detach(self) -> 'SparseMatrix':
return self._wrap(super().detach())
# ==================== FEM-Specific Methods ====================
[文档]
def has_same_layout(self, other: Union[str, 'SparseMatrix']) -> bool:
"""Check whether ``other`` shares this matrix's sparsity pattern.
Parameters
----------
other : str or SparseMatrix
Either another :class:`SparseMatrix` or a ``layout_hash``
string to compare against.
Returns
-------
bool
``True`` iff both have identical ``(row, col)`` arrays.
Raises
------
TypeError
If ``other`` is neither a string nor a :class:`SparseMatrix`.
"""
if isinstance(other, str):
return self._layout_hash == other
elif isinstance(other, SparseMatrix):
return self._layout_hash == other._layout_hash
else:
raise TypeError(f"Expected str or SparseMatrix, got {type(other)}")
[文档]
def degree(self, axis: int = 0) -> torch.Tensor:
"""Non-zero count per row (``axis=0``) or per column (``axis=1``).
Parameters
----------
axis : int, default 0
Aggregate over ``axis``; ``0`` counts nnz per row, ``1`` per
column.
Returns
-------
torch.Tensor
1D int tensor of length ``shape[axis]``.
"""
indices = self.row_indices if axis == 0 else self.col_indices
size = self.shape[0] if axis == 0 else self.shape[1]
return torch.bincount(indices, minlength=size)
[文档]
def transpose(self) -> 'SparseMatrix':
"""Return ``A^T`` as a :class:`SparseMatrix` sharing this matrix's values."""
return SparseMatrix(
self.values, self.col_indices, self.row_indices,
(self.shape[1], self.shape[0])
)
@property
def T(self) -> 'SparseMatrix':
"""Shorthand for :meth:`transpose`."""
return self.transpose()
# ==================== Scipy Interoperability ====================
[文档]
def to_scipy_coo(self) -> scipy.sparse.coo_matrix:
"""Detach and convert to a :class:`scipy.sparse.coo_matrix` on CPU."""
return scipy.sparse.coo_matrix((
self.values.detach().cpu().numpy(),
(self.row_indices.detach().cpu().numpy(),
self.col_indices.detach().cpu().numpy())
), shape=self.shape)
[文档]
def to_sparse_coo(self) -> torch.Tensor:
"""Convert to a :func:`torch.sparse_coo_tensor` (autograd-tracked)."""
return torch.sparse_coo_tensor(
torch.stack([self.row_indices, self.col_indices]),
self.values,
self.shape
)
# ==================== Static Factory Methods ====================
[文档]
@staticmethod
def from_scipy_coo(matrix: scipy.sparse.coo_matrix,
device: str = "cpu",
dtype: torch.dtype = torch.float) -> 'SparseMatrix':
"""Wrap a :class:`scipy.sparse.coo_matrix` as a :class:`SparseMatrix`.
Parameters
----------
matrix : scipy.sparse.coo_matrix
Source matrix.
device : str, default ``"cpu"``
Target torch device.
dtype : torch.dtype, default ``torch.float``
Target value dtype.
Returns
-------
SparseMatrix
"""
edata = torch.from_numpy(matrix.data.astype(np.float32)).to(device).type(dtype)
row = torch.from_numpy(matrix.row.astype(np.int64)).to(device)
col = torch.from_numpy(matrix.col.astype(np.int64)).to(device)
return SparseMatrix(edata, row, col, matrix.shape)
[文档]
@staticmethod
def from_sparse_coo(matrix: torch.Tensor) -> 'SparseMatrix':
"""Wrap a coalesced :func:`torch.sparse_coo_tensor` as a :class:`SparseMatrix`.
Parameters
----------
matrix : torch.Tensor
Sparse COO tensor; will be coalesced in place.
Returns
-------
SparseMatrix
"""
matrix = matrix.coalesce()
return SparseMatrix(
matrix.values(),
matrix.indices()[0],
matrix.indices()[1],
tuple(matrix.shape)
)
[文档]
@staticmethod
def from_dense(tensor: torch.Tensor) -> 'SparseMatrix':
"""Pull non-zero entries out of a dense 2D tensor into COO.
Parameters
----------
tensor : torch.Tensor
2D tensor; entries exactly equal to zero are dropped.
Returns
-------
SparseMatrix
"""
assert tensor.dim() == 2, f"Expected 2D tensor, got {tensor.dim()}D"
rows, cols = torch.where(tensor != 0)
return SparseMatrix(tensor[rows, cols], rows, cols, tuple(tensor.shape))
[文档]
@staticmethod
def from_block_coo(edata: torch.Tensor, row: torch.Tensor,
col: torch.Tensor, shape: Tuple[int, int]) -> 'SparseMatrix':
"""Expand block-COO storage into a flat :class:`SparseMatrix`.
Each block is a small dense matrix attached to a ``(row, col)``
pair in a coarser graph. This is the layout produced by FEM
element assembly when the element data is *vector-valued*
(e.g. linear elasticity), with one ``[dim, dim]`` block per
``(node_i, node_j)`` pair.
Blocks are assumed **square**; the function uses
``edata.shape[1]`` as the block size and ignores ``shape[2]``.
Parameters
----------
edata : torch.Tensor
Block data of shape ``[n_elements, block_size, block_size]``.
row, col : torch.Tensor
Block indices of shape ``[n_elements]``; entries refer to
the coarse graph.
shape : Tuple[int, int]
``(block_rows, block_cols)`` of the coarse graph.
Returns
-------
SparseMatrix
Flat COO matrix of shape
``(shape[0] * block_size, shape[1] * block_size)``.
"""
block_size = edata.shape[1]
assert edata.shape[2] == block_size, (
f"from_block_coo expects square blocks, got "
f"[..., {edata.shape[1]}, {edata.shape[2]}]"
)
edata_flat = edata.flatten()
row_exp = row[:, None].repeat(1, block_size * block_size)
col_exp = col[:, None].repeat(1, block_size * block_size)
i, j = torch.meshgrid(
torch.arange(block_size, device=row.device),
torch.arange(block_size, device=row.device),
indexing='ij'
)
row_final = (row_exp * block_size + i.flatten()).flatten()
col_final = (col_exp * block_size + j.flatten()).flatten()
new_shape = (shape[0] * block_size, shape[1] * block_size)
return SparseMatrix(edata_flat, row_final, col_final, new_shape)
[文档]
@staticmethod
def random(m: int, n: int, density: float = 0.1,
device: str = "cpu", dtype: torch.dtype = torch.float) -> 'SparseMatrix':
"""Random ``m x n`` sparse matrix with the requested density.
Parameters
----------
m, n : int
Dense shape.
density : float, default 0.1
Fraction of entries that are non-zero. Drawn via
:func:`scipy.sparse.random`.
device : str, default ``"cpu"``
dtype : torch.dtype, default ``torch.float``
Returns
-------
SparseMatrix
"""
matrix = scipy.sparse.random(m, n, density, format="coo")
return SparseMatrix.from_scipy_coo(matrix, device=device, dtype=dtype)
[文档]
@staticmethod
def random_layout(m: int, n: int, density: float = 0.1,
device: str = "cpu") -> Tuple[torch.Tensor, torch.Tensor, Tuple[int, int]]:
"""Random ``(row, col, shape)`` triple with no value tensor attached.
Useful when several matrices need to share the same sparsity
pattern; combine with :meth:`random_from_layout`.
Returns
-------
Tuple[torch.Tensor, torch.Tensor, Tuple[int, int]]
``(row, col, (m, n))``.
"""
matrix = scipy.sparse.random(m, n, density, format="coo")
row = torch.from_numpy(matrix.row.astype(np.int64)).to(device)
col = torch.from_numpy(matrix.col.astype(np.int64)).to(device)
return row, col, matrix.shape
[文档]
@staticmethod
def random_from_layout(layout: Tuple[torch.Tensor, torch.Tensor, Tuple[int, int]],
device: str = "cpu",
dtype: torch.dtype = torch.float) -> 'SparseMatrix':
"""Attach random uniform ``[0, 1)`` values to an existing layout.
Parameters
----------
layout : Tuple[torch.Tensor, torch.Tensor, Tuple[int, int]]
``(row, col, shape)`` triple, e.g. from :meth:`random_layout`.
device : str, default ``"cpu"``
dtype : torch.dtype, default ``torch.float``
Returns
-------
SparseMatrix
"""
row, col, shape = layout
edata = torch.rand(row.shape[0], device=device, dtype=dtype)
return SparseMatrix(edata, row.to(device), col.to(device), shape)
[文档]
@staticmethod
def eye(n: int, value: float = 1.,
device: str = "cpu", dtype: torch.dtype = torch.float) -> 'SparseMatrix':
"""``n x n`` sparse identity, optionally scaled by ``value``.
Parameters
----------
n : int
Dimension.
value : float, default 1.0
Diagonal value; ``value * I_n`` is stored.
device : str, default ``"cpu"``
dtype : torch.dtype, default ``torch.float``
Returns
-------
SparseMatrix
"""
indices = torch.arange(n, device=device)
values = torch.ones(n, device=device, dtype=dtype) * value
return SparseMatrix(values, indices, indices.clone(), (n, n))
[文档]
@staticmethod
def full(m: int, n: int, value: float = 1.,
device: str = "cpu", dtype: torch.dtype = torch.float) -> 'SparseMatrix':
"""Constant ``m x n`` matrix stored densely-as-sparse.
Mostly a building block for :meth:`combine_matrix` block layouts;
``value == 0`` is short-circuited to an empty COO tensor.
Parameters
----------
m, n : int
Shape.
value : float, default 1.0
Constant entry.
device : str, default ``"cpu"``
dtype : torch.dtype, default ``torch.float``
Returns
-------
SparseMatrix
"""
if value == 0:
return SparseMatrix(
torch.tensor([], device=device, dtype=dtype),
torch.tensor([], device=device, dtype=torch.int64),
torch.tensor([], device=device, dtype=torch.int64),
(m, n)
)
rows, cols = torch.meshgrid(
torch.arange(m, device=device),
torch.arange(n, device=device),
indexing='ij'
)
edata = torch.ones(m * n, device=device, dtype=dtype) * value
return SparseMatrix(edata, rows.flatten(), cols.flatten(), (m, n))
# ==================== Block Matrix Operations ====================
[文档]
@staticmethod
def combine_vector(matrices: List['SparseMatrix'], axis: int = 0) -> 'SparseMatrix':
"""Stack a list of sparse matrices along ``axis``.
Equivalent of :func:`torch.cat` for :class:`SparseMatrix`. Dense
:class:`torch.Tensor` entries are auto-converted via
:meth:`from_dense`.
Parameters
----------
matrices : List[SparseMatrix or torch.Tensor]
Items must agree on ``shape[1 - axis]``.
axis : int, default 0
``0`` stacks vertically (along rows), ``1`` horizontally.
Returns
-------
SparseMatrix
"""
rows, cols, edatas = [], [], []
offset = 0
fixed_dim = matrices[0].shape[1 - axis]
for mat in matrices:
assert mat.shape[1 - axis] == fixed_dim, "Dimension mismatch"
if isinstance(mat, torch.Tensor):
mat = SparseMatrix.from_dense(mat)
if axis == 0:
rows.append(mat.row_indices + offset)
cols.append(mat.col_indices)
offset += mat.shape[0]
else:
rows.append(mat.row_indices)
cols.append(mat.col_indices + offset)
offset += mat.shape[1]
edatas.append(mat.values)
if axis == 0:
shape = (offset, fixed_dim)
else:
shape = (fixed_dim, offset)
return SparseMatrix(
torch.cat(edatas), torch.cat(rows), torch.cat(cols), shape
)
[文档]
@staticmethod
def combine_matrix(matrices: List[List['SparseMatrix']]) -> 'SparseMatrix':
"""Assemble a 2D block layout of sparse matrices into one matrix.
Each entry of ``matrices[i][j]`` may be:
- ``None`` — treated as a zero block of inferred size;
- ``int`` or ``float`` — expanded via :meth:`full` to a constant
block of the inferred size;
- :class:`torch.Tensor` — auto-converted via :meth:`from_dense`;
- :class:`SparseMatrix` — used as-is.
Block sizes are inferred from the first non-scalar entry in each
row / column; a row or column with only ``None``/scalar entries
will have inferred size zero.
Parameters
----------
matrices : List[List[SparseMatrix or torch.Tensor or int or float or None]]
Rectangular nested list, ``[n_rows][n_cols]``.
Returns
-------
SparseMatrix
Assembled matrix of shape
``(sum(row_sizes), sum(col_sizes))``.
"""
n_rows = len(matrices)
n_cols = len(matrices[0])
# Infer block shapes from the first non-scalar entry in each row/col.
row_sizes = [0] * n_rows
col_sizes = [0] * n_cols
for i in range(n_rows):
for j in range(n_cols):
mat = matrices[i][j]
if mat is not None and not isinstance(mat, (int, float)):
if row_sizes[i] == 0:
row_sizes[i] = mat.shape[0]
if col_sizes[j] == 0:
col_sizes[j] = mat.shape[1]
rows, cols, edatas = [], [], []
row_offset = 0
for i in range(n_rows):
col_offset = 0
for j in range(n_cols):
mat = matrices[i][j]
if mat is not None:
if isinstance(mat, (int, float)):
mat = SparseMatrix.full(row_sizes[i], col_sizes[j], value=mat)
elif isinstance(mat, torch.Tensor):
mat = SparseMatrix.from_dense(mat)
rows.append(mat.row_indices + row_offset)
cols.append(mat.col_indices + col_offset)
edatas.append(mat.values)
col_offset += col_sizes[j]
row_offset += row_sizes[i]
return SparseMatrix(
torch.cat(edatas), torch.cat(rows), torch.cat(cols),
(row_offset, col_offset)
)
[文档]
@staticmethod
def combine(matrices) -> 'SparseMatrix':
"""Dispatch to :meth:`combine_matrix` or :meth:`combine_vector`.
If the first element is a list or tuple, the input is treated as
a 2D block layout (calls :meth:`combine_matrix`); otherwise it is
a 1D stack along axis 0 (calls :meth:`combine_vector`).
"""
if isinstance(matrices[0], (list, tuple)):
return SparseMatrix.combine_matrix(matrices)
else:
return SparseMatrix.combine_vector(matrices, axis=0)
# ==================== String Representation ====================
def __repr__(self) -> str:
return (
f"SparseMatrix(\n"
f" values: {self.values}\n"
f" row : {self.row_indices}\n"
f" col : {self.col_indices}\n"
f" shape : {self.shape}\n"
f" nnz : {self.nnz}\n"
f")"
)