Fix wrong InputSpec setting for AST Dy2St (#3176)

pull/3182/head
Nyakku Shigure 2024-07-05 14:05:02 +08:00 committed by GitHub
parent b1ee8f911b
commit 1dcab0a7bc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 5 additions and 3 deletions

View File

@ -55,7 +55,7 @@ def build_model(config, mode="train"):
return arch
def apply_to_static(config, model):
def apply_to_static(config, model, is_rec):
support_to_static = config['Global'].get('to_static', False)
if support_to_static:
@ -63,6 +63,8 @@ def apply_to_static(config, model):
if 'image_shape' in config['Global']:
specs = [InputSpec([None] + config['Global']['image_shape'])]
specs[0].stop_gradient = True
if is_rec:
specs.append(InputSpec([None, 1], 'int64', stop_gradient=True))
model = to_static(model, input_spec=specs)
logger.info("Successfully to apply @to_static with specs: {}".format(
specs))

View File

@ -226,7 +226,7 @@ class Engine(object):
# build model
self.model = build_model(self.config, self.mode)
# set @to_static for benchmark, skip this by default.
apply_to_static(self.config, self.model)
apply_to_static(self.config, self.model, is_rec=self.is_rec)
# load_pretrain
if self.config["Global"]["pretrained_model"] is not None:

View File

@ -137,7 +137,7 @@ def compute_feature(engine, name="gallery"):
has_camera = True
batch[2] = batch[2].reshape([-1, 1]).astype("int64")
with engine.auto_cast(is_eval=True):
out = engine.model(batch[0])
out = engine.model(batch[0], batch[1])
if "Student" in out:
out = out["Student"]