remove redundant code, fix bugs in lr.step, merge GoodsDataset into Vehicle

pull/2383/head
HydrogenSulfate 2022-10-19 07:18:21 +00:00
parent 30cbb18321
commit 09200a31f4
8 changed files with 59 additions and 137 deletions

View File

@ -88,7 +88,7 @@ Optimizer:
DataLoader:
Train:
dataset:
name: GoodsDataset
name: VeriWild
image_root: ./dataset/SOP
cls_label_path: ./dataset/SOP/train_list.txt
backend: pil
@ -117,7 +117,7 @@ DataLoader:
Eval:
Gallery:
dataset:
name: GoodsDataset
name: VeriWild
image_root: ./dataset/SOP
cls_label_path: ./dataset/SOP/test_list.txt
backend: pil
@ -141,7 +141,7 @@ DataLoader:
Query:
dataset:
name: GoodsDataset
name: VeriWild
image_root: ./dataset/SOP
cls_label_path: ./dataset/SOP/test_list.txt
backend: pil

View File

@ -25,7 +25,6 @@ from ppcls.data.dataloader.imagenet_dataset import ImageNetDataset
from ppcls.data.dataloader.multilabel_dataset import MultiLabelDataset
from ppcls.data.dataloader.common_dataset import create_operators
from ppcls.data.dataloader.vehicle_dataset import CompCars, VeriWild
from ppcls.data.dataloader.goods_dataset import GoodsDataset
from ppcls.data.dataloader.logo_dataset import LogoDataset
from ppcls.data.dataloader.icartoon_dataset import ICartoonDataset
from ppcls.data.dataloader.mix_dataset import MixDataset

View File

@ -82,26 +82,6 @@ class DistributedRandomIdentitySampler(DistributedBatchSampler):
avai_pids = copy.deepcopy(self.pids)
return batch_idxs_dict, avai_pids, count
def __iter__(self):
batch_idxs_dict, avai_pids, count = self._prepare_batch()
for _ in range(self.max_iters):
final_idxs = []
if len(avai_pids) < self.num_pids_per_batch:
batch_idxs_dict, avai_pids, count = self._prepare_batch()
selected_pids = np.random.choice(
avai_pids, self.num_pids_per_batch, False, count / count.sum())
for pid in selected_pids:
batch_idxs = batch_idxs_dict[pid].pop(0)
final_idxs.extend(batch_idxs)
pid_idx = avai_pids.index(pid)
if len(batch_idxs_dict[pid]) == 0:
avai_pids.pop(pid_idx)
count = np.delete(count, pid_idx)
else:
count[pid_idx] = len(batch_idxs_dict[pid])
yield final_idxs
def __iter__(self):
# prepare
batch_idxs_dict, avai_pids, count = self._prepare_batch()

View File

@ -1,95 +0,0 @@
from __future__ import print_function
import os
from typing import Callable, List
import numpy as np
import paddle
from paddle.io import Dataset
from PIL import Image
from ppcls.data.preprocess import transform
from ppcls.utils import logger
from .common_dataset import create_operators
class GoodsDataset(Dataset):
"""Dataset for Goods, such as SOP, Inshop...
Args:
image_root (str): image root
cls_label_path (str): path to annotation file
transform_ops (List[Callable], optional): list of transform op(s). Defaults to None.
backend (str, optional): pil or cv2. Defaults to "cv2".
relabel (bool, optional): whether do relabel when original label do not starts from 0 or are discontinuous. Defaults to False.
"""
def __init__(self,
image_root: str,
cls_label_path: str,
transform_ops: List[Callable]=None,
backend="cv2",
relabel=False):
self._img_root = image_root
self._cls_path = cls_label_path
if transform_ops:
self._transform_ops = create_operators(transform_ops)
self.backend = backend
self._dtype = paddle.get_default_dtype()
self._load_anno(relabel)
def _load_anno(self, seed=None, relabel=False):
assert os.path.exists(
self._cls_path), f"path {self._cls_path} does not exist."
assert os.path.exists(
self._img_root), f"path {self._img_root} does not exist."
self.images = []
self.labels = []
self.cameras = []
with open(self._cls_path) as fd:
lines = fd.readlines()
if relabel:
label_set = set()
for line in lines:
line = line.strip().split()
label_set.add(np.int64(line[1]))
label_map = {
oldlabel: newlabel
for newlabel, oldlabel in enumerate(label_set)
}
if seed is not None:
np.random.RandomState(seed).shuffle(lines)
for line in lines:
line = line.strip().split()
self.images.append(os.path.join(self._img_root, line[0]))
if relabel:
self.labels.append(label_map[np.int64(line[1])])
else:
self.labels.append(np.int64(line[1]))
self.cameras.append(np.int64(line[2]))
assert os.path.exists(self.images[
-1]), f"path {self.images[-1]} does not exist."
def __getitem__(self, idx):
try:
img = Image.open(self.images[idx]).convert("RGB")
if self.backend == "cv2":
img = np.array(img, dtype="float32").astype(np.uint8)
if self._transform_ops:
img = transform(img, self._transform_ops)
if self.backend == "cv2":
img = img.transpose((2, 0, 1))
return (img, self.labels[idx], self.cameras[idx])
except Exception as ex:
logger.error("Exception occured when parse line: {} with msg: {}".
format(self.images[idx], ex))
rnd_idx = np.random.randint(self.__len__())
return self.__getitem__(rnd_idx)
def __len__(self):
return len(self.images)
@property
def class_num(self):
return len(set(self.labels))

View File

