diff --git a/mmcv/ops/border_align.py b/mmcv/ops/border_align.py index e5a5bb743..70c6ed5f1 100644 --- a/mmcv/ops/border_align.py +++ b/mmcv/ops/border_align.py @@ -85,11 +85,12 @@ class BorderAlign(nn.Module): (e.g. top, bottom, left, right). """ - def __init__(self, pool_size): + def __init__(self, pool_size: int): super().__init__() self.pool_size = pool_size - def forward(self, input, boxes): + def forward(self, input: torch.Tensor, + boxes: torch.Tensor) -> torch.Tensor: """ Args: input: Features with shape [N,4C,H,W]. Channels ranged in [0,C),