mirror of
https://github.com/PaddlePaddle/PaddleClas.git
synced 2025-06-03 21:55:06 +08:00
Respect is_rec
when eval (#3187)
This commit is contained in:
parent
caa6fdf57d
commit
6aa3047157
@ -56,7 +56,10 @@ def classification_eval(engine, epoch_id=0):
|
|||||||
|
|
||||||
# image input
|
# image input
|
||||||
with engine.auto_cast(is_eval=True):
|
with engine.auto_cast(is_eval=True):
|
||||||
|
if engine.is_rec:
|
||||||
out = engine.model(batch[0], batch[1])
|
out = engine.model(batch[0], batch[1])
|
||||||
|
else:
|
||||||
|
out = engine.model(batch[0])
|
||||||
|
|
||||||
# just for DistributedBatchSampler issue: repeat sampling
|
# just for DistributedBatchSampler issue: repeat sampling
|
||||||
current_samples = batch_size * paddle.distributed.get_world_size()
|
current_samples = batch_size * paddle.distributed.get_world_size()
|
||||||
|
@ -137,7 +137,10 @@ def compute_feature(engine, name="gallery"):
|
|||||||
has_camera = True
|
has_camera = True
|
||||||
batch[2] = batch[2].reshape([-1, 1]).astype("int64")
|
batch[2] = batch[2].reshape([-1, 1]).astype("int64")
|
||||||
with engine.auto_cast(is_eval=True):
|
with engine.auto_cast(is_eval=True):
|
||||||
|
if engine.is_rec:
|
||||||
out = engine.model(batch[0], batch[1])
|
out = engine.model(batch[0], batch[1])
|
||||||
|
else:
|
||||||
|
out = engine.model(batch[0])
|
||||||
if "Student" in out:
|
if "Student" in out:
|
||||||
out = out["Student"]
|
out = out["Student"]
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user