fix drop last for training process (#713)

pull/716/head
littletomatodonkey 2021-05-06 14:21:54 +08:00 committed by GitHub
parent 9c0f049603
commit e165897c37
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 19 additions and 4 deletions

View File

@ -197,7 +197,7 @@ class CommonDataset(Dataset):
def __len__(self):
return self.num_samples
class MultiLabelDataset(Dataset):
"""
@ -224,9 +224,11 @@ class MultiLabelDataset(Dataset):
labels = label_str.split(',')
labels = [int(i) for i in labels]
return (transform(img, self.ops), np.array(labels).astype("float32"))
return (transform(img, self.ops),
np.array(labels).astype("float32"))
except Exception as e:
logger.error("data read failed: {}, exception info: {}".format(line, e))
logger.error("data read failed: {}, exception info: {}".format(
line, e))
return self.__getitem__(random.randint(0, len(self)))
def __len__(self):
@ -291,7 +293,7 @@ class Reader:
dataset,
batch_size=batch_size,
shuffle=self.shuffle and is_train,
drop_last=is_train)
drop_last=False)
loader = DataLoader(
dataset,
batch_sampler=batch_sampler,

View File

@ -72,6 +72,10 @@ def main(args, return_dict={}):
init_model(config, net, optimizer=None)
valid_dataloader = Reader(config, 'valid', places=place)()
if len(valid_dataloader) <= 0:
logger.error(
"valid dataloader is empty, please check your data config again!")
sys.exit(-1)
net.eval()
with paddle.no_grad():
if not multilabel:

View File

@ -88,9 +88,18 @@ def main(args):
init_model(config, net, optimizer)
train_dataloader = Reader(config, 'train', places=place)()
if len(train_dataloader) <= 0:
logger.error(
"train dataloader is empty, please check your data config again!")
sys.exit(-1)
if config.validate:
valid_dataloader = Reader(config, 'valid', places=place)()
if len(valid_dataloader) <= 0:
logger.error(
"valid dataloader is empty, please check your data config again!"
)
sys.exit(-1)
last_epoch_id = config.get("last_epoch", -1)
best_top1_acc = 0.0 # best top1 acc record