remove redundant code, fix bugs in lr.step, merge GoodsDataset into Vehicle
parent
30cbb18321
commit
09200a31f4
|
@ -88,7 +88,7 @@ Optimizer:
|
|||
DataLoader:
|
||||
Train:
|
||||
dataset:
|
||||
name: GoodsDataset
|
||||
name: VeriWild
|
||||
image_root: ./dataset/SOP
|
||||
cls_label_path: ./dataset/SOP/train_list.txt
|
||||
backend: pil
|
||||
|
@ -117,7 +117,7 @@ DataLoader:
|
|||
Eval:
|
||||
Gallery:
|
||||
dataset:
|
||||
name: GoodsDataset
|
||||
name: VeriWild
|
||||
image_root: ./dataset/SOP
|
||||
cls_label_path: ./dataset/SOP/test_list.txt
|
||||
backend: pil
|
||||
|
@ -141,7 +141,7 @@ DataLoader:
|
|||
|
||||
Query:
|
||||
dataset:
|
||||
name: GoodsDataset
|
||||
name: VeriWild
|
||||
image_root: ./dataset/SOP
|
||||
cls_label_path: ./dataset/SOP/test_list.txt
|
||||
backend: pil
|
||||
|
|
|
@ -25,7 +25,6 @@ from ppcls.data.dataloader.imagenet_dataset import ImageNetDataset
|
|||
from ppcls.data.dataloader.multilabel_dataset import MultiLabelDataset
|
||||
from ppcls.data.dataloader.common_dataset import create_operators
|
||||
from ppcls.data.dataloader.vehicle_dataset import CompCars, VeriWild
|
||||
from ppcls.data.dataloader.goods_dataset import GoodsDataset
|
||||
from ppcls.data.dataloader.logo_dataset import LogoDataset
|
||||
from ppcls.data.dataloader.icartoon_dataset import ICartoonDataset
|
||||
from ppcls.data.dataloader.mix_dataset import MixDataset
|
||||
|
|
|
@ -82,26 +82,6 @@ class DistributedRandomIdentitySampler(DistributedBatchSampler):
|
|||
avai_pids = copy.deepcopy(self.pids)
|
||||
return batch_idxs_dict, avai_pids, count
|
||||
|
||||
def __iter__(self):
|
||||
batch_idxs_dict, avai_pids, count = self._prepare_batch()
|
||||
for _ in range(self.max_iters):
|
||||
final_idxs = []
|
||||
if len(avai_pids) < self.num_pids_per_batch:
|
||||
batch_idxs_dict, avai_pids, count = self._prepare_batch()
|
||||
|
||||
selected_pids = np.random.choice(
|
||||
avai_pids, self.num_pids_per_batch, False, count / count.sum())
|
||||
for pid in selected_pids:
|
||||
batch_idxs = batch_idxs_dict[pid].pop(0)
|
||||
final_idxs.extend(batch_idxs)
|
||||
pid_idx = avai_pids.index(pid)
|
||||
if len(batch_idxs_dict[pid]) == 0:
|
||||
avai_pids.pop(pid_idx)
|
||||
count = np.delete(count, pid_idx)
|
||||
else:
|
||||
count[pid_idx] = len(batch_idxs_dict[pid])
|
||||
yield final_idxs
|
||||
|
||||
def __iter__(self):
|
||||
# prepare
|
||||
batch_idxs_dict, avai_pids, count = self._prepare_batch()
|
||||
|
|
|
@ -1,95 +0,0 @@
|
|||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
from typing import Callable, List
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
from paddle.io import Dataset
|
||||
from PIL import Image
|
||||
from ppcls.data.preprocess import transform
|
||||
from ppcls.utils import logger
|
||||
|
||||
from .common_dataset import create_operators
|
||||
|
||||
|
||||
class GoodsDataset(Dataset):
|
||||
"""Dataset for Goods, such as SOP, Inshop...
|
||||
|
||||
Args:
|
||||
image_root (str): image root
|
||||
cls_label_path (str): path to annotation file
|
||||
transform_ops (List[Callable], optional): list of transform op(s). Defaults to None.
|
||||
backend (str, optional): pil or cv2. Defaults to "cv2".
|
||||
relabel (bool, optional): whether do relabel when original label do not starts from 0 or are discontinuous. Defaults to False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
image_root: str,
|
||||
cls_label_path: str,
|
||||
transform_ops: List[Callable]=None,
|
||||
backend="cv2",
|
||||
relabel=False):
|
||||
self._img_root = image_root
|
||||
self._cls_path = cls_label_path
|
||||
if transform_ops:
|
||||
self._transform_ops = create_operators(transform_ops)
|
||||
self.backend = backend
|
||||
self._dtype = paddle.get_default_dtype()
|
||||
self._load_anno(relabel)
|
||||
|
||||
def _load_anno(self, seed=None, relabel=False):
|
||||
assert os.path.exists(
|
||||
self._cls_path), f"path {self._cls_path} does not exist."
|
||||
assert os.path.exists(
|
||||
self._img_root), f"path {self._img_root} does not exist."
|
||||
self.images = []
|
||||
self.labels = []
|
||||
self.cameras = []
|
||||
with open(self._cls_path) as fd:
|
||||
lines = fd.readlines()
|
||||
if relabel:
|
||||
label_set = set()
|
||||
for line in lines:
|
||||
line = line.strip().split()
|
||||
label_set.add(np.int64(line[1]))
|
||||
label_map = {
|
||||
oldlabel: newlabel
|
||||
for newlabel, oldlabel in enumerate(label_set)
|
||||
}
|
||||
|
||||
if seed is not None:
|
||||
np.random.RandomState(seed).shuffle(lines)
|
||||
for line in lines:
|
||||
line = line.strip().split()
|
||||
self.images.append(os.path.join(self._img_root, line[0]))
|
||||
if relabel:
|
||||
self.labels.append(label_map[np.int64(line[1])])
|
||||
else:
|
||||
self.labels.append(np.int64(line[1]))
|
||||
self.cameras.append(np.int64(line[2]))
|
||||
assert os.path.exists(self.images[
|
||||
-1]), f"path {self.images[-1]} does not exist."
|
||||
|
||||
def __getitem__(self, idx):
|
||||
try:
|
||||
img = Image.open(self.images[idx]).convert("RGB")
|
||||
if self.backend == "cv2":
|
||||
img = np.array(img, dtype="float32").astype(np.uint8)
|
||||
if self._transform_ops:
|
||||
img = transform(img, self._transform_ops)
|
||||
if self.backend == "cv2":
|
||||
img = img.transpose((2, 0, 1))
|
||||
return (img, self.labels[idx], self.cameras[idx])
|
||||
except Exception as ex:
|
||||
logger.error("Exception occured when parse line: {} with msg: {}".
|
||||
format(self.images[idx], ex))
|
||||
rnd_idx = np.random.randint(self.__len__())
|
||||
return self.__getitem__(rnd_idx)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.images)
|
||||
|
||||
@property
|
||||
def class_num(self):
|
||||
return len(set(self.labels))
|
|
@ -19,8 +19,7 @@ import paddle
|
|||
from paddle.io import Dataset
|
||||
import os
|
||||
import cv2
|
||||
|
||||
from ppcls.data import preprocess
|
||||
from PIL import Image
|
||||
from ppcls.data.preprocess import transform
|
||||
from ppcls.utils import logger
|
||||
from .common_dataset import create_operators
|
||||
|
@ -89,15 +88,30 @@ class CompCars(Dataset):
|
|||
|
||||
|
||||
class VeriWild(Dataset):
|
||||
def __init__(self, image_root, cls_label_path, transform_ops=None):
|
||||
"""Dataset for Vehicle and other similar data structure, such as VeRI-Wild, SOP, Inshop...
|
||||
Args:
|
||||
image_root (str): image root
|
||||
cls_label_path (str): path to annotation file
|
||||
transform_ops (List[Callable], optional): list of transform op(s). Defaults to None.
|
||||
backend (str, optional): pil or cv2. Defaults to "cv2".
|
||||
relabel (bool, optional): whether do relabel when original label do not starts from 0 or are discontinuous. Defaults to False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
image_root,
|
||||
cls_label_path,
|
||||
transform_ops=None,
|
||||
backend="cv2",
|
||||
relabel=False):
|
||||
self._img_root = image_root
|
||||
self._cls_path = cls_label_path
|
||||
if transform_ops:
|
||||
self._transform_ops = create_operators(transform_ops)
|
||||
self.backend = backend
|
||||
self._dtype = paddle.get_default_dtype()
|
||||
self._load_anno()
|
||||
self._load_anno(relabel)
|
||||
|
||||
def _load_anno(self):
|
||||
def _load_anno(self, relabel):
|
||||
assert os.path.exists(
|
||||
self._cls_path), f"path {self._cls_path} does not exist."
|
||||
assert os.path.exists(
|
||||
|
@ -107,22 +121,40 @@ class VeriWild(Dataset):
|
|||
self.cameras = []
|
||||
with open(self._cls_path) as fd:
|
||||
lines = fd.readlines()
|
||||
if relabel:
|
||||
label_set = set()
|
||||
for line in lines:
|
||||
line = line.strip().split()
|
||||
label_set.add(np.int64(line[1]))
|
||||
label_map = {
|
||||
oldlabel: newlabel
|
||||
for newlabel, oldlabel in enumerate(label_set)
|
||||
}
|
||||
for line in lines:
|
||||
line = line.strip().split()
|
||||
self.images.append(os.path.join(self._img_root, line[0]))
|
||||
self.labels.append(np.int64(line[1]))
|
||||
if relabel:
|
||||
self.labels.append(label_map[np.int64(line[1])])
|
||||
else:
|
||||
self.labels.append(np.int64(line[1]))
|
||||
if len(line) >= 3:
|
||||
self.cameras.append(np.int64(line[2]))
|
||||
assert os.path.exists(self.images[-1])
|
||||
assert os.path.exists(self.images[-1]), \
|
||||
f"path {self.images[-1]} does not exist."
|
||||
|
||||
self.has_camera = len(self.cameras) > 0
|
||||
|
||||
def __getitem__(self, idx):
|
||||
try:
|
||||
with open(self.images[idx], 'rb') as f:
|
||||
img = f.read()
|
||||
if self.backend == "cv2":
|
||||
with open(self.images[idx], 'rb') as f:
|
||||
img = f.read()
|
||||
else:
|
||||
img = Image.open(self.images[idx]).convert("RGB")
|
||||
if self._transform_ops:
|
||||
img = transform(img, self._transform_ops)
|
||||
img = img.transpose((2, 0, 1))
|
||||
if self.backend == "cv2":
|
||||
img = img.transpose((2, 0, 1))
|
||||
if self.has_camera:
|
||||
return (img, self.labels[idx], self.cameras[idx])
|
||||
else:
|
||||
|
|
|
@ -42,6 +42,7 @@ from ppcls.data.utils.get_image_list import get_image_list
|
|||
from ppcls.data.postprocess import build_postprocess
|
||||
from ppcls.data import create_operators
|
||||
from ppcls.engine.train import train_epoch
|
||||
from ppcls.engine.train.utils import type_name
|
||||
from ppcls.engine import evaluation
|
||||
from ppcls.arch.gears.identity_head import IdentityHead
|
||||
|
||||
|
@ -377,7 +378,7 @@ class Engine(object):
|
|||
# step lr (by epoch) according to given metric, such as acc
|
||||
for i in range(len(self.lr_sch)):
|
||||
if getattr(self.lr_sch[i], "by_epoch", False) and \
|
||||
self.lr_sch[i].__class__.__name__ == "ReduceOnPlateau":
|
||||
type_name(self.lr_sch[i]) == "ReduceOnPlateau":
|
||||
self.lr_sch[i].step(acc)
|
||||
|
||||
if acc > best_metric["metric"]:
|
||||
|
|
|
@ -15,7 +15,7 @@ from __future__ import absolute_import, division, print_function
|
|||
|
||||
import time
|
||||
import paddle
|
||||
from ppcls.engine.train.utils import update_loss, update_metric, log_info
|
||||
from ppcls.engine.train.utils import update_loss, update_metric, log_info, type_name
|
||||
from ppcls.utils import profiler
|
||||
|
||||
|
||||
|
@ -98,7 +98,8 @@ def train_epoch(engine, epoch_id, print_batch_step):
|
|||
|
||||
# step lr(by epoch)
|
||||
for i in range(len(engine.lr_sch)):
|
||||
if getattr(engine.lr_sch[i], "by_epoch", False):
|
||||
if getattr(engine.lr_sch[i], "by_epoch", False) and \
|
||||
type_name(engine.lr_sch[i]) != "ReduceOnPlateau":
|
||||
engine.lr_sch[i].step()
|
||||
|
||||
|
||||
|
|
|
@ -53,14 +53,13 @@ def log_info(trainer, batch_size, epoch_id, iter_id):
|
|||
|
||||
ips_msg = "ips: {:.5f} samples/s".format(
|
||||
batch_size / trainer.time_info["batch_cost"].avg)
|
||||
eta_sec = ((trainer.config["Global"]["epochs"] - epoch_id + 1
|
||||
) * trainer.max_iter - iter_id
|
||||
) * trainer.time_info["batch_cost"].avg
|
||||
eta_sec = (
|
||||
(trainer.config["Global"]["epochs"] - epoch_id + 1
|
||||
) * trainer.max_iter - iter_id) * trainer.time_info["batch_cost"].avg
|
||||
eta_msg = "eta: {:s}".format(str(datetime.timedelta(seconds=int(eta_sec))))
|
||||
logger.info("[Train][Epoch {}/{}][Iter: {}/{}]{}, {}, {}, {}, {}".format(
|
||||
epoch_id, trainer.config["Global"]["epochs"], iter_id,
|
||||
trainer.max_iter, lr_msg, metric_msg, time_msg, ips_msg,
|
||||
eta_msg))
|
||||
trainer.max_iter, lr_msg, metric_msg, time_msg, ips_msg, eta_msg))
|
||||
|
||||
for i, lr in enumerate(trainer.lr_sch):
|
||||
logger.scaler(
|
||||
|
@ -74,3 +73,8 @@ def log_info(trainer, batch_size, epoch_id, iter_id):
|
|||
value=trainer.output_info[key].avg,
|
||||
step=trainer.global_step,
|
||||
writer=trainer.vdl_writer)
|
||||
|
||||
|
||||
def type_name(object: object) -> str:
|
||||
"""get class name of an object"""
|
||||
return object.__class__.__name__
|
||||
|
|
Loading…
Reference in New Issue