Coverage for src/nos/trainers/trainer.py: 86%

107 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-09-19 11:29 +0000

1import json 

2import pathlib 

3import shutil 

4import time 

5 

6import mlflow 

7import pandas as pd 

8import torch.optim.lr_scheduler as sched 

9import torch.utils.data 

10from continuiti.data import ( 

11 OperatorDataset, 

12) 

13from continuiti.operators import ( 

14 Operator, 

15) 

16from torch.utils.data import ( 

17 DataLoader, 

18 random_split, 

19) 

20from tqdm import ( 

21 tqdm, 

22) 

23 

24from nos.utils import ( 

25 UniqueId, 

26) 

27 

28from .util import ( 

29 save_checkpoint, 

30) 

31 

32 

33class Trainer: 

34 def __init__( 

35 self, 

36 operator: Operator, 

37 criterion, 

38 optimizer, 

39 lr_scheduler: sched.LRScheduler = None, 

40 max_epochs: int = 1000, 

41 batch_size: int = 16, 

42 max_n_logs: int = 200, 

43 out_dir: pathlib.Path = None, 

44 ): 

45 self.operator = operator 

46 self.criterion = criterion 

47 self.criterion = criterion 

48 self.optimizer = optimizer 

49 if lr_scheduler is None: 

50 self.lr_scheduler = sched.ConstantLR(self.optimizer, factor=1.0) 

51 else: 

52 self.lr_scheduler = lr_scheduler 

53 

54 self.max_epochs = max_epochs 

55 self.batch_size = batch_size 

56 self.test_val_split = 0.9 

57 

58 # logging and model serialization 

59 if out_dir is None: 

60 uid = UniqueId() 

61 self.out_dir = pathlib.Path.cwd().joinpath("run", str(uid)) 

62 else: 

63 self.out_dir = out_dir 

64 self.out_dir.mkdir(parents=True, exist_ok=True) 

65 

66 log_epochs = torch.round(torch.linspace(0, max_epochs, max_n_logs)) 

67 log_epochs = log_epochs.tolist() 

68 self.log_epochs = [int(epoch) for epoch in log_epochs] 

69 

70 def __call__(self, data_set: OperatorDataset, run_name: str = None) -> Operator: 

71 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 

72 

73 data_set.u = data_set.u.to(device) 

74 data_set.y = data_set.y.to(device) 

75 data_set.x = data_set.x.to(device) 

76 data_set.v = data_set.v.to(device) 

77 

78 for trf in data_set.transform.keys(): 

79 data_set.transform[trf] = data_set.transform[trf].to(device) 

80 

81 # data 

82 train_set, val_set = random_split(data_set, [self.test_val_split, 1 - self.test_val_split]) 

83 train_loader = DataLoader(train_set, batch_size=self.batch_size, shuffle=True) 

84 val_loader = DataLoader(val_set, batch_size=self.batch_size) 

85 

86 training_config = { 

87 "val_indices": val_set.indices, 

88 "val_size": len(val_set), 

89 "train_indices": train_set.indices, 

90 "train_size": len(train_set), 

91 } 

92 with open(self.out_dir.joinpath("training_config.json"), "w") as file_handle: 

93 json.dump(training_config, file_handle) 

94 

95 # setup training 

96 self.operator.to(device) 

97 self.criterion.to(device) 

98 

99 best_val_loss = float("inf") 

100 val_losses = [] 

101 train_losses = [] 

102 lrs = [] 

103 times = [] 

104 

105 pbar = tqdm(range(self.max_epochs)) 

106 train_loss = torch.inf 

107 val_loss = torch.inf 

108 

109 start = time.time() 

110 

111 with mlflow.start_run(): 

112 if run_name is not None: 

113 mlflow.set_tag("mlflow.runName", run_name) 

114 for epoch in pbar: 

115 pbar.set_description( 

116 f"Train Loss: {train_loss: .6f},\t Val Loss: {val_loss: .6f}, Lr: {self.optimizer.param_groups[0]['lr']}" 

117 ) 

118 train_loss = self.train(train_loader, self.operator, epoch, device) 

119 val_loss = self.eval(val_loader, self.operator, epoch, device) 

120 self.lr_scheduler.step(epoch) 

121 

122 # update training parameters 

123 lrs.append(self.optimizer.param_groups[0]["lr"]) 

124 train_losses.append(train_loss) 

125 val_losses.append(val_loss) 

126 times.append(time.time() - start) 

127 

128 # log metrics 

129 if epoch in self.log_epochs: 

130 mlflow.log_metric("Val loss", val_loss, step=epoch) 

131 mlflow.log_metric("Train loss", train_loss, step=epoch) 

132 mlflow.log_metric("LR", self.optimizer.param_groups[0]["lr"], step=epoch) 

133 

134 # save best model 

135 if val_loss < best_val_loss: 

136 best_dir = self.out_dir.joinpath("best") 

137 if best_dir.is_dir(): 

138 shutil.rmtree(best_dir) 

139 best_dir.mkdir(exist_ok=True, parents=True) 

140 

141 save_checkpoint( 

142 self.operator, 

143 val_loss, 

144 train_loss, 

145 epoch, 

146 start, 

147 self.batch_size, 

148 train_set, 

149 val_set, 

150 best_dir, 

151 ) 

152 best_val_loss = val_loss 

153 

154 save_checkpoint( 

155 self.operator, 

156 val_loss, 

157 train_loss, 

158 self.max_epochs, 

159 start, 

160 self.batch_size, 

161 train_set, 

162 val_set, 

163 self.out_dir, 

164 ) 

165 

166 training_curves = pd.DataFrame( 

167 { 

168 "Epochs": torch.arange(0, self.max_epochs).tolist(), 

169 "Val_loss": val_losses, 

170 "Train_loss": train_losses, 

171 "Lr": lrs, 

172 "time": times, 

173 } 

174 ) 

175 training_curves.to_csv(self.out_dir.joinpath("training.csv")) 

176 return self.operator 

177 

178 def train(self, loader, model, epoch, device): 

179 # switch to train mode 

180 model.train() 

181 losses = [] 

182 for x, u, y, v in loader: 

183 x, u, y, v = x.to(device), u.to(device), y.to(device), v.to(device) 

184 

185 # compute output 

186 output = model(x, u, y) 

187 loss = self.criterion(output, v) 

188 

189 # compute gradient 

190 self.optimizer.zero_grad() 

191 loss.backward() 

192 self.optimizer.step() 

193 

194 # update metrics 

195 losses.append(loss.item()) 

196 

197 return torch.mean(torch.tensor(losses)).item() 

198 

199 def eval(self, loader, model, epoch, device): 

200 # switch to train mode 

201 model.eval() 

202 

203 losses = [] 

204 for x, u, y, v in loader: 

205 x, u, y, v = x.to(device), u.to(device), y.to(device), v.to(device) 

206 

207 # compute output 

208 output = model(x, u, y) 

209 loss = self.criterion(output, v) 

210 

211 # update metrics 

212 losses.append(loss.item()) 

213 

214 return torch.mean(torch.tensor(losses)).item()