[Fix] fix inferencer ut (#3117)

This commit is contained in:
谢昕辰 2023-06-19 13:08:04 +08:00 committed by GitHub
parent c30d5060f1
commit b2f4b4fe33
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -5,7 +5,6 @@ import numpy as np
import torch
import torch.nn as nn
from mmengine import ConfigDict
from torch.utils.data import DataLoader, Dataset
from mmseg.apis import MMSegInferencer
from mmseg.models import EncoderDecoder
@ -46,33 +45,8 @@ class ExampleModel(EncoderDecoder):
super().__init__(**kwargs)
class ExampleDataset(Dataset):
def __init__(self) -> None:
super().__init__()
self.pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'),
dict(type='PackSegInputs')
]
def __getitem__(self, idx):
return dict(img=torch.tensor([1]), img_metas=dict())
def __len__(self):
return 1
def test_inferencer():
register_all_modules()
test_dataset = ExampleDataset()
data_loader = DataLoader(
test_dataset,
batch_size=1,
sampler=None,
num_workers=0,
shuffle=False,
)
visualizer = dict(
type='SegLocalVisualizer',
@ -87,7 +61,14 @@ def test_inferencer():
decode_head=dict(type='InferExampleHead'),
test_cfg=dict(mode='whole')),
visualizer=visualizer,
test_dataloader=data_loader)
test_dataloader=dict(
dataset=dict(
type='ExampleDataset',
pipeline=[
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'),
dict(type='PackSegInputs')
]), ))
cfg = ConfigDict(cfg_dict)
model = MODELS.build(cfg.model)