update code for torchscript

This commit is contained in:
dong-hyun 2024-08-02 09:58:13 +09:00
parent 3b6a1d4d48
commit 7a866b6521

View File

@ -55,17 +55,17 @@ class BlockESE(nn.Module):
class DenseBlock(nn.Module): class DenseBlock(nn.Module):
def __init__( def __init__(
self, self,
num_input_features, num_input_features: int = 64,
growth_rate, growth_rate: int = 64,
bottleneck_width_ratio, bottleneck_width_ratio: float = 4.0,
drop_path_rate, drop_path_rate: float = 0.0,
drop_rate=0.0, drop_rate: float = 0.0,
rand_gather_step_prob=0.0, rand_gather_step_prob: float = 0.0,
block_idx=0, block_idx: int = 0,
block_type="Block", block_type: str = "Block",
ls_init_value=1e-6, ls_init_value: float = 1e-6,
norm_layer="layernorm2d", norm_layer: str = "layernorm2d",
act_layer="gelu", act_layer: str = "gelu",
): ):
super().__init__() super().__init__()
self.drop_rate = drop_rate self.drop_rate = drop_rate
@ -78,8 +78,7 @@ class DenseBlock(nn.Module):
growth_rate = int(growth_rate) growth_rate = int(growth_rate)
inter_chs = int(num_input_features * bottleneck_width_ratio / 8) * 8 inter_chs = int(num_input_features * bottleneck_width_ratio / 8) * 8
if self.drop_path_rate > 0: self.drop_path = DropPath(drop_path_rate)
self.drop_path = DropPath(drop_path_rate)
self.layers = eval(block_type)( self.layers = eval(block_type)(
in_chs=num_input_features, in_chs=num_input_features,
@ -89,16 +88,14 @@ class DenseBlock(nn.Module):
act_layer=act_layer, act_layer=act_layer,
) )
def forward(self, x): def forward(self, x: List[torch.Tensor]) -> torch.Tensor:
if isinstance(x, List): x = torch.cat(x, 1)
x = torch.cat(x, 1)
x = self.layers(x) x = self.layers(x)
if self.gamma is not None: if self.gamma is not None:
x = x.mul(self.gamma.reshape(1, -1, 1, 1)) x = x.mul(self.gamma.reshape(1, -1, 1, 1))
if self.drop_path_rate > 0 and self.training: x = self.drop_path(x)
x = self.drop_path(x)
return x return x
@ -117,7 +114,7 @@ class DenseStage(nn.Sequential):
self.add_module(f"dense_block{i}", layer) self.add_module(f"dense_block{i}", layer)
self.num_out_features = num_input_features self.num_out_features = num_input_features
def forward(self, init_feature): def forward(self, init_feature: torch.Tensor) -> torch.Tensor:
features = [init_feature] features = [init_feature]
for module in self: for module in self:
new_feature = module(features) new_feature = module(features)
@ -127,15 +124,15 @@ class DenseStage(nn.Sequential):
class RDNet(nn.Module): class RDNet(nn.Module):
def __init__( def __init__(
self, self,
in_chans: int = 3, # timm option [--in-chans] in_chans: int = 3, # timm option [--in-chans]
num_classes: int = 1000, # timm option [--num-classes] num_classes: int = 1000, # timm option [--num-classes]
global_pool: str = 'avg', # timm option [--gp] global_pool: str = 'avg', # timm option [--gp]
growth_rates: Tuple[int, ...] = (64, 104, 128, 128, 128, 128, 224), growth_rates: Union[List[int], Tuple[int]] = (64, 104, 128, 128, 128, 128, 224),
num_blocks_list: Tuple[int, ...] = (3, 3, 3, 3, 3, 3, 3), num_blocks_list: Union[List[int], Tuple[int]] = (3, 3, 3, 3, 3, 3, 3),
block_type: Tuple[str, ...] = ("Block", "Block", "BlockESE", "BlockESE", "BlockESE", "BlockESE", "BlockESE"), block_type: Union[List[int], Tuple[int]] = ("Block", "Block", "BlockESE", "BlockESE", "BlockESE", "BlockESE", "BlockESE"),
is_downsample_block: Tuple[bool, ...] = (None, True, True, False, False, False, True), is_downsample_block: Union[List[bool], Tuple[bool]] = (None, True, True, False, False, False, True),
bottleneck_width_ratio: int = 4, bottleneck_width_ratio: float = 4.0,
transition_compression_ratio: float = 0.5, transition_compression_ratio: float = 0.5,
ls_init_value: float = 1e-6, ls_init_value: float = 1e-6,
stem_type: str = 'patch', stem_type: str = 'patch',