Coverage for src/nos/transforms/quantile_scaler.py: 100%

84 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-09-19 11:29 +0000

1from typing import ( 

2 Tuple, 

3 Union, 

4) 

5 

6import torch 

7import torch.nn as nn 

8from continuiti.transforms import ( 

9 Transform, 

10) 

11 

12 

13class QuantileScaler(Transform): 

14 """Quantile Scaler Class. 

15 

16 A transform for scaling input data to a specified target distribution using quantiles. This is 

17 particularly useful for normalizing data in a way that is more robust to outliers than standard 

18 z-score normalization. 

19 

20 The transformation maps the quantiles of the input data to the quantiles of the target distribution, 

21 effectively performing a non-linear scaling that preserves the relative distribution of the data. 

22 

23 Args: 

24 src: tensor from which the source distribution is drawn. 

25 n_quantile_intervals: Number of individual bins into which the data is categorized. 

26 target_mean: Mean of the target Gaussian distribution. Can be float (all dimensions use the same mean), or 

27 tensor (allows for different means along different dimensions). 

28 target_std: Std of the target Gaussian distribution. Can be float (all dimensions use the same std), or 

29 tensor (allows for different stds along different dimensions). 

30 eps: Small value to bound the target distribution to a finite interval. 

31 

32 """ 

33 

34 def __init__( 

35 self, 

36 src: torch.Tensor, 

37 n_quantile_intervals: int = 1000, 

38 target_mean: Union[float, torch.Tensor] = 0.0, 

39 target_std: Union[float, torch.Tensor] = 1.0, 

40 eps: float = 1e-3, 

41 ): 

42 assert eps <= 0.5 

43 assert eps >= 0 

44 

45 super().__init__() 

46 

47 if isinstance(target_mean, float): 

48 target_mean = target_mean * torch.ones(1) 

49 if isinstance(target_std, float): 

50 target_std = target_std * torch.ones(1) 

51 self.target_mean = target_mean 

52 self.target_std = target_std 

53 

54 assert n_quantile_intervals > 0 

55 self.n_quantile_intervals = n_quantile_intervals 

56 self.n_q_points = n_quantile_intervals + 2 # n intervals have n + 2 edges 

57 

58 self.n_dim = src.size(1) # assumes src[0] is batch/observation dim 

59 self.batched_n_dim = src.ndim 

60 

61 # source "distribution" 

62 self.quantile_fractions = torch.linspace(0, 1, self.n_q_points) 

63 quantile_points = torch.quantile( 

64 src.transpose(0, 1).reshape(self.n_dim, -1), 

65 self.quantile_fractions, 

66 dim=1, 

67 interpolation="linear", 

68 ) 

69 self.quantile_points = nn.Parameter(quantile_points) 

70 self.deltas = nn.Parameter(quantile_points[1:] - quantile_points[:-1]) 

71 

72 # target distribution 

73 self.target_distribution = torch.distributions.normal.Normal(target_mean, target_std) 

74 self.target_quantile_fractions = torch.linspace(0 + eps, 1 - eps, self.n_q_points) # bounded domain 

75 target_quantile_points = self.target_distribution.icdf(self.target_quantile_fractions) 

76 target_quantile_points = target_quantile_points.unsqueeze(1).repeat(1, self.n_dim) 

77 self.target_quantile_points = nn.Parameter(target_quantile_points) 

78 self.target_deltas = nn.Parameter(target_quantile_points[1:] - target_quantile_points[:-1]) 

79 

80 def _get_scaling_indices( 

81 self, src: torch.Tensor, quantile_tensor: torch.Tensor 

82 ) -> Tuple[torch.Tensor, torch.Tensor]: 

83 """Method to get the indices of a tensor closest to src. 

84 

85 Args: 

86 src: Input tensor. 

87 quantile_tensor: Tensor containing quantile interval information of a distribution. 

88 

89 Returns: 

90 Tuple containing the indices with the same shape as src with indices of quantile_tensor where the distance 

91 between src and quantile_tensor is minimal, according to the last dim. 

92 """ 

