diff --git a/timm/models/rdnet.py b/timm/models/rdnet.py index 34c2101e..7c2e2e27 100644 --- a/timm/models/rdnet.py +++ b/timm/models/rdnet.py @@ -55,17 +55,17 @@ class BlockESE(nn.Module): class DenseBlock(nn.Module): def __init__( self, - num_input_features, - growth_rate, - bottleneck_width_ratio, - drop_path_rate, - drop_rate=0.0, - rand_gather_step_prob=0.0, - block_idx=0, - block_type="Block", - ls_init_value=1e-6, - norm_layer="layernorm2d", - act_layer="gelu", + num_input_features: int = 64, + growth_rate: int = 64, + bottleneck_width_ratio: float = 4.0, + drop_path_rate: float = 0.0, + drop_rate: float = 0.0, + rand_gather_step_prob: float = 0.0, + block_idx: int = 0, + block_type: str = "Block", + ls_init_value: float = 1e-6, + norm_layer: str = "layernorm2d", + act_layer: str = "gelu", ): super().__init__() self.drop_rate = drop_rate @@ -78,8 +78,7 @@ class DenseBlock(nn.Module): growth_rate = int(growth_rate) inter_chs = int(num_input_features * bottleneck_width_ratio / 8) * 8 - if self.drop_path_rate > 0: - self.drop_path = DropPath(drop_path_rate) + self.drop_path = DropPath(drop_path_rate) self.layers = eval(block_type)( in_chs=num_input_features, @@ -89,16 +88,14 @@ class DenseBlock(nn.Module): act_layer=act_layer, ) - def forward(self, x): - if isinstance(x, List): - x = torch.cat(x, 1) + def forward(self, x: List[torch.Tensor]) -> torch.Tensor: + x = torch.cat(x, 1) x = self.layers(x) if self.gamma is not None: x = x.mul(self.gamma.reshape(1, -1, 1, 1)) - if self.drop_path_rate > 0 and self.training: - x = self.drop_path(x) + x = self.drop_path(x) return x @@ -117,7 +114,7 @@ class DenseStage(nn.Sequential): self.add_module(f"dense_block{i}", layer) self.num_out_features = num_input_features - def forward(self, init_feature): + def forward(self, init_feature: torch.Tensor) -> torch.Tensor: features = [init_feature] for module in self: new_feature = module(features) @@ -127,15 +124,15 @@ class DenseStage(nn.Sequential): class RDNet(nn.Module): def __init__( - self, + self, in_chans: int = 3, # timm option [--in-chans] num_classes: int = 1000, # timm option [--num-classes] global_pool: str = 'avg', # timm option [--gp] - growth_rates: Tuple[int, ...] = (64, 104, 128, 128, 128, 128, 224), - num_blocks_list: Tuple[int, ...] = (3, 3, 3, 3, 3, 3, 3), - block_type: Tuple[str, ...] = ("Block", "Block", "BlockESE", "BlockESE", "BlockESE", "BlockESE", "BlockESE"), - is_downsample_block: Tuple[bool, ...] = (None, True, True, False, False, False, True), - bottleneck_width_ratio: int = 4, + growth_rates: Union[List[int], Tuple[int]] = (64, 104, 128, 128, 128, 128, 224), + num_blocks_list: Union[List[int], Tuple[int]] = (3, 3, 3, 3, 3, 3, 3), + block_type: Union[List[int], Tuple[int]] = ("Block", "Block", "BlockESE", "BlockESE", "BlockESE", "BlockESE", "BlockESE"), + is_downsample_block: Union[List[bool], Tuple[bool]] = (None, True, True, False, False, False, True), + bottleneck_width_ratio: float = 4.0, transition_compression_ratio: float = 0.5, ls_init_value: float = 1e-6, stem_type: str = 'patch',