Weights on hf hub, bicubic yields slightly better eval

pull/1919/head
Ross Wightman 2023-08-19 16:25:45 -07:00
parent e7f97cb5ce
commit 69e0ca2e36
1 changed files with 33 additions and 16 deletions

View File

@ -111,13 +111,14 @@ class RepGhostModule(nn.Module):
if len(self.fusion_conv) == 0 and len(self.fusion_bn) == 0: if len(self.fusion_conv) == 0 and len(self.fusion_bn) == 0:
return return
kernel, bias = self.get_equivalent_kernel_bias() kernel, bias = self.get_equivalent_kernel_bias()
self.cheap_operation = nn.Conv2d(in_channels=self.cheap_operation[0].in_channels, self.cheap_operation = nn.Conv2d(
out_channels=self.cheap_operation[0].out_channels, in_channels=self.cheap_operation[0].in_channels,
kernel_size=self.cheap_operation[0].kernel_size, out_channels=self.cheap_operation[0].out_channels,
padding=self.cheap_operation[0].padding, kernel_size=self.cheap_operation[0].kernel_size,
dilation=self.cheap_operation[0].dilation, padding=self.cheap_operation[0].padding,
groups=self.cheap_operation[0].groups, dilation=self.cheap_operation[0].dilation,
bias=True) groups=self.cheap_operation[0].groups,
bias=True)
self.cheap_operation.weight.data = kernel self.cheap_operation.weight.data = kernel
self.cheap_operation.bias.data = bias self.cheap_operation.bias.data = bias
self.__delattr__('fusion_conv') self.__delattr__('fusion_conv')
@ -377,7 +378,7 @@ def _create_repghostnet(variant, width=1.0, pretrained=False, **kwargs):
def _cfg(url='', **kwargs): def _cfg(url='', **kwargs):
return { return {
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': 0.875, 'interpolation': 'bilinear', 'crop_pct': 0.875, 'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'conv_stem', 'classifier': 'classifier', 'first_conv': 'conv_stem', 'classifier': 'classifier',
**kwargs **kwargs
@ -386,21 +387,37 @@ def _cfg(url='', **kwargs):
default_cfgs = generate_default_cfgs({ default_cfgs = generate_default_cfgs({
'repghostnet_050.in1k': _cfg( 'repghostnet_050.in1k': _cfg(
url='https://github.com/ChengpengChen/RepGhost/releases/download/RepGhost/repghostnet_0_5x_43M_66.95.pth.tar'), hf_hub_id='timm/',
# url='https://github.com/ChengpengChen/RepGhost/releases/download/RepGhost/repghostnet_0_5x_43M_66.95.pth.tar'
),
'repghostnet_058.in1k': _cfg( 'repghostnet_058.in1k': _cfg(
url='https://github.com/ChengpengChen/RepGhost/releases/download/RepGhost/repghostnet_0_58x_60M_68.94.pth.tar'), hf_hub_id='timm/',
# url='https://github.com/ChengpengChen/RepGhost/releases/download/RepGhost/repghostnet_0_58x_60M_68.94.pth.tar'
),
'repghostnet_080.in1k': _cfg( 'repghostnet_080.in1k': _cfg(
url='https://github.com/ChengpengChen/RepGhost/releases/download/RepGhost/repghostnet_0_8x_96M_72.24.pth.tar'), hf_hub_id='timm/',
# url='https://github.com/ChengpengChen/RepGhost/releases/download/RepGhost/repghostnet_0_8x_96M_72.24.pth.tar'
),
'repghostnet_100.in1k': _cfg( 'repghostnet_100.in1k': _cfg(
url='https://github.com/ChengpengChen/RepGhost/releases/download/RepGhost/repghostnet_1_0x_142M_74.22.pth.tar'), hf_hub_id='timm/',
# url='https://github.com/ChengpengChen/RepGhost/releases/download/RepGhost/repghostnet_1_0x_142M_74.22.pth.tar'
),
'repghostnet_111.in1k': _cfg( 'repghostnet_111.in1k': _cfg(
url='https://github.com/ChengpengChen/RepGhost/releases/download/RepGhost/repghostnet_1_11x_170M_75.07.pth.tar'), hf_hub_id='timm/',
# url='https://github.com/ChengpengChen/RepGhost/releases/download/RepGhost/repghostnet_1_11x_170M_75.07.pth.tar'
),
'repghostnet_130.in1k': _cfg( 'repghostnet_130.in1k': _cfg(
url='https://github.com/ChengpengChen/RepGhost/releases/download/RepGhost/repghostnet_1_3x_231M_76.37.pth.tar'), hf_hub_id='timm/',
# url='https://github.com/ChengpengChen/RepGhost/releases/download/RepGhost/repghostnet_1_3x_231M_76.37.pth.tar'
),
'repghostnet_150.in1k': _cfg( 'repghostnet_150.in1k': _cfg(
url='https://github.com/ChengpengChen/RepGhost/releases/download/RepGhost/repghostnet_1_5x_301M_77.45.pth.tar'), hf_hub_id='timm/',
# url='https://github.com/ChengpengChen/RepGhost/releases/download/RepGhost/repghostnet_1_5x_301M_77.45.pth.tar'
),
'repghostnet_200.in1k': _cfg( 'repghostnet_200.in1k': _cfg(
url='https://github.com/ChengpengChen/RepGhost/releases/download/RepGhost/repghostnet_2_0x_516M_78.81.pth.tar'), hf_hub_id='timm/',
# url='https://github.com/ChengpengChen/RepGhost/releases/download/RepGhost/repghostnet_2_0x_516M_78.81.pth.tar'
),
}) })