Explicit Imports (#498)

* expand imports

* optimize

* miss

* fix
pull/504/head
Jirka Borovec 2020-08-03 00:47:36 +02:00 committed by GitHub
parent ec7a926163
commit d5b6416c87
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 93 additions and 43 deletions

View File

@ -1,10 +1,19 @@
import argparse import argparse
import os
import platform
import shutil
import time
from pathlib import Path
import cv2
import torch
import torch.backends.cudnn as cudnn import torch.backends.cudnn as cudnn
from numpy import random
from models.experimental import * from models.experimental import attempt_load
from utils.datasets import * from utils.datasets import LoadStreams, LoadImages
from utils.utils import * from utils.general import check_img_size, non_max_suppression, apply_classifier, scale_coords, xyxy2xywh, plot_one_box
from utils.torch_utils import select_device, load_classifier, time_synchronized
def detect(save_img=False): def detect(save_img=False):
@ -13,7 +22,7 @@ def detect(save_img=False):
webcam = source == '0' or source.startswith('rtsp') or source.startswith('http') or source.endswith('.txt') webcam = source == '0' or source.startswith('rtsp') or source.startswith('http') or source.endswith('.txt')
# Initialize # Initialize
device = torch_utils.select_device(opt.device) device = select_device(opt.device)
if os.path.exists(out): if os.path.exists(out):
shutil.rmtree(out) # delete output folder shutil.rmtree(out) # delete output folder
os.makedirs(out) # make new output folder os.makedirs(out) # make new output folder
@ -28,7 +37,7 @@ def detect(save_img=False):
# Second-stage classifier # Second-stage classifier
classify = False classify = False
if classify: if classify:
modelc = torch_utils.load_classifier(name='resnet101', n=2) # initialize modelc = load_classifier(name='resnet101', n=2) # initialize
modelc.load_state_dict(torch.load('weights/resnet101.pt', map_location=device)['model']) # load weights modelc.load_state_dict(torch.load('weights/resnet101.pt', map_location=device)['model']) # load weights
modelc.to(device).eval() modelc.to(device).eval()
@ -58,12 +67,12 @@ def detect(save_img=False):
img = img.unsqueeze(0) img = img.unsqueeze(0)
# Inference # Inference
t1 = torch_utils.time_synchronized() t1 = time_synchronized()
pred = model(img, augment=opt.augment)[0] pred = model(img, augment=opt.augment)[0]
# Apply NMS # Apply NMS
pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres, classes=opt.classes, agnostic=opt.agnostic_nms) pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres, classes=opt.classes, agnostic=opt.agnostic_nms)
t2 = torch_utils.time_synchronized() t2 = time_synchronized()
# Apply Classifier # Apply Classifier
if classify: if classify:

View File

@ -6,13 +6,12 @@ Usage:
""" """
dependencies = ['torch', 'yaml'] dependencies = ['torch', 'yaml']
import os import os
import torch import torch
from models.yolo import Model from models.yolo import Model
from utils import google_utils from utils.google_utils import attempt_download
def create(name, pretrained, channels, classes): def create(name, pretrained, channels, classes):
@ -32,7 +31,7 @@ def create(name, pretrained, channels, classes):
model = Model(config, channels, classes) model = Model(config, channels, classes)
if pretrained: if pretrained:
ckpt = '%s.pt' % name # checkpoint filename ckpt = '%s.pt' % name # checkpoint filename
google_utils.attempt_download(ckpt) # download if not found locally attempt_download(ckpt) # download if not found locally
state_dict = torch.load(ckpt, map_location=torch.device('cpu'))['model'].float().state_dict() # to FP32 state_dict = torch.load(ckpt, map_location=torch.device('cpu'))['model'].float().state_dict() # to FP32
state_dict = {k: v for k, v in state_dict.items() if model.state_dict()[k].shape == v.shape} # filter state_dict = {k: v for k, v in state_dict.items() if model.state_dict()[k].shape == v.shape} # filter
model.load_state_dict(state_dict, strict=False) # load model.load_state_dict(state_dict, strict=False) # load

View File

@ -1,6 +1,8 @@
# This file contains modules common to various models # This file contains modules common to various models
import math
from utils.utils import * import torch
import torch.nn as nn
def autopad(k, p=None): # kernel, padding def autopad(k, p=None): # kernel, padding

View File

