Coverage for src/nos/networks/residual.py: 97%
29 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-19 11:29 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-09-19 11:29 +0000
1import torch
2import torch.nn as nn
5class ResBlock(nn.Module):
6 def __init__(
7 self,
8 width: int,
9 depth: int,
10 act: nn.Module,
11 dropout_p: float = 0.0,
12 is_last: bool = False,
13 ):
14 super().__init__()
16 self.net = nn.Sequential()
17 depth = depth if not is_last else depth - 1
18 for i in range(depth):
19 self.net.add_module(f"linear_{i}", torch.nn.Linear(width, width))
20 self.net.add_module(f"norm_{i}", torch.nn.LayerNorm(width))
21 self.net.add_module(f"Act_{i}", act)
23 if is_last:
24 self.net.add_module(f"linear_{depth - 1}", torch.nn.Linear(width, width))
26 if dropout_p > 0.0:
27 self.net.add_module("Dropout", nn.Dropout(dropout_p))
29 def forward(self, x: torch.Tensor):
30 out = self.net(x)
31 return out + x
34class ResNet(nn.Module):
35 def __init__(self, width: int, depth: int, act: nn.Module, stride: int = 1, dropout_p: float = 0.0):
36 super().__init__()
38 assert depth % stride == 0
39 n_blocks = depth // stride
41 self.net = nn.Sequential()
42 for i in range(n_blocks - 1):
43 self.net.add_module(
44 f"ResBlock_{i}",
45 ResBlock(width=width, depth=stride, act=act, dropout_p=dropout_p),
46 )
47 self.net.add_module(
48 f"ResBlock_{n_blocks - 1}",
49 ResBlock(width=width, depth=stride, act=act, dropout_p=0.0, is_last=True),
50 )
52 def forward(self, x: torch.Tensor):
53 return self.net(x)