fix compatibility for hyper config ()

* fix/hyper

* Hyp giou check to train.py

* restore general.py

* train.py overwrite fix

* restore general.py and pep8 update

Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
pull/1153/head
Jirka Borovec 2020-10-15 15:05:58 +02:00 committed by GitHub
parent 4d3680c81d
commit c67e72200e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 10 additions and 5 deletions

View File

@ -5,6 +5,7 @@ import random
import shutil
import time
from pathlib import Path
from warnings import warn
import math
import numpy as np
@ -430,9 +431,8 @@ if __name__ == '__main__':
opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test)
log_dir = increment_dir(Path(opt.logdir) / 'exp', opt.name) # runs/exp1
device = select_device(opt.device, batch_size=opt.batch_size)
# DDP mode
device = select_device(opt.device, batch_size=opt.batch_size)
if opt.local_rank != -1:
assert torch.cuda.device_count() > opt.local_rank
torch.cuda.set_device(opt.local_rank)
@ -441,11 +441,16 @@ if __name__ == '__main__':
assert opt.batch_size % opt.world_size == 0, '--batch-size must be multiple of CUDA device count'
opt.batch_size = opt.total_batch_size // opt.world_size
logger.info(opt)
# Hyperparameters
with open(opt.hyp) as f:
hyp = yaml.load(f, Loader=yaml.FullLoader) # load hyps
if 'box' not in hyp:
warn('Compatibility: %s missing "box" which was renamed from "giou" in %s' %
(opt.hyp, 'https://github.com/ultralytics/yolov5/pull/1120'))
hyp['box'] = hyp.pop('giou')
# Train
logger.info(opt)
if not opt.evolve:
tb_writer = None
if opt.global_rank in [-1, 0]:

View File

@ -1,18 +1,18 @@
import glob
import logging
import math
import os
import platform
import random
import re
import shutil
import subprocess
import time
import re
from contextlib import contextmanager
from copy import copy
from pathlib import Path
import cv2
import math
import matplotlib
import matplotlib.pyplot as plt
import numpy as np