Source code for tensormesh.sparse.solve

"""Sparse linear system solvers for TensorMesh (legacy entry points).

.. deprecated:: 0.x

   The module-level :func:`spsolve` here is a **TensorMesh-internal
   wrapper scheduled for removal**. The canonical solver path has
   migrated to ``torch-sla``: assembly returns a
   :class:`~tensormesh.sparse.SparseMatrix` (a subclass of
   :class:`torch_sla.SparseTensor`), and you simply call ``K.solve(b)``,
   which auto-detects symmetry / positive-definiteness and dispatches
   directly to ``torch_sla.spsolve``. See
   :doc:`/user_guide/linear_solvers`.

The free function below remains only for the niche case where a caller
holds raw ``(edata, row, col, shape, b)`` arrays without a
:class:`SparseMatrix`. With ``torch-sla`` installed (the default and
recommended path) it dispatches to one of the torch-sla backends —
SciPy / native PyTorch / Eigen / cuDSS / CuPy — chosen by the
``backend`` argument, honours ``method`` / ``preconditioner`` /
``is_spd`` hints, and routes batched right-hand sides through SuperLU.

Without ``torch-sla`` (the legacy fallback path) ``spsolve`` still works
but the choice of algorithm collapses: each fallback wrapper picks a
single hard-coded method (direct SuperLU on CPU / CUDA, BiCGSTAB for the
pure-PyTorch path), and ``method`` / ``preconditioner`` / ``is_spd``
become inert. Install ``torch-sla`` for the full feature set.
"""

import warnings

import torch

try:
    import torch_sla  # noqa: F401  (presence-only check)
    is_torch_sla_available = True
except ImportError:
    is_torch_sla_available = False

from ..utils import is_petsc_available, is_cupy_available


