mirror of
https://github.com/msight-tech/research-ms-loss.git
synced 2025-06-03 14:48:45 +08:00
initial commit
This commit is contained in:
parent
d72939633d
commit
8290973c8e
9
.flake8
Normal file
9
.flake8
Normal file
@ -0,0 +1,9 @@
|
||||
[flake8]
|
||||
ignore = F401, F841, E402, E722, E999
|
||||
max-line-length = 128
|
||||
max-complexity=18
|
||||
format=pylint
|
||||
show_source = True
|
||||
statistics = True
|
||||
count = True
|
||||
exclude = tests,ret_benchmark/modeling/backbone
|
4
.git_push.sh
Normal file
4
.git_push.sh
Normal file
@ -0,0 +1,4 @@
|
||||
git add .
|
||||
git status
|
||||
git commit -m 'update'
|
||||
git push
|
30
.gitignore
vendored
Normal file
30
.gitignore
vendored
Normal file
@ -0,0 +1,30 @@
|
||||
resource
|
||||
build
|
||||
*.pyc
|
||||
*.zip
|
||||
*/__pycache__
|
||||
__pycache__
|
||||
|
||||
# Package Files #
|
||||
*.pkl
|
||||
*.log
|
||||
*.jar
|
||||
*.war
|
||||
*.nar
|
||||
*.ear
|
||||
*.zip
|
||||
*.tar.gz
|
||||
*.rar
|
||||
*.egg-info
|
||||
|
||||
#some local files
|
||||
*/.settings/
|
||||
*/.DS_Store
|
||||
.DS_Store
|
||||
*/.idea/
|
||||
.idea/
|
||||
gradlew
|
||||
gradlew.bat
|
||||
unused.txt
|
||||
output/
|
||||
*.egg-info/
|
335
LICENSE
Normal file
335
LICENSE
Normal file
@ -0,0 +1,335 @@
|
||||
Creative Commons Attribution-NonCommercial 4.0 International (CC-BY-NC-4.0)
|
||||
Public License
|
||||
|
||||
For Multi-Similarity Loss for Deep Metric Learning (MS-Loss)
|
||||
|
||||
Copyright (c) 2014-present, Malong Technologies Co., Ltd. All rights reserved.
|
||||
|
||||
|
||||
By exercising the Licensed Rights (defined below), You accept and agree
|
||||
to be bound by the terms and conditions of this Creative Commons
|
||||
Attribution-NonCommercial 4.0 International Public License ("Public
|
||||
License"). To the extent this Public License may be interpreted as a
|
||||
contract, You are granted the Licensed Rights in consideration of Your
|
||||
acceptance of these terms and conditions, and the Licensor grants You
|
||||
such rights in consideration of benefits the Licensor receives from
|
||||
making the Licensed Material available under these terms and
|
||||
conditions.
|
||||
|
||||
|
||||
Section 1 -- Definitions.
|
||||
|
||||
a. Adapted Material means material subject to Copyright and Similar
|
||||
Rights that is derived from or based upon the Licensed Material
|
||||
and in which the Licensed Material is translated, altered,
|
||||
arranged, transformed, or otherwise modified in a manner requiring
|
||||
permission under the Copyright and Similar Rights held by the
|
||||
Licensor. For purposes of this Public License, where the Licensed
|
||||
Material is a musical work, performance, or sound recording,
|
||||
Adapted Material is always produced where the Licensed Material is
|
||||
synched in timed relation with a moving image.
|
||||
|
||||
b. Adapter's License means the license You apply to Your Copyright
|
||||
and Similar Rights in Your contributions to Adapted Material in
|
||||
accordance with the terms and conditions of this Public License.
|
||||
|
||||
c. Copyright and Similar Rights means copyright and/or similar rights
|
||||
closely related to copyright including, without limitation,
|
||||
performance, broadcast, sound recording, and Sui Generis Database
|
||||
Rights, without regard to how the rights are labeled or
|
||||
categorized. For purposes of this Public License, the rights
|
||||
specified in Section 2(b)(1)-(2) are not Copyright and Similar
|
||||
Rights.
|
||||
d. Effective Technological Measures means those measures that, in the
|
||||
absence of proper authority, may not be circumvented under laws
|
||||
fulfilling obligations under Article 11 of the WIPO Copyright
|
||||
Treaty adopted on December 20, 1996, and/or similar international
|
||||
agreements.
|
||||
|
||||
e. Exceptions and Limitations means fair use, fair dealing, and/or
|
||||
any other exception or limitation to Copyright and Similar Rights
|
||||
that applies to Your use of the Licensed Material.
|
||||
|
||||
f. Licensed Material means the artistic or literary work, database,
|
||||
or other material to which the Licensor applied this Public
|
||||
License.
|
||||
|
||||
g. Licensed Rights means the rights granted to You subject to the
|
||||
terms and conditions of this Public License, which are limited to
|
||||
all Copyright and Similar Rights that apply to Your use of the
|
||||
Licensed Material and that the Licensor has authority to license.
|
||||
|
||||
h. Licensor means the individual(s) or entity(ies) granting rights
|
||||
under this Public License.
|
||||
|
||||
i. NonCommercial means not primarily intended for or directed towards
|
||||
commercial advantage or monetary compensation. For purposes of
|
||||
this Public License, the exchange of the Licensed Material for
|
||||
other material subject to Copyright and Similar Rights by digital
|
||||
file-sharing or similar means is NonCommercial provided there is
|
||||
no payment of monetary compensation in connection with the
|
||||
exchange.
|
||||
|
||||
j. Share means to provide material to the public by any means or
|
||||
process that requires permission under the Licensed Rights, such
|
||||
as reproduction, public display, public performance, distribution,
|
||||
dissemination, communication, or importation, and to make material
|
||||
available to the public including in ways that members of the
|
||||
public may access the material from a place and at a time
|
||||
individually chosen by them.
|
||||
|
||||
k. Sui Generis Database Rights means rights other than copyright
|
||||
resulting from Directive 96/9/EC of the European Parliament and of
|
||||
the Council of 11 March 1996 on the legal protection of databases,
|
||||
as amended and/or succeeded, as well as other essentially
|
||||
equivalent rights anywhere in the world.
|
||||
|
||||
l. You means the individual or entity exercising the Licensed Rights
|
||||
under this Public License. Your has a corresponding meaning.
|
||||
|
||||
|
||||
Section 2 -- Scope.
|
||||
|
||||
a. License grant.
|
||||
|
||||
1. Subject to the terms and conditions of this Public License,
|
||||
the Licensor hereby grants You a worldwide, royalty-free,
|
||||
non-sublicensable, non-exclusive, irrevocable license to
|
||||
exercise the Licensed Rights in the Licensed Material to:
|
||||
|
||||
a. reproduce and Share the Licensed Material, in whole or
|
||||
in part, for NonCommercial purposes only; and
|
||||
|
||||
b. produce, reproduce, and Share Adapted Material for
|
||||
NonCommercial purposes only.
|
||||
|
||||
2. Exceptions and Limitations. For the avoidance of doubt, where
|
||||
Exceptions and Limitations apply to Your use, this Public
|
||||
License does not apply, and You do not need to comply with
|
||||
its terms and conditions.
|
||||
|
||||
3. Term. The term of this Public License is specified in Section
|
||||
6(a).
|
||||
|
||||
4. Media and formats; technical modifications allowed. The
|
||||
Licensor authorizes You to exercise the Licensed Rights in
|
||||
all media and formats whether now known or hereafter created,
|
||||
and to make technical modifications necessary to do so. The
|
||||
Licensor waives and/or agrees not to assert any right or
|
||||
authority to forbid You from making technical modifications
|
||||
necessary to exercise the Licensed Rights, including
|
||||
technical modifications necessary to circumvent Effective
|
||||
Technological Measures. For purposes of this Public License,
|
||||
simply making modifications authorized by this Section 2(a)
|
||||
(4) never produces Adapted Material.
|
||||
|
||||
5. Downstream recipients.
|
||||
|
||||
a. Offer from the Licensor -- Licensed Material. Every
|
||||
recipient of the Licensed Material automatically
|
||||
receives an offer from the Licensor to exercise the
|
||||
Licensed Rights under the terms and conditions of this
|
||||
Public License.
|
||||
|
||||
b. No downstream restrictions. You may not offer or impose
|
||||
any additional or different terms or conditions on, or
|
||||
apply any Effective Technological Measures to, the
|
||||
Licensed Material if doing so restricts exercise of the
|
||||
Licensed Rights by any recipient of the Licensed
|
||||
Material.
|
||||
|
||||
6. No endorsement. Nothing in this Public License constitutes or
|
||||
may be construed as permission to assert or imply that You
|
||||
are, or that Your use of the Licensed Material is, connected
|
||||
with, or sponsored, endorsed, or granted official status by,
|
||||
the Licensor or others designated to receive attribution as
|
||||
provided in Section 3(a)(1)(A)(i).
|
||||
|
||||
b. Other rights.
|
||||
|
||||
1. Moral rights, such as the right of integrity, are not
|
||||
licensed under this Public License, nor are publicity,
|
||||
privacy, and/or other similar personality rights; however, to
|
||||
the extent possible, the Licensor waives and/or agrees not to
|
||||
assert any such rights held by the Licensor to the limited
|
||||
extent necessary to allow You to exercise the Licensed
|
||||
Rights, but not otherwise.
|
||||
|
||||
2. Patent and trademark rights are not licensed under this
|
||||
Public License.
|
||||
|
||||
3. To the extent possible, the Licensor waives any right to
|
||||
collect royalties from You for the exercise of the Licensed
|
||||
Rights, whether directly or through a collecting society
|
||||
under any voluntary or waivable statutory or compulsory
|
||||
licensing scheme. In all other cases the Licensor expressly
|
||||
reserves any right to collect such royalties, including when
|
||||
the Licensed Material is used other than for NonCommercial
|
||||
purposes.
|
||||
|
||||
|
||||
Section 3 -- License Conditions.
|
||||
|
||||
Your exercise of the Licensed Rights is expressly made subject to the
|
||||
following conditions.
|
||||
|
||||
a. Attribution.
|
||||
|
||||
1. If You Share the Licensed Material (including in modified
|
||||
form), You must:
|
||||
|
||||
a. retain the following if it is supplied by the Licensor
|
||||
with the Licensed Material:
|
||||
|
||||
i. identification of the creator(s) of the Licensed
|
||||
Material and any others designated to receive
|
||||
attribution, in any reasonable manner requested by
|
||||
the Licensor (including by pseudonym if
|
||||
designated);
|
||||
|
||||
ii. a copyright notice;
|
||||
|
||||
iii. a notice that refers to this Public License;
|
||||
|
||||
iv. a notice that refers to the disclaimer of
|
||||
warranties;
|
||||
|
||||
v. a URI or hyperlink to the Licensed Material to the
|
||||
extent reasonably practicable;
|
||||
|
||||
b. indicate if You modified the Licensed Material and
|
||||
retain an indication of any previous modifications; and
|
||||
|
||||
c. indicate the Licensed Material is licensed under this
|
||||
Public License, and include the text of, or the URI or
|
||||
hyperlink to, this Public License.
|
||||
|
||||
2. You may satisfy the conditions in Section 3(a)(1) in any
|
||||
reasonable manner based on the medium, means, and context in
|
||||
which You Share the Licensed Material. For example, it may be
|
||||
reasonable to satisfy the conditions by providing a URI or
|
||||
hyperlink to a resource that includes the required
|
||||
information.
|
||||
|
||||
3. If requested by the Licensor, You must remove any of the
|
||||
information required by Section 3(a)(1)(A) to the extent
|
||||
reasonably practicable.
|
||||
|
||||
4. If You Share Adapted Material You produce, the Adapter's
|
||||
License You apply must not prevent recipients of the Adapted
|
||||
Material from complying with this Public License.
|
||||
|
||||
|
||||
Section 4 -- Sui Generis Database Rights.
|
||||
|
||||
Where the Licensed Rights include Sui Generis Database Rights that
|
||||
apply to Your use of the Licensed Material:
|
||||
|
||||
a. for the avoidance of doubt, Section 2(a)(1) grants You the right
|
||||
to extract, reuse, reproduce, and Share all or a substantial
|
||||
portion of the contents of the database for NonCommercial purposes
|
||||
only;
|
||||
|
||||
b. if You include all or a substantial portion of the database
|
||||
contents in a database in which You have Sui Generis Database
|
||||
Rights, then the database in which You have Sui Generis Database
|
||||
Rights (but not its individual contents) is Adapted Material; and
|
||||
|
||||
c. You must comply with the conditions in Section 3(a) if You Share
|
||||
all or a substantial portion of the contents of the database.
|
||||
|
||||
For the avoidance of doubt, this Section 4 supplements and does not
|
||||
replace Your obligations under this Public License where the Licensed
|
||||
Rights include other Copyright and Similar Rights.
|
||||
|
||||
|
||||
Section 5 -- Disclaimer of Warranties and Limitation of Liability.
|
||||
|
||||
a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
|
||||
EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
|
||||
AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
|
||||
ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
|
||||
IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
|
||||
WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
|
||||
PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
|
||||
ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
|
||||
KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
|
||||
ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
|
||||
|
||||
b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
|
||||
TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
|
||||
NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
|
||||
INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
|
||||
COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
|
||||
USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
|
||||
ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
|
||||
DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
|
||||
IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
|
||||
|
||||
c. The disclaimer of warranties and limitation of liability provided
|
||||
above shall be interpreted in a manner that, to the extent
|
||||
possible, most closely approximates an absolute disclaimer and
|
||||
waiver of all liability.
|
||||
|
||||
|
||||
Section 6 -- Term and Termination.
|
||||
|
||||
a. This Public License applies for the term of the Copyright and
|
||||
Similar Rights licensed here. However, if You fail to comply with
|
||||
this Public License, then Your rights under this Public License
|
||||
terminate automatically.
|
||||
|
||||
b. Where Your right to use the Licensed Material has terminated under
|
||||
Section 6(a), it reinstates:
|
||||
|
||||
1. automatically as of the date the violation is cured, provided
|
||||
it is cured within 30 days of Your discovery of the
|
||||
violation; or
|
||||
|
||||
2. upon express reinstatement by the Licensor.
|
||||
|
||||
For the avoidance of doubt, this Section 6(b) does not affect any
|
||||
right the Licensor may have to seek remedies for Your violations
|
||||
of this Public License.
|
||||
|
||||
c. For the avoidance of doubt, the Licensor may also offer the
|
||||
Licensed Material under separate terms or conditions or stop
|
||||
distributing the Licensed Material at any time; however, doing so
|
||||
will not terminate this Public License.
|
||||
|
||||
d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
|
||||
License.
|
||||
|
||||
|
||||
Section 7 -- Other Terms and Conditions.
|
||||
|
||||
a. The Licensor shall not be bound by any additional or different
|
||||
terms or conditions communicated by You unless expressly agreed.
|
||||
|
||||
b. Any arrangements, understandings, or agreements regarding the
|
||||
Licensed Material not stated herein are separate from and
|
||||
independent of the terms and conditions of this Public License.
|
||||
|
||||
|
||||
Section 8 -- Interpretation.
|
||||
|
||||
a. For the avoidance of doubt, this Public License does not, and
|
||||
shall not be interpreted to, reduce, limit, restrict, or impose
|
||||
conditions on any use of the Licensed Material that could lawfully
|
||||
be made without permission under this Public License.
|
||||
|
||||
b. To the extent possible, if any provision of this Public License is
|
||||
deemed unenforceable, it shall be automatically reformed to the
|
||||
minimum extent necessary to make it enforceable. If the provision
|
||||
cannot be reformed, it shall be severed from this Public License
|
||||
without affecting the enforceability of the remaining terms and
|
||||
conditions.
|
||||
|
||||
c. No term or condition of this Public License will be waived and no
|
||||
failure to comply consented to unless expressly agreed to by the
|
||||
Licensor.
|
||||
|
||||
d. Nothing in this Public License constitutes or may be interpreted
|
||||
as a limitation upon, or waiver of, any privileges and immunities
|
||||
that apply to the Licensor or You, including from the legal
|
||||
processes of any jurisdiction or authority.
|
77
README.md
77
README.md
@ -1 +1,76 @@
|
||||
Coming soon!
|
||||
[](https://creativecommons.org/licenses/by-nc/4.0/)
|
||||
|
||||
|
||||
# Multi-Similarity Loss for Deep Metric Learning (MS-Loss)
|
||||
|
||||
Code for the CVPR 2019 paper [Multi-Similarity Loss with General Pair Weighting for Deep Metric Learning](http://openaccess.thecvf.com/content_CVPR_2019/papers/Wang_Multi-Similarity_Loss_With_General_Pair_Weighting_for_Deep_Metric_Learning_CVPR_2019_paper.pdf)
|
||||
|
||||
<img src="misc/ms_loss.png" width="65%" height="65%">
|
||||
|
||||
### Performance compared with SOTA methods on CUB-200-2011
|
||||
|
||||
|Rank@K | 1 | 2 | 4 | 8 | 16 | 32 |
|
||||
|:--- |:-:|:-:|:-:|:-:|:-: |:-: |
|
||||
|Clustering<sup>64</sup> | 48.2 | 61.4 | 71.8 | 81.9 | - | - |
|
||||
|ProxyNCA<sup>64</sup> | 49.2 | 61.9 | 67.9 | 72.4 | - | - |
|
||||
|Smart Mining<sup>64</sup> | 49.8 | 62.3 | 74.1 | 83.3 | - |
|
||||
|Our MS-Loss<sup>64</sup>| **57.4** |**69.8** |**80.0** |**87.8** |93.2 |96.4|
|
||||
|HTL<sup>512</sup> | 57.1| 68.8| 78.7| 86.5| 92.5| 95.5 |
|
||||
|ABIER<sup>512</sup> |57.5 |68.7 |78.3 |86.2 |91.9 |95.5 |
|
||||
|Our MS-Loss<sup>512</sup>|**65.7** |**77.0** |**86.3**|**91.2** |**95.0** |**97.3**|
|
||||
|
||||
|
||||
### Prepare the data and the pretrained model
|
||||
|
||||
Download the dataset from [CUB](http://www.vision.caltech.edu/visipedia-data/CUB-200/images.tgz), put it in the ./resource/datasets/ folder, and build data list (train.txt test.txt) as below:
|
||||
|
||||
```bash
|
||||
train/020.Yellow_breasted_Chat/Yellow_Breasted_Chat_0075_21715.jpg,0
|
||||
train/020.Yellow_breasted_Chat/Yellow_Breasted_Chat_0012_21961.jpg,0
|
||||
train/043.Yellow_bellied_Flycatcher/Yellow_Bellied_Flycatcher_0008_42703.jpg,1
|
||||
train/043.Yellow_bellied_Flycatcher/Yellow_Bellied_Flycatcher_0009_795510.jpg,1
|
||||
train/043.Yellow_bellied_Flycatcher/Yellow_Bellied_Flycatcher_0003_795487.jpg,1
|
||||
```
|
||||
|
||||
Download the imagenet pretrained model of
|
||||
[bninception](http://data.lip6.fr/cadene/pretrainedmodels/bn_inception-52deb4733.pth) and put it in the folder: ~/.torch/models/.
|
||||
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
python setup.py develop build
|
||||
```
|
||||
### Train and Test on CUB200-2011 with MS-Loss
|
||||
|
||||
```bash
|
||||
sh run_cub.sh
|
||||
```
|
||||
Trained models will be saved in the ./output/ folder if using the default config.
|
||||
|
||||
Best recall@1 higher than 66 (65.7 in the paper).
|
||||
|
||||
### Contact
|
||||
|
||||
For any questions, please feel free to reach
|
||||
```
|
||||
github@malong.com
|
||||
```
|
||||
|
||||
### Citation
|
||||
|
||||
If you use this method or this code in your research, please cite as:
|
||||
|
||||
@inproceedings{wang2019multi,
|
||||
title={Multi-Similarity Loss with General Pair Weighting for Deep Metric Learning},
|
||||
author={Wang, Xun and Han, Xintong and Huang, Weilin and Dong, Dengke and Scott, Matthew R},
|
||||
booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
|
||||
pages={5022--5030},
|
||||
year={2019}
|
||||
}
|
||||
|
||||
## License
|
||||
|
||||
MS-Loss is CC-BY-NC 4.0 licensed, as found in the [LICENSE](LICENSE) file. It is released for academic research / non-commercial use only. If you wish to use for commercial purposes, please contact bd@malong.com.
|
||||
|
||||
|
65
ThirdPartyNotices.txt
Normal file
65
ThirdPartyNotices.txt
Normal file
@ -0,0 +1,65 @@
|
||||
THIRD PARTY SOFTWARE NOTICES AND INFORMATION
|
||||
|
||||
Do Not Translate or Localize
|
||||
|
||||
This software incorporates material from the following third parties.
|
||||
|
||||
_____
|
||||
|
||||
Cadene/pretrained-models.pytorch
|
||||
|
||||
BSD 3-Clause License
|
||||
|
||||
Copyright (c) 2017, Remi Cadene
|
||||
All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
* Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
* Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
* Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
_____
|
||||
|
||||
facebookresearch/maskrcnn-benchmark
|
||||
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2018 Facebook
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
29
configs/example.yaml
Normal file
29
configs/example.yaml
Normal file
@ -0,0 +1,29 @@
|
||||
# Copyright (c) Malong Technologies Co., Ltd.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Contact: github@malong.com
|
||||
#
|
||||
# This source code is licensed under the LICENSE file in the root directory of this source tree.
|
||||
|
||||
MODEL:
|
||||
BACKBONE:
|
||||
NAME: bninception
|
||||
|
||||
SOLVER:
|
||||
MAX_ITERS: 3000
|
||||
STEPS: [1200, 2400]
|
||||
OPTIMIZER_NAME: Adam
|
||||
BASE_LR: 0.00003
|
||||
WARMUP_ITERS: 0
|
||||
WEIGHT_DECAY: 0.0005
|
||||
|
||||
DATA:
|
||||
TRAIN_IMG_SOURCE: resource/datasets/CUB_200_2011/train.txt
|
||||
TEST_IMG_SOURCE: resource/datasets/CUB_200_2011/test.txt
|
||||
TRAIN_BATCHSIZE: 80
|
||||
TEST_BATCHSIZE: 256
|
||||
NUM_WORKERS: 8
|
||||
NUM_INSTANCES: 5
|
||||
|
||||
VALIDATION:
|
||||
VERBOSE: 200
|
BIN
misc/ms_loss.png
Normal file
BIN
misc/ms_loss.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 88 KiB |
7
requirements.txt
Normal file
7
requirements.txt
Normal file
@ -0,0 +1,7 @@
|
||||
torch==1.1.0
|
||||
numpy==1.15.4
|
||||
yacs==0.1.4
|
||||
setuptools==40.6.2
|
||||
pytest==4.4.0
|
||||
Pillow==6.1.0
|
||||
torchvision==0.3.0
|
8
ret_benchmark/config/__init__.py
Normal file
8
ret_benchmark/config/__init__.py
Normal file
@ -0,0 +1,8 @@
|
||||
# Copyright (c) Malong Technologies Co., Ltd.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Contact: github@malong.com
|
||||
#
|
||||
# This source code is licensed under the LICENSE file in the root directory of this source tree.
|
||||
|
||||
from .defaults import _C as cfg
|
94
ret_benchmark/config/defaults.py
Normal file
94
ret_benchmark/config/defaults.py
Normal file
@ -0,0 +1,94 @@
|
||||
# Copyright (c) Malong Technologies Co., Ltd.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Contact: github@malong.com
|
||||
#
|
||||
# This source code is licensed under the LICENSE file in the root directory of this source tree.
|
||||
|
||||
from yacs.config import CfgNode as CN
|
||||
from .model_path import MODEL_PATH
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Config definition
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
_C = CN()
|
||||
|
||||
_C.MODEL = CN()
|
||||
_C.MODEL.DEVICE = "cuda"
|
||||
|
||||
_C.MODEL.BACKBONE = CN()
|
||||
_C.MODEL.BACKBONE.NAME = "bninception"
|
||||
|
||||
_C.MODEL.PRETRAIN = 'imagenet'
|
||||
_C.MODEL.PRETRIANED_PATH = MODEL_PATH
|
||||
|
||||
_C.MODEL.HEAD = CN()
|
||||
_C.MODEL.HEAD.NAME = "linear_norm"
|
||||
_C.MODEL.HEAD.DIM = 512
|
||||
|
||||
_C.MODEL.WEIGHT = ""
|
||||
|
||||
# Checkpoint save dir
|
||||
_C.SAVE_DIR = 'output'
|
||||
|
||||
# Loss
|
||||
_C.LOSSES = CN()
|
||||
_C.LOSSES.NAME = 'ms_loss'
|
||||
|
||||
# ms loss
|
||||
_C.LOSSES.MULTI_SIMILARITY_LOSS = CN()
|
||||
_C.LOSSES.MULTI_SIMILARITY_LOSS.SCALE_POS = 2.0
|
||||
_C.LOSSES.MULTI_SIMILARITY_LOSS.SCALE_NEG = 40.0
|
||||
_C.LOSSES.MULTI_SIMILARITY_LOSS.HARD_MINING = True
|
||||
|
||||
# Data option
|
||||
_C.DATA = CN()
|
||||
_C.DATA.TRAIN_IMG_SOURCE = 'resource/datasets/CUB_200_2011/train.txt'
|
||||
_C.DATA.TEST_IMG_SOURCE = 'resource/datasets/CUB_200_2011/test.txt'
|
||||
_C.DATA.TRAIN_BATCHSIZE = 70
|
||||
_C.DATA.TEST_BATCHSIZE = 256
|
||||
_C.DATA.NUM_WORKERS = 8
|
||||
_C.DATA.NUM_INSTANCES = 5
|
||||
|
||||
# Input option
|
||||
_C.INPUT = CN()
|
||||
|
||||
# INPUT CONFIG
|
||||
_C.INPUT.MODE = 'BGR'
|
||||
_C.INPUT.PIXEL_MEAN = [104. / 255, 117. / 255, 128. / 255]
|
||||
_C.INPUT.PIXEL_STD = 3 * [1. / 255]
|
||||
|
||||
_C.INPUT.FLIP_PROB = 0.5
|
||||
_C.INPUT.ORIGIN_SIZE = 256
|
||||
_C.INPUT.CROP_SCALE = [0.16, 1]
|
||||
_C.INPUT.CROP_SIZE = 227
|
||||
|
||||
# SOLVER
|
||||
_C.SOLVER = CN()
|
||||
_C.SOLVER.IS_FINETURN = False
|
||||
_C.SOLVER.FINETURN_MODE_PATH = ''
|
||||
_C.SOLVER.MAX_ITERS = 4000
|
||||
_C.SOLVER.STEPS = [1000, 2000, 3000]
|
||||
_C.SOLVER.OPTIMIZER_NAME = 'SGD'
|
||||
_C.SOLVER.BASE_LR = 0.01
|
||||
_C.SOLVER.BIAS_LR_FACTOR = 1
|
||||
_C.SOLVER.WEIGHT_DECAY = 0.0005
|
||||
_C.SOLVER.WEIGHT_DECAY_BIAS = 0.0005
|
||||
_C.SOLVER.MOMENTUM = 0.9
|
||||
_C.SOLVER.GAMMA = 0.1
|
||||
_C.SOLVER.WARMUP_FACTOR = 0.01
|
||||
_C.SOLVER.WARMUP_ITERS = 200
|
||||
_C.SOLVER.WARMUP_METHOD = 'linear'
|
||||
_C.SOLVER.CHECKPOINT_PERIOD = 200
|
||||
_C.SOLVER.RNG_SEED = 1
|
||||
|
||||
# Logger
|
||||
_C.LOGGER = CN()
|
||||
_C.LOGGER.LEVEL = 20
|
||||
_C.LOGGER.STREAM = 'stdout'
|
||||
|
||||
# Validation
|
||||
_C.VALIDATION = CN()
|
||||
_C.VALIDATION.VERBOSE = 200
|
||||
_C.VALIDATION.IS_VALIDATION = True
|
20
ret_benchmark/config/model_path.py
Normal file
20
ret_benchmark/config/model_path.py
Normal file
@ -0,0 +1,20 @@
|
||||
# Copyright (c) Malong Technologies Co., Ltd.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Contact: github@malong.com
|
||||
#
|
||||
# This source code is licensed under the LICENSE file in the root directory of this source tree.
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Config definition of imagenet pretrained model path
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
from yacs.config import CfgNode as CN
|
||||
|
||||
MODEL_PATH = dict()
|
||||
MODEL_PATH = {
|
||||
'bninception': "~/.torch/models/bn_inception-52deb4733.pth",
|
||||
}
|
||||
|
||||
MODEL_PATH = CN(MODEL_PATH)
|
8
ret_benchmark/data/__init__.py
Normal file
8
ret_benchmark/data/__init__.py
Normal file
@ -0,0 +1,8 @@
|
||||
# Copyright (c) Malong Technologies Co., Ltd.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Contact: github@malong.com
|
||||
#
|
||||
# This source code is licensed under the LICENSE file in the root directory of this source tree.
|
||||
|
||||
from .build import build_data
|
39
ret_benchmark/data/build.py
Normal file
39
ret_benchmark/data/build.py
Normal file
@ -0,0 +1,39 @@
|
||||
# Copyright (c) Malong Technologies Co., Ltd.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Contact: github@malong.com
|
||||
#
|
||||
# This source code is licensed under the LICENSE file in the root directory of this source tree.
|
||||
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from .collate_batch import collate_fn
|
||||
from .datasets import BaseDataSet
|
||||
from .samplers import RandomIdentitySampler
|
||||
from .transforms import build_transforms
|
||||
|
||||
|
||||
def build_data(cfg, is_train=True):
|
||||
transforms = build_transforms(cfg, is_train=is_train)
|
||||
if is_train:
|
||||
dataset = BaseDataSet(cfg.DATA.TRAIN_IMG_SOURCE, transforms=transforms, mode=cfg.INPUT.MODE)
|
||||
sampler = RandomIdentitySampler(dataset=dataset,
|
||||
batch_size=cfg.DATA.TRAIN_BATCHSIZE,
|
||||
num_instances=cfg.DATA.NUM_INSTANCES,
|
||||
max_iters=cfg.SOLVER.MAX_ITERS
|
||||
)
|
||||
data_loader = DataLoader(dataset,
|
||||
collate_fn=collate_fn,
|
||||
batch_sampler=sampler,
|
||||
num_workers=cfg.DATA.NUM_WORKERS,
|
||||
pin_memory=True
|
||||
)
|
||||
else:
|
||||
dataset = BaseDataSet(cfg.DATA.TEST_IMG_SOURCE, transforms=transforms, mode=cfg.INPUT.MODE)
|
||||
data_loader = DataLoader(dataset,
|
||||
collate_fn=collate_fn,
|
||||
shuffle=False,
|
||||
batch_size=cfg.DATA.TEST_BATCHSIZE,
|
||||
num_workers=cfg.DATA.NUM_WORKERS
|
||||
)
|
||||
return data_loader
|
15
ret_benchmark/data/collate_batch.py
Normal file
15
ret_benchmark/data/collate_batch.py
Normal file
@ -0,0 +1,15 @@
|
||||
# Copyright (c) Malong Technologies Co., Ltd.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Contact: github@malong.com
|
||||
#
|
||||
# This source code is licensed under the LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def collate_fn(batch):
|
||||
imgs, labels = zip(*batch)
|
||||
labels = [int(k) for k in labels]
|
||||
labels = torch.tensor(labels, dtype=torch.int64)
|
||||
return torch.stack(imgs, dim=0), labels
|
8
ret_benchmark/data/datasets/__init__.py
Normal file
8
ret_benchmark/data/datasets/__init__.py
Normal file
@ -0,0 +1,8 @@
|
||||
# Copyright (c) Malong Technologies Co., Ltd.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Contact: github@malong.com
|
||||
#
|
||||
# This source code is licensed under the LICENSE file in the root directory of this source tree.
|
||||
|
||||
from .base_dataset import BaseDataSet
|
64
ret_benchmark/data/datasets/base_dataset.py
Normal file
64
ret_benchmark/data/datasets/base_dataset.py
Normal file
@ -0,0 +1,64 @@
|
||||
# Copyright (c) Malong Technologies Co., Ltd.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Contact: github@malong.com
|
||||
#
|
||||
# This source code is licensed under the LICENSE file in the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import re
|
||||
from collections import defaultdict
|
||||
|
||||
from torch.utils.data import Dataset
|
||||
from ret_benchmark.utils.img_reader import read_image
|
||||
|
||||
|
||||
class BaseDataSet(Dataset):
|
||||
"""
|
||||
Basic Dataset read image path from img_source
|
||||
img_source: list of img_path and label
|
||||
"""
|
||||
|
||||
def __init__(self, img_source, transforms=None, mode="RGB"):
|
||||
self.mode = mode
|
||||
self.transforms = transforms
|
||||
self.root = os.path.dirname(img_source)
|
||||
assert os.path.exists(img_source), f"{img_source} NOT found."
|
||||
self.img_source = img_source
|
||||
|
||||
self.label_list = list()
|
||||
self.path_list = list()
|
||||
self._load_data()
|
||||
self.label_index_dict = self._build_label_index_dict()
|
||||
|
||||
def __len__(self):
|
||||
return len(self.label_list)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
def __str__(self):
|
||||
return f"| Dataset Info |datasize: {self.__len__()}|num_labels: {len(set(self.label_list))}|"
|
||||
|
||||
def _load_data(self):
|
||||
with open(self.img_source, 'r') as f:
|
||||
for line in f:
|
||||
_path, _label = re.split(r",| ", line.strip())
|
||||
self.path_list.append(_path)
|
||||
self.label_list.append(_label)
|
||||
|
||||
def _build_label_index_dict(self):
|
||||
index_dict = defaultdict(list)
|
||||
for i, label in enumerate(self.label_list):
|
||||
index_dict[label].append(i)
|
||||
return index_dict
|
||||
|
||||
def __getitem__(self, index):
|
||||
path = self.path_list[index]
|
||||
img_path = os.path.join(self.root, path)
|
||||
label = self.label_list[index]
|
||||
|
||||
img = read_image(img_path, mode=self.mode)
|
||||
if self.transforms is not None:
|
||||
img = self.transforms(img)
|
||||
return img, label
|
8
ret_benchmark/data/evaluations/__init__.py
Normal file
8
ret_benchmark/data/evaluations/__init__.py
Normal file
@ -0,0 +1,8 @@
|
||||
# Copyright (c) Malong Technologies Co., Ltd.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Contact: github@malong.com
|
||||
#
|
||||
# This source code is licensed under the LICENSE file in the root directory of this source tree.
|
||||
|
||||
from .ret_metric import RetMetric
|
44
ret_benchmark/data/evaluations/ret_metric.py
Normal file
44
ret_benchmark/data/evaluations/ret_metric.py
Normal file
@ -0,0 +1,44 @@
|
||||
# Copyright (c) Malong Technologies Co., Ltd.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Contact: github@malong.com
|
||||
#
|
||||
# This source code is licensed under the LICENSE file in the root directory of this source tree.
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class RetMetric(object):
|
||||
def __init__(self, feats, labels):
|
||||
|
||||
if len(feats) == 2 and type(feats) == list:
|
||||
"""
|
||||
feats = [gallery_feats, query_feats]
|
||||
labels = [gallery_labels, query_labels]
|
||||
"""
|
||||
self.is_equal_query = False
|
||||
|
||||
self.gallery_feats, self.query_feats = feats
|
||||
self.gallery_labels, self.query_labels = labels
|
||||
|
||||
else:
|
||||
self.is_equal_query = True
|
||||
self.gallery_feats = self.query_feats = feats
|
||||
self.gallery_labels = self.query_labels = labels
|
||||
|
||||
self.sim_mat = np.matmul(self.query_feats, np.transpose(self.gallery_feats))
|
||||
|
||||
def recall_k(self, k=1):
|
||||
m = len(self.sim_mat)
|
||||
|
||||
match_counter = 0
|
||||
|
||||
for i in range(m):
|
||||
pos_sim = self.sim_mat[i][self.gallery_labels == self.query_labels[i]]
|
||||
neg_sim = self.sim_mat[i][self.gallery_labels != self.query_labels[i]]
|
||||
|
||||
thresh = np.sort(pos_sim)[-2] if self.is_equal_query else np.max(pos_sim)
|
||||
|
||||
if np.sum(neg_sim > thresh) < k:
|
||||
match_counter += 1
|
||||
return float(match_counter) / m
|
8
ret_benchmark/data/samplers/__init__.py
Normal file
8
ret_benchmark/data/samplers/__init__.py
Normal file
@ -0,0 +1,8 @@
|
||||
# Copyright (c) Malong Technologies Co., Ltd.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Contact: github@malong.com
|
||||
#
|
||||
# This source code is licensed under the LICENSE file in the root directory of this source tree.
|
||||
|
||||
from .random_identity_sampler import RandomIdentitySampler
|
71
ret_benchmark/data/samplers/random_identity_sampler.py
Normal file
71
ret_benchmark/data/samplers/random_identity_sampler.py
Normal file
@ -0,0 +1,71 @@
|
||||
# Copyright (c) Malong Technologies Co., Ltd.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Contact: github@malong.com
|
||||
#
|
||||
# This source code is licensed under the LICENSE file in the root directory of this source tree.
|
||||
|
||||
import copy
|
||||
import random
|
||||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data.sampler import Sampler
|
||||
|
||||
|
||||
class RandomIdentitySampler(Sampler):
|
||||
"""
|
||||
Randomly sample N identities, then for each identity,
|
||||
randomly sample K instances, therefore batch size is N*K.
|
||||
Args:
|
||||
- dataset (BaseDataSet).
|
||||
- num_instances (int): number of instances per identity in a batch.
|
||||
- batch_size (int): number of examples in a batch.
|
||||
"""
|
||||
|
||||
def __init__(self, dataset, batch_size, num_instances, max_iters):
|
||||
self.label_index_dict = dataset.label_index_dict
|
||||
self.batch_size = batch_size
|
||||
self.K = num_instances
|
||||
self.num_labels_per_batch = self.batch_size // self.K
|
||||
self.max_iters = max_iters
|
||||
self.labels = list(self.label_index_dict.keys())
|
||||
|
||||
def __len__(self):
|
||||
return self.max_iters
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
def __str__(self):
|
||||
return f"|Sampler| iters {self.max_iters}| K {self.K}| M {self.batch_size}|"
|
||||
|
||||
def _prepare_batch(self):
|
||||
batch_idxs_dict = defaultdict(list)
|
||||
|
||||
for label in self.labels:
|
||||
idxs = copy.deepcopy(self.label_index_dict[label])
|
||||
if len(idxs) < self.K:
|
||||
idxs.extend(np.random.choice(idxs, size=self.K - len(idxs), replace=True))
|
||||
random.shuffle(idxs)
|
||||
|
||||
batch_idxs_dict[label] = [idxs[i * self.K: (i + 1) * self.K] for i in range(len(idxs) // self.K)]
|
||||
|
||||
avai_labels = copy.deepcopy(self.labels)
|
||||
return batch_idxs_dict, avai_labels
|
||||
|
||||
def __iter__(self):
|
||||
batch_idxs_dict, avai_labels = self._prepare_batch()
|
||||
for _ in range(self.max_iters):
|
||||
batch = []
|
||||
if len(avai_labels) < self.num_labels_per_batch:
|
||||
batch_idxs_dict, avai_labels = self._prepare_batch()
|
||||
|
||||
selected_labels = random.sample(avai_labels, self.num_labels_per_batch)
|
||||
for label in selected_labels:
|
||||
batch_idxs = batch_idxs_dict[label].pop(0)
|
||||
batch.extend(batch_idxs)
|
||||
if len(batch_idxs_dict[label]) == 0:
|
||||
avai_labels.remove(label)
|
||||
yield batch
|
8
ret_benchmark/data/transforms/__init__.py
Normal file
8
ret_benchmark/data/transforms/__init__.py
Normal file
@ -0,0 +1,8 @@
|
||||
# Copyright (c) Malong Technologies Co., Ltd.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Contact: github@malong.com
|
||||
#
|
||||
# This source code is licensed under the LICENSE file in the root directory of this source tree.
|
||||
|
||||
from .build import build_transforms
|
32
ret_benchmark/data/transforms/build.py
Normal file
32
ret_benchmark/data/transforms/build.py
Normal file
@ -0,0 +1,32 @@
|
||||
# Copyright (c) Malong Technologies Co., Ltd.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Contact: github@malong.com
|
||||
#
|
||||
# This source code is licensed under the LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torchvision.transforms as T
|
||||
|
||||
|
||||
def build_transforms(cfg, is_train=True):
|
||||
normalize_transform = T.Normalize(mean=cfg.INPUT.PIXEL_MEAN,
|
||||
std=cfg.INPUT.PIXEL_STD)
|
||||
if is_train:
|
||||
transform = T.Compose([
|
||||
T.Resize(size=cfg.INPUT.ORIGIN_SIZE),
|
||||
T.RandomResizedCrop(
|
||||
scale=cfg.INPUT.CROP_SCALE,
|
||||
size=cfg.INPUT.CROP_SIZE
|
||||
),
|
||||
T.RandomHorizontalFlip(p=cfg.INPUT.FLIP_PROB),
|
||||
T.ToTensor(),
|
||||
normalize_transform,
|
||||
])
|
||||
else:
|
||||
transform = T.Compose([
|
||||
T.Resize(size=cfg.INPUT.ORIGIN_SIZE),
|
||||
T.CenterCrop(cfg.INPUT.CROP_SIZE),
|
||||
T.ToTensor(),
|
||||
normalize_transform
|
||||
])
|
||||
return transform
|
8
ret_benchmark/engine/__init__.py
Normal file
8
ret_benchmark/engine/__init__.py
Normal file
@ -0,0 +1,8 @@
|
||||
# Copyright (c) Malong Technologies Co., Ltd.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Contact: github@malong.com
|
||||
#
|
||||
# This source code is licensed under the LICENSE file in the root directory of this source tree.
|
||||
|
||||
from .trainer import do_train
|
119
ret_benchmark/engine/trainer.py
Normal file
119
ret_benchmark/engine/trainer.py
Normal file
@ -0,0 +1,119 @@
|
||||
# Copyright (c) Malong Technologies Co., Ltd.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Contact: github@malong.com
|
||||
#
|
||||
# This source code is licensed under the LICENSE file in the root directory of this source tree.
|
||||
|
||||
import datetime
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ret_benchmark.data.evaluations import RetMetric
|
||||
from ret_benchmark.utils.feat_extractor import feat_extractor
|
||||
from ret_benchmark.utils.freeze_bn import set_bn_eval
|
||||
from ret_benchmark.utils.metric_logger import MetricLogger
|
||||
|
||||
|
||||
def do_train(
|
||||
cfg,
|
||||
model,
|
||||
train_loader,
|
||||
val_loader,
|
||||
optimizer,
|
||||
scheduler,
|
||||
criterion,
|
||||
checkpointer,
|
||||
device,
|
||||
checkpoint_period,
|
||||
arguments,
|
||||
logger
|
||||
):
|
||||
logger.info("Start training")
|
||||
meters = MetricLogger(delimiter=" ")
|
||||
max_iter = len(train_loader)
|
||||
|
||||
start_iter = arguments["iteration"]
|
||||
best_iteration = -1
|
||||
best_recall = 0
|
||||
|
||||
start_training_time = time.time()
|
||||
end = time.time()
|
||||
for iteration, (images, targets) in enumerate(train_loader, start_iter):
|
||||
|
||||
if iteration % cfg.VALIDATION.VERBOSE == 0 or iteration == max_iter:
|
||||
model.eval()
|
||||
logger.info('Validation')
|
||||
labels = val_loader.dataset.label_list
|
||||
labels = np.array([int(k) for k in labels])
|
||||
feats = feat_extractor(model, val_loader, logger=logger)
|
||||
|
||||
ret_metric = RetMetric(feats=feats, labels=labels)
|
||||
recall_curr = ret_metric.recall_k(1)
|
||||
|
||||
if recall_curr > best_recall:
|
||||
best_recall = recall_curr
|
||||
best_iteration = iteration
|
||||
logger.info(f'Best iteration {iteration}: recall@1: {best_recall:.3f}')
|
||||
checkpointer.save(f"best_model")
|
||||
else:
|
||||
logger.info(f'Recall@1 at iteration {iteration:06d}: {recall_curr:.3f}')
|
||||
|
||||
model.train()
|
||||
model.apply(set_bn_eval)
|
||||
|
||||
data_time = time.time() - end
|
||||
iteration = iteration + 1
|
||||
arguments["iteration"] = iteration
|
||||
|
||||
scheduler.step()
|
||||
|
||||
images = images.to(device)
|
||||
targets = torch.stack([target.to(device) for target in targets])
|
||||
|
||||
feats = model(images)
|
||||
loss = criterion(feats, targets)
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
batch_time = time.time() - end
|
||||
end = time.time()
|
||||
meters.update(time=batch_time, data=data_time, loss=loss.item())
|
||||
|
||||
eta_seconds = meters.time.global_avg * (max_iter - iteration)
|
||||
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
||||
|
||||
if iteration % 20 == 0 or iteration == max_iter:
|
||||
logger.info(
|
||||
meters.delimiter.join(
|
||||
[
|
||||
"eta: {eta}",
|
||||
"iter: {iter}",
|
||||
"{meters}",
|
||||
"lr: {lr:.6f}",
|
||||
"max mem: {memory:.1f} GB",
|
||||
]
|
||||
).format(
|
||||
eta=eta_string,
|
||||
iter=iteration,
|
||||
meters=str(meters),
|
||||
lr=optimizer.param_groups[0]["lr"],
|
||||
memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0 / 1024.0,
|
||||
)
|
||||
)
|
||||
|
||||
if iteration % checkpoint_period == 0:
|
||||
checkpointer.save("model_{:06d}".format(iteration))
|
||||
|
||||
total_training_time = time.time() - start_training_time
|
||||
total_time_str = str(datetime.timedelta(seconds=total_training_time))
|
||||
logger.info(
|
||||
"Total training time: {} ({:.4f} s / it)".format(
|
||||
total_time_str, total_training_time / (max_iter)
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(f"Best iteration: {best_iteration :06d} | best recall {best_recall} ")
|
8
ret_benchmark/losses/__init__.py
Normal file
8
ret_benchmark/losses/__init__.py
Normal file
@ -0,0 +1,8 @@
|
||||
# Copyright (c) Malong Technologies Co., Ltd.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Contact: github@malong.com
|
||||
#
|
||||
# This source code is licensed under the LICENSE file in the root directory of this source tree.
|
||||
|
||||
from .build import build_loss
|
16
ret_benchmark/losses/build.py
Normal file
16
ret_benchmark/losses/build.py
Normal file
@ -0,0 +1,16 @@
|
||||
# Copyright (c) Malong Technologies Co., Ltd.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Contact: github@malong.com
|
||||
#
|
||||
# This source code is licensed under the LICENSE file in the root directory of this source tree.
|
||||
|
||||
from .multi_similarity_loss import MultiSimilarityLoss
|
||||
from .registry import LOSS
|
||||
|
||||
|
||||
def build_loss(cfg):
|
||||
loss_name = cfg.LOSSES.NAME
|
||||
assert loss_name in LOSS, \
|
||||
f'loss name {loss_name} is not registered in registry'
|
||||
return LOSS[loss_name](cfg)
|
55
ret_benchmark/losses/multi_similarity_loss.py
Normal file
55
ret_benchmark/losses/multi_similarity_loss.py
Normal file
@ -0,0 +1,55 @@
|
||||
# Copyright (c) Malong Technologies Co., Ltd.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Contact: github@malong.com
|
||||
#
|
||||
# This source code is licensed under the LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ret_benchmark.losses.registry import LOSS
|
||||
|
||||
|
||||
@LOSS.register('ms_loss')
|
||||
class MultiSimilarityLoss(nn.Module):
|
||||
def __init__(self, cfg):
|
||||
super(MultiSimilarityLoss, self).__init__()
|
||||
self.thresh = 0.5
|
||||
self.margin = 0.1
|
||||
|
||||
self.scale_pos = cfg.LOSSES.MULTI_SIMILARITY_LOSS.SCALE_POS
|
||||
self.scale_neg = cfg.LOSSES.MULTI_SIMILARITY_LOSS.SCALE_NEG
|
||||
|
||||
def forward(self, feats, labels):
|
||||
assert feats.size(0) == labels.size(0), \
|
||||
f"feats.size(0): {feats.size(0)} is not equal to labels.size(0): {labels.size(0)}"
|
||||
batch_size = feats.size(0)
|
||||
sim_mat = torch.matmul(feats, torch.t(feats))
|
||||
|
||||
epsilon = 1e-5
|
||||
loss = list()
|
||||
|
||||
for i in range(batch_size):
|
||||
pos_pair_ = sim_mat[i][labels == labels[i]]
|
||||
pos_pair_ = pos_pair_[pos_pair_ < 1 - epsilon]
|
||||
neg_pair_ = sim_mat[i][labels != labels[i]]
|
||||
|
||||
neg_pair = neg_pair_[neg_pair_ + self.margin > min(pos_pair_)]
|
||||
pos_pair = pos_pair_[pos_pair_ - self.margin < max(neg_pair_)]
|
||||
|
||||
if len(neg_pair) < 1 or len(pos_pair) < 1:
|
||||
continue
|
||||
|
||||
# weighting step
|
||||
pos_loss = 1.0 / self.scale_pos * torch.log(
|
||||
1 + torch.sum(torch.exp(-self.scale_pos * (pos_pair - self.thresh))))
|
||||
neg_loss = 1.0 / self.scale_neg * torch.log(
|
||||
1 + torch.sum(torch.exp(self.scale_neg * (neg_pair - self.thresh))))
|
||||
loss.append(pos_loss + neg_loss)
|
||||
|
||||
if len(loss) == 0:
|
||||
return torch.zeros([], requires_grad=True)
|
||||
|
||||
loss = sum(loss) / batch_size
|
||||
return loss
|
10
ret_benchmark/losses/registry.py
Normal file
10
ret_benchmark/losses/registry.py
Normal file
@ -0,0 +1,10 @@
|
||||
# Copyright (c) Malong Technologies Co., Ltd.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Contact: github@malong.com
|
||||
#
|
||||
# This source code is licensed under the LICENSE file in the root directory of this source tree.
|
||||
|
||||
from ret_benchmark.utils.registry import Registry
|
||||
|
||||
LOSS = Registry()
|
11
ret_benchmark/modeling/__init__.py
Normal file
11
ret_benchmark/modeling/__init__.py
Normal file
@ -0,0 +1,11 @@
|
||||
# Copyright (c) Malong Technologies Co., Ltd.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Contact: github@malong.com
|
||||
#
|
||||
# This source code is licensed under the LICENSE file in the root directory of this source tree.
|
||||
|
||||
from .backbone import build_backbone
|
||||
from .build import build_model
|
||||
from .heads import build_head
|
||||
from .registry import BACKBONES, HEADS
|
1
ret_benchmark/modeling/backbone/__init__.py
Normal file
1
ret_benchmark/modeling/backbone/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from .build import build_backbone
|
520
ret_benchmark/modeling/backbone/bninception.py
Normal file
520
ret_benchmark/modeling/backbone/bninception.py
Normal file
@ -0,0 +1,520 @@
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ret_benchmark.modeling import registry
|
||||
|
||||
@registry.BACKBONES.register('bninception')
|
||||
class BNInception(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(BNInception, self).__init__()
|
||||
inplace = True
|
||||
self.conv1_7x7_s2 = nn.Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))
|
||||
self.conv1_7x7_s2_bn = nn.BatchNorm2d(64, affine=True)
|
||||
self.conv1_relu_7x7 = nn.ReLU(inplace)
|
||||
self.pool1_3x3_s2 = nn.MaxPool2d((3, 3), stride=(2, 2), dilation=(1, 1), ceil_mode=True)
|
||||
self.conv2_3x3_reduce = nn.Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
|
||||
self.conv2_3x3_reduce_bn = nn.BatchNorm2d(64, affine=True)
|
||||
self.conv2_relu_3x3_reduce = nn.ReLU(inplace)
|
||||
self.conv2_3x3 = nn.Conv2d(64, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
||||
self.conv2_3x3_bn = nn.BatchNorm2d(192, affine=True)
|
||||
self.conv2_relu_3x3 = nn.ReLU(inplace)
|
||||
self.pool2_3x3_s2 = nn.MaxPool2d((3, 3), stride=(2, 2), dilation=(1, 1), ceil_mode=True)
|
||||
self.inception_3a_1x1 = nn.Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1))
|
||||
self.inception_3a_1x1_bn = nn.BatchNorm2d(64, affine=True)
|
||||
self.inception_3a_relu_1x1 = nn.ReLU(inplace)
|
||||
self.inception_3a_3x3_reduce = nn.Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1))
|
||||
self.inception_3a_3x3_reduce_bn = nn.BatchNorm2d(64, affine=True)
|
||||
self.inception_3a_relu_3x3_reduce = nn.ReLU(inplace)
|
||||
self.inception_3a_3x3 = nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
||||
self.inception_3a_3x3_bn = nn.BatchNorm2d(64, affine=True)
|
||||
self.inception_3a_relu_3x3 = nn.ReLU(inplace)
|
||||
self.inception_3a_double_3x3_reduce = nn.Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1))
|
||||
self.inception_3a_double_3x3_reduce_bn = nn.BatchNorm2d(64, affine=True)
|
||||
self.inception_3a_relu_double_3x3_reduce = nn.ReLU(inplace)
|
||||
self.inception_3a_double_3x3_1 = nn.Conv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
||||
self.inception_3a_double_3x3_1_bn = nn.BatchNorm2d(96, affine=True)
|
||||
self.inception_3a_relu_double_3x3_1 = nn.ReLU(inplace)
|
||||
self.inception_3a_double_3x3_2 = nn.Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
||||
self.inception_3a_double_3x3_2_bn = nn.BatchNorm2d(96, affine=True)
|
||||
self.inception_3a_relu_double_3x3_2 = nn.ReLU(inplace)
|
||||
self.inception_3a_pool = nn.AvgPool2d(3, stride=1, padding=1, ceil_mode=True, count_include_pad=True)
|
||||
self.inception_3a_pool_proj = nn.Conv2d(192, 32, kernel_size=(1, 1), stride=(1, 1))
|
||||
self.inception_3a_pool_proj_bn = nn.BatchNorm2d(32, affine=True)
|
||||
self.inception_3a_relu_pool_proj = nn.ReLU(inplace)
|
||||
self.inception_3b_1x1 = nn.Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))
|
||||
self.inception_3b_1x1_bn = nn.BatchNorm2d(64, affine=True)
|
||||
self.inception_3b_relu_1x1 = nn.ReLU(inplace)
|
||||
self.inception_3b_3x3_reduce = nn.Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))
|
||||
self.inception_3b_3x3_reduce_bn = nn.BatchNorm2d(64, affine=True)
|
||||
self.inception_3b_relu_3x3_reduce = nn.ReLU(inplace)
|
||||
self.inception_3b_3x3 = nn.Conv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
||||
self.inception_3b_3x3_bn = nn.BatchNorm2d(96, affine=True)
|
||||
self.inception_3b_relu_3x3 = nn.ReLU(inplace)
|
||||
self.inception_3b_double_3x3_reduce = nn.Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))
|
||||
self.inception_3b_double_3x3_reduce_bn = nn.BatchNorm2d(64, affine=True)
|
||||
self.inception_3b_relu_double_3x3_reduce = nn.ReLU(inplace)
|
||||
self.inception_3b_double_3x3_1 = nn.Conv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
||||
self.inception_3b_double_3x3_1_bn = nn.BatchNorm2d(96, affine=True)
|
||||
self.inception_3b_relu_double_3x3_1 = nn.ReLU(inplace)
|
||||
self.inception_3b_double_3x3_2 = nn.Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
||||
self.inception_3b_double_3x3_2_bn = nn.BatchNorm2d(96, affine=True)
|
||||
self.inception_3b_relu_double_3x3_2 = nn.ReLU(inplace)
|
||||
self.inception_3b_pool = nn.AvgPool2d(3, stride=1, padding=1, ceil_mode=True, count_include_pad=True)
|
||||
self.inception_3b_pool_proj = nn.Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))
|
||||
self.inception_3b_pool_proj_bn = nn.BatchNorm2d(64, affine=True)
|
||||
self.inception_3b_relu_pool_proj = nn.ReLU(inplace)
|
||||
self.inception_3c_3x3_reduce = nn.Conv2d(320, 128, kernel_size=(1, 1), stride=(1, 1))
|
||||
self.inception_3c_3x3_reduce_bn = nn.BatchNorm2d(128, affine=True)
|
||||
self.inception_3c_relu_3x3_reduce = nn.ReLU(inplace)
|
||||
self.inception_3c_3x3 = nn.Conv2d(128, 160, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
|
||||
self.inception_3c_3x3_bn = nn.BatchNorm2d(160, affine=True)
|
||||
self.inception_3c_relu_3x3 = nn.ReLU(inplace)
|
||||
self.inception_3c_double_3x3_reduce = nn.Conv2d(320, 64, kernel_size=(1, 1), stride=(1, 1))
|
||||
self.inception_3c_double_3x3_reduce_bn = nn.BatchNorm2d(64, affine=True)
|
||||
self.inception_3c_relu_double_3x3_reduce = nn.ReLU(inplace)
|
||||
self.inception_3c_double_3x3_1 = nn.Conv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
||||
self.inception_3c_double_3x3_1_bn = nn.BatchNorm2d(96, affine=True)
|
||||
self.inception_3c_relu_double_3x3_1 = nn.ReLU(inplace)
|
||||
self.inception_3c_double_3x3_2 = nn.Conv2d(96, 96, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
|
||||
self.inception_3c_double_3x3_2_bn = nn.BatchNorm2d(96, affine=True)
|
||||
self.inception_3c_relu_double_3x3_2 = nn.ReLU(inplace)
|
||||
self.inception_3c_pool = nn.MaxPool2d((3, 3), stride=(2, 2), dilation=(1, 1), ceil_mode=True)
|
||||
self.inception_4a_1x1 = nn.Conv2d(576, 224, kernel_size=(1, 1), stride=(1, 1))
|
||||
self.inception_4a_1x1_bn = nn.BatchNorm2d(224, affine=True)
|
||||
self.inception_4a_relu_1x1 = nn.ReLU(inplace)
|
||||
self.inception_4a_3x3_reduce = nn.Conv2d(576, 64, kernel_size=(1, 1), stride=(1, 1))
|
||||
self.inception_4a_3x3_reduce_bn = nn.BatchNorm2d(64, affine=True)
|
||||
self.inception_4a_relu_3x3_reduce = nn.ReLU(inplace)
|
||||
self.inception_4a_3x3 = nn.Conv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
||||
self.inception_4a_3x3_bn = nn.BatchNorm2d(96, affine=True)
|
||||
self.inception_4a_relu_3x3 = nn.ReLU(inplace)
|
||||
self.inception_4a_double_3x3_reduce = nn.Conv2d(576, 96, kernel_size=(1, 1), stride=(1, 1))
|
||||
self.inception_4a_double_3x3_reduce_bn = nn.BatchNorm2d(96, affine=True)
|
||||
self.inception_4a_relu_double_3x3_reduce = nn.ReLU(inplace)
|
||||
self.inception_4a_double_3x3_1 = nn.Conv2d(96, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
||||
self.inception_4a_double_3x3_1_bn = nn.BatchNorm2d(128, affine=True)
|
||||
self.inception_4a_relu_double_3x3_1 = nn.ReLU(inplace)
|
||||
self.inception_4a_double_3x3_2 = nn.Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
||||
self.inception_4a_double_3x3_2_bn = nn.BatchNorm2d(128, affine=True)
|
||||
self.inception_4a_relu_double_3x3_2 = nn.ReLU(inplace)
|
||||
self.inception_4a_pool = nn.AvgPool2d(3, stride=1, padding=1, ceil_mode=True, count_include_pad=True)
|
||||
self.inception_4a_pool_proj = nn.Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1))
|
||||
self.inception_4a_pool_proj_bn = nn.BatchNorm2d(128, affine=True)
|
||||
self.inception_4a_relu_pool_proj = nn.ReLU(inplace)
|
||||
self.inception_4b_1x1 = nn.Conv2d(576, 192, kernel_size=(1, 1), stride=(1, 1))
|
||||
self.inception_4b_1x1_bn = nn.BatchNorm2d(192, affine=True)
|
||||
self.inception_4b_relu_1x1 = nn.ReLU(inplace)
|
||||
self.inception_4b_3x3_reduce = nn.Conv2d(576, 96, kernel_size=(1, 1), stride=(1, 1))
|
||||
self.inception_4b_3x3_reduce_bn = nn.BatchNorm2d(96, affine=True)
|
||||
self.inception_4b_relu_3x3_reduce = nn.ReLU(inplace)
|
||||
self.inception_4b_3x3 = nn.Conv2d(96, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
||||
self.inception_4b_3x3_bn = nn.BatchNorm2d(128, affine=True)
|
||||
self.inception_4b_relu_3x3 = nn.ReLU(inplace)
|
||||
self.inception_4b_double_3x3_reduce = nn.Conv2d(576, 96, kernel_size=(1, 1), stride=(1, 1))
|
||||
self.inception_4b_double_3x3_reduce_bn = nn.BatchNorm2d(96, affine=True)
|
||||
self.inception_4b_relu_double_3x3_reduce = nn.ReLU(inplace)
|
||||
self.inception_4b_double_3x3_1 = nn.Conv2d(96, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
||||
self.inception_4b_double_3x3_1_bn = nn.BatchNorm2d(128, affine=True)
|
||||
self.inception_4b_relu_double_3x3_1 = nn.ReLU(inplace)
|
||||
self.inception_4b_double_3x3_2 = nn.Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
||||
self.inception_4b_double_3x3_2_bn = nn.BatchNorm2d(128, affine=True)
|
||||
self.inception_4b_relu_double_3x3_2 = nn.ReLU(inplace)
|
||||
self.inception_4b_pool = nn.AvgPool2d(3, stride=1, padding=1, ceil_mode=True, count_include_pad=True)
|
||||
self.inception_4b_pool_proj = nn.Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1))
|
||||
self.inception_4b_pool_proj_bn = nn.BatchNorm2d(128, affine=True)
|
||||
self.inception_4b_relu_pool_proj = nn.ReLU(inplace)
|
||||
self.inception_4c_1x1 = nn.Conv2d(576, 160, kernel_size=(1, 1), stride=(1, 1))
|
||||
self.inception_4c_1x1_bn = nn.BatchNorm2d(160, affine=True)
|
||||
self.inception_4c_relu_1x1 = nn.ReLU(inplace)
|
||||
self.inception_4c_3x3_reduce = nn.Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1))
|
||||
self.inception_4c_3x3_reduce_bn = nn.BatchNorm2d(128, affine=True)
|
||||
self.inception_4c_relu_3x3_reduce = nn.ReLU(inplace)
|
||||
self.inception_4c_3x3 = nn.Conv2d(128, 160, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
||||
self.inception_4c_3x3_bn = nn.BatchNorm2d(160, affine=True)
|
||||
self.inception_4c_relu_3x3 = nn.ReLU(inplace)
|
||||
self.inception_4c_double_3x3_reduce = nn.Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1))
|
||||
self.inception_4c_double_3x3_reduce_bn = nn.BatchNorm2d(128, affine=True)
|
||||
self.inception_4c_relu_double_3x3_reduce = nn.ReLU(inplace)
|
||||
self.inception_4c_double_3x3_1 = nn.Conv2d(128, 160, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
||||
self.inception_4c_double_3x3_1_bn = nn.BatchNorm2d(160, affine=True)
|
||||
self.inception_4c_relu_double_3x3_1 = nn.ReLU(inplace)
|
||||
self.inception_4c_double_3x3_2 = nn.Conv2d(160, 160, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
||||
self.inception_4c_double_3x3_2_bn = nn.BatchNorm2d(160, affine=True)
|
||||
self.inception_4c_relu_double_3x3_2 = nn.ReLU(inplace)
|
||||
self.inception_4c_pool = nn.AvgPool2d(3, stride=1, padding=1, ceil_mode=True, count_include_pad=True)
|
||||
self.inception_4c_pool_proj = nn.Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1))
|
||||
self.inception_4c_pool_proj_bn = nn.BatchNorm2d(128, affine=True)
|
||||
self.inception_4c_relu_pool_proj = nn.ReLU(inplace)
|
||||
self.inception_4d_1x1 = nn.Conv2d(608, 96, kernel_size=(1, 1), stride=(1, 1))
|
||||
self.inception_4d_1x1_bn = nn.BatchNorm2d(96, affine=True)
|
||||
self.inception_4d_relu_1x1 = nn.ReLU(inplace)
|
||||
self.inception_4d_3x3_reduce = nn.Conv2d(608, 128, kernel_size=(1, 1), stride=(1, 1))
|
||||
self.inception_4d_3x3_reduce_bn = nn.BatchNorm2d(128, affine=True)
|
||||
self.inception_4d_relu_3x3_reduce = nn.ReLU(inplace)
|
||||
self.inception_4d_3x3 = nn.Conv2d(128, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
||||
self.inception_4d_3x3_bn = nn.BatchNorm2d(192, affine=True)
|
||||
self.inception_4d_relu_3x3 = nn.ReLU(inplace)
|
||||
self.inception_4d_double_3x3_reduce = nn.Conv2d(608, 160, kernel_size=(1, 1), stride=(1, 1))
|
||||
self.inception_4d_double_3x3_reduce_bn = nn.BatchNorm2d(160, affine=True)
|
||||
self.inception_4d_relu_double_3x3_reduce = nn.ReLU(inplace)
|
||||
self.inception_4d_double_3x3_1 = nn.Conv2d(160, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
||||
self.inception_4d_double_3x3_1_bn = nn.BatchNorm2d(192, affine=True)
|
||||
self.inception_4d_relu_double_3x3_1 = nn.ReLU(inplace)
|
||||
self.inception_4d_double_3x3_2 = nn.Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
||||
self.inception_4d_double_3x3_2_bn = nn.BatchNorm2d(192, affine=True)
|
||||
self.inception_4d_relu_double_3x3_2 = nn.ReLU(inplace)
|
||||
self.inception_4d_pool = nn.AvgPool2d(3, stride=1, padding=1, ceil_mode=True, count_include_pad=True)
|
||||
self.inception_4d_pool_proj = nn.Conv2d(608, 128, kernel_size=(1, 1), stride=(1, 1))
|
||||
self.inception_4d_pool_proj_bn = nn.BatchNorm2d(128, affine=True)
|
||||
self.inception_4d_relu_pool_proj = nn.ReLU(inplace)
|
||||
self.inception_4e_3x3_reduce = nn.Conv2d(608, 128, kernel_size=(1, 1), stride=(1, 1))
|
||||
self.inception_4e_3x3_reduce_bn = nn.BatchNorm2d(128, affine=True)
|
||||
self.inception_4e_relu_3x3_reduce = nn.ReLU(inplace)
|
||||
self.inception_4e_3x3 = nn.Conv2d(128, 192, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
|
||||
self.inception_4e_3x3_bn = nn.BatchNorm2d(192, affine=True)
|
||||
self.inception_4e_relu_3x3 = nn.ReLU(inplace)
|
||||
self.inception_4e_double_3x3_reduce = nn.Conv2d(608, 192, kernel_size=(1, 1), stride=(1, 1))
|
||||
self.inception_4e_double_3x3_reduce_bn = nn.BatchNorm2d(192, affine=True)
|
||||
self.inception_4e_relu_double_3x3_reduce = nn.ReLU(inplace)
|
||||
self.inception_4e_double_3x3_1 = nn.Conv2d(192, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
||||
self.inception_4e_double_3x3_1_bn = nn.BatchNorm2d(256, affine=True)
|
||||
self.inception_4e_relu_double_3x3_1 = nn.ReLU(inplace)
|
||||
self.inception_4e_double_3x3_2 = nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
|
||||
self.inception_4e_double_3x3_2_bn = nn.BatchNorm2d(256, affine=True)
|
||||
self.inception_4e_relu_double_3x3_2 = nn.ReLU(inplace)
|
||||
self.inception_4e_pool = nn.MaxPool2d((3, 3), stride=(2, 2), dilation=(1, 1), ceil_mode=True)
|
||||
self.inception_5a_1x1 = nn.Conv2d(1056, 352, kernel_size=(1, 1), stride=(1, 1))
|
||||
self.inception_5a_1x1_bn = nn.BatchNorm2d(352, affine=True)
|
||||
self.inception_5a_relu_1x1 = nn.ReLU(inplace)
|
||||
self.inception_5a_3x3_reduce = nn.Conv2d(1056, 192, kernel_size=(1, 1), stride=(1, 1))
|
||||
self.inception_5a_3x3_reduce_bn = nn.BatchNorm2d(192, affine=True)
|
||||
self.inception_5a_relu_3x3_reduce = nn.ReLU(inplace)
|
||||
self.inception_5a_3x3 = nn.Conv2d(192, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
||||
self.inception_5a_3x3_bn = nn.BatchNorm2d(320, affine=True)
|
||||
self.inception_5a_relu_3x3 = nn.ReLU(inplace)
|
||||
self.inception_5a_double_3x3_reduce = nn.Conv2d(1056, 160, kernel_size=(1, 1), stride=(1, 1))
|
||||
self.inception_5a_double_3x3_reduce_bn = nn.BatchNorm2d(160, affine=True)
|
||||
self.inception_5a_relu_double_3x3_reduce = nn.ReLU(inplace)
|
||||
self.inception_5a_double_3x3_1 = nn.Conv2d(160, 224, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
||||
self.inception_5a_double_3x3_1_bn = nn.BatchNorm2d(224, affine=True)
|
||||
self.inception_5a_relu_double_3x3_1 = nn.ReLU(inplace)
|
||||
self.inception_5a_double_3x3_2 = nn.Conv2d(224, 224, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
||||
self.inception_5a_double_3x3_2_bn = nn.BatchNorm2d(224, affine=True)
|
||||
self.inception_5a_relu_double_3x3_2 = nn.ReLU(inplace)
|
||||
self.inception_5a_pool = nn.AvgPool2d(3, stride=1, padding=1, ceil_mode=True, count_include_pad=True)
|
||||
self.inception_5a_pool_proj = nn.Conv2d(1056, 128, kernel_size=(1, 1), stride=(1, 1))
|
||||
self.inception_5a_pool_proj_bn = nn.BatchNorm2d(128, affine=True)
|
||||
self.inception_5a_relu_pool_proj = nn.ReLU(inplace)
|
||||
self.inception_5b_1x1 = nn.Conv2d(1024, 352, kernel_size=(1, 1), stride=(1, 1))
|
||||
self.inception_5b_1x1_bn = nn.BatchNorm2d(352, affine=True)
|
||||
self.inception_5b_relu_1x1 = nn.ReLU(inplace)
|
||||
self.inception_5b_3x3_reduce = nn.Conv2d(1024, 192, kernel_size=(1, 1), stride=(1, 1))
|
||||
self.inception_5b_3x3_reduce_bn = nn.BatchNorm2d(192, affine=True)
|
||||
self.inception_5b_relu_3x3_reduce = nn.ReLU(inplace)
|
||||
self.inception_5b_3x3 = nn.Conv2d(192, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
||||
self.inception_5b_3x3_bn = nn.BatchNorm2d(320, affine=True)
|
||||
self.inception_5b_relu_3x3 = nn.ReLU(inplace)
|
||||
self.inception_5b_double_3x3_reduce = nn.Conv2d(1024, 192, kernel_size=(1, 1), stride=(1, 1))
|
||||
self.inception_5b_double_3x3_reduce_bn = nn.BatchNorm2d(192, affine=True)
|
||||
self.inception_5b_relu_double_3x3_reduce = nn.ReLU(inplace)
|
||||
self.inception_5b_double_3x3_1 = nn.Conv2d(192, 224, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
||||
self.inception_5b_double_3x3_1_bn = nn.BatchNorm2d(224, affine=True)
|
||||
self.inception_5b_relu_double_3x3_1 = nn.ReLU(inplace)
|
||||
self.inception_5b_double_3x3_2 = nn.Conv2d(224, 224, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
||||
self.inception_5b_double_3x3_2_bn = nn.BatchNorm2d(224, affine=True)
|
||||
self.inception_5b_relu_double_3x3_2 = nn.ReLU(inplace)
|
||||
self.inception_5b_pool = nn.MaxPool2d((3, 3), stride=(1, 1), padding=(1, 1), dilation=(1, 1), ceil_mode=True)
|
||||
self.inception_5b_pool_proj = nn.Conv2d(1024, 128, kernel_size=(1, 1), stride=(1, 1))
|
||||
self.inception_5b_pool_proj_bn = nn.BatchNorm2d(128, affine=True)
|
||||
self.inception_5b_relu_pool_proj = nn.ReLU(inplace)
|
||||
|
||||
def features(self, input):
|
||||
conv1_7x7_s2_out = self.conv1_7x7_s2(input)
|
||||
conv1_7x7_s2_bn_out = self.conv1_7x7_s2_bn(conv1_7x7_s2_out)
|
||||
conv1_relu_7x7_out = self.conv1_relu_7x7(conv1_7x7_s2_bn_out)
|
||||
pool1_3x3_s2_out = self.pool1_3x3_s2(conv1_relu_7x7_out)
|
||||
conv2_3x3_reduce_out = self.conv2_3x3_reduce(pool1_3x3_s2_out)
|
||||
conv2_3x3_reduce_bn_out = self.conv2_3x3_reduce_bn(conv2_3x3_reduce_out)
|
||||
conv2_relu_3x3_reduce_out = self.conv2_relu_3x3_reduce(conv2_3x3_reduce_bn_out)
|
||||
conv2_3x3_out = self.conv2_3x3(conv2_relu_3x3_reduce_out)
|
||||
conv2_3x3_bn_out = self.conv2_3x3_bn(conv2_3x3_out)
|
||||
conv2_relu_3x3_out = self.conv2_relu_3x3(conv2_3x3_bn_out)
|
||||
pool2_3x3_s2_out = self.pool2_3x3_s2(conv2_relu_3x3_out)
|
||||
inception_3a_1x1_out = self.inception_3a_1x1(pool2_3x3_s2_out)
|
||||
inception_3a_1x1_bn_out = self.inception_3a_1x1_bn(inception_3a_1x1_out)
|
||||
inception_3a_relu_1x1_out = self.inception_3a_relu_1x1(inception_3a_1x1_bn_out)
|
||||
inception_3a_3x3_reduce_out = self.inception_3a_3x3_reduce(pool2_3x3_s2_out)
|
||||
inception_3a_3x3_reduce_bn_out = self.inception_3a_3x3_reduce_bn(inception_3a_3x3_reduce_out)
|
||||
inception_3a_relu_3x3_reduce_out = self.inception_3a_relu_3x3_reduce(inception_3a_3x3_reduce_bn_out)
|
||||
inception_3a_3x3_out = self.inception_3a_3x3(inception_3a_relu_3x3_reduce_out)
|
||||
inception_3a_3x3_bn_out = self.inception_3a_3x3_bn(inception_3a_3x3_out)
|
||||
inception_3a_relu_3x3_out = self.inception_3a_relu_3x3(inception_3a_3x3_bn_out)
|
||||
inception_3a_double_3x3_reduce_out = self.inception_3a_double_3x3_reduce(pool2_3x3_s2_out)
|
||||
inception_3a_double_3x3_reduce_bn_out = self.inception_3a_double_3x3_reduce_bn(
|
||||
inception_3a_double_3x3_reduce_out)
|
||||
inception_3a_relu_double_3x3_reduce_out = self.inception_3a_relu_double_3x3_reduce(
|
||||
inception_3a_double_3x3_reduce_bn_out)
|
||||
inception_3a_double_3x3_1_out = self.inception_3a_double_3x3_1(inception_3a_relu_double_3x3_reduce_out)
|
||||
inception_3a_double_3x3_1_bn_out = self.inception_3a_double_3x3_1_bn(inception_3a_double_3x3_1_out)
|
||||
inception_3a_relu_double_3x3_1_out = self.inception_3a_relu_double_3x3_1(inception_3a_double_3x3_1_bn_out)
|
||||
inception_3a_double_3x3_2_out = self.inception_3a_double_3x3_2(inception_3a_relu_double_3x3_1_out)
|
||||
inception_3a_double_3x3_2_bn_out = self.inception_3a_double_3x3_2_bn(inception_3a_double_3x3_2_out)
|
||||
inception_3a_relu_double_3x3_2_out = self.inception_3a_relu_double_3x3_2(inception_3a_double_3x3_2_bn_out)
|
||||
inception_3a_pool_out = self.inception_3a_pool(pool2_3x3_s2_out)
|
||||
inception_3a_pool_proj_out = self.inception_3a_pool_proj(inception_3a_pool_out)
|
||||
inception_3a_pool_proj_bn_out = self.inception_3a_pool_proj_bn(inception_3a_pool_proj_out)
|
||||
inception_3a_relu_pool_proj_out = self.inception_3a_relu_pool_proj(inception_3a_pool_proj_bn_out)
|
||||
inception_3a_output_out = torch.cat(
|
||||
[inception_3a_relu_1x1_out, inception_3a_relu_3x3_out, inception_3a_relu_double_3x3_2_out,
|
||||
inception_3a_relu_pool_proj_out], 1)
|
||||
inception_3b_1x1_out = self.inception_3b_1x1(inception_3a_output_out)
|
||||
inception_3b_1x1_bn_out = self.inception_3b_1x1_bn(inception_3b_1x1_out)
|
||||
inception_3b_relu_1x1_out = self.inception_3b_relu_1x1(inception_3b_1x1_bn_out)
|
||||
inception_3b_3x3_reduce_out = self.inception_3b_3x3_reduce(inception_3a_output_out)
|
||||
inception_3b_3x3_reduce_bn_out = self.inception_3b_3x3_reduce_bn(inception_3b_3x3_reduce_out)
|
||||
inception_3b_relu_3x3_reduce_out = self.inception_3b_relu_3x3_reduce(inception_3b_3x3_reduce_bn_out)
|
||||
inception_3b_3x3_out = self.inception_3b_3x3(inception_3b_relu_3x3_reduce_out)
|
||||
inception_3b_3x3_bn_out = self.inception_3b_3x3_bn(inception_3b_3x3_out)
|
||||
inception_3b_relu_3x3_out = self.inception_3b_relu_3x3(inception_3b_3x3_bn_out)
|
||||
inception_3b_double_3x3_reduce_out = self.inception_3b_double_3x3_reduce(inception_3a_output_out)
|
||||
inception_3b_double_3x3_reduce_bn_out = self.inception_3b_double_3x3_reduce_bn(
|
||||
inception_3b_double_3x3_reduce_out)
|
||||
inception_3b_relu_double_3x3_reduce_out = self.inception_3b_relu_double_3x3_reduce(
|
||||
inception_3b_double_3x3_reduce_bn_out)
|
||||
inception_3b_double_3x3_1_out = self.inception_3b_double_3x3_1(inception_3b_relu_double_3x3_reduce_out)
|
||||
inception_3b_double_3x3_1_bn_out = self.inception_3b_double_3x3_1_bn(inception_3b_double_3x3_1_out)
|
||||
inception_3b_relu_double_3x3_1_out = self.inception_3b_relu_double_3x3_1(inception_3b_double_3x3_1_bn_out)
|
||||
inception_3b_double_3x3_2_out = self.inception_3b_double_3x3_2(inception_3b_relu_double_3x3_1_out)
|
||||
inception_3b_double_3x3_2_bn_out = self.inception_3b_double_3x3_2_bn(inception_3b_double_3x3_2_out)
|
||||
inception_3b_relu_double_3x3_2_out = self.inception_3b_relu_double_3x3_2(inception_3b_double_3x3_2_bn_out)
|
||||
inception_3b_pool_out = self.inception_3b_pool(inception_3a_output_out)
|
||||
inception_3b_pool_proj_out = self.inception_3b_pool_proj(inception_3b_pool_out)
|
||||
inception_3b_pool_proj_bn_out = self.inception_3b_pool_proj_bn(inception_3b_pool_proj_out)
|
||||
inception_3b_relu_pool_proj_out = self.inception_3b_relu_pool_proj(inception_3b_pool_proj_bn_out)
|
||||
inception_3b_output_out = torch.cat(
|
||||
[inception_3b_relu_1x1_out, inception_3b_relu_3x3_out, inception_3b_relu_double_3x3_2_out,
|
||||
inception_3b_relu_pool_proj_out], 1)
|
||||
inception_3c_3x3_reduce_out = self.inception_3c_3x3_reduce(inception_3b_output_out)
|
||||
inception_3c_3x3_reduce_bn_out = self.inception_3c_3x3_reduce_bn(inception_3c_3x3_reduce_out)
|
||||
inception_3c_relu_3x3_reduce_out = self.inception_3c_relu_3x3_reduce(inception_3c_3x3_reduce_bn_out)
|
||||
inception_3c_3x3_out = self.inception_3c_3x3(inception_3c_relu_3x3_reduce_out)
|
||||
inception_3c_3x3_bn_out = self.inception_3c_3x3_bn(inception_3c_3x3_out)
|
||||
inception_3c_relu_3x3_out = self.inception_3c_relu_3x3(inception_3c_3x3_bn_out)
|
||||
inception_3c_double_3x3_reduce_out = self.inception_3c_double_3x3_reduce(inception_3b_output_out)
|
||||
inception_3c_double_3x3_reduce_bn_out = self.inception_3c_double_3x3_reduce_bn(
|
||||
inception_3c_double_3x3_reduce_out)
|
||||
inception_3c_relu_double_3x3_reduce_out = self.inception_3c_relu_double_3x3_reduce(
|
||||
inception_3c_double_3x3_reduce_bn_out)
|
||||
inception_3c_double_3x3_1_out = self.inception_3c_double_3x3_1(inception_3c_relu_double_3x3_reduce_out)
|
||||
inception_3c_double_3x3_1_bn_out = self.inception_3c_double_3x3_1_bn(inception_3c_double_3x3_1_out)
|
||||
inception_3c_relu_double_3x3_1_out = self.inception_3c_relu_double_3x3_1(inception_3c_double_3x3_1_bn_out)
|
||||
inception_3c_double_3x3_2_out = self.inception_3c_double_3x3_2(inception_3c_relu_double_3x3_1_out)
|
||||
inception_3c_double_3x3_2_bn_out = self.inception_3c_double_3x3_2_bn(inception_3c_double_3x3_2_out)
|
||||
inception_3c_relu_double_3x3_2_out = self.inception_3c_relu_double_3x3_2(inception_3c_double_3x3_2_bn_out)
|
||||
inception_3c_pool_out = self.inception_3c_pool(inception_3b_output_out)
|
||||
inception_3c_output_out = torch.cat(
|
||||
[inception_3c_relu_3x3_out, inception_3c_relu_double_3x3_2_out, inception_3c_pool_out], 1)
|
||||
inception_4a_1x1_out = self.inception_4a_1x1(inception_3c_output_out)
|
||||
inception_4a_1x1_bn_out = self.inception_4a_1x1_bn(inception_4a_1x1_out)
|
||||
inception_4a_relu_1x1_out = self.inception_4a_relu_1x1(inception_4a_1x1_bn_out)
|
||||
inception_4a_3x3_reduce_out = self.inception_4a_3x3_reduce(inception_3c_output_out)
|
||||
inception_4a_3x3_reduce_bn_out = self.inception_4a_3x3_reduce_bn(inception_4a_3x3_reduce_out)
|
||||
inception_4a_relu_3x3_reduce_out = self.inception_4a_relu_3x3_reduce(inception_4a_3x3_reduce_bn_out)
|
||||
inception_4a_3x3_out = self.inception_4a_3x3(inception_4a_relu_3x3_reduce_out)
|
||||
inception_4a_3x3_bn_out = self.inception_4a_3x3_bn(inception_4a_3x3_out)
|
||||
inception_4a_relu_3x3_out = self.inception_4a_relu_3x3(inception_4a_3x3_bn_out)
|
||||
inception_4a_double_3x3_reduce_out = self.inception_4a_double_3x3_reduce(inception_3c_output_out)
|
||||
inception_4a_double_3x3_reduce_bn_out = self.inception_4a_double_3x3_reduce_bn(
|
||||
inception_4a_double_3x3_reduce_out)
|
||||
inception_4a_relu_double_3x3_reduce_out = self.inception_4a_relu_double_3x3_reduce(
|
||||
inception_4a_double_3x3_reduce_bn_out)
|
||||
inception_4a_double_3x3_1_out = self.inception_4a_double_3x3_1(inception_4a_relu_double_3x3_reduce_out)
|
||||
inception_4a_double_3x3_1_bn_out = self.inception_4a_double_3x3_1_bn(inception_4a_double_3x3_1_out)
|
||||
inception_4a_relu_double_3x3_1_out = self.inception_4a_relu_double_3x3_1(inception_4a_double_3x3_1_bn_out)
|
||||
inception_4a_double_3x3_2_out = self.inception_4a_double_3x3_2(inception_4a_relu_double_3x3_1_out)
|
||||
inception_4a_double_3x3_2_bn_out = self.inception_4a_double_3x3_2_bn(inception_4a_double_3x3_2_out)
|
||||
inception_4a_relu_double_3x3_2_out = self.inception_4a_relu_double_3x3_2(inception_4a_double_3x3_2_bn_out)
|
||||
inception_4a_pool_out = self.inception_4a_pool(inception_3c_output_out)
|
||||
inception_4a_pool_proj_out = self.inception_4a_pool_proj(inception_4a_pool_out)
|
||||
inception_4a_pool_proj_bn_out = self.inception_4a_pool_proj_bn(inception_4a_pool_proj_out)
|
||||
inception_4a_relu_pool_proj_out = self.inception_4a_relu_pool_proj(inception_4a_pool_proj_bn_out)
|
||||
inception_4a_output_out = torch.cat(
|
||||
[inception_4a_relu_1x1_out, inception_4a_relu_3x3_out, inception_4a_relu_double_3x3_2_out,
|
||||
inception_4a_relu_pool_proj_out], 1)
|
||||
inception_4b_1x1_out = self.inception_4b_1x1(inception_4a_output_out)
|
||||
inception_4b_1x1_bn_out = self.inception_4b_1x1_bn(inception_4b_1x1_out)
|
||||
inception_4b_relu_1x1_out = self.inception_4b_relu_1x1(inception_4b_1x1_bn_out)
|
||||
inception_4b_3x3_reduce_out = self.inception_4b_3x3_reduce(inception_4a_output_out)
|
||||
inception_4b_3x3_reduce_bn_out = self.inception_4b_3x3_reduce_bn(inception_4b_3x3_reduce_out)
|
||||
inception_4b_relu_3x3_reduce_out = self.inception_4b_relu_3x3_reduce(inception_4b_3x3_reduce_bn_out)
|
||||
inception_4b_3x3_out = self.inception_4b_3x3(inception_4b_relu_3x3_reduce_out)
|
||||
inception_4b_3x3_bn_out = self.inception_4b_3x3_bn(inception_4b_3x3_out)
|
||||
inception_4b_relu_3x3_out = self.inception_4b_relu_3x3(inception_4b_3x3_bn_out)
|
||||
inception_4b_double_3x3_reduce_out = self.inception_4b_double_3x3_reduce(inception_4a_output_out)
|
||||
inception_4b_double_3x3_reduce_bn_out = self.inception_4b_double_3x3_reduce_bn(
|
||||
inception_4b_double_3x3_reduce_out)
|
||||
inception_4b_relu_double_3x3_reduce_out = self.inception_4b_relu_double_3x3_reduce(
|
||||
inception_4b_double_3x3_reduce_bn_out)
|
||||
inception_4b_double_3x3_1_out = self.inception_4b_double_3x3_1(inception_4b_relu_double_3x3_reduce_out)
|
||||
inception_4b_double_3x3_1_bn_out = self.inception_4b_double_3x3_1_bn(inception_4b_double_3x3_1_out)
|
||||
inception_4b_relu_double_3x3_1_out = self.inception_4b_relu_double_3x3_1(inception_4b_double_3x3_1_bn_out)
|
||||
inception_4b_double_3x3_2_out = self.inception_4b_double_3x3_2(inception_4b_relu_double_3x3_1_out)
|
||||
inception_4b_double_3x3_2_bn_out = self.inception_4b_double_3x3_2_bn(inception_4b_double_3x3_2_out)
|
||||
inception_4b_relu_double_3x3_2_out = self.inception_4b_relu_double_3x3_2(inception_4b_double_3x3_2_bn_out)
|
||||
inception_4b_pool_out = self.inception_4b_pool(inception_4a_output_out)
|
||||
inception_4b_pool_proj_out = self.inception_4b_pool_proj(inception_4b_pool_out)
|
||||
inception_4b_pool_proj_bn_out = self.inception_4b_pool_proj_bn(inception_4b_pool_proj_out)
|
||||
inception_4b_relu_pool_proj_out = self.inception_4b_relu_pool_proj(inception_4b_pool_proj_bn_out)
|
||||
inception_4b_output_out = torch.cat(
|
||||
[inception_4b_relu_1x1_out, inception_4b_relu_3x3_out, inception_4b_relu_double_3x3_2_out,
|
||||
inception_4b_relu_pool_proj_out], 1)
|
||||
inception_4c_1x1_out = self.inception_4c_1x1(inception_4b_output_out)
|
||||
inception_4c_1x1_bn_out = self.inception_4c_1x1_bn(inception_4c_1x1_out)
|
||||
inception_4c_relu_1x1_out = self.inception_4c_relu_1x1(inception_4c_1x1_bn_out)
|
||||
inception_4c_3x3_reduce_out = self.inception_4c_3x3_reduce(inception_4b_output_out)
|
||||
inception_4c_3x3_reduce_bn_out = self.inception_4c_3x3_reduce_bn(inception_4c_3x3_reduce_out)
|
||||
inception_4c_relu_3x3_reduce_out = self.inception_4c_relu_3x3_reduce(inception_4c_3x3_reduce_bn_out)
|
||||
inception_4c_3x3_out = self.inception_4c_3x3(inception_4c_relu_3x3_reduce_out)
|
||||
inception_4c_3x3_bn_out = self.inception_4c_3x3_bn(inception_4c_3x3_out)
|
||||
inception_4c_relu_3x3_out = self.inception_4c_relu_3x3(inception_4c_3x3_bn_out)
|
||||
inception_4c_double_3x3_reduce_out = self.inception_4c_double_3x3_reduce(inception_4b_output_out)
|
||||
inception_4c_double_3x3_reduce_bn_out = self.inception_4c_double_3x3_reduce_bn(
|
||||
inception_4c_double_3x3_reduce_out)
|
||||
inception_4c_relu_double_3x3_reduce_out = self.inception_4c_relu_double_3x3_reduce(
|
||||
inception_4c_double_3x3_reduce_bn_out)
|
||||
inception_4c_double_3x3_1_out = self.inception_4c_double_3x3_1(inception_4c_relu_double_3x3_reduce_out)
|
||||
inception_4c_double_3x3_1_bn_out = self.inception_4c_double_3x3_1_bn(inception_4c_double_3x3_1_out)
|
||||
inception_4c_relu_double_3x3_1_out = self.inception_4c_relu_double_3x3_1(inception_4c_double_3x3_1_bn_out)
|
||||
inception_4c_double_3x3_2_out = self.inception_4c_double_3x3_2(inception_4c_relu_double_3x3_1_out)
|
||||
inception_4c_double_3x3_2_bn_out = self.inception_4c_double_3x3_2_bn(inception_4c_double_3x3_2_out)
|
||||
inception_4c_relu_double_3x3_2_out = self.inception_4c_relu_double_3x3_2(inception_4c_double_3x3_2_bn_out)
|
||||
inception_4c_pool_out = self.inception_4c_pool(inception_4b_output_out)
|
||||
inception_4c_pool_proj_out = self.inception_4c_pool_proj(inception_4c_pool_out)
|
||||
inception_4c_pool_proj_bn_out = self.inception_4c_pool_proj_bn(inception_4c_pool_proj_out)
|
||||
inception_4c_relu_pool_proj_out = self.inception_4c_relu_pool_proj(inception_4c_pool_proj_bn_out)
|
||||
inception_4c_output_out = torch.cat(
|
||||
[inception_4c_relu_1x1_out, inception_4c_relu_3x3_out, inception_4c_relu_double_3x3_2_out,
|
||||
inception_4c_relu_pool_proj_out], 1)
|
||||
inception_4d_1x1_out = self.inception_4d_1x1(inception_4c_output_out)
|
||||
inception_4d_1x1_bn_out = self.inception_4d_1x1_bn(inception_4d_1x1_out)
|
||||
inception_4d_relu_1x1_out = self.inception_4d_relu_1x1(inception_4d_1x1_bn_out)
|
||||
inception_4d_3x3_reduce_out = self.inception_4d_3x3_reduce(inception_4c_output_out)
|
||||
inception_4d_3x3_reduce_bn_out = self.inception_4d_3x3_reduce_bn(inception_4d_3x3_reduce_out)
|
||||
inception_4d_relu_3x3_reduce_out = self.inception_4d_relu_3x3_reduce(inception_4d_3x3_reduce_bn_out)
|
||||
inception_4d_3x3_out = self.inception_4d_3x3(inception_4d_relu_3x3_reduce_out)
|
||||
inception_4d_3x3_bn_out = self.inception_4d_3x3_bn(inception_4d_3x3_out)
|
||||
inception_4d_relu_3x3_out = self.inception_4d_relu_3x3(inception_4d_3x3_bn_out)
|
||||
inception_4d_double_3x3_reduce_out = self.inception_4d_double_3x3_reduce(inception_4c_output_out)
|
||||
inception_4d_double_3x3_reduce_bn_out = self.inception_4d_double_3x3_reduce_bn(
|
||||
inception_4d_double_3x3_reduce_out)
|
||||
inception_4d_relu_double_3x3_reduce_out = self.inception_4d_relu_double_3x3_reduce(
|
||||
inception_4d_double_3x3_reduce_bn_out)
|
||||
inception_4d_double_3x3_1_out = self.inception_4d_double_3x3_1(inception_4d_relu_double_3x3_reduce_out)
|
||||
inception_4d_double_3x3_1_bn_out = self.inception_4d_double_3x3_1_bn(inception_4d_double_3x3_1_out)
|
||||
inception_4d_relu_double_3x3_1_out = self.inception_4d_relu_double_3x3_1(inception_4d_double_3x3_1_bn_out)
|
||||
inception_4d_double_3x3_2_out = self.inception_4d_double_3x3_2(inception_4d_relu_double_3x3_1_out)
|
||||
inception_4d_double_3x3_2_bn_out = self.inception_4d_double_3x3_2_bn(inception_4d_double_3x3_2_out)
|
||||
inception_4d_relu_double_3x3_2_out = self.inception_4d_relu_double_3x3_2(inception_4d_double_3x3_2_bn_out)
|
||||
inception_4d_pool_out = self.inception_4d_pool(inception_4c_output_out)
|
||||
inception_4d_pool_proj_out = self.inception_4d_pool_proj(inception_4d_pool_out)
|
||||
inception_4d_pool_proj_bn_out = self.inception_4d_pool_proj_bn(inception_4d_pool_proj_out)
|
||||
inception_4d_relu_pool_proj_out = self.inception_4d_relu_pool_proj(inception_4d_pool_proj_bn_out)
|
||||
inception_4d_output_out = torch.cat(
|
||||
[inception_4d_relu_1x1_out, inception_4d_relu_3x3_out, inception_4d_relu_double_3x3_2_out,
|
||||
inception_4d_relu_pool_proj_out], 1)
|
||||
inception_4e_3x3_reduce_out = self.inception_4e_3x3_reduce(inception_4d_output_out)
|
||||
inception_4e_3x3_reduce_bn_out = self.inception_4e_3x3_reduce_bn(inception_4e_3x3_reduce_out)
|
||||
inception_4e_relu_3x3_reduce_out = self.inception_4e_relu_3x3_reduce(inception_4e_3x3_reduce_bn_out)
|
||||
inception_4e_3x3_out = self.inception_4e_3x3(inception_4e_relu_3x3_reduce_out)
|
||||
inception_4e_3x3_bn_out = self.inception_4e_3x3_bn(inception_4e_3x3_out)
|
||||
inception_4e_relu_3x3_out = self.inception_4e_relu_3x3(inception_4e_3x3_bn_out)
|
||||
inception_4e_double_3x3_reduce_out = self.inception_4e_double_3x3_reduce(inception_4d_output_out)
|
||||
inception_4e_double_3x3_reduce_bn_out = self.inception_4e_double_3x3_reduce_bn(
|
||||
inception_4e_double_3x3_reduce_out)
|
||||
inception_4e_relu_double_3x3_reduce_out = self.inception_4e_relu_double_3x3_reduce(
|
||||
inception_4e_double_3x3_reduce_bn_out)
|
||||
inception_4e_double_3x3_1_out = self.inception_4e_double_3x3_1(inception_4e_relu_double_3x3_reduce_out)
|
||||
inception_4e_double_3x3_1_bn_out = self.inception_4e_double_3x3_1_bn(inception_4e_double_3x3_1_out)
|
||||
inception_4e_relu_double_3x3_1_out = self.inception_4e_relu_double_3x3_1(inception_4e_double_3x3_1_bn_out)
|
||||
inception_4e_double_3x3_2_out = self.inception_4e_double_3x3_2(inception_4e_relu_double_3x3_1_out)
|
||||
inception_4e_double_3x3_2_bn_out = self.inception_4e_double_3x3_2_bn(inception_4e_double_3x3_2_out)
|
||||
inception_4e_relu_double_3x3_2_out = self.inception_4e_relu_double_3x3_2(inception_4e_double_3x3_2_bn_out)
|
||||
inception_4e_pool_out = self.inception_4e_pool(inception_4d_output_out)
|
||||
inception_4e_output_out = torch.cat(
|
||||
[inception_4e_relu_3x3_out, inception_4e_relu_double_3x3_2_out, inception_4e_pool_out], 1)
|
||||
inception_5a_1x1_out = self.inception_5a_1x1(inception_4e_output_out)
|
||||
inception_5a_1x1_bn_out = self.inception_5a_1x1_bn(inception_5a_1x1_out)
|
||||
inception_5a_relu_1x1_out = self.inception_5a_relu_1x1(inception_5a_1x1_bn_out)
|
||||
inception_5a_3x3_reduce_out = self.inception_5a_3x3_reduce(inception_4e_output_out)
|
||||
inception_5a_3x3_reduce_bn_out = self.inception_5a_3x3_reduce_bn(inception_5a_3x3_reduce_out)
|
||||
inception_5a_relu_3x3_reduce_out = self.inception_5a_relu_3x3_reduce(inception_5a_3x3_reduce_bn_out)
|
||||
inception_5a_3x3_out = self.inception_5a_3x3(inception_5a_relu_3x3_reduce_out)
|
||||
inception_5a_3x3_bn_out = self.inception_5a_3x3_bn(inception_5a_3x3_out)
|
||||
inception_5a_relu_3x3_out = self.inception_5a_relu_3x3(inception_5a_3x3_bn_out)
|
||||
inception_5a_double_3x3_reduce_out = self.inception_5a_double_3x3_reduce(inception_4e_output_out)
|
||||
inception_5a_double_3x3_reduce_bn_out = self.inception_5a_double_3x3_reduce_bn(
|
||||
inception_5a_double_3x3_reduce_out)
|
||||
inception_5a_relu_double_3x3_reduce_out = self.inception_5a_relu_double_3x3_reduce(
|
||||
inception_5a_double_3x3_reduce_bn_out)
|
||||
inception_5a_double_3x3_1_out = self.inception_5a_double_3x3_1(inception_5a_relu_double_3x3_reduce_out)
|
||||
inception_5a_double_3x3_1_bn_out = self.inception_5a_double_3x3_1_bn(inception_5a_double_3x3_1_out)
|
||||
inception_5a_relu_double_3x3_1_out = self.inception_5a_relu_double_3x3_1(inception_5a_double_3x3_1_bn_out)
|
||||
inception_5a_double_3x3_2_out = self.inception_5a_double_3x3_2(inception_5a_relu_double_3x3_1_out)
|
||||
inception_5a_double_3x3_2_bn_out = self.inception_5a_double_3x3_2_bn(inception_5a_double_3x3_2_out)
|
||||
inception_5a_relu_double_3x3_2_out = self.inception_5a_relu_double_3x3_2(inception_5a_double_3x3_2_bn_out)
|
||||
inception_5a_pool_out = self.inception_5a_pool(inception_4e_output_out)
|
||||
inception_5a_pool_proj_out = self.inception_5a_pool_proj(inception_5a_pool_out)
|
||||
inception_5a_pool_proj_bn_out = self.inception_5a_pool_proj_bn(inception_5a_pool_proj_out)
|
||||
inception_5a_relu_pool_proj_out = self.inception_5a_relu_pool_proj(inception_5a_pool_proj_bn_out)
|
||||
inception_5a_output_out = torch.cat(
|
||||
[inception_5a_relu_1x1_out, inception_5a_relu_3x3_out, inception_5a_relu_double_3x3_2_out,
|
||||
inception_5a_relu_pool_proj_out], 1)
|
||||
inception_5b_1x1_out = self.inception_5b_1x1(inception_5a_output_out)
|
||||
inception_5b_1x1_bn_out = self.inception_5b_1x1_bn(inception_5b_1x1_out)
|
||||
inception_5b_relu_1x1_out = self.inception_5b_relu_1x1(inception_5b_1x1_bn_out)
|
||||
inception_5b_3x3_reduce_out = self.inception_5b_3x3_reduce(inception_5a_output_out)
|
||||
inception_5b_3x3_reduce_bn_out = self.inception_5b_3x3_reduce_bn(inception_5b_3x3_reduce_out)
|
||||
inception_5b_relu_3x3_reduce_out = self.inception_5b_relu_3x3_reduce(inception_5b_3x3_reduce_bn_out)
|
||||
inception_5b_3x3_out = self.inception_5b_3x3(inception_5b_relu_3x3_reduce_out)
|
||||
inception_5b_3x3_bn_out = self.inception_5b_3x3_bn(inception_5b_3x3_out)
|
||||
inception_5b_relu_3x3_out = self.inception_5b_relu_3x3(inception_5b_3x3_bn_out)
|
||||
inception_5b_double_3x3_reduce_out = self.inception_5b_double_3x3_reduce(inception_5a_output_out)
|
||||
inception_5b_double_3x3_reduce_bn_out = self.inception_5b_double_3x3_reduce_bn(
|
||||
inception_5b_double_3x3_reduce_out)
|
||||
inception_5b_relu_double_3x3_reduce_out = self.inception_5b_relu_double_3x3_reduce(
|
||||
inception_5b_double_3x3_reduce_bn_out)
|
||||
inception_5b_double_3x3_1_out = self.inception_5b_double_3x3_1(inception_5b_relu_double_3x3_reduce_out)
|
||||
inception_5b_double_3x3_1_bn_out = self.inception_5b_double_3x3_1_bn(inception_5b_double_3x3_1_out)
|
||||
inception_5b_relu_double_3x3_1_out = self.inception_5b_relu_double_3x3_1(inception_5b_double_3x3_1_bn_out)
|
||||
inception_5b_double_3x3_2_out = self.inception_5b_double_3x3_2(inception_5b_relu_double_3x3_1_out)
|
||||
inception_5b_double_3x3_2_bn_out = self.inception_5b_double_3x3_2_bn(inception_5b_double_3x3_2_out)
|
||||
inception_5b_relu_double_3x3_2_out = self.inception_5b_relu_double_3x3_2(inception_5b_double_3x3_2_bn_out)
|
||||
inception_5b_pool_out = self.inception_5b_pool(inception_5a_output_out)
|
||||
inception_5b_pool_proj_out = self.inception_5b_pool_proj(inception_5b_pool_out)
|
||||
inception_5b_pool_proj_bn_out = self.inception_5b_pool_proj_bn(inception_5b_pool_proj_out)
|
||||
inception_5b_relu_pool_proj_out = self.inception_5b_relu_pool_proj(inception_5b_pool_proj_bn_out)
|
||||
inception_5b_output_out = torch.cat(
|
||||
[inception_5b_relu_1x1_out, inception_5b_relu_3x3_out, inception_5b_relu_double_3x3_2_out,
|
||||
inception_5b_relu_pool_proj_out], 1)
|
||||
return inception_5b_output_out
|
||||
|
||||
def logits(self, features):
|
||||
x = F.adaptive_max_pool2d(features, output_size=1)
|
||||
x = x.view(x.size(0), -1)
|
||||
return x
|
||||
|
||||
def forward(self, input):
|
||||
x = self.features(input)
|
||||
x = self.logits(x)
|
||||
return x
|
||||
|
||||
def load_param(self, model_path):
|
||||
param_dict = torch.load(model_path)
|
||||
for i in param_dict:
|
||||
if 'last_linear' in i:
|
||||
continue
|
||||
self.state_dict()[i].copy_(param_dict[i])
|
8
ret_benchmark/modeling/backbone/build.py
Normal file
8
ret_benchmark/modeling/backbone/build.py
Normal file
@ -0,0 +1,8 @@
|
||||
from ret_benchmark.modeling.registry import BACKBONES
|
||||
|
||||
from .bninception import BNInception
|
||||
|
||||
|
||||
def build_backbone(cfg):
|
||||
assert cfg.MODEL.BACKBONE.NAME in BACKBONES, f"backbone {cfg.MODEL.BACKBONE} is not defined"
|
||||
return BACKBONES[cfg.MODEL.BACKBONE.NAME]()
|
35
ret_benchmark/modeling/build.py
Normal file
35
ret_benchmark/modeling/build.py
Normal file
@ -0,0 +1,35 @@
|
||||
# Copyright (c) Malong Technologies Co., Ltd.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Contact: github@malong.com
|
||||
#
|
||||
# This source code is licensed under the LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
from torch.nn.modules import Sequential
|
||||
|
||||
from .backbone import build_backbone
|
||||
from .heads import build_head
|
||||
|
||||
|
||||
def build_model(cfg):
|
||||
backbone = build_backbone(cfg)
|
||||
head = build_head(cfg)
|
||||
|
||||
model = Sequential(OrderedDict([
|
||||
('backbone', backbone),
|
||||
('head', head)
|
||||
]))
|
||||
|
||||
if cfg.MODEL.PRETRAIN == 'imagenet':
|
||||
print('Loading imagenet pretrianed model ...')
|
||||
pretrained_path = os.path.expanduser(cfg.MODEL.PRETRIANED_PATH[cfg.MODEL.BACKBONE.NAME])
|
||||
model.backbone.load_param(pretrained_path)
|
||||
elif os.path.exists(cfg.MODEL.PRETRAIN):
|
||||
ckp = torch.load(cfg.MODEL.PRETRAIN)
|
||||
model.load_state_dict(ckp['model_state_dict'])
|
||||
return model
|
8
ret_benchmark/modeling/heads/__init__.py
Normal file
8
ret_benchmark/modeling/heads/__init__.py
Normal file
@ -0,0 +1,8 @@
|
||||
# Copyright (c) Malong Technologies Co., Ltd.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Contact: github@malong.com
|
||||
#
|
||||
# This source code is licensed under the LICENSE file in the root directory of this source tree.
|
||||
|
||||
from .build import build_head
|
15
ret_benchmark/modeling/heads/build.py
Normal file
15
ret_benchmark/modeling/heads/build.py
Normal file
@ -0,0 +1,15 @@
|
||||
# Copyright (c) Malong Technologies Co., Ltd.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Contact: github@malong.com
|
||||
#
|
||||
# This source code is licensed under the LICENSE file in the root directory of this source tree.
|
||||
|
||||
from ret_benchmark.modeling.registry import HEADS
|
||||
|
||||
from .linear_norm import LinearNorm
|
||||
|
||||
|
||||
def build_head(cfg):
|
||||
assert cfg.MODEL.HEAD.NAME in HEADS, f"head {cfg.MODEL.HEAD.NAME} is not defined"
|
||||
return HEADS[cfg.MODEL.HEAD.NAME](cfg, in_channels=1024)
|
25
ret_benchmark/modeling/heads/linear_norm.py
Normal file
25
ret_benchmark/modeling/heads/linear_norm.py
Normal file
@ -0,0 +1,25 @@
|
||||
# Copyright (c) Malong Technologies Co., Ltd.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Contact: github@malong.com
|
||||
#
|
||||
# This source code is licensed under the LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ret_benchmark.modeling.registry import HEADS
|
||||
from ret_benchmark.utils.init_methods import weights_init_kaiming
|
||||
|
||||
|
||||
@HEADS.register('linear_norm')
|
||||
class LinearNorm(nn.Module):
|
||||
def __init__(self, cfg, in_channels):
|
||||
super(LinearNorm, self).__init__()
|
||||
self.fc = nn.Linear(in_channels, cfg.MODEL.HEAD.DIM)
|
||||
self.fc.apply(weights_init_kaiming)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc(x)
|
||||
x = nn.functional.normalize(x, p=2, dim=1)
|
||||
return x
|
12
ret_benchmark/modeling/registry.py
Normal file
12
ret_benchmark/modeling/registry.py
Normal file
@ -0,0 +1,12 @@
|
||||
# Copyright (c) Malong Technologies Co., Ltd.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Contact: github@malong.com
|
||||
#
|
||||
# This source code is licensed under the LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
from ret_benchmark.utils.registry import Registry
|
||||
|
||||
BACKBONES = Registry()
|
||||
HEADS = Registry()
|
4
ret_benchmark/solver/__init__.py
Normal file
4
ret_benchmark/solver/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
||||
from .build import build_optimizer
|
||||
from .build import build_lr_scheduler
|
||||
from .lr_scheduler import WarmupMultiStepLR
|
29
ret_benchmark/solver/build.py
Normal file
29
ret_benchmark/solver/build.py
Normal file
@ -0,0 +1,29 @@
|
||||
import torch
|
||||
|
||||
from .lr_scheduler import WarmupMultiStepLR
|
||||
|
||||
|
||||
def build_optimizer(cfg, model):
|
||||
params = []
|
||||
for key, value in model.named_parameters():
|
||||
if not value.requires_grad:
|
||||
continue
|
||||
lr_mul = 1.0
|
||||
if "backbone" in key:
|
||||
lr_mul = 0.1
|
||||
params += [{"params": [value], "lr_mul": lr_mul}]
|
||||
optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params,
|
||||
lr=cfg.SOLVER.BASE_LR,
|
||||
weight_decay=cfg.SOLVER.WEIGHT_DECAY)
|
||||
return optimizer
|
||||
|
||||
|
||||
def build_lr_scheduler(cfg, optimizer):
|
||||
return WarmupMultiStepLR(
|
||||
optimizer,
|
||||
cfg.SOLVER.STEPS,
|
||||
cfg.SOLVER.GAMMA,
|
||||
warmup_factor=cfg.SOLVER.WARMUP_FACTOR,
|
||||
warmup_iters=cfg.SOLVER.WARMUP_ITERS,
|
||||
warmup_method=cfg.SOLVER.WARMUP_METHOD,
|
||||
)
|
49
ret_benchmark/solver/lr_scheduler.py
Normal file
49
ret_benchmark/solver/lr_scheduler.py
Normal file
@ -0,0 +1,49 @@
|
||||
from bisect import bisect_right
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler):
|
||||
def __init__(
|
||||
self,
|
||||
optimizer,
|
||||
milestones,
|
||||
gamma=0.1,
|
||||
warmup_factor=1.0 / 3,
|
||||
warmup_iters=500,
|
||||
warmup_method="linear",
|
||||
last_epoch=-1,
|
||||
):
|
||||
if not list(milestones) == sorted(milestones):
|
||||
raise ValueError(
|
||||
"Milestones should be a list of" " increasing integers. Got {}",
|
||||
milestones,
|
||||
)
|
||||
|
||||
if warmup_method not in ("constant", "linear"):
|
||||
raise ValueError(
|
||||
"Only 'constant' or 'linear' warmup_method accepted"
|
||||
"got {}".format(warmup_method)
|
||||
)
|
||||
self.milestones = milestones
|
||||
self.gamma = gamma
|
||||
self.warmup_factor = warmup_factor
|
||||
self.warmup_iters = warmup_iters
|
||||
self.warmup_method = warmup_method
|
||||
super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch)
|
||||
|
||||
def get_lr(self):
|
||||
warmup_factor = 1
|
||||
if self.last_epoch < self.warmup_iters:
|
||||
if self.warmup_method == "constant":
|
||||
warmup_factor = self.warmup_factor
|
||||
elif self.warmup_method == "linear":
|
||||
alpha = float(self.last_epoch) / self.warmup_iters
|
||||
warmup_factor = self.warmup_factor * (1 - alpha) + alpha
|
||||
return [
|
||||
base_lr * warmup_factor * self.gamma ** bisect_right(
|
||||
self.milestones,
|
||||
self.last_epoch
|
||||
)
|
||||
for base_lr in self.base_lrs
|
||||
]
|
89
ret_benchmark/utils/checkpoint.py
Normal file
89
ret_benchmark/utils/checkpoint.py
Normal file
@ -0,0 +1,89 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
||||
import logging
|
||||
import os
|
||||
|
||||
import torch
|
||||
from ret_benchmark.utils.model_serialization import load_state_dict
|
||||
|
||||
|
||||
class Checkpointer(object):
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
optimizer=None,
|
||||
scheduler=None,
|
||||
save_dir="",
|
||||
save_to_disk=None,
|
||||
logger=None,
|
||||
):
|
||||
self.model = model
|
||||
self.optimizer = optimizer
|
||||
self.scheduler = scheduler
|
||||
self.save_dir = save_dir
|
||||
self.save_to_disk = save_to_disk
|
||||
if logger is None:
|
||||
logger = logging.getLogger(__name__)
|
||||
self.logger = logger
|
||||
|
||||
def save(self, name):
|
||||
if not self.save_dir:
|
||||
return
|
||||
|
||||
data = {}
|
||||
data["model"] = self.model.state_dict()
|
||||
if self.optimizer is not None:
|
||||
data["optimizer"] = self.optimizer.state_dict()
|
||||
if self.scheduler is not None:
|
||||
data["scheduler"] = self.scheduler.state_dict()
|
||||
|
||||
save_file = os.path.join(self.save_dir, "{}.pth".format(name))
|
||||
self.logger.info("Saving checkpoint to {}".format(save_file))
|
||||
torch.save(data, save_file)
|
||||
|
||||
def load(self, f=None):
|
||||
if self.has_checkpoint():
|
||||
# override argument with existing checkpoint
|
||||
f = self.get_checkpoint_file()
|
||||
if not f:
|
||||
# no checkpoint could be found
|
||||
self.logger.info("No checkpoint found. Initializing model from scratch")
|
||||
return {}
|
||||
self.logger.info("Loading checkpoint from {}".format(f))
|
||||
checkpoint = self._load_file(f)
|
||||
self._load_model(checkpoint)
|
||||
if "optimizer" in checkpoint and self.optimizer:
|
||||
self.logger.info("Loading optimizer from {}".format(f))
|
||||
self.optimizer.load_state_dict(checkpoint.pop("optimizer"))
|
||||
if "scheduler" in checkpoint and self.scheduler:
|
||||
self.logger.info("Loading scheduler from {}".format(f))
|
||||
self.scheduler.load_state_dict(checkpoint.pop("scheduler"))
|
||||
|
||||
# return any further checkpoint data
|
||||
return checkpoint
|
||||
|
||||
def has_checkpoint(self):
|
||||
save_file = os.path.join(self.save_dir, "last_checkpoint")
|
||||
return os.path.exists(save_file)
|
||||
|
||||
def get_checkpoint_file(self):
|
||||
save_file = os.path.join(self.save_dir, "last_checkpoint")
|
||||
try:
|
||||
with open(save_file, "r") as f:
|
||||
last_saved = f.read()
|
||||
last_saved = last_saved.strip()
|
||||
except IOError:
|
||||
# if file doesn't exist, maybe because it has just been
|
||||
# deleted by a separate process
|
||||
last_saved = ""
|
||||
return last_saved
|
||||
|
||||
def tag_last_checkpoint(self, last_filename):
|
||||
save_file = os.path.join(self.save_dir, "last_checkpoint")
|
||||
with open(save_file, "w") as f:
|
||||
f.write(last_filename)
|
||||
|
||||
def _load_file(self, f):
|
||||
return torch.load(f, map_location=torch.device("cpu"))
|
||||
|
||||
def _load_model(self, checkpoint):
|
||||
load_state_dict(self.model, checkpoint.pop("model"))
|
29
ret_benchmark/utils/config_util.py
Normal file
29
ret_benchmark/utils/config_util.py
Normal file
@ -0,0 +1,29 @@
|
||||
from __future__ import (absolute_import, division, print_function,
|
||||
unicode_literals)
|
||||
|
||||
import copy
|
||||
import os
|
||||
|
||||
from ret_benchmark.config import cfg as g_cfg
|
||||
|
||||
|
||||
def get_config_root_path():
|
||||
''' Path to configs for unit tests '''
|
||||
# cur_file_dir is root/tests/env_tests
|
||||
cur_file_dir = os.path.dirname(os.path.abspath(os.path.realpath(__file__)))
|
||||
ret = os.path.dirname(os.path.dirname(cur_file_dir))
|
||||
ret = os.path.join(ret, "configs")
|
||||
return ret
|
||||
|
||||
|
||||
def load_config(rel_path):
|
||||
''' Load config from file path specified as path relative to config_root '''
|
||||
cfg_path = os.path.join(get_config_root_path(), rel_path)
|
||||
return load_config_from_file(cfg_path)
|
||||
|
||||
|
||||
def load_config_from_file(file_path):
|
||||
''' Load config from file path specified as absolute path '''
|
||||
ret = copy.deepcopy(g_cfg)
|
||||
ret.merge_from_file(file_path)
|
||||
return ret
|
27
ret_benchmark/utils/feat_extractor.py
Normal file
27
ret_benchmark/utils/feat_extractor.py
Normal file
@ -0,0 +1,27 @@
|
||||
# Copyright (c) Malong Technologies Co., Ltd.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Contact: github@malong.com
|
||||
#
|
||||
# This source code is licensed under the LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
def feat_extractor(model, data_loader, logger=None):
|
||||
model.eval()
|
||||
feats = list()
|
||||
|
||||
for i, batch in enumerate(data_loader):
|
||||
imgs = batch[0].cuda()
|
||||
|
||||
with torch.no_grad():
|
||||
out = model(imgs).data.cpu().numpy()
|
||||
feats.append(out)
|
||||
|
||||
if logger is not None and (i + 1) % 100 == 0:
|
||||
logger.debug(f'Extract Features: [{i + 1}/{len(data_loader)}]')
|
||||
del out
|
||||
feats = np.vstack(feats)
|
||||
return feats
|
14
ret_benchmark/utils/freeze_bn.py
Normal file
14
ret_benchmark/utils/freeze_bn.py
Normal file
@ -0,0 +1,14 @@
|
||||
# Copyright (c) Malong Technologies Co., Ltd.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Contact: github@malong.com
|
||||
#
|
||||
# This source code is licensed under the LICENSE file in the root directory of this source tree.
|
||||
|
||||
# Batch Norm Freezer
|
||||
# Note: adds an additional 2% improvement on CUB (on others benchmarks, it brings no effect)
|
||||
|
||||
def set_bn_eval(m):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find('BatchNorm') != -1:
|
||||
m.eval()
|
21
ret_benchmark/utils/img_reader.py
Normal file
21
ret_benchmark/utils/img_reader.py
Normal file
@ -0,0 +1,21 @@
|
||||
import os.path as osp
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def read_image(img_path, mode='RGB'):
|
||||
"""Keep reading image until succeed.
|
||||
This can avoid IOError incurred by heavy IO process."""
|
||||
got_img = False
|
||||
if not osp.exists(img_path):
|
||||
raise IOError(f"{img_path} does not exist")
|
||||
while not got_img:
|
||||
try:
|
||||
img = Image.open(img_path).convert("RGB")
|
||||
if mode == "BGR":
|
||||
r, g, b = img.split()
|
||||
img = Image.merge("RGB", (b, g, r))
|
||||
got_img = True
|
||||
except IOError:
|
||||
print(f"IOError incurred when reading '{img_path}'. Will redo.")
|
||||
pass
|
||||
return img
|
32
ret_benchmark/utils/init_methods.py
Normal file
32
ret_benchmark/utils/init_methods.py
Normal file
@ -0,0 +1,32 @@
|
||||
# Copyright (c) Malong Technologies Co., Ltd.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Contact: github@malong.com
|
||||
#
|
||||
# This source code is licensed under the LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
def weights_init_kaiming(m):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find('Linear') != -1:
|
||||
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_out')
|
||||
nn.init.constant_(m.bias, 0.0)
|
||||
elif classname.find('Conv') != -1:
|
||||
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in')
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0.0)
|
||||
elif classname.find('BatchNorm') != -1:
|
||||
if m.affine:
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
nn.init.constant_(m.bias, 0.0)
|
||||
|
||||
|
||||
def weights_init_classifier(m):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find('Linear') != -1:
|
||||
nn.init.normal_(m.weight, std=0.001)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0.0)
|
32
ret_benchmark/utils/logger.py
Normal file
32
ret_benchmark/utils/logger.py
Normal file
@ -0,0 +1,32 @@
|
||||
# Copyright (c) Malong Technologies Co., Ltd.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Contact: github@malong.com
|
||||
#
|
||||
# This source code is licensed under the LICENSE file in the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
|
||||
_streams = {
|
||||
"stdout": sys.stdout
|
||||
}
|
||||
|
||||
|
||||
def setup_logger(name: str, level: int, stream: str = "stdout") -> logging.Logger:
|
||||
global _streams
|
||||
if stream not in _streams:
|
||||
log_folder = os.path.dirname(stream)
|
||||
os.makedirs(log_folder, exist_ok=True)
|
||||
_streams[stream] = open(stream, 'w')
|
||||
logger = logging.getLogger(name)
|
||||
logger.propagate = False
|
||||
logger.setLevel(level)
|
||||
|
||||
sh = logging.StreamHandler(stream=_streams[stream])
|
||||
sh.setLevel(level)
|
||||
formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s")
|
||||
sh.setFormatter(formatter)
|
||||
logger.addHandler(sh)
|
||||
return logger
|
66
ret_benchmark/utils/metric_logger.py
Normal file
66
ret_benchmark/utils/metric_logger.py
Normal file
@ -0,0 +1,66 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
||||
from collections import defaultdict
|
||||
from collections import deque
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class SmoothedValue(object):
|
||||
"""Track a series of values and provide access to smoothed values over a
|
||||
window or the global series average.
|
||||
"""
|
||||
|
||||
def __init__(self, window_size=20):
|
||||
self.deque = deque(maxlen=window_size)
|
||||
self.series = []
|
||||
self.total = 0.0
|
||||
self.count = 0
|
||||
|
||||
def update(self, value):
|
||||
self.deque.append(value)
|
||||
self.series.append(value)
|
||||
self.count += 1
|
||||
self.total += value
|
||||
|
||||
@property
|
||||
def median(self):
|
||||
d = torch.tensor(list(self.deque))
|
||||
return d.median().item()
|
||||
|
||||
@property
|
||||
def avg(self):
|
||||
d = torch.tensor(list(self.deque))
|
||||
return d.mean().item()
|
||||
|
||||
@property
|
||||
def global_avg(self):
|
||||
return self.total / self.count
|
||||
|
||||
|
||||
class MetricLogger(object):
|
||||
def __init__(self, delimiter="\t"):
|
||||
self.meters = defaultdict(SmoothedValue)
|
||||
self.delimiter = delimiter
|
||||
|
||||
def update(self, **kwargs):
|
||||
for k, v in kwargs.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
v = v.item()
|
||||
assert isinstance(v, (float, int))
|
||||
self.meters[k].update(v)
|
||||
|
||||
def __getattr__(self, attr):
|
||||
if attr in self.meters:
|
||||
return self.meters[attr]
|
||||
if attr in self.__dict__:
|
||||
return self.__dict__[attr]
|
||||
raise AttributeError("'{}' object has no attribute '{}'".format(
|
||||
type(self).__name__, attr))
|
||||
|
||||
def __str__(self):
|
||||
loss_str = []
|
||||
for name, meter in self.meters.items():
|
||||
loss_str.append(
|
||||
"{}: {:.4f} ({:.4f})".format(name, meter.median, meter.global_avg)
|
||||
)
|
||||
return self.delimiter.join(loss_str)
|
78
ret_benchmark/utils/model_serialization.py
Normal file
78
ret_benchmark/utils/model_serialization.py
Normal file
@ -0,0 +1,78 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
||||
from collections import OrderedDict
|
||||
import logging
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def align_and_update_state_dicts(model_state_dict, loaded_state_dict):
|
||||
"""
|
||||
Strategy: suppose that the models that we will create will have prefixes appended
|
||||
to each of its keys, for example due to an extra level of nesting that the original
|
||||
pre-trained weights from ImageNet won't contain. For example, model.state_dict()
|
||||
might return backbone[0].body.res2.conv1.weight, while the pre-trained model contains
|
||||
res2.conv1.weight. We thus want to match both parameters together.
|
||||
For that, we look for each model weight, look among all loaded keys if there is one
|
||||
that is a suffix of the current weight name, and use it if that's the case.
|
||||
If multiple matches exist, take the one with longest size
|
||||
of the corresponding name. For example, for the same model as before, the pretrained
|
||||
weight file can contain both res2.conv1.weight, as well as conv1.weight. In this case,
|
||||
we want to match backbone[0].body.conv1.weight to conv1.weight, and
|
||||
backbone[0].body.res2.conv1.weight to res2.conv1.weight.
|
||||
"""
|
||||
current_keys = sorted(list(model_state_dict.keys()))
|
||||
loaded_keys = sorted(list(loaded_state_dict.keys()))
|
||||
# get a matrix of string matches, where each (i, j) entry correspond to the size of the
|
||||
# loaded_key string, if it matches
|
||||
match_matrix = [
|
||||
len(j) if i.endswith(j) else 0 for i in current_keys for j in loaded_keys
|
||||
]
|
||||
match_matrix = torch.as_tensor(match_matrix).view(
|
||||
len(current_keys), len(loaded_keys)
|
||||
)
|
||||
max_match_size, idxs = match_matrix.max(1)
|
||||
# remove indices that correspond to no-match
|
||||
idxs[max_match_size == 0] = -1
|
||||
|
||||
# used for logging
|
||||
max_size = max([len(key) for key in current_keys]) if current_keys else 1
|
||||
max_size_loaded = max([len(key) for key in loaded_keys]) if loaded_keys else 1
|
||||
log_str_template = "{: <{}} loaded from {: <{}} of shape {}"
|
||||
logger = logging.getLogger(__name__)
|
||||
for idx_new, idx_old in enumerate(idxs.tolist()):
|
||||
if idx_old == -1:
|
||||
continue
|
||||
key = current_keys[idx_new]
|
||||
key_old = loaded_keys[idx_old]
|
||||
model_state_dict[key] = loaded_state_dict[key_old]
|
||||
logger.info(
|
||||
log_str_template.format(
|
||||
key,
|
||||
max_size,
|
||||
key_old,
|
||||
max_size_loaded,
|
||||
tuple(loaded_state_dict[key_old].shape),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def strip_prefix_if_present(state_dict, prefix):
|
||||
keys = sorted(state_dict.keys())
|
||||
if not all(key.startswith(prefix) for key in keys):
|
||||
return state_dict
|
||||
stripped_state_dict = OrderedDict()
|
||||
for key, value in state_dict.items():
|
||||
stripped_state_dict[key.replace(prefix, "")] = value
|
||||
return stripped_state_dict
|
||||
|
||||
|
||||
def load_state_dict(model, loaded_state_dict):
|
||||
model_state_dict = model.state_dict()
|
||||
# if the state_dict comes from a model that was wrapped in a
|
||||
# DataParallel or DistributedDataParallel during serialization,
|
||||
# remove the "module" prefix before performing the matching
|
||||
loaded_state_dict = strip_prefix_if_present(loaded_state_dict, prefix="module.")
|
||||
align_and_update_state_dicts(model_state_dict, loaded_state_dict)
|
||||
|
||||
# use strict loading
|
||||
model.load_state_dict(model_state_dict)
|
46
ret_benchmark/utils/registry.py
Normal file
46
ret_benchmark/utils/registry.py
Normal file
@ -0,0 +1,46 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
||||
|
||||
|
||||
def _register_generic(module_dict, module_name, module):
|
||||
assert module_name not in module_dict
|
||||
module_dict[module_name] = module
|
||||
|
||||
|
||||
class Registry(dict):
|
||||
'''
|
||||
A helper class for managing registering modules, it extends a dictionary
|
||||
and provides a register functions.
|
||||
|
||||
Eg. creeting a registry:
|
||||
some_registry = Registry({"default": default_module})
|
||||
|
||||
There're two ways of registering new modules:
|
||||
1): normal way is just calling register function:
|
||||
def foo():
|
||||
...
|
||||
some_registry.register("foo_module", foo)
|
||||
2): used as decorator when declaring the module:
|
||||
@some_registry.register("foo_module")
|
||||
@some_registry.register("foo_modeul_nickname")
|
||||
def foo():
|
||||
...
|
||||
|
||||
Access of module is just like using a dictionary, eg:
|
||||
f = some_registry["foo_modeul"]
|
||||
'''
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(Registry, self).__init__(*args, **kwargs)
|
||||
|
||||
def register(self, module_name, module=None):
|
||||
# used as function call
|
||||
if module is not None:
|
||||
_register_generic(self, module_name, module)
|
||||
return
|
||||
|
||||
# used as decorator
|
||||
def register_fn(fn):
|
||||
_register_generic(self, module_name, fn)
|
||||
return fn
|
||||
|
||||
return register_fn
|
1
run_cub.sh
Normal file
1
run_cub.sh
Normal file
@ -0,0 +1 @@
|
||||
CUDA_VISIBLE_DEVICES=0 python3.6 tools/main.py --cfg configs/example.yaml
|
25
setup.py
Normal file
25
setup.py
Normal file
@ -0,0 +1,25 @@
|
||||
# Copyright (c) Malong Technologies Co., Ltd.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Contact: github@malong.com
|
||||
#
|
||||
# This source code is licensed under the LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import torch
|
||||
from setuptools import find_packages, setup
|
||||
from torch.utils.cpp_extension import CppExtension
|
||||
|
||||
|
||||
requirements = ["torch", "torchvision"]
|
||||
|
||||
setup(
|
||||
name="ret_benchmark",
|
||||
version="0.1",
|
||||
author="Malong Technologies",
|
||||
url="https://github.com/MalongTech/research-ms-loss",
|
||||
description="ms-loss",
|
||||
packages=find_packages(exclude=("configs", "tests")),
|
||||
install_requires=requirements,
|
||||
cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension},
|
||||
)
|
78
tools/main.py
Normal file
78
tools/main.py
Normal file
@ -0,0 +1,78 @@
|
||||
# Copyright (c) Malong Technologies Co., Ltd.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Contact: github@malong.com
|
||||
#
|
||||
# This source code is licensed under the LICENSE file in the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
import torch
|
||||
|
||||
from ret_benchmark.config import cfg
|
||||
from ret_benchmark.data import build_data
|
||||
from ret_benchmark.engine.trainer import do_train
|
||||
from ret_benchmark.losses import build_loss
|
||||
from ret_benchmark.modeling import build_model
|
||||
from ret_benchmark.solver import build_lr_scheduler, build_optimizer
|
||||
from ret_benchmark.utils.logger import setup_logger
|
||||
from ret_benchmark.utils.checkpoint import Checkpointer
|
||||
|
||||
|
||||
def train(cfg):
|
||||
logger = setup_logger(name='Train', level=cfg.LOGGER.LEVEL)
|
||||
logger.info(cfg)
|
||||
model = build_model(cfg)
|
||||
device = torch.device(cfg.MODEL.DEVICE)
|
||||
model.to(device)
|
||||
|
||||
criterion = build_loss(cfg)
|
||||
|
||||
optimizer = build_optimizer(cfg, model)
|
||||
scheduler = build_lr_scheduler(cfg, optimizer)
|
||||
|
||||
train_loader = build_data(cfg, is_train=True)
|
||||
val_loader = build_data(cfg, is_train=False)
|
||||
|
||||
logger.info(train_loader.dataset)
|
||||
logger.info(val_loader.dataset)
|
||||
|
||||
arguments = dict()
|
||||
arguments["iteration"] = 0
|
||||
|
||||
checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD
|
||||
checkpointer = Checkpointer(model, optimizer, scheduler, cfg.SAVE_DIR)
|
||||
|
||||
do_train(
|
||||
cfg,
|
||||
model,
|
||||
train_loader,
|
||||
val_loader,
|
||||
optimizer,
|
||||
scheduler,
|
||||
criterion,
|
||||
checkpointer,
|
||||
device,
|
||||
checkpoint_period,
|
||||
arguments,
|
||||
logger
|
||||
)
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""
|
||||
Parse input arguments
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description='Train a retrieval network')
|
||||
parser.add_argument(
|
||||
'--cfg',
|
||||
dest='cfg_file',
|
||||
help='config file',
|
||||
default=None,
|
||||
type=str)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
cfg.merge_from_file(args.cfg_file)
|
||||
train(cfg)
|
Loading…
x
Reference in New Issue
Block a user