Add suffix checks (#4711)

* Add suffix checks

* Cleanup

* Cleanup2

* Cleanup3
pull/4712/head
Glenn Jocher 2021-09-08 14:36:12 +02:00 committed by GitHub
parent 8e5f9ddbdb
commit a2b3c71636
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 41 additions and 21 deletions

View File

@ -21,8 +21,9 @@ sys.path.append(FILE.parents[0].as_posix()) # add yolov5/ to path
from models.experimental import attempt_load
from utils.datasets import LoadStreams, LoadImages
from utils.general import check_img_size, check_requirements, check_imshow, colorstr, is_ascii, non_max_suppression, \
apply_classifier, scale_coords, xyxy2xywh, strip_optimizer, set_logging, increment_path, save_one_box
from utils.general import check_img_size, check_imshow, check_requirements, check_suffix, colorstr, is_ascii, \
non_max_suppression, apply_classifier, scale_coords, xyxy2xywh, strip_optimizer, set_logging, increment_path, \
save_one_box
from utils.plots import Annotator, colors
from utils.torch_utils import select_device, load_classifier, time_sync
@ -68,8 +69,9 @@ def run(weights='yolov5s.pt', # model.pt path(s)
# Load model
w = weights[0] if isinstance(weights, list) else weights
classify, suffix = False, Path(w).suffix.lower()
pt, onnx, tflite, pb, saved_model = (suffix == x for x in ['.pt', '.onnx', '.tflite', '.pb', '']) # backend
classify, suffix, suffixes = False, Path(w).suffix.lower(), ['.pt', '.onnx', '.tflite', '.pb', '']
check_suffix(w, suffixes) # check weights have acceptable suffix
pt, onnx, tflite, pb, saved_model = (suffix == x for x in suffixes) # backend booleans
stride, names = 64, [f'class{i}' for i in range(1000)] # assign defaults
if pt:
model = attempt_load(weights, map_location=device) # load FP32 model

View File

@ -53,7 +53,7 @@ from models.common import Conv, Bottleneck, SPP, DWConv, Focus, BottleneckCSP, C
from models.experimental import MixConv2d, CrossConv, attempt_load
from models.yolo import Detect
from utils.datasets import LoadImages
from utils.general import make_divisible, check_file, check_dataset
from utils.general import check_dataset, check_yaml, make_divisible
logger = logging.getLogger(__name__)
@ -447,7 +447,7 @@ if __name__ == "__main__":
parser.add_argument('--iou-thres', type=float, default=0.5, help='IOU threshold for NMS')
parser.add_argument('--score-thres', type=float, default=0.4, help='score threshold for NMS')
opt = parser.parse_args()
opt.cfg = check_file(opt.cfg) # check file
opt.cfg = check_yaml(opt.cfg) # check YAML
opt.img_size *= 2 if len(opt.img_size) == 1 else 1 # expand
print(opt)
@ -534,7 +534,7 @@ if __name__ == "__main__":
if opt.tfl_int8:
# Representative Dataset
if opt.source.endswith('.yaml'):
with open(check_file(opt.source)) as f:
with open(check_yaml(opt.source)) as f:
data = yaml.load(f, Loader=yaml.FullLoader) # data dict
check_dataset(data) # check
opt.source = data['train']

View File

@ -17,10 +17,10 @@ sys.path.append(FILE.parents[1].as_posix()) # add yolov5/ to path
from models.common import *
from models.experimental import *
from utils.autoanchor import check_anchor_order
from utils.general import make_divisible, check_file, set_logging
from utils.general import check_yaml, make_divisible, set_logging
from utils.plots import feature_visualization
from utils.torch_utils import time_sync, fuse_conv_and_bn, model_info, scale_img, initialize_weights, \
select_device, copy_attr
from utils.torch_utils import copy_attr, fuse_conv_and_bn, initialize_weights, model_info, scale_img, \
select_device, time_sync
try:
import thop # for FLOPs computation
@ -281,7 +281,7 @@ if __name__ == '__main__':
parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
parser.add_argument('--profile', action='store_true', help='profile model speed')
opt = parser.parse_args()
opt.cfg = check_file(opt.cfg) # check file
opt.cfg = check_yaml(opt.cfg) # check YAML
set_logging()
device = select_device(opt.device)

View File

@ -35,8 +35,8 @@ from models.yolo import Model
from utils.autoanchor import check_anchors
from utils.datasets import create_dataloader
from utils.general import labels_to_class_weights, increment_path, labels_to_image_weights, init_seeds, \
strip_optimizer, get_latest_run, check_dataset, check_file, check_git_status, check_img_size, \
check_requirements, print_mutation, set_logging, one_cycle, colorstr, methods
strip_optimizer, get_latest_run, check_dataset, check_git_status, check_img_size, check_requirements, \
check_yaml, check_suffix, print_mutation, set_logging, one_cycle, colorstr, methods
from utils.downloads import attempt_download
from utils.loss import ComputeLoss
from utils.plots import plot_labels, plot_evolve
@ -484,7 +484,8 @@ def main(opt, callbacks=Callbacks()):
opt.cfg, opt.weights, opt.resume = '', ckpt, True # reinstate
LOGGER.info(f'Resuming training from {ckpt}')
else:
opt.data, opt.cfg, opt.hyp = check_file(opt.data), check_file(opt.cfg), check_file(opt.hyp) # check files
check_suffix(opt.weights, '.pt') # check weights
opt.data, opt.cfg, opt.hyp = check_yaml(opt.data), check_yaml(opt.cfg), check_yaml(opt.hyp) # check YAMLs
assert len(opt.cfg) or len(opt.weights), 'either --cfg or --weights must be specified'
if opt.evolve:
opt.project = 'runs/evolve'

View File

@ -26,8 +26,8 @@ from torch.utils.data import Dataset
from tqdm import tqdm
from utils.augmentations import Albumentations, augment_hsv, copy_paste, letterbox, mixup, random_perspective
from utils.general import check_requirements, check_file, check_dataset, xywh2xyxy, xywhn2xyxy, xyxy2xywhn, \
xyn2xy, segments2boxes, clean_str
from utils.general import check_dataset, check_requirements, check_yaml, clean_str, segments2boxes, \
xywh2xyxy, xywhn2xyxy, xyxy2xywhn, xyn2xy
from utils.torch_utils import torch_distributed_zero_first
# Parameters
@ -938,7 +938,7 @@ def dataset_stats(path='coco128.yaml', autodownload=False, verbose=False, profil
im.save(im_dir / Path(f).name, quality=75) # save
zipped, data_dir, yaml_path = unzip(Path(path))
with open(check_file(yaml_path), errors='ignore') as f:
with open(check_yaml(yaml_path), errors='ignore') as f:
data = yaml.safe_load(f) # data dict
if zipped:
data['path'] = data_dir # TODO: should this be dir.resolve()?

View File

@ -242,8 +242,23 @@ def check_imshow():
return False
def check_file(file):
def check_suffix(file='yolov5s.pt', suffix=('.pt',), msg=''):
# Check file(s) for acceptable suffixes
if any(suffix):
if isinstance(suffix, str):
suffix = [suffix]
for f in file if isinstance(file, (list, tuple)) else [file]:
assert Path(f).suffix.lower() in suffix, f"{msg}{f} acceptable suffix is {suffix}"
def check_yaml(file, suffix=('.yaml', '.yml')):
# Check YAML file(s) for acceptable suffixes
return check_file(file, suffix)
def check_file(file, suffix=''):
# Search/download file (if necessary) and return path
check_suffix(file, suffix)
file = str(file) # convert to str()
if Path(file).is_file() or file == '': # exists
return file

8
val.py
View File

@ -22,8 +22,9 @@ sys.path.append(FILE.parents[0].as_posix()) # add yolov5/ to path
from models.experimental import attempt_load
from utils.datasets import create_dataloader
from utils.general import coco80_to_coco91_class, check_dataset, check_file, check_img_size, check_requirements, \
box_iou, non_max_suppression, scale_coords, xyxy2xywh, xywh2xyxy, set_logging, increment_path, colorstr
from utils.general import coco80_to_coco91_class, check_dataset, check_img_size, check_requirements, \
check_suffix, check_yaml, box_iou, non_max_suppression, scale_coords, xyxy2xywh, xywh2xyxy, set_logging, \
increment_path, colorstr
from utils.metrics import ap_per_class, ConfusionMatrix
from utils.plots import plot_images, output_to_target, plot_study_txt
from utils.torch_utils import select_device, time_sync
@ -116,6 +117,7 @@ def run(data,
(save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir
# Load model
check_suffix(weights, '.pt')
model = attempt_load(weights, map_location=device) # load FP32 model
gs = max(int(model.stride.max()), 32) # grid size (max stride)
imgsz = check_img_size(imgsz, s=gs) # check image size
@ -316,7 +318,7 @@ def parse_opt():
opt = parser.parse_args()
opt.save_json |= opt.data.endswith('coco.yaml')
opt.save_txt |= opt.save_hybrid
opt.data = check_file(opt.data) # check file
opt.data = check_yaml(opt.data) # check YAML
return opt