fix some bugs

pull/2224/head
HydrogenSulfate 2022-08-23 10:59:30 +00:00
parent 1b5e00e82a
commit 5a4874079d
10 changed files with 209 additions and 54 deletions

View File

@ -12,31 +12,31 @@ def PPLCNetV2_base_ShiTu(pretrained=False, use_ssld=False, **kwargs):
new_conv = Identity() new_conv = Identity()
return new_conv return new_conv
def last_stride_1_function(conv, pattern): # def last_stride_function(conv, pattern):
new_conv = Conv2D( # new_conv = Conv2D(
weight_attr=conv._weight_attr, # weight_attr=conv._param_attr,
in_channels=conv._in_channels, # in_channels=conv._in_channels,
out_channels=conv._out_channels, # out_channels=conv._out_channels,
kernel_size=conv._kernel_size, # kernel_size=conv._kernel_size,
stride=1, # stride=1,
padding=conv._padding, # padding=conv._padding,
groups=conv._groups, # groups=conv._groups,
bias_attr=conv._bias_attr) # bias_attr=conv._bias_attr)
return new_conv # return new_conv
pattern_act = ["act"] pattern_act = ["act"]
pattern_last_stride = [ # pattern_last_stride = [
"stages[3][0].dw_conv_list[0].conv", # "stages[3][0].dw_conv_list[0].conv",
"stages[3][0].dw_conv_list[1].conv", # "stages[3][0].dw_conv_list[1].conv",
"stages[3][0].dw_conv", # "stages[3][0].dw_conv",
"stages[3][0].pw_conv.conv", # "stages[3][0].pw_conv.conv",
"stages[3][1].dw_conv_list[0].conv", # "stages[3][1].dw_conv_list[0].conv",
"stages[3][1].dw_conv_list[1].conv", # "stages[3][1].dw_conv_list[1].conv",
"stages[3][1].dw_conv_list[2].conv", # "stages[3][1].dw_conv_list[2].conv",
"stages[3][1].dw_conv", # "stages[3][1].dw_conv",
"stages[3][1].pw_conv.conv", # "stages[3][1].pw_conv.conv",
] # ]
model.upgrade_sublayer(pattern_last_stride, last_stride_1_function) # model.upgrade_sublayer(pattern_last_stride, last_stride_function) # TODO: theseuslayer有BUG暂时注释掉
model.upgrade_sublayer(pattern_act, remove_ReLU_function) model.upgrade_sublayer(pattern_act, remove_ReLU_function)
# load params again after upgrade some layers # load params again after upgrade some layers

View File

@ -52,7 +52,7 @@ Arch:
Head: Head:
name: FC name: FC
embedding_size: *feat_dim embedding_size: *feat_dim
class_num: 192612 class_num: 100000
weight_attr: weight_attr:
initializer: initializer:
name: Normal name: Normal
@ -65,14 +65,14 @@ Loss:
- CELoss: - CELoss:
weight: 1.0 weight: 1.0
epsilon: 0.1 epsilon: 0.1
- TripletAngleMarinLoss: - TripletAngularMarginLoss:
weight: 1.0 weight: 1.0
feature_from: features
margin: 0.5 margin: 0.5
reduction: mean reduction: mean
add_absolute: True add_absolute: True
absolute_loss_weight: 0.1 absolute_loss_weight: 0.1
normalize_feature: True normalize_feature: True
feature_from: features
ap_value: 0.8 ap_value: 0.8
an_value: 0.4 an_value: 0.4
Eval: Eval:
@ -84,7 +84,7 @@ Optimizer:
momentum: 0.9 momentum: 0.9
lr: lr:
name: Cosine name: Cosine
learning_rate: 0.04 learning_rate: 0.06 # for 8gpu x 256bs
warmup_epoch: 5 warmup_epoch: 5
regularizer: regularizer:
name: L2 name: L2
@ -97,6 +97,7 @@ DataLoader:
name: ImageNetDataset name: ImageNetDataset
image_root: ./dataset/ image_root: ./dataset/
cls_label_path: ./dataset/train_reg_all_data.txt cls_label_path: ./dataset/train_reg_all_data.txt
relabel: True
transform_ops: transform_ops:
- DecodeImage: - DecodeImage:
to_rgb: True to_rgb: True
@ -108,8 +109,9 @@ DataLoader:
backend: cv2 backend: cv2
- RandFlipImage: - RandFlipImage:
flip_code: 1 flip_code: 1
- Pad_cv2: - Pad:
padding: 10 padding: 10
backend: cv2
- RandCropImageV2: - RandCropImageV2:
size: [224, 224] size: [224, 224]
- RandomRotation: - RandomRotation:
@ -127,10 +129,14 @@ DataLoader:
std: [0.229, 0.224, 0.225] std: [0.229, 0.224, 0.225]
order: hwc order: hwc
sampler: sampler:
name: DistributedBatchSampler name: PKSampler
batch_size: 256 batch_size: 8
sample_per_id: 4
drop_last: False drop_last: False
shuffle: True shuffle: True
sample_method: "id_avg_prob"
id_list: [50030, 80700, 92019, 96015]
ratio: [4, 4]
loader: loader:
num_workers: 4 num_workers: 4
use_shared_memory: True use_shared_memory: True
@ -196,3 +202,4 @@ Metric:
Eval: Eval:
- Recallk: - Recallk:
topk: [1, 5] topk: [1, 5]
- mAP: {}

