mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
include suggestions from review
Co-Authored-By: Ross Wightman <rwightman@gmail.com>
This commit is contained in:
parent
5f14bdd564
commit
d5f1525334
@ -1,9 +1,7 @@
|
||||
import functools
|
||||
import types
|
||||
from typing import Tuple, Union
|
||||
from typing import Callable, Tuple, Type, Union
|
||||
|
||||
import torch.nn
|
||||
import torch
|
||||
|
||||
|
||||
LayerType = Union[type, str, types.FunctionType, functools.partial, torch.nn.Module]
|
||||
LayerType = Union[str, Callable, Type[torch.nn.Module]]
|
||||
PadType = Union[str, int, Tuple[int, int]]
|
||||
|
@ -12,7 +12,6 @@ from typing import Callable, List, Optional, Tuple
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||
@ -151,7 +150,7 @@ class MobileNetV3(nn.Module):
|
||||
self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled
|
||||
self.classifier = Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
def forward_features(self, x: Tensor) -> Tensor:
|
||||
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.conv_stem(x)
|
||||
x = self.bn1(x)
|
||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||
@ -160,7 +159,7 @@ class MobileNetV3(nn.Module):
|
||||
x = self.blocks(x)
|
||||
return x
|
||||
|
||||
def forward_head(self, x: Tensor, pre_logits: bool = False) -> Tensor:
|
||||
def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
|
||||
x = self.global_pool(x)
|
||||
x = self.conv_head(x)
|
||||
x = self.act2(x)
|
||||
@ -171,7 +170,7 @@ class MobileNetV3(nn.Module):
|
||||
x = F.dropout(x, p=self.drop_rate, training=self.training)
|
||||
return self.classifier(x)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.forward_features(x)
|
||||
x = self.forward_head(x)
|
||||
return x
|
||||
@ -262,7 +261,7 @@ class MobileNetV3Features(nn.Module):
|
||||
def set_grad_checkpointing(self, enable: bool = True):
|
||||
self.grad_checkpointing = enable
|
||||
|
||||
def forward(self, x: Tensor) -> List[Tensor]:
|
||||
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
|
||||
x = self.conv_stem(x)
|
||||
x = self.bn1(x)
|
||||
x = self.act1(x)
|
||||
|
@ -14,7 +14,6 @@ from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import DropBlock2d, DropPath, AvgPool2dSame, BlurPool2d, GroupNorm, LayerType, create_attn, \
|
||||
@ -112,7 +111,7 @@ class BasicBlock(nn.Module):
|
||||
if getattr(self.bn2, 'weight', None) is not None:
|
||||
nn.init.zeros_(self.bn2.weight)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
shortcut = x
|
||||
|
||||
x = self.conv1(x)
|
||||
@ -212,7 +211,7 @@ class Bottleneck(nn.Module):
|
||||
if getattr(self.bn3, 'weight', None) is not None:
|
||||
nn.init.zeros_(self.bn3.weight)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
shortcut = x
|
||||
|
||||
x = self.conv1(x)
|
||||
@ -554,7 +553,7 @@ class ResNet(nn.Module):
|
||||
self.num_classes = num_classes
|
||||
self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
|
||||
|
||||
def forward_features(self, x: Tensor) -> Tensor:
|
||||
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.act1(x)
|
||||
@ -569,13 +568,13 @@ class ResNet(nn.Module):
|
||||
x = self.layer4(x)
|
||||
return x
|
||||
|
||||
def forward_head(self, x: Tensor, pre_logits: bool = False) -> Tensor:
|
||||
def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
|
||||
x = self.global_pool(x)
|
||||
if self.drop_rate:
|
||||
x = F.dropout(x, p=float(self.drop_rate), training=self.training)
|
||||
return x if pre_logits else self.fc(x)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.forward_features(x)
|
||||
x = self.forward_head(x)
|
||||
return x
|
||||
|
Loading…
x
Reference in New Issue
Block a user