check_git_status() improvements (#1916)

* check_online()

* Update general.py

* update check_git_status()

* reverse rev-parse order

* fetch

* improved responsiveness

* comment

* comment

* remove hyp['giou'] compat warning
pull/1917/head
Glenn Jocher 2021-01-12 21:51:49 -08:00 committed by GitHub
parent dd03b20ba5
commit 509dd51aca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 28 additions and 12 deletions

View File

@ -6,7 +6,6 @@ import random
import time import time
from pathlib import Path from pathlib import Path
from threading import Thread from threading import Thread
from warnings import warn
import numpy as np import numpy as np
import torch.distributed as dist import torch.distributed as dist
@ -38,7 +37,7 @@ logger = logging.getLogger(__name__)
def train(hyp, opt, device, tb_writer=None, wandb=None): def train(hyp, opt, device, tb_writer=None, wandb=None):
logger.info(colorstr('Hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items())) logger.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))
save_dir, epochs, batch_size, total_batch_size, weights, rank = \ save_dir, epochs, batch_size, total_batch_size, weights, rank = \
Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank
@ -502,10 +501,6 @@ if __name__ == '__main__':
# Hyperparameters # Hyperparameters
with open(opt.hyp) as f: with open(opt.hyp) as f:
hyp = yaml.load(f, Loader=yaml.FullLoader) # load hyps hyp = yaml.load(f, Loader=yaml.FullLoader) # load hyps
if 'box' not in hyp:
warn('Compatibility: %s missing "box" which was renamed from "giou" in %s' %
(opt.hyp, 'https://github.com/ultralytics/yolov5/pull/1120'))
hyp['box'] = hyp.pop('giou')
# Train # Train
logger.info(opt) logger.info(opt)

View File

@ -4,7 +4,6 @@ import glob
import logging import logging
import math import math
import os import os
import platform
import random import random
import re import re
import subprocess import subprocess
@ -35,6 +34,7 @@ def set_logging(rank=-1):
def init_seeds(seed=0): def init_seeds(seed=0):
# Initialize random number generator (RNG) seeds
random.seed(seed) random.seed(seed)
np.random.seed(seed) np.random.seed(seed)
init_torch_seeds(seed) init_torch_seeds(seed)
@ -46,12 +46,33 @@ def get_latest_run(search_dir='.'):
return max(last_list, key=os.path.getctime) if last_list else '' return max(last_list, key=os.path.getctime) if last_list else ''
def check_online():
# Check internet connectivity
import socket
try:
socket.create_connection(("1.1.1.1", 53)) # check host accesability
return True
except OSError:
return False
def check_git_status(): def check_git_status():
# Suggest 'git pull' if repo is out of date # Suggest 'git pull' if YOLOv5 is out of date
if Path('.git').exists() and platform.system() in ['Linux', 'Darwin'] and not Path('/.dockerenv').is_file(): print(colorstr('github: '), end='')
s = subprocess.check_output('if [ -d .git ]; then git fetch && git status -uno; fi', shell=True).decode('utf-8') try:
if 'Your branch is behind' in s: if Path('.git').exists() and check_online():
print(s[s.find('Your branch is behind'):s.find('\n\n')] + '\n') url = subprocess.check_output(
'git fetch && git config --get remote.origin.url', shell=True).decode('utf-8')[:-1]
n = int(subprocess.check_output(
'git rev-list $(git rev-parse --abbrev-ref HEAD)..origin/master --count', shell=True)) # commits behind
if n > 0:
s = f"⚠️ WARNING: code is out of date by {n} {'commits' if n > 1 else 'commmit'}. " \
f"Use 'git pull' to update or 'git clone {url}' to download latest."
else:
s = f'up to date with {url}'
except Exception as e:
s = str(e)
print(s)
def check_requirements(file='requirements.txt'): def check_requirements(file='requirements.txt'):