mirror of
https://github.com/ultralytics/yolov5.git
synced 2025-06-03 14:49:29 +08:00
* Update hubconf.py with check_requirements() Dependency checks have been missing from YOLOv5 PyTorch Hub model loading, causing errors in some cases when users are attempting to import hub models in unsupported environments. This should examine the YOLOv5 requirements.txt file and pip install any missing or version-conflict packages encountered. This is highly experimental (!), please let us know if this creates problems in your custom workflows. * Update hubconf.py
150 lines
5.4 KiB
Python
150 lines
5.4 KiB
Python
"""File for accessing YOLOv5 models via PyTorch Hub https://pytorch.org/hub/ultralytics_yolov5/
|
|
|
|
Usage:
|
|
import torch
|
|
model = torch.hub.load('ultralytics/yolov5', 'yolov5s')
|
|
"""
|
|
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
|
|
from models.yolo import Model
|
|
from utils.general import check_requirements, set_logging
|
|
from utils.google_utils import attempt_download
|
|
from utils.torch_utils import select_device
|
|
|
|
dependencies = ['torch', 'yaml']
|
|
check_requirements(exclude=('pycocotools', 'thop'))
|
|
set_logging()
|
|
|
|
|
|
def create(name, pretrained, channels, classes, autoshape):
|
|
"""Creates a specified YOLOv5 model
|
|
|
|
Arguments:
|
|
name (str): name of model, i.e. 'yolov5s'
|
|
pretrained (bool): load pretrained weights into the model
|
|
channels (int): number of input channels
|
|
classes (int): number of model classes
|
|
|
|
Returns:
|
|
pytorch model
|
|
"""
|
|
config = Path(__file__).parent / 'models' / f'{name}.yaml' # model.yaml path
|
|
try:
|
|
model = Model(config, channels, classes)
|
|
if pretrained:
|
|
fname = f'{name}.pt' # checkpoint filename
|
|
attempt_download(fname) # download if not found locally
|
|
ckpt = torch.load(fname, map_location=torch.device('cpu')) # load
|
|
state_dict = ckpt['model'].float().state_dict() # to FP32
|
|
state_dict = {k: v for k, v in state_dict.items() if model.state_dict()[k].shape == v.shape} # filter
|
|
model.load_state_dict(state_dict, strict=False) # load
|
|
if len(ckpt['model'].names) == classes:
|
|
model.names = ckpt['model'].names # set class names attribute
|
|
if autoshape:
|
|
model = model.autoshape() # for file/URI/PIL/cv2/np inputs and NMS
|
|
device = select_device('0' if torch.cuda.is_available() else 'cpu') # default to GPU if available
|
|
return model.to(device)
|
|
|
|
except Exception as e:
|
|
help_url = 'https://github.com/ultralytics/yolov5/issues/36'
|
|
s = 'Cache maybe be out of date, try force_reload=True. See %s for help.' % help_url
|
|
raise Exception(s) from e
|
|
|
|
|
|
def yolov5s(pretrained=True, channels=3, classes=80, autoshape=True):
|
|
"""YOLOv5-small model from https://github.com/ultralytics/yolov5
|
|
|
|
Arguments:
|
|
pretrained (bool): load pretrained weights into the model, default=False
|
|
channels (int): number of input channels, default=3
|
|
classes (int): number of model classes, default=80
|
|
|
|
Returns:
|
|
pytorch model
|
|
"""
|
|
return create('yolov5s', pretrained, channels, classes, autoshape)
|
|
|
|
|
|
def yolov5m(pretrained=True, channels=3, classes=80, autoshape=True):
|
|
"""YOLOv5-medium model from https://github.com/ultralytics/yolov5
|
|
|
|
Arguments:
|
|
pretrained (bool): load pretrained weights into the model, default=False
|
|
channels (int): number of input channels, default=3
|
|
classes (int): number of model classes, default=80
|
|
|
|
Returns:
|
|
pytorch model
|
|
"""
|
|
return create('yolov5m', pretrained, channels, classes, autoshape)
|
|
|
|
|
|
def yolov5l(pretrained=True, channels=3, classes=80, autoshape=True):
|
|
"""YOLOv5-large model from https://github.com/ultralytics/yolov5
|
|
|
|
Arguments:
|
|
pretrained (bool): load pretrained weights into the model, default=False
|
|
channels (int): number of input channels, default=3
|
|
classes (int): number of model classes, default=80
|
|
|
|
Returns:
|
|
pytorch model
|
|
"""
|
|
return create('yolov5l', pretrained, channels, classes, autoshape)
|
|
|
|
|
|
def yolov5x(pretrained=True, channels=3, classes=80, autoshape=True):
|
|
"""YOLOv5-xlarge model from https://github.com/ultralytics/yolov5
|
|
|
|
Arguments:
|
|
pretrained (bool): load pretrained weights into the model, default=False
|
|
channels (int): number of input channels, default=3
|
|
classes (int): number of model classes, default=80
|
|
|
|
Returns:
|
|
pytorch model
|
|
"""
|
|
return create('yolov5x', pretrained, channels, classes, autoshape)
|
|
|
|
|
|
def custom(path_or_model='path/to/model.pt', autoshape=True):
|
|
"""YOLOv5-custom model from https://github.com/ultralytics/yolov5
|
|
|
|
Arguments (3 options):
|
|
path_or_model (str): 'path/to/model.pt'
|
|
path_or_model (dict): torch.load('path/to/model.pt')
|
|
path_or_model (nn.Module): torch.load('path/to/model.pt')['model']
|
|
|
|
Returns:
|
|
pytorch model
|
|
"""
|
|
model = torch.load(path_or_model) if isinstance(path_or_model, str) else path_or_model # load checkpoint
|
|
if isinstance(model, dict):
|
|
model = model['ema' if model.get('ema') else 'model'] # load model
|
|
|
|
hub_model = Model(model.yaml).to(next(model.parameters()).device) # create
|
|
hub_model.load_state_dict(model.float().state_dict()) # load state_dict
|
|
hub_model.names = model.names # class names
|
|
return hub_model.autoshape() if autoshape else hub_model
|
|
|
|
|
|
if __name__ == '__main__':
|
|
model = create(name='yolov5s', pretrained=True, channels=3, classes=80, autoshape=True) # pretrained example
|
|
# model = custom(path_or_model='path/to/model.pt') # custom example
|
|
|
|
# Verify inference
|
|
import numpy as np
|
|
from PIL import Image
|
|
|
|
imgs = [Image.open('data/images/bus.jpg'), # PIL
|
|
'data/images/zidane.jpg', # filename
|
|
'https://github.com/ultralytics/yolov5/raw/master/data/images/bus.jpg', # URI
|
|
np.zeros((640, 480, 3))] # numpy
|
|
|
|
results = model(imgs) # batched inference
|
|
results.print()
|
|
results.save()
|