Coverage for src/nos/data/transmission_loss.py: 72%

83 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-09-19 11:29 +0000

1import pathlib 

2from typing import ( 

3 Literal, 

4) 

5 

6import numpy as np 

7import pandas as pd 

8import torch 

9from continuiti.data import ( 

10 OperatorDataset, 

11) 

12from continuiti.transforms import ( 

13 Normalize, 

14 Transform, 

15) 

16 

17from nos.transforms import ( 

18 MinMaxScale, 

19 QuantileScaler, 

20) 

21 

22FILTER_COLS = ["radius", "inner_radius", "gap_width"] 

23 

24 

25def get_tl_from_path(path: pathlib.Path): 

26 if path.is_file(): 

27 df = pd.read_csv(path, dtype=np.float32) 

28 else: 

29 df = pd.DataFrame() 

30 for file in path.rglob("*.csv"): 

31 df_tmp = pd.read_csv(file, dtype=np.float32) 

32 df = pd.concat([df, df_tmp]) 

33 return df 

34 

35 

36def get_unique_crystals(df: pd.DataFrame) -> pd.DataFrame: 

37 return df[FILTER_COLS].drop_duplicates() 

38 

39 

40def get_n_unique(df: pd.DataFrame, n_samples: int = -1): 

41 if n_samples == -1: 

42 return df 

43 

44 unique_crystals = get_unique_crystals(df) 

45 unique_crystals = unique_crystals.sample(n_samples) 

46 

47 return pd.merge(df, unique_crystals, on=FILTER_COLS) 

48 

49 

50def get_tl_frame(path: pathlib.Path, n_samples: int = -1): 

51 df = get_tl_from_path(path) 

52 return get_n_unique(df, n_samples) 

53 

54 

55def get_min_max_transform(src: torch.Tensor) -> Transform: 

56 src_tmp = src.transpose(0, 1).flatten(1, -1) 

57 src_min, _ = torch.min(src_tmp, dim=1) 

58 src_max, _ = torch.max(src_tmp, dim=1) 

59 src_min = src_min.reshape(src.size(1), *[1] * (src.ndim - 2)) # without observation dimension for dataloader. 

60 src_max = src_max.reshape(src.size(1), *[1] * (src.ndim - 2)) 

61 

62 return MinMaxScale(src_min, src_max) 

63 

64 

65def get_normalize_transform(src: torch.Tensor) -> Transform: 

66 src_tmp = src.transpose(0, 1).flatten(1, -1) 

67 src_mean = torch.mean(src_tmp, dim=1) 

68 src_std = torch.std(src_tmp, dim=1) 

69 src_mean = src_mean.reshape(src.size(1), *[1] * (src.ndim - 2)) # without observation dimension for dataloader. 

70 src_std = src_std.reshape(src.size(1), *[1] * (src.ndim - 2)) 

71 

72 return Normalize(src_mean, src_std) 

73 

74 

75class TLDataset(OperatorDataset): 

76 def __init__( 

77 self, 

78 path: pathlib.Path, 

79 n_samples: int = -1, 

80 v_transform: Literal["quantile", "min_max", "normalize"] = "quantile", 

81 ): 

82 # retrieve data 

83 df = get_tl_frame(path, n_samples) 

84 

85 x = torch.stack( 

86 [ 

87 torch.tensor(df["radius"].tolist()), 

88 torch.tensor(df["inner_radius"].tolist()), 

89 torch.tensor(df["gap_width"].tolist()), 

90 ], 

91 dim=1, 

92 ).unsqueeze(-1) 

93 u = x 

94 y = torch.tensor(df["frequency"].tolist()).reshape(-1, 1, 1) 

95 v = torch.tensor(df["transmission_loss"].tolist()).unsqueeze(1).reshape(-1, 1, 1) 

96 

97 v_t: Transform 

98 if v_transform == "quantile": 

99 v_t = QuantileScaler(v) 

100 elif v_transform == "min_max": 

101 v_t = get_min_max_transform(v) 

102 elif v_transform == "normalize": 

103 v_t = get_normalize_transform(v) 

104 else: 

105 raise ValueError(f"Unknown transformation: {v_transform}.") 

106 

107 transformations = { 

108 "x_transform": get_min_max_transform(x), 

109 "u_transform": get_min_max_transform(u), 

110 "y_transform": get_min_max_transform(y), 

111 "v_transform": v_t, 

112 } 

113 

114 super().__init__(x, u, y, v, **transformations) 

115 

116 

117def get_tl_compact(path: pathlib.Path, n_samples: int = -1): 

118 df = get_tl_frame(path, n_samples) 

119 unique_crystals = get_unique_crystals(df) 

120 

121 compact_df = pd.DataFrame() 

122 for _, crystal in unique_crystals.iterrows(): 

123 tmp_df = df[ 

124 (df["radius"] == crystal["radius"]) 

125 & (df["inner_radius"] == crystal["inner_radius"]) 

126 & (df["gap_width"] == crystal["gap_width"]) 

127 ] 

128 c_df = pd.DataFrame( 

129 { 

130 "radius": crystal["radius"], 

131 "inner_radius": crystal["inner_radius"], 

132 "gap_width": crystal["gap_width"], 

133 "frequencies": [tmp_df["frequency"].tolist()], 

134 "min_frequency": min(tmp_df["frequency"]), 

135 "max_frequency": max(tmp_df["frequency"]), 

136 "transmission_losses": [tmp_df["transmission_loss"].tolist()], 

137 "min_transmission_loss": min(tmp_df["transmission_loss"]), 

138 "max_transmission_loss": max(tmp_df["transmission_loss"]), 

139 } 

140 ) 

141 

142 compact_df = pd.concat([compact_df, c_df], ignore_index=True) 

143 return compact_df 

144 

145 

146class TLDatasetCompact(OperatorDataset): 

147 """Transmission loss dataset, with bigger evaluation space.""" 

148 

149 def __init__( 

150 self, 

151 path: pathlib.Path, 

152 n_samples: int = -1, 

153 v_transform: Literal["quantile", "min_max", "normalize"] = "quantile", 

154 ): 

155 df = get_tl_compact(path, n_samples) 

156 

157 x = torch.stack( 

158 [ 

159 torch.tensor(df["radius"].tolist()), 

160 torch.tensor(df["inner_radius"].tolist()), 

161 torch.tensor(df["gap_width"].tolist()), 

162 ], 

163 dim=1, 

164 ).unsqueeze(-1) 

165 u = x 

166 y = torch.tensor(df["frequencies"].tolist()).reshape(len(df), 1, -1) 

167 v = torch.tensor(df["transmission_losses"]).unsqueeze(1).reshape(len(df), 1, -1) 

168 

169 v_t: Transform 

170 if v_transform == "quantile": 

171 v_t = QuantileScaler(v) 

172 elif v_transform == "min_max": 

173 v_t = get_min_max_transform(v) 

174 elif v_transform == "normalize": 

175 v_t = get_normalize_transform(v) 

176 else: 

177 raise ValueError(f"Unknown transformation: {v_transform}.") 

178 

179 transformations = { 

180 "x_transform": get_min_max_transform(x), 

181 "u_transform": get_min_max_transform(u), 

182 "y_transform": get_min_max_transform(y), 

183 "v_transform": v_t, 

184 } 

185 

186 super().__init__(x, u, y, v, **transformations)