Fix wrong InputSpec setting for AST Dy2St (#3176)
parent
b1ee8f911b
commit
1dcab0a7bc
|
@ -55,7 +55,7 @@ def build_model(config, mode="train"):
|
|||
return arch
|
||||
|
||||
|
||||
def apply_to_static(config, model):
|
||||
def apply_to_static(config, model, is_rec):
|
||||
support_to_static = config['Global'].get('to_static', False)
|
||||
|
||||
if support_to_static:
|
||||
|
@ -63,6 +63,8 @@ def apply_to_static(config, model):
|
|||
if 'image_shape' in config['Global']:
|
||||
specs = [InputSpec([None] + config['Global']['image_shape'])]
|
||||
specs[0].stop_gradient = True
|
||||
if is_rec:
|
||||
specs.append(InputSpec([None, 1], 'int64', stop_gradient=True))
|
||||
model = to_static(model, input_spec=specs)
|
||||
logger.info("Successfully to apply @to_static with specs: {}".format(
|
||||
specs))
|
||||
|
|
|
@ -226,7 +226,7 @@ class Engine(object):
|
|||
# build model
|
||||
self.model = build_model(self.config, self.mode)
|
||||
# set @to_static for benchmark, skip this by default.
|
||||
apply_to_static(self.config, self.model)
|
||||
apply_to_static(self.config, self.model, is_rec=self.is_rec)
|
||||
|
||||
# load_pretrain
|
||||
if self.config["Global"]["pretrained_model"] is not None:
|
||||
|
|
|
@ -137,7 +137,7 @@ def compute_feature(engine, name="gallery"):
|
|||
has_camera = True
|
||||
batch[2] = batch[2].reshape([-1, 1]).astype("int64")
|
||||
with engine.auto_cast(is_eval=True):
|
||||
out = engine.model(batch[0])
|
||||
out = engine.model(batch[0], batch[1])
|
||||
if "Student" in out:
|
||||
out = out["Student"]
|
||||
|
||||
|
|
Loading…
Reference in New Issue