Coverage for src/nos/physics/helmholtz_residual.py: 100%

21 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-09-19 11:29 +0000

1import torch 

2import torch.nn as nn 

3 

4from .laplace import ( 

5 Laplace, 

6) 

7 

8 

9class HelmholtzDomainResidual(nn.Module): 

10 def __init__(self): 

11 super().__init__() 

12 self.laplace = Laplace() 

13 

14 def forward(self, y: torch.Tensor, v: torch.Tensor, k: torch.Tensor) -> torch.Tensor: 

15 ks = k.squeeze() ** 2 

16 ks = ks.reshape(-1, 1, 1) 

17 ks = ks.expand(v.size(0), 1, 1) 

18 lpl = self.laplace(y, v) 

19 return lpl + ks * v 

20 

21 

22class HelmholtzDomainMSE(nn.Module): 

23 def __init__(self): 

24 super().__init__() 

25 self.pde = HelmholtzDomainResidual() 

26 

27 def forward(self, y: torch.Tensor, v: torch.Tensor, k: torch.Tensor) -> torch.Tensor: 

28 residual = self.pde(y, v, k) 

29 residual = residual**2 

30 return torch.mean(residual)