Respect `is_rec` when eval (#3187)
parent
caa6fdf57d
commit
6aa3047157
|
@ -56,7 +56,10 @@ def classification_eval(engine, epoch_id=0):
|
|||
|
||||
# image input
|
||||
with engine.auto_cast(is_eval=True):
|
||||
out = engine.model(batch[0], batch[1])
|
||||
if engine.is_rec:
|
||||
out = engine.model(batch[0], batch[1])
|
||||
else:
|
||||
out = engine.model(batch[0])
|
||||
|
||||
# just for DistributedBatchSampler issue: repeat sampling
|
||||
current_samples = batch_size * paddle.distributed.get_world_size()
|
||||
|
|
|
@ -137,7 +137,10 @@ def compute_feature(engine, name="gallery"):
|
|||
has_camera = True
|
||||
batch[2] = batch[2].reshape([-1, 1]).astype("int64")
|
||||
with engine.auto_cast(is_eval=True):
|
||||
out = engine.model(batch[0], batch[1])
|
||||
if engine.is_rec:
|
||||
out = engine.model(batch[0], batch[1])
|
||||
else:
|
||||
out = engine.model(batch[0])
|
||||
if "Student" in out:
|
||||
out = out["Student"]
|
||||
|
||||
|
|
Loading…
Reference in New Issue