Skip to content
Snippets Groups Projects 4.53 KiB
Newer Older
Mina Moshfegh's avatar
Mina Moshfegh committed
import sys
import numpy as np
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
from torch.autograd import Variable

# This function creates a 3x3 convolution with optional stride.
def conv3x3(in_planes, out_planes, stride=1):
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True)

# A helper to initialize model parameters in a Xavier style.
def conv_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        init.xavier_uniform_(m.weight, gain=np.sqrt(2))
        init.constant_(m.bias, 0)
    elif classname.find('BatchNorm') != -1:
        init.constant_(m.weight, 1)
        init.constant_(m.bias, 0)

# This is one "wide block" used in WideResNet.
# Typically, it's a residual block with dropout.
class WideBasic(nn.Module):
    def __init__(self, in_planes, planes, dropout_rate, stride=1, *args, **kwargs):
        # Notice that there's a duplicated super() call below, might need a fix.
        # "super(wide_basic, self).__init__()" was probably intended
        super(WideBasic, self).__init__()  # <-- "wide_basic" might be a leftover name
        super().__init__(*args, **kwargs)

        self.bn1 = nn.BatchNorm2d(in_planes)
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, bias=True)
        self.dropout = nn.Dropout(p=dropout_rate)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=True)

        # Shortcut for dimension/stride mismatch.
        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes:
            self.shortcut = nn.Sequential(nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=True), )

    def forward(self, x):
        # Classic WRN sequence: BN -> ReLU -> Conv -> Dropout -> BN -> ReLU -> Conv
        out = self.dropout(self.conv1(F.relu(self.bn1(x))))
        out = self.conv2(F.relu(self.bn2(out)))
        # Add skip connection from input x
        out += self.shortcut(x)

        return out

# Full WideResNet model that uses multiple WideBasic blocks.
class WideResNet(nn.Module):
    def __init__(self, depth, widen_factor, dropout_rate, num_classes, *args, **kwargs):
        # Similar duplication of super calls here:
        super(WideResNet, self).__init__()  # Might be leftover from earlier naming
        super().__init__(*args, **kwargs)

        self.in_planes = 16

        # According to the paper, wide-resnet depth is 6n + 4.
        assert ((depth - 4) % 6 == 0), 'Wide-resnet depth should be 6n+4'
        n = (depth - 4) / 6
        k = widen_factor

        print('| Wide-Resnet %dx%d' % (depth, k))
        nStages = [16, 16 * k, 32 * k, 64 * k]

        # Initial conv layer
        self.conv1 = conv3x3(3, nStages[0])

        # First group of blocks
        self.layer1 = self._wide_layer(WideBasic, nStages[1], n, dropout_rate, stride=1)
        self.layer2 = self._wide_layer(WideBasic, nStages[2], n, dropout_rate, stride=2)
        self.layer3 = self._wide_layer(WideBasic, nStages[3], n, dropout_rate, stride=2)

        # Final BN + linear classifier
        self.bn1 = nn.BatchNorm2d(nStages[3], momentum=0.9)
        self.linear = nn.Linear(nStages[3], num_classes)

    def _wide_layer(self, block, planes, num_blocks, dropout_rate, stride):
        # Each "layer" is a sequence of wide_basic blocks, controlling strides
        strides = [stride] + [1] * (int(num_blocks) - 1)
        layers = []

        for stride in strides:
            layers.append(block(self.in_planes, planes, dropout_rate, stride))
            self.in_planes = planes

        return nn.Sequential(*layers)

    def forward(self, x):
        # Standard WRN forward pass:
        out = self.conv1(x)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)

        out = F.relu(self.bn1(out))
        # Global average pool, typical for wide resnet
        out = F.avg_pool2d(out, 8)
        # Flatten
        out = out.view(out.size(0), -1)
        # Final linear layer
        out = self.linear(out)

        return out

    def forward_features(self, x):
        out = self.conv1(x)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)

        out = F.relu(self.bn1(out))
        # Global average pool, but we don't flatten it yet for feature extraction
        out = F.avg_pool2d(out, 8)
        return out