Refactor: make pcb and clas in one sub project

pull/608/head
zuchen.wang 2021-11-22 16:13:08 +08:00
parent 9c4c6018a3
commit e1d069abc7
24 changed files with 73 additions and 725 deletions

View File

@ -13,5 +13,4 @@ from .mgn import MGN
from .moco import MoCo
from .distiller import Distiller
from .metric import Metric
from .pcb import PCB
from .pcb_online import PcbOnline
from .pcb import PCB

View File

@ -1,167 +1,37 @@
# coding: utf-8
"""
Sun, Y. , Zheng, L. , Yang, Y. , Tian, Q. , & Wang, S. . (2017). Beyond part models: person retrieval with refined part pooling (and a strong convolutional baseline). Springer, Cham.
实现和线上一模一样的PCB
"""
from typing import Union
import math
import torch
from torch import nn
import torch.nn.functional as F
from fastreid.config import configurable
from fastreid.modeling.losses import cross_entropy_loss, log_accuracy, contrastive_loss
from fastreid.modeling.meta_arch import Baseline
from fastreid.modeling.meta_arch import META_ARCH_REGISTRY
from fastreid.layers import weights_init_classifier
from fastreid.modeling.losses import cross_entropy_loss, log_accuracy
@META_ARCH_REGISTRY.register()
class PCB(Baseline):
@configurable
def __init__(
self,
*,
backbone,
heads,
pixel_mean,
pixel_std,
part_num,
part_dim,
embedding_dim,
loss_kwargs=None
):
"""
NOTE: this interface is experimental.
Args:
backbone:
heads:
pixel_mean:
pixel_std:
part_num
"""
super(PCB, self).__init__(
backbone=backbone,
heads=heads,
pixel_mean=pixel_mean,
pixel_std=pixel_std,
loss_kwargs=loss_kwargs
)
self.part_num = part_num
self.part_dim = part_dim
self.embedding_dim = embedding_dim
self.modify_backbone()
self.random_init()
def modify_backbone(self):
self.backbone.avgpool_e = nn.AdaptiveAvgPool2d((1, self.part_num))
# cnn feature
self.resnet_conv = nn.Sequential(
self.backbone.conv1,
self.backbone.bn1,
self.backbone.relu,
self.backbone.maxpool,
self.backbone.layer1,
self.backbone.layer2,
self.backbone.layer3,
self.backbone.layer4,
)
self.layer5 = nn.Sequential(
self.backbone._make_layer(block=self.backbone.layer4[-1].__class__,
planes=512, blocks=1, stride=2,bn_norm='BN', with_se=False),
nn.AdaptiveAvgPool2d((1, 1)),
)
self.pool_e = nn.Sequential(self.backbone.avgpool_e)
# embedding
for i in range(self.part_num):
name = 'embedder' + str(i)
setattr(self, name, nn.Linear(self.embedding_dim, self.part_dim))
@classmethod
def from_config(cls, cfg):
config_dict = super(PCB, cls).from_config(cfg)
config_dict['part_num'] = cfg.MODEL.PCB.PART_NUM
config_dict['part_dim'] = cfg.MODEL.PCB.PART_DIM
config_dict['embedding_dim'] = cfg.MODEL.PCB.EMBEDDING_DIM
return config_dict
def random_init(self) -> None:
for m in self.layer5.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
nn.init.normal_(m.weight, 0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
for i in range(self.part_num):
embedder_i = getattr(self, 'embedder' + str(i))
embedder_i.apply(weights_init_classifier)
def forward(self, batched_inputs):
# preprocess image
images = self.preprocess_image(batched_inputs)
bsz = int(images.size(0) / 2)
feats = self.backbone(images)
feats = torch.cat((feats['full'], feats['parts'][0], feats['parts'][1], feats['parts'][2]), 1)
feats = F.normalize(feats, p=2.0, dim=-1)
# backbone: extract global features and local features
features = self.resnet_conv(images)
features_full = torch.squeeze(self.layer5(features))
features_part = torch.squeeze(self.pool_e(features))
qf = feats[0: bsz * 2: 2, ...]
xf = feats[1: bsz * 2: 2, ...]
outputs = self.heads({'query': qf, 'gallery': xf})
embeddings_list = []
for i in range(self.part_num):
if self.part_num == 1:
features_i = features_part
else:
features_i = torch.squeeze(features_part[:, :, i])
embedder_i = getattr(self, 'embedder' + str(i))
embedding_i = embedder_i(features_i)
embeddings_list.append(embedding_i)
all_features = {'full': features_full, 'parts': embeddings_list}
outputs['query_feature'] = qf
outputs['gallery_feature'] = xf
outputs['features'] = {}
if self.training:
assert "targets" in batched_inputs, "Person ID annotation are missing in training!"
targets = batched_inputs["targets"]
# PreciseBN flag, When do preciseBN on different dataset, the number of classes in new dataset
# may be larger than that in the original dataset, so the circle/arcface will
# throw an error. We just set all the targets to 0 to avoid this problem.
if targets.sum() < 0: targets.zero_()
outputs = self.heads(all_features, targets)
losses = self.losses(outputs, targets) # 损失有问题
targets = batched_inputs['targets']
losses = self.losses(outputs, targets)
return losses
else:
outputs = self.heads(all_features)
return outputs
return outputs
def losses(self, outputs, gt_labels):
"""
Compute loss from modeling's outputs, the loss function input arguments
must be the same as the outputs of the model forwarding.
"""
# model predictions
pred_class_logits = outputs['pred_class_logits'].detach()
cls_outputs = outputs['cls_outputs']
# Log prediction accuracy
log_accuracy(pred_class_logits, gt_labels)
loss_dict = {}
loss_names = self.loss_kwargs['loss_names']
if 'CrossEntropyLoss' in loss_names:
ce_kwargs = self.loss_kwargs.get('ce')
loss_dict['loss_cls'] = cross_entropy_loss(
cls_outputs,
gt_labels,
ce_kwargs.get('eps'),
ce_kwargs.get('alpha')
) * ce_kwargs.get('scale')
return loss_dict

View File

@ -1,37 +0,0 @@
# coding: utf-8
"""
Sun, Y. , Zheng, L. , Yang, Y. , Tian, Q. , & Wang, S. . (2017). Beyond part models: person retrieval with refined part pooling (and a strong convolutional baseline). Springer, Cham.
实现和线上一模一样的PCB
"""
import torch
import torch.nn.functional as F
from fastreid.modeling.losses import cross_entropy_loss, log_accuracy, contrastive_loss
from fastreid.modeling.meta_arch import Baseline
from fastreid.modeling.meta_arch import META_ARCH_REGISTRY
@META_ARCH_REGISTRY.register()
class PcbOnline(Baseline):
def forward(self, batched_inputs):
images = self.preprocess_image(batched_inputs)
bsz = int(images.size(0) / 2)
feats = self.backbone(images)
feats = torch.cat((feats['full'], feats['parts'][0], feats['parts'][1], feats['parts'][2]), 1)
feats = F.normalize(feats, p=2.0, dim=-1)
qf = feats[0: bsz * 2: 2, ...]
xf = feats[1: bsz * 2: 2, ...]
outputs = self.heads({'query': qf, 'gallery': xf})
outputs['query_feature'] = qf
outputs['gallery_feature'] = xf
outputs['features'] = {}
if self.training:
targets = batched_inputs['targets']
losses = self.losses(outputs, targets)
return losses
else:
return outputs

View File

@ -1,6 +0,0 @@
# -*- coding: utf-8 -*-
# @Time : 2021/10/8 16:55:17
# @Author : zuchen.wang@vipshop.com
# @File : __init__.py.py
from .pair_dataset import PairDataset
from .excel_dataset import ExcelDataset

View File

@ -1,74 +0,0 @@
# coding: utf-8
import os
import logging
import pandas as pd
from tabulate import tabulate
from termcolor import colored
from fastreid.data.data_utils import read_image
from fastreid.data.datasets import DATASET_REGISTRY
from fastreid.data.datasets.bases import ImageDataset
from fastreid.utils.env import seed_all_rng
@DATASET_REGISTRY.register()
class ExcelDataset(ImageDataset):
_logger = logging.getLogger('fastreid.fastshoe')
def __init__(self, img_root, anno_path, transform=None, **kwargs):
self._logger.info('set with {} random seed: 12345'.format(self.__class__.__name__))
seed_all_rng(12345)
self.img_root = img_root
self.anno_path = anno_path
self.transform = transform
df = pd.read_csv(self.anno_path)
df = df[['内网crop图', '外网crop图', '确认是否撞款']]
df['确认是否撞款'] = df['确认是否撞款'].map({'': 1, '': 0})
self.df = df
def __getitem__(self, idx):
image_inner, image_outer, label = tuple(self.df.loc[idx])
image_inner_path = os.path.join(self.img_root, image_inner)
image_outer_path = os.path.join(self.img_root, image_outer)
img1 = read_image(image_inner_path)
img2 = read_image(image_outer_path)
if self.transform:
img1 = self.transform(img1)
img2 = self.transform(img2)
return {
'img1': img1,
'img2': img2,
'target': label
}
def __len__(self):
return len(self.df)
#-------------下面是辅助信息------------------#
@property
def num_classes(self):
return 2
def show_test(self):
num_pairs = len(self)
num_images = num_pairs * 2
headers = ['pairs', 'images']
csv_results = [[num_pairs, num_images]]
# tabulate it
table = tabulate(
csv_results,
tablefmt="pipe",
headers=headers,
numalign="left",
)
self._logger.info(f"=> Loaded {self.__class__.__name__}: \n" + colored(table, "cyan"))

View File

@ -1,121 +0,0 @@
# -*- coding: utf-8 -*-
import os
import logging
import json
import random
import pandas as pd
from tabulate import tabulate
from termcolor import colored
from fastreid.data.datasets import DATASET_REGISTRY
from fastreid.data.datasets.bases import ImageDataset
from fastreid.data.data_utils import read_image
from fastreid.utils.env import seed_all_rng
@DATASET_REGISTRY.register()
class PairDataset(ImageDataset):
_logger = logging.getLogger('fastreid.fastshoe')
def __init__(self, img_root: str, anno_path: str, transform=None, mode: str = 'train'):
assert mode in ('train', 'val', 'test'), self._logger.info(
'''mode should the one of ('train', 'val', 'test')''')
self.mode = mode
if self.mode != 'train':
self._logger.info('set {} with {} random seed: 12345'.format(self.mode, self.__class__.__name__))
seed_all_rng(12345)
self.img_root = img_root
self.anno_path = anno_path
self.transform = transform
all_data = json.load(open(self.anno_path))
pos_folders = []
neg_folders = []
for data in all_data:
if len(data['positive_img_list']) >= 2 and len(data['negative_img_list']) >= 1:
pos_folders.append(data['positive_img_list'])
neg_folders.append(data['negative_img_list'])
assert len(pos_folders) == len(neg_folders), self._logger.error('the len of self.pos_foders should be equal to self.pos_foders')
self.pos_folders = pos_folders
self.neg_folders = neg_folders
def __len__(self):
if self.mode == 'test':
return len(self.pos_folders) * 10
return len(self.pos_folders)
def __getitem__(self, idx):
if self.mode == 'test':
idx = int(idx / 10)
pf = self.pos_folders[idx]
nf = self.neg_folders[idx]
label = 1
if random.random() < 0.5:
# generate positive pair
img_path1, img_path2 = random.sample(pf, 2)
else:
# generate negative pair
label = 0
img_path1, img_path2 = random.choice(pf), random.choice(nf)
img_path1 = os.path.join(self.img_root, img_path1)
img_path2 = os.path.join(self.img_root, img_path2)
img1 = read_image(img_path1)
img2 = read_image(img_path2)
if self.transform:
img1 = self.transform(img1)
img2 = self.transform(img2)
return {
'img1': img1,
'img2': img2,
'target': label
}
#-------------下面是辅助信息------------------#
@property
def num_classes(self):
return 2
@property
def num_folders(self):
return len(self)
@property
def num_pos_images(self):
return sum([len(x) for x in self.pos_folders])
@property
def num_neg_images(self):
return sum([len(x) for x in self.neg_folders])
def describe(self):
headers = ['subset', 'folders', 'pos images', 'neg images']
csv_results = [[self.mode, self.num_folders, self.num_pos_images, self.num_neg_images]]
# tabulate it
table = tabulate(
csv_results,
tablefmt="pipe",
headers=headers,
numalign="left",
)
self._logger.info(f"=> Loaded {self.__class__.__name__}: \n" + colored(table, "cyan"))
def show_train(self):
return self.describe()
def show_test(self):
return self.describe()

View File

@ -1,90 +0,0 @@
# -*- coding: utf-8 -*-
# @Time : 2021/10/11 9:47:53
# @Author : zuchen.wang@vipshop.com
# @File : trainer.py
import logging
import os
import torch
from fastreid.data.build import _root
from fastreid.engine import DefaultTrainer
from fastreid.data.datasets import DATASET_REGISTRY
from fastreid.utils import comm
from fastreid.data.transforms import build_transforms
from fastreid.data.build import build_reid_train_loader, build_reid_test_loader
from fastreid.evaluation import PairScoreEvaluator, PairDistanceEvaluator
# ensure custom datasets are registered
import projects.FastShoe.fastshoe.data
class PairTrainer(DefaultTrainer):
_logger = logging.getLogger('fastreid.fastshoe')
@classmethod
def build_train_loader(cls, cfg):
cls._logger.info("Prepare training set")
transforms = build_transforms(cfg, is_train=True)
img_root = os.path.join(_root, '20211115/rotate_shoe_crop_images')
anno_path = os.path.join(_root, 'labels/1115/train_1115.json')
cls._logger.info('Loading {} with {}.'.format(img_root, anno_path))
datasets = []
for d in cfg.DATASETS.NAMES:
dataset = DATASET_REGISTRY.get(d)(img_root=img_root, anno_path=anno_path, transform=transforms, mode='train')
if comm.is_main_process():
dataset.show_train()
datasets.append(dataset)
train_set = datasets[0] if len(datasets) == 1 else torch.utils.data.ConcatDataset(datasets)
data_loader = build_reid_train_loader(cfg, train_set=train_set)
return data_loader
@classmethod
def build_test_loader(cls, cfg, dataset_name):
cls._logger.info("Prepare {} set".format('test' if cfg.eval_only else 'validation'))
transforms = build_transforms(cfg, is_train=False)
if dataset_name == 'PairDataset':
img_root = os.path.join(_root, '20211115/rotate_shoe_crop_images')
val_json = os.path.join(_root, 'labels/1115/val_1115.json')
test_json = os.path.join(_root, 'labels/1115/val_1115.json')
anno_path, mode = (test_json, 'test') if cfg.eval_only else (val_json, 'val')
cls._logger.info('Loading {} with {} for {}.'.format(img_root, anno_path, mode))
test_set = DATASET_REGISTRY.get(dataset_name)(img_root=img_root, anno_path=anno_path, transform=transforms, mode=mode)
test_set.show_test()
elif dataset_name == 'ExcelDataset':
img_root_0830 = os.path.join(_root, 'excel/0830/rotate_shoe_crop_images')
test_csv_0830 = os.path.join(_root, 'excel/0830/excel_pair_crop.csv')
img_root_0908 = os.path.join(_root, 'excel/0908/rotate_shoe_crop_images')
val_csv_0908 = os.path.join(_root, 'excel/0908/excel_pair_crop_val.csv')
test_csv_0908 = os.path.join(_root, 'excel/0908/excel_pair_crop_test.csv')
if cfg.eval_only:
cls._logger.info('Loading {} with {} for test.'.format(img_root_0830, test_csv_0830))
test_set_0830 = DATASET_REGISTRY.get(dataset_name)(img_root=img_root_0830, anno_path=test_csv_0830, transform=transforms)
test_set_0830.show_test()
cls._logger.info('Loading {} with {} for test.'.format(img_root_0908, test_csv_0908))
test_set_0908 = DATASET_REGISTRY.get(dataset_name)(img_root=img_root_0908, anno_path=test_csv_0908, transform=transforms)
test_set_0908.show_test()
test_set = torch.utils.data.ConcatDataset((test_set_0830, test_set_0908))
else:
cls._logger.info('Loading {} with {} for validation.'.format(img_root_0908, val_csv_0908))
test_set = DATASET_REGISTRY.get(dataset_name)(img_root=img_root_0908, anno_path=val_csv_0908, transform=transforms)
test_set.show_test()
else:
raise ValueError("Undefined Dataset!!!")
data_loader, _ = build_reid_test_loader(cfg, test_set=test_set)
return data_loader
@classmethod
def build_evaluator(cls, cfg, dataset_name, output_dir=None):
data_loader = cls.build_test_loader(cfg, dataset_name)
return data_loader, PairScoreEvaluator(cfg, output_dir)

View File

@ -1,7 +1,7 @@
_BASE_: base.yaml
MODEL:
META_ARCHITECTURE: PcbOnline
META_ARCHITECTURE: PCB
PCB:
PART_NUM: 3

View File

@ -15,7 +15,7 @@ from fastreid.utils.env import seed_all_rng
@DATASET_REGISTRY.register()
class ExcelDataset(ImageDataset):
_logger = logging.getLogger('fastreid.fastshoe')
_logger = logging.getLogger('fastreid.shoe.data.excel')
def __init__(self, img_root, anno_path, transform=None, **kwargs):
self._logger.info('set with {} random seed: 12345'.format(self.__class__.__name__))

View File

@ -23,7 +23,7 @@ from .augment import augment_pos_image, augment_neg_image
class ShoeDataset(ImageDataset):
_logger = logging.getLogger('fastreid.shoe.dataset')
_logger = logging.getLogger('fastreid.shoe.data')
def __init__(self, img_root: str, anno_path: str, transform=None, mode: str = 'train'):
if mode not in ('train', 'val', 'test'):

View File

@ -15,12 +15,12 @@ from fastreid.data.transforms import build_transforms
from fastreid.data.build import build_reid_train_loader, build_reid_test_loader
from fastreid.evaluation import ShoeScoreEvaluator, ShoeDistanceEvaluator
# ensure custom datasets are registered
import projects.ShoeClas.shoeclas.data
import projects.Shoe.shoe.data
class PairTrainer(DefaultTrainer):
_logger = logging.getLogger('fastreid.shoeclas')
_logger = logging.getLogger('fastreid.shoe')
@classmethod
def build_train_loader(cls, cfg):

View File

@ -13,7 +13,7 @@ from fastreid.config import get_cfg
from fastreid.engine import default_argument_parser, default_setup, launch
from fastreid.utils.checkpoint import Checkpointer
from fastreid.utils import bughook
from fastshoe import PairTrainer
from shoe import PairTrainer

View File

@ -0,0 +1,50 @@
# coding: utf-8
import os
from collections import defaultdict
import sys
import shutil
sys.path.append('')
from fastreid.utils.env import seed_all_rng
from fastreid.data.datasets import DATASET_REGISTRY
import projects.Shoe.shoe.data
seed_all_rng(0)
save_root = 'debug/neg_aug'
if os.path.exists(save_root):
shutil.rmtree(save_root)
os.mkdir(save_root)
root = '/data97/bijia/shoe/'
img_root=os.path.join(root, 'shoe_crop_all_images')
anno_path=os.path.join(root, 'labels/1102/train_1102.json')
dataset = DATASET_REGISTRY.get('PairDataset')(img_root=img_root, anno_path=anno_path, transform=None, mode='train')
pos_imgs = []
neg_imgs = []
for i in range(100):
data = dataset[100]
img1 = data['img1']
img2 = data['img2']
target = data['target']
if target == 0:
pos_imgs.append(img1)
neg_imgs.append(img2)
else:
pos_imgs.append(img1)
pos_imgs.append(img2)
pos_dict = defaultdict(list)
for img in pos_imgs:
pos_dict[img.size].append(img)
for i, k in enumerate(pos_dict.keys()):
img = pos_dict[k][0]
img.save(os.path.join(save_root, 'p-' + str(i) + '.jpg'))
for i, img in enumerate(neg_imgs):
img.save(os.path.join(save_root, 'n-' + str(i) + '.jpg'))

View File

@ -1,63 +0,0 @@
INPUT:
SIZE_TRAIN: [0, 0] # resize size
SIZE_TEST: [0, 0]
CROP:
ENABLED: False
FLIP:
ENABLED: False
PADDING:
ENABLED: False
CJ:
ENABLED: False
AFFINE:
ENABLED: False
AUTOAUG:
ENABLED: False
AUGMIX:
ENABLED: False
REA:
ENABLED: False
RPT:
ENABLED: False
DATALOADER:
SAMPLER_TRAIN: TrainingSampler
NUM_WORKERS: 8
SOLVER:
MAX_EPOCH: 1000
AMP:
ENABLED: True
OPT: SGD
SCHED: CosineAnnealingLR
BASE_LR: 0.001
MOMENTUM: 0.9
NESTEROV: False
BIAS_LR_FACTOR: 1.
WEIGHT_DECAY: 0.0005
WEIGHT_DECAY_BIAS: 0.
IMS_PER_BATCH: 16
ETA_MIN_LR: 0.00003
WARMUP_FACTOR: 0.1
WARMUP_ITERS: 100
CHECKPOINT_PERIOD: 1
TEST:
EVAL_PERIOD: 1
IMS_PER_BATCH: 32

View File

@ -1,5 +0,0 @@
# -*- coding: utf-8 -*-
# @Time : 2021/10/11 10:10:47
# @Author : zuchen.wang@vipshop.com
# @File : __init__.py.py
from .trainer import PairTrainer

View File

@ -1,115 +0,0 @@
# -*- coding: utf-8 -*-
import os
import logging
import json
import random
import pandas as pd
from tabulate import tabulate
from termcolor import colored
from fastreid.data.datasets import DATASET_REGISTRY
from fastreid.data.datasets.bases import ImageDataset
from fastreid.data.data_utils import read_image
from fastreid.utils.env import seed_all_rng
@DATASET_REGISTRY.register()
class PairDataset(ImageDataset):
_logger = logging.getLogger('fastreid.fastshoe')
def __init__(self, img_root: str, anno_path: str, transform=None, mode: str = 'train'):
assert mode in ('train', 'val', 'test'), self._logger.info(
'''mode should the one of ('train', 'val', 'test')''')
self.mode = mode
if self.mode != 'train':
self._logger.info('set {} with {} random seed: 12345'.format(self.mode, self.__class__.__name__))
seed_all_rng(12345)
self.img_root = img_root
self.anno_path = anno_path
self.transform = transform
all_data = json.load(open(self.anno_path))
pos_folders = []
neg_folders = []
for data in all_data:
if len(data['positive_img_list']) >= 2 and len(data['negative_img_list']) >= 1:
pos_folders.append(data['positive_img_list'])
neg_folders.append(data['negative_img_list'])
assert len(pos_folders) == len(neg_folders), self._logger.error('the len of self.pos_foders should be equal to self.pos_foders')
self.pos_folders = pos_folders
self.neg_folders = neg_folders
def __len__(self):
return len(self.pos_folders)
def __getitem__(self, idx):
pf = self.pos_folders[idx]
nf = self.neg_folders[idx]
label = 1
if random.random() < 0.5:
# generate positive pair
img_path1, img_path2 = random.sample(pf, 2)
else:
# generate negative pair
label = 0
img_path1, img_path2 = random.choice(pf), random.choice(nf)
img_path1 = os.path.join(self.img_root, img_path1)
img_path2 = os.path.join(self.img_root, img_path2)
img1 = read_image(img_path1)
img2 = read_image(img_path2)
if self.transform:
img1 = self.transform(img1)
img2 = self.transform(img2)
return {
'img1': img1,
'img2': img2,
'target': label
}
#-------------下面是辅助信息------------------#
@property
def num_classes(self):
return 2
@property
def num_folders(self):
return len(self)
@property
def num_pos_images(self):
return sum([len(x) for x in self.pos_folders])
@property
def num_neg_images(self):
return sum([len(x) for x in self.neg_folders])
def describe(self):
headers = ['subset', 'folders', 'pos images', 'neg images']
csv_results = [[self.mode, self.num_folders, self.num_pos_images, self.num_neg_images]]
# tabulate it
table = tabulate(
csv_results,
tablefmt="pipe",
headers=headers,
numalign="left",
)
self._logger.info(f"=> Loaded {self.__class__.__name__}: \n" + colored(table, "cyan"))
def show_train(self):
return self.describe()
def show_test(self):
return self.describe()

View File

@ -1,60 +0,0 @@
# -*- coding: utf-8 -*-
# @Time : 2021/10/11 10:12:28
# @Author : zuchen.wang@vipshop.com
# @File : train_net.py.py
import json
import os
import sys
sys.path.append('.')
from fastreid.config import get_cfg
from fastreid.engine import default_argument_parser, default_setup, launch
from fastreid.utils.checkpoint import Checkpointer
from fastreid.utils import bughook
from shoeclas import PairTrainer
def setup(args):
"""
Create configs and perform basic setups.
"""
cfg = get_cfg()
cfg.merge_from_file(args.config_file)
cfg.merge_from_list(args.opts)
setattr(cfg, 'eval_only', args.eval_only)
cfg.freeze()
default_setup(cfg, args)
return cfg
def main(args):
cfg = setup(args)
if args.eval_only:
cfg.defrost()
cfg.MODEL.BACKBONE.PRETRAIN = False
model = PairTrainer.build_model(cfg)
Checkpointer(model).load(cfg.MODEL.WEIGHTS) # load trained model
res = PairTrainer.test(cfg, model)
return res
else:
trainer = PairTrainer(cfg)
trainer.resume_or_load(resume=args.resume)
return trainer.train()
if __name__ == "__main__":
args = default_argument_parser().parse_args()
print("Command Line Args:", args)
launch(
main,
args.num_gpus,
num_machines=args.num_machines,
machine_rank=args.machine_rank,
dist_url='auto',
args=(args,),
)