mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Fix] fix inferencer ut (#3117)
This commit is contained in:
parent
c30d5060f1
commit
b2f4b4fe33
@ -5,7 +5,6 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmengine import ConfigDict
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
from mmseg.apis import MMSegInferencer
|
||||
from mmseg.models import EncoderDecoder
|
||||
@ -46,33 +45,8 @@ class ExampleModel(EncoderDecoder):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
class ExampleDataset(Dataset):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='LoadAnnotations'),
|
||||
dict(type='PackSegInputs')
|
||||
]
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return dict(img=torch.tensor([1]), img_metas=dict())
|
||||
|
||||
def __len__(self):
|
||||
return 1
|
||||
|
||||
|
||||
def test_inferencer():
|
||||
register_all_modules()
|
||||
test_dataset = ExampleDataset()
|
||||
data_loader = DataLoader(
|
||||
test_dataset,
|
||||
batch_size=1,
|
||||
sampler=None,
|
||||
num_workers=0,
|
||||
shuffle=False,
|
||||
)
|
||||
|
||||
visualizer = dict(
|
||||
type='SegLocalVisualizer',
|
||||
@ -87,7 +61,14 @@ def test_inferencer():
|
||||
decode_head=dict(type='InferExampleHead'),
|
||||
test_cfg=dict(mode='whole')),
|
||||
visualizer=visualizer,
|
||||
test_dataloader=data_loader)
|
||||
test_dataloader=dict(
|
||||
dataset=dict(
|
||||
type='ExampleDataset',
|
||||
pipeline=[
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='LoadAnnotations'),
|
||||
dict(type='PackSegInputs')
|
||||
]), ))
|
||||
cfg = ConfigDict(cfg_dict)
|
||||
model = MODELS.build(cfg.model)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user