fix: unable to export rep net
parent
03142ea32b
commit
f04b52343f
ppcls
arch/backbone
legendary_models
model_zoo
engine
|
@ -188,7 +188,7 @@ class RepDepthwiseSeparable(TheseusLayer):
|
|||
def forward(self, x):
|
||||
if self.use_rep:
|
||||
input_x = x
|
||||
if not self.training:
|
||||
if self.is_repped:
|
||||
x = self.act(self.dw_conv(x))
|
||||
else:
|
||||
y = self.dw_conv_list[0](x)
|
||||
|
@ -209,14 +209,12 @@ class RepDepthwiseSeparable(TheseusLayer):
|
|||
x = x + input_x
|
||||
return x
|
||||
|
||||
def eval(self):
|
||||
def rep(self):
|
||||
if self.use_rep:
|
||||
self.is_repped = True
|
||||
kernel, bias = self._get_equivalent_kernel_bias()
|
||||
self.dw_conv.weight.set_value(kernel)
|
||||
self.dw_conv.bias.set_value(bias)
|
||||
self.training = False
|
||||
for layer in self.sublayers():
|
||||
layer.eval()
|
||||
|
||||
def _get_equivalent_kernel_bias(self):
|
||||
kernel_sum = 0
|
||||
|
|
|
@ -124,13 +124,7 @@ class RepVGGBlock(nn.Layer):
|
|||
groups=groups)
|
||||
|
||||
def forward(self, inputs):
|
||||
if not self.training and not self.is_repped:
|
||||
self.rep()
|
||||
self.is_repped = True
|
||||
if self.training and self.is_repped:
|
||||
self.is_repped = False
|
||||
|
||||
if not self.training:
|
||||
if self.is_repped:
|
||||
return self.nonlinearity(self.rbr_reparam(inputs))
|
||||
|
||||
if self.rbr_identity is None:
|
||||
|
@ -154,6 +148,7 @@ class RepVGGBlock(nn.Layer):
|
|||
kernel, bias = self.get_equivalent_kernel_bias()
|
||||
self.rbr_reparam.weight.set_value(kernel)
|
||||
self.rbr_reparam.bias.set_value(bias)
|
||||
self.is_repped = True
|
||||
|
||||
def get_equivalent_kernel_bias(self):
|
||||
kernel3x3, bias3x3 = self._fuse_bn_tensor(self.rbr_dense)
|
||||
|
|
|
@ -452,6 +452,12 @@ class Engine(object):
|
|||
self.config["Global"]["pretrained_model"])
|
||||
|
||||
model.eval()
|
||||
|
||||
# for rep nets
|
||||
for layer in self.model.sublayers():
|
||||
if hasattr(layer, "rep"):
|
||||
layer.rep()
|
||||
|
||||
save_path = os.path.join(self.config["Global"]["save_inference_dir"],
|
||||
"inference")
|
||||
if model.quanter:
|
||||
|
|
Loading…
Reference in New Issue