Add model inference (#16)

* add model inference on single image

* rm --eval

* revise doc

* add inference tool and demo

* fix linting

* rename inference_image to inference_model

* infer pred_label and pred_score

* fix linting

* add docstr for inference

* add remove_keys

* add doc for inference

* dump results rather than outputs

* add class_names

* add related infer scripts

* add demo image and the first part of colab tutorial

* conduct evaluation in dataset

* return lst in simple_test

* compuate topk accuracy with numpy

* return outputs in test api

* merge inference and evaluation tool

* fix typo

* rm gt_labels in test conifg

* get gt_labels during evaluation

* sperate the ipython notebook to another PR

* return tensor for onnx_export

* detach var in simple_test

* rm inference script

* rm inference script

* construct data dict to replace LoadImage

* print first predicted result if args.out is None

* modify test_pipeline in inference

* refactor class_names of imagenet

* set class_to_idx as a property in base dataset

* output pred_class during inference

* remove unused docstr
pull/52/head
Lei Yang 2020-09-30 19:00:20 +08:00 committed by GitHub
parent b6774b1224
commit 9547e7b7a5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 1309 additions and 117 deletions

View File

@ -17,8 +17,7 @@ test_pipeline = [
dict(type='CenterCrop', crop_size=224),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='ToTensor', keys=['gt_label']),
dict(type='Collect', keys=['img', 'gt_label'])
dict(type='Collect', keys=['img'])
]
data = dict(
samples_per_gpu=32,

View File

@ -17,8 +17,7 @@ test_pipeline = [
dict(type='CenterCrop', crop_size=224),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='ToTensor', keys=['gt_label']),
dict(type='Collect', keys=['img', 'gt_label'])
dict(type='Collect', keys=['img'])
]
data = dict(
samples_per_gpu=32,

View File

@ -17,8 +17,7 @@ test_pipeline = [
dict(type='CenterCrop', crop_size=224),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='ToTensor', keys=['gt_label']),
dict(type='Collect', keys=['img', 'gt_label'])
dict(type='Collect', keys=['img'])
]
data = dict(
samples_per_gpu=64,

View File

@ -17,8 +17,7 @@ test_pipeline = [
dict(type='CenterCrop', crop_size=224),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='ToTensor', keys=['gt_label']),
dict(type='Collect', keys=['img', 'gt_label'])
dict(type='Collect', keys=['img'])
]
data = dict(
samples_per_gpu=64,

BIN
demo/demo.JPEG 100755

Binary file not shown.

After

Width:  |  Height:  |  Size: 107 KiB

24
demo/image_demo.py 100644
View File

@ -0,0 +1,24 @@
from argparse import ArgumentParser
from mmcls.apis import inference_model, init_model
def main():
parser = ArgumentParser()
parser.add_argument('img', help='Image file')
parser.add_argument('config', help='Config file')
parser.add_argument('checkpoint', help='Checkpoint file')
parser.add_argument(
'--device', default='cuda:0', help='Device used for inference')
args = parser.parse_args()
# build the model from a config file and a checkpoint file
model = init_model(args.config, args.checkpoint, device=args.device)
# test a single image
result = inference_model(model, args.img)
# print result on terminal
print(result)
if __name__ == '__main__':
main()

View File

@ -42,7 +42,42 @@ For using custom datasets, please refer to [Tutorials 2: Adding New Dataset](tut
## Inference with pretrained models
We provide testing scripts to evaluate a whole dataset (ImageNet, etc.).
We provide scripts to inference a single image, inference a dataset and test a dataset (e.g., ImageNet).
### Inference a single image
```shell
python demo/image_demo.py ${IMAGE_FILE} ${CONFIG_FILE} ${CHECKPOINT_FILE}
```
### Inference a dataset
- single GPU
- single node multiple GPU
- multiple node
You can use the following commands to infer a dataset.
```shell
# single-gpu inference
python tools/inference.py ${CONFIG_FILE} ${CHECKPOINT_FILE} [--out ${RESULT_FILE}]
# multi-gpu inference
./tools/dist_inference.sh ${CONFIG_FILE} ${CHECKPOINT_FILE} ${GPU_NUM} [--out ${RESULT_FILE}]
```
Optional arguments:
- `RESULT_FILE`: Filename of the output results in pickle format. If not specified, the results will not be saved to a file.
Examples:
Assume that you have already downloaded the checkpoints to the directory `checkpoints/`.
Infer ResNet-50 on ImageNet validation set to get predicted labels and their corresponding predicted scores.
```shell
python tools/inference.py configs/imagenet/resnet50_batch256.py \
checkpoints/xxx.pth
```
### Test a dataset

View File

@ -1,6 +1,79 @@
def init_model():
pass
import warnings
import mmcv
import numpy as np
import torch
from mmcv.parallel import collate, scatter
from mmcv.runner import load_checkpoint
from mmcls.datasets.pipelines import Compose
from mmcls.models import build_classifier
def inference_model():
pass
def init_model(config, checkpoint=None, device='cuda:0'):
"""Initialize a classifier from config file.
Args:
config (str or :obj:`mmcv.Config`): Config file path or the config
object.
checkpoint (str, optional): Checkpoint path. If left as None, the model
will not load any weights.
Returns:
nn.Module: The constructed classifier.
"""
if isinstance(config, str):
config = mmcv.Config.fromfile(config)
elif not isinstance(config, mmcv.Config):
raise TypeError('config must be a filename or Config object, '
f'but got {type(config)}')
config.model.pretrained = None
model = build_classifier(config.model)
if checkpoint is not None:
map_loc = 'cpu' if device == 'cpu' else None
checkpoint = load_checkpoint(model, checkpoint, map_location=map_loc)
if 'CLASSES' in checkpoint['meta']:
model.CLASSES = checkpoint['meta']['CLASSES']
else:
from mmcls.datasets import ImageNet
warnings.simplefilter('once')
warnings.warn('Class names are not saved in the checkpoint\'s '
'meta data, use imagenet by default.')
model.CLASSES = ImageNet.CLASSES
model.cfg = config # save the config in the model for convenience
model.to(device)
model.eval()
return model
def inference_model(model, img):
"""Inference image(s) with the classifier.
Args:
model (nn.Module): The loaded classifier.
img (str/ndarray): The image filename.
Returns:
result (dict): The classification results that contains
`class_name`, `pred_label` and `pred_score`.
"""
cfg = model.cfg
device = next(model.parameters()).device # model device
# build the data pipeline
test_pipeline = Compose(cfg.data.test.pipeline)
# prepare data
data = dict(img_info=dict(filename=img), img_prefix=None)
data = test_pipeline(data)
data = collate([data], samples_per_gpu=1)
if next(model.parameters()).is_cuda:
# scatter to specified GPU
data = scatter(data, [device])[0]
# forward the model
with torch.no_grad():
scores = model(return_loss=False, **data)
pred_score = np.max(scores, axis=1)[0]
pred_label = np.argmax(scores, axis=1)[0]
result = {'pred_label': pred_label, 'pred_score': pred_score}
result['class_name'] = model.CLASSES[result['pred_label']]
return result

View File

@ -17,7 +17,7 @@ def single_gpu_test(model, data_loader, show=False, out_dir=None):
prog_bar = mmcv.ProgressBar(len(dataset))
for i, data in enumerate(data_loader):
with torch.no_grad():
result = model(return_loss=True, **data)
result = model(return_loss=False, **data)
results.append(result)
if show or out_dir:
@ -57,8 +57,11 @@ def multi_gpu_test(model, data_loader, tmpdir=None, gpu_collect=False):
time.sleep(2) # This line can prevent deadlock problem in some cases.
for i, data in enumerate(data_loader):
with torch.no_grad():
result = model(return_loss=True, **data)
results.append(result)
result = model(return_loss=False, **data)
if isinstance(result, list):
results.extend(result)
else:
results.append(result)
if rank == 0:
batch_size = data['img'].size(0)

View File

@ -4,6 +4,7 @@ from abc import ABCMeta, abstractmethod
import numpy as np
from torch.utils.data import Dataset
from mmcls.models.losses import accuracy
from .pipelines import Compose
@ -35,6 +36,14 @@ class BaseDataset(Dataset, metaclass=ABCMeta):
def load_annotations(self):
pass
@property
def class_to_idx(self):
return {_class: i for i, _class in enumerate(self.CLASSES)}
def get_gt_labels(self):
gt_labels = np.array([data['gt_label'] for data in self.data_infos])
return gt_labels
def prepare_data(self, idx):
results = copy.deepcopy(self.data_infos[idx])
return self.pipeline(results)
@ -45,7 +54,11 @@ class BaseDataset(Dataset, metaclass=ABCMeta):
def __getitem__(self, idx):
return self.prepare_data(idx)
def evaluate(self, results, metric='accuracy', logger=None):
def evaluate(self,
results,
metric='accuracy',
metric_options={'topk': (1, 5)},
logger=None):
"""Evaluate the dataset.
Args:
@ -63,16 +76,14 @@ class BaseDataset(Dataset, metaclass=ABCMeta):
allowed_metrics = ['accuracy']
if metric not in allowed_metrics:
raise KeyError(f'metric {metric} is not supported')
eval_results = {}
if metric == 'accuracy':
nums = []
for result in results:
nums.append(result['num_samples'].item())
for topk, v in result['accuracy'].items():
if topk not in eval_results:
eval_results[topk] = []
eval_results[topk].append(v.item())
assert sum(nums) == len(self.data_infos)
for topk, accs in eval_results.items():
eval_results[topk] = np.average(accs, weights=nums)
topk = metric_options.get('topk')
results = np.vstack(results)
gt_labels = self.get_gt_labels()
num_imgs = len(results)
assert len(gt_labels) == num_imgs
acc = accuracy(results, gt_labels, topk)
eval_results = {f'top-{k}': a.item() for k, a in zip(topk, acc)}
return eval_results

View File

@ -89,10 +89,6 @@ class CIFAR10(BaseDataset):
with open(path, 'rb') as infile:
data = pickle.load(infile, encoding='latin1')
self.CLASSES = data[self.meta['key']]
self.class_to_idx = {
_class: i
for i, _class in enumerate(self.CLASSES)
}
def _check_integrity(self):
root = self.data_prefix

File diff suppressed because it is too large Load Diff

View File

@ -35,10 +35,6 @@ class MNIST(BaseDataset):
'6 - six', '7 - seven', '8 - eight', '9 - nine'
]
@property
def class_to_idx(self):
return {_class: i for i, _class in enumerate(self.CLASSES)}
def load_annotations(self):
train_image_file = osp.join(
self.data_prefix, rm_suffix(self.resources['train_image_file'][0]))

View File

@ -38,7 +38,6 @@ class ClsHead(BaseHead):
assert len(acc) == len(self.topk)
losses['loss'] = loss
losses['accuracy'] = {f'top-{k}': a for k, a in zip(self.topk, acc)}
losses['num_samples'] = loss.new(1).fill_(num_samples)
return losses
def forward_train(self, cls_score, gt_label):

View File

@ -1,4 +1,6 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import normal_init
from ..builder import HEADS
@ -37,6 +39,17 @@ class LinearClsHead(ClsHead):
def init_weights(self):
normal_init(self.fc, mean=0, std=0.01, bias=0)
def simple_test(self, img):
"""Test without augmentation."""
cls_score = self.fc(img)
if isinstance(cls_score, list):
cls_score = sum(cls_score) / float(len(cls_score))
pred = F.softmax(cls_score, dim=1) if cls_score is not None else None
if torch.onnx.is_in_onnx_export():
return pred
pred = list(pred.detach().cpu().numpy())
return pred
def forward_train(self, x, gt_label):
cls_score = self.fc(x)
losses = self.loss(cls_score, gt_label)

View File

@ -1,12 +1,41 @@
import numpy as np
import torch
import torch.nn as nn
def accuracy_numpy(pred, target, topk):
res = []
maxk = max(topk)
num = pred.shape[0]
pred_label = pred.argsort(axis=1)[:, -maxk:][:, ::-1]
for k in topk:
correct_k = np.logical_or.reduce(
pred_label[:, :k] == target.reshape(-1, 1), axis=1)
res.append(correct_k.sum() * 100. / num)
return res
def accuracy_torch(pred, target, topk=1):
res = []
maxk = max(topk)
num = pred.size(0)
_, pred_label = pred.topk(maxk, dim=1)
pred_label = pred_label.t()
correct = pred_label.eq(target.view(1, -1).expand_as(pred_label))
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100. / num))
return res
def accuracy(pred, target, topk=1):
"""Calculate accuracy according to the prediction and target
Args:
pred (torch.Tensor): The model prediction.
target (torch.Tensor): The target of each prediction
pred (torch.Tensor | np.array): The model prediction.
target (torch.Tensor | np.array): The target of each prediction
topk (int | tuple[int], optional): If the predictions in ``topk``
matches the target, the predictions will be regarded as
correct ones. Defaults to 1.
@ -25,15 +54,14 @@ def accuracy(pred, target, topk=1):
else:
return_single = False
maxk = max(topk)
_, pred_label = pred.topk(maxk, dim=1)
pred_label = pred_label.t()
correct = pred_label.eq(target.view(1, -1).expand_as(pred_label))
if isinstance(pred, torch.Tensor) and isinstance(target, torch.Tensor):
res = accuracy_torch(pred, target, topk)
elif isinstance(pred, np.ndarray) and isinstance(target, np.ndarray):
res = accuracy_numpy(pred, target, topk)
else:
raise TypeError('pred and target should both be'
'torch.Tensor or np.ndarray')
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / pred.size(0)))
return res[0] if return_single else res

