mirror of
https://github.com/open-mmlab/mmselfsup.git
synced 2025-06-03 14:59:38 +08:00
83 lines
2.4 KiB
Python
83 lines
2.4 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.utils.data import Dataset
|
|
|
|
from mmselfsup.models.utils import Extractor
|
|
|
|
|
|
class DummyDataset(Dataset):
|
|
METAINFO = dict() # type: ignore
|
|
data = torch.randn(12, 2)
|
|
label = torch.ones(12)
|
|
|
|
def __len__(self):
|
|
return self.data.size(0)
|
|
|
|
def __getitem__(self, index):
|
|
return dict(inputs=self.data[index], data_sample=self.label[index])
|
|
|
|
|
|
class ToyModel(nn.Module):
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.loss_lambda = 0.5
|
|
self.linear = nn.Linear(2, 1)
|
|
|
|
def forward(self, data_batch, return_loss=False):
|
|
inputs, labels = [], []
|
|
for x in data_batch:
|
|
inputs.append(x['inputs'])
|
|
labels.append(x['data_sample'])
|
|
|
|
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
|
inputs = torch.stack(inputs).to(device)
|
|
labels = torch.stack(labels).to(device)
|
|
outputs = self.linear(inputs)
|
|
if return_loss:
|
|
loss = (labels - outputs).sum()
|
|
outputs = dict(loss=loss, log_vars=dict(loss=loss.item()))
|
|
return outputs
|
|
else:
|
|
outputs = dict(log_vars=dict(a=1, b=0.5))
|
|
return outputs
|
|
|
|
|
|
def test_extractor():
|
|
dummy_dataset = DummyDataset()
|
|
|
|
extract_dataloader = dict(
|
|
batch_size=1,
|
|
num_workers=1,
|
|
sampler=dict(type='DefaultSampler', shuffle=False),
|
|
dataset=dummy_dataset)
|
|
|
|
# test init
|
|
extractor = Extractor(
|
|
extract_dataloader=extract_dataloader, dist_mode=False, pool_cfg=None)
|
|
assert getattr(extractor, 'pool', None) is None
|
|
|
|
# test init
|
|
extractor = Extractor(
|
|
extract_dataloader=extract_dataloader,
|
|
dist_mode=False,
|
|
pool_cfg=dict(type='AvgPool2d', output_size=1))
|
|
|
|
# TODO: test runtime
|
|
# As the BaseModel is not defined finally, I will add it later.
|
|
|
|
# # test extractor
|
|
# with tempfile.TemporaryDirectory() as tmpdir:
|
|
# model = MMDataParallel(ExampleModel())
|
|
# optimizer = build_optimizer(model, optim_cfg)
|
|
# runner = build_runner(
|
|
# runner_cfg,
|
|
# default_args=dict(
|
|
# model=model,
|
|
# optimizer=optimizer,
|
|
# work_dir=tmpdir,
|
|
# logger=logging.getLogger()))
|
|
# features = extractor(runner)
|
|
# assert features.shape == (1, 1)
|