Respect `is_rec` when eval (#3187)

pull/3191/head
Nyakku Shigure 2024-07-10 15:45:37 +08:00 committed by GitHub
parent caa6fdf57d
commit 6aa3047157
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 8 additions and 2 deletions

View File

@ -56,7 +56,10 @@ def classification_eval(engine, epoch_id=0):
# image input # image input
with engine.auto_cast(is_eval=True): with engine.auto_cast(is_eval=True):
out = engine.model(batch[0], batch[1]) if engine.is_rec:
out = engine.model(batch[0], batch[1])
else:
out = engine.model(batch[0])
# just for DistributedBatchSampler issue: repeat sampling # just for DistributedBatchSampler issue: repeat sampling
current_samples = batch_size * paddle.distributed.get_world_size() current_samples = batch_size * paddle.distributed.get_world_size()

View File

@ -137,7 +137,10 @@ def compute_feature(engine, name="gallery"):
has_camera = True has_camera = True
batch[2] = batch[2].reshape([-1, 1]).astype("int64") batch[2] = batch[2].reshape([-1, 1]).astype("int64")
with engine.auto_cast(is_eval=True): with engine.auto_cast(is_eval=True):
out = engine.model(batch[0], batch[1]) if engine.is_rec:
out = engine.model(batch[0], batch[1])
else:
out = engine.model(batch[0])
if "Student" in out: if "Student" in out:
out = out["Student"] out = out["Student"]