mirror of
https://github.com/ultralytics/yolov5.git
synced 2025-06-03 14:49:29 +08:00
* Update * Logger step fix: Increment step with epochs (#8654) * enhance * revert * allow training from scratch * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update --img argument from train.py single line * fix image size from 640 to 128 * suport custom dataloader and augmentation * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * format * Update dataloaders.py * Single line return, single line comment, remove unused argument * address PR comments * fix spelling * don't augment eval set * use fstring * update augmentations.py * new maning convention for transforms * reverse if statement, inline ops * reverse if statement, inline ops * updates * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update dataloaders * Remove additional if statement * Remove is_train as redundant * Cleanup * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Cleanup2 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update classifier.py * Update augmentations.py * fix: imshow clip warning * update * Revert ToTensorV2 removal * Update classifier.py * Update normalize values, revert uint8 * normalize image using cv2 * remove dedundant comment * Update classifier.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * replace print with logger * commit steps * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com> * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Allow logging models from GenericLogger (#8676) * enhance * revert * allow training from scratch * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update --img argument from train.py single line * fix image size from 640 to 128 * suport custom dataloader and augmentation * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * format * Update dataloaders.py * Single line return, single line comment, remove unused argument * address PR comments * fix spelling * don't augment eval set * use fstring * update augmentations.py * new maning convention for transforms * reverse if statement, inline ops * reverse if statement, inline ops * updates * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update dataloaders * Remove additional if statement * Remove is_train as redundant * Cleanup * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Cleanup2 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update classifier.py * Update augmentations.py * fix: imshow clip warning * update * Revert ToTensorV2 removal * Update classifier.py * Update normalize values, revert uint8 * normalize image using cv2 * remove dedundant comment * Update classifier.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * replace print with logger * commit steps * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * support final model logging * update * update * update * update * remove curses * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update classifier.py * Update __init__.py Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com> * Update * Update * Update * Update * Update dataset download * Update dataset download * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Pass imgsz to classify_transforms() * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Cos scheduler * Cos scheduler * Remove unused args * Update * Add seed * Add seed * Update * Update * Add run(), main() * Merge master * Merge master * Update * Update * Update * Update * Update * Update * Update * Create YOLOv5 BaseModel class (#8829) * Create BaseModel * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * Hub load device fix * Update Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * Add experiment * Merge master * Attach names * weight decay = 1e-4 * weight decay = 5e-5 * update smart_optimizer console printout * fashion-mnist fix * Merge master * Update Table * Update Table * Remove destroy process group * add kwargs to forward() * fuse fix for resnet50 * nc, names fix for resnet50 * nc, names fix for resnet50 * ONNX CPU inference fix * revert * cuda * if augment or visualize * if augment or visualize * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * New smart_inference_mode() * Update README * Refactor into /classify dir * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * reset defaults * reset defaults * fix gpu predict * warmup * ema half fix * spacing * remove data * remove cache * remove denormalize * save run settings * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * verbose false on initial plots * new save_yaml() function * Update ci-testing.yml * Path(data) CI fix * Separate classification CI * fix val * fix val * fix val * smartCrossEntropyLoss * skip validation on hub load * autodownload with working dir root * str(data) * Dataset usage example * im_show normalize * im_show normalize * add imagenet simple names to multibackend * Add validation speeds * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * 24-space names * Update bash scripts * Update permissions * Add bash script arguments * remove verbose * TRT data fix * names generator fix * optimize if names * update usage * Add local loading * Verbose=False * update names printing * Add Usage examples * Add Usage examples * Add Usage examples * Add Usage examples * named_children * reshape_classifier_outputs * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update * update * fix CI * fix incorrect class substitution * fix incorrect class substitution * remove denormalize * ravel fix * cleanup * update opt file printing * update opt file printing * update defaults * add opt to checkpoint * Add warning * Add comment * plot half bug fix * Use NotImplementedError * fix export shape report * Fix TRT load * cleanup CI * profile comment * CI fix * Add cls models * avoid inplace error * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix usage examples * Update README * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update README * Update README * Update README * Update README * Update README * Update README * Update README * Update README * Update README * Update README * Update README * Update README * Update README * Update README * Update README * Update README Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
110 lines
3.9 KiB
Python
110 lines
3.9 KiB
Python
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
|
|
"""
|
|
Run classification inference on images
|
|
|
|
Usage:
|
|
$ python classify/predict.py --weights yolov5s-cls.pt --source im.jpg
|
|
"""
|
|
|
|
import argparse
|
|
import os
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
import cv2
|
|
import torch.nn.functional as F
|
|
|
|
FILE = Path(__file__).resolve()
|
|
ROOT = FILE.parents[1] # YOLOv5 root directory
|
|
if str(ROOT) not in sys.path:
|
|
sys.path.append(str(ROOT)) # add ROOT to PATH
|
|
ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
|
|
|
|
from classify.train import imshow_cls
|
|
from models.common import DetectMultiBackend
|
|
from utils.augmentations import classify_transforms
|
|
from utils.general import LOGGER, check_requirements, colorstr, increment_path, print_args
|
|
from utils.torch_utils import select_device, smart_inference_mode, time_sync
|
|
|
|
|
|
@smart_inference_mode()
|
|
def run(
|
|
weights=ROOT / 'yolov5s-cls.pt', # model.pt path(s)
|
|
source=ROOT / 'data/images/bus.jpg', # file/dir/URL/glob, 0 for webcam
|
|
imgsz=224, # inference size
|
|
device='', # cuda device, i.e. 0 or 0,1,2,3 or cpu
|
|
half=False, # use FP16 half-precision inference
|
|
dnn=False, # use OpenCV DNN for ONNX inference
|
|
show=True,
|
|
project=ROOT / 'runs/predict-cls', # save to project/name
|
|
name='exp', # save to project/name
|
|
exist_ok=False, # existing project/name ok, do not increment
|
|
):
|
|
file = str(source)
|
|
seen, dt = 1, [0.0, 0.0, 0.0]
|
|
device = select_device(device)
|
|
|
|
# Directories
|
|
save_dir = increment_path(Path(project) / name, exist_ok=exist_ok) # increment run
|
|
save_dir.mkdir(parents=True, exist_ok=True) # make dir
|
|
|
|
# Transforms
|
|
transforms = classify_transforms(imgsz)
|
|
|
|
# Load model
|
|
model = DetectMultiBackend(weights, device=device, dnn=dnn, fp16=half)
|
|
model.warmup(imgsz=(1, 3, imgsz, imgsz)) # warmup
|
|
|
|
# Image
|
|
t1 = time_sync()
|
|
im = cv2.cvtColor(cv2.imread(file), cv2.COLOR_BGR2RGB)
|
|
im = transforms(im).unsqueeze(0).to(device)
|
|
im = im.half() if model.fp16 else im.float()
|
|
t2 = time_sync()
|
|
dt[0] += t2 - t1
|
|
|
|
# Inference
|
|
results = model(im)
|
|
t3 = time_sync()
|
|
dt[1] += t3 - t2
|
|
|
|
p = F.softmax(results, dim=1) # probabilities
|
|
i = p.argsort(1, descending=True)[:, :5].squeeze() # top 5 indices
|
|
dt[2] += time_sync() - t3
|
|
LOGGER.info(f"image 1/1 {file}: {imgsz}x{imgsz} {', '.join(f'{model.names[j]} {p[0, j]:.2f}' for j in i)}")
|
|
|
|
# Print results
|
|
t = tuple(x / seen * 1E3 for x in dt) # speeds per image
|
|
shape = (1, 3, imgsz, imgsz)
|
|
LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms post-process per image at shape {shape}' % t)
|
|
if show:
|
|
imshow_cls(im, f=save_dir / Path(file).name, verbose=True)
|
|
LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}")
|
|
return p
|
|
|
|
|
|
def parse_opt():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--weights', nargs='+', type=str, default=ROOT / 'yolov5s-cls.pt', help='model path(s)')
|
|
parser.add_argument('--source', type=str, default=ROOT / 'data/images/bus.jpg', help='file')
|
|
parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=224, help='train, val image size (pixels)')
|
|
parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
|
|
parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference')
|
|
parser.add_argument('--dnn', action='store_true', help='use OpenCV DNN for ONNX inference')
|
|
parser.add_argument('--project', default=ROOT / 'runs/predict-cls', help='save to project/name')
|
|
parser.add_argument('--name', default='exp', help='save to project/name')
|
|
parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
|
|
opt = parser.parse_args()
|
|
print_args(vars(opt))
|
|
return opt
|
|
|
|
|
|
def main(opt):
|
|
check_requirements(exclude=('tensorboard', 'thop'))
|
|
run(**vars(opt))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
opt = parse_opt()
|
|
main(opt)
|