@ -1,7 +1,11 @@
# This file contains experimental modules # This file contains experimental modules
from models.common import * import numpy as np
from utils import google_utils import torch
import torch.nn as nn
from models.common import Conv, DWConv
from utils.google_utils import attempt_download
class CrossConv(nn.Module): class CrossConv(nn.Module):
@ -129,7 +133,7 @@ def attempt_load(weights, map_location=None):
# Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a # Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
model = Ensemble() model = Ensemble()
for w in weights if isinstance(weights, list) else [weights]: for w in weights if isinstance(weights, list) else [weights]:
google_utils.attempt_download(w) attempt_download(w)
model.append(torch.load(w, map_location=map_location)['model'].float().fuse().eval()) # load FP32 model model.append(torch.load(w, map_location=map_location)['model'].float().fuse().eval()) # load FP32 model
if len(model) == 1: if len(model) == 1:

View File

@ -6,8 +6,9 @@ Usage:
import argparse import argparse
from models.common import * import torch
from utils import google_utils
from utils.google_utils import attempt_download
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
@ -22,7 +23,7 @@ if __name__ == '__main__':
img = torch.zeros((opt.batch_size, 3, *opt.img_size)) # image size(1,3,320,192) iDetection img = torch.zeros((opt.batch_size, 3, *opt.img_size)) # image size(1,3,320,192) iDetection
# Load PyTorch model # Load PyTorch model
google_utils.attempt_download(opt.weights) attempt_download(opt.weights)
model = torch.load(opt.weights, map_location=torch.device('cpu'))['model'].float() model = torch.load(opt.weights, map_location=torch.device('cpu'))['model'].float()
model.eval() model.eval()
model.model[-1].export = True # set Detect() layer export=True model.model[-1].export = True # set Detect() layer export=True

View File

@ -1,7 +1,16 @@
import argparse import argparse
import math
from copy import deepcopy from copy import deepcopy
from pathlib import Path
from models.experimental import * import torch
import torch.nn as nn
from models.common import Conv, Bottleneck, SPP, DWConv, Focus, BottleneckCSP, Concat
from models.experimental import MixConv2d, CrossConv, C3
from utils.general import check_anchor_order, make_divisible, check_file
from utils.torch_utils import (
time_synchronized, fuse_conv_and_bn, model_info, scale_img, initialize_weights, select_device)
class Detect(nn.Module): class Detect(nn.Module):
@ -75,7 +84,7 @@ class Model(nn.Module):
# print('Strides: %s' % m.stride.tolist()) # print('Strides: %s' % m.stride.tolist())
# Init weights, biases # Init weights, biases
torch_utils.initialize_weights(self) initialize_weights(self)
self.info() self.info()
print('') print('')
@ -86,7 +95,7 @@ class Model(nn.Module):
f = [None, 3, None] # flips (2-ud, 3-lr) f = [None, 3, None] # flips (2-ud, 3-lr)
y = [] # outputs y = [] # outputs
for si, fi in zip(s, f): for si, fi in zip(s, f):
xi = torch_utils.scale_img(x.flip(fi) if fi else x, si) xi = scale_img(x.flip(fi) if fi else x, si)
yi = self.forward_once(xi)[0] # forward yi = self.forward_once(xi)[0] # forward
# cv2.imwrite('img%g.jpg' % s, 255 * xi[0].numpy().transpose((1, 2, 0))[:, :, ::-1]) # save # cv2.imwrite('img%g.jpg' % s, 255 * xi[0].numpy().transpose((1, 2, 0))[:, :, ::-1]) # save
yi[..., :4] /= si # de-scale yi[..., :4] /= si # de-scale
@ -111,10 +120,10 @@ class Model(nn.Module):
o = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 # FLOPS o = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 # FLOPS
except: except:
o = 0 o = 0
t = torch_utils.time_synchronized() t = time_synchronized()
for _ in range(10): for _ in range(10):
_ = m(x) _ = m(x)
dt.append((torch_utils.time_synchronized() - t) * 100) dt.append((time_synchronized() - t) * 100)
print('%10.1f%10.0f%10.1fms %-40s' % (o, m.np, dt[-1], m.type)) print('%10.1f%10.0f%10.1fms %-40s' % (o, m.np, dt[-1], m.type))
x = m(x) # run x = m(x) # run
@ -149,14 +158,14 @@ class Model(nn.Module):
for m in self.model.modules(): for m in self.model.modules():
if type(m) is Conv: if type(m) is Conv:
m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatability m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatability
m.conv = torch_utils.fuse_conv_and_bn(m.conv, m.bn) # update conv m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
m.bn = None # remove batchnorm m.bn = None # remove batchnorm
m.forward = m.fuseforward # update forward m.forward = m.fuseforward # update forward
self.info() self.info()
return self return self
def info(self): # print model information def info(self): # print model information
torch_utils.model_info(self) model_info(self)
def parse_model(d, ch): # model_dict, input_channels(3) def parse_model(d, ch): # model_dict, input_channels(3)
@ -228,7 +237,7 @@ if __name__ == '__main__':
parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
opt = parser.parse_args() opt = parser.parse_args()
opt.cfg = check_file(opt.cfg) # check file opt.cfg = check_file(opt.cfg) # check file
device = torch_utils.select_device(opt.device) device = select_device(opt.device)
# Create model # Create model
model = Model(opt.cfg).to(device) model = Model(opt.cfg).to(device)

