75 lines
2.8 KiB
Python
75 lines
2.8 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import torch.nn as nn
|
|
from torch.utils.data import Dataset
|
|
|
|
from mmselfsup.utils import dist_forward_collect, nondist_forward_collect
|
|
|
|
|
|
class Extractor(object):
|
|
"""Feature extractor.
|
|
|
|
Args:
|
|
dataset (Dataset | dict): A PyTorch dataset or dict that indicates
|
|
the dataset.
|
|
samples_per_gpu (int): Number of images on each GPU, i.e., batch size
|
|
of each GPU.
|
|
workers_per_gpu (int): How many subprocesses to use for data loading
|
|
for each GPU.
|
|
dist_mode (bool): Use distributed extraction or not. Defaults to False.
|
|
persistent_workers (bool): If True, the data loader will not shutdown
|
|
the worker processes after a dataset has been consumed once.
|
|
This allows to maintain the workers Dataset instances alive.
|
|
The argument also has effect in PyTorch>=1.7.0.
|
|
Defaults to True.
|
|
"""
|
|
|
|
def __init__(self,
|
|
dataset,
|
|
samples_per_gpu,
|
|
workers_per_gpu,
|
|
dist_mode=False,
|
|
persistent_workers=True,
|
|
**kwargs):
|
|
from mmselfsup import datasets
|
|
if isinstance(dataset, Dataset):
|
|
self.dataset = dataset
|
|
elif isinstance(dataset, dict):
|
|
self.dataset = datasets.build_dataset(dataset)
|
|
else:
|
|
raise TypeError(f'dataset must be a Dataset object or a dict, '
|
|
f'not {type(dataset)}')
|
|
self.data_loader = datasets.build_dataloader(
|
|
self.dataset,
|
|
samples_per_gpu=samples_per_gpu,
|
|
workers_per_gpu=workers_per_gpu,
|
|
dist=dist_mode,
|
|
shuffle=False,
|
|
persistent_workers=persistent_workers,
|
|
prefetch=kwargs.get('prefetch', False),
|
|
img_norm_cfg=kwargs.get('img_norm_cfg', dict()))
|
|
self.dist_mode = dist_mode
|
|
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
|
|
|
|
def _forward_func(self, runner, **x):
|
|
backbone_feat = runner.model(mode='extract', **x)
|
|
last_layer_feat = runner.model.module.neck([backbone_feat[-1]])[0]
|
|
last_layer_feat = last_layer_feat.view(last_layer_feat.size(0), -1)
|
|
return dict(feature=last_layer_feat.cpu())
|
|
|
|
def __call__(self, runner):
|
|
# the function sent to collect function
|
|
def func(**x):
|
|
return self._forward_func(runner, **x)
|
|
|
|
if self.dist_mode:
|
|
feats = dist_forward_collect(
|
|
func,
|
|
self.data_loader,
|
|
runner.rank,
|
|
len(self.dataset),
|
|
ret_rank=-1)['feature'] # NxD
|
|
else:
|
|
feats = nondist_forward_collect(func, self.data_loader,
|
|
len(self.dataset))['feature']
|
|
return feats
|