mmselfsup/tests/test_runtime/test_extract_process.py

76 lines
2.2 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
from unittest.mock import MagicMock
import pytest
import torch
import torch.nn as nn
from mmcv.parallel import MMDataParallel
from torch.utils.data import DataLoader, Dataset
from mmselfsup.models.utils import ExtractProcess, MultiExtractProcess
class ExampleDataset(Dataset):
def __getitem__(self, idx):
results = dict(img=torch.tensor([1]), img_metas=dict())
return results
def __len__(self):
return 1
class ExampleModel(nn.Module):
def __init__(self):
super(ExampleModel, self).__init__()
self.conv = nn.Conv2d(3, 3, 3)
def forward(self, img, test_mode=False, **kwargs):
return [
torch.rand((1, 32, 112, 112)),
torch.rand((1, 64, 56, 56)),
torch.rand((1, 128, 28, 28)),
]
def train_step(self, data_batch, optimizer):
loss = self.forward(**data_batch)
return dict(loss=loss)
def test_extract_process():
test_dataset = ExampleDataset()
test_dataset.evaluate = MagicMock(return_value=dict(test='success'))
data_loader = DataLoader(
test_dataset, batch_size=1, sampler=None, num_workers=0, shuffle=False)
model = MMDataParallel(ExampleModel())
process = ExtractProcess()
results = process.extract(model, data_loader)
assert 'feat' in results
assert results['feat'].shape == (1, 128 * 1 * 1)
def test_multi_extract_process():
with pytest.raises(AssertionError):
process = MultiExtractProcess(
pool_type='specified', backbone='resnet50', layer_indices=(-1, ))
test_dataset = ExampleDataset()
test_dataset.evaluate = MagicMock(return_value=dict(test='success'))
data_loader = DataLoader(
test_dataset, batch_size=1, sampler=None, num_workers=0, shuffle=False)
model = MMDataParallel(ExampleModel())
process = MultiExtractProcess(
pool_type='specified', backbone='resnet50', layer_indices=(0, 1, 2))
results = process.extract(model, data_loader)
assert 'feat1' in results
assert 'feat2' in results
assert 'feat3' in results
assert results['feat1'].shape == (1, 32 * 12 * 12)
assert results['feat2'].shape == (1, 64 * 6 * 6)
assert results['feat3'].shape == (1, 128 * 4 * 4)