[Refactor]: refactor extract.py and the configs

pull/352/head
renqin 2022-07-14 05:25:17 +00:00 committed by fangyixiao18
parent d390a03a47
commit 775364cf11
6 changed files with 217 additions and 73 deletions

View File

@ -1,22 +1,30 @@
data_source = 'ImageList'
dataset_type = 'SingleViewDataset'
custom_imports = dict(imports='mmcls.datasets', allow_failed_imports=False)
dataset_type = 'ImageList'
data_root = 'data/VOCdevkit/VOC2007/'
file_client_args = dict(backend='disk')
split_at = [5011]
split_name = ['voc07_trainval', 'voc07_test']
img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
data = dict(
samples_per_gpu=32,
workers_per_gpu=4,
extract=dict(
extract_pipeline = [
dict(type='LoadImageFromFile', file_client_args=file_client_args),
dict(type='mmcls.ResizeEdge', scale=256),
dict(type='Resize', scale=224),
dict(type='PackSelfSupInputs', meta_keys=['img_path'])
]
extract_dataloader = dict(
batch_size=32,
num_workers=4,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=False),
extract_dataset=dict(
type=dataset_type,
data_source=dict(
type=data_source,
data_prefix='data/VOCdevkit/VOC2007/JPEGImages',
ann_file='data/VOCdevkit/VOC2007/Lists/trainvaltest.txt',
),
pipeline=[
dict(type='Resize', size=256),
dict(type='Resize', size=(224, 224)),
dict(type='ToTensor'),
dict(type='Normalize', **img_norm_cfg),
]))
ann_file='Lists/trainvaltest.txt',
data_root=data_root,
data_prefix='JPEGImages/',
pipeline=extract_pipeline))
# pooling cfg
pool_cfg = dict(
type='MultiPooling', pool_type='specified', in_indices=(0, 1, 2, 3, 4))

View File

@ -2,11 +2,12 @@
from .builder import DATASETS, build_dataset
from .dataset_wrappers import ConcatDataset, RepeatDataset
from .deepcluster_dataset import DeepClusterImageNet
from .image_list_dataset import ImageList
from .pipelines import * # noqa: F401,F403
from .places205 import Places205
from .samplers import * # noqa: F401,F403
__all__ = [
'DATASETS', 'build_dataset', 'ConcatDataset', 'RepeatDataset', 'Places205',
'DeepClusterImageNet'
'DeepClusterImageNet', 'ImageList'
]

View File

@ -0,0 +1,107 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Union
import numpy as np
from mmcls.datasets import CustomDataset
from mmengine import FileClient
from mmselfsup.registry import DATASETS
@DATASETS.register_module()
class ImageList(CustomDataset):
"""The dataset implementation for loading any image list file.
The `ImageList` can load an annotation file or a list of files and merge
all data records to one list. If data is unlabeled, the gt_label will be
set -1.
An annotation file should be provided, and each line indicates a sample:
The sample files: ::
data_prefix/
folder_1
xxx.png
xxy.png
...
folder_2
123.png
nsdf3.png
...
1. If data is labeled, the annotation file (the first column is the image
path and the second column is the index of category): ::
folder_1/xxx.png 0
folder_1/xxy.png 1
folder_2/123.png 5
folder_2/nsdf3.png 3
...
2. If data is unlabeled, the annotation file is: ::
folder_1/xxx.png
folder_1/xxy.png
folder_2/123.png
folder_2/nsdf3.png
...
Args:
ann_file (str): Annotation file path. Defaults to None.
metainfo (dict, optional): Meta information for dataset, such as class
information. Defaults to None.
data_root (str): The root directory for ``data_prefix`` and
``ann_file``. Defaults to None.
data_prefix (str | dict): Prefix for training data. Defaults
to None.
**kwargs: Other keyword arguments in :class:`CustomDataset` and
:class:`BaseDataset`.
""" # noqa: E501
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif')
def __init__(self,
ann_file: str = '',
metainfo: Optional[dict] = None,
data_root: str = '',
data_prefix: Union[str, dict] = '',
**kwargs) -> None:
kwargs = {'extensions': self.IMG_EXTENSIONS, **kwargs}
super().__init__(
ann_file=ann_file,
metainfo=metainfo,
data_root=data_root,
data_prefix=data_prefix,
**kwargs)
def load_data_list(self) -> List[dict]:
"""Rewrite load_data_list() function for supporting a list of
annotation files and unlabeled data.
Returns:
List[dict]: A list of data information.
"""
if self.img_prefix is not None:
file_client = FileClient.infer_client(uri=self.img_prefix)
assert self.ann_file is not None
if not isinstance(self.ann_file, list):
self.ann_file = [self.ann_file]
data_list = []
for ann_file in self.ann_file:
with open(ann_file, 'r') as f:
self.samples = f.readlines()
self.has_labels = len(self.samples[0].split()) == 2
for sample in self.samples:
info = {'img_prefix': self.img_prefix}
sample = sample.split()
info['img_path'] = file_client.join_path(
self.img_prefix, sample[0])
info['img_info'] = {'filename': sample[0]}
labels = sample[1] if self.has_labels else -1
info['gt_label'] = np.array(labels, dtype=np.int64)
data_list.append(info)
return data_list

