[Refactor]: refactor extract.py and the configs
parent
d390a03a47
commit
775364cf11
|
@ -1,22 +1,30 @@
|
|||
data_source = 'ImageList'
|
||||
dataset_type = 'SingleViewDataset'
|
||||
custom_imports = dict(imports='mmcls.datasets', allow_failed_imports=False)
|
||||
dataset_type = 'ImageList'
|
||||
data_root = 'data/VOCdevkit/VOC2007/'
|
||||
file_client_args = dict(backend='disk')
|
||||
|
||||
split_at = [5011]
|
||||
split_name = ['voc07_trainval', 'voc07_test']
|
||||
img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
|
||||
data = dict(
|
||||
samples_per_gpu=32,
|
||||
workers_per_gpu=4,
|
||||
extract=dict(
|
||||
extract_pipeline = [
|
||||
dict(type='LoadImageFromFile', file_client_args=file_client_args),
|
||||
dict(type='mmcls.ResizeEdge', scale=256),
|
||||
dict(type='Resize', scale=224),
|
||||
dict(type='PackSelfSupInputs', meta_keys=['img_path'])
|
||||
]
|
||||
|
||||
extract_dataloader = dict(
|
||||
batch_size=32,
|
||||
num_workers=4,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
extract_dataset=dict(
|
||||
type=dataset_type,
|
||||
data_source=dict(
|
||||
type=data_source,
|
||||
data_prefix='data/VOCdevkit/VOC2007/JPEGImages',
|
||||
ann_file='data/VOCdevkit/VOC2007/Lists/trainvaltest.txt',
|
||||
),
|
||||
pipeline=[
|
||||
dict(type='Resize', size=256),
|
||||
dict(type='Resize', size=(224, 224)),
|
||||
dict(type='ToTensor'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
]))
|
||||
ann_file='Lists/trainvaltest.txt',
|
||||
data_root=data_root,
|
||||
data_prefix='JPEGImages/',
|
||||
pipeline=extract_pipeline))
|
||||
|
||||
# pooling cfg
|
||||
pool_cfg = dict(
|
||||
type='MultiPooling', pool_type='specified', in_indices=(0, 1, 2, 3, 4))
|
||||
|
|
|
@ -2,11 +2,12 @@
|
|||
from .builder import DATASETS, build_dataset
|
||||
from .dataset_wrappers import ConcatDataset, RepeatDataset
|
||||
from .deepcluster_dataset import DeepClusterImageNet
|
||||
from .image_list_dataset import ImageList
|
||||
from .pipelines import * # noqa: F401,F403
|
||||
from .places205 import Places205
|
||||
from .samplers import * # noqa: F401,F403
|
||||
|
||||
__all__ = [
|
||||
'DATASETS', 'build_dataset', 'ConcatDataset', 'RepeatDataset', 'Places205',
|
||||
'DeepClusterImageNet'
|
||||
'DeepClusterImageNet', 'ImageList'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,107 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from mmcls.datasets import CustomDataset
|
||||
from mmengine import FileClient
|
||||
|
||||
from mmselfsup.registry import DATASETS
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class ImageList(CustomDataset):
|
||||
"""The dataset implementation for loading any image list file.
|
||||
|
||||
The `ImageList` can load an annotation file or a list of files and merge
|
||||
all data records to one list. If data is unlabeled, the gt_label will be
|
||||
set -1.
|
||||
|
||||
An annotation file should be provided, and each line indicates a sample:
|
||||
|
||||
The sample files: ::
|
||||
|
||||
data_prefix/
|
||||
├── folder_1
|
||||
│ ├── xxx.png
|
||||
│ ├── xxy.png
|
||||
│ └── ...
|
||||
└── folder_2
|
||||
├── 123.png
|
||||
├── nsdf3.png
|
||||
└── ...
|
||||
|
||||
1. If data is labeled, the annotation file (the first column is the image
|
||||
path and the second column is the index of category): ::
|
||||
|
||||
folder_1/xxx.png 0
|
||||
folder_1/xxy.png 1
|
||||
folder_2/123.png 5
|
||||
folder_2/nsdf3.png 3
|
||||
...
|
||||
|
||||
2. If data is unlabeled, the annotation file is: ::
|
||||
|
||||
folder_1/xxx.png
|
||||
folder_1/xxy.png
|
||||
folder_2/123.png
|
||||
folder_2/nsdf3.png
|
||||
...
|
||||
|
||||
Args:
|
||||
ann_file (str): Annotation file path. Defaults to None.
|
||||
metainfo (dict, optional): Meta information for dataset, such as class
|
||||
information. Defaults to None.
|
||||
data_root (str): The root directory for ``data_prefix`` and
|
||||
``ann_file``. Defaults to None.
|
||||
data_prefix (str | dict): Prefix for training data. Defaults
|
||||
to None.
|
||||
**kwargs: Other keyword arguments in :class:`CustomDataset` and
|
||||
:class:`BaseDataset`.
|
||||
""" # noqa: E501
|
||||
|
||||
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif')
|
||||
|
||||
def __init__(self,
|
||||
ann_file: str = '',
|
||||
metainfo: Optional[dict] = None,
|
||||
data_root: str = '',
|
||||
data_prefix: Union[str, dict] = '',
|
||||
**kwargs) -> None:
|
||||
kwargs = {'extensions': self.IMG_EXTENSIONS, **kwargs}
|
||||
super().__init__(
|
||||
ann_file=ann_file,
|
||||
metainfo=metainfo,
|
||||
data_root=data_root,
|
||||
data_prefix=data_prefix,
|
||||
**kwargs)
|
||||
|
||||
def load_data_list(self) -> List[dict]:
|
||||
"""Rewrite load_data_list() function for supporting a list of
|
||||
annotation files and unlabeled data.
|
||||
|
||||
Returns:
|
||||
List[dict]: A list of data information.
|
||||
"""
|
||||
if self.img_prefix is not None:
|
||||
file_client = FileClient.infer_client(uri=self.img_prefix)
|
||||
|
||||
assert self.ann_file is not None
|
||||
if not isinstance(self.ann_file, list):
|
||||
self.ann_file = [self.ann_file]
|
||||
|
||||
data_list = []
|
||||
for ann_file in self.ann_file:
|
||||
with open(ann_file, 'r') as f:
|
||||
self.samples = f.readlines()
|
||||
self.has_labels = len(self.samples[0].split()) == 2
|
||||
|
||||
for sample in self.samples:
|
||||
info = {'img_prefix': self.img_prefix}
|
||||
sample = sample.split()
|
||||
info['img_path'] = file_client.join_path(
|
||||
self.img_prefix, sample[0])
|
||||
info['img_info'] = {'filename': sample[0]}
|
||||
labels = sample[1] if self.has_labels else -1
|
||||
info['gt_label'] = np.array(labels, dtype=np.int64)
|
||||
data_list.append(info)
|
||||
return data_list
|
|
@ -3,18 +3,22 @@ import argparse
|
|||
import os
|
||||
import os.path as osp
|
||||
import time
|
||||
from functools import partial
|
||||
from typing import Optional
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmcv import DictAction
|
||||
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
|
||||
from mmcv.runner import get_dist_info, init_dist, load_checkpoint
|
||||
from mmengine.config import Config, DictAction
|
||||
from mmengine.data import pseudo_collate, worker_init_fn
|
||||
from mmengine.dist import get_rank, init_dist
|
||||
from mmengine.model.wrappers import MMDistributedDataParallel, is_model_wrapper
|
||||
from mmengine.runner import load_checkpoint
|
||||
from mmengine.utils import mkdir_or_exist
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from mmselfsup.datasets import build_dataloader, build_dataset
|
||||
from mmselfsup.models import build_algorithm
|
||||
from mmselfsup.models.utils import MultiExtractProcess
|
||||
from mmselfsup.utils import get_root_logger
|
||||
from mmselfsup.models.utils import Extractor
|
||||
from mmselfsup.registry import DATA_SAMPLERS, DATASETS, MODELS
|
||||
from mmselfsup.utils import get_root_logger, register_all_modules
|
||||
|
||||
|
||||
def parse_args():
|
||||
|
@ -31,7 +35,7 @@ def parse_args():
|
|||
type=str,
|
||||
help='layer indices, separated by comma, e.g., "0,1,2,3,4"')
|
||||
parser.add_argument(
|
||||
'--work_dir',
|
||||
'--work-dir',
|
||||
type=str,
|
||||
default=None,
|
||||
help='the dir to save logs and models')
|
||||
|
@ -41,6 +45,7 @@ def parse_args():
|
|||
default='none',
|
||||
help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
parser.add_argument('--seed', type=int, default=0, help='random seed')
|
||||
parser.add_argument(
|
||||
'--cfg-options',
|
||||
nargs='+',
|
||||
|
@ -59,12 +64,19 @@ def parse_args():
|
|||
|
||||
def main():
|
||||
args = parse_args()
|
||||
cfg = mmcv.Config.fromfile(args.config)
|
||||
|
||||
# register all modules in mmselfsup into the registries
|
||||
register_all_modules()
|
||||
|
||||
# load config
|
||||
cfg = Config.fromfile(args.config)
|
||||
if args.cfg_options is not None:
|
||||
cfg.merge_from_dict(args.cfg_options)
|
||||
|
||||
# set cudnn_benchmark
|
||||
if cfg.get('cudnn_benchmark', False):
|
||||
if cfg.env_cfg.get('cudnn_benchmark', False):
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
# work_dir is determined in this priority: CLI > segment in file > filename
|
||||
if args.work_dir is not None:
|
||||
# update configs according to CLI args if args.work_dir is not None
|
||||
|
@ -75,50 +87,59 @@ def main():
|
|||
cfg.work_dir = osp.join('./work_dirs', work_type,
|
||||
osp.splitext(osp.basename(args.config))[0])
|
||||
|
||||
# get out_indices from args
|
||||
layer_ind = [int(idx) for idx in args.layer_ind.split(',')]
|
||||
cfg.model.backbone.out_indices = layer_ind
|
||||
|
||||
# init distributed env first, since logger depends on the dist info.
|
||||
if args.launcher == 'none':
|
||||
distributed = False
|
||||
else:
|
||||
distributed = True
|
||||
init_dist(args.launcher, **cfg.dist_params)
|
||||
init_dist(args.launcher)
|
||||
|
||||
# create work_dir
|
||||
mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
|
||||
mkdir_or_exist(osp.abspath(cfg.work_dir))
|
||||
|
||||
# init the logger before other steps
|
||||
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
|
||||
log_file = osp.join(cfg.work_dir, f'extract_{timestamp}.log')
|
||||
logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)
|
||||
|
||||
# build the dataloader
|
||||
dataset_cfg = mmcv.Config.fromfile(args.dataset_config)
|
||||
dataset = build_dataset(dataset_cfg.data.extract)
|
||||
if 'imgs_per_gpu' in cfg.data:
|
||||
logger.warning('"imgs_per_gpu" is deprecated. '
|
||||
'Please use "samples_per_gpu" instead')
|
||||
if 'samples_per_gpu' in cfg.data:
|
||||
logger.warning(
|
||||
f'Got "imgs_per_gpu"={cfg.data.imgs_per_gpu} and '
|
||||
f'"samples_per_gpu"={cfg.data.samples_per_gpu}, "imgs_per_gpu"'
|
||||
f'={cfg.data.imgs_per_gpu} is used in this experiments')
|
||||
else:
|
||||
logger.warning(
|
||||
'Automatically set "samples_per_gpu"="imgs_per_gpu"='
|
||||
f'{cfg.data.imgs_per_gpu} in this experiments')
|
||||
cfg.data.samples_per_gpu = cfg.data.imgs_per_gpu
|
||||
# build the dataset
|
||||
dataset_cfg = Config.fromfile(args.dataset_config)
|
||||
extract_dataloader_cfg = dataset_cfg.get('extract_dataloader')
|
||||
extract_dataset_cfg = extract_dataloader_cfg.pop('extract_dataset')
|
||||
if isinstance(extract_dataset_cfg, dict):
|
||||
dataset = DATASETS.build(extract_dataset_cfg)
|
||||
if hasattr(dataset, 'full_init'):
|
||||
dataset.full_init()
|
||||
|
||||
data_loader = build_dataloader(
|
||||
dataset,
|
||||
samples_per_gpu=dataset_cfg.data.samples_per_gpu,
|
||||
workers_per_gpu=dataset_cfg.data.workers_per_gpu,
|
||||
dist=distributed,
|
||||
shuffle=False)
|
||||
# build sampler
|
||||
sampler_cfg = extract_dataloader_cfg.pop('sampler')
|
||||
if isinstance(sampler_cfg, dict):
|
||||
sampler = DATA_SAMPLERS.build(
|
||||
sampler_cfg, default_args=dict(dataset=dataset, seed=args.seed))
|
||||
|
||||
# build dataloader
|
||||
init_fn: Optional[partial]
|
||||
if args.seed is not None:
|
||||
init_fn = partial(
|
||||
worker_init_fn,
|
||||
num_workers=extract_dataloader_cfg.get('num_workers'),
|
||||
rank=get_rank(),
|
||||
seed=args.seed)
|
||||
else:
|
||||
init_fn = None
|
||||
|
||||
data_loader = DataLoader(
|
||||
dataset=dataset,
|
||||
sampler=sampler,
|
||||
collate_fn=pseudo_collate,
|
||||
worker_init_fn=init_fn,
|
||||
**extract_dataloader_cfg)
|
||||
|
||||
# build the model
|
||||
model = build_algorithm(cfg.model)
|
||||
# get out_indices from args
|
||||
layer_ind = [int(idx) for idx in args.layer_ind.split(',')]
|
||||
cfg.model.backbone.out_indices = layer_ind
|
||||
model = MODELS.build(cfg.model)
|
||||
model.init_weights()
|
||||
|
||||
# model is determined in this priority: init_cfg > checkpoint > random
|
||||
|
@ -134,28 +155,35 @@ def main():
|
|||
else:
|
||||
logger.info('No pretrained or checkpoint is given, use random init.')
|
||||
|
||||
if not distributed:
|
||||
model = MMDataParallel(model, device_ids=[0])
|
||||
else:
|
||||
if torch.cuda.is_available():
|
||||
model = model.cuda()
|
||||
|
||||
if distributed:
|
||||
model = MMDistributedDataParallel(
|
||||
model.cuda(),
|
||||
module=model.cuda(),
|
||||
device_ids=[torch.cuda.current_device()],
|
||||
broadcast_buffers=False)
|
||||
|
||||
# build extraction processor
|
||||
extractor = MultiExtractProcess(
|
||||
pool_type='specified', backbone='resnet50', layer_indices=layer_ind)
|
||||
if is_model_wrapper(model):
|
||||
model = model.module
|
||||
|
||||
# build extractor and extract features
|
||||
extractor = Extractor(
|
||||
extract_dataloader=data_loader,
|
||||
seed=args.seed,
|
||||
dist_mode=distributed,
|
||||
pool_cfg=dataset_cfg.pool_cfg)
|
||||
outputs = extractor(model)
|
||||
|
||||
# run
|
||||
outputs = extractor.extract(model, data_loader, distributed=distributed)
|
||||
rank, _ = get_dist_info()
|
||||
mmcv.mkdir_or_exist(f'{args.work_dir}/features/')
|
||||
rank = get_rank()
|
||||
mkdir_or_exist(f'{cfg.work_dir}/features/')
|
||||
if rank == 0:
|
||||
for key, val in outputs.items():
|
||||
split_num = len(dataset_cfg.split_name)
|
||||
split_at = dataset_cfg.split_at
|
||||
for ss in range(split_num):
|
||||
output_file = f'{args.work_dir}/features/' \
|
||||
output_file = f'{cfg.work_dir}/features/' \
|
||||
f'{dataset_cfg.split_name[ss]}_{key}.npy'
|
||||
if ss == 0:
|
||||
np.save(output_file, val[:split_at[0]])
|
||||
|
|
|
@ -22,5 +22,5 @@ srun -p ${PARTITION} \
|
|||
--kill-on-bad-exit=1 \
|
||||
${SRUN_ARGS} \
|
||||
python -u tools/benchmarks/classification/svm_voc07/extract.py ${CFG} \
|
||||
--layer-ind "0,1,2,3,4" --work_dir ${WORK_DIR} \
|
||||
--layer-ind "0,1,2,3,4" --work-dir ${WORK_DIR} \
|
||||
--launcher="slurm" ${PY_ARGS}
|
||||
|
|
|
@ -69,8 +69,8 @@ def test_svm(opts):
|
|||
for cls in range(num_classes):
|
||||
cost = costs_list[cls]
|
||||
model_file = osp.join(
|
||||
opts.output_path,
|
||||
'cls' + str(cls) + '_cost' + str(cost) + '.pickle')
|
||||
opts.output_path, 'cls' + str(cls) + '_cost' +
|
||||
svm_helper.py2_py3_compatible_cost(cost) + '.pickle')
|
||||
with open(model_file, 'rb') as fopen:
|
||||
if six.PY2:
|
||||
model = pickle.load(fopen)
|
||||
|
|
Loading…
Reference in New Issue