mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
update code for torchscript
This commit is contained in:
parent
3b6a1d4d48
commit
7a866b6521
@ -55,17 +55,17 @@ class BlockESE(nn.Module):
|
|||||||
class DenseBlock(nn.Module):
|
class DenseBlock(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
num_input_features,
|
num_input_features: int = 64,
|
||||||
growth_rate,
|
growth_rate: int = 64,
|
||||||
bottleneck_width_ratio,
|
bottleneck_width_ratio: float = 4.0,
|
||||||
drop_path_rate,
|
drop_path_rate: float = 0.0,
|
||||||
drop_rate=0.0,
|
drop_rate: float = 0.0,
|
||||||
rand_gather_step_prob=0.0,
|
rand_gather_step_prob: float = 0.0,
|
||||||
block_idx=0,
|
block_idx: int = 0,
|
||||||
block_type="Block",
|
block_type: str = "Block",
|
||||||
ls_init_value=1e-6,
|
ls_init_value: float = 1e-6,
|
||||||
norm_layer="layernorm2d",
|
norm_layer: str = "layernorm2d",
|
||||||
act_layer="gelu",
|
act_layer: str = "gelu",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.drop_rate = drop_rate
|
self.drop_rate = drop_rate
|
||||||
@ -78,8 +78,7 @@ class DenseBlock(nn.Module):
|
|||||||
growth_rate = int(growth_rate)
|
growth_rate = int(growth_rate)
|
||||||
inter_chs = int(num_input_features * bottleneck_width_ratio / 8) * 8
|
inter_chs = int(num_input_features * bottleneck_width_ratio / 8) * 8
|
||||||
|
|
||||||
if self.drop_path_rate > 0:
|
self.drop_path = DropPath(drop_path_rate)
|
||||||
self.drop_path = DropPath(drop_path_rate)
|
|
||||||
|
|
||||||
self.layers = eval(block_type)(
|
self.layers = eval(block_type)(
|
||||||
in_chs=num_input_features,
|
in_chs=num_input_features,
|
||||||
@ -89,16 +88,14 @@ class DenseBlock(nn.Module):
|
|||||||
act_layer=act_layer,
|
act_layer=act_layer,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x: List[torch.Tensor]) -> torch.Tensor:
|
||||||
if isinstance(x, List):
|
x = torch.cat(x, 1)
|
||||||
x = torch.cat(x, 1)
|
|
||||||
x = self.layers(x)
|
x = self.layers(x)
|
||||||
|
|
||||||
if self.gamma is not None:
|
if self.gamma is not None:
|
||||||
x = x.mul(self.gamma.reshape(1, -1, 1, 1))
|
x = x.mul(self.gamma.reshape(1, -1, 1, 1))
|
||||||
|
|
||||||
if self.drop_path_rate > 0 and self.training:
|
x = self.drop_path(x)
|
||||||
x = self.drop_path(x)
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@ -117,7 +114,7 @@ class DenseStage(nn.Sequential):
|
|||||||
self.add_module(f"dense_block{i}", layer)
|
self.add_module(f"dense_block{i}", layer)
|
||||||
self.num_out_features = num_input_features
|
self.num_out_features = num_input_features
|
||||||
|
|
||||||
def forward(self, init_feature):
|
def forward(self, init_feature: torch.Tensor) -> torch.Tensor:
|
||||||
features = [init_feature]
|
features = [init_feature]
|
||||||
for module in self:
|
for module in self:
|
||||||
new_feature = module(features)
|
new_feature = module(features)
|
||||||
@ -127,15 +124,15 @@ class DenseStage(nn.Sequential):
|
|||||||
|
|
||||||
class RDNet(nn.Module):
|
class RDNet(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
in_chans: int = 3, # timm option [--in-chans]
|
in_chans: int = 3, # timm option [--in-chans]
|
||||||
num_classes: int = 1000, # timm option [--num-classes]
|
num_classes: int = 1000, # timm option [--num-classes]
|
||||||
global_pool: str = 'avg', # timm option [--gp]
|
global_pool: str = 'avg', # timm option [--gp]
|
||||||
growth_rates: Tuple[int, ...] = (64, 104, 128, 128, 128, 128, 224),
|
growth_rates: Union[List[int], Tuple[int]] = (64, 104, 128, 128, 128, 128, 224),
|
||||||
num_blocks_list: Tuple[int, ...] = (3, 3, 3, 3, 3, 3, 3),
|
num_blocks_list: Union[List[int], Tuple[int]] = (3, 3, 3, 3, 3, 3, 3),
|
||||||
block_type: Tuple[str, ...] = ("Block", "Block", "BlockESE", "BlockESE", "BlockESE", "BlockESE", "BlockESE"),
|
block_type: Union[List[int], Tuple[int]] = ("Block", "Block", "BlockESE", "BlockESE", "BlockESE", "BlockESE", "BlockESE"),
|
||||||
is_downsample_block: Tuple[bool, ...] = (None, True, True, False, False, False, True),
|
is_downsample_block: Union[List[bool], Tuple[bool]] = (None, True, True, False, False, False, True),
|
||||||
bottleneck_width_ratio: int = 4,
|
bottleneck_width_ratio: float = 4.0,
|
||||||
transition_compression_ratio: float = 0.5,
|
transition_compression_ratio: float = 0.5,
|
||||||
ls_init_value: float = 1e-6,
|
ls_init_value: float = 1e-6,
|
||||||
stem_type: str = 'patch',
|
stem_type: str = 'patch',
|
||||||
|
Loading…
x
Reference in New Issue
Block a user