[Fix] fix inferencer ut (#3117)

This commit is contained in:
谢昕辰 2023-06-19 13:08:04 +08:00 committed by GitHub
parent c30d5060f1
commit b2f4b4fe33
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -5,7 +5,6 @@ import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from mmengine import ConfigDict from mmengine import ConfigDict
from torch.utils.data import DataLoader, Dataset
from mmseg.apis import MMSegInferencer from mmseg.apis import MMSegInferencer
from mmseg.models import EncoderDecoder from mmseg.models import EncoderDecoder
@ -46,33 +45,8 @@ class ExampleModel(EncoderDecoder):
super().__init__(**kwargs) super().__init__(**kwargs)
class ExampleDataset(Dataset):
def __init__(self) -> None:
super().__init__()
self.pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'),
dict(type='PackSegInputs')
]
def __getitem__(self, idx):
return dict(img=torch.tensor([1]), img_metas=dict())
def __len__(self):
return 1
def test_inferencer(): def test_inferencer():
register_all_modules() register_all_modules()
test_dataset = ExampleDataset()
data_loader = DataLoader(
test_dataset,
batch_size=1,
sampler=None,
num_workers=0,
shuffle=False,
)
visualizer = dict( visualizer = dict(
type='SegLocalVisualizer', type='SegLocalVisualizer',
@ -87,7 +61,14 @@ def test_inferencer():
decode_head=dict(type='InferExampleHead'), decode_head=dict(type='InferExampleHead'),
test_cfg=dict(mode='whole')), test_cfg=dict(mode='whole')),
visualizer=visualizer, visualizer=visualizer,
test_dataloader=data_loader) test_dataloader=dict(
dataset=dict(
type='ExampleDataset',
pipeline=[
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'),
dict(type='PackSegInputs')
]), ))
cfg = ConfigDict(cfg_dict) cfg = ConfigDict(cfg_dict)
model = MODELS.build(cfg.model) model = MODELS.build(cfg.model)