Initial commit
commit
38aae447f5
|
@ -0,0 +1,80 @@
|
|||
# Code of Conduct
|
||||
|
||||
## Our Pledge
|
||||
|
||||
In the interest of fostering an open and welcoming environment, we as
|
||||
contributors and maintainers pledge to make participation in our project and
|
||||
our community a harassment-free experience for everyone, regardless of age, body
|
||||
size, disability, ethnicity, sex characteristics, gender identity and expression,
|
||||
level of experience, education, socio-economic status, nationality, personal
|
||||
appearance, race, religion, or sexual identity and orientation.
|
||||
|
||||
## Our Standards
|
||||
|
||||
Examples of behavior that contributes to creating a positive environment
|
||||
include:
|
||||
|
||||
* Using welcoming and inclusive language
|
||||
* Being respectful of differing viewpoints and experiences
|
||||
* Gracefully accepting constructive criticism
|
||||
* Focusing on what is best for the community
|
||||
* Showing empathy towards other community members
|
||||
|
||||
Examples of unacceptable behavior by participants include:
|
||||
|
||||
* The use of sexualized language or imagery and unwelcome sexual attention or
|
||||
advances
|
||||
* Trolling, insulting/derogatory comments, and personal or political attacks
|
||||
* Public or private harassment
|
||||
* Publishing others' private information, such as a physical or electronic
|
||||
address, without explicit permission
|
||||
* Other conduct which could reasonably be considered inappropriate in a
|
||||
professional setting
|
||||
|
||||
## Our Responsibilities
|
||||
|
||||
Project maintainers are responsible for clarifying the standards of acceptable
|
||||
behavior and are expected to take appropriate and fair corrective action in
|
||||
response to any instances of unacceptable behavior.
|
||||
|
||||
Project maintainers have the right and responsibility to remove, edit, or
|
||||
reject comments, commits, code, wiki edits, issues, and other contributions
|
||||
that are not aligned to this Code of Conduct, or to ban temporarily or
|
||||
permanently any contributor for other behaviors that they deem inappropriate,
|
||||
threatening, offensive, or harmful.
|
||||
|
||||
## Scope
|
||||
|
||||
This Code of Conduct applies within all project spaces, and it also applies when
|
||||
an individual is representing the project or its community in public spaces.
|
||||
Examples of representing a project or community include using an official
|
||||
project e-mail address, posting via an official social media account, or acting
|
||||
as an appointed representative at an online or offline event. Representation of
|
||||
a project may be further defined and clarified by project maintainers.
|
||||
|
||||
This Code of Conduct also applies outside the project spaces when there is a
|
||||
reasonable belief that an individual's behavior may have a negative impact on
|
||||
the project or its community.
|
||||
|
||||
## Enforcement
|
||||
|
||||
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
||||
reported by contacting the project team at <opensource-conduct@fb.com>. All
|
||||
complaints will be reviewed and investigated and will result in a response that
|
||||
is deemed necessary and appropriate to the circumstances. The project team is
|
||||
obligated to maintain confidentiality with regard to the reporter of an incident.
|
||||
Further details of specific enforcement policies may be posted separately.
|
||||
|
||||
Project maintainers who do not follow or enforce the Code of Conduct in good
|
||||
faith may face temporary or permanent repercussions as determined by other
|
||||
members of the project's leadership.
|
||||
|
||||
## Attribution
|
||||
|
||||
This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
|
||||
available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
|
||||
|
||||
[homepage]: https://www.contributor-covenant.org
|
||||
|
||||
For answers to common questions about this code of conduct, see
|
||||
https://www.contributor-covenant.org/faq
|
|
@ -0,0 +1,210 @@
|
|||
## MoCo v3 Reference Setups and Models
|
||||
|
||||
Here we document the reference commands for pre-training and evaluating various MoCo v3 models.
|
||||
|
||||
### ResNet-50 models
|
||||
|
||||
With batch 4096, the training of all ResNet-50 models can fit into 2 nodes with a total of 16 Volta 32G GPUs.
|
||||
|
||||
<details>
|
||||
<summary>ResNet-50, 100-epoch pre-training.</summary>
|
||||
|
||||
On the first node, run:
|
||||
```
|
||||
python main_moco.py \
|
||||
--moco-m-cos --crop-min=.2 \
|
||||
--dist-url 'tcp://[your first node address]:[specified port]' \
|
||||
--multiprocessing-distributed --world-size 2 --rank 0 \
|
||||
[your imagenet-folder with train and val folders]
|
||||
```
|
||||
On the second node, run the same command with `--rank 1`.
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>ResNet-50, 300-epoch pre-training.</summary>
|
||||
|
||||
On the first node, run:
|
||||
```
|
||||
python main_moco.py \
|
||||
--lr=.3 --epochs=300 \
|
||||
--moco-m-cos --crop-min=.2 \
|
||||
--dist-url 'tcp://[your first node address]:[specified port]' \
|
||||
--multiprocessing-distributed --world-size 2 --rank 0 \
|
||||
[your imagenet-folder with train and val folders]
|
||||
```
|
||||
On the second node, run the same command with `--rank 1`.
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>ResNet-50, 1000-epoch pre-training.</summary>
|
||||
|
||||
On the first node, run:
|
||||
```
|
||||
python main_moco.py \
|
||||
--lr=.3 --wd=1.5e-6 --epochs=1000 \
|
||||
--moco-m=0.996 --moco-m-cos --crop-min=.2 \
|
||||
--dist-url 'tcp://[your first node address]:[specified port]' \
|
||||
--multiprocessing-distributed --world-size 2 --rank 0 \
|
||||
[your imagenet-folder with train and val folders]
|
||||
```
|
||||
On the second node, run the same command with `--rank 1`.
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>ResNet-50, linear classification.</summary>
|
||||
|
||||
Run on single node:
|
||||
```
|
||||
python main_lincls.py \
|
||||
--dist-url 'tcp://localhost:10001' \
|
||||
--multiprocessing-distributed --world-size 1 --rank 0 \
|
||||
--pretrained [your checkpoint path]/[your checkpoint file].pth.tar \
|
||||
[your imagenet-folder with train and val folders]
|
||||
```
|
||||
</details>
|
||||
|
||||
Below are our pre-trained ResNet-50 models and logs.
|
||||
|
||||
<table><tbody>
|
||||
<!-- START TABLE -->
|
||||
<!-- TABLE HEADER -->
|
||||
<th valign="center">pretrain<br/>epochs</th>
|
||||
<th valign="center">linear<br/>acc</th>
|
||||
<th valign="center">pretrain<br/>files</th>
|
||||
<th valign="center">linear<br/>files</th>
|
||||
<!-- TABLE BODY -->
|
||||
<tr>
|
||||
<td align="right">100</td>
|
||||
<td align="center">68.9</td>
|
||||
<td align="center"><a href="https://dl.fbaipublicfiles.com/moco-v3/r-50-100ep/r-50-100ep.pth.tar">chpt</a></td>
|
||||
<td align="center"><a href="https://dl.fbaipublicfiles.com/moco-v3/r-50-100ep/linear-100ep.pth.tar">chpt</a> /
|
||||
<a href="https://dl.fbaipublicfiles.com/moco-v3/r-50-100ep/linear-100ep.std">log</a></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="right">300</td>
|
||||
<td align="center">72.8</td>
|
||||
<td align="center"><a href="https://dl.fbaipublicfiles.com/moco-v3/r-50-300ep/r-50-300ep.pth.tar">chpt</a></td>
|
||||
<td align="center"><a href="https://dl.fbaipublicfiles.com/moco-v3/r-50-300ep/linear-300ep.pth.tar">chpt</a> /
|
||||
<a href="https://dl.fbaipublicfiles.com/moco-v3/r-50-300ep/linear-300ep.std">log</a></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="right">1000</td>
|
||||
<td align="center">74.6</td>
|
||||
<td align="center"><a href="https://dl.fbaipublicfiles.com/moco-v3/r-50-1000ep/r-50-1000ep.pth.tar">chpt</a></td>
|
||||
<td align="center"><a href="https://dl.fbaipublicfiles.com/moco-v3/r-50-1000ep/linear-1000ep.pth.tar">chpt</a> /
|
||||
<a href="https://dl.fbaipublicfiles.com/moco-v3/r-50-1000ep/linear-1000ep.std">log</a></td>
|
||||
</tr>
|
||||
</tbody></table>
|
||||
|
||||
|
||||
### ViT Models
|
||||
|
||||
All ViT models are pre-trained for 300 epochs with AdamW.
|
||||
|
||||
<details>
|
||||
<summary>ViT-Small, 1-node (8-GPU), 1024-batch pre-training.</summary>
|
||||
|
||||
This setup fits into a single node of 8 Volta 32G GPUs, for ease of debugging.
|
||||
```
|
||||
python main_moco.py \
|
||||
-a vit_small -b 1024 \
|
||||
--optimizer=adamw --lr=1.5e-4 --weight-decay=.1 \
|
||||
--epochs=300 --warmup-epochs=40 \
|
||||
--stop-grad-conv1 --moco-m-cos --moco-t=.2 \
|
||||
--dist-url 'tcp://localhost:10001' \
|
||||
--multiprocessing-distributed --world-size 1 --rank 0 \
|
||||
[your imagenet-folder with train and val folders]
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>ViT-Small, 4-node (32-GPU) pre-training.</summary>
|
||||
|
||||
On the first node, run:
|
||||
```
|
||||
python main_moco.py \
|
||||
-a vit_small \
|
||||
--optimizer=adamw --lr=1.5e-4 --weight-decay=.1 \
|
||||
--epochs=300 --warmup-epochs=40 \
|
||||
--stop-grad-conv1 --moco-m-cos --moco-t=.2 \
|
||||
--dist-url 'tcp://[your first node address]:[specified port]' \
|
||||
--multiprocessing-distributed --world-size 8 --rank 0 \
|
||||
[your imagenet-folder with train and val folders]
|
||||
```
|
||||
On other nodes, run the same command with `--rank 1`, ..., `--rank 3` respectively.
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>ViT-Small, linear classification.</summary>
|
||||
|
||||
Run on single node:
|
||||
```
|
||||
python main_lincls.py \
|
||||
-a vit_small --lr=3 \
|
||||
--dist-url 'tcp://localhost:10001' \
|
||||
--multiprocessing-distributed --world-size 1 --rank 0 \
|
||||
--pretrained [your checkpoint path]/[your checkpoint file].pth.tar \
|
||||
[your imagenet-folder with train and val folders]
|
||||
```
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>ViT-Base, 8-node (64-GPU) pre-training.</summary>
|
||||
|
||||
```
|
||||
python main_moco.py \
|
||||
-a vit_base \
|
||||
--optimizer=adamw --lr=1.5e-4 --weight-decay=.1 \
|
||||
--epochs=300 --warmup-epochs=40 \
|
||||
--stop-grad-conv1 --moco-m-cos --moco-t=.2 \
|
||||
--dist-url 'tcp://[your first node address]:[specified port]' \
|
||||
--multiprocessing-distributed --world-size 8 --rank 0 \
|
||||
[your imagenet-folder with train and val folders]
|
||||
```
|
||||
On other nodes, run the same command with `--rank 1`, ..., `--rank 7` respectively.
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>ViT-Base, linear classification.</summary>
|
||||
|
||||
Run on single node:
|
||||
```
|
||||
python main_lincls.py \
|
||||
-a vit_base --lr=3 \
|
||||
--dist-url 'tcp://localhost:10001' \
|
||||
--multiprocessing-distributed --world-size 1 --rank 0 \
|
||||
--pretrained [your checkpoint path]/[your checkpoint file].pth.tar \
|
||||
[your imagenet-folder with train and val folders]
|
||||
```
|
||||
</details>
|
||||
|
||||
|
||||
Below are our pre-trained ViT models and logs (batch 4096).
|
||||
|
||||
<table><tbody>
|
||||
<!-- START TABLE -->
|
||||
<!-- TABLE HEADER -->
|
||||
<th valign="center">model</th>
|
||||
<th valign="center">pretrain<br/>epochs</th>
|
||||
<th valign="center">linear<br/>acc</th>
|
||||
<th valign="center">pretrain<br/>files</th>
|
||||
<th valign="center">linear<br/>files</th>
|
||||
<!-- TABLE BODY -->
|
||||
<tr>
|
||||
<td align="left">ViT-Small</td>
|
||||
<td align="center">300</td>
|
||||
<td align="center">73.2</td>
|
||||
<td align="center"><a href="https://dl.fbaipublicfiles.com/moco-v3/vit-s-300ep/vit-s-300ep.pth.tar">chpt</a></td>
|
||||
<td align="center"><a href="https://dl.fbaipublicfiles.com/moco-v3/vit-s-300ep/linear-vit-s-300ep.pth.tar">chpt</a> /
|
||||
<a href="https://dl.fbaipublicfiles.com/moco-v3/vit-s-300ep/linear-vit-s-300ep.std">log</a></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="left">ViT-Base</td>
|
||||
<td align="center">300</td>
|
||||
<td align="center">76.7</td>
|
||||
<td align="center"><a href="https://dl.fbaipublicfiles.com/moco-v3/vit-b-300ep/vit-b-300ep.pth.tar">chpt</a></td>
|
||||
<td align="center"><a href="https://dl.fbaipublicfiles.com/moco-v3/vit-b-300ep/linear-vit-b-300ep.pth.tar">chpt</a> /
|
||||
<a href="https://dl.fbaipublicfiles.com/moco-v3/vit-b-300ep/linear-vit-b-300ep.std">log</a></td>
|
||||
</tr>
|
||||
</tbody></table>
|
|
@ -0,0 +1,31 @@
|
|||
# Contributing to moco-v3
|
||||
We want to make contributing to this project as easy and transparent as
|
||||
possible.
|
||||
|
||||
## Pull Requests
|
||||
We actively welcome your pull requests.
|
||||
|
||||
1. Fork the repo and create your branch from `master`.
|
||||
2. If you've added code that should be tested, add tests.
|
||||
3. If you've changed APIs, update the documentation.
|
||||
4. Ensure the test suite passes.
|
||||
5. Make sure your code lints.
|
||||
6. If you haven't already, complete the Contributor License Agreement ("CLA").
|
||||
|
||||
## Contributor License Agreement ("CLA")
|
||||
In order to accept your pull request, we need you to submit a CLA. You only need
|
||||
to do this once to work on any of Facebook's open source projects.
|
||||
|
||||
Complete your CLA here: <https://code.facebook.com/cla>
|
||||
|
||||
## Issues
|
||||
We use GitHub issues to track public bugs. Please ensure your description is
|
||||
clear and has sufficient instructions to be able to reproduce the issue.
|
||||
|
||||
Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe
|
||||
disclosure of security bugs. In those cases, please go through the process
|
||||
outlined on that page and do not file a public issue.
|
||||
|
||||
## License
|
||||
By contributing to moco-v3, you agree that your contributions will be licensed
|
||||
under the LICENSE file in the root directory of this source tree.
|
|
@ -0,0 +1,399 @@
|
|||
Attribution-NonCommercial 4.0 International
|
||||
|
||||
=======================================================================
|
||||
|
||||
Creative Commons Corporation ("Creative Commons") is not a law firm and
|
||||
does not provide legal services or legal advice. Distribution of
|
||||
Creative Commons public licenses does not create a lawyer-client or
|
||||
other relationship. Creative Commons makes its licenses and related
|
||||
information available on an "as-is" basis. Creative Commons gives no
|
||||
warranties regarding its licenses, any material licensed under their
|
||||
terms and conditions, or any related information. Creative Commons
|
||||
disclaims all liability for damages resulting from their use to the
|
||||
fullest extent possible.
|
||||
|
||||
Using Creative Commons Public Licenses
|
||||
|
||||
Creative Commons public licenses provide a standard set of terms and
|
||||
conditions that creators and other rights holders may use to share
|
||||
original works of authorship and other material subject to copyright
|
||||
and certain other rights specified in the public license below. The
|
||||
following considerations are for informational purposes only, are not
|
||||
exhaustive, and do not form part of our licenses.
|
||||
|
||||
Considerations for licensors: Our public licenses are
|
||||
intended for use by those authorized to give the public
|
||||
permission to use material in ways otherwise restricted by
|
||||
copyright and certain other rights. Our licenses are
|
||||
irrevocable. Licensors should read and understand the terms
|
||||
and conditions of the license they choose before applying it.
|
||||
Licensors should also secure all rights necessary before
|
||||
applying our licenses so that the public can reuse the
|
||||
material as expected. Licensors should clearly mark any
|
||||
material not subject to the license. This includes other CC-
|
||||
licensed material, or material used under an exception or
|
||||
limitation to copyright. More considerations for licensors:
|
||||
wiki.creativecommons.org/Considerations_for_licensors
|
||||
|
||||
Considerations for the public: By using one of our public
|
||||
licenses, a licensor grants the public permission to use the
|
||||
licensed material under specified terms and conditions. If
|
||||
the licensor's permission is not necessary for any reason--for
|
||||
example, because of any applicable exception or limitation to
|
||||
copyright--then that use is not regulated by the license. Our
|
||||
licenses grant only permissions under copyright and certain
|
||||
other rights that a licensor has authority to grant. Use of
|
||||
the licensed material may still be restricted for other
|
||||
reasons, including because others have copyright or other
|
||||
rights in the material. A licensor may make special requests,
|
||||
such as asking that all changes be marked or described.
|
||||
Although not required by our licenses, you are encouraged to
|
||||
respect those requests where reasonable. More_considerations
|
||||
for the public:
|
||||
wiki.creativecommons.org/Considerations_for_licensees
|
||||
|
||||
=======================================================================
|
||||
|
||||
Creative Commons Attribution-NonCommercial 4.0 International Public
|
||||
License
|
||||
|
||||
By exercising the Licensed Rights (defined below), You accept and agree
|
||||
to be bound by the terms and conditions of this Creative Commons
|
||||
Attribution-NonCommercial 4.0 International Public License ("Public
|
||||
License"). To the extent this Public License may be interpreted as a
|
||||
contract, You are granted the Licensed Rights in consideration of Your
|
||||
acceptance of these terms and conditions, and the Licensor grants You
|
||||
such rights in consideration of benefits the Licensor receives from
|
||||
making the Licensed Material available under these terms and
|
||||
conditions.
|
||||
|
||||
Section 1 -- Definitions.
|
||||
|
||||
a. Adapted Material means material subject to Copyright and Similar
|
||||
Rights that is derived from or based upon the Licensed Material
|
||||
and in which the Licensed Material is translated, altered,
|
||||
arranged, transformed, or otherwise modified in a manner requiring
|
||||
permission under the Copyright and Similar Rights held by the
|
||||
Licensor. For purposes of this Public License, where the Licensed
|
||||
Material is a musical work, performance, or sound recording,
|
||||
Adapted Material is always produced where the Licensed Material is
|
||||
synched in timed relation with a moving image.
|
||||
|
||||
b. Adapter's License means the license You apply to Your Copyright
|
||||
and Similar Rights in Your contributions to Adapted Material in
|
||||
accordance with the terms and conditions of this Public License.
|
||||
|
||||
c. Copyright and Similar Rights means copyright and/or similar rights
|
||||
closely related to copyright including, without limitation,
|
||||
performance, broadcast, sound recording, and Sui Generis Database
|
||||
Rights, without regard to how the rights are labeled or
|
||||
categorized. For purposes of this Public License, the rights
|
||||
specified in Section 2(b)(1)-(2) are not Copyright and Similar
|
||||
Rights.
|
||||
d. Effective Technological Measures means those measures that, in the
|
||||
absence of proper authority, may not be circumvented under laws
|
||||
fulfilling obligations under Article 11 of the WIPO Copyright
|
||||
Treaty adopted on December 20, 1996, and/or similar international
|
||||
agreements.
|
||||
|
||||
e. Exceptions and Limitations means fair use, fair dealing, and/or
|
||||
any other exception or limitation to Copyright and Similar Rights
|
||||
that applies to Your use of the Licensed Material.
|
||||
|
||||
f. Licensed Material means the artistic or literary work, database,
|
||||
or other material to which the Licensor applied this Public
|
||||
License.
|
||||
|
||||
g. Licensed Rights means the rights granted to You subject to the
|
||||
terms and conditions of this Public License, which are limited to
|
||||
all Copyright and Similar Rights that apply to Your use of the
|
||||
Licensed Material and that the Licensor has authority to license.
|
||||
|
||||
h. Licensor means the individual(s) or entity(ies) granting rights
|
||||
under this Public License.
|
||||
|
||||
i. NonCommercial means not primarily intended for or directed towards
|
||||
commercial advantage or monetary compensation. For purposes of
|
||||
this Public License, the exchange of the Licensed Material for
|
||||
other material subject to Copyright and Similar Rights by digital
|
||||
file-sharing or similar means is NonCommercial provided there is
|
||||
no payment of monetary compensation in connection with the
|
||||
exchange.
|
||||
|
||||
j. Share means to provide material to the public by any means or
|
||||
process that requires permission under the Licensed Rights, such
|
||||
as reproduction, public display, public performance, distribution,
|
||||
dissemination, communication, or importation, and to make material
|
||||
available to the public including in ways that members of the
|
||||
public may access the material from a place and at a time
|
||||
individually chosen by them.
|
||||
|
||||
k. Sui Generis Database Rights means rights other than copyright
|
||||
resulting from Directive 96/9/EC of the European Parliament and of
|
||||
the Council of 11 March 1996 on the legal protection of databases,
|
||||
as amended and/or succeeded, as well as other essentially
|
||||
equivalent rights anywhere in the world.
|
||||
|
||||
l. You means the individual or entity exercising the Licensed Rights
|
||||
under this Public License. Your has a corresponding meaning.
|
||||
|
||||
Section 2 -- Scope.
|
||||
|
||||
a. License grant.
|
||||
|
||||
1. Subject to the terms and conditions of this Public License,
|
||||
the Licensor hereby grants You a worldwide, royalty-free,
|
||||
non-sublicensable, non-exclusive, irrevocable license to
|
||||
exercise the Licensed Rights in the Licensed Material to:
|
||||
|
||||
a. reproduce and Share the Licensed Material, in whole or
|
||||
in part, for NonCommercial purposes only; and
|
||||
|
||||
b. produce, reproduce, and Share Adapted Material for
|
||||
NonCommercial purposes only.
|
||||
|
||||
2. Exceptions and Limitations. For the avoidance of doubt, where
|
||||
Exceptions and Limitations apply to Your use, this Public
|
||||
License does not apply, and You do not need to comply with
|
||||
its terms and conditions.
|
||||
|
||||
3. Term. The term of this Public License is specified in Section
|
||||
6(a).
|
||||
|
||||
4. Media and formats; technical modifications allowed. The
|
||||
Licensor authorizes You to exercise the Licensed Rights in
|
||||
all media and formats whether now known or hereafter created,
|
||||
and to make technical modifications necessary to do so. The
|
||||
Licensor waives and/or agrees not to assert any right or
|
||||
authority to forbid You from making technical modifications
|
||||
necessary to exercise the Licensed Rights, including
|
||||
technical modifications necessary to circumvent Effective
|
||||
Technological Measures. For purposes of this Public License,
|
||||
simply making modifications authorized by this Section 2(a)
|
||||
(4) never produces Adapted Material.
|
||||
|
||||
5. Downstream recipients.
|
||||
|
||||
a. Offer from the Licensor -- Licensed Material. Every
|
||||
recipient of the Licensed Material automatically
|
||||
receives an offer from the Licensor to exercise the
|
||||
Licensed Rights under the terms and conditions of this
|
||||
Public License.
|
||||
|
||||
b. No downstream restrictions. You may not offer or impose
|
||||
any additional or different terms or conditions on, or
|
||||
apply any Effective Technological Measures to, the
|
||||
Licensed Material if doing so restricts exercise of the
|
||||
Licensed Rights by any recipient of the Licensed
|
||||
Material.
|
||||
|
||||
6. No endorsement. Nothing in this Public License constitutes or
|
||||
may be construed as permission to assert or imply that You
|
||||
are, or that Your use of the Licensed Material is, connected
|
||||
with, or sponsored, endorsed, or granted official status by,
|
||||
the Licensor or others designated to receive attribution as
|
||||
provided in Section 3(a)(1)(A)(i).
|
||||
|
||||
b. Other rights.
|
||||
|
||||
1. Moral rights, such as the right of integrity, are not
|
||||
licensed under this Public License, nor are publicity,
|
||||
privacy, and/or other similar personality rights; however, to
|
||||
the extent possible, the Licensor waives and/or agrees not to
|
||||
assert any such rights held by the Licensor to the limited
|
||||
extent necessary to allow You to exercise the Licensed
|
||||
Rights, but not otherwise.
|
||||
|
||||
2. Patent and trademark rights are not licensed under this
|
||||
Public License.
|
||||
|
||||
3. To the extent possible, the Licensor waives any right to
|
||||
collect royalties from You for the exercise of the Licensed
|
||||
Rights, whether directly or through a collecting society
|
||||
under any voluntary or waivable statutory or compulsory
|
||||
licensing scheme. In all other cases the Licensor expressly
|
||||
reserves any right to collect such royalties, including when
|
||||
the Licensed Material is used other than for NonCommercial
|
||||
purposes.
|
||||
|
||||
Section 3 -- License Conditions.
|
||||
|
||||
Your exercise of the Licensed Rights is expressly made subject to the
|
||||
following conditions.
|
||||
|
||||
a. Attribution.
|
||||
|
||||
1. If You Share the Licensed Material (including in modified
|
||||
form), You must:
|
||||
|
||||
a. retain the following if it is supplied by the Licensor
|
||||
with the Licensed Material:
|
||||
|
||||
i. identification of the creator(s) of the Licensed
|
||||
Material and any others designated to receive
|
||||
attribution, in any reasonable manner requested by
|
||||
the Licensor (including by pseudonym if
|
||||
designated);
|
||||
|
||||
ii. a copyright notice;
|
||||
|
||||
iii. a notice that refers to this Public License;
|
||||
|
||||
iv. a notice that refers to the disclaimer of
|
||||
warranties;
|
||||
|
||||
v. a URI or hyperlink to the Licensed Material to the
|
||||
extent reasonably practicable;
|
||||
|
||||
b. indicate if You modified the Licensed Material and
|
||||
retain an indication of any previous modifications; and
|
||||
|
||||
c. indicate the Licensed Material is licensed under this
|
||||
Public License, and include the text of, or the URI or
|
||||
hyperlink to, this Public License.
|
||||
|
||||
2. You may satisfy the conditions in Section 3(a)(1) in any
|
||||
reasonable manner based on the medium, means, and context in
|
||||
which You Share the Licensed Material. For example, it may be
|
||||
reasonable to satisfy the conditions by providing a URI or
|
||||
hyperlink to a resource that includes the required
|
||||
information.
|
||||
|
||||
3. If requested by the Licensor, You must remove any of the
|
||||
information required by Section 3(a)(1)(A) to the extent
|
||||
reasonably practicable.
|
||||
|
||||
4. If You Share Adapted Material You produce, the Adapter's
|
||||
License You apply must not prevent recipients of the Adapted
|
||||
Material from complying with this Public License.
|
||||
|
||||
Section 4 -- Sui Generis Database Rights.
|
||||
|
||||
Where the Licensed Rights include Sui Generis Database Rights that
|
||||
apply to Your use of the Licensed Material:
|
||||
|
||||
a. for the avoidance of doubt, Section 2(a)(1) grants You the right
|
||||
to extract, reuse, reproduce, and Share all or a substantial
|
||||
portion of the contents of the database for NonCommercial purposes
|
||||
only;
|
||||
|
||||
b. if You include all or a substantial portion of the database
|
||||
contents in a database in which You have Sui Generis Database
|
||||
Rights, then the database in which You have Sui Generis Database
|
||||
Rights (but not its individual contents) is Adapted Material; and
|
||||
|
||||
c. You must comply with the conditions in Section 3(a) if You Share
|
||||
all or a substantial portion of the contents of the database.
|
||||
|
||||
For the avoidance of doubt, this Section 4 supplements and does not
|
||||
replace Your obligations under this Public License where the Licensed
|
||||
Rights include other Copyright and Similar Rights.
|
||||
|
||||
Section 5 -- Disclaimer of Warranties and Limitation of Liability.
|
||||
|
||||
a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
|
||||
EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
|
||||
AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
|
||||
ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
|
||||
IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
|
||||
WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
|
||||
PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
|
||||
ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
|
||||
KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
|
||||
ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
|
||||
|
||||
b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
|
||||
TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
|
||||
NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
|
||||
INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
|
||||
COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
|
||||
USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
|
||||
ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
|
||||
DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
|
||||
IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
|
||||
|
||||
c. The disclaimer of warranties and limitation of liability provided
|
||||
above shall be interpreted in a manner that, to the extent
|
||||
possible, most closely approximates an absolute disclaimer and
|
||||
waiver of all liability.
|
||||
|
||||
Section 6 -- Term and Termination.
|
||||
|
||||
a. This Public License applies for the term of the Copyright and
|
||||
Similar Rights licensed here. However, if You fail to comply with
|
||||
this Public License, then Your rights under this Public License
|
||||
terminate automatically.
|
||||
|
||||
b. Where Your right to use the Licensed Material has terminated under
|
||||
Section 6(a), it reinstates:
|
||||
|
||||
1. automatically as of the date the violation is cured, provided
|
||||
it is cured within 30 days of Your discovery of the
|
||||
violation; or
|
||||
|
||||
2. upon express reinstatement by the Licensor.
|
||||
|
||||
For the avoidance of doubt, this Section 6(b) does not affect any
|
||||
right the Licensor may have to seek remedies for Your violations
|
||||
of this Public License.
|
||||
|
||||
c. For the avoidance of doubt, the Licensor may also offer the
|
||||
Licensed Material under separate terms or conditions or stop
|
||||
distributing the Licensed Material at any time; however, doing so
|
||||
will not terminate this Public License.
|
||||
|
||||
d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
|
||||
License.
|
||||
|
||||
Section 7 -- Other Terms and Conditions.
|
||||
|
||||
a. The Licensor shall not be bound by any additional or different
|
||||
terms or conditions communicated by You unless expressly agreed.
|
||||
|
||||
b. Any arrangements, understandings, or agreements regarding the
|
||||
Licensed Material not stated herein are separate from and
|
||||
independent of the terms and conditions of this Public License.
|
||||
|
||||
Section 8 -- Interpretation.
|
||||
|
||||
a. For the avoidance of doubt, this Public License does not, and
|
||||
shall not be interpreted to, reduce, limit, restrict, or impose
|
||||
conditions on any use of the Licensed Material that could lawfully
|
||||
be made without permission under this Public License.
|
||||
|
||||
b. To the extent possible, if any provision of this Public License is
|
||||
deemed unenforceable, it shall be automatically reformed to the
|
||||
minimum extent necessary to make it enforceable. If the provision
|
||||
cannot be reformed, it shall be severed from this Public License
|
||||
without affecting the enforceability of the remaining terms and
|
||||
conditions.
|
||||
|
||||
c. No term or condition of this Public License will be waived and no
|
||||
failure to comply consented to unless expressly agreed to by the
|
||||
Licensor.
|
||||
|
||||
d. Nothing in this Public License constitutes or may be interpreted
|
||||
as a limitation upon, or waiver of, any privileges and immunities
|
||||
that apply to the Licensor or You, including from the legal
|
||||
processes of any jurisdiction or authority.
|
||||
|
||||
=======================================================================
|
||||
|
||||
Creative Commons is not a party to its public
|
||||
licenses. Notwithstanding, Creative Commons may elect to apply one of
|
||||
its public licenses to material it publishes and in those instances
|
||||
will be considered the “Licensor.” The text of the Creative Commons
|
||||
public licenses is dedicated to the public domain under the CC0 Public
|
||||
Domain Dedication. Except for the limited purpose of indicating that
|
||||
material is shared under a Creative Commons public license or as
|
||||
otherwise permitted by the Creative Commons policies published at
|
||||
creativecommons.org/policies, Creative Commons does not authorize the
|
||||
use of the trademark "Creative Commons" or any other trademark or logo
|
||||
of Creative Commons without its prior written consent including,
|
||||
without limitation, in connection with any unauthorized modifications
|
||||
to any of its public licenses or any other arrangements,
|
||||
understandings, or agreements concerning use of licensed material. For
|
||||
the avoidance of doubt, this paragraph does not form part of the
|
||||
public licenses.
|
||||
|
||||
Creative Commons may be contacted at creativecommons.org.
|
|
@ -0,0 +1,197 @@
|
|||
## MoCo v3 for Self-supervised ResNet and ViT
|
||||
|
||||
### Introduction
|
||||
This is a PyTorch implementation of [MoCo v3](https://arxiv.org/abs/2104.02057) for self-supervised ResNet and ViT.
|
||||
|
||||
The original MoCo v3 was implemented in Tensorflow and run in TPUs. This repo re-implements in PyTorch and GPUs. Despite the library and numerical differences, this repo reproduces the results and observations in the paper.
|
||||
|
||||
### Main Results
|
||||
|
||||
The following results are based on ImageNet-1k self-supervised pre-training, followed by ImageNet-1k supervised training for linear evaluation or end-to-end fine-tuning. All results in these tables are based on a batch size of 4096.
|
||||
|
||||
#### ResNet-50, linear classification
|
||||
<table><tbody>
|
||||
<!-- START TABLE -->
|
||||
<!-- TABLE HEADER -->
|
||||
<th valign="center">pretrain<br/>epochs</th>
|
||||
<th valign="center">pretrain<br/>crops</th>
|
||||
<th valign="center">linear<br/>acc</th>
|
||||
<!-- TABLE BODY -->
|
||||
<tr>
|
||||
<td align="right">100</td>
|
||||
<td align="center">2x224</td>
|
||||
<td align="center">68.9</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="right">300</td>
|
||||
<td align="center">2x224</td>
|
||||
<td align="center">72.8</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="right">1000</td>
|
||||
<td align="center">2x224</td>
|
||||
<td align="center">74.6</td>
|
||||
</tr>
|
||||
</tbody></table>
|
||||
|
||||
#### ViT, linear classification
|
||||
<table><tbody>
|
||||
<!-- START TABLE -->
|
||||
<!-- TABLE HEADER -->
|
||||
<th valign="center">model</th>
|
||||
<th valign="center">pretrain<br/>epochs</th>
|
||||
<th valign="center">pretrain<br/>crops</th>
|
||||
<th valign="center">linear<br/>acc</th>
|
||||
<!-- TABLE BODY -->
|
||||
<tr>
|
||||
<td align="left">ViT-Small</td>
|
||||
<td align="right">300</td>
|
||||
<td align="center">2x224</td>
|
||||
<td align="center">73.2</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="left">ViT-Base</td>
|
||||
<td align="right">300</td>
|
||||
<td align="center">2x224</td>
|
||||
<td align="center">76.7</td>
|
||||
</tr>
|
||||
</tbody></table>
|
||||
|
||||
#### ViT, end-to-end fine-tuning
|
||||
<table><tbody>
|
||||
<!-- START TABLE -->
|
||||
<!-- TABLE HEADER -->
|
||||
<th valign="center">model</th>
|
||||
<th valign="center">pretrain<br/>epochs</th>
|
||||
<th valign="center">pretrain<br/>crops</th>
|
||||
<th valign="center">e2e<br/>acc</th>
|
||||
<!-- TABLE BODY -->
|
||||
<tr>
|
||||
<td align="left">ViT-Small</td>
|
||||
<td align="right">300</td>
|
||||
<td align="center">2x224</td>
|
||||
<td align="center">81.4</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="left">ViT-Base</td>
|
||||
<td align="right">300</td>
|
||||
<td align="center">2x224</td>
|
||||
<td align="center">83.2</td>
|
||||
</tr>
|
||||
</tbody></table>
|
||||
|
||||
The end-to-end fine-tuning results are obtained using the [DeiT](https://github.com/facebookresearch/deit) repo, using all the default DeiT configs. ViT-B is fine-tuned for 150 epochs (vs DeiT-B's 300ep, which has 81.8% accuracy).
|
||||
|
||||
### Usage: Preparation
|
||||
|
||||
Install PyTorch and download the ImageNet dataset following the [official PyTorch ImageNet training code](https://github.com/pytorch/examples/tree/master/imagenet). Similar to [MoCo v1/2](https://github.com/facebookresearch/moco), this repo contains minimal modifications on the official PyTorch ImageNet code. We assume the user can successfully run the official PyTorch ImageNet code.
|
||||
For ViT models, install [timm](https://github.com/rwightman/pytorch-image-models) (`timm==0.4.9`).
|
||||
|
||||
The code has been tested with CUDA 10.2/CuDNN 7.6.5, PyTorch 1.9.0 and timm 0.4.9.
|
||||
|
||||
### Usage: Self-supervised Pre-Training
|
||||
|
||||
Below are three examples for MoCo v3 pre-training.
|
||||
|
||||
#### ResNet-50 with 2-node (16-GPU) training, batch 4096
|
||||
|
||||
On the first node, run:
|
||||
```
|
||||
python main_moco.py \
|
||||
--moco-m-cos --crop-min=.2 \
|
||||
--dist-url 'tcp://[your first node address]:[specified port]' \
|
||||
--multiprocessing-distributed --world-size 2 --rank 0 \
|
||||
[your imagenet-folder with train and val folders]
|
||||
```
|
||||
On the second node, run the same command with `--rank 1`.
|
||||
With a batch size of 4096, the training can fit into 2 nodes with a total of 16 Volta 32G GPUs.
|
||||
|
||||
|
||||
#### ViT-Small with 1-node (8-GPU) training, batch 1024
|
||||
|
||||
```
|
||||
python main_moco.py \
|
||||
-a vit_small -b 1024 \
|
||||
--optimizer=adamw --lr=1.5e-4 --weight-decay=.1 \
|
||||
--epochs=300 --warmup-epochs=40 \
|
||||
--stop-grad-conv1 --moco-m-cos --moco-t=.2 \
|
||||
--dist-url 'tcp://localhost:10001' \
|
||||
--multiprocessing-distributed --world-size 1 --rank 0 \
|
||||
[your imagenet-folder with train and val folders]
|
||||
```
|
||||
|
||||
#### ViT-Base with 8-node training, batch 4096
|
||||
|
||||
With a batch size of 4096, ViT-Base is trained with 8 nodes:
|
||||
```
|
||||
python main_moco.py \
|
||||
-a vit_base \
|
||||
--optimizer=adamw --lr=1.5e-4 --weight-decay=.1 \
|
||||
--epochs=300 --warmup-epochs=40 \
|
||||
--stop-grad-conv1 --moco-m-cos --moco-t=.2 \
|
||||
--dist-url 'tcp://[your first node address]:[specified port]' \
|
||||
--multiprocessing-distributed --world-size 8 --rank 0 \
|
||||
[your imagenet-folder with train and val folders]
|
||||
```
|
||||
On other nodes, run the same command with `--rank 1`, ..., `--rank 7` respectively.
|
||||
|
||||
#### Notes:
|
||||
1. The batch size specified by `-b` is the total batch size across all GPUs.
|
||||
1. The learning rate specified by `--lr` is the *base* lr, and is adjusted by the [linear lr scaling rule](https://arxiv.org/abs/1706.02677) in [this line](https://github.com/facebookresearch/moco-v3/blob/main/main_moco.py#L213).
|
||||
1. Using a smaller batch size has a more stable result (see paper), but has lower speed. Using a large batch size is critical for good speed in TPUs (as we did in the paper).
|
||||
1. In this repo, only *multi-gpu*, *DistributedDataParallel* training is supported; single-gpu or DataParallel training is not supported. This code is improved to better suit the *multi-node* setting, and by default uses automatic *mixed-precision* for pre-training.
|
||||
|
||||
### Usage: Linear Classification
|
||||
|
||||
By default, we use momentum-SGD and a batch size of 1024 for linear classification on frozen features/weights. This can be done with a single 8-GPU node.
|
||||
|
||||
```
|
||||
python main_lincls.py \
|
||||
-a [architecture] --lr [learning rate] \
|
||||
--dist-url 'tcp://localhost:10001' \
|
||||
--multiprocessing-distributed --world-size 1 --rank 0 \
|
||||
--pretrained [your checkpoint path]/[your checkpoint file].pth.tar \
|
||||
[your imagenet-folder with train and val folders]
|
||||
```
|
||||
|
||||
### Usage: End-to-End Fine-tuning ViT
|
||||
|
||||
To perform end-to-end fine-tuning for ViT, use our script to convert the pre-trained ViT checkpoint to [DEiT](https://github.com/facebookresearch/deit) format:
|
||||
```
|
||||
python convert_to_deit.py \
|
||||
--input [your checkpoint path]/[your checkpoint file].pth.tar \
|
||||
--output [target checkpoint file].pth
|
||||
```
|
||||
Then run the training (in the DeiT repo) with the converted checkpoint:
|
||||
```
|
||||
python $DEIT_DIR/main.py \
|
||||
--resume [target checkpoint file].pth \
|
||||
--epochs 150
|
||||
```
|
||||
This gives us 83.2% accuracy for ViT-Base with 150-epoch fine-tuning.
|
||||
|
||||
**Note**:
|
||||
1. We use `--resume` rather than `--finetune` in the DeiT repo, as its `--finetune` option trains under eval mode. When loading the pre-trained model, revise `model_without_ddp.load_state_dict(checkpoint['model'])` with `strict=False`.
|
||||
1. Our ViT-Small is with `heads=12` in the Transformer block, while by default in DeiT it is `heads=6`. Please modify the DeiT code accordingly when fine-tuning our ViT-Small model.
|
||||
|
||||
### Model Configs
|
||||
|
||||
See the commands listed in [CONFIG.md](https://github.com/facebookresearch/moco-v3/blob/main/CONFIG.md).
|
||||
|
||||
### Transfer Learning
|
||||
|
||||
See the instruction in the [transfer](https://github.com/facebookresearch/moco-v3/tree/main/transfer) dir.
|
||||
|
||||
### License
|
||||
|
||||
This project is under the CC-BY-NC 4.0 license. See [LICENSE](LICENSE) for details.
|
||||
|
||||
### Citation
|
||||
```
|
||||
@Article{chen2021mocov3,
|
||||
author = {Xinlei Chen* and Saining Xie* and Kaiming He},
|
||||
title = {An Empirical Study of Training Self-Supervised Vision Transformers},
|
||||
journal = {arXiv preprint arXiv:2104.02057},
|
||||
year = {2021},
|
||||
}
|
||||
```
|
|
@ -0,0 +1,39 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import torch
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='Convert MoCo Pre-Traind Model to DEiT')
|
||||
parser.add_argument('--input', default='', type=str, metavar='PATH', required=True,
|
||||
help='path to moco pre-trained checkpoint')
|
||||
parser.add_argument('--output', default='', type=str, metavar='PATH', required=True,
|
||||
help='path to output checkpoint in DEiT format')
|
||||
args = parser.parse_args()
|
||||
print(args)
|
||||
|
||||
# load input
|
||||
checkpoint = torch.load(args.input, map_location="cpu")
|
||||
state_dict = checkpoint['state_dict']
|
||||
for k in list(state_dict.keys()):
|
||||
# retain only base_encoder up to before the embedding layer
|
||||
if k.startswith('module.base_encoder') and not k.startswith('module.base_encoder.head'):
|
||||
# remove prefix
|
||||
state_dict[k[len("module.base_encoder."):]] = state_dict[k]
|
||||
# delete renamed or unused k
|
||||
del state_dict[k]
|
||||
|
||||
# make output directory if necessary
|
||||
output_dir = os.path.dirname(args.output)
|
||||
if not os.path.isdir(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
# save to output
|
||||
torch.save({'model': state_dict}, args.output)
|
|
@ -0,0 +1,524 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
import builtins
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
import time
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.parallel
|
||||
import torch.backends.cudnn as cudnn
|
||||
import torch.distributed as dist
|
||||
import torch.optim
|
||||
import torch.multiprocessing as mp
|
||||
import torch.utils.data
|
||||
import torch.utils.data.distributed
|
||||
import torchvision.transforms as transforms
|
||||
import torchvision.datasets as datasets
|
||||
import torchvision.models as torchvision_models
|
||||
|
||||
import vits
|
||||
|
||||
torchvision_model_names = sorted(name for name in torchvision_models.__dict__
|
||||
if name.islower() and not name.startswith("__")
|
||||
and callable(torchvision_models.__dict__[name]))
|
||||
|
||||
model_names = ['vit_small', 'vit_base', 'vit_conv_small', 'vit_conv_base'] + torchvision_model_names
|
||||
|
||||
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
|
||||
parser.add_argument('data', metavar='DIR',
|
||||
help='path to dataset')
|
||||
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50',
|
||||
choices=model_names,
|
||||
help='model architecture: ' +
|
||||
' | '.join(model_names) +
|
||||
' (default: resnet50)')
|
||||
parser.add_argument('-j', '--workers', default=32, type=int, metavar='N',
|
||||
help='number of data loading workers (default: 32)')
|
||||
parser.add_argument('--epochs', default=90, type=int, metavar='N',
|
||||
help='number of total epochs to run')
|
||||
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
|
||||
help='manual epoch number (useful on restarts)')
|
||||
parser.add_argument('-b', '--batch-size', default=1024, type=int,
|
||||
metavar='N',
|
||||
help='mini-batch size (default: 1024), this is the total '
|
||||
'batch size of all GPUs on the current node when '
|
||||
'using Data Parallel or Distributed Data Parallel')
|
||||
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
|
||||
metavar='LR', help='initial (base) learning rate', dest='lr')
|
||||
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
|
||||
help='momentum')
|
||||
parser.add_argument('--wd', '--weight-decay', default=0., type=float,
|
||||
metavar='W', help='weight decay (default: 0.)',
|
||||
dest='weight_decay')
|
||||
parser.add_argument('-p', '--print-freq', default=10, type=int,
|
||||
metavar='N', help='print frequency (default: 10)')
|
||||
parser.add_argument('--resume', default='', type=str, metavar='PATH',
|
||||
help='path to latest checkpoint (default: none)')
|
||||
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
|
||||
help='evaluate model on validation set')
|
||||
parser.add_argument('--world-size', default=-1, type=int,
|
||||
help='number of nodes for distributed training')
|
||||
parser.add_argument('--rank', default=-1, type=int,
|
||||
help='node rank for distributed training')
|
||||
parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str,
|
||||
help='url used to set up distributed training')
|
||||
parser.add_argument('--dist-backend', default='nccl', type=str,
|
||||
help='distributed backend')
|
||||
parser.add_argument('--seed', default=None, type=int,
|
||||
help='seed for initializing training. ')
|
||||
parser.add_argument('--gpu', default=None, type=int,
|
||||
help='GPU id to use.')
|
||||
parser.add_argument('--multiprocessing-distributed', action='store_true',
|
||||
help='Use multi-processing distributed training to launch '
|
||||
'N processes per node, which has N GPUs. This is the '
|
||||
'fastest way to use PyTorch for either single node or '
|
||||
'multi node data parallel training')
|
||||
|
||||
# additional configs:
|
||||
parser.add_argument('--pretrained', default='', type=str,
|
||||
help='path to moco pretrained checkpoint')
|
||||
|
||||
best_acc1 = 0
|
||||
|
||||
|
||||
def main():
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.seed is not None:
|
||||
random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
cudnn.deterministic = True
|
||||
warnings.warn('You have chosen to seed training. '
|
||||
'This will turn on the CUDNN deterministic setting, '
|
||||
'which can slow down your training considerably! '
|
||||
'You may see unexpected behavior when restarting '
|
||||
'from checkpoints.')
|
||||
|
||||
if args.gpu is not None:
|
||||
warnings.warn('You have chosen a specific GPU. This will completely '
|
||||
'disable data parallelism.')
|
||||
|
||||
if args.dist_url == "env://" and args.world_size == -1:
|
||||
args.world_size = int(os.environ["WORLD_SIZE"])
|
||||
|
||||
args.distributed = args.world_size > 1 or args.multiprocessing_distributed
|
||||
|
||||
ngpus_per_node = torch.cuda.device_count()
|
||||
if args.multiprocessing_distributed:
|
||||
# Since we have ngpus_per_node processes per node, the total world_size
|
||||
# needs to be adjusted accordingly
|
||||
args.world_size = ngpus_per_node * args.world_size
|
||||
# Use torch.multiprocessing.spawn to launch distributed processes: the
|
||||
# main_worker process function
|
||||
mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
|
||||
else:
|
||||
# Simply call main_worker function
|
||||
main_worker(args.gpu, ngpus_per_node, args)
|
||||
|
||||
|
||||
def main_worker(gpu, ngpus_per_node, args):
|
||||
global best_acc1
|
||||
args.gpu = gpu
|
||||
|
||||
# suppress printing if not master
|
||||
if args.multiprocessing_distributed and args.gpu != 0:
|
||||
def print_pass(*args):
|
||||
pass
|
||||
builtins.print = print_pass
|
||||
|
||||
if args.gpu is not None:
|
||||
print("Use GPU: {} for training".format(args.gpu))
|
||||
|
||||
if args.distributed:
|
||||
if args.dist_url == "env://" and args.rank == -1:
|
||||
args.rank = int(os.environ["RANK"])
|
||||
if args.multiprocessing_distributed:
|
||||
# For multiprocessing distributed training, rank needs to be the
|
||||
# global rank among all the processes
|
||||
args.rank = args.rank * ngpus_per_node + gpu
|
||||
dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
|
||||
world_size=args.world_size, rank=args.rank)
|
||||
torch.distributed.barrier()
|
||||
# create model
|
||||
print("=> creating model '{}'".format(args.arch))
|
||||
if args.arch.startswith('vit'):
|
||||
model = vits.__dict__[args.arch]()
|
||||
linear_keyword = 'head'
|
||||
else:
|
||||
model = torchvision_models.__dict__[args.arch]()
|
||||
linear_keyword = 'fc'
|
||||
|
||||
# freeze all layers but the last fc
|
||||
for name, param in model.named_parameters():
|
||||
if name not in ['%s.weight' % linear_keyword, '%s.bias' % linear_keyword]:
|
||||
param.requires_grad = False
|
||||
# init the fc layer
|
||||
getattr(model, linear_keyword).weight.data.normal_(mean=0.0, std=0.01)
|
||||
getattr(model, linear_keyword).bias.data.zero_()
|
||||
|
||||
# load from pre-trained, before DistributedDataParallel constructor
|
||||
if args.pretrained:
|
||||
if os.path.isfile(args.pretrained):
|
||||
print("=> loading checkpoint '{}'".format(args.pretrained))
|
||||
checkpoint = torch.load(args.pretrained, map_location="cpu")
|
||||
|
||||
# rename moco pre-trained keys
|
||||
state_dict = checkpoint['state_dict']
|
||||
for k in list(state_dict.keys()):
|
||||
# retain only base_encoder up to before the embedding layer
|
||||
if k.startswith('module.base_encoder') and not k.startswith('module.base_encoder.%s' % linear_keyword):
|
||||
# remove prefix
|
||||
state_dict[k[len("module.base_encoder."):]] = state_dict[k]
|
||||
# delete renamed or unused k
|
||||
del state_dict[k]
|
||||
|
||||
args.start_epoch = 0
|
||||
msg = model.load_state_dict(state_dict, strict=False)
|
||||
assert set(msg.missing_keys) == {"%s.weight" % linear_keyword, "%s.bias" % linear_keyword}
|
||||
|
||||
print("=> loaded pre-trained model '{}'".format(args.pretrained))
|
||||
else:
|
||||
print("=> no checkpoint found at '{}'".format(args.pretrained))
|
||||
|
||||
# infer learning rate before changing batch size
|
||||
init_lr = args.lr * args.batch_size / 256
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
print('using CPU, this will be slow')
|
||||
elif args.distributed:
|
||||
# For multiprocessing distributed, DistributedDataParallel constructor
|
||||
# should always set the single device scope, otherwise,
|
||||
# DistributedDataParallel will use all available devices.
|
||||
if args.gpu is not None:
|
||||
torch.cuda.set_device(args.gpu)
|
||||
model.cuda(args.gpu)
|
||||
# When using a single GPU per process and per
|
||||
# DistributedDataParallel, we need to divide the batch size
|
||||
# ourselves based on the total number of GPUs we have
|
||||
args.batch_size = int(args.batch_size / ngpus_per_node)
|
||||
args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
|
||||
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
|
||||
else:
|
||||
model.cuda()
|
||||
# DistributedDataParallel will divide and allocate batch_size to all
|
||||
# available GPUs if device_ids are not set
|
||||
model = torch.nn.parallel.DistributedDataParallel(model)
|
||||
elif args.gpu is not None:
|
||||
torch.cuda.set_device(args.gpu)
|
||||
model = model.cuda(args.gpu)
|
||||
else:
|
||||
# DataParallel will divide and allocate batch_size to all available GPUs
|
||||
if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
|
||||
model.features = torch.nn.DataParallel(model.features)
|
||||
model.cuda()
|
||||
else:
|
||||
model = torch.nn.DataParallel(model).cuda()
|
||||
|
||||
# define loss function (criterion) and optimizer
|
||||
criterion = nn.CrossEntropyLoss().cuda(args.gpu)
|
||||
|
||||
# optimize only the linear classifier
|
||||
parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
|
||||
assert len(parameters) == 2 # weight, bias
|
||||
|
||||
optimizer = torch.optim.SGD(parameters, init_lr,
|
||||
momentum=args.momentum,
|
||||
weight_decay=args.weight_decay)
|
||||
|
||||
# optionally resume from a checkpoint
|
||||
if args.resume:
|
||||
if os.path.isfile(args.resume):
|
||||
print("=> loading checkpoint '{}'".format(args.resume))
|
||||
if args.gpu is None:
|
||||
checkpoint = torch.load(args.resume)
|
||||
else:
|
||||
# Map model to be loaded to specified single gpu.
|
||||
loc = 'cuda:{}'.format(args.gpu)
|
||||
checkpoint = torch.load(args.resume, map_location=loc)
|
||||
args.start_epoch = checkpoint['epoch']
|
||||
best_acc1 = checkpoint['best_acc1']
|
||||
if args.gpu is not None:
|
||||
# best_acc1 may be from a checkpoint from a different GPU
|
||||
best_acc1 = best_acc1.to(args.gpu)
|
||||
model.load_state_dict(checkpoint['state_dict'])
|
||||
optimizer.load_state_dict(checkpoint['optimizer'])
|
||||
print("=> loaded checkpoint '{}' (epoch {})"
|
||||
.format(args.resume, checkpoint['epoch']))
|
||||
else:
|
||||
print("=> no checkpoint found at '{}'".format(args.resume))
|
||||
|
||||
cudnn.benchmark = True
|
||||
|
||||
# Data loading code
|
||||
traindir = os.path.join(args.data, 'train')
|
||||
valdir = os.path.join(args.data, 'val')
|
||||
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225])
|
||||
|
||||
train_dataset = datasets.ImageFolder(
|
||||
traindir,
|
||||
transforms.Compose([
|
||||
transforms.RandomResizedCrop(224),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ToTensor(),
|
||||
normalize,
|
||||
]))
|
||||
|
||||
if args.distributed:
|
||||
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
|
||||
else:
|
||||
train_sampler = None
|
||||
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
|
||||
num_workers=args.workers, pin_memory=True, sampler=train_sampler)
|
||||
|
||||
val_loader = torch.utils.data.DataLoader(
|
||||
datasets.ImageFolder(valdir, transforms.Compose([
|
||||
transforms.Resize(256),
|
||||
transforms.CenterCrop(224),
|
||||
transforms.ToTensor(),
|
||||
normalize,
|
||||
])),
|
||||
batch_size=256, shuffle=False,
|
||||
num_workers=args.workers, pin_memory=True)
|
||||
|
||||
if args.evaluate:
|
||||
validate(val_loader, model, criterion, args)
|
||||
return
|
||||
|
||||
for epoch in range(args.start_epoch, args.epochs):
|
||||
if args.distributed:
|
||||
train_sampler.set_epoch(epoch)
|
||||
adjust_learning_rate(optimizer, init_lr, epoch, args)
|
||||
|
||||
# train for one epoch
|
||||
train(train_loader, model, criterion, optimizer, epoch, args)
|
||||
|
||||
# evaluate on validation set
|
||||
acc1 = validate(val_loader, model, criterion, args)
|
||||
|
||||
# remember best acc@1 and save checkpoint
|
||||
is_best = acc1 > best_acc1
|
||||
best_acc1 = max(acc1, best_acc1)
|
||||
|
||||
if not args.multiprocessing_distributed or (args.multiprocessing_distributed
|
||||
and args.rank == 0): # only the first GPU saves checkpoint
|
||||
save_checkpoint({
|
||||
'epoch': epoch + 1,
|
||||
'arch': args.arch,
|
||||
'state_dict': model.state_dict(),
|
||||
'best_acc1': best_acc1,
|
||||
'optimizer' : optimizer.state_dict(),
|
||||
}, is_best)
|
||||
if epoch == args.start_epoch:
|
||||
sanity_check(model.state_dict(), args.pretrained, linear_keyword)
|
||||
|
||||
|
||||
def train(train_loader, model, criterion, optimizer, epoch, args):
|
||||
batch_time = AverageMeter('Time', ':6.3f')
|
||||
data_time = AverageMeter('Data', ':6.3f')
|
||||
losses = AverageMeter('Loss', ':.4e')
|
||||
top1 = AverageMeter('Acc@1', ':6.2f')
|
||||
top5 = AverageMeter('Acc@5', ':6.2f')
|
||||
progress = ProgressMeter(
|
||||
len(train_loader),
|
||||
[batch_time, data_time, losses, top1, top5],
|
||||
prefix="Epoch: [{}]".format(epoch))
|
||||
|
||||
"""
|
||||
Switch to eval mode:
|
||||
Under the protocol of linear classification on frozen features/models,
|
||||
it is not legitimate to change any part of the pre-trained model.
|
||||
BatchNorm in train mode may revise running mean/std (even if it receives
|
||||
no gradient), which are part of the model parameters too.
|
||||
"""
|
||||
model.eval()
|
||||
|
||||
end = time.time()
|
||||
for i, (images, target) in enumerate(train_loader):
|
||||
# measure data loading time
|
||||
data_time.update(time.time() - end)
|
||||
|
||||
if args.gpu is not None:
|
||||
images = images.cuda(args.gpu, non_blocking=True)
|
||||
if torch.cuda.is_available():
|
||||
target = target.cuda(args.gpu, non_blocking=True)
|
||||
|
||||
# compute output
|
||||
output = model(images)
|
||||
loss = criterion(output, target)
|
||||
|
||||
# measure accuracy and record loss
|
||||
acc1, acc5 = accuracy(output, target, topk=(1, 5))
|
||||
losses.update(loss.item(), images.size(0))
|
||||
top1.update(acc1[0], images.size(0))
|
||||
top5.update(acc5[0], images.size(0))
|
||||
|
||||
# compute gradient and do SGD step
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
|
||||
if i % args.print_freq == 0:
|
||||
progress.display(i)
|
||||
|
||||
|
||||
def validate(val_loader, model, criterion, args):
|
||||
batch_time = AverageMeter('Time', ':6.3f')
|
||||
losses = AverageMeter('Loss', ':.4e')
|
||||
top1 = AverageMeter('Acc@1', ':6.2f')
|
||||
top5 = AverageMeter('Acc@5', ':6.2f')
|
||||
progress = ProgressMeter(
|
||||
len(val_loader),
|
||||
[batch_time, losses, top1, top5],
|
||||
prefix='Test: ')
|
||||
|
||||
# switch to evaluate mode
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
end = time.time()
|
||||
for i, (images, target) in enumerate(val_loader):
|
||||
if args.gpu is not None:
|
||||
images = images.cuda(args.gpu, non_blocking=True)
|
||||
if torch.cuda.is_available():
|
||||
target = target.cuda(args.gpu, non_blocking=True)
|
||||
|
||||
# compute output
|
||||
output = model(images)
|
||||
loss = criterion(output, target)
|
||||
|
||||
# measure accuracy and record loss
|
||||
acc1, acc5 = accuracy(output, target, topk=(1, 5))
|
||||
losses.update(loss.item(), images.size(0))
|
||||
top1.update(acc1[0], images.size(0))
|
||||
top5.update(acc5[0], images.size(0))
|
||||
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
|
||||
if i % args.print_freq == 0:
|
||||
progress.display(i)
|
||||
|
||||
# TODO: this should also be done with the ProgressMeter
|
||||
print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
|
||||
.format(top1=top1, top5=top5))
|
||||
|
||||
return top1.avg
|
||||
|
||||
|
||||
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
|
||||
torch.save(state, filename)
|
||||
if is_best:
|
||||
shutil.copyfile(filename, 'model_best.pth.tar')
|
||||
|
||||
|
||||
def sanity_check(state_dict, pretrained_weights, linear_keyword):
|
||||
"""
|
||||
Linear classifier should not change any weights other than the linear layer.
|
||||
This sanity check asserts nothing wrong happens (e.g., BN stats updated).
|
||||
"""
|
||||
print("=> loading '{}' for sanity check".format(pretrained_weights))
|
||||
checkpoint = torch.load(pretrained_weights, map_location="cpu")
|
||||
state_dict_pre = checkpoint['state_dict']
|
||||
|
||||
for k in list(state_dict.keys()):
|
||||
# only ignore linear layer
|
||||
if '%s.weight' % linear_keyword in k or '%s.bias' % linear_keyword in k:
|
||||
continue
|
||||
|
||||
# name in pretrained model
|
||||
k_pre = 'module.base_encoder.' + k[len('module.'):] \
|
||||
if k.startswith('module.') else 'module.base_encoder.' + k
|
||||
|
||||
assert ((state_dict[k].cpu() == state_dict_pre[k_pre]).all()), \
|
||||
'{} is changed in linear classifier training.'.format(k)
|
||||
|
||||
print("=> sanity check passed.")
|
||||
|
||||
|
||||
class AverageMeter(object):
|
||||
"""Computes and stores the average and current value"""
|
||||
def __init__(self, name, fmt=':f'):
|
||||
self.name = name
|
||||
self.fmt = fmt
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.val = 0
|
||||
self.avg = 0
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
|
||||
def update(self, val, n=1):
|
||||
self.val = val
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
||||
|
||||
def __str__(self):
|
||||
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
|
||||
return fmtstr.format(**self.__dict__)
|
||||
|
||||
|
||||
class ProgressMeter(object):
|
||||
def __init__(self, num_batches, meters, prefix=""):
|
||||
self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
|
||||
self.meters = meters
|
||||
self.prefix = prefix
|
||||
|
||||
def display(self, batch):
|
||||
entries = [self.prefix + self.batch_fmtstr.format(batch)]
|
||||
entries += [str(meter) for meter in self.meters]
|
||||
print('\t'.join(entries))
|
||||
|
||||
def _get_batch_fmtstr(self, num_batches):
|
||||
num_digits = len(str(num_batches // 1))
|
||||
fmt = '{:' + str(num_digits) + 'd}'
|
||||
return '[' + fmt + '/' + fmt.format(num_batches) + ']'
|
||||
|
||||
|
||||
def adjust_learning_rate(optimizer, init_lr, epoch, args):
|
||||
"""Decay the learning rate based on schedule"""
|
||||
cur_lr = init_lr * 0.5 * (1. + math.cos(math.pi * epoch / args.epochs))
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = cur_lr
|
||||
|
||||
|
||||
def accuracy(output, target, topk=(1,)):
|
||||
"""Computes the accuracy over the k top predictions for the specified values of k"""
|
||||
with torch.no_grad():
|
||||
maxk = max(topk)
|
||||
batch_size = target.size(0)
|
||||
|
||||
_, pred = output.topk(maxk, 1, True, True)
|
||||
pred = pred.t()
|
||||
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
||||
|
||||
res = []
|
||||
for k in topk:
|
||||
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
|
||||
res.append(correct_k.mul_(100.0 / batch_size))
|
||||
return res
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,438 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
import builtins
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
import time
|
||||
import warnings
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.parallel
|
||||
import torch.backends.cudnn as cudnn
|
||||
import torch.distributed as dist
|
||||
import torch.optim
|
||||
import torch.multiprocessing as mp
|
||||
import torch.utils.data
|
||||
import torch.utils.data.distributed
|
||||
import torchvision.transforms as transforms
|
||||
import torchvision.datasets as datasets
|
||||
import torchvision.models as torchvision_models
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
import moco.builder
|
||||
import moco.loader
|
||||
import moco.optimizer
|
||||
|
||||
import vits
|
||||
|
||||
|
||||
torchvision_model_names = sorted(name for name in torchvision_models.__dict__
|
||||
if name.islower() and not name.startswith("__")
|
||||
and callable(torchvision_models.__dict__[name]))
|
||||
|
||||
model_names = ['vit_small', 'vit_base', 'vit_conv_small', 'vit_conv_base'] + torchvision_model_names
|
||||
|
||||
parser = argparse.ArgumentParser(description='MoCo ImageNet Pre-Training')
|
||||
parser.add_argument('data', metavar='DIR',
|
||||
help='path to dataset')
|
||||
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50',
|
||||
choices=model_names,
|
||||
help='model architecture: ' +
|
||||
' | '.join(model_names) +
|
||||
' (default: resnet50)')
|
||||
parser.add_argument('-j', '--workers', default=32, type=int, metavar='N',
|
||||
help='number of data loading workers (default: 32)')
|
||||
parser.add_argument('--epochs', default=100, type=int, metavar='N',
|
||||
help='number of total epochs to run')
|
||||
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
|
||||
help='manual epoch number (useful on restarts)')
|
||||
parser.add_argument('-b', '--batch-size', default=4096, type=int,
|
||||
metavar='N',
|
||||
help='mini-batch size (default: 4096), this is the total '
|
||||
'batch size of all GPUs on the current node when '
|
||||
'using Data Parallel or Distributed Data Parallel')
|
||||
parser.add_argument('--lr', '--learning-rate', default=0.6, type=float,
|
||||
metavar='LR', help='initial (base) learning rate', dest='lr')
|
||||
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
|
||||
help='momentum')
|
||||
parser.add_argument('--wd', '--weight-decay', default=1e-6, type=float,
|
||||
metavar='W', help='weight decay (default: 1e-6)',
|
||||
dest='weight_decay')
|
||||
parser.add_argument('-p', '--print-freq', default=10, type=int,
|
||||
metavar='N', help='print frequency (default: 10)')
|
||||
parser.add_argument('--resume', default='', type=str, metavar='PATH',
|
||||
help='path to latest checkpoint (default: none)')
|
||||
parser.add_argument('--world-size', default=-1, type=int,
|
||||
help='number of nodes for distributed training')
|
||||
parser.add_argument('--rank', default=-1, type=int,
|
||||
help='node rank for distributed training')
|
||||
parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str,
|
||||
help='url used to set up distributed training')
|
||||
parser.add_argument('--dist-backend', default='nccl', type=str,
|
||||
help='distributed backend')
|
||||
parser.add_argument('--seed', default=None, type=int,
|
||||
help='seed for initializing training. ')
|
||||
parser.add_argument('--gpu', default=None, type=int,
|
||||
help='GPU id to use.')
|
||||
parser.add_argument('--multiprocessing-distributed', action='store_true',
|
||||
help='Use multi-processing distributed training to launch '
|
||||
'N processes per node, which has N GPUs. This is the '
|
||||
'fastest way to use PyTorch for either single node or '
|
||||
'multi node data parallel training')
|
||||
|
||||
# moco specific configs:
|
||||
parser.add_argument('--moco-dim', default=256, type=int,
|
||||
help='feature dimension (default: 256)')
|
||||
parser.add_argument('--moco-mlp-dim', default=4096, type=int,
|
||||
help='hidden dimension in MLPs (default: 4096)')
|
||||
parser.add_argument('--moco-m', default=0.99, type=float,
|
||||
help='moco momentum of updating momentum encoder (default: 0.99)')
|
||||
parser.add_argument('--moco-m-cos', action='store_true',
|
||||
help='gradually increase moco momentum to 1 with a '
|
||||
'half-cycle cosine schedule')
|
||||
parser.add_argument('--moco-t', default=1.0, type=float,
|
||||
help='softmax temperature (default: 1.0)')
|
||||
|
||||
# vit specific configs:
|
||||
parser.add_argument('--stop-grad-conv1', action='store_true',
|
||||
help='stop-grad after first conv, or patch embedding')
|
||||
|
||||
# other upgrades
|
||||
parser.add_argument('--optimizer', default='lars', type=str,
|
||||
choices=['lars', 'adamw'],
|
||||
help='optimizer used (default: lars)')
|
||||
parser.add_argument('--warmup-epochs', default=10, type=int, metavar='N',
|
||||
help='number of warmup epochs')
|
||||
parser.add_argument('--crop-min', default=0.08, type=float,
|
||||
help='minimum scale for random cropping (default: 0.2)')
|
||||
|
||||
|
||||
def main():
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.seed is not None:
|
||||
random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
cudnn.deterministic = True
|
||||
warnings.warn('You have chosen to seed training. '
|
||||
'This will turn on the CUDNN deterministic setting, '
|
||||
'which can slow down your training considerably! '
|
||||
'You may see unexpected behavior when restarting '
|
||||
'from checkpoints.')
|
||||
|
||||
if args.gpu is not None:
|
||||
warnings.warn('You have chosen a specific GPU. This will completely '
|
||||
'disable data parallelism.')
|
||||
|
||||
if args.dist_url == "env://" and args.world_size == -1:
|
||||
args.world_size = int(os.environ["WORLD_SIZE"])
|
||||
|
||||
args.distributed = args.world_size > 1 or args.multiprocessing_distributed
|
||||
|
||||
ngpus_per_node = torch.cuda.device_count()
|
||||
if args.multiprocessing_distributed:
|
||||
# Since we have ngpus_per_node processes per node, the total world_size
|
||||
# needs to be adjusted accordingly
|
||||
args.world_size = ngpus_per_node * args.world_size
|
||||
# Use torch.multiprocessing.spawn to launch distributed processes: the
|
||||
# main_worker process function
|
||||
mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
|
||||
else:
|
||||
# Simply call main_worker function
|
||||
main_worker(args.gpu, ngpus_per_node, args)
|
||||
|
||||
|
||||
def main_worker(gpu, ngpus_per_node, args):
|
||||
args.gpu = gpu
|
||||
|
||||
# suppress printing if not first GPU on each node
|
||||
if args.multiprocessing_distributed and (args.gpu != 0 or args.rank != 0):
|
||||
def print_pass(*args):
|
||||
pass
|
||||
builtins.print = print_pass
|
||||
|
||||
if args.gpu is not None:
|
||||
print("Use GPU: {} for training".format(args.gpu))
|
||||
|
||||
if args.distributed:
|
||||
if args.dist_url == "env://" and args.rank == -1:
|
||||
args.rank = int(os.environ["RANK"])
|
||||
if args.multiprocessing_distributed:
|
||||
# For multiprocessing distributed training, rank needs to be the
|
||||
# global rank among all the processes
|
||||
args.rank = args.rank * ngpus_per_node + gpu
|
||||
dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
|
||||
world_size=args.world_size, rank=args.rank)
|
||||
torch.distributed.barrier()
|
||||
# create model
|
||||
print("=> creating model '{}'".format(args.arch))
|
||||
if args.arch.startswith('vit'):
|
||||
model = moco.builder.MoCo_ViT(
|
||||
partial(vits.__dict__[args.arch], stop_grad_conv1=args.stop_grad_conv1),
|
||||
args.moco_dim, args.moco_mlp_dim, args.moco_t)
|
||||
else:
|
||||
model = moco.builder.MoCo_ResNet(
|
||||
partial(torchvision_models.__dict__[args.arch], zero_init_residual=True),
|
||||
args.moco_dim, args.moco_mlp_dim, args.moco_t)
|
||||
|
||||
# infer learning rate before changing batch size
|
||||
args.lr = args.lr * args.batch_size / 256
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
print('using CPU, this will be slow')
|
||||
elif args.distributed:
|
||||
# apply SyncBN
|
||||
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
||||
# For multiprocessing distributed, DistributedDataParallel constructor
|
||||
# should always set the single device scope, otherwise,
|
||||
# DistributedDataParallel will use all available devices.
|
||||
if args.gpu is not None:
|
||||
torch.cuda.set_device(args.gpu)
|
||||
model.cuda(args.gpu)
|
||||
# When using a single GPU per process and per
|
||||
# DistributedDataParallel, we need to divide the batch size
|
||||
# ourselves based on the total number of GPUs we have
|
||||
args.batch_size = int(args.batch_size / args.world_size)
|
||||
args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
|
||||
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
|
||||
else:
|
||||
model.cuda()
|
||||
# DistributedDataParallel will divide and allocate batch_size to all
|
||||
# available GPUs if device_ids are not set
|
||||
model = torch.nn.parallel.DistributedDataParallel(model)
|
||||
elif args.gpu is not None:
|
||||
torch.cuda.set_device(args.gpu)
|
||||
model = model.cuda(args.gpu)
|
||||
# comment out the following line for debugging
|
||||
raise NotImplementedError("Only DistributedDataParallel is supported.")
|
||||
else:
|
||||
# AllGather/rank implementation in this code only supports DistributedDataParallel.
|
||||
raise NotImplementedError("Only DistributedDataParallel is supported.")
|
||||
print(model) # print model after SyncBatchNorm
|
||||
|
||||
if args.optimizer == 'lars':
|
||||
optimizer = moco.optimizer.LARS(model.parameters(), args.lr,
|
||||
weight_decay=args.weight_decay,
|
||||
momentum=args.momentum)
|
||||
elif args.optimizer == 'adamw':
|
||||
optimizer = torch.optim.AdamW(model.parameters(), args.lr,
|
||||
weight_decay=args.weight_decay)
|
||||
|
||||
scaler = torch.cuda.amp.GradScaler()
|
||||
summary_writer = SummaryWriter() if args.rank == 0 else None
|
||||
|
||||
# optionally resume from a checkpoint
|
||||
if args.resume:
|
||||
if os.path.isfile(args.resume):
|
||||
print("=> loading checkpoint '{}'".format(args.resume))
|
||||
if args.gpu is None:
|
||||
checkpoint = torch.load(args.resume)
|
||||
else:
|
||||
# Map model to be loaded to specified single gpu.
|
||||
loc = 'cuda:{}'.format(args.gpu)
|
||||
checkpoint = torch.load(args.resume, map_location=loc)
|
||||
args.start_epoch = checkpoint['epoch']
|
||||
model.load_state_dict(checkpoint['state_dict'])
|
||||
optimizer.load_state_dict(checkpoint['optimizer'])
|
||||
scaler.load_state_dict(checkpoint['scaler'])
|
||||
print("=> loaded checkpoint '{}' (epoch {})"
|
||||
.format(args.resume, checkpoint['epoch']))
|
||||
else:
|
||||
print("=> no checkpoint found at '{}'".format(args.resume))
|
||||
|
||||
cudnn.benchmark = True
|
||||
|
||||
# Data loading code
|
||||
traindir = os.path.join(args.data, 'train')
|
||||
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225])
|
||||
|
||||
# follow BYOL's augmentation recipe: https://arxiv.org/abs/2006.07733
|
||||
augmentation1 = [
|
||||
transforms.RandomResizedCrop(224, scale=(args.crop_min, 1.)),
|
||||
transforms.RandomApply([
|
||||
transforms.ColorJitter(0.4, 0.4, 0.2, 0.1) # not strengthened
|
||||
], p=0.8),
|
||||
transforms.RandomGrayscale(p=0.2),
|
||||
transforms.RandomApply([moco.loader.GaussianBlur([.1, 2.])], p=1.0),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ToTensor(),
|
||||
normalize
|
||||
]
|
||||
|
||||
augmentation2 = [
|
||||
transforms.RandomResizedCrop(224, scale=(args.crop_min, 1.)),
|
||||
transforms.RandomApply([
|
||||
transforms.ColorJitter(0.4, 0.4, 0.2, 0.1) # not strengthened
|
||||
], p=0.8),
|
||||
transforms.RandomGrayscale(p=0.2),
|
||||
transforms.RandomApply([moco.loader.GaussianBlur([.1, 2.])], p=0.1),
|
||||
transforms.RandomApply([moco.loader.Solarize()], p=0.2),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ToTensor(),
|
||||
normalize
|
||||
]
|
||||
|
||||
train_dataset = datasets.ImageFolder(
|
||||
traindir,
|
||||
moco.loader.TwoCropsTransform(transforms.Compose(augmentation1),
|
||||
transforms.Compose(augmentation2)))
|
||||
|
||||
if args.distributed:
|
||||
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
|
||||
else:
|
||||
train_sampler = None
|
||||
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
|
||||
num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True)
|
||||
|
||||
for epoch in range(args.start_epoch, args.epochs):
|
||||
if args.distributed:
|
||||
train_sampler.set_epoch(epoch)
|
||||
|
||||
# train for one epoch
|
||||
train(train_loader, model, optimizer, scaler, summary_writer, epoch, args)
|
||||
|
||||
if not args.multiprocessing_distributed or (args.multiprocessing_distributed
|
||||
and args.rank == 0): # only the first GPU saves checkpoint
|
||||
save_checkpoint({
|
||||
'epoch': epoch + 1,
|
||||
'arch': args.arch,
|
||||
'state_dict': model.state_dict(),
|
||||
'optimizer' : optimizer.state_dict(),
|
||||
'scaler': scaler.state_dict(),
|
||||
}, is_best=False, filename='checkpoint_%04d.pth.tar' % epoch)
|
||||
|
||||
if args.rank == 0:
|
||||
summary_writer.close()
|
||||
|
||||
def train(train_loader, model, optimizer, scaler, summary_writer, epoch, args):
|
||||
batch_time = AverageMeter('Time', ':6.3f')
|
||||
data_time = AverageMeter('Data', ':6.3f')
|
||||
learning_rates = AverageMeter('LR', ':.4e')
|
||||
losses = AverageMeter('Loss', ':.4e')
|
||||
progress = ProgressMeter(
|
||||
len(train_loader),
|
||||
[batch_time, data_time, learning_rates, losses],
|
||||
prefix="Epoch: [{}]".format(epoch))
|
||||
|
||||
# switch to train mode
|
||||
model.train()
|
||||
|
||||
end = time.time()
|
||||
iters_per_epoch = len(train_loader)
|
||||
moco_m = args.moco_m
|
||||
for i, (images, _) in enumerate(train_loader):
|
||||
# measure data loading time
|
||||
data_time.update(time.time() - end)
|
||||
|
||||
# adjust learning rate and momentum coefficient per iteration
|
||||
lr = adjust_learning_rate(optimizer, epoch + i / iters_per_epoch, args)
|
||||
learning_rates.update(lr)
|
||||
if args.moco_m_cos:
|
||||
moco_m = adjust_moco_momentum(epoch + i / iters_per_epoch, args)
|
||||
|
||||
if args.gpu is not None:
|
||||
images[0] = images[0].cuda(args.gpu, non_blocking=True)
|
||||
images[1] = images[1].cuda(args.gpu, non_blocking=True)
|
||||
|
||||
# compute output
|
||||
with torch.cuda.amp.autocast(True):
|
||||
loss = model(images[0], images[1], moco_m)
|
||||
|
||||
losses.update(loss.item(), images[0].size(0))
|
||||
if args.rank == 0:
|
||||
summary_writer.add_scalar("loss", loss.item(), epoch * iters_per_epoch + i)
|
||||
|
||||
# compute gradient and do SGD step
|
||||
optimizer.zero_grad()
|
||||
scaler.scale(loss).backward()
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
|
||||
if i % args.print_freq == 0:
|
||||
progress.display(i)
|
||||
|
||||
|
||||
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
|
||||
torch.save(state, filename)
|
||||
if is_best:
|
||||
shutil.copyfile(filename, 'model_best.pth.tar')
|
||||
|
||||
|
||||
class AverageMeter(object):
|
||||
"""Computes and stores the average and current value"""
|
||||
def __init__(self, name, fmt=':f'):
|
||||
self.name = name
|
||||
self.fmt = fmt
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.val = 0
|
||||
self.avg = 0
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
|
||||
def update(self, val, n=1):
|
||||
self.val = val
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
||||
|
||||
def __str__(self):
|
||||
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
|
||||
return fmtstr.format(**self.__dict__)
|
||||
|
||||
|
||||
class ProgressMeter(object):
|
||||
def __init__(self, num_batches, meters, prefix=""):
|
||||
self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
|
||||
self.meters = meters
|
||||
self.prefix = prefix
|
||||
|
||||
def display(self, batch):
|
||||
entries = [self.prefix + self.batch_fmtstr.format(batch)]
|
||||
entries += [str(meter) for meter in self.meters]
|
||||
print('\t'.join(entries))
|
||||
|
||||
def _get_batch_fmtstr(self, num_batches):
|
||||
num_digits = len(str(num_batches // 1))
|
||||
fmt = '{:' + str(num_digits) + 'd}'
|
||||
return '[' + fmt + '/' + fmt.format(num_batches) + ']'
|
||||
|
||||
|
||||
def adjust_learning_rate(optimizer, epoch, args):
|
||||
"""Decays the learning rate with half-cycle cosine after warmup"""
|
||||
if epoch < args.warmup_epochs:
|
||||
lr = args.lr * epoch / args.warmup_epochs
|
||||
else:
|
||||
lr = args.lr * 0.5 * (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs)))
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = lr
|
||||
return lr
|
||||
|
||||
|
||||
def adjust_moco_momentum(epoch, args):
|
||||
"""Adjust moco momentum based on current epoch"""
|
||||
m = 1. - 0.5 * (1. + math.cos(math.pi * epoch / args.epochs)) * (1. - args.moco_m)
|
||||
return m
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
|
@ -0,0 +1,137 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class MoCo(nn.Module):
|
||||
"""
|
||||
Build a MoCo model with a base encoder, a momentum encoder, and two MLPs
|
||||
https://arxiv.org/abs/1911.05722
|
||||
"""
|
||||
def __init__(self, base_encoder, dim=256, mlp_dim=4096, T=1.0):
|
||||
"""
|
||||
dim: feature dimension (default: 256)
|
||||
mlp_dim: hidden dimension in MLPs (default: 4096)
|
||||
T: softmax temperature (default: 1.0)
|
||||
"""
|
||||
super(MoCo, self).__init__()
|
||||
|
||||
self.T = T
|
||||
|
||||
# build encoders
|
||||
self.base_encoder = base_encoder(num_classes=mlp_dim)
|
||||
self.momentum_encoder = base_encoder(num_classes=mlp_dim)
|
||||
|
||||
self._build_projector_and_predictor_mlps(dim, mlp_dim)
|
||||
|
||||
for param_b, param_m in zip(self.base_encoder.parameters(), self.momentum_encoder.parameters()):
|
||||
param_m.data.copy_(param_b.data) # initialize
|
||||
param_m.requires_grad = False # not update by gradient
|
||||
|
||||
def _build_mlp(self, num_layers, input_dim, mlp_dim, output_dim, last_bn=True):
|
||||
mlp = []
|
||||
for l in range(num_layers):
|
||||
dim1 = input_dim if l == 0 else mlp_dim
|
||||
dim2 = output_dim if l == num_layers - 1 else mlp_dim
|
||||
|
||||
mlp.append(nn.Linear(dim1, dim2, bias=False))
|
||||
|
||||
if l < num_layers - 1:
|
||||
mlp.append(nn.BatchNorm1d(dim2))
|
||||
mlp.append(nn.ReLU(inplace=True))
|
||||
elif last_bn:
|
||||
# follow SimCLR's design: https://github.com/google-research/simclr/blob/master/model_util.py#L157
|
||||
# for simplicity, we further removed gamma in BN
|
||||
mlp.append(nn.BatchNorm1d(dim2, affine=False))
|
||||
|
||||
return nn.Sequential(*mlp)
|
||||
|
||||
def _build_projector_and_predictor_mlps(self, dim, mlp_dim):
|
||||
pass
|
||||
|
||||
@torch.no_grad()
|
||||
def _update_momentum_encoder(self, m):
|
||||
"""Momentum update of the momentum encoder"""
|
||||
for param_b, param_m in zip(self.base_encoder.parameters(), self.momentum_encoder.parameters()):
|
||||
param_m.data = param_m.data * m + param_b.data * (1. - m)
|
||||
|
||||
def contrastive_loss(self, q, k):
|
||||
# normalize
|
||||
q = nn.functional.normalize(q, dim=1)
|
||||
k = nn.functional.normalize(k, dim=1)
|
||||
# gather all targets
|
||||
k = concat_all_gather(k)
|
||||
# Einstein sum is more intuitive
|
||||
logits = torch.einsum('nc,mc->nm', [q, k]) / self.T
|
||||
N = logits.shape[0] # batch size per GPU
|
||||
labels = (torch.arange(N, dtype=torch.long) + N * torch.distributed.get_rank()).cuda()
|
||||
return nn.CrossEntropyLoss()(logits, labels) * (2 * self.T)
|
||||
|
||||
def forward(self, x1, x2, m):
|
||||
"""
|
||||
Input:
|
||||
x1: first views of images
|
||||
x2: second views of images
|
||||
m: moco momentum
|
||||
Output:
|
||||
loss
|
||||
"""
|
||||
|
||||
# compute features
|
||||
q1 = self.predictor(self.base_encoder(x1))
|
||||
q2 = self.predictor(self.base_encoder(x2))
|
||||
|
||||
with torch.no_grad(): # no gradient
|
||||
self._update_momentum_encoder(m) # update the momentum encoder
|
||||
|
||||
# compute momentum features as targets
|
||||
k1 = self.momentum_encoder(x1)
|
||||
k2 = self.momentum_encoder(x2)
|
||||
|
||||
return self.contrastive_loss(q1, k2) + self.contrastive_loss(q2, k1)
|
||||
|
||||
|
||||
class MoCo_ResNet(MoCo):
|
||||
def _build_projector_and_predictor_mlps(self, dim, mlp_dim):
|
||||
hidden_dim = self.base_encoder.fc.weight.shape[1]
|
||||
del self.base_encoder.fc, self.momentum_encoder.fc # remove original fc layer
|
||||
|
||||
# projectors
|
||||
self.base_encoder.fc = self._build_mlp(2, hidden_dim, mlp_dim, dim)
|
||||
self.momentum_encoder.fc = self._build_mlp(2, hidden_dim, mlp_dim, dim)
|
||||
|
||||
# predictor
|
||||
self.predictor = self._build_mlp(2, dim, mlp_dim, dim, False)
|
||||
|
||||
|
||||
class MoCo_ViT(MoCo):
|
||||
def _build_projector_and_predictor_mlps(self, dim, mlp_dim):
|
||||
hidden_dim = self.base_encoder.head.weight.shape[1]
|
||||
del self.base_encoder.head, self.momentum_encoder.head # remove original fc layer
|
||||
|
||||
# projectors
|
||||
self.base_encoder.head = self._build_mlp(3, hidden_dim, mlp_dim, dim)
|
||||
self.momentum_encoder.head = self._build_mlp(3, hidden_dim, mlp_dim, dim)
|
||||
|
||||
# predictor
|
||||
self.predictor = self._build_mlp(2, dim, mlp_dim, dim)
|
||||
|
||||
|
||||
# utils
|
||||
@torch.no_grad()
|
||||
def concat_all_gather(tensor):
|
||||
"""
|
||||
Performs all_gather operation on the provided tensors.
|
||||
*** Warning ***: torch.distributed.all_gather has no gradient.
|
||||
"""
|
||||
tensors_gather = [torch.ones_like(tensor)
|
||||
for _ in range(torch.distributed.get_world_size())]
|
||||
torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
|
||||
|
||||
output = torch.cat(tensors_gather, dim=0)
|
||||
return output
|
|
@ -0,0 +1,42 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from PIL import Image, ImageFilter, ImageOps
|
||||
import math
|
||||
import random
|
||||
import torchvision.transforms.functional as tf
|
||||
|
||||
|
||||
class TwoCropsTransform:
|
||||
"""Take two random crops of one image"""
|
||||
|
||||
def __init__(self, base_transform1, base_transform2):
|
||||
self.base_transform1 = base_transform1
|
||||
self.base_transform2 = base_transform2
|
||||
|
||||
def __call__(self, x):
|
||||
im1 = self.base_transform1(x)
|
||||
im2 = self.base_transform2(x)
|
||||
return [im1, im2]
|
||||
|
||||
|
||||
class GaussianBlur(object):
|
||||
"""Gaussian blur augmentation from SimCLR: https://arxiv.org/abs/2002.05709"""
|
||||
|
||||
def __init__(self, sigma=[.1, 2.]):
|
||||
self.sigma = sigma
|
||||
|
||||
def __call__(self, x):
|
||||
sigma = random.uniform(self.sigma[0], self.sigma[1])
|
||||
x = x.filter(ImageFilter.GaussianBlur(radius=sigma))
|
||||
return x
|
||||
|
||||
|
||||
class Solarize(object):
|
||||
"""Solarize augmentation from BYOL: https://arxiv.org/abs/2006.07733"""
|
||||
|
||||
def __call__(self, x):
|
||||
return ImageOps.solarize(x)
|
|
@ -0,0 +1,44 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import math
|
||||
import torch
|
||||
|
||||
|
||||
class LARS(torch.optim.Optimizer):
|
||||
"""
|
||||
LARS optimizer, no rate scaling or weight decay for parameters <= 1D.
|
||||
"""
|
||||
def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, trust_coefficient=0.001):
|
||||
defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, trust_coefficient=trust_coefficient)
|
||||
super().__init__(params, defaults)
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self):
|
||||
for g in self.param_groups:
|
||||
for p in g['params']:
|
||||
dp = p.grad
|
||||
|
||||
if dp is None:
|
||||
continue
|
||||
|
||||
if p.ndim > 1: # if not normalization gamma/beta or bias
|
||||
dp = dp.add(p, alpha=g['weight_decay'])
|
||||
param_norm = torch.norm(p)
|
||||
update_norm = torch.norm(dp)
|
||||
one = torch.ones_like(param_norm)
|
||||
q = torch.where(param_norm > 0.,
|
||||
torch.where(update_norm > 0,
|
||||
(g['trust_coefficient'] * param_norm / update_norm), one),
|
||||
one)
|
||||
dp = dp.mul(q)
|
||||
|
||||
param_state = self.state[p]
|
||||
if 'mu' not in param_state:
|
||||
param_state['mu'] = torch.zeros_like(p)
|
||||
mu = param_state['mu']
|
||||
mu.mul_(g['momentum']).add_(dp)
|
||||
p.add_(mu, alpha=-g['lr'])
|
|
@ -0,0 +1,128 @@
|
|||
## MoCo v3 Transfer Learning with ViT
|
||||
|
||||
This folder includes the transfer learning experiments on CIFAR-10, CIFAR-100, Flowers and Pets datasets. We provide finetuning recipes for the ViT-Base model.
|
||||
|
||||
### Transfer Results
|
||||
|
||||
The following results are based on ImageNet-1k self-supervised pre-training, followed by end-to-end fine-tuning on downstream datasets. All results are based on a batch size of 128 and 100 training epochs.
|
||||
|
||||
#### ViT-Base, transfer learning
|
||||
<table><tbody>
|
||||
<!-- START TABLE -->
|
||||
<!-- TABLE HEADER -->
|
||||
<th valign="center">dataset</th>
|
||||
<th valign="center">pretrain<br/>epochs</th>
|
||||
<th valign="center">pretrain<br/>crops</th>
|
||||
<th valign="center">finetune<br/>epochs</th>
|
||||
<th valign="center">transfer<br/>acc</th>
|
||||
<!-- TABLE BODY -->
|
||||
<tr>
|
||||
<td align="left">CIFAR-10</td>
|
||||
<td align="right">300</td>
|
||||
<td align="center">2x224</td>
|
||||
<td align="right">100</td>
|
||||
<td align="center">98.9</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="left">CIFAR-100</td>
|
||||
<td align="right">300</td>
|
||||
<td align="center">2x224</td>
|
||||
<td align="right">100</td>
|
||||
<td align="center">90.5</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="left">Flowers</td>
|
||||
<td align="right">300</td>
|
||||
<td align="center">2x224</td>
|
||||
<td align="right">100</td>
|
||||
<td align="center">97.7</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="left">Pets</td>
|
||||
<td align="right">300</td>
|
||||
<td align="center">2x224</td>
|
||||
<td align="right">100</td>
|
||||
<td align="center">93.2</td>
|
||||
</tr>
|
||||
</tbody></table>
|
||||
|
||||
Similar to the end-to-end fine-tuning experiment on ImageNet, the transfer learning results are also obtained using the [DeiT](https://github.com/facebookresearch/deit) repo, with the default model [deit_base_patch16_224].
|
||||
|
||||
### Preparation: Transfer learning with ViT
|
||||
|
||||
To perform transfer learning for ViT, use our script to convert the pre-trained ViT checkpoint to [DEiT](https://github.com/facebookresearch/deit) format:
|
||||
```
|
||||
python convert_to_deit.py \
|
||||
--input [your checkpoint path]/[your checkpoint file].pth.tar \
|
||||
--output [target checkpoint file].pth
|
||||
```
|
||||
Then copy (or replace) the following files to the DeiT folder:
|
||||
```
|
||||
datasets.py
|
||||
oxford_flowers_dataset.py
|
||||
oxford_pets_dataset.py
|
||||
```
|
||||
|
||||
#### Download and prepare the datasets
|
||||
|
||||
Pets [\[Homepage\]](https://www.robots.ox.ac.uk/~vgg/data/pets/)
|
||||
```
|
||||
./data/
|
||||
└── ./data/pets/
|
||||
├── ./data/pets/annotations/ # split and label files
|
||||
└── ./data/pets/images/ # data images
|
||||
```
|
||||
|
||||
Flowers [\[Homepage\]](https://www.robots.ox.ac.uk/~vgg/data/flowers/102/)
|
||||
```
|
||||
./data/
|
||||
└── ./data/flowers/
|
||||
├── ./data/flowers/jpg/ # jpg images
|
||||
├── ./data/flowers/setid.mat # dataset split
|
||||
└── ./data/flowers/imagelabels.mat # labels
|
||||
```
|
||||
|
||||
|
||||
CIFAR-10/CIFAR-100 datasets will be downloaded automatically.
|
||||
|
||||
|
||||
### Transfer learning scripts (with a 8-GPU machine):
|
||||
|
||||
#### CIFAR-10
|
||||
```
|
||||
python -u -m torch.distributed.launch --nproc_per_node=8 --use_env main.py \
|
||||
--batch-size 128 --output_dir [your output dir path] --epochs 100 --lr 3e-4 --weight-decay 0.1 --eval-freq 10 \
|
||||
--no-pin-mem --warmup-epochs 3 --data-set cifar10 --data-path [cifar-10 data path] --no-repeated-aug \
|
||||
--resume [your pretrain checkpoint file] \
|
||||
--reprob 0.0 --drop-path 0.1 --mixup 0.8 --cutmix 1
|
||||
```
|
||||
|
||||
#### CIFAR-100
|
||||
```
|
||||
python -u -m torch.distributed.launch --nproc_per_node=8 --use_env main.py \
|
||||
--batch-size 128 --output_dir [your output dir path] --epochs 100 --lr 3e-4 --weight-decay 0.1 --eval-freq 10 \
|
||||
--no-pin-mem --warmup-epochs 3 --data-set cifar10 --data-path [cifar-100 data path] --no-repeated-aug \
|
||||
--resume [your pretrain checkpoint file] \
|
||||
--reprob 0.0 --drop-path 0.1 --mixup 0.5 --cutmix 1
|
||||
```
|
||||
|
||||
#### Flowers
|
||||
```
|
||||
python -u -m torch.distributed.launch --nproc_per_node=8 --use_env main.py \
|
||||
--batch-size 128 --output_dir [your output dir path] --epochs 100 --lr 3e-4 --weight-decay 0.3 --eval-freq 10 \
|
||||
--no-pin-mem --warmup-epochs 3 --data-set cifar10 --data-path [oxford-flowers data path] --no-repeated-aug \
|
||||
--resume [your pretrain checkpoint file] \
|
||||
--reprob 0.25 --drop-path 0.1 --mixup 0 --cutmix 0
|
||||
```
|
||||
|
||||
#### Pets
|
||||
```
|
||||
python -u -m torch.distributed.launch --nproc_per_node=8 --use_env main.py \
|
||||
--batch-size 128 --output_dir [your output dir path] --epochs 100 --lr 3e-4 --weight-decay 0.1 --eval-freq 10 \
|
||||
--no-pin-mem --warmup-epochs 3 --data-set cifar10 --data-path [oxford-pets data path] --no-repeated-aug \
|
||||
--resume [your pretrain checkpoint file] \
|
||||
--reprob 0 --drop-path 0 --mixup 0.8 --cutmix 0
|
||||
```
|
||||
|
||||
**Note**:
|
||||
Similar to the ImageNet end-to-end finetuning experiment, we use `--resume` rather than `--finetune` in the DeiT repo, as its `--finetune` option trains under eval mode. When loading the pre-trained model, revise `model_without_ddp.load_state_dict(checkpoint['model'])` with `strict=False`.
|
|
@ -0,0 +1,74 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
from torchvision import datasets, transforms
|
||||
from torchvision.datasets.folder import ImageFolder, default_loader
|
||||
|
||||
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
|
||||
import oxford_flowers_dataset, oxford_pets_dataset
|
||||
|
||||
|
||||
def build_transform(is_train, args):
|
||||
transform_train = transforms.Compose([
|
||||
transforms.RandomResizedCrop((args.input_size, args.input_size), scale=(0.05, 1.0)),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
|
||||
])
|
||||
transform_test = transforms.Compose([
|
||||
transforms.Resize(int((256 / 224) * args.input_size)),
|
||||
transforms.CenterCrop(args.input_size),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
|
||||
])
|
||||
return transform_train if is_train else transform_test
|
||||
|
||||
|
||||
def build_dataset(is_train, args):
|
||||
transform = build_transform(is_train, args)
|
||||
|
||||
if args.data_set == 'imagenet':
|
||||
raise NotImplementedError("Only [cifar10, cifar100, flowers, pets] are supported; \
|
||||
for imagenet end-to-end finetuning, please refer to the instructions in the main README.")
|
||||
|
||||
if args.data_set == 'imagenet':
|
||||
root = os.path.join(args.data_path, 'train' if is_train else 'val')
|
||||
dataset = datasets.ImageFolder(root, transform=transform)
|
||||
nb_classes = 1000
|
||||
|
||||
elif args.data_set == 'cifar10':
|
||||
dataset = datasets.CIFAR10(root=args.data_path,
|
||||
train=is_train,
|
||||
download=True,
|
||||
transform=transform)
|
||||
nb_classes = 10
|
||||
elif args.data_set == "cifar100":
|
||||
dataset = datasets.CIFAR100(root=args.data_path,
|
||||
train=is_train,
|
||||
download=True,
|
||||
transform=transform)
|
||||
nb_classes = 100
|
||||
elif args.data_set == "flowers":
|
||||
dataset = oxford_flowers_dataset.Flowers(root=args.data_path,
|
||||
train=is_train,
|
||||
download=False,
|
||||
transform=transform)
|
||||
nb_classes = 102
|
||||
elif args.data_set == "pets":
|
||||
dataset = oxford_pets_dataset.Pets(root=args.data_path,
|
||||
train=is_train,
|
||||
download=False,
|
||||
transform=transform)
|
||||
nb_classes = 37
|
||||
else:
|
||||
raise NotImplementedError("Only [cifar10, cifar100, flowers, pets] are supported; \
|
||||
for imagenet end-to-end finetuning, please refer to the instructions in the main README.")
|
||||
|
||||
return dataset, nb_classes
|
|
@ -0,0 +1,67 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from __future__ import print_function
|
||||
from PIL import Image
|
||||
from typing import Any, Callable, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import os
|
||||
import os.path
|
||||
import pickle
|
||||
import scipy.io
|
||||
|
||||
from torchvision.datasets.vision import VisionDataset
|
||||
|
||||
|
||||
class Flowers(VisionDataset):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root,
|
||||
train=True,
|
||||
transform=None,
|
||||
target_transform=None,
|
||||
download=False,
|
||||
):
|
||||
|
||||
super(Flowers, self).__init__(root, transform=transform,
|
||||
target_transform=target_transform)
|
||||
|
||||
base_folder = root
|
||||
self.image_folder = os.path.join(base_folder, "jpg")
|
||||
label_file = os.path.join(base_folder, "imagelabels.mat")
|
||||
setid_file = os.path.join(base_folder, "setid.mat")
|
||||
|
||||
self.train = train
|
||||
|
||||
self.labels = scipy.io.loadmat(label_file)["labels"][0]
|
||||
train_list = scipy.io.loadmat(setid_file)["trnid"][0]
|
||||
val_list = scipy.io.loadmat(setid_file)["valid"][0]
|
||||
test_list = scipy.io.loadmat(setid_file)["tstid"][0]
|
||||
trainval_list = np.concatenate([train_list, val_list])
|
||||
|
||||
if self.train:
|
||||
self.img_files = trainval_list
|
||||
else:
|
||||
self.img_files = test_list
|
||||
|
||||
|
||||
def __getitem__(self, index):
|
||||
img_name = "image_%05d.jpg" % self.img_files[index]
|
||||
target = self.labels[self.img_files[index] - 1] - 1
|
||||
img = Image.open(os.path.join(self.image_folder, img_name))
|
||||
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
|
||||
if self.target_transform is not None:
|
||||
target = self.target_transform(target)
|
||||
|
||||
return img, target
|
||||
|
||||
def __len__(self):
|
||||
return len(self.img_files)
|
|
@ -0,0 +1,67 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from PIL import Image
|
||||
from typing import Any, Callable, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import os
|
||||
import os.path
|
||||
import pickle
|
||||
import scipy.io
|
||||
|
||||
from torchvision.datasets.vision import VisionDataset
|
||||
|
||||
|
||||
class Pets(VisionDataset):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root: str,
|
||||
train: bool = True,
|
||||
transform: Optional[Callable] = None,
|
||||
target_transform: Optional[Callable] = None,
|
||||
download: bool = False,
|
||||
) -> None:
|
||||
|
||||
super(Pets, self).__init__(root, transform=transform,
|
||||
target_transform=target_transform)
|
||||
|
||||
base_folder = root
|
||||
self.train = train
|
||||
annotations_path_dir = os.path.join(base_folder, "annotations")
|
||||
self.image_path_dir = os.path.join(base_folder, "images")
|
||||
|
||||
if self.train:
|
||||
split_file = os.path.join(annotations_path_dir, "trainval.txt")
|
||||
with open(split_file) as f:
|
||||
self.images_list = f.readlines()
|
||||
else:
|
||||
split_file = os.path.join(annotations_path_dir, "test.txt")
|
||||
with open(split_file) as f:
|
||||
self.images_list = f.readlines()
|
||||
|
||||
|
||||
def __getitem__(self, index: int) -> Tuple[Any, Any]:
|
||||
|
||||
img_name, label, species, _ = self.images_list[index].strip().split(" ")
|
||||
|
||||
img_name += ".jpg"
|
||||
target = int(label) - 1
|
||||
|
||||
img = Image.open(os.path.join(self.image_path_dir, img_name))
|
||||
img = img.convert('RGB')
|
||||
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
|
||||
if self.target_transform is not None:
|
||||
target = self.target_transform(target)
|
||||
|
||||
return img, target
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.images_list)
|
|
@ -0,0 +1,143 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from functools import partial, reduce
|
||||
from operator import mul
|
||||
|
||||
from timm.models.vision_transformer import VisionTransformer, _cfg
|
||||
from timm.models.layers.helpers import to_2tuple
|
||||
from timm.models.layers import PatchEmbed
|
||||
|
||||
__all__ = [
|
||||
'vit_small',
|
||||
'vit_base',
|
||||
'vit_conv_small',
|
||||
'vit_conv_base',
|
||||
]
|
||||
|
||||
|
||||
class VisionTransformerMoCo(VisionTransformer):
|
||||
def __init__(self, stop_grad_conv1=False, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
# Use fixed 2D sin-cos position embedding
|
||||
self.build_2d_sincos_position_embedding()
|
||||
|
||||
# weight initialization
|
||||
for name, m in self.named_modules():
|
||||
if isinstance(m, nn.Linear):
|
||||
if 'qkv' in name:
|
||||
# treat the weights of Q, K, V separately
|
||||
val = math.sqrt(6. / float(m.weight.shape[0] // 3 + m.weight.shape[1]))
|
||||
nn.init.uniform_(m.weight, -val, val)
|
||||
else:
|
||||
nn.init.xavier_uniform_(m.weight)
|
||||
nn.init.zeros_(m.bias)
|
||||
nn.init.normal_(self.cls_token, std=1e-6)
|
||||
|
||||
if isinstance(self.patch_embed, PatchEmbed):
|
||||
# xavier_uniform initialization
|
||||
val = math.sqrt(6. / float(3 * reduce(mul, self.patch_embed.patch_size, 1) + self.embed_dim))
|
||||
nn.init.uniform_(self.patch_embed.proj.weight, -val, val)
|
||||
nn.init.zeros_(self.patch_embed.proj.bias)
|
||||
|
||||
if stop_grad_conv1:
|
||||
self.patch_embed.proj.weight.requires_grad = False
|
||||
self.patch_embed.proj.bias.requires_grad = False
|
||||
|
||||
def build_2d_sincos_position_embedding(self, temperature=10000.):
|
||||
h, w = self.patch_embed.grid_size
|
||||
grid_w = torch.arange(w, dtype=torch.float32)
|
||||
grid_h = torch.arange(h, dtype=torch.float32)
|
||||
grid_w, grid_h = torch.meshgrid(grid_w, grid_h)
|
||||
assert self.embed_dim % 4 == 0, 'Embed dimension must be divisible by 4 for 2D sin-cos position embedding'
|
||||
pos_dim = self.embed_dim // 4
|
||||
omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
|
||||
omega = 1. / (temperature**omega)
|
||||
out_w = torch.einsum('m,d->md', [grid_w.flatten(), omega])
|
||||
out_h = torch.einsum('m,d->md', [grid_h.flatten(), omega])
|
||||
pos_emb = torch.cat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h)], dim=1)[None, :, :]
|
||||
|
||||
assert self.num_tokens == 1, 'Assuming one and only one token, [cls]'
|
||||
pe_token = torch.zeros([1, 1, self.embed_dim], dtype=torch.float32)
|
||||
self.pos_embed = nn.Parameter(torch.cat([pe_token, pos_emb], dim=1))
|
||||
self.pos_embed.requires_grad = False
|
||||
|
||||
|
||||
class ConvStem(nn.Module):
|
||||
"""
|
||||
ConvStem, from Early Convolutions Help Transformers See Better, Tete et al. https://arxiv.org/abs/2106.14881
|
||||
"""
|
||||
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True):
|
||||
super().__init__()
|
||||
|
||||
assert patch_size == 16, 'ConvStem only supports patch size of 16'
|
||||
assert embed_dim % 8 == 0, 'Embed dimension must be divisible by 8 for ConvStem'
|
||||
|
||||
img_size = to_2tuple(img_size)
|
||||
patch_size = to_2tuple(patch_size)
|
||||
self.img_size = img_size
|
||||
self.patch_size = patch_size
|
||||
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
|
||||
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
||||
self.flatten = flatten
|
||||
|
||||
# build stem, similar to the design in https://arxiv.org/abs/2106.14881
|
||||
stem = []
|
||||
input_dim, output_dim = 3, embed_dim // 8
|
||||
for l in range(4):
|
||||
stem.append(nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=2, padding=1, bias=False))
|
||||
stem.append(nn.BatchNorm2d(output_dim))
|
||||
stem.append(nn.ReLU(inplace=True))
|
||||
input_dim = output_dim
|
||||
output_dim *= 2
|
||||
stem.append(nn.Conv2d(input_dim, embed_dim, kernel_size=1))
|
||||
self.proj = nn.Sequential(*stem)
|
||||
|
||||
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
B, C, H, W = x.shape
|
||||
assert H == self.img_size[0] and W == self.img_size[1], \
|
||||
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
||||
x = self.proj(x)
|
||||
if self.flatten:
|
||||
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
|
||||
def vit_small(**kwargs):
|
||||
model = VisionTransformerMoCo(
|
||||
patch_size=16, embed_dim=384, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
||||
model.default_cfg = _cfg()
|
||||
return model
|
||||
|
||||
def vit_base(**kwargs):
|
||||
model = VisionTransformerMoCo(
|
||||
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
||||
model.default_cfg = _cfg()
|
||||
return model
|
||||
|
||||
def vit_conv_small(**kwargs):
|
||||
# minus one ViT block
|
||||
model = VisionTransformerMoCo(
|
||||
patch_size=16, embed_dim=384, depth=11, num_heads=12, mlp_ratio=4, qkv_bias=True,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6), embed_layer=ConvStem, **kwargs)
|
||||
model.default_cfg = _cfg()
|
||||
return model
|
||||
|
||||
def vit_conv_base(**kwargs):
|
||||
# minus one ViT block
|
||||
model = VisionTransformerMoCo(
|
||||
patch_size=16, embed_dim=768, depth=11, num_heads=12, mlp_ratio=4, qkv_bias=True,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6), embed_layer=ConvStem, **kwargs)
|
||||
model.default_cfg = _cfg()
|
||||
return model
|
Loading…
Reference in New Issue