[Refactor]: Refactor relative loc
parent
5ba17adb23
commit
bcc4576ace
|
@ -14,16 +14,6 @@ train_pipeline = [
|
|||
pseudo_label_keys=['patch_box', 'patch_label', 'unpatched_img'],
|
||||
meta_keys=['img_path'])
|
||||
]
|
||||
val_pipeline = [
|
||||
dict(type='LoadImageFromFile', file_client_args=file_client_args),
|
||||
dict(type='Resize', scale=292),
|
||||
dict(type='CenterCrop', crop_size=255),
|
||||
dict(type='RandomPatchWithLabels'),
|
||||
dict(
|
||||
type='PackSelfSupInputs',
|
||||
pseudo_label_keys=['patch_label'],
|
||||
meta_keys=['img_path'])
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=64,
|
||||
|
@ -36,14 +26,3 @@ train_dataloader = dict(
|
|||
ann_file='meta/train.txt',
|
||||
data_prefix=dict(img_path='train/'),
|
||||
pipeline=train_pipeline))
|
||||
val_dataloader = dict(
|
||||
batch_size=64,
|
||||
num_workers=4,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
ann_file='meta/val.txt',
|
||||
data_prefix=dict(img_path='val/'),
|
||||
pipeline=val_pipeline))
|
||||
|
|
|
@ -1,6 +1,11 @@
|
|||
# model settings
|
||||
model = dict(
|
||||
type='RelativeLoc',
|
||||
data_preprocessor=dict(
|
||||
type='mmselfsup.RelativeLocDataPreprocessor',
|
||||
mean=[124, 117, 104],
|
||||
std=[59, 58, 58],
|
||||
bgr_to_rgb=True),
|
||||
backbone=dict(
|
||||
type='ResNet',
|
||||
depth=50,
|
||||
|
@ -14,11 +19,11 @@ model = dict(
|
|||
with_avg_pool=True),
|
||||
head=dict(
|
||||
type='ClsHead',
|
||||
loss=dict(type='mmcls.CrossEntropyLoss'),
|
||||
with_avg_pool=False,
|
||||
in_channels=4096,
|
||||
num_classes=8,
|
||||
init_cfg=[
|
||||
dict(type='Normal', std=0.005, layer='Linear'),
|
||||
dict(type='Constant', val=1, layer=['_BatchNorm', 'GroupNorm'])
|
||||
]),
|
||||
loss=dict(type='mmcls.CrossEntropyLoss'))
|
||||
]))
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# optimizer
|
||||
# optimizer wrapper
|
||||
optimizer = dict(type='SGD', lr=0.03, weight_decay=1e-4, momentum=0.9)
|
||||
optim_wrapper = dict(type='OptimWrapper', optimizer=optimizer)
|
||||
|
||||
|
|
|
@ -5,19 +5,18 @@ _base_ = [
|
|||
'../_base_/default_runtime.py',
|
||||
]
|
||||
|
||||
# optimizer
|
||||
optimizer = dict(
|
||||
type='SGD',
|
||||
lr=0.2,
|
||||
weight_decay=1e-4,
|
||||
momentum=0.9,
|
||||
paramwise_options={
|
||||
'\\Aneck.': dict(weight_decay=5e-4),
|
||||
'\\Ahead.': dict(weight_decay=5e-4)
|
||||
})
|
||||
# optimizer wrapper
|
||||
optimizer = dict(type='SGD', lr=0.2, momentum=0.9, weight_decay=1e-4)
|
||||
optim_wrapper = dict(
|
||||
type='OptimWrapper',
|
||||
optimizer=optimizer,
|
||||
paramwise_cfg=dict(custom_keys={
|
||||
'neck': dict(decay_mult=5.0),
|
||||
'head': dict(decay_mult=5.0)
|
||||
}))
|
||||
|
||||
# learning rate scheduler
|
||||
scheduler = [
|
||||
param_scheduler = [
|
||||
dict(
|
||||
type='LinearLR',
|
||||
start_factor=0.1,
|
||||
|
@ -29,8 +28,10 @@ scheduler = [
|
|||
]
|
||||
|
||||
# runtime settings
|
||||
runner = dict(type='EpochBasedRunner', max_epochs=70)
|
||||
# pre-train for 70 epochs
|
||||
train_cfg = dict(max_epochs=70)
|
||||
# the max_keep_ckpts controls the max number of ckpt file in your work_dirs
|
||||
# if it is 3, when CheckpointHook (in mmcv) saves the 4th ckpt
|
||||
# it will remove the oldest one to keep the number of total ckpts as 3
|
||||
checkpoint_config = dict(interval=10, max_keep_ckpts=3)
|
||||
default_hooks = dict(
|
||||
checkpoint=dict(type='CheckpointHook', interval=10, max_keep_ckpts=3))
|
||||
|
|
|
@ -143,12 +143,12 @@ class SelfSupDataSample(BaseDataElement):
|
|||
del self._pred_label
|
||||
|
||||
@property
|
||||
def pseudo_label(self) -> InstanceData:
|
||||
def pseudo_label(self) -> BaseDataElement:
|
||||
return self._pseudo_label
|
||||
|
||||
@pseudo_label.setter
|
||||
def pseudo_label(self, value: InstanceData):
|
||||
self.set_field(value, '_pseudo_label', dtype=InstanceData)
|
||||
def pseudo_label(self, value: BaseDataElement):
|
||||
self.set_field(value, '_pseudo_label', dtype=BaseDataElement)
|
||||
|
||||
@pseudo_label.deleter
|
||||
def pseudo_label(self):
|
||||
|
|
|
@ -1,179 +1,86 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
from mmengine.data import InstanceData
|
||||
from mmengine.data import LabelData
|
||||
|
||||
from mmselfsup.core import SelfSupDataSample
|
||||
from mmselfsup.utils import get_module_device
|
||||
from ..builder import (ALGORITHMS, build_backbone, build_head, build_loss,
|
||||
build_neck)
|
||||
from ..builder import MODELS
|
||||
from .base import BaseModel
|
||||
|
||||
|
||||
@ALGORITHMS.register_module()
|
||||
@MODELS.register_module()
|
||||
class RelativeLoc(BaseModel):
|
||||
"""Relative patch location.
|
||||
|
||||
Implementation of `Unsupervised Visual Representation Learning
|
||||
by Context Prediction <https://arxiv.org/abs/1505.05192>`_.
|
||||
|
||||
Args:
|
||||
backbone (Dict, optional): Config dict for module of backbone.
|
||||
Defaults to None.
|
||||
neck (Dict, optional): Config dict for module of deep features
|
||||
to compact feature vectors. Defaults to None.
|
||||
head (Dict, optional): Config dict for module of head function.
|
||||
Defaults to None.
|
||||
loss (Dict): Config dict for module of loss function. Defaults to None.
|
||||
preprocess_cfg (Dict, optional): Config dict to preprocess images.
|
||||
Defaults to None.
|
||||
init_cfg (Dict or List[Dict], optional): Config dict for weight
|
||||
initialization. Defaults to None.
|
||||
Implementation of `Unsupervised Visual Representation Learning by Context
|
||||
Prediction <https://arxiv.org/abs/1505.05192>`_.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
backbone: Optional[Dict] = None,
|
||||
neck: Optional[Dict] = None,
|
||||
head: Optional[Dict] = None,
|
||||
loss: Optional[Dict] = None,
|
||||
preprocess_cfg: Optional[Dict] = None,
|
||||
init_cfg: Optional[Union[Dict, List[Dict]]] = None,
|
||||
**kwargs) -> None:
|
||||
super().__init__(preprocess_cfg=preprocess_cfg, init_cfg=init_cfg)
|
||||
assert backbone is not None
|
||||
self.backbone = build_backbone(backbone)
|
||||
assert neck is not None
|
||||
self.neck = build_neck(neck)
|
||||
assert head is not None
|
||||
self.head = build_head(head)
|
||||
assert loss is not None
|
||||
self.loss = build_loss(loss)
|
||||
|
||||
def extract_feat(self, inputs: List[torch.Tensor],
|
||||
data_samples: List[SelfSupDataSample],
|
||||
def extract_feat(self, batch_inputs: List[torch.Tensor],
|
||||
**kwargs) -> Tuple[torch.Tensor]:
|
||||
"""Function to extract features from backbone.
|
||||
|
||||
Args:
|
||||
inputs (List[torch.Tensor]): The input images.
|
||||
data_samples (List[SelfSupDataSample]): All elements required
|
||||
during the forward function.
|
||||
batch_inputs (List[torch.Tensor]): The input images.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor]: backbone outputs.
|
||||
"""
|
||||
|
||||
x = self.backbone(inputs[0])
|
||||
x = self.backbone(batch_inputs[0])
|
||||
return x
|
||||
|
||||
def forward_train(self, inputs: List[torch.Tensor],
|
||||
data_samples: List[SelfSupDataSample],
|
||||
**kwargs) -> Dict[str, torch.Tensor]:
|
||||
def loss(self, batch_inputs: List[torch.Tensor],
|
||||
data_samples: List[SelfSupDataSample],
|
||||
**kwargs) -> Dict[str, torch.Tensor]:
|
||||
"""Forward computation during training.
|
||||
|
||||
Args:
|
||||
inputs (List[torch.Tensor]): The input images.
|
||||
batch_inputs (List[torch.Tensor]): The input images.
|
||||
data_samples (List[SelfSupDataSample]): All elements required
|
||||
during the forward function.
|
||||
|
||||
Returns:
|
||||
Dict[str, torch.Tensor]: A dictionary of loss components.
|
||||
"""
|
||||
|
||||
x1 = self.backbone(inputs[0])
|
||||
x2 = self.backbone(inputs[1])
|
||||
x1 = self.backbone(batch_inputs[0])
|
||||
x2 = self.backbone(batch_inputs[1])
|
||||
x = (torch.cat((x1[0], x2[0]), dim=1), )
|
||||
x = self.neck(x)
|
||||
outs = self.head(x)
|
||||
|
||||
patch_label = [
|
||||
data_sample.patch_label.value for data_sample in data_samples
|
||||
data_sample.pseudo_label.patch_label
|
||||
for data_sample in data_samples
|
||||
]
|
||||
|
||||
patch_label = torch.flatten(torch.stack(patch_label, 0))
|
||||
loss = self.loss(outs[0], patch_label)
|
||||
loss = self.head(x, patch_label)
|
||||
losses = dict(loss=loss)
|
||||
return losses
|
||||
|
||||
def forward_test(self, inputs: List[torch.Tensor],
|
||||
data_samples: List[SelfSupDataSample],
|
||||
**kwargs) -> List[SelfSupDataSample]:
|
||||
def predict(self, batch_inputs: List[torch.Tensor],
|
||||
data_samples: List[SelfSupDataSample],
|
||||
**kwargs) -> List[SelfSupDataSample]:
|
||||
"""The forward function in testing.
|
||||
|
||||
Args:
|
||||
inputs (List[torch.Tensor]): The input images.
|
||||
batch_inputs (List[torch.Tensor]): The input images.
|
||||
data_samples (List[SelfSupDataSample]): All elements required
|
||||
during the forward function.
|
||||
|
||||
Returns:
|
||||
List[SelfSupDataSample]: The prediction from model.
|
||||
"""
|
||||
|
||||
x1 = self.backbone(inputs[0])
|
||||
x2 = self.backbone(inputs[1])
|
||||
x1 = self.backbone(batch_inputs[0])
|
||||
x2 = self.backbone(batch_inputs[1])
|
||||
x = (torch.cat((x1[0], x2[0]), dim=1), )
|
||||
x = self.neck(x)
|
||||
outs = self.head(x)
|
||||
outs = self.head.logits(x)
|
||||
keys = [f'head{i}' for i in self.backbone.out_indices]
|
||||
outs = [torch.chunk(out, len(outs[0]) // 8, 0) for out in outs]
|
||||
|
||||
for i in range(len(outs[0])):
|
||||
prediction_data = {key: out[i] for key, out in zip(keys, outs)}
|
||||
prediction = InstanceData(**prediction_data)
|
||||
data_samples[i].prediction = prediction
|
||||
prediction = LabelData(**prediction_data)
|
||||
data_samples[i].pred_label = prediction
|
||||
return data_samples
|
||||
|
||||
def preprocss_data(
|
||||
self,
|
||||
data: List[Dict]) -> Tuple[List[torch.Tensor], SelfSupDataSample]:
|
||||
"""Process input data during training, testing or extracting.
|
||||
|
||||
Args:
|
||||
data (List[Dict]): The data to be processed, which
|
||||
comes from dataloader.
|
||||
|
||||
Returns:
|
||||
tuple: It should contain 2 item.
|
||||
- batch_images (List[torch.Tensor]): The batch image tensor.
|
||||
- data_samples (List[SelfSupDataSample], Optional): The Data
|
||||
Samples. It usually includes information such as
|
||||
`gt_label`. Return None If the input data does not
|
||||
contain `data_sample`.
|
||||
"""
|
||||
|
||||
# data_['inputs] is a list
|
||||
images = [data_['inputs'] for data_ in data]
|
||||
data_samples = [data_['data_sample'] for data_ in data]
|
||||
|
||||
device = get_module_device(self)
|
||||
data_samples = [data_sample.to(device) for data_sample in data_samples]
|
||||
images = [[img_.to(device) for img_ in img] for img in images]
|
||||
|
||||
# convert images to rgb
|
||||
if self.to_rgb and images[0][0].size(0) == 3:
|
||||
images = [[img_[[2, 1, 0], ...] for img_ in img] for img in images]
|
||||
|
||||
# normalize images
|
||||
images = [[(img_ - self.mean_norm) / self.std_norm for img_ in img]
|
||||
for img in images]
|
||||
|
||||
# reconstruct images into several batches. RelativeLoc needs
|
||||
# nine crops for each image, and this code snippet will convert images
|
||||
# into nine batches, each containing one crop of an image.
|
||||
batch_images = []
|
||||
for i in range(len(images[0])):
|
||||
cur_batch = [img[i] for img in images]
|
||||
batch_images.append(torch.stack(cur_batch))
|
||||
|
||||
img1 = torch.stack(batch_images[1:], 1) # Nx8xCxHxW
|
||||
img1 = img1.view(
|
||||
img1.size(0) * img1.size(1), img1.size(2), img1.size(3),
|
||||
img1.size(4)) # (8N)xCxHxW
|
||||
img2 = torch.unsqueeze(batch_images[0], 1).repeat(1, 8, 1, 1,
|
||||
1) # Nx8xCxHxW
|
||||
img2 = img2.view(
|
||||
img2.size(0) * img2.size(1), img2.size(2), img2.size(3),
|
||||
img2.size(4)) # (8N)xCxHxW
|
||||
batch_images = [img1, img2]
|
||||
|
||||
return batch_images, data_samples
|
||||
|
|
|
@ -5,14 +5,15 @@ import torch
|
|||
import torch.nn as nn
|
||||
from mmcv.runner import BaseModule
|
||||
|
||||
from ..builder import HEADS
|
||||
from ..builder import MODELS
|
||||
|
||||
|
||||
@HEADS.register_module()
|
||||
@MODELS.register_module()
|
||||
class ClsHead(BaseModule):
|
||||
"""Simplest classifier head, with only one fc layer.
|
||||
|
||||
Args:
|
||||
loss (dict): Config of the loss.
|
||||
with_avg_pool (bool): Whether to apply the average pooling
|
||||
after neck. Defaults to False.
|
||||
in_channels (int): Number of input channels. Defaults to 2048.
|
||||
|
@ -22,6 +23,7 @@ class ClsHead(BaseModule):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
loss: dict,
|
||||
with_avg_pool: Optional[bool] = False,
|
||||
in_channels: Optional[int] = 2048,
|
||||
num_classes: Optional[int] = 1000,
|
||||
|
@ -32,6 +34,7 @@ class ClsHead(BaseModule):
|
|||
]
|
||||
) -> None:
|
||||
super().__init__(init_cfg)
|
||||
self.loss = MODELS.build(loss)
|
||||
self.with_avg_pool = with_avg_pool
|
||||
self.in_channels = in_channels
|
||||
self.num_classes = num_classes
|
||||
|
@ -41,10 +44,12 @@ class ClsHead(BaseModule):
|
|||
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.fc_cls = nn.Linear(in_channels, num_classes)
|
||||
|
||||
def forward(
|
||||
def logits(
|
||||
self, x: Union[List[torch.Tensor],
|
||||
Tuple[torch.Tensor]]) -> List[torch.Tensor]:
|
||||
"""Forward head.
|
||||
"""Get the logits before the cross_entropy loss.
|
||||
|
||||
This module is used to obtain the logits before the loss.
|
||||
|
||||
Args:
|
||||
x (List[Tensor] | Tuple[Tensor]): Feature maps of backbone,
|
||||
|
@ -64,3 +69,16 @@ class ClsHead(BaseModule):
|
|||
x = x.view(x.size(0), -1)
|
||||
cls_score = self.fc_cls(x)
|
||||
return [cls_score]
|
||||
|
||||
def forward(self, x: Union[List[torch.Tensor], Tuple[torch.Tensor]],
|
||||
label: torch.Tensor) -> torch.Tensor:
|
||||
"""Get the loss.
|
||||
|
||||
Args:
|
||||
x (List[Tensor] | Tuple[Tensor]): Feature maps of backbone,
|
||||
each tensor has shape (N, C, H, W).
|
||||
label (torch.Tensor): The label for cross entropy loss.
|
||||
"""
|
||||
outs = self.logits(x)
|
||||
loss = self.loss(outs[0], label)
|
||||
return loss
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .dall_e import Encoder
|
||||
from .data_preprocessor import SelfSupDataPreprocessor
|
||||
from .data_preprocessor import (RelativeLocDataPreprocessor,
|
||||
SelfSupDataPreprocessor)
|
||||
from .ema import CosineEMA
|
||||
from .extractor import Extractor
|
||||
from .gather_layer import GatherLayer
|
||||
|
@ -15,5 +16,5 @@ __all__ = [
|
|||
'Extractor', 'GatherLayer', 'MultiPooling', 'MultiPrototypes',
|
||||
'build_2d_sincos_position_embedding', 'Sobel', 'MultiheadAttention',
|
||||
'TransformerEncoderLayer', 'CAETransformerRegressorLayer', 'Encoder',
|
||||
'CosineEMA', 'SelfSupDataPreprocessor'
|
||||
'CosineEMA', 'SelfSupDataPreprocessor', 'RelativeLocDataPreprocessor'
|
||||
]
|
||||
|
|
|
@ -91,3 +91,58 @@ class SelfSupDataPreprocessor(ImgDataPreprocessor):
|
|||
batch_inputs.append(torch.stack(cur_batch))
|
||||
|
||||
return batch_inputs, batch_data_samples
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class RelativeLocDataPreprocessor(SelfSupDataPreprocessor):
|
||||
"""Image pre-processor for Relative Location."""
|
||||
|
||||
def forward(
|
||||
self,
|
||||
data: Sequence[dict],
|
||||
training: bool = False
|
||||
) -> Tuple[List[torch.Tensor], Optional[list]]:
|
||||
"""Performs normalization、padding and bgr2rgb conversion based on
|
||||
``BaseDataPreprocessor``.
|
||||
|
||||
Args:
|
||||
data (Sequence[dict]): data sampled from dataloader.
|
||||
training (bool): Whether to enable training time augmentation. If
|
||||
subclasses override this method, they can perform different
|
||||
preprocessing strategies for training and testing based on the
|
||||
value of ``training``.
|
||||
Returns:
|
||||
Tuple[torch.Tensor, Optional[list]]: Data in the same format as the
|
||||
model input.
|
||||
"""
|
||||
inputs, batch_data_samples = self.collate_data(data)
|
||||
# channel transform
|
||||
if self.channel_conversion:
|
||||
inputs = [[img_[[2, 1, 0], ...] for img_ in _input]
|
||||
for _input in inputs]
|
||||
|
||||
# Normalization. Here is what is different from
|
||||
# :class:`mmengine.ImgDataPreprocessor`. Since there are multiple views
|
||||
# for an image for some algorithms, e.g. SimCLR, each item in inputs
|
||||
# is a list, containing multi-views for an image.
|
||||
inputs = [[(img_ - self.mean) / self.std for img_ in _input]
|
||||
for _input in inputs]
|
||||
|
||||
batch_inputs = []
|
||||
for i in range(len(inputs[0])):
|
||||
cur_batch = [img[i] for img in inputs]
|
||||
batch_inputs.append(torch.stack(cur_batch))
|
||||
|
||||
# This part is unique to Relative Loc
|
||||
img1 = torch.stack(batch_inputs[1:], 1) # Nx8xCxHxW
|
||||
img1 = img1.view(
|
||||
img1.size(0) * img1.size(1), img1.size(2), img1.size(3),
|
||||
img1.size(4)) # (8N)xCxHxW
|
||||
img2 = torch.unsqueeze(batch_inputs[0], 1).repeat(1, 8, 1, 1,
|
||||
1) # Nx8xCxHxW
|
||||
img2 = img2.view(
|
||||
img2.size(0) * img2.size(1), img2.size(2), img2.size(3),
|
||||
img2.size(4)) # (8N)xCxHxW
|
||||
batch_inputs = [img1, img2]
|
||||
|
||||
return batch_inputs, batch_data_samples
|
||||
|
|
|
@ -1,10 +1,9 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
import platform
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from mmengine.data import LabelData
|
||||
from mmengine.data import InstanceData
|
||||
|
||||
from mmselfsup.core.data_structures.selfsup_data_sample import \
|
||||
SelfSupDataSample
|
||||
|
@ -23,6 +22,7 @@ neck = dict(
|
|||
with_avg_pool=True)
|
||||
head = dict(
|
||||
type='ClsHead',
|
||||
loss=dict(type='mmcls.CrossEntropyLoss'),
|
||||
with_avg_pool=False,
|
||||
in_channels=32,
|
||||
num_classes=8,
|
||||
|
@ -30,51 +30,22 @@ head = dict(
|
|||
dict(type='Normal', std=0.005, layer='Linear'),
|
||||
dict(type='Constant', val=1, layer=['_BatchNorm', 'GroupNorm'])
|
||||
])
|
||||
loss = dict(type='mmcls.CrossEntropyLoss')
|
||||
|
||||
|
||||
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
|
||||
def test_relative_loc():
|
||||
preprocess_cfg = {
|
||||
data_preprocessor = {
|
||||
'type': 'mmselfsup.RelativeLocDataPreprocessor',
|
||||
'mean': [0.5, 0.5, 0.5],
|
||||
'std': [0.5, 0.5, 0.5],
|
||||
'to_rgb': True
|
||||
'bgr_to_rgb': True
|
||||
}
|
||||
with pytest.raises(AssertionError):
|
||||
alg = RelativeLoc(
|
||||
backbone=backbone,
|
||||
neck=None,
|
||||
head=head,
|
||||
loss=loss,
|
||||
preprocess_cfg=copy.deepcopy(preprocess_cfg))
|
||||
with pytest.raises(AssertionError):
|
||||
alg = RelativeLoc(
|
||||
backbone=backbone,
|
||||
neck=neck,
|
||||
head=None,
|
||||
loss=loss,
|
||||
preprocess_cfg=copy.deepcopy(preprocess_cfg))
|
||||
with pytest.raises(AssertionError):
|
||||
alg = RelativeLoc(
|
||||
backbone=None,
|
||||
neck=neck,
|
||||
head=head,
|
||||
preprocess_cfg=copy.deepcopy(preprocess_cfg))
|
||||
with pytest.raises(AssertionError):
|
||||
alg = RelativeLoc(
|
||||
backbone=backbone,
|
||||
neck=neck,
|
||||
head=head,
|
||||
loss=None,
|
||||
preprocess_cfg=copy.deepcopy(preprocess_cfg))
|
||||
|
||||
alg = RelativeLoc(
|
||||
backbone=backbone,
|
||||
neck=neck,
|
||||
head=head,
|
||||
loss=loss,
|
||||
preprocess_cfg=copy.deepcopy(preprocess_cfg))
|
||||
alg.init_weights()
|
||||
data_preprocessor=data_preprocessor)
|
||||
|
||||
batch_size = 5
|
||||
fake_data = [{
|
||||
|
@ -89,19 +60,19 @@ def test_relative_loc():
|
|||
SelfSupDataSample()
|
||||
} for _ in range(batch_size)]
|
||||
|
||||
patch_label = LabelData()
|
||||
patch_label.value = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
|
||||
pseudo_label = InstanceData()
|
||||
pseudo_label.patch_label = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
|
||||
for i in range(batch_size):
|
||||
fake_data[i]['data_sample'].patch_label = patch_label
|
||||
fake_data[i]['data_sample'].pseudo_label = pseudo_label
|
||||
|
||||
fake_outputs = alg(fake_data, return_loss=True)
|
||||
fake_batch_inputs, fake_data_samples = alg.data_preprocessor(fake_data)
|
||||
|
||||
fake_outputs = alg(fake_batch_inputs, fake_data_samples, mode='loss')
|
||||
assert isinstance(fake_outputs['loss'].item(), float)
|
||||
|
||||
test_results = alg(fake_data, return_loss=False)
|
||||
test_results = alg(fake_batch_inputs, fake_data_samples, mode='predict')
|
||||
assert len(test_results) == len(fake_data)
|
||||
assert list(test_results[0].prediction.head4.shape) == [8, 8]
|
||||
assert list(test_results[0].pred_label.head4.shape) == [8, 8]
|
||||
|
||||
fake_inputs, fake_data_samples = alg.preprocss_data(fake_data)
|
||||
fake_feat = alg.extract_feat(
|
||||
inputs=fake_inputs, data_samples=fake_data_samples)
|
||||
fake_feat = alg(fake_batch_inputs, fake_data_samples, mode='tensor')
|
||||
assert list(fake_feat[0].shape) == [batch_size * 8, 512, 1, 1]
|
||||
|
|
Loading…
Reference in New Issue