add ghostnet

pull/161/head
wqz960 2020-06-11 18:04:37 +08:00
parent 5a1c4210e1
commit e8c3d72b40
2 changed files with 2 additions and 10 deletions

View File

@ -42,6 +42,7 @@ from .res2net_vd import Res2Net50_vd_48w_2s, Res2Net50_vd_26w_4s, Res2Net50_vd_1
from .hrnet import HRNet_W18_C, HRNet_W30_C, HRNet_W32_C, HRNet_W40_C, HRNet_W44_C, HRNet_W48_C, HRNet_W60_C, HRNet_W64_C, SE_HRNet_W18_C, SE_HRNet_W30_C, SE_HRNet_W32_C, SE_HRNet_W40_C, SE_HRNet_W44_C, SE_HRNet_W48_C, SE_HRNet_W60_C, SE_HRNet_W64_C
from .darts_gs import DARTS_GS_6M, DARTS_GS_4M
from .resnet_acnet import ResNet18_ACNet, ResNet34_ACNet, ResNet50_ACNet, ResNet101_ACNet, ResNet152_ACNet
from .ghostnet import GhostNet_0_5, GhostNet_1_0, GhostNet_1_3
# distillation model
from .distillation_models import ResNet50_vd_distill_MobileNetV3_large_x1_0, ResNeXt101_32x16d_wsl_distill_ResNet50_vd

View File

@ -109,10 +109,6 @@ class GhostNet():
initializer=fluid.initializer.Uniform(-stdv, stdv),
name=name + '_exc_weights'),
bias_attr=ParamAttr(name=name + '_exc_offset'))
#ones = fluid.layers.fill_constant(excitation.shape, "float32", 1)
#zeros = fluid.layers.fill_constant(excitation.shape, "float32", 0)
#excitation = fluid.layers.elementwise_max(excitation, zeros)
# excitation = fluid.layers.elementwise_min(excitation, ones)
excitation = fluid.layers.clip(x=excitation,
min=0,
max=1,
@ -167,10 +163,7 @@ class GhostNet():
name=name+"_cheap_operation",
data_format=data_format)
out = fluid.layers.concat([primary_conv, cheap_operation], axis=1, name=name+"_concat")
# return out[:, :self.oup, :, :]
print(self.oup)
print(out.shape)
return fluid.layers.slice(out, axes=[1], starts=[0], ends=[self.oup])
return out
def GhostBottleneck(self,
inp,
@ -251,8 +244,6 @@ class GhostNet():
for k, exp_size, c, use_se, s in self.cfgs:
output_channel = int(self._make_divisible(c*self.width_mult, 4))
hidden_channel = int(self._make_divisible(exp_size*self.width_mult, 4))
#print(output_channel)
#print(hidden_channel)
x = self.GhostBottleneck(inp=x,
hidden_dim=hidden_channel,
oup=output_channel,