mirror of
https://github.com/open-mmlab/mmselfsup.git
synced 2025-06-03 14:59:38 +08:00
[Refactor] refactor DeepCluster and ODC hook and related files
This commit is contained in:
parent
d1b7466d9b
commit
5dc3696df7
@ -1,5 +1,5 @@
|
|||||||
# dataset settings
|
# dataset settings
|
||||||
dataset_type = 'mmcls.ImageNet'
|
dataset_type = 'DeepClusterImageNet'
|
||||||
data_root = 'data/imagenet/'
|
data_root = 'data/imagenet/'
|
||||||
file_client_args = dict(backend='disk')
|
file_client_args = dict(backend='disk')
|
||||||
|
|
||||||
@ -14,7 +14,7 @@ train_pipeline = [
|
|||||||
contrast=0.4,
|
contrast=0.4,
|
||||||
saturation=1.0,
|
saturation=1.0,
|
||||||
hue=0.5),
|
hue=0.5),
|
||||||
dict(type='RandomGrayscale', p=0.2),
|
dict(type='RandomGrayscale', prob=0.2),
|
||||||
dict(type='PackSelfSupInputs')
|
dict(type='PackSelfSupInputs')
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -37,14 +37,15 @@ train_dataloader = dict(
|
|||||||
data_prefix=dict(img='train/'),
|
data_prefix=dict(img='train/'),
|
||||||
pipeline=train_pipeline))
|
pipeline=train_pipeline))
|
||||||
|
|
||||||
# TODO: refactor the hook and modify the config
|
|
||||||
num_classes = 10000
|
num_classes = 10000
|
||||||
custom_hooks = [
|
custom_hooks = [
|
||||||
dict(
|
dict(
|
||||||
type='DeepClusterHook',
|
type='DeepClusterHook',
|
||||||
extractor=dict(
|
extract_dataloader=dict(
|
||||||
samples_per_gpu=128,
|
batch_size=128,
|
||||||
workers_per_gpu=8,
|
num_workers=8,
|
||||||
|
persistent_workers=True,
|
||||||
|
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||||
dataset=dict(
|
dataset=dict(
|
||||||
type=dataset_type,
|
type=dataset_type,
|
||||||
data_root=data_root,
|
data_root=data_root,
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
# dataset settings
|
# dataset settings
|
||||||
dataset_type = 'mmcls.ImageNet'
|
dataset_type = 'DeepClusterImageNet'
|
||||||
data_root = 'data/imagenet/'
|
data_root = 'data/imagenet/'
|
||||||
file_client_args = dict(backend='disk')
|
file_client_args = dict(backend='disk')
|
||||||
|
|
||||||
@ -38,14 +38,15 @@ train_dataloader = dict(
|
|||||||
data_prefix=dict(img='train/'),
|
data_prefix=dict(img='train/'),
|
||||||
pipeline=train_pipeline))
|
pipeline=train_pipeline))
|
||||||
|
|
||||||
# TODO: refactor the hook and modify the config
|
|
||||||
num_classes = 10000
|
num_classes = 10000
|
||||||
custom_hooks = [
|
custom_hooks = [
|
||||||
dict(
|
dict(
|
||||||
type='DeepClusterHook',
|
type='DeepClusterHook',
|
||||||
extractor=dict(
|
extract_dataloader=dict(
|
||||||
samples_per_gpu=128,
|
batch_size=128,
|
||||||
workers_per_gpu=8,
|
num_workers=8,
|
||||||
|
persistent_workers=True,
|
||||||
|
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||||
dataset=dict(
|
dataset=dict(
|
||||||
type=dataset_type,
|
type=dataset_type,
|
||||||
data_root=data_root,
|
data_root=data_root,
|
||||||
|
@ -6,7 +6,6 @@ import torch
|
|||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from mmengine.hooks import Hook
|
from mmengine.hooks import Hook
|
||||||
from mmengine.logging import print_log
|
from mmengine.logging import print_log
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
|
|
||||||
from mmselfsup.registry import HOOKS
|
from mmselfsup.registry import HOOKS
|
||||||
from mmselfsup.utils import Extractor
|
from mmselfsup.utils import Extractor
|
||||||
@ -35,7 +34,7 @@ class DeepClusterHook(Hook):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
extractor: Dict,
|
extract_dataloader: Dict,
|
||||||
clustering: Dict,
|
clustering: Dict,
|
||||||
unif_sampling: bool,
|
unif_sampling: bool,
|
||||||
reweight: bool,
|
reweight: bool,
|
||||||
@ -43,9 +42,12 @@ class DeepClusterHook(Hook):
|
|||||||
init_memory: Optional[bool] = False, # for ODC
|
init_memory: Optional[bool] = False, # for ODC
|
||||||
initial: Optional[bool] = True,
|
initial: Optional[bool] = True,
|
||||||
interval: Optional[int] = 1,
|
interval: Optional[int] = 1,
|
||||||
dist_mode: Optional[bool] = True,
|
seed: Optional[int] = None,
|
||||||
data_loaders: Optional[DataLoader] = None) -> None:
|
dist_mode: Optional[bool] = True) -> None:
|
||||||
self.extractor = Extractor(dist_mode=dist_mode, **extractor)
|
self.extractor = Extractor(
|
||||||
|
extract_dataloader=extract_dataloader,
|
||||||
|
seed=seed,
|
||||||
|
dist_mode=dist_mode)
|
||||||
self.clustering_type = clustering.pop('type')
|
self.clustering_type = clustering.pop('type')
|
||||||
self.clustering_cfg = clustering
|
self.clustering_cfg = clustering
|
||||||
self.unif_sampling = unif_sampling
|
self.unif_sampling = unif_sampling
|
||||||
@ -55,9 +57,9 @@ class DeepClusterHook(Hook):
|
|||||||
self.initial = initial
|
self.initial = initial
|
||||||
self.interval = interval
|
self.interval = interval
|
||||||
self.dist_mode = dist_mode
|
self.dist_mode = dist_mode
|
||||||
self.data_loaders = data_loaders
|
|
||||||
|
|
||||||
def before_run(self, runner) -> None:
|
def before_run(self, runner) -> None:
|
||||||
|
self.data_loader = runner.train_dataloader
|
||||||
if self.initial:
|
if self.initial:
|
||||||
self.deepcluster(runner)
|
self.deepcluster(runner)
|
||||||
|
|
||||||
@ -84,7 +86,7 @@ class DeepClusterHook(Hook):
|
|||||||
new_labels)
|
new_labels)
|
||||||
self.evaluate(runner, new_labels)
|
self.evaluate(runner, new_labels)
|
||||||
else:
|
else:
|
||||||
new_labels = np.zeros((len(self.data_loaders[0].dataset), ),
|
new_labels = np.zeros((len(self.data_loader.dataset), ),
|
||||||
dtype=np.int64)
|
dtype=np.int64)
|
||||||
|
|
||||||
if self.dist_mode:
|
if self.dist_mode:
|
||||||
@ -94,11 +96,11 @@ class DeepClusterHook(Hook):
|
|||||||
new_labels_list = list(new_labels)
|
new_labels_list = list(new_labels)
|
||||||
|
|
||||||
# step 3: assign new labels
|
# step 3: assign new labels
|
||||||
self.data_loaders[0].dataset.assign_labels(new_labels_list)
|
self.data_loader.dataset.assign_labels(new_labels_list)
|
||||||
|
|
||||||
# step 4 (a): set uniform sampler
|
# step 4 (a): set uniform sampler
|
||||||
if self.unif_sampling:
|
if self.unif_sampling:
|
||||||
self.data_loaders[0].sampler.set_uniform_indices(
|
self.data_loader.sampler.set_uniform_indices(
|
||||||
new_labels_list, self.clustering_cfg.k)
|
new_labels_list, self.clustering_cfg.k)
|
||||||
|
|
||||||
# step 4 (b): set loss reweight
|
# step 4 (b): set loss reweight
|
||||||
|
64
mmselfsup/datasets/deepcluster_dataset.py
Normal file
64
mmselfsup/datasets/deepcluster_dataset.py
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
from typing import Any, List, Optional, Union
|
||||||
|
|
||||||
|
from mmcls.datasets import ImageNet
|
||||||
|
|
||||||
|
|
||||||
|
class DeepClusterImageNet(ImageNet):
|
||||||
|
"""`ImageNet <http://www.image-net.org>`_ Dataset.
|
||||||
|
|
||||||
|
The dataset inherit ImageNet dataset from MMClassification as the
|
||||||
|
DeepCluster and Online Deep Clustering algorithm need to initialize
|
||||||
|
clustering labels and assign them during training.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ann_file (str, optional): Annotation file path. Defaults to None.
|
||||||
|
metainfo (dict, optional): Meta information for dataset, such as class
|
||||||
|
information. Defaults to None.
|
||||||
|
data_root (str, optional): The root directory for ``data_prefix`` and
|
||||||
|
``ann_file``. Defaults to None.
|
||||||
|
data_prefix (str | dict, optional): Prefix for training data. Defaults
|
||||||
|
to None.
|
||||||
|
**kwargs: Other keyword arguments in :class:`CustomDataset` and
|
||||||
|
:class:`BaseDataset`.
|
||||||
|
""" # noqa: E501
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
ann_file: Optional[str] = None,
|
||||||
|
metainfo: Optional[dict] = None,
|
||||||
|
data_root: Optional[str] = None,
|
||||||
|
data_prefix: Union[str, dict, None] = None,
|
||||||
|
**kwargs):
|
||||||
|
kwargs = {'extensions': self.IMG_EXTENSIONS, **kwargs}
|
||||||
|
super().__init__(
|
||||||
|
ann_file=ann_file,
|
||||||
|
metainfo=metainfo,
|
||||||
|
data_root=data_root,
|
||||||
|
data_prefix=data_prefix,
|
||||||
|
**kwargs)
|
||||||
|
# init clustering labels
|
||||||
|
self.clustering_labels = [-1 for _ in range(len(self))]
|
||||||
|
|
||||||
|
def assign_labels(self, labels: List) -> None:
|
||||||
|
"""Assign new labels to `self.clustering_labels`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
labels (list): The new labels.
|
||||||
|
"""
|
||||||
|
assert len(self.clustering_labels) == len(labels), (
|
||||||
|
f'Inconsistent length of assigned labels, '
|
||||||
|
f'{len(self.clustering_labels)} vs {len(labels)}')
|
||||||
|
self.clustering_labels = labels[:]
|
||||||
|
|
||||||
|
def prepare_data(self, idx: int) -> Any:
|
||||||
|
"""Get data processed by ``self.pipeline``.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
idx (int): The index of ``data_info``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Any: Depends on ``self.pipeline``.
|
||||||
|
"""
|
||||||
|
data_info = self.get_data_info(idx)
|
||||||
|
data_info['clustering_label'] = self.clustering_labels[idx]
|
||||||
|
return self.pipeline(data_info)
|
@ -1,65 +1,40 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
import torch.nn as nn
|
from typing import Dict, Optional
|
||||||
from torch.utils.data import Dataset
|
|
||||||
|
from mmengine import Runner
|
||||||
|
|
||||||
from mmselfsup.utils import dist_forward_collect, nondist_forward_collect
|
from mmselfsup.utils import dist_forward_collect, nondist_forward_collect
|
||||||
|
|
||||||
|
|
||||||
class Extractor(object):
|
class Extractor():
|
||||||
"""Feature extractor.
|
"""Feature extractor.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dataset (Dataset | dict): A PyTorch dataset or dict that indicates
|
extract_dataloader (dict): A dict to build Dataloader object.
|
||||||
the dataset.
|
seed (int, optional): Random seed. Defaults to None.
|
||||||
samples_per_gpu (int): Number of images on each GPU, i.e., batch size
|
dist_mode (bool, optional): Use distributed extraction or not.
|
||||||
of each GPU.
|
Defaults to False.
|
||||||
workers_per_gpu (int): How many subprocesses to use for data loading
|
|
||||||
for each GPU.
|
|
||||||
dist_mode (bool): Use distributed extraction or not. Defaults to False.
|
|
||||||
persistent_workers (bool): If True, the data loader will not shutdown
|
|
||||||
the worker processes after a dataset has been consumed once.
|
|
||||||
This allows to maintain the workers Dataset instances alive.
|
|
||||||
The argument also has effect in PyTorch>=1.7.0.
|
|
||||||
Defaults to True.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
dataset,
|
extract_dataloader: Dict,
|
||||||
samples_per_gpu,
|
seed: Optional[int] = None,
|
||||||
workers_per_gpu,
|
dist_mode: bool = False,
|
||||||
dist_mode=False,
|
|
||||||
persistent_workers=True,
|
|
||||||
**kwargs):
|
**kwargs):
|
||||||
from mmselfsup import datasets
|
self.data_loader = Runner.build_dataloader(
|
||||||
if isinstance(dataset, Dataset):
|
extract_dataloader=extract_dataloader, seed=seed)
|
||||||
self.dataset = dataset
|
|
||||||
elif isinstance(dataset, dict):
|
|
||||||
self.dataset = datasets.build_dataset(dataset)
|
|
||||||
else:
|
|
||||||
raise TypeError(f'dataset must be a Dataset object or a dict, '
|
|
||||||
f'not {type(dataset)}')
|
|
||||||
self.data_loader = datasets.build_dataloader(
|
|
||||||
self.dataset,
|
|
||||||
samples_per_gpu=samples_per_gpu,
|
|
||||||
workers_per_gpu=workers_per_gpu,
|
|
||||||
dist=dist_mode,
|
|
||||||
shuffle=False,
|
|
||||||
persistent_workers=persistent_workers,
|
|
||||||
prefetch=kwargs.get('prefetch', False),
|
|
||||||
img_norm_cfg=kwargs.get('img_norm_cfg', dict()))
|
|
||||||
self.dist_mode = dist_mode
|
self.dist_mode = dist_mode
|
||||||
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
|
|
||||||
|
|
||||||
def _forward_func(self, runner, **x):
|
def _forward_func(self, runner, packed_data):
|
||||||
backbone_feat = runner.model(mode='extract', **x)
|
backbone_feat = runner.model(packed_data, extract=True)
|
||||||
last_layer_feat = runner.model.module.neck([backbone_feat[-1]])[0]
|
last_layer_feat = runner.model.module.neck([backbone_feat[-1]])[0]
|
||||||
last_layer_feat = last_layer_feat.view(last_layer_feat.size(0), -1)
|
last_layer_feat = last_layer_feat.view(last_layer_feat.size(0), -1)
|
||||||
return dict(feature=last_layer_feat.cpu())
|
return dict(feature=last_layer_feat.cpu())
|
||||||
|
|
||||||
def __call__(self, runner):
|
def __call__(self, runner):
|
||||||
# the function sent to collect function
|
# the function sent to collect function
|
||||||
def func(**x):
|
def func(packed_data):
|
||||||
return self._forward_func(runner, **x)
|
return self._forward_func(runner, packed_data)
|
||||||
|
|
||||||
if self.dist_mode:
|
if self.dist_mode:
|
||||||
feats = dist_forward_collect(
|
feats = dist_forward_collect(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user