add support for static training (#6297)

* add support for static training

* fix assert info
pull/6331/head
littletomatodonkey 2022-05-16 16:57:13 +08:00
parent 6d08ff0a70
commit 1bb03b4dfb
2 changed files with 22 additions and 1 deletions

View File

@ -15,10 +15,13 @@
import copy
import importlib
from paddle.jit import to_static
from paddle.static import InputSpec
from .base_model import BaseModel
from .distillation_model import DistillationModel
__all__ = ['build_model']
__all__ = ["build_model", "apply_to_static"]
def build_model(config):
@ -30,3 +33,18 @@ def build_model(config):
mod = importlib.import_module(__name__)
arch = getattr(mod, name)(config)
return arch
def apply_to_static(model, config, logger):
if config["Global"].get("to_static", False) is not True:
return model
assert "image_shape" in config[
"Global"], "image_shape must be assigned for static training mode..."
supported_list = ["DB"]
assert config["Architecture"][
"algorithm"] in supported_list, f"algorithms that supports static training must in in {supported_list} but got {config['Architecture']['algorithm']}"
specs = [InputSpec([None] + config["Global"]["image_shape"])]
model = to_static(model, input_spec=specs)
logger.info("Successfully to apply @to_static with specs: {}".format(specs))
return model

View File

@ -35,6 +35,7 @@ from ppocr.postprocess import build_post_process
from ppocr.metrics import build_metric
from ppocr.utils.save_load import load_model
from ppocr.utils.utility import set_seed
from ppocr.modeling.architectures import apply_to_static
import tools.program as program
dist.get_world_size()
@ -121,6 +122,8 @@ def main(config, device, logger, vdl_writer):
if config['Global']['distributed']:
model = paddle.DataParallel(model)
model = apply_to_static(model, config, logger)
# build loss
loss_class = build_loss(config['Loss'])