mirror of https://github.com/open-mmlab/mmyolo.git
Fix typehint in YOLOv6 Head (#415)
parent
cdc359c2de
commit
13de22dfd2
|
@ -1,5 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Sequence, Tuple, Union
|
||||
from typing import List, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -129,7 +129,7 @@ class YOLOv6HeadModule(BaseModule):
|
|||
conv.bias.data.fill_(1.0)
|
||||
conv.weight.data.fill_(0.)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
def forward(self, x: Tuple[Tensor]) -> Tuple[List]:
|
||||
"""Forward features from the upstream network.
|
||||
|
||||
Args:
|
||||
|
@ -143,10 +143,9 @@ class YOLOv6HeadModule(BaseModule):
|
|||
return multi_apply(self.forward_single, x, self.stems, self.cls_convs,
|
||||
self.cls_preds, self.reg_convs, self.reg_preds)
|
||||
|
||||
def forward_single(self, x: Tensor, stem: nn.ModuleList,
|
||||
cls_conv: nn.ModuleList, cls_pred: nn.ModuleList,
|
||||
reg_conv: nn.ModuleList,
|
||||
reg_pred: nn.ModuleList) -> Tuple[Tensor, Tensor]:
|
||||
def forward_single(self, x: Tensor, stem: nn.Module, cls_conv: nn.Module,
|
||||
cls_pred: nn.Module, reg_conv: nn.Module,
|
||||
reg_pred: nn.Module) -> Tuple[Tensor, Tensor]:
|
||||
"""Forward feature of a single scale level."""
|
||||
y = stem(x)
|
||||
cls_x = y
|
||||
|
|
Loading…
Reference in New Issue