mirror of
https://github.com/ultralytics/yolov5.git
synced 2025-06-03 14:49:29 +08:00
Merge remote-tracking branch 'origin/master'
This commit is contained in:
commit
01a67a9742
@ -31,7 +31,7 @@ if __name__ == '__main__':
|
|||||||
# TorchScript export
|
# TorchScript export
|
||||||
try:
|
try:
|
||||||
print('\nStarting TorchScript export with torch %s...' % torch.__version__)
|
print('\nStarting TorchScript export with torch %s...' % torch.__version__)
|
||||||
f = opt.weights.replace('.pt', '.torchscript') # filename
|
f = opt.weights.replace('.pt', '.torchscript.pt') # filename
|
||||||
ts = torch.jit.trace(model, img)
|
ts = torch.jit.trace(model, img)
|
||||||
ts.save(f)
|
ts.save(f)
|
||||||
print('TorchScript export success, saved as %s' % f)
|
print('TorchScript export success, saved as %s' % f)
|
||||||
@ -62,7 +62,7 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
print('\nStarting CoreML export with coremltools %s...' % ct.__version__)
|
print('\nStarting CoreML export with coremltools %s...' % ct.__version__)
|
||||||
# convert model from torchscript and apply pixel scaling as per detect.py
|
# convert model from torchscript and apply pixel scaling as per detect.py
|
||||||
model = ct.convert(ts, inputs=[ct.ImageType(name='images', shape=img.shape, scale=1/255.0, bias=[0, 0, 0])])
|
model = ct.convert(ts, inputs=[ct.ImageType(name='images', shape=img.shape, scale=1 / 255.0, bias=[0, 0, 0])])
|
||||||
f = opt.weights.replace('.pt', '.mlmodel') # filename
|
f = opt.weights.replace('.pt', '.mlmodel') # filename
|
||||||
model.save(f)
|
model.save(f)
|
||||||
print('CoreML export success, saved as %s' % f)
|
print('CoreML export success, saved as %s' % f)
|
||||||
|
13
train.py
13
train.py
@ -44,7 +44,7 @@ hyp = {'optimizer': 'SGD', # ['adam', 'SGD', None] if none, default is SGD
|
|||||||
|
|
||||||
def train(hyp):
|
def train(hyp):
|
||||||
print(f'Hyperparameters {hyp}')
|
print(f'Hyperparameters {hyp}')
|
||||||
log_dir = tb_writer.log_dir # run directory
|
log_dir = tb_writer.log_dir if tb_writer else 'runs/evolution' # run directory
|
||||||
wdir = str(Path(log_dir) / 'weights') + os.sep # weights directory
|
wdir = str(Path(log_dir) / 'weights') + os.sep # weights directory
|
||||||
|
|
||||||
os.makedirs(wdir, exist_ok=True)
|
os.makedirs(wdir, exist_ok=True)
|
||||||
@ -387,7 +387,10 @@ if __name__ == '__main__':
|
|||||||
opt.weights = last if opt.resume and not opt.weights else opt.weights
|
opt.weights = last if opt.resume and not opt.weights else opt.weights
|
||||||
opt.cfg = check_file(opt.cfg) # check file
|
opt.cfg = check_file(opt.cfg) # check file
|
||||||
opt.data = check_file(opt.data) # check file
|
opt.data = check_file(opt.data) # check file
|
||||||
opt.hyp = check_file(opt.hyp) if opt.hyp else '' # check file
|
if opt.hyp: # update hyps
|
||||||
|
opt.hyp = check_file(opt.hyp) # check file
|
||||||
|
with open(opt.hyp) as f:
|
||||||
|
hyp.update(yaml.load(f, Loader=yaml.FullLoader)) # update hyps
|
||||||
print(opt)
|
print(opt)
|
||||||
opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test)
|
opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test)
|
||||||
device = torch_utils.select_device(opt.device, apex=mixed_precision, batch_size=opt.batch_size)
|
device = torch_utils.select_device(opt.device, apex=mixed_precision, batch_size=opt.batch_size)
|
||||||
@ -396,12 +399,8 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
# Train
|
# Train
|
||||||
if not opt.evolve:
|
if not opt.evolve:
|
||||||
print('Start Tensorboard with "tensorboard --logdir=runs", view at http://localhost:6006/')
|
|
||||||
tb_writer = SummaryWriter(log_dir=increment_dir('runs/exp', opt.name))
|
tb_writer = SummaryWriter(log_dir=increment_dir('runs/exp', opt.name))
|
||||||
if opt.hyp: # update hyps
|
print('Start Tensorboard with "tensorboard --logdir=runs", view at http://localhost:6006/')
|
||||||
with open(opt.hyp) as f:
|
|
||||||
hyp.update(yaml.load(f, Loader=yaml.FullLoader))
|
|
||||||
|
|
||||||
train(hyp)
|
train(hyp)
|
||||||
|
|
||||||
# Evolve hyperparameters (optional)
|
# Evolve hyperparameters (optional)
|
||||||
|
@ -26,6 +26,11 @@ for orientation in ExifTags.TAGS.keys():
|
|||||||
break
|
break
|
||||||
|
|
||||||
|
|
||||||
|
def get_hash(files):
|
||||||
|
# Returns a single hash value of a list of files
|
||||||
|
return sum(os.path.getsize(f) for f in files if os.path.isfile(f))
|
||||||
|
|
||||||
|
|
||||||
def exif_size(img):
|
def exif_size(img):
|
||||||
# Returns exif-corrected PIL size
|
# Returns exif-corrected PIL size
|
||||||
s = img.size # (width, height)
|
s = img.size # (width, height)
|
||||||
@ -280,7 +285,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
|
|||||||
def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, rect=False, image_weights=False,
|
def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, rect=False, image_weights=False,
|
||||||
cache_images=False, single_cls=False, stride=32, pad=0.0):
|
cache_images=False, single_cls=False, stride=32, pad=0.0):
|
||||||
try:
|
try:
|
||||||
f = []
|
f = [] # image files
|
||||||
for p in path if isinstance(path, list) else [path]:
|
for p in path if isinstance(path, list) else [path]:
|
||||||
p = str(Path(p)) # os-agnostic
|
p = str(Path(p)) # os-agnostic
|
||||||
parent = str(Path(p).parent) + os.sep
|
parent = str(Path(p).parent) + os.sep
|
||||||
@ -292,7 +297,6 @@ class LoadImagesAndLabels(Dataset): # for training/testing
|
|||||||
f += glob.iglob(p + os.sep + '*.*')
|
f += glob.iglob(p + os.sep + '*.*')
|
||||||
else:
|
else:
|
||||||
raise Exception('%s does not exist' % p)
|
raise Exception('%s does not exist' % p)
|
||||||
path = p # *.npy dir
|
|
||||||
self.img_files = [x.replace('/', os.sep) for x in f if os.path.splitext(x)[-1].lower() in img_formats]
|
self.img_files = [x.replace('/', os.sep) for x in f if os.path.splitext(x)[-1].lower() in img_formats]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise Exception('Error loading data from %s: %s\nSee %s' % (path, e, help_url))
|
raise Exception('Error loading data from %s: %s\nSee %s' % (path, e, help_url))
|
||||||
@ -314,20 +318,22 @@ class LoadImagesAndLabels(Dataset): # for training/testing
|
|||||||
self.stride = stride
|
self.stride = stride
|
||||||
|
|
||||||
# Define labels
|
# Define labels
|
||||||
self.label_files = [x.replace('images', 'labels').replace(os.path.splitext(x)[-1], '.txt')
|
self.label_files = [x.replace('images', 'labels').replace(os.path.splitext(x)[-1], '.txt') for x in
|
||||||
for x in self.img_files]
|
self.img_files]
|
||||||
|
|
||||||
# Read image shapes (wh)
|
# Check cache
|
||||||
sp = path.replace('.txt', '') + '.shapes' # shapefile path
|
cache_path = str(Path(self.label_files[0]).parent) + '.cache' # cached labels
|
||||||
try:
|
if os.path.isfile(cache_path):
|
||||||
with open(sp, 'r') as f: # read existing shapefile
|
cache = torch.load(cache_path) # load
|
||||||
s = [x.split() for x in f.read().splitlines()]
|
if cache['hash'] != get_hash(self.label_files + self.img_files): # dataset changed
|
||||||
assert len(s) == n, 'Shapefile out of sync'
|
cache = self.cache_labels(cache_path) # re-cache
|
||||||
except:
|
else:
|
||||||
s = [exif_size(Image.open(f)) for f in tqdm(self.img_files, desc='Reading image shapes')]
|
cache = self.cache_labels(cache_path) # cache
|
||||||
np.savetxt(sp, s, fmt='%g') # overwrites existing (if any)
|
|
||||||
|
|
||||||
self.shapes = np.array(s, dtype=np.float64)
|
# Get labels
|
||||||
|
labels, shapes = zip(*[cache[x] for x in self.img_files])
|
||||||
|
self.shapes = np.array(shapes, dtype=np.float64)
|
||||||
|
self.labels = list(labels)
|
||||||
|
|
||||||
# Rectangular Training https://github.com/ultralytics/yolov3/issues/232
|
# Rectangular Training https://github.com/ultralytics/yolov3/issues/232
|
||||||
if self.rect:
|
if self.rect:
|
||||||
@ -337,6 +343,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
|
|||||||
irect = ar.argsort()
|
irect = ar.argsort()
|
||||||
self.img_files = [self.img_files[i] for i in irect]
|
self.img_files = [self.img_files[i] for i in irect]
|
||||||
self.label_files = [self.label_files[i] for i in irect]
|
self.label_files = [self.label_files[i] for i in irect]
|
||||||
|
self.labels = [self.labels[i] for i in irect]
|
||||||
self.shapes = s[irect] # wh
|
self.shapes = s[irect] # wh
|
||||||
ar = ar[irect]
|
ar = ar[irect]
|
||||||
|
|
||||||
@ -353,33 +360,11 @@ class LoadImagesAndLabels(Dataset): # for training/testing
|
|||||||
self.batch_shapes = np.ceil(np.array(shapes) * img_size / stride + pad).astype(np.int) * stride
|
self.batch_shapes = np.ceil(np.array(shapes) * img_size / stride + pad).astype(np.int) * stride
|
||||||
|
|
||||||
# Cache labels
|
# Cache labels
|
||||||
self.imgs = [None] * n
|
|
||||||
self.labels = [np.zeros((0, 5), dtype=np.float32)] * n
|
|
||||||
create_datasubset, extract_bounding_boxes, labels_loaded = False, False, False
|
create_datasubset, extract_bounding_boxes, labels_loaded = False, False, False
|
||||||
nm, nf, ne, ns, nd = 0, 0, 0, 0, 0 # number missing, found, empty, datasubset, duplicate
|
nm, nf, ne, ns, nd = 0, 0, 0, 0, 0 # number missing, found, empty, datasubset, duplicate
|
||||||
np_labels_path = str(Path(self.label_files[0]).parent) + '.npy' # saved labels in *.npy file
|
|
||||||
if os.path.isfile(np_labels_path):
|
|
||||||
s = np_labels_path # print string
|
|
||||||
x = np.load(np_labels_path, allow_pickle=True)
|
|
||||||
if len(x) == n:
|
|
||||||
self.labels = x
|
|
||||||
labels_loaded = True
|
|
||||||
else:
|
|
||||||
s = path.replace('images', 'labels')
|
|
||||||
|
|
||||||
pbar = tqdm(self.label_files)
|
pbar = tqdm(self.label_files)
|
||||||
for i, file in enumerate(pbar):
|
for i, file in enumerate(pbar):
|
||||||
if labels_loaded:
|
l = self.labels[i] # label
|
||||||
l = self.labels[i]
|
|
||||||
# np.savetxt(file, l, '%g') # save *.txt from *.npy file
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
with open(file, 'r') as f:
|
|
||||||
l = np.array([x.split() for x in f.read().splitlines()], dtype=np.float32)
|
|
||||||
except:
|
|
||||||
nm += 1 # print('missing labels for image %s' % self.img_files[i]) # file missing
|
|
||||||
continue
|
|
||||||
|
|
||||||
if l.shape[0]:
|
if l.shape[0]:
|
||||||
assert l.shape[1] == 5, '> 5 label columns: %s' % file
|
assert l.shape[1] == 5, '> 5 label columns: %s' % file
|
||||||
assert (l >= 0).all(), 'negative labels: %s' % file
|
assert (l >= 0).all(), 'negative labels: %s' % file
|
||||||
@ -425,15 +410,13 @@ class LoadImagesAndLabels(Dataset): # for training/testing
|
|||||||
ne += 1 # print('empty labels for image %s' % self.img_files[i]) # file empty
|
ne += 1 # print('empty labels for image %s' % self.img_files[i]) # file empty
|
||||||
# os.system("rm '%s' '%s'" % (self.img_files[i], self.label_files[i])) # remove
|
# os.system("rm '%s' '%s'" % (self.img_files[i], self.label_files[i])) # remove
|
||||||
|
|
||||||
pbar.desc = 'Caching labels %s (%g found, %g missing, %g empty, %g duplicate, for %g images)' % (
|
pbar.desc = 'Scanning labels %s (%g found, %g missing, %g empty, %g duplicate, for %g images)' % (
|
||||||
s, nf, nm, ne, nd, n)
|
cache_path, nf, nm, ne, nd, n)
|
||||||
assert nf > 0 or n == 20288, 'No labels found in %s. See %s' % (os.path.dirname(file) + os.sep, help_url)
|
assert nf > 0, 'No labels found in %s. See %s' % (os.path.dirname(file) + os.sep, help_url)
|
||||||
if not labels_loaded and n > 1000:
|
|
||||||
print('Saving labels to %s for faster future loading' % np_labels_path)
|
|
||||||
np.save(np_labels_path, self.labels) # save for next time
|
|
||||||
|
|
||||||
# Cache images into memory for faster training (WARNING: large datasets may exceed system RAM)
|
# Cache images into memory for faster training (WARNING: large datasets may exceed system RAM)
|
||||||
if cache_images: # if training
|
self.imgs = [None] * n
|
||||||
|
if cache_images:
|
||||||
gb = 0 # Gigabytes of cached images
|
gb = 0 # Gigabytes of cached images
|
||||||
pbar = tqdm(range(len(self.img_files)), desc='Caching images')
|
pbar = tqdm(range(len(self.img_files)), desc='Caching images')
|
||||||
self.img_hw0, self.img_hw = [None] * n, [None] * n
|
self.img_hw0, self.img_hw = [None] * n, [None] * n
|
||||||
@ -442,15 +425,30 @@ class LoadImagesAndLabels(Dataset): # for training/testing
|
|||||||
gb += self.imgs[i].nbytes
|
gb += self.imgs[i].nbytes
|
||||||
pbar.desc = 'Caching images (%.1fGB)' % (gb / 1E9)
|
pbar.desc = 'Caching images (%.1fGB)' % (gb / 1E9)
|
||||||
|
|
||||||
# Detect corrupted images https://medium.com/joelthchao/programmatically-detect-corrupted-image-8c1b2006c3d3
|
def cache_labels(self, path='labels.cache'):
|
||||||
detect_corrupted_images = False
|
# Cache dataset labels, check images and read shapes
|
||||||
if detect_corrupted_images:
|
x = {} # dict
|
||||||
from skimage import io # conda install -c conda-forge scikit-image
|
pbar = tqdm(zip(self.img_files, self.label_files), desc='Scanning images', total=len(self.img_files))
|
||||||
for file in tqdm(self.img_files, desc='Detecting corrupted images'):
|
for (img, label) in pbar:
|
||||||
try:
|
try:
|
||||||
_ = io.imread(file)
|
l = []
|
||||||
except:
|
image = Image.open(img)
|
||||||
print('Corrupted image detected: %s' % file)
|
image.verify() # PIL verify
|
||||||
|
# _ = io.imread(img) # skimage verify (from skimage import io)
|
||||||
|
shape = exif_size(image) # image size
|
||||||
|
if os.path.isfile(label):
|
||||||
|
with open(label, 'r') as f:
|
||||||
|
l = np.array([x.split() for x in f.read().splitlines()], dtype=np.float32) # labels
|
||||||
|
if len(l) == 0:
|
||||||
|
l = np.zeros((0, 5), dtype=np.float32)
|
||||||
|
x[img] = [l, shape]
|
||||||
|
except Exception as e:
|
||||||
|
x[img] = None
|
||||||
|
print('WARNING: %s: %s' % (img, e))
|
||||||
|
|
||||||
|
x['hash'] = get_hash(self.label_files + self.img_files)
|
||||||
|
torch.save(x, path) # save for next time
|
||||||
|
return x
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.img_files)
|
return len(self.img_files)
|
||||||
|
@ -45,7 +45,7 @@ def get_latest_run(search_dir='./runs'):
|
|||||||
|
|
||||||
def check_git_status():
|
def check_git_status():
|
||||||
# Suggest 'git pull' if repo is out of date
|
# Suggest 'git pull' if repo is out of date
|
||||||
if platform in ['linux', 'darwin']:
|
if platform in ['linux', 'darwin'] and not os.path.isfile('/.dockerenv'):
|
||||||
s = subprocess.check_output('if [ -d .git ]; then git fetch && git status -uno; fi', shell=True).decode('utf-8')
|
s = subprocess.check_output('if [ -d .git ]; then git fetch && git status -uno; fi', shell=True).decode('utf-8')
|
||||||
if 'Your branch is behind' in s:
|
if 'Your branch is behind' in s:
|
||||||
print(s[s.find('Your branch is behind'):s.find('\n\n')] + '\n')
|
print(s[s.find('Your branch is behind'):s.find('\n\n')] + '\n')
|
||||||
@ -636,14 +636,12 @@ def strip_optimizer(f='weights/best.pt'): # from utils.utils import *; strip_op
|
|||||||
x['optimizer'] = None
|
x['optimizer'] = None
|
||||||
x['model'].half() # to FP16
|
x['model'].half() # to FP16
|
||||||
torch.save(x, f)
|
torch.save(x, f)
|
||||||
print('Optimizer stripped from %s' % f)
|
print('Optimizer stripped from %s, %.1fMB' % (f, os.path.getsize(f) / 1E6))
|
||||||
|
|
||||||
|
|
||||||
def create_pretrained(f='weights/best.pt', s='weights/pretrained.pt'): # from utils.utils import *; create_pretrained()
|
def create_pretrained(f='weights/best.pt', s='weights/pretrained.pt'): # from utils.utils import *; create_pretrained()
|
||||||
# create pretrained checkpoint 's' from 'f' (create_pretrained(x, x) for x in glob.glob('./*.pt'))
|
# create pretrained checkpoint 's' from 'f' (create_pretrained(x, x) for x in glob.glob('./*.pt'))
|
||||||
device = torch.device('cpu')
|
x = torch.load(f, map_location=torch.device('cpu'))
|
||||||
x = torch.load(s, map_location=device)
|
|
||||||
|
|
||||||
x['optimizer'] = None
|
x['optimizer'] = None
|
||||||
x['training_results'] = None
|
x['training_results'] = None
|
||||||
x['epoch'] = -1
|
x['epoch'] = -1
|
||||||
@ -651,7 +649,7 @@ def create_pretrained(f='weights/best.pt', s='weights/pretrained.pt'): # from u
|
|||||||
for p in x['model'].parameters():
|
for p in x['model'].parameters():
|
||||||
p.requires_grad = True
|
p.requires_grad = True
|
||||||
torch.save(x, s)
|
torch.save(x, s)
|
||||||
print('%s saved as pretrained checkpoint %s' % (f, s))
|
print('%s saved as pretrained checkpoint %s, %.1fMB' % (f, s, os.path.getsize(s) / 1E6))
|
||||||
|
|
||||||
|
|
||||||
def coco_class_count(path='../coco/labels/train2014/'):
|
def coco_class_count(path='../coco/labels/train2014/'):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user