View File

@ -25,24 +25,50 @@ class ImageNetDataset(CommonDataset):
image_root, image_root,
cls_label_path, cls_label_path,
transform_ops=None, transform_ops=None,
delimiter=None): delimiter=None,
relabel=False):
"""ImageNetDataset
Args:
image_root (str): _description_
cls_label_path (str): _description_
transform_ops (list, optional): list of transform op(s). Defaults to None.
delimiter (str, optional): delimiter. Defaults to None.
relabel (bool, optional): whether do relabel when original label do not starts from 0 or are discontinuous. Defaults to False.
"""
self.delimiter = delimiter if delimiter is not None else " " self.delimiter = delimiter if delimiter is not None else " "
self.relabel = relabel
super(ImageNetDataset, self).__init__(image_root, cls_label_path, super(ImageNetDataset, self).__init__(image_root, cls_label_path,
transform_ops) transform_ops)
def _load_anno(self, seed=None): def _load_anno(self, seed=None):
assert os.path.exists(self._cls_path) assert os.path.exists(
assert os.path.exists(self._img_root) self._cls_path), f"path {self._cls_path} does not exist."
assert os.path.exists(
self._img_root), f"path {self._img_root} does not exist."
self.images = [] self.images = []
self.labels = [] self.labels = []
with open(self._cls_path) as fd: with open(self._cls_path) as fd:
lines = fd.readlines() lines = fd.readlines()
if self.relabel:
label_set = set()
for line in lines:
line = line.strip().split(self.delimiter)
label_set.add(np.int64(line[1]))
label_map = {
oldlabel: newlabel
for newlabel, oldlabel in enumerate(label_set)
}
if seed is not None: if seed is not None:
np.random.RandomState(seed).shuffle(lines) np.random.RandomState(seed).shuffle(lines)
for line in lines: for line in lines:
line = line.strip().split(self.delimiter) line = line.strip().split(self.delimiter)
self.images.append(os.path.join(self._img_root, line[0])) self.images.append(os.path.join(self._img_root, line[0]))
if self.relabel:
self.labels.append(label_map[np.int64(line[1])])
else:
self.labels.append(np.int64(line[1])) self.labels.append(np.int64(line[1]))
assert os.path.exists(self.images[ assert os.path.exists(self.images[
-1]), f"path {self.images[-1]} does not exist." -1]), f"path {self.images[-1]} does not exist."

View File

