Coverage for src/nos/physics/helmholtz_residual.py: 100%
21 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-19 11:29 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-19 11:29 +0000
1import torch
2import torch.nn as nn
4from .laplace import (
5 Laplace,
6)
9class HelmholtzDomainResidual(nn.Module):
10 def __init__(self):
11 super().__init__()
12 self.laplace = Laplace()
14 def forward(self, y: torch.Tensor, v: torch.Tensor, k: torch.Tensor) -> torch.Tensor:
15 ks = k.squeeze() ** 2
16 ks = ks.reshape(-1, 1, 1)
17 ks = ks.expand(v.size(0), 1, 1)
18 lpl = self.laplace(y, v)
19 return lpl + ks * v
22class HelmholtzDomainMSE(nn.Module):
23 def __init__(self):
24 super().__init__()
25 self.pde = HelmholtzDomainResidual()
27 def forward(self, y: torch.Tensor, v: torch.Tensor, k: torch.Tensor) -> torch.Tensor:
28 residual = self.pde(y, v, k)
29 residual = residual**2
30 return torch.mean(residual)