fix: support AMP infer
parent
5f88903e6e
commit
683adcda46
|
@ -105,7 +105,6 @@ DataLoader:
|
|||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
output_fp16: True
|
||||
channel_num: *image_channel
|
||||
sampler:
|
||||
name: DistributedBatchSampler
|
||||
|
@ -132,7 +131,6 @@ Infer:
|
|||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
output_fp16: True
|
||||
channel_num: *image_channel
|
||||
- ToCHWImage:
|
||||
PostProcess:
|
||||
|
|
|
@ -99,7 +99,6 @@ DataLoader:
|
|||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
output_fp16: True
|
||||
channel_num: *image_channel
|
||||
sampler:
|
||||
name: DistributedBatchSampler
|
||||
|
@ -126,7 +125,6 @@ Infer:
|
|||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
output_fp16: True
|
||||
channel_num: *image_channel
|
||||
- ToCHWImage:
|
||||
PostProcess:
|
||||
|
|
|
@ -239,7 +239,7 @@ class Engine(object):
|
|||
|
||||
self.amp_eval = self.config["AMP"].get("use_fp16_test", False)
|
||||
# TODO(gaotingquan): Paddle not yet support FP32 evaluation when training with AMPO2
|
||||
if self.config["Global"].get(
|
||||
if self.mode == "train" and self.config["Global"].get(
|
||||
"eval_during_train",
|
||||
True) and self.amp_level == "O2" and self.amp_eval == False:
|
||||
msg = "PaddlePaddle only support FP16 evaluation when training with AMP O2 now. "
|
||||
|
@ -269,10 +269,11 @@ class Engine(object):
|
|||
save_dtype='float32')
|
||||
# paddle version >= 2.3.0 or develop
|
||||
else:
|
||||
self.model = paddle.amp.decorate(
|
||||
models=self.model,
|
||||
level=self.amp_level,
|
||||
save_dtype='float32')
|
||||
if self.mode == "train" or self.amp_eval:
|
||||
self.model = paddle.amp.decorate(
|
||||
models=self.model,
|
||||
level=self.amp_level,
|
||||
save_dtype='float32')
|
||||
|
||||
if self.mode == "train" and len(self.train_loss_func.parameters(
|
||||
)) > 0:
|
||||
|
@ -431,7 +432,17 @@ class Engine(object):
|
|||
image_file_list.append(image_file)
|
||||
if len(batch_data) >= batch_size or idx == len(image_list) - 1:
|
||||
batch_tensor = paddle.to_tensor(batch_data)
|
||||
out = self.model(batch_tensor)
|
||||
|
||||
if self.amp and self.amp_eval:
|
||||
with paddle.amp.auto_cast(
|
||||
custom_black_list={
|
||||
"flatten_contiguous_range", "greater_than"
|
||||
},
|
||||
level=self.amp_level):
|
||||
out = self.model(batch_tensor)
|
||||
else:
|
||||
out = self.model(batch_tensor)
|
||||
|
||||
if isinstance(out, list):
|
||||
out = out[0]
|
||||
if isinstance(out, dict) and "logits" in out:
|
||||
|
|
Loading…
Reference in New Issue