Coverage for src/nos/trainers/util.py: 100%
11 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-19 11:29 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-19 11:29 +0000
1import json
2import pathlib
3import time
5import torch
8def save_checkpoint(
9 operator, val_loss, train_loss, epoch, start, batch_size, train_set, val_set, out_dir: pathlib.Path = None
10) -> pathlib.Path:
11 checkpoint = {
12 "val_loss": val_loss,
13 "train_loss": train_loss,
14 "epoch": epoch,
15 "Time_trained": time.time() - start,
16 "batch_size": batch_size,
17 "train_size": len(train_set),
18 "val_size": len(val_set),
19 }
21 torch.save(operator, out_dir.joinpath("operator.pt"))
22 checkpoint_path = out_dir.joinpath("checkpoint.json")
23 with open(checkpoint_path, "w") as file_handle:
24 json.dump(checkpoint, file_handle)
26 return out_dir