Update classifier.py

This commit is contained in:
Glenn Jocher 2021-02-10 13:29:23 -08:00 committed by GitHub
parent b34e21b97b
commit 04fddf507f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -68,7 +68,7 @@ def train():
# Model
if opt.model.startswith('yolov5'):
# YOLOv5 Classifier
model = torch.hub.load('ultralytics/yolov5', opt.model, pretrained=True)
model = torch.hub.load('ultralytics/yolov5', opt.model, pretrained=True, autoshape=False)
model.model = model.model[:8]
m = model.model[-1] # last layer
ch = m.conv.in_channels if hasattr(m, 'conv') else sum([x.in_channels for x in m.m]) # ch into module