@ -32,17 +32,23 @@ class PKSampler(DistributedBatchSampler):
batch_size (int): batch size batch_size (int): batch size
sample_per_id (int): number of instance(s) within an class sample_per_id (int): number of instance(s) within an class
shuffle (bool, optional): _description_. Defaults to True. shuffle (bool, optional): _description_. Defaults to True.
id_list(list): list of (start_id, end_id, start_id, end_id) for set of ids to duplicated.
ratio(list): list of (ratio1, ratio2..) the duplication number for ids in id_list.
drop_last (bool, optional): whether to discard the data at the end. Defaults to True. drop_last (bool, optional): whether to discard the data at the end. Defaults to True.
sample_method (str, optional): sample method when generating prob_list. Defaults to "sample_avg_prob". sample_method (str, optional): sample method when generating prob_list. Defaults to "sample_avg_prob".
""" """
def __init__(self, def __init__(self,
dataset, dataset,
batch_size, batch_size,
sample_per_id, sample_per_id,
shuffle=True, shuffle=True,
drop_last=True, drop_last=True,
id_list=None,
ratio=None,
sample_method="sample_avg_prob"): sample_method="sample_avg_prob"):
super().__init__(dataset, batch_size, shuffle=shuffle, drop_last=drop_last) super().__init__(
dataset, batch_size, shuffle=shuffle, drop_last=drop_last)
assert batch_size % sample_per_id == 0, \ assert batch_size % sample_per_id == 0, \
f"PKSampler configs error, sample_per_id({sample_per_id}) must be a divisor of batch_size({batch_size})." f"PKSampler configs error, sample_per_id({sample_per_id}) must be a divisor of batch_size({batch_size})."
assert hasattr(self.dataset, assert hasattr(self.dataset,
@ -67,6 +73,16 @@ class PKSampler(DistributedBatchSampler):
logger.error( logger.error(
"PKSampler only support id_avg_prob and sample_avg_prob sample method, " "PKSampler only support id_avg_prob and sample_avg_prob sample method, "
"but receive {}.".format(self.sample_method)) "but receive {}.".format(self.sample_method))
if id_list and ratio:
assert len(id_list) % 2 == 0 and len(id_list) == len(ratio) * 2
for i in range(len(self.prob_list)):
for j in range(len(ratio)):
if i >= id_list[j * 2] and i <= id_list[j * 2 + 1]:
self.prob_list[i] = self.prob_list[i] * ratio[j]
break
self.prob_list = self.prob_list / sum(self.prob_list)
diff = np.abs(sum(self.prob_list) - 1) diff = np.abs(sum(self.prob_list) - 1)
if diff > 0.00000001: if diff > 0.00000001:
self.prob_list[-1] = 1 - sum(self.prob_list[:-1]) self.prob_list[-1] = 1 - sum(self.prob_list[:-1])
@ -74,8 +90,8 @@ class PKSampler(DistributedBatchSampler):
logger.error("PKSampler prob list error") logger.error("PKSampler prob list error")
else: else:
logger.info( logger.info(
"PKSampler: sum of prob list not equal to 1, diff is {}, change the last prob".format(diff) "PKSampler: sum of prob list not equal to 1, diff is {}, change the last prob".
) format(diff))
def __iter__(self): def __iter__(self):
label_per_batch = self.batch_size // self.sample_per_label label_per_batch = self.batch_size // self.sample_per_label

View File

@ -98,8 +98,10 @@ class VeriWild(Dataset):
self._load_anno() self._load_anno()
def _load_anno(self): def _load_anno(self):
assert os.path.exists(self._cls_path) assert os.path.exists(
assert os.path.exists(self._img_root) self._cls_path), f"path {self._cls_path} does not exist."
assert os.path.exists(
self._img_root), f"path {self._img_root} does not exist."
self.images = [] self.images = []
self.labels = [] self.labels = []
self.cameras = [] self.cameras = []

View File

@ -681,11 +681,18 @@ class Pad(object):
adapted from: https://pytorch.org/vision/stable/_modules/torchvision/transforms/transforms.html#Pad adapted from: https://pytorch.org/vision/stable/_modules/torchvision/transforms/transforms.html#Pad
""" """
def __init__(self, padding: int, fill: int=0, def __init__(self,
padding_mode: str="constant"): padding: int,
fill: int=0,
padding_mode: str="constant",
backend: str="pil"):
self.padding = padding self.padding = padding
self.fill = fill self.fill = fill
self.padding_mode = padding_mode self.padding_mode = padding_mode
self.backend = backend
assert backend in [
"pil", "cv2"
], f"backend must in ['pil', 'cv2'], but got {backend}"
def _parse_fill(self, fill, img, min_pil_version, name="fillcolor"): def _parse_fill(self, fill, img, min_pil_version, name="fillcolor"):
# Process fill color for affine transforms # Process fill color for affine transforms
@ -720,11 +727,21 @@ class Pad(object):
return {name: fill} return {name: fill}
def __call__(self, img): def __call__(self, img):
if self.backend == "pil":
opts = self._parse_fill(self.fill, img, "2.3.0", name="fill") opts = self._parse_fill(self.fill, img, "2.3.0", name="fill")
if img.mode == "P": if img.mode == "P":
palette = img.getpalette() palette = img.getpalette()
img = ImageOps.expand(img, border=self.padding, **opts) img = ImageOps.expand(img, border=self.padding, **opts)
img.putpalette(palette) img.putpalette(palette)
return img return img
return ImageOps.expand(img, border=self.padding, **opts) return ImageOps.expand(img, border=self.padding, **opts)
else:
img = cv2.copyMakeBorder(
img,
self.padding,
self.padding,
self.padding,
self.padding,
cv2.BORDER_CONSTANT,
value=(self.fill, self.fill, self.fill))
return img

View File

@ -0,0 +1,90 @@
import numpy as np
import paddle.vision.transforms as T
import cv2
class Pad(object):
"""
Pads the given PIL.Image on all sides with specified padding mode and fill value.
adapted from: https://pytorch.org/vision/stable/_modules/torchvision/transforms/transforms.html#Pad
"""
def __init__(self,
padding: int,
fill: int=0,
padding_mode: str="constant",
backend: str="pil"):
self.padding = padding
self.fill = fill
self.padding_mode = padding_mode
self.backend = backend
assert backend in [
"pil", "cv2"
], f"backend in Pad must in ['pil', 'cv2'], but got {backend}"
def _parse_fill(self, fill, img, min_pil_version, name="fillcolor"):
# Process fill color for affine transforms
major_found, minor_found = (int(v)
for v in PILLOW_VERSION.split('.')[:2])
major_required, minor_required = (int(v) for v in
min_pil_version.split('.')[:2])
if major_found < major_required or (major_found == major_required and
minor_found < minor_required):
if fill is None:
return {}
else:
msg = (
"The option to fill background area of the transformed image, "
"requires pillow>={}")
raise RuntimeError(msg.format(min_pil_version))
num_bands = len(img.getbands())
if fill is None:
fill = 0
if isinstance(fill, (int, float)) and num_bands > 1:
fill = tuple([fill] * num_bands)
if isinstance(fill, (list, tuple)):
if len(fill) != num_bands:
msg = (
"The number of elements in 'fill' does not match the number of "
"bands of the image ({} != {})")
raise ValueError(msg.format(len(fill), num_bands))
fill = tuple(fill)
return {name: fill}
def __call__(self, img):
if self.backend == "pil":
opts = self._parse_fill(self.fill, img, "2.3.0", name="fill")
if img.mode == "P":
palette = img.getpalette()
img = ImageOps.expand(img, border=self.padding, **opts)
img.putpalette(palette)
return img
return ImageOps.expand(img, border=self.padding, **opts)
else:
img = cv2.copyMakeBorder(
img,
self.padding,
self.padding,
self.padding,
self.padding,
cv2.BORDER_CONSTANT,
value=(self.fill, self.fill, self.fill))
return img
img = np.random.randint(0, 255, [3, 4, 3], dtype=np.uint8)
for p in range(0, 10):
for v in range(0, 10):
img_1 = Pad(p, v, backend="cv2")(img)
img_2 = T.Pad(p, (v, v, v))(img)
print(f"{p} - {v}", np.allclose(img_1, img_2))
if not np.allclose(img_1, img_2):
print(img_1[..., 0], "\n", img_2[..., 0])
print(img_1[..., 1], "\n", img_2[..., 1])
print(img_1[..., 2], "\n", img_2[..., 2])
exit(0)

View File

@ -114,10 +114,7 @@ class Engine(object):
#TODO(gaotingquan): support rec #TODO(gaotingquan): support rec
class_num = config["Arch"].get("class_num", None) class_num = config["Arch"].get("class_num", None)
self.config["DataLoader"].update({"class_num": class_num}) self.config["DataLoader"].update({"class_num": class_num})
self.model = build_model(self.config, self.mode)
# print(*self.model.state_dict().keys(), sep='\n')
print(self.model.backbone.stages[3][0].dw_conv_list[0].conv)
exit(0)
# build dataloader # build dataloader
if self.mode == 'train': if self.mode == 'train':
self.train_dataloader = build_dataloader( self.train_dataloader = build_dataloader(

View File

@ -12,7 +12,7 @@ from .msmloss import MSMLoss
from .npairsloss import NpairsLoss from .npairsloss import NpairsLoss
from .trihardloss import TriHardLoss from .trihardloss import TriHardLoss
from .triplet import TripletLoss, TripletLossV2 from .triplet import TripletLoss, TripletLossV2
from .tripletangularmarginloss import TTripletAngularMarginLoss from .tripletangularmarginloss import TripletAngularMarginLoss
from .supconloss import SupConLoss from .supconloss import SupConLoss
from .pairwisecosface import PairwiseCosface from .pairwisecosface import PairwiseCosface
from .dmlloss import DMLLoss from .dmlloss import DMLLoss

View File

@ -43,7 +43,7 @@ class TripletAngularMarginLoss(nn.Layer):
ap_value=0.9, ap_value=0.9,
an_value=0.5, an_value=0.5,
feature_from="features"): feature_from="features"):
super(TripletAngleMarginLoss, self).__init__() super(TripletAngularMarginLoss, self).__init__()
self.margin = margin self.margin = margin
self.feature_from = feature_from self.feature_from = feature_from
self.ranking_loss = paddle.nn.loss.MarginRankingLoss( self.ranking_loss = paddle.nn.loss.MarginRankingLoss(