27
test.py
View File

@ -1,8 +1,21 @@
import argparse import argparse
import glob
import json import json
import os
import shutil
from pathlib import Path
from models.experimental import * import numpy as np
from utils.datasets import * import torch
import yaml
from tqdm import tqdm
from models.experimental import attempt_load
from utils.datasets import create_dataloader
from utils.general import (
coco80_to_coco91_class, check_file, check_img_size, compute_loss, non_max_suppression,
scale_coords, xyxy2xywh, clip_coords, plot_images, xywh2xyxy, box_iou, output_to_target, ap_per_class)
from utils.torch_utils import select_device, time_synchronized
def test(data, def test(data,
@ -26,7 +39,7 @@ def test(data,
device = next(model.parameters()).device # get model device device = next(model.parameters()).device # get model device
else: # called directly else: # called directly
device = torch_utils.select_device(opt.device, batch_size=batch_size) device = select_device(opt.device, batch_size=batch_size)
merge, save_txt = opt.merge, opt.save_txt # use Merge NMS, save *.txt labels merge, save_txt = opt.merge, opt.save_txt # use Merge NMS, save *.txt labels
if save_txt: if save_txt:
out = Path('inference/output') out = Path('inference/output')
@ -85,18 +98,18 @@ def test(data,
# Disable gradients # Disable gradients
with torch.no_grad(): with torch.no_grad():
# Run model # Run model
t = torch_utils.time_synchronized() t = time_synchronized()
inf_out, train_out = model(img, augment=augment) # inference and training outputs inf_out, train_out = model(img, augment=augment) # inference and training outputs
t0 += torch_utils.time_synchronized() - t t0 += time_synchronized() - t
# Compute loss # Compute loss
if training: # if model has loss hyperparameters if training: # if model has loss hyperparameters
loss += compute_loss([x.float() for x in train_out], targets, model)[1][:3] # GIoU, obj, cls loss += compute_loss([x.float() for x in train_out], targets, model)[1][:3] # GIoU, obj, cls
# Run NMS # Run NMS
t = torch_utils.time_synchronized() t = time_synchronized()
output = non_max_suppression(inf_out, conf_thres=conf_thres, iou_thres=iou_thres, merge=merge) output = non_max_suppression(inf_out, conf_thres=conf_thres, iou_thres=iou_thres, merge=merge)
t1 += torch_utils.time_synchronized() - t t1 += time_synchronized() - t
# Statistics per image # Statistics per image
for si, pred in enumerate(output): for si, pred in enumerate(output):

View File

@ -1,19 +1,32 @@
import argparse import argparse
import glob
import math
import os
import time
from pathlib import Path
from random import random
import numpy as np
import torch.distributed as dist import torch.distributed as dist
import torch.nn.functional as F import torch.nn.functional as F
import torch.optim as optim import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler import torch.optim.lr_scheduler as lr_scheduler
import torch.utils.data import torch.utils.data
import yaml
from torch.cuda import amp from torch.cuda import amp
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import test # import test.py to get mAP after each epoch import test # import test.py to get mAP after each epoch
from models.yolo import Model from models.yolo import Model
from utils import google_utils from utils.datasets import create_dataloader
from utils.datasets import * from utils.general import (
from utils.utils import * check_img_size, torch_distributed_zero_first, labels_to_class_weights, plot_labels, check_anchors,
labels_to_image_weights, compute_loss, plot_images, fitness, strip_optimizer, plot_results,
get_latest_run, check_git_status, check_file, increment_dir, print_mutation)
from utils.google_utils import attempt_download
from utils.torch_utils import init_seeds, ModelEMA, select_device
# Hyperparameters # Hyperparameters
hyp = {'lr0': 0.01, # initial learning rate (SGD=1E-2, Adam=1E-3) hyp = {'lr0': 0.01, # initial learning rate (SGD=1E-2, Adam=1E-3)
@ -119,7 +132,7 @@ def train(hyp, opt, device, tb_writer=None):
# Load Model # Load Model
with torch_distributed_zero_first(rank): with torch_distributed_zero_first(rank):
google_utils.attempt_download(weights) attempt_download(weights)
start_epoch, best_fitness = 0, 0.0 start_epoch, best_fitness = 0, 0.0
if weights.endswith('.pt'): # pytorch format if weights.endswith('.pt'): # pytorch format
ckpt = torch.load(weights, map_location=device) # load checkpoint ckpt = torch.load(weights, map_location=device) # load checkpoint
@ -167,7 +180,7 @@ def train(hyp, opt, device, tb_writer=None):
print('Using SyncBatchNorm()') print('Using SyncBatchNorm()')
# Exponential moving average # Exponential moving average
ema = torch_utils.ModelEMA(model) if rank in [-1, 0] else None ema = ModelEMA(model) if rank in [-1, 0] else None
# DDP mode # DDP mode
if cuda and rank != -1: if cuda and rank != -1:
@ -438,7 +451,7 @@ if __name__ == '__main__':
with open(opt.hyp) as f: with open(opt.hyp) as f:
hyp.update(yaml.load(f, Loader=yaml.FullLoader)) # update hyps hyp.update(yaml.load(f, Loader=yaml.FullLoader)) # update hyps
opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test) opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test)
device = torch_utils.select_device(opt.device, batch_size=opt.batch_size) device = select_device(opt.device, batch_size=opt.batch_size)
opt.total_batch_size = opt.batch_size opt.total_batch_size = opt.batch_size
opt.world_size = 1 opt.world_size = 1

View File

@ -14,7 +14,7 @@ from PIL import Image, ExifTags
from torch.utils.data import Dataset from torch.utils.data import Dataset
from tqdm import tqdm from tqdm import tqdm
from utils.utils import xyxy2xywh, xywh2xyxy, torch_distributed_zero_first from utils.general import xyxy2xywh, xywh2xyxy, torch_distributed_zero_first
help_url = 'https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data' help_url = 'https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data'
img_formats = ['.bmp', '.jpg', '.jpeg', '.png', '.tif', '.tiff', '.dng'] img_formats = ['.bmp', '.jpg', '.jpeg', '.png', '.tif', '.tiff', '.dng']

View File

@ -18,10 +18,11 @@ import torch
import torch.nn as nn import torch.nn as nn
import torchvision import torchvision
import yaml import yaml
from scipy.cluster.vq import kmeans
from scipy.signal import butter, filtfilt from scipy.signal import butter, filtfilt
from tqdm import tqdm from tqdm import tqdm
from . import torch_utils # torch_utils, google_utils from utils.torch_utils import init_seeds, is_parallel
# Set printoptions # Set printoptions
torch.set_printoptions(linewidth=320, precision=5, profile='long') torch.set_printoptions(linewidth=320, precision=5, profile='long')
@ -47,7 +48,7 @@ def torch_distributed_zero_first(local_rank: int):
def init_seeds(seed=0): def init_seeds(seed=0):
random.seed(seed) random.seed(seed)
np.random.seed(seed) np.random.seed(seed)
torch_utils.init_seeds(seed=seed) init_seeds(seed=seed)
def get_latest_run(search_dir='./runs'): def get_latest_run(search_dir='./runs'):
@ -505,7 +506,7 @@ def compute_loss(p, targets, model): # predictions, targets, model
def build_targets(p, targets, model): def build_targets(p, targets, model):
# Build targets for compute_loss(), input targets(image,class,x,y,w,h) # Build targets for compute_loss(), input targets(image,class,x,y,w,h)
det = model.module.model[-1] if torch_utils.is_parallel(model) else model.model[-1] # Detect() module det = model.module.model[-1] if is_parallel(model) else model.model[-1] # Detect() module
na, nt = det.na, targets.shape[0] # number of anchors, targets na, nt = det.na, targets.shape[0] # number of anchors, targets
tcls, tbox, indices, anch = [], [], [], [] tcls, tbox, indices, anch = [], [], [], []
gain = torch.ones(7, device=targets.device) # normalized to gridspace gain gain = torch.ones(7, device=targets.device) # normalized to gridspace gain
@ -779,7 +780,6 @@ def kmean_anchors(path='./data/coco128.yaml', n=9, img_size=640, thr=4.0, gen=10
wh = wh0[(wh0 >= 2.0).any(1)] # filter > 2 pixels wh = wh0[(wh0 >= 2.0).any(1)] # filter > 2 pixels
# Kmeans calculation # Kmeans calculation
from scipy.cluster.vq import kmeans
print('Running kmeans for %g anchors on %g points...' % (n, len(wh))) print('Running kmeans for %g anchors on %g points...' % (n, len(wh)))
s = wh.std(0) # sigmas for whitening s = wh.std(0) # sigmas for whitening
k, dist = kmeans(wh / s, n, iter=30) # points, mean distance k, dist = kmeans(wh / s, n, iter=30) # points, mean distance