parent
ec7a926163
commit
d5b6416c87
23
detect.py
23
detect.py
|
@ -1,10 +1,19 @@
|
||||||
import argparse
|
import argparse
|
||||||
|
import os
|
||||||
|
import platform
|
||||||
|
import shutil
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import torch
|
||||||
import torch.backends.cudnn as cudnn
|
import torch.backends.cudnn as cudnn
|
||||||
|
from numpy import random
|
||||||
|
|
||||||
from models.experimental import *
|
from models.experimental import attempt_load
|
||||||
from utils.datasets import *
|
from utils.datasets import LoadStreams, LoadImages
|
||||||
from utils.utils import *
|
from utils.general import check_img_size, non_max_suppression, apply_classifier, scale_coords, xyxy2xywh, plot_one_box
|
||||||
|
from utils.torch_utils import select_device, load_classifier, time_synchronized
|
||||||
|
|
||||||
|
|
||||||
def detect(save_img=False):
|
def detect(save_img=False):
|
||||||
|
@ -13,7 +22,7 @@ def detect(save_img=False):
|
||||||
webcam = source == '0' or source.startswith('rtsp') or source.startswith('http') or source.endswith('.txt')
|
webcam = source == '0' or source.startswith('rtsp') or source.startswith('http') or source.endswith('.txt')
|
||||||
|
|
||||||
# Initialize
|
# Initialize
|
||||||
device = torch_utils.select_device(opt.device)
|
device = select_device(opt.device)
|
||||||
if os.path.exists(out):
|
if os.path.exists(out):
|
||||||
shutil.rmtree(out) # delete output folder
|
shutil.rmtree(out) # delete output folder
|
||||||
os.makedirs(out) # make new output folder
|
os.makedirs(out) # make new output folder
|
||||||
|
@ -28,7 +37,7 @@ def detect(save_img=False):
|
||||||
# Second-stage classifier
|
# Second-stage classifier
|
||||||
classify = False
|
classify = False
|
||||||
if classify:
|
if classify:
|
||||||
modelc = torch_utils.load_classifier(name='resnet101', n=2) # initialize
|
modelc = load_classifier(name='resnet101', n=2) # initialize
|
||||||
modelc.load_state_dict(torch.load('weights/resnet101.pt', map_location=device)['model']) # load weights
|
modelc.load_state_dict(torch.load('weights/resnet101.pt', map_location=device)['model']) # load weights
|
||||||
modelc.to(device).eval()
|
modelc.to(device).eval()
|
||||||
|
|
||||||
|
@ -58,12 +67,12 @@ def detect(save_img=False):
|
||||||
img = img.unsqueeze(0)
|
img = img.unsqueeze(0)
|
||||||
|
|
||||||
# Inference
|
# Inference
|
||||||
t1 = torch_utils.time_synchronized()
|
t1 = time_synchronized()
|
||||||
pred = model(img, augment=opt.augment)[0]
|
pred = model(img, augment=opt.augment)[0]
|
||||||
|
|
||||||
# Apply NMS
|
# Apply NMS
|
||||||
pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres, classes=opt.classes, agnostic=opt.agnostic_nms)
|
pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres, classes=opt.classes, agnostic=opt.agnostic_nms)
|
||||||
t2 = torch_utils.time_synchronized()
|
t2 = time_synchronized()
|
||||||
|
|
||||||
# Apply Classifier
|
# Apply Classifier
|
||||||
if classify:
|
if classify:
|
||||||
|
|
|
@ -6,13 +6,12 @@ Usage:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
dependencies = ['torch', 'yaml']
|
dependencies = ['torch', 'yaml']
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from models.yolo import Model
|
from models.yolo import Model
|
||||||
from utils import google_utils
|
from utils.google_utils import attempt_download
|
||||||
|
|
||||||
|
|
||||||
def create(name, pretrained, channels, classes):
|
def create(name, pretrained, channels, classes):
|
||||||
|
@ -32,7 +31,7 @@ def create(name, pretrained, channels, classes):
|
||||||
model = Model(config, channels, classes)
|
model = Model(config, channels, classes)
|
||||||
if pretrained:
|
if pretrained:
|
||||||
ckpt = '%s.pt' % name # checkpoint filename
|
ckpt = '%s.pt' % name # checkpoint filename
|
||||||
google_utils.attempt_download(ckpt) # download if not found locally
|
attempt_download(ckpt) # download if not found locally
|
||||||
state_dict = torch.load(ckpt, map_location=torch.device('cpu'))['model'].float().state_dict() # to FP32
|
state_dict = torch.load(ckpt, map_location=torch.device('cpu'))['model'].float().state_dict() # to FP32
|
||||||
state_dict = {k: v for k, v in state_dict.items() if model.state_dict()[k].shape == v.shape} # filter
|
state_dict = {k: v for k, v in state_dict.items() if model.state_dict()[k].shape == v.shape} # filter
|
||||||
model.load_state_dict(state_dict, strict=False) # load
|
model.load_state_dict(state_dict, strict=False) # load
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
# This file contains modules common to various models
|
# This file contains modules common to various models
|
||||||
|
import math
|
||||||
|
|
||||||
from utils.utils import *
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
def autopad(k, p=None): # kernel, padding
|
def autopad(k, p=None): # kernel, padding
|
||||||
|
|
|
@ -1,7 +1,11 @@
|
||||||
# This file contains experimental modules
|
# This file contains experimental modules
|
||||||
|
|
||||||
from models.common import *
|
import numpy as np
|
||||||
from utils import google_utils
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from models.common import Conv, DWConv
|
||||||
|
from utils.google_utils import attempt_download
|
||||||
|
|
||||||
|
|
||||||
class CrossConv(nn.Module):
|
class CrossConv(nn.Module):
|
||||||
|
@ -129,7 +133,7 @@ def attempt_load(weights, map_location=None):
|
||||||
# Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
|
# Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
|
||||||
model = Ensemble()
|
model = Ensemble()
|
||||||
for w in weights if isinstance(weights, list) else [weights]:
|
for w in weights if isinstance(weights, list) else [weights]:
|
||||||
google_utils.attempt_download(w)
|
attempt_download(w)
|
||||||
model.append(torch.load(w, map_location=map_location)['model'].float().fuse().eval()) # load FP32 model
|
model.append(torch.load(w, map_location=map_location)['model'].float().fuse().eval()) # load FP32 model
|
||||||
|
|
||||||
if len(model) == 1:
|
if len(model) == 1:
|
||||||
|
|
|
@ -6,8 +6,9 @@ Usage:
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
from models.common import *
|
import torch
|
||||||
from utils import google_utils
|
|
||||||
|
from utils.google_utils import attempt_download
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
@ -22,7 +23,7 @@ if __name__ == '__main__':
|
||||||
img = torch.zeros((opt.batch_size, 3, *opt.img_size)) # image size(1,3,320,192) iDetection
|
img = torch.zeros((opt.batch_size, 3, *opt.img_size)) # image size(1,3,320,192) iDetection
|
||||||
|
|
||||||
# Load PyTorch model
|
# Load PyTorch model
|
||||||
google_utils.attempt_download(opt.weights)
|
attempt_download(opt.weights)
|
||||||
model = torch.load(opt.weights, map_location=torch.device('cpu'))['model'].float()
|
model = torch.load(opt.weights, map_location=torch.device('cpu'))['model'].float()
|
||||||
model.eval()
|
model.eval()
|
||||||
model.model[-1].export = True # set Detect() layer export=True
|
model.model[-1].export = True # set Detect() layer export=True
|
||||||
|
|
|
@ -1,7 +1,16 @@
|
||||||
import argparse
|
import argparse
|
||||||
|
import math
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from models.experimental import *
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from models.common import Conv, Bottleneck, SPP, DWConv, Focus, BottleneckCSP, Concat
|
||||||
|
from models.experimental import MixConv2d, CrossConv, C3
|
||||||
|
from utils.general import check_anchor_order, make_divisible, check_file
|
||||||
|
from utils.torch_utils import (
|
||||||
|
time_synchronized, fuse_conv_and_bn, model_info, scale_img, initialize_weights, select_device)
|
||||||
|
|
||||||
|
|
||||||
class Detect(nn.Module):
|
class Detect(nn.Module):
|
||||||
|
@ -75,7 +84,7 @@ class Model(nn.Module):
|
||||||
# print('Strides: %s' % m.stride.tolist())
|
# print('Strides: %s' % m.stride.tolist())
|
||||||
|
|
||||||
# Init weights, biases
|
# Init weights, biases
|
||||||
torch_utils.initialize_weights(self)
|
initialize_weights(self)
|
||||||
self.info()
|
self.info()
|
||||||
print('')
|
print('')
|
||||||
|
|
||||||
|
@ -86,7 +95,7 @@ class Model(nn.Module):
|
||||||
f = [None, 3, None] # flips (2-ud, 3-lr)
|
f = [None, 3, None] # flips (2-ud, 3-lr)
|
||||||
y = [] # outputs
|
y = [] # outputs
|
||||||
for si, fi in zip(s, f):
|
for si, fi in zip(s, f):
|
||||||
xi = torch_utils.scale_img(x.flip(fi) if fi else x, si)
|
xi = scale_img(x.flip(fi) if fi else x, si)
|
||||||
yi = self.forward_once(xi)[0] # forward
|
yi = self.forward_once(xi)[0] # forward
|
||||||
# cv2.imwrite('img%g.jpg' % s, 255 * xi[0].numpy().transpose((1, 2, 0))[:, :, ::-1]) # save
|
# cv2.imwrite('img%g.jpg' % s, 255 * xi[0].numpy().transpose((1, 2, 0))[:, :, ::-1]) # save
|
||||||
yi[..., :4] /= si # de-scale
|
yi[..., :4] /= si # de-scale
|
||||||
|
@ -111,10 +120,10 @@ class Model(nn.Module):
|
||||||
o = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 # FLOPS
|
o = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 # FLOPS
|
||||||
except:
|
except:
|
||||||
o = 0
|
o = 0
|
||||||
t = torch_utils.time_synchronized()
|
t = time_synchronized()
|
||||||
for _ in range(10):
|
for _ in range(10):
|
||||||
_ = m(x)
|
_ = m(x)
|
||||||
dt.append((torch_utils.time_synchronized() - t) * 100)
|
dt.append((time_synchronized() - t) * 100)
|
||||||
print('%10.1f%10.0f%10.1fms %-40s' % (o, m.np, dt[-1], m.type))
|
print('%10.1f%10.0f%10.1fms %-40s' % (o, m.np, dt[-1], m.type))
|
||||||
|
|
||||||
x = m(x) # run
|
x = m(x) # run
|
||||||
|
@ -149,14 +158,14 @@ class Model(nn.Module):
|
||||||
for m in self.model.modules():
|
for m in self.model.modules():
|
||||||
if type(m) is Conv:
|
if type(m) is Conv:
|
||||||
m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatability
|
m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatability
|
||||||
m.conv = torch_utils.fuse_conv_and_bn(m.conv, m.bn) # update conv
|
m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
|
||||||
m.bn = None # remove batchnorm
|
m.bn = None # remove batchnorm
|
||||||
m.forward = m.fuseforward # update forward
|
m.forward = m.fuseforward # update forward
|
||||||
self.info()
|
self.info()
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def info(self): # print model information
|
def info(self): # print model information
|
||||||
torch_utils.model_info(self)
|
model_info(self)
|
||||||
|
|
||||||
|
|
||||||
def parse_model(d, ch): # model_dict, input_channels(3)
|
def parse_model(d, ch): # model_dict, input_channels(3)
|
||||||
|
@ -228,7 +237,7 @@ if __name__ == '__main__':
|
||||||
parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
|
parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
|
||||||
opt = parser.parse_args()
|
opt = parser.parse_args()
|
||||||
opt.cfg = check_file(opt.cfg) # check file
|
opt.cfg = check_file(opt.cfg) # check file
|
||||||
device = torch_utils.select_device(opt.device)
|
device = select_device(opt.device)
|
||||||
|
|
||||||
# Create model
|
# Create model
|
||||||
model = Model(opt.cfg).to(device)
|
model = Model(opt.cfg).to(device)
|
||||||
|
|
27
test.py
27
test.py
|
@ -1,8 +1,21 @@
|
||||||
import argparse
|
import argparse
|
||||||
|
import glob
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from models.experimental import *
|
import numpy as np
|
||||||
from utils.datasets import *
|
import torch
|
||||||
|
import yaml
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from models.experimental import attempt_load
|
||||||
|
from utils.datasets import create_dataloader
|
||||||
|
from utils.general import (
|
||||||
|
coco80_to_coco91_class, check_file, check_img_size, compute_loss, non_max_suppression,
|
||||||
|
scale_coords, xyxy2xywh, clip_coords, plot_images, xywh2xyxy, box_iou, output_to_target, ap_per_class)
|
||||||
|
from utils.torch_utils import select_device, time_synchronized
|
||||||
|
|
||||||
|
|
||||||
def test(data,
|
def test(data,
|
||||||
|
@ -26,7 +39,7 @@ def test(data,
|
||||||
device = next(model.parameters()).device # get model device
|
device = next(model.parameters()).device # get model device
|
||||||
|
|
||||||
else: # called directly
|
else: # called directly
|
||||||
device = torch_utils.select_device(opt.device, batch_size=batch_size)
|
device = select_device(opt.device, batch_size=batch_size)
|
||||||
merge, save_txt = opt.merge, opt.save_txt # use Merge NMS, save *.txt labels
|
merge, save_txt = opt.merge, opt.save_txt # use Merge NMS, save *.txt labels
|
||||||
if save_txt:
|
if save_txt:
|
||||||
out = Path('inference/output')
|
out = Path('inference/output')
|
||||||
|
@ -85,18 +98,18 @@ def test(data,
|
||||||
# Disable gradients
|
# Disable gradients
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
# Run model
|
# Run model
|
||||||
t = torch_utils.time_synchronized()
|
t = time_synchronized()
|
||||||
inf_out, train_out = model(img, augment=augment) # inference and training outputs
|
inf_out, train_out = model(img, augment=augment) # inference and training outputs
|
||||||
t0 += torch_utils.time_synchronized() - t
|
t0 += time_synchronized() - t
|
||||||
|
|
||||||
# Compute loss
|
# Compute loss
|
||||||
if training: # if model has loss hyperparameters
|
if training: # if model has loss hyperparameters
|
||||||
loss += compute_loss([x.float() for x in train_out], targets, model)[1][:3] # GIoU, obj, cls
|
loss += compute_loss([x.float() for x in train_out], targets, model)[1][:3] # GIoU, obj, cls
|
||||||
|
|
||||||
# Run NMS
|
# Run NMS
|
||||||
t = torch_utils.time_synchronized()
|
t = time_synchronized()
|
||||||
output = non_max_suppression(inf_out, conf_thres=conf_thres, iou_thres=iou_thres, merge=merge)
|
output = non_max_suppression(inf_out, conf_thres=conf_thres, iou_thres=iou_thres, merge=merge)
|
||||||
t1 += torch_utils.time_synchronized() - t
|
t1 += time_synchronized() - t
|
||||||
|
|
||||||
# Statistics per image
|
# Statistics per image
|
||||||
for si, pred in enumerate(output):
|
for si, pred in enumerate(output):
|
||||||
|
|
25
train.py
25
train.py
|
@ -1,19 +1,32 @@
|
||||||
import argparse
|
import argparse
|
||||||
|
import glob
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from random import random
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
import torch.optim.lr_scheduler as lr_scheduler
|
import torch.optim.lr_scheduler as lr_scheduler
|
||||||
import torch.utils.data
|
import torch.utils.data
|
||||||
|
import yaml
|
||||||
from torch.cuda import amp
|
from torch.cuda import amp
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
import test # import test.py to get mAP after each epoch
|
import test # import test.py to get mAP after each epoch
|
||||||
from models.yolo import Model
|
from models.yolo import Model
|
||||||
from utils import google_utils
|
from utils.datasets import create_dataloader
|
||||||
from utils.datasets import *
|
from utils.general import (
|
||||||
from utils.utils import *
|
check_img_size, torch_distributed_zero_first, labels_to_class_weights, plot_labels, check_anchors,
|
||||||
|
labels_to_image_weights, compute_loss, plot_images, fitness, strip_optimizer, plot_results,
|
||||||
|
get_latest_run, check_git_status, check_file, increment_dir, print_mutation)
|
||||||
|
from utils.google_utils import attempt_download
|
||||||
|
from utils.torch_utils import init_seeds, ModelEMA, select_device
|
||||||
|
|
||||||
# Hyperparameters
|
# Hyperparameters
|
||||||
hyp = {'lr0': 0.01, # initial learning rate (SGD=1E-2, Adam=1E-3)
|
hyp = {'lr0': 0.01, # initial learning rate (SGD=1E-2, Adam=1E-3)
|
||||||
|
@ -119,7 +132,7 @@ def train(hyp, opt, device, tb_writer=None):
|
||||||
|
|
||||||
# Load Model
|
# Load Model
|
||||||
with torch_distributed_zero_first(rank):
|
with torch_distributed_zero_first(rank):
|
||||||
google_utils.attempt_download(weights)
|
attempt_download(weights)
|
||||||
start_epoch, best_fitness = 0, 0.0
|
start_epoch, best_fitness = 0, 0.0
|
||||||
if weights.endswith('.pt'): # pytorch format
|
if weights.endswith('.pt'): # pytorch format
|
||||||
ckpt = torch.load(weights, map_location=device) # load checkpoint
|
ckpt = torch.load(weights, map_location=device) # load checkpoint
|
||||||
|
@ -167,7 +180,7 @@ def train(hyp, opt, device, tb_writer=None):
|
||||||
print('Using SyncBatchNorm()')
|
print('Using SyncBatchNorm()')
|
||||||
|
|
||||||
# Exponential moving average
|
# Exponential moving average
|
||||||
ema = torch_utils.ModelEMA(model) if rank in [-1, 0] else None
|
ema = ModelEMA(model) if rank in [-1, 0] else None
|
||||||
|
|
||||||
# DDP mode
|
# DDP mode
|
||||||
if cuda and rank != -1:
|
if cuda and rank != -1:
|
||||||
|
@ -438,7 +451,7 @@ if __name__ == '__main__':
|
||||||
with open(opt.hyp) as f:
|
with open(opt.hyp) as f:
|
||||||
hyp.update(yaml.load(f, Loader=yaml.FullLoader)) # update hyps
|
hyp.update(yaml.load(f, Loader=yaml.FullLoader)) # update hyps
|
||||||
opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test)
|
opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test)
|
||||||
device = torch_utils.select_device(opt.device, batch_size=opt.batch_size)
|
device = select_device(opt.device, batch_size=opt.batch_size)
|
||||||
opt.total_batch_size = opt.batch_size
|
opt.total_batch_size = opt.batch_size
|
||||||
opt.world_size = 1
|
opt.world_size = 1
|
||||||
|
|
||||||
|
|
|
@ -14,7 +14,7 @@ from PIL import Image, ExifTags
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from utils.utils import xyxy2xywh, xywh2xyxy, torch_distributed_zero_first
|
from utils.general import xyxy2xywh, xywh2xyxy, torch_distributed_zero_first
|
||||||
|
|
||||||
help_url = 'https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data'
|
help_url = 'https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data'
|
||||||
img_formats = ['.bmp', '.jpg', '.jpeg', '.png', '.tif', '.tiff', '.dng']
|
img_formats = ['.bmp', '.jpg', '.jpeg', '.png', '.tif', '.tiff', '.dng']
|
||||||
|
|
|
@ -18,10 +18,11 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torchvision
|
import torchvision
|
||||||
import yaml
|
import yaml
|
||||||
|
from scipy.cluster.vq import kmeans
|
||||||
from scipy.signal import butter, filtfilt
|
from scipy.signal import butter, filtfilt
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from . import torch_utils # torch_utils, google_utils
|
from utils.torch_utils import init_seeds, is_parallel
|
||||||
|
|
||||||
# Set printoptions
|
# Set printoptions
|
||||||
torch.set_printoptions(linewidth=320, precision=5, profile='long')
|
torch.set_printoptions(linewidth=320, precision=5, profile='long')
|
||||||
|
@ -47,7 +48,7 @@ def torch_distributed_zero_first(local_rank: int):
|
||||||
def init_seeds(seed=0):
|
def init_seeds(seed=0):
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
np.random.seed(seed)
|
np.random.seed(seed)
|
||||||
torch_utils.init_seeds(seed=seed)
|
init_seeds(seed=seed)
|
||||||
|
|
||||||
|
|
||||||
def get_latest_run(search_dir='./runs'):
|
def get_latest_run(search_dir='./runs'):
|
||||||
|
@ -505,7 +506,7 @@ def compute_loss(p, targets, model): # predictions, targets, model
|
||||||
|
|
||||||
def build_targets(p, targets, model):
|
def build_targets(p, targets, model):
|
||||||
# Build targets for compute_loss(), input targets(image,class,x,y,w,h)
|
# Build targets for compute_loss(), input targets(image,class,x,y,w,h)
|
||||||
det = model.module.model[-1] if torch_utils.is_parallel(model) else model.model[-1] # Detect() module
|
det = model.module.model[-1] if is_parallel(model) else model.model[-1] # Detect() module
|
||||||
na, nt = det.na, targets.shape[0] # number of anchors, targets
|
na, nt = det.na, targets.shape[0] # number of anchors, targets
|
||||||
tcls, tbox, indices, anch = [], [], [], []
|
tcls, tbox, indices, anch = [], [], [], []
|
||||||
gain = torch.ones(7, device=targets.device) # normalized to gridspace gain
|
gain = torch.ones(7, device=targets.device) # normalized to gridspace gain
|
||||||
|
@ -779,7 +780,6 @@ def kmean_anchors(path='./data/coco128.yaml', n=9, img_size=640, thr=4.0, gen=10
|
||||||
wh = wh0[(wh0 >= 2.0).any(1)] # filter > 2 pixels
|
wh = wh0[(wh0 >= 2.0).any(1)] # filter > 2 pixels
|
||||||
|
|
||||||
# Kmeans calculation
|
# Kmeans calculation
|
||||||
from scipy.cluster.vq import kmeans
|
|
||||||
print('Running kmeans for %g anchors on %g points...' % (n, len(wh)))
|
print('Running kmeans for %g anchors on %g points...' % (n, len(wh)))
|
||||||
s = wh.std(0) # sigmas for whitening
|
s = wh.std(0) # sigmas for whitening
|
||||||
k, dist = kmeans(wh / s, n, iter=30) # points, mean distance
|
k, dist = kmeans(wh / s, n, iter=30) # points, mean distance
|
Loading…
Reference in New Issue