diff --git a/src/models/wide_resnet.py b/src/models/wide_resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..09813e17c3a3190a3fc3c907cd2300bf7dcff2be --- /dev/null +++ b/src/models/wide_resnet.py @@ -0,0 +1,111 @@ +import sys +import numpy as np +import torch +import torch.nn as nn +import torch.nn.init as init +import torch.nn.functional as F +from torch.autograd import Variable + + +# This function creates a 3x3 convolution with optional stride. +def conv3x3(in_planes, out_planes, stride=1): + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True) + + +# A helper to initialize model parameters in a Xavier style. +def conv_init(m): + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + init.xavier_uniform_(m.weight, gain=np.sqrt(2)) + init.constant_(m.bias, 0) + elif classname.find('BatchNorm') != -1: + init.constant_(m.weight, 1) + init.constant_(m.bias, 0) + + +# This is one "wide block" used in WideResNet. +# Typically, it's a residual block with dropout. +class WideBasic(nn.Module): + def __init__(self, in_planes, planes, dropout_rate, stride=1, *args, **kwargs): + # Notice that there's a duplicated super() call below, might need a fix. + # "super(wide_basic, self).__init__()" was probably intended + super(wide_basic, self).__init__() # <-- "wide_basic" might be a leftover name + super().__init__(*args, **kwargs) + + self.bn1 = nn.BatchNorm2d(in_planes) + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, bias=True) + self.dropout = nn.Dropout(p=dropout_rate) + self.bn2 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=True) + + # Shortcut for dimension/stride mismatch. + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != planes: + self.shortcut = nn.Sequential(nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=True), ) + + def forward(self, x): + # Classic WRN sequence: BN -> ReLU -> Conv -> Dropout -> BN -> ReLU -> Conv + out = self.dropout(self.conv1(F.relu(self.bn1(x)))) + out = self.conv2(F.relu(self.bn2(out))) + # Add skip connection from input x + out += self.shortcut(x) + + return out + + +# Full WideResNet model that uses multiple WideBasic blocks. +class WideResNet(nn.Module): + def __init__(self, depth, widen_factor, dropout_rate, num_classes, *args, **kwargs): + # Similar duplication of super calls here: + super(Wide_ResNet, self).__init__() # Might be leftover from earlier naming + super().__init__(*args, **kwargs) + + self.in_planes = 16 + + # According to the paper, wide-resnet depth is 6n + 4. + assert ((depth - 4) % 6 == 0), 'Wide-resnet depth should be 6n+4' + n = (depth - 4) / 6 + k = widen_factor + + print('| Wide-Resnet %dx%d' % (depth, k)) + nStages = [16, 16 * k, 32 * k, 64 * k] + + # Initial conv layer + self.conv1 = conv3x3(3, nStages[0]) + + # First group of blocks + self.layer1 = self._wide_layer(wide_basic, nStages[1], n, dropout_rate, stride=1) + self.layer2 = self._wide_layer(wide_basic, nStages[2], n, dropout_rate, stride=2) + self.layer3 = self._wide_layer(wide_basic, nStages[3], n, dropout_rate, stride=2) + + # Final BN + linear classifier + self.bn1 = nn.BatchNorm2d(nStages[3], momentum=0.9) + self.linear = nn.Linear(nStages[3], num_classes) + + def _wide_layer(self, block, planes, num_blocks, dropout_rate, stride): + # Each "layer" is a sequence of wide_basic blocks, controlling strides + strides = [stride] + [1] * (int(num_blocks) - 1) + layers = [] + + for stride in strides: + layers.append(block(self.in_planes, planes, dropout_rate, stride)) + self.in_planes = planes + + return nn.Sequential(*layers) + + def forward(self, x): + # Standard WRN forward pass: + out = self.conv1(x) + out = self.layer1(out) + out = self.layer2(out) + out = self.layer3(out) + + out = F.relu(self.bn1(out)) + # Global average pool, typical for wide resnet + out = F.avg_pool2d(out, 8) + # Flatten + out = out.view(out.size(0), -1) + # Final linear layer + out = self.linear(out) + + return out