tensormesh.dataset.equation.poisson 源代码

import torch 


[文档] class PoissonMultiFrequency: r""" Multi-frequency wave equation, with :math:`0` boundary condition .. math:: -\Delta u = f \quad (x, y) where :math:`(x_1,x_2)\in [0,1]^2`, with the boundary condition :math:`u(t, \pm 1, \pm 1) = 0` Parameters ----------- a: torch.Tensor , optional 3D tensor of shape :math:`[N, K, K]` or 2D tensor of shape :math:`[K, K]`, where :math:`N` is the number of samples, :math:`K` is the dimension of the frequencies the coefficient of the wave equation, if ``None``, it will be randomly generated by :math:`\mu\sim Unif([-1,1]^{K\times K})` K: int, optional the dimension of the frequencies, if ``a`` is not ``None``, this parameter will be ignored if ``a`` is ``None``, it will be used to generate the random ``a`` c: float, optional the poisson speed, default is :math:`1.0` r: float, optional the coefficient of the poisson equation, default is :math:`0.5` """
[文档] def __init__(self, a=None, K=2, r= -0.5 ): if a is None: assert K is not None, "K should be specified if a is None" a = torch.zeros((K, K)).uniform_(-1, 1) else: K = a.shape[-1] assert a.shape[-2:] == (K, K), f"the shape of a should be (N, {K}, {K}) or ({K}, {K}), but got {a.shape}" self.K = K self.a = a self.r = r
[文档] def source_term(self, points, domain="rectangle"): r"""Generate the poisson source function at each point in the domain .. math:: f=\frac{\pi}{K^2} \sum_{i,j=1}^{K} a_{ij} \cdot (i^2 + j^2)^{r} sin(\pi ix) sin(\pi jy) Parameters ---------- points: torch.Tensor 2D tensor of shape :math:`[|\mathcal V|, 2]`, where :math:`|\mathcal V|` is the number of vertices all the points must be in :math:`[0,1]^2` domain: str, optional Domain shape. Analytical solution only supports ``"rectangle"`` (default). Returns ------- u0: torch.Tensor 1D tensor of shape :math:`[|\mathcal V|]` or 2D Tensor :math:`[N, |\mathcal V|]`, where :math:`N` is the number of samples, :math:`|\mathcal V|` is the number of vertices """ assert points.shape[-1] == 2, f"the shape of points must be [n_points, 2], but got {points.shape}" if domain == "rectangle": assert ((points<=1) & (points>=0)).all(), f"the points must be in [0,1]^2, but got {points}" # Memory-efficient implementation: # Avoid materializing [batch, n_points, K, K] intermediates by using matmul: # f(p) = (pi/K^2) * sum_{i,j} B_ij * sin(i*pi*x_p) * sin(j*pi*y_p) # where B = a * (i^2+j^2)^(-r) # # For each point p, this is: sinx_p^T * B * siny_p. # Compute P = sinx @ B -> [n_points, K] (or [batch, n_points, K]) then dot with siny. K = self.K device = points.device dtype = points.dtype pi = torch.pi k_idx = torch.arange(1, K + 1, device=device, dtype=dtype) i, j = torch.meshgrid(k_idx, k_idx, indexing="ij") # [K, K] w = (i * i + j * j) ** (-self.r) # [K, K] x = points[:, 0] # [n_points] y = points[:, 1] # [n_points] sinx = torch.sin(pi * x[:, None] * k_idx[None, :]) # [n_points, K] siny = torch.sin(pi * y[:, None] * k_idx[None, :]) # [n_points, K] if len(self.a.shape) == 2: B = self.a.to(device=device, dtype=dtype) * w # [K, K] P = sinx @ B # [n_points, K] f = (P * siny).sum(dim=-1) # [n_points] else: # Batched case: avoid allocating huge [N, n_points, K] tensors by chunking over batch dimension. B = self.a.to(device=device, dtype=dtype) * w # [N, K, K] N = B.shape[0] # Target ~512MB temporary for P: chunk * n_points * K * sizeof(dtype) bytes_per = torch.finfo(dtype).bits // 8 denom = int(sinx.shape[0] * K * bytes_per) target_bytes = 512 * 1024 * 1024 chunk = max(1, target_bytes // max(1, denom)) out = [] for s in range(0, N, chunk): Be = B[s : s + chunk] # [c, K, K] # (sinx @ Be) -> [c, n_points, K] (broadcasted matmul) P = torch.matmul(sinx, Be) # [c, n_points, K] out.append((P * siny[None, :, :]).sum(dim=-1)) # [c, n_points] f = torch.cat(out, dim=0) # [N, n_points] f = (pi / (K * K)) * f return f
[文档] def solution(self, points): r"""Generate the poisson solution function at each point in the domain .. math:: u(x, y) = \frac{1}{\pi\cdot K^2} \sum_{i,j=1}^{K} a_{ij} \cdot (i^2 + j^2)^{r-1} sin(\pi ix) sin(\pi jy) Parameters ---------- points: torch.Tensor 2D tensor of shape :math:`[|\mathcal V|, 2]`, where :math:`|\mathcal V|` is the number of vertices all the points must be in :math:`[0,1]^2` Returns ------- u: torch.Tensor 1D tenor of shape :math:`[|\\mathcal V|]` or :math:`[N, |\\mathcal V|]`, where :math:`N` is the number of samples, :math:`|\\mathcal V|` is the number of vertices """ K = self.K device = points.device dtype = points.dtype pi = torch.pi k_idx = torch.arange(1, K + 1, device=device, dtype=dtype) i, j = torch.meshgrid(k_idx, k_idx, indexing="ij") # [K, K] w = (i * i + j * j) ** (-self.r - 1) # [K, K] x = points[:, 0] # [n_points] y = points[:, 1] # [n_points] sinx = torch.sin(pi * x[:, None] * k_idx[None, :]) # [n_points, K] siny = torch.sin(pi * y[:, None] * k_idx[None, :]) # [n_points, K] if len(self.a.shape) == 2: B = self.a.to(device=device, dtype=dtype) * w # [K, K] P = sinx @ B # [n_points, K] u = (P * siny).sum(dim=-1) # [n_points] else: # Batched case: chunk over batch dimension to keep memory bounded. B = self.a.to(device=device, dtype=dtype) * w # [N, K, K] N = B.shape[0] bytes_per = torch.finfo(dtype).bits // 8 denom = int(sinx.shape[0] * K * bytes_per) target_bytes = 512 * 1024 * 1024 chunk = max(1, target_bytes // max(1, denom)) out = [] for s in range(0, N, chunk): Be = B[s : s + chunk] P = torch.matmul(sinx, Be) # [c, n_points, K] out.append((P * siny[None, :, :]).sum(dim=-1)) # [c, n_points] u = torch.cat(out, dim=0) # [N, n_points] u = (1.0 / (pi * (K * K))) * u return u
class PoissonMultiFrequency3D: r""" Multi-frequency Poisson equation in 3D, with :math:`0` boundary condition .. math:: -\Delta u = f \quad (x, y, z) \in [0,1]^3 with the boundary condition :math:`u = 0` on the boundary of :math:`[0,1]^3` Parameters ----------- a: torch.Tensor , optional 4D tensor of shape :math:`[N, K, K, K]` or 3D tensor of shape :math:`[K, K, K]`, where :math:`N` is the number of samples, :math:`K` is the dimension of the frequencies. If ``None``, it will be randomly generated from :math:`\text{Unif}([-1,1]^{K \times K \times K})` K: int, optional the dimension of the frequencies; if ``a`` is not ``None``, this parameter will be ignored r: float, optional the coefficient exponent, default is :math:`-0.5` """ def __init__(self, a=None, K=2, r=-0.5): if a is None: assert K is not None, "K should be specified if a is None" a = torch.zeros((K, K, K)).uniform_(-1, 1) else: K = a.shape[-1] assert a.shape[-3:] == (K, K, K), f"the shape of a should be (N, {K}, {K}, {K}) or ({K}, {K}, {K}), but got {a.shape}" self.K = K self.a = a self.r = r def source_term(self, points, domain="cube"): r"""Generate the Poisson source function at each point in the 3D domain .. math:: f = \frac{\pi}{K^3} \sum_{i,j,k=1}^{K} a_{ijk} \cdot (i^2 + j^2 + k^2)^{-r} \sin(\pi ix) \sin(\pi jy) \sin(\pi kz) Parameters ---------- points: torch.Tensor 2D tensor of shape :math:`[|\mathcal V|, 3]`, where :math:`|\mathcal V|` is the number of vertices all the points must be in :math:`[0,1]^3` domain: str, optional Domain shape. Default: ``"cube"``. Returns ------- f: torch.Tensor 1D tensor of shape :math:`[|\mathcal V|]` or 2D Tensor :math:`[N, |\mathcal V|]` """ assert points.shape[-1] == 3, f"the shape of points must be [n_points, 3], but got {points.shape}" if domain == "cube": assert ((points <= 1) & (points >= 0)).all(), f"the points must be in [0,1]^3, but got min={points.min()}, max={points.max()}" K = self.K device = points.device dtype = points.dtype pi = torch.pi k_idx = torch.arange(1, K + 1, device=device, dtype=dtype) # Create 3D meshgrid for i, j, k i, j, k = torch.meshgrid(k_idx, k_idx, k_idx, indexing="ij") # [K, K, K] w = (i * i + j * j + k * k) ** (-self.r) # [K, K, K] x = points[:, 0] # [n_points] y = points[:, 1] # [n_points] z = points[:, 2] # [n_points] sinx = torch.sin(pi * x[:, None] * k_idx[None, :]) # [n_points, K] siny = torch.sin(pi * y[:, None] * k_idx[None, :]) # [n_points, K] sinz = torch.sin(pi * z[:, None] * k_idx[None, :]) # [n_points, K] if len(self.a.shape) == 3: # Non-batched: a is [K, K, K] B = self.a.to(device=device, dtype=dtype) * w # [K, K, K] # Compute sum_{i,j,k} B_ijk * sinx_i * siny_j * sinz_k for each point # = sinx @ B @ siny^T then element-wise with sinz # More efficient: (sinx @ B.reshape(K, K*K)) -> [n_points, K*K] # Then reshape and sum with siny, sinz # Use einsum for clarity: f_p = sum_{ijk} B_ijk * sinx_{p,i} * siny_{p,j} * sinz_{p,k} f = torch.einsum('pi,pj,pk,ijk->p', sinx, siny, sinz, B) else: # Batched case: a is [N, K, K, K] B = self.a.to(device=device, dtype=dtype) * w # [N, K, K, K] N = B.shape[0] # Chunk over batch to limit memory bytes_per = torch.finfo(dtype).bits // 8 n_pts = sinx.shape[0] # einsum('pi,pj,pk,nijk->np') would need [N, n_pts, K, K, K] intermediate # We chunk over N instead target_bytes = 512 * 1024 * 1024 chunk = max(1, target_bytes // max(1, n_pts * K * K * K * bytes_per)) out = [] for s in range(0, N, chunk): Be = B[s : s + chunk] # [c, K, K, K] fe = torch.einsum('pi,pj,pk,nijk->np', sinx, siny, sinz, Be) # [c, n_pts] out.append(fe) f = torch.cat(out, dim=0) # [N, n_points] f = (pi / (K ** 3)) * f return f def solution(self, points): r"""Generate the Poisson solution function at each point in the 3D domain .. math:: u(x, y, z) = \frac{1}{\pi \cdot K^3} \sum_{i,j,k=1}^{K} a_{ijk} \cdot (i^2 + j^2 + k^2)^{r-1} \sin(\pi ix) \sin(\pi jy) \sin(\pi kz) Parameters ---------- points: torch.Tensor 2D tensor of shape :math:`[|\mathcal V|, 3]` all the points must be in :math:`[0,1]^3` Returns ------- u: torch.Tensor 1D tensor of shape :math:`[|\mathcal V|]` or :math:`[N, |\mathcal V|]` """ K = self.K device = points.device dtype = points.dtype pi = torch.pi k_idx = torch.arange(1, K + 1, device=device, dtype=dtype) i, j, k = torch.meshgrid(k_idx, k_idx, k_idx, indexing="ij") # [K, K, K] w = (i * i + j * j + k * k) ** (-self.r - 1) # [K, K, K] x = points[:, 0] # [n_points] y = points[:, 1] # [n_points] z = points[:, 2] # [n_points] sinx = torch.sin(pi * x[:, None] * k_idx[None, :]) # [n_points, K] siny = torch.sin(pi * y[:, None] * k_idx[None, :]) # [n_points, K] sinz = torch.sin(pi * z[:, None] * k_idx[None, :]) # [n_points, K] if len(self.a.shape) == 3: B = self.a.to(device=device, dtype=dtype) * w # [K, K, K] u = torch.einsum('pi,pj,pk,ijk->p', sinx, siny, sinz, B) else: B = self.a.to(device=device, dtype=dtype) * w # [N, K, K, K] N = B.shape[0] bytes_per = torch.finfo(dtype).bits // 8 n_pts = sinx.shape[0] target_bytes = 512 * 1024 * 1024 chunk = max(1, target_bytes // max(1, n_pts * K * K * K * bytes_per)) out = [] for s in range(0, N, chunk): Be = B[s : s + chunk] ue = torch.einsum('pi,pj,pk,nijk->np', sinx, siny, sinz, Be) out.append(ue) u = torch.cat(out, dim=0) # [N, n_points] u = (1.0 / (pi * (K ** 3))) * u return u