diff --git a/tests/test_models/test_utils/test_extractor.py b/tests/test_models/test_utils/test_extractor.py index 4a3c606d..54876960 100644 --- a/tests/test_models/test_utils/test_extractor.py +++ b/tests/test_models/test_utils/test_extractor.py @@ -1,6 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. -from unittest.mock import MagicMock - import torch import torch.nn as nn from torch.utils.data import Dataset @@ -8,41 +6,52 @@ from torch.utils.data import Dataset from mmselfsup.models.utils import Extractor -class ExampleDataset(Dataset): - - def __getitem__(self, idx): - results = dict(img=torch.tensor([1]), img_metas=dict()) - return results +class DummyDataset(Dataset): + METAINFO = dict() # type: ignore + data = torch.randn(12, 2) + label = torch.ones(12) def __len__(self): - return 1 + return self.data.size(0) + + def __getitem__(self, index): + return dict(inputs=self.data[index], data_sample=self.label[index]) -class ExampleModel(nn.Module): +class ToyModel(nn.Module): def __init__(self): - super(ExampleModel, self).__init__() - self.test_cfg = None - self.conv = nn.Conv2d(3, 3, 3) - self.neck = nn.Identity() + super().__init__() + self.loss_lambda = 0.5 + self.linear = nn.Linear(2, 1) - def forward(self, img, test_mode=False, **kwargs): - return img + def forward(self, data_batch, return_loss=False): + inputs, labels = [], [] + for x in data_batch: + inputs.append(x['inputs']) + labels.append(x['data_sample']) - def train_step(self, data_batch, optimizer): - loss = self.forward(**data_batch) - return dict(loss=loss) + device = 'cuda:0' if torch.cuda.is_available() else 'cpu' + inputs = torch.stack(inputs).to(device) + labels = torch.stack(labels).to(device) + outputs = self.linear(inputs) + if return_loss: + loss = (labels - outputs).sum() + outputs = dict(loss=loss, log_vars=dict(loss=loss.item())) + return outputs + else: + outputs = dict(log_vars=dict(a=1, b=0.5)) + return outputs def test_extractor(): - test_dataset = ExampleDataset() - test_dataset.evaluate = MagicMock(return_value=dict(test='success')) + dummy_dataset = DummyDataset() extract_dataloader = dict( batch_size=1, num_workers=1, sampler=dict(type='DefaultSampler', shuffle=False), - dataset=test_dataset) + dataset=dummy_dataset) # test init extractor = Extractor( @@ -51,7 +60,9 @@ def test_extractor(): # test init extractor = Extractor( - extract_dataloader=extract_dataloader, dist_mode=False) + extract_dataloader=extract_dataloader, + dist_mode=False, + pool_cfg=dict(type='AvgPool2d', output_size=1)) # TODO: test runtime # As the BaseModel is not defined finally, I will add it later.