Add colorstr() (#1887)
* Add colorful() * update * newline fix * add git description * --always * update loss scaling * update loss scaling 2 * rename to colorstr()pull/1894/head
parent
3e25f1e9e5
commit
6ab589583c
5
train.py
5
train.py
|
@ -216,8 +216,9 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
|
|||
check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
|
||||
|
||||
# Model parameters
|
||||
hyp['cls'] *= nc / 80. # scale hyp['cls'] to class count
|
||||
hyp['obj'] *= imgsz ** 2 / 640. ** 2 * 3. / nl # scale hyp['obj'] to image size and output layers
|
||||
hyp['box'] *= 3. / nl # scale to layers
|
||||
hyp['cls'] *= nc / 80. * 3. / nl # scale to classes and layers
|
||||
hyp['obj'] *= (imgsz / 640) ** 2 * 3. / nl # scale to image size and layers
|
||||
model.nc = nc # attach number of classes to model
|
||||
model.hyp = hyp # attach hyperparameters to model
|
||||
model.gr = 1.0 # iou loss ratio (obj_loss = 1.0 or iou)
|
||||
|
|
|
@ -6,6 +6,8 @@ import yaml
|
|||
from scipy.cluster.vq import kmeans
|
||||
from tqdm import tqdm
|
||||
|
||||
from utils.general import colorstr
|
||||
|
||||
|
||||
def check_anchor_order(m):
|
||||
# Check anchor order against stride order for YOLOv5 Detect() module m, and correct if necessary
|
||||
|
@ -20,7 +22,8 @@ def check_anchor_order(m):
|
|||
|
||||
def check_anchors(dataset, model, thr=4.0, imgsz=640):
|
||||
# Check anchor fit to data, recompute if necessary
|
||||
print('\nAnalyzing anchors... ', end='')
|
||||
prefix = colorstr('blue', 'bold', 'autoanchor') + ': '
|
||||
print(f'\n{prefix}Analyzing anchors... ', end='')
|
||||
m = model.module.model[-1] if hasattr(model, 'module') else model.model[-1] # Detect()
|
||||
shapes = imgsz * dataset.shapes / dataset.shapes.max(1, keepdims=True)
|
||||
scale = np.random.uniform(0.9, 1.1, size=(shapes.shape[0], 1)) # augment scale
|
||||
|
@ -35,7 +38,7 @@ def check_anchors(dataset, model, thr=4.0, imgsz=640):
|
|||
return bpr, aat
|
||||
|
||||
bpr, aat = metric(m.anchor_grid.clone().cpu().view(-1, 2))
|
||||
print('anchors/target = %.2f, Best Possible Recall (BPR) = %.4f' % (aat, bpr), end='')
|
||||
print(f'anchors/target = {aat:.2f}, Best Possible Recall (BPR) = {bpr:.4f}', end='')
|
||||
if bpr < 0.98: # threshold to recompute
|
||||
print('. Attempting to improve anchors, please wait...')
|
||||
na = m.anchor_grid.numel() // 2 # number of anchors
|
||||
|
@ -46,9 +49,9 @@ def check_anchors(dataset, model, thr=4.0, imgsz=640):
|
|||
m.anchor_grid[:] = new_anchors.clone().view_as(m.anchor_grid) # for inference
|
||||
m.anchors[:] = new_anchors.clone().view_as(m.anchors) / m.stride.to(m.anchors.device).view(-1, 1, 1) # loss
|
||||
check_anchor_order(m)
|
||||
print('New anchors saved to model. Update model *.yaml to use these anchors in the future.')
|
||||
print(f'{prefix}New anchors saved to model. Update model *.yaml to use these anchors in the future.')
|
||||
else:
|
||||
print('Original anchors better than new anchors. Proceeding with original anchors.')
|
||||
print(f'{prefix}Original anchors better than new anchors. Proceeding with original anchors.')
|
||||
print('') # newline
|
||||
|
||||
|
||||
|
@ -70,6 +73,7 @@ def kmean_anchors(path='./data/coco128.yaml', n=9, img_size=640, thr=4.0, gen=10
|
|||
from utils.autoanchor import *; _ = kmean_anchors()
|
||||
"""
|
||||
thr = 1. / thr
|
||||
prefix = colorstr('blue', 'bold', 'autoanchor') + ': '
|
||||
|
||||
def metric(k, wh): # compute metrics
|
||||
r = wh[:, None] / k[None]
|
||||
|
@ -85,9 +89,9 @@ def kmean_anchors(path='./data/coco128.yaml', n=9, img_size=640, thr=4.0, gen=10
|
|||
k = k[np.argsort(k.prod(1))] # sort small to large
|
||||
x, best = metric(k, wh0)
|
||||
bpr, aat = (best > thr).float().mean(), (x > thr).float().mean() * n # best possible recall, anch > thr
|
||||
print('thr=%.2f: %.4f best possible recall, %.2f anchors past thr' % (thr, bpr, aat))
|
||||
print('n=%g, img_size=%s, metric_all=%.3f/%.3f-mean/best, past_thr=%.3f-mean: ' %
|
||||
(n, img_size, x.mean(), best.mean(), x[x > thr].mean()), end='')
|
||||
print(f'{prefix}thr={thr:.2f}: {bpr:.4f} best possible recall, {aat:.2f} anchors past thr')
|
||||
print(f'{prefix}n={n}, img_size={img_size}, metric_all={x.mean():.3f}/{best.mean():.3f}-mean/best, '
|
||||
f'past_thr={x[x > thr].mean():.3f}-mean: ', end='')
|
||||
for i, x in enumerate(k):
|
||||
print('%i,%i' % (round(x[0]), round(x[1])), end=', ' if i < len(k) - 1 else '\n') # use in *.cfg
|
||||
return k
|
||||
|
@ -107,13 +111,12 @@ def kmean_anchors(path='./data/coco128.yaml', n=9, img_size=640, thr=4.0, gen=10
|
|||
# Filter
|
||||
i = (wh0 < 3.0).any(1).sum()
|
||||
if i:
|
||||
print('WARNING: Extremely small objects found. '
|
||||
'%g of %g labels are < 3 pixels in width or height.' % (i, len(wh0)))
|
||||
print(f'{prefix}WARNING: Extremely small objects found. {i} of {len(wh0)} labels are < 3 pixels in size.')
|
||||
wh = wh0[(wh0 >= 2.0).any(1)] # filter > 2 pixels
|
||||
# wh = wh * (np.random.rand(wh.shape[0], 1) * 0.9 + 0.1) # multiply by random scale 0-1
|
||||
|
||||
# Kmeans calculation
|
||||
print('Running kmeans for %g anchors on %g points...' % (n, len(wh)))
|
||||
print(f'{prefix}Running kmeans for {n} anchors on {len(wh)} points...')
|
||||
s = wh.std(0) # sigmas for whitening
|
||||
k, dist = kmeans(wh / s, n, iter=30) # points, mean distance
|
||||
k *= s
|
||||
|
@ -136,7 +139,7 @@ def kmean_anchors(path='./data/coco128.yaml', n=9, img_size=640, thr=4.0, gen=10
|
|||
# Evolve
|
||||
npr = np.random
|
||||
f, sh, mp, s = anchor_fitness(k), k.shape, 0.9, 0.1 # fitness, generations, mutation prob, sigma
|
||||
pbar = tqdm(range(gen), desc='Evolving anchors with Genetic Algorithm') # progress bar
|
||||
pbar = tqdm(range(gen), desc=f'{prefix}Evolving anchors with Genetic Algorithm:') # progress bar
|
||||
for _ in pbar:
|
||||
v = np.ones(sh)
|
||||
while (v == 1).all(): # mutate until a change occurs (prevent duplicates)
|
||||
|
@ -145,7 +148,7 @@ def kmean_anchors(path='./data/coco128.yaml', n=9, img_size=640, thr=4.0, gen=10
|
|||
fg = anchor_fitness(kg)
|
||||
if fg > f:
|
||||
f, k = fg, kg.copy()
|
||||
pbar.desc = 'Evolving anchors with Genetic Algorithm: fitness = %.4f' % f
|
||||
pbar.desc = f'{prefix}Evolving anchors with Genetic Algorithm: fitness = {f:.4f}'
|
||||
if verbose:
|
||||
print_results(k)
|
||||
|
||||
|
|
|
@ -47,7 +47,7 @@ def get_latest_run(search_dir='.'):
|
|||
|
||||
def check_git_status():
|
||||
# Suggest 'git pull' if repo is out of date
|
||||
if platform.system() in ['Linux', 'Darwin'] and not os.path.isfile('/.dockerenv'):
|
||||
if Path('.git').exists() and platform.system() in ['Linux', 'Darwin'] and not Path('/.dockerenv').is_file():
|
||||
s = subprocess.check_output('if [ -d .git ]; then git fetch && git status -uno; fi', shell=True).decode('utf-8')
|
||||
if 'Your branch is behind' in s:
|
||||
print(s[s.find('Your branch is behind'):s.find('\n\n')] + '\n')
|
||||
|
@ -115,6 +115,32 @@ def one_cycle(y1=0.0, y2=1.0, steps=100):
|
|||
return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1
|
||||
|
||||
|
||||
def colorstr(*input):
|
||||
# Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world')
|
||||
*prefix, str = input # color arguments, string
|
||||
colors = {'black': '\033[30m', # basic colors
|
||||
'red': '\033[31m',
|
||||
'green': '\033[32m',
|
||||
'yellow': '\033[33m',
|
||||
'blue': '\033[34m',
|
||||
'magenta': '\033[35m',
|
||||
'cyan': '\033[36m',
|
||||
'white': '\033[37m',
|
||||
'bright_black': '\033[90m', # bright colors
|
||||
'bright_red': '\033[91m',
|
||||
'bright_green': '\033[92m',
|
||||
'bright_yellow': '\033[93m',
|
||||
'bright_blue': '\033[94m',
|
||||
'bright_magenta': '\033[95m',
|
||||
'bright_cyan': '\033[96m',
|
||||
'bright_white': '\033[97m',
|
||||
'end': '\033[0m', # misc
|
||||
'bold': '\033[1m',
|
||||
'undelrine': '\033[4m'}
|
||||
|
||||
return ''.join(colors[x] for x in prefix) + str + colors['end']
|
||||
|
||||
|
||||
def labels_to_class_weights(labels, nc=80):
|
||||
# Get class weights (inverse frequency) from training labels
|
||||
if labels[0] is None: # no labels loaded
|
||||
|
|
|
@ -105,7 +105,6 @@ def compute_loss(p, targets, model): # predictions, targets, model
|
|||
|
||||
# Losses
|
||||
nt = 0 # number of targets
|
||||
no = len(p) # number of outputs
|
||||
balance = [4.0, 1.0, 0.3, 0.1, 0.03] # P3-P7
|
||||
for i, pi in enumerate(p): # layer index, layer predictions
|
||||
b, a, gj, gi = indices[i] # image, anchor, gridy, gridx
|
||||
|
@ -138,10 +137,9 @@ def compute_loss(p, targets, model): # predictions, targets, model
|
|||
|
||||
lobj += BCEobj(pi[..., 4], tobj) * balance[i] # obj loss
|
||||
|
||||
s = 3 / no # output count scaling
|
||||
lbox *= h['box'] * s
|
||||
lbox *= h['box']
|
||||
lobj *= h['obj']
|
||||
lcls *= h['cls'] * s
|
||||
lcls *= h['cls']
|
||||
bs = tobj.shape[0] # batch size
|
||||
|
||||
loss = lbox + lobj + lcls
|
||||
|
|
|
@ -3,9 +3,11 @@
|
|||
import logging
|
||||
import math
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.backends.cudnn as cudnn
|
||||
|
@ -41,9 +43,17 @@ def init_torch_seeds(seed=0):
|
|||
cudnn.benchmark, cudnn.deterministic = True, False
|
||||
|
||||
|
||||
def git_describe():
|
||||
# return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe
|
||||
if Path('.git').exists():
|
||||
return subprocess.check_output('git describe --tags --long --always', shell=True).decode('utf-8')[:-1]
|
||||
else:
|
||||
return ''
|
||||
|
||||
|
||||
def select_device(device='', batch_size=None):
|
||||
# device = 'cpu' or '0' or '0,1,2,3'
|
||||
s = f'Using torch {torch.__version__} ' # string
|
||||
s = f'YOLOv5 {git_describe()} torch {torch.__version__} ' # string
|
||||
cpu = device.lower() == 'cpu'
|
||||
if cpu:
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False
|
||||
|
@ -61,9 +71,9 @@ def select_device(device='', batch_size=None):
|
|||
p = torch.cuda.get_device_properties(i)
|
||||
s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / 1024 ** 2}MB)\n" # bytes to MB
|
||||
else:
|
||||
s += 'CPU'
|
||||
s += 'CPU\n'
|
||||
|
||||
logger.info(f'{s}\n') # skip a line
|
||||
logger.info(s) # skip a line
|
||||
return torch.device('cuda:0' if cuda else 'cpu')
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue