From c67e72200e8b5caab6b2b2fa560cdb647b46f004 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Thu, 15 Oct 2020 15:05:58 +0200 Subject: [PATCH] fix compatibility for hyper config (#1146) * fix/hyper * Hyp giou check to train.py * restore general.py * train.py overwrite fix * restore general.py and pep8 update Co-authored-by: Glenn Jocher --- train.py | 11 ++++++++--- utils/general.py | 4 ++-- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/train.py b/train.py index 80d9eec1e..42774b880 100644 --- a/train.py +++ b/train.py @@ -5,6 +5,7 @@ import random import shutil import time from pathlib import Path +from warnings import warn import math import numpy as np @@ -430,9 +431,8 @@ if __name__ == '__main__': opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test) log_dir = increment_dir(Path(opt.logdir) / 'exp', opt.name) # runs/exp1 - device = select_device(opt.device, batch_size=opt.batch_size) - # DDP mode + device = select_device(opt.device, batch_size=opt.batch_size) if opt.local_rank != -1: assert torch.cuda.device_count() > opt.local_rank torch.cuda.set_device(opt.local_rank) @@ -441,11 +441,16 @@ if __name__ == '__main__': assert opt.batch_size % opt.world_size == 0, '--batch-size must be multiple of CUDA device count' opt.batch_size = opt.total_batch_size // opt.world_size - logger.info(opt) + # Hyperparameters with open(opt.hyp) as f: hyp = yaml.load(f, Loader=yaml.FullLoader) # load hyps + if 'box' not in hyp: + warn('Compatibility: %s missing "box" which was renamed from "giou" in %s' % + (opt.hyp, 'https://github.com/ultralytics/yolov5/pull/1120')) + hyp['box'] = hyp.pop('giou') # Train + logger.info(opt) if not opt.evolve: tb_writer = None if opt.global_rank in [-1, 0]: diff --git a/utils/general.py b/utils/general.py index 3513d65cb..f8415feef 100755 --- a/utils/general.py +++ b/utils/general.py @@ -1,18 +1,18 @@ import glob import logging -import math import os import platform import random +import re import shutil import subprocess import time -import re from contextlib import contextmanager from copy import copy from pathlib import Path import cv2 +import math import matplotlib import matplotlib.pyplot as plt import numpy as np