Jiahao Xie fc69e38009
Bump version to v0.6.0 (#199)
* [Feature] Add MoCo v3 (#194)

* [Feature] add position embedding function

* [Fature] modify nonlinear neck for vit backbone

* [Feature] add mocov3 head

* [Feature] modify cls_head for vit backbone

* [Feature] add ViT backbone

* [Feature] add mocov3 algorithm

* [Docs] revise BYOL hook docstring

* [Feature] add mocov3 vit small config files

* [Feature] add mocov3 vit small linear eval config files

* [Fix] solve conflict

* [Fix] add mmcls

* [Fix] fix docstring format

* [Fix] fix isort

* [Fix] add mmcls to runtime requirements

* [Feature] remove duplicated codes

* [Feature] add mocov3 related unit test

* [Feature] revise position embedding function

* [Feature] add UT codes

* [Docs] add README.md

* [Docs] add model links and results to model zoo

* [Docs] fix model links

* [Docs] add metafile

* [Docs] modify install.md and add mmcls requirements

* [Docs] modify description

* [Fix] using specific arch name `mocov3-small`  rather than general arch name `small`

* [Fix] add mmcls

* [Fix] fix arch name

* [Feature] change name to `MoCoV3`

* [Fix] fix unit test bug

* [Feature] change `BYOLHook` name to `MomentumUpdateHook`

* [Feature] change name to MoCoV3

* [Docs] modify description

Co-authored-by: fangyixiao18 <fangyx18@hotmail.com>
Co-authored-by: Yixiao Fang <36138628+fangyixiao18@users.noreply.github.com>

* [Docs] update model zoo results (#195)

* Bump version to v0.6.0 (#198)

* [Docs] update model zoo results

* Bump version to v0.6.0

Co-authored-by: fangyixiao18 <fangyx18@hotmail.com>
Co-authored-by: Yixiao Fang <36138628+fangyixiao18@users.noreply.github.com>
2022-02-02 11:16:06 +08:00

80 lines
2.5 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import logging
import tempfile
from unittest.mock import MagicMock
import torch
import torch.nn as nn
from mmcv.parallel import MMDataParallel
from mmcv.runner import build_runner, obj_from_dict
from torch.utils.data import DataLoader, Dataset
from mmselfsup.core.hooks import MomentumUpdateHook
class ExampleDataset(Dataset):
def __getitem__(self, idx):
results = dict(img=torch.tensor([1]), img_metas=dict())
return results
def __len__(self):
return 1
class ExampleModel(nn.Module):
def __init__(self):
super(ExampleModel, self).__init__()
self.test_cfg = None
self.online_net = nn.Conv2d(3, 3, 3)
self.target_net = nn.Conv2d(3, 3, 3)
self.base_momentum = 0.96
self.momentum = self.base_momentum
def forward(self, img, img_metas, test_mode=False, **kwargs):
return img
def train_step(self, data_batch, optimizer):
loss = self.forward(**data_batch)
return dict(loss=loss)
@torch.no_grad()
def _momentum_update(self):
"""Momentum update of the target network."""
for param_ol, param_tgt in zip(self.online_net.parameters(),
self.target_net.parameters()):
param_tgt.data = param_tgt.data * self.momentum + \
param_ol.data * (1. - self.momentum)
@torch.no_grad()
def momentum_update(self):
self._momentum_update()
def test_byol_hook():
test_dataset = ExampleDataset()
test_dataset.evaluate = MagicMock(return_value=dict(test='success'))
data_loader = DataLoader(
test_dataset, batch_size=1, sampler=None, num_workers=0, shuffle=False)
runner_cfg = dict(type='EpochBasedRunner', max_epochs=2)
optim_cfg = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
# test MomentumUpdateHook
with tempfile.TemporaryDirectory() as tmpdir:
model = MMDataParallel(ExampleModel())
optimizer = obj_from_dict(optim_cfg, torch.optim,
dict(params=model.parameters()))
momentum_hook = MomentumUpdateHook()
runner = build_runner(
runner_cfg,
default_args=dict(
model=model,
optimizer=optimizer,
work_dir=tmpdir,
logger=logging.getLogger()))
runner.register_hook(momentum_hook)
runner.run([data_loader], [('train', 1)])
assert runner.model.module.momentum == 0.98