Fix typehint in YOLOv6 Head (#415)

pull/416/head
jason_w 2022-12-29 21:21:44 +08:00 committed by GitHub
parent cdc359c2de
commit 13de22dfd2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 5 additions and 6 deletions

View File

@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Sequence, Tuple, Union
from typing import List, Sequence, Tuple, Union
import torch
import torch.nn as nn
@ -129,7 +129,7 @@ class YOLOv6HeadModule(BaseModule):
conv.bias.data.fill_(1.0)
conv.weight.data.fill_(0.)
def forward(self, x: Tensor) -> Tensor:
def forward(self, x: Tuple[Tensor]) -> Tuple[List]:
"""Forward features from the upstream network.
Args:
@ -143,10 +143,9 @@ class YOLOv6HeadModule(BaseModule):
return multi_apply(self.forward_single, x, self.stems, self.cls_convs,
self.cls_preds, self.reg_convs, self.reg_preds)
def forward_single(self, x: Tensor, stem: nn.ModuleList,
cls_conv: nn.ModuleList, cls_pred: nn.ModuleList,
reg_conv: nn.ModuleList,
reg_pred: nn.ModuleList) -> Tuple[Tensor, Tensor]:
def forward_single(self, x: Tensor, stem: nn.Module, cls_conv: nn.Module,
cls_pred: nn.Module, reg_conv: nn.Module,
reg_pred: nn.Module) -> Tuple[Tensor, Tensor]:
"""Forward feature of a single scale level."""
y = stem(x)
cls_x = y