Unified '/project/name' results saving (#1377)

* Project/name update

* Update ci-testing.yml

* address project with path separator failure mode

* Project/name update

* address project with path separator failure mode

* Update ci-testing.yml

* detect.py default --name bug fix

* missing rstrip PR

* train/exp0 to train/exp
pull/1390/head
Glenn Jocher 2020-11-12 23:37:46 +01:00 committed by GitHub
parent 04081f8102
commit c4addd7761
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 71 additions and 78 deletions

View File

@ -66,10 +66,10 @@ jobs:
python train.py --img 256 --batch 8 --weights weights/${{ matrix.model }}.pt --cfg models/${{ matrix.model }}.yaml --epochs 1 --device $di
# detect
python detect.py --weights weights/${{ matrix.model }}.pt --device $di
python detect.py --weights runs/train/exp0/weights/last.pt --device $di
python detect.py --weights runs/train/exp/weights/last.pt --device $di
# test
python test.py --img 256 --batch 8 --weights weights/${{ matrix.model }}.pt --device $di
python test.py --img 256 --batch 8 --weights runs/train/exp0/weights/last.pt --device $di
python test.py --img 256 --batch 8 --weights runs/train/exp/weights/last.pt --device $di
python models/yolo.py --cfg models/${{ matrix.model }}.yaml # inspect
python models/export.py --img 256 --batch 1 --weights weights/${{ matrix.model }}.pt # export

View File

@ -96,7 +96,7 @@ Fusing layers...
Model Summary: 140 layers, 7.45958e+06 parameters, 0 gradients
image 1/2 data/images/bus.jpg: 640x480 4 persons, 1 buss, 1 skateboards, Done. (0.013s)
image 2/2 data/images/zidane.jpg: 384x640 2 persons, 2 ties, Done. (0.013s)
Results saved to runs/detect/exp0
Results saved to runs/detect/exp
Done. (0.124s)
```
<img src="https://user-images.githubusercontent.com/26833433/97107365-685a8d80-16c7-11eb-8c2e-83aac701d8b9.jpeg" width="500">

View File

@ -10,21 +10,18 @@ from numpy import random
from models.experimental import attempt_load
from utils.datasets import LoadStreams, LoadImages
from utils.general import check_img_size, non_max_suppression, apply_classifier, scale_coords, xyxy2xywh, \
plot_one_box, strip_optimizer, set_logging, increment_dir
plot_one_box, strip_optimizer, set_logging, increment_path
from utils.torch_utils import select_device, load_classifier, time_synchronized
def detect(save_img=False):
save_dir, source, weights, view_img, save_txt, imgsz = \
Path(opt.save_dir), opt.source, opt.weights, opt.view_img, opt.save_txt, opt.img_size
source, weights, view_img, save_txt, imgsz = opt.source, opt.weights, opt.view_img, opt.save_txt, opt.img_size
webcam = source.isnumeric() or source.endswith('.txt') or \
source.lower().startswith(('rtsp://', 'rtmp://', 'http://'))
# Directories
if save_dir == Path('runs/detect'): # if default
save_dir.mkdir(parents=True, exist_ok=True) # make base
save_dir = Path(increment_dir(save_dir / 'exp', opt.name)) # increment run
(save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make new dir
save_dir = Path(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok)) # increment run
(save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir
# Initialize
set_logging()
@ -156,12 +153,13 @@ if __name__ == '__main__':
parser.add_argument('--view-img', action='store_true', help='display results')
parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')
parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels')
parser.add_argument('--save-dir', type=str, default='runs/detect', help='directory to save results')
parser.add_argument('--name', default='', help='name to append to --save-dir: i.e. runs/{N} -> runs/{N}_{name}')
parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --class 0, or --class 0 2 3')
parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS')
parser.add_argument('--augment', action='store_true', help='augmented inference')
parser.add_argument('--update', action='store_true', help='update all models')
parser.add_argument('--project', default='runs/detect', help='save results to project/name')
parser.add_argument('--name', default='exp', help='save results to project/name')
parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
opt = parser.parse_args()
print(opt)

16
test.py
View File

@ -13,7 +13,7 @@ from models.experimental import attempt_load
from utils.datasets import create_dataloader
from utils.general import coco80_to_coco91_class, check_dataset, check_file, check_img_size, compute_loss, \
non_max_suppression, scale_coords, xyxy2xywh, clip_coords, plot_images, xywh2xyxy, box_iou, output_to_target, \
ap_per_class, set_logging, increment_dir
ap_per_class, set_logging, increment_path
from utils.torch_utils import select_device, time_synchronized
@ -46,10 +46,8 @@ def test(data,
save_txt = opt.save_txt # save *.txt labels
# Directories
if save_dir == Path('runs/test'): # if default
save_dir.mkdir(parents=True, exist_ok=True) # make base
save_dir = Path(increment_dir(save_dir / 'exp', opt.name)) # increment run
(save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make new dir
save_dir = Path(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok)) # increment run
(save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir
# Load model
model = attempt_load(weights, map_location=device) # load FP32 model
@ -279,7 +277,6 @@ if __name__ == '__main__':
parser.add_argument('--img-size', type=int, default=640, help='inference size (pixels)')
parser.add_argument('--conf-thres', type=float, default=0.001, help='object confidence threshold')
parser.add_argument('--iou-thres', type=float, default=0.65, help='IOU threshold for NMS')
parser.add_argument('--save-json', action='store_true', help='save a cocoapi-compatible JSON results file')
parser.add_argument('--task', default='val', help="'val', 'test', 'study'")
parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
parser.add_argument('--single-cls', action='store_true', help='treat as single-class dataset')
@ -287,8 +284,10 @@ if __name__ == '__main__':
parser.add_argument('--verbose', action='store_true', help='report mAP by class')
parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')
parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels')
parser.add_argument('--save-dir', type=str, default='runs/test', help='directory to save results')
parser.add_argument('--name', default='', help='name to append to --save-dir: i.e. runs/{N} -> runs/{N}_{name}')
parser.add_argument('--save-json', action='store_true', help='save a cocoapi-compatible JSON results file')
parser.add_argument('--project', default='runs/test', help='save to project/name')
parser.add_argument('--name', default='exp', help='save to project/name')
parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
opt = parser.parse_args()
opt.save_json |= opt.data.endswith('coco.yaml')
opt.data = check_file(opt.data) # check file
@ -305,7 +304,6 @@ if __name__ == '__main__':
opt.single_cls,
opt.augment,
opt.verbose,
save_dir=Path(opt.save_dir),
save_txt=opt.save_txt,
save_conf=opt.save_conf,
)

View File

@ -27,7 +27,7 @@ from utils.datasets import create_dataloader
from utils.general import (
torch_distributed_zero_first, labels_to_class_weights, plot_labels, check_anchors, labels_to_image_weights,
compute_loss, plot_images, fitness, strip_optimizer, plot_results, get_latest_run, check_dataset, check_file,
check_git_status, check_img_size, increment_dir, print_mutation, plot_evolution, set_logging, init_seeds)
check_git_status, check_img_size, increment_path, print_mutation, plot_evolution, set_logging, init_seeds)
from utils.google_utils import attempt_download
from utils.torch_utils import ModelEMA, select_device, intersect_dicts
@ -36,19 +36,20 @@ logger = logging.getLogger(__name__)
def train(hyp, opt, device, tb_writer=None, wandb=None):
logger.info(f'Hyperparameters {hyp}')
log_dir = Path(tb_writer.log_dir) if tb_writer else Path(opt.logdir) / 'evolve' # logging directory
wdir = log_dir / 'weights' # weights directory
wdir.mkdir(parents=True, exist_ok=True)
save_dir, epochs, batch_size, total_batch_size, weights, rank = \
opt.save_dir, opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank
# Directories
wdir = save_dir / 'weights'
wdir.mkdir(parents=True, exist_ok=True) # make dir
last = wdir / 'last.pt'
best = wdir / 'best.pt'
results_file = log_dir / 'results.txt'
epochs, batch_size, total_batch_size, weights, rank = \
opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank
results_file = save_dir / 'results.txt'
# Save run settings
with open(log_dir / 'hyp.yaml', 'w') as f:
with open(save_dir / 'hyp.yaml', 'w') as f:
yaml.dump(hyp, f, sort_keys=False)
with open(log_dir / 'opt.yaml', 'w') as f:
with open(save_dir / 'opt.yaml', 'w') as f:
yaml.dump(vars(opt), f, sort_keys=False)
# Configure
@ -120,8 +121,10 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
# Logging
if wandb and wandb.run is None:
id = ckpt.get('wandb_id') if 'ckpt' in locals() else None
wandb_run = wandb.init(config=opt, resume="allow", project="YOLOv5", name=log_dir.stem, id=id)
wandb_run = wandb.init(config=opt, resume="allow",
project='YOLOv5' if opt.project == 'runs/train' else Path(opt.project).stem,
name=save_dir.stem,
id=ckpt.get('wandb_id') if 'ckpt' in locals() else None)
# Resume
start_epoch, best_fitness = 0, 0.0
@ -188,7 +191,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
c = torch.tensor(labels[:, 0]) # classes
# cf = torch.bincount(c.long(), minlength=nc) + 1. # frequency
# model._initialize_biases(cf.to(device))
plot_labels(labels, save_dir=log_dir)
plot_labels(labels, save_dir=save_dir)
if tb_writer:
# tb_writer.add_hparams(hyp, {}) # causes duplicate https://github.com/ultralytics/yolov5/pull/384
tb_writer.add_histogram('classes', c, 0)
@ -215,7 +218,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
scaler = amp.GradScaler(enabled=cuda)
logger.info('Image sizes %g train, %g test\n'
'Using %g dataloader workers\nLogging results to %s\n'
'Starting training for %g epochs...' % (imgsz, imgsz_test, dataloader.num_workers, log_dir, epochs))
'Starting training for %g epochs...' % (imgsz, imgsz_test, dataloader.num_workers, save_dir, epochs))
for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------
model.train()
@ -296,7 +299,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
# Plot
if ni < 3:
f = str(log_dir / f'train_batch{ni}.jpg') # filename
f = str(save_dir / f'train_batch{ni}.jpg') # filename
result = plot_images(images=imgs, targets=targets, paths=paths, fname=f)
# if tb_writer and result is not None:
# tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch)
@ -321,7 +324,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
model=ema.ema,
single_cls=opt.single_cls,
dataloader=testloader,
save_dir=log_dir,
save_dir=save_dir,
plots=epoch == 0 or final_epoch, # plot first and last
log_imgs=opt.log_imgs if wandb else 0)
@ -369,7 +372,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
if rank in [-1, 0]:
# Strip optimizers
n = opt.name if opt.name.isnumeric() else ''
fresults, flast, fbest = log_dir / f'results{n}.txt', wdir / f'last{n}.pt', wdir / f'best{n}.pt'
fresults, flast, fbest = save_dir / f'results{n}.txt', wdir / f'last{n}.pt', wdir / f'best{n}.pt'
for f1, f2 in zip([wdir / 'last.pt', wdir / 'best.pt', results_file], [flast, fbest, fresults]):
if f1.exists():
os.rename(f1, f2) # rename
@ -378,7 +381,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
os.system('gsutil cp %s gs://%s/weights' % (f2, opt.bucket)) if opt.bucket else None # upload
# Finish
if not opt.evolve:
plot_results(save_dir=log_dir) # save as results.png
plot_results(save_dir=save_dir) # save as results.png
logger.info('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600))
dist.destroy_process_group() if rank not in [-1, 0] else None
@ -410,11 +413,11 @@ if __name__ == '__main__':
parser.add_argument('--adam', action='store_true', help='use torch.optim.Adam() optimizer')
parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode')
parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
parser.add_argument('--logdir', type=str, default='runs/train', help='logging directory')
parser.add_argument('--name', default='', help='name to append to --save-dir: i.e. runs/{N} -> runs/{N}_{name}')
parser.add_argument('--log-imgs', type=int, default=10, help='number of images for W&B logging, max 100')
parser.add_argument('--workers', type=int, default=8, help='maximum number of dataloader workers')
parser.add_argument('--project', default='runs/train', help='save to project/name')
parser.add_argument('--name', default='exp', help='save to project/name')
parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
opt = parser.parse_args()
# Set DDP variables
@ -428,19 +431,19 @@ if __name__ == '__main__':
# Resume
if opt.resume: # resume an interrupted run
ckpt = opt.resume if isinstance(opt.resume, str) else get_latest_run() # specified or most recent path
log_dir = Path(ckpt).parent.parent # runs/train/exp0
opt.save_dir = Path(ckpt).parent.parent # runs/train/exp
assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist'
with open(log_dir / 'opt.yaml') as f:
with open(opt.save_dir / 'opt.yaml') as f:
opt = argparse.Namespace(**yaml.load(f, Loader=yaml.FullLoader)) # replace
opt.cfg, opt.weights, opt.resume = '', ckpt, True
logger.info('Resuming training from %s' % ckpt)
else:
# opt.hyp = opt.hyp or ('hyp.finetune.yaml' if opt.weights else 'hyp.scratch.yaml')
opt.data, opt.cfg, opt.hyp = check_file(opt.data), check_file(opt.cfg), check_file(opt.hyp) # check files
assert len(opt.cfg) or len(opt.weights), 'either --cfg or --weights must be specified'
opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test)
log_dir = increment_dir(Path(opt.logdir) / 'exp', opt.name) # runs/exp1
opt.name = 'evolve' if opt.evolve else opt.name
opt.save_dir = Path(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok)) # increment run
# DDP mode
device = select_device(opt.device, batch_size=opt.batch_size)
@ -466,8 +469,8 @@ if __name__ == '__main__':
tb_writer, wandb = None, None # init loggers
if opt.global_rank in [-1, 0]:
# Tensorboard
logger.info(f'Start Tensorboard with "tensorboard --logdir {opt.logdir}", view at http://localhost:6006/')
tb_writer = SummaryWriter(log_dir=log_dir) # runs/train/exp0
logger.info(f'Start Tensorboard with "tensorboard --logdir {opt.project}", view at http://localhost:6006/')
tb_writer = SummaryWriter(opt.save_dir) # runs/train/exp
# W&B
try:
@ -514,7 +517,7 @@ if __name__ == '__main__':
assert opt.local_rank == -1, 'DDP mode not implemented for --evolve'
opt.notest, opt.nosave = True, True # only test/save final epoch
# ei = [isinstance(x, (int, float)) for x in hyp.values()] # evolvable indices
yaml_file = Path(opt.logdir) / 'evolve' / 'hyp_evolved.yaml' # save best result here
yaml_file = opt.save_dir / 'hyp_evolved.yaml' # save best result here
if opt.bucket:
os.system('gsutil cp gs://%s/evolve.txt .' % opt.bucket) # download evolve.txt if exists

26
tutorial.ipynb vendored
View File

@ -597,7 +597,7 @@
},
"source": [
"!python detect.py --weights yolov5s.pt --img 640 --conf 0.25 --source data/images/\n",
"Image(filename='runs/detect/exp0/zidane.jpg', width=600)"
"Image(filename='runs/detect/exp/zidane.jpg', width=600)"
],
"execution_count": null,
"outputs": [
@ -611,7 +611,7 @@
"Model Summary: 140 layers, 7.45958e+06 parameters, 0 gradients\n",
"image 1/2 /content/yolov5/data/images/bus.jpg: 640x480 4 persons, 1 buss, 1 skateboards, Done. (0.012s)\n",
"image 2/2 /content/yolov5/data/images/zidane.jpg: 384x640 2 persons, 2 ties, Done. (0.012s)\n",
"Results saved to runs/detect/exp0\n",
"Results saved to runs/detect/exp\n",
"Done. (0.113s)\n"
],
"name": "stdout"
@ -887,7 +887,7 @@
"source": [
"Train a YOLOv5s model on [COCO128](https://www.kaggle.com/ultralytics/coco128) with `--data coco128.yaml`, starting from pretrained `--weights yolov5s.pt`, or from randomly initialized `--weights '' --cfg yolov5s.yaml`. Models are downloaded automatically from the [latest YOLOv5 release](https://github.com/ultralytics/yolov5/releases), and **COCO, COCO128, and VOC datasets are downloaded automatically** on first use.\n",
"\n",
"All training results are saved to `runs/train/` with incrementing run directories, i.e. `runs/train/exp0`, `runs/train/exp1` etc.\n"
"All training results are saved to `runs/train/` with incrementing run directories, i.e. `runs/train/exp2`, `runs/train/exp3` etc.\n"
]
},
{
@ -969,7 +969,7 @@
"Analyzing anchors... anchors/target = 4.26, Best Possible Recall (BPR) = 0.9946\n",
"Image sizes 640 train, 640 test\n",
"Using 2 dataloader workers\n",
"Logging results to runs/train/exp0\n",
"Logging results to runs/train/exp\n",
"Starting training for 3 epochs...\n",
"\n",
" Epoch gpu_mem box obj cls total targets img_size\n",
@ -986,8 +986,8 @@
" 2/2 3.17G 0.04445 0.06545 0.01666 0.1266 149 640: 100% 8/8 [00:01<00:00, 4.33it/s]\n",
" Class Images Targets P R mAP@.5 mAP@.5:.95: 100% 8/8 [00:02<00:00, 2.78it/s]\n",
" all 128 929 0.395 0.766 0.701 0.455\n",
"Optimizer stripped from runs/train/exp0/weights/last.pt, 15.2MB\n",
"Optimizer stripped from runs/train/exp0/weights/best.pt, 15.2MB\n",
"Optimizer stripped from runs/train/exp/weights/last.pt, 15.2MB\n",
"Optimizer stripped from runs/train/exp/weights/best.pt, 15.2MB\n",
"3 epochs completed in 0.005 hours.\n",
"\n"
],
@ -1030,7 +1030,7 @@
"source": [
"## Local Logging\n",
"\n",
"All results are logged by default to `runs/train`, with a new experiment directory created for each new training as `runs/train/exp0`, `runs/train/exp1`, etc. View train and test jpgs to see mosaics, labels, predictions and augmentation effects. Note a **Mosaic Dataloader** is used for training (shown below), a new concept developed by Ultralytics and first featured in [YOLOv4](https://arxiv.org/abs/2004.10934)."
"All results are logged by default to `runs/train`, with a new experiment directory created for each new training as `runs/train/exp2`, `runs/train/exp3`, etc. View train and test jpgs to see mosaics, labels, predictions and augmentation effects. Note a **Mosaic Dataloader** is used for training (shown below), a new concept developed by Ultralytics and first featured in [YOLOv4](https://arxiv.org/abs/2004.10934)."
]
},
{
@ -1039,9 +1039,9 @@
"id": "riPdhraOTCO0"
},
"source": [
"Image(filename='runs/train/exp0/train_batch0.jpg', width=800) # train batch 0 mosaics and labels\n",
"Image(filename='runs/train/exp0/test_batch0_labels.jpg', width=800) # test batch 0 labels\n",
"Image(filename='runs/train/exp0/test_batch0_pred.jpg', width=800) # test batch 0 predictions"
"Image(filename='runs/train/exp/train_batch0.jpg', width=800) # train batch 0 mosaics and labels\n",
"Image(filename='runs/train/exp/test_batch0_labels.jpg', width=800) # test batch 0 labels\n",
"Image(filename='runs/train/exp/test_batch0_pred.jpg', width=800) # test batch 0 predictions"
],
"execution_count": null,
"outputs": []
@ -1078,7 +1078,7 @@
},
"source": [
"from utils.utils import plot_results \n",
"plot_results(save_dir='runs/train/exp0') # plot results.txt as results.png\n",
"plot_results(save_dir='runs/train/exp') # plot results.txt as results.png\n",
"Image(filename='results.png', width=800) "
],
"execution_count": null,
@ -1170,9 +1170,9 @@
" for di in 0 cpu # inference devices\n",
" do\n",
" python detect.py --weights $x.pt --device $di # detect official\n",
" python detect.py --weights runs/train/exp0/weights/last.pt --device $di # detect custom\n",
" python detect.py --weights runs/train/exp/weights/last.pt --device $di # detect custom\n",
" python test.py --weights $x.pt --device $di # test official\n",
" python test.py --weights runs/train/exp0/weights/last.pt --device $di # test custom\n",
" python test.py --weights runs/train/exp/weights/last.pt --device $di # test custom\n",
" done\n",
" python models/yolo.py --cfg $x.yaml # inspect\n",
" python models/export.py --weights $x.pt --img 640 --batch 1 # export\n",

View File

@ -60,7 +60,7 @@ def init_seeds(seed=0):
init_torch_seeds(seed)
def get_latest_run(search_dir='./runs'):
def get_latest_run(search_dir='.'):
# Return path to most recent 'last.pt' in /runs (i.e. to --resume from)
last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
return max(last_list, key=os.path.getctime) if last_list else ''
@ -951,23 +951,17 @@ def output_to_target(output, width, height):
return np.array(targets)
def increment_dir(dir, comment=''):
# Increments a directory runs/exp1 --> runs/exp2_comment
n = 0 # number
dir = str(Path(dir)) # os-agnostic
if os.path.isdir(dir):
stem = ''
dir += os.sep # removed by Path
def increment_path(path, exist_ok=True, sep=''):
# Increment path, i.e. runs/exp --> runs/exp{sep}0, runs/exp{sep}1 etc.
path = Path(path) # os-agnostic
if (path.exists() and exist_ok) or (not path.exists()):
return str(path)
else:
stem = Path(dir).stem
dirs = sorted(glob.glob(dir + '*')) # directories
if dirs:
matches = [re.search(r"%s(\d+)" % stem, d) for d in dirs]
idxs = [int(m.groups()[0]) for m in matches if m]
if idxs:
n = max(idxs) + 1 # increment
return dir + str(n) + ('_' + comment if comment else '')
dirs = glob.glob(f"{path}{sep}*") # similar paths
matches = [re.search(rf"%s{sep}(\d+)" % path.stem, d) for d in dirs]
i = [int(m.groups()[0]) for m in matches if m] # indices
n = max(i) + 1 if i else 2 # increment number
return f"{path}{sep}{n}" # update path
# Plotting functions ---------------------------------------------------------------------------------------------------