mirror of https://github.com/JDAI-CV/fast-reid.git
feat: support multi-teacher kd
Summary: support multi-teacher kd with logits and overhaul distillationpull/397/head
parent
db8670db63
commit
77a91b1204
|
@ -1,8 +1,37 @@
|
|||
.idea
|
||||
|
||||
logs
|
||||
|
||||
# compilation and distribution
|
||||
__pycache__
|
||||
.DS_Store
|
||||
.vscode
|
||||
_ext
|
||||
*.pyc
|
||||
*.pyd
|
||||
*.so
|
||||
logs/
|
||||
.ipynb_checkpoints
|
||||
logs
|
||||
*.dll
|
||||
*.egg-info/
|
||||
build/
|
||||
dist/
|
||||
wheels/
|
||||
|
||||
# pytorch/python/numpy formats
|
||||
*.pth
|
||||
*.pkl
|
||||
*.npy
|
||||
*.ts
|
||||
model_ts*.txt
|
||||
|
||||
# ipython/jupyter notebooks
|
||||
*.ipynb
|
||||
**/.ipynb_checkpoints/
|
||||
|
||||
# Editor temporaries
|
||||
*.swn
|
||||
*.swo
|
||||
*.swp
|
||||
*~
|
||||
|
||||
# editor settings
|
||||
.idea
|
||||
.vscode
|
||||
_darcs
|
||||
.DS_Store
|
||||
|
|
|
@ -0,0 +1,12 @@
|
|||
# Changelog
|
||||
|
||||
### v1.1 (29/01/2021)
|
||||
|
||||
#### New Features
|
||||
|
||||
- NAIC20(reid track) [1-st solution](https://github.com/JDAI-CV/fast-reid/tree/master/projects/NAIC20)
|
||||
- Multi-teacher Knowledge Distillation
|
||||
|
||||
#### Bug Fixes
|
||||
|
||||
#### Improvements
|
30
README.md
30
README.md
|
@ -4,48 +4,52 @@ FastReID is a research platform that implements state-of-the-art re-identificati
|
|||
|
||||
## What's New
|
||||
|
||||
- [Jan 2021] NAIC20(reid track) [1-st solution](https://github.com/JDAI-CV/fast-reid/tree/master/projects/NAIC20) based on fastreid has been released!
|
||||
- [Jan 2021] NAIC20(reid track) [1-st solution](projects/NAIC20) based on fastreid has been released!
|
||||
- [Jan 2021] FastReID V1.0 has been released!🎉
|
||||
Support many tasks beyond reid, such image retrieval and face recognition. See [release notes](https://github.com/JDAI-CV/fast-reid/releases/tag/v1.0.0).
|
||||
- [Oct 2020] Added the [Hyper-Parameter Optimization](https://github.com/JDAI-CV/fast-reid/tree/master/projects/FastTune) based on fastreid. See `projects/FastTune`.
|
||||
- [Sep 2020] Added the [person attribute recognition](https://github.com/JDAI-CV/fast-reid/tree/master/projects/FastAttr) based on fastreid. See `projects/FastAttr`.
|
||||
- [Oct 2020] Added the [Hyper-Parameter Optimization](projects/FastTune) based on fastreid. See `projects/FastTune`.
|
||||
- [Sep 2020] Added the [person attribute recognition](projects/FastAttr) based on fastreid. See `projects/FastAttr`.
|
||||
- [Sep 2020] Automatic Mixed Precision training is supported with `apex`. Set `cfg.SOLVER.FP16_ENABLED=True` to switch it on.
|
||||
- [Aug 2020] [Model Distillation](https://github.com/JDAI-CV/fast-reid/tree/master/projects/FastDistill) is supported, thanks for [guan'an wang](https://github.com/wangguanan)'s contribution.
|
||||
- [Aug 2020] [Model Distillation](projects/FastDistill) is supported, thanks for [guan'an wang](https://github.com/wangguanan)'s contribution.
|
||||
- [Aug 2020] ONNX/TensorRT converter is supported.
|
||||
- [Jul 2020] Distributed training with multiple GPUs, it trains much faster.
|
||||
- Includes more features such as circle loss, abundant visualization methods and evaluation metrics, SoTA results on conventional, cross-domain, partial and vehicle re-id, testing on multi-datasets simultaneously, etc.
|
||||
- Can be used as a library to support [different projects](https://github.com/JDAI-CV/fast-reid/tree/master/projects) on top of it. We'll open source more research projects in this way.
|
||||
- Can be used as a library to support [different projects](projects) on top of it. We'll open source more research projects in this way.
|
||||
- Remove [ignite](https://github.com/pytorch/ignite)(a high-level library) dependency and powered by [PyTorch](https://pytorch.org/).
|
||||
|
||||
We write a [chinese blog](https://l1aoxingyu.github.io/blogpages/reid/2020/05/29/fastreid.html) about this toolbox.
|
||||
|
||||
## Changelog
|
||||
|
||||
Please refer to [changelog.md](CHANGELOG.md) for details and release history.
|
||||
|
||||
## Installation
|
||||
|
||||
See [INSTALL.md](https://github.com/JDAI-CV/fast-reid/blob/master/INSTALL.md).
|
||||
See [INSTALL.md](INSTALL.md).
|
||||
|
||||
## Quick Start
|
||||
|
||||
The designed architecture follows this guide [PyTorch-Project-Template](https://github.com/L1aoXingyu/PyTorch-Project-Template), you can check each folder's purpose by yourself.
|
||||
|
||||
See [GETTING_STARTED.md](https://github.com/JDAI-CV/fast-reid/blob/master/GETTING_STARTED.md).
|
||||
See [GETTING_STARTED.md](GETTING_STARTED.md).
|
||||
|
||||
Learn more at out [documentation](https://fast-reid.readthedocs.io/). And see [projects/](https://github.com/JDAI-CV/fast-reid/tree/master/projects) for some projects that are build on top of fastreid.
|
||||
Learn more at out [documentation](https://fast-reid.readthedocs.io/). And see [projects/](projects) for some projects that are build on top of fastreid.
|
||||
|
||||
## Model Zoo and Baselines
|
||||
|
||||
We provide a large set of baseline results and trained models available for download in the [Fastreid Model Zoo](https://github.com/JDAI-CV/fast-reid/blob/master/MODEL_ZOO.md).
|
||||
We provide a large set of baseline results and trained models available for download in the [Fastreid Model Zoo](MODEL_ZOO.md).
|
||||
|
||||
## Deployment
|
||||
|
||||
We provide some examples and scripts to convert fastreid model to Caffe, ONNX and TensorRT format in [Fastreid deploy](https://github.com/JDAI-CV/fast-reid/blob/master/tools/deploy).
|
||||
We provide some examples and scripts to convert fastreid model to Caffe, ONNX and TensorRT format in [Fastreid deploy](tools/deploy).
|
||||
|
||||
## License
|
||||
|
||||
Fastreid is released under the [Apache 2.0 license](https://github.com/JDAI-CV/fast-reid/blob/master/LICENSE).
|
||||
Fastreid is released under the [Apache 2.0 license](LICENSE).
|
||||
|
||||
## Citing Fastreid
|
||||
## Citing FastReID
|
||||
|
||||
If you use Fastreid in your research or wish to refer to the baseline results published in the Model Zoo, please use the following BibTeX entry.
|
||||
If you use FastReID in your research or wish to refer to the baseline results published in the Model Zoo, please use the following BibTeX entry.
|
||||
|
||||
```BibTeX
|
||||
@article{he2020fastreid,
|
||||
|
|
|
@ -128,8 +128,8 @@ _C.MODEL.PIXEL_STD = [0.229*255, 0.224*255, 0.225*255]
|
|||
# -----------------------------------------------------------------------------
|
||||
|
||||
_C.KD = CN()
|
||||
_C.KD.MODEL_CONFIG = ""
|
||||
_C.KD.MODEL_WEIGHTS = ""
|
||||
_C.KD.MODEL_CONFIG = ['',]
|
||||
_C.KD.MODEL_WEIGHTS = ['',]
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# INPUT
|
||||
|
|
|
@ -22,22 +22,25 @@ class Distiller(Baseline):
|
|||
super(Distiller, self).__init__(cfg)
|
||||
|
||||
# Get teacher model config
|
||||
cfg_t = get_cfg()
|
||||
cfg_t.merge_from_file(cfg.KD.MODEL_CONFIG)
|
||||
model_ts = []
|
||||
for i in range(len(cfg.KD.MODEL_CONFIG)):
|
||||
cfg_t = get_cfg()
|
||||
cfg_t.merge_from_file(cfg.KD.MODEL_CONFIG[i])
|
||||
|
||||
model_t = build_model(cfg_t)
|
||||
logger.info("Teacher model:\n{}".format(model_t))
|
||||
model_t = build_model(cfg_t)
|
||||
|
||||
# No gradients for teacher model
|
||||
for param in model_t.parameters():
|
||||
param.requires_grad_(False)
|
||||
# No gradients for teacher model
|
||||
for param in model_t.parameters():
|
||||
param.requires_grad_(False)
|
||||
|
||||
logger.info("Loading teacher model weights ...")
|
||||
Checkpointer(model_t).load(cfg.KD.MODEL_WEIGHTS)
|
||||
logger.info("Loading teacher model weights ...")
|
||||
Checkpointer(model_t).load(cfg.KD.MODEL_WEIGHTS[i])
|
||||
|
||||
model_ts.append(model_t)
|
||||
|
||||
# Not register teacher model as `nn.Module`, this is
|
||||
# make sure teacher model weights not saved
|
||||
self.model_t = [model_t.backbone, model_t.heads]
|
||||
self.model_ts = model_ts
|
||||
|
||||
def forward(self, batched_inputs):
|
||||
if self.training:
|
||||
|
@ -51,10 +54,13 @@ class Distiller(Baseline):
|
|||
|
||||
s_outputs = self.heads(s_feat, targets)
|
||||
|
||||
t_outputs = []
|
||||
# teacher model forward
|
||||
with torch.no_grad():
|
||||
t_feat = self.model_t[0](images)
|
||||
t_outputs = self.model_t[1](t_feat, targets)
|
||||
for model_t in self.model_ts:
|
||||
t_feat = model_t.backbone(images)
|
||||
t_output = model_t.heads(t_feat, targets)
|
||||
t_outputs.append(t_output)
|
||||
|
||||
losses = self.losses(s_outputs, t_outputs, targets)
|
||||
return losses
|
||||
|
@ -71,8 +77,12 @@ class Distiller(Baseline):
|
|||
loss_dict = super(Distiller, self).losses(s_outputs, gt_labels)
|
||||
|
||||
s_logits = s_outputs["pred_class_logits"]
|
||||
t_logits = t_outputs["pred_class_logits"].detach()
|
||||
loss_dict["loss_jsdiv"] = self.jsdiv_loss(s_logits, t_logits)
|
||||
loss_jsdiv = 0.
|
||||
for t_output in t_outputs:
|
||||
t_logits = t_output["pred_class_logits"].detach()
|
||||
loss_jsdiv += self.jsdiv_loss(s_logits, t_logits)
|
||||
|
||||
loss_dict["loss_jsdiv"] = loss_jsdiv / len(t_outputs)
|
||||
|
||||
return loss_dict
|
||||
|
||||
|
|
|
@ -8,8 +8,8 @@ MODEL:
|
|||
WITH_IBN: False
|
||||
|
||||
KD:
|
||||
MODEL_CONFIG: projects/FastDistill/logs/dukemtmc/r101_ibn/config.yaml
|
||||
MODEL_WEIGHTS: projects/FastDistill/logs/dukemtmc/r101_ibn/model_best.pth
|
||||
MODEL_CONFIG: ("projects/FastDistill/logs/dukemtmc/r101_ibn/config.yaml",)
|
||||
MODEL_WEIGHTS: ("projects/FastDistill/logs/dukemtmc/r101_ibn/model_best.pth",)
|
||||
|
||||
DATASETS:
|
||||
NAMES: ("DukeMTMC",)
|
||||
|
|
|
@ -61,16 +61,18 @@ class DistillerOverhaul(Distiller):
|
|||
super().__init__(cfg)
|
||||
|
||||
s_channels = self.backbone.get_channel_nums()
|
||||
t_channels = self.model_t[0].get_channel_nums()
|
||||
|
||||
self.connectors = nn.ModuleList(
|
||||
[build_feature_connector(t, s) for t, s in zip(t_channels, s_channels)])
|
||||
for i in range(len(self.model_ts)):
|
||||
t_channels = self.model_ts[i].backbone.get_channel_nums()
|
||||
|
||||
teacher_bns = self.model_t[0].get_bn_before_relu()
|
||||
margins = [get_margin_from_BN(bn) for bn in teacher_bns]
|
||||
for i, margin in enumerate(margins):
|
||||
self.register_buffer("margin%d" % (i + 1),
|
||||
margin.unsqueeze(1).unsqueeze(2).unsqueeze(0).detach())
|
||||
setattr(self, "connectors_{}".format(i), nn.ModuleList(
|
||||
[build_feature_connector(t, s) for t, s in zip(t_channels, s_channels)]))
|
||||
|
||||
teacher_bns = self.model_ts[i].backbone.get_bn_before_relu()
|
||||
margins = [get_margin_from_BN(bn) for bn in teacher_bns]
|
||||
for j, margin in enumerate(margins):
|
||||
self.register_buffer("margin{}_{}".format(i, j + 1),
|
||||
margin.unsqueeze(1).unsqueeze(2).unsqueeze(0).detach())
|
||||
|
||||
def forward(self, batched_inputs):
|
||||
if self.training:
|
||||
|
@ -84,20 +86,25 @@ class DistillerOverhaul(Distiller):
|
|||
|
||||
s_outputs = self.heads(s_feat, targets)
|
||||
|
||||
t_feats_list = []
|
||||
t_outputs = []
|
||||
# teacher model forward
|
||||
with torch.no_grad():
|
||||
t_feats, t_feat = self.model_t[0].extract_feature(images, preReLU=True)
|
||||
t_outputs = self.model_t[1](t_feat, targets)
|
||||
for model_t in self.model_ts:
|
||||
t_feats, t_feat = model_t.backbone.extract_feature(images, preReLU=True)
|
||||
t_output = model_t.heads(t_feat, targets)
|
||||
t_feats_list.append(t_feats)
|
||||
t_outputs.append(t_output)
|
||||
|
||||
losses = self.losses(s_outputs, s_feats, t_outputs, t_feats, targets)
|
||||
losses = self.losses(s_outputs, s_feats, t_outputs, t_feats_list, targets)
|
||||
return losses
|
||||
|
||||
else:
|
||||
outputs = super(DistillerOverhaul, self).forward(batched_inputs)
|
||||
return outputs
|
||||
|
||||
def losses(self, s_outputs, s_feats, t_outputs, t_feats, gt_labels):
|
||||
r"""
|
||||
def losses(self, s_outputs, s_feats, t_outputs, t_feats_list, gt_labels):
|
||||
"""
|
||||
Compute loss from modeling's outputs, the loss function input arguments
|
||||
must be the same as the outputs of the model forwarding.
|
||||
"""
|
||||
|
@ -106,11 +113,12 @@ class DistillerOverhaul(Distiller):
|
|||
# Overhaul distillation loss
|
||||
feat_num = len(s_feats)
|
||||
loss_distill = 0
|
||||
for i in range(feat_num):
|
||||
s_feats[i] = self.connectors[i](s_feats[i])
|
||||
loss_distill += distillation_loss(s_feats[i], t_feats[i].detach(), getattr(
|
||||
self, "margin%d" % (i + 1)).to(s_feats[i].dtype)) / 2 ** (feat_num - i - 1)
|
||||
for i in range(len(t_feats_list)):
|
||||
for j in range(feat_num):
|
||||
s_feats_connect = getattr(self, "connectors_{}".format(i))[j](s_feats[j])
|
||||
loss_distill += distillation_loss(s_feats_connect, t_feats_list[i][j].detach(), getattr(
|
||||
self, "margin{}_{}".format(i, j + 1)).to(s_feats_connect.dtype)) / 2 ** (feat_num - j - 1)
|
||||
|
||||
loss_dict["loss_overhaul"] = loss_distill / len(gt_labels) / 10000
|
||||
loss_dict["loss_overhaul"] = loss_distill / len(t_feats_list) / len(gt_labels) / 10000
|
||||
|
||||
return loss_dict
|
||||
|
|
Loading…
Reference in New Issue