@ -19,8 +19,7 @@ import paddle
from paddle.io import Dataset
import os
import cv2
from ppcls.data import preprocess
from PIL import Image
from ppcls.data.preprocess import transform
from ppcls.utils import logger
from .common_dataset import create_operators
@ -89,15 +88,30 @@ class CompCars(Dataset):
class VeriWild(Dataset):
def __init__(self, image_root, cls_label_path, transform_ops=None):
"""Dataset for Vehicle and other similar data structure, such as VeRI-Wild, SOP, Inshop...
Args:
image_root (str): image root
cls_label_path (str): path to annotation file
transform_ops (List[Callable], optional): list of transform op(s). Defaults to None.
backend (str, optional): pil or cv2. Defaults to "cv2".
relabel (bool, optional): whether do relabel when original label do not starts from 0 or are discontinuous. Defaults to False.
"""
def __init__(self,
image_root,
cls_label_path,
transform_ops=None,
backend="cv2",
relabel=False):
self._img_root = image_root
self._cls_path = cls_label_path
if transform_ops:
self._transform_ops = create_operators(transform_ops)
self.backend = backend
self._dtype = paddle.get_default_dtype()
self._load_anno()
self._load_anno(relabel)
def _load_anno(self):
def _load_anno(self, relabel):
assert os.path.exists(
self._cls_path), f"path {self._cls_path} does not exist."
assert os.path.exists(
@ -107,22 +121,40 @@ class VeriWild(Dataset):
self.cameras = []
with open(self._cls_path) as fd:
lines = fd.readlines()
if relabel:
label_set = set()
for line in lines:
line = line.strip().split()
label_set.add(np.int64(line[1]))
label_map = {
oldlabel: newlabel
for newlabel, oldlabel in enumerate(label_set)
}
for line in lines:
line = line.strip().split()
self.images.append(os.path.join(self._img_root, line[0]))
self.labels.append(np.int64(line[1]))
if relabel:
self.labels.append(label_map[np.int64(line[1])])
else:
self.labels.append(np.int64(line[1]))
if len(line) >= 3:
self.cameras.append(np.int64(line[2]))
assert os.path.exists(self.images[-1])
assert os.path.exists(self.images[-1]), \
f"path {self.images[-1]} does not exist."
self.has_camera = len(self.cameras) > 0
def __getitem__(self, idx):
try:
with open(self.images[idx], 'rb') as f:
img = f.read()
if self.backend == "cv2":
with open(self.images[idx], 'rb') as f:
img = f.read()
else:
img = Image.open(self.images[idx]).convert("RGB")
if self._transform_ops:
img = transform(img, self._transform_ops)
img = img.transpose((2, 0, 1))
if self.backend == "cv2":
img = img.transpose((2, 0, 1))
if self.has_camera:
return (img, self.labels[idx], self.cameras[idx])
else:

View File

@ -42,6 +42,7 @@ from ppcls.data.utils.get_image_list import get_image_list
from ppcls.data.postprocess import build_postprocess
from ppcls.data import create_operators
from ppcls.engine.train import train_epoch
from ppcls.engine.train.utils import type_name
from ppcls.engine import evaluation
from ppcls.arch.gears.identity_head import IdentityHead
@ -377,7 +378,7 @@ class Engine(object):
# step lr (by epoch) according to given metric, such as acc
for i in range(len(self.lr_sch)):
if getattr(self.lr_sch[i], "by_epoch", False) and \
self.lr_sch[i].__class__.__name__ == "ReduceOnPlateau":
type_name(self.lr_sch[i]) == "ReduceOnPlateau":
self.lr_sch[i].step(acc)
if acc > best_metric["metric"]:

View File

@ -15,7 +15,7 @@ from __future__ import absolute_import, division, print_function
import time
import paddle
from ppcls.engine.train.utils import update_loss, update_metric, log_info
from ppcls.engine.train.utils import update_loss, update_metric, log_info, type_name
from ppcls.utils import profiler
@ -98,7 +98,8 @@ def train_epoch(engine, epoch_id, print_batch_step):
# step lr(by epoch)
for i in range(len(engine.lr_sch)):
if getattr(engine.lr_sch[i], "by_epoch", False):
if getattr(engine.lr_sch[i], "by_epoch", False) and \
type_name(engine.lr_sch[i]) != "ReduceOnPlateau":
engine.lr_sch[i].step()

View File

@ -53,14 +53,13 @@ def log_info(trainer, batch_size, epoch_id, iter_id):
ips_msg = "ips: {:.5f} samples/s".format(
batch_size / trainer.time_info["batch_cost"].avg)
eta_sec = ((trainer.config["Global"]["epochs"] - epoch_id + 1
) * trainer.max_iter - iter_id
) * trainer.time_info["batch_cost"].avg
eta_sec = (
(trainer.config["Global"]["epochs"] - epoch_id + 1
) * trainer.max_iter - iter_id) * trainer.time_info["batch_cost"].avg
eta_msg = "eta: {:s}".format(str(datetime.timedelta(seconds=int(eta_sec))))
logger.info("[Train][Epoch {}/{}][Iter: {}/{}]{}, {}, {}, {}, {}".format(
epoch_id, trainer.config["Global"]["epochs"], iter_id,
trainer.max_iter, lr_msg, metric_msg, time_msg, ips_msg,
eta_msg))
trainer.max_iter, lr_msg, metric_msg, time_msg, ips_msg, eta_msg))
for i, lr in enumerate(trainer.lr_sch):
logger.scaler(
@ -74,3 +73,8 @@ def log_info(trainer, batch_size, epoch_id, iter_id):
value=trainer.output_info[key].avg,
step=trainer.global_step,
writer=trainer.vdl_writer)
def type_name(object: object) -> str:
"""get class name of an object"""
return object.__class__.__name__