[Refactor] Refactor hooks to support non-dist training (#397)
* apply mmengine.dist instead of torch.dist * appy get_model to densecl_hook * fix bug to pass ut * update typehintpull/408/head
parent
a703ba2fcb
commit
0ea07c0750
|
@ -4,6 +4,7 @@ from typing import Optional, Sequence
|
|||
from mmengine.hooks import Hook
|
||||
|
||||
from mmselfsup.registry import HOOKS
|
||||
from mmselfsup.utils import get_model
|
||||
|
||||
|
||||
@HOOKS.register_module()
|
||||
|
@ -23,19 +24,19 @@ class DenseCLHook(Hook):
|
|||
|
||||
def before_train(self, runner) -> None:
|
||||
"""Obtain ``loss_lambda`` from algorithm."""
|
||||
assert hasattr(runner.model.module, 'loss_lambda'), \
|
||||
assert hasattr(get_model(runner.model), 'loss_lambda'), \
|
||||
"The runner must have attribute \"loss_lambda\" in DenseCL."
|
||||
self.loss_lambda = runner.model.module.loss_lambda
|
||||
self.loss_lambda = get_model(runner.model).loss_lambda
|
||||
|
||||
def before_train_iter(self,
|
||||
runner,
|
||||
batch_idx: int,
|
||||
data_batch: Optional[Sequence[dict]] = None) -> None:
|
||||
"""Adjust ``loss_lambda`` every train iter."""
|
||||
assert hasattr(runner.model.module, 'loss_lambda'), \
|
||||
assert hasattr(get_model(runner.model), 'loss_lambda'), \
|
||||
"The runner must have attribute \"loss_lambda\" in DenseCL."
|
||||
cur_iter = runner.iter
|
||||
if cur_iter >= self.start_iters:
|
||||
runner.model.module.loss_lambda = self.loss_lambda
|
||||
get_model(runner.model).loss_lambda = self.loss_lambda
|
||||
else:
|
||||
runner.model.module.loss_lambda = 0.
|
||||
get_model(runner.model).loss_lambda = 0.
|
||||
|
|
|
@ -3,11 +3,12 @@ import os.path as osp
|
|||
from typing import Dict, List, Optional, Sequence
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from mmengine.dist import get_rank, get_world_size, is_distributed
|
||||
from mmengine.hooks import Hook
|
||||
from mmengine.logging import MMLogger
|
||||
|
||||
from mmselfsup.registry import HOOKS
|
||||
from mmselfsup.utils import get_model
|
||||
|
||||
|
||||
@HOOKS.register_module()
|
||||
|
@ -45,8 +46,7 @@ class SwAVHook(Hook):
|
|||
interval: Optional[int] = 1,
|
||||
frozen_layers_cfg: Optional[Dict] = dict()
|
||||
) -> None:
|
||||
self.batch_size = batch_size * dist.get_world_size()\
|
||||
if dist.is_initialized() else batch_size
|
||||
self.batch_size = batch_size * get_world_size()
|
||||
self.epoch_queue_starts = epoch_queue_starts
|
||||
self.crops_for_assign = crops_for_assign
|
||||
self.feat_dim = feat_dim
|
||||
|
@ -58,16 +58,16 @@ class SwAVHook(Hook):
|
|||
|
||||
def before_run(self, runner) -> None:
|
||||
"""Check whether the queues exist locally or not."""
|
||||
if dist.is_initialized():
|
||||
if is_distributed():
|
||||
self.queue_path = osp.join(runner.work_dir,
|
||||
'queue' + str(dist.get_rank()) + '.pth')
|
||||
'queue' + str(get_rank()) + '.pth')
|
||||
else:
|
||||
self.queue_path = osp.join(runner.work_dir, 'queue.pth')
|
||||
|
||||
# load the queues if queues exist locally
|
||||
if osp.isfile(self.queue_path):
|
||||
self.queue = torch.load(self.queue_path)['queue']
|
||||
runner.model.module.head.loss.queue = self.queue
|
||||
get_model(runner.model).head.loss.queue = self.queue
|
||||
MMLogger.get_current_instance().info(
|
||||
f'Load queue from file: {self.queue_path}')
|
||||
|
||||
|
@ -82,12 +82,12 @@ class SwAVHook(Hook):
|
|||
for layer, frozen_iters in self.frozen_layers_cfg.items():
|
||||
if runner.iter < frozen_iters and self.requires_grad:
|
||||
self.requires_grad = False
|
||||
for name, p in runner.model.module.named_parameters():
|
||||
for name, p in get_model(runner.model).named_parameters():
|
||||
if layer in name:
|
||||
p.requires_grad = False
|
||||
elif runner.iter >= frozen_iters and not self.requires_grad:
|
||||
self.requires_grad = True
|
||||
for name, p in runner.model.module.named_parameters():
|
||||
for name, p in get_model(runner.model).named_parameters():
|
||||
if layer in name:
|
||||
p.requires_grad = True
|
||||
|
||||
|
@ -104,12 +104,12 @@ class SwAVHook(Hook):
|
|||
).cuda()
|
||||
|
||||
# set the boolean type of use_the_queue
|
||||
runner.model.module.head.loss.queue = self.queue
|
||||
runner.model.module.head.loss.use_queue = False
|
||||
get_model(runner.model).head.loss.queue = self.queue
|
||||
get_model(runner.model).head.loss.use_queue = False
|
||||
|
||||
def after_train_epoch(self, runner) -> None:
|
||||
"""Save the queues locally."""
|
||||
self.queue = runner.model.module.head.loss.queue
|
||||
self.queue = get_model(runner.model).head.loss.queue
|
||||
|
||||
if self.queue is not None and self.every_n_epochs(
|
||||
runner, self.interval):
|
||||
|
|
|
@ -82,8 +82,9 @@ class SelfSupDataPreprocessor(ImgDataPreprocessor):
|
|||
# :class:`mmengine.ImgDataPreprocessor`. Since there are multiple views
|
||||
# for an image for some algorithms, e.g. SimCLR, each item in inputs
|
||||
# is a list, containing multi-views for an image.
|
||||
inputs = [[(img_ - self.mean) / self.std for img_ in _input]
|
||||
for _input in inputs]
|
||||
if self._enable_normalize:
|
||||
inputs = [[(img_ - self.mean) / self.std for img_ in _input]
|
||||
for _input in inputs]
|
||||
|
||||
batch_inputs = []
|
||||
for i in range(len(inputs[0])):
|
||||
|
@ -125,8 +126,9 @@ class RelativeLocDataPreprocessor(SelfSupDataPreprocessor):
|
|||
# :class:`mmengine.ImgDataPreprocessor`. Since there are multiple views
|
||||
# for an image for some algorithms, e.g. SimCLR, each item in inputs
|
||||
# is a list, containing multi-views for an image.
|
||||
inputs = [[(img_ - self.mean) / self.std for img_ in _input]
|
||||
for _input in inputs]
|
||||
if self._enable_normalize:
|
||||
inputs = [[(img_ - self.mean) / self.std for img_ in _input]
|
||||
for _input in inputs]
|
||||
|
||||
batch_inputs = []
|
||||
for i in range(len(inputs[0])):
|
||||
|
@ -180,8 +182,9 @@ class RotationPredDataPreprocessor(SelfSupDataPreprocessor):
|
|||
# :class:`mmengine.ImgDataPreprocessor`. Since there are multiple views
|
||||
# for an image for some algorithms, e.g. SimCLR, each item in inputs
|
||||
# is a list, containing multi-views for an image.
|
||||
inputs = [[(img_ - self.mean) / self.std for img_ in _input]
|
||||
for _input in inputs]
|
||||
if self._enable_normalize:
|
||||
inputs = [[(img_ - self.mean) / self.std for img_ in _input]
|
||||
for _input in inputs]
|
||||
|
||||
batch_inputs = []
|
||||
for i in range(len(inputs[0])):
|
||||
|
|
|
@ -5,11 +5,13 @@ from .collect import dist_forward_collect, nondist_forward_collect
|
|||
from .collect_env import collect_env
|
||||
from .distributed_sinkhorn import distributed_sinkhorn
|
||||
from .gather import concat_all_gather, gather_tensors, gather_tensors_batch
|
||||
from .misc import get_model
|
||||
from .setup_env import register_all_modules
|
||||
|
||||
__all__ = [
|
||||
'AliasMethod', 'batch_shuffle_ddp', 'batch_unshuffle_ddp',
|
||||
'dist_forward_collect', 'nondist_forward_collect', 'collect_env',
|
||||
'sync_random_seed', 'distributed_sinkhorn', 'concat_all_gather',
|
||||
'gather_tensors', 'gather_tensors_batch', 'register_all_modules'
|
||||
'gather_tensors', 'gather_tensors_batch', 'register_all_modules',
|
||||
'get_model'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,18 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch.nn as nn
|
||||
from mmengine.model import BaseModel, is_model_wrapper
|
||||
|
||||
|
||||
def get_model(model: nn.Module) -> BaseModel:
|
||||
"""Get model if the input model is a model wrapper.
|
||||
|
||||
Args:
|
||||
model (nn.Module): A model may be a model wrapper.
|
||||
|
||||
Returns:
|
||||
BaseModel: The model without model wrapper.
|
||||
"""
|
||||
if is_model_wrapper(model):
|
||||
return model.module
|
||||
else:
|
||||
return model
|
|
@ -6,7 +6,6 @@ import torch
|
|||
import torch.nn as nn
|
||||
from mmengine import Runner
|
||||
from mmengine.data import LabelData
|
||||
from mmengine.model import BaseModel as EngineBaseModel
|
||||
from mmengine.model import BaseModule
|
||||
from mmengine.optim import OptimWrapper
|
||||
from torch.utils.data import Dataset
|
||||
|
@ -15,6 +14,7 @@ from mmselfsup.engine import DenseCLHook
|
|||
from mmselfsup.models.algorithms import BaseModel
|
||||
from mmselfsup.registry import MODELS
|
||||
from mmselfsup.structures import SelfSupDataSample
|
||||
from mmselfsup.utils import get_model
|
||||
|
||||
|
||||
class DummyDataset(Dataset):
|
||||
|
@ -33,7 +33,7 @@ class DummyDataset(Dataset):
|
|||
data_sample = SelfSupDataSample()
|
||||
gt_label = LabelData(value=self.label[index])
|
||||
setattr(data_sample, 'gt_label', gt_label)
|
||||
return dict(inputs=self.data[index], data_sample=data_sample)
|
||||
return dict(inputs=[self.data[index]], data_sample=data_sample)
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
|
@ -58,7 +58,7 @@ class ToyModel(BaseModel):
|
|||
for x in data_samples:
|
||||
labels.append(x.gt_label.value)
|
||||
labels = torch.stack(labels)
|
||||
outputs = self.backbone(batch_inputs)
|
||||
outputs = self.backbone(batch_inputs[0])
|
||||
loss = (labels - outputs).sum()
|
||||
outputs = dict(loss=loss)
|
||||
return outputs
|
||||
|
@ -78,18 +78,9 @@ class TestDenseCLHook(TestCase):
|
|||
toy_model = ToyModel().to(device)
|
||||
densecl_hook = DenseCLHook(start_iters=1)
|
||||
|
||||
class DummyWrapper(EngineBaseModel):
|
||||
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.module = model
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.module(*args, **kwargs)
|
||||
|
||||
# test DenseCLHook with model wrapper
|
||||
runner = Runner(
|
||||
model=DummyWrapper(toy_model),
|
||||
model=toy_model,
|
||||
work_dir=self.temp_dir.name,
|
||||
train_dataloader=dict(
|
||||
dataset=dummy_dataset,
|
||||
|
@ -108,6 +99,6 @@ class TestDenseCLHook(TestCase):
|
|||
runner.train()
|
||||
|
||||
if runner.iter >= 1:
|
||||
assert runner.model.module.loss_lambda == 0.5
|
||||
assert get_model(runner.model).loss_lambda == 0.5
|
||||
else:
|
||||
assert runner.model.module.loss_lambda == 0.
|
||||
assert get_model(runner.model).loss_lambda == 0.
|
||||
|
|
|
@ -6,7 +6,6 @@ import torch
|
|||
import torch.nn as nn
|
||||
from mmengine import Runner
|
||||
from mmengine.data import LabelData
|
||||
from mmengine.model import BaseModel as EngineBaseModel
|
||||
from mmengine.model import BaseModule
|
||||
from mmengine.optim import OptimWrapper
|
||||
from torch.utils.data import Dataset
|
||||
|
@ -16,6 +15,7 @@ from mmselfsup.models.algorithms import BaseModel
|
|||
from mmselfsup.models.heads import SwAVHead
|
||||
from mmselfsup.registry import MODELS
|
||||
from mmselfsup.structures import SelfSupDataSample
|
||||
from mmselfsup.utils import get_model
|
||||
|
||||
|
||||
class DummyDataset(Dataset):
|
||||
|
@ -34,7 +34,7 @@ class DummyDataset(Dataset):
|
|||
data_sample = SelfSupDataSample()
|
||||
gt_label = LabelData(value=self.label[index])
|
||||
setattr(data_sample, 'gt_label', gt_label)
|
||||
return dict(inputs=self.data[index], data_sample=data_sample)
|
||||
return dict(inputs=[self.data[index]], data_sample=data_sample)
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
|
@ -65,7 +65,7 @@ class ToyModel(BaseModel):
|
|||
for x in data_samples:
|
||||
labels.append(x.gt_label.value)
|
||||
labels = torch.stack(labels)
|
||||
outputs = self.backbone(batch_inputs)
|
||||
outputs = self.backbone(batch_inputs[0])
|
||||
loss = (labels - outputs).sum()
|
||||
outputs = dict(loss=loss)
|
||||
return outputs
|
||||
|
@ -91,18 +91,9 @@ class TestSwAVHook(TestCase):
|
|||
queue_length=300,
|
||||
frozen_layers_cfg=dict(prototypes=2))
|
||||
|
||||
class DummyWrapper(EngineBaseModel):
|
||||
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.module = model
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.module(*args, **kwargs)
|
||||
|
||||
# test SwAVHook
|
||||
runner = Runner(
|
||||
model=DummyWrapper(toy_model),
|
||||
model=toy_model,
|
||||
work_dir=self.temp_dir.name,
|
||||
train_dataloader=dict(
|
||||
dataset=dummy_dataset,
|
||||
|
@ -124,4 +115,4 @@ class TestSwAVHook(TestCase):
|
|||
if isinstance(hook, SwAVHook):
|
||||
assert hook.queue_length == 300
|
||||
|
||||
assert runner.model.module.head.loss.use_queue is False
|
||||
assert get_model(runner.model).head.loss.use_queue is False
|
||||
|
|
Loading…
Reference in New Issue