mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Cleanup re-use of Dropout modules in Mlp modules after some twitter feedback :p
This commit is contained in:
parent
71f00bfe9e
commit
f658a72e72
@ -4,6 +4,8 @@ Hacked together by / Copyright 2020 Ross Wightman
|
||||
"""
|
||||
from torch import nn as nn
|
||||
|
||||
from .helpers import to_2tuple
|
||||
|
||||
|
||||
class Mlp(nn.Module):
|
||||
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
|
||||
@ -12,17 +14,20 @@ class Mlp(nn.Module):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
drop_probs = to_2tuple(drop)
|
||||
|
||||
self.fc1 = nn.Linear(in_features, hidden_features)
|
||||
self.act = act_layer()
|
||||
self.drop1 = nn.Dropout(drop_probs[0])
|
||||
self.fc2 = nn.Linear(hidden_features, out_features)
|
||||
self.drop = nn.Dropout(drop)
|
||||
self.drop2 = nn.Dropout(drop_probs[1])
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.drop(x)
|
||||
x = self.drop1(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop(x)
|
||||
x = self.drop2(x)
|
||||
return x
|
||||
|
||||
|
||||
@ -35,10 +40,13 @@ class GluMlp(nn.Module):
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
assert hidden_features % 2 == 0
|
||||
drop_probs = to_2tuple(drop)
|
||||
|
||||
self.fc1 = nn.Linear(in_features, hidden_features)
|
||||
self.act = act_layer()
|
||||
self.drop1 = nn.Dropout(drop_probs[0])
|
||||
self.fc2 = nn.Linear(hidden_features // 2, out_features)
|
||||
self.drop = nn.Dropout(drop)
|
||||
self.drop2 = nn.Dropout(drop_probs[1])
|
||||
|
||||
def init_weights(self):
|
||||
# override init of fc1 w/ gate portion set to weight near zero, bias=1
|
||||
@ -50,9 +58,9 @@ class GluMlp(nn.Module):
|
||||
x = self.fc1(x)
|
||||
x, gates = x.chunk(2, dim=-1)
|
||||
x = x * self.act(gates)
|
||||
x = self.drop(x)
|
||||
x = self.drop1(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop(x)
|
||||
x = self.drop2(x)
|
||||
return x
|
||||
|
||||
|
||||
@ -64,8 +72,11 @@ class GatedMlp(nn.Module):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
drop_probs = to_2tuple(drop)
|
||||
|
||||
self.fc1 = nn.Linear(in_features, hidden_features)
|
||||
self.act = act_layer()
|
||||
self.drop1 = nn.Dropout(drop_probs[0])
|
||||
if gate_layer is not None:
|
||||
assert hidden_features % 2 == 0
|
||||
self.gate = gate_layer(hidden_features)
|
||||
@ -73,15 +84,15 @@ class GatedMlp(nn.Module):
|
||||
else:
|
||||
self.gate = nn.Identity()
|
||||
self.fc2 = nn.Linear(hidden_features, out_features)
|
||||
self.drop = nn.Dropout(drop)
|
||||
self.drop2 = nn.Dropout(drop_probs[1])
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.drop(x)
|
||||
x = self.drop1(x)
|
||||
x = self.gate(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop(x)
|
||||
x = self.drop2(x)
|
||||
return x
|
||||
|
||||
|
||||
|
@ -45,6 +45,8 @@ class SpatialMlp(nn.Module):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
drop_probs = to_2tuple(drop)
|
||||
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.spatial_conv = spatial_conv
|
||||
@ -55,9 +57,9 @@ class SpatialMlp(nn.Module):
|
||||
hidden_features = in_features * 2
|
||||
self.hidden_features = hidden_features
|
||||
self.group = group
|
||||
self.drop = nn.Dropout(drop)
|
||||
self.conv1 = nn.Conv2d(in_features, hidden_features, 1, stride=1, padding=0, bias=False)
|
||||
self.act1 = act_layer()
|
||||
self.drop1 = nn.Dropout(drop_probs[0])
|
||||
if self.spatial_conv:
|
||||
self.conv2 = nn.Conv2d(
|
||||
hidden_features, hidden_features, 3, stride=1, padding=1, groups=self.group, bias=False)
|
||||
@ -66,16 +68,17 @@ class SpatialMlp(nn.Module):
|
||||
self.conv2 = None
|
||||
self.act2 = None
|
||||
self.conv3 = nn.Conv2d(hidden_features, out_features, 1, stride=1, padding=0, bias=False)
|
||||
self.drop3 = nn.Dropout(drop_probs[1])
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.act1(x)
|
||||
x = self.drop(x)
|
||||
x = self.drop1(x)
|
||||
if self.conv2 is not None:
|
||||
x = self.conv2(x)
|
||||
x = self.act2(x)
|
||||
x = self.conv3(x)
|
||||
x = self.drop(x)
|
||||
x = self.drop3(x)
|
||||
return x
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user