2022-04-02 20:01:06 +08:00
|
|
|
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
|
|
|
import argparse
|
|
|
|
|
import json
|
|
|
|
|
import os
|
|
|
|
|
import os.path as osp
|
|
|
|
|
import random
|
|
|
|
|
import time
|
|
|
|
|
|
|
|
|
|
import mmcv
|
|
|
|
|
import numpy as np
|
|
|
|
|
import torch
|
|
|
|
|
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
|
|
|
|
|
from mmcv.runner import get_dist_info, init_dist, load_checkpoint
|
|
|
|
|
|
|
|
|
|
from easycv.datasets import build_dataloader, build_dataset
|
|
|
|
|
from easycv.file import io
|
|
|
|
|
from easycv.models import build_model
|
2022-04-22 15:22:43 +08:00
|
|
|
|
from easycv.utils.collect import dist_forward_collect, nondist_forward_collect
|
2022-04-02 20:01:06 +08:00
|
|
|
|
from easycv.utils.config_tools import mmcv_config_fromfile
|
2022-04-22 15:22:43 +08:00
|
|
|
|
from easycv.utils.logger import get_root_logger
|
2022-04-02 20:01:06 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def set_random_seed(seed, deterministic=True):
|
|
|
|
|
"""Set random seed.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
seed (int): Seed to be used.
|
|
|
|
|
deterministic (bool): Whether to set the deterministic option for
|
|
|
|
|
CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
|
|
|
|
|
to True and `torch.backends.cudnn.benchmark` to False.
|
|
|
|
|
Default: False.
|
|
|
|
|
"""
|
|
|
|
|
random.seed(seed)
|
|
|
|
|
np.random.seed(seed)
|
|
|
|
|
torch.manual_seed(seed)
|
|
|
|
|
torch.cuda.manual_seed_all(seed)
|
|
|
|
|
if deterministic:
|
|
|
|
|
torch.backends.cudnn.deterministic = True
|
|
|
|
|
torch.backends.cudnn.benchmark = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ExtractProcess(object):
|
|
|
|
|
|
2022-04-07 17:16:13 +08:00
|
|
|
|
def __init__(self, extract_list=['neck']):
|
2022-04-02 20:01:06 +08:00
|
|
|
|
self.extract_list = extract_list
|
|
|
|
|
self.pool = torch.nn.AdaptiveAvgPool2d((1, 1))
|
|
|
|
|
|
|
|
|
|
def _forward_func(self, model, **kwargs):
|
|
|
|
|
if hasattr(model.module, 'update_extract_list'):
|
|
|
|
|
for k in self.extract_list:
|
|
|
|
|
model.module.update_extract_list(k)
|
|
|
|
|
|
|
|
|
|
feats = model(mode='extract', **kwargs)
|
|
|
|
|
|
|
|
|
|
for k in self.extract_list:
|
|
|
|
|
if type(feats[k]) is torch.Tensor:
|
|
|
|
|
feats[k] = [feats[k]]
|
|
|
|
|
feat_dict = {
|
|
|
|
|
'feat{}'.format(i + 1): feat.cpu()
|
2022-04-07 17:16:13 +08:00
|
|
|
|
for i, feat in enumerate(feats['neck'])
|
2022-04-02 20:01:06 +08:00
|
|
|
|
}
|
|
|
|
|
|
2022-04-07 17:16:13 +08:00
|
|
|
|
if 'gt_labels' in kwargs.keys():
|
|
|
|
|
feat_dict['label'] = kwargs['gt_labels']
|
2022-04-02 20:01:06 +08:00
|
|
|
|
return feat_dict
|
|
|
|
|
|
|
|
|
|
def extract(self, model, data_loader, distributed=False):
|
|
|
|
|
model.eval()
|
|
|
|
|
func = lambda **x: self._forward_func(model, **x)
|
|
|
|
|
|
|
|
|
|
if hasattr(data_loader, 'dataset'):
|
|
|
|
|
length = len(data_loader.dataset)
|
|
|
|
|
else:
|
|
|
|
|
length = data_loader.data_length
|
|
|
|
|
|
|
|
|
|
if distributed:
|
|
|
|
|
rank, world_size = get_dist_info()
|
|
|
|
|
results = dist_forward_collect(func, data_loader, rank, length)
|
|
|
|
|
else:
|
|
|
|
|
results = nondist_forward_collect(func, data_loader, length)
|
|
|
|
|
return results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def parse_args():
|
|
|
|
|
parser = argparse.ArgumentParser(
|
|
|
|
|
description='EVTORCH batch(use dataloader) extract features of a model'
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
'config', help='config file path', type=str, default=None)
|
|
|
|
|
parser.add_argument('--checkpoint', default=None, help='checkpoint file')
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
'--pretrained',
|
|
|
|
|
default='random',
|
|
|
|
|
help='pretrained model file, exclusive to --checkpoint')
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
'--work_dir',
|
|
|
|
|
type=str,
|
|
|
|
|
default=None,
|
|
|
|
|
help='the dir to save logs and models')
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
'--extract_list',
|
|
|
|
|
type=list,
|
2022-04-07 17:16:13 +08:00
|
|
|
|
default=['neck'],
|
2022-04-02 20:01:06 +08:00
|
|
|
|
help='the dir to save logs and models')
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
'--launcher',
|
|
|
|
|
choices=['none', 'pytorch', 'slurm', 'mpi'],
|
|
|
|
|
default='none',
|
|
|
|
|
help='job launcher')
|
|
|
|
|
parser.add_argument('--local_rank', type=int, default=0)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
'--port',
|
|
|
|
|
type=int,
|
|
|
|
|
default=29500,
|
|
|
|
|
help='port only works when launcher=="slurm"')
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
if 'LOCAL_RANK' not in os.environ:
|
|
|
|
|
os.environ['LOCAL_RANK'] = str(args.local_rank)
|
|
|
|
|
return args
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
|
|
args = parse_args()
|
|
|
|
|
# set cudnn_benchmark
|
|
|
|
|
cfg = mmcv_config_fromfile(args.config)
|
|
|
|
|
|
|
|
|
|
if cfg.get('oss_io_config', None):
|
|
|
|
|
io.access_oss(**cfg.oss_io_config)
|
|
|
|
|
|
|
|
|
|
if cfg.get('cudnn_benchmark', False):
|
|
|
|
|
torch.backends.cudnn.benchmark = True
|
|
|
|
|
|
|
|
|
|
# update configs according to CLI args
|
|
|
|
|
if args.work_dir is not None:
|
|
|
|
|
cfg.work_dir = args.work_dir
|
|
|
|
|
|
|
|
|
|
# checkpoint and pretrained are exclusive
|
|
|
|
|
assert args.pretrained == 'random' or args.checkpoint is None, \
|
|
|
|
|
'Checkpoint and pretrained are exclusive.'
|
|
|
|
|
|
|
|
|
|
# init distributed env first, since logger depends on the dist info.
|
|
|
|
|
if args.launcher == 'none':
|
|
|
|
|
distributed = False
|
|
|
|
|
else:
|
|
|
|
|
distributed = True
|
|
|
|
|
if args.launcher == 'slurm':
|
|
|
|
|
cfg.dist_params['port'] = args.port
|
|
|
|
|
init_dist(args.launcher, **cfg.dist_params)
|
|
|
|
|
|
|
|
|
|
# create work_dir
|
|
|
|
|
mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
|
|
|
|
|
# logger
|
|
|
|
|
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
|
|
|
|
|
log_file = osp.join(cfg.work_dir, 'extract_{}.log'.format(timestamp))
|
|
|
|
|
logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)
|
|
|
|
|
|
|
|
|
|
datasets = [build_dataset(cfg.data.extract)]
|
|
|
|
|
seed = 0
|
|
|
|
|
set_random_seed(seed)
|
|
|
|
|
data_loader = [
|
|
|
|
|
build_dataloader(
|
|
|
|
|
ds,
|
|
|
|
|
cfg.data.imgs_per_gpu,
|
|
|
|
|
cfg.data.workers_per_gpu,
|
|
|
|
|
# cfg.gpus,
|
|
|
|
|
dist=distributed,
|
|
|
|
|
shuffle=False,
|
|
|
|
|
replace=getattr(cfg.data, 'sampling_replace', False),
|
|
|
|
|
seed=seed,
|
|
|
|
|
drop_last=getattr(cfg.data, 'drop_last', False)) for ds in datasets
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
# specify pretrained model
|
|
|
|
|
if args.pretrained != 'random':
|
|
|
|
|
assert isinstance(args.pretrained, str)
|
|
|
|
|
cfg.model.pretrained = args.pretrained
|
|
|
|
|
|
|
|
|
|
assert os.path.exists(args.checkpoint), \
|
|
|
|
|
'checkpoint must be set when use extractor!'
|
|
|
|
|
ckpt_meta = torch.load(args.checkpoint).get('meta', None)
|
|
|
|
|
cfg_model = cfg.get('model', None)
|
|
|
|
|
|
|
|
|
|
if cfg_model is not None:
|
|
|
|
|
logger.info('load model scripts from cfg config')
|
|
|
|
|
model = build_model(cfg_model)
|
|
|
|
|
else:
|
|
|
|
|
assert ckpt_meta is not None, 'extract need either cfg model or ckpt with meta!'
|
|
|
|
|
logger.info('load model scripts from ckpt meta')
|
|
|
|
|
ckpt_cfg = json.loads(ckpt_meta['config'])
|
|
|
|
|
if 'model' not in ckpt_cfg:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
'build model from %s, must use model after export' %
|
|
|
|
|
(args.checkpoint))
|
|
|
|
|
model = build_model(ckpt_cfg['model'])
|
|
|
|
|
|
|
|
|
|
# build the model and load checkpoint
|
|
|
|
|
|
|
|
|
|
if args.checkpoint is not None:
|
|
|
|
|
logger.info('Use checkpoint: {} to extract features'.format(
|
|
|
|
|
args.checkpoint))
|
|
|
|
|
load_checkpoint(model, args.checkpoint, map_location='cpu')
|
|
|
|
|
elif args.pretrained != 'random':
|
|
|
|
|
logger.info('Use pretrained model: {} to extract features'.format(
|
|
|
|
|
args.pretrained))
|
|
|
|
|
else:
|
|
|
|
|
logger.info('No checkpoint or pretrained is give, use random init.')
|
|
|
|
|
|
|
|
|
|
if not distributed:
|
|
|
|
|
model = MMDataParallel(model, device_ids=[0])
|
|
|
|
|
else:
|
|
|
|
|
model = MMDistributedDataParallel(
|
|
|
|
|
model.cuda(),
|
|
|
|
|
device_ids=[torch.cuda.current_device()],
|
|
|
|
|
broadcast_buffers=False)
|
|
|
|
|
|
|
|
|
|
# build extraction processor
|
|
|
|
|
extractor = ExtractProcess(extract_list=args.extract_list)
|
|
|
|
|
|
|
|
|
|
# run
|
|
|
|
|
outputs = extractor.extract(model, data_loader[0], distributed=distributed)
|
|
|
|
|
|
|
|
|
|
rank, _ = get_dist_info()
|
|
|
|
|
mmcv.mkdir_or_exist(args.work_dir)
|
|
|
|
|
|
|
|
|
|
if rank == 0:
|
|
|
|
|
for key, val in outputs.items():
|
|
|
|
|
split_num = len(cfg.split_name)
|
|
|
|
|
split_at = cfg.split_at
|
|
|
|
|
print(split_num, split_at)
|
|
|
|
|
for ss in range(split_num):
|
|
|
|
|
output_file = '{}/{}_{}.npy'.format(args.work_dir,
|
|
|
|
|
cfg.split_name[ss], key)
|
|
|
|
|
if ss == 0:
|
|
|
|
|
np.save(output_file, val[:split_at[0]])
|
|
|
|
|
elif ss == split_num - 1:
|
|
|
|
|
np.save(output_file, val[split_at[-1]:])
|
|
|
|
|
else:
|
|
|
|
|
np.save(output_file, val[split_at[ss - 1]:split_at[ss]])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
main()
|