mirror of
https://github.com/open-mmlab/mmselfsup.git
synced 2025-06-03 14:59:38 +08:00
[Refactor] refactor knn (#420)
This commit is contained in:
parent
8f1c35957b
commit
b23765fce7
@ -1,29 +1,38 @@
|
|||||||
data_source = 'ImageNet'
|
custom_imports = dict(imports='mmcls.datasets', allow_failed_imports=False)
|
||||||
dataset_type = 'SingleViewDataset'
|
dataset_type = 'mmcls.ImageNet'
|
||||||
img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
data_root = 'data/imagenet'
|
||||||
pipeline = [
|
file_client_args = dict(backend='disk')
|
||||||
dict(type='Resize', size=256),
|
|
||||||
dict(type='CenterCrop', size=224),
|
extract_pipeline = [
|
||||||
dict(type='ToTensor'),
|
dict(type='LoadImageFromFile', file_client_args=file_client_args),
|
||||||
dict(type='Normalize', **img_norm_cfg),
|
dict(type='mmcls.ResizeEdge', scale=256, edge='short'),
|
||||||
|
dict(type='CenterCrop', crop_size=224),
|
||||||
|
dict(type='PackSelfSupInputs'),
|
||||||
]
|
]
|
||||||
|
|
||||||
data = dict(
|
train_dataloader = dict(
|
||||||
samples_per_gpu=256,
|
batch_size=256,
|
||||||
workers_per_gpu=8,
|
num_workers=8,
|
||||||
train=dict(
|
dataset=dict(
|
||||||
type=dataset_type,
|
type=dataset_type,
|
||||||
data_source=dict(
|
data_root=data_root,
|
||||||
type=data_source,
|
ann_file='meta/train.txt',
|
||||||
data_prefix='data/imagenet/train',
|
data_prefix='train',
|
||||||
ann_file='data/imagenet/meta/train.txt',
|
pipeline=extract_pipeline),
|
||||||
),
|
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||||
pipeline=pipeline),
|
)
|
||||||
val=dict(
|
|
||||||
|
val_dataloader = dict(
|
||||||
|
batch_size=256,
|
||||||
|
num_workers=8,
|
||||||
|
dataset=dict(
|
||||||
type=dataset_type,
|
type=dataset_type,
|
||||||
data_source=dict(
|
data_root=data_root,
|
||||||
type=data_source,
|
ann_file='meta/val.txt',
|
||||||
data_prefix='data/imagenet/val',
|
data_prefix='val',
|
||||||
ann_file='data/imagenet/meta/val.txt',
|
pipeline=extract_pipeline),
|
||||||
),
|
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||||
pipeline=pipeline))
|
)
|
||||||
|
|
||||||
|
# pooling cfg
|
||||||
|
pool_cfg = dict(type='AvgPool2d')
|
||||||
|
@ -7,10 +7,7 @@ CFG=$1
|
|||||||
EPOCH=$2
|
EPOCH=$2
|
||||||
PY_ARGS=${@:3}
|
PY_ARGS=${@:3}
|
||||||
GPUS=${GPUS:-8}
|
GPUS=${GPUS:-8}
|
||||||
NNODES=${NNODES:-1}
|
|
||||||
NODE_RANK=${NODE_RANK:-0}
|
|
||||||
PORT=${PORT:-29500}
|
PORT=${PORT:-29500}
|
||||||
MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
|
|
||||||
|
|
||||||
WORK_DIR=$(echo ${CFG%.*} | sed -e "s/configs/work_dirs/g")/
|
WORK_DIR=$(echo ${CFG%.*} | sed -e "s/configs/work_dirs/g")/
|
||||||
|
|
||||||
@ -20,9 +17,6 @@ if [ ! -f $WORK_DIR/epoch_${EPOCH}.pth ]; then
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
python -m torch.distributed.launch \
|
python -m torch.distributed.launch \
|
||||||
--nnodes=$NNODES \
|
|
||||||
--node_rank=$NODE_RANK \
|
|
||||||
--master_addr=$MASTER_ADDR \
|
|
||||||
--nproc_per_node=$GPUS \
|
--nproc_per_node=$GPUS \
|
||||||
--master_port=$PORT \
|
--master_port=$PORT \
|
||||||
tools/benchmarks/classification/knn_imagenet/test_knn.py $CFG \
|
tools/benchmarks/classification/knn_imagenet/test_knn.py $CFG \
|
||||||
|
@ -7,18 +7,12 @@ CFG=$1
|
|||||||
PRETRAIN=$2 # pretrained model
|
PRETRAIN=$2 # pretrained model
|
||||||
PY_ARGS=${@:3}
|
PY_ARGS=${@:3}
|
||||||
GPUS=${GPUS:-8}
|
GPUS=${GPUS:-8}
|
||||||
NNODES=${NNODES:-1}
|
|
||||||
NODE_RANK=${NODE_RANK:-0}
|
|
||||||
PORT=${PORT:-29500}
|
PORT=${PORT:-29500}
|
||||||
MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
|
|
||||||
|
|
||||||
# set work_dir according to config path and pretrained model to distinguish different models
|
# set work_dir according to config path and pretrained model to distinguish different models
|
||||||
WORK_DIR="$(echo ${CFG%.*} | sed -e "s/configs/work_dirs/g")/$(echo $PRETRAIN | rev | cut -d/ -f 1 | rev)"
|
WORK_DIR="$(echo ${CFG%.*} | sed -e "s/configs/work_dirs/g")/$(echo $PRETRAIN | rev | cut -d/ -f 1 | rev)"
|
||||||
|
|
||||||
python -m torch.distributed.launch \
|
python -m torch.distributed.launch \
|
||||||
--nnodes=$NNODES \
|
|
||||||
--node_rank=$NODE_RANK \
|
|
||||||
--master_addr=$MASTER_ADDR \
|
|
||||||
--nproc_per_node=$GPUS \
|
--nproc_per_node=$GPUS \
|
||||||
--master_port=$PORT \
|
--master_port=$PORT \
|
||||||
tools/benchmarks/classification/knn_imagenet/test_knn.py $CFG \
|
tools/benchmarks/classification/knn_imagenet/test_knn.py $CFG \
|
||||||
|
@ -11,7 +11,6 @@ PY_ARGS=${@:5}
|
|||||||
GPUS=${GPUS:-8}
|
GPUS=${GPUS:-8}
|
||||||
GPUS_PER_NODE=${GPUS_PER_NODE:-8}
|
GPUS_PER_NODE=${GPUS_PER_NODE:-8}
|
||||||
CPUS_PER_TASK=${CPUS_PER_TASK:-5}
|
CPUS_PER_TASK=${CPUS_PER_TASK:-5}
|
||||||
PORT=${PORT:-29500}
|
|
||||||
SRUN_ARGS=${SRUN_ARGS:-""}
|
SRUN_ARGS=${SRUN_ARGS:-""}
|
||||||
|
|
||||||
WORK_DIR=$(echo ${CFG%.*} | sed -e "s/configs/work_dirs/g")/
|
WORK_DIR=$(echo ${CFG%.*} | sed -e "s/configs/work_dirs/g")/
|
||||||
@ -32,5 +31,4 @@ srun -p ${PARTITION} \
|
|||||||
${SRUN_ARGS} \
|
${SRUN_ARGS} \
|
||||||
python -u tools/benchmarks/classification/knn_imagenet/test_knn.py $CFG \
|
python -u tools/benchmarks/classification/knn_imagenet/test_knn.py $CFG \
|
||||||
--checkpoint $WORK_DIR/epoch_${EPOCH}.pth \
|
--checkpoint $WORK_DIR/epoch_${EPOCH}.pth \
|
||||||
--cfg-options dist_params.port=$PORT \
|
|
||||||
--work-dir $WORK_DIR --launcher="slurm" ${PY_ARGS}
|
--work-dir $WORK_DIR --launcher="slurm" ${PY_ARGS}
|
||||||
|
@ -11,7 +11,6 @@ PY_ARGS=${@:5}
|
|||||||
GPUS=${GPUS:-8}
|
GPUS=${GPUS:-8}
|
||||||
GPUS_PER_NODE=${GPUS_PER_NODE:-8}
|
GPUS_PER_NODE=${GPUS_PER_NODE:-8}
|
||||||
CPUS_PER_TASK=${CPUS_PER_TASK:-5}
|
CPUS_PER_TASK=${CPUS_PER_TASK:-5}
|
||||||
PORT=${PORT:-29500}
|
|
||||||
SRUN_ARGS=${SRUN_ARGS:-""}
|
SRUN_ARGS=${SRUN_ARGS:-""}
|
||||||
|
|
||||||
# set work_dir according to config path and pretrained model to distinguish different models
|
# set work_dir according to config path and pretrained model to distinguish different models
|
||||||
@ -29,5 +28,4 @@ srun -p ${PARTITION} \
|
|||||||
python -u tools/benchmarks/classification/knn_imagenet/test_knn.py $CFG \
|
python -u tools/benchmarks/classification/knn_imagenet/test_knn.py $CFG \
|
||||||
--cfg-options model.backbone.init_cfg.type=Pretrained \
|
--cfg-options model.backbone.init_cfg.type=Pretrained \
|
||||||
model.backbone.init_cfg.checkpoint=$PRETRAIN \
|
model.backbone.init_cfg.checkpoint=$PRETRAIN \
|
||||||
dist_params.port=$PORT \
|
|
||||||
--work-dir $WORK_DIR --launcher="slurm" ${PY_ARGS}
|
--work-dir $WORK_DIR --launcher="slurm" ${PY_ARGS}
|
||||||
|
@ -1,19 +1,23 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
import argparse
|
import argparse
|
||||||
|
import copy
|
||||||
import os
|
import os
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import mmcv
|
|
||||||
import torch
|
import torch
|
||||||
from mmcv import DictAction
|
from mmengine import Runner
|
||||||
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
|
from mmengine.config import Config, DictAction
|
||||||
from mmcv.runner import get_dist_info, init_dist, load_checkpoint
|
from mmengine.dist import get_rank, init_dist
|
||||||
|
from mmengine.logging import MMLogger
|
||||||
|
from mmengine.model.wrappers import MMDistributedDataParallel, is_model_wrapper
|
||||||
|
from mmengine.runner import load_checkpoint
|
||||||
|
from mmengine.utils import mkdir_or_exist
|
||||||
|
|
||||||
from mmselfsup.datasets import build_dataloader, build_dataset
|
from mmselfsup.evaluation.functional import knn_classifier
|
||||||
from mmselfsup.models import build_algorithm
|
from mmselfsup.models.utils import Extractor
|
||||||
from mmselfsup.models.utils import ExtractProcess, knn_classifier
|
from mmselfsup.registry import MODELS
|
||||||
from mmselfsup.utils import get_root_logger
|
from mmselfsup.utils import register_all_modules
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
@ -59,6 +63,7 @@ def parse_args():
|
|||||||
type=bool,
|
type=bool,
|
||||||
help='Store the features on GPU. Set to False if you encounter OOM')
|
help='Store the features on GPU. Set to False if you encounter OOM')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
|
parser.add_argument('--seed', type=int, default=0, help='random seed')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
if 'LOCAL_RANK' not in os.environ:
|
if 'LOCAL_RANK' not in os.environ:
|
||||||
os.environ['LOCAL_RANK'] = str(args.local_rank)
|
os.environ['LOCAL_RANK'] = str(args.local_rank)
|
||||||
@ -68,12 +73,18 @@ def parse_args():
|
|||||||
def main():
|
def main():
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
|
|
||||||
cfg = mmcv.Config.fromfile(args.config)
|
# register all modules in mmselfsup into the registries
|
||||||
|
register_all_modules()
|
||||||
|
|
||||||
|
# load config
|
||||||
|
cfg = Config.fromfile(args.config)
|
||||||
if args.cfg_options is not None:
|
if args.cfg_options is not None:
|
||||||
cfg.merge_from_dict(args.cfg_options)
|
cfg.merge_from_dict(args.cfg_options)
|
||||||
|
|
||||||
# set cudnn_benchmark
|
# set cudnn_benchmark
|
||||||
if cfg.get('cudnn_benchmark', False):
|
if cfg.env_cfg.get('cudnn_benchmark', False):
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
|
|
||||||
# work_dir is determined in this priority: CLI > segment in file > filename
|
# work_dir is determined in this priority: CLI > segment in file > filename
|
||||||
if args.work_dir is not None:
|
if args.work_dir is not None:
|
||||||
# update configs according to CLI args if args.work_dir is not None
|
# update configs according to CLI args if args.work_dir is not None
|
||||||
@ -89,47 +100,30 @@ def main():
|
|||||||
distributed = False
|
distributed = False
|
||||||
else:
|
else:
|
||||||
distributed = True
|
distributed = True
|
||||||
init_dist(args.launcher, **cfg.dist_params)
|
init_dist(args.launcher)
|
||||||
|
|
||||||
# create work_dir and init the logger before other steps
|
# create work_dir
|
||||||
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
|
|
||||||
knn_work_dir = osp.join(cfg.work_dir, 'knn/')
|
knn_work_dir = osp.join(cfg.work_dir, 'knn/')
|
||||||
mmcv.mkdir_or_exist(osp.abspath(knn_work_dir))
|
mkdir_or_exist(osp.abspath(knn_work_dir))
|
||||||
log_file = osp.join(knn_work_dir, f'knn_{timestamp}.log')
|
|
||||||
logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)
|
|
||||||
|
|
||||||
# build the dataloader
|
# init the logger before other steps
|
||||||
dataset_cfg = mmcv.Config.fromfile(args.dataset_config)
|
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
|
||||||
dataset_train = build_dataset(dataset_cfg.data.train)
|
log_file = osp.join(knn_work_dir, f'knn_{timestamp}.log')
|
||||||
dataset_val = build_dataset(dataset_cfg.data.val)
|
logger = MMLogger.get_instance(
|
||||||
if 'imgs_per_gpu' in cfg.data:
|
'mmselfsup',
|
||||||
logger.warning('"imgs_per_gpu" is deprecated. '
|
logger_name='mmselfsup',
|
||||||
'Please use "samples_per_gpu" instead')
|
log_file=log_file,
|
||||||
if 'samples_per_gpu' in cfg.data:
|
log_level=cfg.log_level)
|
||||||
logger.warning(
|
|
||||||
f'Got "imgs_per_gpu"={cfg.data.imgs_per_gpu} and '
|
# build dataloader
|
||||||
f'"samples_per_gpu"={cfg.data.samples_per_gpu}, "imgs_per_gpu"'
|
dataset_cfg = Config.fromfile(args.dataset_config)
|
||||||
f'={cfg.data.imgs_per_gpu} is used in this experiments')
|
data_loader_train = Runner.build_dataloader(
|
||||||
else:
|
dataloader=dataset_cfg.train_dataloader, seed=args.seed)
|
||||||
logger.warning(
|
data_loader_val = Runner.build_dataloader(
|
||||||
'Automatically set "samples_per_gpu"="imgs_per_gpu"='
|
dataloader=dataset_cfg.val_dataloader, seed=args.seed)
|
||||||
f'{cfg.data.imgs_per_gpu} in this experiments')
|
|
||||||
cfg.data.samples_per_gpu = cfg.data.imgs_per_gpu
|
|
||||||
data_loader_train = build_dataloader(
|
|
||||||
dataset_train,
|
|
||||||
samples_per_gpu=dataset_cfg.data.samples_per_gpu,
|
|
||||||
workers_per_gpu=dataset_cfg.data.workers_per_gpu,
|
|
||||||
dist=distributed,
|
|
||||||
shuffle=False)
|
|
||||||
data_loader_val = build_dataloader(
|
|
||||||
dataset_val,
|
|
||||||
samples_per_gpu=dataset_cfg.data.samples_per_gpu,
|
|
||||||
workers_per_gpu=dataset_cfg.data.workers_per_gpu,
|
|
||||||
dist=distributed,
|
|
||||||
shuffle=False)
|
|
||||||
|
|
||||||
# build the model
|
# build the model
|
||||||
model = build_algorithm(cfg.model)
|
model = MODELS.build(cfg.model)
|
||||||
model.init_weights()
|
model.init_weights()
|
||||||
|
|
||||||
# model is determined in this priority: init_cfg > checkpoint > random
|
# model is determined in this priority: init_cfg > checkpoint > random
|
||||||
@ -145,30 +139,41 @@ def main():
|
|||||||
else:
|
else:
|
||||||
logger.info('No pretrained or checkpoint is given, use random init.')
|
logger.info('No pretrained or checkpoint is given, use random init.')
|
||||||
|
|
||||||
if not distributed:
|
if torch.cuda.is_available():
|
||||||
model = MMDataParallel(model, device_ids=[0])
|
model = model.cuda()
|
||||||
else:
|
|
||||||
|
if distributed:
|
||||||
model = MMDistributedDataParallel(
|
model = MMDistributedDataParallel(
|
||||||
model.cuda(),
|
module=model.cuda(),
|
||||||
device_ids=[torch.cuda.current_device()],
|
device_ids=[torch.cuda.current_device()],
|
||||||
broadcast_buffers=False)
|
broadcast_buffers=False)
|
||||||
|
|
||||||
model.eval()
|
if is_model_wrapper(model):
|
||||||
# build extraction processor and run
|
model = model.module
|
||||||
extractor = ExtractProcess()
|
|
||||||
train_feats = extractor.extract(
|
# build extractor and extract features
|
||||||
model, data_loader_train, distributed=distributed)['feat']
|
extractor_train = Extractor(
|
||||||
val_feats = extractor.extract(
|
extract_dataloader=data_loader_train,
|
||||||
model, data_loader_val, distributed=distributed)['feat']
|
seed=args.seed,
|
||||||
|
dist_mode=distributed,
|
||||||
|
pool_cfg=copy.deepcopy(dataset_cfg.pool_cfg))
|
||||||
|
extractor_val = Extractor(
|
||||||
|
extract_dataloader=data_loader_val,
|
||||||
|
seed=args.seed,
|
||||||
|
dist_mode=distributed,
|
||||||
|
pool_cfg=copy.deepcopy(dataset_cfg.pool_cfg))
|
||||||
|
train_feats = extractor_train(model)['feat5']
|
||||||
|
val_feats = extractor_val(model)['feat5']
|
||||||
|
|
||||||
train_feats = torch.from_numpy(train_feats)
|
train_feats = torch.from_numpy(train_feats)
|
||||||
val_feats = torch.from_numpy(val_feats)
|
val_feats = torch.from_numpy(val_feats)
|
||||||
train_labels = torch.LongTensor(dataset_train.data_source.get_gt_labels())
|
train_labels = torch.LongTensor(data_loader_train.dataset.get_gt_labels())
|
||||||
val_labels = torch.LongTensor(dataset_val.data_source.get_gt_labels())
|
val_labels = torch.LongTensor(data_loader_val.dataset.get_gt_labels())
|
||||||
|
|
||||||
logger.info('Features are extracted! Start k-NN classification...')
|
logger.info('Features are extracted! Start k-NN classification...')
|
||||||
|
|
||||||
rank, _ = get_dist_info()
|
# run knn
|
||||||
|
rank = get_rank()
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
if args.use_cuda:
|
if args.use_cuda:
|
||||||
train_feats = train_feats.cuda()
|
train_feats = train_feats.cuda()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user