tensormesh.assemble.projector 源代码

from typing import Union, Sequence

import numpy as np
import torch
import torch.nn as nn

Tensor = Union[torch.Tensor, np.ndarray]
Shape = Union[Sequence[int], int, np.ndarray, torch.Size]


[文档] class Projector(nn.Module): """Abstract base for the element-to-global scatter operators. A :class:`Projector` consumes a tensor with leading shape ``from_shape`` (per-element / per-facet quantities) and returns a tensor with leading shape ``to_shape`` (global edge / node indexing), summing duplicates. The two concrete implementations are ``ReduceProjector`` (uses :meth:`torch.Tensor.index_add_`) and ``SparseProjector`` (uses a sparse mat-vec product). """ pass
class ReduceProjector(Projector): """Element-to-global scatter backed by :meth:`torch.Tensor.index_add_`. More widely compatible than ``SparseProjector`` because it only relies on the dense ``index_add_`` kernel that PyTorch ships for every backend. Attributes ---------- indices : torch.Tensor Long tensor of shape :math:`[\prod \text{from\_shape}]` mapping each flat-from index to its flat-to slot. from_shape : tuple Leading shape of accepted inputs (``input.shape[:len(from_shape)]``). to_shape : tuple Leading shape of returned outputs. use_fp64 : bool If ``True``, the accumulation runs in ``float64`` and is cast back to the input dtype on return — useful for deterministic accumulation of many small contributions. """ indices:torch.Tensor from_shape:Shape to_shape:Shape use_fp64:bool def __init__(self, indices:torch.Tensor, from_shape:Shape, to_shape:Shape, use_fp64:bool = False): """Wire up the scatter indices and the input/output shapes. Parameters ---------- indices : torch.Tensor or np.ndarray 1D index tensor of length :math:`\prod \text{from\_shape}`. from_shape : tuple, int, np.ndarray, or torch.Size Leading shape of accepted inputs. to_shape : tuple, int, np.ndarray, or torch.Size Leading shape of returned outputs. use_fp64 : bool, optional Accumulate in ``float64`` (default ``False``). """ super().__init__() if isinstance(indices, np.ndarray): indices = torch.from_numpy(indices) assert indices.dim() == 1, f"indices must be 1D, but got {indices.dim()}" if isinstance(from_shape, int): from_shape = (from_shape,) elif isinstance(from_shape, np.ndarray): assert from_shape.ndim == 1, f"from_shape must be 1D, but got {from_shape.ndim}" if isinstance(to_shape, int): to_shape = (to_shape,) elif isinstance(to_shape, np.ndarray): assert to_shape.ndim == 1, f"to_shape must be 1D, but got {to_shape.ndim}" self.register_buffer("indices", indices) self.from_shape = from_shape self.to_shape = to_shape self.use_fp64 = use_fp64 # Pre-compute for torch.compile compatibility self._from_size = int(np.prod(from_shape)) self._to_size = int(np.prod(to_shape)) @property def device(self): return self.indices.device def __call__(self, x:torch.Tensor)->torch.Tensor: """Scatter ``x`` from ``from_shape`` to ``to_shape``, summing duplicates. Parameters ---------- x : torch.Tensor Input tensor of shape ``[*from_shape, ...]``. Returns ------- torch.Tensor Output tensor of shape ``[*to_shape, ...]``. """ assert self.device == x.device, f"the device of x must be {self.device}, but got {x.device}" assert x.shape[:len(self.from_shape)] == self.from_shape, f"the shape of x must be [{self.from_shape}, ...], but got {x.shape}" dim_shape = x.shape[len(self.from_shape):] x = x.reshape(self._from_size, *dim_shape) o = torch.zeros(self._to_size, *dim_shape, device=x.device, dtype=x.dtype) if self.use_fp64: dtype = x.dtype x = x.double() o = o.index_add_(0, self.indices, x) if self.use_fp64: o = o.type(dtype) o = o.reshape(*self.to_shape, *dim_shape) return o def __str__(self): return f"{type(self).__name__}({self.from_shape} -> {self.to_shape}, device={self.device})" def __repr__(self): return str(self) class SparseProjector(Projector): """Element-to-global scatter backed by a CSR sparse mat-vec product. Faster than ``ReduceProjector`` for large meshes on backends that optimize sparse mat-vec, at the cost of materializing the projection matrix. Attributes ---------- projection : torch.Tensor CSR sparse tensor of shape :math:`(\prod \text{to\_shape}, \prod \text{from\_shape})`. from_shape : tuple Leading shape of accepted inputs. to_shape : tuple Leading shape of returned outputs. """ projection:torch.sparse_csr_tensor from_shape:Shape to_shape:Shape def __init__(self, from_:Tensor, to_:Tensor, from_shape:Shape, to_shape:Shape, dtype = None): """Wire up the scatter index pairs and the input/output shapes. Parameters ---------- from_ : torch.Tensor or np.ndarray 1D source index tensor. to_ : torch.Tensor or np.ndarray 1D destination index tensor (same length as ``from_``). from_shape : tuple, int, np.ndarray, or torch.Size Leading shape of accepted inputs. to_shape : tuple, int, np.ndarray, or torch.Size Leading shape of returned outputs. Examples -------- .. code-block:: python import scipy.sparse m = scipy.sparse.rand(3, 4, 0.5, format="coo") p = SparseProjector(m.col, m.row, 4, 3) """ super().__init__() if isinstance(from_shape, int): from_shape = (from_shape,) elif isinstance(from_shape, np.ndarray): assert from_shape.ndim == 1, f"from_shape must be 1D, but got {from_shape.ndim}" if isinstance(to_shape, int): to_shape = (to_shape,) elif isinstance(to_shape, np.ndarray): assert to_shape.ndim == 1, f"to_shape must be 1D, but got {to_shape.ndim}" if isinstance(from_, np.ndarray): from_ = torch.from_numpy(from_) if isinstance(to_, np.ndarray): to_ = torch.from_numpy(to_) assert from_.shape == to_.shape, f"from_ and to_ must have the same shape, but got {from_.shape}, {to_.shape}" assert len(from_.shape) == 1, f"from_ and to_ must be 1D, but got {from_.shape}, {to_.shape}" if dtype is None: dtype = from_.dtype projection = torch.sparse_coo_tensor( torch.stack([to_,from_],0), torch.ones_like(from_,dtype=torch.float32), size = (np.prod(to_shape), np.prod(from_shape)) ).to_sparse_csr() self.register_buffer("projection", projection) self.from_shape = from_shape self.to_shape = to_shape def type(self, dtype:torch.dtype): if dtype != self.dtype: self.projection = self.projection.type(dtype) return self @property def device(self): return self.projection.device @property def dtype(self): return self.projection.dtype def __call__(self, x:torch.Tensor)->torch.Tensor: """Scatter ``x`` via the cached CSR projection matrix. Parameters ---------- x : torch.Tensor Input tensor of shape ``[*from_shape, ...]``. Returns ------- torch.Tensor Output tensor of shape ``[*to_shape, ...]``. """ assert self.dtype == x.dtype, f"the dtype of x must be {self.dtype}, but got {x.dtype}" assert self.device == x.device, f"the device of x must be {self.device}, but got {x.device}" assert x.shape[:len(self.from_shape)] == self.from_shape, f"the shape of x must be [{self.from_shape}, ...], but got {x.shape}" dim_shape = x.shape[len(self.from_shape):] x = x.reshape(np.prod(self.from_shape), -1) if x.dim() == 1: x = x.unsqueeze(-1) x = (self.projection @ x).squeeze(-1) else: x = self.projection @ x x = x.reshape(*self.to_shape, *dim_shape) return x def __str__(self): return f"{type(self).__name__}({self.from_shape} -> {self.to_shape}, device={self.device})" def __repr__(self): return str(self)