mirror of
https://github.com/PaddlePaddle/PaddleClas.git
synced 2025-06-03 21:55:06 +08:00
revert for running
This commit is contained in:
parent
9beb154bc3
commit
392b75b1ac
@ -165,7 +165,6 @@ class MobileNet(TheseusLayer):
|
||||
return_stages=return_stages)
|
||||
|
||||
@AMP_forward_decorator
|
||||
@clas_forward_decorator
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.blocks(x)
|
||||
|
@ -67,7 +67,7 @@ class ClassEval(object):
|
||||
if not self.config["Global"].get("use_multilabel", False):
|
||||
batch[1] = batch[1].reshape([-1, 1]).astype("int64")
|
||||
|
||||
out = self.model(batch)
|
||||
out = self.model(batch[0])
|
||||
|
||||
# just for DistributedBatchSampler issue: repeat sampling
|
||||
current_samples = batch_size * paddle.distributed.get_world_size()
|
||||
|
@ -41,6 +41,12 @@ class ClassTrainer(object):
|
||||
# gradient accumulation
|
||||
self.update_freq = self.config["Global"].get("update_freq", 1)
|
||||
|
||||
if "Head" in self.config["Arch"] or self.config["Arch"].get("is_rec",
|
||||
False):
|
||||
self.is_rec = True
|
||||
else:
|
||||
self.is_rec = False
|
||||
|
||||
# TODO(gaotingquan): mv to build_model
|
||||
# build EMA model
|
||||
self.model_ema = self._build_ema_model()
|
||||
@ -197,7 +203,11 @@ class ClassTrainer(object):
|
||||
batch[1] = batch[1].reshape([batch_size, -1])
|
||||
self.global_step += 1
|
||||
|
||||
out = self.model(batch)
|
||||
if self.is_rec:
|
||||
out = self.model(batch)
|
||||
else:
|
||||
out = self.model(batch[0])
|
||||
|
||||
loss_dict = self.loss_func(out, batch[1])
|
||||
# TODO(gaotingquan): mv update_freq to loss and optimizer
|
||||
loss = loss_dict["loss"] / self.update_freq
|
||||
|
Loading…
x
Reference in New Issue
Block a user