Skip to content
Snippets Groups Projects
Commit b418c0ea authored by Mina Moshfegh's avatar Mina Moshfegh
Browse files

Upload New File

parent cb5112e4
No related branches found
No related tags found
No related merge requests found
import sys
import numpy as np
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
from torch.autograd import Variable
# This function creates a 3x3 convolution with optional stride.
def conv3x3(in_planes, out_planes, stride=1):
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True)
# A helper to initialize model parameters in a Xavier style.
def conv_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
init.xavier_uniform_(m.weight, gain=np.sqrt(2))
init.constant_(m.bias, 0)
elif classname.find('BatchNorm') != -1:
init.constant_(m.weight, 1)
init.constant_(m.bias, 0)
# This is one "wide block" used in WideResNet.
# Typically, it's a residual block with dropout.
class WideBasic(nn.Module):
def __init__(self, in_planes, planes, dropout_rate, stride=1, *args, **kwargs):
# Notice that there's a duplicated super() call below, might need a fix.
# "super(wide_basic, self).__init__()" was probably intended
super(WideBasic, self).__init__() # <-- "wide_basic" might be a leftover name
super().__init__(*args, **kwargs)
self.bn1 = nn.BatchNorm2d(in_planes)
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, bias=True)
self.dropout = nn.Dropout(p=dropout_rate)
self.bn2 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=True)
# Shortcut for dimension/stride mismatch.
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != planes:
self.shortcut = nn.Sequential(nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=True), )
def forward(self, x):
# Classic WRN sequence: BN -> ReLU -> Conv -> Dropout -> BN -> ReLU -> Conv
out = self.dropout(self.conv1(F.relu(self.bn1(x))))
out = self.conv2(F.relu(self.bn2(out)))
# Add skip connection from input x
out += self.shortcut(x)
return out
# Full WideResNet model that uses multiple WideBasic blocks.
class WideResNet(nn.Module):
def __init__(self, depth, widen_factor, dropout_rate, num_classes, *args, **kwargs):
# Similar duplication of super calls here:
super(WideResNet, self).__init__() # Might be leftover from earlier naming
super().__init__(*args, **kwargs)
self.in_planes = 16
# According to the paper, wide-resnet depth is 6n + 4.
assert ((depth - 4) % 6 == 0), 'Wide-resnet depth should be 6n+4'
n = (depth - 4) / 6
k = widen_factor
print('| Wide-Resnet %dx%d' % (depth, k))
nStages = [16, 16 * k, 32 * k, 64 * k]
# Initial conv layer
self.conv1 = conv3x3(3, nStages[0])
# First group of blocks
self.layer1 = self._wide_layer(WideBasic, nStages[1], n, dropout_rate, stride=1)
self.layer2 = self._wide_layer(WideBasic, nStages[2], n, dropout_rate, stride=2)
self.layer3 = self._wide_layer(WideBasic, nStages[3], n, dropout_rate, stride=2)
# Final BN + linear classifier
self.bn1 = nn.BatchNorm2d(nStages[3], momentum=0.9)
self.linear = nn.Linear(nStages[3], num_classes)
def _wide_layer(self, block, planes, num_blocks, dropout_rate, stride):
# Each "layer" is a sequence of wide_basic blocks, controlling strides
strides = [stride] + [1] * (int(num_blocks) - 1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, dropout_rate, stride))
self.in_planes = planes
return nn.Sequential(*layers)
def forward(self, x):
# Standard WRN forward pass:
out = self.conv1(x)
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = F.relu(self.bn1(out))
# Global average pool, typical for wide resnet
out = F.avg_pool2d(out, 8)
# Flatten
out = out.view(out.size(0), -1)
# Final linear layer
out = self.linear(out)
return out
def forward_features(self, x):
out = self.conv1(x)
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = F.relu(self.bn1(out))
# Global average pool, but we don't flatten it yet for feature extraction
out = F.avg_pool2d(out, 8)
return out
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment