From ba6f3f974bfc4a2968964dbe5eedea73c9f5efcb Mon Sep 17 00:00:00 2001
From: Glenn Jocher <glenn.jocher@ultralytics.com>
Date: Fri, 28 May 2021 15:18:44 +0200
Subject: [PATCH] Enable direct `--weights URL` definition (#3373)

* Enable direct `--weights URL` definition

@KalenMike this PR will enable direct --weights URL definition. Example use case:
```
python train.py --weights https://storage.googleapis.com/bucket/dir/model.pt
```

* cleanup

* bug fixes

* weights = attempt_download(weights)

* Update experimental.py

* Update hubconf.py

* return bug fix

* comment mirror

* min_bytes
---
 hubconf.py             |  3 +--
 models/experimental.py |  3 +--
 train.py               |  2 +-
 utils/google_utils.py  | 53 ++++++++++++++++++++++++++----------------
 4 files changed, 36 insertions(+), 25 deletions(-)

diff --git a/hubconf.py b/hubconf.py
index f74e70c8..40bbb1ed 100644
--- a/hubconf.py
+++ b/hubconf.py
@@ -41,8 +41,7 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo
             cfg = list((Path(__file__).parent / 'models').rglob(f'{name}.yaml'))[0]  # model.yaml path
             model = Model(cfg, channels, classes)  # create model
             if pretrained:
-                attempt_download(fname)  # download if not found locally
-                ckpt = torch.load(fname, map_location=torch.device('cpu'))  # load
+                ckpt = torch.load(attempt_download(fname), map_location=torch.device('cpu'))  # load
                 msd = model.state_dict()  # model state_dict
                 csd = ckpt['model'].float().state_dict()  # checkpoint state_dict as FP32
                 csd = {k: v for k, v in csd.items() if msd[k].shape == v.shape}  # filter
diff --git a/models/experimental.py b/models/experimental.py
index afa78790..d316b183 100644
--- a/models/experimental.py
+++ b/models/experimental.py
@@ -116,8 +116,7 @@ def attempt_load(weights, map_location=None, inplace=True):
     # Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
     model = Ensemble()
     for w in weights if isinstance(weights, list) else [weights]:
-        attempt_download(w)
-        ckpt = torch.load(w, map_location=map_location)  # load
+        ckpt = torch.load(attempt_download(w), map_location=map_location)  # load
         model.append(ckpt['ema' if ckpt.get('ema') else 'model'].float().fuse().eval())  # FP32 model
 
     # Compatibility updates
diff --git a/train.py b/train.py
index 3e8d5075..b74cdb28 100644
--- a/train.py
+++ b/train.py
@@ -83,7 +83,7 @@ def train(hyp, opt, device, tb_writer=None):
     pretrained = weights.endswith('.pt')
     if pretrained:
         with torch_distributed_zero_first(rank):
-            attempt_download(weights)  # download if not found locally
+            weights = attempt_download(weights)  # download if not found locally
         ckpt = torch.load(weights, map_location=device)  # load checkpoint
         model = Model(opt.cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device)  # create
         exclude = ['anchor'] if (opt.cfg or hyp.get('anchors')) and not opt.resume else []  # exclude keys
diff --git a/utils/google_utils.py b/utils/google_utils.py
index 63d3e5b2..ac5c54db 100644
--- a/utils/google_utils.py
+++ b/utils/google_utils.py
@@ -16,11 +16,37 @@ def gsutil_getsize(url=''):
     return eval(s.split(' ')[0]) if len(s) else 0  # bytes
 
 
+def safe_download(file, url, url2=None, min_bytes=1E0, error_msg=''):
+    # Attempts to download file from url or url2, checks and removes incomplete downloads < min_bytes
+    file = Path(file)
+    try:  # GitHub
+        print(f'Downloading {url} to {file}...')
+        torch.hub.download_url_to_file(url, str(file))
+        assert file.exists() and file.stat().st_size > min_bytes  # check
+    except Exception as e:  # GCP
+        file.unlink(missing_ok=True)  # remove partial downloads
+        print(f'Download error: {e}\nRe-attempting {url2 or url} to {file}...')
+        os.system(f"curl -L '{url2 or url}' -o '{file}' --retry 3 -C -")  # curl download, retry and resume on fail
+    finally:
+        if not file.exists() or file.stat().st_size < min_bytes:  # check
+            file.unlink(missing_ok=True)  # remove partial downloads
+            print(f'ERROR: Download failure: {error_msg or url}')
+        print('')
+
+
 def attempt_download(file, repo='ultralytics/yolov5'):
     # Attempt file download if does not exist
     file = Path(str(file).strip().replace("'", ''))
 
     if not file.exists():
+        # URL specified
+        name = file.name
+        if str(file).startswith(('http:/', 'https:/')):  # download
+            url = str(file).replace(':/', '://')  # Pathlib turns :// -> :/
+            safe_download(file=name, url=url, min_bytes=1E5)
+            return name
+
+        # GitHub assets
         file.parent.mkdir(parents=True, exist_ok=True)  # make parent dir (if required)
         try:
             response = requests.get(f'https://api.github.com/repos/{repo}/releases/latest').json()  # github api
@@ -34,27 +60,14 @@ def attempt_download(file, repo='ultralytics/yolov5'):
             except:
                 tag = 'v5.0'  # current release
 
-        name = file.name
         if name in assets:
-            msg = f'{file} missing, try downloading from https://github.com/{repo}/releases/'
-            redundant = False  # second download option
-            try:  # GitHub
-                url = f'https://github.com/{repo}/releases/download/{tag}/{name}'
-                print(f'Downloading {url} to {file}...')
-                torch.hub.download_url_to_file(url, file)
-                assert file.exists() and file.stat().st_size > 1E6  # check
-            except Exception as e:  # GCP
-                print(f'Download error: {e}')
-                assert redundant, 'No secondary mirror'
-                url = f'https://storage.googleapis.com/{repo}/ckpt/{name}'
-                print(f'Downloading {url} to {file}...')
-                os.system(f"curl -L '{url}' -o '{file}' --retry 3 -C -")  # curl download, retry and resume on fail
-            finally:
-                if not file.exists() or file.stat().st_size < 1E6:  # check
-                    file.unlink(missing_ok=True)  # remove partial downloads
-                    print(f'ERROR: Download failure: {msg}')
-                print('')
-                return
+            safe_download(file,
+                          url=f'https://github.com/{repo}/releases/download/{tag}/{name}',
+                          # url2=f'https://storage.googleapis.com/{repo}/ckpt/{name}',  # backup url (optional)
+                          min_bytes=1E5,
+                          error_msg=f'{file} missing, try downloading from https://github.com/{repo}/releases/')
+
+    return str(file)
 
 
 def gdrive_download(id='16TiPfZj7htmTyhntwcZyEEAejOUxuT6m', file='tmp.zip'):
-- 
GitLab