Coverage for src/nos/physics/weight_scheduler_lin.py: 100%

22 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-09-19 11:29 +0000

1from typing import ( 

2 Union, 

3) 

4 

5import torch 

6import torch.nn as nn 

7 

8 

9class WeightSchedulerLinear(nn.Module): 

10 def __init__(self, pde_loss_start: int, pde_loss_full: int): 

11 super().__init__() 

12 self.pde_loss_start = pde_loss_start 

13 self.pde_loss_full = pde_loss_full 

14 self.pde_delta = pde_loss_full - pde_loss_start 

15 

16 def _get_data_weight(self, epoch: torch.tensor) -> torch.tensor: 

17 return torch.ones(epoch.shape) 

18 

19 def _get_pde_weight(self, epoch: torch.tensor) -> torch.tensor: 

20 out = (epoch - self.pde_loss_start) / self.pde_delta 

21 out[out < 0.0] = 0.0 

22 out[out > 1] = 1.0 

23 return out * 1e-3 

24 

25 def forward(self, epoch: Union[float, torch.tensor]) -> torch.tensor: 

26 if not isinstance(epoch, torch.Tensor): 

27 epoch = epoch * torch.ones(1) 

28 

29 data_weights = self._get_data_weight(epoch) 

30 pde_weights = self._get_pde_weight(epoch) 

31 

32 return torch.stack([data_weights, pde_weights], dim=1)