fix compatibility for hyper config (#1146)
* fix/hyper * Hyp giou check to train.py * restore general.py * train.py overwrite fix * restore general.py and pep8 update Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>pull/1153/head
parent
4d3680c81d
commit
c67e72200e
utils
11
train.py
11
train.py
|
@ -5,6 +5,7 @@ import random
|
|||
import shutil
|
||||
import time
|
||||
from pathlib import Path
|
||||
from warnings import warn
|
||||
|
||||
import math
|
||||
import numpy as np
|
||||
|
@ -430,9 +431,8 @@ if __name__ == '__main__':
|
|||
opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test)
|
||||
log_dir = increment_dir(Path(opt.logdir) / 'exp', opt.name) # runs/exp1
|
||||
|
||||
device = select_device(opt.device, batch_size=opt.batch_size)
|
||||
|
||||
# DDP mode
|
||||
device = select_device(opt.device, batch_size=opt.batch_size)
|
||||
if opt.local_rank != -1:
|
||||
assert torch.cuda.device_count() > opt.local_rank
|
||||
torch.cuda.set_device(opt.local_rank)
|
||||
|
@ -441,11 +441,16 @@ if __name__ == '__main__':
|
|||
assert opt.batch_size % opt.world_size == 0, '--batch-size must be multiple of CUDA device count'
|
||||
opt.batch_size = opt.total_batch_size // opt.world_size
|
||||
|
||||
logger.info(opt)
|
||||
# Hyperparameters
|
||||
with open(opt.hyp) as f:
|
||||
hyp = yaml.load(f, Loader=yaml.FullLoader) # load hyps
|
||||
if 'box' not in hyp:
|
||||
warn('Compatibility: %s missing "box" which was renamed from "giou" in %s' %
|
||||
(opt.hyp, 'https://github.com/ultralytics/yolov5/pull/1120'))
|
||||
hyp['box'] = hyp.pop('giou')
|
||||
|
||||
# Train
|
||||
logger.info(opt)
|
||||
if not opt.evolve:
|
||||
tb_writer = None
|
||||
if opt.global_rank in [-1, 0]:
|
||||
|
|
|
@ -1,18 +1,18 @@
|
|||
import glob
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import platform
|
||||
import random
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
import time
|
||||
import re
|
||||
from contextlib import contextmanager
|
||||
from copy import copy
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import math
|
||||
import matplotlib
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
|
Loading…
Reference in New Issue