mirror of https://github.com/JDAI-CV/fast-reid.git
Refactor: make pcb and clas in one sub project
parent
9c4c6018a3
commit
e1d069abc7
|
@ -13,5 +13,4 @@ from .mgn import MGN
|
|||
from .moco import MoCo
|
||||
from .distiller import Distiller
|
||||
from .metric import Metric
|
||||
from .pcb import PCB
|
||||
from .pcb_online import PcbOnline
|
||||
from .pcb import PCB
|
|
@ -1,167 +1,37 @@
|
|||
# coding: utf-8
|
||||
"""
|
||||
Sun, Y. , Zheng, L. , Yang, Y. , Tian, Q. , & Wang, S. . (2017). Beyond part models: person retrieval with refined part pooling (and a strong convolutional baseline). Springer, Cham.
|
||||
实现和线上一模一样的PCB
|
||||
"""
|
||||
from typing import Union
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from fastreid.config import configurable
|
||||
from fastreid.modeling.losses import cross_entropy_loss, log_accuracy, contrastive_loss
|
||||
from fastreid.modeling.meta_arch import Baseline
|
||||
from fastreid.modeling.meta_arch import META_ARCH_REGISTRY
|
||||
from fastreid.layers import weights_init_classifier
|
||||
from fastreid.modeling.losses import cross_entropy_loss, log_accuracy
|
||||
|
||||
|
||||
@META_ARCH_REGISTRY.register()
|
||||
class PCB(Baseline):
|
||||
|
||||
@configurable
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
backbone,
|
||||
heads,
|
||||
pixel_mean,
|
||||
pixel_std,
|
||||
part_num,
|
||||
part_dim,
|
||||
embedding_dim,
|
||||
loss_kwargs=None
|
||||
):
|
||||
"""
|
||||
NOTE: this interface is experimental.
|
||||
|
||||
Args:
|
||||
backbone:
|
||||
heads:
|
||||
pixel_mean:
|
||||
pixel_std:
|
||||
part_num
|
||||
"""
|
||||
super(PCB, self).__init__(
|
||||
backbone=backbone,
|
||||
heads=heads,
|
||||
pixel_mean=pixel_mean,
|
||||
pixel_std=pixel_std,
|
||||
loss_kwargs=loss_kwargs
|
||||
)
|
||||
self.part_num = part_num
|
||||
self.part_dim = part_dim
|
||||
self.embedding_dim = embedding_dim
|
||||
self.modify_backbone()
|
||||
self.random_init()
|
||||
|
||||
def modify_backbone(self):
|
||||
self.backbone.avgpool_e = nn.AdaptiveAvgPool2d((1, self.part_num))
|
||||
|
||||
# cnn feature
|
||||
self.resnet_conv = nn.Sequential(
|
||||
self.backbone.conv1,
|
||||
self.backbone.bn1,
|
||||
self.backbone.relu,
|
||||
self.backbone.maxpool,
|
||||
self.backbone.layer1,
|
||||
self.backbone.layer2,
|
||||
self.backbone.layer3,
|
||||
self.backbone.layer4,
|
||||
)
|
||||
self.layer5 = nn.Sequential(
|
||||
self.backbone._make_layer(block=self.backbone.layer4[-1].__class__,
|
||||
planes=512, blocks=1, stride=2,bn_norm='BN', with_se=False),
|
||||
nn.AdaptiveAvgPool2d((1, 1)),
|
||||
)
|
||||
|
||||
self.pool_e = nn.Sequential(self.backbone.avgpool_e)
|
||||
|
||||
# embedding
|
||||
for i in range(self.part_num):
|
||||
name = 'embedder' + str(i)
|
||||
setattr(self, name, nn.Linear(self.embedding_dim, self.part_dim))
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, cfg):
|
||||
config_dict = super(PCB, cls).from_config(cfg)
|
||||
config_dict['part_num'] = cfg.MODEL.PCB.PART_NUM
|
||||
config_dict['part_dim'] = cfg.MODEL.PCB.PART_DIM
|
||||
config_dict['embedding_dim'] = cfg.MODEL.PCB.EMBEDDING_DIM
|
||||
return config_dict
|
||||
|
||||
def random_init(self) -> None:
|
||||
for m in self.layer5.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
nn.init.normal_(m.weight, 0, math.sqrt(2. / n))
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
for i in range(self.part_num):
|
||||
embedder_i = getattr(self, 'embedder' + str(i))
|
||||
embedder_i.apply(weights_init_classifier)
|
||||
|
||||
def forward(self, batched_inputs):
|
||||
# preprocess image
|
||||
images = self.preprocess_image(batched_inputs)
|
||||
bsz = int(images.size(0) / 2)
|
||||
feats = self.backbone(images)
|
||||
feats = torch.cat((feats['full'], feats['parts'][0], feats['parts'][1], feats['parts'][2]), 1)
|
||||
feats = F.normalize(feats, p=2.0, dim=-1)
|
||||
|
||||
# backbone: extract global features and local features
|
||||
features = self.resnet_conv(images)
|
||||
features_full = torch.squeeze(self.layer5(features))
|
||||
features_part = torch.squeeze(self.pool_e(features))
|
||||
qf = feats[0: bsz * 2: 2, ...]
|
||||
xf = feats[1: bsz * 2: 2, ...]
|
||||
outputs = self.heads({'query': qf, 'gallery': xf})
|
||||
|
||||
embeddings_list = []
|
||||
for i in range(self.part_num):
|
||||
if self.part_num == 1:
|
||||
features_i = features_part
|
||||
else:
|
||||
features_i = torch.squeeze(features_part[:, :, i])
|
||||
|
||||
embedder_i = getattr(self, 'embedder' + str(i))
|
||||
embedding_i = embedder_i(features_i)
|
||||
embeddings_list.append(embedding_i)
|
||||
|
||||
all_features = {'full': features_full, 'parts': embeddings_list}
|
||||
outputs['query_feature'] = qf
|
||||
outputs['gallery_feature'] = xf
|
||||
outputs['features'] = {}
|
||||
if self.training:
|
||||
assert "targets" in batched_inputs, "Person ID annotation are missing in training!"
|
||||
targets = batched_inputs["targets"]
|
||||
|
||||
# PreciseBN flag, When do preciseBN on different dataset, the number of classes in new dataset
|
||||
# may be larger than that in the original dataset, so the circle/arcface will
|
||||
# throw an error. We just set all the targets to 0 to avoid this problem.
|
||||
if targets.sum() < 0: targets.zero_()
|
||||
|
||||
outputs = self.heads(all_features, targets)
|
||||
losses = self.losses(outputs, targets) # 损失有问题
|
||||
targets = batched_inputs['targets']
|
||||
losses = self.losses(outputs, targets)
|
||||
return losses
|
||||
else:
|
||||
outputs = self.heads(all_features)
|
||||
return outputs
|
||||
return outputs
|
||||
|
||||
def losses(self, outputs, gt_labels):
|
||||
"""
|
||||
Compute loss from modeling's outputs, the loss function input arguments
|
||||
must be the same as the outputs of the model forwarding.
|
||||
"""
|
||||
# model predictions
|
||||
pred_class_logits = outputs['pred_class_logits'].detach()
|
||||
cls_outputs = outputs['cls_outputs']
|
||||
|
||||
# Log prediction accuracy
|
||||
log_accuracy(pred_class_logits, gt_labels)
|
||||
|
||||
loss_dict = {}
|
||||
loss_names = self.loss_kwargs['loss_names']
|
||||
|
||||
if 'CrossEntropyLoss' in loss_names:
|
||||
ce_kwargs = self.loss_kwargs.get('ce')
|
||||
loss_dict['loss_cls'] = cross_entropy_loss(
|
||||
cls_outputs,
|
||||
gt_labels,
|
||||
ce_kwargs.get('eps'),
|
||||
ce_kwargs.get('alpha')
|
||||
) * ce_kwargs.get('scale')
|
||||
|
||||
|
||||
return loss_dict
|
||||
|
|
|
@ -1,37 +0,0 @@
|
|||
# coding: utf-8
|
||||
"""
|
||||
Sun, Y. , Zheng, L. , Yang, Y. , Tian, Q. , & Wang, S. . (2017). Beyond part models: person retrieval with refined part pooling (and a strong convolutional baseline). Springer, Cham.
|
||||
实现和线上一模一样的PCB
|
||||
"""
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from fastreid.modeling.losses import cross_entropy_loss, log_accuracy, contrastive_loss
|
||||
from fastreid.modeling.meta_arch import Baseline
|
||||
from fastreid.modeling.meta_arch import META_ARCH_REGISTRY
|
||||
|
||||
|
||||
@META_ARCH_REGISTRY.register()
|
||||
class PcbOnline(Baseline):
|
||||
|
||||
def forward(self, batched_inputs):
|
||||
images = self.preprocess_image(batched_inputs)
|
||||
bsz = int(images.size(0) / 2)
|
||||
feats = self.backbone(images)
|
||||
feats = torch.cat((feats['full'], feats['parts'][0], feats['parts'][1], feats['parts'][2]), 1)
|
||||
feats = F.normalize(feats, p=2.0, dim=-1)
|
||||
|
||||
qf = feats[0: bsz * 2: 2, ...]
|
||||
xf = feats[1: bsz * 2: 2, ...]
|
||||
outputs = self.heads({'query': qf, 'gallery': xf})
|
||||
|
||||
outputs['query_feature'] = qf
|
||||
outputs['gallery_feature'] = xf
|
||||
outputs['features'] = {}
|
||||
if self.training:
|
||||
targets = batched_inputs['targets']
|
||||
losses = self.losses(outputs, targets)
|
||||
return losses
|
||||
else:
|
||||
return outputs
|
||||
|
|
@ -1,6 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Time : 2021/10/8 16:55:17
|
||||
# @Author : zuchen.wang@vipshop.com
|
||||
# @File : __init__.py.py
|
||||
from .pair_dataset import PairDataset
|
||||
from .excel_dataset import ExcelDataset
|
|
@ -1,74 +0,0 @@
|
|||
# coding: utf-8
|
||||
import os
|
||||
import logging
|
||||
|
||||
import pandas as pd
|
||||
from tabulate import tabulate
|
||||
from termcolor import colored
|
||||
|
||||
from fastreid.data.data_utils import read_image
|
||||
from fastreid.data.datasets import DATASET_REGISTRY
|
||||
from fastreid.data.datasets.bases import ImageDataset
|
||||
from fastreid.utils.env import seed_all_rng
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class ExcelDataset(ImageDataset):
|
||||
|
||||
_logger = logging.getLogger('fastreid.fastshoe')
|
||||
|
||||
def __init__(self, img_root, anno_path, transform=None, **kwargs):
|
||||
self._logger.info('set with {} random seed: 12345'.format(self.__class__.__name__))
|
||||
seed_all_rng(12345)
|
||||
|
||||
self.img_root = img_root
|
||||
self.anno_path = anno_path
|
||||
self.transform = transform
|
||||
|
||||
df = pd.read_csv(self.anno_path)
|
||||
df = df[['内网crop图', '外网crop图', '确认是否撞款']]
|
||||
df['确认是否撞款'] = df['确认是否撞款'].map({'是': 1, '否': 0})
|
||||
self.df = df
|
||||
|
||||
def __getitem__(self, idx):
|
||||
image_inner, image_outer, label = tuple(self.df.loc[idx])
|
||||
|
||||
image_inner_path = os.path.join(self.img_root, image_inner)
|
||||
image_outer_path = os.path.join(self.img_root, image_outer)
|
||||
|
||||
img1 = read_image(image_inner_path)
|
||||
img2 = read_image(image_outer_path)
|
||||
|
||||
if self.transform:
|
||||
img1 = self.transform(img1)
|
||||
img2 = self.transform(img2)
|
||||
|
||||
return {
|
||||
'img1': img1,
|
||||
'img2': img2,
|
||||
'target': label
|
||||
}
|
||||
|
||||
def __len__(self):
|
||||
return len(self.df)
|
||||
|
||||
#-------------下面是辅助信息------------------#
|
||||
@property
|
||||
def num_classes(self):
|
||||
return 2
|
||||
|
||||
def show_test(self):
|
||||
num_pairs = len(self)
|
||||
num_images = num_pairs * 2
|
||||
|
||||
headers = ['pairs', 'images']
|
||||
csv_results = [[num_pairs, num_images]]
|
||||
|
||||
# tabulate it
|
||||
table = tabulate(
|
||||
csv_results,
|
||||
tablefmt="pipe",
|
||||
headers=headers,
|
||||
numalign="left",
|
||||
)
|
||||
self._logger.info(f"=> Loaded {self.__class__.__name__}: \n" + colored(table, "cyan"))
|
|
@ -1,121 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import os
|
||||
import logging
|
||||
import json
|
||||
import random
|
||||
|
||||
import pandas as pd
|
||||
from tabulate import tabulate
|
||||
from termcolor import colored
|
||||
|
||||
from fastreid.data.datasets import DATASET_REGISTRY
|
||||
from fastreid.data.datasets.bases import ImageDataset
|
||||
from fastreid.data.data_utils import read_image
|
||||
from fastreid.utils.env import seed_all_rng
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class PairDataset(ImageDataset):
|
||||
|
||||
_logger = logging.getLogger('fastreid.fastshoe')
|
||||
|
||||
def __init__(self, img_root: str, anno_path: str, transform=None, mode: str = 'train'):
|
||||
|
||||
assert mode in ('train', 'val', 'test'), self._logger.info(
|
||||
'''mode should the one of ('train', 'val', 'test')''')
|
||||
self.mode = mode
|
||||
if self.mode != 'train':
|
||||
self._logger.info('set {} with {} random seed: 12345'.format(self.mode, self.__class__.__name__))
|
||||
seed_all_rng(12345)
|
||||
|
||||
self.img_root = img_root
|
||||
self.anno_path = anno_path
|
||||
self.transform = transform
|
||||
|
||||
all_data = json.load(open(self.anno_path))
|
||||
pos_folders = []
|
||||
neg_folders = []
|
||||
for data in all_data:
|
||||
if len(data['positive_img_list']) >= 2 and len(data['negative_img_list']) >= 1:
|
||||
pos_folders.append(data['positive_img_list'])
|
||||
neg_folders.append(data['negative_img_list'])
|
||||
|
||||
assert len(pos_folders) == len(neg_folders), self._logger.error('the len of self.pos_foders should be equal to self.pos_foders')
|
||||
self.pos_folders = pos_folders
|
||||
self.neg_folders = neg_folders
|
||||
|
||||
def __len__(self):
|
||||
if self.mode == 'test':
|
||||
return len(self.pos_folders) * 10
|
||||
|
||||
return len(self.pos_folders)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if self.mode == 'test':
|
||||
idx = int(idx / 10)
|
||||
|
||||
pf = self.pos_folders[idx]
|
||||
nf = self.neg_folders[idx]
|
||||
|
||||
label = 1
|
||||
if random.random() < 0.5:
|
||||
# generate positive pair
|
||||
img_path1, img_path2 = random.sample(pf, 2)
|
||||
else:
|
||||
# generate negative pair
|
||||
label = 0
|
||||
img_path1, img_path2 = random.choice(pf), random.choice(nf)
|
||||
|
||||
img_path1 = os.path.join(self.img_root, img_path1)
|
||||
img_path2 = os.path.join(self.img_root, img_path2)
|
||||
|
||||
img1 = read_image(img_path1)
|
||||
img2 = read_image(img_path2)
|
||||
|
||||
if self.transform:
|
||||
img1 = self.transform(img1)
|
||||
img2 = self.transform(img2)
|
||||
|
||||
return {
|
||||
'img1': img1,
|
||||
'img2': img2,
|
||||
'target': label
|
||||
}
|
||||
|
||||
#-------------下面是辅助信息------------------#
|
||||
@property
|
||||
def num_classes(self):
|
||||
return 2
|
||||
|
||||
@property
|
||||
def num_folders(self):
|
||||
return len(self)
|
||||
|
||||
@property
|
||||
def num_pos_images(self):
|
||||
return sum([len(x) for x in self.pos_folders])
|
||||
|
||||
@property
|
||||
def num_neg_images(self):
|
||||
return sum([len(x) for x in self.neg_folders])
|
||||
|
||||
def describe(self):
|
||||
headers = ['subset', 'folders', 'pos images', 'neg images']
|
||||
csv_results = [[self.mode, self.num_folders, self.num_pos_images, self.num_neg_images]]
|
||||
|
||||
# tabulate it
|
||||
table = tabulate(
|
||||
csv_results,
|
||||
tablefmt="pipe",
|
||||
headers=headers,
|
||||
numalign="left",
|
||||
)
|
||||
|
||||
self._logger.info(f"=> Loaded {self.__class__.__name__}: \n" + colored(table, "cyan"))
|
||||
|
||||
def show_train(self):
|
||||
return self.describe()
|
||||
|
||||
def show_test(self):
|
||||
return self.describe()
|
|
@ -1,90 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Time : 2021/10/11 9:47:53
|
||||
# @Author : zuchen.wang@vipshop.com
|
||||
# @File : trainer.py
|
||||
import logging
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
from fastreid.data.build import _root
|
||||
from fastreid.engine import DefaultTrainer
|
||||
from fastreid.data.datasets import DATASET_REGISTRY
|
||||
from fastreid.utils import comm
|
||||
from fastreid.data.transforms import build_transforms
|
||||
from fastreid.data.build import build_reid_train_loader, build_reid_test_loader
|
||||
from fastreid.evaluation import PairScoreEvaluator, PairDistanceEvaluator
|
||||
# ensure custom datasets are registered
|
||||
import projects.FastShoe.fastshoe.data
|
||||
|
||||
|
||||
class PairTrainer(DefaultTrainer):
|
||||
|
||||
_logger = logging.getLogger('fastreid.fastshoe')
|
||||
|
||||
@classmethod
|
||||
def build_train_loader(cls, cfg):
|
||||
cls._logger.info("Prepare training set")
|
||||
|
||||
transforms = build_transforms(cfg, is_train=True)
|
||||
img_root = os.path.join(_root, '20211115/rotate_shoe_crop_images')
|
||||
anno_path = os.path.join(_root, 'labels/1115/train_1115.json')
|
||||
cls._logger.info('Loading {} with {}.'.format(img_root, anno_path))
|
||||
|
||||
datasets = []
|
||||
for d in cfg.DATASETS.NAMES:
|
||||
dataset = DATASET_REGISTRY.get(d)(img_root=img_root, anno_path=anno_path, transform=transforms, mode='train')
|
||||
if comm.is_main_process():
|
||||
dataset.show_train()
|
||||
datasets.append(dataset)
|
||||
|
||||
train_set = datasets[0] if len(datasets) == 1 else torch.utils.data.ConcatDataset(datasets)
|
||||
data_loader = build_reid_train_loader(cfg, train_set=train_set)
|
||||
return data_loader
|
||||
|
||||
@classmethod
|
||||
def build_test_loader(cls, cfg, dataset_name):
|
||||
cls._logger.info("Prepare {} set".format('test' if cfg.eval_only else 'validation'))
|
||||
|
||||
transforms = build_transforms(cfg, is_train=False)
|
||||
if dataset_name == 'PairDataset':
|
||||
img_root = os.path.join(_root, '20211115/rotate_shoe_crop_images')
|
||||
val_json = os.path.join(_root, 'labels/1115/val_1115.json')
|
||||
test_json = os.path.join(_root, 'labels/1115/val_1115.json')
|
||||
|
||||
anno_path, mode = (test_json, 'test') if cfg.eval_only else (val_json, 'val')
|
||||
cls._logger.info('Loading {} with {} for {}.'.format(img_root, anno_path, mode))
|
||||
test_set = DATASET_REGISTRY.get(dataset_name)(img_root=img_root, anno_path=anno_path, transform=transforms, mode=mode)
|
||||
test_set.show_test()
|
||||
|
||||
elif dataset_name == 'ExcelDataset':
|
||||
img_root_0830 = os.path.join(_root, 'excel/0830/rotate_shoe_crop_images')
|
||||
test_csv_0830 = os.path.join(_root, 'excel/0830/excel_pair_crop.csv')
|
||||
|
||||
img_root_0908 = os.path.join(_root, 'excel/0908/rotate_shoe_crop_images')
|
||||
val_csv_0908 = os.path.join(_root, 'excel/0908/excel_pair_crop_val.csv')
|
||||
test_csv_0908 = os.path.join(_root, 'excel/0908/excel_pair_crop_test.csv')
|
||||
if cfg.eval_only:
|
||||
cls._logger.info('Loading {} with {} for test.'.format(img_root_0830, test_csv_0830))
|
||||
test_set_0830 = DATASET_REGISTRY.get(dataset_name)(img_root=img_root_0830, anno_path=test_csv_0830, transform=transforms)
|
||||
test_set_0830.show_test()
|
||||
|
||||
cls._logger.info('Loading {} with {} for test.'.format(img_root_0908, test_csv_0908))
|
||||
test_set_0908 = DATASET_REGISTRY.get(dataset_name)(img_root=img_root_0908, anno_path=test_csv_0908, transform=transforms)
|
||||
test_set_0908.show_test()
|
||||
|
||||
test_set = torch.utils.data.ConcatDataset((test_set_0830, test_set_0908))
|
||||
else:
|
||||
cls._logger.info('Loading {} with {} for validation.'.format(img_root_0908, val_csv_0908))
|
||||
test_set = DATASET_REGISTRY.get(dataset_name)(img_root=img_root_0908, anno_path=val_csv_0908, transform=transforms)
|
||||
test_set.show_test()
|
||||
else:
|
||||
raise ValueError("Undefined Dataset!!!")
|
||||
|
||||
data_loader, _ = build_reid_test_loader(cfg, test_set=test_set)
|
||||
return data_loader
|
||||
|
||||
@classmethod
|
||||
def build_evaluator(cls, cfg, dataset_name, output_dir=None):
|
||||
data_loader = cls.build_test_loader(cfg, dataset_name)
|
||||
return data_loader, PairScoreEvaluator(cfg, output_dir)
|
|
@ -1,7 +1,7 @@
|
|||
_BASE_: base.yaml
|
||||
|
||||
MODEL:
|
||||
META_ARCHITECTURE: PcbOnline
|
||||
META_ARCHITECTURE: PCB
|
||||
|
||||
PCB:
|
||||
PART_NUM: 3
|
|
@ -15,7 +15,7 @@ from fastreid.utils.env import seed_all_rng
|
|||
@DATASET_REGISTRY.register()
|
||||
class ExcelDataset(ImageDataset):
|
||||
|
||||
_logger = logging.getLogger('fastreid.fastshoe')
|
||||
_logger = logging.getLogger('fastreid.shoe.data.excel')
|
||||
|
||||
def __init__(self, img_root, anno_path, transform=None, **kwargs):
|
||||
self._logger.info('set with {} random seed: 12345'.format(self.__class__.__name__))
|
|
@ -23,7 +23,7 @@ from .augment import augment_pos_image, augment_neg_image
|
|||
|
||||
class ShoeDataset(ImageDataset):
|
||||
|
||||
_logger = logging.getLogger('fastreid.shoe.dataset')
|
||||
_logger = logging.getLogger('fastreid.shoe.data')
|
||||
|
||||
def __init__(self, img_root: str, anno_path: str, transform=None, mode: str = 'train'):
|
||||
if mode not in ('train', 'val', 'test'):
|
|
@ -15,12 +15,12 @@ from fastreid.data.transforms import build_transforms
|
|||
from fastreid.data.build import build_reid_train_loader, build_reid_test_loader
|
||||
from fastreid.evaluation import ShoeScoreEvaluator, ShoeDistanceEvaluator
|
||||
# ensure custom datasets are registered
|
||||
import projects.ShoeClas.shoeclas.data
|
||||
import projects.Shoe.shoe.data
|
||||
|
||||
|
||||
class PairTrainer(DefaultTrainer):
|
||||
|
||||
_logger = logging.getLogger('fastreid.shoeclas')
|
||||
_logger = logging.getLogger('fastreid.shoe')
|
||||
|
||||
@classmethod
|
||||
def build_train_loader(cls, cfg):
|
|
@ -13,7 +13,7 @@ from fastreid.config import get_cfg
|
|||
from fastreid.engine import default_argument_parser, default_setup, launch
|
||||
from fastreid.utils.checkpoint import Checkpointer
|
||||
from fastreid.utils import bughook
|
||||
from fastshoe import PairTrainer
|
||||
from shoe import PairTrainer
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,50 @@
|
|||
# coding: utf-8
|
||||
import os
|
||||
from collections import defaultdict
|
||||
import sys
|
||||
import shutil
|
||||
|
||||
sys.path.append('')
|
||||
|
||||
from fastreid.utils.env import seed_all_rng
|
||||
from fastreid.data.datasets import DATASET_REGISTRY
|
||||
|
||||
import projects.Shoe.shoe.data
|
||||
|
||||
seed_all_rng(0)
|
||||
|
||||
save_root = 'debug/neg_aug'
|
||||
if os.path.exists(save_root):
|
||||
shutil.rmtree(save_root)
|
||||
os.mkdir(save_root)
|
||||
|
||||
root = '/data97/bijia/shoe/'
|
||||
img_root=os.path.join(root, 'shoe_crop_all_images')
|
||||
anno_path=os.path.join(root, 'labels/1102/train_1102.json')
|
||||
dataset = DATASET_REGISTRY.get('PairDataset')(img_root=img_root, anno_path=anno_path, transform=None, mode='train')
|
||||
|
||||
pos_imgs = []
|
||||
neg_imgs = []
|
||||
for i in range(100):
|
||||
data = dataset[100]
|
||||
img1 = data['img1']
|
||||
img2 = data['img2']
|
||||
target = data['target']
|
||||
|
||||
if target == 0:
|
||||
pos_imgs.append(img1)
|
||||
neg_imgs.append(img2)
|
||||
else:
|
||||
pos_imgs.append(img1)
|
||||
pos_imgs.append(img2)
|
||||
|
||||
pos_dict = defaultdict(list)
|
||||
for img in pos_imgs:
|
||||
pos_dict[img.size].append(img)
|
||||
|
||||
for i, k in enumerate(pos_dict.keys()):
|
||||
img = pos_dict[k][0]
|
||||
img.save(os.path.join(save_root, 'p-' + str(i) + '.jpg'))
|
||||
|
||||
for i, img in enumerate(neg_imgs):
|
||||
img.save(os.path.join(save_root, 'n-' + str(i) + '.jpg'))
|
|
@ -1,63 +0,0 @@
|
|||
INPUT:
|
||||
SIZE_TRAIN: [0, 0] # resize size
|
||||
SIZE_TEST: [0, 0]
|
||||
|
||||
CROP:
|
||||
ENABLED: False
|
||||
|
||||
FLIP:
|
||||
ENABLED: False
|
||||
|
||||
PADDING:
|
||||
ENABLED: False
|
||||
|
||||
CJ:
|
||||
ENABLED: False
|
||||
|
||||
AFFINE:
|
||||
ENABLED: False
|
||||
|
||||
AUTOAUG:
|
||||
ENABLED: False
|
||||
|
||||
AUGMIX:
|
||||
ENABLED: False
|
||||
|
||||
REA:
|
||||
ENABLED: False
|
||||
|
||||
RPT:
|
||||
ENABLED: False
|
||||
|
||||
|
||||
DATALOADER:
|
||||
SAMPLER_TRAIN: TrainingSampler
|
||||
NUM_WORKERS: 8
|
||||
|
||||
SOLVER:
|
||||
MAX_EPOCH: 1000
|
||||
AMP:
|
||||
ENABLED: True
|
||||
|
||||
OPT: SGD
|
||||
SCHED: CosineAnnealingLR
|
||||
|
||||
BASE_LR: 0.001
|
||||
MOMENTUM: 0.9
|
||||
NESTEROV: False
|
||||
|
||||
BIAS_LR_FACTOR: 1.
|
||||
WEIGHT_DECAY: 0.0005
|
||||
WEIGHT_DECAY_BIAS: 0.
|
||||
IMS_PER_BATCH: 16
|
||||
|
||||
ETA_MIN_LR: 0.00003
|
||||
|
||||
WARMUP_FACTOR: 0.1
|
||||
WARMUP_ITERS: 100
|
||||
|
||||
CHECKPOINT_PERIOD: 1
|
||||
|
||||
TEST:
|
||||
EVAL_PERIOD: 1
|
||||
IMS_PER_BATCH: 32
|
|
@ -1,5 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Time : 2021/10/11 10:10:47
|
||||
# @Author : zuchen.wang@vipshop.com
|
||||
# @File : __init__.py.py
|
||||
from .trainer import PairTrainer
|
|
@ -1,115 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import os
|
||||
import logging
|
||||
import json
|
||||
import random
|
||||
|
||||
import pandas as pd
|
||||
from tabulate import tabulate
|
||||
from termcolor import colored
|
||||
|
||||
from fastreid.data.datasets import DATASET_REGISTRY
|
||||
from fastreid.data.datasets.bases import ImageDataset
|
||||
from fastreid.data.data_utils import read_image
|
||||
from fastreid.utils.env import seed_all_rng
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class PairDataset(ImageDataset):
|
||||
|
||||
_logger = logging.getLogger('fastreid.fastshoe')
|
||||
|
||||
def __init__(self, img_root: str, anno_path: str, transform=None, mode: str = 'train'):
|
||||
|
||||
assert mode in ('train', 'val', 'test'), self._logger.info(
|
||||
'''mode should the one of ('train', 'val', 'test')''')
|
||||
self.mode = mode
|
||||
if self.mode != 'train':
|
||||
self._logger.info('set {} with {} random seed: 12345'.format(self.mode, self.__class__.__name__))
|
||||
seed_all_rng(12345)
|
||||
|
||||
self.img_root = img_root
|
||||
self.anno_path = anno_path
|
||||
self.transform = transform
|
||||
|
||||
all_data = json.load(open(self.anno_path))
|
||||
pos_folders = []
|
||||
neg_folders = []
|
||||
for data in all_data:
|
||||
if len(data['positive_img_list']) >= 2 and len(data['negative_img_list']) >= 1:
|
||||
pos_folders.append(data['positive_img_list'])
|
||||
neg_folders.append(data['negative_img_list'])
|
||||
|
||||
assert len(pos_folders) == len(neg_folders), self._logger.error('the len of self.pos_foders should be equal to self.pos_foders')
|
||||
self.pos_folders = pos_folders
|
||||
self.neg_folders = neg_folders
|
||||
|
||||
def __len__(self):
|
||||
return len(self.pos_folders)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
pf = self.pos_folders[idx]
|
||||
nf = self.neg_folders[idx]
|
||||
|
||||
label = 1
|
||||
if random.random() < 0.5:
|
||||
# generate positive pair
|
||||
img_path1, img_path2 = random.sample(pf, 2)
|
||||
else:
|
||||
# generate negative pair
|
||||
label = 0
|
||||
img_path1, img_path2 = random.choice(pf), random.choice(nf)
|
||||
|
||||
img_path1 = os.path.join(self.img_root, img_path1)
|
||||
img_path2 = os.path.join(self.img_root, img_path2)
|
||||
|
||||
img1 = read_image(img_path1)
|
||||
img2 = read_image(img_path2)
|
||||
|
||||
if self.transform:
|
||||
img1 = self.transform(img1)
|
||||
img2 = self.transform(img2)
|
||||
|
||||
return {
|
||||
'img1': img1,
|
||||
'img2': img2,
|
||||
'target': label
|
||||
}
|
||||
|
||||
#-------------下面是辅助信息------------------#
|
||||
@property
|
||||
def num_classes(self):
|
||||
return 2
|
||||
|
||||
@property
|
||||
def num_folders(self):
|
||||
return len(self)
|
||||
|
||||
@property
|
||||
def num_pos_images(self):
|
||||
return sum([len(x) for x in self.pos_folders])
|
||||
|
||||
@property
|
||||
def num_neg_images(self):
|
||||
return sum([len(x) for x in self.neg_folders])
|
||||
|
||||
def describe(self):
|
||||
headers = ['subset', 'folders', 'pos images', 'neg images']
|
||||
csv_results = [[self.mode, self.num_folders, self.num_pos_images, self.num_neg_images]]
|
||||
|
||||
# tabulate it
|
||||
table = tabulate(
|
||||
csv_results,
|
||||
tablefmt="pipe",
|
||||
headers=headers,
|
||||
numalign="left",
|
||||
)
|
||||
|
||||
self._logger.info(f"=> Loaded {self.__class__.__name__}: \n" + colored(table, "cyan"))
|
||||
|
||||
def show_train(self):
|
||||
return self.describe()
|
||||
|
||||
def show_test(self):
|
||||
return self.describe()
|
|
@ -1,60 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Time : 2021/10/11 10:12:28
|
||||
# @Author : zuchen.wang@vipshop.com
|
||||
# @File : train_net.py.py
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.append('.')
|
||||
|
||||
from fastreid.config import get_cfg
|
||||
from fastreid.engine import default_argument_parser, default_setup, launch
|
||||
from fastreid.utils.checkpoint import Checkpointer
|
||||
from fastreid.utils import bughook
|
||||
from shoeclas import PairTrainer
|
||||
|
||||
|
||||
|
||||
def setup(args):
|
||||
"""
|
||||
Create configs and perform basic setups.
|
||||
"""
|
||||
cfg = get_cfg()
|
||||
cfg.merge_from_file(args.config_file)
|
||||
cfg.merge_from_list(args.opts)
|
||||
setattr(cfg, 'eval_only', args.eval_only)
|
||||
cfg.freeze()
|
||||
default_setup(cfg, args)
|
||||
return cfg
|
||||
|
||||
|
||||
def main(args):
|
||||
cfg = setup(args)
|
||||
|
||||
if args.eval_only:
|
||||
cfg.defrost()
|
||||
cfg.MODEL.BACKBONE.PRETRAIN = False
|
||||
|
||||
model = PairTrainer.build_model(cfg)
|
||||
Checkpointer(model).load(cfg.MODEL.WEIGHTS) # load trained model
|
||||
res = PairTrainer.test(cfg, model)
|
||||
return res
|
||||
else:
|
||||
trainer = PairTrainer(cfg)
|
||||
trainer.resume_or_load(resume=args.resume)
|
||||
return trainer.train()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = default_argument_parser().parse_args()
|
||||
print("Command Line Args:", args)
|
||||
launch(
|
||||
main,
|
||||
args.num_gpus,
|
||||
num_machines=args.num_machines,
|
||||
machine_rank=args.machine_rank,
|
||||
dist_url='auto',
|
||||
args=(args,),
|
||||
)
|
Loading…
Reference in New Issue