update code for torchscript

This commit is contained in:
dong-hyun 2024-08-02 09:58:13 +09:00
parent 3b6a1d4d48
commit 7a866b6521

View File

@ -55,17 +55,17 @@ class BlockESE(nn.Module):
class DenseBlock(nn.Module):
def __init__(
self,
num_input_features,
growth_rate,
bottleneck_width_ratio,
drop_path_rate,
drop_rate=0.0,
rand_gather_step_prob=0.0,
block_idx=0,
block_type="Block",
ls_init_value=1e-6,
norm_layer="layernorm2d",
act_layer="gelu",
num_input_features: int = 64,
growth_rate: int = 64,
bottleneck_width_ratio: float = 4.0,
drop_path_rate: float = 0.0,
drop_rate: float = 0.0,
rand_gather_step_prob: float = 0.0,
block_idx: int = 0,
block_type: str = "Block",
ls_init_value: float = 1e-6,
norm_layer: str = "layernorm2d",
act_layer: str = "gelu",
):
super().__init__()
self.drop_rate = drop_rate
@ -78,8 +78,7 @@ class DenseBlock(nn.Module):
growth_rate = int(growth_rate)
inter_chs = int(num_input_features * bottleneck_width_ratio / 8) * 8
if self.drop_path_rate > 0:
self.drop_path = DropPath(drop_path_rate)
self.drop_path = DropPath(drop_path_rate)
self.layers = eval(block_type)(
in_chs=num_input_features,
@ -89,16 +88,14 @@ class DenseBlock(nn.Module):
act_layer=act_layer,
)
def forward(self, x):
if isinstance(x, List):
x = torch.cat(x, 1)
def forward(self, x: List[torch.Tensor]) -> torch.Tensor:
x = torch.cat(x, 1)
x = self.layers(x)
if self.gamma is not None:
x = x.mul(self.gamma.reshape(1, -1, 1, 1))
if self.drop_path_rate > 0 and self.training:
x = self.drop_path(x)
x = self.drop_path(x)
return x
@ -117,7 +114,7 @@ class DenseStage(nn.Sequential):
self.add_module(f"dense_block{i}", layer)
self.num_out_features = num_input_features
def forward(self, init_feature):
def forward(self, init_feature: torch.Tensor) -> torch.Tensor:
features = [init_feature]
for module in self:
new_feature = module(features)
@ -127,15 +124,15 @@ class DenseStage(nn.Sequential):
class RDNet(nn.Module):
def __init__(
self,
self,
in_chans: int = 3, # timm option [--in-chans]
num_classes: int = 1000, # timm option [--num-classes]
global_pool: str = 'avg', # timm option [--gp]
growth_rates: Tuple[int, ...] = (64, 104, 128, 128, 128, 128, 224),
num_blocks_list: Tuple[int, ...] = (3, 3, 3, 3, 3, 3, 3),
block_type: Tuple[str, ...] = ("Block", "Block", "BlockESE", "BlockESE", "BlockESE", "BlockESE", "BlockESE"),
is_downsample_block: Tuple[bool, ...] = (None, True, True, False, False, False, True),
bottleneck_width_ratio: int = 4,
growth_rates: Union[List[int], Tuple[int]] = (64, 104, 128, 128, 128, 128, 224),
num_blocks_list: Union[List[int], Tuple[int]] = (3, 3, 3, 3, 3, 3, 3),
block_type: Union[List[int], Tuple[int]] = ("Block", "Block", "BlockESE", "BlockESE", "BlockESE", "BlockESE", "BlockESE"),
is_downsample_block: Union[List[bool], Tuple[bool]] = (None, True, True, False, False, False, True),
bottleneck_width_ratio: float = 4.0,
transition_compression_ratio: float = 0.5,
ls_init_value: float = 1e-6,
stem_type: str = 'patch',