Support training without amp
parent
4553d22cf1
commit
f313a6d873
|
@ -98,7 +98,6 @@ def train_epoch_metabin(engine, epoch_id, print_batch_step):
|
||||||
for key, value in mtest_loss_dict.items()}
|
for key, value in mtest_loss_dict.items()}
|
||||||
}
|
}
|
||||||
# step lr (by iter)
|
# step lr (by iter)
|
||||||
# the last lr_sch is cyclic_lr
|
|
||||||
for i in range(len(engine.lr_sch)):
|
for i in range(len(engine.lr_sch)):
|
||||||
if not getattr(engine.lr_sch[i], "by_epoch", False):
|
if not getattr(engine.lr_sch[i], "by_epoch", False):
|
||||||
engine.lr_sch[i].step()
|
engine.lr_sch[i].step()
|
||||||
|
@ -117,7 +116,6 @@ def train_epoch_metabin(engine, epoch_id, print_batch_step):
|
||||||
tic = time.time()
|
tic = time.time()
|
||||||
|
|
||||||
# step lr(by epoch)
|
# step lr(by epoch)
|
||||||
# the last lr_sch is cyclic_lr
|
|
||||||
for i in range(len(engine.lr_sch)):
|
for i in range(len(engine.lr_sch)):
|
||||||
if getattr(engine.lr_sch[i], "by_epoch", False) and \
|
if getattr(engine.lr_sch[i], "by_epoch", False) and \
|
||||||
type_name(engine.lr_sch[i]) != "ReduceOnPlateau":
|
type_name(engine.lr_sch[i]) != "ReduceOnPlateau":
|
||||||
|
@ -191,10 +189,16 @@ def get_meta_data(meta_dataloader_iter, num_domain):
|
||||||
def forward(engine, batch, loss_func):
|
def forward(engine, batch, loss_func):
|
||||||
batch_info = defaultdict()
|
batch_info = defaultdict()
|
||||||
batch_info = {"label": batch[1], "domain": batch[2]}
|
batch_info = {"label": batch[1], "domain": batch[2]}
|
||||||
amp_level = engine.config["AMP"].get("level", "O1").upper()
|
if engine.amp:
|
||||||
with paddle.amp.auto_cast(
|
amp_level = engine.config["AMP"].get("level", "O1").upper()
|
||||||
custom_black_list={"flatten_contiguous_range", "greater_than"},
|
with paddle.amp.auto_cast(
|
||||||
level=amp_level):
|
custom_black_list={
|
||||||
|
"flatten_contiguous_range", "greater_than"
|
||||||
|
},
|
||||||
|
level=amp_level):
|
||||||
|
out = engine.model(batch[0], batch[1])
|
||||||
|
loss_dict = loss_func(out, batch_info)
|
||||||
|
else:
|
||||||
out = engine.model(batch[0], batch[1])
|
out = engine.model(batch[0], batch[1])
|
||||||
loss_dict = loss_func(out, batch_info)
|
loss_dict = loss_func(out, batch_info)
|
||||||
return out, loss_dict
|
return out, loss_dict
|
||||||
|
@ -202,9 +206,13 @@ def forward(engine, batch, loss_func):
|
||||||
|
|
||||||
def backward(engine, loss, optimizer):
|
def backward(engine, loss, optimizer):
|
||||||
optimizer.clear_grad()
|
optimizer.clear_grad()
|
||||||
scaled = engine.scaler.scale(loss)
|
if engine.amp:
|
||||||
scaled.backward()
|
scaled = engine.scaler.scale(loss)
|
||||||
engine.scaler.minimize(optimizer, scaled)
|
scaled.backward()
|
||||||
|
engine.scaler.minimize(optimizer, scaled)
|
||||||
|
else:
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
for name, layer in engine.model.backbone.named_sublayers():
|
for name, layer in engine.model.backbone.named_sublayers():
|
||||||
if "gate" == name.split('.')[-1]:
|
if "gate" == name.split('.')[-1]:
|
||||||
layer.clip_gate()
|
layer.clip_gate()
|
||||||
|
|
Loading…
Reference in New Issue