Coverage for src/nos/data/transmission_loss.py: 72%
83 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-19 11:29 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-19 11:29 +0000
1import pathlib
2from typing import (
3 Literal,
4)
6import numpy as np
7import pandas as pd
8import torch
9from continuiti.data import (
10 OperatorDataset,
11)
12from continuiti.transforms import (
13 Normalize,
14 Transform,
15)
17from nos.transforms import (
18 MinMaxScale,
19 QuantileScaler,
20)
22FILTER_COLS = ["radius", "inner_radius", "gap_width"]
25def get_tl_from_path(path: pathlib.Path):
26 if path.is_file():
27 df = pd.read_csv(path, dtype=np.float32)
28 else:
29 df = pd.DataFrame()
30 for file in path.rglob("*.csv"):
31 df_tmp = pd.read_csv(file, dtype=np.float32)
32 df = pd.concat([df, df_tmp])
33 return df
36def get_unique_crystals(df: pd.DataFrame) -> pd.DataFrame:
37 return df[FILTER_COLS].drop_duplicates()
40def get_n_unique(df: pd.DataFrame, n_samples: int = -1):
41 if n_samples == -1:
42 return df
44 unique_crystals = get_unique_crystals(df)
45 unique_crystals = unique_crystals.sample(n_samples)
47 return pd.merge(df, unique_crystals, on=FILTER_COLS)
50def get_tl_frame(path: pathlib.Path, n_samples: int = -1):
51 df = get_tl_from_path(path)
52 return get_n_unique(df, n_samples)
55def get_min_max_transform(src: torch.Tensor) -> Transform:
56 src_tmp = src.transpose(0, 1).flatten(1, -1)
57 src_min, _ = torch.min(src_tmp, dim=1)
58 src_max, _ = torch.max(src_tmp, dim=1)
59 src_min = src_min.reshape(src.size(1), *[1] * (src.ndim - 2)) # without observation dimension for dataloader.
60 src_max = src_max.reshape(src.size(1), *[1] * (src.ndim - 2))
62 return MinMaxScale(src_min, src_max)
65def get_normalize_transform(src: torch.Tensor) -> Transform:
66 src_tmp = src.transpose(0, 1).flatten(1, -1)
67 src_mean = torch.mean(src_tmp, dim=1)
68 src_std = torch.std(src_tmp, dim=1)
69 src_mean = src_mean.reshape(src.size(1), *[1] * (src.ndim - 2)) # without observation dimension for dataloader.
70 src_std = src_std.reshape(src.size(1), *[1] * (src.ndim - 2))
72 return Normalize(src_mean, src_std)
75class TLDataset(OperatorDataset):
76 def __init__(
77 self,
78 path: pathlib.Path,
79 n_samples: int = -1,
80 v_transform: Literal["quantile", "min_max", "normalize"] = "quantile",
81 ):
82 # retrieve data
83 df = get_tl_frame(path, n_samples)
85 x = torch.stack(
86 [
87 torch.tensor(df["radius"].tolist()),
88 torch.tensor(df["inner_radius"].tolist()),
89 torch.tensor(df["gap_width"].tolist()),
90 ],
91 dim=1,
92 ).unsqueeze(-1)
93 u = x
94 y = torch.tensor(df["frequency"].tolist()).reshape(-1, 1, 1)
95 v = torch.tensor(df["transmission_loss"].tolist()).unsqueeze(1).reshape(-1, 1, 1)
97 v_t: Transform
98 if v_transform == "quantile":
99 v_t = QuantileScaler(v)
100 elif v_transform == "min_max":
101 v_t = get_min_max_transform(v)
102 elif v_transform == "normalize":
103 v_t = get_normalize_transform(v)
104 else:
105 raise ValueError(f"Unknown transformation: {v_transform}.")
107 transformations = {
108 "x_transform": get_min_max_transform(x),
109 "u_transform": get_min_max_transform(u),
110 "y_transform": get_min_max_transform(y),
111 "v_transform": v_t,
112 }
114 super().__init__(x, u, y, v, **transformations)
117def get_tl_compact(path: pathlib.Path, n_samples: int = -1):
118 df = get_tl_frame(path, n_samples)
119 unique_crystals = get_unique_crystals(df)
121 compact_df = pd.DataFrame()
122 for _, crystal in unique_crystals.iterrows():
123 tmp_df = df[
124 (df["radius"] == crystal["radius"])
125 & (df["inner_radius"] == crystal["inner_radius"])
126 & (df["gap_width"] == crystal["gap_width"])
127 ]
128 c_df = pd.DataFrame(
129 {
130 "radius": crystal["radius"],
131 "inner_radius": crystal["inner_radius"],
132 "gap_width": crystal["gap_width"],
133 "frequencies": [tmp_df["frequency"].tolist()],
134 "min_frequency": min(tmp_df["frequency"]),
135 "max_frequency": max(tmp_df["frequency"]),
136 "transmission_losses": [tmp_df["transmission_loss"].tolist()],
137 "min_transmission_loss": min(tmp_df["transmission_loss"]),
138 "max_transmission_loss": max(tmp_df["transmission_loss"]),
139 }
140 )
142 compact_df = pd.concat([compact_df, c_df], ignore_index=True)
143 return compact_df
146class TLDatasetCompact(OperatorDataset):
147 """Transmission loss dataset, with bigger evaluation space."""
149 def __init__(
150 self,
151 path: pathlib.Path,
152 n_samples: int = -1,
153 v_transform: Literal["quantile", "min_max", "normalize"] = "quantile",
154 ):
155 df = get_tl_compact(path, n_samples)
157 x = torch.stack(
158 [
159 torch.tensor(df["radius"].tolist()),
160 torch.tensor(df["inner_radius"].tolist()),
161 torch.tensor(df["gap_width"].tolist()),
162 ],
163 dim=1,
164 ).unsqueeze(-1)
165 u = x
166 y = torch.tensor(df["frequencies"].tolist()).reshape(len(df), 1, -1)
167 v = torch.tensor(df["transmission_losses"]).unsqueeze(1).reshape(len(df), 1, -1)
169 v_t: Transform
170 if v_transform == "quantile":
171 v_t = QuantileScaler(v)
172 elif v_transform == "min_max":
173 v_t = get_min_max_transform(v)
174 elif v_transform == "normalize":
175 v_t = get_normalize_transform(v)
176 else:
177 raise ValueError(f"Unknown transformation: {v_transform}.")
179 transformations = {
180 "x_transform": get_min_max_transform(x),
181 "u_transform": get_min_max_transform(u),
182 "y_transform": get_min_max_transform(y),
183 "v_transform": v_t,
184 }
186 super().__init__(x, u, y, v, **transformations)