Add model inference (#16)
* add model inference on single image * rm --eval * revise doc * add inference tool and demo * fix linting * rename inference_image to inference_model * infer pred_label and pred_score * fix linting * add docstr for inference * add remove_keys * add doc for inference * dump results rather than outputs * add class_names * add related infer scripts * add demo image and the first part of colab tutorial * conduct evaluation in dataset * return lst in simple_test * compuate topk accuracy with numpy * return outputs in test api * merge inference and evaluation tool * fix typo * rm gt_labels in test conifg * get gt_labels during evaluation * sperate the ipython notebook to another PR * return tensor for onnx_export * detach var in simple_test * rm inference script * rm inference script * construct data dict to replace LoadImage * print first predicted result if args.out is None * modify test_pipeline in inference * refactor class_names of imagenet * set class_to_idx as a property in base dataset * output pred_class during inference * remove unused docstrpull/52/head
parent
b6774b1224
commit
9547e7b7a5
|
@ -17,8 +17,7 @@ test_pipeline = [
|
|||
dict(type='CenterCrop', crop_size=224),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='ToTensor', keys=['gt_label']),
|
||||
dict(type='Collect', keys=['img', 'gt_label'])
|
||||
dict(type='Collect', keys=['img'])
|
||||
]
|
||||
data = dict(
|
||||
samples_per_gpu=32,
|
||||
|
|
|
@ -17,8 +17,7 @@ test_pipeline = [
|
|||
dict(type='CenterCrop', crop_size=224),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='ToTensor', keys=['gt_label']),
|
||||
dict(type='Collect', keys=['img', 'gt_label'])
|
||||
dict(type='Collect', keys=['img'])
|
||||
]
|
||||
data = dict(
|
||||
samples_per_gpu=32,
|
||||
|
|
|
@ -17,8 +17,7 @@ test_pipeline = [
|
|||
dict(type='CenterCrop', crop_size=224),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='ToTensor', keys=['gt_label']),
|
||||
dict(type='Collect', keys=['img', 'gt_label'])
|
||||
dict(type='Collect', keys=['img'])
|
||||
]
|
||||
data = dict(
|
||||
samples_per_gpu=64,
|
||||
|
|
|
@ -17,8 +17,7 @@ test_pipeline = [
|
|||
dict(type='CenterCrop', crop_size=224),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='ToTensor', keys=['gt_label']),
|
||||
dict(type='Collect', keys=['img', 'gt_label'])
|
||||
dict(type='Collect', keys=['img'])
|
||||
]
|
||||
data = dict(
|
||||
samples_per_gpu=64,
|
||||
|
|
Binary file not shown.
After Width: | Height: | Size: 107 KiB |
|
@ -0,0 +1,24 @@
|
|||
from argparse import ArgumentParser
|
||||
|
||||
from mmcls.apis import inference_model, init_model
|
||||
|
||||
|
||||
def main():
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument('img', help='Image file')
|
||||
parser.add_argument('config', help='Config file')
|
||||
parser.add_argument('checkpoint', help='Checkpoint file')
|
||||
parser.add_argument(
|
||||
'--device', default='cuda:0', help='Device used for inference')
|
||||
args = parser.parse_args()
|
||||
|
||||
# build the model from a config file and a checkpoint file
|
||||
model = init_model(args.config, args.checkpoint, device=args.device)
|
||||
# test a single image
|
||||
result = inference_model(model, args.img)
|
||||
# print result on terminal
|
||||
print(result)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -42,7 +42,42 @@ For using custom datasets, please refer to [Tutorials 2: Adding New Dataset](tut
|
|||
|
||||
## Inference with pretrained models
|
||||
|
||||
We provide testing scripts to evaluate a whole dataset (ImageNet, etc.).
|
||||
We provide scripts to inference a single image, inference a dataset and test a dataset (e.g., ImageNet).
|
||||
|
||||
### Inference a single image
|
||||
|
||||
```shell
|
||||
python demo/image_demo.py ${IMAGE_FILE} ${CONFIG_FILE} ${CHECKPOINT_FILE}
|
||||
```
|
||||
|
||||
### Inference a dataset
|
||||
|
||||
- single GPU
|
||||
- single node multiple GPU
|
||||
- multiple node
|
||||
|
||||
You can use the following commands to infer a dataset.
|
||||
|
||||
```shell
|
||||
# single-gpu inference
|
||||
python tools/inference.py ${CONFIG_FILE} ${CHECKPOINT_FILE} [--out ${RESULT_FILE}]
|
||||
|
||||
# multi-gpu inference
|
||||
./tools/dist_inference.sh ${CONFIG_FILE} ${CHECKPOINT_FILE} ${GPU_NUM} [--out ${RESULT_FILE}]
|
||||
```
|
||||
|
||||
Optional arguments:
|
||||
- `RESULT_FILE`: Filename of the output results in pickle format. If not specified, the results will not be saved to a file.
|
||||
|
||||
Examples:
|
||||
|
||||
Assume that you have already downloaded the checkpoints to the directory `checkpoints/`.
|
||||
Infer ResNet-50 on ImageNet validation set to get predicted labels and their corresponding predicted scores.
|
||||
|
||||
```shell
|
||||
python tools/inference.py configs/imagenet/resnet50_batch256.py \
|
||||
checkpoints/xxx.pth
|
||||
```
|
||||
|
||||
### Test a dataset
|
||||
|
||||
|
|
|
@ -1,6 +1,79 @@
|
|||
def init_model():
|
||||
pass
|
||||
import warnings
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmcv.parallel import collate, scatter
|
||||
from mmcv.runner import load_checkpoint
|
||||
|
||||
from mmcls.datasets.pipelines import Compose
|
||||
from mmcls.models import build_classifier
|
||||
|
||||
|
||||
def inference_model():
|
||||
pass
|
||||
def init_model(config, checkpoint=None, device='cuda:0'):
|
||||
"""Initialize a classifier from config file.
|
||||
|
||||
Args:
|
||||
config (str or :obj:`mmcv.Config`): Config file path or the config
|
||||
object.
|
||||
checkpoint (str, optional): Checkpoint path. If left as None, the model
|
||||
will not load any weights.
|
||||
|
||||
Returns:
|
||||
nn.Module: The constructed classifier.
|
||||
"""
|
||||
if isinstance(config, str):
|
||||
config = mmcv.Config.fromfile(config)
|
||||
elif not isinstance(config, mmcv.Config):
|
||||
raise TypeError('config must be a filename or Config object, '
|
||||
f'but got {type(config)}')
|
||||
config.model.pretrained = None
|
||||
model = build_classifier(config.model)
|
||||
if checkpoint is not None:
|
||||
map_loc = 'cpu' if device == 'cpu' else None
|
||||
checkpoint = load_checkpoint(model, checkpoint, map_location=map_loc)
|
||||
if 'CLASSES' in checkpoint['meta']:
|
||||
model.CLASSES = checkpoint['meta']['CLASSES']
|
||||
else:
|
||||
from mmcls.datasets import ImageNet
|
||||
warnings.simplefilter('once')
|
||||
warnings.warn('Class names are not saved in the checkpoint\'s '
|
||||
'meta data, use imagenet by default.')
|
||||
model.CLASSES = ImageNet.CLASSES
|
||||
model.cfg = config # save the config in the model for convenience
|
||||
model.to(device)
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
def inference_model(model, img):
|
||||
"""Inference image(s) with the classifier.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The loaded classifier.
|
||||
img (str/ndarray): The image filename.
|
||||
|
||||
Returns:
|
||||
result (dict): The classification results that contains
|
||||
`class_name`, `pred_label` and `pred_score`.
|
||||
"""
|
||||
cfg = model.cfg
|
||||
device = next(model.parameters()).device # model device
|
||||
# build the data pipeline
|
||||
test_pipeline = Compose(cfg.data.test.pipeline)
|
||||
# prepare data
|
||||
data = dict(img_info=dict(filename=img), img_prefix=None)
|
||||
data = test_pipeline(data)
|
||||
data = collate([data], samples_per_gpu=1)
|
||||
if next(model.parameters()).is_cuda:
|
||||
# scatter to specified GPU
|
||||
data = scatter(data, [device])[0]
|
||||
|
||||
# forward the model
|
||||
with torch.no_grad():
|
||||
scores = model(return_loss=False, **data)
|
||||
pred_score = np.max(scores, axis=1)[0]
|
||||
pred_label = np.argmax(scores, axis=1)[0]
|
||||
result = {'pred_label': pred_label, 'pred_score': pred_score}
|
||||
result['class_name'] = model.CLASSES[result['pred_label']]
|
||||
return result
|
||||
|
|
|
@ -17,7 +17,7 @@ def single_gpu_test(model, data_loader, show=False, out_dir=None):
|
|||
prog_bar = mmcv.ProgressBar(len(dataset))
|
||||
for i, data in enumerate(data_loader):
|
||||
with torch.no_grad():
|
||||
result = model(return_loss=True, **data)
|
||||
result = model(return_loss=False, **data)
|
||||
results.append(result)
|
||||
|
||||
if show or out_dir:
|
||||
|
@ -57,8 +57,11 @@ def multi_gpu_test(model, data_loader, tmpdir=None, gpu_collect=False):
|
|||
time.sleep(2) # This line can prevent deadlock problem in some cases.
|
||||
for i, data in enumerate(data_loader):
|
||||
with torch.no_grad():
|
||||
result = model(return_loss=True, **data)
|
||||
results.append(result)
|
||||
result = model(return_loss=False, **data)
|
||||
if isinstance(result, list):
|
||||
results.extend(result)
|
||||
else:
|
||||
results.append(result)
|
||||
|
||||
if rank == 0:
|
||||
batch_size = data['img'].size(0)
|
||||
|
|
|
@ -4,6 +4,7 @@ from abc import ABCMeta, abstractmethod
|
|||
import numpy as np
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from mmcls.models.losses import accuracy
|
||||
from .pipelines import Compose
|
||||
|
||||
|
||||
|
@ -35,6 +36,14 @@ class BaseDataset(Dataset, metaclass=ABCMeta):
|
|||
def load_annotations(self):
|
||||
pass
|
||||
|
||||
@property
|
||||
def class_to_idx(self):
|
||||
return {_class: i for i, _class in enumerate(self.CLASSES)}
|
||||
|
||||
def get_gt_labels(self):
|
||||
gt_labels = np.array([data['gt_label'] for data in self.data_infos])
|
||||
return gt_labels
|
||||
|
||||
def prepare_data(self, idx):
|
||||
results = copy.deepcopy(self.data_infos[idx])
|
||||
return self.pipeline(results)
|
||||
|
@ -45,7 +54,11 @@ class BaseDataset(Dataset, metaclass=ABCMeta):
|
|||
def __getitem__(self, idx):
|
||||
return self.prepare_data(idx)
|
||||
|
||||
def evaluate(self, results, metric='accuracy', logger=None):
|
||||
def evaluate(self,
|
||||
results,
|
||||
metric='accuracy',
|
||||
metric_options={'topk': (1, 5)},
|
||||
logger=None):
|
||||
"""Evaluate the dataset.
|
||||
|
||||
Args:
|
||||
|
@ -63,16 +76,14 @@ class BaseDataset(Dataset, metaclass=ABCMeta):
|
|||
allowed_metrics = ['accuracy']
|
||||
if metric not in allowed_metrics:
|
||||
raise KeyError(f'metric {metric} is not supported')
|
||||
|
||||
eval_results = {}
|
||||
if metric == 'accuracy':
|
||||
nums = []
|
||||
for result in results:
|
||||
nums.append(result['num_samples'].item())
|
||||
for topk, v in result['accuracy'].items():
|
||||
if topk not in eval_results:
|
||||
eval_results[topk] = []
|
||||
eval_results[topk].append(v.item())
|
||||
assert sum(nums) == len(self.data_infos)
|
||||
for topk, accs in eval_results.items():
|
||||
eval_results[topk] = np.average(accs, weights=nums)
|
||||
topk = metric_options.get('topk')
|
||||
results = np.vstack(results)
|
||||
gt_labels = self.get_gt_labels()
|
||||
num_imgs = len(results)
|
||||
assert len(gt_labels) == num_imgs
|
||||
acc = accuracy(results, gt_labels, topk)
|
||||
eval_results = {f'top-{k}': a.item() for k, a in zip(topk, acc)}
|
||||
return eval_results
|
||||
|
|
|
@ -89,10 +89,6 @@ class CIFAR10(BaseDataset):
|
|||
with open(path, 'rb') as infile:
|
||||
data = pickle.load(infile, encoding='latin1')
|
||||
self.CLASSES = data[self.meta['key']]
|
||||
self.class_to_idx = {
|
||||
_class: i
|
||||
for i, _class in enumerate(self.CLASSES)
|
||||
}
|
||||
|
||||
def _check_integrity(self):
|
||||
root = self.data_prefix
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -35,10 +35,6 @@ class MNIST(BaseDataset):
|
|||
'6 - six', '7 - seven', '8 - eight', '9 - nine'
|
||||
]
|
||||
|
||||
@property
|
||||
def class_to_idx(self):
|
||||
return {_class: i for i, _class in enumerate(self.CLASSES)}
|
||||
|
||||
def load_annotations(self):
|
||||
train_image_file = osp.join(
|
||||
self.data_prefix, rm_suffix(self.resources['train_image_file'][0]))
|
||||
|
|
|
@ -38,7 +38,6 @@ class ClsHead(BaseHead):
|
|||
assert len(acc) == len(self.topk)
|
||||
losses['loss'] = loss
|
||||
losses['accuracy'] = {f'top-{k}': a for k, a in zip(self.topk, acc)}
|
||||
losses['num_samples'] = loss.new(1).fill_(num_samples)
|
||||
return losses
|
||||
|
||||
def forward_train(self, cls_score, gt_label):
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import normal_init
|
||||
|
||||
from ..builder import HEADS
|
||||
|
@ -37,6 +39,17 @@ class LinearClsHead(ClsHead):
|
|||
def init_weights(self):
|
||||
normal_init(self.fc, mean=0, std=0.01, bias=0)
|
||||
|
||||
def simple_test(self, img):
|
||||
"""Test without augmentation."""
|
||||
cls_score = self.fc(img)
|
||||
if isinstance(cls_score, list):
|
||||
cls_score = sum(cls_score) / float(len(cls_score))
|
||||
pred = F.softmax(cls_score, dim=1) if cls_score is not None else None
|
||||
if torch.onnx.is_in_onnx_export():
|
||||
return pred
|
||||
pred = list(pred.detach().cpu().numpy())
|
||||
return pred
|
||||
|
||||
def forward_train(self, x, gt_label):
|
||||
cls_score = self.fc(x)
|
||||
losses = self.loss(cls_score, gt_label)
|
||||
|
|
|
@ -1,12 +1,41 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def accuracy_numpy(pred, target, topk):
|
||||
res = []
|
||||
maxk = max(topk)
|
||||
num = pred.shape[0]
|
||||
pred_label = pred.argsort(axis=1)[:, -maxk:][:, ::-1]
|
||||
|
||||
for k in topk:
|
||||
correct_k = np.logical_or.reduce(
|
||||
pred_label[:, :k] == target.reshape(-1, 1), axis=1)
|
||||
res.append(correct_k.sum() * 100. / num)
|
||||
return res
|
||||
|
||||
|
||||
def accuracy_torch(pred, target, topk=1):
|
||||
res = []
|
||||
maxk = max(topk)
|
||||
num = pred.size(0)
|
||||
_, pred_label = pred.topk(maxk, dim=1)
|
||||
pred_label = pred_label.t()
|
||||
correct = pred_label.eq(target.view(1, -1).expand_as(pred_label))
|
||||
|
||||
for k in topk:
|
||||
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
|
||||
res.append(correct_k.mul_(100. / num))
|
||||
return res
|
||||
|
||||
|
||||
def accuracy(pred, target, topk=1):
|
||||
"""Calculate accuracy according to the prediction and target
|
||||
|
||||
Args:
|
||||
pred (torch.Tensor): The model prediction.
|
||||
target (torch.Tensor): The target of each prediction
|
||||
pred (torch.Tensor | np.array): The model prediction.
|
||||
target (torch.Tensor | np.array): The target of each prediction
|
||||
topk (int | tuple[int], optional): If the predictions in ``topk``
|
||||
matches the target, the predictions will be regarded as
|
||||
correct ones. Defaults to 1.
|
||||
|
@ -25,15 +54,14 @@ def accuracy(pred, target, topk=1):
|
|||
else:
|
||||
return_single = False
|
||||
|
||||
maxk = max(topk)
|
||||
_, pred_label = pred.topk(maxk, dim=1)
|
||||
pred_label = pred_label.t()
|
||||
correct = pred_label.eq(target.view(1, -1).expand_as(pred_label))
|
||||
if isinstance(pred, torch.Tensor) and isinstance(target, torch.Tensor):
|
||||
res = accuracy_torch(pred, target, topk)
|
||||
elif isinstance(pred, np.ndarray) and isinstance(target, np.ndarray):
|
||||
res = accuracy_numpy(pred, target, topk)
|
||||
else:
|
||||
raise TypeError('pred and target should both be'
|
||||
'torch.Tensor or np.ndarray')
|
||||
|
||||
res = []
|
||||
for k in topk:
|
||||
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
|
||||
res.append(correct_k.mul_(100.0 / pred.size(0)))
|
||||
return res[0] if return_single else res
|
||||
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import argparse
|
||||
import os
|
||||
import warnings
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
|
@ -19,7 +20,7 @@ def parse_args():
|
|||
parser.add_argument('checkpoint', help='checkpoint file')
|
||||
parser.add_argument('--out', help='output result file')
|
||||
parser.add_argument(
|
||||
'--eval', type=str, nargs='+', choices=['accuracy'], help='eval types')
|
||||
'--metric', type=str, default='accuracy', help='evaluation metric')
|
||||
parser.add_argument(
|
||||
'--gpu_collect',
|
||||
action='store_true',
|
||||
|
@ -69,7 +70,7 @@ def main():
|
|||
fp16_cfg = cfg.get('fp16', None)
|
||||
if fp16_cfg is not None:
|
||||
wrap_fp16_model(model)
|
||||
_ = load_checkpoint(model, args.checkpoint, map_location='cpu')
|
||||
checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu')
|
||||
|
||||
if not distributed:
|
||||
model = MMDataParallel(model, device_ids=[0])
|
||||
|
@ -84,21 +85,37 @@ def main():
|
|||
|
||||
rank, _ = get_dist_info()
|
||||
if rank == 0:
|
||||
nums = []
|
||||
results = {}
|
||||
for output in outputs:
|
||||
nums.append(output['num_samples'].item())
|
||||
for topk, v in output['accuracy'].items():
|
||||
if topk not in results:
|
||||
results[topk] = []
|
||||
results[topk].append(v.item())
|
||||
assert sum(nums) == len(dataset)
|
||||
for topk, accs in results.items():
|
||||
avg_acc = np.average(accs, weights=nums)
|
||||
print(f'\n{topk} accuracy: {avg_acc:.2f}')
|
||||
if args.metric != '':
|
||||
results = dataset.evaluate(outputs, args.metric)
|
||||
for topk, acc in results.items():
|
||||
print(f'\n{topk} accuracy: {acc:.2f}')
|
||||
else:
|
||||
scores = np.vstack(outputs)
|
||||
pred_score = np.max(scores, axis=1)
|
||||
pred_label = np.argmax(scores, axis=1)
|
||||
if 'CLASSES' in checkpoint['meta']:
|
||||
CLASSES = checkpoint['meta']['CLASSES']
|
||||
else:
|
||||
from mmcls.datasets import ImageNet
|
||||
warnings.simplefilter('once')
|
||||
warnings.warn('Class names are not saved in the checkpoint\'s '
|
||||
'meta data, use imagenet by default.')
|
||||
CLASSES = ImageNet.CLASSES
|
||||
pred_class = [CLASSES[lb] for lb in pred_label]
|
||||
results = {
|
||||
'pred_score': pred_score,
|
||||
'pred_label': pred_label,
|
||||
'pred_class': pred_class
|
||||
}
|
||||
if not args.out:
|
||||
print('\nthe predicted result for the first element is '
|
||||
f'pred_score = {pred_score[0]:.2f}, '
|
||||
f'pred_label = {pred_label[0]} '
|
||||
f'and pred_class = {pred_class[0]}. '
|
||||
'Specify --out to save all results to files.')
|
||||
if args.out and rank == 0:
|
||||
print(f'\nwriting results to {args.out}')
|
||||
mmcv.dump(outputs, args.out)
|
||||
mmcv.dump(results, args.out)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
Loading…
Reference in New Issue