From e165897c379cff56d58a8944f0089967c5723a7c Mon Sep 17 00:00:00 2001 From: littletomatodonkey <2120160898@bit.edu.cn> Date: Thu, 6 May 2021 14:21:54 +0800 Subject: [PATCH] fix drop last for training process (#713) --- ppcls/data/reader.py | 10 ++++++---- tools/eval.py | 4 ++++ tools/train.py | 9 +++++++++ 3 files changed, 19 insertions(+), 4 deletions(-) diff --git a/ppcls/data/reader.py b/ppcls/data/reader.py index 90bff3589..67a1c0540 100755 --- a/ppcls/data/reader.py +++ b/ppcls/data/reader.py @@ -197,7 +197,7 @@ class CommonDataset(Dataset): def __len__(self): return self.num_samples - + class MultiLabelDataset(Dataset): """ @@ -224,9 +224,11 @@ class MultiLabelDataset(Dataset): labels = label_str.split(',') labels = [int(i) for i in labels] - return (transform(img, self.ops), np.array(labels).astype("float32")) + return (transform(img, self.ops), + np.array(labels).astype("float32")) except Exception as e: - logger.error("data read failed: {}, exception info: {}".format(line, e)) + logger.error("data read failed: {}, exception info: {}".format( + line, e)) return self.__getitem__(random.randint(0, len(self))) def __len__(self): @@ -291,7 +293,7 @@ class Reader: dataset, batch_size=batch_size, shuffle=self.shuffle and is_train, - drop_last=is_train) + drop_last=False) loader = DataLoader( dataset, batch_sampler=batch_sampler, diff --git a/tools/eval.py b/tools/eval.py index 8e0bcf16b..b214e70c6 100644 --- a/tools/eval.py +++ b/tools/eval.py @@ -72,6 +72,10 @@ def main(args, return_dict={}): init_model(config, net, optimizer=None) valid_dataloader = Reader(config, 'valid', places=place)() + if len(valid_dataloader) <= 0: + logger.error( + "valid dataloader is empty, please check your data config again!") + sys.exit(-1) net.eval() with paddle.no_grad(): if not multilabel: diff --git a/tools/train.py b/tools/train.py index 48e15676c..b17175427 100644 --- a/tools/train.py +++ b/tools/train.py @@ -88,9 +88,18 @@ def main(args): init_model(config, net, optimizer) train_dataloader = Reader(config, 'train', places=place)() + if len(train_dataloader) <= 0: + logger.error( + "train dataloader is empty, please check your data config again!") + sys.exit(-1) if config.validate: valid_dataloader = Reader(config, 'valid', places=place)() + if len(valid_dataloader) <= 0: + logger.error( + "valid dataloader is empty, please check your data config again!" + ) + sys.exit(-1) last_epoch_id = config.get("last_epoch", -1) best_top1_acc = 0.0 # best top1 acc record