fix drop last for training process (#713)
parent
9c0f049603
commit
e165897c37
|
@ -197,7 +197,7 @@ class CommonDataset(Dataset):
|
|||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
|
||||
|
||||
class MultiLabelDataset(Dataset):
|
||||
"""
|
||||
|
@ -224,9 +224,11 @@ class MultiLabelDataset(Dataset):
|
|||
labels = label_str.split(',')
|
||||
labels = [int(i) for i in labels]
|
||||
|
||||
return (transform(img, self.ops), np.array(labels).astype("float32"))
|
||||
return (transform(img, self.ops),
|
||||
np.array(labels).astype("float32"))
|
||||
except Exception as e:
|
||||
logger.error("data read failed: {}, exception info: {}".format(line, e))
|
||||
logger.error("data read failed: {}, exception info: {}".format(
|
||||
line, e))
|
||||
return self.__getitem__(random.randint(0, len(self)))
|
||||
|
||||
def __len__(self):
|
||||
|
@ -291,7 +293,7 @@ class Reader:
|
|||
dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=self.shuffle and is_train,
|
||||
drop_last=is_train)
|
||||
drop_last=False)
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_sampler=batch_sampler,
|
||||
|
|
|
@ -72,6 +72,10 @@ def main(args, return_dict={}):
|
|||
|
||||
init_model(config, net, optimizer=None)
|
||||
valid_dataloader = Reader(config, 'valid', places=place)()
|
||||
if len(valid_dataloader) <= 0:
|
||||
logger.error(
|
||||
"valid dataloader is empty, please check your data config again!")
|
||||
sys.exit(-1)
|
||||
net.eval()
|
||||
with paddle.no_grad():
|
||||
if not multilabel:
|
||||
|
|
|
@ -88,9 +88,18 @@ def main(args):
|
|||
init_model(config, net, optimizer)
|
||||
|
||||
train_dataloader = Reader(config, 'train', places=place)()
|
||||
if len(train_dataloader) <= 0:
|
||||
logger.error(
|
||||
"train dataloader is empty, please check your data config again!")
|
||||
sys.exit(-1)
|
||||
|
||||
if config.validate:
|
||||
valid_dataloader = Reader(config, 'valid', places=place)()
|
||||
if len(valid_dataloader) <= 0:
|
||||
logger.error(
|
||||
"valid dataloader is empty, please check your data config again!"
|
||||
)
|
||||
sys.exit(-1)
|
||||
|
||||
last_epoch_id = config.get("last_epoch", -1)
|
||||
best_top1_acc = 0.0 # best top1 acc record
|
||||
|
|
Loading…
Reference in New Issue