mirror of
https://github.com/open-mmlab/mmselfsup.git
synced 2025-06-03 14:59:38 +08:00
fix ut
This commit is contained in:
parent
ad382c2115
commit
5ea54a48c8
@ -1,6 +1,4 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import Dataset
|
||||
@ -8,41 +6,52 @@ from torch.utils.data import Dataset
|
||||
from mmselfsup.models.utils import Extractor
|
||||
|
||||
|
||||
class ExampleDataset(Dataset):
|
||||
|
||||
def __getitem__(self, idx):
|
||||
results = dict(img=torch.tensor([1]), img_metas=dict())
|
||||
return results
|
||||
class DummyDataset(Dataset):
|
||||
METAINFO = dict() # type: ignore
|
||||
data = torch.randn(12, 2)
|
||||
label = torch.ones(12)
|
||||
|
||||
def __len__(self):
|
||||
return 1
|
||||
return self.data.size(0)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return dict(inputs=self.data[index], data_sample=self.label[index])
|
||||
|
||||
|
||||
class ExampleModel(nn.Module):
|
||||
class ToyModel(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(ExampleModel, self).__init__()
|
||||
self.test_cfg = None
|
||||
self.conv = nn.Conv2d(3, 3, 3)
|
||||
self.neck = nn.Identity()
|
||||
super().__init__()
|
||||
self.loss_lambda = 0.5
|
||||
self.linear = nn.Linear(2, 1)
|
||||
|
||||
def forward(self, img, test_mode=False, **kwargs):
|
||||
return img
|
||||
def forward(self, data_batch, return_loss=False):
|
||||
inputs, labels = [], []
|
||||
for x in data_batch:
|
||||
inputs.append(x['inputs'])
|
||||
labels.append(x['data_sample'])
|
||||
|
||||
def train_step(self, data_batch, optimizer):
|
||||
loss = self.forward(**data_batch)
|
||||
return dict(loss=loss)
|
||||
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
||||
inputs = torch.stack(inputs).to(device)
|
||||
labels = torch.stack(labels).to(device)
|
||||
outputs = self.linear(inputs)
|
||||
if return_loss:
|
||||
loss = (labels - outputs).sum()
|
||||
outputs = dict(loss=loss, log_vars=dict(loss=loss.item()))
|
||||
return outputs
|
||||
else:
|
||||
outputs = dict(log_vars=dict(a=1, b=0.5))
|
||||
return outputs
|
||||
|
||||
|
||||
def test_extractor():
|
||||
test_dataset = ExampleDataset()
|
||||
test_dataset.evaluate = MagicMock(return_value=dict(test='success'))
|
||||
dummy_dataset = DummyDataset()
|
||||
|
||||
extract_dataloader = dict(
|
||||
batch_size=1,
|
||||
num_workers=1,
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
dataset=test_dataset)
|
||||
dataset=dummy_dataset)
|
||||
|
||||
# test init
|
||||
extractor = Extractor(
|
||||
@ -51,7 +60,9 @@ def test_extractor():
|
||||
|
||||
# test init
|
||||
extractor = Extractor(
|
||||
extract_dataloader=extract_dataloader, dist_mode=False)
|
||||
extract_dataloader=extract_dataloader,
|
||||
dist_mode=False,
|
||||
pool_cfg=dict(type='AvgPool2d', output_size=1))
|
||||
|
||||
# TODO: test runtime
|
||||
# As the BaseModel is not defined finally, I will add it later.
|
||||
|
Loading…
x
Reference in New Issue
Block a user