mirror of https://github.com/alibaba/EasyCV.git
226 lines
7.4 KiB
Python
226 lines
7.4 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||
import argparse
|
||
import json
|
||
import os
|
||
import os.path as osp
|
||
import random
|
||
import time
|
||
|
||
import mmcv
|
||
import numpy as np
|
||
import torch
|
||
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
|
||
from mmcv.runner import get_dist_info, init_dist, load_checkpoint
|
||
|
||
from easycv.apis import set_random_seed
|
||
from easycv.datasets import build_dataloader, build_dataset
|
||
from easycv.file import io
|
||
from easycv.framework.errors import ValueError
|
||
from easycv.models import build_model
|
||
from easycv.utils.collect import dist_forward_collect, nondist_forward_collect
|
||
from easycv.utils.config_tools import mmcv_config_fromfile
|
||
from easycv.utils.logger import get_root_logger
|
||
|
||
|
||
class ExtractProcess(object):
|
||
|
||
def __init__(self, extract_list=['neck']):
|
||
self.extract_list = extract_list
|
||
self.pool = torch.nn.AdaptiveAvgPool2d((1, 1))
|
||
|
||
def _forward_func(self, model, **kwargs):
|
||
if hasattr(model.module, 'update_extract_list'):
|
||
for k in self.extract_list:
|
||
model.module.update_extract_list(k)
|
||
|
||
feats = model(mode='extract', **kwargs)
|
||
|
||
for k in self.extract_list:
|
||
if type(feats[k]) is torch.Tensor:
|
||
feats[k] = [feats[k]]
|
||
feat_dict = {
|
||
'feat{}'.format(i + 1): feat.cpu()
|
||
for i, feat in enumerate(feats['neck'])
|
||
}
|
||
|
||
if 'gt_labels' in kwargs.keys():
|
||
feat_dict['label'] = kwargs['gt_labels']
|
||
return feat_dict
|
||
|
||
def extract(self, model, data_loader, distributed=False):
|
||
model.eval()
|
||
func = lambda **x: self._forward_func(model, **x)
|
||
|
||
if hasattr(data_loader, 'dataset'):
|
||
length = len(data_loader.dataset)
|
||
else:
|
||
length = data_loader.data_length
|
||
|
||
if distributed:
|
||
rank, world_size = get_dist_info()
|
||
results = dist_forward_collect(func, data_loader, rank, length)
|
||
else:
|
||
results = nondist_forward_collect(func, data_loader, length)
|
||
return results
|
||
|
||
|
||
def parse_args():
|
||
parser = argparse.ArgumentParser(
|
||
description='EVTORCH batch(use dataloader) extract features of a model'
|
||
)
|
||
parser.add_argument(
|
||
'config', help='config file path', type=str, default=None)
|
||
parser.add_argument('--checkpoint', default=None, help='checkpoint file')
|
||
parser.add_argument(
|
||
'--pretrained',
|
||
default='random',
|
||
help='pretrained model file, exclusive to --checkpoint')
|
||
parser.add_argument(
|
||
'--work_dir',
|
||
type=str,
|
||
default=None,
|
||
help='the dir to save logs and models')
|
||
parser.add_argument(
|
||
'--extract_list',
|
||
type=list,
|
||
default=['neck'],
|
||
help='the dir to save logs and models')
|
||
parser.add_argument(
|
||
'--launcher',
|
||
choices=['none', 'pytorch', 'slurm', 'mpi'],
|
||
default='none',
|
||
help='job launcher')
|
||
parser.add_argument('--local_rank', type=int, default=0)
|
||
parser.add_argument(
|
||
'--port',
|
||
type=int,
|
||
default=29500,
|
||
help='port only works when launcher=="slurm"')
|
||
args = parser.parse_args()
|
||
if 'LOCAL_RANK' not in os.environ:
|
||
os.environ['LOCAL_RANK'] = str(args.local_rank)
|
||
return args
|
||
|
||
|
||
def main():
|
||
args = parse_args()
|
||
# set cudnn_benchmark
|
||
cfg = mmcv_config_fromfile(args.config)
|
||
|
||
if cfg.get('oss_io_config', None):
|
||
io.access_oss(**cfg.oss_io_config)
|
||
|
||
if cfg.get('cudnn_benchmark', False):
|
||
torch.backends.cudnn.benchmark = True
|
||
|
||
# update configs according to CLI args
|
||
if args.work_dir is not None:
|
||
cfg.work_dir = args.work_dir
|
||
|
||
# checkpoint and pretrained are exclusive
|
||
assert args.pretrained == 'random' or args.checkpoint is None, \
|
||
'Checkpoint and pretrained are exclusive.'
|
||
|
||
# init distributed env first, since logger depends on the dist info.
|
||
if args.launcher == 'none':
|
||
distributed = False
|
||
else:
|
||
distributed = True
|
||
if args.launcher == 'slurm':
|
||
cfg.dist_params['port'] = args.port
|
||
init_dist(args.launcher, **cfg.dist_params)
|
||
|
||
# create work_dir
|
||
mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
|
||
# logger
|
||
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
|
||
log_file = osp.join(cfg.work_dir, 'extract_{}.log'.format(timestamp))
|
||
logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)
|
||
|
||
datasets = [build_dataset(cfg.data.extract)]
|
||
seed = 0
|
||
set_random_seed(seed)
|
||
data_loader = [
|
||
build_dataloader(
|
||
ds,
|
||
cfg.data.imgs_per_gpu,
|
||
cfg.data.workers_per_gpu,
|
||
# cfg.gpus,
|
||
dist=distributed,
|
||
shuffle=False,
|
||
replace=getattr(cfg.data, 'sampling_replace', False),
|
||
seed=seed,
|
||
drop_last=getattr(cfg.data, 'drop_last', False)) for ds in datasets
|
||
]
|
||
|
||
# specify pretrained model
|
||
if args.pretrained != 'random':
|
||
assert isinstance(args.pretrained, str)
|
||
cfg.model.pretrained = args.pretrained
|
||
|
||
assert os.path.exists(args.checkpoint), \
|
||
'checkpoint must be set when use extractor!'
|
||
ckpt_meta = torch.load(args.checkpoint).get('meta', None)
|
||
cfg_model = cfg.get('model', None)
|
||
|
||
if cfg_model is not None:
|
||
logger.info('load model scripts from cfg config')
|
||
model = build_model(cfg_model)
|
||
else:
|
||
assert ckpt_meta is not None, 'extract need either cfg model or ckpt with meta!'
|
||
logger.info('load model scripts from ckpt meta')
|
||
ckpt_cfg = json.loads(ckpt_meta['config'])
|
||
if 'model' not in ckpt_cfg:
|
||
raise ValueError(
|
||
'build model from %s, must use model after export' %
|
||
(args.checkpoint))
|
||
model = build_model(ckpt_cfg['model'])
|
||
|
||
# build the model and load checkpoint
|
||
|
||
if args.checkpoint is not None:
|
||
logger.info('Use checkpoint: {} to extract features'.format(
|
||
args.checkpoint))
|
||
load_checkpoint(model, args.checkpoint, map_location='cpu')
|
||
elif args.pretrained != 'random':
|
||
logger.info('Use pretrained model: {} to extract features'.format(
|
||
args.pretrained))
|
||
else:
|
||
logger.info('No checkpoint or pretrained is give, use random init.')
|
||
|
||
if not distributed:
|
||
model = MMDataParallel(model, device_ids=[0])
|
||
else:
|
||
model = MMDistributedDataParallel(
|
||
model.cuda(),
|
||
device_ids=[torch.cuda.current_device()],
|
||
broadcast_buffers=False)
|
||
|
||
# build extraction processor
|
||
extractor = ExtractProcess(extract_list=args.extract_list)
|
||
|
||
# run
|
||
outputs = extractor.extract(model, data_loader[0], distributed=distributed)
|
||
|
||
rank, _ = get_dist_info()
|
||
mmcv.mkdir_or_exist(args.work_dir)
|
||
|
||
if rank == 0:
|
||
for key, val in outputs.items():
|
||
split_num = len(cfg.split_name)
|
||
split_at = cfg.split_at
|
||
print(split_num, split_at)
|
||
for ss in range(split_num):
|
||
output_file = '{}/{}_{}.npy'.format(args.work_dir,
|
||
cfg.split_name[ss], key)
|
||
if ss == 0:
|
||
np.save(output_file, val[:split_at[0]])
|
||
elif ss == split_num - 1:
|
||
np.save(output_file, val[split_at[-1]:])
|
||
else:
|
||
np.save(output_file, val[split_at[ss - 1]:split_at[ss]])
|
||
|
||
|
||
if __name__ == '__main__':
|
||
main()
|