View File

@ -3,18 +3,22 @@ import argparse
import os
import os.path as osp
import time
from functools import partial
from typing import Optional
import mmcv
import numpy as np
import torch
from mmcv import DictAction
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import get_dist_info, init_dist, load_checkpoint
from mmengine.config import Config, DictAction
from mmengine.data import pseudo_collate, worker_init_fn
from mmengine.dist import get_rank, init_dist
from mmengine.model.wrappers import MMDistributedDataParallel, is_model_wrapper
from mmengine.runner import load_checkpoint
from mmengine.utils import mkdir_or_exist
from torch.utils.data import DataLoader
from mmselfsup.datasets import build_dataloader, build_dataset
from mmselfsup.models import build_algorithm
from mmselfsup.models.utils import MultiExtractProcess
from mmselfsup.utils import get_root_logger
from mmselfsup.models.utils import Extractor
from mmselfsup.registry import DATA_SAMPLERS, DATASETS, MODELS
from mmselfsup.utils import get_root_logger, register_all_modules
def parse_args():
@ -31,7 +35,7 @@ def parse_args():
type=str,
help='layer indices, separated by comma, e.g., "0,1,2,3,4"')
parser.add_argument(
'--work_dir',
'--work-dir',
type=str,
default=None,
help='the dir to save logs and models')
@ -41,6 +45,7 @@ def parse_args():
default='none',
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
parser.add_argument('--seed', type=int, default=0, help='random seed')
parser.add_argument(
'--cfg-options',
nargs='+',
@ -59,12 +64,19 @@ def parse_args():
def main():
args = parse_args()
cfg = mmcv.Config.fromfile(args.config)
# register all modules in mmselfsup into the registries
register_all_modules()
# load config
cfg = Config.fromfile(args.config)
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
# set cudnn_benchmark
if cfg.get('cudnn_benchmark', False):
if cfg.env_cfg.get('cudnn_benchmark', False):
torch.backends.cudnn.benchmark = True
# work_dir is determined in this priority: CLI > segment in file > filename
if args.work_dir is not None:
# update configs according to CLI args if args.work_dir is not None
@ -75,50 +87,59 @@ def main():
cfg.work_dir = osp.join('./work_dirs', work_type,
osp.splitext(osp.basename(args.config))[0])
# get out_indices from args
layer_ind = [int(idx) for idx in args.layer_ind.split(',')]
cfg.model.backbone.out_indices = layer_ind
# init distributed env first, since logger depends on the dist info.
if args.launcher == 'none':
distributed = False
else:
distributed = True
init_dist(args.launcher, **cfg.dist_params)
init_dist(args.launcher)
# create work_dir
mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
mkdir_or_exist(osp.abspath(cfg.work_dir))
# init the logger before other steps
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
log_file = osp.join(cfg.work_dir, f'extract_{timestamp}.log')
logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)
# build the dataloader
dataset_cfg = mmcv.Config.fromfile(args.dataset_config)
dataset = build_dataset(dataset_cfg.data.extract)
if 'imgs_per_gpu' in cfg.data:
logger.warning('"imgs_per_gpu" is deprecated. '
'Please use "samples_per_gpu" instead')
if 'samples_per_gpu' in cfg.data:
logger.warning(
f'Got "imgs_per_gpu"={cfg.data.imgs_per_gpu} and '
f'"samples_per_gpu"={cfg.data.samples_per_gpu}, "imgs_per_gpu"'
f'={cfg.data.imgs_per_gpu} is used in this experiments')
else:
logger.warning(
'Automatically set "samples_per_gpu"="imgs_per_gpu"='
f'{cfg.data.imgs_per_gpu} in this experiments')
cfg.data.samples_per_gpu = cfg.data.imgs_per_gpu
# build the dataset
dataset_cfg = Config.fromfile(args.dataset_config)
extract_dataloader_cfg = dataset_cfg.get('extract_dataloader')
extract_dataset_cfg = extract_dataloader_cfg.pop('extract_dataset')
if isinstance(extract_dataset_cfg, dict):
dataset = DATASETS.build(extract_dataset_cfg)
if hasattr(dataset, 'full_init'):
dataset.full_init()
data_loader = build_dataloader(
dataset,
samples_per_gpu=dataset_cfg.data.samples_per_gpu,
workers_per_gpu=dataset_cfg.data.workers_per_gpu,
dist=distributed,
shuffle=False)
# build sampler
sampler_cfg = extract_dataloader_cfg.pop('sampler')
if isinstance(sampler_cfg, dict):
sampler = DATA_SAMPLERS.build(
sampler_cfg, default_args=dict(dataset=dataset, seed=args.seed))
# build dataloader
init_fn: Optional[partial]
if args.seed is not None:
init_fn = partial(
worker_init_fn,
num_workers=extract_dataloader_cfg.get('num_workers'),
rank=get_rank(),
seed=args.seed)
else:
init_fn = None
data_loader = DataLoader(
dataset=dataset,
sampler=sampler,
collate_fn=pseudo_collate,
worker_init_fn=init_fn,
**extract_dataloader_cfg)
# build the model
model = build_algorithm(cfg.model)
# get out_indices from args
layer_ind = [int(idx) for idx in args.layer_ind.split(',')]
cfg.model.backbone.out_indices = layer_ind
model = MODELS.build(cfg.model)
model.init_weights()
# model is determined in this priority: init_cfg > checkpoint > random
@ -134,28 +155,35 @@ def main():
else:
logger.info('No pretrained or checkpoint is given, use random init.')
if not distributed:
model = MMDataParallel(model, device_ids=[0])
else:
if torch.cuda.is_available():
model = model.cuda()
if distributed:
model = MMDistributedDataParallel(
model.cuda(),
module=model.cuda(),
device_ids=[torch.cuda.current_device()],
broadcast_buffers=False)
# build extraction processor
extractor = MultiExtractProcess(
pool_type='specified', backbone='resnet50', layer_indices=layer_ind)
if is_model_wrapper(model):
model = model.module
# build extractor and extract features
extractor = Extractor(
extract_dataloader=data_loader,
seed=args.seed,
dist_mode=distributed,
pool_cfg=dataset_cfg.pool_cfg)
outputs = extractor(model)
# run
outputs = extractor.extract(model, data_loader, distributed=distributed)
rank, _ = get_dist_info()
mmcv.mkdir_or_exist(f'{args.work_dir}/features/')
rank = get_rank()
mkdir_or_exist(f'{cfg.work_dir}/features/')
if rank == 0:
for key, val in outputs.items():
split_num = len(dataset_cfg.split_name)
split_at = dataset_cfg.split_at
for ss in range(split_num):
output_file = f'{args.work_dir}/features/' \
output_file = f'{cfg.work_dir}/features/' \
f'{dataset_cfg.split_name[ss]}_{key}.npy'
if ss == 0:
np.save(output_file, val[:split_at[0]])

View File

@ -22,5 +22,5 @@ srun -p ${PARTITION} \
--kill-on-bad-exit=1 \
${SRUN_ARGS} \
python -u tools/benchmarks/classification/svm_voc07/extract.py ${CFG} \
--layer-ind "0,1,2,3,4" --work_dir ${WORK_DIR} \
--layer-ind "0,1,2,3,4" --work-dir ${WORK_DIR} \
--launcher="slurm" ${PY_ARGS}

View File

@ -69,8 +69,8 @@ def test_svm(opts):
for cls in range(num_classes):
cost = costs_list[cls]
model_file = osp.join(
opts.output_path,
'cls' + str(cls) + '_cost' + str(cost) + '.pickle')
opts.output_path, 'cls' + str(cls) + '_cost' +
svm_helper.py2_py3_compatible_cost(cost) + '.pickle')
with open(model_file, 'rb') as fopen:
if six.PY2:
model = pickle.load(fopen)