Coverage for src/nos/metrics/operator_metrics.py: 100%
22 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-19 11:29 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-19 11:29 +0000
1import time
2from typing import (
3 Dict,
4)
6from continuiti.data import (
7 OperatorDataset,
8)
9from continuiti.operators import (
10 Operator,
11)
13from .metric import (
14 Metric,
15)
18class NumberOfParameters(Metric):
19 """Number of parameters in the operator."""
21 def __init__(self):
22 super().__init__("Number_of_parameters")
24 def __call__(self, operator: Operator, dataset: OperatorDataset) -> Dict:
25 num_params = sum(p.numel() for p in operator.parameters() if p.requires_grad)
26 return {"Value": num_params, "Unit": "[1]"}
29class SpeedOfEvaluation(Metric):
30 """Speed of a single evaluation in milliseconds."""
32 def __init__(self):
33 super().__init__("Speed_of_evaluation")
35 def __call__(self, operator: Operator, dataset: OperatorDataset) -> Dict:
36 operator.eval()
37 start_time = time.time_ns()
38 _ = operator(dataset.x, dataset.u, dataset.v)
39 end_time = time.time_ns()
40 delta_time = (end_time - start_time) * 1e-6
41 delta_time = delta_time / len(dataset)
42 return {"Value": delta_time, "Unit": "[ms]"}