initial commit

This commit is contained in:
Matt Scott 2019-08-29 15:49:29 +08:00
parent d72939633d
commit 8290973c8e
54 changed files with 2453 additions and 1 deletions

9
.flake8 Normal file
View File

@ -0,0 +1,9 @@
[flake8]
ignore = F401, F841, E402, E722, E999
max-line-length = 128
max-complexity=18
format=pylint
show_source = True
statistics = True
count = True
exclude = tests,ret_benchmark/modeling/backbone

4
.git_push.sh Normal file
View File

@ -0,0 +1,4 @@
git add .
git status
git commit -m 'update'
git push

30
.gitignore vendored Normal file
View File

@ -0,0 +1,30 @@
resource
build
*.pyc
*.zip
*/__pycache__
__pycache__
# Package Files #
*.pkl
*.log
*.jar
*.war
*.nar
*.ear
*.zip
*.tar.gz
*.rar
*.egg-info
#some local files
*/.settings/
*/.DS_Store
.DS_Store
*/.idea/
.idea/
gradlew
gradlew.bat
unused.txt
output/
*.egg-info/

335
LICENSE Normal file
View File

@ -0,0 +1,335 @@
Creative Commons Attribution-NonCommercial 4.0 International (CC-BY-NC-4.0)
Public License
For Multi-Similarity Loss for Deep Metric Learning (MS-Loss)
Copyright (c) 2014-present, Malong Technologies Co., Ltd. All rights reserved.
By exercising the Licensed Rights (defined below), You accept and agree
to be bound by the terms and conditions of this Creative Commons
Attribution-NonCommercial 4.0 International Public License ("Public
License"). To the extent this Public License may be interpreted as a
contract, You are granted the Licensed Rights in consideration of Your
acceptance of these terms and conditions, and the Licensor grants You
such rights in consideration of benefits the Licensor receives from
making the Licensed Material available under these terms and
conditions.
Section 1 -- Definitions.
a. Adapted Material means material subject to Copyright and Similar
Rights that is derived from or based upon the Licensed Material
and in which the Licensed Material is translated, altered,
arranged, transformed, or otherwise modified in a manner requiring
permission under the Copyright and Similar Rights held by the
Licensor. For purposes of this Public License, where the Licensed
Material is a musical work, performance, or sound recording,
Adapted Material is always produced where the Licensed Material is
synched in timed relation with a moving image.
b. Adapter's License means the license You apply to Your Copyright
and Similar Rights in Your contributions to Adapted Material in
accordance with the terms and conditions of this Public License.
c. Copyright and Similar Rights means copyright and/or similar rights
closely related to copyright including, without limitation,
performance, broadcast, sound recording, and Sui Generis Database
Rights, without regard to how the rights are labeled or
categorized. For purposes of this Public License, the rights
specified in Section 2(b)(1)-(2) are not Copyright and Similar
Rights.
d. Effective Technological Measures means those measures that, in the
absence of proper authority, may not be circumvented under laws
fulfilling obligations under Article 11 of the WIPO Copyright
Treaty adopted on December 20, 1996, and/or similar international
agreements.
e. Exceptions and Limitations means fair use, fair dealing, and/or
any other exception or limitation to Copyright and Similar Rights
that applies to Your use of the Licensed Material.
f. Licensed Material means the artistic or literary work, database,
or other material to which the Licensor applied this Public
License.
g. Licensed Rights means the rights granted to You subject to the
terms and conditions of this Public License, which are limited to
all Copyright and Similar Rights that apply to Your use of the
Licensed Material and that the Licensor has authority to license.
h. Licensor means the individual(s) or entity(ies) granting rights
under this Public License.
i. NonCommercial means not primarily intended for or directed towards
commercial advantage or monetary compensation. For purposes of
this Public License, the exchange of the Licensed Material for
other material subject to Copyright and Similar Rights by digital
file-sharing or similar means is NonCommercial provided there is
no payment of monetary compensation in connection with the
exchange.
j. Share means to provide material to the public by any means or
process that requires permission under the Licensed Rights, such
as reproduction, public display, public performance, distribution,
dissemination, communication, or importation, and to make material
available to the public including in ways that members of the
public may access the material from a place and at a time
individually chosen by them.
k. Sui Generis Database Rights means rights other than copyright
resulting from Directive 96/9/EC of the European Parliament and of
the Council of 11 March 1996 on the legal protection of databases,
as amended and/or succeeded, as well as other essentially
equivalent rights anywhere in the world.
l. You means the individual or entity exercising the Licensed Rights
under this Public License. Your has a corresponding meaning.
Section 2 -- Scope.
a. License grant.
1. Subject to the terms and conditions of this Public License,
the Licensor hereby grants You a worldwide, royalty-free,
non-sublicensable, non-exclusive, irrevocable license to
exercise the Licensed Rights in the Licensed Material to:
a. reproduce and Share the Licensed Material, in whole or
in part, for NonCommercial purposes only; and
b. produce, reproduce, and Share Adapted Material for
NonCommercial purposes only.
2. Exceptions and Limitations. For the avoidance of doubt, where
Exceptions and Limitations apply to Your use, this Public
License does not apply, and You do not need to comply with
its terms and conditions.
3. Term. The term of this Public License is specified in Section
6(a).
4. Media and formats; technical modifications allowed. The
Licensor authorizes You to exercise the Licensed Rights in
all media and formats whether now known or hereafter created,
and to make technical modifications necessary to do so. The
Licensor waives and/or agrees not to assert any right or
authority to forbid You from making technical modifications
necessary to exercise the Licensed Rights, including
technical modifications necessary to circumvent Effective
Technological Measures. For purposes of this Public License,
simply making modifications authorized by this Section 2(a)
(4) never produces Adapted Material.
5. Downstream recipients.
a. Offer from the Licensor -- Licensed Material. Every
recipient of the Licensed Material automatically
receives an offer from the Licensor to exercise the
Licensed Rights under the terms and conditions of this
Public License.
b. No downstream restrictions. You may not offer or impose
any additional or different terms or conditions on, or
apply any Effective Technological Measures to, the
Licensed Material if doing so restricts exercise of the
Licensed Rights by any recipient of the Licensed
Material.
6. No endorsement. Nothing in this Public License constitutes or
may be construed as permission to assert or imply that You
are, or that Your use of the Licensed Material is, connected
with, or sponsored, endorsed, or granted official status by,
the Licensor or others designated to receive attribution as
provided in Section 3(a)(1)(A)(i).
b. Other rights.
1. Moral rights, such as the right of integrity, are not
licensed under this Public License, nor are publicity,
privacy, and/or other similar personality rights; however, to
the extent possible, the Licensor waives and/or agrees not to
assert any such rights held by the Licensor to the limited
extent necessary to allow You to exercise the Licensed
Rights, but not otherwise.
2. Patent and trademark rights are not licensed under this
Public License.
3. To the extent possible, the Licensor waives any right to
collect royalties from You for the exercise of the Licensed
Rights, whether directly or through a collecting society
under any voluntary or waivable statutory or compulsory
licensing scheme. In all other cases the Licensor expressly
reserves any right to collect such royalties, including when
the Licensed Material is used other than for NonCommercial
purposes.
Section 3 -- License Conditions.
Your exercise of the Licensed Rights is expressly made subject to the
following conditions.
a. Attribution.
1. If You Share the Licensed Material (including in modified
form), You must:
a. retain the following if it is supplied by the Licensor
with the Licensed Material:
i. identification of the creator(s) of the Licensed
Material and any others designated to receive
attribution, in any reasonable manner requested by
the Licensor (including by pseudonym if
designated);
ii. a copyright notice;
iii. a notice that refers to this Public License;
iv. a notice that refers to the disclaimer of
warranties;
v. a URI or hyperlink to the Licensed Material to the
extent reasonably practicable;
b. indicate if You modified the Licensed Material and
retain an indication of any previous modifications; and
c. indicate the Licensed Material is licensed under this
Public License, and include the text of, or the URI or
hyperlink to, this Public License.
2. You may satisfy the conditions in Section 3(a)(1) in any
reasonable manner based on the medium, means, and context in
which You Share the Licensed Material. For example, it may be
reasonable to satisfy the conditions by providing a URI or
hyperlink to a resource that includes the required
information.
3. If requested by the Licensor, You must remove any of the
information required by Section 3(a)(1)(A) to the extent
reasonably practicable.
4. If You Share Adapted Material You produce, the Adapter's
License You apply must not prevent recipients of the Adapted
Material from complying with this Public License.
Section 4 -- Sui Generis Database Rights.
Where the Licensed Rights include Sui Generis Database Rights that
apply to Your use of the Licensed Material:
a. for the avoidance of doubt, Section 2(a)(1) grants You the right
to extract, reuse, reproduce, and Share all or a substantial
portion of the contents of the database for NonCommercial purposes
only;
b. if You include all or a substantial portion of the database
contents in a database in which You have Sui Generis Database
Rights, then the database in which You have Sui Generis Database
Rights (but not its individual contents) is Adapted Material; and
c. You must comply with the conditions in Section 3(a) if You Share
all or a substantial portion of the contents of the database.
For the avoidance of doubt, this Section 4 supplements and does not
replace Your obligations under this Public License where the Licensed
Rights include other Copyright and Similar Rights.
Section 5 -- Disclaimer of Warranties and Limitation of Liability.
a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
c. The disclaimer of warranties and limitation of liability provided
above shall be interpreted in a manner that, to the extent
possible, most closely approximates an absolute disclaimer and
waiver of all liability.
Section 6 -- Term and Termination.
a. This Public License applies for the term of the Copyright and
Similar Rights licensed here. However, if You fail to comply with
this Public License, then Your rights under this Public License
terminate automatically.
b. Where Your right to use the Licensed Material has terminated under
Section 6(a), it reinstates:
1. automatically as of the date the violation is cured, provided
it is cured within 30 days of Your discovery of the
violation; or
2. upon express reinstatement by the Licensor.
For the avoidance of doubt, this Section 6(b) does not affect any
right the Licensor may have to seek remedies for Your violations
of this Public License.
c. For the avoidance of doubt, the Licensor may also offer the
Licensed Material under separate terms or conditions or stop
distributing the Licensed Material at any time; however, doing so
will not terminate this Public License.
d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
License.
Section 7 -- Other Terms and Conditions.
a. The Licensor shall not be bound by any additional or different
terms or conditions communicated by You unless expressly agreed.
b. Any arrangements, understandings, or agreements regarding the
Licensed Material not stated herein are separate from and
independent of the terms and conditions of this Public License.
Section 8 -- Interpretation.
a. For the avoidance of doubt, this Public License does not, and
shall not be interpreted to, reduce, limit, restrict, or impose
conditions on any use of the Licensed Material that could lawfully
be made without permission under this Public License.
b. To the extent possible, if any provision of this Public License is
deemed unenforceable, it shall be automatically reformed to the
minimum extent necessary to make it enforceable. If the provision
cannot be reformed, it shall be severed from this Public License
without affecting the enforceability of the remaining terms and
conditions.
c. No term or condition of this Public License will be waived and no
failure to comply consented to unless expressly agreed to by the
Licensor.
d. Nothing in this Public License constitutes or may be interpreted
as a limitation upon, or waiver of, any privileges and immunities
that apply to the Licensor or You, including from the legal
processes of any jurisdiction or authority.

View File

@ -1 +1,76 @@
Coming soon!
[![License: CC BY-NC 4.0](https://licensebuttons.net/l/by-nc/4.0/80x15.png)](https://creativecommons.org/licenses/by-nc/4.0/)
# Multi-Similarity Loss for Deep Metric Learning (MS-Loss)
Code for the CVPR 2019 paper [Multi-Similarity Loss with General Pair Weighting for Deep Metric Learning](http://openaccess.thecvf.com/content_CVPR_2019/papers/Wang_Multi-Similarity_Loss_With_General_Pair_Weighting_for_Deep_Metric_Learning_CVPR_2019_paper.pdf)
<img src="misc/ms_loss.png" width="65%" height="65%">
### Performance compared with SOTA methods on CUB-200-2011
|Rank@K | 1 | 2 | 4 | 8 | 16 | 32 |
|:--- |:-:|:-:|:-:|:-:|:-: |:-: |
|Clustering<sup>64</sup> | 48.2 | 61.4 | 71.8 | 81.9 | - | - |
|ProxyNCA<sup>64</sup> | 49.2 | 61.9 | 67.9 | 72.4 | - | - |
|Smart Mining<sup>64</sup> | 49.8 | 62.3 | 74.1 | 83.3 | - |
|Our MS-Loss<sup>64</sup>| **57.4** |**69.8** |**80.0** |**87.8** |93.2 |96.4|
|HTL<sup>512</sup> | 57.1| 68.8| 78.7| 86.5| 92.5| 95.5 |
|ABIER<sup>512</sup> |57.5 |68.7 |78.3 |86.2 |91.9 |95.5 |
|Our MS-Loss<sup>512</sup>|**65.7** |**77.0** |**86.3**|**91.2** |**95.0** |**97.3**|
### Prepare the data and the pretrained model
Download the dataset from [CUB](http://www.vision.caltech.edu/visipedia-data/CUB-200/images.tgz), put it in the ./resource/datasets/ folder, and build data list (train.txt test.txt) as below:
```bash
train/020.Yellow_breasted_Chat/Yellow_Breasted_Chat_0075_21715.jpg,0
train/020.Yellow_breasted_Chat/Yellow_Breasted_Chat_0012_21961.jpg,0
train/043.Yellow_bellied_Flycatcher/Yellow_Bellied_Flycatcher_0008_42703.jpg,1
train/043.Yellow_bellied_Flycatcher/Yellow_Bellied_Flycatcher_0009_795510.jpg,1
train/043.Yellow_bellied_Flycatcher/Yellow_Bellied_Flycatcher_0003_795487.jpg,1
```
Download the imagenet pretrained model of
[bninception](http://data.lip6.fr/cadene/pretrainedmodels/bn_inception-52deb4733.pth) and put it in the folder: ~/.torch/models/.
### Installation
```bash
pip install -r requirements.txt
python setup.py develop build
```
### Train and Test on CUB200-2011 with MS-Loss
```bash
sh run_cub.sh
```
Trained models will be saved in the ./output/ folder if using the default config.
Best recall@1 higher than 66 (65.7 in the paper).
### Contact
For any questions, please feel free to reach
```
github@malong.com
```
### Citation
If you use this method or this code in your research, please cite as:
@inproceedings{wang2019multi,
title={Multi-Similarity Loss with General Pair Weighting for Deep Metric Learning},
author={Wang, Xun and Han, Xintong and Huang, Weilin and Dong, Dengke and Scott, Matthew R},
booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
pages={5022--5030},
year={2019}
}
## License
MS-Loss is CC-BY-NC 4.0 licensed, as found in the [LICENSE](LICENSE) file. It is released for academic research / non-commercial use only. If you wish to use for commercial purposes, please contact bd@malong.com.

65
ThirdPartyNotices.txt Normal file
View File

@ -0,0 +1,65 @@
THIRD PARTY SOFTWARE NOTICES AND INFORMATION
Do Not Translate or Localize
This software incorporates material from the following third parties.
_____
Cadene/pretrained-models.pytorch
BSD 3-Clause License
Copyright (c) 2017, Remi Cadene
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
* Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
_____
facebookresearch/maskrcnn-benchmark
MIT License
Copyright (c) 2018 Facebook
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

29
configs/example.yaml Normal file
View File

@ -0,0 +1,29 @@
# Copyright (c) Malong Technologies Co., Ltd.
# All rights reserved.
#
# Contact: github@malong.com
#
# This source code is licensed under the LICENSE file in the root directory of this source tree.
MODEL:
BACKBONE:
NAME: bninception
SOLVER:
MAX_ITERS: 3000
STEPS: [1200, 2400]
OPTIMIZER_NAME: Adam
BASE_LR: 0.00003
WARMUP_ITERS: 0
WEIGHT_DECAY: 0.0005
DATA:
TRAIN_IMG_SOURCE: resource/datasets/CUB_200_2011/train.txt
TEST_IMG_SOURCE: resource/datasets/CUB_200_2011/test.txt
TRAIN_BATCHSIZE: 80
TEST_BATCHSIZE: 256
NUM_WORKERS: 8
NUM_INSTANCES: 5
VALIDATION:
VERBOSE: 200

BIN
misc/ms_loss.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 88 KiB

7
requirements.txt Normal file
View File

@ -0,0 +1,7 @@
torch==1.1.0
numpy==1.15.4
yacs==0.1.4
setuptools==40.6.2
pytest==4.4.0
Pillow==6.1.0
torchvision==0.3.0

View File

@ -0,0 +1,8 @@
# Copyright (c) Malong Technologies Co., Ltd.
# All rights reserved.
#
# Contact: github@malong.com
#
# This source code is licensed under the LICENSE file in the root directory of this source tree.
from .defaults import _C as cfg

View File

@ -0,0 +1,94 @@
# Copyright (c) Malong Technologies Co., Ltd.
# All rights reserved.
#
# Contact: github@malong.com
#
# This source code is licensed under the LICENSE file in the root directory of this source tree.
from yacs.config import CfgNode as CN
from .model_path import MODEL_PATH
# -----------------------------------------------------------------------------
# Config definition
# -----------------------------------------------------------------------------
_C = CN()
_C.MODEL = CN()
_C.MODEL.DEVICE = "cuda"
_C.MODEL.BACKBONE = CN()
_C.MODEL.BACKBONE.NAME = "bninception"
_C.MODEL.PRETRAIN = 'imagenet'
_C.MODEL.PRETRIANED_PATH = MODEL_PATH
_C.MODEL.HEAD = CN()
_C.MODEL.HEAD.NAME = "linear_norm"
_C.MODEL.HEAD.DIM = 512
_C.MODEL.WEIGHT = ""
# Checkpoint save dir
_C.SAVE_DIR = 'output'
# Loss
_C.LOSSES = CN()
_C.LOSSES.NAME = 'ms_loss'
# ms loss
_C.LOSSES.MULTI_SIMILARITY_LOSS = CN()
_C.LOSSES.MULTI_SIMILARITY_LOSS.SCALE_POS = 2.0
_C.LOSSES.MULTI_SIMILARITY_LOSS.SCALE_NEG = 40.0
_C.LOSSES.MULTI_SIMILARITY_LOSS.HARD_MINING = True
# Data option
_C.DATA = CN()
_C.DATA.TRAIN_IMG_SOURCE = 'resource/datasets/CUB_200_2011/train.txt'
_C.DATA.TEST_IMG_SOURCE = 'resource/datasets/CUB_200_2011/test.txt'
_C.DATA.TRAIN_BATCHSIZE = 70
_C.DATA.TEST_BATCHSIZE = 256
_C.DATA.NUM_WORKERS = 8
_C.DATA.NUM_INSTANCES = 5
# Input option
_C.INPUT = CN()
# INPUT CONFIG
_C.INPUT.MODE = 'BGR'
_C.INPUT.PIXEL_MEAN = [104. / 255, 117. / 255, 128. / 255]
_C.INPUT.PIXEL_STD = 3 * [1. / 255]
_C.INPUT.FLIP_PROB = 0.5
_C.INPUT.ORIGIN_SIZE = 256
_C.INPUT.CROP_SCALE = [0.16, 1]
_C.INPUT.CROP_SIZE = 227
# SOLVER
_C.SOLVER = CN()
_C.SOLVER.IS_FINETURN = False
_C.SOLVER.FINETURN_MODE_PATH = ''
_C.SOLVER.MAX_ITERS = 4000
_C.SOLVER.STEPS = [1000, 2000, 3000]
_C.SOLVER.OPTIMIZER_NAME = 'SGD'
_C.SOLVER.BASE_LR = 0.01
_C.SOLVER.BIAS_LR_FACTOR = 1
_C.SOLVER.WEIGHT_DECAY = 0.0005
_C.SOLVER.WEIGHT_DECAY_BIAS = 0.0005
_C.SOLVER.MOMENTUM = 0.9
_C.SOLVER.GAMMA = 0.1
_C.SOLVER.WARMUP_FACTOR = 0.01
_C.SOLVER.WARMUP_ITERS = 200
_C.SOLVER.WARMUP_METHOD = 'linear'
_C.SOLVER.CHECKPOINT_PERIOD = 200
_C.SOLVER.RNG_SEED = 1
# Logger
_C.LOGGER = CN()
_C.LOGGER.LEVEL = 20
_C.LOGGER.STREAM = 'stdout'
# Validation
_C.VALIDATION = CN()
_C.VALIDATION.VERBOSE = 200
_C.VALIDATION.IS_VALIDATION = True

View File

@ -0,0 +1,20 @@
# Copyright (c) Malong Technologies Co., Ltd.
# All rights reserved.
#
# Contact: github@malong.com
#
# This source code is licensed under the LICENSE file in the root directory of this source tree.
# -----------------------------------------------------------------------------
# Config definition of imagenet pretrained model path
# -----------------------------------------------------------------------------
from yacs.config import CfgNode as CN
MODEL_PATH = dict()
MODEL_PATH = {
'bninception': "~/.torch/models/bn_inception-52deb4733.pth",
}
MODEL_PATH = CN(MODEL_PATH)

View File

@ -0,0 +1,8 @@
# Copyright (c) Malong Technologies Co., Ltd.
# All rights reserved.
#
# Contact: github@malong.com
#
# This source code is licensed under the LICENSE file in the root directory of this source tree.
from .build import build_data

View File

@ -0,0 +1,39 @@
# Copyright (c) Malong Technologies Co., Ltd.
# All rights reserved.
#
# Contact: github@malong.com
#
# This source code is licensed under the LICENSE file in the root directory of this source tree.
from torch.utils.data import DataLoader
from .collate_batch import collate_fn
from .datasets import BaseDataSet
from .samplers import RandomIdentitySampler
from .transforms import build_transforms
def build_data(cfg, is_train=True):
transforms = build_transforms(cfg, is_train=is_train)
if is_train:
dataset = BaseDataSet(cfg.DATA.TRAIN_IMG_SOURCE, transforms=transforms, mode=cfg.INPUT.MODE)
sampler = RandomIdentitySampler(dataset=dataset,
batch_size=cfg.DATA.TRAIN_BATCHSIZE,
num_instances=cfg.DATA.NUM_INSTANCES,
max_iters=cfg.SOLVER.MAX_ITERS
)
data_loader = DataLoader(dataset,
collate_fn=collate_fn,
batch_sampler=sampler,
num_workers=cfg.DATA.NUM_WORKERS,
pin_memory=True
)
else:
dataset = BaseDataSet(cfg.DATA.TEST_IMG_SOURCE, transforms=transforms, mode=cfg.INPUT.MODE)
data_loader = DataLoader(dataset,
collate_fn=collate_fn,
shuffle=False,
batch_size=cfg.DATA.TEST_BATCHSIZE,
num_workers=cfg.DATA.NUM_WORKERS
)
return data_loader

View File

@ -0,0 +1,15 @@
# Copyright (c) Malong Technologies Co., Ltd.
# All rights reserved.
#
# Contact: github@malong.com
#
# This source code is licensed under the LICENSE file in the root directory of this source tree.
import torch
def collate_fn(batch):
imgs, labels = zip(*batch)
labels = [int(k) for k in labels]
labels = torch.tensor(labels, dtype=torch.int64)
return torch.stack(imgs, dim=0), labels

View File

@ -0,0 +1,8 @@
# Copyright (c) Malong Technologies Co., Ltd.
# All rights reserved.
#
# Contact: github@malong.com
#
# This source code is licensed under the LICENSE file in the root directory of this source tree.
from .base_dataset import BaseDataSet

View File

@ -0,0 +1,64 @@
# Copyright (c) Malong Technologies Co., Ltd.
# All rights reserved.
#
# Contact: github@malong.com
#
# This source code is licensed under the LICENSE file in the root directory of this source tree.
import os
import re
from collections import defaultdict
from torch.utils.data import Dataset
from ret_benchmark.utils.img_reader import read_image
class BaseDataSet(Dataset):
"""
Basic Dataset read image path from img_source
img_source: list of img_path and label
"""
def __init__(self, img_source, transforms=None, mode="RGB"):
self.mode = mode
self.transforms = transforms
self.root = os.path.dirname(img_source)
assert os.path.exists(img_source), f"{img_source} NOT found."
self.img_source = img_source
self.label_list = list()
self.path_list = list()
self._load_data()
self.label_index_dict = self._build_label_index_dict()
def __len__(self):
return len(self.label_list)
def __repr__(self):
return self.__str__()
def __str__(self):
return f"| Dataset Info |datasize: {self.__len__()}|num_labels: {len(set(self.label_list))}|"
def _load_data(self):
with open(self.img_source, 'r') as f:
for line in f:
_path, _label = re.split(r",| ", line.strip())
self.path_list.append(_path)
self.label_list.append(_label)
def _build_label_index_dict(self):
index_dict = defaultdict(list)
for i, label in enumerate(self.label_list):
index_dict[label].append(i)
return index_dict
def __getitem__(self, index):
path = self.path_list[index]
img_path = os.path.join(self.root, path)
label = self.label_list[index]
img = read_image(img_path, mode=self.mode)
if self.transforms is not None:
img = self.transforms(img)
return img, label

View File

@ -0,0 +1,8 @@
# Copyright (c) Malong Technologies Co., Ltd.
# All rights reserved.
#
# Contact: github@malong.com
#
# This source code is licensed under the LICENSE file in the root directory of this source tree.
from .ret_metric import RetMetric

View File

@ -0,0 +1,44 @@
# Copyright (c) Malong Technologies Co., Ltd.
# All rights reserved.
#
# Contact: github@malong.com
#
# This source code is licensed under the LICENSE file in the root directory of this source tree.
import numpy as np
class RetMetric(object):
def __init__(self, feats, labels):
if len(feats) == 2 and type(feats) == list:
"""
feats = [gallery_feats, query_feats]
labels = [gallery_labels, query_labels]
"""
self.is_equal_query = False
self.gallery_feats, self.query_feats = feats
self.gallery_labels, self.query_labels = labels
else:
self.is_equal_query = True
self.gallery_feats = self.query_feats = feats
self.gallery_labels = self.query_labels = labels
self.sim_mat = np.matmul(self.query_feats, np.transpose(self.gallery_feats))
def recall_k(self, k=1):
m = len(self.sim_mat)
match_counter = 0
for i in range(m):
pos_sim = self.sim_mat[i][self.gallery_labels == self.query_labels[i]]
neg_sim = self.sim_mat[i][self.gallery_labels != self.query_labels[i]]
thresh = np.sort(pos_sim)[-2] if self.is_equal_query else np.max(pos_sim)
if np.sum(neg_sim > thresh) < k:
match_counter += 1
return float(match_counter) / m

View File

@ -0,0 +1,8 @@
# Copyright (c) Malong Technologies Co., Ltd.
# All rights reserved.
#
# Contact: github@malong.com
#
# This source code is licensed under the LICENSE file in the root directory of this source tree.
from .random_identity_sampler import RandomIdentitySampler

View File

@ -0,0 +1,71 @@
# Copyright (c) Malong Technologies Co., Ltd.
# All rights reserved.
#
# Contact: github@malong.com
#
# This source code is licensed under the LICENSE file in the root directory of this source tree.
import copy
import random
from collections import defaultdict
import numpy as np
import torch
from torch.utils.data.sampler import Sampler
class RandomIdentitySampler(Sampler):
"""
Randomly sample N identities, then for each identity,
randomly sample K instances, therefore batch size is N*K.
Args:
- dataset (BaseDataSet).
- num_instances (int): number of instances per identity in a batch.
- batch_size (int): number of examples in a batch.
"""
def __init__(self, dataset, batch_size, num_instances, max_iters):
self.label_index_dict = dataset.label_index_dict
self.batch_size = batch_size
self.K = num_instances
self.num_labels_per_batch = self.batch_size // self.K
self.max_iters = max_iters
self.labels = list(self.label_index_dict.keys())
def __len__(self):
return self.max_iters
def __repr__(self):
return self.__str__()
def __str__(self):
return f"|Sampler| iters {self.max_iters}| K {self.K}| M {self.batch_size}|"
def _prepare_batch(self):
batch_idxs_dict = defaultdict(list)
for label in self.labels:
idxs = copy.deepcopy(self.label_index_dict[label])
if len(idxs) < self.K:
idxs.extend(np.random.choice(idxs, size=self.K - len(idxs), replace=True))
random.shuffle(idxs)
batch_idxs_dict[label] = [idxs[i * self.K: (i + 1) * self.K] for i in range(len(idxs) // self.K)]
avai_labels = copy.deepcopy(self.labels)
return batch_idxs_dict, avai_labels
def __iter__(self):
batch_idxs_dict, avai_labels = self._prepare_batch()
for _ in range(self.max_iters):
batch = []
if len(avai_labels) < self.num_labels_per_batch:
batch_idxs_dict, avai_labels = self._prepare_batch()
selected_labels = random.sample(avai_labels, self.num_labels_per_batch)
for label in selected_labels:
batch_idxs = batch_idxs_dict[label].pop(0)
batch.extend(batch_idxs)
if len(batch_idxs_dict[label]) == 0:
avai_labels.remove(label)
yield batch

View File

@ -0,0 +1,8 @@
# Copyright (c) Malong Technologies Co., Ltd.
# All rights reserved.
#
# Contact: github@malong.com
#
# This source code is licensed under the LICENSE file in the root directory of this source tree.
from .build import build_transforms

View File

@ -0,0 +1,32 @@
# Copyright (c) Malong Technologies Co., Ltd.
# All rights reserved.
#
# Contact: github@malong.com
#
# This source code is licensed under the LICENSE file in the root directory of this source tree.
import torchvision.transforms as T
def build_transforms(cfg, is_train=True):
normalize_transform = T.Normalize(mean=cfg.INPUT.PIXEL_MEAN,
std=cfg.INPUT.PIXEL_STD)
if is_train:
transform = T.Compose([
T.Resize(size=cfg.INPUT.ORIGIN_SIZE),
T.RandomResizedCrop(
scale=cfg.INPUT.CROP_SCALE,
size=cfg.INPUT.CROP_SIZE
),
T.RandomHorizontalFlip(p=cfg.INPUT.FLIP_PROB),
T.ToTensor(),
normalize_transform,
])
else:
transform = T.Compose([
T.Resize(size=cfg.INPUT.ORIGIN_SIZE),
T.CenterCrop(cfg.INPUT.CROP_SIZE),
T.ToTensor(),
normalize_transform
])
return transform

View File

@ -0,0 +1,8 @@
# Copyright (c) Malong Technologies Co., Ltd.
# All rights reserved.
#
# Contact: github@malong.com
#
# This source code is licensed under the LICENSE file in the root directory of this source tree.
from .trainer import do_train

View File

@ -0,0 +1,119 @@
# Copyright (c) Malong Technologies Co., Ltd.
# All rights reserved.
#
# Contact: github@malong.com
#
# This source code is licensed under the LICENSE file in the root directory of this source tree.
import datetime
import time
import numpy as np
import torch
from ret_benchmark.data.evaluations import RetMetric
from ret_benchmark.utils.feat_extractor import feat_extractor
from ret_benchmark.utils.freeze_bn import set_bn_eval
from ret_benchmark.utils.metric_logger import MetricLogger
def do_train(
cfg,
model,
train_loader,
val_loader,
optimizer,
scheduler,
criterion,
checkpointer,
device,
checkpoint_period,
arguments,
logger
):
logger.info("Start training")
meters = MetricLogger(delimiter=" ")
max_iter = len(train_loader)
start_iter = arguments["iteration"]
best_iteration = -1
best_recall = 0
start_training_time = time.time()
end = time.time()
for iteration, (images, targets) in enumerate(train_loader, start_iter):
if iteration % cfg.VALIDATION.VERBOSE == 0 or iteration == max_iter:
model.eval()
logger.info('Validation')
labels = val_loader.dataset.label_list
labels = np.array([int(k) for k in labels])
feats = feat_extractor(model, val_loader, logger=logger)
ret_metric = RetMetric(feats=feats, labels=labels)
recall_curr = ret_metric.recall_k(1)
if recall_curr > best_recall:
best_recall = recall_curr
best_iteration = iteration
logger.info(f'Best iteration {iteration}: recall@1: {best_recall:.3f}')
checkpointer.save(f"best_model")
else:
logger.info(f'Recall@1 at iteration {iteration:06d}: {recall_curr:.3f}')
model.train()
model.apply(set_bn_eval)
data_time = time.time() - end
iteration = iteration + 1
arguments["iteration"] = iteration
scheduler.step()
images = images.to(device)
targets = torch.stack([target.to(device) for target in targets])
feats = model(images)
loss = criterion(feats, targets)
optimizer.zero_grad()
loss.backward()
optimizer.step()
batch_time = time.time() - end
end = time.time()
meters.update(time=batch_time, data=data_time, loss=loss.item())
eta_seconds = meters.time.global_avg * (max_iter - iteration)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
if iteration % 20 == 0 or iteration == max_iter:
logger.info(
meters.delimiter.join(
[
"eta: {eta}",
"iter: {iter}",
"{meters}",
"lr: {lr:.6f}",
"max mem: {memory:.1f} GB",
]
).format(
eta=eta_string,
iter=iteration,
meters=str(meters),
lr=optimizer.param_groups[0]["lr"],
memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0 / 1024.0,
)
)
if iteration % checkpoint_period == 0:
checkpointer.save("model_{:06d}".format(iteration))
total_training_time = time.time() - start_training_time
total_time_str = str(datetime.timedelta(seconds=total_training_time))
logger.info(
"Total training time: {} ({:.4f} s / it)".format(
total_time_str, total_training_time / (max_iter)
)
)
logger.info(f"Best iteration: {best_iteration :06d} | best recall {best_recall} ")

View File

@ -0,0 +1,8 @@
# Copyright (c) Malong Technologies Co., Ltd.
# All rights reserved.
#
# Contact: github@malong.com
#
# This source code is licensed under the LICENSE file in the root directory of this source tree.
from .build import build_loss

View File

@ -0,0 +1,16 @@
# Copyright (c) Malong Technologies Co., Ltd.
# All rights reserved.
#
# Contact: github@malong.com
#
# This source code is licensed under the LICENSE file in the root directory of this source tree.
from .multi_similarity_loss import MultiSimilarityLoss
from .registry import LOSS
def build_loss(cfg):
loss_name = cfg.LOSSES.NAME
assert loss_name in LOSS, \
f'loss name {loss_name} is not registered in registry'
return LOSS[loss_name](cfg)

View File

@ -0,0 +1,55 @@
# Copyright (c) Malong Technologies Co., Ltd.
# All rights reserved.
#
# Contact: github@malong.com
#
# This source code is licensed under the LICENSE file in the root directory of this source tree.
import torch
from torch import nn
from ret_benchmark.losses.registry import LOSS
@LOSS.register('ms_loss')
class MultiSimilarityLoss(nn.Module):
def __init__(self, cfg):
super(MultiSimilarityLoss, self).__init__()
self.thresh = 0.5
self.margin = 0.1
self.scale_pos = cfg.LOSSES.MULTI_SIMILARITY_LOSS.SCALE_POS
self.scale_neg = cfg.LOSSES.MULTI_SIMILARITY_LOSS.SCALE_NEG
def forward(self, feats, labels):
assert feats.size(0) == labels.size(0), \
f"feats.size(0): {feats.size(0)} is not equal to labels.size(0): {labels.size(0)}"
batch_size = feats.size(0)
sim_mat = torch.matmul(feats, torch.t(feats))
epsilon = 1e-5
loss = list()
for i in range(batch_size):
pos_pair_ = sim_mat[i][labels == labels[i]]
pos_pair_ = pos_pair_[pos_pair_ < 1 - epsilon]
neg_pair_ = sim_mat[i][labels != labels[i]]
neg_pair = neg_pair_[neg_pair_ + self.margin > min(pos_pair_)]
pos_pair = pos_pair_[pos_pair_ - self.margin < max(neg_pair_)]
if len(neg_pair) < 1 or len(pos_pair) < 1:
continue
# weighting step
pos_loss = 1.0 / self.scale_pos * torch.log(
1 + torch.sum(torch.exp(-self.scale_pos * (pos_pair - self.thresh))))
neg_loss = 1.0 / self.scale_neg * torch.log(
1 + torch.sum(torch.exp(self.scale_neg * (neg_pair - self.thresh))))
loss.append(pos_loss + neg_loss)
if len(loss) == 0:
return torch.zeros([], requires_grad=True)
loss = sum(loss) / batch_size
return loss

View File

@ -0,0 +1,10 @@
# Copyright (c) Malong Technologies Co., Ltd.
# All rights reserved.
#
# Contact: github@malong.com
#
# This source code is licensed under the LICENSE file in the root directory of this source tree.
from ret_benchmark.utils.registry import Registry
LOSS = Registry()

View File

@ -0,0 +1,11 @@
# Copyright (c) Malong Technologies Co., Ltd.
# All rights reserved.
#
# Contact: github@malong.com
#
# This source code is licensed under the LICENSE file in the root directory of this source tree.
from .backbone import build_backbone
from .build import build_model
from .heads import build_head
from .registry import BACKBONES, HEADS

View File

@ -0,0 +1 @@
from .build import build_backbone

View File

@ -0,0 +1,520 @@
from __future__ import absolute_import, division, print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
from ret_benchmark.modeling import registry
@registry.BACKBONES.register('bninception')
class BNInception(nn.Module):
def __init__(self):
super(BNInception, self).__init__()
inplace = True
self.conv1_7x7_s2 = nn.Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))
self.conv1_7x7_s2_bn = nn.BatchNorm2d(64, affine=True)
self.conv1_relu_7x7 = nn.ReLU(inplace)
self.pool1_3x3_s2 = nn.MaxPool2d((3, 3), stride=(2, 2), dilation=(1, 1), ceil_mode=True)
self.conv2_3x3_reduce = nn.Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
self.conv2_3x3_reduce_bn = nn.BatchNorm2d(64, affine=True)
self.conv2_relu_3x3_reduce = nn.ReLU(inplace)
self.conv2_3x3 = nn.Conv2d(64, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.conv2_3x3_bn = nn.BatchNorm2d(192, affine=True)
self.conv2_relu_3x3 = nn.ReLU(inplace)
self.pool2_3x3_s2 = nn.MaxPool2d((3, 3), stride=(2, 2), dilation=(1, 1), ceil_mode=True)
self.inception_3a_1x1 = nn.Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1))
self.inception_3a_1x1_bn = nn.BatchNorm2d(64, affine=True)
self.inception_3a_relu_1x1 = nn.ReLU(inplace)
self.inception_3a_3x3_reduce = nn.Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1))
self.inception_3a_3x3_reduce_bn = nn.BatchNorm2d(64, affine=True)
self.inception_3a_relu_3x3_reduce = nn.ReLU(inplace)
self.inception_3a_3x3 = nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.inception_3a_3x3_bn = nn.BatchNorm2d(64, affine=True)
self.inception_3a_relu_3x3 = nn.ReLU(inplace)
self.inception_3a_double_3x3_reduce = nn.Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1))
self.inception_3a_double_3x3_reduce_bn = nn.BatchNorm2d(64, affine=True)
self.inception_3a_relu_double_3x3_reduce = nn.ReLU(inplace)
self.inception_3a_double_3x3_1 = nn.Conv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.inception_3a_double_3x3_1_bn = nn.BatchNorm2d(96, affine=True)
self.inception_3a_relu_double_3x3_1 = nn.ReLU(inplace)
self.inception_3a_double_3x3_2 = nn.Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.inception_3a_double_3x3_2_bn = nn.BatchNorm2d(96, affine=True)
self.inception_3a_relu_double_3x3_2 = nn.ReLU(inplace)
self.inception_3a_pool = nn.AvgPool2d(3, stride=1, padding=1, ceil_mode=True, count_include_pad=True)
self.inception_3a_pool_proj = nn.Conv2d(192, 32, kernel_size=(1, 1), stride=(1, 1))
self.inception_3a_pool_proj_bn = nn.BatchNorm2d(32, affine=True)
self.inception_3a_relu_pool_proj = nn.ReLU(inplace)
self.inception_3b_1x1 = nn.Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))
self.inception_3b_1x1_bn = nn.BatchNorm2d(64, affine=True)
self.inception_3b_relu_1x1 = nn.ReLU(inplace)
self.inception_3b_3x3_reduce = nn.Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))
self.inception_3b_3x3_reduce_bn = nn.BatchNorm2d(64, affine=True)
self.inception_3b_relu_3x3_reduce = nn.ReLU(inplace)
self.inception_3b_3x3 = nn.Conv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.inception_3b_3x3_bn = nn.BatchNorm2d(96, affine=True)
self.inception_3b_relu_3x3 = nn.ReLU(inplace)
self.inception_3b_double_3x3_reduce = nn.Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))
self.inception_3b_double_3x3_reduce_bn = nn.BatchNorm2d(64, affine=True)
self.inception_3b_relu_double_3x3_reduce = nn.ReLU(inplace)
self.inception_3b_double_3x3_1 = nn.Conv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.inception_3b_double_3x3_1_bn = nn.BatchNorm2d(96, affine=True)
self.inception_3b_relu_double_3x3_1 = nn.ReLU(inplace)
self.inception_3b_double_3x3_2 = nn.Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.inception_3b_double_3x3_2_bn = nn.BatchNorm2d(96, affine=True)
self.inception_3b_relu_double_3x3_2 = nn.ReLU(inplace)
self.inception_3b_pool = nn.AvgPool2d(3, stride=1, padding=1, ceil_mode=True, count_include_pad=True)
self.inception_3b_pool_proj = nn.Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))
self.inception_3b_pool_proj_bn = nn.BatchNorm2d(64, affine=True)
self.inception_3b_relu_pool_proj = nn.ReLU(inplace)
self.inception_3c_3x3_reduce = nn.Conv2d(320, 128, kernel_size=(1, 1), stride=(1, 1))
self.inception_3c_3x3_reduce_bn = nn.BatchNorm2d(128, affine=True)
self.inception_3c_relu_3x3_reduce = nn.ReLU(inplace)
self.inception_3c_3x3 = nn.Conv2d(128, 160, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
self.inception_3c_3x3_bn = nn.BatchNorm2d(160, affine=True)
self.inception_3c_relu_3x3 = nn.ReLU(inplace)
self.inception_3c_double_3x3_reduce = nn.Conv2d(320, 64, kernel_size=(1, 1), stride=(1, 1))
self.inception_3c_double_3x3_reduce_bn = nn.BatchNorm2d(64, affine=True)
self.inception_3c_relu_double_3x3_reduce = nn.ReLU(inplace)
self.inception_3c_double_3x3_1 = nn.Conv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.inception_3c_double_3x3_1_bn = nn.BatchNorm2d(96, affine=True)
self.inception_3c_relu_double_3x3_1 = nn.ReLU(inplace)
self.inception_3c_double_3x3_2 = nn.Conv2d(96, 96, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
self.inception_3c_double_3x3_2_bn = nn.BatchNorm2d(96, affine=True)
self.inception_3c_relu_double_3x3_2 = nn.ReLU(inplace)
self.inception_3c_pool = nn.MaxPool2d((3, 3), stride=(2, 2), dilation=(1, 1), ceil_mode=True)
self.inception_4a_1x1 = nn.Conv2d(576, 224, kernel_size=(1, 1), stride=(1, 1))
self.inception_4a_1x1_bn = nn.BatchNorm2d(224, affine=True)
self.inception_4a_relu_1x1 = nn.ReLU(inplace)
self.inception_4a_3x3_reduce = nn.Conv2d(576, 64, kernel_size=(1, 1), stride=(1, 1))
self.inception_4a_3x3_reduce_bn = nn.BatchNorm2d(64, affine=True)
self.inception_4a_relu_3x3_reduce = nn.ReLU(inplace)
self.inception_4a_3x3 = nn.Conv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.inception_4a_3x3_bn = nn.BatchNorm2d(96, affine=True)
self.inception_4a_relu_3x3 = nn.ReLU(inplace)
self.inception_4a_double_3x3_reduce = nn.Conv2d(576, 96, kernel_size=(1, 1), stride=(1, 1))
self.inception_4a_double_3x3_reduce_bn = nn.BatchNorm2d(96, affine=True)
self.inception_4a_relu_double_3x3_reduce = nn.ReLU(inplace)
self.inception_4a_double_3x3_1 = nn.Conv2d(96, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.inception_4a_double_3x3_1_bn = nn.BatchNorm2d(128, affine=True)
self.inception_4a_relu_double_3x3_1 = nn.ReLU(inplace)
self.inception_4a_double_3x3_2 = nn.Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.inception_4a_double_3x3_2_bn = nn.BatchNorm2d(128, affine=True)
self.inception_4a_relu_double_3x3_2 = nn.ReLU(inplace)
self.inception_4a_pool = nn.AvgPool2d(3, stride=1, padding=1, ceil_mode=True, count_include_pad=True)
self.inception_4a_pool_proj = nn.Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1))
self.inception_4a_pool_proj_bn = nn.BatchNorm2d(128, affine=True)
self.inception_4a_relu_pool_proj = nn.ReLU(inplace)
self.inception_4b_1x1 = nn.Conv2d(576, 192, kernel_size=(1, 1), stride=(1, 1))
self.inception_4b_1x1_bn = nn.BatchNorm2d(192, affine=True)
self.inception_4b_relu_1x1 = nn.ReLU(inplace)
self.inception_4b_3x3_reduce = nn.Conv2d(576, 96, kernel_size=(1, 1), stride=(1, 1))
self.inception_4b_3x3_reduce_bn = nn.BatchNorm2d(96, affine=True)
self.inception_4b_relu_3x3_reduce = nn.ReLU(inplace)
self.inception_4b_3x3 = nn.Conv2d(96, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.inception_4b_3x3_bn = nn.BatchNorm2d(128, affine=True)
self.inception_4b_relu_3x3 = nn.ReLU(inplace)
self.inception_4b_double_3x3_reduce = nn.Conv2d(576, 96, kernel_size=(1, 1), stride=(1, 1))
self.inception_4b_double_3x3_reduce_bn = nn.BatchNorm2d(96, affine=True)
self.inception_4b_relu_double_3x3_reduce = nn.ReLU(inplace)
self.inception_4b_double_3x3_1 = nn.Conv2d(96, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.inception_4b_double_3x3_1_bn = nn.BatchNorm2d(128, affine=True)
self.inception_4b_relu_double_3x3_1 = nn.ReLU(inplace)
self.inception_4b_double_3x3_2 = nn.Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.inception_4b_double_3x3_2_bn = nn.BatchNorm2d(128, affine=True)
self.inception_4b_relu_double_3x3_2 = nn.ReLU(inplace)
self.inception_4b_pool = nn.AvgPool2d(3, stride=1, padding=1, ceil_mode=True, count_include_pad=True)
self.inception_4b_pool_proj = nn.Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1))
self.inception_4b_pool_proj_bn = nn.BatchNorm2d(128, affine=True)
self.inception_4b_relu_pool_proj = nn.ReLU(inplace)
self.inception_4c_1x1 = nn.Conv2d(576, 160, kernel_size=(1, 1), stride=(1, 1))
self.inception_4c_1x1_bn = nn.BatchNorm2d(160, affine=True)
self.inception_4c_relu_1x1 = nn.ReLU(inplace)
self.inception_4c_3x3_reduce = nn.Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1))
self.inception_4c_3x3_reduce_bn = nn.BatchNorm2d(128, affine=True)
self.inception_4c_relu_3x3_reduce = nn.ReLU(inplace)
self.inception_4c_3x3 = nn.Conv2d(128, 160, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.inception_4c_3x3_bn = nn.BatchNorm2d(160, affine=True)
self.inception_4c_relu_3x3 = nn.ReLU(inplace)
self.inception_4c_double_3x3_reduce = nn.Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1))
self.inception_4c_double_3x3_reduce_bn = nn.BatchNorm2d(128, affine=True)
self.inception_4c_relu_double_3x3_reduce = nn.ReLU(inplace)
self.inception_4c_double_3x3_1 = nn.Conv2d(128, 160, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.inception_4c_double_3x3_1_bn = nn.BatchNorm2d(160, affine=True)
self.inception_4c_relu_double_3x3_1 = nn.ReLU(inplace)
self.inception_4c_double_3x3_2 = nn.Conv2d(160, 160, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.inception_4c_double_3x3_2_bn = nn.BatchNorm2d(160, affine=True)
self.inception_4c_relu_double_3x3_2 = nn.ReLU(inplace)
self.inception_4c_pool = nn.AvgPool2d(3, stride=1, padding=1, ceil_mode=True, count_include_pad=True)
self.inception_4c_pool_proj = nn.Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1))
self.inception_4c_pool_proj_bn = nn.BatchNorm2d(128, affine=True)
self.inception_4c_relu_pool_proj = nn.ReLU(inplace)
self.inception_4d_1x1 = nn.Conv2d(608, 96, kernel_size=(1, 1), stride=(1, 1))
self.inception_4d_1x1_bn = nn.BatchNorm2d(96, affine=True)
self.inception_4d_relu_1x1 = nn.ReLU(inplace)
self.inception_4d_3x3_reduce = nn.Conv2d(608, 128, kernel_size=(1, 1), stride=(1, 1))
self.inception_4d_3x3_reduce_bn = nn.BatchNorm2d(128, affine=True)
self.inception_4d_relu_3x3_reduce = nn.ReLU(inplace)
self.inception_4d_3x3 = nn.Conv2d(128, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.inception_4d_3x3_bn = nn.BatchNorm2d(192, affine=True)
self.inception_4d_relu_3x3 = nn.ReLU(inplace)
self.inception_4d_double_3x3_reduce = nn.Conv2d(608, 160, kernel_size=(1, 1), stride=(1, 1))
self.inception_4d_double_3x3_reduce_bn = nn.BatchNorm2d(160, affine=True)
self.inception_4d_relu_double_3x3_reduce = nn.ReLU(inplace)
self.inception_4d_double_3x3_1 = nn.Conv2d(160, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.inception_4d_double_3x3_1_bn = nn.BatchNorm2d(192, affine=True)
self.inception_4d_relu_double_3x3_1 = nn.ReLU(inplace)
self.inception_4d_double_3x3_2 = nn.Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.inception_4d_double_3x3_2_bn = nn.BatchNorm2d(192, affine=True)
self.inception_4d_relu_double_3x3_2 = nn.ReLU(inplace)
self.inception_4d_pool = nn.AvgPool2d(3, stride=1, padding=1, ceil_mode=True, count_include_pad=True)
self.inception_4d_pool_proj = nn.Conv2d(608, 128, kernel_size=(1, 1), stride=(1, 1))
self.inception_4d_pool_proj_bn = nn.BatchNorm2d(128, affine=True)
self.inception_4d_relu_pool_proj = nn.ReLU(inplace)
self.inception_4e_3x3_reduce = nn.Conv2d(608, 128, kernel_size=(1, 1), stride=(1, 1))
self.inception_4e_3x3_reduce_bn = nn.BatchNorm2d(128, affine=True)
self.inception_4e_relu_3x3_reduce = nn.ReLU(inplace)
self.inception_4e_3x3 = nn.Conv2d(128, 192, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
self.inception_4e_3x3_bn = nn.BatchNorm2d(192, affine=True)
self.inception_4e_relu_3x3 = nn.ReLU(inplace)
self.inception_4e_double_3x3_reduce = nn.Conv2d(608, 192, kernel_size=(1, 1), stride=(1, 1))
self.inception_4e_double_3x3_reduce_bn = nn.BatchNorm2d(192, affine=True)
self.inception_4e_relu_double_3x3_reduce = nn.ReLU(inplace)
self.inception_4e_double_3x3_1 = nn.Conv2d(192, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.inception_4e_double_3x3_1_bn = nn.BatchNorm2d(256, affine=True)
self.inception_4e_relu_double_3x3_1 = nn.ReLU(inplace)
self.inception_4e_double_3x3_2 = nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
self.inception_4e_double_3x3_2_bn = nn.BatchNorm2d(256, affine=True)
self.inception_4e_relu_double_3x3_2 = nn.ReLU(inplace)
self.inception_4e_pool = nn.MaxPool2d((3, 3), stride=(2, 2), dilation=(1, 1), ceil_mode=True)
self.inception_5a_1x1 = nn.Conv2d(1056, 352, kernel_size=(1, 1), stride=(1, 1))
self.inception_5a_1x1_bn = nn.BatchNorm2d(352, affine=True)
self.inception_5a_relu_1x1 = nn.ReLU(inplace)
self.inception_5a_3x3_reduce = nn.Conv2d(1056, 192, kernel_size=(1, 1), stride=(1, 1))
self.inception_5a_3x3_reduce_bn = nn.BatchNorm2d(192, affine=True)
self.inception_5a_relu_3x3_reduce = nn.ReLU(inplace)
self.inception_5a_3x3 = nn.Conv2d(192, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.inception_5a_3x3_bn = nn.BatchNorm2d(320, affine=True)
self.inception_5a_relu_3x3 = nn.ReLU(inplace)
self.inception_5a_double_3x3_reduce = nn.Conv2d(1056, 160, kernel_size=(1, 1), stride=(1, 1))
self.inception_5a_double_3x3_reduce_bn = nn.BatchNorm2d(160, affine=True)
self.inception_5a_relu_double_3x3_reduce = nn.ReLU(inplace)
self.inception_5a_double_3x3_1 = nn.Conv2d(160, 224, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.inception_5a_double_3x3_1_bn = nn.BatchNorm2d(224, affine=True)
self.inception_5a_relu_double_3x3_1 = nn.ReLU(inplace)
self.inception_5a_double_3x3_2 = nn.Conv2d(224, 224, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.inception_5a_double_3x3_2_bn = nn.BatchNorm2d(224, affine=True)
self.inception_5a_relu_double_3x3_2 = nn.ReLU(inplace)
self.inception_5a_pool = nn.AvgPool2d(3, stride=1, padding=1, ceil_mode=True, count_include_pad=True)
self.inception_5a_pool_proj = nn.Conv2d(1056, 128, kernel_size=(1, 1), stride=(1, 1))
self.inception_5a_pool_proj_bn = nn.BatchNorm2d(128, affine=True)
self.inception_5a_relu_pool_proj = nn.ReLU(inplace)
self.inception_5b_1x1 = nn.Conv2d(1024, 352, kernel_size=(1, 1), stride=(1, 1))
self.inception_5b_1x1_bn = nn.BatchNorm2d(352, affine=True)
self.inception_5b_relu_1x1 = nn.ReLU(inplace)
self.inception_5b_3x3_reduce = nn.Conv2d(1024, 192, kernel_size=(1, 1), stride=(1, 1))
self.inception_5b_3x3_reduce_bn = nn.BatchNorm2d(192, affine=True)
self.inception_5b_relu_3x3_reduce = nn.ReLU(inplace)
self.inception_5b_3x3 = nn.Conv2d(192, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.inception_5b_3x3_bn = nn.BatchNorm2d(320, affine=True)
self.inception_5b_relu_3x3 = nn.ReLU(inplace)
self.inception_5b_double_3x3_reduce = nn.Conv2d(1024, 192, kernel_size=(1, 1), stride=(1, 1))
self.inception_5b_double_3x3_reduce_bn = nn.BatchNorm2d(192, affine=True)
self.inception_5b_relu_double_3x3_reduce = nn.ReLU(inplace)
self.inception_5b_double_3x3_1 = nn.Conv2d(192, 224, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.inception_5b_double_3x3_1_bn = nn.BatchNorm2d(224, affine=True)
self.inception_5b_relu_double_3x3_1 = nn.ReLU(inplace)
self.inception_5b_double_3x3_2 = nn.Conv2d(224, 224, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.inception_5b_double_3x3_2_bn = nn.BatchNorm2d(224, affine=True)
self.inception_5b_relu_double_3x3_2 = nn.ReLU(inplace)
self.inception_5b_pool = nn.MaxPool2d((3, 3), stride=(1, 1), padding=(1, 1), dilation=(1, 1), ceil_mode=True)
self.inception_5b_pool_proj = nn.Conv2d(1024, 128, kernel_size=(1, 1), stride=(1, 1))
self.inception_5b_pool_proj_bn = nn.BatchNorm2d(128, affine=True)
self.inception_5b_relu_pool_proj = nn.ReLU(inplace)
def features(self, input):
conv1_7x7_s2_out = self.conv1_7x7_s2(input)
conv1_7x7_s2_bn_out = self.conv1_7x7_s2_bn(conv1_7x7_s2_out)
conv1_relu_7x7_out = self.conv1_relu_7x7(conv1_7x7_s2_bn_out)
pool1_3x3_s2_out = self.pool1_3x3_s2(conv1_relu_7x7_out)
conv2_3x3_reduce_out = self.conv2_3x3_reduce(pool1_3x3_s2_out)
conv2_3x3_reduce_bn_out = self.conv2_3x3_reduce_bn(conv2_3x3_reduce_out)
conv2_relu_3x3_reduce_out = self.conv2_relu_3x3_reduce(conv2_3x3_reduce_bn_out)
conv2_3x3_out = self.conv2_3x3(conv2_relu_3x3_reduce_out)
conv2_3x3_bn_out = self.conv2_3x3_bn(conv2_3x3_out)
conv2_relu_3x3_out = self.conv2_relu_3x3(conv2_3x3_bn_out)
pool2_3x3_s2_out = self.pool2_3x3_s2(conv2_relu_3x3_out)
inception_3a_1x1_out = self.inception_3a_1x1(pool2_3x3_s2_out)
inception_3a_1x1_bn_out = self.inception_3a_1x1_bn(inception_3a_1x1_out)
inception_3a_relu_1x1_out = self.inception_3a_relu_1x1(inception_3a_1x1_bn_out)
inception_3a_3x3_reduce_out = self.inception_3a_3x3_reduce(pool2_3x3_s2_out)
inception_3a_3x3_reduce_bn_out = self.inception_3a_3x3_reduce_bn(inception_3a_3x3_reduce_out)
inception_3a_relu_3x3_reduce_out = self.inception_3a_relu_3x3_reduce(inception_3a_3x3_reduce_bn_out)
inception_3a_3x3_out = self.inception_3a_3x3(inception_3a_relu_3x3_reduce_out)
inception_3a_3x3_bn_out = self.inception_3a_3x3_bn(inception_3a_3x3_out)
inception_3a_relu_3x3_out = self.inception_3a_relu_3x3(inception_3a_3x3_bn_out)
inception_3a_double_3x3_reduce_out = self.inception_3a_double_3x3_reduce(pool2_3x3_s2_out)
inception_3a_double_3x3_reduce_bn_out = self.inception_3a_double_3x3_reduce_bn(
inception_3a_double_3x3_reduce_out)
inception_3a_relu_double_3x3_reduce_out = self.inception_3a_relu_double_3x3_reduce(
inception_3a_double_3x3_reduce_bn_out)
inception_3a_double_3x3_1_out = self.inception_3a_double_3x3_1(inception_3a_relu_double_3x3_reduce_out)
inception_3a_double_3x3_1_bn_out = self.inception_3a_double_3x3_1_bn(inception_3a_double_3x3_1_out)
inception_3a_relu_double_3x3_1_out = self.inception_3a_relu_double_3x3_1(inception_3a_double_3x3_1_bn_out)
inception_3a_double_3x3_2_out = self.inception_3a_double_3x3_2(inception_3a_relu_double_3x3_1_out)
inception_3a_double_3x3_2_bn_out = self.inception_3a_double_3x3_2_bn(inception_3a_double_3x3_2_out)
inception_3a_relu_double_3x3_2_out = self.inception_3a_relu_double_3x3_2(inception_3a_double_3x3_2_bn_out)
inception_3a_pool_out = self.inception_3a_pool(pool2_3x3_s2_out)
inception_3a_pool_proj_out = self.inception_3a_pool_proj(inception_3a_pool_out)
inception_3a_pool_proj_bn_out = self.inception_3a_pool_proj_bn(inception_3a_pool_proj_out)
inception_3a_relu_pool_proj_out = self.inception_3a_relu_pool_proj(inception_3a_pool_proj_bn_out)
inception_3a_output_out = torch.cat(
[inception_3a_relu_1x1_out, inception_3a_relu_3x3_out, inception_3a_relu_double_3x3_2_out,
inception_3a_relu_pool_proj_out], 1)
inception_3b_1x1_out = self.inception_3b_1x1(inception_3a_output_out)
inception_3b_1x1_bn_out = self.inception_3b_1x1_bn(inception_3b_1x1_out)
inception_3b_relu_1x1_out = self.inception_3b_relu_1x1(inception_3b_1x1_bn_out)
inception_3b_3x3_reduce_out = self.inception_3b_3x3_reduce(inception_3a_output_out)
inception_3b_3x3_reduce_bn_out = self.inception_3b_3x3_reduce_bn(inception_3b_3x3_reduce_out)
inception_3b_relu_3x3_reduce_out = self.inception_3b_relu_3x3_reduce(inception_3b_3x3_reduce_bn_out)
inception_3b_3x3_out = self.inception_3b_3x3(inception_3b_relu_3x3_reduce_out)
inception_3b_3x3_bn_out = self.inception_3b_3x3_bn(inception_3b_3x3_out)
inception_3b_relu_3x3_out = self.inception_3b_relu_3x3(inception_3b_3x3_bn_out)
inception_3b_double_3x3_reduce_out = self.inception_3b_double_3x3_reduce(inception_3a_output_out)
inception_3b_double_3x3_reduce_bn_out = self.inception_3b_double_3x3_reduce_bn(
inception_3b_double_3x3_reduce_out)
inception_3b_relu_double_3x3_reduce_out = self.inception_3b_relu_double_3x3_reduce(
inception_3b_double_3x3_reduce_bn_out)
inception_3b_double_3x3_1_out = self.inception_3b_double_3x3_1(inception_3b_relu_double_3x3_reduce_out)
inception_3b_double_3x3_1_bn_out = self.inception_3b_double_3x3_1_bn(inception_3b_double_3x3_1_out)
inception_3b_relu_double_3x3_1_out = self.inception_3b_relu_double_3x3_1(inception_3b_double_3x3_1_bn_out)
inception_3b_double_3x3_2_out = self.inception_3b_double_3x3_2(inception_3b_relu_double_3x3_1_out)
inception_3b_double_3x3_2_bn_out = self.inception_3b_double_3x3_2_bn(inception_3b_double_3x3_2_out)
inception_3b_relu_double_3x3_2_out = self.inception_3b_relu_double_3x3_2(inception_3b_double_3x3_2_bn_out)
inception_3b_pool_out = self.inception_3b_pool(inception_3a_output_out)
inception_3b_pool_proj_out = self.inception_3b_pool_proj(inception_3b_pool_out)
inception_3b_pool_proj_bn_out = self.inception_3b_pool_proj_bn(inception_3b_pool_proj_out)
inception_3b_relu_pool_proj_out = self.inception_3b_relu_pool_proj(inception_3b_pool_proj_bn_out)
inception_3b_output_out = torch.cat(
[inception_3b_relu_1x1_out, inception_3b_relu_3x3_out, inception_3b_relu_double_3x3_2_out,
inception_3b_relu_pool_proj_out], 1)
inception_3c_3x3_reduce_out = self.inception_3c_3x3_reduce(inception_3b_output_out)
inception_3c_3x3_reduce_bn_out = self.inception_3c_3x3_reduce_bn(inception_3c_3x3_reduce_out)
inception_3c_relu_3x3_reduce_out = self.inception_3c_relu_3x3_reduce(inception_3c_3x3_reduce_bn_out)
inception_3c_3x3_out = self.inception_3c_3x3(inception_3c_relu_3x3_reduce_out)
inception_3c_3x3_bn_out = self.inception_3c_3x3_bn(inception_3c_3x3_out)
inception_3c_relu_3x3_out = self.inception_3c_relu_3x3(inception_3c_3x3_bn_out)
inception_3c_double_3x3_reduce_out = self.inception_3c_double_3x3_reduce(inception_3b_output_out)
inception_3c_double_3x3_reduce_bn_out = self.inception_3c_double_3x3_reduce_bn(
inception_3c_double_3x3_reduce_out)
inception_3c_relu_double_3x3_reduce_out = self.inception_3c_relu_double_3x3_reduce(
inception_3c_double_3x3_reduce_bn_out)
inception_3c_double_3x3_1_out = self.inception_3c_double_3x3_1(inception_3c_relu_double_3x3_reduce_out)
inception_3c_double_3x3_1_bn_out = self.inception_3c_double_3x3_1_bn(inception_3c_double_3x3_1_out)
inception_3c_relu_double_3x3_1_out = self.inception_3c_relu_double_3x3_1(inception_3c_double_3x3_1_bn_out)
inception_3c_double_3x3_2_out = self.inception_3c_double_3x3_2(inception_3c_relu_double_3x3_1_out)
inception_3c_double_3x3_2_bn_out = self.inception_3c_double_3x3_2_bn(inception_3c_double_3x3_2_out)
inception_3c_relu_double_3x3_2_out = self.inception_3c_relu_double_3x3_2(inception_3c_double_3x3_2_bn_out)
inception_3c_pool_out = self.inception_3c_pool(inception_3b_output_out)
inception_3c_output_out = torch.cat(
[inception_3c_relu_3x3_out, inception_3c_relu_double_3x3_2_out, inception_3c_pool_out], 1)
inception_4a_1x1_out = self.inception_4a_1x1(inception_3c_output_out)
inception_4a_1x1_bn_out = self.inception_4a_1x1_bn(inception_4a_1x1_out)
inception_4a_relu_1x1_out = self.inception_4a_relu_1x1(inception_4a_1x1_bn_out)
inception_4a_3x3_reduce_out = self.inception_4a_3x3_reduce(inception_3c_output_out)
inception_4a_3x3_reduce_bn_out = self.inception_4a_3x3_reduce_bn(inception_4a_3x3_reduce_out)
inception_4a_relu_3x3_reduce_out = self.inception_4a_relu_3x3_reduce(inception_4a_3x3_reduce_bn_out)
inception_4a_3x3_out = self.inception_4a_3x3(inception_4a_relu_3x3_reduce_out)
inception_4a_3x3_bn_out = self.inception_4a_3x3_bn(inception_4a_3x3_out)
inception_4a_relu_3x3_out = self.inception_4a_relu_3x3(inception_4a_3x3_bn_out)
inception_4a_double_3x3_reduce_out = self.inception_4a_double_3x3_reduce(inception_3c_output_out)
inception_4a_double_3x3_reduce_bn_out = self.inception_4a_double_3x3_reduce_bn(
inception_4a_double_3x3_reduce_out)
inception_4a_relu_double_3x3_reduce_out = self.inception_4a_relu_double_3x3_reduce(
inception_4a_double_3x3_reduce_bn_out)
inception_4a_double_3x3_1_out = self.inception_4a_double_3x3_1(inception_4a_relu_double_3x3_reduce_out)
inception_4a_double_3x3_1_bn_out = self.inception_4a_double_3x3_1_bn(inception_4a_double_3x3_1_out)
inception_4a_relu_double_3x3_1_out = self.inception_4a_relu_double_3x3_1(inception_4a_double_3x3_1_bn_out)
inception_4a_double_3x3_2_out = self.inception_4a_double_3x3_2(inception_4a_relu_double_3x3_1_out)
inception_4a_double_3x3_2_bn_out = self.inception_4a_double_3x3_2_bn(inception_4a_double_3x3_2_out)
inception_4a_relu_double_3x3_2_out = self.inception_4a_relu_double_3x3_2(inception_4a_double_3x3_2_bn_out)
inception_4a_pool_out = self.inception_4a_pool(inception_3c_output_out)
inception_4a_pool_proj_out = self.inception_4a_pool_proj(inception_4a_pool_out)
inception_4a_pool_proj_bn_out = self.inception_4a_pool_proj_bn(inception_4a_pool_proj_out)
inception_4a_relu_pool_proj_out = self.inception_4a_relu_pool_proj(inception_4a_pool_proj_bn_out)
inception_4a_output_out = torch.cat(
[inception_4a_relu_1x1_out, inception_4a_relu_3x3_out, inception_4a_relu_double_3x3_2_out,
inception_4a_relu_pool_proj_out], 1)
inception_4b_1x1_out = self.inception_4b_1x1(inception_4a_output_out)
inception_4b_1x1_bn_out = self.inception_4b_1x1_bn(inception_4b_1x1_out)
inception_4b_relu_1x1_out = self.inception_4b_relu_1x1(inception_4b_1x1_bn_out)
inception_4b_3x3_reduce_out = self.inception_4b_3x3_reduce(inception_4a_output_out)
inception_4b_3x3_reduce_bn_out = self.inception_4b_3x3_reduce_bn(inception_4b_3x3_reduce_out)
inception_4b_relu_3x3_reduce_out = self.inception_4b_relu_3x3_reduce(inception_4b_3x3_reduce_bn_out)
inception_4b_3x3_out = self.inception_4b_3x3(inception_4b_relu_3x3_reduce_out)
inception_4b_3x3_bn_out = self.inception_4b_3x3_bn(inception_4b_3x3_out)
inception_4b_relu_3x3_out = self.inception_4b_relu_3x3(inception_4b_3x3_bn_out)
inception_4b_double_3x3_reduce_out = self.inception_4b_double_3x3_reduce(inception_4a_output_out)
inception_4b_double_3x3_reduce_bn_out = self.inception_4b_double_3x3_reduce_bn(
inception_4b_double_3x3_reduce_out)
inception_4b_relu_double_3x3_reduce_out = self.inception_4b_relu_double_3x3_reduce(
inception_4b_double_3x3_reduce_bn_out)
inception_4b_double_3x3_1_out = self.inception_4b_double_3x3_1(inception_4b_relu_double_3x3_reduce_out)
inception_4b_double_3x3_1_bn_out = self.inception_4b_double_3x3_1_bn(inception_4b_double_3x3_1_out)
inception_4b_relu_double_3x3_1_out = self.inception_4b_relu_double_3x3_1(inception_4b_double_3x3_1_bn_out)
inception_4b_double_3x3_2_out = self.inception_4b_double_3x3_2(inception_4b_relu_double_3x3_1_out)
inception_4b_double_3x3_2_bn_out = self.inception_4b_double_3x3_2_bn(inception_4b_double_3x3_2_out)
inception_4b_relu_double_3x3_2_out = self.inception_4b_relu_double_3x3_2(inception_4b_double_3x3_2_bn_out)
inception_4b_pool_out = self.inception_4b_pool(inception_4a_output_out)
inception_4b_pool_proj_out = self.inception_4b_pool_proj(inception_4b_pool_out)
inception_4b_pool_proj_bn_out = self.inception_4b_pool_proj_bn(inception_4b_pool_proj_out)
inception_4b_relu_pool_proj_out = self.inception_4b_relu_pool_proj(inception_4b_pool_proj_bn_out)
inception_4b_output_out = torch.cat(
[inception_4b_relu_1x1_out, inception_4b_relu_3x3_out, inception_4b_relu_double_3x3_2_out,
inception_4b_relu_pool_proj_out], 1)
inception_4c_1x1_out = self.inception_4c_1x1(inception_4b_output_out)
inception_4c_1x1_bn_out = self.inception_4c_1x1_bn(inception_4c_1x1_out)
inception_4c_relu_1x1_out = self.inception_4c_relu_1x1(inception_4c_1x1_bn_out)
inception_4c_3x3_reduce_out = self.inception_4c_3x3_reduce(inception_4b_output_out)
inception_4c_3x3_reduce_bn_out = self.inception_4c_3x3_reduce_bn(inception_4c_3x3_reduce_out)
inception_4c_relu_3x3_reduce_out = self.inception_4c_relu_3x3_reduce(inception_4c_3x3_reduce_bn_out)
inception_4c_3x3_out = self.inception_4c_3x3(inception_4c_relu_3x3_reduce_out)
inception_4c_3x3_bn_out = self.inception_4c_3x3_bn(inception_4c_3x3_out)
inception_4c_relu_3x3_out = self.inception_4c_relu_3x3(inception_4c_3x3_bn_out)
inception_4c_double_3x3_reduce_out = self.inception_4c_double_3x3_reduce(inception_4b_output_out)
inception_4c_double_3x3_reduce_bn_out = self.inception_4c_double_3x3_reduce_bn(
inception_4c_double_3x3_reduce_out)
inception_4c_relu_double_3x3_reduce_out = self.inception_4c_relu_double_3x3_reduce(
inception_4c_double_3x3_reduce_bn_out)
inception_4c_double_3x3_1_out = self.inception_4c_double_3x3_1(inception_4c_relu_double_3x3_reduce_out)
inception_4c_double_3x3_1_bn_out = self.inception_4c_double_3x3_1_bn(inception_4c_double_3x3_1_out)
inception_4c_relu_double_3x3_1_out = self.inception_4c_relu_double_3x3_1(inception_4c_double_3x3_1_bn_out)
inception_4c_double_3x3_2_out = self.inception_4c_double_3x3_2(inception_4c_relu_double_3x3_1_out)
inception_4c_double_3x3_2_bn_out = self.inception_4c_double_3x3_2_bn(inception_4c_double_3x3_2_out)
inception_4c_relu_double_3x3_2_out = self.inception_4c_relu_double_3x3_2(inception_4c_double_3x3_2_bn_out)
inception_4c_pool_out = self.inception_4c_pool(inception_4b_output_out)
inception_4c_pool_proj_out = self.inception_4c_pool_proj(inception_4c_pool_out)
inception_4c_pool_proj_bn_out = self.inception_4c_pool_proj_bn(inception_4c_pool_proj_out)
inception_4c_relu_pool_proj_out = self.inception_4c_relu_pool_proj(inception_4c_pool_proj_bn_out)
inception_4c_output_out = torch.cat(
[inception_4c_relu_1x1_out, inception_4c_relu_3x3_out, inception_4c_relu_double_3x3_2_out,
inception_4c_relu_pool_proj_out], 1)
inception_4d_1x1_out = self.inception_4d_1x1(inception_4c_output_out)
inception_4d_1x1_bn_out = self.inception_4d_1x1_bn(inception_4d_1x1_out)
inception_4d_relu_1x1_out = self.inception_4d_relu_1x1(inception_4d_1x1_bn_out)
inception_4d_3x3_reduce_out = self.inception_4d_3x3_reduce(inception_4c_output_out)
inception_4d_3x3_reduce_bn_out = self.inception_4d_3x3_reduce_bn(inception_4d_3x3_reduce_out)
inception_4d_relu_3x3_reduce_out = self.inception_4d_relu_3x3_reduce(inception_4d_3x3_reduce_bn_out)
inception_4d_3x3_out = self.inception_4d_3x3(inception_4d_relu_3x3_reduce_out)
inception_4d_3x3_bn_out = self.inception_4d_3x3_bn(inception_4d_3x3_out)
inception_4d_relu_3x3_out = self.inception_4d_relu_3x3(inception_4d_3x3_bn_out)
inception_4d_double_3x3_reduce_out = self.inception_4d_double_3x3_reduce(inception_4c_output_out)
inception_4d_double_3x3_reduce_bn_out = self.inception_4d_double_3x3_reduce_bn(
inception_4d_double_3x3_reduce_out)
inception_4d_relu_double_3x3_reduce_out = self.inception_4d_relu_double_3x3_reduce(
inception_4d_double_3x3_reduce_bn_out)
inception_4d_double_3x3_1_out = self.inception_4d_double_3x3_1(inception_4d_relu_double_3x3_reduce_out)
inception_4d_double_3x3_1_bn_out = self.inception_4d_double_3x3_1_bn(inception_4d_double_3x3_1_out)
inception_4d_relu_double_3x3_1_out = self.inception_4d_relu_double_3x3_1(inception_4d_double_3x3_1_bn_out)
inception_4d_double_3x3_2_out = self.inception_4d_double_3x3_2(inception_4d_relu_double_3x3_1_out)
inception_4d_double_3x3_2_bn_out = self.inception_4d_double_3x3_2_bn(inception_4d_double_3x3_2_out)
inception_4d_relu_double_3x3_2_out = self.inception_4d_relu_double_3x3_2(inception_4d_double_3x3_2_bn_out)
inception_4d_pool_out = self.inception_4d_pool(inception_4c_output_out)
inception_4d_pool_proj_out = self.inception_4d_pool_proj(inception_4d_pool_out)
inception_4d_pool_proj_bn_out = self.inception_4d_pool_proj_bn(inception_4d_pool_proj_out)
inception_4d_relu_pool_proj_out = self.inception_4d_relu_pool_proj(inception_4d_pool_proj_bn_out)
inception_4d_output_out = torch.cat(
[inception_4d_relu_1x1_out, inception_4d_relu_3x3_out, inception_4d_relu_double_3x3_2_out,
inception_4d_relu_pool_proj_out], 1)
inception_4e_3x3_reduce_out = self.inception_4e_3x3_reduce(inception_4d_output_out)
inception_4e_3x3_reduce_bn_out = self.inception_4e_3x3_reduce_bn(inception_4e_3x3_reduce_out)
inception_4e_relu_3x3_reduce_out = self.inception_4e_relu_3x3_reduce(inception_4e_3x3_reduce_bn_out)
inception_4e_3x3_out = self.inception_4e_3x3(inception_4e_relu_3x3_reduce_out)
inception_4e_3x3_bn_out = self.inception_4e_3x3_bn(inception_4e_3x3_out)
inception_4e_relu_3x3_out = self.inception_4e_relu_3x3(inception_4e_3x3_bn_out)
inception_4e_double_3x3_reduce_out = self.inception_4e_double_3x3_reduce(inception_4d_output_out)
inception_4e_double_3x3_reduce_bn_out = self.inception_4e_double_3x3_reduce_bn(
inception_4e_double_3x3_reduce_out)
inception_4e_relu_double_3x3_reduce_out = self.inception_4e_relu_double_3x3_reduce(
inception_4e_double_3x3_reduce_bn_out)
inception_4e_double_3x3_1_out = self.inception_4e_double_3x3_1(inception_4e_relu_double_3x3_reduce_out)
inception_4e_double_3x3_1_bn_out = self.inception_4e_double_3x3_1_bn(inception_4e_double_3x3_1_out)
inception_4e_relu_double_3x3_1_out = self.inception_4e_relu_double_3x3_1(inception_4e_double_3x3_1_bn_out)
inception_4e_double_3x3_2_out = self.inception_4e_double_3x3_2(inception_4e_relu_double_3x3_1_out)
inception_4e_double_3x3_2_bn_out = self.inception_4e_double_3x3_2_bn(inception_4e_double_3x3_2_out)
inception_4e_relu_double_3x3_2_out = self.inception_4e_relu_double_3x3_2(inception_4e_double_3x3_2_bn_out)
inception_4e_pool_out = self.inception_4e_pool(inception_4d_output_out)
inception_4e_output_out = torch.cat(
[inception_4e_relu_3x3_out, inception_4e_relu_double_3x3_2_out, inception_4e_pool_out], 1)
inception_5a_1x1_out = self.inception_5a_1x1(inception_4e_output_out)
inception_5a_1x1_bn_out = self.inception_5a_1x1_bn(inception_5a_1x1_out)
inception_5a_relu_1x1_out = self.inception_5a_relu_1x1(inception_5a_1x1_bn_out)
inception_5a_3x3_reduce_out = self.inception_5a_3x3_reduce(inception_4e_output_out)
inception_5a_3x3_reduce_bn_out = self.inception_5a_3x3_reduce_bn(inception_5a_3x3_reduce_out)
inception_5a_relu_3x3_reduce_out = self.inception_5a_relu_3x3_reduce(inception_5a_3x3_reduce_bn_out)
inception_5a_3x3_out = self.inception_5a_3x3(inception_5a_relu_3x3_reduce_out)
inception_5a_3x3_bn_out = self.inception_5a_3x3_bn(inception_5a_3x3_out)
inception_5a_relu_3x3_out = self.inception_5a_relu_3x3(inception_5a_3x3_bn_out)
inception_5a_double_3x3_reduce_out = self.inception_5a_double_3x3_reduce(inception_4e_output_out)
inception_5a_double_3x3_reduce_bn_out = self.inception_5a_double_3x3_reduce_bn(
inception_5a_double_3x3_reduce_out)
inception_5a_relu_double_3x3_reduce_out = self.inception_5a_relu_double_3x3_reduce(
inception_5a_double_3x3_reduce_bn_out)
inception_5a_double_3x3_1_out = self.inception_5a_double_3x3_1(inception_5a_relu_double_3x3_reduce_out)
inception_5a_double_3x3_1_bn_out = self.inception_5a_double_3x3_1_bn(inception_5a_double_3x3_1_out)
inception_5a_relu_double_3x3_1_out = self.inception_5a_relu_double_3x3_1(inception_5a_double_3x3_1_bn_out)
inception_5a_double_3x3_2_out = self.inception_5a_double_3x3_2(inception_5a_relu_double_3x3_1_out)
inception_5a_double_3x3_2_bn_out = self.inception_5a_double_3x3_2_bn(inception_5a_double_3x3_2_out)
inception_5a_relu_double_3x3_2_out = self.inception_5a_relu_double_3x3_2(inception_5a_double_3x3_2_bn_out)
inception_5a_pool_out = self.inception_5a_pool(inception_4e_output_out)
inception_5a_pool_proj_out = self.inception_5a_pool_proj(inception_5a_pool_out)
inception_5a_pool_proj_bn_out = self.inception_5a_pool_proj_bn(inception_5a_pool_proj_out)
inception_5a_relu_pool_proj_out = self.inception_5a_relu_pool_proj(inception_5a_pool_proj_bn_out)
inception_5a_output_out = torch.cat(
[inception_5a_relu_1x1_out, inception_5a_relu_3x3_out, inception_5a_relu_double_3x3_2_out,
inception_5a_relu_pool_proj_out], 1)
inception_5b_1x1_out = self.inception_5b_1x1(inception_5a_output_out)
inception_5b_1x1_bn_out = self.inception_5b_1x1_bn(inception_5b_1x1_out)
inception_5b_relu_1x1_out = self.inception_5b_relu_1x1(inception_5b_1x1_bn_out)
inception_5b_3x3_reduce_out = self.inception_5b_3x3_reduce(inception_5a_output_out)
inception_5b_3x3_reduce_bn_out = self.inception_5b_3x3_reduce_bn(inception_5b_3x3_reduce_out)
inception_5b_relu_3x3_reduce_out = self.inception_5b_relu_3x3_reduce(inception_5b_3x3_reduce_bn_out)
inception_5b_3x3_out = self.inception_5b_3x3(inception_5b_relu_3x3_reduce_out)
inception_5b_3x3_bn_out = self.inception_5b_3x3_bn(inception_5b_3x3_out)
inception_5b_relu_3x3_out = self.inception_5b_relu_3x3(inception_5b_3x3_bn_out)
inception_5b_double_3x3_reduce_out = self.inception_5b_double_3x3_reduce(inception_5a_output_out)
inception_5b_double_3x3_reduce_bn_out = self.inception_5b_double_3x3_reduce_bn(
inception_5b_double_3x3_reduce_out)
inception_5b_relu_double_3x3_reduce_out = self.inception_5b_relu_double_3x3_reduce(
inception_5b_double_3x3_reduce_bn_out)
inception_5b_double_3x3_1_out = self.inception_5b_double_3x3_1(inception_5b_relu_double_3x3_reduce_out)
inception_5b_double_3x3_1_bn_out = self.inception_5b_double_3x3_1_bn(inception_5b_double_3x3_1_out)
inception_5b_relu_double_3x3_1_out = self.inception_5b_relu_double_3x3_1(inception_5b_double_3x3_1_bn_out)
inception_5b_double_3x3_2_out = self.inception_5b_double_3x3_2(inception_5b_relu_double_3x3_1_out)
inception_5b_double_3x3_2_bn_out = self.inception_5b_double_3x3_2_bn(inception_5b_double_3x3_2_out)
inception_5b_relu_double_3x3_2_out = self.inception_5b_relu_double_3x3_2(inception_5b_double_3x3_2_bn_out)
inception_5b_pool_out = self.inception_5b_pool(inception_5a_output_out)
inception_5b_pool_proj_out = self.inception_5b_pool_proj(inception_5b_pool_out)
inception_5b_pool_proj_bn_out = self.inception_5b_pool_proj_bn(inception_5b_pool_proj_out)
inception_5b_relu_pool_proj_out = self.inception_5b_relu_pool_proj(inception_5b_pool_proj_bn_out)
inception_5b_output_out = torch.cat(
[inception_5b_relu_1x1_out, inception_5b_relu_3x3_out, inception_5b_relu_double_3x3_2_out,
inception_5b_relu_pool_proj_out], 1)
return inception_5b_output_out
def logits(self, features):
x = F.adaptive_max_pool2d(features, output_size=1)
x = x.view(x.size(0), -1)
return x
def forward(self, input):
x = self.features(input)
x = self.logits(x)
return x
def load_param(self, model_path):
param_dict = torch.load(model_path)
for i in param_dict:
if 'last_linear' in i:
continue
self.state_dict()[i].copy_(param_dict[i])

View File

@ -0,0 +1,8 @@
from ret_benchmark.modeling.registry import BACKBONES
from .bninception import BNInception
def build_backbone(cfg):
assert cfg.MODEL.BACKBONE.NAME in BACKBONES, f"backbone {cfg.MODEL.BACKBONE} is not defined"
return BACKBONES[cfg.MODEL.BACKBONE.NAME]()

View File

@ -0,0 +1,35 @@
# Copyright (c) Malong Technologies Co., Ltd.
# All rights reserved.
#
# Contact: github@malong.com
#
# This source code is licensed under the LICENSE file in the root directory of this source tree.
import os
from collections import OrderedDict
import torch
from torch.nn.modules import Sequential
from .backbone import build_backbone
from .heads import build_head
def build_model(cfg):
backbone = build_backbone(cfg)
head = build_head(cfg)
model = Sequential(OrderedDict([
('backbone', backbone),
('head', head)
]))
if cfg.MODEL.PRETRAIN == 'imagenet':
print('Loading imagenet pretrianed model ...')
pretrained_path = os.path.expanduser(cfg.MODEL.PRETRIANED_PATH[cfg.MODEL.BACKBONE.NAME])
model.backbone.load_param(pretrained_path)
elif os.path.exists(cfg.MODEL.PRETRAIN):
ckp = torch.load(cfg.MODEL.PRETRAIN)
model.load_state_dict(ckp['model_state_dict'])
return model

View File

@ -0,0 +1,8 @@
# Copyright (c) Malong Technologies Co., Ltd.
# All rights reserved.
#
# Contact: github@malong.com
#
# This source code is licensed under the LICENSE file in the root directory of this source tree.
from .build import build_head

View File

@ -0,0 +1,15 @@
# Copyright (c) Malong Technologies Co., Ltd.
# All rights reserved.
#
# Contact: github@malong.com
#
# This source code is licensed under the LICENSE file in the root directory of this source tree.
from ret_benchmark.modeling.registry import HEADS
from .linear_norm import LinearNorm
def build_head(cfg):
assert cfg.MODEL.HEAD.NAME in HEADS, f"head {cfg.MODEL.HEAD.NAME} is not defined"
return HEADS[cfg.MODEL.HEAD.NAME](cfg, in_channels=1024)

View File

@ -0,0 +1,25 @@
# Copyright (c) Malong Technologies Co., Ltd.
# All rights reserved.
#
# Contact: github@malong.com
#
# This source code is licensed under the LICENSE file in the root directory of this source tree.
import torch
from torch import nn
from ret_benchmark.modeling.registry import HEADS
from ret_benchmark.utils.init_methods import weights_init_kaiming
@HEADS.register('linear_norm')
class LinearNorm(nn.Module):
def __init__(self, cfg, in_channels):
super(LinearNorm, self).__init__()
self.fc = nn.Linear(in_channels, cfg.MODEL.HEAD.DIM)
self.fc.apply(weights_init_kaiming)
def forward(self, x):
x = self.fc(x)
x = nn.functional.normalize(x, p=2, dim=1)
return x

View File

@ -0,0 +1,12 @@
# Copyright (c) Malong Technologies Co., Ltd.
# All rights reserved.
#
# Contact: github@malong.com
#
# This source code is licensed under the LICENSE file in the root directory of this source tree.
from ret_benchmark.utils.registry import Registry
BACKBONES = Registry()
HEADS = Registry()

View File

@ -0,0 +1,4 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
from .build import build_optimizer
from .build import build_lr_scheduler
from .lr_scheduler import WarmupMultiStepLR

View File

@ -0,0 +1,29 @@
import torch
from .lr_scheduler import WarmupMultiStepLR
def build_optimizer(cfg, model):
params = []
for key, value in model.named_parameters():
if not value.requires_grad:
continue
lr_mul = 1.0
if "backbone" in key:
lr_mul = 0.1
params += [{"params": [value], "lr_mul": lr_mul}]
optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params,
lr=cfg.SOLVER.BASE_LR,
weight_decay=cfg.SOLVER.WEIGHT_DECAY)
return optimizer
def build_lr_scheduler(cfg, optimizer):
return WarmupMultiStepLR(
optimizer,
cfg.SOLVER.STEPS,
cfg.SOLVER.GAMMA,
warmup_factor=cfg.SOLVER.WARMUP_FACTOR,
warmup_iters=cfg.SOLVER.WARMUP_ITERS,
warmup_method=cfg.SOLVER.WARMUP_METHOD,
)

View File

@ -0,0 +1,49 @@
from bisect import bisect_right
import torch
class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler):
def __init__(
self,
optimizer,
milestones,
gamma=0.1,
warmup_factor=1.0 / 3,
warmup_iters=500,
warmup_method="linear",
last_epoch=-1,
):
if not list(milestones) == sorted(milestones):
raise ValueError(
"Milestones should be a list of" " increasing integers. Got {}",
milestones,
)
if warmup_method not in ("constant", "linear"):
raise ValueError(
"Only 'constant' or 'linear' warmup_method accepted"
"got {}".format(warmup_method)
)
self.milestones = milestones
self.gamma = gamma
self.warmup_factor = warmup_factor
self.warmup_iters = warmup_iters
self.warmup_method = warmup_method
super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch)
def get_lr(self):
warmup_factor = 1
if self.last_epoch < self.warmup_iters:
if self.warmup_method == "constant":
warmup_factor = self.warmup_factor
elif self.warmup_method == "linear":
alpha = float(self.last_epoch) / self.warmup_iters
warmup_factor = self.warmup_factor * (1 - alpha) + alpha
return [
base_lr * warmup_factor * self.gamma ** bisect_right(
self.milestones,
self.last_epoch
)
for base_lr in self.base_lrs
]

View File

@ -0,0 +1,89 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import logging
import os
import torch
from ret_benchmark.utils.model_serialization import load_state_dict
class Checkpointer(object):
def __init__(
self,
model,
optimizer=None,
scheduler=None,
save_dir="",
save_to_disk=None,
logger=None,
):
self.model = model
self.optimizer = optimizer
self.scheduler = scheduler
self.save_dir = save_dir
self.save_to_disk = save_to_disk
if logger is None:
logger = logging.getLogger(__name__)
self.logger = logger
def save(self, name):
if not self.save_dir:
return
data = {}
data["model"] = self.model.state_dict()
if self.optimizer is not None:
data["optimizer"] = self.optimizer.state_dict()
if self.scheduler is not None:
data["scheduler"] = self.scheduler.state_dict()
save_file = os.path.join(self.save_dir, "{}.pth".format(name))
self.logger.info("Saving checkpoint to {}".format(save_file))
torch.save(data, save_file)
def load(self, f=None):
if self.has_checkpoint():
# override argument with existing checkpoint
f = self.get_checkpoint_file()
if not f:
# no checkpoint could be found
self.logger.info("No checkpoint found. Initializing model from scratch")
return {}
self.logger.info("Loading checkpoint from {}".format(f))
checkpoint = self._load_file(f)
self._load_model(checkpoint)
if "optimizer" in checkpoint and self.optimizer:
self.logger.info("Loading optimizer from {}".format(f))
self.optimizer.load_state_dict(checkpoint.pop("optimizer"))
if "scheduler" in checkpoint and self.scheduler:
self.logger.info("Loading scheduler from {}".format(f))
self.scheduler.load_state_dict(checkpoint.pop("scheduler"))
# return any further checkpoint data
return checkpoint
def has_checkpoint(self):
save_file = os.path.join(self.save_dir, "last_checkpoint")
return os.path.exists(save_file)
def get_checkpoint_file(self):
save_file = os.path.join(self.save_dir, "last_checkpoint")
try:
with open(save_file, "r") as f:
last_saved = f.read()
last_saved = last_saved.strip()
except IOError:
# if file doesn't exist, maybe because it has just been
# deleted by a separate process
last_saved = ""
return last_saved
def tag_last_checkpoint(self, last_filename):
save_file = os.path.join(self.save_dir, "last_checkpoint")
with open(save_file, "w") as f:
f.write(last_filename)
def _load_file(self, f):
return torch.load(f, map_location=torch.device("cpu"))
def _load_model(self, checkpoint):
load_state_dict(self.model, checkpoint.pop("model"))

View File

@ -0,0 +1,29 @@
from __future__ import (absolute_import, division, print_function,
unicode_literals)
import copy
import os
from ret_benchmark.config import cfg as g_cfg
def get_config_root_path():
''' Path to configs for unit tests '''
# cur_file_dir is root/tests/env_tests
cur_file_dir = os.path.dirname(os.path.abspath(os.path.realpath(__file__)))
ret = os.path.dirname(os.path.dirname(cur_file_dir))
ret = os.path.join(ret, "configs")
return ret
def load_config(rel_path):
''' Load config from file path specified as path relative to config_root '''
cfg_path = os.path.join(get_config_root_path(), rel_path)
return load_config_from_file(cfg_path)
def load_config_from_file(file_path):
''' Load config from file path specified as absolute path '''
ret = copy.deepcopy(g_cfg)
ret.merge_from_file(file_path)
return ret

View File

@ -0,0 +1,27 @@
# Copyright (c) Malong Technologies Co., Ltd.
# All rights reserved.
#
# Contact: github@malong.com
#
# This source code is licensed under the LICENSE file in the root directory of this source tree.
import torch
import numpy as np
def feat_extractor(model, data_loader, logger=None):
model.eval()
feats = list()
for i, batch in enumerate(data_loader):
imgs = batch[0].cuda()
with torch.no_grad():
out = model(imgs).data.cpu().numpy()
feats.append(out)
if logger is not None and (i + 1) % 100 == 0:
logger.debug(f'Extract Features: [{i + 1}/{len(data_loader)}]')
del out
feats = np.vstack(feats)
return feats

View File

@ -0,0 +1,14 @@
# Copyright (c) Malong Technologies Co., Ltd.
# All rights reserved.
#
# Contact: github@malong.com
#
# This source code is licensed under the LICENSE file in the root directory of this source tree.
# Batch Norm Freezer
# Note: adds an additional 2% improvement on CUB (on others benchmarks, it brings no effect)
def set_bn_eval(m):
classname = m.__class__.__name__
if classname.find('BatchNorm') != -1:
m.eval()

View File

@ -0,0 +1,21 @@
import os.path as osp
from PIL import Image
def read_image(img_path, mode='RGB'):
"""Keep reading image until succeed.
This can avoid IOError incurred by heavy IO process."""
got_img = False
if not osp.exists(img_path):
raise IOError(f"{img_path} does not exist")
while not got_img:
try:
img = Image.open(img_path).convert("RGB")
if mode == "BGR":
r, g, b = img.split()
img = Image.merge("RGB", (b, g, r))
got_img = True
except IOError:
print(f"IOError incurred when reading '{img_path}'. Will redo.")
pass
return img

View File

@ -0,0 +1,32 @@
# Copyright (c) Malong Technologies Co., Ltd.
# All rights reserved.
#
# Contact: github@malong.com
#
# This source code is licensed under the LICENSE file in the root directory of this source tree.
import torch
from torch import nn
def weights_init_kaiming(m):
classname = m.__class__.__name__
if classname.find('Linear') != -1:
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_out')
nn.init.constant_(m.bias, 0.0)
elif classname.find('Conv') != -1:
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in')
if m.bias is not None:
nn.init.constant_(m.bias, 0.0)
elif classname.find('BatchNorm') != -1:
if m.affine:
nn.init.constant_(m.weight, 1.0)
nn.init.constant_(m.bias, 0.0)
def weights_init_classifier(m):
classname = m.__class__.__name__
if classname.find('Linear') != -1:
nn.init.normal_(m.weight, std=0.001)
if m.bias is not None:
nn.init.constant_(m.bias, 0.0)

View File

@ -0,0 +1,32 @@
# Copyright (c) Malong Technologies Co., Ltd.
# All rights reserved.
#
# Contact: github@malong.com
#
# This source code is licensed under the LICENSE file in the root directory of this source tree.
import os
import sys
import logging
_streams = {
"stdout": sys.stdout
}
def setup_logger(name: str, level: int, stream: str = "stdout") -> logging.Logger:
global _streams
if stream not in _streams:
log_folder = os.path.dirname(stream)
os.makedirs(log_folder, exist_ok=True)
_streams[stream] = open(stream, 'w')
logger = logging.getLogger(name)
logger.propagate = False
logger.setLevel(level)
sh = logging.StreamHandler(stream=_streams[stream])
sh.setLevel(level)
formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s")
sh.setFormatter(formatter)
logger.addHandler(sh)
return logger

View File

@ -0,0 +1,66 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
from collections import defaultdict
from collections import deque
import torch
class SmoothedValue(object):
"""Track a series of values and provide access to smoothed values over a
window or the global series average.
"""
def __init__(self, window_size=20):
self.deque = deque(maxlen=window_size)
self.series = []
self.total = 0.0
self.count = 0
def update(self, value):
self.deque.append(value)
self.series.append(value)
self.count += 1
self.total += value
@property
def median(self):
d = torch.tensor(list(self.deque))
return d.median().item()
@property
def avg(self):
d = torch.tensor(list(self.deque))
return d.mean().item()
@property
def global_avg(self):
return self.total / self.count
class MetricLogger(object):
def __init__(self, delimiter="\t"):
self.meters = defaultdict(SmoothedValue)
self.delimiter = delimiter
def update(self, **kwargs):
for k, v in kwargs.items():
if isinstance(v, torch.Tensor):
v = v.item()
assert isinstance(v, (float, int))
self.meters[k].update(v)
def __getattr__(self, attr):
if attr in self.meters:
return self.meters[attr]
if attr in self.__dict__:
return self.__dict__[attr]
raise AttributeError("'{}' object has no attribute '{}'".format(
type(self).__name__, attr))
def __str__(self):
loss_str = []
for name, meter in self.meters.items():
loss_str.append(
"{}: {:.4f} ({:.4f})".format(name, meter.median, meter.global_avg)
)
return self.delimiter.join(loss_str)

View File

@ -0,0 +1,78 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
from collections import OrderedDict
import logging
import torch
def align_and_update_state_dicts(model_state_dict, loaded_state_dict):
"""
Strategy: suppose that the models that we will create will have prefixes appended
to each of its keys, for example due to an extra level of nesting that the original
pre-trained weights from ImageNet won't contain. For example, model.state_dict()
might return backbone[0].body.res2.conv1.weight, while the pre-trained model contains
res2.conv1.weight. We thus want to match both parameters together.
For that, we look for each model weight, look among all loaded keys if there is one
that is a suffix of the current weight name, and use it if that's the case.
If multiple matches exist, take the one with longest size
of the corresponding name. For example, for the same model as before, the pretrained
weight file can contain both res2.conv1.weight, as well as conv1.weight. In this case,
we want to match backbone[0].body.conv1.weight to conv1.weight, and
backbone[0].body.res2.conv1.weight to res2.conv1.weight.
"""
current_keys = sorted(list(model_state_dict.keys()))
loaded_keys = sorted(list(loaded_state_dict.keys()))
# get a matrix of string matches, where each (i, j) entry correspond to the size of the
# loaded_key string, if it matches
match_matrix = [
len(j) if i.endswith(j) else 0 for i in current_keys for j in loaded_keys
]
match_matrix = torch.as_tensor(match_matrix).view(
len(current_keys), len(loaded_keys)
)
max_match_size, idxs = match_matrix.max(1)
# remove indices that correspond to no-match
idxs[max_match_size == 0] = -1
# used for logging
max_size = max([len(key) for key in current_keys]) if current_keys else 1
max_size_loaded = max([len(key) for key in loaded_keys]) if loaded_keys else 1
log_str_template = "{: <{}} loaded from {: <{}} of shape {}"
logger = logging.getLogger(__name__)
for idx_new, idx_old in enumerate(idxs.tolist()):
if idx_old == -1:
continue
key = current_keys[idx_new]
key_old = loaded_keys[idx_old]
model_state_dict[key] = loaded_state_dict[key_old]
logger.info(
log_str_template.format(
key,
max_size,
key_old,
max_size_loaded,
tuple(loaded_state_dict[key_old].shape),
)
)
def strip_prefix_if_present(state_dict, prefix):
keys = sorted(state_dict.keys())
if not all(key.startswith(prefix) for key in keys):
return state_dict
stripped_state_dict = OrderedDict()
for key, value in state_dict.items():
stripped_state_dict[key.replace(prefix, "")] = value
return stripped_state_dict
def load_state_dict(model, loaded_state_dict):
model_state_dict = model.state_dict()
# if the state_dict comes from a model that was wrapped in a
# DataParallel or DistributedDataParallel during serialization,
# remove the "module" prefix before performing the matching
loaded_state_dict = strip_prefix_if_present(loaded_state_dict, prefix="module.")
align_and_update_state_dicts(model_state_dict, loaded_state_dict)
# use strict loading
model.load_state_dict(model_state_dict)

View File

@ -0,0 +1,46 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
def _register_generic(module_dict, module_name, module):
assert module_name not in module_dict
module_dict[module_name] = module
class Registry(dict):
'''
A helper class for managing registering modules, it extends a dictionary
and provides a register functions.
Eg. creeting a registry:
some_registry = Registry({"default": default_module})
There're two ways of registering new modules:
1): normal way is just calling register function:
def foo():
...
some_registry.register("foo_module", foo)
2): used as decorator when declaring the module:
@some_registry.register("foo_module")
@some_registry.register("foo_modeul_nickname")
def foo():
...
Access of module is just like using a dictionary, eg:
f = some_registry["foo_modeul"]
'''
def __init__(self, *args, **kwargs):
super(Registry, self).__init__(*args, **kwargs)
def register(self, module_name, module=None):
# used as function call
if module is not None:
_register_generic(self, module_name, module)
return
# used as decorator
def register_fn(fn):
_register_generic(self, module_name, fn)
return fn
return register_fn

1
run_cub.sh Normal file
View File

@ -0,0 +1 @@
CUDA_VISIBLE_DEVICES=0 python3.6 tools/main.py --cfg configs/example.yaml

25
setup.py Normal file
View File

@ -0,0 +1,25 @@
# Copyright (c) Malong Technologies Co., Ltd.
# All rights reserved.
#
# Contact: github@malong.com
#
# This source code is licensed under the LICENSE file in the root directory of this source tree.
import torch
from setuptools import find_packages, setup
from torch.utils.cpp_extension import CppExtension
requirements = ["torch", "torchvision"]
setup(
name="ret_benchmark",
version="0.1",
author="Malong Technologies",
url="https://github.com/MalongTech/research-ms-loss",
description="ms-loss",
packages=find_packages(exclude=("configs", "tests")),
install_requires=requirements,
cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension},
)

78
tools/main.py Normal file
View File

@ -0,0 +1,78 @@
# Copyright (c) Malong Technologies Co., Ltd.
# All rights reserved.
#
# Contact: github@malong.com
#
# This source code is licensed under the LICENSE file in the root directory of this source tree.
import argparse
import torch
from ret_benchmark.config import cfg
from ret_benchmark.data import build_data
from ret_benchmark.engine.trainer import do_train
from ret_benchmark.losses import build_loss
from ret_benchmark.modeling import build_model
from ret_benchmark.solver import build_lr_scheduler, build_optimizer
from ret_benchmark.utils.logger import setup_logger
from ret_benchmark.utils.checkpoint import Checkpointer
def train(cfg):
logger = setup_logger(name='Train', level=cfg.LOGGER.LEVEL)
logger.info(cfg)
model = build_model(cfg)
device = torch.device(cfg.MODEL.DEVICE)
model.to(device)
criterion = build_loss(cfg)
optimizer = build_optimizer(cfg, model)
scheduler = build_lr_scheduler(cfg, optimizer)
train_loader = build_data(cfg, is_train=True)
val_loader = build_data(cfg, is_train=False)
logger.info(train_loader.dataset)
logger.info(val_loader.dataset)
arguments = dict()
arguments["iteration"] = 0
checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD
checkpointer = Checkpointer(model, optimizer, scheduler, cfg.SAVE_DIR)
do_train(
cfg,
model,
train_loader,
val_loader,
optimizer,
scheduler,
criterion,
checkpointer,
device,
checkpoint_period,
arguments,
logger
)
def parse_args():
"""
Parse input arguments
"""
parser = argparse.ArgumentParser(description='Train a retrieval network')
parser.add_argument(
'--cfg',
dest='cfg_file',
help='config file',
default=None,
type=str)
return parser.parse_args()
if __name__ == '__main__':
args = parse_args()
cfg.merge_from_file(args.cfg_file)
train(cfg)