mmselfsup/tests/test_models/test_heads.py
Yuan Liu 399b5a0d6e
Bump version to v0.9.0 (#299)
* [Feature]: MAE pre-training with fp16 (#271)

* [Feature]: MAE pre-training with fp16

* [Fix]: Fix lint

* [Fix]: Fix SimMIM config link, and add SimMIM to model_zoo (#272)

* [Fix]: Fix link error

* [Fix]: Add SimMIM to model zoo

* [Fix]: Fix lint

* [Fix] fix 'no init_cfg' error for pre-trained model backbones (#256)

* [UT] add unit test for apis (#276)

* [UT] add unit test for apis

* ignore pytest log

* [Feature] Add extra dataloader settings in configs. (#264)

* [Feature] support to set validation samples per gpu independently

* set default to be cfg.data.samples_per_gpu

* modify the tools/test.py

* using 'train_dataloader', 'val_dataloader', 'test_dataloader' for specific settings

* test 'evaluation' branch

* [Fix]: Change imgs_per_gpu to samples_per_gpu MAE (#278)

* [Feature]: Add SimMIM 192 pt 224 ft (#280)

* [Feature]: Add SimMIM 192 pt 224 ft

* [Feature]: Add simmim 192 pt 224 ft to readme

* [Fix] fix key error bug when registering custom hooks (#273)

* [UT] remove pytorch1.5 test (#288)

* [Benchmark] rename linear probing config file names (#281)

* [Benchmark] rename linear probing config file names

* update config links

* Avoid GPU memory leak with prefetch dataloader (#277)

* [Feature] barlowtwins (#207)

* [Fix]: Fix mmcls upgrade bug (#235)

* [Feature]: Add multi machine dist_train (#232)

* [Feature]: Add multi machine dist_train

* [Fix]: Change bash to sh

* [Fix]: Fix missing sh suffix

* [Refactor]: Change bash to sh

* [Refactor] Add unit test (#234)

* [Refactor] add unit test

* update workflow

* update

* [Fix] fix lint

* update test

* refactor moco and densecl unit test

* fix lint

* add unit test

* update unit test

* remove modification

* [Feature]: Add MAE metafile (#238)

* [Feature]: Add MAE metafile

* [Fix]: Fix lint

* [Fix]: Change LARS to AdamW in the metafile of MAE

* Add barlowtwins

* Add unit test for barlowtwins

* Adjust training params

* add decorator to pass CI

* adjust params

* Add barlowtwins

* Add unit test for barlowtwins

* Adjust training params

* add decorator to pass CI

* adjust params

* add barlowtwins configs

* revise LatentCrossCorrelationHead

* modify ut to save memory

* add metafile

* add barlowtwins results to model zoo

* add barlow twins to homepage

* fix batch size bug

* add algorithm readme

* add type hints

* reorganize the model zoo

* remove one config

* recover the config

* add missing docstring

* revise barlowtwins

* reorganize coco and voc benchmark

* add barlowtwins to index.rst

* revise docstring

Co-authored-by: Yuan Liu <30762564+YuanLiuuuuuu@users.noreply.github.com>
Co-authored-by: Yixiao Fang <36138628+fangyixiao18@users.noreply.github.com>
Co-authored-by: fangyixiao18 <fangyx18@hotmail.com>

* [Fix] fix --local-rank (#290)

* [UT] reduce memory usage while runing unit test (#291)

* [Feature]: CAE Supported (#284)

* [Feature]: Add mc

* [Feature]: Add dataset of CAE

* [Feature]: Init version of CAE

* [Feature]: Add mc

* [Fix]: Change beta to (0.9, 0.999)

* [Fix]: New feature

* [Fix]: Decouple the qkv bias

* [Feature]: Decouple qkv bias in MultiheadAttention

* [Feature]: New mask generator

* [Fix]: Fix TransformEncoderLayer bug

* [Feature]: Add MAE CAE linear prob

* [Fix]: Fix config

* [Fix]: Delete redundant mc

* [Fix]: Add init value in mim cls vit

* [Fix]: Fix cae ft config

* [Fix]: Delete repeated init_values

* [Fix]: Change bs from 64 to 128 in CAE ft

* [Fix]: Add mc in cae pt

* [Fix]: Fix momemtum update bug

* [Fix]: Add no weight_decay for gamma

* [Feature]: Add mc for cae pt

* [Fix]: Delete mc

* [Fix]: Delete redundant files

* [Fix]: Fix lint

* [Feature]: Add docstring to algo, backbone, neck and head

* [Fix]: Fix lint

* [Fix]: network

* [Feature]: Add docstrings for network blocks

* [Feature]: Add docstring to ToTensor

* [Feature]: Add docstring to transoform

* [Fix]: Add type hint to BEiTMaskGenerator

* [Fix]: Fix lint

* [Fix]: Add copyright to dalle_e

* [Fix]: Fix BlockwiseMaskGenerator

* [Feature]: Add UT for CAE

* [Fix]: Fix dalle state_dict path not existed bug

* [Fix]: Delete file_client_args related code

* [Fix]: Remove redundant code

* [Refactor]: Add fp16 to the name of cae pre-train config

* [Refactor]: Use FFN from mmcv

* [Refactor]: Change network_blocks to trasformer_blocks

* [Fix]: Fix mask generator name bug

* [Fix]: cae pre-train config bug

* [Fix]: Fix docstring grammar

* [Fix]: Fix mc related code

* [Fix]: Add object parent to transform

* [Fix]: Delete unnecessary modification

* [Fix]: Change blockwisemask generator to simmim mask generator

* [Refactor]: Change cae mae pretrain vit to cae mae vit

* [Refactor]: Change lamb to lambd

* [Fix]: Remove blank line

* [Fix]: Fix lint

* [Fix]: Fix UT

* [Fix]: Delete modification to swin

* [Fix]: Fix lint

* [Feature]: Add README and metafile

* [Feature]: Update index.rst

* [Fix]: Update model_zoo

* [Fix]: Change MAE to CAE in algorithm

* [Fix]: Change SimMIMMaskGenerator to CAEMaskGenerator

* [Fix]: Fix model zoo

* [Fix]: Change to dalle_encoder

* [Feature]: Add download link for dalle

* [Fix]: Fix lint

* [Fix]: Fix UT

* [Fix]: Update metafile

* [Fix]: Change b to base

* [Feature]: Add dalle download link in warning

* [Fix] add arxiv link in readme

Co-authored-by: Jiahao Xie <52497952+Jiahao000@users.noreply.github.com>

* [Enhance] update SimCLR models and results (#295)

* [Enhance] update simclr models and results

* [Fix] revise comments to indicate settings

* Update version (#296)

* [Feature]: Update to 0.9.0

* [Feature]: Add version constrain for mmcls

* [Fix]: Fix bug

* [Fix]: Fix version bug

* [Feature]: Update version in install.md

* update changelog

* update readme

* [Fix] fix uppercase

* [Fix] fix uppercase

* [Fix] fix uppercase

* update version dependency

* add cae to readme

Co-authored-by: fangyixiao18 <fangyx18@hotmail.com>
Co-authored-by: Jiahao Xie <52497952+Jiahao000@users.noreply.github.com>

Co-authored-by: Yixiao Fang <36138628+fangyixiao18@users.noreply.github.com>
Co-authored-by: Ming Li <73068772+mitming@users.noreply.github.com>
Co-authored-by: xcnick <xcnick0412@gmail.com>
Co-authored-by: fangyixiao18 <fangyx18@hotmail.com>
Co-authored-by: Jiahao Xie <52497952+Jiahao000@users.noreply.github.com>
2022-04-29 20:01:30 +08:00

119 lines
3.3 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn.functional as F
from mmselfsup.models.heads import (ClsHead, ContrastiveHead, LatentClsHead,
LatentCrossCorrelationHead,
LatentPredictHead, MAEFinetuneHead,
MAEPretrainHead, MultiClsHead, SwAVHead)
def test_cls_head():
# test ClsHead
head = ClsHead()
fake_cls_score = [torch.rand(4, 3)]
fake_gt_label = torch.randint(0, 2, (4, ))
loss = head.loss(fake_cls_score, fake_gt_label)
assert loss['loss'].item() > 0
def test_contrastive_head():
head = ContrastiveHead()
fake_pos = torch.rand(32, 1) # N, 1
fake_neg = torch.rand(32, 100) # N, k
loss = head.forward(fake_pos, fake_neg)
assert loss['loss'].item() > 0
def test_latent_predict_head():
predictor = dict(
type='NonLinearNeck',
in_channels=64,
hid_channels=128,
out_channels=64,
with_bias=True,
with_last_bn=True,
with_avg_pool=False,
norm_cfg=dict(type='BN1d'))
head = LatentPredictHead(predictor=predictor)
fake_input = torch.rand(32, 64) # N, C
fake_traget = torch.rand(32, 64) # N, C
loss = head.forward(fake_input, fake_traget)
assert loss['loss'].item() > -1
def test_latent_cls_head():
head = LatentClsHead(64, 10)
fake_input = torch.rand(32, 64) # N, C
fake_traget = torch.rand(32, 64) # N, C
loss = head.forward(fake_input, fake_traget)
assert loss['loss'].item() > 0
def test_latent_cross_correlation_head():
head = LatentCrossCorrelationHead(2, 0.0051)
fake_input = torch.rand(32, 2) # N, C
fake_traget = torch.rand(32, 2) # N, C
loss = head.forward(fake_input, fake_traget)
assert loss['loss'].item() > 0
def test_multi_cls_head():
head = MultiClsHead(in_indices=(0, 1))
fake_input = [torch.rand(8, 64, 5, 5), torch.rand(8, 256, 14, 14)]
out = head.forward(fake_input)
assert isinstance(out, list)
fake_cls_score = [torch.rand(4, 3)]
fake_gt_label = torch.randint(0, 2, (4, ))
loss = head.loss(fake_cls_score, fake_gt_label)
print(loss.keys())
for k in loss.keys():
if 'loss' in k:
assert loss[k].item() > 0
def test_swav_head():
head = SwAVHead(feat_dim=128, num_crops=[2, 6])
fake_input = torch.rand(32, 128) # N, C
loss = head.forward(fake_input)
assert loss['loss'].item() > 0
def test_mae_pretrain_head():
head = MAEPretrainHead(norm_pix=False, patch_size=16)
fake_input = torch.rand((2, 3, 224, 224))
fake_mask = torch.ones((2, 196))
fake_pred = torch.rand((2, 196, 768))
loss = head.forward(fake_input, fake_pred, fake_mask)
assert loss['loss'].item() > 0
head_norm_pixel = MAEPretrainHead(norm_pix=True, patch_size=16)
loss_norm_pixel = head_norm_pixel.forward(fake_input, fake_pred, fake_mask)
assert loss_norm_pixel['loss'].item() > 0
def test_mae_finetune_head():
head = MAEFinetuneHead(num_classes=1000, embed_dim=768)
fake_input = torch.rand((2, 768))
fake_labels = F.normalize(torch.rand((2, 1000)), dim=-1)
fake_features = head.forward(fake_input)
assert list(fake_features[0].shape) == [2, 1000]
loss = head.loss(fake_features, fake_labels)
assert loss['loss'].item() > 0