Add hubconf.py argparser (#8799)

* Add hubconf.py argparser

* Add hubconf.py argparser
pull/8800/head
Glenn Jocher 2022-07-30 21:00:28 +02:00 committed by GitHub
parent e34ae8837b
commit 9111246208
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 15 additions and 6 deletions

View File

@ -106,7 +106,7 @@ jobs:
# Detect
python detect.py --weights $model.pt --device $d
python detect.py --weights $best --device $d
python hubconf.py # hub
python hubconf.py --model $model # hub
# Export
# python models/tf.py --weights $model.pt # build TF model
python models/yolo.py --cfg $model.yaml # build PyTorch model

View File

@ -41,7 +41,6 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo
path = name.with_suffix('.pt') if name.suffix == '' and not name.is_dir() else name # checkpoint path
try:
device = select_device(device)
if pretrained and channels == 3 and classes == 80:
model = DetectMultiBackend(path, device=device, fuse=autoshape) # download/load FP32 model
# model = models.experimental.attempt_load(path, map_location=device) # download/load FP32 model
@ -123,10 +122,7 @@ def yolov5x6(pretrained=True, channels=3, classes=80, autoshape=True, _verbose=T
if __name__ == '__main__':
model = _create(name='yolov5s', pretrained=True, channels=3, classes=80, autoshape=True, verbose=True)
# model = custom(path='path/to/model.pt') # custom
# Verify inference
import argparse
from pathlib import Path
import numpy as np
@ -134,6 +130,16 @@ if __name__ == '__main__':
from utils.general import cv2
# Argparser
parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, default='yolov5s', help='model name')
opt = parser.parse_args()
# Model
model = _create(name=opt.model, pretrained=True, channels=3, classes=80, autoshape=True, verbose=True)
# model = custom(path='path/to/model.pt') # custom
# Images
imgs = [
'data/images/zidane.jpg', # filename
Path('data/images/zidane.jpg'), # Path
@ -142,6 +148,9 @@ if __name__ == '__main__':
Image.open('data/images/bus.jpg'), # PIL
np.zeros((320, 640, 3))] # numpy
# Inference
results = model(imgs, size=320) # batched inference
# Results
results.print()
results.save()