yolov5/classifier.py

313 lines
13 KiB
Python
Raw Normal View History

2021-10-09 15:22:12 -07:00
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
"""
Train a YOLOv5 classifier model on a classification dataset
Usage-train:
2021-12-08 17:48:05 +01:00
$ python path/to/classifier.py --model yolov5s --data mnist --epochs 5 --img 128 --adam
2021-10-09 15:22:12 -07:00
Usage-inference:
2021-12-19 17:55:16 +01:00
from classifier import *
2021-10-09 15:22:12 -07:00
2021-12-19 18:04:52 +01:00
model = torch.load('path/to/best.pt', map_location=torch.device('cpu'))['model'].float()
files = Path('../datasets/mnist/test/7').glob('*.png') # images from dir
for f in list(files)[:10]: # first 10 images
2021-12-19 17:55:16 +01:00
classify(model, size=128, file=f)
2021-10-09 15:22:12 -07:00
"""
2021-03-09 15:04:04 -08:00
2020-12-07 05:27:01 +01:00
import argparse
2021-03-09 15:00:22 -08:00
import math
2020-12-07 05:27:01 +01:00
import os
2021-12-11 13:00:42 +01:00
import sys
2021-03-09 15:00:22 -08:00
from copy import deepcopy
2021-12-19 18:17:52 +01:00
from datetime import datetime
2020-12-07 05:27:01 +01:00
from pathlib import Path
import torch
import torch.nn as nn
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
import torchvision
import torchvision.transforms as T
from torch.cuda import amp
from tqdm import tqdm
2021-12-11 13:00:42 +01:00
FILE = Path(__file__).resolve()
ROOT = FILE.parents[0] # YOLOv5 root directory
if str(ROOT) not in sys.path:
sys.path.append(str(ROOT)) # add ROOT to PATH
ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
2021-12-08 16:15:31 +01:00
from models.common import Classify, DetectMultiBackend
2021-12-19 16:34:56 +01:00
from utils.general import NUM_THREADS, download, check_file, increment_path, check_git_status, check_requirements, \
colorstr
2021-12-19 18:17:52 +01:00
from utils.torch_utils import model_info, select_device, de_parallel
# Functions
normalize = lambda x, mean=0.5, std=0.25: (x - mean) / std
denormalize = lambda x, mean=0.5, std=0.25: x * std + mean
2020-12-07 05:27:01 +01:00
def train():
2021-03-09 17:08:01 -08:00
save_dir, data, bs, epochs, nw, imgsz = Path(opt.save_dir), opt.data, opt.batch_size, opt.epochs, \
2021-12-11 14:34:12 +01:00
min(NUM_THREADS, opt.workers), opt.img_size
2021-03-09 15:00:22 -08:00
# Directories
wdir = save_dir / 'weights'
wdir.mkdir(parents=True, exist_ok=True) # make dir
last, best = wdir / 'last.pt', wdir / 'best.pt'
2020-12-07 05:27:01 +01:00
# Download Dataset
2021-12-11 13:00:42 +01:00
data_dir = FILE.parents[1] / 'datasets' / data
if not data_dir.is_dir():
url = f'https://github.com/ultralytics/yolov5/releases/download/v1.0/{data}.zip'
download(url, dir=data_dir.parent)
2020-12-07 05:27:01 +01:00
# Transforms
trainform = T.Compose([T.RandomGrayscale(p=0.01),
T.RandomHorizontalFlip(p=0.5),
T.RandomAffine(degrees=1, translate=(.2, .2), scale=(1 / 1.5, 1.5),
2021-03-09 15:05:15 -08:00
shear=(-1, 1, -1, 1), fill=(114, 114, 114)),
2021-03-09 17:29:45 -08:00
# T.Resize([imgsz, imgsz]), # very slow
2020-12-07 05:27:01 +01:00
T.ToTensor(),
T.Normalize((0.5, 0.5, 0.5), (0.25, 0.25, 0.25))]) # PILImage from [0, 1] to [-1, 1]
2021-03-09 17:42:15 -08:00
testform = T.Compose(trainform.transforms[-2:])
2020-12-07 05:27:01 +01:00
# Dataloaders
2021-12-11 13:00:42 +01:00
trainset = torchvision.datasets.ImageFolder(root=data_dir / 'train', transform=trainform)
2020-12-07 05:27:01 +01:00
trainloader = torch.utils.data.DataLoader(trainset, batch_size=bs, shuffle=True, num_workers=nw)
2021-12-11 13:00:42 +01:00
testset = torchvision.datasets.ImageFolder(root=data_dir / 'test', transform=testform)
2021-12-19 16:34:56 +01:00
testloader = torch.utils.data.DataLoader(testset, batch_size=bs, shuffle=True, num_workers=nw)
2020-12-07 05:27:01 +01:00
names = trainset.classes
nc = len(names)
print(f'Training {opt.model} on {data} dataset with {nc} classes...')
# Show images
2021-12-19 16:34:56 +01:00
images, labels = iter(trainloader).next()
2021-12-19 17:14:20 +01:00
imshow(denormalize(images[:64]), labels[:64], names=names, f=save_dir / 'train_images.jpg')
2020-12-07 05:27:01 +01:00
# Model
if opt.model.startswith('yolov5'):
# YOLOv5 Classifier
2021-12-08 16:15:31 +01:00
model = torch.hub.load('ultralytics/yolov5', opt.model, pretrained=True, autoshape=False)
if isinstance(model, DetectMultiBackend):
model = model.model # unwrap DetectMultiBackend
2021-12-06 20:50:38 +01:00
model.model = model.model[:8] # backbone
2020-12-07 05:27:01 +01:00
m = model.model[-1] # last layer
ch = m.conv.in_channels if hasattr(m, 'conv') else sum([x.in_channels for x in m.m]) # ch into module
c = Classify(ch, nc) # Classify()
c.i, c.f, c.type = m.i, m.f, 'models.common.Classify' # index, from, type
model.model[-1] = c # replace
2021-05-19 20:39:55 +02:00
for p in model.parameters():
p.requires_grad = True # for training
2021-05-19 13:25:49 +02:00
elif opt.model in torch.hub.list('rwightman/gen-efficientnet-pytorch'): # i.e. efficientnet_b0
2020-12-07 05:27:01 +01:00
model = torch.hub.load('rwightman/gen-efficientnet-pytorch', opt.model, pretrained=True)
model.classifier = nn.Linear(model.classifier.in_features, nc)
else: # try torchvision
model = torchvision.models.__dict__[opt.model](pretrained=True)
model.fc = nn.Linear(model.fc.weight.shape[1], nc)
2021-03-09 15:22:14 -08:00
# print(model) # debug
2020-12-07 05:27:01 +01:00
model_info(model)
# Optimizer
lr0 = 0.0001 * bs # intial lr
lrf = 0.01 # final lr (fraction of lr0)
if opt.adam:
optimizer = optim.Adam(model.parameters(), lr=lr0 / 10)
else:
optimizer = optim.SGD(model.parameters(), lr=lr0, momentum=0.9, nesterov=True)
# Scheduler
lf = lambda x: ((1 + math.cos(x * math.pi / epochs)) / 2) * (1 - lrf) + lrf # cosine
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
2021-12-08 17:19:58 +01:00
# scheduler = lr_scheduler.OneCycleLR(optimizer, max_lr=lr0, total_steps=epochs, pct_start=0.1,
# final_div_factor=1 / 25 / lrf)
2020-12-07 05:27:01 +01:00
# Train
model = model.to(device)
criterion = nn.CrossEntropyLoss() # loss function
2021-12-06 20:55:05 +01:00
best_fitness = 0.0
2021-12-08 17:19:58 +01:00
# scaler = amp.GradScaler(enabled=cuda)
2021-03-09 17:29:45 -08:00
print(f'Image sizes {imgsz} train, {imgsz} test\n'
2021-03-09 17:08:01 -08:00
f'Using {nw} dataloader workers\n'
2021-12-19 16:34:56 +01:00
f"Logging results to {colorstr('bold', save_dir)}\n"
2021-03-09 17:08:01 -08:00
f'Starting training for {epochs} epochs...\n\n'
f"{'epoch':10s}{'gpu_mem':10s}{'train_loss':12s}{'val_loss':12s}{'accuracy':12s}")
2020-12-07 05:27:01 +01:00
for epoch in range(epochs): # loop over the dataset multiple times
2021-12-06 20:55:05 +01:00
mloss = 0.0 # mean loss
2020-12-07 05:27:01 +01:00
model.train()
2021-12-06 20:55:05 +01:00
pbar = tqdm(enumerate(trainloader), total=len(trainloader), bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}')
for i, (images, labels) in pbar: # progress bar
2021-03-09 17:42:15 -08:00
images, labels = resize(images.to(device)), labels.to(device)
2020-12-07 05:27:01 +01:00
# Forward
2021-12-08 17:19:58 +01:00
with amp.autocast(enabled=False): # stability issues when enabled
2020-12-07 05:27:01 +01:00
loss = criterion(model(images), labels)
# Backward
2021-12-08 17:19:58 +01:00
loss.backward() # scaler.scale(loss).backward()
2020-12-07 05:27:01 +01:00
# Optimize
optimizer.step() # scaler.step(optimizer); scaler.update()
optimizer.zero_grad()
# Print
mloss += loss.item()
mem = '%.3gG' % (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0) # (GB)
pbar.desc = f"{'%s/%s' % (epoch + 1, epochs):10s}{mem:10s}{mloss / (i + 1):<12.3g}"
# Test
if i == len(pbar) - 1:
2021-03-09 15:00:22 -08:00
fitness = test(model, testloader, names, criterion, pbar=pbar) # test
2020-12-07 05:27:01 +01:00
2021-03-09 15:00:22 -08:00
# Scheduler
2020-12-07 05:27:01 +01:00
scheduler.step()
2021-03-09 15:22:14 -08:00
# Best fitness
2021-03-09 15:00:22 -08:00
if fitness > best_fitness:
best_fitness = fitness
# Save model
final_epoch = epoch + 1 == epochs
if (not opt.nosave) or final_epoch:
ckpt = {'epoch': epoch,
'best_fitness': best_fitness,
2021-12-19 18:17:52 +01:00
'model': deepcopy(de_parallel(model)).half(),
'optimizer': None, # optimizer.state_dict()
'date': datetime.now().isoformat()}
2021-03-09 15:00:22 -08:00
# Save last, best and delete
torch.save(ckpt, last)
if best_fitness == fitness:
torch.save(ckpt, best)
del ckpt
2021-03-09 15:13:49 -08:00
# Train complete
if final_epoch:
2021-03-09 15:22:14 -08:00
print(f'Training complete. Results saved to {save_dir}.')
2021-03-09 15:13:49 -08:00
2021-12-19 16:34:56 +01:00
# Show predictions
2021-12-19 16:04:43 +01:00
images, labels = iter(testloader).next()
2021-12-19 16:14:44 +01:00
images = resize(images.to(device))
2021-12-19 16:34:56 +01:00
pred = torch.max(model(images), 1)[1]
2021-12-19 17:14:20 +01:00
imshow(denormalize(images), labels, pred, names, verbose=True, f=save_dir / 'test_images.jpg')
2020-12-07 05:27:01 +01:00
def test(model, dataloader, names, criterion=None, verbose=False, pbar=None):
model.eval()
pred, targets, loss = [], [], 0
with torch.no_grad():
for images, labels in dataloader:
2021-03-09 17:42:15 -08:00
images, labels = resize(images.to(device)), labels.to(device)
2020-12-07 05:27:01 +01:00
y = model(images)
pred.append(torch.max(y, 1)[1])
targets.append(labels)
if criterion:
loss += criterion(y, labels)
pred, targets = torch.cat(pred), torch.cat(targets)
correct = (targets == pred).float()
if pbar:
pbar.desc += f"{loss / len(dataloader):<12.3g}{correct.mean().item():<12.3g}"
2021-03-09 15:00:22 -08:00
accuracy = correct.mean().item()
2020-12-07 05:27:01 +01:00
if verbose: # all classes
2021-03-09 16:59:25 -08:00
print(f"{'class':10s}{'number':10s}{'accuracy':10s}")
print(f"{'all':10s}{correct.shape[0]:10s}{accuracy:10.5g}")
2020-12-07 05:27:01 +01:00
for i, c in enumerate(names):
t = correct[targets == i]
2021-03-09 16:59:25 -08:00
print(f"{c:10s}{t.shape[0]:10s}{t.mean().item():10.5g}")
2020-12-07 05:27:01 +01:00
2021-03-09 15:00:22 -08:00
return accuracy
2020-12-07 05:27:01 +01:00
2021-12-19 17:55:16 +01:00
def classify(model, size=128, file='../datasets/mnist/test/3/30.png', plot=False):
# YOLOv5 classification model inference
import cv2
import numpy as np
import torch.nn.functional as F
resize = torch.nn.Upsample(size=(size, size), mode='bilinear', align_corners=False) # image resize
# Image
im = cv2.imread(str(file))[..., ::-1] # HWC, BGR to RGB
im = np.ascontiguousarray(np.asarray(im).transpose((2, 0, 1))) # HWC to CHW
im = torch.tensor(im).float().unsqueeze(0) / 255.0 # to Tensor, to BCWH, rescale
im = resize(normalize(im))
# Inference
results = model(im)
p = F.softmax(results, dim=1) # probabilities
i = p.argmax() # max index
print(f'{file} prediction: {i} ({p[0, i]:.2f})')
# Plot
if plot:
denormalize = lambda x, mean=0.5, std=0.25: x * std + mean
imshow(denormalize(im), f=Path(file).name)
return p
2021-12-19 17:05:21 +01:00
def imshow(img, labels=None, pred=None, names=None, nmax=64, verbose=False, f=Path('images.jpg')):
2021-12-19 17:14:20 +01:00
# Show classification image grid with labels (optional) and predictions (optional)
2021-12-11 13:00:42 +01:00
import matplotlib.pyplot as plt
2021-12-19 17:05:21 +01:00
2021-12-19 16:34:56 +01:00
names = names or [f'class{i}' for i in range(1000)]
2021-12-19 17:14:20 +01:00
blocks = torch.chunk(img.cpu(), len(img), dim=0) # select batch index 0, block by channels
2021-12-19 17:05:21 +01:00
n = min(len(blocks), nmax) # number of plots
2021-12-19 17:55:16 +01:00
m = min(8, round(n ** 0.5)) # 8 x 8 default
fig, ax = plt.subplots(math.ceil(n / m), m, tight_layout=True) # 8 rows x n/8 cols
ax = ax.ravel() if m > 1 else [ax]
2021-12-19 17:05:21 +01:00
plt.subplots_adjust(wspace=0.05, hspace=0.05)
for i in range(n):
ax[i].imshow(blocks[i].squeeze().permute((1, 2, 0))) # cmap='gray'
ax[i].axis('off')
if labels is not None:
2021-12-19 17:08:15 +01:00
s = names[labels[i]] + (f'{names[pred[i]]}' if pred is not None else '')
2021-12-19 17:05:21 +01:00
ax[i].set_title(s)
plt.savefig(f, dpi=300, bbox_inches='tight')
plt.close()
2021-12-19 18:22:02 +01:00
print(colorstr('imshow: ') + f"examples saved to {f}")
2021-12-19 17:05:21 +01:00
if verbose and labels is not None:
2021-12-19 16:34:56 +01:00
print('True: ', ' '.join(f'{names[i]:3s}' for i in labels))
2021-12-19 17:05:21 +01:00
if verbose and pred is not None:
2021-12-19 16:34:56 +01:00
print('Predicted:', ' '.join(f'{names[i]:3s}' for i in pred))
2021-12-11 13:00:42 +01:00
2020-12-07 05:27:01 +01:00
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, default='yolov5s', help='initial weights path')
2021-12-11 13:36:02 +01:00
parser.add_argument('--data', type=str, default='mnist', help='cifar10, cifar100, mnist or mnist-fashion')
2021-10-09 14:40:42 -07:00
parser.add_argument('--hyp', type=str, default='data/hyps/hyp.scratch.yaml', help='hyperparameters path')
2020-12-07 05:27:01 +01:00
parser.add_argument('--epochs', type=int, default=20)
2021-12-19 18:22:02 +01:00
parser.add_argument('--batch-size', type=int, default=256, help='total batch size for all GPUs')
parser.add_argument('--img-size', type=int, default=128, help='train, test image sizes (pixels)')
2021-03-09 15:00:22 -08:00
parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
2020-12-07 05:27:01 +01:00
parser.add_argument('--adam', action='store_true', help='use torch.optim.Adam() optimizer')
parser.add_argument('--evolve', action='store_true', help='evolve hyperparameters')
parser.add_argument('--cache-images', action='store_true', help='cache images for faster training')
parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
2021-12-19 18:22:02 +01:00
parser.add_argument('--workers', type=int, default=8, help='max dataloader workers (per RANK in DDP mode)')
2020-12-07 05:27:01 +01:00
parser.add_argument('--project', default='runs/train', help='save to project/name')
parser.add_argument('--name', default='exp', help='save to project/name')
parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
opt = parser.parse_args()
2021-03-09 17:47:36 -08:00
# Checks
check_git_status()
check_requirements()
# Parameters
2020-12-07 05:27:01 +01:00
device = select_device(opt.device, batch_size=opt.batch_size)
cuda = device.type != 'cpu'
opt.hyp = check_file(opt.hyp) # check files
opt.save_dir = increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok | opt.evolve) # increment run
2021-03-09 17:42:15 -08:00
resize = torch.nn.Upsample(size=(opt.img_size, opt.img_size), mode='bilinear', align_corners=False) # image resize
2020-12-07 05:27:01 +01:00
2021-03-09 17:47:36 -08:00
# Train
2020-12-07 05:27:01 +01:00
train()