From c8c5ef36c9a19c7843993ee8d51aebb685467eca Mon Sep 17 00:00:00 2001
From: Glenn Jocher <glenn.jocher@ultralytics.com>
Date: Wed, 28 Oct 2020 15:03:50 +0100
Subject: [PATCH] PyTorch 1.7.0 Compatibility Updates (#1233)

* torch 1.7.0 compatibility updates

* add inference verification
---
 hubconf.py             | 8 ++++++++
 models/experimental.py | 7 +++++++
 models/yolo.py         | 1 -
 utils/torch_utils.py   | 2 +-
 4 files changed, 16 insertions(+), 2 deletions(-)

diff --git a/hubconf.py b/hubconf.py
index cd14863c..cc210528 100644
--- a/hubconf.py
+++ b/hubconf.py
@@ -108,3 +108,11 @@ def yolov5x(pretrained=False, channels=3, classes=80):
 
 if __name__ == '__main__':
     model = create(name='yolov5s', pretrained=True, channels=3, classes=80)  # example
+    model = model.fuse().eval().autoshape()  # for autoshaping of PIL/cv2/np inputs and NMS
+
+    # Verify inference
+    from PIL import Image
+
+    img = Image.open('inference/images/zidane.jpg')
+    y = model(img)
+    print(y[0].shape)
diff --git a/models/experimental.py b/models/experimental.py
index 0b61027b..a2908a15 100644
--- a/models/experimental.py
+++ b/models/experimental.py
@@ -136,6 +136,13 @@ def attempt_load(weights, map_location=None):
         attempt_download(w)
         model.append(torch.load(w, map_location=map_location)['model'].float().fuse().eval())  # load FP32 model
 
+    # Compatibility updates
+    for m in model.modules():
+        if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6]:
+            m.inplace = True  # pytorch 1.7.0 compatibility
+        elif type(m) is Conv:
+            m._non_persistent_buffers_set = set()  # pytorch 1.6.0 compatibility
+
     if len(model) == 1:
         return model[-1]  # return model
     else:
diff --git a/models/yolo.py b/models/yolo.py
index 0d46054e..e1c30baa 100644
--- a/models/yolo.py
+++ b/models/yolo.py
@@ -165,7 +165,6 @@ class Model(nn.Module):
         print('Fusing layers... ')
         for m in self.model.modules():
             if type(m) is Conv and hasattr(m, 'bn'):
-                m._non_persistent_buffers_set = set()  # pytorch 1.6.0 compatability
                 m.conv = fuse_conv_and_bn(m.conv, m.bn)  # update conv
                 delattr(m, 'bn')  # remove batchnorm
                 m.forward = m.fuseforward  # update forward
diff --git a/utils/torch_utils.py b/utils/torch_utils.py
index f6818238..25eff07f 100644
--- a/utils/torch_utils.py
+++ b/utils/torch_utils.py
@@ -74,7 +74,7 @@ def initialize_weights(model):
         elif t is nn.BatchNorm2d:
             m.eps = 1e-3
             m.momentum = 0.03
-        elif t in [nn.LeakyReLU, nn.ReLU, nn.ReLU6]:
+        elif t in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6]:
             m.inplace = True
 
 
-- 
GitLab