Coverage for src/nos/metrics/error_metrics.py: 71%

21 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-09-19 11:29 +0000

1from typing import ( 

2 Dict, 

3) 

4 

5import torch 

6from continuiti.data import ( 

7 OperatorDataset, 

8) 

9from continuiti.operators import ( 

10 Operator, 

11) 

12 

13from .metric import ( 

14 Metric, 

15) 

16 

17 

18class Loss(Metric): 

19 """Class for evaluating error metrics. 

20 

21 Args: 

22 name: The name of the metric. 

23 loss: The loss function for calculating the metric. 

24 

25 """ 

26 

27 def __init__(self, name: str, loss): 

28 super().__init__(name) 

29 self.loss = loss 

30 

31 def __call__(self, operator: Operator, dataset: OperatorDataset) -> Dict: 

32 operator.eval() 

33 prediction = operator(dataset.x, dataset.u, dataset.v) 

34 value = self.loss(prediction, dataset.v).item() 

35 value /= len(dataset) 

36 return { 

37 "Value": value, 

38 "Unit": "[1]", 

39 } 

40 

41 

42class L1Error(Loss): 

43 """L1 error metric (Mean Absolute Error).""" 

44 

45 def __init__(self): 

46 super().__init__("L1_error", torch.nn.L1Loss()) 

47 

48 

49class MSError(Loss): 

50 """Mean square error metric (L2 Error).""" 

51 

52 def __init__(self): 

53 super().__init__("MS_error", torch.nn.MSELoss())