Coverage for src/nos/transforms/min_max_scale.py: 100%
40 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-19 11:29 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-19 11:29 +0000
1import torch
2import torch.nn as nn
3from continuiti.transforms import (
4 Transform,
5)
8class MinMaxScale(Transform):
9 def __init__(self, min_value: torch.Tensor, max_value: torch.Tensor):
10 super().__init__()
11 self.ndim = min_value.ndim
13 self.min_value = nn.Parameter(min_value)
14 self.max_value = nn.Parameter(max_value)
16 delta = max_value - min_value
17 delta[delta == 0] = 1.0
18 self.delta = nn.Parameter(delta)
20 target_min = -1.0 * torch.ones(min_value.shape)
21 self.target_min = nn.Parameter(target_min)
23 target_max = 1.0 * torch.ones(min_value.shape)
25 self.target_delta = nn.Parameter(target_max - target_min)
27 def _is_batched(self, tensor: torch.Tensor) -> bool:
28 return tensor.ndim == (self.ndim + 1)
30 def forward(self, tensor: torch.Tensor) -> torch.Tensor:
31 if self._is_batched(tensor):
32 # observation dimension 0
33 v_m = self.min_value.unsqueeze(0)
34 d = self.delta.unsqueeze(0)
35 tgt_d = self.target_delta.unsqueeze(0)
36 tgt_m = self.target_min.unsqueeze(0)
37 else:
38 # single observation
39 v_m = self.min_value
40 d = self.delta
41 tgt_d = self.target_delta
42 tgt_m = self.target_min
44 return ((tensor - v_m) / d) * tgt_d + tgt_m
46 def undo(self, tensor: torch.Tensor) -> torch.Tensor:
47 if self._is_batched(tensor):
48 tgt_m = self.target_min.unsqueeze(0)
49 tgt_d = self.target_delta.unsqueeze(0)
50 d = self.delta.unsqueeze(0)
51 v_m = self.min_value.unsqueeze(0)
52 else:
53 tgt_m = self.target_min
54 tgt_d = self.target_delta
55 d = self.delta
56 v_m = self.min_value
58 return (tensor - tgt_m) / tgt_d * d + v_m