2021-12-15 19:07:01 +08:00
|
|
|
# Copyright (c) OpenMMLab. All rights reserved.
|
|
|
|
import torch
|
2022-03-04 13:43:49 +08:00
|
|
|
import torch.nn.functional as F
|
2021-12-15 19:07:01 +08:00
|
|
|
|
|
|
|
from mmselfsup.models.heads import (ClsHead, ContrastiveHead, LatentClsHead,
|
2022-04-29 20:01:30 +08:00
|
|
|
LatentCrossCorrelationHead,
|
2022-03-04 13:43:49 +08:00
|
|
|
LatentPredictHead, MAEFinetuneHead,
|
2022-09-24 17:48:36 +08:00
|
|
|
MAEPretrainHead, MaskFeatFinetuneHead,
|
|
|
|
MaskFeatPretrainHead, MultiClsHead,
|
|
|
|
SwAVHead)
|
2021-12-15 19:07:01 +08:00
|
|
|
|
|
|
|
|
|
|
|
def test_cls_head():
|
|
|
|
# test ClsHead
|
|
|
|
head = ClsHead()
|
|
|
|
fake_cls_score = [torch.rand(4, 3)]
|
|
|
|
fake_gt_label = torch.randint(0, 2, (4, ))
|
|
|
|
|
|
|
|
loss = head.loss(fake_cls_score, fake_gt_label)
|
|
|
|
assert loss['loss'].item() > 0
|
|
|
|
|
|
|
|
|
|
|
|
def test_contrastive_head():
|
|
|
|
head = ContrastiveHead()
|
|
|
|
fake_pos = torch.rand(32, 1) # N, 1
|
|
|
|
fake_neg = torch.rand(32, 100) # N, k
|
|
|
|
|
|
|
|
loss = head.forward(fake_pos, fake_neg)
|
|
|
|
assert loss['loss'].item() > 0
|
|
|
|
|
|
|
|
|
|
|
|
def test_latent_predict_head():
|
|
|
|
predictor = dict(
|
|
|
|
type='NonLinearNeck',
|
|
|
|
in_channels=64,
|
|
|
|
hid_channels=128,
|
|
|
|
out_channels=64,
|
|
|
|
with_bias=True,
|
|
|
|
with_last_bn=True,
|
|
|
|
with_avg_pool=False,
|
|
|
|
norm_cfg=dict(type='BN1d'))
|
|
|
|
head = LatentPredictHead(predictor=predictor)
|
|
|
|
fake_input = torch.rand(32, 64) # N, C
|
|
|
|
fake_traget = torch.rand(32, 64) # N, C
|
|
|
|
|
|
|
|
loss = head.forward(fake_input, fake_traget)
|
|
|
|
assert loss['loss'].item() > -1
|
|
|
|
|
|
|
|
|
|
|
|
def test_latent_cls_head():
|
|
|
|
head = LatentClsHead(64, 10)
|
|
|
|
fake_input = torch.rand(32, 64) # N, C
|
|
|
|
fake_traget = torch.rand(32, 64) # N, C
|
|
|
|
|
|
|
|
loss = head.forward(fake_input, fake_traget)
|
|
|
|
assert loss['loss'].item() > 0
|
|
|
|
|
|
|
|
|
2022-04-29 20:01:30 +08:00
|
|
|
def test_latent_cross_correlation_head():
|
|
|
|
head = LatentCrossCorrelationHead(2, 0.0051)
|
|
|
|
fake_input = torch.rand(32, 2) # N, C
|
|
|
|
fake_traget = torch.rand(32, 2) # N, C
|
|
|
|
|
|
|
|
loss = head.forward(fake_input, fake_traget)
|
|
|
|
assert loss['loss'].item() > 0
|
|
|
|
|
|
|
|
|
2021-12-15 19:07:01 +08:00
|
|
|
def test_multi_cls_head():
|
|
|
|
head = MultiClsHead(in_indices=(0, 1))
|
|
|
|
fake_input = [torch.rand(8, 64, 5, 5), torch.rand(8, 256, 14, 14)]
|
|
|
|
out = head.forward(fake_input)
|
|
|
|
assert isinstance(out, list)
|
|
|
|
|
|
|
|
fake_cls_score = [torch.rand(4, 3)]
|
|
|
|
fake_gt_label = torch.randint(0, 2, (4, ))
|
|
|
|
|
|
|
|
loss = head.loss(fake_cls_score, fake_gt_label)
|
|
|
|
print(loss.keys())
|
|
|
|
for k in loss.keys():
|
|
|
|
if 'loss' in k:
|
|
|
|
assert loss[k].item() > 0
|
|
|
|
|
|
|
|
|
|
|
|
def test_swav_head():
|
|
|
|
head = SwAVHead(feat_dim=128, num_crops=[2, 6])
|
|
|
|
fake_input = torch.rand(32, 128) # N, C
|
|
|
|
|
|
|
|
loss = head.forward(fake_input)
|
|
|
|
assert loss['loss'].item() > 0
|
2022-03-04 13:43:49 +08:00
|
|
|
|
|
|
|
|
|
|
|
def test_mae_pretrain_head():
|
|
|
|
head = MAEPretrainHead(norm_pix=False, patch_size=16)
|
|
|
|
fake_input = torch.rand((2, 3, 224, 224))
|
|
|
|
fake_mask = torch.ones((2, 196))
|
|
|
|
fake_pred = torch.rand((2, 196, 768))
|
|
|
|
|
|
|
|
loss = head.forward(fake_input, fake_pred, fake_mask)
|
|
|
|
|
|
|
|
assert loss['loss'].item() > 0
|
|
|
|
|
|
|
|
head_norm_pixel = MAEPretrainHead(norm_pix=True, patch_size=16)
|
|
|
|
|
|
|
|
loss_norm_pixel = head_norm_pixel.forward(fake_input, fake_pred, fake_mask)
|
|
|
|
|
|
|
|
assert loss_norm_pixel['loss'].item() > 0
|
|
|
|
|
2022-07-27 16:03:57 +08:00
|
|
|
x = torch.rand((1, 4, 16**2 * 3))
|
|
|
|
imgs = head_norm_pixel.unpatchify(x)
|
|
|
|
assert imgs.size() == torch.Size((1, 3, 32, 32))
|
|
|
|
|
2022-03-04 13:43:49 +08:00
|
|
|
|
|
|
|
def test_mae_finetune_head():
|
|
|
|
|
|
|
|
head = MAEFinetuneHead(num_classes=1000, embed_dim=768)
|
|
|
|
fake_input = torch.rand((2, 768))
|
|
|
|
fake_labels = F.normalize(torch.rand((2, 1000)), dim=-1)
|
|
|
|
fake_features = head.forward(fake_input)
|
|
|
|
|
|
|
|
assert list(fake_features[0].shape) == [2, 1000]
|
|
|
|
|
|
|
|
loss = head.loss(fake_features, fake_labels)
|
|
|
|
|
|
|
|
assert loss['loss'].item() > 0
|
2022-09-24 17:48:36 +08:00
|
|
|
|
|
|
|
|
|
|
|
def test_maskfeat_pretrain_head():
|
|
|
|
head = MaskFeatPretrainHead(hog_dim=108)
|
|
|
|
fake_mask = torch.ones((2, 14, 14)).bool()
|
|
|
|
fake_pred = torch.rand((2, 197, 768))
|
|
|
|
fake_hog = torch.rand((2, 196, 108))
|
|
|
|
|
|
|
|
loss = head.forward(fake_pred, fake_hog, fake_mask)
|
|
|
|
|
|
|
|
assert loss['loss'].item() > 0
|
|
|
|
|
|
|
|
|
|
|
|
def test_maskfeat_finetune_head():
|
|
|
|
|
|
|
|
head = MaskFeatFinetuneHead(num_classes=1000, embed_dim=768)
|
|
|
|
fake_input = torch.rand((2, 768))
|
|
|
|
fake_labels = F.normalize(torch.rand((2, 1000)), dim=-1)
|
|
|
|
fake_features = head.forward(fake_input)
|
|
|
|
|
|
|
|
assert list(fake_features[0].shape) == [2, 1000]
|
|
|
|
|
|
|
|
loss = head.loss(fake_features, fake_labels)
|
|
|
|
|
|
|
|
assert loss['loss'].item() > 0
|