This commit is contained in:
fangyixiao18 2022-06-06 18:16:35 +08:00
parent ad382c2115
commit 5ea54a48c8

View File

@ -1,6 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from unittest.mock import MagicMock
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.utils.data import Dataset from torch.utils.data import Dataset
@ -8,41 +6,52 @@ from torch.utils.data import Dataset
from mmselfsup.models.utils import Extractor from mmselfsup.models.utils import Extractor
class ExampleDataset(Dataset): class DummyDataset(Dataset):
METAINFO = dict() # type: ignore
def __getitem__(self, idx): data = torch.randn(12, 2)
results = dict(img=torch.tensor([1]), img_metas=dict()) label = torch.ones(12)
return results
def __len__(self): def __len__(self):
return 1 return self.data.size(0)
def __getitem__(self, index):
return dict(inputs=self.data[index], data_sample=self.label[index])
class ExampleModel(nn.Module): class ToyModel(nn.Module):
def __init__(self): def __init__(self):
super(ExampleModel, self).__init__() super().__init__()
self.test_cfg = None self.loss_lambda = 0.5
self.conv = nn.Conv2d(3, 3, 3) self.linear = nn.Linear(2, 1)
self.neck = nn.Identity()
def forward(self, img, test_mode=False, **kwargs): def forward(self, data_batch, return_loss=False):
return img inputs, labels = [], []
for x in data_batch:
inputs.append(x['inputs'])
labels.append(x['data_sample'])
def train_step(self, data_batch, optimizer): device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
loss = self.forward(**data_batch) inputs = torch.stack(inputs).to(device)
return dict(loss=loss) labels = torch.stack(labels).to(device)
outputs = self.linear(inputs)
if return_loss:
loss = (labels - outputs).sum()
outputs = dict(loss=loss, log_vars=dict(loss=loss.item()))
return outputs
else:
outputs = dict(log_vars=dict(a=1, b=0.5))
return outputs
def test_extractor(): def test_extractor():
test_dataset = ExampleDataset() dummy_dataset = DummyDataset()
test_dataset.evaluate = MagicMock(return_value=dict(test='success'))
extract_dataloader = dict( extract_dataloader = dict(
batch_size=1, batch_size=1,
num_workers=1, num_workers=1,
sampler=dict(type='DefaultSampler', shuffle=False), sampler=dict(type='DefaultSampler', shuffle=False),
dataset=test_dataset) dataset=dummy_dataset)
# test init # test init
extractor = Extractor( extractor = Extractor(
@ -51,7 +60,9 @@ def test_extractor():
# test init # test init
extractor = Extractor( extractor = Extractor(
extract_dataloader=extract_dataloader, dist_mode=False) extract_dataloader=extract_dataloader,
dist_mode=False,
pool_cfg=dict(type='AvgPool2d', output_size=1))
# TODO: test runtime # TODO: test runtime
# As the BaseModel is not defined finally, I will add it later. # As the BaseModel is not defined finally, I will add it later.