[Refactor] refactor knn (#420)

This commit is contained in:
Jiahao Xie 2022-08-31 13:32:03 +08:00 committed by GitHub
parent 8f1c35957b
commit b23765fce7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 99 additions and 101 deletions

View File

@ -1,29 +1,38 @@
data_source = 'ImageNet'
dataset_type = 'SingleViewDataset'
img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
pipeline = [
dict(type='Resize', size=256),
dict(type='CenterCrop', size=224),
dict(type='ToTensor'),
dict(type='Normalize', **img_norm_cfg),
custom_imports = dict(imports='mmcls.datasets', allow_failed_imports=False)
dataset_type = 'mmcls.ImageNet'
data_root = 'data/imagenet'
file_client_args = dict(backend='disk')
extract_pipeline = [
dict(type='LoadImageFromFile', file_client_args=file_client_args),
dict(type='mmcls.ResizeEdge', scale=256, edge='short'),
dict(type='CenterCrop', crop_size=224),
dict(type='PackSelfSupInputs'),
]
data = dict(
samples_per_gpu=256,
workers_per_gpu=8,
train=dict(
train_dataloader = dict(
batch_size=256,
num_workers=8,
dataset=dict(
type=dataset_type,
data_source=dict(
type=data_source,
data_prefix='data/imagenet/train',
ann_file='data/imagenet/meta/train.txt',
),
pipeline=pipeline),
val=dict(
data_root=data_root,
ann_file='meta/train.txt',
data_prefix='train',
pipeline=extract_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
)
val_dataloader = dict(
batch_size=256,
num_workers=8,
dataset=dict(
type=dataset_type,
data_source=dict(
type=data_source,
data_prefix='data/imagenet/val',
ann_file='data/imagenet/meta/val.txt',
),
pipeline=pipeline))
data_root=data_root,
ann_file='meta/val.txt',
data_prefix='val',
pipeline=extract_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
)
# pooling cfg
pool_cfg = dict(type='AvgPool2d')

View File

@ -7,10 +7,7 @@ CFG=$1
EPOCH=$2
PY_ARGS=${@:3}
GPUS=${GPUS:-8}
NNODES=${NNODES:-1}
NODE_RANK=${NODE_RANK:-0}
PORT=${PORT:-29500}
MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
WORK_DIR=$(echo ${CFG%.*} | sed -e "s/configs/work_dirs/g")/
@ -20,9 +17,6 @@ if [ ! -f $WORK_DIR/epoch_${EPOCH}.pth ]; then
fi
python -m torch.distributed.launch \
--nnodes=$NNODES \
--node_rank=$NODE_RANK \
--master_addr=$MASTER_ADDR \
--nproc_per_node=$GPUS \
--master_port=$PORT \
tools/benchmarks/classification/knn_imagenet/test_knn.py $CFG \

View File

@ -7,18 +7,12 @@ CFG=$1
PRETRAIN=$2 # pretrained model
PY_ARGS=${@:3}
GPUS=${GPUS:-8}
NNODES=${NNODES:-1}
NODE_RANK=${NODE_RANK:-0}
PORT=${PORT:-29500}
MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
# set work_dir according to config path and pretrained model to distinguish different models
WORK_DIR="$(echo ${CFG%.*} | sed -e "s/configs/work_dirs/g")/$(echo $PRETRAIN | rev | cut -d/ -f 1 | rev)"
python -m torch.distributed.launch \
--nnodes=$NNODES \
--node_rank=$NODE_RANK \
--master_addr=$MASTER_ADDR \
--nproc_per_node=$GPUS \
--master_port=$PORT \
tools/benchmarks/classification/knn_imagenet/test_knn.py $CFG \

View File

@ -11,7 +11,6 @@ PY_ARGS=${@:5}
GPUS=${GPUS:-8}
GPUS_PER_NODE=${GPUS_PER_NODE:-8}
CPUS_PER_TASK=${CPUS_PER_TASK:-5}
PORT=${PORT:-29500}
SRUN_ARGS=${SRUN_ARGS:-""}
WORK_DIR=$(echo ${CFG%.*} | sed -e "s/configs/work_dirs/g")/
@ -32,5 +31,4 @@ srun -p ${PARTITION} \
${SRUN_ARGS} \
python -u tools/benchmarks/classification/knn_imagenet/test_knn.py $CFG \
--checkpoint $WORK_DIR/epoch_${EPOCH}.pth \
--cfg-options dist_params.port=$PORT \
--work-dir $WORK_DIR --launcher="slurm" ${PY_ARGS}

View File

@ -11,7 +11,6 @@ PY_ARGS=${@:5}
GPUS=${GPUS:-8}
GPUS_PER_NODE=${GPUS_PER_NODE:-8}
CPUS_PER_TASK=${CPUS_PER_TASK:-5}
PORT=${PORT:-29500}
SRUN_ARGS=${SRUN_ARGS:-""}
# set work_dir according to config path and pretrained model to distinguish different models
@ -29,5 +28,4 @@ srun -p ${PARTITION} \
python -u tools/benchmarks/classification/knn_imagenet/test_knn.py $CFG \
--cfg-options model.backbone.init_cfg.type=Pretrained \
model.backbone.init_cfg.checkpoint=$PRETRAIN \
dist_params.port=$PORT \
--work-dir $WORK_DIR --launcher="slurm" ${PY_ARGS}

View File

@ -1,19 +1,23 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import copy
import os
import os.path as osp
import time
import mmcv
import torch
from mmcv import DictAction
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import get_dist_info, init_dist, load_checkpoint
from mmengine import Runner
from mmengine.config import Config, DictAction
from mmengine.dist import get_rank, init_dist
from mmengine.logging import MMLogger
from mmengine.model.wrappers import MMDistributedDataParallel, is_model_wrapper
from mmengine.runner import load_checkpoint
from mmengine.utils import mkdir_or_exist
from mmselfsup.datasets import build_dataloader, build_dataset
from mmselfsup.models import build_algorithm
from mmselfsup.models.utils import ExtractProcess, knn_classifier
from mmselfsup.utils import get_root_logger
from mmselfsup.evaluation.functional import knn_classifier
from mmselfsup.models.utils import Extractor
from mmselfsup.registry import MODELS
from mmselfsup.utils import register_all_modules
def parse_args():
@ -59,6 +63,7 @@ def parse_args():
type=bool,
help='Store the features on GPU. Set to False if you encounter OOM')
parser.add_argument('--local_rank', type=int, default=0)
parser.add_argument('--seed', type=int, default=0, help='random seed')
args = parser.parse_args()
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(args.local_rank)
@ -68,12 +73,18 @@ def parse_args():
def main():
args = parse_args()
cfg = mmcv.Config.fromfile(args.config)
# register all modules in mmselfsup into the registries
register_all_modules()
# load config
cfg = Config.fromfile(args.config)
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
# set cudnn_benchmark
if cfg.get('cudnn_benchmark', False):
if cfg.env_cfg.get('cudnn_benchmark', False):
torch.backends.cudnn.benchmark = True
# work_dir is determined in this priority: CLI > segment in file > filename
if args.work_dir is not None:
# update configs according to CLI args if args.work_dir is not None
@ -89,47 +100,30 @@ def main():
distributed = False
else:
distributed = True
init_dist(args.launcher, **cfg.dist_params)
init_dist(args.launcher)
# create work_dir and init the logger before other steps
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
# create work_dir
knn_work_dir = osp.join(cfg.work_dir, 'knn/')
mmcv.mkdir_or_exist(osp.abspath(knn_work_dir))
log_file = osp.join(knn_work_dir, f'knn_{timestamp}.log')
logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)
mkdir_or_exist(osp.abspath(knn_work_dir))
# build the dataloader
dataset_cfg = mmcv.Config.fromfile(args.dataset_config)
dataset_train = build_dataset(dataset_cfg.data.train)
dataset_val = build_dataset(dataset_cfg.data.val)
if 'imgs_per_gpu' in cfg.data:
logger.warning('"imgs_per_gpu" is deprecated. '
'Please use "samples_per_gpu" instead')
if 'samples_per_gpu' in cfg.data:
logger.warning(
f'Got "imgs_per_gpu"={cfg.data.imgs_per_gpu} and '
f'"samples_per_gpu"={cfg.data.samples_per_gpu}, "imgs_per_gpu"'
f'={cfg.data.imgs_per_gpu} is used in this experiments')
else:
logger.warning(
'Automatically set "samples_per_gpu"="imgs_per_gpu"='
f'{cfg.data.imgs_per_gpu} in this experiments')
cfg.data.samples_per_gpu = cfg.data.imgs_per_gpu
data_loader_train = build_dataloader(
dataset_train,
samples_per_gpu=dataset_cfg.data.samples_per_gpu,
workers_per_gpu=dataset_cfg.data.workers_per_gpu,
dist=distributed,
shuffle=False)
data_loader_val = build_dataloader(
dataset_val,
samples_per_gpu=dataset_cfg.data.samples_per_gpu,
workers_per_gpu=dataset_cfg.data.workers_per_gpu,
dist=distributed,
shuffle=False)
# init the logger before other steps
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
log_file = osp.join(knn_work_dir, f'knn_{timestamp}.log')
logger = MMLogger.get_instance(
'mmselfsup',
logger_name='mmselfsup',
log_file=log_file,
log_level=cfg.log_level)
# build dataloader
dataset_cfg = Config.fromfile(args.dataset_config)
data_loader_train = Runner.build_dataloader(
dataloader=dataset_cfg.train_dataloader, seed=args.seed)
data_loader_val = Runner.build_dataloader(
dataloader=dataset_cfg.val_dataloader, seed=args.seed)
# build the model
model = build_algorithm(cfg.model)
model = MODELS.build(cfg.model)
model.init_weights()
# model is determined in this priority: init_cfg > checkpoint > random
@ -145,30 +139,41 @@ def main():
else:
logger.info('No pretrained or checkpoint is given, use random init.')
if not distributed:
model = MMDataParallel(model, device_ids=[0])
else:
if torch.cuda.is_available():
model = model.cuda()
if distributed:
model = MMDistributedDataParallel(
model.cuda(),
module=model.cuda(),
device_ids=[torch.cuda.current_device()],
broadcast_buffers=False)
model.eval()
# build extraction processor and run
extractor = ExtractProcess()
train_feats = extractor.extract(
model, data_loader_train, distributed=distributed)['feat']
val_feats = extractor.extract(
model, data_loader_val, distributed=distributed)['feat']
if is_model_wrapper(model):
model = model.module
# build extractor and extract features
extractor_train = Extractor(
extract_dataloader=data_loader_train,
seed=args.seed,
dist_mode=distributed,
pool_cfg=copy.deepcopy(dataset_cfg.pool_cfg))
extractor_val = Extractor(
extract_dataloader=data_loader_val,
seed=args.seed,
dist_mode=distributed,
pool_cfg=copy.deepcopy(dataset_cfg.pool_cfg))
train_feats = extractor_train(model)['feat5']
val_feats = extractor_val(model)['feat5']
train_feats = torch.from_numpy(train_feats)
val_feats = torch.from_numpy(val_feats)
train_labels = torch.LongTensor(dataset_train.data_source.get_gt_labels())
val_labels = torch.LongTensor(dataset_val.data_source.get_gt_labels())
train_labels = torch.LongTensor(data_loader_train.dataset.get_gt_labels())
val_labels = torch.LongTensor(data_loader_val.dataset.get_gt_labels())
logger.info('Features are extracted! Start k-NN classification...')
rank, _ = get_dist_info()
# run knn
rank = get_rank()
if rank == 0:
if args.use_cuda:
train_feats = train_feats.cuda()