mirror of https://github.com/JDAI-CV/fast-reid.git
bugfix for attribute project (#450)
Summary: refactor sample weight in attribute recognition; change all options to False in defaults.py and modify yaml filespull/456/head
parent
25cfa88fd9
commit
fb36b23678
configs
fastreid/config
projects
FastClas
FastRetri/configs
NAIC20/configs
|
@ -34,9 +34,14 @@ MODEL:
|
|||
INPUT:
|
||||
SIZE_TRAIN: [ 256, 128 ]
|
||||
SIZE_TEST: [ 256, 128 ]
|
||||
|
||||
REA:
|
||||
ENABLED: True
|
||||
PROB: 0.5
|
||||
|
||||
FLIP:
|
||||
ENABLED: True
|
||||
|
||||
PADDING:
|
||||
ENABLED: True
|
||||
|
||||
|
|
|
@ -9,7 +9,7 @@ MODEL:
|
|||
WITH_IBN: True
|
||||
|
||||
HEADS:
|
||||
POOL_LAYER: gempool
|
||||
POOL_LAYER: GeneralizedMeanPooling
|
||||
|
||||
LOSSES:
|
||||
TRI:
|
||||
|
|
|
@ -8,7 +8,8 @@ MODEL:
|
|||
BACKBONE:
|
||||
WITH_IBN: True
|
||||
HEADS:
|
||||
POOL_LAYER: gempool
|
||||
POOL_LAYER: GeneralizedMeanPooling
|
||||
|
||||
LOSSES:
|
||||
TRI:
|
||||
HARD_MINING: False
|
||||
|
|
|
@ -47,7 +47,7 @@ _C.MODEL.BACKBONE.WITH_SE = False
|
|||
# If use Non-local block in backbone
|
||||
_C.MODEL.BACKBONE.WITH_NL = False
|
||||
# If use ImageNet pretrain model
|
||||
_C.MODEL.BACKBONE.PRETRAIN = True
|
||||
_C.MODEL.BACKBONE.PRETRAIN = False
|
||||
# Pretrain model path
|
||||
_C.MODEL.BACKBONE.PRETRAIN_PATH = ''
|
||||
|
||||
|
@ -63,14 +63,14 @@ _C.MODEL.HEADS.NUM_CLASSES = 0
|
|||
# Embedding dimension in head
|
||||
_C.MODEL.HEADS.EMBEDDING_DIM = 0
|
||||
# If use BNneck in embedding
|
||||
_C.MODEL.HEADS.WITH_BNNECK = True
|
||||
_C.MODEL.HEADS.WITH_BNNECK = False
|
||||
# Triplet feature using feature before(after) bnneck
|
||||
_C.MODEL.HEADS.NECK_FEAT = "before" # options: before, after
|
||||
# Pooling layer type
|
||||
_C.MODEL.HEADS.POOL_LAYER = "avgpool"
|
||||
_C.MODEL.HEADS.POOL_LAYER = "GlobalAvgPool"
|
||||
|
||||
# Classification layer type
|
||||
_C.MODEL.HEADS.CLS_LAYER = "linear" # "arcSoftmax" or "circleSoftmax"
|
||||
_C.MODEL.HEADS.CLS_LAYER = "Linear" # ArcSoftmax" or "CircleSoftmax"
|
||||
|
||||
# Margin and Scale for margin-based classification layer
|
||||
_C.MODEL.HEADS.MARGIN = 0.15
|
||||
|
@ -100,7 +100,7 @@ _C.MODEL.LOSSES.FL.SCALE = 1.0
|
|||
_C.MODEL.LOSSES.TRI = CN()
|
||||
_C.MODEL.LOSSES.TRI.MARGIN = 0.3
|
||||
_C.MODEL.LOSSES.TRI.NORM_FEAT = False
|
||||
_C.MODEL.LOSSES.TRI.HARD_MINING = True
|
||||
_C.MODEL.LOSSES.TRI.HARD_MINING = False
|
||||
_C.MODEL.LOSSES.TRI.SCALE = 1.0
|
||||
|
||||
# Circle Loss options
|
||||
|
@ -150,11 +150,11 @@ _C.INPUT.CROP.SCALE = [0.16, 1]
|
|||
_C.INPUT.CROP.RATIO = [3./4., 4./3.]
|
||||
|
||||
# Random probability for image horizontal flip
|
||||
_C.INPUT.FLIP = CN({"ENABLED": True})
|
||||
_C.INPUT.FLIP = CN({"ENABLED": False})
|
||||
_C.INPUT.FLIP.PROB = 0.5
|
||||
|
||||
# Value of padding size
|
||||
_C.INPUT.PADDING = CN({"ENABLED": True})
|
||||
_C.INPUT.PADDING = CN({"ENABLED": False})
|
||||
_C.INPUT.PADDING.MODE = 'constant'
|
||||
_C.INPUT.PADDING.SIZE = 10
|
||||
|
||||
|
@ -201,7 +201,7 @@ _C.DATASETS.COMBINEALL = False
|
|||
# -----------------------------------------------------------------------------
|
||||
_C.DATALOADER = CN()
|
||||
# Options: TrainingSampler, NaiveIdentitySampler, BalancedIdentitySampler
|
||||
_C.DATALOADER.SAMPLER_TRAIN = "NaiveIdentitySampler"
|
||||
_C.DATALOADER.SAMPLER_TRAIN = "TrainingSampler"
|
||||
# Number of instance for each person
|
||||
_C.DATALOADER.NUM_INSTANCE = 4
|
||||
_C.DATALOADER.NUM_WORKERS = 8
|
||||
|
@ -224,7 +224,7 @@ _C.SOLVER.BIAS_LR_FACTOR = 1.
|
|||
_C.SOLVER.HEADS_LR_FACTOR = 1.
|
||||
|
||||
_C.SOLVER.MOMENTUM = 0.9
|
||||
_C.SOLVER.NESTEROV = True
|
||||
_C.SOLVER.NESTEROV = False
|
||||
|
||||
_C.SOLVER.WEIGHT_DECAY = 0.0005
|
||||
_C.SOLVER.WEIGHT_DECAY_BIAS = 0.
|
||||
|
|
|
@ -14,8 +14,8 @@ MODEL:
|
|||
HEADS:
|
||||
NAME: AttrHead
|
||||
WITH_BNNECK: True
|
||||
POOL_LAYER: fastavgpool
|
||||
CLS_LAYER: linear
|
||||
POOL_LAYER: FastGlobalAvgPool
|
||||
CLS_LAYER: Linear
|
||||
NUM_CLASSES: 26
|
||||
|
||||
LOSSES:
|
||||
|
@ -28,11 +28,15 @@ MODEL:
|
|||
INPUT:
|
||||
SIZE_TRAIN: [ 256, 192 ]
|
||||
SIZE_TEST: [ 256, 192 ]
|
||||
REA:
|
||||
ENABLED: False
|
||||
DO_PAD: True
|
||||
|
||||
FLIP:
|
||||
ENABLED: True
|
||||
|
||||
PADDING:
|
||||
ENABLED: True
|
||||
|
||||
DATALOADER:
|
||||
SAMPLER_TRAIN: TrainingSampler
|
||||
NUM_WORKERS: 8
|
||||
|
||||
SOLVER:
|
||||
|
|
|
@ -11,13 +11,16 @@ from .bce_loss import cross_entropy_sigmoid_loss
|
|||
|
||||
@META_ARCH_REGISTRY.register()
|
||||
class AttrBaseline(Baseline):
|
||||
def __init__(self, cfg, sample_weights):
|
||||
super(AttrBaseline, self).__init__(cfg)
|
||||
bce_weight_enabled = cfg.MODEL.LOSSES.BCE.WEIGHT_ENABLED
|
||||
if bce_weight_enabled:
|
||||
self.register_buffer("sample_weight", sample_weights)
|
||||
else:
|
||||
self.sample_weights = None
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, cfg):
|
||||
base_res = Baseline.from_config(cfg)
|
||||
base_res["loss_kwargs"].update({
|
||||
'bce': {
|
||||
'scale': cfg.MODEL.LOSSES.BCE.SCALE
|
||||
}
|
||||
})
|
||||
return base_res
|
||||
|
||||
def losses(self, outputs, gt_labels):
|
||||
r"""
|
||||
|
@ -25,16 +28,17 @@ class AttrBaseline(Baseline):
|
|||
must be the same as the outputs of the model forwarding.
|
||||
"""
|
||||
# model predictions
|
||||
cls_outputs = outputs['cls_outputs']
|
||||
cls_outputs = outputs["cls_outputs"]
|
||||
|
||||
loss_dict = {}
|
||||
loss_names = self._cfg.MODEL.LOSSES.NAME
|
||||
loss_names = self.loss_kwargs["loss_names"]
|
||||
|
||||
if "BinaryCrossEntropyLoss" in loss_names:
|
||||
bce_kwargs = self.loss_kwargs.get('bce')
|
||||
loss_dict["loss_bce"] = cross_entropy_sigmoid_loss(
|
||||
cls_outputs,
|
||||
gt_labels,
|
||||
self.sample_weight,
|
||||
) * self._cfg.MODEL.LOSSES.BCE.SCALE
|
||||
self.sample_weights,
|
||||
) * bce_kwargs.get('scale')
|
||||
|
||||
return loss_dict
|
||||
|
|
|
@ -13,10 +13,10 @@ from fastreid.data.data_utils import read_image
|
|||
class AttrDataset(Dataset):
|
||||
"""Image Person Attribute Dataset"""
|
||||
|
||||
def __init__(self, img_items, attr_dict, transform=None):
|
||||
def __init__(self, img_items, transform, attr_dict):
|
||||
self.img_items = img_items
|
||||
self.attr_dict = attr_dict
|
||||
self.transform = transform
|
||||
self.attr_dict = attr_dict
|
||||
|
||||
def __len__(self):
|
||||
return len(self.img_items)
|
||||
|
@ -24,6 +24,7 @@ class AttrDataset(Dataset):
|
|||
def __getitem__(self, index):
|
||||
img_path, labels = self.img_items[index]
|
||||
img = read_image(img_path)
|
||||
|
||||
if self.transform is not None: img = self.transform(img)
|
||||
|
||||
labels = torch.as_tensor(labels)
|
||||
|
|
|
@ -12,7 +12,7 @@ import torch
|
|||
from fastreid.evaluation.evaluator import DatasetEvaluator
|
||||
from fastreid.utils import comm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = logging.getLogger("fastreid.attr_evaluation")
|
||||
|
||||
|
||||
class AttrEvaluator(DatasetEvaluator):
|
||||
|
|
|
@ -28,22 +28,16 @@ class AttrHead(EmbeddingHead):
|
|||
"""
|
||||
pool_feat = self.pool_layer(features)
|
||||
neck_feat = self.bottleneck(pool_feat)
|
||||
neck_feat = neck_feat[..., 0, 0]
|
||||
neck_feat = neck_feat.view(neck_feat.size(0), -1)
|
||||
|
||||
if self.cls_layer.__class__.__name__ == 'Linear':
|
||||
logits = F.linear(neck_feat, self.weight)
|
||||
else:
|
||||
logits = F.linear(F.normalize(neck_feat), F.normalize(self.weight))
|
||||
logits = F.linear(neck_feat, self.weight)
|
||||
logits = self.bnneck(logits)
|
||||
|
||||
# Evaluation
|
||||
if not self.training:
|
||||
logits = self.bnneck(logits * self.cls_layer.s)
|
||||
cls_outptus = torch.sigmoid(logits)
|
||||
return cls_outptus
|
||||
|
||||
cls_outputs = self.cls_layer(logits, targets)
|
||||
cls_outputs = self.bnneck(cls_outputs)
|
||||
|
||||
return {
|
||||
'cls_outputs': cls_outputs,
|
||||
"cls_outputs": logits,
|
||||
}
|
||||
|
|
|
@ -20,9 +20,6 @@ _root = os.getenv("FASTREID_DATASETS", "datasets")
|
|||
|
||||
|
||||
def build_attr_train_loader(cfg):
|
||||
cfg = cfg.clone()
|
||||
cfg.defrost()
|
||||
|
||||
train_items = list()
|
||||
attr_dict = None
|
||||
for d in cfg.DATASETS.NAMES:
|
||||
|
@ -30,13 +27,13 @@ def build_attr_train_loader(cfg):
|
|||
if comm.is_main_process():
|
||||
dataset.show_train()
|
||||
if attr_dict is not None:
|
||||
assert attr_dict == dataset.attr_dict, "attr_dict in {} does not match with previous ones".format(d)
|
||||
assert attr_dict == dataset.attr_dict, f"attr_dict in {d} does not match with previous ones"
|
||||
else:
|
||||
attr_dict = dataset.attr_dict
|
||||
train_items.extend(dataset.train)
|
||||
|
||||
train_transforms = build_transforms(cfg, is_train=True)
|
||||
train_set = AttrDataset(train_items, attr_dict, train_transforms)
|
||||
train_set = AttrDataset(train_items, train_transforms, attr_dict)
|
||||
|
||||
num_workers = cfg.DATALOADER.NUM_WORKERS
|
||||
mini_batch_size = cfg.SOLVER.IMS_PER_BATCH // comm.get_world_size()
|
||||
|
@ -55,16 +52,14 @@ def build_attr_train_loader(cfg):
|
|||
|
||||
|
||||
def build_attr_test_loader(cfg, dataset_name):
|
||||
cfg = cfg.clone()
|
||||
cfg.defrost()
|
||||
|
||||
dataset = DATASET_REGISTRY.get(dataset_name)(root=_root, combineall=cfg.DATASETS.COMBINEALL)
|
||||
attr_dict = dataset.attr_dict
|
||||
if comm.is_main_process():
|
||||
dataset.show_test()
|
||||
test_items = dataset.test
|
||||
|
||||
test_transforms = build_transforms(cfg, is_train=False)
|
||||
test_set = AttrDataset(test_items, dataset.attr_dict, test_transforms)
|
||||
test_set = AttrDataset(test_items, test_transforms, attr_dict)
|
||||
|
||||
mini_batch_size = cfg.TEST.IMS_PER_BATCH // comm.get_world_size()
|
||||
data_sampler = samplers.InferenceSampler(len(test_set))
|
||||
|
|
|
@ -3,14 +3,12 @@
|
|||
@author: xingyu liao
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
import logging
|
||||
import sys
|
||||
|
||||
sys.path.append('.')
|
||||
|
||||
from fastreid.config import get_cfg
|
||||
from fastreid.engine import DefaultTrainer
|
||||
from fastreid.modeling import build_model
|
||||
from fastreid.engine import default_argument_parser, default_setup, launch
|
||||
from fastreid.utils.checkpoint import Checkpointer
|
||||
|
||||
|
@ -18,22 +16,28 @@ from fastattr import *
|
|||
|
||||
|
||||
class Trainer(DefaultTrainer):
|
||||
sample_weights = None
|
||||
|
||||
def build_model(self, cfg):
|
||||
@classmethod
|
||||
def build_model(cls, cfg):
|
||||
"""
|
||||
Returns:
|
||||
torch.nn.Module:
|
||||
It now calls :func:`fastreid.modeling.build_model`.
|
||||
Overwrite it if you'd like a different model.
|
||||
"""
|
||||
model = build_model(cfg, sample_weights=self.sample_weights)
|
||||
logger = logging.getLogger("fastreid.attr_model")
|
||||
logger.info("Model:\n{}".format(model))
|
||||
model = DefaultTrainer.build_model(cfg)
|
||||
if cfg.MODEL.LOSSES.BCE.WEIGHT_ENABLED and \
|
||||
Trainer.sample_weights is not None:
|
||||
setattr(model, "sample_weights", Trainer.sample_weights.to(model.device))
|
||||
else:
|
||||
setattr(model, "sample_weights", None)
|
||||
return model
|
||||
|
||||
def build_train_loader(self, cfg):
|
||||
@classmethod
|
||||
def build_train_loader(cls, cfg):
|
||||
data_loader = build_attr_train_loader(cfg)
|
||||
self.sample_weights = data_loader.dataset.sample_weights
|
||||
Trainer.sample_weights = data_loader.dataset.sample_weights
|
||||
return data_loader
|
||||
|
||||
@classmethod
|
||||
|
|
|
@ -26,25 +26,14 @@ MODEL:
|
|||
|
||||
INPUT:
|
||||
SIZE_TEST: [256, 256]
|
||||
|
||||
CROP:
|
||||
ENABLED: True
|
||||
SIZE: [224, 224]
|
||||
|
||||
FLIP:
|
||||
ENABLED: True
|
||||
|
||||
AUGMIX:
|
||||
ENABLED: False
|
||||
AUTOAUG:
|
||||
ENABLED: False
|
||||
PADDING:
|
||||
ENABLED: False
|
||||
CJ:
|
||||
ENABLED: False
|
||||
AFFINE:
|
||||
ENABLED: False
|
||||
REA:
|
||||
ENABLED: False
|
||||
|
||||
DATALOADER:
|
||||
SAMPLER_TRAIN: TrainingSampler
|
||||
NUM_WORKERS: 8
|
||||
|
|
|
@ -20,7 +20,7 @@ from fastreid.utils.checkpoint import Checkpointer, PathManager
|
|||
from fastreid.utils import comm
|
||||
from fastreid.engine import DefaultTrainer
|
||||
|
||||
from fastcls import *
|
||||
from fastclas import *
|
||||
|
||||
|
||||
class Trainer(DefaultTrainer):
|
||||
|
|
|
@ -34,8 +34,8 @@ INPUT:
|
|||
SCALE: [0.16, 1.]
|
||||
RATIO: [0.75, 1.33333]
|
||||
|
||||
PADDING:
|
||||
ENABLED: False
|
||||
FLIP:
|
||||
ENABLED: True
|
||||
|
||||
CJ:
|
||||
ENABLED: False
|
||||
|
@ -44,6 +44,7 @@ INPUT:
|
|||
SATURATION: 0.1
|
||||
HUE: 0.1
|
||||
|
||||
|
||||
DATALOADER:
|
||||
SAMPLER_TRAIN: TrainingSampler
|
||||
NUM_WORKERS: 8
|
||||
|
|
|
@ -35,6 +35,11 @@ INPUT:
|
|||
SIZE_TRAIN: [ 256, 128 ]
|
||||
SIZE_TEST: [ 256, 128 ]
|
||||
|
||||
FLIP:
|
||||
ENABLED: True
|
||||
|
||||
PADDING:
|
||||
ENABLED: True
|
||||
|
||||
AUGMIX:
|
||||
ENABLED: True
|
||||
|
|
|
@ -19,6 +19,9 @@ TEST:
|
|||
K2: 3
|
||||
LAMBDA: 0.8
|
||||
|
||||
FLIP:
|
||||
ENABLED: True
|
||||
|
||||
SAVE_DISTMAT: True
|
||||
|
||||
OUTPUT_DIR: projects/NAIC20/logs/r34_ibn-128x256-submit
|
Loading…
Reference in New Issue