Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import sys
import numpy as np
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
from torch.autograd import Variable
# This function creates a 3x3 convolution with optional stride.
def conv3x3(in_planes, out_planes, stride=1):
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True)
# A helper to initialize model parameters in a Xavier style.
def conv_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
init.xavier_uniform_(m.weight, gain=np.sqrt(2))
init.constant_(m.bias, 0)
elif classname.find('BatchNorm') != -1:
init.constant_(m.weight, 1)
init.constant_(m.bias, 0)
# This is one "wide block" used in WideResNet.
# Typically, it's a residual block with dropout.
class WideBasic(nn.Module):
def __init__(self, in_planes, planes, dropout_rate, stride=1, *args, **kwargs):
# Notice that there's a duplicated super() call below, might need a fix.
# "super(wide_basic, self).__init__()" was probably intended
super(WideBasic, self).__init__() # <-- "wide_basic" might be a leftover name
super().__init__(*args, **kwargs)
self.bn1 = nn.BatchNorm2d(in_planes)
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, bias=True)
self.dropout = nn.Dropout(p=dropout_rate)
self.bn2 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=True)
# Shortcut for dimension/stride mismatch.
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != planes:
self.shortcut = nn.Sequential(nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=True), )
def forward(self, x):
# Classic WRN sequence: BN -> ReLU -> Conv -> Dropout -> BN -> ReLU -> Conv
out = self.dropout(self.conv1(F.relu(self.bn1(x))))
out = self.conv2(F.relu(self.bn2(out)))
# Add skip connection from input x
out += self.shortcut(x)
return out
# Full WideResNet model that uses multiple WideBasic blocks.
class WideResNet(nn.Module):
def __init__(self, depth, widen_factor, dropout_rate, num_classes, *args, **kwargs):
# Similar duplication of super calls here:
super(WideResNet, self).__init__() # Might be leftover from earlier naming
super().__init__(*args, **kwargs)
self.in_planes = 16
# According to the paper, wide-resnet depth is 6n + 4.
assert ((depth - 4) % 6 == 0), 'Wide-resnet depth should be 6n+4'
n = (depth - 4) / 6
k = widen_factor
print('| Wide-Resnet %dx%d' % (depth, k))
nStages = [16, 16 * k, 32 * k, 64 * k]
# Initial conv layer
self.conv1 = conv3x3(3, nStages[0])
# First group of blocks
self.layer1 = self._wide_layer(WideBasic, nStages[1], n, dropout_rate, stride=1)
self.layer2 = self._wide_layer(WideBasic, nStages[2], n, dropout_rate, stride=2)
self.layer3 = self._wide_layer(WideBasic, nStages[3], n, dropout_rate, stride=2)
# Final BN + linear classifier
self.bn1 = nn.BatchNorm2d(nStages[3], momentum=0.9)
self.linear = nn.Linear(nStages[3], num_classes)
def _wide_layer(self, block, planes, num_blocks, dropout_rate, stride):
# Each "layer" is a sequence of wide_basic blocks, controlling strides
strides = [stride] + [1] * (int(num_blocks) - 1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, dropout_rate, stride))
self.in_planes = planes
return nn.Sequential(*layers)
def forward(self, x):
# Standard WRN forward pass:
out = self.conv1(x)
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = F.relu(self.bn1(out))
# Global average pool, typical for wide resnet
out = F.avg_pool2d(out, 8)
# Flatten
out = out.view(out.size(0), -1)
# Final linear layer
out = self.linear(out)
return out
def forward_features(self, x):
out = self.conv1(x)
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = F.relu(self.bn1(out))
# Global average pool, but we don't flatten it yet for feature extraction
out = F.avg_pool2d(out, 8)
return out