Coverage for src/nos/trainers/trainer.py: 86%
107 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-19 11:29 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-19 11:29 +0000
1import json
2import pathlib
3import shutil
4import time
6import mlflow
7import pandas as pd
8import torch.optim.lr_scheduler as sched
9import torch.utils.data
10from continuiti.data import (
11 OperatorDataset,
12)
13from continuiti.operators import (
14 Operator,
15)
16from torch.utils.data import (
17 DataLoader,
18 random_split,
19)
20from tqdm import (
21 tqdm,
22)
24from nos.utils import (
25 UniqueId,
26)
28from .util import (
29 save_checkpoint,
30)
33class Trainer:
34 def __init__(
35 self,
36 operator: Operator,
37 criterion,
38 optimizer,
39 lr_scheduler: sched.LRScheduler = None,
40 max_epochs: int = 1000,
41 batch_size: int = 16,
42 max_n_logs: int = 200,
43 out_dir: pathlib.Path = None,
44 ):
45 self.operator = operator
46 self.criterion = criterion
47 self.criterion = criterion
48 self.optimizer = optimizer
49 if lr_scheduler is None:
50 self.lr_scheduler = sched.ConstantLR(self.optimizer, factor=1.0)
51 else:
52 self.lr_scheduler = lr_scheduler
54 self.max_epochs = max_epochs
55 self.batch_size = batch_size
56 self.test_val_split = 0.9
58 # logging and model serialization
59 if out_dir is None:
60 uid = UniqueId()
61 self.out_dir = pathlib.Path.cwd().joinpath("run", str(uid))
62 else:
63 self.out_dir = out_dir
64 self.out_dir.mkdir(parents=True, exist_ok=True)
66 log_epochs = torch.round(torch.linspace(0, max_epochs, max_n_logs))
67 log_epochs = log_epochs.tolist()
68 self.log_epochs = [int(epoch) for epoch in log_epochs]
70 def __call__(self, data_set: OperatorDataset, run_name: str = None) -> Operator:
71 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
73 data_set.u = data_set.u.to(device)
74 data_set.y = data_set.y.to(device)
75 data_set.x = data_set.x.to(device)
76 data_set.v = data_set.v.to(device)
78 for trf in data_set.transform.keys():
79 data_set.transform[trf] = data_set.transform[trf].to(device)
81 # data
82 train_set, val_set = random_split(data_set, [self.test_val_split, 1 - self.test_val_split])
83 train_loader = DataLoader(train_set, batch_size=self.batch_size, shuffle=True)
84 val_loader = DataLoader(val_set, batch_size=self.batch_size)
86 training_config = {
87 "val_indices": val_set.indices,
88 "val_size": len(val_set),
89 "train_indices": train_set.indices,
90 "train_size": len(train_set),
91 }
92 with open(self.out_dir.joinpath("training_config.json"), "w") as file_handle:
93 json.dump(training_config, file_handle)
95 # setup training
96 self.operator.to(device)
97 self.criterion.to(device)
99 best_val_loss = float("inf")
100 val_losses = []
101 train_losses = []
102 lrs = []
103 times = []
105 pbar = tqdm(range(self.max_epochs))
106 train_loss = torch.inf
107 val_loss = torch.inf
109 start = time.time()
111 with mlflow.start_run():
112 if run_name is not None:
113 mlflow.set_tag("mlflow.runName", run_name)
114 for epoch in pbar:
115 pbar.set_description(
116 f"Train Loss: {train_loss: .6f},\t Val Loss: {val_loss: .6f}, Lr: {self.optimizer.param_groups[0]['lr']}"
117 )
118 train_loss = self.train(train_loader, self.operator, epoch, device)
119 val_loss = self.eval(val_loader, self.operator, epoch, device)
120 self.lr_scheduler.step(epoch)
122 # update training parameters
123 lrs.append(self.optimizer.param_groups[0]["lr"])
124 train_losses.append(train_loss)
125 val_losses.append(val_loss)
126 times.append(time.time() - start)
128 # log metrics
129 if epoch in self.log_epochs:
130 mlflow.log_metric("Val loss", val_loss, step=epoch)
131 mlflow.log_metric("Train loss", train_loss, step=epoch)
132 mlflow.log_metric("LR", self.optimizer.param_groups[0]["lr"], step=epoch)
134 # save best model
135 if val_loss < best_val_loss:
136 best_dir = self.out_dir.joinpath("best")
137 if best_dir.is_dir():
138 shutil.rmtree(best_dir)
139 best_dir.mkdir(exist_ok=True, parents=True)
141 save_checkpoint(
142 self.operator,
143 val_loss,
144 train_loss,
145 epoch,
146 start,
147 self.batch_size,
148 train_set,
149 val_set,
150 best_dir,
151 )
152 best_val_loss = val_loss
154 save_checkpoint(
155 self.operator,
156 val_loss,
157 train_loss,
158 self.max_epochs,
159 start,
160 self.batch_size,
161 train_set,
162 val_set,
163 self.out_dir,
164 )
166 training_curves = pd.DataFrame(
167 {
168 "Epochs": torch.arange(0, self.max_epochs).tolist(),
169 "Val_loss": val_losses,
170 "Train_loss": train_losses,
171 "Lr": lrs,
172 "time": times,
173 }
174 )
175 training_curves.to_csv(self.out_dir.joinpath("training.csv"))
176 return self.operator
178 def train(self, loader, model, epoch, device):
179 # switch to train mode
180 model.train()
181 losses = []
182 for x, u, y, v in loader:
183 x, u, y, v = x.to(device), u.to(device), y.to(device), v.to(device)
185 # compute output
186 output = model(x, u, y)
187 loss = self.criterion(output, v)
189 # compute gradient
190 self.optimizer.zero_grad()
191 loss.backward()
192 self.optimizer.step()
194 # update metrics
195 losses.append(loss.item())
197 return torch.mean(torch.tensor(losses)).item()
199 def eval(self, loader, model, epoch, device):
200 # switch to train mode
201 model.eval()
203 losses = []
204 for x, u, y, v in loader:
205 x, u, y, v = x.to(device), u.to(device), y.to(device), v.to(device)
207 # compute output
208 output = model(x, u, y)
209 loss = self.criterion(output, v)
211 # update metrics
212 losses.append(loss.item())
214 return torch.mean(torch.tensor(losses)).item()