Update multilabel
parent
a90881c99f
commit
fe6f614680
|
@ -46,8 +46,8 @@ DataLoader:
|
||||||
Train:
|
Train:
|
||||||
dataset:
|
dataset:
|
||||||
name: MultiLabelDataset
|
name: MultiLabelDataset
|
||||||
image_root: ./dataset/NUS-SCENE-dataset/images/
|
image_root: ./dataset/NUS-WIDE-SCENE/NUS-SCENE-dataset/images/
|
||||||
cls_label_path: ./dataset/NUS-SCENE-dataset/multilabel_train_list.txt
|
cls_label_path: ./dataset/NUS-WIDE-SCENE/NUS-SCENE-dataset/multilabel_train_list.txt
|
||||||
transform_ops:
|
transform_ops:
|
||||||
- DecodeImage:
|
- DecodeImage:
|
||||||
to_rgb: True
|
to_rgb: True
|
||||||
|
@ -74,8 +74,8 @@ DataLoader:
|
||||||
Eval:
|
Eval:
|
||||||
dataset:
|
dataset:
|
||||||
name: MultiLabelDataset
|
name: MultiLabelDataset
|
||||||
image_root: ./dataset/NUS-SCENE-dataset/images/
|
image_root: ./dataset/NUS-WIDE-SCENE/NUS-SCENE-dataset/images/
|
||||||
cls_label_path: ./dataset/NUS-SCENE-dataset/multilabel_test_list.txt
|
cls_label_path: ./dataset/NUS-WIDE-SCENE/NUS-SCENE-dataset/multilabel_test_list.txt
|
||||||
transform_ops:
|
transform_ops:
|
||||||
- DecodeImage:
|
- DecodeImage:
|
||||||
to_rgb: True
|
to_rgb: True
|
||||||
|
|
|
@ -50,7 +50,7 @@ def classification_eval(engine, epoch_id=0):
|
||||||
time_info["reader_cost"].update(time.time() - tic)
|
time_info["reader_cost"].update(time.time() - tic)
|
||||||
batch_size = batch[0].shape[0]
|
batch_size = batch[0].shape[0]
|
||||||
batch[0] = paddle.to_tensor(batch[0]).astype("float32")
|
batch[0] = paddle.to_tensor(batch[0]).astype("float32")
|
||||||
if not evaler.config["Global"].get("use_multilabel", False):
|
if not engine.config["Global"].get("use_multilabel", False):
|
||||||
batch[1] = batch[1].reshape([-1, 1]).astype("int64")
|
batch[1] = batch[1].reshape([-1, 1]).astype("int64")
|
||||||
# image input
|
# image input
|
||||||
out = engine.model(batch[0])
|
out = engine.model(batch[0])
|
||||||
|
|
|
@ -76,8 +76,8 @@ def train_epoch(engine, epoch_id, print_batch_step):
|
||||||
tic = time.time()
|
tic = time.time()
|
||||||
|
|
||||||
|
|
||||||
def forward(trainer, batch):
|
def forward(engine, batch):
|
||||||
if not trainer.is_rec:
|
if not engine.is_rec:
|
||||||
return trainer.model(batch[0])
|
return engine.model(batch[0])
|
||||||
else:
|
else:
|
||||||
return trainer.model(batch[0], batch[1])
|
return engine.model(batch[0], batch[1])
|
||||||
|
|
Loading…
Reference in New Issue