93 # preprocess tensors 

94 v1 = src 

95 v2 = quantile_tensor 

96 work_ndim = max([v1.ndim, v2.ndim]) 

97 

98 v2_shape = [1] * (work_ndim - v2.ndim) + list(v2.shape) 

99 v2 = v2.view(*v2_shape) 

100 v2 = v2.unsqueeze(0) 

101 

102 v1_shape = [1] * (work_ndim - v1.ndim) + list(v1.shape) 

103 v1 = v1.view(*v1_shape) 

104 v1 = v1.unsqueeze(v2.ndim - 2) 

105 

106 work_dims = torch.Size([max([a, b]) for a, b in zip(v1.shape, v2.shape)]) 

107 v1 = v1.expand(work_dims) 

108 v2 = v2.expand(work_dims) 

109 

110 # find left boundary inside quantile intervals 

111 diff = v2 - v1 

112 diff[diff >= 0] = -torch.inf # discard right boundaries 

113 indices = diff.argmax(dim=-2) # defaults to zero when all values are -inf 

114 indices[indices > self.n_quantile_intervals] -= 1 # right boundary overflow 

115 

116 # prepare for indexing 

117 return ( 

118 indices.view(-1), 

119 torch.arange(self.n_dim).repeat(src.nelement() // self.n_dim), 

120 ) 

121 

122 def forward(self, tensor: torch.Tensor) -> torch.Tensor: 

123 """Transforms the input tensor to match the target distribution using quantile scaling . 

124 

125 Args: 

126 tensor: The input tensor to transform. 

127 

128 Returns: 

129 The transformed tensor, scaled to the target distribution. 

130 """ 

131 is_batched = tensor.ndim == self.batched_n_dim 

132 if not is_batched: 

133 # tensor is a single observation without observation dim 

134 tensor = tensor.unsqueeze(0) 

135 

136 tensor = tensor.transpose(1, -1) 

137 indices = self._get_scaling_indices(tensor, self.quantile_points) 

138 # Scale input tensor to the unit interval based on source quantiles 

139 p_min = self.quantile_points[indices].view(tensor.shape) 

140 delta = self.deltas[indices].view(tensor.shape) 

141 out = tensor - p_min 

142 out = out / delta 

143 

144 # Scale and shift to match the target distribution 

145 p_t_min = self.target_quantile_points[indices].view(tensor.shape) 

146 delta_t = self.target_deltas[indices].view(tensor.shape) 

147 out = out * delta_t 

148 out = out + p_t_min 

149 

150 out = out.transpose(1, -1) 

151 

152 if not is_batched: 

153 return out.squeeze(0) 

154 return out 

155 

156 def undo(self, tensor: torch.Tensor) -> torch.Tensor: 

157 """Reverses the transformation applied by the forward method, mapping the tensor back to its original 

158 distribution. 

159 

160 Args: 

161 tensor: The tensor to reverse the transformation on. 

162 

163 Returns: 

164 The tensor with the quantile scaling transformation reversed according to the src distribution. 

165 """ 

166 is_batched = tensor.ndim == self.batched_n_dim 

167 if not is_batched: 

168 # tensor is a single observation without observation dim 

169 tensor = tensor.unsqueeze(0) 

170 

171 tensor = tensor.transpose(1, -1) 

172 indices = self._get_scaling_indices(tensor, self.target_quantile_points) 

173 

174 # Scale input tensor to the unit interval based on the target distribution 

175 p_t_min = self.target_quantile_points[indices].view(tensor.shape) 

176 delta_t = self.target_deltas[indices].view(tensor.shape) 

177 out = tensor - p_t_min 

178 out = out / delta_t 

179 

180 # Scale and shift to match the src distribution 

181 p_min = self.quantile_points[indices].view(tensor.shape) 

182 delta = self.deltas[indices].view(tensor.shape) 

183 out = out * delta 

184 out = out + p_min 

185 

186 out = out.transpose(1, -1) 

187 

188 if not is_batched: 

189 return out.squeeze(0) 

190 return out