mirror of
https://github.com/ultralytics/yolov5.git
synced 2025-06-03 14:49:29 +08:00
* Update * Logger step fix: Increment step with epochs (#8654) * enhance * revert * allow training from scratch * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update --img argument from train.py single line * fix image size from 640 to 128 * suport custom dataloader and augmentation * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * format * Update dataloaders.py * Single line return, single line comment, remove unused argument * address PR comments * fix spelling * don't augment eval set * use fstring * update augmentations.py * new maning convention for transforms * reverse if statement, inline ops * reverse if statement, inline ops * updates * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update dataloaders * Remove additional if statement * Remove is_train as redundant * Cleanup * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Cleanup2 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update classifier.py * Update augmentations.py * fix: imshow clip warning * update * Revert ToTensorV2 removal * Update classifier.py * Update normalize values, revert uint8 * normalize image using cv2 * remove dedundant comment * Update classifier.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * replace print with logger * commit steps * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com> * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Allow logging models from GenericLogger (#8676) * enhance * revert * allow training from scratch * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update --img argument from train.py single line * fix image size from 640 to 128 * suport custom dataloader and augmentation * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * format * Update dataloaders.py * Single line return, single line comment, remove unused argument * address PR comments * fix spelling * don't augment eval set * use fstring * update augmentations.py * new maning convention for transforms * reverse if statement, inline ops * reverse if statement, inline ops * updates * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update dataloaders * Remove additional if statement * Remove is_train as redundant * Cleanup * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Cleanup2 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update classifier.py * Update augmentations.py * fix: imshow clip warning * update * Revert ToTensorV2 removal * Update classifier.py * Update normalize values, revert uint8 * normalize image using cv2 * remove dedundant comment * Update classifier.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * replace print with logger * commit steps * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * support final model logging * update * update * update * update * remove curses * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update classifier.py * Update __init__.py Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com> * Update * Update * Update * Update * Update dataset download * Update dataset download * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Pass imgsz to classify_transforms() * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Cos scheduler * Cos scheduler * Remove unused args * Update * Add seed * Add seed * Update * Update * Add run(), main() * Merge master * Merge master * Update * Update * Update * Update * Update * Update * Update * Create YOLOv5 BaseModel class (#8829) * Create BaseModel * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * Hub load device fix * Update Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * Add experiment * Merge master * Attach names * weight decay = 1e-4 * weight decay = 5e-5 * update smart_optimizer console printout * fashion-mnist fix * Merge master * Update Table * Update Table * Remove destroy process group * add kwargs to forward() * fuse fix for resnet50 * nc, names fix for resnet50 * nc, names fix for resnet50 * ONNX CPU inference fix * revert * cuda * if augment or visualize * if augment or visualize * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * New smart_inference_mode() * Update README * Refactor into /classify dir * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * reset defaults * reset defaults * fix gpu predict * warmup * ema half fix * spacing * remove data * remove cache * remove denormalize * save run settings * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * verbose false on initial plots * new save_yaml() function * Update ci-testing.yml * Path(data) CI fix * Separate classification CI * fix val * fix val * fix val * smartCrossEntropyLoss * skip validation on hub load * autodownload with working dir root * str(data) * Dataset usage example * im_show normalize * im_show normalize * add imagenet simple names to multibackend * Add validation speeds * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * 24-space names * Update bash scripts * Update permissions * Add bash script arguments * remove verbose * TRT data fix * names generator fix * optimize if names * update usage * Add local loading * Verbose=False * update names printing * Add Usage examples * Add Usage examples * Add Usage examples * Add Usage examples * named_children * reshape_classifier_outputs * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update * update * fix CI * fix incorrect class substitution * fix incorrect class substitution * remove denormalize * ravel fix * cleanup * update opt file printing * update opt file printing * update defaults * add opt to checkpoint * Add warning * Add comment * plot half bug fix * Use NotImplementedError * fix export shape report * Fix TRT load * cleanup CI * profile comment * CI fix * Add cls models * avoid inplace error * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix usage examples * Update README * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update README * Update README * Update README * Update README * Update README * Update README * Update README * Update README * Update README * Update README * Update README * Update README * Update README * Update README * Update README * Update README Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
108 lines
4.1 KiB
Python
108 lines
4.1 KiB
Python
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
|
|
"""
|
|
Experimental modules
|
|
"""
|
|
import math
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from models.common import Conv
|
|
from utils.downloads import attempt_download
|
|
|
|
|
|
class Sum(nn.Module):
|
|
# Weighted sum of 2 or more layers https://arxiv.org/abs/1911.09070
|
|
def __init__(self, n, weight=False): # n: number of inputs
|
|
super().__init__()
|
|
self.weight = weight # apply weights boolean
|
|
self.iter = range(n - 1) # iter object
|
|
if weight:
|
|
self.w = nn.Parameter(-torch.arange(1.0, n) / 2, requires_grad=True) # layer weights
|
|
|
|
def forward(self, x):
|
|
y = x[0] # no weight
|
|
if self.weight:
|
|
w = torch.sigmoid(self.w) * 2
|
|
for i in self.iter:
|
|
y = y + x[i + 1] * w[i]
|
|
else:
|
|
for i in self.iter:
|
|
y = y + x[i + 1]
|
|
return y
|
|
|
|
|
|
class MixConv2d(nn.Module):
|
|
# Mixed Depth-wise Conv https://arxiv.org/abs/1907.09595
|
|
def __init__(self, c1, c2, k=(1, 3), s=1, equal_ch=True): # ch_in, ch_out, kernel, stride, ch_strategy
|
|
super().__init__()
|
|
n = len(k) # number of convolutions
|
|
if equal_ch: # equal c_ per group
|
|
i = torch.linspace(0, n - 1E-6, c2).floor() # c2 indices
|
|
c_ = [(i == g).sum() for g in range(n)] # intermediate channels
|
|
else: # equal weight.numel() per group
|
|
b = [c2] + [0] * n
|
|
a = np.eye(n + 1, n, k=-1)
|
|
a -= np.roll(a, 1, axis=1)
|
|
a *= np.array(k) ** 2
|
|
a[0] = 1
|
|
c_ = np.linalg.lstsq(a, b, rcond=None)[0].round() # solve for equal weight indices, ax = b
|
|
|
|
self.m = nn.ModuleList([
|
|
nn.Conv2d(c1, int(c_), k, s, k // 2, groups=math.gcd(c1, int(c_)), bias=False) for k, c_ in zip(k, c_)])
|
|
self.bn = nn.BatchNorm2d(c2)
|
|
self.act = nn.SiLU()
|
|
|
|
def forward(self, x):
|
|
return self.act(self.bn(torch.cat([m(x) for m in self.m], 1)))
|
|
|
|
|
|
class Ensemble(nn.ModuleList):
|
|
# Ensemble of models
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def forward(self, x, augment=False, profile=False, visualize=False):
|
|
y = [module(x, augment, profile, visualize)[0] for module in self]
|
|
# y = torch.stack(y).max(0)[0] # max ensemble
|
|
# y = torch.stack(y).mean(0) # mean ensemble
|
|
y = torch.cat(y, 1) # nms ensemble
|
|
return y, None # inference, train output
|
|
|
|
|
|
def attempt_load(weights, device=None, inplace=True, fuse=True):
|
|
# Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
|
|
from models.yolo import Detect, Model
|
|
|
|
model = Ensemble()
|
|
for w in weights if isinstance(weights, list) else [weights]:
|
|
ckpt = torch.load(attempt_download(w), map_location='cpu') # load
|
|
ckpt = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
|
|
if not hasattr(ckpt, 'stride'):
|
|
ckpt.stride = torch.tensor([32.]) # compatibility update for ResNet etc.
|
|
model.append(ckpt.fuse().eval() if fuse and hasattr(ckpt, 'fuse') else ckpt.eval()) # model in eval mode
|
|
|
|
# Compatibility updates
|
|
for m in model.modules():
|
|
t = type(m)
|
|
if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Model):
|
|
m.inplace = inplace # torch 1.7.0 compatibility
|
|
if t is Detect and not isinstance(m.anchor_grid, list):
|
|
delattr(m, 'anchor_grid')
|
|
setattr(m, 'anchor_grid', [torch.zeros(1)] * m.nl)
|
|
elif t is nn.Upsample and not hasattr(m, 'recompute_scale_factor'):
|
|
m.recompute_scale_factor = None # torch 1.11.0 compatibility
|
|
|
|
# Return model
|
|
if len(model) == 1:
|
|
return model[-1]
|
|
|
|
# Return detection ensemble
|
|
print(f'Ensemble created with {weights}\n')
|
|
for k in 'names', 'nc', 'yaml':
|
|
setattr(model, k, getattr(model[0], k))
|
|
model.stride = model[torch.argmax(torch.tensor([m.stride.max() for m in model])).int()].stride # max stride
|
|
assert all(model[0].nc == m.nc for m in model), f'Models have different class counts: {[m.nc for m in model]}'
|
|
return model
|