yolov5/models/yolo.py

338 lines
15 KiB
Python
Raw Normal View History

# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
"""
YOLO-specific modules
2021-05-26 18:06:08 +08:00
Usage:
$ python path/to/models/yolo.py --cfg yolov5s.yaml
"""
2020-05-30 08:04:54 +08:00
import argparse
import contextlib
import os
import platform
import sys
2020-07-14 05:35:47 +08:00
from copy import deepcopy
from pathlib import Path
2020-05-30 08:04:54 +08:00
FILE = Path(__file__).resolve()
ROOT = FILE.parents[1] # YOLOv5 root directory
if str(ROOT) not in sys.path:
sys.path.append(str(ROOT)) # add ROOT to PATH
if platform.system() != 'Windows':
ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
2020-12-06 19:41:37 +08:00
from models.common import *
2021-01-30 03:25:01 +08:00
from models.experimental import *
from utils.autoanchor import check_anchor_order
from utils.general import LOGGER, check_version, check_yaml, make_divisible, print_args
from utils.plots import feature_visualization
from utils.torch_utils import (fuse_conv_and_bn, initialize_weights, model_info, profile, scale_img, select_device,
time_sync)
2020-05-30 08:04:54 +08:00
try:
Merge `develop` branch into `master` (#3518) * update ci-testing.yml (#3322) * update ci-testing.yml * update greetings.yml * bring back os matrix * update ci-testing.yml (#3322) * update ci-testing.yml * update greetings.yml * bring back os matrix * Enable direct `--weights URL` definition (#3373) * Enable direct `--weights URL` definition @KalenMike this PR will enable direct --weights URL definition. Example use case: ``` python train.py --weights https://storage.googleapis.com/bucket/dir/model.pt ``` * cleanup * bug fixes * weights = attempt_download(weights) * Update experimental.py * Update hubconf.py * return bug fix * comment mirror * min_bytes * Update tutorial.ipynb (#3368) add Open in Kaggle badge * `cv2.imread(img, -1)` for IMREAD_UNCHANGED (#3379) * Update datasets.py * comment Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com> * COCO evolution fix (#3388) * COCO evolution fix * cleanup * update print * print fix * Create `is_pip()` function (#3391) Returns `True` if file is part of pip package. Useful for contextual behavior modification. ```python def is_pip(): # Is file in a pip package? return 'site-packages' in Path(__file__).absolute().parts ``` * Revert "`cv2.imread(img, -1)` for IMREAD_UNCHANGED (#3379)" (#3395) This reverts commit 21a9607e00f1365b21d8c4bd81bdbf5fc0efea24. * Update FLOPs description (#3422) * Update README.md * Changing FLOPS to FLOPs. Co-authored-by: BuildTools <unconfigured@null.spigotmc.org> * Parse URL authentication (#3424) * Parse URL authentication * urllib.parse.unquote() * improved error handling * improved error handling * remove %3F * update check_file() * Add FLOPs title to table (#3453) * Suppress jit trace warning + graph once (#3454) * Suppress jit trace warning + graph once Suppress harmless jit trace warning on TensorBoard add_graph call. Also fix multiple add_graph() calls bug, now only on batch 0. * Update train.py * Update MixUp augmentation `alpha=beta=32.0` (#3455) Per VOC empirical results https://github.com/ultralytics/yolov5/issues/3380#issuecomment-853001307 by @developer0hye * Add `timeout()` class (#3460) * Add `timeout()` class * rearrange order * Faster HSV augmentation (#3462) remove datatype conversion process that can be skipped * Add `check_git_status()` 5 second timeout (#3464) * Add check_git_status() 5 second timeout This should prevent the SSH Git bug that we were discussing @KalenMike * cleanup * replace timeout with check_output built-in timeout * Improved `check_requirements()` offline-handling (#3466) Improve robustness of `check_requirements()` function to offline environments (do not attempt pip installs when offline). * Add `output_names` argument for ONNX export with dynamic axes (#3456) * Add output names & dynamic axes for onnx export Add output_names and dynamic_axes names for all outputs in torch.onnx.export. The first four outputs of the model will have names output0, output1, output2, output3 * use first output only + cleanup Co-authored-by: Samridha Shrestha <samridha.shrestha@g42.ai> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com> * Revert FP16 `test.py` and `detect.py` inference to FP32 default (#3423) * fixed inference bug ,while use half precision * replace --use-half with --half * replace space and PEP8 in detect.py * PEP8 detect.py * update --half help comment * Update test.py * revert space Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com> * Add additional links/resources to stale.yml message (#3467) * Update stale.yml * cleanup * Update stale.yml * reformat * Update stale.yml HUB URL (#3468) * Stale `github.actor` bug fix (#3483) * Explicit `model.eval()` call `if opt.train=False` (#3475) * call model.eval() when opt.train is False call model.eval() when opt.train is False * single-line if statement * cleanup Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com> * check_requirements() exclude `opencv-python` (#3495) Fix for 3rd party or contrib versions of installed OpenCV as in https://github.com/ultralytics/yolov5/issues/3494. * Earlier `assert` for cpu and half option (#3508) * early assert for cpu and half option early assert for cpu and half option * Modified comment Modified comment * Update tutorial.ipynb (#3510) * Reduce test.py results spacing (#3511) * Update README.md (#3512) * Update README.md Minor modifications * 850 width * Update greetings.yml revert greeting change as PRs will now merge to master. Co-authored-by: Piotr Skalski <SkalskiP@users.noreply.github.com> Co-authored-by: SkalskiP <piotr.skalski92@gmail.com> Co-authored-by: Peretz Cohen <pizzaz93@users.noreply.github.com> Co-authored-by: tudoulei <34886368+tudoulei@users.noreply.github.com> Co-authored-by: chocosaj <chocosaj@users.noreply.github.com> Co-authored-by: BuildTools <unconfigured@null.spigotmc.org> Co-authored-by: Yonghye Kwon <developer.0hye@gmail.com> Co-authored-by: Sam_S <SamSamhuns@users.noreply.github.com> Co-authored-by: Samridha Shrestha <samridha.shrestha@g42.ai> Co-authored-by: edificewang <609552430@qq.com>
2021-06-08 16:22:10 +08:00
import thop # for FLOPs computation
except ImportError:
thop = None
2020-05-30 08:04:54 +08:00
class Detect(nn.Module):
stride = None # strides computed during build
onnx_dynamic = False # ONNX export parameter
export = False # export mode
def __init__(self, nc=80, anchors=(), ch=(), inplace=True): # detection layer
super().__init__()
2020-05-30 08:04:54 +08:00
self.nc = nc # number of classes
self.no = nc + 5 # number of outputs per anchor
self.nl = len(anchors) # number of detection layers
self.na = len(anchors[0]) // 2 # number of anchors
self.grid = [torch.zeros(1)] * self.nl # init grid
self.anchor_grid = [torch.zeros(1)] * self.nl # init anchor grid
self.register_buffer('anchors', torch.tensor(anchors).float().view(self.nl, -1, 2)) # shape(nl,na,2)
self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv
self.inplace = inplace # use in-place ops (e.g. slice assignment)
2020-05-30 08:04:54 +08:00
def forward(self, x):
z = [] # inference output
for i in range(self.nl):
x[i] = self.m[i](x[i]) # conv
2020-05-30 08:04:54 +08:00
bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
if not self.training: # inference
if self.onnx_dynamic or self.grid[i].shape[2:4] != x[i].shape[2:4]:
self.grid[i], self.anchor_grid[i] = self._make_grid(nx, ny, i)
2020-05-30 08:04:54 +08:00
y = x[i].sigmoid()
if self.inplace:
y[..., 0:2] = (y[..., 0:2] * 2 + self.grid[i]) * self.stride[i] # xy
y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
else: # for YOLOv5 on AWS Inferentia https://github.com/ultralytics/yolov5/pull/2953
xy, wh, conf = y.split((2, 2, self.nc + 1), 4) # y.tensor_split((2, 4, 5), 4) # torch 1.8.0
xy = (xy * 2 + self.grid[i]) * self.stride[i] # xy
wh = (wh * 2) ** 2 * self.anchor_grid[i] # wh
y = torch.cat((xy, wh, conf), 4)
2020-05-30 08:04:54 +08:00
z.append(y.view(bs, -1, self.no))
return x if self.training else (torch.cat(z, 1),) if self.export else (torch.cat(z, 1), x)
2020-05-30 08:04:54 +08:00
def _make_grid(self, nx=20, ny=20, i=0):
d = self.anchors[i].device
2022-04-08 18:55:16 +08:00
t = self.anchors[i].dtype
shape = 1, self.na, ny, nx, 2 # grid shape
2022-04-08 18:55:16 +08:00
y, x = torch.arange(ny, device=d, dtype=t), torch.arange(nx, device=d, dtype=t)
if check_version(torch.__version__, '1.10.0'): # torch>=1.10.0 meshgrid workaround for torch>=0.7 compatibility
2022-04-08 18:55:16 +08:00
yv, xv = torch.meshgrid(y, x, indexing='ij')
else:
2022-04-08 18:55:16 +08:00
yv, xv = torch.meshgrid(y, x)
grid = torch.stack((xv, yv), 2).expand(shape) - 0.5 # add grid offset, i.e. y = 2.0 * x - 0.5
anchor_grid = (self.anchors[i] * self.stride[i]).view((1, self.na, 1, 1, 2)).expand(shape)
return grid, anchor_grid
2020-05-30 08:04:54 +08:00
class Model(nn.Module):
# YOLOv5 model
2021-03-03 15:08:21 +08:00
def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None, anchors=None): # model, input channels, number of classes
super().__init__()
2020-07-14 05:35:47 +08:00
if isinstance(cfg, dict):
self.yaml = cfg # model dict
2020-06-10 05:31:04 +08:00
else: # is *.yaml
2020-07-06 04:41:21 +08:00
import yaml # for torch hub
2020-07-14 05:35:47 +08:00
self.yaml_file = Path(cfg).name
with open(cfg, encoding='ascii', errors='ignore') as f:
self.yaml = yaml.safe_load(f) # model dict
2020-05-30 08:04:54 +08:00
# Define model
ch = self.yaml['ch'] = self.yaml.get('ch', ch) # input channels
2020-07-14 05:35:47 +08:00
if nc and nc != self.yaml['nc']:
LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
2020-07-14 05:35:47 +08:00
self.yaml['nc'] = nc # override yaml value
2021-03-03 15:08:21 +08:00
if anchors:
LOGGER.info(f'Overriding model.yaml anchors with anchors={anchors}')
2021-03-03 15:08:21 +08:00
self.yaml['anchors'] = round(anchors) # override yaml value
self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist
2020-12-06 19:41:37 +08:00
self.names = [str(i) for i in range(self.yaml['nc'])] # default names
self.inplace = self.yaml.get('inplace', True)
2020-05-30 08:04:54 +08:00
# Build strides, anchors
m = self.model[-1] # Detect()
2020-07-05 14:07:26 +08:00
if isinstance(m, Detect):
s = 256 # 2x min stride
m.inplace = self.inplace
2020-07-05 14:07:26 +08:00
m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))]) # forward
check_anchor_order(m) # must be in pixel-space (not grid-space)
2020-07-05 14:07:26 +08:00
m.anchors /= m.stride.view(-1, 1, 1)
self.stride = m.stride
self._initialize_biases() # only run once
2020-05-30 08:04:54 +08:00
# Init weights, biases
initialize_weights(self)
2020-07-14 06:21:19 +08:00
self.info()
LOGGER.info('')
2020-05-30 08:04:54 +08:00
def forward(self, x, augment=False, profile=False, visualize=False):
2020-06-06 11:36:29 +08:00
if augment:
return self._forward_augment(x) # augmented inference, None
return self._forward_once(x, profile, visualize) # single-scale inference, train
2020-06-06 11:36:29 +08:00
def _forward_augment(self, x):
img_size = x.shape[-2:] # height, width
s = [1, 0.83, 0.67] # scales
f = [None, 3, None] # flips (2-ud, 3-lr)
y = [] # outputs
for si, fi in zip(s, f):
xi = scale_img(x.flip(fi) if fi else x, si, gs=int(self.stride.max()))
yi = self._forward_once(xi)[0] # forward
# cv2.imwrite(f'img_{si}.jpg', 255 * xi[0].cpu().numpy().transpose((1, 2, 0))[:, :, ::-1]) # save
yi = self._descale_pred(yi, fi, si, img_size)
y.append(yi)
y = self._clip_augmented(y) # clip augmented tails
return torch.cat(y, 1), None # augmented inference, train
def _forward_once(self, x, profile=False, visualize=False):
2020-06-06 11:22:17 +08:00
y, dt = [], [] # outputs
2020-05-30 08:04:54 +08:00
for m in self.model:
if m.f != -1: # if not from previous layer
x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
if profile:
self._profile_one_layer(m, x, dt)
2020-05-30 08:04:54 +08:00
x = m(x) # run
y.append(x if m.i in self.save else None) # save output
if visualize:
feature_visualization(x, m.type, m.i, save_dir=visualize)
2020-05-30 08:04:54 +08:00
return x
def _descale_pred(self, p, flips, scale, img_size):
# de-scale predictions following augmented inference (inverse operation)
if self.inplace:
p[..., :4] /= scale # de-scale
if flips == 2:
p[..., 1] = img_size[0] - p[..., 1] # de-flip ud
elif flips == 3:
p[..., 0] = img_size[1] - p[..., 0] # de-flip lr
else:
x, y, wh = p[..., 0:1] / scale, p[..., 1:2] / scale, p[..., 2:4] / scale # de-scale
if flips == 2:
y = img_size[0] - y # de-flip ud
elif flips == 3:
x = img_size[1] - x # de-flip lr
p = torch.cat((x, y, wh, p[..., 4:]), -1)
return p
def _clip_augmented(self, y):
# Clip YOLOv5 augmented inference tails
nl = self.model[-1].nl # number of detection layers (P3-P5)
g = sum(4 ** x for x in range(nl)) # grid points
e = 1 # exclude layer count
i = (y[0].shape[1] // g) * sum(4 ** x for x in range(e)) # indices
y[0] = y[0][:, :-i] # large
i = (y[-1].shape[1] // g) * sum(4 ** (nl - 1 - x) for x in range(e)) # indices
y[-1] = y[-1][:, i:] # small
return y
def _profile_one_layer(self, m, x, dt):
c = isinstance(m, Detect) # is final layer, copy input as inplace fix
o = thop.profile(m, inputs=(x.copy() if c else x,), verbose=False)[0] / 1E9 * 2 if thop else 0 # FLOPs
t = time_sync()
for _ in range(10):
m(x.copy() if c else x)
dt.append((time_sync() - t) * 100)
if m == self.model[0]:
2022-05-16 01:44:16 +08:00
LOGGER.info(f"{'time (ms)':>10s} {'GFLOPs':>10s} {'params':>10s} module")
LOGGER.info(f'{dt[-1]:10.2f} {o:10.2f} {m.np:10.0f} {m.type}')
if c:
LOGGER.info(f"{sum(dt):10.2f} {'-':>10s} {'-':>10s} Total")
2020-05-30 08:04:54 +08:00
def _initialize_biases(self, cf=None): # initialize biases into Detect(), cf is class frequency
2020-10-16 02:10:08 +08:00
# https://arxiv.org/abs/1708.02002 section 3.3
2020-05-30 08:04:54 +08:00
# cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1.
m = self.model[-1] # Detect() module
2020-08-03 05:23:05 +08:00
for mi, s in zip(m.m, m.stride): # from
b = mi.bias.view(m.na, -1).detach() # conv.bias(255) to (3,85)
b[:, 4] += math.log(8 / (640 / s) ** 2) # obj (8 objects per 640 image)
b[:, 5:] += math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # cls
2020-05-30 08:04:54 +08:00
mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
def _print_biases(self):
m = self.model[-1] # Detect() module
2020-08-03 05:23:05 +08:00
for mi in m.m: # from
b = mi.bias.detach().view(m.na, -1).T # conv.bias(255) to (3,85)
LOGGER.info(
('%6g Conv2d.bias:' + '%10.3g' * 6) % (mi.weight.shape[1], *b[:5].mean(1).tolist(), b[5:].mean()))
2020-05-30 08:04:54 +08:00
2020-06-10 01:51:47 +08:00
# def _print_weights(self):
# for m in self.model.modules():
# if type(m) is Bottleneck:
# LOGGER.info('%10.3g' % (m.w.detach().sigmoid() * 2)) # shortcut weights
2020-06-10 01:51:47 +08:00
2020-06-08 04:42:33 +08:00
def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers
LOGGER.info('Fusing layers... ')
2020-06-08 04:42:33 +08:00
for m in self.model.modules():
if isinstance(m, (Conv, DWConv)) and hasattr(m, 'bn'):
m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
2020-08-25 10:27:54 +08:00
delattr(m, 'bn') # remove batchnorm
m.forward = m.forward_fuse # update forward
2020-07-14 05:35:47 +08:00
self.info()
2020-07-06 14:16:50 +08:00
return self
2020-05-30 08:04:54 +08:00
def info(self, verbose=False, img_size=640): # print model information
model_info(self, verbose, img_size)
2020-07-14 05:35:47 +08:00
def _apply(self, fn):
# Apply to(), cpu(), cuda(), half() to model tensors that are not parameters or registered buffers
self = super()._apply(fn)
m = self.model[-1] # Detect()
if isinstance(m, Detect):
m.stride = fn(m.stride)
m.grid = list(map(fn, m.grid))
if isinstance(m.anchor_grid, list):
m.anchor_grid = list(map(fn, m.anchor_grid))
return self
2020-07-14 05:35:47 +08:00
def parse_model(d, ch): # model_dict, input_channels(3)
LOGGER.info(f"\n{'':>3}{'from':>18}{'n':>3}{'params':>10} {'module':<40}{'arguments':<30}")
2020-07-14 05:35:47 +08:00
anchors, nc, gd, gw = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple']
na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors # number of anchors
2020-05-30 08:04:54 +08:00
no = na * (nc + 5) # number of outputs = anchors * (classes + 5)
layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
2020-07-14 05:35:47 +08:00
for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args
2020-05-30 08:04:54 +08:00
m = eval(m) if isinstance(m, str) else m # eval strings
for j, a in enumerate(args):
with contextlib.suppress(NameError):
2020-05-30 08:04:54 +08:00
args[j] = eval(a) if isinstance(a, str) else a # eval strings
n = n_ = max(round(n * gd), 1) if n > 1 else n # depth gain
precommit: yapf (#5494) * precommit: yapf * align isort * fix # Conflicts: # utils/plots.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update setup.cfg * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update setup.cfg * Update setup.cfg * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update wandb_utils.py * Update augmentations.py * Update setup.cfg * Update yolo.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update val.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * simplify colorstr * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * val run fix * export.py last comma * Update export.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update hubconf.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * PyTorch Hub tuple fix * PyTorch Hub tuple fix2 * PyTorch Hub tuple fix3 * Update setup Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
2022-03-31 22:52:34 +08:00
if m in (Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv,
BottleneckCSP, C3, C3TR, C3SPP, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x):
2020-05-30 08:04:54 +08:00
c1, c2 = ch[f], args[0]
2021-02-16 13:21:53 +08:00
if c2 != no: # if not output
c2 = make_divisible(c2 * gw, 8)
2020-05-30 08:04:54 +08:00
args = [c1, c2, *args[1:]]
if m in [BottleneckCSP, C3, C3TR, C3Ghost, C3x]:
2021-02-16 13:21:53 +08:00
args.insert(2, n) # number of repeats
2020-05-30 08:04:54 +08:00
n = 1
elif m is nn.BatchNorm2d:
args = [ch[f]]
elif m is Concat:
c2 = sum(ch[x] for x in f)
2020-05-30 08:04:54 +08:00
elif m is Detect:
2021-02-16 13:21:53 +08:00
args.append([ch[x] for x in f])
if isinstance(args[1], int): # number of anchors
args[1] = [list(range(args[1] * 2))] * len(f)
elif m is Contract:
2021-02-16 13:21:53 +08:00
c2 = ch[f] * args[0] ** 2
elif m is Expand:
2021-02-16 13:21:53 +08:00
c2 = ch[f] // args[0] ** 2
2020-05-30 08:04:54 +08:00
else:
2021-02-16 13:21:53 +08:00
c2 = ch[f]
2020-05-30 08:04:54 +08:00
m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module
2020-05-30 08:04:54 +08:00
t = str(m)[8:-2].replace('__main__.', '') # module type
np = sum(x.numel() for x in m_.parameters()) # number params
2020-05-30 08:04:54 +08:00
m_.i, m_.f, m_.type, m_.np = i, f, t, np # attach index, 'from' index, type, number params
LOGGER.info(f'{i:>3}{str(f):>18}{n_:>3}{np:10.0f} {t:<40}{str(args):<30}') # print
2020-05-30 08:04:54 +08:00
save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
layers.append(m_)
2021-02-16 13:21:53 +08:00
if i == 0:
ch = []
2020-05-30 08:04:54 +08:00
ch.append(c2)
2020-06-11 10:11:11 +08:00
return nn.Sequential(*layers), sorted(save)
2020-05-30 08:04:54 +08:00
if __name__ == '__main__':
parser = argparse.ArgumentParser()
2020-06-25 09:00:43 +08:00
parser.add_argument('--cfg', type=str, default='yolov5s.yaml', help='model.yaml')
parser.add_argument('--batch-size', type=int, default=1, help='total batch size for all GPUs')
2020-05-30 08:04:54 +08:00
parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
parser.add_argument('--profile', action='store_true', help='profile model speed')
parser.add_argument('--line-profile', action='store_true', help='profile model speed layer by layer')
parser.add_argument('--test', action='store_true', help='test all yolo*.yaml')
2020-05-30 08:04:54 +08:00
opt = parser.parse_args()
opt.cfg = check_yaml(opt.cfg) # check YAML
print_args(vars(opt))
device = select_device(opt.device)
2020-05-30 08:04:54 +08:00
# Create model
im = torch.rand(opt.batch_size, 3, 640, 640).to(device)
2020-05-30 08:04:54 +08:00
model = Model(opt.cfg).to(device)
# Options
if opt.line_profile: # profile layer by layer
_ = model(im, profile=True)
2020-05-30 08:04:54 +08:00
elif opt.profile: # profile forward-backward
results = profile(input=im, ops=[model], n=3)
elif opt.test: # test all models
for cfg in Path(ROOT / 'models').rglob('yolo*.yaml'):
try:
_ = Model(cfg)
except Exception as e:
print(f'Error in {cfg}: {e}')
else: # report fused model summary
model.fuse()