[docs] def spsolve(edata, row, col, shape, b, backend='auto', method='cg', preconditioner='jacobi', tol=1e-5, max_iter=10000, x0=None, is_spd=True, verbose=False): """Solve the sparse linear system ``A x = b`` (legacy entry point). .. deprecated:: 0.x This free function pre-dates the ``torch-sla`` integration and is **scheduled for removal**. The canonical path is to wrap the data in a :class:`~tensormesh.sparse.SparseMatrix` and call its ``solve`` method (inherited from ``torch_sla.SparseTensor``), which auto-detects symmetry / positive-definiteness and routes to ``torch_sla.spsolve`` directly. See :doc:`/user_guide/linear_solvers`. Low-level entry point: takes raw COO arrays instead of a :class:`~tensormesh.sparse.SparseMatrix` object. With ``torch-sla`` installed, dispatches to a differentiable sparse-linear-algebra backend; without it, falls back to a curated mini-stack of SciPy / SuperLU / CuPy / PETSc wrappers. Parameters ---------- edata : torch.Tensor 1D tensor of shape ``[nnz]``: non-zero values of ``A``. row : torch.Tensor 1D int tensor of shape ``[nnz]``: row indices of ``A``. col : torch.Tensor 1D int tensor of shape ``[nnz]``: column indices of ``A``. shape : Tuple[int, int] Dense shape ``(m, n)`` of ``A``. b : torch.Tensor Right-hand side. Shape ``[n]`` for a single RHS, or ``[n, n_batch]`` for batched RHS (auto-routed through SuperLU). backend : str, default ``"auto"`` Torch-sla path: ``"auto"`` (CPU → ``"scipy"``, CUDA → ``"pytorch"``), ``"scipy"``, ``"pytorch"``, ``"eigen"``, ``"cudss"``, ``"cupy"``. Fallback path (no torch-sla): ``"auto"``, ``"petsc"``, ``"cupy"`` — others are accepted but the method/preconditioner hints below are ignored. method : str, default ``"cg"`` Iterative algorithm — ``"cg"``, ``"bicgstab"``, ``"minres"``, ``"gmres"``, ``"lgmres"`` — or one of the direct factorizations ``"lu"``, ``"umfpack"``, ``"cholesky"``, ``"ldlt"``. See the installed ``torch_sla.spsolve`` signature for the canonical list. **Honoured only on the torch-sla path.** On the fallback path, each wrapper uses a fixed algorithm. preconditioner : str, default ``"jacobi"`` ``"jacobi"``, ``"ilu"``, or ``"none"``. Same caveat as ``method`` — torch-sla path only. tol : float, default ``1e-5`` Convergence tolerance (iterative methods). max_iter : int, default ``10000`` Iteration budget (iterative methods). x0 : torch.Tensor, optional Initial guess. Currently consumed only by some fallback wrappers and ignored by torch-sla. is_spd : bool, default ``True`` Hint to the torch-sla path that ``A`` is symmetric positive definite. Picks CG as the default ``method``; set ``False`` for indefinite / non-symmetric ``A`` and combine with ``method="bicgstab"`` or ``"gmres"``. verbose : bool, default ``False`` Print which backend/method was picked. Returns ------- torch.Tensor Solution ``x``, same shape and dtype as ``b``. Notes ----- Both paths are autograd-aware: gradients of ``x`` flow back into ``edata`` and ``b`` via an adjoint sparse solve. On the torch-sla path this is built in to the library; on the fallback path each wrapper supplies its own :class:`torch.autograd.Function` backward. Examples -------- >>> from tensormesh.sparse import spsolve >>> x = spsolve(edata, row, col, (n, n), b) # auto >>> x = spsolve(edata, row, col, (n, n), b, method="lu") # direct >>> x = spsolve(edata, row, col, (n, n), b, backend="cudss") # GPU direct >>> x = spsolve(edata, row, col, (n, n), b, ... is_spd=False, method="bicgstab") # non-SPD """ # Validate inputs assert edata.device == row.device == col.device == b.device, \ f"All inputs must be on same device, got {edata.device}, {row.device}, {col.device}, {b.device}" if edata.dtype != torch.float64: warnings.warn("float64 recommended for better accuracy in spsolve") # Handle batched RHS is_batched = len(b.shape) == 2 # Use torch-sla if available (preferred) if is_torch_sla_available: return _solve_torch_sla( edata, row, col, shape, b, backend=backend, method=method, preconditioner=preconditioner, tol=tol, max_iter=max_iter, x0=x0, is_spd=is_spd, is_batched=is_batched, verbose=verbose ) # Fallback to legacy solvers warnings.warn( "torch-sla not available, using fallback solver. " "Install torch-sla for better performance: pip install torch-sla" ) return _solve_fallback( edata, row, col, shape, b, backend=backend, tol=tol, max_iter=max_iter, x0=x0, is_batched=is_batched, verbose=verbose )
def _solve_torch_sla(edata, row, col, shape, b, backend, method, preconditioner, tol, max_iter, x0, is_spd, is_batched, verbose): """Solve via torch-sla; honours ``method`` / ``preconditioner`` / ``is_spd``. Batched RHS (``b.ndim == 2``) auto-routes to an LU factorization when the user did not request a direct method, since a single factorization + ``n_batch`` back-substitutions beats running an iterative solver per column. """ from .torch_sla_solve import SparseSolveTorchSLA # Map 'auto' to appropriate torch-sla backend. if backend == 'auto': if edata.device.type == 'cuda': backend = 'pytorch' else: backend = 'scipy' # For batched solve, route iterative requests to a direct factorization. _DIRECT_METHODS = ('lu', 'umfpack', 'cholesky', 'ldlt') if is_batched and method not in _DIRECT_METHODS: if verbose: print(f"Using LU for batched solve (batch_size={b.shape[1]})") method = 'lu' if verbose: print(f"Solving with torch-sla: backend={backend}, method={method}, preconditioner={preconditioner}") return SparseSolveTorchSLA.apply( edata, row, col, shape, b, x0, tol, max_iter, backend, method, preconditioner, is_spd ) def _solve_fallback(edata, row, col, shape, b, backend, tol, max_iter, x0, is_batched, verbose): """Fallback dispatcher when torch-sla is unavailable. Each branch picks a fixed algorithm: - CPU + non-batched + ``backend="petsc"`` → PETSc BiCGSTAB + ILU; - CPU + non-batched + otherwise → SciPy ``spsolve`` (direct); - CPU + batched → SciPy SuperLU; - CUDA + non-batched + CuPy available → CuPy ``spsolve`` (direct); - CUDA + batched + CuPy available → CuPy SuperLU; - CUDA + CuPy missing → pure-PyTorch BiCGSTAB (only path that consults ``tol`` / ``max_iter`` / ``x0``). ``method`` and ``preconditioner`` from :func:`spsolve` do not reach this function; they are torch-sla-only knobs. """ # Import fallback solvers from .scipy_solve import SparseSolveScipy, SparseLUSolveScipy from .torch_solve import SparseSolveTorch device = edata.device if device.type == 'cuda': # CUDA fallback if is_cupy_available: from .cupy_solve import SparseSolveCupy, SparseLUSolveCupy if is_batched: return SparseLUSolveCupy.apply(edata, row, col, shape, b) else: return SparseSolveCupy.apply(edata, row, col, shape, b) else: # Use torch sparse solver return SparseSolveTorch.apply(edata, row, col, shape, b, x0, tol, max_iter) else: # CPU fallback if is_batched: return SparseLUSolveScipy.apply(edata, row, col, shape, b) else: if backend == 'petsc' and is_petsc_available: from .petsc_solve import SparseSolvePETSc return SparseSolvePETSc.apply(edata, row, col, shape, b) else: return SparseSolveScipy.apply(edata, row, col, shape, b)