update readme
parent
f6aab3ee89
commit
c57773ec0b
|
@ -33,6 +33,7 @@ You can find some research projects that are built on top of Torchreid `here <ht
|
|||
|
||||
What's new
|
||||
------------
|
||||
- [May 2020] Added the person attribute recognition code used in [Omni-Scale Feature Learning for Person Re-Identification (ICCV'19)](https://arxiv.org/abs/1905.00953).
|
||||
- [May 2020] ``1.2.1``: Added a simple API for feature extraction (``torchreid/utils/feature_extractor.py``). See the `documentation <https://kaiyangzhou.github.io/deep-person-reid/user_guide.html>`_ for the instruction.
|
||||
- [Apr 2020] Code for reproducing the experiments of `deep mutual learning <https://zpascal.net/cvpr2018/Zhang_Deep_Mutual_Learning_CVPR_2018_paper.pdf>`_ in the `OSNet paper <https://arxiv.org/pdf/1905.00953v6.pdf>`__ (Supp. B) has been released at ``projects/DML``.
|
||||
- [Apr 2020] Upgraded to ``1.2.0``. The engine class has been made more model-agnostic to improve extensibility. See `Engine <torchreid/engine/engine.py>`_ and `ImageSoftmaxEngine <torchreid/engine/image/softmax.py>`_ for more details. Credit to `Dassl.pytorch <https://github.com/KaiyangZhou/Dassl.pytorch>`_.
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
Here are some research projects built on [Torchreid](https://arxiv.org/abs/1910.10093).
|
||||
|
||||
+ [Learning Generalisable Omni-Scale Representations for Person Re-Identification](OSNet_AIN)
|
||||
+ [Deep Mutual Learning (CVPR'18)](DML)
|
||||
+ `OSNet_AIN`: [Learning Generalisable Omni-Scale Representations for Person Re-Identification](https://arxiv.org/abs/1910.06827)
|
||||
+ `DML`: [Deep Mutual Learning (CVPR'18)](https://arxiv.org/abs/1706.00384)
|
||||
+ `attribute_recognition`: [Omni-Scale Feature Learning for Person Re-Identification (ICCV'19)](https://arxiv.org/abs/1905.00953)
|
|
@ -0,0 +1,18 @@
|
|||
# Person Attribute Recognition
|
||||
This code was developed for the experiment of person attribute recognition in [Omni-Scale Feature Learning for Person Re-Identification (ICCV'19)](https://arxiv.org/abs/1905.00953).
|
||||
|
||||
## Download data
|
||||
Download the PA-100K dataset from [https://github.com/xh-liu/HydraPlus-Net](https://github.com/xh-liu/HydraPlus-Net), and extract the file under the folder where you store your data (say $DATASET). The folder structure should look like
|
||||
```bash
|
||||
$DATASET/
|
||||
pa100k/
|
||||
data/ # images
|
||||
annotation/
|
||||
annotation.mat
|
||||
```
|
||||
|
||||
## Train
|
||||
The training command is provided in `train.sh`. Run `bash train.sh $DATASET` to start training.
|
||||
|
||||
## Test
|
||||
To test a pretrained model, add the following two arguments to `train.sh`: `--load-weights $PATH_TO_WEIGHTS --evaluate`.
|
|
@ -53,6 +53,7 @@ def init_dataset(use_gpu):
|
|||
mode='train',
|
||||
verbose=True
|
||||
)
|
||||
|
||||
valset = datasets.init_dataset(
|
||||
args.dataset,
|
||||
root=args.root,
|
||||
|
@ -60,6 +61,7 @@ def init_dataset(use_gpu):
|
|||
mode='val',
|
||||
verbose=False
|
||||
)
|
||||
|
||||
testset = datasets.init_dataset(
|
||||
args.dataset,
|
||||
root=args.root,
|
||||
|
@ -108,6 +110,7 @@ def main():
|
|||
use_gpu = torch.cuda.is_available() and not args.use_cpu
|
||||
log_name = 'test.log' if args.evaluate else 'train.log'
|
||||
sys.stdout = Logger(osp.join(args.save_dir, log_name))
|
||||
|
||||
print('** Arguments **')
|
||||
arg_keys = list(args.__dict__.keys())
|
||||
arg_keys.sort()
|
||||
|
@ -116,6 +119,7 @@ def main():
|
|||
print('\n')
|
||||
print('Collecting env info ...')
|
||||
print('** System info **\n{}\n'.format(collect_env_info()))
|
||||
|
||||
if use_gpu:
|
||||
torch.backends.cudnn.benchmark = True
|
||||
else:
|
||||
|
@ -123,9 +127,8 @@ def main():
|
|||
'Currently using CPU, however, GPU is highly recommended'
|
||||
)
|
||||
|
||||
trainloader, valloader, testloader, num_attrs, attr_dict = init_dataset(
|
||||
use_gpu
|
||||
)
|
||||
dataset_vars = init_dataset(use_gpu)
|
||||
trainloader, valloader, testloader, num_attrs, attr_dict = dataset_vars
|
||||
|
||||
if args.weighted_bce:
|
||||
print('Use weighted binary cross entropy')
|
||||
|
@ -139,6 +142,7 @@ def main():
|
|||
print('BCE weights: {}'.format(bce_weights))
|
||||
bce_weights = bce_weights.expand(args.batch_size, num_attrs)
|
||||
criterion = nn.BCEWithLogitsLoss(weight=bce_weights)
|
||||
|
||||
else:
|
||||
print('Use plain binary cross entropy')
|
||||
criterion = nn.BCEWithLogitsLoss()
|
||||
|
@ -147,7 +151,7 @@ def main():
|
|||
model = models.build_model(
|
||||
args.arch,
|
||||
num_attrs,
|
||||
pretrained=(not args.no_pretrained),
|
||||
pretrained=not args.no_pretrained,
|
||||
use_gpu=use_gpu
|
||||
)
|
||||
num_params, flops = compute_model_complexity(
|
||||
|
@ -181,12 +185,9 @@ def main():
|
|||
optimizer.load_state_dict(checkpoint['optimizer'])
|
||||
start_epoch = checkpoint['epoch']
|
||||
best_result = checkpoint['label_mA']
|
||||
print("Loaded checkpoint from '{}'".format(args.resume))
|
||||
print(
|
||||
"- start_epoch: {}\n- label_mA: {}".format(
|
||||
start_epoch, best_result
|
||||
)
|
||||
)
|
||||
print('Loaded checkpoint from "{}"'.format(args.resume))
|
||||
print('- start epoch: {}'.format(start_epoch))
|
||||
print('- label_mA: {}'.format(best_result))
|
||||
|
||||
time_start = time.time()
|
||||
|
||||
|
@ -199,6 +200,7 @@ def main():
|
|||
is_best = label_mA > best_result
|
||||
if is_best:
|
||||
best_result = label_mA
|
||||
|
||||
save_checkpoint(
|
||||
{
|
||||
'state_dict': model.state_dict(),
|
||||
|
@ -340,7 +342,7 @@ def test(model, testloader, attr_dict, use_gpu):
|
|||
|
||||
if (batch_idx+1) % args.print_freq == 0:
|
||||
print(
|
||||
"Processed batch {}/{}".format(batch_idx + 1, len(testloader))
|
||||
'Processed batch {}/{}'.format(batch_idx + 1, len(testloader))
|
||||
)
|
||||
|
||||
if args.save_prediction:
|
||||
|
@ -349,24 +351,24 @@ def test(model, testloader, attr_dict, use_gpu):
|
|||
img_path = img_paths[idx]
|
||||
probs = orig_outputs[idx, :]
|
||||
labels = attrs[idx, :]
|
||||
txtfile.write("{}\n".format(img_path))
|
||||
txtfile.write("*** Correct prediction ***\n")
|
||||
txtfile.write('{}\n'.format(img_path))
|
||||
txtfile.write('*** Correct prediction ***\n')
|
||||
for attr_idx, (label, prob) in enumerate(zip(labels, probs)):
|
||||
if label:
|
||||
attr_name = attr_dict[attr_idx]
|
||||
info = "{}: {:.1%} ".format(attr_name, prob)
|
||||
info = '{}: {:.1%} '.format(attr_name, prob)
|
||||
txtfile.write(info)
|
||||
txtfile.write("\n*** Incorrect prediction ***\n")
|
||||
txtfile.write('\n*** Incorrect prediction ***\n')
|
||||
for attr_idx, (label, prob) in enumerate(zip(labels, probs)):
|
||||
if not label and prob > 0.5:
|
||||
attr_name = attr_dict[attr_idx]
|
||||
info = "{}: {:.1%} ".format(attr_name, prob)
|
||||
info = '{}: {:.1%} '.format(attr_name, prob)
|
||||
txtfile.write(info)
|
||||
txtfile.write("\n\n")
|
||||
txtfile.write('\n\n')
|
||||
txtfile.close()
|
||||
|
||||
print(
|
||||
"=> BatchTime(s)/BatchSize(img): {:.4f}/{}".format(
|
||||
'=> BatchTime(s)/BatchSize(img): {:.4f}/{}'.format(
|
||||
batch_time.avg, args.batch_size
|
||||
)
|
||||
)
|
||||
|
@ -381,14 +383,14 @@ def test(model, testloader, attr_dict, use_gpu):
|
|||
label_mA_verbose = (term1+term2) * 0.5
|
||||
label_mA = label_mA_verbose.mean()
|
||||
|
||||
print("* Results *")
|
||||
print(" # test persons: {}".format(num_persons))
|
||||
print(" (instance-based) accuracy: {:.1%}".format(ins_acc))
|
||||
print(" (instance-based) precition: {:.1%}".format(ins_prec))
|
||||
print(" (instance-based) recall: {:.1%}".format(ins_rec))
|
||||
print(" (instance-based) f1-score: {:.1%}".format(ins_f1))
|
||||
print(" (label-based) mean accuracy: {:.1%}".format(label_mA))
|
||||
print(" mA for each attribute: {}".format(label_mA_verbose))
|
||||
print('* Results *')
|
||||
print(' # test persons: {}'.format(num_persons))
|
||||
print(' (instance-based) accuracy: {:.1%}'.format(ins_acc))
|
||||
print(' (instance-based) precition: {:.1%}'.format(ins_prec))
|
||||
print(' (instance-based) recall: {:.1%}'.format(ins_rec))
|
||||
print(' (instance-based) f1-score: {:.1%}'.format(ins_f1))
|
||||
print(' (label-based) mean accuracy: {:.1%}'.format(label_mA))
|
||||
print(' mA for each attribute: {}'.format(label_mA_verbose))
|
||||
|
||||
return label_mA, ins_acc, ins_prec, ins_rec, ins_f1
|
||||
|
||||
|
|
Loading…
Reference in New Issue