add docstring of ema and yolo_bricks(#55)

pull/41/head^2
wanghonglie 2022-09-21 15:44:52 +08:00 committed by GitHub
parent 4f4f00c602
commit 7b7c807eb5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 44 additions and 5 deletions

View File

@ -75,6 +75,11 @@ class ExpMomentumEMA(MMDET_ExpMomentumEMA):
averaged_param.lerp_(source_param, momentum)
def update_parameters(self, model: nn.Module):
"""Update the parameters after each training step.
Args:
model (nn.Module): The model of the parameter needs to be updated.
"""
if self.steps == 0:
for k, p_avg in self.avg_parameters.items():
p_avg.data.copy_(self.src_parameters[k].data)

View File

@ -16,6 +16,7 @@ if digit_version(torch.__version__) >= digit_version('1.7.0'):
else:
class SiLU(nn.Module):
"""Sigmoid Weighted Liner Unit."""
def __init__(self, inplace=True):
super().__init__()
@ -83,7 +84,10 @@ class SPPFBottleneck(BaseModule):
act_cfg=act_cfg)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward process."""
"""Forward process
Args:
x (Tensor): The input tensor.
"""
x = self.conv1(x)
if isinstance(self.kernel_sizes, int):
y1 = self.poolings(x)
@ -183,7 +187,13 @@ class RepVGGBlock(nn.Module):
act_cfg=None)
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
"""Forward process."""
"""Forward process.
Args:
inputs (Tensor): The input tensor.
Returns:
Tensor: The output tensor.
"""
if hasattr(self, 'rbr_reparam'):
return self.nonlinearity(self.se(self.rbr_reparam(inputs)))
@ -196,7 +206,11 @@ class RepVGGBlock(nn.Module):
self.se(self.rbr_dense(inputs) + self.rbr_1x1(inputs) + id_out))
def get_equivalent_kernel_bias(self):
"""Derives the equivalent kernel and bias in a differentiable way."""
"""Derives the equivalent kernel and bias in a differentiable way.
Returns:
tuple: Equivalent kernel and bias
"""
kernel3x3, bias3x3 = self._fuse_bn_tensor(self.rbr_dense)
kernel1x1, bias1x1 = self._fuse_bn_tensor(self.rbr_1x1)
kernelid, biasid = self._fuse_bn_tensor(self.rbr_identity)
@ -204,7 +218,13 @@ class RepVGGBlock(nn.Module):
kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid
def _pad_1x1_to_3x3_tensor(self, kernel1x1):
"""Pad 1x1 tensor to 3x3."""
"""Pad 1x1 tensor to 3x3.
Args:
kernel1x1 (Tensor): The input 1x1 kernel need to be padded.
Returns:
Tensor: 3x3 kernel after padded.
"""
if kernel1x1 is None:
return 0
else:
@ -278,7 +298,14 @@ class RepVGGBlock(nn.Module):
class RepStageBlock(nn.Module):
"""RepStageBlock is a stage block with rep-style basic block."""
"""RepStageBlock is a stage block with rep-style basic block.
Args:
in_channels (int): The input channels of this Module.
out_channels (int): The output channels of this Module.
n (int, tuple[int]): Number of blocks. Defaults to 1.
block (nn.Module): Basic unit of RepStage. Defaults to RepVGGBlock.
"""
def __init__(self,
in_channels: int,
@ -291,6 +318,13 @@ class RepStageBlock(nn.Module):
for _ in range(n - 1))) if n > 1 else None
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward process.
Args:
inputs (Tensor): The input tensor.
Returns:
Tensor: The output tensor.
"""
x = self.conv1(x)
if self.block is not None:
x = self.block(x)