fix: wrong base class & simplify train func
parent
f02a4630b2
commit
a9b8432597
|
@ -22,7 +22,7 @@ import paddle.nn.functional as F
|
|||
from ..utils import get_param_attr_dict
|
||||
|
||||
|
||||
class MetaBN1D(nn.BatchNorm2D):
|
||||
class MetaBN1D(nn.BatchNorm1D):
|
||||
def forward(self, inputs, opt={}):
|
||||
mode = opt.get("bn_mode", "general") if self.training else "eval"
|
||||
if mode == "general": # update, but not apply running_mean/var
|
||||
|
|
|
@ -141,15 +141,15 @@ def setup_opt(engine, stage):
|
|||
opt["bn_mode"] = "hold"
|
||||
opt["enable_inside_update"] = True
|
||||
opt["lr_gate"] = norm_lr * cyclic_lr
|
||||
for name, layer in engine.model.backbone.named_sublayers():
|
||||
if "bn" == name.split('.')[-1]:
|
||||
for layer in engine.model.backbone.sublayers():
|
||||
if type_name(layer) == "MetaBIN":
|
||||
layer.setup_opt(opt)
|
||||
engine.model.neck.setup_opt(opt)
|
||||
|
||||
|
||||
def reset_opt(model):
|
||||
for name, layer in model.backbone.named_sublayers():
|
||||
if "bn" == name.split('.')[-1]:
|
||||
for layer in model.backbone.sublayers():
|
||||
if type_name(layer) == "MetaBIN":
|
||||
layer.reset_opt()
|
||||
model.neck.reset_opt()
|
||||
|
||||
|
|
Loading…
Reference in New Issue