EasyCV/benchmarks/tools/extract.py

226 lines
7.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# Copyright (c) Alibaba, Inc. and its affiliates.
import argparse
import json
import os
import os.path as osp
import random
import time
import mmcv
import numpy as np
import torch
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import get_dist_info, init_dist, load_checkpoint
from easycv.apis import set_random_seed
from easycv.datasets import build_dataloader, build_dataset
from easycv.file import io
from easycv.framework.errors import ValueError
from easycv.models import build_model
from easycv.utils.collect import dist_forward_collect, nondist_forward_collect
from easycv.utils.config_tools import mmcv_config_fromfile
from easycv.utils.logger import get_root_logger
class ExtractProcess(object):
def __init__(self, extract_list=['neck']):
self.extract_list = extract_list
self.pool = torch.nn.AdaptiveAvgPool2d((1, 1))
def _forward_func(self, model, **kwargs):
if hasattr(model.module, 'update_extract_list'):
for k in self.extract_list:
model.module.update_extract_list(k)
feats = model(mode='extract', **kwargs)
for k in self.extract_list:
if type(feats[k]) is torch.Tensor:
feats[k] = [feats[k]]
feat_dict = {
'feat{}'.format(i + 1): feat.cpu()
for i, feat in enumerate(feats['neck'])
}
if 'gt_labels' in kwargs.keys():
feat_dict['label'] = kwargs['gt_labels']
return feat_dict
def extract(self, model, data_loader, distributed=False):
model.eval()
func = lambda **x: self._forward_func(model, **x)
if hasattr(data_loader, 'dataset'):
length = len(data_loader.dataset)
else:
length = data_loader.data_length
if distributed:
rank, world_size = get_dist_info()
results = dist_forward_collect(func, data_loader, rank, length)
else:
results = nondist_forward_collect(func, data_loader, length)
return results
def parse_args():
parser = argparse.ArgumentParser(
description='EVTORCH batchuse dataloader extract features of a model'
)
parser.add_argument(
'config', help='config file path', type=str, default=None)
parser.add_argument('--checkpoint', default=None, help='checkpoint file')
parser.add_argument(
'--pretrained',
default='random',
help='pretrained model file, exclusive to --checkpoint')
parser.add_argument(
'--work_dir',
type=str,
default=None,
help='the dir to save logs and models')
parser.add_argument(
'--extract_list',
type=list,
default=['neck'],
help='the dir to save logs and models')
parser.add_argument(
'--launcher',
choices=['none', 'pytorch', 'slurm', 'mpi'],
default='none',
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
parser.add_argument(
'--port',
type=int,
default=29500,
help='port only works when launcher=="slurm"')
args = parser.parse_args()
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(args.local_rank)
return args
def main():
args = parse_args()
# set cudnn_benchmark
cfg = mmcv_config_fromfile(args.config)
if cfg.get('oss_io_config', None):
io.access_oss(**cfg.oss_io_config)
if cfg.get('cudnn_benchmark', False):
torch.backends.cudnn.benchmark = True
# update configs according to CLI args
if args.work_dir is not None:
cfg.work_dir = args.work_dir
# checkpoint and pretrained are exclusive
assert args.pretrained == 'random' or args.checkpoint is None, \
'Checkpoint and pretrained are exclusive.'
# init distributed env first, since logger depends on the dist info.
if args.launcher == 'none':
distributed = False
else:
distributed = True
if args.launcher == 'slurm':
cfg.dist_params['port'] = args.port
init_dist(args.launcher, **cfg.dist_params)
# create work_dir
mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
# logger
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
log_file = osp.join(cfg.work_dir, 'extract_{}.log'.format(timestamp))
logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)
datasets = [build_dataset(cfg.data.extract)]
seed = 0
set_random_seed(seed)
data_loader = [
build_dataloader(
ds,
cfg.data.imgs_per_gpu,
cfg.data.workers_per_gpu,
# cfg.gpus,
dist=distributed,
shuffle=False,
replace=getattr(cfg.data, 'sampling_replace', False),
seed=seed,
drop_last=getattr(cfg.data, 'drop_last', False)) for ds in datasets
]
# specify pretrained model
if args.pretrained != 'random':
assert isinstance(args.pretrained, str)
cfg.model.pretrained = args.pretrained
assert os.path.exists(args.checkpoint), \
'checkpoint must be set when use extractor!'
ckpt_meta = torch.load(args.checkpoint).get('meta', None)
cfg_model = cfg.get('model', None)
if cfg_model is not None:
logger.info('load model scripts from cfg config')
model = build_model(cfg_model)
else:
assert ckpt_meta is not None, 'extract need either cfg model or ckpt with meta!'
logger.info('load model scripts from ckpt meta')
ckpt_cfg = json.loads(ckpt_meta['config'])
if 'model' not in ckpt_cfg:
raise ValueError(
'build model from %s, must use model after export' %
(args.checkpoint))
model = build_model(ckpt_cfg['model'])
# build the model and load checkpoint
if args.checkpoint is not None:
logger.info('Use checkpoint: {} to extract features'.format(
args.checkpoint))
load_checkpoint(model, args.checkpoint, map_location='cpu')
elif args.pretrained != 'random':
logger.info('Use pretrained model: {} to extract features'.format(
args.pretrained))
else:
logger.info('No checkpoint or pretrained is give, use random init.')
if not distributed:
model = MMDataParallel(model, device_ids=[0])
else:
model = MMDistributedDataParallel(
model.cuda(),
device_ids=[torch.cuda.current_device()],
broadcast_buffers=False)
# build extraction processor
extractor = ExtractProcess(extract_list=args.extract_list)
# run
outputs = extractor.extract(model, data_loader[0], distributed=distributed)
rank, _ = get_dist_info()
mmcv.mkdir_or_exist(args.work_dir)
if rank == 0:
for key, val in outputs.items():
split_num = len(cfg.split_name)
split_at = cfg.split_at
print(split_num, split_at)
for ss in range(split_num):
output_file = '{}/{}_{}.npy'.format(args.work_dir,
cfg.split_name[ss], key)
if ss == 0:
np.save(output_file, val[:split_at[0]])
elif ss == split_num - 1:
np.save(output_file, val[split_at[-1]:])
else:
np.save(output_file, val[split_at[ss - 1]:split_at[ss]])
if __name__ == '__main__':
main()