From f3c3d2ce5d85ba77336a9d0a87c6a446732cdda6 Mon Sep 17 00:00:00 2001
From: Glenn Jocher <glenn.jocher@ultralytics.com>
Date: Tue, 8 Jun 2021 10:22:10 +0200
Subject: [PATCH] Merge `develop` branch into `master` (#3518)

* update ci-testing.yml (#3322)

* update ci-testing.yml

* update greetings.yml

* bring back os matrix

* update ci-testing.yml (#3322)

* update ci-testing.yml

* update greetings.yml

* bring back os matrix

* Enable direct `--weights URL` definition (#3373)

* Enable direct `--weights URL` definition

@KalenMike this PR will enable direct --weights URL definition. Example use case:
```
python train.py --weights https://storage.googleapis.com/bucket/dir/model.pt
```

* cleanup

* bug fixes

* weights = attempt_download(weights)

* Update experimental.py

* Update hubconf.py

* return bug fix

* comment mirror

* min_bytes

* Update tutorial.ipynb (#3368)

add Open in Kaggle badge

* `cv2.imread(img, -1)` for IMREAD_UNCHANGED (#3379)

* Update datasets.py

* comment

Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>

* COCO evolution fix (#3388)

* COCO evolution fix

* cleanup

* update print

* print fix

* Create `is_pip()` function (#3391)

Returns `True` if file is part of pip package. Useful for contextual behavior modification.

```python
def is_pip():
    # Is file in a pip package?
    return 'site-packages' in Path(__file__).absolute().parts
```

* Revert "`cv2.imread(img, -1)` for IMREAD_UNCHANGED (#3379)" (#3395)

This reverts commit 21a9607e00f1365b21d8c4bd81bdbf5fc0efea24.

* Update FLOPs description (#3422)

* Update README.md

* Changing FLOPS to FLOPs.

Co-authored-by: BuildTools <unconfigured@null.spigotmc.org>

* Parse URL authentication (#3424)

* Parse URL authentication

* urllib.parse.unquote()

* improved error handling

* improved error handling

* remove %3F

* update check_file()

* Add FLOPs title to table (#3453)

* Suppress jit trace warning + graph once (#3454)

* Suppress jit trace warning + graph once

Suppress harmless jit trace warning on TensorBoard add_graph call. Also fix multiple add_graph() calls bug, now only on batch 0.

* Update train.py

* Update MixUp augmentation `alpha=beta=32.0` (#3455)

Per VOC empirical results https://github.com/ultralytics/yolov5/issues/3380#issuecomment-853001307 by @developer0hye

* Add `timeout()` class (#3460)

* Add `timeout()` class

* rearrange order

* Faster HSV augmentation (#3462)

remove datatype conversion process that can be skipped

* Add `check_git_status()` 5 second timeout (#3464)

* Add check_git_status() 5 second timeout

This should prevent the SSH Git bug that we were discussing @KalenMike

* cleanup

* replace timeout with check_output built-in timeout

* Improved `check_requirements()` offline-handling (#3466)

Improve robustness of `check_requirements()` function to offline environments (do not attempt pip installs when offline).

* Add `output_names` argument for ONNX export with dynamic axes (#3456)

* Add output names & dynamic axes for onnx export

Add output_names and dynamic_axes names for all outputs in torch.onnx.export. The first four outputs of the model will have names output0, output1, output2, output3

* use first output only + cleanup

Co-authored-by: Samridha Shrestha <samridha.shrestha@g42.ai>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>

* Revert FP16 `test.py` and `detect.py` inference to FP32 default (#3423)

* fixed inference bug ,while use half precision

* replace --use-half with --half

* replace space and PEP8 in detect.py

* PEP8 detect.py

* update --half help comment

* Update test.py

* revert space

Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>

* Add additional links/resources to stale.yml message (#3467)

* Update stale.yml

* cleanup

* Update stale.yml

* reformat

* Update stale.yml HUB URL (#3468)

* Stale `github.actor` bug fix (#3483)

* Explicit `model.eval()` call `if opt.train=False` (#3475)

* call model.eval() when opt.train is False

call model.eval() when opt.train is False

* single-line if statement

* cleanup

Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>

* check_requirements() exclude `opencv-python` (#3495)

Fix for 3rd party or contrib versions of installed OpenCV as in https://github.com/ultralytics/yolov5/issues/3494.

* Earlier `assert` for cpu and half option (#3508)

* early assert for cpu and half option

early assert for cpu and half option

* Modified comment

Modified comment

* Update tutorial.ipynb (#3510)

* Reduce test.py results spacing (#3511)

* Update README.md (#3512)

* Update README.md

Minor modifications

* 850 width

* Update greetings.yml

revert greeting change as PRs will now merge to master.

Co-authored-by: Piotr Skalski <SkalskiP@users.noreply.github.com>
Co-authored-by: SkalskiP <piotr.skalski92@gmail.com>
Co-authored-by: Peretz Cohen <pizzaz93@users.noreply.github.com>
Co-authored-by: tudoulei <34886368+tudoulei@users.noreply.github.com>
Co-authored-by: chocosaj <chocosaj@users.noreply.github.com>
Co-authored-by: BuildTools <unconfigured@null.spigotmc.org>
Co-authored-by: Yonghye Kwon <developer.0hye@gmail.com>
Co-authored-by: Sam_S <SamSamhuns@users.noreply.github.com>
Co-authored-by: Samridha Shrestha <samridha.shrestha@g42.ai>
Co-authored-by: edificewang <609552430@qq.com>
---
 .github/workflows/ci-testing.yml |  6 +--
 .github/workflows/stale.yml      | 22 +++++++++-
 README.md                        | 26 ++++++------
 detect.py                        |  3 +-
 hubconf.py                       |  3 +-
 models/experimental.py           |  3 +-
 models/export.py                 | 18 ++++----
 models/yolo.py                   |  6 +--
 requirements.txt                 |  2 +-
 test.py                          |  6 ++-
 train.py                         | 70 ++++++++++++++++----------------
 tutorial.ipynb                   | 13 +++---
 utils/datasets.py                |  6 +--
 utils/general.py                 | 54 ++++++++++++++++++------
 utils/google_utils.py            | 58 ++++++++++++++++----------
 utils/torch_utils.py             | 12 +++---
 16 files changed, 186 insertions(+), 122 deletions(-)

diff --git a/.github/workflows/ci-testing.yml b/.github/workflows/ci-testing.yml
index df508474a..bb8b173cd 100644
--- a/.github/workflows/ci-testing.yml
+++ b/.github/workflows/ci-testing.yml
@@ -2,12 +2,10 @@ name: CI CPU testing
 
 on:  # https://help.github.com/en/actions/reference/events-that-trigger-workflows
   push:
-    branches: [ master ]
+    branches: [ master, develop ]
   pull_request:
     # The branches below must be a subset of the branches above
-    branches: [ master ]
-  schedule:
-    - cron: '0 0 * * *'  # Runs at 00:00 UTC every day
+    branches: [ master, develop ]
 
 jobs:
   cpu-tests:
diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml
index 0a094e237..a81e4007c 100644
--- a/.github/workflows/stale.yml
+++ b/.github/workflows/stale.yml
@@ -10,8 +10,26 @@ jobs:
       - uses: actions/stale@v3
         with:
           repo-token: ${{ secrets.GITHUB_TOKEN }}
-          stale-issue-message: 'This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.'
-          stale-pr-message: 'This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.'
+          stale-issue-message: |
+            👋 Hello, this issue has been automatically marked as stale because it has not had recent activity. Please note it will be closed if no further activity occurs.
+
+            Access additional [YOLOv5](https://ultralytics.com/yolov5) 🚀 resources:
+            - **Wiki** – https://github.com/ultralytics/yolov5/wiki
+            - **Tutorials** – https://github.com/ultralytics/yolov5#tutorials
+            - **Docs** – https://docs.ultralytics.com
+
+            Access additional [Ultralytics](https://ultralytics.com) ⚡ resources:
+            - **Ultralytics HUB** – https://ultralytics.com/pricing
+            - **Vision API** – https://ultralytics.com/yolov5
+            - **About Us** – https://ultralytics.com/about
+            - **Join Our Team** – https://ultralytics.com/work
+            - **Contact Us** – https://ultralytics.com/contact
+
+            Feel free to inform us of any other **issues** you discover or **feature requests** that come to mind in the future. Pull Requests (PRs) are also always welcomed!
+
+            Thank you for your contributions to YOLOv5 🚀 and Vision AI ⭐!
+
+          stale-pr-message: 'This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions YOLOv5 🚀 and Vision AI ⭐.'
           days-before-stale: 30
           days-before-close: 5
           exempt-issue-labels: 'documentation,tutorial'
diff --git a/README.md b/README.md
index a638657b3..3a785cc85 100755
--- a/README.md
+++ b/README.md
@@ -1,5 +1,5 @@
 <a align="left" href="https://apps.apple.com/app/id1452689527" target="_blank">
-<img width="800" src="https://user-images.githubusercontent.com/26833433/98699617-a1595a00-2377-11eb-8145-fc674eb9b1a7.jpg"></a>
+<img width="850" src="https://user-images.githubusercontent.com/26833433/121094150-72607500-c7ee-11eb-9f39-1d9e4ce89a9e.jpg"></a>
 &nbsp
 
 <a href="https://github.com/ultralytics/yolov5/actions"><img src="https://github.com/ultralytics/yolov5/workflows/CI%20CPU%20testing/badge.svg" alt="CI CPU testing"></a>
@@ -30,19 +30,19 @@ This repository represents Ultralytics open-source research into future object d
 
 [assets]: https://github.com/ultralytics/yolov5/releases
 
-Model |size<br><sup>(pixels) |mAP<sup>val<br>0.5:0.95 |mAP<sup>test<br>0.5:0.95 |mAP<sup>val<br>0.5 |Speed<br><sup>V100 (ms) | |params<br><sup>(M) |FLOPS<br><sup>640 (B)
----   |---  |---        |---         |---             |---                |---|---              |---
-[YOLOv5s][assets]    |640  |36.7     |36.7     |55.4     |**2.0** | |7.3   |17.0
-[YOLOv5m][assets]    |640  |44.5     |44.5     |63.1     |2.7     | |21.4  |51.3
-[YOLOv5l][assets]    |640  |48.2     |48.2     |66.9     |3.8     | |47.0  |115.4
-[YOLOv5x][assets]    |640  |**50.4** |**50.4** |**68.8** |6.1     | |87.7  |218.8
+|Model |size<br><sup>(pixels) |mAP<sup>val<br>0.5:0.95 |mAP<sup>test<br>0.5:0.95 |mAP<sup>val<br>0.5 |Speed<br><sup>V100 (ms) | |params<br><sup>(M) |FLOPs<br><sup>640 (B)
+|---                  |---  |---      |---      |---      |---     |---|--- |---
+|[YOLOv5s][assets]    |640  |36.7     |36.7     |55.4     |**2.0** | |7.3   |17.0
+|[YOLOv5m][assets]    |640  |44.5     |44.5     |63.1     |2.7     | |21.4  |51.3
+|[YOLOv5l][assets]    |640  |48.2     |48.2     |66.9     |3.8     | |47.0  |115.4
+|[YOLOv5x][assets]    |640  |**50.4** |**50.4** |**68.8** |6.1     | |87.7  |218.8
 | | | | | | || |
-[YOLOv5s6][assets]   |1280 |43.3     |43.3     |61.9     |**4.3** | |12.7  |17.4
-[YOLOv5m6][assets]   |1280 |50.5     |50.5     |68.7     |8.4     | |35.9  |52.4
-[YOLOv5l6][assets]   |1280 |53.4     |53.4     |71.1     |12.3    | |77.2  |117.7
-[YOLOv5x6][assets]   |1280 |**54.4** |**54.4** |**72.0** |22.4    | |141.8 |222.9
+|[YOLOv5s6][assets]   |1280 |43.3     |43.3     |61.9     |**4.3** | |12.7  |17.4
+|[YOLOv5m6][assets]   |1280 |50.5     |50.5     |68.7     |8.4     | |35.9  |52.4
+|[YOLOv5l6][assets]   |1280 |53.4     |53.4     |71.1     |12.3    | |77.2  |117.7
+|[YOLOv5x6][assets]   |1280 |**54.4** |**54.4** |**72.0** |22.4    | |141.8 |222.9
 | | | | | | || |
-[YOLOv5x6][assets] TTA |1280 |**55.0** |**55.0** |**72.0** |70.8 | |-  |-
+|[YOLOv5x6][assets] TTA |1280 |**55.0** |**55.0** |**72.0** |70.8 | |-  |-
 
 <details>
   <summary>Table Notes (click to expand)</summary>
@@ -112,7 +112,7 @@ Namespace(agnostic_nms=False, augment=False, classes=None, conf_thres=0.25, devi
 YOLOv5 v4.0-96-g83dc1b4 torch 1.7.0+cu101 CUDA:0 (Tesla V100-SXM2-16GB, 16160.5MB)
 
 Fusing layers... 
-Model Summary: 224 layers, 7266973 parameters, 0 gradients, 17.0 GFLOPS
+Model Summary: 224 layers, 7266973 parameters, 0 gradients, 17.0 GFLOPs
 image 1/2 /content/yolov5/data/images/bus.jpg: 640x480 4 persons, 1 bus, Done. (0.010s)
 image 2/2 /content/yolov5/data/images/zidane.jpg: 384x640 2 persons, 1 tie, Done. (0.011s)
 Results saved to runs/detect/exp2
diff --git a/detect.py b/detect.py
index c6b76d981..aba87687e 100644
--- a/detect.py
+++ b/detect.py
@@ -28,7 +28,7 @@ def detect(opt):
     # Initialize
     set_logging()
     device = select_device(opt.device)
-    half = device.type != 'cpu'  # half precision only supported on CUDA
+    half = opt.half and device.type != 'cpu'  # half precision only supported on CUDA
 
     # Load model
     model = attempt_load(weights, map_location=device)  # load FP32 model
@@ -172,6 +172,7 @@ if __name__ == '__main__':
     parser.add_argument('--line-thickness', default=3, type=int, help='bounding box thickness (pixels)')
     parser.add_argument('--hide-labels', default=False, action='store_true', help='hide labels')
     parser.add_argument('--hide-conf', default=False, action='store_true', help='hide confidences')
+    parser.add_argument('--half', type=bool, default=False, help='use FP16 half-precision inference')
     opt = parser.parse_args()
     print(opt)
     check_requirements(exclude=('tensorboard', 'pycocotools', 'thop'))
diff --git a/hubconf.py b/hubconf.py
index a52aae9fd..bedbee18f 100644
--- a/hubconf.py
+++ b/hubconf.py
@@ -42,8 +42,7 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo
             cfg = list((Path(__file__).parent / 'models').rglob(f'{name}.yaml'))[0]  # model.yaml path
             model = Model(cfg, channels, classes)  # create model
             if pretrained:
-                attempt_download(fname)  # download if not found locally
-                ckpt = torch.load(fname, map_location=torch.device('cpu'))  # load
+                ckpt = torch.load(attempt_download(fname), map_location=torch.device('cpu'))  # load
                 msd = model.state_dict()  # model state_dict
                 csd = ckpt['model'].float().state_dict()  # checkpoint state_dict as FP32
                 csd = {k: v for k, v in csd.items() if msd[k].shape == v.shape}  # filter
diff --git a/models/experimental.py b/models/experimental.py
index afa787907..d316b1837 100644
--- a/models/experimental.py
+++ b/models/experimental.py
@@ -116,8 +116,7 @@ def attempt_load(weights, map_location=None, inplace=True):
     # Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
     model = Ensemble()
     for w in weights if isinstance(weights, list) else [weights]:
-        attempt_download(w)
-        ckpt = torch.load(w, map_location=map_location)  # load
+        ckpt = torch.load(attempt_download(w), map_location=map_location)  # load
         model.append(ckpt['ema' if ckpt.get('ema') else 'model'].float().fuse().eval())  # FP32 model
 
     # Compatibility updates
diff --git a/models/export.py b/models/export.py
index 0d1147938..c03770178 100644
--- a/models/export.py
+++ b/models/export.py
@@ -44,22 +44,19 @@ if __name__ == '__main__':
 
     # Load PyTorch model
     device = select_device(opt.device)
+    assert not (opt.device.lower() == 'cpu' and opt.half), '--half only compatible with GPU export, i.e. use --device 0'
     model = attempt_load(opt.weights, map_location=device)  # load FP32 model
     labels = model.names
 
-    # Checks
+    # Input
     gs = int(max(model.stride))  # grid size (max stride)
     opt.img_size = [check_img_size(x, gs) for x in opt.img_size]  # verify img_size are gs-multiples
-    assert not (opt.device.lower() == 'cpu' and opt.half), '--half only compatible with GPU export, i.e. use --device 0'
-
-    # Input
     img = torch.zeros(opt.batch_size, 3, *opt.img_size).to(device)  # image size(1,3,320,192) iDetection
 
     # Update model
     if opt.half:
         img, model = img.half(), model.half()  # to FP16
-    if opt.train:
-        model.train()  # training mode (no grid construction in Detect layer)
+    model.train() if opt.train else model.eval()  # training mode = no Detect() layer grid construction
     for k, m in model.named_modules():
         m._non_persistent_buffers_set = set()  # pytorch 1.6.0 compatibility
         if isinstance(m, models.common.Conv):  # assign export-friendly activations
@@ -96,11 +93,14 @@ if __name__ == '__main__':
 
             print(f'{prefix} starting export with onnx {onnx.__version__}...')
             f = opt.weights.replace('.pt', '.onnx')  # filename
-            torch.onnx.export(model, img, f, verbose=False, opset_version=opt.opset_version, input_names=['images'],
+            torch.onnx.export(model, img, f, verbose=False, opset_version=opt.opset_version,
                               training=torch.onnx.TrainingMode.TRAINING if opt.train else torch.onnx.TrainingMode.EVAL,
                               do_constant_folding=not opt.train,
-                              dynamic_axes={'images': {0: 'batch', 2: 'height', 3: 'width'},  # size(1,3,640,640)
-                                            'output': {0: 'batch', 2: 'y', 3: 'x'}} if opt.dynamic else None)
+                              input_names=['images'],
+                              output_names=['output'],
+                              dynamic_axes={'images': {0: 'batch', 2: 'height', 3: 'width'},  # shape(1,3,640,640)
+                                            'output': {0: 'batch', 1: 'anchors'}  # shape(1,25200,85)
+                                            } if opt.dynamic else None)
 
             # Checks
             model_onnx = onnx.load(f)  # load onnx model
diff --git a/models/yolo.py b/models/yolo.py
index 2844cd041..1a7be9130 100644
--- a/models/yolo.py
+++ b/models/yolo.py
@@ -21,7 +21,7 @@ from utils.torch_utils import time_synchronized, fuse_conv_and_bn, model_info, s
     select_device, copy_attr
 
 try:
-    import thop  # for FLOPS computation
+    import thop  # for FLOPs computation
 except ImportError:
     thop = None
 
@@ -140,13 +140,13 @@ class Model(nn.Module):
                 x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f]  # from earlier layers
 
             if profile:
-                o = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 if thop else 0  # FLOPS
+                o = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 if thop else 0  # FLOPs
                 t = time_synchronized()
                 for _ in range(10):
                     _ = m(x)
                 dt.append((time_synchronized() - t) * 100)
                 if m == self.model[0]:
-                    logger.info(f"{'time (ms)':>10s} {'GFLOPS':>10s} {'params':>10s}  {'module'}")
+                    logger.info(f"{'time (ms)':>10s} {'GFLOPs':>10s} {'params':>10s}  {'module'}")
                 logger.info(f'{dt[-1]:10.2f} {o:10.2f} {m.np:10.0f}  {m.type}')
 
             x = m(x)  # run
diff --git a/requirements.txt b/requirements.txt
index 1c07c6511..a20fb6ad0 100755
--- a/requirements.txt
+++ b/requirements.txt
@@ -27,4 +27,4 @@ pandas
 # extras --------------------------------------
 # Cython  # for pycocotools https://github.com/cocodataset/cocoapi/issues/172
 pycocotools>=2.0  # COCO mAP
-thop  # FLOPS computation
+thop  # FLOPs computation
diff --git a/test.py b/test.py
index 0716c5d8b..12141f71c 100644
--- a/test.py
+++ b/test.py
@@ -95,7 +95,7 @@ def test(data,
     confusion_matrix = ConfusionMatrix(nc=nc)
     names = {k: v for k, v in enumerate(model.names if hasattr(model, 'names') else model.module.names)}
     coco91class = coco80_to_coco91_class()
-    s = ('%20s' + '%12s' * 6) % ('Class', 'Images', 'Labels', 'P', 'R', 'mAP@.5', 'mAP@.5:.95')
+    s = ('%20s' + '%11s' * 6) % ('Class', 'Images', 'Labels', 'P', 'R', 'mAP@.5', 'mAP@.5:.95')
     p, r, f1, mp, mr, map50, map, t0, t1 = 0., 0., 0., 0., 0., 0., 0., 0., 0.
     loss = torch.zeros(3, device=device)
     jdict, stats, ap, ap_class, wandb_images = [], [], [], [], []
@@ -228,7 +228,7 @@ def test(data,
         nt = torch.zeros(1)
 
     # Print results
-    pf = '%20s' + '%12i' * 2 + '%12.3g' * 4  # print format
+    pf = '%20s' + '%11i' * 2 + '%11.3g' * 4  # print format
     print(pf % ('all', seen, nt.sum(), mp, mr, map50, map))
 
     # Print results per class
@@ -306,6 +306,7 @@ if __name__ == '__main__':
     parser.add_argument('--project', default='runs/test', help='save to project/name')
     parser.add_argument('--name', default='exp', help='save to project/name')
     parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
+    parser.add_argument('--half', type=bool, default=False, help='use FP16 half-precision inference')
     opt = parser.parse_args()
     opt.save_json |= opt.data.endswith('coco.yaml')
     opt.data = check_file(opt.data)  # check file
@@ -326,6 +327,7 @@ if __name__ == '__main__':
              save_txt=opt.save_txt | opt.save_hybrid,
              save_hybrid=opt.save_hybrid,
              save_conf=opt.save_conf,
+             half_precision=opt.half,
              opt=opt
              )
 
diff --git a/train.py b/train.py
index 3e8d5075a..093a6197f 100644
--- a/train.py
+++ b/train.py
@@ -4,6 +4,7 @@ import math
 import os
 import random
 import time
+import warnings
 from copy import deepcopy
 from pathlib import Path
 from threading import Thread
@@ -62,7 +63,6 @@ def train(hyp, opt, device, tb_writer=None):
     init_seeds(2 + rank)
     with open(opt.data) as f:
         data_dict = yaml.safe_load(f)  # data dict
-    is_coco = opt.data.endswith('coco.yaml')
 
     # Logging- Doing this before checking the dataset. Might update data_dict
     loggers = {'wandb': None}  # loggers dict
@@ -78,12 +78,13 @@ def train(hyp, opt, device, tb_writer=None):
     nc = 1 if opt.single_cls else int(data_dict['nc'])  # number of classes
     names = ['item'] if opt.single_cls and len(data_dict['names']) != 1 else data_dict['names']  # class names
     assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, opt.data)  # check
+    is_coco = opt.data.endswith('coco.yaml') and nc == 80  # COCO dataset
 
     # Model
     pretrained = weights.endswith('.pt')
     if pretrained:
         with torch_distributed_zero_first(rank):
-            attempt_download(weights)  # download if not found locally
+            weights = attempt_download(weights)  # download if not found locally
         ckpt = torch.load(weights, map_location=device)  # load checkpoint
         model = Model(opt.cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device)  # create
         exclude = ['anchor'] if (opt.cfg or hyp.get('anchors')) and not opt.resume else []  # exclude keys
@@ -323,18 +324,19 @@ def train(hyp, opt, device, tb_writer=None):
                 mloss = (mloss * i + loss_items) / (i + 1)  # update mean losses
                 mem = '%.3gG' % (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0)  # (GB)
                 s = ('%10s' * 2 + '%10.4g' * 6) % (
-                    '%g/%g' % (epoch, epochs - 1), mem, *mloss, targets.shape[0], imgs.shape[-1])
+                    f'{epoch}/{epochs - 1}', mem, *mloss, targets.shape[0], imgs.shape[-1])
                 pbar.set_description(s)
 
                 # Plot
                 if plots and ni < 3:
                     f = save_dir / f'train_batch{ni}.jpg'  # filename
                     Thread(target=plot_images, args=(imgs, targets, paths, f), daemon=True).start()
-                    if tb_writer:
-                        tb_writer.add_graph(torch.jit.trace(de_parallel(model), imgs, strict=False), [])  # model graph
-                        # tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch)
+                    if tb_writer and ni == 0:
+                        with warnings.catch_warnings():
+                            warnings.simplefilter('ignore')  # suppress jit trace warning
+                            tb_writer.add_graph(torch.jit.trace(de_parallel(model), imgs, strict=False), [])  # graph
                 elif plots and ni == 10 and wandb_logger.wandb:
-                    wandb_logger.log({"Mosaics": [wandb_logger.wandb.Image(str(x), caption=x.name) for x in
+                    wandb_logger.log({'Mosaics': [wandb_logger.wandb.Image(str(x), caption=x.name) for x in
                                                   save_dir.glob('train*.jpg') if x.exists()]})
 
             # end batch ------------------------------------------------------------------------------------------------
@@ -358,6 +360,7 @@ def train(hyp, opt, device, tb_writer=None):
                                                  single_cls=opt.single_cls,
                                                  dataloader=testloader,
                                                  save_dir=save_dir,
+                                                 save_json=is_coco and final_epoch,
                                                  verbose=nc < 50 and final_epoch,
                                                  plots=plots and final_epoch,
                                                  wandb_logger=wandb_logger,
@@ -409,41 +412,38 @@ def train(hyp, opt, device, tb_writer=None):
         # end epoch ----------------------------------------------------------------------------------------------------
     # end training
     if rank in [-1, 0]:
-        # Plots
+        logger.info(f'{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.\n')
         if plots:
             plot_results(save_dir=save_dir)  # save as results.png
             if wandb_logger.wandb:
                 files = ['results.png', 'confusion_matrix.png', *[f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R')]]
                 wandb_logger.log({"Results": [wandb_logger.wandb.Image(str(save_dir / f), caption=f) for f in files
                                               if (save_dir / f).exists()]})
-        # Test best.pt
-        logger.info('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600))
-        if opt.data.endswith('coco.yaml') and nc == 80:  # if COCO
-            for m in [last, best] if best.exists() else [last]:  # speed, mAP tests
-                results, _, _ = test.test(opt.data,
-                                          batch_size=batch_size * 2,
-                                          imgsz=imgsz_test,
-                                          conf_thres=0.001,
-                                          iou_thres=0.7,
-                                          model=attempt_load(m, device).half(),
-                                          single_cls=opt.single_cls,
-                                          dataloader=testloader,
-                                          save_dir=save_dir,
-                                          save_json=True,
-                                          plots=False,
-                                          is_coco=is_coco)
 
-        # Strip optimizers
-        final = best if best.exists() else last  # final model
-        for f in last, best:
-            if f.exists():
-                strip_optimizer(f)  # strip optimizers
-        if opt.bucket:
-            os.system(f'gsutil cp {final} gs://{opt.bucket}/weights')  # upload
-        if wandb_logger.wandb and not opt.evolve:  # Log the stripped model
-            wandb_logger.wandb.log_artifact(str(final), type='model',
-                                            name='run_' + wandb_logger.wandb_run.id + '_model',
-                                            aliases=['latest', 'best', 'stripped'])
+        if not opt.evolve:
+            if is_coco:  # COCO dataset
+                for m in [last, best] if best.exists() else [last]:  # speed, mAP tests
+                    results, _, _ = test.test(opt.data,
+                                              batch_size=batch_size * 2,
+                                              imgsz=imgsz_test,
+                                              conf_thres=0.001,
+                                              iou_thres=0.7,
+                                              model=attempt_load(m, device).half(),
+                                              single_cls=opt.single_cls,
+                                              dataloader=testloader,
+                                              save_dir=save_dir,
+                                              save_json=True,
+                                              plots=False,
+                                              is_coco=is_coco)
+
+            # Strip optimizers
+            for f in last, best:
+                if f.exists():
+                    strip_optimizer(f)  # strip optimizers
+            if wandb_logger.wandb:  # Log the stripped model
+                wandb_logger.wandb.log_artifact(str(best if best.exists() else last), type='model',
+                                                name='run_' + wandb_logger.wandb_run.id + '_model',
+                                                aliases=['latest', 'best', 'stripped'])
         wandb_logger.finish_run()
     else:
         dist.destroy_process_group()
diff --git a/tutorial.ipynb b/tutorial.ipynb
index 3954feadf..4e760b13b 100644
--- a/tutorial.ipynb
+++ b/tutorial.ipynb
@@ -517,7 +517,8 @@
         "colab_type": "text"
       },
       "source": [
-        "<a href=\"https://colab.research.google.com/github/ultralytics/yolov5/blob/master/tutorial.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
+        "<a href=\"https://colab.research.google.com/github/ultralytics/yolov5/blob/master/tutorial.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>",
+        "<a href=\"https://kaggle.com/kernels/welcome?src=https://github.com/ultralytics/yolov5/blob/master/tutorial.ipynb\" target=\"_parent\"><img alt=\"Kaggle\" title=\"Open in Kaggle\" src=\"https://kaggle.com/static/images/open-in-kaggle.svg\"></a>"        
       ]
     },
     {
@@ -529,7 +530,7 @@
         "<img src=\"https://user-images.githubusercontent.com/26833433/98702494-b71c4e80-237a-11eb-87ed-17fcd6b3f066.jpg\">\n",
         "\n",
         "This is the **official YOLOv5 🚀 notebook** authored by **Ultralytics**, and is freely available for redistribution under the [GPL-3.0 license](https://choosealicense.com/licenses/gpl-3.0/). \n",
-        "For more information please visit https://github.com/ultralytics/yolov5 and https://www.ultralytics.com. Thank you!"
+        "For more information please visit https://github.com/ultralytics/yolov5 and https://ultralytics.com. Thank you!"
       ]
     },
     {
@@ -610,7 +611,7 @@
             "YOLOv5 🚀 v5.0-1-g0f395b3 torch 1.8.1+cu101 CUDA:0 (Tesla V100-SXM2-16GB, 16160.5MB)\n",
             "\n",
             "Fusing layers... \n",
-            "Model Summary: 224 layers, 7266973 parameters, 0 gradients, 17.0 GFLOPS\n",
+            "Model Summary: 224 layers, 7266973 parameters, 0 gradients, 17.0 GFLOPs\n",
             "image 1/2 /content/yolov5/data/images/bus.jpg: 640x480 4 persons, 1 bus, Done. (0.008s)\n",
             "image 2/2 /content/yolov5/data/images/zidane.jpg: 384x640 2 persons, 2 ties, Done. (0.008s)\n",
             "Results saved to runs/detect/exp\n",
@@ -733,7 +734,7 @@
             "100% 168M/168M [00:05<00:00, 32.3MB/s]\n",
             "\n",
             "Fusing layers... \n",
-            "Model Summary: 476 layers, 87730285 parameters, 0 gradients, 218.8 GFLOPS\n",
+            "Model Summary: 476 layers, 87730285 parameters, 0 gradients, 218.8 GFLOPs\n",
             "\u001b[34m\u001b[1mval: \u001b[0mScanning '../coco/val2017' images and labels... 4952 found, 48 missing, 0 empty, 0 corrupted: 100% 5000/5000 [00:01<00:00, 3102.29it/s]\n",
             "\u001b[34m\u001b[1mval: \u001b[0mNew cache created: ../coco/val2017.cache\n",
             "               Class      Images      Labels           P           R      mAP@.5  mAP@.5:.95: 100% 157/157 [01:23<00:00,  1.87it/s]\n",
@@ -963,7 +964,7 @@
             " 22          [-1, 10]  1         0  models.common.Concat                    [1]                           \n",
             " 23                -1  1   1182720  models.common.C3                        [512, 512, 1, False]          \n",
             " 24      [17, 20, 23]  1    229245  models.yolo.Detect                      [80, [[10, 13, 16, 30, 33, 23], [30, 61, 62, 45, 59, 119], [116, 90, 156, 198, 373, 326]], [128, 256, 512]]\n",
-            "Model Summary: 283 layers, 7276605 parameters, 7276605 gradients, 17.1 GFLOPS\n",
+            "Model Summary: 283 layers, 7276605 parameters, 7276605 gradients, 17.1 GFLOPs\n",
             "\n",
             "Transferred 362/362 items from yolov5s.pt\n",
             "Scaled weight_decay = 0.0005\n",
@@ -1260,4 +1261,4 @@
       "outputs": []
     }
   ]
-}
\ No newline at end of file
+}
diff --git a/utils/datasets.py b/utils/datasets.py
index 7dd181400..b6e43b94c 100755
--- a/utils/datasets.py
+++ b/utils/datasets.py
@@ -535,7 +535,7 @@ class LoadImagesAndLabels(Dataset):  # for training/testing
             # MixUp https://arxiv.org/pdf/1710.09412.pdf
             if random.random() < hyp['mixup']:
                 img2, labels2 = load_mosaic(self, random.randint(0, self.n - 1))
-                r = np.random.beta(8.0, 8.0)  # mixup ratio, alpha=beta=8.0
+                r = np.random.beta(32.0, 32.0)  # mixup ratio, alpha=beta=32.0
                 img = (img * r + img2 * (1 - r)).astype(np.uint8)
                 labels = np.concatenate((labels, labels2), 0)
 
@@ -655,12 +655,12 @@ def augment_hsv(img, hgain=0.5, sgain=0.5, vgain=0.5):
     hue, sat, val = cv2.split(cv2.cvtColor(img, cv2.COLOR_BGR2HSV))
     dtype = img.dtype  # uint8
 
-    x = np.arange(0, 256, dtype=np.int16)
+    x = np.arange(0, 256, dtype=r.dtype)
     lut_hue = ((x * r[0]) % 180).astype(dtype)
     lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)
     lut_val = np.clip(x * r[2], 0, 255).astype(dtype)
 
-    img_hsv = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val))).astype(dtype)
+    img_hsv = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val)))
     cv2.cvtColor(img_hsv, cv2.COLOR_HSV2BGR, dst=img)  # no return needed
 
 
diff --git a/utils/general.py b/utils/general.py
index 006e64859..a12b0aafb 100755
--- a/utils/general.py
+++ b/utils/general.py
@@ -1,5 +1,6 @@
 # YOLOv5 general utils
 
+import contextlib
 import glob
 import logging
 import math
@@ -7,11 +8,13 @@ import os
 import platform
 import random
 import re
-import subprocess
+import signal
 import time
+import urllib
 from itertools import repeat
 from multiprocessing.pool import ThreadPool
 from pathlib import Path
+from subprocess import check_output
 
 import cv2
 import numpy as np
@@ -33,6 +36,26 @@ cv2.setNumThreads(0)  # prevent OpenCV from multithreading (incompatible with Py
 os.environ['NUMEXPR_MAX_THREADS'] = str(min(os.cpu_count(), 8))  # NumExpr max threads
 
 
+class timeout(contextlib.ContextDecorator):
+    # Usage: @timeout(seconds) decorator or 'with timeout(seconds):' context manager
+    def __init__(self, seconds, *, timeout_msg='', suppress_timeout_errors=True):
+        self.seconds = int(seconds)
+        self.timeout_message = timeout_msg
+        self.suppress = bool(suppress_timeout_errors)
+
+    def _timeout_handler(self, signum, frame):
+        raise TimeoutError(self.timeout_message)
+
+    def __enter__(self):
+        signal.signal(signal.SIGALRM, self._timeout_handler)  # Set handler for SIGALRM
+        signal.alarm(self.seconds)  # start countdown for SIGALRM to be raised
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        signal.alarm(0)  # Cancel SIGALRM if it's scheduled
+        if self.suppress and exc_type is TimeoutError:  # Suppress TimeoutError
+            return True
+
+
 def set_logging(rank=-1, verbose=True):
     logging.basicConfig(
         format="%(message)s",
@@ -53,12 +76,12 @@ def get_latest_run(search_dir='.'):
 
 
 def is_docker():
-    # Is environment a Docker container
+    # Is environment a Docker container?
     return Path('/workspace').exists()  # or Path('/.dockerenv').exists()
 
 
 def is_colab():
-    # Is environment a Google Colab instance
+    # Is environment a Google Colab instance?
     try:
         import google.colab
         return True
@@ -66,6 +89,11 @@ def is_colab():
         return False
 
 
+def is_pip():
+    # Is file in a pip package?
+    return 'site-packages' in Path(__file__).absolute().parts
+
+
 def emojis(str=''):
     # Return platform-dependent emoji-safe version of string
     return str.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else str
@@ -80,13 +108,13 @@ def check_online():
     # Check internet connectivity
     import socket
     try:
-        socket.create_connection(("1.1.1.1", 443), 5)  # check host accesability
+        socket.create_connection(("1.1.1.1", 443), 5)  # check host accessibility
         return True
     except OSError:
         return False
 
 
-def check_git_status():
+def check_git_status(err_msg=', for updates see https://github.com/ultralytics/yolov5'):
     # Recommend 'git pull' if code is out of date
     print(colorstr('github: '), end='')
     try:
@@ -95,9 +123,9 @@ def check_git_status():
         assert check_online(), 'skipping check (offline)'
 
         cmd = 'git fetch && git config --get remote.origin.url'
-        url = subprocess.check_output(cmd, shell=True).decode().strip().rstrip('.git')  # github repo url
-        branch = subprocess.check_output('git rev-parse --abbrev-ref HEAD', shell=True).decode().strip()  # checked out
-        n = int(subprocess.check_output(f'git rev-list {branch}..origin/master --count', shell=True))  # commits behind
+        url = check_output(cmd, shell=True, timeout=5).decode().strip().rstrip('.git')  # git fetch
+        branch = check_output('git rev-parse --abbrev-ref HEAD', shell=True).decode().strip()  # checked out
+        n = int(check_output(f'git rev-list {branch}..origin/master --count', shell=True))  # commits behind
         if n > 0:
             s = f"⚠️ WARNING: code is out of date by {n} commit{'s' * (n > 1)}. " \
                 f"Use 'git pull' to update or 'git clone {url}' to download latest."
@@ -105,7 +133,7 @@ def check_git_status():
             s = f'up to date with {url} ✅'
         print(emojis(s))  # emoji-safe
     except Exception as e:
-        print(e)
+        print(f'{e}{err_msg}')
 
 
 def check_python(minimum='3.7.0', required=True):
@@ -135,10 +163,11 @@ def check_requirements(requirements='requirements.txt', exclude=()):
         try:
             pkg.require(r)
         except Exception as e:  # DistributionNotFound or VersionConflict if requirements not met
-            n += 1
             print(f"{prefix} {r} not found and is required by YOLOv5, attempting auto-update...")
             try:
-                print(subprocess.check_output(f"pip install '{r}'", shell=True).decode())
+                assert check_online(), f"'pip install {r}' skipped (offline)"
+                print(check_output(f"pip install '{r}'", shell=True).decode())
+                n += 1
             except Exception as e:
                 print(f'{prefix} {e}')
 
@@ -178,7 +207,8 @@ def check_file(file):
     if Path(file).is_file() or file == '':  # exists
         return file
     elif file.startswith(('http://', 'https://')):  # download
-        url, file = file, Path(file).name
+        url, file = file, Path(urllib.parse.unquote(str(file))).name  # url, file (decode '%2F' to '/' etc.)
+        file = file.split('?')[0]  # parse authentication https://url.com/file.txt?auth...
         print(f'Downloading {url} to {file}...')
         torch.hub.download_url_to_file(url, file)
         assert Path(file).exists() and Path(file).stat().st_size > 0, f'File download failed: {url}'  # check
diff --git a/utils/google_utils.py b/utils/google_utils.py
index 63d3e5b21..aefc7de2d 100644
--- a/utils/google_utils.py
+++ b/utils/google_utils.py
@@ -4,6 +4,7 @@ import os
 import platform
 import subprocess
 import time
+import urllib
 from pathlib import Path
 
 import requests
@@ -16,11 +17,39 @@ def gsutil_getsize(url=''):
     return eval(s.split(' ')[0]) if len(s) else 0  # bytes
 
 
-def attempt_download(file, repo='ultralytics/yolov5'):
+def safe_download(file, url, url2=None, min_bytes=1E0, error_msg=''):
+    # Attempts to download file from url or url2, checks and removes incomplete downloads < min_bytes
+    file = Path(file)
+    assert_msg = f"Downloaded file '{file}' does not exist or size is < min_bytes={min_bytes}"
+    try:  # url1
+        print(f'Downloading {url} to {file}...')
+        torch.hub.download_url_to_file(url, str(file))
+        assert file.exists() and file.stat().st_size > min_bytes, assert_msg  # check
+    except Exception as e:  # url2
+        file.unlink(missing_ok=True)  # remove partial downloads
+        print(f'ERROR: {e}\nRe-attempting {url2 or url} to {file}...')
+        os.system(f"curl -L '{url2 or url}' -o '{file}' --retry 3 -C -")  # curl download, retry and resume on fail
+    finally:
+        if not file.exists() or file.stat().st_size < min_bytes:  # check
+            file.unlink(missing_ok=True)  # remove partial downloads
+            print(f"ERROR: {assert_msg}\n{error_msg}")
+        print('')
+
+
+def attempt_download(file, repo='ultralytics/yolov5'):  # from utils.google_utils import *; attempt_download()
     # Attempt file download if does not exist
     file = Path(str(file).strip().replace("'", ''))
 
     if not file.exists():
+        # URL specified
+        name = Path(urllib.parse.unquote(str(file))).name  # decode '%2F' to '/' etc.
+        if str(file).startswith(('http:/', 'https:/')):  # download
+            url = str(file).replace(':/', '://')  # Pathlib turns :// -> :/
+            name = name.split('?')[0]  # parse authentication https://url.com/file.txt?auth...
+            safe_download(file=name, url=url, min_bytes=1E5)
+            return name
+
+        # GitHub assets
         file.parent.mkdir(parents=True, exist_ok=True)  # make parent dir (if required)
         try:
             response = requests.get(f'https://api.github.com/repos/{repo}/releases/latest').json()  # github api
@@ -34,27 +63,14 @@ def attempt_download(file, repo='ultralytics/yolov5'):
             except:
                 tag = 'v5.0'  # current release
 
-        name = file.name
         if name in assets:
-            msg = f'{file} missing, try downloading from https://github.com/{repo}/releases/'
-            redundant = False  # second download option
-            try:  # GitHub
-                url = f'https://github.com/{repo}/releases/download/{tag}/{name}'
-                print(f'Downloading {url} to {file}...')
-                torch.hub.download_url_to_file(url, file)
-                assert file.exists() and file.stat().st_size > 1E6  # check
-            except Exception as e:  # GCP
-                print(f'Download error: {e}')
-                assert redundant, 'No secondary mirror'
-                url = f'https://storage.googleapis.com/{repo}/ckpt/{name}'
-                print(f'Downloading {url} to {file}...')
-                os.system(f"curl -L '{url}' -o '{file}' --retry 3 -C -")  # curl download, retry and resume on fail
-            finally:
-                if not file.exists() or file.stat().st_size < 1E6:  # check
-                    file.unlink(missing_ok=True)  # remove partial downloads
-                    print(f'ERROR: Download failure: {msg}')
-                print('')
-                return
+            safe_download(file,
+                          url=f'https://github.com/{repo}/releases/download/{tag}/{name}',
+                          # url2=f'https://storage.googleapis.com/{repo}/ckpt/{name}',  # backup url (optional)
+                          min_bytes=1E5,
+                          error_msg=f'{file} missing, try downloading from https://github.com/{repo}/releases/')
+
+    return str(file)
 
 
 def gdrive_download(id='16TiPfZj7htmTyhntwcZyEEAejOUxuT6m', file='tmp.zip'):
diff --git a/utils/torch_utils.py b/utils/torch_utils.py
index aa54c3cf5..6a7d07634 100644
--- a/utils/torch_utils.py
+++ b/utils/torch_utils.py
@@ -18,7 +18,7 @@ import torch.nn.functional as F
 import torchvision
 
 try:
-    import thop  # for FLOPS computation
+    import thop  # for FLOPs computation
 except ImportError:
     thop = None
 logger = logging.getLogger(__name__)
@@ -105,13 +105,13 @@ def profile(x, ops, n=100, device=None):
     x = x.to(device)
     x.requires_grad = True
     print(torch.__version__, device.type, torch.cuda.get_device_properties(0) if device.type == 'cuda' else '')
-    print(f"\n{'Params':>12s}{'GFLOPS':>12s}{'forward (ms)':>16s}{'backward (ms)':>16s}{'input':>24s}{'output':>24s}")
+    print(f"\n{'Params':>12s}{'GFLOPs':>12s}{'forward (ms)':>16s}{'backward (ms)':>16s}{'input':>24s}{'output':>24s}")
     for m in ops if isinstance(ops, list) else [ops]:
         m = m.to(device) if hasattr(m, 'to') else m  # device
         m = m.half() if hasattr(m, 'half') and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m  # type
         dtf, dtb, t = 0., 0., [0., 0., 0.]  # dt forward, backward
         try:
-            flops = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2  # GFLOPS
+            flops = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2  # GFLOPs
         except:
             flops = 0
 
@@ -219,13 +219,13 @@ def model_info(model, verbose=False, img_size=640):
             print('%5g %40s %9s %12g %20s %10.3g %10.3g' %
                   (i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std()))
 
-    try:  # FLOPS
+    try:  # FLOPs
         from thop import profile
         stride = max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32
         img = torch.zeros((1, model.yaml.get('ch', 3), stride, stride), device=next(model.parameters()).device)  # input
-        flops = profile(deepcopy(model), inputs=(img,), verbose=False)[0] / 1E9 * 2  # stride GFLOPS
+        flops = profile(deepcopy(model), inputs=(img,), verbose=False)[0] / 1E9 * 2  # stride GFLOPs
         img_size = img_size if isinstance(img_size, list) else [img_size, img_size]  # expand if int/float
-        fs = ', %.1f GFLOPS' % (flops * img_size[0] / stride * img_size[1] / stride)  # 640x640 GFLOPS
+        fs = ', %.1f GFLOPs' % (flops * img_size[0] / stride * img_size[1] / stride)  # 640x640 GFLOPs
     except (ImportError, Exception):
         fs = ''