Coverage for src/nos/physics/laplace.py: 100%

11 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-09-19 11:29 +0000

1import torch 

2from continuiti.operators import ( 

3 Operator, 

4) 

5from continuiti.pde import ( 

6 Grad, 

7) 

8 

9 

10class Laplace(Operator): 

11 def forward(self, x: torch.Tensor, u: torch.Tensor, y: torch.Tensor = None) -> torch.Tensor: 

12 second_derivatives = [] 

13 derivative = Grad()(x, u) 

14 for dim in range(x.size(-1)): 

15 second_derivatives.append(Grad()(x, derivative[:, :, dim])[:, :, dim]) 

16 second_derivatives = torch.stack(second_derivatives, dim=-1) 

17 return torch.sum(second_derivatives, dim=-1, keepdim=True)