Coverage for src/nos/metrics/error_metrics.py: 71%
21 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-19 11:29 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-19 11:29 +0000
1from typing import (
2 Dict,
3)
5import torch
6from continuiti.data import (
7 OperatorDataset,
8)
9from continuiti.operators import (
10 Operator,
11)
13from .metric import (
14 Metric,
15)
18class Loss(Metric):
19 """Class for evaluating error metrics.
21 Args:
22 name: The name of the metric.
23 loss: The loss function for calculating the metric.
25 """
27 def __init__(self, name: str, loss):
28 super().__init__(name)
29 self.loss = loss
31 def __call__(self, operator: Operator, dataset: OperatorDataset) -> Dict:
32 operator.eval()
33 prediction = operator(dataset.x, dataset.u, dataset.v)
34 value = self.loss(prediction, dataset.v).item()
35 value /= len(dataset)
36 return {
37 "Value": value,
38 "Unit": "[1]",
39 }
42class L1Error(Loss):
43 """L1 error metric (Mean Absolute Error)."""
45 def __init__(self):
46 super().__init__("L1_error", torch.nn.L1Loss())
49class MSError(Loss):
50 """Mean square error metric (L2 Error)."""
52 def __init__(self):
53 super().__init__("MS_error", torch.nn.MSELoss())