import torch
[docs]
class PoissonMultiFrequency:
r"""
Multi-frequency wave equation, with :math:`0` boundary condition
.. math::
-\Delta u = f \quad (x, y)
where :math:`(x_1,x_2)\in [0,1]^2`,
with the boundary condition :math:`u(t, \pm 1, \pm 1) = 0`
Parameters
-----------
a: torch.Tensor , optional
3D tensor of shape :math:`[N, K, K]` or 2D tensor of shape :math:`[K, K]`, where :math:`N` is the number of samples, :math:`K` is the dimension of the frequencies
the coefficient of the wave equation,
if ``None``, it will be randomly generated by :math:`\mu\sim Unif([-1,1]^{K\times K})`
K: int, optional
the dimension of the frequencies, if ``a`` is not ``None``, this parameter will be ignored
if ``a`` is ``None``, it will be used to generate the random ``a``
c: float, optional
the poisson speed, default is :math:`1.0`
r: float, optional
the coefficient of the poisson equation, default is :math:`0.5`
"""
[docs]
def __init__(self, a=None, K=2, r= -0.5 ):
if a is None:
assert K is not None, "K should be specified if a is None"
a = torch.zeros((K, K)).uniform_(-1, 1)
else:
K = a.shape[-1]
assert a.shape[-2:] == (K, K), f"the shape of a should be (N, {K}, {K}) or ({K}, {K}), but got {a.shape}"
self.K = K
self.a = a
self.r = r
[docs]
def source_term(self, points, domain="rectangle"):
r"""Generate the poisson source function at each point in the domain
.. math::
f=\frac{\pi}{K^2} \sum_{i,j=1}^{K} a_{ij} \cdot (i^2 + j^2)^{r} sin(\pi ix) sin(\pi jy)
Parameters
----------
points: torch.Tensor
2D tensor of shape :math:`[|\mathcal V|, 2]`, where :math:`|\mathcal V|` is the number of vertices
all the points must be in :math:`[0,1]^2`
domain: str, optional
Domain shape. Analytical solution only supports
``"rectangle"`` (default).
Returns
-------
u0: torch.Tensor
1D tensor of shape :math:`[|\mathcal V|]` or 2D Tensor :math:`[N, |\mathcal V|]`, where :math:`N` is the number of samples, :math:`|\mathcal V|` is the number of vertices
"""
assert points.shape[-1] == 2, f"the shape of points must be [n_points, 2], but got {points.shape}"
if domain == "rectangle":
assert ((points<=1) & (points>=0)).all(), f"the points must be in [0,1]^2, but got {points}"
# Memory-efficient implementation:
# Avoid materializing [batch, n_points, K, K] intermediates by using matmul:
# f(p) = (pi/K^2) * sum_{i,j} B_ij * sin(i*pi*x_p) * sin(j*pi*y_p)
# where B = a * (i^2+j^2)^(-r)
#
# For each point p, this is: sinx_p^T * B * siny_p.
# Compute P = sinx @ B -> [n_points, K] (or [batch, n_points, K]) then dot with siny.
K = self.K
device = points.device
dtype = points.dtype
pi = torch.pi
k_idx = torch.arange(1, K + 1, device=device, dtype=dtype)
i, j = torch.meshgrid(k_idx, k_idx, indexing="ij") # [K, K]
w = (i * i + j * j) ** (-self.r) # [K, K]
x = points[:, 0] # [n_points]
y = points[:, 1] # [n_points]
sinx = torch.sin(pi * x[:, None] * k_idx[None, :]) # [n_points, K]
siny = torch.sin(pi * y[:, None] * k_idx[None, :]) # [n_points, K]
if len(self.a.shape) == 2:
B = self.a.to(device=device, dtype=dtype) * w # [K, K]
P = sinx @ B # [n_points, K]
f = (P * siny).sum(dim=-1) # [n_points]
else:
# Batched case: avoid allocating huge [N, n_points, K] tensors by chunking over batch dimension.
B = self.a.to(device=device, dtype=dtype) * w # [N, K, K]
N = B.shape[0]
# Target ~512MB temporary for P: chunk * n_points * K * sizeof(dtype)
bytes_per = torch.finfo(dtype).bits // 8
denom = int(sinx.shape[0] * K * bytes_per)
target_bytes = 512 * 1024 * 1024
chunk = max(1, target_bytes // max(1, denom))
out = []
for s in range(0, N, chunk):
Be = B[s : s + chunk] # [c, K, K]
# (sinx @ Be) -> [c, n_points, K] (broadcasted matmul)
P = torch.matmul(sinx, Be) # [c, n_points, K]
out.append((P * siny[None, :, :]).sum(dim=-1)) # [c, n_points]
f = torch.cat(out, dim=0) # [N, n_points]
f = (pi / (K * K)) * f
return f
[docs]
def solution(self, points):
r"""Generate the poisson solution function at each point in the domain
.. math::
u(x, y) = \frac{1}{\pi\cdot K^2} \sum_{i,j=1}^{K} a_{ij} \cdot (i^2 + j^2)^{r-1} sin(\pi ix) sin(\pi jy)
Parameters
----------
points: torch.Tensor
2D tensor of shape :math:`[|\mathcal V|, 2]`, where :math:`|\mathcal V|` is the number of vertices
all the points must be in :math:`[0,1]^2`
Returns
-------
u: torch.Tensor
1D tenor of shape :math:`[|\\mathcal V|]` or :math:`[N, |\\mathcal V|]`, where :math:`N` is the number of samples, :math:`|\\mathcal V|` is the number of vertices
"""
K = self.K
device = points.device
dtype = points.dtype
pi = torch.pi
k_idx = torch.arange(1, K + 1, device=device, dtype=dtype)
i, j = torch.meshgrid(k_idx, k_idx, indexing="ij") # [K, K]
w = (i * i + j * j) ** (-self.r - 1) # [K, K]
x = points[:, 0] # [n_points]
y = points[:, 1] # [n_points]
sinx = torch.sin(pi * x[:, None] * k_idx[None, :]) # [n_points, K]
siny = torch.sin(pi * y[:, None] * k_idx[None, :]) # [n_points, K]
if len(self.a.shape) == 2:
B = self.a.to(device=device, dtype=dtype) * w # [K, K]
P = sinx @ B # [n_points, K]
u = (P * siny).sum(dim=-1) # [n_points]
else:
# Batched case: chunk over batch dimension to keep memory bounded.
B = self.a.to(device=device, dtype=dtype) * w # [N, K, K]
N = B.shape[0]
bytes_per = torch.finfo(dtype).bits // 8
denom = int(sinx.shape[0] * K * bytes_per)
target_bytes = 512 * 1024 * 1024
chunk = max(1, target_bytes // max(1, denom))
out = []
for s in range(0, N, chunk):
Be = B[s : s + chunk]
P = torch.matmul(sinx, Be) # [c, n_points, K]
out.append((P * siny[None, :, :]).sum(dim=-1)) # [c, n_points]
u = torch.cat(out, dim=0) # [N, n_points]
u = (1.0 / (pi * (K * K))) * u
return u
class PoissonMultiFrequency3D:
r"""
Multi-frequency Poisson equation in 3D, with :math:`0` boundary condition
.. math::
-\Delta u = f \quad (x, y, z) \in [0,1]^3
with the boundary condition :math:`u = 0` on the boundary of :math:`[0,1]^3`
Parameters
-----------
a: torch.Tensor , optional
4D tensor of shape :math:`[N, K, K, K]` or 3D tensor of shape :math:`[K, K, K]`,
where :math:`N` is the number of samples, :math:`K` is the dimension of the frequencies.
If ``None``, it will be randomly generated from :math:`\text{Unif}([-1,1]^{K \times K \times K})`
K: int, optional
the dimension of the frequencies; if ``a`` is not ``None``, this parameter will be ignored
r: float, optional
the coefficient exponent, default is :math:`-0.5`
"""
def __init__(self, a=None, K=2, r=-0.5):
if a is None:
assert K is not None, "K should be specified if a is None"
a = torch.zeros((K, K, K)).uniform_(-1, 1)
else:
K = a.shape[-1]
assert a.shape[-3:] == (K, K, K), f"the shape of a should be (N, {K}, {K}, {K}) or ({K}, {K}, {K}), but got {a.shape}"
self.K = K
self.a = a
self.r = r
def source_term(self, points, domain="cube"):
r"""Generate the Poisson source function at each point in the 3D domain
.. math::
f = \frac{\pi}{K^3} \sum_{i,j,k=1}^{K} a_{ijk} \cdot (i^2 + j^2 + k^2)^{-r} \sin(\pi ix) \sin(\pi jy) \sin(\pi kz)
Parameters
----------
points: torch.Tensor
2D tensor of shape :math:`[|\mathcal V|, 3]`, where :math:`|\mathcal V|` is the number of vertices
all the points must be in :math:`[0,1]^3`
domain: str, optional
Domain shape. Default: ``"cube"``.
Returns
-------
f: torch.Tensor
1D tensor of shape :math:`[|\mathcal V|]` or 2D Tensor :math:`[N, |\mathcal V|]`
"""
assert points.shape[-1] == 3, f"the shape of points must be [n_points, 3], but got {points.shape}"
if domain == "cube":
assert ((points <= 1) & (points >= 0)).all(), f"the points must be in [0,1]^3, but got min={points.min()}, max={points.max()}"
K = self.K
device = points.device
dtype = points.dtype
pi = torch.pi
k_idx = torch.arange(1, K + 1, device=device, dtype=dtype)
# Create 3D meshgrid for i, j, k
i, j, k = torch.meshgrid(k_idx, k_idx, k_idx, indexing="ij") # [K, K, K]
w = (i * i + j * j + k * k) ** (-self.r) # [K, K, K]
x = points[:, 0] # [n_points]
y = points[:, 1] # [n_points]
z = points[:, 2] # [n_points]
sinx = torch.sin(pi * x[:, None] * k_idx[None, :]) # [n_points, K]
siny = torch.sin(pi * y[:, None] * k_idx[None, :]) # [n_points, K]
sinz = torch.sin(pi * z[:, None] * k_idx[None, :]) # [n_points, K]
if len(self.a.shape) == 3:
# Non-batched: a is [K, K, K]
B = self.a.to(device=device, dtype=dtype) * w # [K, K, K]
# Compute sum_{i,j,k} B_ijk * sinx_i * siny_j * sinz_k for each point
# = sinx @ B @ siny^T then element-wise with sinz
# More efficient: (sinx @ B.reshape(K, K*K)) -> [n_points, K*K]
# Then reshape and sum with siny, sinz
# Use einsum for clarity: f_p = sum_{ijk} B_ijk * sinx_{p,i} * siny_{p,j} * sinz_{p,k}
f = torch.einsum('pi,pj,pk,ijk->p', sinx, siny, sinz, B)
else:
# Batched case: a is [N, K, K, K]
B = self.a.to(device=device, dtype=dtype) * w # [N, K, K, K]
N = B.shape[0]
# Chunk over batch to limit memory
bytes_per = torch.finfo(dtype).bits // 8
n_pts = sinx.shape[0]
# einsum('pi,pj,pk,nijk->np') would need [N, n_pts, K, K, K] intermediate
# We chunk over N instead
target_bytes = 512 * 1024 * 1024
chunk = max(1, target_bytes // max(1, n_pts * K * K * K * bytes_per))
out = []
for s in range(0, N, chunk):
Be = B[s : s + chunk] # [c, K, K, K]
fe = torch.einsum('pi,pj,pk,nijk->np', sinx, siny, sinz, Be) # [c, n_pts]
out.append(fe)
f = torch.cat(out, dim=0) # [N, n_points]
f = (pi / (K ** 3)) * f
return f
def solution(self, points):
r"""Generate the Poisson solution function at each point in the 3D domain
.. math::
u(x, y, z) = \frac{1}{\pi \cdot K^3} \sum_{i,j,k=1}^{K} a_{ijk} \cdot (i^2 + j^2 + k^2)^{r-1} \sin(\pi ix) \sin(\pi jy) \sin(\pi kz)
Parameters
----------
points: torch.Tensor
2D tensor of shape :math:`[|\mathcal V|, 3]`
all the points must be in :math:`[0,1]^3`
Returns
-------
u: torch.Tensor
1D tensor of shape :math:`[|\mathcal V|]` or :math:`[N, |\mathcal V|]`
"""
K = self.K
device = points.device
dtype = points.dtype
pi = torch.pi
k_idx = torch.arange(1, K + 1, device=device, dtype=dtype)
i, j, k = torch.meshgrid(k_idx, k_idx, k_idx, indexing="ij") # [K, K, K]
w = (i * i + j * j + k * k) ** (-self.r - 1) # [K, K, K]
x = points[:, 0] # [n_points]
y = points[:, 1] # [n_points]
z = points[:, 2] # [n_points]
sinx = torch.sin(pi * x[:, None] * k_idx[None, :]) # [n_points, K]
siny = torch.sin(pi * y[:, None] * k_idx[None, :]) # [n_points, K]
sinz = torch.sin(pi * z[:, None] * k_idx[None, :]) # [n_points, K]
if len(self.a.shape) == 3:
B = self.a.to(device=device, dtype=dtype) * w # [K, K, K]
u = torch.einsum('pi,pj,pk,ijk->p', sinx, siny, sinz, B)
else:
B = self.a.to(device=device, dtype=dtype) * w # [N, K, K, K]
N = B.shape[0]
bytes_per = torch.finfo(dtype).bits // 8
n_pts = sinx.shape[0]
target_bytes = 512 * 1024 * 1024
chunk = max(1, target_bytes // max(1, n_pts * K * K * K * bytes_per))
out = []
for s in range(0, N, chunk):
Be = B[s : s + chunk]
ue = torch.einsum('pi,pj,pk,nijk->np', sinx, siny, sinz, Be)
out.append(ue)
u = torch.cat(out, dim=0) # [N, n_points]
u = (1.0 / (pi * (K ** 3))) * u
return u