include suggestions from review

Co-Authored-By: Ross Wightman <rwightman@gmail.com>
This commit is contained in:
a-r-r-o-w 2023-10-30 02:37:50 +05:30 committed by Ross Wightman
parent 5f14bdd564
commit d5f1525334
3 changed files with 12 additions and 16 deletions

View File

@ -1,9 +1,7 @@
import functools
import types
from typing import Tuple, Union
from typing import Callable, Tuple, Type, Union
import torch.nn
import torch
LayerType = Union[type, str, types.FunctionType, functools.partial, torch.nn.Module]
LayerType = Union[str, Callable, Type[torch.nn.Module]]
PadType = Union[str, int, Tuple[int, int]]

View File

@ -12,7 +12,6 @@ from typing import Callable, List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.utils.checkpoint import checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
@ -151,7 +150,7 @@ class MobileNetV3(nn.Module):
self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled
self.classifier = Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x: Tensor) -> Tensor:
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
x = self.conv_stem(x)
x = self.bn1(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
@ -160,7 +159,7 @@ class MobileNetV3(nn.Module):
x = self.blocks(x)
return x
def forward_head(self, x: Tensor, pre_logits: bool = False) -> Tensor:
def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
x = self.global_pool(x)
x = self.conv_head(x)
x = self.act2(x)
@ -171,7 +170,7 @@ class MobileNetV3(nn.Module):
x = F.dropout(x, p=self.drop_rate, training=self.training)
return self.classifier(x)
def forward(self, x: Tensor) -> Tensor:
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.forward_features(x)
x = self.forward_head(x)
return x
@ -262,7 +261,7 @@ class MobileNetV3Features(nn.Module):
def set_grad_checkpointing(self, enable: bool = True):
self.grad_checkpointing = enable
def forward(self, x: Tensor) -> List[Tensor]:
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
x = self.conv_stem(x)
x = self.bn1(x)
x = self.act1(x)

View File

@ -14,7 +14,6 @@ from typing import Any, Dict, List, Optional, Tuple, Type, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import DropBlock2d, DropPath, AvgPool2dSame, BlurPool2d, GroupNorm, LayerType, create_attn, \
@ -112,7 +111,7 @@ class BasicBlock(nn.Module):
if getattr(self.bn2, 'weight', None) is not None:
nn.init.zeros_(self.bn2.weight)
def forward(self, x: Tensor) -> Tensor:
def forward(self, x: torch.Tensor) -> torch.Tensor:
shortcut = x
x = self.conv1(x)
@ -212,7 +211,7 @@ class Bottleneck(nn.Module):
if getattr(self.bn3, 'weight', None) is not None:
nn.init.zeros_(self.bn3.weight)
def forward(self, x: Tensor) -> Tensor:
def forward(self, x: torch.Tensor) -> torch.Tensor:
shortcut = x
x = self.conv1(x)
@ -554,7 +553,7 @@ class ResNet(nn.Module):
self.num_classes = num_classes
self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
def forward_features(self, x: Tensor) -> Tensor:
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
x = self.conv1(x)
x = self.bn1(x)
x = self.act1(x)
@ -569,13 +568,13 @@ class ResNet(nn.Module):
x = self.layer4(x)
return x
def forward_head(self, x: Tensor, pre_logits: bool = False) -> Tensor:
def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
x = self.global_pool(x)
if self.drop_rate:
x = F.dropout(x, p=float(self.drop_rate), training=self.training)
return x if pre_logits else self.fc(x)
def forward(self, x: Tensor) -> Tensor:
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.forward_features(x)
x = self.forward_head(x)
return x