diff --git a/configs/_base_/datasets/imagenet_bs32.py b/configs/_base_/datasets/imagenet_bs32.py index 80320b1f..8a546590 100644 --- a/configs/_base_/datasets/imagenet_bs32.py +++ b/configs/_base_/datasets/imagenet_bs32.py @@ -17,8 +17,7 @@ test_pipeline = [ dict(type='CenterCrop', crop_size=224), dict(type='Normalize', **img_norm_cfg), dict(type='ImageToTensor', keys=['img']), - dict(type='ToTensor', keys=['gt_label']), - dict(type='Collect', keys=['img', 'gt_label']) + dict(type='Collect', keys=['img']) ] data = dict( samples_per_gpu=32, diff --git a/configs/_base_/datasets/imagenet_bs32_pil_resize.py b/configs/_base_/datasets/imagenet_bs32_pil_resize.py index 529d2fb8..22b74f76 100644 --- a/configs/_base_/datasets/imagenet_bs32_pil_resize.py +++ b/configs/_base_/datasets/imagenet_bs32_pil_resize.py @@ -17,8 +17,7 @@ test_pipeline = [ dict(type='CenterCrop', crop_size=224), dict(type='Normalize', **img_norm_cfg), dict(type='ImageToTensor', keys=['img']), - dict(type='ToTensor', keys=['gt_label']), - dict(type='Collect', keys=['img', 'gt_label']) + dict(type='Collect', keys=['img']) ] data = dict( samples_per_gpu=32, diff --git a/configs/_base_/datasets/imagenet_bs64.py b/configs/_base_/datasets/imagenet_bs64.py index 5abddca4..b9f866a4 100644 --- a/configs/_base_/datasets/imagenet_bs64.py +++ b/configs/_base_/datasets/imagenet_bs64.py @@ -17,8 +17,7 @@ test_pipeline = [ dict(type='CenterCrop', crop_size=224), dict(type='Normalize', **img_norm_cfg), dict(type='ImageToTensor', keys=['img']), - dict(type='ToTensor', keys=['gt_label']), - dict(type='Collect', keys=['img', 'gt_label']) + dict(type='Collect', keys=['img']) ] data = dict( samples_per_gpu=64, diff --git a/configs/_base_/datasets/imagenet_bs64_pil_resize.py b/configs/_base_/datasets/imagenet_bs64_pil_resize.py index 55b5ca2b..95d0e1f2 100644 --- a/configs/_base_/datasets/imagenet_bs64_pil_resize.py +++ b/configs/_base_/datasets/imagenet_bs64_pil_resize.py @@ -17,8 +17,7 @@ test_pipeline = [ dict(type='CenterCrop', crop_size=224), dict(type='Normalize', **img_norm_cfg), dict(type='ImageToTensor', keys=['img']), - dict(type='ToTensor', keys=['gt_label']), - dict(type='Collect', keys=['img', 'gt_label']) + dict(type='Collect', keys=['img']) ] data = dict( samples_per_gpu=64, diff --git a/demo/demo.JPEG b/demo/demo.JPEG new file mode 100755 index 00000000..fd3a93f5 Binary files /dev/null and b/demo/demo.JPEG differ diff --git a/demo/image_demo.py b/demo/image_demo.py new file mode 100644 index 00000000..5c208f6c --- /dev/null +++ b/demo/image_demo.py @@ -0,0 +1,24 @@ +from argparse import ArgumentParser + +from mmcls.apis import inference_model, init_model + + +def main(): + parser = ArgumentParser() + parser.add_argument('img', help='Image file') + parser.add_argument('config', help='Config file') + parser.add_argument('checkpoint', help='Checkpoint file') + parser.add_argument( + '--device', default='cuda:0', help='Device used for inference') + args = parser.parse_args() + + # build the model from a config file and a checkpoint file + model = init_model(args.config, args.checkpoint, device=args.device) + # test a single image + result = inference_model(model, args.img) + # print result on terminal + print(result) + + +if __name__ == '__main__': + main() diff --git a/docs/getting_started.md b/docs/getting_started.md index a937af78..c53430b5 100644 --- a/docs/getting_started.md +++ b/docs/getting_started.md @@ -42,7 +42,42 @@ For using custom datasets, please refer to [Tutorials 2: Adding New Dataset](tut ## Inference with pretrained models -We provide testing scripts to evaluate a whole dataset (ImageNet, etc.). +We provide scripts to inference a single image, inference a dataset and test a dataset (e.g., ImageNet). + +### Inference a single image + +```shell +python demo/image_demo.py ${IMAGE_FILE} ${CONFIG_FILE} ${CHECKPOINT_FILE} +``` + +### Inference a dataset + +- single GPU +- single node multiple GPU +- multiple node + +You can use the following commands to infer a dataset. + +```shell +# single-gpu inference +python tools/inference.py ${CONFIG_FILE} ${CHECKPOINT_FILE} [--out ${RESULT_FILE}] + +# multi-gpu inference +./tools/dist_inference.sh ${CONFIG_FILE} ${CHECKPOINT_FILE} ${GPU_NUM} [--out ${RESULT_FILE}] +``` + +Optional arguments: +- `RESULT_FILE`: Filename of the output results in pickle format. If not specified, the results will not be saved to a file. + +Examples: + +Assume that you have already downloaded the checkpoints to the directory `checkpoints/`. +Infer ResNet-50 on ImageNet validation set to get predicted labels and their corresponding predicted scores. + +```shell +python tools/inference.py configs/imagenet/resnet50_batch256.py \ + checkpoints/xxx.pth +``` ### Test a dataset diff --git a/mmcls/apis/inference.py b/mmcls/apis/inference.py index dfc64e8d..83edcdaa 100644 --- a/mmcls/apis/inference.py +++ b/mmcls/apis/inference.py @@ -1,6 +1,79 @@ -def init_model(): - pass +import warnings + +import mmcv +import numpy as np +import torch +from mmcv.parallel import collate, scatter +from mmcv.runner import load_checkpoint + +from mmcls.datasets.pipelines import Compose +from mmcls.models import build_classifier -def inference_model(): - pass +def init_model(config, checkpoint=None, device='cuda:0'): + """Initialize a classifier from config file. + + Args: + config (str or :obj:`mmcv.Config`): Config file path or the config + object. + checkpoint (str, optional): Checkpoint path. If left as None, the model + will not load any weights. + + Returns: + nn.Module: The constructed classifier. + """ + if isinstance(config, str): + config = mmcv.Config.fromfile(config) + elif not isinstance(config, mmcv.Config): + raise TypeError('config must be a filename or Config object, ' + f'but got {type(config)}') + config.model.pretrained = None + model = build_classifier(config.model) + if checkpoint is not None: + map_loc = 'cpu' if device == 'cpu' else None + checkpoint = load_checkpoint(model, checkpoint, map_location=map_loc) + if 'CLASSES' in checkpoint['meta']: + model.CLASSES = checkpoint['meta']['CLASSES'] + else: + from mmcls.datasets import ImageNet + warnings.simplefilter('once') + warnings.warn('Class names are not saved in the checkpoint\'s ' + 'meta data, use imagenet by default.') + model.CLASSES = ImageNet.CLASSES + model.cfg = config # save the config in the model for convenience + model.to(device) + model.eval() + return model + + +def inference_model(model, img): + """Inference image(s) with the classifier. + + Args: + model (nn.Module): The loaded classifier. + img (str/ndarray): The image filename. + + Returns: + result (dict): The classification results that contains + `class_name`, `pred_label` and `pred_score`. + """ + cfg = model.cfg + device = next(model.parameters()).device # model device + # build the data pipeline + test_pipeline = Compose(cfg.data.test.pipeline) + # prepare data + data = dict(img_info=dict(filename=img), img_prefix=None) + data = test_pipeline(data) + data = collate([data], samples_per_gpu=1) + if next(model.parameters()).is_cuda: + # scatter to specified GPU + data = scatter(data, [device])[0] + + # forward the model + with torch.no_grad(): + scores = model(return_loss=False, **data) + pred_score = np.max(scores, axis=1)[0] + pred_label = np.argmax(scores, axis=1)[0] + result = {'pred_label': pred_label, 'pred_score': pred_score} + result['class_name'] = model.CLASSES[result['pred_label']] + return result diff --git a/mmcls/apis/test.py b/mmcls/apis/test.py index a5070543..dd773827 100644 --- a/mmcls/apis/test.py +++ b/mmcls/apis/test.py @@ -17,7 +17,7 @@ def single_gpu_test(model, data_loader, show=False, out_dir=None): prog_bar = mmcv.ProgressBar(len(dataset)) for i, data in enumerate(data_loader): with torch.no_grad(): - result = model(return_loss=True, **data) + result = model(return_loss=False, **data) results.append(result) if show or out_dir: @@ -57,8 +57,11 @@ def multi_gpu_test(model, data_loader, tmpdir=None, gpu_collect=False): time.sleep(2) # This line can prevent deadlock problem in some cases. for i, data in enumerate(data_loader): with torch.no_grad(): - result = model(return_loss=True, **data) - results.append(result) + result = model(return_loss=False, **data) + if isinstance(result, list): + results.extend(result) + else: + results.append(result) if rank == 0: batch_size = data['img'].size(0) diff --git a/mmcls/datasets/base_dataset.py b/mmcls/datasets/base_dataset.py index 8bae46e0..30154c00 100644 --- a/mmcls/datasets/base_dataset.py +++ b/mmcls/datasets/base_dataset.py @@ -4,6 +4,7 @@ from abc import ABCMeta, abstractmethod import numpy as np from torch.utils.data import Dataset +from mmcls.models.losses import accuracy from .pipelines import Compose @@ -35,6 +36,14 @@ class BaseDataset(Dataset, metaclass=ABCMeta): def load_annotations(self): pass + @property + def class_to_idx(self): + return {_class: i for i, _class in enumerate(self.CLASSES)} + + def get_gt_labels(self): + gt_labels = np.array([data['gt_label'] for data in self.data_infos]) + return gt_labels + def prepare_data(self, idx): results = copy.deepcopy(self.data_infos[idx]) return self.pipeline(results) @@ -45,7 +54,11 @@ class BaseDataset(Dataset, metaclass=ABCMeta): def __getitem__(self, idx): return self.prepare_data(idx) - def evaluate(self, results, metric='accuracy', logger=None): + def evaluate(self, + results, + metric='accuracy', + metric_options={'topk': (1, 5)}, + logger=None): """Evaluate the dataset. Args: @@ -63,16 +76,14 @@ class BaseDataset(Dataset, metaclass=ABCMeta): allowed_metrics = ['accuracy'] if metric not in allowed_metrics: raise KeyError(f'metric {metric} is not supported') + eval_results = {} if metric == 'accuracy': - nums = [] - for result in results: - nums.append(result['num_samples'].item()) - for topk, v in result['accuracy'].items(): - if topk not in eval_results: - eval_results[topk] = [] - eval_results[topk].append(v.item()) - assert sum(nums) == len(self.data_infos) - for topk, accs in eval_results.items(): - eval_results[topk] = np.average(accs, weights=nums) + topk = metric_options.get('topk') + results = np.vstack(results) + gt_labels = self.get_gt_labels() + num_imgs = len(results) + assert len(gt_labels) == num_imgs + acc = accuracy(results, gt_labels, topk) + eval_results = {f'top-{k}': a.item() for k, a in zip(topk, acc)} return eval_results diff --git a/mmcls/datasets/cifar.py b/mmcls/datasets/cifar.py index ba7e7fc9..27a5d747 100644 --- a/mmcls/datasets/cifar.py +++ b/mmcls/datasets/cifar.py @@ -89,10 +89,6 @@ class CIFAR10(BaseDataset): with open(path, 'rb') as infile: data = pickle.load(infile, encoding='latin1') self.CLASSES = data[self.meta['key']] - self.class_to_idx = { - _class: i - for i, _class in enumerate(self.CLASSES) - } def _check_integrity(self): root = self.data_prefix diff --git a/mmcls/datasets/imagenet.py b/mmcls/datasets/imagenet.py index bd8641bf..86ac5c5c 100644 --- a/mmcls/datasets/imagenet.py +++ b/mmcls/datasets/imagenet.py @@ -6,45 +6,6 @@ from .base_dataset import BaseDataset from .builder import DATASETS -@DATASETS.register_module() -class ImageNet(BaseDataset): - """`ImageNet `_ Dataset. - - This implementation is modified from - https://github.com/pytorch/vision/blob/master/torchvision/datasets/imagenet.py # noqa: E501 - """ - - IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif') - - def load_annotations(self): - if self.ann_file is None: - classes, class_to_idx = find_classes(self.data_prefix) - samples = make_dataset( - self.data_prefix, class_to_idx, extensions=self.IMG_EXTENSIONS) - if len(samples) == 0: - raise (RuntimeError('Found 0 files in subfolders of: ' - f'{self.data_prefix}. ' - 'Supported extensions are: ' - f'{",".join(self.IMG_EXTENSIONS)}')) - - self.CLASSES = classes - self.class_to_idx = class_to_idx - elif isinstance(self.ann_file, str): - with open(self.ann_file) as f: - samples = [x.strip().split(' ') for x in f.readlines()] - else: - raise TypeError('ann_file must be a str or None') - self.samples = samples - - data_infos = [] - for filename, gt_label in self.samples: - info = {'img_prefix': self.data_prefix} - info['img_info'] = {'filename': filename} - info['gt_label'] = np.array(gt_label, dtype=np.int64) - data_infos.append(info) - return data_infos - - def has_file_allowed_extension(filename, extensions): """Checks if a file is an allowed extension. @@ -58,47 +19,1087 @@ def has_file_allowed_extension(filename, extensions): return any(filename_lower.endswith(ext) for ext in extensions) -def find_classes(root): +def find_folders(root): """Find classes by folders under a root. Args: root (string): root directory of folders Returns: - classes (list): a list of class names - class_to_idx (dict): the map from class name to class idx + folder_to_idx (dict): the map from folder name to class idx """ - classes = [ + folders = [ d for d in os.listdir(root) if os.path.isdir(os.path.join(root, d)) ] - classes.sort() - class_to_idx = {classes[i]: i for i in range(len(classes))} - return classes, class_to_idx + folders.sort() + folder_to_idx = {folders[i]: i for i in range(len(folders))} + return folder_to_idx -def make_dataset(root, class_to_idx, extensions): +def get_samples(root, folder_to_idx, extensions): """Make dataset by walking all images under a root. Args: root (string): root directory of folders - class_to_idx (dict): the map from class name to class idx + folder_to_idx (dict): the map from class name to class idx extensions (tuple): allowed extensions Returns: - images (list): a list of tuple where each element is (image, label) + samples (list): a list of tuple where each element is (image, label) """ - images = [] + samples = [] root = os.path.expanduser(root) - for class_name in sorted(os.listdir(root)): - _dir = os.path.join(root, class_name) + for folder_name in sorted(os.listdir(root)): + _dir = os.path.join(root, folder_name) if not os.path.isdir(_dir): continue for _, _, fns in sorted(os.walk(_dir)): for fn in sorted(fns): if has_file_allowed_extension(fn, extensions): - path = os.path.join(class_name, fn) - item = (path, class_to_idx[class_name]) - images.append(item) + path = os.path.join(folder_name, fn) + item = (path, folder_to_idx[folder_name]) + samples.append(item) + return samples - return images + +@DATASETS.register_module() +class ImageNet(BaseDataset): + """`ImageNet `_ Dataset. + + This implementation is modified from + https://github.com/pytorch/vision/blob/master/torchvision/datasets/imagenet.py # noqa: E501 + """ + + IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif') + CLASSES = [ + 'tench, Tinca tinca', + 'goldfish, Carassius auratus', + 'great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias', # noqa: E501 + 'tiger shark, Galeocerdo cuvieri', + 'hammerhead, hammerhead shark', + 'electric ray, crampfish, numbfish, torpedo', + 'stingray', + 'cock', + 'hen', + 'ostrich, Struthio camelus', + 'brambling, Fringilla montifringilla', + 'goldfinch, Carduelis carduelis', + 'house finch, linnet, Carpodacus mexicanus', + 'junco, snowbird', + 'indigo bunting, indigo finch, indigo bird, Passerina cyanea', + 'robin, American robin, Turdus migratorius', + 'bulbul', + 'jay', + 'magpie', + 'chickadee', + 'water ouzel, dipper', + 'kite', + 'bald eagle, American eagle, Haliaeetus leucocephalus', + 'vulture', + 'great grey owl, great gray owl, Strix nebulosa', + 'European fire salamander, Salamandra salamandra', + 'common newt, Triturus vulgaris', + 'eft', + 'spotted salamander, Ambystoma maculatum', + 'axolotl, mud puppy, Ambystoma mexicanum', + 'bullfrog, Rana catesbeiana', + 'tree frog, tree-frog', + 'tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui', + 'loggerhead, loggerhead turtle, Caretta caretta', + 'leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea', # noqa: E501 + 'mud turtle', + 'terrapin', + 'box turtle, box tortoise', + 'banded gecko', + 'common iguana, iguana, Iguana iguana', + 'American chameleon, anole, Anolis carolinensis', + 'whiptail, whiptail lizard', + 'agama', + 'frilled lizard, Chlamydosaurus kingi', + 'alligator lizard', + 'Gila monster, Heloderma suspectum', + 'green lizard, Lacerta viridis', + 'African chameleon, Chamaeleo chamaeleon', + 'Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis', # noqa: E501 + 'African crocodile, Nile crocodile, Crocodylus niloticus', + 'American alligator, Alligator mississipiensis', + 'triceratops', + 'thunder snake, worm snake, Carphophis amoenus', + 'ringneck snake, ring-necked snake, ring snake', + 'hognose snake, puff adder, sand viper', + 'green snake, grass snake', + 'king snake, kingsnake', + 'garter snake, grass snake', + 'water snake', + 'vine snake', + 'night snake, Hypsiglena torquata', + 'boa constrictor, Constrictor constrictor', + 'rock python, rock snake, Python sebae', + 'Indian cobra, Naja naja', + 'green mamba', + 'sea snake', + 'horned viper, cerastes, sand viper, horned asp, Cerastes cornutus', + 'diamondback, diamondback rattlesnake, Crotalus adamanteus', + 'sidewinder, horned rattlesnake, Crotalus cerastes', + 'trilobite', + 'harvestman, daddy longlegs, Phalangium opilio', + 'scorpion', + 'black and gold garden spider, Argiope aurantia', + 'barn spider, Araneus cavaticus', + 'garden spider, Aranea diademata', + 'black widow, Latrodectus mactans', + 'tarantula', + 'wolf spider, hunting spider', + 'tick', + 'centipede', + 'black grouse', + 'ptarmigan', + 'ruffed grouse, partridge, Bonasa umbellus', + 'prairie chicken, prairie grouse, prairie fowl', + 'peacock', + 'quail', + 'partridge', + 'African grey, African gray, Psittacus erithacus', + 'macaw', + 'sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita', + 'lorikeet', + 'coucal', + 'bee eater', + 'hornbill', + 'hummingbird', + 'jacamar', + 'toucan', + 'drake', + 'red-breasted merganser, Mergus serrator', + 'goose', + 'black swan, Cygnus atratus', + 'tusker', + 'echidna, spiny anteater, anteater', + 'platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus', # noqa: E501 + 'wallaby, brush kangaroo', + 'koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus', # noqa: E501 + 'wombat', + 'jellyfish', + 'sea anemone, anemone', + 'brain coral', + 'flatworm, platyhelminth', + 'nematode, nematode worm, roundworm', + 'conch', + 'snail', + 'slug', + 'sea slug, nudibranch', + 'chiton, coat-of-mail shell, sea cradle, polyplacophore', + 'chambered nautilus, pearly nautilus, nautilus', + 'Dungeness crab, Cancer magister', + 'rock crab, Cancer irroratus', + 'fiddler crab', + 'king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica', # noqa: E501 + 'American lobster, Northern lobster, Maine lobster, Homarus americanus', # noqa: E501 + 'spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish', # noqa: E501 + 'crayfish, crawfish, crawdad, crawdaddy', + 'hermit crab', + 'isopod', + 'white stork, Ciconia ciconia', + 'black stork, Ciconia nigra', + 'spoonbill', + 'flamingo', + 'little blue heron, Egretta caerulea', + 'American egret, great white heron, Egretta albus', + 'bittern', + 'crane', + 'limpkin, Aramus pictus', + 'European gallinule, Porphyrio porphyrio', + 'American coot, marsh hen, mud hen, water hen, Fulica americana', + 'bustard', + 'ruddy turnstone, Arenaria interpres', + 'red-backed sandpiper, dunlin, Erolia alpina', + 'redshank, Tringa totanus', + 'dowitcher', + 'oystercatcher, oyster catcher', + 'pelican', + 'king penguin, Aptenodytes patagonica', + 'albatross, mollymawk', + 'grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus', # noqa: E501 + 'killer whale, killer, orca, grampus, sea wolf, Orcinus orca', + 'dugong, Dugong dugon', + 'sea lion', + 'Chihuahua', + 'Japanese spaniel', + 'Maltese dog, Maltese terrier, Maltese', + 'Pekinese, Pekingese, Peke', + 'Shih-Tzu', + 'Blenheim spaniel', + 'papillon', + 'toy terrier', + 'Rhodesian ridgeback', + 'Afghan hound, Afghan', + 'basset, basset hound', + 'beagle', + 'bloodhound, sleuthhound', + 'bluetick', + 'black-and-tan coonhound', + 'Walker hound, Walker foxhound', + 'English foxhound', + 'redbone', + 'borzoi, Russian wolfhound', + 'Irish wolfhound', + 'Italian greyhound', + 'whippet', + 'Ibizan hound, Ibizan Podenco', + 'Norwegian elkhound, elkhound', + 'otterhound, otter hound', + 'Saluki, gazelle hound', + 'Scottish deerhound, deerhound', + 'Weimaraner', + 'Staffordshire bullterrier, Staffordshire bull terrier', + 'American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier', # noqa: E501 + 'Bedlington terrier', + 'Border terrier', + 'Kerry blue terrier', + 'Irish terrier', + 'Norfolk terrier', + 'Norwich terrier', + 'Yorkshire terrier', + 'wire-haired fox terrier', + 'Lakeland terrier', + 'Sealyham terrier, Sealyham', + 'Airedale, Airedale terrier', + 'cairn, cairn terrier', + 'Australian terrier', + 'Dandie Dinmont, Dandie Dinmont terrier', + 'Boston bull, Boston terrier', + 'miniature schnauzer', + 'giant schnauzer', + 'standard schnauzer', + 'Scotch terrier, Scottish terrier, Scottie', + 'Tibetan terrier, chrysanthemum dog', + 'silky terrier, Sydney silky', + 'soft-coated wheaten terrier', + 'West Highland white terrier', + 'Lhasa, Lhasa apso', + 'flat-coated retriever', + 'curly-coated retriever', + 'golden retriever', + 'Labrador retriever', + 'Chesapeake Bay retriever', + 'German short-haired pointer', + 'vizsla, Hungarian pointer', + 'English setter', + 'Irish setter, red setter', + 'Gordon setter', + 'Brittany spaniel', + 'clumber, clumber spaniel', + 'English springer, English springer spaniel', + 'Welsh springer spaniel', + 'cocker spaniel, English cocker spaniel, cocker', + 'Sussex spaniel', + 'Irish water spaniel', + 'kuvasz', + 'schipperke', + 'groenendael', + 'malinois', + 'briard', + 'kelpie', + 'komondor', + 'Old English sheepdog, bobtail', + 'Shetland sheepdog, Shetland sheep dog, Shetland', + 'collie', + 'Border collie', + 'Bouvier des Flandres, Bouviers des Flandres', + 'Rottweiler', + 'German shepherd, German shepherd dog, German police dog, alsatian', + 'Doberman, Doberman pinscher', + 'miniature pinscher', + 'Greater Swiss Mountain dog', + 'Bernese mountain dog', + 'Appenzeller', + 'EntleBucher', + 'boxer', + 'bull mastiff', + 'Tibetan mastiff', + 'French bulldog', + 'Great Dane', + 'Saint Bernard, St Bernard', + 'Eskimo dog, husky', + 'malamute, malemute, Alaskan malamute', + 'Siberian husky', + 'dalmatian, coach dog, carriage dog', + 'affenpinscher, monkey pinscher, monkey dog', + 'basenji', + 'pug, pug-dog', + 'Leonberg', + 'Newfoundland, Newfoundland dog', + 'Great Pyrenees', + 'Samoyed, Samoyede', + 'Pomeranian', + 'chow, chow chow', + 'keeshond', + 'Brabancon griffon', + 'Pembroke, Pembroke Welsh corgi', + 'Cardigan, Cardigan Welsh corgi', + 'toy poodle', + 'miniature poodle', + 'standard poodle', + 'Mexican hairless', + 'timber wolf, grey wolf, gray wolf, Canis lupus', + 'white wolf, Arctic wolf, Canis lupus tundrarum', + 'red wolf, maned wolf, Canis rufus, Canis niger', + 'coyote, prairie wolf, brush wolf, Canis latrans', + 'dingo, warrigal, warragal, Canis dingo', + 'dhole, Cuon alpinus', + 'African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus', + 'hyena, hyaena', + 'red fox, Vulpes vulpes', + 'kit fox, Vulpes macrotis', + 'Arctic fox, white fox, Alopex lagopus', + 'grey fox, gray fox, Urocyon cinereoargenteus', + 'tabby, tabby cat', + 'tiger cat', + 'Persian cat', + 'Siamese cat, Siamese', + 'Egyptian cat', + 'cougar, puma, catamount, mountain lion, painter, panther, Felis concolor', # noqa: E501 + 'lynx, catamount', + 'leopard, Panthera pardus', + 'snow leopard, ounce, Panthera uncia', + 'jaguar, panther, Panthera onca, Felis onca', + 'lion, king of beasts, Panthera leo', + 'tiger, Panthera tigris', + 'cheetah, chetah, Acinonyx jubatus', + 'brown bear, bruin, Ursus arctos', + 'American black bear, black bear, Ursus americanus, Euarctos americanus', # noqa: E501 + 'ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus', + 'sloth bear, Melursus ursinus, Ursus ursinus', + 'mongoose', + 'meerkat, mierkat', + 'tiger beetle', + 'ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle', + 'ground beetle, carabid beetle', + 'long-horned beetle, longicorn, longicorn beetle', + 'leaf beetle, chrysomelid', + 'dung beetle', + 'rhinoceros beetle', + 'weevil', + 'fly', + 'bee', + 'ant, emmet, pismire', + 'grasshopper, hopper', + 'cricket', + 'walking stick, walkingstick, stick insect', + 'cockroach, roach', + 'mantis, mantid', + 'cicada, cicala', + 'leafhopper', + 'lacewing, lacewing fly', + "dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk", # noqa: E501 + 'damselfly', + 'admiral', + 'ringlet, ringlet butterfly', + 'monarch, monarch butterfly, milkweed butterfly, Danaus plexippus', + 'cabbage butterfly', + 'sulphur butterfly, sulfur butterfly', + 'lycaenid, lycaenid butterfly', + 'starfish, sea star', + 'sea urchin', + 'sea cucumber, holothurian', + 'wood rabbit, cottontail, cottontail rabbit', + 'hare', + 'Angora, Angora rabbit', + 'hamster', + 'porcupine, hedgehog', + 'fox squirrel, eastern fox squirrel, Sciurus niger', + 'marmot', + 'beaver', + 'guinea pig, Cavia cobaya', + 'sorrel', + 'zebra', + 'hog, pig, grunter, squealer, Sus scrofa', + 'wild boar, boar, Sus scrofa', + 'warthog', + 'hippopotamus, hippo, river horse, Hippopotamus amphibius', + 'ox', + 'water buffalo, water ox, Asiatic buffalo, Bubalus bubalis', + 'bison', + 'ram, tup', + 'bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis', # noqa: E501 + 'ibex, Capra ibex', + 'hartebeest', + 'impala, Aepyceros melampus', + 'gazelle', + 'Arabian camel, dromedary, Camelus dromedarius', + 'llama', + 'weasel', + 'mink', + 'polecat, fitch, foulmart, foumart, Mustela putorius', + 'black-footed ferret, ferret, Mustela nigripes', + 'otter', + 'skunk, polecat, wood pussy', + 'badger', + 'armadillo', + 'three-toed sloth, ai, Bradypus tridactylus', + 'orangutan, orang, orangutang, Pongo pygmaeus', + 'gorilla, Gorilla gorilla', + 'chimpanzee, chimp, Pan troglodytes', + 'gibbon, Hylobates lar', + 'siamang, Hylobates syndactylus, Symphalangus syndactylus', + 'guenon, guenon monkey', + 'patas, hussar monkey, Erythrocebus patas', + 'baboon', + 'macaque', + 'langur', + 'colobus, colobus monkey', + 'proboscis monkey, Nasalis larvatus', + 'marmoset', + 'capuchin, ringtail, Cebus capucinus', + 'howler monkey, howler', + 'titi, titi monkey', + 'spider monkey, Ateles geoffroyi', + 'squirrel monkey, Saimiri sciureus', + 'Madagascar cat, ring-tailed lemur, Lemur catta', + 'indri, indris, Indri indri, Indri brevicaudatus', + 'Indian elephant, Elephas maximus', + 'African elephant, Loxodonta africana', + 'lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens', + 'giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca', + 'barracouta, snoek', + 'eel', + 'coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch', # noqa: E501 + 'rock beauty, Holocanthus tricolor', + 'anemone fish', + 'sturgeon', + 'gar, garfish, garpike, billfish, Lepisosteus osseus', + 'lionfish', + 'puffer, pufferfish, blowfish, globefish', + 'abacus', + 'abaya', + "academic gown, academic robe, judge's robe", + 'accordion, piano accordion, squeeze box', + 'acoustic guitar', + 'aircraft carrier, carrier, flattop, attack aircraft carrier', + 'airliner', + 'airship, dirigible', + 'altar', + 'ambulance', + 'amphibian, amphibious vehicle', + 'analog clock', + 'apiary, bee house', + 'apron', + 'ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin', # noqa: E501 + 'assault rifle, assault gun', + 'backpack, back pack, knapsack, packsack, rucksack, haversack', + 'bakery, bakeshop, bakehouse', + 'balance beam, beam', + 'balloon', + 'ballpoint, ballpoint pen, ballpen, Biro', + 'Band Aid', + 'banjo', + 'bannister, banister, balustrade, balusters, handrail', + 'barbell', + 'barber chair', + 'barbershop', + 'barn', + 'barometer', + 'barrel, cask', + 'barrow, garden cart, lawn cart, wheelbarrow', + 'baseball', + 'basketball', + 'bassinet', + 'bassoon', + 'bathing cap, swimming cap', + 'bath towel', + 'bathtub, bathing tub, bath, tub', + 'beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon', # noqa: E501 + 'beacon, lighthouse, beacon light, pharos', + 'beaker', + 'bearskin, busby, shako', + 'beer bottle', + 'beer glass', + 'bell cote, bell cot', + 'bib', + 'bicycle-built-for-two, tandem bicycle, tandem', + 'bikini, two-piece', + 'binder, ring-binder', + 'binoculars, field glasses, opera glasses', + 'birdhouse', + 'boathouse', + 'bobsled, bobsleigh, bob', + 'bolo tie, bolo, bola tie, bola', + 'bonnet, poke bonnet', + 'bookcase', + 'bookshop, bookstore, bookstall', + 'bottlecap', + 'bow', + 'bow tie, bow-tie, bowtie', + 'brass, memorial tablet, plaque', + 'brassiere, bra, bandeau', + 'breakwater, groin, groyne, mole, bulwark, seawall, jetty', + 'breastplate, aegis, egis', + 'broom', + 'bucket, pail', + 'buckle', + 'bulletproof vest', + 'bullet train, bullet', + 'butcher shop, meat market', + 'cab, hack, taxi, taxicab', + 'caldron, cauldron', + 'candle, taper, wax light', + 'cannon', + 'canoe', + 'can opener, tin opener', + 'cardigan', + 'car mirror', + 'carousel, carrousel, merry-go-round, roundabout, whirligig', + "carpenter's kit, tool kit", + 'carton', + 'car wheel', + 'cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM', # noqa: E501 + 'cassette', + 'cassette player', + 'castle', + 'catamaran', + 'CD player', + 'cello, violoncello', + 'cellular telephone, cellular phone, cellphone, cell, mobile phone', + 'chain', + 'chainlink fence', + 'chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour', # noqa: E501 + 'chain saw, chainsaw', + 'chest', + 'chiffonier, commode', + 'chime, bell, gong', + 'china cabinet, china closet', + 'Christmas stocking', + 'church, church building', + 'cinema, movie theater, movie theatre, movie house, picture palace', + 'cleaver, meat cleaver, chopper', + 'cliff dwelling', + 'cloak', + 'clog, geta, patten, sabot', + 'cocktail shaker', + 'coffee mug', + 'coffeepot', + 'coil, spiral, volute, whorl, helix', + 'combination lock', + 'computer keyboard, keypad', + 'confectionery, confectionary, candy store', + 'container ship, containership, container vessel', + 'convertible', + 'corkscrew, bottle screw', + 'cornet, horn, trumpet, trump', + 'cowboy boot', + 'cowboy hat, ten-gallon hat', + 'cradle', + 'crane', + 'crash helmet', + 'crate', + 'crib, cot', + 'Crock Pot', + 'croquet ball', + 'crutch', + 'cuirass', + 'dam, dike, dyke', + 'desk', + 'desktop computer', + 'dial telephone, dial phone', + 'diaper, nappy, napkin', + 'digital clock', + 'digital watch', + 'dining table, board', + 'dishrag, dishcloth', + 'dishwasher, dish washer, dishwashing machine', + 'disk brake, disc brake', + 'dock, dockage, docking facility', + 'dogsled, dog sled, dog sleigh', + 'dome', + 'doormat, welcome mat', + 'drilling platform, offshore rig', + 'drum, membranophone, tympan', + 'drumstick', + 'dumbbell', + 'Dutch oven', + 'electric fan, blower', + 'electric guitar', + 'electric locomotive', + 'entertainment center', + 'envelope', + 'espresso maker', + 'face powder', + 'feather boa, boa', + 'file, file cabinet, filing cabinet', + 'fireboat', + 'fire engine, fire truck', + 'fire screen, fireguard', + 'flagpole, flagstaff', + 'flute, transverse flute', + 'folding chair', + 'football helmet', + 'forklift', + 'fountain', + 'fountain pen', + 'four-poster', + 'freight car', + 'French horn, horn', + 'frying pan, frypan, skillet', + 'fur coat', + 'garbage truck, dustcart', + 'gasmask, respirator, gas helmet', + 'gas pump, gasoline pump, petrol pump, island dispenser', + 'goblet', + 'go-kart', + 'golf ball', + 'golfcart, golf cart', + 'gondola', + 'gong, tam-tam', + 'gown', + 'grand piano, grand', + 'greenhouse, nursery, glasshouse', + 'grille, radiator grille', + 'grocery store, grocery, food market, market', + 'guillotine', + 'hair slide', + 'hair spray', + 'half track', + 'hammer', + 'hamper', + 'hand blower, blow dryer, blow drier, hair dryer, hair drier', + 'hand-held computer, hand-held microcomputer', + 'handkerchief, hankie, hanky, hankey', + 'hard disc, hard disk, fixed disk', + 'harmonica, mouth organ, harp, mouth harp', + 'harp', + 'harvester, reaper', + 'hatchet', + 'holster', + 'home theater, home theatre', + 'honeycomb', + 'hook, claw', + 'hoopskirt, crinoline', + 'horizontal bar, high bar', + 'horse cart, horse-cart', + 'hourglass', + 'iPod', + 'iron, smoothing iron', + "jack-o'-lantern", + 'jean, blue jean, denim', + 'jeep, landrover', + 'jersey, T-shirt, tee shirt', + 'jigsaw puzzle', + 'jinrikisha, ricksha, rickshaw', + 'joystick', + 'kimono', + 'knee pad', + 'knot', + 'lab coat, laboratory coat', + 'ladle', + 'lampshade, lamp shade', + 'laptop, laptop computer', + 'lawn mower, mower', + 'lens cap, lens cover', + 'letter opener, paper knife, paperknife', + 'library', + 'lifeboat', + 'lighter, light, igniter, ignitor', + 'limousine, limo', + 'liner, ocean liner', + 'lipstick, lip rouge', + 'Loafer', + 'lotion', + 'loudspeaker, speaker, speaker unit, loudspeaker system, speaker system', # noqa: E501 + "loupe, jeweler's loupe", + 'lumbermill, sawmill', + 'magnetic compass', + 'mailbag, postbag', + 'mailbox, letter box', + 'maillot', + 'maillot, tank suit', + 'manhole cover', + 'maraca', + 'marimba, xylophone', + 'mask', + 'matchstick', + 'maypole', + 'maze, labyrinth', + 'measuring cup', + 'medicine chest, medicine cabinet', + 'megalith, megalithic structure', + 'microphone, mike', + 'microwave, microwave oven', + 'military uniform', + 'milk can', + 'minibus', + 'miniskirt, mini', + 'minivan', + 'missile', + 'mitten', + 'mixing bowl', + 'mobile home, manufactured home', + 'Model T', + 'modem', + 'monastery', + 'monitor', + 'moped', + 'mortar', + 'mortarboard', + 'mosque', + 'mosquito net', + 'motor scooter, scooter', + 'mountain bike, all-terrain bike, off-roader', + 'mountain tent', + 'mouse, computer mouse', + 'mousetrap', + 'moving van', + 'muzzle', + 'nail', + 'neck brace', + 'necklace', + 'nipple', + 'notebook, notebook computer', + 'obelisk', + 'oboe, hautboy, hautbois', + 'ocarina, sweet potato', + 'odometer, hodometer, mileometer, milometer', + 'oil filter', + 'organ, pipe organ', + 'oscilloscope, scope, cathode-ray oscilloscope, CRO', + 'overskirt', + 'oxcart', + 'oxygen mask', + 'packet', + 'paddle, boat paddle', + 'paddlewheel, paddle wheel', + 'padlock', + 'paintbrush', + "pajama, pyjama, pj's, jammies", + 'palace', + 'panpipe, pandean pipe, syrinx', + 'paper towel', + 'parachute, chute', + 'parallel bars, bars', + 'park bench', + 'parking meter', + 'passenger car, coach, carriage', + 'patio, terrace', + 'pay-phone, pay-station', + 'pedestal, plinth, footstall', + 'pencil box, pencil case', + 'pencil sharpener', + 'perfume, essence', + 'Petri dish', + 'photocopier', + 'pick, plectrum, plectron', + 'pickelhaube', + 'picket fence, paling', + 'pickup, pickup truck', + 'pier', + 'piggy bank, penny bank', + 'pill bottle', + 'pillow', + 'ping-pong ball', + 'pinwheel', + 'pirate, pirate ship', + 'pitcher, ewer', + "plane, carpenter's plane, woodworking plane", + 'planetarium', + 'plastic bag', + 'plate rack', + 'plow, plough', + "plunger, plumber's helper", + 'Polaroid camera, Polaroid Land camera', + 'pole', + 'police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria', # noqa: E501 + 'poncho', + 'pool table, billiard table, snooker table', + 'pop bottle, soda bottle', + 'pot, flowerpot', + "potter's wheel", + 'power drill', + 'prayer rug, prayer mat', + 'printer', + 'prison, prison house', + 'projectile, missile', + 'projector', + 'puck, hockey puck', + 'punching bag, punch bag, punching ball, punchball', + 'purse', + 'quill, quill pen', + 'quilt, comforter, comfort, puff', + 'racer, race car, racing car', + 'racket, racquet', + 'radiator', + 'radio, wireless', + 'radio telescope, radio reflector', + 'rain barrel', + 'recreational vehicle, RV, R.V.', + 'reel', + 'reflex camera', + 'refrigerator, icebox', + 'remote control, remote', + 'restaurant, eating house, eating place, eatery', + 'revolver, six-gun, six-shooter', + 'rifle', + 'rocking chair, rocker', + 'rotisserie', + 'rubber eraser, rubber, pencil eraser', + 'rugby ball', + 'rule, ruler', + 'running shoe', + 'safe', + 'safety pin', + 'saltshaker, salt shaker', + 'sandal', + 'sarong', + 'sax, saxophone', + 'scabbard', + 'scale, weighing machine', + 'school bus', + 'schooner', + 'scoreboard', + 'screen, CRT screen', + 'screw', + 'screwdriver', + 'seat belt, seatbelt', + 'sewing machine', + 'shield, buckler', + 'shoe shop, shoe-shop, shoe store', + 'shoji', + 'shopping basket', + 'shopping cart', + 'shovel', + 'shower cap', + 'shower curtain', + 'ski', + 'ski mask', + 'sleeping bag', + 'slide rule, slipstick', + 'sliding door', + 'slot, one-armed bandit', + 'snorkel', + 'snowmobile', + 'snowplow, snowplough', + 'soap dispenser', + 'soccer ball', + 'sock', + 'solar dish, solar collector, solar furnace', + 'sombrero', + 'soup bowl', + 'space bar', + 'space heater', + 'space shuttle', + 'spatula', + 'speedboat', + "spider web, spider's web", + 'spindle', + 'sports car, sport car', + 'spotlight, spot', + 'stage', + 'steam locomotive', + 'steel arch bridge', + 'steel drum', + 'stethoscope', + 'stole', + 'stone wall', + 'stopwatch, stop watch', + 'stove', + 'strainer', + 'streetcar, tram, tramcar, trolley, trolley car', + 'stretcher', + 'studio couch, day bed', + 'stupa, tope', + 'submarine, pigboat, sub, U-boat', + 'suit, suit of clothes', + 'sundial', + 'sunglass', + 'sunglasses, dark glasses, shades', + 'sunscreen, sunblock, sun blocker', + 'suspension bridge', + 'swab, swob, mop', + 'sweatshirt', + 'swimming trunks, bathing trunks', + 'swing', + 'switch, electric switch, electrical switch', + 'syringe', + 'table lamp', + 'tank, army tank, armored combat vehicle, armoured combat vehicle', + 'tape player', + 'teapot', + 'teddy, teddy bear', + 'television, television system', + 'tennis ball', + 'thatch, thatched roof', + 'theater curtain, theatre curtain', + 'thimble', + 'thresher, thrasher, threshing machine', + 'throne', + 'tile roof', + 'toaster', + 'tobacco shop, tobacconist shop, tobacconist', + 'toilet seat', + 'torch', + 'totem pole', + 'tow truck, tow car, wrecker', + 'toyshop', + 'tractor', + 'trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi', # noqa: E501 + 'tray', + 'trench coat', + 'tricycle, trike, velocipede', + 'trimaran', + 'tripod', + 'triumphal arch', + 'trolleybus, trolley coach, trackless trolley', + 'trombone', + 'tub, vat', + 'turnstile', + 'typewriter keyboard', + 'umbrella', + 'unicycle, monocycle', + 'upright, upright piano', + 'vacuum, vacuum cleaner', + 'vase', + 'vault', + 'velvet', + 'vending machine', + 'vestment', + 'viaduct', + 'violin, fiddle', + 'volleyball', + 'waffle iron', + 'wall clock', + 'wallet, billfold, notecase, pocketbook', + 'wardrobe, closet, press', + 'warplane, military plane', + 'washbasin, handbasin, washbowl, lavabo, wash-hand basin', + 'washer, automatic washer, washing machine', + 'water bottle', + 'water jug', + 'water tower', + 'whiskey jug', + 'whistle', + 'wig', + 'window screen', + 'window shade', + 'Windsor tie', + 'wine bottle', + 'wing', + 'wok', + 'wooden spoon', + 'wool, woolen, woollen', + 'worm fence, snake fence, snake-rail fence, Virginia fence', + 'wreck', + 'yawl', + 'yurt', + 'web site, website, internet site, site', + 'comic book', + 'crossword puzzle, crossword', + 'street sign', + 'traffic light, traffic signal, stoplight', + 'book jacket, dust cover, dust jacket, dust wrapper', + 'menu', + 'plate', + 'guacamole', + 'consomme', + 'hot pot, hotpot', + 'trifle', + 'ice cream, icecream', + 'ice lolly, lolly, lollipop, popsicle', + 'French loaf', + 'bagel, beigel', + 'pretzel', + 'cheeseburger', + 'hotdog, hot dog, red hot', + 'mashed potato', + 'head cabbage', + 'broccoli', + 'cauliflower', + 'zucchini, courgette', + 'spaghetti squash', + 'acorn squash', + 'butternut squash', + 'cucumber, cuke', + 'artichoke, globe artichoke', + 'bell pepper', + 'cardoon', + 'mushroom', + 'Granny Smith', + 'strawberry', + 'orange', + 'lemon', + 'fig', + 'pineapple, ananas', + 'banana', + 'jackfruit, jak, jack', + 'custard apple', + 'pomegranate', + 'hay', + 'carbonara', + 'chocolate sauce, chocolate syrup', + 'dough', + 'meat loaf, meatloaf', + 'pizza, pizza pie', + 'potpie', + 'burrito', + 'red wine', + 'espresso', + 'cup', + 'eggnog', + 'alp', + 'bubble', + 'cliff, drop, drop-off', + 'coral reef', + 'geyser', + 'lakeside, lakeshore', + 'promontory, headland, head, foreland', + 'sandbar, sand bar', + 'seashore, coast, seacoast, sea-coast', + 'valley, vale', + 'volcano', + 'ballplayer, baseball player', + 'groom, bridegroom', + 'scuba diver', + 'rapeseed', + 'daisy', + "yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum", # noqa: E501 + 'corn', + 'acorn', + 'hip, rose hip, rosehip', + 'buckeye, horse chestnut, conker', + 'coral fungus', + 'agaric', + 'gyromitra', + 'stinkhorn, carrion fungus', + 'earthstar', + 'hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa', # noqa: E501 + 'bolete', + 'ear, spike, capitulum', + 'toilet tissue, toilet paper, bathroom tissue' + ] + + def load_annotations(self): + if self.ann_file is None: + folder_to_idx = find_folders(self.data_prefix) + samples = get_samples( + self.data_prefix, + folder_to_idx, + extensions=self.IMG_EXTENSIONS) + if len(samples) == 0: + raise (RuntimeError('Found 0 files in subfolders of: ' + f'{self.data_prefix}. ' + 'Supported extensions are: ' + f'{",".join(self.IMG_EXTENSIONS)}')) + + self.folder_to_idx = folder_to_idx + elif isinstance(self.ann_file, str): + with open(self.ann_file) as f: + samples = [x.strip().split(' ') for x in f.readlines()] + else: + raise TypeError('ann_file must be a str or None') + self.samples = samples + + data_infos = [] + for filename, gt_label in self.samples: + info = {'img_prefix': self.data_prefix} + info['img_info'] = {'filename': filename} + info['gt_label'] = np.array(gt_label, dtype=np.int64) + data_infos.append(info) + return data_infos diff --git a/mmcls/datasets/mnist.py b/mmcls/datasets/mnist.py index 822a338e..fadc2f71 100644 --- a/mmcls/datasets/mnist.py +++ b/mmcls/datasets/mnist.py @@ -35,10 +35,6 @@ class MNIST(BaseDataset): '6 - six', '7 - seven', '8 - eight', '9 - nine' ] - @property - def class_to_idx(self): - return {_class: i for i, _class in enumerate(self.CLASSES)} - def load_annotations(self): train_image_file = osp.join( self.data_prefix, rm_suffix(self.resources['train_image_file'][0])) diff --git a/mmcls/models/heads/cls_head.py b/mmcls/models/heads/cls_head.py index 739d588a..f4fd71b6 100644 --- a/mmcls/models/heads/cls_head.py +++ b/mmcls/models/heads/cls_head.py @@ -38,7 +38,6 @@ class ClsHead(BaseHead): assert len(acc) == len(self.topk) losses['loss'] = loss losses['accuracy'] = {f'top-{k}': a for k, a in zip(self.topk, acc)} - losses['num_samples'] = loss.new(1).fill_(num_samples) return losses def forward_train(self, cls_score, gt_label): diff --git a/mmcls/models/heads/linear_head.py b/mmcls/models/heads/linear_head.py index 7d6c7c48..12c4671a 100644 --- a/mmcls/models/heads/linear_head.py +++ b/mmcls/models/heads/linear_head.py @@ -1,4 +1,6 @@ +import torch import torch.nn as nn +import torch.nn.functional as F from mmcv.cnn import normal_init from ..builder import HEADS @@ -37,6 +39,17 @@ class LinearClsHead(ClsHead): def init_weights(self): normal_init(self.fc, mean=0, std=0.01, bias=0) + def simple_test(self, img): + """Test without augmentation.""" + cls_score = self.fc(img) + if isinstance(cls_score, list): + cls_score = sum(cls_score) / float(len(cls_score)) + pred = F.softmax(cls_score, dim=1) if cls_score is not None else None + if torch.onnx.is_in_onnx_export(): + return pred + pred = list(pred.detach().cpu().numpy()) + return pred + def forward_train(self, x, gt_label): cls_score = self.fc(x) losses = self.loss(cls_score, gt_label) diff --git a/mmcls/models/losses/accuracy.py b/mmcls/models/losses/accuracy.py index 0932722a..2111ff43 100644 --- a/mmcls/models/losses/accuracy.py +++ b/mmcls/models/losses/accuracy.py @@ -1,12 +1,41 @@ +import numpy as np +import torch import torch.nn as nn +def accuracy_numpy(pred, target, topk): + res = [] + maxk = max(topk) + num = pred.shape[0] + pred_label = pred.argsort(axis=1)[:, -maxk:][:, ::-1] + + for k in topk: + correct_k = np.logical_or.reduce( + pred_label[:, :k] == target.reshape(-1, 1), axis=1) + res.append(correct_k.sum() * 100. / num) + return res + + +def accuracy_torch(pred, target, topk=1): + res = [] + maxk = max(topk) + num = pred.size(0) + _, pred_label = pred.topk(maxk, dim=1) + pred_label = pred_label.t() + correct = pred_label.eq(target.view(1, -1).expand_as(pred_label)) + + for k in topk: + correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) + res.append(correct_k.mul_(100. / num)) + return res + + def accuracy(pred, target, topk=1): """Calculate accuracy according to the prediction and target Args: - pred (torch.Tensor): The model prediction. - target (torch.Tensor): The target of each prediction + pred (torch.Tensor | np.array): The model prediction. + target (torch.Tensor | np.array): The target of each prediction topk (int | tuple[int], optional): If the predictions in ``topk`` matches the target, the predictions will be regarded as correct ones. Defaults to 1. @@ -25,15 +54,14 @@ def accuracy(pred, target, topk=1): else: return_single = False - maxk = max(topk) - _, pred_label = pred.topk(maxk, dim=1) - pred_label = pred_label.t() - correct = pred_label.eq(target.view(1, -1).expand_as(pred_label)) + if isinstance(pred, torch.Tensor) and isinstance(target, torch.Tensor): + res = accuracy_torch(pred, target, topk) + elif isinstance(pred, np.ndarray) and isinstance(target, np.ndarray): + res = accuracy_numpy(pred, target, topk) + else: + raise TypeError('pred and target should both be' + 'torch.Tensor or np.ndarray') - res = [] - for k in topk: - correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) - res.append(correct_k.mul_(100.0 / pred.size(0))) return res[0] if return_single else res diff --git a/tools/test.py b/tools/test.py index 5dedbdcb..72d08708 100644 --- a/tools/test.py +++ b/tools/test.py @@ -1,5 +1,6 @@ import argparse import os +import warnings import mmcv import numpy as np @@ -19,7 +20,7 @@ def parse_args(): parser.add_argument('checkpoint', help='checkpoint file') parser.add_argument('--out', help='output result file') parser.add_argument( - '--eval', type=str, nargs='+', choices=['accuracy'], help='eval types') + '--metric', type=str, default='accuracy', help='evaluation metric') parser.add_argument( '--gpu_collect', action='store_true', @@ -69,7 +70,7 @@ def main(): fp16_cfg = cfg.get('fp16', None) if fp16_cfg is not None: wrap_fp16_model(model) - _ = load_checkpoint(model, args.checkpoint, map_location='cpu') + checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu') if not distributed: model = MMDataParallel(model, device_ids=[0]) @@ -84,21 +85,37 @@ def main(): rank, _ = get_dist_info() if rank == 0: - nums = [] - results = {} - for output in outputs: - nums.append(output['num_samples'].item()) - for topk, v in output['accuracy'].items(): - if topk not in results: - results[topk] = [] - results[topk].append(v.item()) - assert sum(nums) == len(dataset) - for topk, accs in results.items(): - avg_acc = np.average(accs, weights=nums) - print(f'\n{topk} accuracy: {avg_acc:.2f}') + if args.metric != '': + results = dataset.evaluate(outputs, args.metric) + for topk, acc in results.items(): + print(f'\n{topk} accuracy: {acc:.2f}') + else: + scores = np.vstack(outputs) + pred_score = np.max(scores, axis=1) + pred_label = np.argmax(scores, axis=1) + if 'CLASSES' in checkpoint['meta']: + CLASSES = checkpoint['meta']['CLASSES'] + else: + from mmcls.datasets import ImageNet + warnings.simplefilter('once') + warnings.warn('Class names are not saved in the checkpoint\'s ' + 'meta data, use imagenet by default.') + CLASSES = ImageNet.CLASSES + pred_class = [CLASSES[lb] for lb in pred_label] + results = { + 'pred_score': pred_score, + 'pred_label': pred_label, + 'pred_class': pred_class + } + if not args.out: + print('\nthe predicted result for the first element is ' + f'pred_score = {pred_score[0]:.2f}, ' + f'pred_label = {pred_label[0]} ' + f'and pred_class = {pred_class[0]}. ' + 'Specify --out to save all results to files.') if args.out and rank == 0: print(f'\nwriting results to {args.out}') - mmcv.dump(outputs, args.out) + mmcv.dump(results, args.out) if __name__ == '__main__':