Coverage for src/nos/trainers/util.py: 100%

11 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-09-19 11:29 +0000

1import json 

2import pathlib 

3import time 

4 

5import torch 

6 

7 

8def save_checkpoint( 

9 operator, val_loss, train_loss, epoch, start, batch_size, train_set, val_set, out_dir: pathlib.Path = None 

10) -> pathlib.Path: 

11 checkpoint = { 

12 "val_loss": val_loss, 

13 "train_loss": train_loss, 

14 "epoch": epoch, 

15 "Time_trained": time.time() - start, 

16 "batch_size": batch_size, 

17 "train_size": len(train_set), 

18 "val_size": len(val_set), 

19 } 

20 

21 torch.save(operator, out_dir.joinpath("operator.pt")) 

22 checkpoint_path = out_dir.joinpath("checkpoint.json") 

23 with open(checkpoint_path, "w") as file_handle: 

24 json.dump(checkpoint, file_handle) 

25 

26 return out_dir