feat: support multi-teacher kd

Summary: support multi-teacher kd with logits and overhaul distillation
pull/397/head
liaoxingyu 2021-01-29 17:25:31 +08:00
parent db8670db63
commit 77a91b1204
7 changed files with 118 additions and 55 deletions

41
.gitignore vendored
View File

@ -1,8 +1,37 @@
.idea
logs
# compilation and distribution
__pycache__
.DS_Store
.vscode
_ext
*.pyc
*.pyd
*.so
logs/
.ipynb_checkpoints
logs
*.dll
*.egg-info/
build/
dist/
wheels/
# pytorch/python/numpy formats
*.pth
*.pkl
*.npy
*.ts
model_ts*.txt
# ipython/jupyter notebooks
*.ipynb
**/.ipynb_checkpoints/
# Editor temporaries
*.swn
*.swo
*.swp
*~
# editor settings
.idea
.vscode
_darcs
.DS_Store

12
CHANGELOG.md 100644
View File

@ -0,0 +1,12 @@
# Changelog
### v1.1 (29/01/2021)
#### New Features
- NAIC20(reid track) [1-st solution](https://github.com/JDAI-CV/fast-reid/tree/master/projects/NAIC20)
- Multi-teacher Knowledge Distillation
#### Bug Fixes
#### Improvements

View File

@ -4,48 +4,52 @@ FastReID is a research platform that implements state-of-the-art re-identificati
## What's New
- [Jan 2021] NAIC20(reid track) [1-st solution](https://github.com/JDAI-CV/fast-reid/tree/master/projects/NAIC20) based on fastreid has been released
- [Jan 2021] NAIC20(reid track) [1-st solution](projects/NAIC20) based on fastreid has been released
- [Jan 2021] FastReID V1.0 has been released🎉
Support many tasks beyond reid, such image retrieval and face recognition. See [release notes](https://github.com/JDAI-CV/fast-reid/releases/tag/v1.0.0).
- [Oct 2020] Added the [Hyper-Parameter Optimization](https://github.com/JDAI-CV/fast-reid/tree/master/projects/FastTune) based on fastreid. See `projects/FastTune`.
- [Sep 2020] Added the [person attribute recognition](https://github.com/JDAI-CV/fast-reid/tree/master/projects/FastAttr) based on fastreid. See `projects/FastAttr`.
- [Oct 2020] Added the [Hyper-Parameter Optimization](projects/FastTune) based on fastreid. See `projects/FastTune`.
- [Sep 2020] Added the [person attribute recognition](projects/FastAttr) based on fastreid. See `projects/FastAttr`.
- [Sep 2020] Automatic Mixed Precision training is supported with `apex`. Set `cfg.SOLVER.FP16_ENABLED=True` to switch it on.
- [Aug 2020] [Model Distillation](https://github.com/JDAI-CV/fast-reid/tree/master/projects/FastDistill) is supported, thanks for [guan'an wang](https://github.com/wangguanan)'s contribution.
- [Aug 2020] [Model Distillation](projects/FastDistill) is supported, thanks for [guan'an wang](https://github.com/wangguanan)'s contribution.
- [Aug 2020] ONNX/TensorRT converter is supported.
- [Jul 2020] Distributed training with multiple GPUs, it trains much faster.
- Includes more features such as circle loss, abundant visualization methods and evaluation metrics, SoTA results on conventional, cross-domain, partial and vehicle re-id, testing on multi-datasets simultaneously, etc.
- Can be used as a library to support [different projects](https://github.com/JDAI-CV/fast-reid/tree/master/projects) on top of it. We'll open source more research projects in this way.
- Can be used as a library to support [different projects](projects) on top of it. We'll open source more research projects in this way.
- Remove [ignite](https://github.com/pytorch/ignite)(a high-level library) dependency and powered by [PyTorch](https://pytorch.org/).
We write a [chinese blog](https://l1aoxingyu.github.io/blogpages/reid/2020/05/29/fastreid.html) about this toolbox.
## Changelog
Please refer to [changelog.md](CHANGELOG.md) for details and release history.
## Installation
See [INSTALL.md](https://github.com/JDAI-CV/fast-reid/blob/master/INSTALL.md).
See [INSTALL.md](INSTALL.md).
## Quick Start
The designed architecture follows this guide [PyTorch-Project-Template](https://github.com/L1aoXingyu/PyTorch-Project-Template), you can check each folder's purpose by yourself.
See [GETTING_STARTED.md](https://github.com/JDAI-CV/fast-reid/blob/master/GETTING_STARTED.md).
See [GETTING_STARTED.md](GETTING_STARTED.md).
Learn more at out [documentation](https://fast-reid.readthedocs.io/). And see [projects/](https://github.com/JDAI-CV/fast-reid/tree/master/projects) for some projects that are build on top of fastreid.
Learn more at out [documentation](https://fast-reid.readthedocs.io/). And see [projects/](projects) for some projects that are build on top of fastreid.
## Model Zoo and Baselines
We provide a large set of baseline results and trained models available for download in the [Fastreid Model Zoo](https://github.com/JDAI-CV/fast-reid/blob/master/MODEL_ZOO.md).
We provide a large set of baseline results and trained models available for download in the [Fastreid Model Zoo](MODEL_ZOO.md).
## Deployment
We provide some examples and scripts to convert fastreid model to Caffe, ONNX and TensorRT format in [Fastreid deploy](https://github.com/JDAI-CV/fast-reid/blob/master/tools/deploy).
We provide some examples and scripts to convert fastreid model to Caffe, ONNX and TensorRT format in [Fastreid deploy](tools/deploy).
## License
Fastreid is released under the [Apache 2.0 license](https://github.com/JDAI-CV/fast-reid/blob/master/LICENSE).
Fastreid is released under the [Apache 2.0 license](LICENSE).
## Citing Fastreid
## Citing FastReID
If you use Fastreid in your research or wish to refer to the baseline results published in the Model Zoo, please use the following BibTeX entry.
If you use FastReID in your research or wish to refer to the baseline results published in the Model Zoo, please use the following BibTeX entry.
```BibTeX
@article{he2020fastreid,

View File

@ -128,8 +128,8 @@ _C.MODEL.PIXEL_STD = [0.229*255, 0.224*255, 0.225*255]
# -----------------------------------------------------------------------------
_C.KD = CN()
_C.KD.MODEL_CONFIG = ""
_C.KD.MODEL_WEIGHTS = ""
_C.KD.MODEL_CONFIG = ['',]
_C.KD.MODEL_WEIGHTS = ['',]
# -----------------------------------------------------------------------------
# INPUT

View File

@ -22,22 +22,25 @@ class Distiller(Baseline):
super(Distiller, self).__init__(cfg)
# Get teacher model config
cfg_t = get_cfg()
cfg_t.merge_from_file(cfg.KD.MODEL_CONFIG)
model_ts = []
for i in range(len(cfg.KD.MODEL_CONFIG)):
cfg_t = get_cfg()
cfg_t.merge_from_file(cfg.KD.MODEL_CONFIG[i])
model_t = build_model(cfg_t)
logger.info("Teacher model:\n{}".format(model_t))
model_t = build_model(cfg_t)
# No gradients for teacher model
for param in model_t.parameters():
param.requires_grad_(False)
# No gradients for teacher model
for param in model_t.parameters():
param.requires_grad_(False)
logger.info("Loading teacher model weights ...")
Checkpointer(model_t).load(cfg.KD.MODEL_WEIGHTS)
logger.info("Loading teacher model weights ...")
Checkpointer(model_t).load(cfg.KD.MODEL_WEIGHTS[i])
model_ts.append(model_t)
# Not register teacher model as `nn.Module`, this is
# make sure teacher model weights not saved
self.model_t = [model_t.backbone, model_t.heads]
self.model_ts = model_ts
def forward(self, batched_inputs):
if self.training:
@ -51,10 +54,13 @@ class Distiller(Baseline):
s_outputs = self.heads(s_feat, targets)
t_outputs = []
# teacher model forward
with torch.no_grad():
t_feat = self.model_t[0](images)
t_outputs = self.model_t[1](t_feat, targets)
for model_t in self.model_ts:
t_feat = model_t.backbone(images)
t_output = model_t.heads(t_feat, targets)
t_outputs.append(t_output)
losses = self.losses(s_outputs, t_outputs, targets)
return losses
@ -71,8 +77,12 @@ class Distiller(Baseline):
loss_dict = super(Distiller, self).losses(s_outputs, gt_labels)
s_logits = s_outputs["pred_class_logits"]
t_logits = t_outputs["pred_class_logits"].detach()
loss_dict["loss_jsdiv"] = self.jsdiv_loss(s_logits, t_logits)
loss_jsdiv = 0.
for t_output in t_outputs:
t_logits = t_output["pred_class_logits"].detach()
loss_jsdiv += self.jsdiv_loss(s_logits, t_logits)
loss_dict["loss_jsdiv"] = loss_jsdiv / len(t_outputs)
return loss_dict

View File

@ -8,8 +8,8 @@ MODEL:
WITH_IBN: False
KD:
MODEL_CONFIG: projects/FastDistill/logs/dukemtmc/r101_ibn/config.yaml
MODEL_WEIGHTS: projects/FastDistill/logs/dukemtmc/r101_ibn/model_best.pth
MODEL_CONFIG: ("projects/FastDistill/logs/dukemtmc/r101_ibn/config.yaml",)
MODEL_WEIGHTS: ("projects/FastDistill/logs/dukemtmc/r101_ibn/model_best.pth",)
DATASETS:
NAMES: ("DukeMTMC",)

View File

@ -61,16 +61,18 @@ class DistillerOverhaul(Distiller):
super().__init__(cfg)
s_channels = self.backbone.get_channel_nums()
t_channels = self.model_t[0].get_channel_nums()
self.connectors = nn.ModuleList(
[build_feature_connector(t, s) for t, s in zip(t_channels, s_channels)])
for i in range(len(self.model_ts)):
t_channels = self.model_ts[i].backbone.get_channel_nums()
teacher_bns = self.model_t[0].get_bn_before_relu()
margins = [get_margin_from_BN(bn) for bn in teacher_bns]
for i, margin in enumerate(margins):
self.register_buffer("margin%d" % (i + 1),
margin.unsqueeze(1).unsqueeze(2).unsqueeze(0).detach())
setattr(self, "connectors_{}".format(i), nn.ModuleList(
[build_feature_connector(t, s) for t, s in zip(t_channels, s_channels)]))
teacher_bns = self.model_ts[i].backbone.get_bn_before_relu()
margins = [get_margin_from_BN(bn) for bn in teacher_bns]
for j, margin in enumerate(margins):
self.register_buffer("margin{}_{}".format(i, j + 1),
margin.unsqueeze(1).unsqueeze(2).unsqueeze(0).detach())
def forward(self, batched_inputs):
if self.training:
@ -84,20 +86,25 @@ class DistillerOverhaul(Distiller):
s_outputs = self.heads(s_feat, targets)
t_feats_list = []
t_outputs = []
# teacher model forward
with torch.no_grad():
t_feats, t_feat = self.model_t[0].extract_feature(images, preReLU=True)
t_outputs = self.model_t[1](t_feat, targets)
for model_t in self.model_ts:
t_feats, t_feat = model_t.backbone.extract_feature(images, preReLU=True)
t_output = model_t.heads(t_feat, targets)
t_feats_list.append(t_feats)
t_outputs.append(t_output)
losses = self.losses(s_outputs, s_feats, t_outputs, t_feats, targets)
losses = self.losses(s_outputs, s_feats, t_outputs, t_feats_list, targets)
return losses
else:
outputs = super(DistillerOverhaul, self).forward(batched_inputs)
return outputs
def losses(self, s_outputs, s_feats, t_outputs, t_feats, gt_labels):
r"""
def losses(self, s_outputs, s_feats, t_outputs, t_feats_list, gt_labels):
"""
Compute loss from modeling's outputs, the loss function input arguments
must be the same as the outputs of the model forwarding.
"""
@ -106,11 +113,12 @@ class DistillerOverhaul(Distiller):
# Overhaul distillation loss
feat_num = len(s_feats)
loss_distill = 0
for i in range(feat_num):
s_feats[i] = self.connectors[i](s_feats[i])
loss_distill += distillation_loss(s_feats[i], t_feats[i].detach(), getattr(
self, "margin%d" % (i + 1)).to(s_feats[i].dtype)) / 2 ** (feat_num - i - 1)
for i in range(len(t_feats_list)):
for j in range(feat_num):
s_feats_connect = getattr(self, "connectors_{}".format(i))[j](s_feats[j])
loss_distill += distillation_loss(s_feats_connect, t_feats_list[i][j].detach(), getattr(
self, "margin{}_{}".format(i, j + 1)).to(s_feats_connect.dtype)) / 2 ** (feat_num - j - 1)
loss_dict["loss_overhaul"] = loss_distill / len(gt_labels) / 10000
loss_dict["loss_overhaul"] = loss_distill / len(t_feats_list) / len(gt_labels) / 10000
return loss_dict