mirror of
https://github.com/ultralytics/yolov5.git
synced 2025-06-03 14:49:29 +08:00
Improve doctstring, add usage examples
This commit is contained in:
parent
76259b15b0
commit
136640eee8
@ -1,5 +1,33 @@
|
||||
# YOLOv5 classifier training
|
||||
# Usage: python classifier.py --model yolov5s --data mnist --epochs 10 --img 128
|
||||
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
|
||||
"""
|
||||
Train a YOLOv5 classifier model on a classification dataset
|
||||
|
||||
Usage-train:
|
||||
$ python path/to/classifier.py --model yolov5s --data mnist --epochs 5 --img 128
|
||||
|
||||
Usage-inference:
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
# Functions
|
||||
resize = torch.nn.Upsample(size=(128, 128), mode='bilinear', align_corners=False)
|
||||
normalize = lambda x, mean=0.5, std=0.25: (x - mean) / std
|
||||
|
||||
# Model
|
||||
model = torch.load('runs/train/exp2/weights/best.pt')['model'].cpu().float()
|
||||
|
||||
# Image
|
||||
im = cv2.imread('../mnist/test/0/10.png')[::-1] # HWC, BGR to RGB
|
||||
im = np.ascontiguousarray(np.asarray(im).transpose((2, 0, 1))) # HWC to CHW
|
||||
im = torch.tensor(im).unsqueeze(0) / 255.0 # to Tensor, to BCWH, rescale
|
||||
im = resize(normalize(im))
|
||||
|
||||
# Inference
|
||||
results = model(im)
|
||||
p = F.softmax(results, dim=1) # probabilities
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
|
Loading…
x
Reference in New Issue
Block a user