Unverified Commit 01cdb767 authored by Glenn Jocher's avatar Glenn Jocher Committed by GitHub
Browse files

Add `SPPF()` layer (#4420)

* Add `SPPF()` layer

* Cleanup

* Add credit
parent 24bea5e4
...@@ -161,7 +161,7 @@ class C3Ghost(C3): ...@@ -161,7 +161,7 @@ class C3Ghost(C3):
class SPP(nn.Module): class SPP(nn.Module):
# Spatial pyramid pooling layer used in YOLOv3-SPP # Spatial Pyramid Pooling (SPP) layer https://arxiv.org/abs/1406.4729
def __init__(self, c1, c2, k=(5, 9, 13)): def __init__(self, c1, c2, k=(5, 9, 13)):
super().__init__() super().__init__()
c_ = c1 // 2 # hidden channels c_ = c1 // 2 # hidden channels
...@@ -176,6 +176,24 @@ class SPP(nn.Module): ...@@ -176,6 +176,24 @@ class SPP(nn.Module):
return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1)) return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1))
class SPPF(nn.Module):
# Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher
def __init__(self, c1, c2, k=5): # equivalent to SPP(k=(5, 9, 13))
super().__init__()
c_ = c1 // 2 # hidden channels
self.cv1 = Conv(c1, c_, 1, 1)
self.cv2 = Conv(c_ * 4, c2, 1, 1)
self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)
def forward(self, x):
x = self.cv1(x)
with warnings.catch_warnings():
warnings.simplefilter('ignore') # suppress torch 1.9.0 max_pool2d() warning
y1 = self.m(x)
y2 = self.m(y1)
return self.cv2(torch.cat([x, y1, y2, self.m(y2)], 1))
class Focus(nn.Module): class Focus(nn.Module):
# Focus wh information into c-space # Focus wh information into c-space
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
......
...@@ -237,8 +237,8 @@ def parse_model(d, ch): # model_dict, input_channels(3) ...@@ -237,8 +237,8 @@ def parse_model(d, ch): # model_dict, input_channels(3)
pass pass
n = n_ = max(round(n * gd), 1) if n > 1 else n # depth gain n = n_ = max(round(n * gd), 1) if n > 1 else n # depth gain
if m in [Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP, if m in [Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv,
C3, C3TR, C3SPP, C3Ghost]: BottleneckCSP, C3, C3TR, C3SPP, C3Ghost]:
c1, c2 = ch[f], args[0] c1, c2 = ch[f], args[0]
if c2 != no: # if not output if c2 != no: # if not output
c2 = make_divisible(c2 * gw, 8) c2 = make_divisible(c2 * gw, 8)
...@@ -279,6 +279,7 @@ if __name__ == '__main__': ...@@ -279,6 +279,7 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--cfg', type=str, default='yolov5s.yaml', help='model.yaml') parser.add_argument('--cfg', type=str, default='yolov5s.yaml', help='model.yaml')
parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
parser.add_argument('--profile', action='store_true', help='profile model speed')
opt = parser.parse_args() opt = parser.parse_args()
opt.cfg = check_file(opt.cfg) # check file opt.cfg = check_file(opt.cfg) # check file
set_logging() set_logging()
...@@ -289,8 +290,9 @@ if __name__ == '__main__': ...@@ -289,8 +290,9 @@ if __name__ == '__main__':
model.train() model.train()
# Profile # Profile
# img = torch.rand(8 if torch.cuda.is_available() else 1, 3, 320, 320).to(device) if opt.profile:
# y = model(img, profile=True) img = torch.rand(8 if torch.cuda.is_available() else 1, 3, 640, 640).to(device)
y = model(img, profile=True)
# Tensorboard (not working https://github.com/ultralytics/yolov5/issues/2898) # Tensorboard (not working https://github.com/ultralytics/yolov5/issues/2898)
# from torch.utils.tensorboard import SummaryWriter # from torch.utils.tensorboard import SummaryWriter
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment