From d5f1525334e1b111e4bfdf59fcd38eb9f8c9d3de Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Mon, 30 Oct 2023 02:37:50 +0530 Subject: [PATCH] include suggestions from review Co-Authored-By: Ross Wightman --- timm/layers/typing.py | 8 +++----- timm/models/mobilenetv3.py | 9 ++++----- timm/models/resnet.py | 11 +++++------ 3 files changed, 12 insertions(+), 16 deletions(-) diff --git a/timm/layers/typing.py b/timm/layers/typing.py index 35aa9f88..593fa5cc 100644 --- a/timm/layers/typing.py +++ b/timm/layers/typing.py @@ -1,9 +1,7 @@ -import functools -import types -from typing import Tuple, Union +from typing import Callable, Tuple, Type, Union -import torch.nn +import torch -LayerType = Union[type, str, types.FunctionType, functools.partial, torch.nn.Module] +LayerType = Union[str, Callable, Type[torch.nn.Module]] PadType = Union[str, int, Tuple[int, int]] diff --git a/timm/models/mobilenetv3.py b/timm/models/mobilenetv3.py index f6cd8e08..21f0d3f1 100644 --- a/timm/models/mobilenetv3.py +++ b/timm/models/mobilenetv3.py @@ -12,7 +12,6 @@ from typing import Callable, List, Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F -from torch import Tensor from torch.utils.checkpoint import checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD @@ -151,7 +150,7 @@ class MobileNetV3(nn.Module): self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled self.classifier = Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() - def forward_features(self, x: Tensor) -> Tensor: + def forward_features(self, x: torch.Tensor) -> torch.Tensor: x = self.conv_stem(x) x = self.bn1(x) if self.grad_checkpointing and not torch.jit.is_scripting(): @@ -160,7 +159,7 @@ class MobileNetV3(nn.Module): x = self.blocks(x) return x - def forward_head(self, x: Tensor, pre_logits: bool = False) -> Tensor: + def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor: x = self.global_pool(x) x = self.conv_head(x) x = self.act2(x) @@ -171,7 +170,7 @@ class MobileNetV3(nn.Module): x = F.dropout(x, p=self.drop_rate, training=self.training) return self.classifier(x) - def forward(self, x: Tensor) -> Tensor: + def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.forward_features(x) x = self.forward_head(x) return x @@ -262,7 +261,7 @@ class MobileNetV3Features(nn.Module): def set_grad_checkpointing(self, enable: bool = True): self.grad_checkpointing = enable - def forward(self, x: Tensor) -> List[Tensor]: + def forward(self, x: torch.Tensor) -> List[torch.Tensor]: x = self.conv_stem(x) x = self.bn1(x) x = self.act1(x) diff --git a/timm/models/resnet.py b/timm/models/resnet.py index 429a54c4..2549eb15 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -14,7 +14,6 @@ from typing import Any, Dict, List, Optional, Tuple, Type, Union import torch import torch.nn as nn import torch.nn.functional as F -from torch import Tensor from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import DropBlock2d, DropPath, AvgPool2dSame, BlurPool2d, GroupNorm, LayerType, create_attn, \ @@ -112,7 +111,7 @@ class BasicBlock(nn.Module): if getattr(self.bn2, 'weight', None) is not None: nn.init.zeros_(self.bn2.weight) - def forward(self, x: Tensor) -> Tensor: + def forward(self, x: torch.Tensor) -> torch.Tensor: shortcut = x x = self.conv1(x) @@ -212,7 +211,7 @@ class Bottleneck(nn.Module): if getattr(self.bn3, 'weight', None) is not None: nn.init.zeros_(self.bn3.weight) - def forward(self, x: Tensor) -> Tensor: + def forward(self, x: torch.Tensor) -> torch.Tensor: shortcut = x x = self.conv1(x) @@ -554,7 +553,7 @@ class ResNet(nn.Module): self.num_classes = num_classes self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) - def forward_features(self, x: Tensor) -> Tensor: + def forward_features(self, x: torch.Tensor) -> torch.Tensor: x = self.conv1(x) x = self.bn1(x) x = self.act1(x) @@ -569,13 +568,13 @@ class ResNet(nn.Module): x = self.layer4(x) return x - def forward_head(self, x: Tensor, pre_logits: bool = False) -> Tensor: + def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor: x = self.global_pool(x) if self.drop_rate: x = F.dropout(x, p=float(self.drop_rate), training=self.training) return x if pre_logits else self.fc(x) - def forward(self, x: Tensor) -> Tensor: + def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.forward_features(x) x = self.forward_head(x) return x