diff --git a/tests/test_models/test_heads.py b/tests/test_models/test_heads.py
index 8182e28a5..160e2c367 100644
--- a/tests/test_models/test_heads.py
+++ b/tests/test_models/test_heads.py
@@ -551,9 +551,9 @@ def test_sep_fcn_head():
         num_classes=19,
         in_index=-1,
         norm_cfg=dict(type='SyncBN', requires_grad=True, momentum=0.01))
-    x = torch.rand(1, 128, 32, 32)
+    x = torch.rand(2, 128, 32, 32)
     output = head(x)
-    assert output.shape == (1, head.num_classes, 32, 32)
+    assert output.shape == (2, head.num_classes, 32, 32)
     assert not head.concat_input
     from mmseg.ops.separable_conv_module import DepthwiseSeparableConvModule
     assert isinstance(head.convs[0], DepthwiseSeparableConvModule)