diff --git a/tests/test_apis/test_inferencer.py b/tests/test_apis/test_inferencer.py index 497eae4a0..663680976 100644 --- a/tests/test_apis/test_inferencer.py +++ b/tests/test_apis/test_inferencer.py @@ -5,7 +5,6 @@ import numpy as np import torch import torch.nn as nn from mmengine import ConfigDict -from torch.utils.data import DataLoader, Dataset from mmseg.apis import MMSegInferencer from mmseg.models import EncoderDecoder @@ -46,33 +45,8 @@ class ExampleModel(EncoderDecoder): super().__init__(**kwargs) -class ExampleDataset(Dataset): - - def __init__(self) -> None: - super().__init__() - self.pipeline = [ - dict(type='LoadImageFromFile'), - dict(type='LoadAnnotations'), - dict(type='PackSegInputs') - ] - - def __getitem__(self, idx): - return dict(img=torch.tensor([1]), img_metas=dict()) - - def __len__(self): - return 1 - - def test_inferencer(): register_all_modules() - test_dataset = ExampleDataset() - data_loader = DataLoader( - test_dataset, - batch_size=1, - sampler=None, - num_workers=0, - shuffle=False, - ) visualizer = dict( type='SegLocalVisualizer', @@ -87,7 +61,14 @@ def test_inferencer(): decode_head=dict(type='InferExampleHead'), test_cfg=dict(mode='whole')), visualizer=visualizer, - test_dataloader=data_loader) + test_dataloader=dict( + dataset=dict( + type='ExampleDataset', + pipeline=[ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='PackSegInputs') + ]), )) cfg = ConfigDict(cfg_dict) model = MODELS.build(cfg.model)