From c4addd7761a4627f6fa97604ad9f1124f439b774 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Thu, 12 Nov 2020 23:37:46 +0100 Subject: [PATCH] Unified '/project/name' results saving (#1377) * Project/name update * Update ci-testing.yml * address project with path separator failure mode * Project/name update * address project with path separator failure mode * Update ci-testing.yml * detect.py default --name bug fix * missing rstrip PR * train/exp0 to train/exp --- .github/workflows/ci-testing.yml | 4 +-- README.md | 2 +- detect.py | 16 ++++----- test.py | 16 ++++----- train.py | 57 +++++++++++++++++--------------- tutorial.ipynb | 26 +++++++-------- utils/general.py | 28 ++++++---------- 7 files changed, 71 insertions(+), 78 deletions(-) diff --git a/.github/workflows/ci-testing.yml b/.github/workflows/ci-testing.yml index 3dddc7c5e..faae95ab1 100644 --- a/.github/workflows/ci-testing.yml +++ b/.github/workflows/ci-testing.yml @@ -66,10 +66,10 @@ jobs: python train.py --img 256 --batch 8 --weights weights/${{ matrix.model }}.pt --cfg models/${{ matrix.model }}.yaml --epochs 1 --device $di # detect python detect.py --weights weights/${{ matrix.model }}.pt --device $di - python detect.py --weights runs/train/exp0/weights/last.pt --device $di + python detect.py --weights runs/train/exp/weights/last.pt --device $di # test python test.py --img 256 --batch 8 --weights weights/${{ matrix.model }}.pt --device $di - python test.py --img 256 --batch 8 --weights runs/train/exp0/weights/last.pt --device $di + python test.py --img 256 --batch 8 --weights runs/train/exp/weights/last.pt --device $di python models/yolo.py --cfg models/${{ matrix.model }}.yaml # inspect python models/export.py --img 256 --batch 1 --weights weights/${{ matrix.model }}.pt # export diff --git a/README.md b/README.md index 95a82ab8c..cc978c1d4 100755 --- a/README.md +++ b/README.md @@ -96,7 +96,7 @@ Fusing layers... Model Summary: 140 layers, 7.45958e+06 parameters, 0 gradients image 1/2 data/images/bus.jpg: 640x480 4 persons, 1 buss, 1 skateboards, Done. (0.013s) image 2/2 data/images/zidane.jpg: 384x640 2 persons, 2 ties, Done. (0.013s) -Results saved to runs/detect/exp0 +Results saved to runs/detect/exp Done. (0.124s) ``` diff --git a/detect.py b/detect.py index fed737e35..50e5c3cbd 100644 --- a/detect.py +++ b/detect.py @@ -10,21 +10,18 @@ from numpy import random from models.experimental import attempt_load from utils.datasets import LoadStreams, LoadImages from utils.general import check_img_size, non_max_suppression, apply_classifier, scale_coords, xyxy2xywh, \ - plot_one_box, strip_optimizer, set_logging, increment_dir + plot_one_box, strip_optimizer, set_logging, increment_path from utils.torch_utils import select_device, load_classifier, time_synchronized def detect(save_img=False): - save_dir, source, weights, view_img, save_txt, imgsz = \ - Path(opt.save_dir), opt.source, opt.weights, opt.view_img, opt.save_txt, opt.img_size + source, weights, view_img, save_txt, imgsz = opt.source, opt.weights, opt.view_img, opt.save_txt, opt.img_size webcam = source.isnumeric() or source.endswith('.txt') or \ source.lower().startswith(('rtsp://', 'rtmp://', 'http://')) # Directories - if save_dir == Path('runs/detect'): # if default - save_dir.mkdir(parents=True, exist_ok=True) # make base - save_dir = Path(increment_dir(save_dir / 'exp', opt.name)) # increment run - (save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make new dir + save_dir = Path(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok)) # increment run + (save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir # Initialize set_logging() @@ -156,12 +153,13 @@ if __name__ == '__main__': parser.add_argument('--view-img', action='store_true', help='display results') parser.add_argument('--save-txt', action='store_true', help='save results to *.txt') parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels') - parser.add_argument('--save-dir', type=str, default='runs/detect', help='directory to save results') - parser.add_argument('--name', default='', help='name to append to --save-dir: i.e. runs/{N} -> runs/{N}_{name}') parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --class 0, or --class 0 2 3') parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS') parser.add_argument('--augment', action='store_true', help='augmented inference') parser.add_argument('--update', action='store_true', help='update all models') + parser.add_argument('--project', default='runs/detect', help='save results to project/name') + parser.add_argument('--name', default='exp', help='save results to project/name') + parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment') opt = parser.parse_args() print(opt) diff --git a/test.py b/test.py index cef13f7f7..a3c5df8d1 100644 --- a/test.py +++ b/test.py @@ -13,7 +13,7 @@ from models.experimental import attempt_load from utils.datasets import create_dataloader from utils.general import coco80_to_coco91_class, check_dataset, check_file, check_img_size, compute_loss, \ non_max_suppression, scale_coords, xyxy2xywh, clip_coords, plot_images, xywh2xyxy, box_iou, output_to_target, \ - ap_per_class, set_logging, increment_dir + ap_per_class, set_logging, increment_path from utils.torch_utils import select_device, time_synchronized @@ -46,10 +46,8 @@ def test(data, save_txt = opt.save_txt # save *.txt labels # Directories - if save_dir == Path('runs/test'): # if default - save_dir.mkdir(parents=True, exist_ok=True) # make base - save_dir = Path(increment_dir(save_dir / 'exp', opt.name)) # increment run - (save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make new dir + save_dir = Path(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok)) # increment run + (save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir # Load model model = attempt_load(weights, map_location=device) # load FP32 model @@ -279,7 +277,6 @@ if __name__ == '__main__': parser.add_argument('--img-size', type=int, default=640, help='inference size (pixels)') parser.add_argument('--conf-thres', type=float, default=0.001, help='object confidence threshold') parser.add_argument('--iou-thres', type=float, default=0.65, help='IOU threshold for NMS') - parser.add_argument('--save-json', action='store_true', help='save a cocoapi-compatible JSON results file') parser.add_argument('--task', default='val', help="'val', 'test', 'study'") parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') parser.add_argument('--single-cls', action='store_true', help='treat as single-class dataset') @@ -287,8 +284,10 @@ if __name__ == '__main__': parser.add_argument('--verbose', action='store_true', help='report mAP by class') parser.add_argument('--save-txt', action='store_true', help='save results to *.txt') parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels') - parser.add_argument('--save-dir', type=str, default='runs/test', help='directory to save results') - parser.add_argument('--name', default='', help='name to append to --save-dir: i.e. runs/{N} -> runs/{N}_{name}') + parser.add_argument('--save-json', action='store_true', help='save a cocoapi-compatible JSON results file') + parser.add_argument('--project', default='runs/test', help='save to project/name') + parser.add_argument('--name', default='exp', help='save to project/name') + parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment') opt = parser.parse_args() opt.save_json |= opt.data.endswith('coco.yaml') opt.data = check_file(opt.data) # check file @@ -305,7 +304,6 @@ if __name__ == '__main__': opt.single_cls, opt.augment, opt.verbose, - save_dir=Path(opt.save_dir), save_txt=opt.save_txt, save_conf=opt.save_conf, ) diff --git a/train.py b/train.py index bfa1e035d..a2bf59f3c 100644 --- a/train.py +++ b/train.py @@ -27,7 +27,7 @@ from utils.datasets import create_dataloader from utils.general import ( torch_distributed_zero_first, labels_to_class_weights, plot_labels, check_anchors, labels_to_image_weights, compute_loss, plot_images, fitness, strip_optimizer, plot_results, get_latest_run, check_dataset, check_file, - check_git_status, check_img_size, increment_dir, print_mutation, plot_evolution, set_logging, init_seeds) + check_git_status, check_img_size, increment_path, print_mutation, plot_evolution, set_logging, init_seeds) from utils.google_utils import attempt_download from utils.torch_utils import ModelEMA, select_device, intersect_dicts @@ -36,19 +36,20 @@ logger = logging.getLogger(__name__) def train(hyp, opt, device, tb_writer=None, wandb=None): logger.info(f'Hyperparameters {hyp}') - log_dir = Path(tb_writer.log_dir) if tb_writer else Path(opt.logdir) / 'evolve' # logging directory - wdir = log_dir / 'weights' # weights directory - wdir.mkdir(parents=True, exist_ok=True) + save_dir, epochs, batch_size, total_batch_size, weights, rank = \ + opt.save_dir, opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank + + # Directories + wdir = save_dir / 'weights' + wdir.mkdir(parents=True, exist_ok=True) # make dir last = wdir / 'last.pt' best = wdir / 'best.pt' - results_file = log_dir / 'results.txt' - epochs, batch_size, total_batch_size, weights, rank = \ - opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank + results_file = save_dir / 'results.txt' # Save run settings - with open(log_dir / 'hyp.yaml', 'w') as f: + with open(save_dir / 'hyp.yaml', 'w') as f: yaml.dump(hyp, f, sort_keys=False) - with open(log_dir / 'opt.yaml', 'w') as f: + with open(save_dir / 'opt.yaml', 'w') as f: yaml.dump(vars(opt), f, sort_keys=False) # Configure @@ -120,8 +121,10 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): # Logging if wandb and wandb.run is None: - id = ckpt.get('wandb_id') if 'ckpt' in locals() else None - wandb_run = wandb.init(config=opt, resume="allow", project="YOLOv5", name=log_dir.stem, id=id) + wandb_run = wandb.init(config=opt, resume="allow", + project='YOLOv5' if opt.project == 'runs/train' else Path(opt.project).stem, + name=save_dir.stem, + id=ckpt.get('wandb_id') if 'ckpt' in locals() else None) # Resume start_epoch, best_fitness = 0, 0.0 @@ -188,7 +191,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): c = torch.tensor(labels[:, 0]) # classes # cf = torch.bincount(c.long(), minlength=nc) + 1. # frequency # model._initialize_biases(cf.to(device)) - plot_labels(labels, save_dir=log_dir) + plot_labels(labels, save_dir=save_dir) if tb_writer: # tb_writer.add_hparams(hyp, {}) # causes duplicate https://github.com/ultralytics/yolov5/pull/384 tb_writer.add_histogram('classes', c, 0) @@ -215,7 +218,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): scaler = amp.GradScaler(enabled=cuda) logger.info('Image sizes %g train, %g test\n' 'Using %g dataloader workers\nLogging results to %s\n' - 'Starting training for %g epochs...' % (imgsz, imgsz_test, dataloader.num_workers, log_dir, epochs)) + 'Starting training for %g epochs...' % (imgsz, imgsz_test, dataloader.num_workers, save_dir, epochs)) for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------ model.train() @@ -296,7 +299,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): # Plot if ni < 3: - f = str(log_dir / f'train_batch{ni}.jpg') # filename + f = str(save_dir / f'train_batch{ni}.jpg') # filename result = plot_images(images=imgs, targets=targets, paths=paths, fname=f) # if tb_writer and result is not None: # tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch) @@ -321,7 +324,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): model=ema.ema, single_cls=opt.single_cls, dataloader=testloader, - save_dir=log_dir, + save_dir=save_dir, plots=epoch == 0 or final_epoch, # plot first and last log_imgs=opt.log_imgs if wandb else 0) @@ -369,7 +372,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): if rank in [-1, 0]: # Strip optimizers n = opt.name if opt.name.isnumeric() else '' - fresults, flast, fbest = log_dir / f'results{n}.txt', wdir / f'last{n}.pt', wdir / f'best{n}.pt' + fresults, flast, fbest = save_dir / f'results{n}.txt', wdir / f'last{n}.pt', wdir / f'best{n}.pt' for f1, f2 in zip([wdir / 'last.pt', wdir / 'best.pt', results_file], [flast, fbest, fresults]): if f1.exists(): os.rename(f1, f2) # rename @@ -378,7 +381,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): os.system('gsutil cp %s gs://%s/weights' % (f2, opt.bucket)) if opt.bucket else None # upload # Finish if not opt.evolve: - plot_results(save_dir=log_dir) # save as results.png + plot_results(save_dir=save_dir) # save as results.png logger.info('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600)) dist.destroy_process_group() if rank not in [-1, 0] else None @@ -410,11 +413,11 @@ if __name__ == '__main__': parser.add_argument('--adam', action='store_true', help='use torch.optim.Adam() optimizer') parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode') parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify') - parser.add_argument('--logdir', type=str, default='runs/train', help='logging directory') - parser.add_argument('--name', default='', help='name to append to --save-dir: i.e. runs/{N} -> runs/{N}_{name}') parser.add_argument('--log-imgs', type=int, default=10, help='number of images for W&B logging, max 100') parser.add_argument('--workers', type=int, default=8, help='maximum number of dataloader workers') - + parser.add_argument('--project', default='runs/train', help='save to project/name') + parser.add_argument('--name', default='exp', help='save to project/name') + parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment') opt = parser.parse_args() # Set DDP variables @@ -428,19 +431,19 @@ if __name__ == '__main__': # Resume if opt.resume: # resume an interrupted run ckpt = opt.resume if isinstance(opt.resume, str) else get_latest_run() # specified or most recent path - log_dir = Path(ckpt).parent.parent # runs/train/exp0 + opt.save_dir = Path(ckpt).parent.parent # runs/train/exp assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist' - with open(log_dir / 'opt.yaml') as f: + with open(opt.save_dir / 'opt.yaml') as f: opt = argparse.Namespace(**yaml.load(f, Loader=yaml.FullLoader)) # replace opt.cfg, opt.weights, opt.resume = '', ckpt, True logger.info('Resuming training from %s' % ckpt) - else: # opt.hyp = opt.hyp or ('hyp.finetune.yaml' if opt.weights else 'hyp.scratch.yaml') opt.data, opt.cfg, opt.hyp = check_file(opt.data), check_file(opt.cfg), check_file(opt.hyp) # check files assert len(opt.cfg) or len(opt.weights), 'either --cfg or --weights must be specified' opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test) - log_dir = increment_dir(Path(opt.logdir) / 'exp', opt.name) # runs/exp1 + opt.name = 'evolve' if opt.evolve else opt.name + opt.save_dir = Path(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok)) # increment run # DDP mode device = select_device(opt.device, batch_size=opt.batch_size) @@ -466,8 +469,8 @@ if __name__ == '__main__': tb_writer, wandb = None, None # init loggers if opt.global_rank in [-1, 0]: # Tensorboard - logger.info(f'Start Tensorboard with "tensorboard --logdir {opt.logdir}", view at http://localhost:6006/') - tb_writer = SummaryWriter(log_dir=log_dir) # runs/train/exp0 + logger.info(f'Start Tensorboard with "tensorboard --logdir {opt.project}", view at http://localhost:6006/') + tb_writer = SummaryWriter(opt.save_dir) # runs/train/exp # W&B try: @@ -514,7 +517,7 @@ if __name__ == '__main__': assert opt.local_rank == -1, 'DDP mode not implemented for --evolve' opt.notest, opt.nosave = True, True # only test/save final epoch # ei = [isinstance(x, (int, float)) for x in hyp.values()] # evolvable indices - yaml_file = Path(opt.logdir) / 'evolve' / 'hyp_evolved.yaml' # save best result here + yaml_file = opt.save_dir / 'hyp_evolved.yaml' # save best result here if opt.bucket: os.system('gsutil cp gs://%s/evolve.txt .' % opt.bucket) # download evolve.txt if exists diff --git a/tutorial.ipynb b/tutorial.ipynb index d09c9e2e2..42fbf578e 100644 --- a/tutorial.ipynb +++ b/tutorial.ipynb @@ -597,7 +597,7 @@ }, "source": [ "!python detect.py --weights yolov5s.pt --img 640 --conf 0.25 --source data/images/\n", - "Image(filename='runs/detect/exp0/zidane.jpg', width=600)" + "Image(filename='runs/detect/exp/zidane.jpg', width=600)" ], "execution_count": null, "outputs": [ @@ -611,7 +611,7 @@ "Model Summary: 140 layers, 7.45958e+06 parameters, 0 gradients\n", "image 1/2 /content/yolov5/data/images/bus.jpg: 640x480 4 persons, 1 buss, 1 skateboards, Done. (0.012s)\n", "image 2/2 /content/yolov5/data/images/zidane.jpg: 384x640 2 persons, 2 ties, Done. (0.012s)\n", - "Results saved to runs/detect/exp0\n", + "Results saved to runs/detect/exp\n", "Done. (0.113s)\n" ], "name": "stdout" @@ -887,7 +887,7 @@ "source": [ "Train a YOLOv5s model on [COCO128](https://www.kaggle.com/ultralytics/coco128) with `--data coco128.yaml`, starting from pretrained `--weights yolov5s.pt`, or from randomly initialized `--weights '' --cfg yolov5s.yaml`. Models are downloaded automatically from the [latest YOLOv5 release](https://github.com/ultralytics/yolov5/releases), and **COCO, COCO128, and VOC datasets are downloaded automatically** on first use.\n", "\n", - "All training results are saved to `runs/train/` with incrementing run directories, i.e. `runs/train/exp0`, `runs/train/exp1` etc.\n" + "All training results are saved to `runs/train/` with incrementing run directories, i.e. `runs/train/exp2`, `runs/train/exp3` etc.\n" ] }, { @@ -969,7 +969,7 @@ "Analyzing anchors... anchors/target = 4.26, Best Possible Recall (BPR) = 0.9946\n", "Image sizes 640 train, 640 test\n", "Using 2 dataloader workers\n", - "Logging results to runs/train/exp0\n", + "Logging results to runs/train/exp\n", "Starting training for 3 epochs...\n", "\n", " Epoch gpu_mem box obj cls total targets img_size\n", @@ -986,8 +986,8 @@ " 2/2 3.17G 0.04445 0.06545 0.01666 0.1266 149 640: 100% 8/8 [00:01<00:00, 4.33it/s]\n", " Class Images Targets P R mAP@.5 mAP@.5:.95: 100% 8/8 [00:02<00:00, 2.78it/s]\n", " all 128 929 0.395 0.766 0.701 0.455\n", - "Optimizer stripped from runs/train/exp0/weights/last.pt, 15.2MB\n", - "Optimizer stripped from runs/train/exp0/weights/best.pt, 15.2MB\n", + "Optimizer stripped from runs/train/exp/weights/last.pt, 15.2MB\n", + "Optimizer stripped from runs/train/exp/weights/best.pt, 15.2MB\n", "3 epochs completed in 0.005 hours.\n", "\n" ], @@ -1030,7 +1030,7 @@ "source": [ "## Local Logging\n", "\n", - "All results are logged by default to `runs/train`, with a new experiment directory created for each new training as `runs/train/exp0`, `runs/train/exp1`, etc. View train and test jpgs to see mosaics, labels, predictions and augmentation effects. Note a **Mosaic Dataloader** is used for training (shown below), a new concept developed by Ultralytics and first featured in [YOLOv4](https://arxiv.org/abs/2004.10934)." + "All results are logged by default to `runs/train`, with a new experiment directory created for each new training as `runs/train/exp2`, `runs/train/exp3`, etc. View train and test jpgs to see mosaics, labels, predictions and augmentation effects. Note a **Mosaic Dataloader** is used for training (shown below), a new concept developed by Ultralytics and first featured in [YOLOv4](https://arxiv.org/abs/2004.10934)." ] }, { @@ -1039,9 +1039,9 @@ "id": "riPdhraOTCO0" }, "source": [ - "Image(filename='runs/train/exp0/train_batch0.jpg', width=800) # train batch 0 mosaics and labels\n", - "Image(filename='runs/train/exp0/test_batch0_labels.jpg', width=800) # test batch 0 labels\n", - "Image(filename='runs/train/exp0/test_batch0_pred.jpg', width=800) # test batch 0 predictions" + "Image(filename='runs/train/exp/train_batch0.jpg', width=800) # train batch 0 mosaics and labels\n", + "Image(filename='runs/train/exp/test_batch0_labels.jpg', width=800) # test batch 0 labels\n", + "Image(filename='runs/train/exp/test_batch0_pred.jpg', width=800) # test batch 0 predictions" ], "execution_count": null, "outputs": [] @@ -1078,7 +1078,7 @@ }, "source": [ "from utils.utils import plot_results \n", - "plot_results(save_dir='runs/train/exp0') # plot results.txt as results.png\n", + "plot_results(save_dir='runs/train/exp') # plot results.txt as results.png\n", "Image(filename='results.png', width=800) " ], "execution_count": null, @@ -1170,9 +1170,9 @@ " for di in 0 cpu # inference devices\n", " do\n", " python detect.py --weights $x.pt --device $di # detect official\n", - " python detect.py --weights runs/train/exp0/weights/last.pt --device $di # detect custom\n", + " python detect.py --weights runs/train/exp/weights/last.pt --device $di # detect custom\n", " python test.py --weights $x.pt --device $di # test official\n", - " python test.py --weights runs/train/exp0/weights/last.pt --device $di # test custom\n", + " python test.py --weights runs/train/exp/weights/last.pt --device $di # test custom\n", " done\n", " python models/yolo.py --cfg $x.yaml # inspect\n", " python models/export.py --weights $x.pt --img 640 --batch 1 # export\n", diff --git a/utils/general.py b/utils/general.py index 499d58285..5b2bbefaa 100755 --- a/utils/general.py +++ b/utils/general.py @@ -60,7 +60,7 @@ def init_seeds(seed=0): init_torch_seeds(seed) -def get_latest_run(search_dir='./runs'): +def get_latest_run(search_dir='.'): # Return path to most recent 'last.pt' in /runs (i.e. to --resume from) last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True) return max(last_list, key=os.path.getctime) if last_list else '' @@ -951,23 +951,17 @@ def output_to_target(output, width, height): return np.array(targets) -def increment_dir(dir, comment=''): - # Increments a directory runs/exp1 --> runs/exp2_comment - n = 0 # number - dir = str(Path(dir)) # os-agnostic - if os.path.isdir(dir): - stem = '' - dir += os.sep # removed by Path +def increment_path(path, exist_ok=True, sep=''): + # Increment path, i.e. runs/exp --> runs/exp{sep}0, runs/exp{sep}1 etc. + path = Path(path) # os-agnostic + if (path.exists() and exist_ok) or (not path.exists()): + return str(path) else: - stem = Path(dir).stem - - dirs = sorted(glob.glob(dir + '*')) # directories - if dirs: - matches = [re.search(r"%s(\d+)" % stem, d) for d in dirs] - idxs = [int(m.groups()[0]) for m in matches if m] - if idxs: - n = max(idxs) + 1 # increment - return dir + str(n) + ('_' + comment if comment else '') + dirs = glob.glob(f"{path}{sep}*") # similar paths + matches = [re.search(rf"%s{sep}(\d+)" % path.stem, d) for d in dirs] + i = [int(m.groups()[0]) for m in matches if m] # indices + n = max(i) + 1 if i else 2 # increment number + return f"{path}{sep}{n}" # update path # Plotting functions ---------------------------------------------------------------------------------------------------