Respect `is_rec` when eval (#3187)
parent
caa6fdf57d
commit
6aa3047157
|
@ -56,7 +56,10 @@ def classification_eval(engine, epoch_id=0):
|
||||||
|
|
||||||
# image input
|
# image input
|
||||||
with engine.auto_cast(is_eval=True):
|
with engine.auto_cast(is_eval=True):
|
||||||
out = engine.model(batch[0], batch[1])
|
if engine.is_rec:
|
||||||
|
out = engine.model(batch[0], batch[1])
|
||||||
|
else:
|
||||||
|
out = engine.model(batch[0])
|
||||||
|
|
||||||
# just for DistributedBatchSampler issue: repeat sampling
|
# just for DistributedBatchSampler issue: repeat sampling
|
||||||
current_samples = batch_size * paddle.distributed.get_world_size()
|
current_samples = batch_size * paddle.distributed.get_world_size()
|
||||||
|
|
|
@ -137,7 +137,10 @@ def compute_feature(engine, name="gallery"):
|
||||||
has_camera = True
|
has_camera = True
|
||||||
batch[2] = batch[2].reshape([-1, 1]).astype("int64")
|
batch[2] = batch[2].reshape([-1, 1]).astype("int64")
|
||||||
with engine.auto_cast(is_eval=True):
|
with engine.auto_cast(is_eval=True):
|
||||||
out = engine.model(batch[0], batch[1])
|
if engine.is_rec:
|
||||||
|
out = engine.model(batch[0], batch[1])
|
||||||
|
else:
|
||||||
|
out = engine.model(batch[0])
|
||||||
if "Student" in out:
|
if "Student" in out:
|
||||||
out = out["Student"]
|
out = out["Student"]
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue