Coverage for src/nos/physics/weight_scheduler_lin.py: 100%
22 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-19 11:29 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-19 11:29 +0000
1from typing import (
2 Union,
3)
5import torch
6import torch.nn as nn
9class WeightSchedulerLinear(nn.Module):
10 def __init__(self, pde_loss_start: int, pde_loss_full: int):
11 super().__init__()
12 self.pde_loss_start = pde_loss_start
13 self.pde_loss_full = pde_loss_full
14 self.pde_delta = pde_loss_full - pde_loss_start
16 def _get_data_weight(self, epoch: torch.tensor) -> torch.tensor:
17 return torch.ones(epoch.shape)
19 def _get_pde_weight(self, epoch: torch.tensor) -> torch.tensor:
20 out = (epoch - self.pde_loss_start) / self.pde_delta
21 out[out < 0.0] = 0.0
22 out[out > 1] = 1.0
23 return out * 1e-3
25 def forward(self, epoch: Union[float, torch.tensor]) -> torch.tensor:
26 if not isinstance(epoch, torch.Tensor):
27 epoch = epoch * torch.ones(1)
29 data_weights = self._get_data_weight(epoch)
30 pde_weights = self._get_pde_weight(epoch)
32 return torch.stack([data_weights, pde_weights], dim=1)