Coverage for src/nos/transforms/quantile_scaler.py: 100%
84 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-19 11:29 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-19 11:29 +0000
1from typing import (
2 Tuple,
3 Union,
4)
6import torch
7import torch.nn as nn
8from continuiti.transforms import (
9 Transform,
10)
13class QuantileScaler(Transform):
14 """Quantile Scaler Class.
16 A transform for scaling input data to a specified target distribution using quantiles. This is
17 particularly useful for normalizing data in a way that is more robust to outliers than standard
18 z-score normalization.
20 The transformation maps the quantiles of the input data to the quantiles of the target distribution,
21 effectively performing a non-linear scaling that preserves the relative distribution of the data.
23 Args:
24 src: tensor from which the source distribution is drawn.
25 n_quantile_intervals: Number of individual bins into which the data is categorized.
26 target_mean: Mean of the target Gaussian distribution. Can be float (all dimensions use the same mean), or
27 tensor (allows for different means along different dimensions).
28 target_std: Std of the target Gaussian distribution. Can be float (all dimensions use the same std), or
29 tensor (allows for different stds along different dimensions).
30 eps: Small value to bound the target distribution to a finite interval.
32 """
34 def __init__(
35 self,
36 src: torch.Tensor,
37 n_quantile_intervals: int = 1000,
38 target_mean: Union[float, torch.Tensor] = 0.0,
39 target_std: Union[float, torch.Tensor] = 1.0,
40 eps: float = 1e-3,
41 ):
42 assert eps <= 0.5
43 assert eps >= 0
45 super().__init__()
47 if isinstance(target_mean, float):
48 target_mean = target_mean * torch.ones(1)
49 if isinstance(target_std, float):
50 target_std = target_std * torch.ones(1)
51 self.target_mean = target_mean
52 self.target_std = target_std
54 assert n_quantile_intervals > 0
55 self.n_quantile_intervals = n_quantile_intervals
56 self.n_q_points = n_quantile_intervals + 2 # n intervals have n + 2 edges
58 self.n_dim = src.size(1) # assumes src[0] is batch/observation dim
59 self.batched_n_dim = src.ndim
61 # source "distribution"
62 self.quantile_fractions = torch.linspace(0, 1, self.n_q_points)
63 quantile_points = torch.quantile(
64 src.transpose(0, 1).reshape(self.n_dim, -1),
65 self.quantile_fractions,
66 dim=1,
67 interpolation="linear",
68 )
69 self.quantile_points = nn.Parameter(quantile_points)
70 self.deltas = nn.Parameter(quantile_points[1:] - quantile_points[:-1])
72 # target distribution
73 self.target_distribution = torch.distributions.normal.Normal(target_mean, target_std)
74 self.target_quantile_fractions = torch.linspace(0 + eps, 1 - eps, self.n_q_points) # bounded domain
75 target_quantile_points = self.target_distribution.icdf(self.target_quantile_fractions)
76 target_quantile_points = target_quantile_points.unsqueeze(1).repeat(1, self.n_dim)
77 self.target_quantile_points = nn.Parameter(target_quantile_points)
78 self.target_deltas = nn.Parameter(target_quantile_points[1:] - target_quantile_points[:-1])
80 def _get_scaling_indices(
81 self, src: torch.Tensor, quantile_tensor: torch.Tensor
82 ) -> Tuple[torch.Tensor, torch.Tensor]:
83 """Method to get the indices of a tensor closest to src.
85 Args:
86 src: Input tensor.
87 quantile_tensor: Tensor containing quantile interval information of a distribution.
89 Returns:
90 Tuple containing the indices with the same shape as src with indices of quantile_tensor where the distance
91 between src and quantile_tensor is minimal, according to the last dim.
92 """
93 # preprocess tensors
94 v1 = src
95 v2 = quantile_tensor
96 work_ndim = max([v1.ndim, v2.ndim])
98 v2_shape = [1] * (work_ndim - v2.ndim) + list(v2.shape)
99 v2 = v2.view(*v2_shape)
100 v2 = v2.unsqueeze(0)
102 v1_shape = [1] * (work_ndim - v1.ndim) + list(v1.shape)
103 v1 = v1.view(*v1_shape)
104 v1 = v1.unsqueeze(v2.ndim - 2)
106 work_dims = torch.Size([max([a, b]) for a, b in zip(v1.shape, v2.shape)])
107 v1 = v1.expand(work_dims)
108 v2 = v2.expand(work_dims)
110 # find left boundary inside quantile intervals
111 diff = v2 - v1
112 diff[diff >= 0] = -torch.inf # discard right boundaries
113 indices = diff.argmax(dim=-2) # defaults to zero when all values are -inf
114 indices[indices > self.n_quantile_intervals] -= 1 # right boundary overflow
116 # prepare for indexing
117 return (
118 indices.view(-1),
119 torch.arange(self.n_dim).repeat(src.nelement() // self.n_dim),
120 )
122 def forward(self, tensor: torch.Tensor) -> torch.Tensor:
123 """Transforms the input tensor to match the target distribution using quantile scaling .
125 Args:
126 tensor: The input tensor to transform.
128 Returns:
129 The transformed tensor, scaled to the target distribution.
130 """
131 is_batched = tensor.ndim == self.batched_n_dim
132 if not is_batched:
133 # tensor is a single observation without observation dim
134 tensor = tensor.unsqueeze(0)
136 tensor = tensor.transpose(1, -1)
137 indices = self._get_scaling_indices(tensor, self.quantile_points)
138 # Scale input tensor to the unit interval based on source quantiles
139 p_min = self.quantile_points[indices].view(tensor.shape)
140 delta = self.deltas[indices].view(tensor.shape)
141 out = tensor - p_min
142 out = out / delta
144 # Scale and shift to match the target distribution
145 p_t_min = self.target_quantile_points[indices].view(tensor.shape)
146 delta_t = self.target_deltas[indices].view(tensor.shape)
147 out = out * delta_t
148 out = out + p_t_min
150 out = out.transpose(1, -1)
152 if not is_batched:
153 return out.squeeze(0)
154 return out
156 def undo(self, tensor: torch.Tensor) -> torch.Tensor:
157 """Reverses the transformation applied by the forward method, mapping the tensor back to its original
158 distribution.
160 Args:
161 tensor: The tensor to reverse the transformation on.
163 Returns:
164 The tensor with the quantile scaling transformation reversed according to the src distribution.
165 """
166 is_batched = tensor.ndim == self.batched_n_dim
167 if not is_batched:
168 # tensor is a single observation without observation dim
169 tensor = tensor.unsqueeze(0)
171 tensor = tensor.transpose(1, -1)
172 indices = self._get_scaling_indices(tensor, self.target_quantile_points)
174 # Scale input tensor to the unit interval based on the target distribution
175 p_t_min = self.target_quantile_points[indices].view(tensor.shape)
176 delta_t = self.target_deltas[indices].view(tensor.shape)
177 out = tensor - p_t_min
178 out = out / delta_t
180 # Scale and shift to match the src distribution
181 p_min = self.quantile_points[indices].view(tensor.shape)
182 delta = self.deltas[indices].view(tensor.shape)
183 out = out * delta
184 out = out + p_min
186 out = out.transpose(1, -1)
188 if not is_batched:
189 return out.squeeze(0)
190 return out