Coverage for src/nos/transforms/min_max_scale.py: 100%

40 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-09-19 11:29 +0000

1import torch 

2import torch.nn as nn 

3from continuiti.transforms import ( 

4 Transform, 

5) 

6 

7 

8class MinMaxScale(Transform): 

9 def __init__(self, min_value: torch.Tensor, max_value: torch.Tensor): 

10 super().__init__() 

11 self.ndim = min_value.ndim 

12 

13 self.min_value = nn.Parameter(min_value) 

14 self.max_value = nn.Parameter(max_value) 

15 

16 delta = max_value - min_value 

17 delta[delta == 0] = 1.0 

18 self.delta = nn.Parameter(delta) 

19 

20 target_min = -1.0 * torch.ones(min_value.shape) 

21 self.target_min = nn.Parameter(target_min) 

22 

23 target_max = 1.0 * torch.ones(min_value.shape) 

24 

25 self.target_delta = nn.Parameter(target_max - target_min) 

26 

27 def _is_batched(self, tensor: torch.Tensor) -> bool: 

28 return tensor.ndim == (self.ndim + 1) 

29 

30 def forward(self, tensor: torch.Tensor) -> torch.Tensor: 

31 if self._is_batched(tensor): 

32 # observation dimension 0 

33 v_m = self.min_value.unsqueeze(0) 

34 d = self.delta.unsqueeze(0) 

35 tgt_d = self.target_delta.unsqueeze(0) 

36 tgt_m = self.target_min.unsqueeze(0) 

37 else: 

38 # single observation 

39 v_m = self.min_value 

40 d = self.delta 

41 tgt_d = self.target_delta 

42 tgt_m = self.target_min 

43 

44 return ((tensor - v_m) / d) * tgt_d + tgt_m 

45 

46 def undo(self, tensor: torch.Tensor) -> torch.Tensor: 

47 if self._is_batched(tensor): 

48 tgt_m = self.target_min.unsqueeze(0) 

49 tgt_d = self.target_delta.unsqueeze(0) 

50 d = self.delta.unsqueeze(0) 

51 v_m = self.min_value.unsqueeze(0) 

52 else: 

53 tgt_m = self.target_min 

54 tgt_d = self.target_delta 

55 d = self.delta 

56 v_m = self.min_value 

57 

58 return (tensor - tgt_m) / tgt_d * d + v_m