View File

@ -1,5 +1,6 @@
import argparse
import os
import warnings
import mmcv
import numpy as np
@ -19,7 +20,7 @@ def parse_args():
parser.add_argument('checkpoint', help='checkpoint file')
parser.add_argument('--out', help='output result file')
parser.add_argument(
'--eval', type=str, nargs='+', choices=['accuracy'], help='eval types')
'--metric', type=str, default='accuracy', help='evaluation metric')
parser.add_argument(
'--gpu_collect',
action='store_true',
@ -69,7 +70,7 @@ def main():
fp16_cfg = cfg.get('fp16', None)
if fp16_cfg is not None:
wrap_fp16_model(model)
_ = load_checkpoint(model, args.checkpoint, map_location='cpu')
checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu')
if not distributed:
model = MMDataParallel(model, device_ids=[0])
@ -84,21 +85,37 @@ def main():
rank, _ = get_dist_info()
if rank == 0:
nums = []
results = {}
for output in outputs:
nums.append(output['num_samples'].item())
for topk, v in output['accuracy'].items():
if topk not in results:
results[topk] = []
results[topk].append(v.item())
assert sum(nums) == len(dataset)
for topk, accs in results.items():
avg_acc = np.average(accs, weights=nums)
print(f'\n{topk} accuracy: {avg_acc:.2f}')
if args.metric != '':
results = dataset.evaluate(outputs, args.metric)
for topk, acc in results.items():
print(f'\n{topk} accuracy: {acc:.2f}')
else:
scores = np.vstack(outputs)
pred_score = np.max(scores, axis=1)
pred_label = np.argmax(scores, axis=1)
if 'CLASSES' in checkpoint['meta']:
CLASSES = checkpoint['meta']['CLASSES']
else:
from mmcls.datasets import ImageNet
warnings.simplefilter('once')
warnings.warn('Class names are not saved in the checkpoint\'s '
'meta data, use imagenet by default.')
CLASSES = ImageNet.CLASSES
pred_class = [CLASSES[lb] for lb in pred_label]
results = {
'pred_score': pred_score,
'pred_label': pred_label,
'pred_class': pred_class
}
if not args.out:
print('\nthe predicted result for the first element is '
f'pred_score = {pred_score[0]:.2f}, '
f'pred_label = {pred_label[0]} '
f'and pred_class = {pred_class[0]}. '
'Specify --out to save all results to files.')
if args.out and rank == 0:
print(f'\nwriting results to {args.out}')
mmcv.dump(outputs, args.out)
mmcv.dump(results, args.out)
if __name__ == '__main__':