diff --git a/ppcls/arch/backbone/variant_models/pp_lcnetv2_variant.py b/ppcls/arch/backbone/variant_models/pp_lcnetv2_variant.py index ebde9af82..73c76993d 100644 --- a/ppcls/arch/backbone/variant_models/pp_lcnetv2_variant.py +++ b/ppcls/arch/backbone/variant_models/pp_lcnetv2_variant.py @@ -12,31 +12,31 @@ def PPLCNetV2_base_ShiTu(pretrained=False, use_ssld=False, **kwargs): new_conv = Identity() return new_conv - def last_stride_1_function(conv, pattern): - new_conv = Conv2D( - weight_attr=conv._weight_attr, - in_channels=conv._in_channels, - out_channels=conv._out_channels, - kernel_size=conv._kernel_size, - stride=1, - padding=conv._padding, - groups=conv._groups, - bias_attr=conv._bias_attr) - return new_conv + # def last_stride_function(conv, pattern): + # new_conv = Conv2D( + # weight_attr=conv._param_attr, + # in_channels=conv._in_channels, + # out_channels=conv._out_channels, + # kernel_size=conv._kernel_size, + # stride=1, + # padding=conv._padding, + # groups=conv._groups, + # bias_attr=conv._bias_attr) + # return new_conv pattern_act = ["act"] - pattern_last_stride = [ - "stages[3][0].dw_conv_list[0].conv", - "stages[3][0].dw_conv_list[1].conv", - "stages[3][0].dw_conv", - "stages[3][0].pw_conv.conv", - "stages[3][1].dw_conv_list[0].conv", - "stages[3][1].dw_conv_list[1].conv", - "stages[3][1].dw_conv_list[2].conv", - "stages[3][1].dw_conv", - "stages[3][1].pw_conv.conv", - ] - model.upgrade_sublayer(pattern_last_stride, last_stride_1_function) + # pattern_last_stride = [ + # "stages[3][0].dw_conv_list[0].conv", + # "stages[3][0].dw_conv_list[1].conv", + # "stages[3][0].dw_conv", + # "stages[3][0].pw_conv.conv", + # "stages[3][1].dw_conv_list[0].conv", + # "stages[3][1].dw_conv_list[1].conv", + # "stages[3][1].dw_conv_list[2].conv", + # "stages[3][1].dw_conv", + # "stages[3][1].pw_conv.conv", + # ] + # model.upgrade_sublayer(pattern_last_stride, last_stride_function) # TODO: theseuslayer有BUG,暂时注释掉 model.upgrade_sublayer(pattern_act, remove_ReLU_function) # load params again after upgrade some layers diff --git a/ppcls/configs/GeneralRecognitionV2/GeneralRecognitionV2_PPLCNetV2_base.yaml b/ppcls/configs/GeneralRecognitionV2/GeneralRecognitionV2_PPLCNetV2_base.yaml index b3babb255..9bc1d85cd 100644 --- a/ppcls/configs/GeneralRecognitionV2/GeneralRecognitionV2_PPLCNetV2_base.yaml +++ b/ppcls/configs/GeneralRecognitionV2/GeneralRecognitionV2_PPLCNetV2_base.yaml @@ -52,7 +52,7 @@ Arch: Head: name: FC embedding_size: *feat_dim - class_num: 192612 + class_num: 100000 weight_attr: initializer: name: Normal @@ -65,14 +65,14 @@ Loss: - CELoss: weight: 1.0 epsilon: 0.1 - - TripletAngleMarinLoss: + - TripletAngularMarginLoss: weight: 1.0 + feature_from: features margin: 0.5 reduction: mean add_absolute: True absolute_loss_weight: 0.1 normalize_feature: True - feature_from: features ap_value: 0.8 an_value: 0.4 Eval: @@ -84,7 +84,7 @@ Optimizer: momentum: 0.9 lr: name: Cosine - learning_rate: 0.04 + learning_rate: 0.06 # for 8gpu x 256bs warmup_epoch: 5 regularizer: name: L2 @@ -97,6 +97,7 @@ DataLoader: name: ImageNetDataset image_root: ./dataset/ cls_label_path: ./dataset/train_reg_all_data.txt + relabel: True transform_ops: - DecodeImage: to_rgb: True @@ -108,8 +109,9 @@ DataLoader: backend: cv2 - RandFlipImage: flip_code: 1 - - Pad_cv2: + - Pad: padding: 10 + backend: cv2 - RandCropImageV2: size: [224, 224] - RandomRotation: @@ -127,10 +129,14 @@ DataLoader: std: [0.229, 0.224, 0.225] order: hwc sampler: - name: DistributedBatchSampler - batch_size: 256 + name: PKSampler + batch_size: 8 + sample_per_id: 4 drop_last: False shuffle: True + sample_method: "id_avg_prob" + id_list: [50030, 80700, 92019, 96015] + ratio: [4, 4] loader: num_workers: 4 use_shared_memory: True @@ -196,3 +202,4 @@ Metric: Eval: - Recallk: topk: [1, 5] + - mAP: {} diff --git a/ppcls/data/dataloader/imagenet_dataset.py b/ppcls/data/dataloader/imagenet_dataset.py index 87188c160..542291caa 100644 --- a/ppcls/data/dataloader/imagenet_dataset.py +++ b/ppcls/data/dataloader/imagenet_dataset.py @@ -25,24 +25,50 @@ class ImageNetDataset(CommonDataset): image_root, cls_label_path, transform_ops=None, - delimiter=None): + delimiter=None, + relabel=False): + """ImageNetDataset + + Args: + image_root (str): _description_ + cls_label_path (str): _description_ + transform_ops (list, optional): list of transform op(s). Defaults to None. + delimiter (str, optional): delimiter. Defaults to None. + relabel (bool, optional): whether do relabel when original label do not starts from 0 or are discontinuous. Defaults to False. + """ self.delimiter = delimiter if delimiter is not None else " " + self.relabel = relabel super(ImageNetDataset, self).__init__(image_root, cls_label_path, transform_ops) def _load_anno(self, seed=None): - assert os.path.exists(self._cls_path) - assert os.path.exists(self._img_root) + assert os.path.exists( + self._cls_path), f"path {self._cls_path} does not exist." + assert os.path.exists( + self._img_root), f"path {self._img_root} does not exist." self.images = [] self.labels = [] with open(self._cls_path) as fd: lines = fd.readlines() + if self.relabel: + label_set = set() + for line in lines: + line = line.strip().split(self.delimiter) + label_set.add(np.int64(line[1])) + label_map = { + oldlabel: newlabel + for newlabel, oldlabel in enumerate(label_set) + } + if seed is not None: np.random.RandomState(seed).shuffle(lines) for line in lines: line = line.strip().split(self.delimiter) self.images.append(os.path.join(self._img_root, line[0])) - self.labels.append(np.int64(line[1])) + if self.relabel: + self.labels.append(label_map[np.int64(line[1])]) + else: + self.labels.append(np.int64(line[1])) assert os.path.exists(self.images[ -1]), f"path {self.images[-1]} does not exist." diff --git a/ppcls/data/dataloader/pk_sampler.py b/ppcls/data/dataloader/pk_sampler.py index 69d1a7c83..a4081b5c3 100644 --- a/ppcls/data/dataloader/pk_sampler.py +++ b/ppcls/data/dataloader/pk_sampler.py @@ -32,17 +32,23 @@ class PKSampler(DistributedBatchSampler): batch_size (int): batch size sample_per_id (int): number of instance(s) within an class shuffle (bool, optional): _description_. Defaults to True. + id_list(list): list of (start_id, end_id, start_id, end_id) for set of ids to duplicated. + ratio(list): list of (ratio1, ratio2..) the duplication number for ids in id_list. drop_last (bool, optional): whether to discard the data at the end. Defaults to True. sample_method (str, optional): sample method when generating prob_list. Defaults to "sample_avg_prob". """ + def __init__(self, dataset, batch_size, sample_per_id, shuffle=True, drop_last=True, + id_list=None, + ratio=None, sample_method="sample_avg_prob"): - super().__init__(dataset, batch_size, shuffle=shuffle, drop_last=drop_last) + super().__init__( + dataset, batch_size, shuffle=shuffle, drop_last=drop_last) assert batch_size % sample_per_id == 0, \ f"PKSampler configs error, sample_per_id({sample_per_id}) must be a divisor of batch_size({batch_size})." assert hasattr(self.dataset, @@ -67,6 +73,16 @@ class PKSampler(DistributedBatchSampler): logger.error( "PKSampler only support id_avg_prob and sample_avg_prob sample method, " "but receive {}.".format(self.sample_method)) + + if id_list and ratio: + assert len(id_list) % 2 == 0 and len(id_list) == len(ratio) * 2 + for i in range(len(self.prob_list)): + for j in range(len(ratio)): + if i >= id_list[j * 2] and i <= id_list[j * 2 + 1]: + self.prob_list[i] = self.prob_list[i] * ratio[j] + break + self.prob_list = self.prob_list / sum(self.prob_list) + diff = np.abs(sum(self.prob_list) - 1) if diff > 0.00000001: self.prob_list[-1] = 1 - sum(self.prob_list[:-1]) @@ -74,8 +90,8 @@ class PKSampler(DistributedBatchSampler): logger.error("PKSampler prob list error") else: logger.info( - "PKSampler: sum of prob list not equal to 1, diff is {}, change the last prob".format(diff) - ) + "PKSampler: sum of prob list not equal to 1, diff is {}, change the last prob". + format(diff)) def __iter__(self): label_per_batch = self.batch_size // self.sample_per_label diff --git a/ppcls/data/dataloader/vehicle_dataset.py b/ppcls/data/dataloader/vehicle_dataset.py index 8c89e3822..e4fbcad6a 100644 --- a/ppcls/data/dataloader/vehicle_dataset.py +++ b/ppcls/data/dataloader/vehicle_dataset.py @@ -98,8 +98,10 @@ class VeriWild(Dataset): self._load_anno() def _load_anno(self): - assert os.path.exists(self._cls_path) - assert os.path.exists(self._img_root) + assert os.path.exists( + self._cls_path), f"path {self._cls_path} does not exist." + assert os.path.exists( + self._img_root), f"path {self._img_root} does not exist." self.images = [] self.labels = [] self.cameras = [] diff --git a/ppcls/data/preprocess/ops/operators.py b/ppcls/data/preprocess/ops/operators.py index b36319bcb..c70b9cb72 100644 --- a/ppcls/data/preprocess/ops/operators.py +++ b/ppcls/data/preprocess/ops/operators.py @@ -681,11 +681,18 @@ class Pad(object): adapted from: https://pytorch.org/vision/stable/_modules/torchvision/transforms/transforms.html#Pad """ - def __init__(self, padding: int, fill: int=0, - padding_mode: str="constant"): + def __init__(self, + padding: int, + fill: int=0, + padding_mode: str="constant", + backend: str="pil"): self.padding = padding self.fill = fill self.padding_mode = padding_mode + self.backend = backend + assert backend in [ + "pil", "cv2" + ], f"backend must in ['pil', 'cv2'], but got {backend}" def _parse_fill(self, fill, img, min_pil_version, name="fillcolor"): # Process fill color for affine transforms @@ -720,11 +727,21 @@ class Pad(object): return {name: fill} def __call__(self, img): - opts = self._parse_fill(self.fill, img, "2.3.0", name="fill") - if img.mode == "P": - palette = img.getpalette() - img = ImageOps.expand(img, border=self.padding, **opts) - img.putpalette(palette) + if self.backend == "pil": + opts = self._parse_fill(self.fill, img, "2.3.0", name="fill") + if img.mode == "P": + palette = img.getpalette() + img = ImageOps.expand(img, border=self.padding, **opts) + img.putpalette(palette) + return img + return ImageOps.expand(img, border=self.padding, **opts) + else: + img = cv2.copyMakeBorder( + img, + self.padding, + self.padding, + self.padding, + self.padding, + cv2.BORDER_CONSTANT, + value=(self.fill, self.fill, self.fill)) return img - - return ImageOps.expand(img, border=self.padding, **opts) diff --git a/ppcls/data/preprocess/ops/test_pad.py b/ppcls/data/preprocess/ops/test_pad.py new file mode 100644 index 000000000..56ecb20a4 --- /dev/null +++ b/ppcls/data/preprocess/ops/test_pad.py @@ -0,0 +1,90 @@ +import numpy as np + +import paddle.vision.transforms as T +import cv2 + + +class Pad(object): + """ + Pads the given PIL.Image on all sides with specified padding mode and fill value. + adapted from: https://pytorch.org/vision/stable/_modules/torchvision/transforms/transforms.html#Pad + """ + + def __init__(self, + padding: int, + fill: int=0, + padding_mode: str="constant", + backend: str="pil"): + self.padding = padding + self.fill = fill + self.padding_mode = padding_mode + self.backend = backend + assert backend in [ + "pil", "cv2" + ], f"backend in Pad must in ['pil', 'cv2'], but got {backend}" + + def _parse_fill(self, fill, img, min_pil_version, name="fillcolor"): + # Process fill color for affine transforms + major_found, minor_found = (int(v) + for v in PILLOW_VERSION.split('.')[:2]) + major_required, minor_required = (int(v) for v in + min_pil_version.split('.')[:2]) + if major_found < major_required or (major_found == major_required and + minor_found < minor_required): + if fill is None: + return {} + else: + msg = ( + "The option to fill background area of the transformed image, " + "requires pillow>={}") + raise RuntimeError(msg.format(min_pil_version)) + + num_bands = len(img.getbands()) + if fill is None: + fill = 0 + if isinstance(fill, (int, float)) and num_bands > 1: + fill = tuple([fill] * num_bands) + if isinstance(fill, (list, tuple)): + if len(fill) != num_bands: + msg = ( + "The number of elements in 'fill' does not match the number of " + "bands of the image ({} != {})") + raise ValueError(msg.format(len(fill), num_bands)) + + fill = tuple(fill) + + return {name: fill} + + def __call__(self, img): + if self.backend == "pil": + opts = self._parse_fill(self.fill, img, "2.3.0", name="fill") + if img.mode == "P": + palette = img.getpalette() + img = ImageOps.expand(img, border=self.padding, **opts) + img.putpalette(palette) + return img + return ImageOps.expand(img, border=self.padding, **opts) + else: + img = cv2.copyMakeBorder( + img, + self.padding, + self.padding, + self.padding, + self.padding, + cv2.BORDER_CONSTANT, + value=(self.fill, self.fill, self.fill)) + return img + + +img = np.random.randint(0, 255, [3, 4, 3], dtype=np.uint8) + +for p in range(0, 10): + for v in range(0, 10): + img_1 = Pad(p, v, backend="cv2")(img) + img_2 = T.Pad(p, (v, v, v))(img) + print(f"{p} - {v}", np.allclose(img_1, img_2)) + if not np.allclose(img_1, img_2): + print(img_1[..., 0], "\n", img_2[..., 0]) + print(img_1[..., 1], "\n", img_2[..., 1]) + print(img_1[..., 2], "\n", img_2[..., 2]) + exit(0) diff --git a/ppcls/engine/engine.py b/ppcls/engine/engine.py index c740eb201..5a3fe20a2 100644 --- a/ppcls/engine/engine.py +++ b/ppcls/engine/engine.py @@ -114,10 +114,7 @@ class Engine(object): #TODO(gaotingquan): support rec class_num = config["Arch"].get("class_num", None) self.config["DataLoader"].update({"class_num": class_num}) - self.model = build_model(self.config, self.mode) - # print(*self.model.state_dict().keys(), sep='\n') - print(self.model.backbone.stages[3][0].dw_conv_list[0].conv) - exit(0) + # build dataloader if self.mode == 'train': self.train_dataloader = build_dataloader( diff --git a/ppcls/loss/__init__.py b/ppcls/loss/__init__.py index d4d548819..4f973eb09 100644 --- a/ppcls/loss/__init__.py +++ b/ppcls/loss/__init__.py @@ -12,7 +12,7 @@ from .msmloss import MSMLoss from .npairsloss import NpairsLoss from .trihardloss import TriHardLoss from .triplet import TripletLoss, TripletLossV2 -from .tripletangularmarginloss import TTripletAngularMarginLoss +from .tripletangularmarginloss import TripletAngularMarginLoss from .supconloss import SupConLoss from .pairwisecosface import PairwiseCosface from .dmlloss import DMLLoss diff --git a/ppcls/loss/tripletangularmarginloss.py b/ppcls/loss/tripletangularmarginloss.py index fa32a197b..3a91d2d49 100644 --- a/ppcls/loss/tripletangularmarginloss.py +++ b/ppcls/loss/tripletangularmarginloss.py @@ -43,7 +43,7 @@ class TripletAngularMarginLoss(nn.Layer): ap_value=0.9, an_value=0.5, feature_from="features"): - super(TripletAngleMarginLoss, self).__init__() + super(TripletAngularMarginLoss, self).__init__() self.margin = margin self.feature_from = feature_from self.ranking_loss = paddle.nn.loss.MarginRankingLoss(