mirror of
https://github.com/open-mmlab/mmselfsup.git
synced 2025-06-03 14:59:38 +08:00
[Refactor] refactor DeepCluster and ODC hook and related files
This commit is contained in:
parent
d1b7466d9b
commit
5dc3696df7
@ -1,5 +1,5 @@
|
||||
# dataset settings
|
||||
dataset_type = 'mmcls.ImageNet'
|
||||
dataset_type = 'DeepClusterImageNet'
|
||||
data_root = 'data/imagenet/'
|
||||
file_client_args = dict(backend='disk')
|
||||
|
||||
@ -14,7 +14,7 @@ train_pipeline = [
|
||||
contrast=0.4,
|
||||
saturation=1.0,
|
||||
hue=0.5),
|
||||
dict(type='RandomGrayscale', p=0.2),
|
||||
dict(type='RandomGrayscale', prob=0.2),
|
||||
dict(type='PackSelfSupInputs')
|
||||
]
|
||||
|
||||
@ -37,14 +37,15 @@ train_dataloader = dict(
|
||||
data_prefix=dict(img='train/'),
|
||||
pipeline=train_pipeline))
|
||||
|
||||
# TODO: refactor the hook and modify the config
|
||||
num_classes = 10000
|
||||
custom_hooks = [
|
||||
dict(
|
||||
type='DeepClusterHook',
|
||||
extractor=dict(
|
||||
samples_per_gpu=128,
|
||||
workers_per_gpu=8,
|
||||
extract_dataloader=dict(
|
||||
batch_size=128,
|
||||
num_workers=8,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
|
@ -1,5 +1,5 @@
|
||||
# dataset settings
|
||||
dataset_type = 'mmcls.ImageNet'
|
||||
dataset_type = 'DeepClusterImageNet'
|
||||
data_root = 'data/imagenet/'
|
||||
file_client_args = dict(backend='disk')
|
||||
|
||||
@ -38,14 +38,15 @@ train_dataloader = dict(
|
||||
data_prefix=dict(img='train/'),
|
||||
pipeline=train_pipeline))
|
||||
|
||||
# TODO: refactor the hook and modify the config
|
||||
num_classes = 10000
|
||||
custom_hooks = [
|
||||
dict(
|
||||
type='DeepClusterHook',
|
||||
extractor=dict(
|
||||
samples_per_gpu=128,
|
||||
workers_per_gpu=8,
|
||||
extract_dataloader=dict(
|
||||
batch_size=128,
|
||||
num_workers=8,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
|
@ -6,7 +6,6 @@ import torch
|
||||
import torch.distributed as dist
|
||||
from mmengine.hooks import Hook
|
||||
from mmengine.logging import print_log
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from mmselfsup.registry import HOOKS
|
||||
from mmselfsup.utils import Extractor
|
||||
@ -35,7 +34,7 @@ class DeepClusterHook(Hook):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
extractor: Dict,
|
||||
extract_dataloader: Dict,
|
||||
clustering: Dict,
|
||||
unif_sampling: bool,
|
||||
reweight: bool,
|
||||
@ -43,9 +42,12 @@ class DeepClusterHook(Hook):
|
||||
init_memory: Optional[bool] = False, # for ODC
|
||||
initial: Optional[bool] = True,
|
||||
interval: Optional[int] = 1,
|
||||
dist_mode: Optional[bool] = True,
|
||||
data_loaders: Optional[DataLoader] = None) -> None:
|
||||
self.extractor = Extractor(dist_mode=dist_mode, **extractor)
|
||||
seed: Optional[int] = None,
|
||||
dist_mode: Optional[bool] = True) -> None:
|
||||
self.extractor = Extractor(
|
||||
extract_dataloader=extract_dataloader,
|
||||
seed=seed,
|
||||
dist_mode=dist_mode)
|
||||
self.clustering_type = clustering.pop('type')
|
||||
self.clustering_cfg = clustering
|
||||
self.unif_sampling = unif_sampling
|
||||
@ -55,9 +57,9 @@ class DeepClusterHook(Hook):
|
||||
self.initial = initial
|
||||
self.interval = interval
|
||||
self.dist_mode = dist_mode
|
||||
self.data_loaders = data_loaders
|
||||
|
||||
def before_run(self, runner) -> None:
|
||||
self.data_loader = runner.train_dataloader
|
||||
if self.initial:
|
||||
self.deepcluster(runner)
|
||||
|
||||
@ -84,7 +86,7 @@ class DeepClusterHook(Hook):
|
||||
new_labels)
|
||||
self.evaluate(runner, new_labels)
|
||||
else:
|
||||
new_labels = np.zeros((len(self.data_loaders[0].dataset), ),
|
||||
new_labels = np.zeros((len(self.data_loader.dataset), ),
|
||||
dtype=np.int64)
|
||||
|
||||
if self.dist_mode:
|
||||
@ -94,11 +96,11 @@ class DeepClusterHook(Hook):
|
||||
new_labels_list = list(new_labels)
|
||||
|
||||
# step 3: assign new labels
|
||||
self.data_loaders[0].dataset.assign_labels(new_labels_list)
|
||||
self.data_loader.dataset.assign_labels(new_labels_list)
|
||||
|
||||
# step 4 (a): set uniform sampler
|
||||
if self.unif_sampling:
|
||||
self.data_loaders[0].sampler.set_uniform_indices(
|
||||
self.data_loader.sampler.set_uniform_indices(
|
||||
new_labels_list, self.clustering_cfg.k)
|
||||
|
||||
# step 4 (b): set loss reweight
|
||||
|
64
mmselfsup/datasets/deepcluster_dataset.py
Normal file
64
mmselfsup/datasets/deepcluster_dataset.py
Normal file
@ -0,0 +1,64 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
from mmcls.datasets import ImageNet
|
||||
|
||||
|
||||
class DeepClusterImageNet(ImageNet):
|
||||
"""`ImageNet <http://www.image-net.org>`_ Dataset.
|
||||
|
||||
The dataset inherit ImageNet dataset from MMClassification as the
|
||||
DeepCluster and Online Deep Clustering algorithm need to initialize
|
||||
clustering labels and assign them during training.
|
||||
|
||||
Args:
|
||||
ann_file (str, optional): Annotation file path. Defaults to None.
|
||||
metainfo (dict, optional): Meta information for dataset, such as class
|
||||
information. Defaults to None.
|
||||
data_root (str, optional): The root directory for ``data_prefix`` and
|
||||
``ann_file``. Defaults to None.
|
||||
data_prefix (str | dict, optional): Prefix for training data. Defaults
|
||||
to None.
|
||||
**kwargs: Other keyword arguments in :class:`CustomDataset` and
|
||||
:class:`BaseDataset`.
|
||||
""" # noqa: E501
|
||||
|
||||
def __init__(self,
|
||||
ann_file: Optional[str] = None,
|
||||
metainfo: Optional[dict] = None,
|
||||
data_root: Optional[str] = None,
|
||||
data_prefix: Union[str, dict, None] = None,
|
||||
**kwargs):
|
||||
kwargs = {'extensions': self.IMG_EXTENSIONS, **kwargs}
|
||||
super().__init__(
|
||||
ann_file=ann_file,
|
||||
metainfo=metainfo,
|
||||
data_root=data_root,
|
||||
data_prefix=data_prefix,
|
||||
**kwargs)
|
||||
# init clustering labels
|
||||
self.clustering_labels = [-1 for _ in range(len(self))]
|
||||
|
||||
def assign_labels(self, labels: List) -> None:
|
||||
"""Assign new labels to `self.clustering_labels`.
|
||||
|
||||
Args:
|
||||
labels (list): The new labels.
|
||||
"""
|
||||
assert len(self.clustering_labels) == len(labels), (
|
||||
f'Inconsistent length of assigned labels, '
|
||||
f'{len(self.clustering_labels)} vs {len(labels)}')
|
||||
self.clustering_labels = labels[:]
|
||||
|
||||
def prepare_data(self, idx: int) -> Any:
|
||||
"""Get data processed by ``self.pipeline``.
|
||||
|
||||
Args:
|
||||
idx (int): The index of ``data_info``.
|
||||
|
||||
Returns:
|
||||
Any: Depends on ``self.pipeline``.
|
||||
"""
|
||||
data_info = self.get_data_info(idx)
|
||||
data_info['clustering_label'] = self.clustering_labels[idx]
|
||||
return self.pipeline(data_info)
|
@ -1,65 +1,40 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import Dataset
|
||||
from typing import Dict, Optional
|
||||
|
||||
from mmengine import Runner
|
||||
|
||||
from mmselfsup.utils import dist_forward_collect, nondist_forward_collect
|
||||
|
||||
|
||||
class Extractor(object):
|
||||
class Extractor():
|
||||
"""Feature extractor.
|
||||
|
||||
Args:
|
||||
dataset (Dataset | dict): A PyTorch dataset or dict that indicates
|
||||
the dataset.
|
||||
samples_per_gpu (int): Number of images on each GPU, i.e., batch size
|
||||
of each GPU.
|
||||
workers_per_gpu (int): How many subprocesses to use for data loading
|
||||
for each GPU.
|
||||
dist_mode (bool): Use distributed extraction or not. Defaults to False.
|
||||
persistent_workers (bool): If True, the data loader will not shutdown
|
||||
the worker processes after a dataset has been consumed once.
|
||||
This allows to maintain the workers Dataset instances alive.
|
||||
The argument also has effect in PyTorch>=1.7.0.
|
||||
Defaults to True.
|
||||
extract_dataloader (dict): A dict to build Dataloader object.
|
||||
seed (int, optional): Random seed. Defaults to None.
|
||||
dist_mode (bool, optional): Use distributed extraction or not.
|
||||
Defaults to False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
dataset,
|
||||
samples_per_gpu,
|
||||
workers_per_gpu,
|
||||
dist_mode=False,
|
||||
persistent_workers=True,
|
||||
extract_dataloader: Dict,
|
||||
seed: Optional[int] = None,
|
||||
dist_mode: bool = False,
|
||||
**kwargs):
|
||||
from mmselfsup import datasets
|
||||
if isinstance(dataset, Dataset):
|
||||
self.dataset = dataset
|
||||
elif isinstance(dataset, dict):
|
||||
self.dataset = datasets.build_dataset(dataset)
|
||||
else:
|
||||
raise TypeError(f'dataset must be a Dataset object or a dict, '
|
||||
f'not {type(dataset)}')
|
||||
self.data_loader = datasets.build_dataloader(
|
||||
self.dataset,
|
||||
samples_per_gpu=samples_per_gpu,
|
||||
workers_per_gpu=workers_per_gpu,
|
||||
dist=dist_mode,
|
||||
shuffle=False,
|
||||
persistent_workers=persistent_workers,
|
||||
prefetch=kwargs.get('prefetch', False),
|
||||
img_norm_cfg=kwargs.get('img_norm_cfg', dict()))
|
||||
self.data_loader = Runner.build_dataloader(
|
||||
extract_dataloader=extract_dataloader, seed=seed)
|
||||
self.dist_mode = dist_mode
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
|
||||
def _forward_func(self, runner, **x):
|
||||
backbone_feat = runner.model(mode='extract', **x)
|
||||
def _forward_func(self, runner, packed_data):
|
||||
backbone_feat = runner.model(packed_data, extract=True)
|
||||
last_layer_feat = runner.model.module.neck([backbone_feat[-1]])[0]
|
||||
last_layer_feat = last_layer_feat.view(last_layer_feat.size(0), -1)
|
||||
return dict(feature=last_layer_feat.cpu())
|
||||
|
||||
def __call__(self, runner):
|
||||
# the function sent to collect function
|
||||
def func(**x):
|
||||
return self._forward_func(runner, **x)
|
||||
def func(packed_data):
|
||||
return self._forward_func(runner, packed_data)
|
||||
|
||||
if self.dist_mode:
|
||||
feats = dist_forward_collect(
|
||||
|
Loading…
x
Reference in New Issue
Block a user