check_git_status() improvements (#1916)
* check_online() * Update general.py * update check_git_status() * reverse rev-parse order * fetch * improved responsiveness * comment * comment * remove hyp['giou'] compat warningpull/1917/head
parent
dd03b20ba5
commit
509dd51aca
7
train.py
7
train.py
|
@ -6,7 +6,6 @@ import random
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from warnings import warn
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
@ -38,7 +37,7 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def train(hyp, opt, device, tb_writer=None, wandb=None):
|
def train(hyp, opt, device, tb_writer=None, wandb=None):
|
||||||
logger.info(colorstr('Hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))
|
logger.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))
|
||||||
save_dir, epochs, batch_size, total_batch_size, weights, rank = \
|
save_dir, epochs, batch_size, total_batch_size, weights, rank = \
|
||||||
Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank
|
Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank
|
||||||
|
|
||||||
|
@ -502,10 +501,6 @@ if __name__ == '__main__':
|
||||||
# Hyperparameters
|
# Hyperparameters
|
||||||
with open(opt.hyp) as f:
|
with open(opt.hyp) as f:
|
||||||
hyp = yaml.load(f, Loader=yaml.FullLoader) # load hyps
|
hyp = yaml.load(f, Loader=yaml.FullLoader) # load hyps
|
||||||
if 'box' not in hyp:
|
|
||||||
warn('Compatibility: %s missing "box" which was renamed from "giou" in %s' %
|
|
||||||
(opt.hyp, 'https://github.com/ultralytics/yolov5/pull/1120'))
|
|
||||||
hyp['box'] = hyp.pop('giou')
|
|
||||||
|
|
||||||
# Train
|
# Train
|
||||||
logger.info(opt)
|
logger.info(opt)
|
||||||
|
|
|
@ -4,7 +4,6 @@ import glob
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import platform
|
|
||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
import subprocess
|
import subprocess
|
||||||
|
@ -35,6 +34,7 @@ def set_logging(rank=-1):
|
||||||
|
|
||||||
|
|
||||||
def init_seeds(seed=0):
|
def init_seeds(seed=0):
|
||||||
|
# Initialize random number generator (RNG) seeds
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
np.random.seed(seed)
|
np.random.seed(seed)
|
||||||
init_torch_seeds(seed)
|
init_torch_seeds(seed)
|
||||||
|
@ -46,12 +46,33 @@ def get_latest_run(search_dir='.'):
|
||||||
return max(last_list, key=os.path.getctime) if last_list else ''
|
return max(last_list, key=os.path.getctime) if last_list else ''
|
||||||
|
|
||||||
|
|
||||||
|
def check_online():
|
||||||
|
# Check internet connectivity
|
||||||
|
import socket
|
||||||
|
try:
|
||||||
|
socket.create_connection(("1.1.1.1", 53)) # check host accesability
|
||||||
|
return True
|
||||||
|
except OSError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def check_git_status():
|
def check_git_status():
|
||||||
# Suggest 'git pull' if repo is out of date
|
# Suggest 'git pull' if YOLOv5 is out of date
|
||||||
if Path('.git').exists() and platform.system() in ['Linux', 'Darwin'] and not Path('/.dockerenv').is_file():
|
print(colorstr('github: '), end='')
|
||||||
s = subprocess.check_output('if [ -d .git ]; then git fetch && git status -uno; fi', shell=True).decode('utf-8')
|
try:
|
||||||
if 'Your branch is behind' in s:
|
if Path('.git').exists() and check_online():
|
||||||
print(s[s.find('Your branch is behind'):s.find('\n\n')] + '\n')
|
url = subprocess.check_output(
|
||||||
|
'git fetch && git config --get remote.origin.url', shell=True).decode('utf-8')[:-1]
|
||||||
|
n = int(subprocess.check_output(
|
||||||
|
'git rev-list $(git rev-parse --abbrev-ref HEAD)..origin/master --count', shell=True)) # commits behind
|
||||||
|
if n > 0:
|
||||||
|
s = f"⚠️ WARNING: code is out of date by {n} {'commits' if n > 1 else 'commmit'}. " \
|
||||||
|
f"Use 'git pull' to update or 'git clone {url}' to download latest."
|
||||||
|
else:
|
||||||
|
s = f'up to date with {url} ✅'
|
||||||
|
except Exception as e:
|
||||||
|
s = str(e)
|
||||||
|
print(s)
|
||||||
|
|
||||||
|
|
||||||
def check_requirements(file='requirements.txt'):
|
def check_requirements(file='requirements.txt'):
|
||||||
|
|
Loading…
Reference in New Issue