Coverage for src/nos/networks/residual.py: 97%

29 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-09-19 11:29 +0000

1import torch 

2import torch.nn as nn 

3 

4 

5class ResBlock(nn.Module): 

6 def __init__( 

7 self, 

8 width: int, 

9 depth: int, 

10 act: nn.Module, 

11 dropout_p: float = 0.0, 

12 is_last: bool = False, 

13 ): 

14 super().__init__() 

15 

16 self.net = nn.Sequential() 

17 depth = depth if not is_last else depth - 1 

18 for i in range(depth): 

19 self.net.add_module(f"linear_{i}", torch.nn.Linear(width, width)) 

20 self.net.add_module(f"norm_{i}", torch.nn.LayerNorm(width)) 

21 self.net.add_module(f"Act_{i}", act) 

22 

23 if is_last: 

24 self.net.add_module(f"linear_{depth - 1}", torch.nn.Linear(width, width)) 

25 

26 if dropout_p > 0.0: 

27 self.net.add_module("Dropout", nn.Dropout(dropout_p)) 

28 

29 def forward(self, x: torch.Tensor): 

30 out = self.net(x) 

31 return out + x 

32 

33 

34class ResNet(nn.Module): 

35 def __init__(self, width: int, depth: int, act: nn.Module, stride: int = 1, dropout_p: float = 0.0): 

36 super().__init__() 

37 

38 assert depth % stride == 0 

39 n_blocks = depth // stride 

40 

41 self.net = nn.Sequential() 

42 for i in range(n_blocks - 1): 

43 self.net.add_module( 

44 f"ResBlock_{i}", 

45 ResBlock(width=width, depth=stride, act=act, dropout_p=dropout_p), 

46 ) 

47 self.net.add_module( 

48 f"ResBlock_{n_blocks - 1}", 

49 ResBlock(width=width, depth=stride, act=act, dropout_p=0.0, is_last=True), 

50 ) 

51 

52 def forward(self, x: torch.Tensor): 

53 return self.net(x)