Coverage for src/nos/metrics/operator_metrics.py: 100%

22 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-09-19 11:29 +0000

1import time 

2from typing import ( 

3 Dict, 

4) 

5 

6from continuiti.data import ( 

7 OperatorDataset, 

8) 

9from continuiti.operators import ( 

10 Operator, 

11) 

12 

13from .metric import ( 

14 Metric, 

15) 

16 

17 

18class NumberOfParameters(Metric): 

19 """Number of parameters in the operator.""" 

20 

21 def __init__(self): 

22 super().__init__("Number_of_parameters") 

23 

24 def __call__(self, operator: Operator, dataset: OperatorDataset) -> Dict: 

25 num_params = sum(p.numel() for p in operator.parameters() if p.requires_grad) 

26 return {"Value": num_params, "Unit": "[1]"} 

27 

28 

29class SpeedOfEvaluation(Metric): 

30 """Speed of a single evaluation in milliseconds.""" 

31 

32 def __init__(self): 

33 super().__init__("Speed_of_evaluation") 

34 

35 def __call__(self, operator: Operator, dataset: OperatorDataset) -> Dict: 

36 operator.eval() 

37 start_time = time.time_ns() 

38 _ = operator(dataset.x, dataset.u, dataset.v) 

39 end_time = time.time_ns() 

40 delta_time = (end_time - start_time) * 1e-6 

41 delta_time = delta_time / len(dataset) 

42 return {"Value": delta_time, "Unit": "[ms]"}