mirror of
https://github.com/PaddlePaddle/PaddleClas.git
synced 2025-06-03 21:55:06 +08:00
revert for running
This commit is contained in:
parent
9beb154bc3
commit
392b75b1ac
@ -165,7 +165,6 @@ class MobileNet(TheseusLayer):
|
|||||||
return_stages=return_stages)
|
return_stages=return_stages)
|
||||||
|
|
||||||
@AMP_forward_decorator
|
@AMP_forward_decorator
|
||||||
@clas_forward_decorator
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.conv(x)
|
x = self.conv(x)
|
||||||
x = self.blocks(x)
|
x = self.blocks(x)
|
||||||
|
@ -67,7 +67,7 @@ class ClassEval(object):
|
|||||||
if not self.config["Global"].get("use_multilabel", False):
|
if not self.config["Global"].get("use_multilabel", False):
|
||||||
batch[1] = batch[1].reshape([-1, 1]).astype("int64")
|
batch[1] = batch[1].reshape([-1, 1]).astype("int64")
|
||||||
|
|
||||||
out = self.model(batch)
|
out = self.model(batch[0])
|
||||||
|
|
||||||
# just for DistributedBatchSampler issue: repeat sampling
|
# just for DistributedBatchSampler issue: repeat sampling
|
||||||
current_samples = batch_size * paddle.distributed.get_world_size()
|
current_samples = batch_size * paddle.distributed.get_world_size()
|
||||||
|
@ -41,6 +41,12 @@ class ClassTrainer(object):
|
|||||||
# gradient accumulation
|
# gradient accumulation
|
||||||
self.update_freq = self.config["Global"].get("update_freq", 1)
|
self.update_freq = self.config["Global"].get("update_freq", 1)
|
||||||
|
|
||||||
|
if "Head" in self.config["Arch"] or self.config["Arch"].get("is_rec",
|
||||||
|
False):
|
||||||
|
self.is_rec = True
|
||||||
|
else:
|
||||||
|
self.is_rec = False
|
||||||
|
|
||||||
# TODO(gaotingquan): mv to build_model
|
# TODO(gaotingquan): mv to build_model
|
||||||
# build EMA model
|
# build EMA model
|
||||||
self.model_ema = self._build_ema_model()
|
self.model_ema = self._build_ema_model()
|
||||||
@ -197,7 +203,11 @@ class ClassTrainer(object):
|
|||||||
batch[1] = batch[1].reshape([batch_size, -1])
|
batch[1] = batch[1].reshape([batch_size, -1])
|
||||||
self.global_step += 1
|
self.global_step += 1
|
||||||
|
|
||||||
out = self.model(batch)
|
if self.is_rec:
|
||||||
|
out = self.model(batch)
|
||||||
|
else:
|
||||||
|
out = self.model(batch[0])
|
||||||
|
|
||||||
loss_dict = self.loss_func(out, batch[1])
|
loss_dict = self.loss_func(out, batch[1])
|
||||||
# TODO(gaotingquan): mv update_freq to loss and optimizer
|
# TODO(gaotingquan): mv update_freq to loss and optimizer
|
||||||
loss = loss_dict["loss"] / self.update_freq
|
loss = loss_dict["loss"] / self.update_freq
|
||||||
|
Loading…
x
Reference in New Issue
Block a user