[Refactor] refactor DeepCluster and ODC hook and related files

This commit is contained in:
fangyixiao.vendor 2022-05-23 07:26:25 +00:00 committed by fangyixiao18
parent d1b7466d9b
commit 5dc3696df7
5 changed files with 105 additions and 62 deletions

View File

@ -1,5 +1,5 @@
# dataset settings
dataset_type = 'mmcls.ImageNet'
dataset_type = 'DeepClusterImageNet'
data_root = 'data/imagenet/'
file_client_args = dict(backend='disk')
@ -14,7 +14,7 @@ train_pipeline = [
contrast=0.4,
saturation=1.0,
hue=0.5),
dict(type='RandomGrayscale', p=0.2),
dict(type='RandomGrayscale', prob=0.2),
dict(type='PackSelfSupInputs')
]
@ -37,14 +37,15 @@ train_dataloader = dict(
data_prefix=dict(img='train/'),
pipeline=train_pipeline))
# TODO: refactor the hook and modify the config
num_classes = 10000
custom_hooks = [
dict(
type='DeepClusterHook',
extractor=dict(
samples_per_gpu=128,
workers_per_gpu=8,
extract_dataloader=dict(
batch_size=128,
num_workers=8,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,

View File

@ -1,5 +1,5 @@
# dataset settings
dataset_type = 'mmcls.ImageNet'
dataset_type = 'DeepClusterImageNet'
data_root = 'data/imagenet/'
file_client_args = dict(backend='disk')
@ -38,14 +38,15 @@ train_dataloader = dict(
data_prefix=dict(img='train/'),
pipeline=train_pipeline))
# TODO: refactor the hook and modify the config
num_classes = 10000
custom_hooks = [
dict(
type='DeepClusterHook',
extractor=dict(
samples_per_gpu=128,
workers_per_gpu=8,
extract_dataloader=dict(
batch_size=128,
num_workers=8,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,

View File

@ -6,7 +6,6 @@ import torch
import torch.distributed as dist
from mmengine.hooks import Hook
from mmengine.logging import print_log
from torch.utils.data import DataLoader
from mmselfsup.registry import HOOKS
from mmselfsup.utils import Extractor
@ -35,7 +34,7 @@ class DeepClusterHook(Hook):
def __init__(
self,
extractor: Dict,
extract_dataloader: Dict,
clustering: Dict,
unif_sampling: bool,
reweight: bool,
@ -43,9 +42,12 @@ class DeepClusterHook(Hook):
init_memory: Optional[bool] = False, # for ODC
initial: Optional[bool] = True,
interval: Optional[int] = 1,
dist_mode: Optional[bool] = True,
data_loaders: Optional[DataLoader] = None) -> None:
self.extractor = Extractor(dist_mode=dist_mode, **extractor)
seed: Optional[int] = None,
dist_mode: Optional[bool] = True) -> None:
self.extractor = Extractor(
extract_dataloader=extract_dataloader,
seed=seed,
dist_mode=dist_mode)
self.clustering_type = clustering.pop('type')
self.clustering_cfg = clustering
self.unif_sampling = unif_sampling
@ -55,9 +57,9 @@ class DeepClusterHook(Hook):
self.initial = initial
self.interval = interval
self.dist_mode = dist_mode
self.data_loaders = data_loaders
def before_run(self, runner) -> None:
self.data_loader = runner.train_dataloader
if self.initial:
self.deepcluster(runner)
@ -84,7 +86,7 @@ class DeepClusterHook(Hook):
new_labels)
self.evaluate(runner, new_labels)
else:
new_labels = np.zeros((len(self.data_loaders[0].dataset), ),
new_labels = np.zeros((len(self.data_loader.dataset), ),
dtype=np.int64)
if self.dist_mode:
@ -94,11 +96,11 @@ class DeepClusterHook(Hook):
new_labels_list = list(new_labels)
# step 3: assign new labels
self.data_loaders[0].dataset.assign_labels(new_labels_list)
self.data_loader.dataset.assign_labels(new_labels_list)
# step 4 (a): set uniform sampler
if self.unif_sampling:
self.data_loaders[0].sampler.set_uniform_indices(
self.data_loader.sampler.set_uniform_indices(
new_labels_list, self.clustering_cfg.k)
# step 4 (b): set loss reweight

View File

@ -0,0 +1,64 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, List, Optional, Union
from mmcls.datasets import ImageNet
class DeepClusterImageNet(ImageNet):
"""`ImageNet <http://www.image-net.org>`_ Dataset.
The dataset inherit ImageNet dataset from MMClassification as the
DeepCluster and Online Deep Clustering algorithm need to initialize
clustering labels and assign them during training.
Args:
ann_file (str, optional): Annotation file path. Defaults to None.
metainfo (dict, optional): Meta information for dataset, such as class
information. Defaults to None.
data_root (str, optional): The root directory for ``data_prefix`` and
``ann_file``. Defaults to None.
data_prefix (str | dict, optional): Prefix for training data. Defaults
to None.
**kwargs: Other keyword arguments in :class:`CustomDataset` and
:class:`BaseDataset`.
""" # noqa: E501
def __init__(self,
ann_file: Optional[str] = None,
metainfo: Optional[dict] = None,
data_root: Optional[str] = None,
data_prefix: Union[str, dict, None] = None,
**kwargs):
kwargs = {'extensions': self.IMG_EXTENSIONS, **kwargs}
super().__init__(
ann_file=ann_file,
metainfo=metainfo,
data_root=data_root,
data_prefix=data_prefix,
**kwargs)
# init clustering labels
self.clustering_labels = [-1 for _ in range(len(self))]
def assign_labels(self, labels: List) -> None:
"""Assign new labels to `self.clustering_labels`.
Args:
labels (list): The new labels.
"""
assert len(self.clustering_labels) == len(labels), (
f'Inconsistent length of assigned labels, '
f'{len(self.clustering_labels)} vs {len(labels)}')
self.clustering_labels = labels[:]
def prepare_data(self, idx: int) -> Any:
"""Get data processed by ``self.pipeline``.
Args:
idx (int): The index of ``data_info``.
Returns:
Any: Depends on ``self.pipeline``.
"""
data_info = self.get_data_info(idx)
data_info['clustering_label'] = self.clustering_labels[idx]
return self.pipeline(data_info)

View File

@ -1,65 +1,40 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from torch.utils.data import Dataset
from typing import Dict, Optional
from mmengine import Runner
from mmselfsup.utils import dist_forward_collect, nondist_forward_collect
class Extractor(object):
class Extractor():
"""Feature extractor.
Args:
dataset (Dataset | dict): A PyTorch dataset or dict that indicates
the dataset.
samples_per_gpu (int): Number of images on each GPU, i.e., batch size
of each GPU.
workers_per_gpu (int): How many subprocesses to use for data loading
for each GPU.
dist_mode (bool): Use distributed extraction or not. Defaults to False.
persistent_workers (bool): If True, the data loader will not shutdown
the worker processes after a dataset has been consumed once.
This allows to maintain the workers Dataset instances alive.
The argument also has effect in PyTorch>=1.7.0.
Defaults to True.
extract_dataloader (dict): A dict to build Dataloader object.
seed (int, optional): Random seed. Defaults to None.
dist_mode (bool, optional): Use distributed extraction or not.
Defaults to False.
"""
def __init__(self,
dataset,
samples_per_gpu,
workers_per_gpu,
dist_mode=False,
persistent_workers=True,
extract_dataloader: Dict,
seed: Optional[int] = None,
dist_mode: bool = False,
**kwargs):
from mmselfsup import datasets
if isinstance(dataset, Dataset):
self.dataset = dataset
elif isinstance(dataset, dict):
self.dataset = datasets.build_dataset(dataset)
else:
raise TypeError(f'dataset must be a Dataset object or a dict, '
f'not {type(dataset)}')
self.data_loader = datasets.build_dataloader(
self.dataset,
samples_per_gpu=samples_per_gpu,
workers_per_gpu=workers_per_gpu,
dist=dist_mode,
shuffle=False,
persistent_workers=persistent_workers,
prefetch=kwargs.get('prefetch', False),
img_norm_cfg=kwargs.get('img_norm_cfg', dict()))
self.data_loader = Runner.build_dataloader(
extract_dataloader=extract_dataloader, seed=seed)
self.dist_mode = dist_mode
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
def _forward_func(self, runner, **x):
backbone_feat = runner.model(mode='extract', **x)
def _forward_func(self, runner, packed_data):
backbone_feat = runner.model(packed_data, extract=True)
last_layer_feat = runner.model.module.neck([backbone_feat[-1]])[0]
last_layer_feat = last_layer_feat.view(last_layer_feat.size(0), -1)
return dict(feature=last_layer_feat.cpu())
def __call__(self, runner):
# the function sent to collect function
def func(**x):
return self._forward_func(runner, **x)
def func(packed_data):
return self._forward_func(runner, packed_data)
if self.dist_mode:
feats = dist_forward_collect(