rm op black list in amp
the op flatten_contiguous_range and greater_than has supported amp mode since paddle 2.4pull/2636/merge
parent
a7ba6eabd2
commit
f37cb543b1
|
@ -504,11 +504,7 @@ class Engine(object):
|
|||
batch_tensor = paddle.to_tensor(batch_data)
|
||||
|
||||
if self.amp and self.amp_eval:
|
||||
with paddle.amp.auto_cast(
|
||||
custom_black_list={
|
||||
"flatten_contiguous_range", "greater_than"
|
||||
},
|
||||
level=self.amp_level):
|
||||
with paddle.amp.auto_cast(level=self.amp_level):
|
||||
out = self.model(batch_tensor)
|
||||
else:
|
||||
out = self.model(batch_tensor)
|
||||
|
|
|
@ -56,11 +56,7 @@ def classification_eval(engine, epoch_id=0):
|
|||
|
||||
# image input
|
||||
if engine.amp and engine.amp_eval:
|
||||
with paddle.amp.auto_cast(
|
||||
custom_black_list={
|
||||
"flatten_contiguous_range", "greater_than"
|
||||
},
|
||||
level=engine.amp_level):
|
||||
with paddle.amp.auto_cast(level=engine.amp_level):
|
||||
out = engine.model(batch[0])
|
||||
else:
|
||||
out = engine.model(batch[0])
|
||||
|
@ -114,11 +110,7 @@ def classification_eval(engine, epoch_id=0):
|
|||
# calc loss
|
||||
if engine.eval_loss_func is not None:
|
||||
if engine.amp and engine.amp_eval:
|
||||
with paddle.amp.auto_cast(
|
||||
custom_black_list={
|
||||
"flatten_contiguous_range", "greater_than"
|
||||
},
|
||||
level=engine.amp_level):
|
||||
with paddle.amp.auto_cast(level=engine.amp_level):
|
||||
loss_dict = engine.eval_loss_func(preds, labels)
|
||||
else:
|
||||
loss_dict = engine.eval_loss_func(preds, labels)
|
||||
|
|
|
@ -137,11 +137,7 @@ def compute_feature(engine, name="gallery"):
|
|||
has_camera = True
|
||||
batch[2] = batch[2].reshape([-1, 1]).astype("int64")
|
||||
if engine.amp and engine.amp_eval:
|
||||
with paddle.amp.auto_cast(
|
||||
custom_black_list={
|
||||
"flatten_contiguous_range", "greater_than"
|
||||
},
|
||||
level=engine.amp_level):
|
||||
with paddle.amp.auto_cast(level=engine.amp_level):
|
||||
out = engine.model(batch[0])
|
||||
else:
|
||||
out = engine.model(batch[0])
|
||||
|
|
|
@ -50,11 +50,7 @@ def train_epoch(engine, epoch_id, print_batch_step):
|
|||
# image input
|
||||
if engine.amp:
|
||||
amp_level = engine.config["AMP"].get("level", "O1").upper()
|
||||
with paddle.amp.auto_cast(
|
||||
custom_black_list={
|
||||
"flatten_contiguous_range", "greater_than"
|
||||
},
|
||||
level=amp_level):
|
||||
with paddle.amp.auto_cast(level=amp_level):
|
||||
out = forward(engine, batch)
|
||||
loss_dict = engine.train_loss_func(out, batch[1])
|
||||
else:
|
||||
|
|
|
@ -64,11 +64,7 @@ def train_epoch_fixmatch(engine, epoch_id, print_batch_step):
|
|||
# image input
|
||||
if engine.amp:
|
||||
amp_level = engine.config['AMP'].get("level", "O1").upper()
|
||||
with paddle.amp.auto_cast(
|
||||
custom_black_list={
|
||||
"flatten_contiguous_range", "greater_than"
|
||||
},
|
||||
level=amp_level):
|
||||
with paddle.amp.auto_cast(level=amp_level):
|
||||
loss_dict, logits_label = get_loss(
|
||||
engine, inputs, batch_size_label, temperture, threshold,
|
||||
targets_x)
|
||||
|
|
|
@ -191,11 +191,7 @@ def forward(engine, batch, loss_func):
|
|||
batch_info = {"label": batch[1], "domain": batch[2]}
|
||||
if engine.amp:
|
||||
amp_level = engine.config["AMP"].get("level", "O1").upper()
|
||||
with paddle.amp.auto_cast(
|
||||
custom_black_list={
|
||||
"flatten_contiguous_range", "greater_than"
|
||||
},
|
||||
level=amp_level):
|
||||
with paddle.amp.auto_cast(level=amp_level):
|
||||
out = engine.model(batch[0], batch[1])
|
||||
loss_dict = loss_func(out, batch_info)
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue