fix compatibility for hyper config (#1146)
* fix/hyper * Hyp giou check to train.py * restore general.py * train.py overwrite fix * restore general.py and pep8 update Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>pull/1153/head
parent
4d3680c81d
commit
c67e72200e
11
train.py
11
train.py
|
@ -5,6 +5,7 @@ import random
|
||||||
import shutil
|
import shutil
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from warnings import warn
|
||||||
|
|
||||||
import math
|
import math
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -430,9 +431,8 @@ if __name__ == '__main__':
|
||||||
opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test)
|
opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test)
|
||||||
log_dir = increment_dir(Path(opt.logdir) / 'exp', opt.name) # runs/exp1
|
log_dir = increment_dir(Path(opt.logdir) / 'exp', opt.name) # runs/exp1
|
||||||
|
|
||||||
device = select_device(opt.device, batch_size=opt.batch_size)
|
|
||||||
|
|
||||||
# DDP mode
|
# DDP mode
|
||||||
|
device = select_device(opt.device, batch_size=opt.batch_size)
|
||||||
if opt.local_rank != -1:
|
if opt.local_rank != -1:
|
||||||
assert torch.cuda.device_count() > opt.local_rank
|
assert torch.cuda.device_count() > opt.local_rank
|
||||||
torch.cuda.set_device(opt.local_rank)
|
torch.cuda.set_device(opt.local_rank)
|
||||||
|
@ -441,11 +441,16 @@ if __name__ == '__main__':
|
||||||
assert opt.batch_size % opt.world_size == 0, '--batch-size must be multiple of CUDA device count'
|
assert opt.batch_size % opt.world_size == 0, '--batch-size must be multiple of CUDA device count'
|
||||||
opt.batch_size = opt.total_batch_size // opt.world_size
|
opt.batch_size = opt.total_batch_size // opt.world_size
|
||||||
|
|
||||||
logger.info(opt)
|
# Hyperparameters
|
||||||
with open(opt.hyp) as f:
|
with open(opt.hyp) as f:
|
||||||
hyp = yaml.load(f, Loader=yaml.FullLoader) # load hyps
|
hyp = yaml.load(f, Loader=yaml.FullLoader) # load hyps
|
||||||
|
if 'box' not in hyp:
|
||||||
|
warn('Compatibility: %s missing "box" which was renamed from "giou" in %s' %
|
||||||
|
(opt.hyp, 'https://github.com/ultralytics/yolov5/pull/1120'))
|
||||||
|
hyp['box'] = hyp.pop('giou')
|
||||||
|
|
||||||
# Train
|
# Train
|
||||||
|
logger.info(opt)
|
||||||
if not opt.evolve:
|
if not opt.evolve:
|
||||||
tb_writer = None
|
tb_writer = None
|
||||||
if opt.global_rank in [-1, 0]:
|
if opt.global_rank in [-1, 0]:
|
||||||
|
|
|
@ -1,18 +1,18 @@
|
||||||
import glob
|
import glob
|
||||||
import logging
|
import logging
|
||||||
import math
|
|
||||||
import os
|
import os
|
||||||
import platform
|
import platform
|
||||||
import random
|
import random
|
||||||
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
import subprocess
|
import subprocess
|
||||||
import time
|
import time
|
||||||
import re
|
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from copy import copy
|
from copy import copy
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
|
import math
|
||||||
import matplotlib
|
import matplotlib
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
Loading…
Reference in New Issue