mirror of https://github.com/open-mmlab/mmcv.git
Add type hints in mmcv/ops/active_rotated_filter.py (#2017)
* add typehints in ops-active-rotated-filter * resolve typehints in ops-active-rotated-filterpull/2021/head
parent
19902d897a
commit
1577f40744
|
@ -1,4 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from torch.autograd import Function
|
||||
from torch.autograd.function import once_differentiable
|
||||
|
@ -19,7 +21,8 @@ class ActiveRotatedFilterFunction(Function):
|
|||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input, indices):
|
||||
def forward(ctx, input: torch.Tensor,
|
||||
indices: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
input (torch.Tensor): Input features with shape
|
||||
|
@ -41,7 +44,7 @@ class ActiveRotatedFilterFunction(Function):
|
|||
|
||||
@staticmethod
|
||||
@once_differentiable
|
||||
def backward(ctx, grad_out):
|
||||
def backward(ctx, grad_out: torch.Tensor) -> Tuple[torch.Tensor, None]:
|
||||
"""
|
||||
Args:
|
||||
grad_output (torch.Tensor): The gradiant of output features
|
||||
|
|
Loading…
Reference in New Issue