first commit
first commit first commit add visus readme summary summary nitpull/6/head
commit
58aabc09e0
|
@ -0,0 +1,5 @@
|
|||
# Code of Conduct
|
||||
|
||||
Facebook has adopted a Code of Conduct that we expect project participants to adhere to.
|
||||
Please read the [full text](https://code.fb.com/codeofconduct/)
|
||||
so that you can understand what actions will and will not be tolerated.
|
|
@ -0,0 +1,4 @@
|
|||
# Contributing
|
||||
|
||||
In the context of this project, we do not expect pull requests.
|
||||
If you find a bug, or would like to suggest an improvement, please open an issue.
|
Binary file not shown.
After Width: | Height: | Size: 4.0 MiB |
Binary file not shown.
After Width: | Height: | Size: 20 MiB |
|
@ -0,0 +1,399 @@
|
|||
Attribution-NonCommercial 4.0 International
|
||||
|
||||
=======================================================================
|
||||
|
||||
Creative Commons Corporation ("Creative Commons") is not a law firm and
|
||||
does not provide legal services or legal advice. Distribution of
|
||||
Creative Commons public licenses does not create a lawyer-client or
|
||||
other relationship. Creative Commons makes its licenses and related
|
||||
information available on an "as-is" basis. Creative Commons gives no
|
||||
warranties regarding its licenses, any material licensed under their
|
||||
terms and conditions, or any related information. Creative Commons
|
||||
disclaims all liability for damages resulting from their use to the
|
||||
fullest extent possible.
|
||||
|
||||
Using Creative Commons Public Licenses
|
||||
|
||||
Creative Commons public licenses provide a standard set of terms and
|
||||
conditions that creators and other rights holders may use to share
|
||||
original works of authorship and other material subject to copyright
|
||||
and certain other rights specified in the public license below. The
|
||||
following considerations are for informational purposes only, are not
|
||||
exhaustive, and do not form part of our licenses.
|
||||
|
||||
Considerations for licensors: Our public licenses are
|
||||
intended for use by those authorized to give the public
|
||||
permission to use material in ways otherwise restricted by
|
||||
copyright and certain other rights. Our licenses are
|
||||
irrevocable. Licensors should read and understand the terms
|
||||
and conditions of the license they choose before applying it.
|
||||
Licensors should also secure all rights necessary before
|
||||
applying our licenses so that the public can reuse the
|
||||
material as expected. Licensors should clearly mark any
|
||||
material not subject to the license. This includes other CC-
|
||||
licensed material, or material used under an exception or
|
||||
limitation to copyright. More considerations for licensors:
|
||||
wiki.creativecommons.org/Considerations_for_licensors
|
||||
|
||||
Considerations for the public: By using one of our public
|
||||
licenses, a licensor grants the public permission to use the
|
||||
licensed material under specified terms and conditions. If
|
||||
the licensor's permission is not necessary for any reason--for
|
||||
example, because of any applicable exception or limitation to
|
||||
copyright--then that use is not regulated by the license. Our
|
||||
licenses grant only permissions under copyright and certain
|
||||
other rights that a licensor has authority to grant. Use of
|
||||
the licensed material may still be restricted for other
|
||||
reasons, including because others have copyright or other
|
||||
rights in the material. A licensor may make special requests,
|
||||
such as asking that all changes be marked or described.
|
||||
Although not required by our licenses, you are encouraged to
|
||||
respect those requests where reasonable. More_considerations
|
||||
for the public:
|
||||
wiki.creativecommons.org/Considerations_for_licensees
|
||||
|
||||
=======================================================================
|
||||
|
||||
Creative Commons Attribution-NonCommercial 4.0 International Public
|
||||
License
|
||||
|
||||
By exercising the Licensed Rights (defined below), You accept and agree
|
||||
to be bound by the terms and conditions of this Creative Commons
|
||||
Attribution-NonCommercial 4.0 International Public License ("Public
|
||||
License"). To the extent this Public License may be interpreted as a
|
||||
contract, You are granted the Licensed Rights in consideration of Your
|
||||
acceptance of these terms and conditions, and the Licensor grants You
|
||||
such rights in consideration of benefits the Licensor receives from
|
||||
making the Licensed Material available under these terms and
|
||||
conditions.
|
||||
|
||||
Section 1 -- Definitions.
|
||||
|
||||
a. Adapted Material means material subject to Copyright and Similar
|
||||
Rights that is derived from or based upon the Licensed Material
|
||||
and in which the Licensed Material is translated, altered,
|
||||
arranged, transformed, or otherwise modified in a manner requiring
|
||||
permission under the Copyright and Similar Rights held by the
|
||||
Licensor. For purposes of this Public License, where the Licensed
|
||||
Material is a musical work, performance, or sound recording,
|
||||
Adapted Material is always produced where the Licensed Material is
|
||||
synched in timed relation with a moving image.
|
||||
|
||||
b. Adapter's License means the license You apply to Your Copyright
|
||||
and Similar Rights in Your contributions to Adapted Material in
|
||||
accordance with the terms and conditions of this Public License.
|
||||
|
||||
c. Copyright and Similar Rights means copyright and/or similar rights
|
||||
closely related to copyright including, without limitation,
|
||||
performance, broadcast, sound recording, and Sui Generis Database
|
||||
Rights, without regard to how the rights are labeled or
|
||||
categorized. For purposes of this Public License, the rights
|
||||
specified in Section 2(b)(1)-(2) are not Copyright and Similar
|
||||
Rights.
|
||||
d. Effective Technological Measures means those measures that, in the
|
||||
absence of proper authority, may not be circumvented under laws
|
||||
fulfilling obligations under Article 11 of the WIPO Copyright
|
||||
Treaty adopted on December 20, 1996, and/or similar international
|
||||
agreements.
|
||||
|
||||
e. Exceptions and Limitations means fair use, fair dealing, and/or
|
||||
any other exception or limitation to Copyright and Similar Rights
|
||||
that applies to Your use of the Licensed Material.
|
||||
|
||||
f. Licensed Material means the artistic or literary work, database,
|
||||
or other material to which the Licensor applied this Public
|
||||
License.
|
||||
|
||||
g. Licensed Rights means the rights granted to You subject to the
|
||||
terms and conditions of this Public License, which are limited to
|
||||
all Copyright and Similar Rights that apply to Your use of the
|
||||
Licensed Material and that the Licensor has authority to license.
|
||||
|
||||
h. Licensor means the individual(s) or entity(ies) granting rights
|
||||
under this Public License.
|
||||
|
||||
i. NonCommercial means not primarily intended for or directed towards
|
||||
commercial advantage or monetary compensation. For purposes of
|
||||
this Public License, the exchange of the Licensed Material for
|
||||
other material subject to Copyright and Similar Rights by digital
|
||||
file-sharing or similar means is NonCommercial provided there is
|
||||
no payment of monetary compensation in connection with the
|
||||
exchange.
|
||||
|
||||
j. Share means to provide material to the public by any means or
|
||||
process that requires permission under the Licensed Rights, such
|
||||
as reproduction, public display, public performance, distribution,
|
||||
dissemination, communication, or importation, and to make material
|
||||
available to the public including in ways that members of the
|
||||
public may access the material from a place and at a time
|
||||
individually chosen by them.
|
||||
|
||||
k. Sui Generis Database Rights means rights other than copyright
|
||||
resulting from Directive 96/9/EC of the European Parliament and of
|
||||
the Council of 11 March 1996 on the legal protection of databases,
|
||||
as amended and/or succeeded, as well as other essentially
|
||||
equivalent rights anywhere in the world.
|
||||
|
||||
l. You means the individual or entity exercising the Licensed Rights
|
||||
under this Public License. Your has a corresponding meaning.
|
||||
|
||||
Section 2 -- Scope.
|
||||
|
||||
a. License grant.
|
||||
|
||||
1. Subject to the terms and conditions of this Public License,
|
||||
the Licensor hereby grants You a worldwide, royalty-free,
|
||||
non-sublicensable, non-exclusive, irrevocable license to
|
||||
exercise the Licensed Rights in the Licensed Material to:
|
||||
|
||||
a. reproduce and Share the Licensed Material, in whole or
|
||||
in part, for NonCommercial purposes only; and
|
||||
|
||||
b. produce, reproduce, and Share Adapted Material for
|
||||
NonCommercial purposes only.
|
||||
|
||||
2. Exceptions and Limitations. For the avoidance of doubt, where
|
||||
Exceptions and Limitations apply to Your use, this Public
|
||||
License does not apply, and You do not need to comply with
|
||||
its terms and conditions.
|
||||
|
||||
3. Term. The term of this Public License is specified in Section
|
||||
6(a).
|
||||
|
||||
4. Media and formats; technical modifications allowed. The
|
||||
Licensor authorizes You to exercise the Licensed Rights in
|
||||
all media and formats whether now known or hereafter created,
|
||||
and to make technical modifications necessary to do so. The
|
||||
Licensor waives and/or agrees not to assert any right or
|
||||
authority to forbid You from making technical modifications
|
||||
necessary to exercise the Licensed Rights, including
|
||||
technical modifications necessary to circumvent Effective
|
||||
Technological Measures. For purposes of this Public License,
|
||||
simply making modifications authorized by this Section 2(a)
|
||||
(4) never produces Adapted Material.
|
||||
|
||||
5. Downstream recipients.
|
||||
|
||||
a. Offer from the Licensor -- Licensed Material. Every
|
||||
recipient of the Licensed Material automatically
|
||||
receives an offer from the Licensor to exercise the
|
||||
Licensed Rights under the terms and conditions of this
|
||||
Public License.
|
||||
|
||||
b. No downstream restrictions. You may not offer or impose
|
||||
any additional or different terms or conditions on, or
|
||||
apply any Effective Technological Measures to, the
|
||||
Licensed Material if doing so restricts exercise of the
|
||||
Licensed Rights by any recipient of the Licensed
|
||||
Material.
|
||||
|
||||
6. No endorsement. Nothing in this Public License constitutes or
|
||||
may be construed as permission to assert or imply that You
|
||||
are, or that Your use of the Licensed Material is, connected
|
||||
with, or sponsored, endorsed, or granted official status by,
|
||||
the Licensor or others designated to receive attribution as
|
||||
provided in Section 3(a)(1)(A)(i).
|
||||
|
||||
b. Other rights.
|
||||
|
||||
1. Moral rights, such as the right of integrity, are not
|
||||
licensed under this Public License, nor are publicity,
|
||||
privacy, and/or other similar personality rights; however, to
|
||||
the extent possible, the Licensor waives and/or agrees not to
|
||||
assert any such rights held by the Licensor to the limited
|
||||
extent necessary to allow You to exercise the Licensed
|
||||
Rights, but not otherwise.
|
||||
|
||||
2. Patent and trademark rights are not licensed under this
|
||||
Public License.
|
||||
|
||||
3. To the extent possible, the Licensor waives any right to
|
||||
collect royalties from You for the exercise of the Licensed
|
||||
Rights, whether directly or through a collecting society
|
||||
under any voluntary or waivable statutory or compulsory
|
||||
licensing scheme. In all other cases the Licensor expressly
|
||||
reserves any right to collect such royalties, including when
|
||||
the Licensed Material is used other than for NonCommercial
|
||||
purposes.
|
||||
|
||||
Section 3 -- License Conditions.
|
||||
|
||||
Your exercise of the Licensed Rights is expressly made subject to the
|
||||
following conditions.
|
||||
|
||||
a. Attribution.
|
||||
|
||||
1. If You Share the Licensed Material (including in modified
|
||||
form), You must:
|
||||
|
||||
a. retain the following if it is supplied by the Licensor
|
||||
with the Licensed Material:
|
||||
|
||||
i. identification of the creator(s) of the Licensed
|
||||
Material and any others designated to receive
|
||||
attribution, in any reasonable manner requested by
|
||||
the Licensor (including by pseudonym if
|
||||
designated);
|
||||
|
||||
ii. a copyright notice;
|
||||
|
||||
iii. a notice that refers to this Public License;
|
||||
|
||||
iv. a notice that refers to the disclaimer of
|
||||
warranties;
|
||||
|
||||
v. a URI or hyperlink to the Licensed Material to the
|
||||
extent reasonably practicable;
|
||||
|
||||
b. indicate if You modified the Licensed Material and
|
||||
retain an indication of any previous modifications; and
|
||||
|
||||
c. indicate the Licensed Material is licensed under this
|
||||
Public License, and include the text of, or the URI or
|
||||
hyperlink to, this Public License.
|
||||
|
||||
2. You may satisfy the conditions in Section 3(a)(1) in any
|
||||
reasonable manner based on the medium, means, and context in
|
||||
which You Share the Licensed Material. For example, it may be
|
||||
reasonable to satisfy the conditions by providing a URI or
|
||||
hyperlink to a resource that includes the required
|
||||
information.
|
||||
|
||||
3. If requested by the Licensor, You must remove any of the
|
||||
information required by Section 3(a)(1)(A) to the extent
|
||||
reasonably practicable.
|
||||
|
||||
4. If You Share Adapted Material You produce, the Adapter's
|
||||
License You apply must not prevent recipients of the Adapted
|
||||
Material from complying with this Public License.
|
||||
|
||||
Section 4 -- Sui Generis Database Rights.
|
||||
|
||||
Where the Licensed Rights include Sui Generis Database Rights that
|
||||
apply to Your use of the Licensed Material:
|
||||
|
||||
a. for the avoidance of doubt, Section 2(a)(1) grants You the right
|
||||
to extract, reuse, reproduce, and Share all or a substantial
|
||||
portion of the contents of the database for NonCommercial purposes
|
||||
only;
|
||||
|
||||
b. if You include all or a substantial portion of the database
|
||||
contents in a database in which You have Sui Generis Database
|
||||
Rights, then the database in which You have Sui Generis Database
|
||||
Rights (but not its individual contents) is Adapted Material; and
|
||||
|
||||
c. You must comply with the conditions in Section 3(a) if You Share
|
||||
all or a substantial portion of the contents of the database.
|
||||
|
||||
For the avoidance of doubt, this Section 4 supplements and does not
|
||||
replace Your obligations under this Public License where the Licensed
|
||||
Rights include other Copyright and Similar Rights.
|
||||
|
||||
Section 5 -- Disclaimer of Warranties and Limitation of Liability.
|
||||
|
||||
a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
|
||||
EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
|
||||
AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
|
||||
ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
|
||||
IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
|
||||
WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
|
||||
PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
|
||||
ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
|
||||
KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
|
||||
ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
|
||||
|
||||
b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
|
||||
TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
|
||||
NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
|
||||
INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
|
||||
COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
|
||||
USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
|
||||
ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
|
||||
DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
|
||||
IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
|
||||
|
||||
c. The disclaimer of warranties and limitation of liability provided
|
||||
above shall be interpreted in a manner that, to the extent
|
||||
possible, most closely approximates an absolute disclaimer and
|
||||
waiver of all liability.
|
||||
|
||||
Section 6 -- Term and Termination.
|
||||
|
||||
a. This Public License applies for the term of the Copyright and
|
||||
Similar Rights licensed here. However, if You fail to comply with
|
||||
this Public License, then Your rights under this Public License
|
||||
terminate automatically.
|
||||
|
||||
b. Where Your right to use the Licensed Material has terminated under
|
||||
Section 6(a), it reinstates:
|
||||
|
||||
1. automatically as of the date the violation is cured, provided
|
||||
it is cured within 30 days of Your discovery of the
|
||||
violation; or
|
||||
|
||||
2. upon express reinstatement by the Licensor.
|
||||
|
||||
For the avoidance of doubt, this Section 6(b) does not affect any
|
||||
right the Licensor may have to seek remedies for Your violations
|
||||
of this Public License.
|
||||
|
||||
c. For the avoidance of doubt, the Licensor may also offer the
|
||||
Licensed Material under separate terms or conditions or stop
|
||||
distributing the Licensed Material at any time; however, doing so
|
||||
will not terminate this Public License.
|
||||
|
||||
d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
|
||||
License.
|
||||
|
||||
Section 7 -- Other Terms and Conditions.
|
||||
|
||||
a. The Licensor shall not be bound by any additional or different
|
||||
terms or conditions communicated by You unless expressly agreed.
|
||||
|
||||
b. Any arrangements, understandings, or agreements regarding the
|
||||
Licensed Material not stated herein are separate from and
|
||||
independent of the terms and conditions of this Public License.
|
||||
|
||||
Section 8 -- Interpretation.
|
||||
|
||||
a. For the avoidance of doubt, this Public License does not, and
|
||||
shall not be interpreted to, reduce, limit, restrict, or impose
|
||||
conditions on any use of the Licensed Material that could lawfully
|
||||
be made without permission under this Public License.
|
||||
|
||||
b. To the extent possible, if any provision of this Public License is
|
||||
deemed unenforceable, it shall be automatically reformed to the
|
||||
minimum extent necessary to make it enforceable. If the provision
|
||||
cannot be reformed, it shall be severed from this Public License
|
||||
without affecting the enforceability of the remaining terms and
|
||||
conditions.
|
||||
|
||||
c. No term or condition of this Public License will be waived and no
|
||||
failure to comply consented to unless expressly agreed to by the
|
||||
Licensor.
|
||||
|
||||
d. Nothing in this Public License constitutes or may be interpreted
|
||||
as a limitation upon, or waiver of, any privileges and immunities
|
||||
that apply to the Licensor or You, including from the legal
|
||||
processes of any jurisdiction or authority.
|
||||
|
||||
=======================================================================
|
||||
|
||||
Creative Commons is not a party to its public
|
||||
licenses. Notwithstanding, Creative Commons may elect to apply one of
|
||||
its public licenses to material it publishes and in those instances
|
||||
will be considered the “Licensor.” The text of the Creative Commons
|
||||
public licenses is dedicated to the public domain under the CC0 Public
|
||||
Domain Dedication. Except for the limited purpose of indicating that
|
||||
material is shared under a Creative Commons public license or as
|
||||
otherwise permitted by the Creative Commons policies published at
|
||||
creativecommons.org/policies, Creative Commons does not authorize the
|
||||
use of the trademark "Creative Commons" or any other trademark or logo
|
||||
of Creative Commons without its prior written consent including,
|
||||
without limitation, in connection with any unauthorized modifications
|
||||
to any of its public licenses or any other arrangements,
|
||||
understandings, or agreements concerning use of licensed material. For
|
||||
the avoidance of doubt, this paragraph does not form part of the
|
||||
public licenses.
|
||||
|
||||
Creative Commons may be contacted at creativecommons.org.
|
|
@ -0,0 +1,179 @@
|
|||
# Self-Supervised Vision Transformers with DINO
|
||||
|
||||
PyTorch implementation and pretrained models of DINO, as described in [Emerging Properties in Self-Supervised Vision Transformers](https://arxiv.org/abs/2104.14294).
|
||||
<div align="center">
|
||||
<img width="100%" alt="DINO illustration" src=".github/dino.gif">
|
||||
</div>
|
||||
|
||||
## Pretrained models
|
||||
You can choose to download only the weights of the pretrained backbone used for downstream tasks, or the full checkpoint which contains backbone and projection head weights for both student and teacher networks. We also provide the training and evaluation logs.
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
<th>arch</th>
|
||||
<th>params</th>
|
||||
<th>k-nn</th>
|
||||
<th>linear</th>
|
||||
<th colspan="5">download</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>DeiT-S/16</td>
|
||||
<td>21M</td>
|
||||
<td>74.5%</td>
|
||||
<td>77.0%</td>
|
||||
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth">backbone only</a></td>
|
||||
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain_full_checkpoint.pth">full checkpoint</a></td>
|
||||
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/args.txt">args</a></td>
|
||||
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain_log.txt">logs</a></td>
|
||||
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain_eval_linear_log.txt">eval logs</a></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>DeiT-S/8</td>
|
||||
<td>21M</td>
|
||||
<td>78.3%</td>
|
||||
<td>79.7%</td>
|
||||
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth">backbone only</a></td>
|
||||
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain_full_checkpoint.pth">full checkpoint</a></td>
|
||||
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_deitsmall8_pretrain/args.txt">args</a></td>
|
||||
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain_log.txt">logs</a></td>
|
||||
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain_eval_linear_log.txt">eval logs</a></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>ViT-B/16</td>
|
||||
<td>85M</td>
|
||||
<td>76.1%</td>
|
||||
<td>78.2%</td>
|
||||
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth">backbone only</a></td>
|
||||
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitbase16_pretrain_full_checkpoint.pth">full checkpoint</a></td>
|
||||
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/args.txt">args</a></td>
|
||||
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitbase16_pretrain_log.txt">logs</a></td>
|
||||
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitbase16_pretrain_eval_linear_log.txt">eval logs</a></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>ViT-B/8</td>
|
||||
<td>85M</td>
|
||||
<td>77.4%</td>
|
||||
<td>80.1%</td>
|
||||
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth">backbone only</a></td>
|
||||
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain_full_checkpoint.pth">full checkpoint</a></td>
|
||||
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/args.txt">args</a></td>
|
||||
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain_log.txt">logs</a></td>
|
||||
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain_eval_linear_log.txt">eval logs</a></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>ResNet-50</td>
|
||||
<td>23M</td>
|
||||
<td>67.5%</td>
|
||||
<td>75.3%</td>
|
||||
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_resnet50_pretrain/dino_resnet50_pretrain.pth">backbone only</a></td>
|
||||
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_resnet50_pretrain/dino_resnet50_pretrain_full_checkpoint.pth">full checkpoint</a></td>
|
||||
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_resnet50_pretrain/args.txt">args</a></td>
|
||||
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_resnet50_pretrain/dino_resnet50_pretrain_log.txt">logs</a></td>
|
||||
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_resnet50_pretrain/dino_resnet50_pretrain_eval_linear_log.txt">eval logs</a></td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
The pretrained models are available on PyTorch Hub.
|
||||
```python
|
||||
import torch
|
||||
deits16 = torch.hub.load('facebookresearch/dino', 'dino_deits16')
|
||||
deits8 = torch.hub.load('facebookresearch/dino', 'dino_deits8')
|
||||
vitb16 = torch.hub.load('facebookresearch/dino', 'dino_vitb16')
|
||||
vitb8 = torch.hub.load('facebookresearch/dino', 'dino_vitb8')
|
||||
resnet50 = torch.hub.load('facebookresearch/dino', 'dino_resnet50')
|
||||
```
|
||||
|
||||
## Training
|
||||
|
||||
### Documentation
|
||||
Please install [PyTorch](https://pytorch.org/) and download the [ImageNet](https://imagenet.stanford.edu/) dataset. This codebase has been developed with python version 3.6, PyTorch version 1.7.1, CUDA 11.0 and torchvision 0.8.2. The exact arguments to reproduce the models presented in our paper can be found in the `args` column of the [pretrained models section](https://github.com/facebookresearch/dino#pretrained-models). For a glimpse at the full documentation of DINO training please run:
|
||||
```
|
||||
python main_dino.py --help
|
||||
```
|
||||
|
||||
### Vanilla DINO training :sauropod:
|
||||
Run DINO with DeiT-small network on a single node with 8 GPUs for 100 epochs with the following command. Training time is 1.75 day and the resulting checkpoint should reach ~69.3% on k-NN eval and ~73.8% on linear eval. We will shortly provide [training](/to/do) and [linear evaluation](/to/do) logs for this run to help reproducibility.
|
||||
```
|
||||
python -m torch.distributed.launch --nproc_per_node=8 main_dino.py --arch deit_small --data_path /path/to/imagenet/train --output_dir /path/to/saving_dir
|
||||
```
|
||||
|
||||
### Multi-node training
|
||||
We use Slurm and [submitit](https://github.com/facebookincubator/submitit) (`pip install submitit`). To train on 2 nodes with 8 GPUs each (total 16 GPUs):
|
||||
```
|
||||
python run_with_submitit.py --nodes 2 --ngpus 8 --arch deit_small --data_path /path/to/imagenet/train --output_dir /path/to/saving_dir
|
||||
```
|
||||
|
||||
<details>
|
||||
<summary>
|
||||
DINO with ViT-base network.
|
||||
</summary>
|
||||
|
||||
```
|
||||
python run_with_submitit.py --nodes 2 --ngpus 8 --use_volta32 --arch vit_base --data_path /path/to/imagenet/train --output_dir /path/to/saving_dir
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
### Boosting DINO performance :t-rex:
|
||||
You can improve the performance of the vanilla run by:
|
||||
- training for more epochs: `--epochs 300`,
|
||||
- increasing the teacher temperature: `--teacher_temp 0.07 --warmup_teacher_temp_epochs 30`.
|
||||
- removing last layer normalization (only safe with `--arch deit_small`): `--norm_last_layer false`,
|
||||
|
||||
<details>
|
||||
<summary>
|
||||
Full command.
|
||||
</summary>
|
||||
|
||||
```
|
||||
python run_with_submitit.py --arch deit_small --epochs 300 --teacher_temp 0.07 --warmup_teacher_temp_epochs 30 --norm_last_layer false --data_path /path/to/imagenet/train --output_dir /path/to/saving_dir
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
The resulting pretrained model should reach ~73.4% on k-NN eval and ~76.1% on linear eval. Training time is 2.6 days with 16 GPUs. We will shortly provide [training](/to/do) and [linear evaluation](/to/do) logs for this run to help reproducibility.
|
||||
|
||||
### ResNet-50 and other convnets trainings
|
||||
This code also works for training DINO on convolutional networks, like ResNet-50 for example. We highly recommend to adapt some optimization arguments in this case. For example here is a command to train DINO on ResNet-50 on a single node with 8 GPUs for 100 epochs:
|
||||
```
|
||||
python -m torch.distributed.launch --nproc_per_node=8 main_dino.py --arch resnet50 --optimizer sgd --weight_decay 1e-4 --weight_decay_end 1e-4 --global_crops_scale 0.14 1 --local_crops_scale 0.05 0.14 --data_path /path/to/imagenet/train --output_dir /path/to/saving_dir
|
||||
```
|
||||
|
||||
## Evaluation: k-NN classification on ImageNet
|
||||
To evaluate a simple k-NN classifier with a single GPU on a pre-trained model, run:
|
||||
```
|
||||
python -m torch.distributed.launch --nproc_per_node=1 eval_knn.py --data_path /path/to/imagenet
|
||||
```
|
||||
If you choose not to specify `--pretrained_weights`, then DINO reference weights are used by default. If you want instead to evaluate checkpoints from a run of your own, you can run for example:
|
||||
```
|
||||
python -m torch.distributed.launch --nproc_per_node=1 eval_knn.py --pretrained_weights /path/to/checkpoint.pth --checkpoint_key teacher --data_path /path/to/imagenet
|
||||
```
|
||||
|
||||
## Evaluation: Linear classification on ImageNet
|
||||
To train a supervised linear classifier on frozen weights on a single node with 8 gpus, run:
|
||||
```
|
||||
python -m torch.distributed.launch --nproc_per_node=8 eval_linear.py --data_path /path/to/imagenet
|
||||
```
|
||||
|
||||
## Self-attention visualization
|
||||
You can look at the self-attention of the [CLS] token on the different heads of the last layer by running:
|
||||
```
|
||||
python visualize_attention.py
|
||||
```
|
||||
<div align="center">
|
||||
<img width="100%" alt="Self-attention from a Vision Transformer with 8x8 patches trained with DINO" src=".github/attention_maps.png">
|
||||
</div>
|
||||
|
||||
## License
|
||||
See the [LICENSE](LICENSE) file for more details.
|
||||
|
||||
## Citation
|
||||
If you find this repository useful, please consider giving a star :star: and citation :t-rex::
|
||||
```
|
||||
@article{caron2021emerging,
|
||||
title={Emerging Properties in Self-Supervised Vision Transformers},
|
||||
author={Caron, Mathilde and Touvron, Hugo and Misra, Ishan and J\'egou, Herv\'e and Mairal, Julien and Bojanowski, Piotr and Joulin, Armand},
|
||||
journal={arXiv preprint arXiv:2104.14294},
|
||||
year={2021}
|
||||
}
|
||||
```
|
|
@ -0,0 +1,217 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
import os
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.distributed as dist
|
||||
import torch.backends.cudnn as cudnn
|
||||
from torchvision import datasets
|
||||
from torchvision import transforms as pth_transforms
|
||||
|
||||
import utils
|
||||
import vision_transformer as vits
|
||||
|
||||
|
||||
def extract_feature_pipeline(args):
|
||||
# ============ preparing data ... ============
|
||||
transform = pth_transforms.Compose([
|
||||
pth_transforms.Resize(256, interpolation=3),
|
||||
pth_transforms.CenterCrop(224),
|
||||
pth_transforms.ToTensor(),
|
||||
pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
|
||||
])
|
||||
dataset_train = ReturnIndexDataset(os.path.join(args.data_path, "train"), transform=transform)
|
||||
dataset_val = ReturnIndexDataset(os.path.join(args.data_path, "val"), transform=transform)
|
||||
sampler = torch.utils.data.DistributedSampler(dataset_train, shuffle=False)
|
||||
data_loader_train = torch.utils.data.DataLoader(
|
||||
dataset_train,
|
||||
sampler=sampler,
|
||||
batch_size=args.batch_size_per_gpu,
|
||||
num_workers=args.num_workers,
|
||||
pin_memory=True,
|
||||
drop_last=False,
|
||||
)
|
||||
data_loader_val = torch.utils.data.DataLoader(
|
||||
dataset_val,
|
||||
batch_size=args.batch_size_per_gpu,
|
||||
num_workers=args.num_workers,
|
||||
pin_memory=True,
|
||||
drop_last=False,
|
||||
)
|
||||
print(f"Data loaded with {len(dataset_train)} train and {len(dataset_val)} val imgs.")
|
||||
|
||||
# ============ building network ... ============
|
||||
model = vits.__dict__[args.arch](patch_size=args.patch_size, num_classes=0)
|
||||
print(f"Model {args.arch} {args.patch_size}x{args.patch_size} built.")
|
||||
model.cuda()
|
||||
utils.load_pretrained_weights(model, args.pretrained_weights, args.checkpoint_key, args.arch, args.patch_size)
|
||||
model.eval()
|
||||
|
||||
# ============ extract features ... ============
|
||||
print("Extracting features for train set...")
|
||||
train_features = extract_features(model, data_loader_train)
|
||||
print("Extracting features for val set...")
|
||||
test_features = extract_features(model, data_loader_val)
|
||||
|
||||
if utils.get_rank() == 0:
|
||||
train_features = nn.functional.normalize(train_features, dim=1, p=2)
|
||||
test_features = nn.functional.normalize(test_features, dim=1, p=2)
|
||||
|
||||
train_labels = torch.tensor([s[-1] for s in dataset_train.samples]).long()
|
||||
test_labels = torch.tensor([s[-1] for s in dataset_val.samples]).long()
|
||||
# save features and labels
|
||||
if args.dump_features and dist.get_rank() == 0:
|
||||
torch.save(train_features.cpu(), os.path.join(args.dump_features, "trainfeat.pth"))
|
||||
torch.save(test_features.cpu(), os.path.join(args.dump_features, "testfeat.pth"))
|
||||
torch.save(train_labels.cpu(), os.path.join(args.dump_features, "trainlabels.pth"))
|
||||
torch.save(test_labels.cpu(), os.path.join(args.dump_features, "testlabels.pth"))
|
||||
return train_features, test_features, train_labels, test_labels
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def extract_features(model, data_loader):
|
||||
metric_logger = utils.MetricLogger(delimiter=" ")
|
||||
features = None
|
||||
for samples, index in metric_logger.log_every(data_loader, 10):
|
||||
samples = samples.cuda(non_blocking=True)
|
||||
index = index.cuda(non_blocking=True)
|
||||
feats = model(samples).clone()
|
||||
|
||||
# init storage feature matrix
|
||||
if dist.get_rank() == 0 and features is None:
|
||||
features = torch.zeros(len(data_loader.dataset), feats.shape[-1])
|
||||
if args.use_cuda:
|
||||
features = features.cuda(non_blocking=True)
|
||||
print(f"Storing features into tensor of shape {features.shape}")
|
||||
|
||||
# get indexes from all processes
|
||||
y_all = torch.empty(dist.get_world_size(), index.size(0), dtype=index.dtype, device=index.device)
|
||||
y_l = list(y_all.unbind(0))
|
||||
y_all_reduce = torch.distributed.all_gather(y_l, index, async_op=True)
|
||||
y_all_reduce.wait()
|
||||
index_all = torch.cat(y_l)
|
||||
|
||||
# share features between processes
|
||||
feats_all = torch.empty(
|
||||
dist.get_world_size(),
|
||||
feats.size(0),
|
||||
feats.size(1),
|
||||
dtype=feats.dtype,
|
||||
device=feats.device,
|
||||
)
|
||||
output_l = list(feats_all.unbind(0))
|
||||
output_all_reduce = torch.distributed.all_gather(output_l, feats, async_op=True)
|
||||
output_all_reduce.wait()
|
||||
|
||||
# update storage feature matrix
|
||||
if dist.get_rank() == 0:
|
||||
if args.use_cuda:
|
||||
features.index_copy_(0, index_all, torch.cat(output_l))
|
||||
else:
|
||||
features.index_copy_(0, index_all.cpu(), torch.cat(output_l).cpu())
|
||||
return features
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def knn_classifier(train_features, train_labels, test_features, test_labels, k, T, num_classes=1000):
|
||||
top1, top5, total = 0.0, 0.0, 0
|
||||
train_features = train_features.t()
|
||||
num_test_images, num_chunks = test_labels.shape[0], 100
|
||||
imgs_per_chunk = num_test_images // num_chunks
|
||||
retrieval_one_hot = torch.zeros(k, num_classes).cuda()
|
||||
for idx in range(0, num_test_images, imgs_per_chunk):
|
||||
# get the features for test images
|
||||
features = test_features[
|
||||
idx : min((idx + imgs_per_chunk), num_test_images), :
|
||||
]
|
||||
targets = test_labels[idx : min((idx + imgs_per_chunk), num_test_images)]
|
||||
batch_size = targets.shape[0]
|
||||
|
||||
# calculate the dot product and compute top-k neighbors
|
||||
similarity = torch.mm(features, train_features)
|
||||
distances, indices = similarity.topk(k, largest=True, sorted=True)
|
||||
candidates = train_labels.view(1, -1).expand(batch_size, -1)
|
||||
retrieved_neighbors = torch.gather(candidates, 1, indices)
|
||||
|
||||
retrieval_one_hot.resize_(batch_size * k, num_classes).zero_()
|
||||
retrieval_one_hot.scatter_(1, retrieved_neighbors.view(-1, 1), 1)
|
||||
distances_transform = distances.clone().div_(T).exp_()
|
||||
probs = torch.sum(
|
||||
torch.mul(
|
||||
retrieval_one_hot.view(batch_size, -1, num_classes),
|
||||
distances_transform.view(batch_size, -1, 1),
|
||||
),
|
||||
1,
|
||||
)
|
||||
_, predictions = probs.sort(1, True)
|
||||
|
||||
# find the predictions that match the target
|
||||
correct = predictions.eq(targets.data.view(-1, 1))
|
||||
top1 = top1 + correct.narrow(1, 0, 1).sum().item()
|
||||
top5 = top5 + correct.narrow(1, 0, 5).sum().item()
|
||||
total += targets.size(0)
|
||||
top1 = top1 * 100.0 / total
|
||||
top5 = top5 * 100.0 / total
|
||||
return top1, top5
|
||||
|
||||
|
||||
class ReturnIndexDataset(datasets.ImageFolder):
|
||||
def __getitem__(self, idx):
|
||||
img, lab = super(ReturnIndexDataset, self).__getitem__(idx)
|
||||
return img, idx
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser('Evaluation with weighted k-NN on ImageNet')
|
||||
parser.add_argument('--batch_size_per_gpu', default=128, type=int, help='Per-GPU batch-size')
|
||||
parser.add_argument('--nb_knn', default=[10, 20, 100, 200], nargs='+', type=int,
|
||||
help='Number of NN to use. 20 is usually working the best.')
|
||||
parser.add_argument('--temperature', default=0.07, type=float,
|
||||
help='Temperature used in the voting coefficient')
|
||||
parser.add_argument('--pretrained_weights', default='', type=str, help="Path to pretrained weights to evaluate.")
|
||||
parser.add_argument('--use_cuda', default=True, type=utils.bool_flag,
|
||||
help="Should we store the features on GPU? We recommend setting this to False if you encounter OOM")
|
||||
parser.add_argument('--arch', default='deit_small', type=str,
|
||||
choices=['deit_tiny', 'deit_small', 'vit_base'], help='Architecture (support only ViT atm).')
|
||||
parser.add_argument('--patch_size', default=16, type=int, help='Patch resolution of the model.')
|
||||
parser.add_argument("--checkpoint_key", default="teacher", type=str,
|
||||
help='Key to use in the checkpoint (example: "teacher")')
|
||||
parser.add_argument('--dump_features', default=None,
|
||||
help='Path where to save computed features, empty for no saving')
|
||||
parser.add_argument('--load_features', default=None, help="""If the features have
|
||||
already been computed, where to find them.""")
|
||||
parser.add_argument('--num_workers', default=10, type=int, help='Number of data loading workers per GPU.')
|
||||
parser.add_argument("--dist_url", default="env://", type=str, help="""url used to set up
|
||||
distributed training; see https://pytorch.org/docs/stable/distributed.html""")
|
||||
parser.add_argument("--local_rank", default=0, type=int, help="Please ignore and do not set this argument.")
|
||||
parser.add_argument('--data_path', default='/path/to/imagenet/', type=str)
|
||||
args = parser.parse_args()
|
||||
|
||||
utils.init_distributed_mode(args)
|
||||
print("git:\n {}\n".format(utils.get_sha()))
|
||||
print("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items())))
|
||||
cudnn.benchmark = True
|
||||
|
||||
if args.load_features:
|
||||
train_features = torch.load(os.path.join(args.load_features, "trainfeat.pth"))
|
||||
test_features = torch.load(os.path.join(args.load_features, "testfeat.pth"))
|
||||
train_labels = torch.load(os.path.join(args.load_features, "trainlabels.pth"))
|
||||
test_labels = torch.load(os.path.join(args.load_features, "testlabels.pth"))
|
||||
else:
|
||||
# need to extract features !
|
||||
train_features, test_features, train_labels, test_labels = extract_feature_pipeline(args)
|
||||
|
||||
if utils.get_rank() == 0:
|
||||
if args.use_cuda:
|
||||
train_features = train_features.cuda()
|
||||
test_features = test_features.cuda()
|
||||
train_labels = train_labels.cuda()
|
||||
test_labels = test_labels.cuda()
|
||||
|
||||
print("Features are ready!\nStart the k-NN classification.")
|
||||
for k in args.nb_knn:
|
||||
top1, top5 = knn_classifier(train_features, train_labels,
|
||||
test_features, test_labels, k, args.temperature)
|
||||
print(f"{k}-NN classifier result: Top1: {top1}, Top5: {top5}")
|
||||
dist.barrier()
|
|
@ -0,0 +1,221 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
import os
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.distributed as dist
|
||||
import torch.backends.cudnn as cudnn
|
||||
from torchvision import datasets
|
||||
from torchvision import transforms as pth_transforms
|
||||
|
||||
import utils
|
||||
import vision_transformer as vits
|
||||
|
||||
|
||||
def eval_linear(args):
|
||||
utils.init_distributed_mode(args)
|
||||
print("git:\n {}\n".format(utils.get_sha()))
|
||||
print("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items())))
|
||||
cudnn.benchmark = True
|
||||
|
||||
# ============ preparing data ... ============
|
||||
train_transform = pth_transforms.Compose([
|
||||
pth_transforms.RandomResizedCrop(224),
|
||||
pth_transforms.RandomHorizontalFlip(),
|
||||
pth_transforms.ToTensor(),
|
||||
pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
|
||||
])
|
||||
val_transform = pth_transforms.Compose([
|
||||
pth_transforms.Resize(256, interpolation=3),
|
||||
pth_transforms.CenterCrop(224),
|
||||
pth_transforms.ToTensor(),
|
||||
pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
|
||||
])
|
||||
dataset_train = datasets.ImageFolder(os.path.join(args.data_path, "train"), transform=train_transform)
|
||||
dataset_val = datasets.ImageFolder(os.path.join(args.data_path, "val"), transform=val_transform)
|
||||
sampler = torch.utils.data.distributed.DistributedSampler(dataset_train)
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
dataset_train,
|
||||
sampler=sampler,
|
||||
batch_size=args.batch_size_per_gpu,
|
||||
num_workers=args.num_workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
val_loader = torch.utils.data.DataLoader(
|
||||
dataset_val,
|
||||
batch_size=args.batch_size_per_gpu,
|
||||
num_workers=args.num_workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
print(f"Data loaded with {len(dataset_train)} train and {len(dataset_val)} val imgs.")
|
||||
|
||||
# ============ building network ... ============
|
||||
model = vits.__dict__[args.arch](patch_size=args.patch_size, num_classes=0)
|
||||
model.cuda()
|
||||
model.eval()
|
||||
print(f"Model {args.arch} {args.patch_size}x{args.patch_size} built.")
|
||||
# load weights to evaluate
|
||||
utils.load_pretrained_weights(model, args.pretrained_weights, args.checkpoint_key, args.arch, args.patch_size)
|
||||
|
||||
linear_classifier = LinearClassifier(model.embed_dim * (args.n_last_blocks + int(args.avgpool_patchtokens)))
|
||||
linear_classifier = linear_classifier.cuda()
|
||||
linear_classifier = nn.parallel.DistributedDataParallel(linear_classifier, device_ids=[args.gpu])
|
||||
|
||||
# set optimizer
|
||||
optimizer = torch.optim.SGD(
|
||||
linear_classifier.parameters(),
|
||||
args.lr * (args.batch_size_per_gpu * utils.get_world_size()) / 256., # linear scaling rule
|
||||
momentum=0.9,
|
||||
weight_decay=0, # we do not apply weight decay
|
||||
)
|
||||
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs, eta_min=0)
|
||||
|
||||
# Optionally resume from a checkpoint
|
||||
to_restore = {"epoch": 0, "best_acc": 0.}
|
||||
utils.restart_from_checkpoint(
|
||||
os.path.join(args.output_dir, "checkpoint.pth.tar"),
|
||||
run_variables=to_restore,
|
||||
state_dict=linear_classifier,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
start_epoch = to_restore["epoch"]
|
||||
best_acc = to_restore["best_acc"]
|
||||
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
train_loader.sampler.set_epoch(epoch)
|
||||
|
||||
train_stats = train(model, linear_classifier, optimizer, train_loader, epoch, args.n_last_blocks, args.avgpool_patchtokens)
|
||||
scheduler.step()
|
||||
|
||||
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
|
||||
'epoch': epoch}
|
||||
if epoch % args.val_freq == 0 or epoch == args.epochs - 1:
|
||||
test_stats = validate_network(val_loader, model, linear_classifier, args.n_last_blocks, args.avgpool_patchtokens)
|
||||
print(f"Accuracy at epoch {epoch} of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
|
||||
best_acc = max(best_acc, test_stats["acc1"])
|
||||
print(f'Max accuracy so far: {best_acc:.2f}%')
|
||||
log_stats = {**{k: v for k, v in log_stats.items()},
|
||||
**{f'test_{k}': v for k, v in test_stats.items()}}
|
||||
if utils.is_main_process():
|
||||
with (Path(args.output_dir) / "log.txt").open("a") as f:
|
||||
f.write(json.dumps(log_stats) + "\n")
|
||||
save_dict = {
|
||||
"epoch": epoch + 1,
|
||||
"state_dict": linear_classifier.state_dict(),
|
||||
"optimizer": optimizer.state_dict(),
|
||||
"scheduler": scheduler.state_dict(),
|
||||
"best_acc": best_acc,
|
||||
}
|
||||
torch.save(save_dict, os.path.join(args.output_dir, "checkpoint.pth.tar"))
|
||||
print("Training of the supervised linear classifier on frozen features completed.\n"
|
||||
"Top-1 test accuracy: {acc:.1f}".format(acc=max_accuracy))
|
||||
|
||||
|
||||
def train(model, linear_classifier, optimizer, loader, epoch, n, avgpool):
|
||||
linear_classifier.train()
|
||||
metric_logger = utils.MetricLogger(delimiter=" ")
|
||||
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
|
||||
header = 'Epoch: [{}]'.format(epoch)
|
||||
for (inp, target) in metric_logger.log_every(loader, 20, header):
|
||||
# move to gpu
|
||||
inp = inp.cuda(non_blocking=True)
|
||||
target = target.cuda(non_blocking=True)
|
||||
|
||||
# forward
|
||||
with torch.no_grad():
|
||||
output = model.forward_return_n_last_blocks(inp, n, avgpool)
|
||||
output = linear_classifier(output)
|
||||
|
||||
# compute cross entropy loss
|
||||
loss = nn.CrossEntropyLoss()(output, target)
|
||||
|
||||
# compute the gradients
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
|
||||
# step
|
||||
optimizer.step()
|
||||
|
||||
# log
|
||||
torch.cuda.synchronize()
|
||||
metric_logger.update(loss=loss.item())
|
||||
metric_logger.update(lr=optimizer.param_groups[0]["lr"])
|
||||
# gather the stats from all processes
|
||||
metric_logger.synchronize_between_processes()
|
||||
print("Averaged stats:", metric_logger)
|
||||
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def validate_network(val_loader, model, linear_classifier, n, avgpool):
|
||||
linear_classifier.eval()
|
||||
metric_logger = utils.MetricLogger(delimiter=" ")
|
||||
header = 'Test:'
|
||||
for inp, target in metric_logger.log_every(val_loader, 20, header):
|
||||
# move to gpu
|
||||
inp = inp.cuda(non_blocking=True)
|
||||
target = target.cuda(non_blocking=True)
|
||||
|
||||
# compute output
|
||||
output = model.forward_return_n_last_blocks(inp, n, avgpool)
|
||||
output = linear_classifier(output)
|
||||
loss = nn.CrossEntropyLoss()(output, target)
|
||||
|
||||
acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
|
||||
|
||||
batch_size = inp.shape[0]
|
||||
metric_logger.update(loss=loss.item())
|
||||
metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
|
||||
metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
|
||||
print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
|
||||
.format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))
|
||||
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
|
||||
|
||||
|
||||
class LinearClassifier(nn.Module):
|
||||
"""Linear layer to train on top of frozen features"""
|
||||
def __init__(self, dim, num_labels=1000):
|
||||
super(LinearClassifier, self).__init__()
|
||||
self.linear = nn.Linear(dim, num_labels)
|
||||
self.linear.weight.data.normal_(mean=0.0, std=0.01)
|
||||
self.linear.bias.data.zero_()
|
||||
|
||||
def forward(self, x):
|
||||
# flatten
|
||||
x = x.view(x.size(0), -1)
|
||||
|
||||
# linear layer
|
||||
return self.linear(x)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser('Evaluation with linear classification on ImageNet')
|
||||
parser.add_argument('--n_last_blocks', default=4, type=int, help="""Concatenate [CLS] tokens
|
||||
for the `n` last blocks. We use `n=4` when evaluating DeiT-Small and `n=1` with ViT-Base.""")
|
||||
parser.add_argument('--avgpool_patchtokens', default=False, type=utils.bool_flag,
|
||||
help="""Whether ot not to concatenate the global average pooled features to the [CLS] token.
|
||||
We typically set this to False for DeiT-Small and to True with ViT-Base.""")
|
||||
parser.add_argument('--arch', default='deit_small', type=str,
|
||||
choices=['deit_tiny', 'deit_small', 'vit_base'], help='Architecture (support only ViT atm).')
|
||||
parser.add_argument('--patch_size', default=16, type=int, help='Patch resolution of the model.')
|
||||
parser.add_argument('--pretrained_weights', default='', type=str, help="Path to pretrained weights to evaluate.")
|
||||
parser.add_argument("--checkpoint_key", default="teacher", type=str, help='Key to use in the checkpoint (example: "teacher")')
|
||||
parser.add_argument('--epochs', default=100, type=int, help='Number of epochs of training.')
|
||||
parser.add_argument("--lr", default=0.001, type=float, help="""Learning rate at the beginning of
|
||||
training (highest LR used during training). The learning rate is linearly scaled
|
||||
with the batch size, and specified here for a reference batch size of 256.
|
||||
We recommend tweaking the LR depending on the checkpoint evaluated.""")
|
||||
parser.add_argument('--batch_size_per_gpu', default=128, type=int, help='Per-GPU batch-size')
|
||||
parser.add_argument("--dist_url", default="env://", type=str, help="""url used to set up
|
||||
distributed training; see https://pytorch.org/docs/stable/distributed.html""")
|
||||
parser.add_argument("--local_rank", default=0, type=int, help="Please ignore and do not set this argument.")
|
||||
parser.add_argument('--data_path', default='/path/to/imagenet/', type=str)
|
||||
parser.add_argument('--num_workers', default=10, type=int, help='Number of data loading workers per GPU.')
|
||||
parser.add_argument('--val_freq', default=1, type=int, help="Epoch frequency for validation.")
|
||||
parser.add_argument('--output_dir', default=".", help='Path to save logs and checkpoints')
|
||||
args = parser.parse_args()
|
||||
eval_linear(args)
|
|
@ -0,0 +1,83 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
import torch
|
||||
from torchvision.models.resnet import resnet50
|
||||
|
||||
import vision_transformer as vits
|
||||
|
||||
dependencies = ["torch", "torchvision"]
|
||||
|
||||
|
||||
def dino_deits16(pretrained=True, **kwargs):
|
||||
"""
|
||||
DeiT-Small/16x16 pre-trained with DINO.
|
||||
Achieves 74.5% top-1 accuracy on ImageNet with k-NN classification.
|
||||
"""
|
||||
model = vits.__dict__["deit_small"](patch_size=16, num_classes=0, **kwargs)
|
||||
if pretrained:
|
||||
state_dict = torch.hub.load_state_dict_from_url(
|
||||
url="https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth",
|
||||
map_location="cpu",
|
||||
)
|
||||
model.load_state_dict(state_dict, strict=True)
|
||||
return model
|
||||
|
||||
|
||||
def dino_deits8(pretrained=True, **kwargs):
|
||||
"""
|
||||
DeiT-Small/8x8 pre-trained with DINO.
|
||||
Achieves 78.3% top-1 accuracy on ImageNet with k-NN classification.
|
||||
"""
|
||||
model = vits.__dict__["deit_small"](patch_size=8, num_classes=0, **kwargs)
|
||||
if pretrained:
|
||||
state_dict = torch.hub.load_state_dict_from_url(
|
||||
url="https://dl.fbaipublicfiles.com/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth",
|
||||
map_location="cpu",
|
||||
)
|
||||
model.load_state_dict(state_dict, strict=True)
|
||||
return model
|
||||
|
||||
|
||||
def dino_vitb16(pretrained=True, **kwargs):
|
||||
"""
|
||||
ViT-Base/16x16 pre-trained with DINO.
|
||||
Achieves 76.1% top-1 accuracy on ImageNet with k-NN classification.
|
||||
"""
|
||||
model = vits.__dict__["vit_base"](patch_size=16, num_classes=0, **kwargs)
|
||||
if pretrained:
|
||||
state_dict = torch.hub.load_state_dict_from_url(
|
||||
url="https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth",
|
||||
map_location="cpu",
|
||||
)
|
||||
model.load_state_dict(state_dict, strict=True)
|
||||
return model
|
||||
|
||||
|
||||
def dino_vitb8(pretrained=True, **kwargs):
|
||||
"""
|
||||
ViT-Base/8x8 pre-trained with DINO.
|
||||
Achieves 77.4% top-1 accuracy on ImageNet with k-NN classification.
|
||||
"""
|
||||
model = vits.__dict__["vit_base"](patch_size=8, num_classes=0, **kwargs)
|
||||
if pretrained:
|
||||
state_dict = torch.hub.load_state_dict_from_url(
|
||||
url="https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth",
|
||||
map_location="cpu",
|
||||
)
|
||||
model.load_state_dict(state_dict, strict=True)
|
||||
return model
|
||||
|
||||
|
||||
def dino_resnet50(pretrained=True, **kwargs):
|
||||
"""
|
||||
ResNet-50 pre-trained with DINO.
|
||||
Achieves 75.3% top-1 accuracy on ImageNet linear evaluation benchmark.
|
||||
Note that `fc.weight` and `fc.bias` are randomly initialized.
|
||||
"""
|
||||
model = resnet50(pretrained=False, **kwargs)
|
||||
if pretrained:
|
||||
state_dict = torch.hub.load_state_dict_from_url(
|
||||
url="https://dl.fbaipublicfiles.com/dino/dino_resnet50_pretrain/dino_resnet50_pretrain.pth",
|
||||
map_location="cpu",
|
||||
)
|
||||
model.load_state_dict(state_dict, strict=True)
|
||||
return model
|
|
@ -0,0 +1,455 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import datetime
|
||||
import time
|
||||
import math
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.distributed as dist
|
||||
import torch.backends.cudnn as cudnn
|
||||
import torch.nn.functional as F
|
||||
from torchvision import datasets, transforms
|
||||
from torchvision import models as torchvision_models
|
||||
|
||||
import utils
|
||||
import vision_transformer as vits
|
||||
from vision_transformer import DINOHead
|
||||
|
||||
torchvision_archs = sorted(name for name in torchvision_models.__dict__
|
||||
if name.islower() and not name.startswith("__")
|
||||
and callable(torchvision_models.__dict__[name]))
|
||||
|
||||
def get_args_parser():
|
||||
parser = argparse.ArgumentParser('DINO', add_help=False)
|
||||
|
||||
# Model parameters
|
||||
parser.add_argument('--arch', default='deit_small', type=str,
|
||||
choices=['deit_tiny', 'deit_small', 'vit_base'] + torchvision_archs,
|
||||
help="""Name of architecture to train. For quick experiments with ViTs,
|
||||
we recommend using deit_tiny or deit_small.""")
|
||||
parser.add_argument('--patch_size', default=16, type=int, help="""Size in pixels
|
||||
of input square patches - default 16 (for 16x16 patches). Using smaller
|
||||
values leads to better performance but requires more memory. Applies only
|
||||
for ViTs (deit_tiny, deit_small and vit_base). If <16, we recommend disabling
|
||||
mixed precision training (--use_fp16 false) to avoid unstabilities.""")
|
||||
parser.add_argument('--out_dim', default=65536, type=int, help="""Dimensionality of
|
||||
the DINO head output. For complex and large datasets large values (like 65k) work well.""")
|
||||
parser.add_argument('--norm_last_layer', default=True, type=utils.bool_flag,
|
||||
help="""Whether or not to weight normalize the last layer of the DINO head.
|
||||
Not normalizing leads to better performance but can make the training unstable.
|
||||
In our experiments, we typically set this paramater to False with deit_small and True with vit_base.""")
|
||||
parser.add_argument('--momentum_teacher', default=0.996, type=float, help="""Base EMA
|
||||
parameter for teacher update. The value is increased to 1 during training with cosine schedule.
|
||||
We recommend setting a higher value with small batches: for example use 0.9995 with batch size of 256.""")
|
||||
parser.add_argument('--use_bn_in_head', default=False, type=utils.bool_flag,
|
||||
help="Whether to use batch normalizations in projection head (Default: False)")
|
||||
|
||||
# Temperature teacher parameters
|
||||
parser.add_argument('--warmup_teacher_temp', default=0.04, type=float,
|
||||
help="""Initial value for the teacher temperature: 0.04 works well in most cases.
|
||||
Try decreasing it if the training loss does not decrease.""")
|
||||
parser.add_argument('--teacher_temp', default=0.04, type=float, help="""Final value (after linear warmup)
|
||||
of the teacher temperature. For most experiments, anything above 0.07 is unstable. We recommend
|
||||
starting with the default value of 0.04 and increase this slightly if needed.""")
|
||||
parser.add_argument('--warmup_teacher_temp_epochs', default=0, type=int,
|
||||
help='Number of warmup epochs for the teacher temperature (Default: 30).')
|
||||
|
||||
# Training/Optimization parameters
|
||||
parser.add_argument('--use_fp16', type=utils.bool_flag, default=True, help="""Whether or not
|
||||
to use half precision for training. Improves training time and memory requirements,
|
||||
but can provoke instability and slight decay of performance. We recommend disabling
|
||||
mixed precision if the loss is unstable, if reducing the patch size or if training with bigger ViTs.""")
|
||||
parser.add_argument('--weight_decay', type=float, default=0.04, help="""Initial value of the
|
||||
weight decay. With ViT, a smaller value at the beginning of training works well.""")
|
||||
parser.add_argument('--weight_decay_end', type=float, default=0.4, help="""Final value of the
|
||||
weight decay. We use a cosine schedule for WD and using a larger decay by
|
||||
the end of training improves performance for ViTs.""")
|
||||
parser.add_argument('--clip_grad', type=float, default=3.0, help="""Maximal parameter
|
||||
gradient norm if using gradient clipping. Clipping with norm .3 ~ 1.0 can
|
||||
help optimization for larger ViT architectures. 0 for disabling.""")
|
||||
parser.add_argument('--batch_size_per_gpu', default=64, type=int,
|
||||
help='Per-GPU batch-size : number of distinct images loaded on one GPU.')
|
||||
parser.add_argument('--epochs', default=100, type=int, help='Number of epochs of training.')
|
||||
parser.add_argument('--freeze_last_layer', default=1, type=int, help="""Number of epochs
|
||||
during which we keep the output layer fixed. Typically doing so during
|
||||
the first epoch helps training. Try increasing this value if the loss does not decrease.""")
|
||||
parser.add_argument("--lr", default=0.0005, type=float, help="""Learning rate at the end of
|
||||
linear warmup (highest LR used during training). The learning rate is linearly scaled
|
||||
with the batch size, and specified here for a reference batch size of 256.""")
|
||||
parser.add_argument("--warmup_epochs", default=10, type=int,
|
||||
help="Number of epochs for the linear learning-rate warm up.")
|
||||
parser.add_argument('--min_lr', type=float, default=1e-6, help="""Target LR at the
|
||||
end of optimization. We use a cosine LR schedule with linear warmup.""")
|
||||
parser.add_argument('--optimizer', default='adamw', type=str,
|
||||
choices=['adamw', 'sgd', 'lars'], help="""Type of optimizer. We recommend using adamw with ViTs.""")
|
||||
|
||||
# Multi-crop parameters
|
||||
parser.add_argument('--global_crops_scale', type=float, nargs='+', default=(0.4, 1.),
|
||||
help="""Scale range of the cropped image before resizing, relatively to the origin image.
|
||||
Used for large global view cropping. When disabling multi-crop (--local_crops_number 0), we
|
||||
recommand using a wider range of scale ("--global_crops_scale 0.14 1." for example)""")
|
||||
parser.add_argument('--local_crops_number', type=int, default=8, help="""Number of small
|
||||
local views to generate. Set this parameter to 0 to disable multi-crop training.
|
||||
When disabling multi-crop we recommend to use "--global_crops_scale 0.14 1." """)
|
||||
parser.add_argument('--local_crops_scale', type=float, nargs='+', default=(0.05, 0.4),
|
||||
help="""Scale range of the cropped image before resizing, relatively to the origin image.
|
||||
Used for small local view cropping of multi-crop.""")
|
||||
|
||||
# Misc
|
||||
parser.add_argument('--data_path', default='/path/to/imagenet/train/', type=str,
|
||||
help='Please specify path to the ImageNet training data.')
|
||||
parser.add_argument('--output_dir', default=".", type=str, help='Path to save logs and checkpoints.')
|
||||
parser.add_argument('--saveckp_freq', default=20, type=int, help='Save checkpoint every x epochs.')
|
||||
parser.add_argument('--seed', default=0, type=int, help='Random seed.')
|
||||
parser.add_argument('--num_workers', default=10, type=int, help='Number of data loading workers per GPU.')
|
||||
parser.add_argument("--dist_url", default="env://", type=str, help="""url used to set up
|
||||
distributed training; see https://pytorch.org/docs/stable/distributed.html""")
|
||||
parser.add_argument("--local_rank", default=0, type=int, help="Please ignore and do not set this argument.")
|
||||
return parser
|
||||
|
||||
|
||||
def train_dino(args):
|
||||
utils.init_distributed_mode(args)
|
||||
utils.fix_random_seeds(args.seed)
|
||||
print("git:\n {}\n".format(utils.get_sha()))
|
||||
print("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items())))
|
||||
cudnn.benchmark = True
|
||||
|
||||
# ============ preparing data ... ============
|
||||
transform = DataAugmentationDINO(
|
||||
args.global_crops_scale,
|
||||
args.local_crops_scale,
|
||||
args.local_crops_number,
|
||||
)
|
||||
dataset = datasets.ImageFolder(args.data_path, transform=transform)
|
||||
sampler = torch.utils.data.DistributedSampler(dataset, shuffle=True)
|
||||
data_loader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
sampler=sampler,
|
||||
batch_size=args.batch_size_per_gpu,
|
||||
num_workers=args.num_workers,
|
||||
pin_memory=True,
|
||||
drop_last=True,
|
||||
)
|
||||
print(f"Data loaded: there are {len(dataset)} images.")
|
||||
|
||||
# ============ building student and teacher networks ... ============
|
||||
# if the network is a vision transformer (i.e. deit_tiny, deit_small, vit_base)
|
||||
if args.arch in vits.__dict__.keys():
|
||||
student = vits.__dict__[args.arch](
|
||||
patch_size=args.patch_size,
|
||||
drop_path_rate=0.1, # stochastic depth
|
||||
)
|
||||
teacher = vits.__dict__[args.arch](patch_size=args.patch_size)
|
||||
student.head = DINOHead(
|
||||
student.embed_dim,
|
||||
args.out_dim,
|
||||
use_bn=args.use_bn_in_head,
|
||||
norm_last_layer=args.norm_last_layer,
|
||||
)
|
||||
teacher.head = DINOHead(teacher.embed_dim, args.out_dim, args.use_bn_in_head)
|
||||
|
||||
# otherwise, we check if the architecture is in torchvision models
|
||||
elif args.arch in torchvision_models.__dict__.keys():
|
||||
student = torchvision_models.__dict__[args.arch]()
|
||||
teacher = torchvision_models.__dict__[args.arch]()
|
||||
embed_dim = student.fc.weight.shape[1]
|
||||
student = utils.MultiCropWrapper(student, DINOHead(
|
||||
embed_dim,
|
||||
args.out_dim,
|
||||
use_bn=args.use_bn_in_head,
|
||||
norm_last_layer=args.norm_last_layer,
|
||||
))
|
||||
teacher = utils.MultiCropWrapper(
|
||||
teacher,
|
||||
DINOHead(embed_dim, args.out_dim, args.use_bn_in_head),
|
||||
)
|
||||
else:
|
||||
print(f"Unknow architecture: {args.arch}")
|
||||
|
||||
# move networks to gpu
|
||||
student, teacher = student.cuda(), teacher.cuda()
|
||||
# synchronize batch norms (if any)
|
||||
if utils.has_batchnorms(student):
|
||||
student = nn.SyncBatchNorm.convert_sync_batchnorm(student)
|
||||
teacher = nn.SyncBatchNorm.convert_sync_batchnorm(teacher)
|
||||
|
||||
# we need DDP wrapper to have synchro batch norms working...
|
||||
teacher = nn.parallel.DistributedDataParallel(teacher, device_ids=[args.gpu])
|
||||
teacher_without_ddp = teacher.module
|
||||
else:
|
||||
# teacher_without_ddp and teacher are the same thing
|
||||
teacher_without_ddp = teacher
|
||||
student = nn.parallel.DistributedDataParallel(student, device_ids=[args.gpu])
|
||||
# teacher and student start with the same weights
|
||||
teacher_without_ddp.load_state_dict(student.module.state_dict())
|
||||
# there is no backpropagation through the teacher, so no need for gradients
|
||||
for p in teacher.parameters():
|
||||
p.requires_grad = False
|
||||
print(f"Student and Teacher are built: they are both {args.arch} network.")
|
||||
|
||||
# ============ preparing loss ... ============
|
||||
dino_loss = DINOLoss(
|
||||
args.out_dim,
|
||||
args.local_crops_number + 2, # total number of crops = 2 global crops + local_crops_number
|
||||
args.warmup_teacher_temp,
|
||||
args.teacher_temp,
|
||||
args.warmup_teacher_temp_epochs,
|
||||
args.epochs,
|
||||
).cuda()
|
||||
|
||||
# ============ preparing optimizer ... ============
|
||||
params_groups = utils.get_params_groups(student)
|
||||
if args.optimizer == "adamw":
|
||||
optimizer = torch.optim.AdamW(params_groups) # to use with ViTs
|
||||
elif args.optimizer == "sgd":
|
||||
optimizer = torch.optim.SGD(params_groups, lr=0, momentum=0.9) # lr is set by scheduler
|
||||
elif args.optimizer == "lars":
|
||||
optimizer = utils.LARS(params_groups) # to use with convnet and large batches
|
||||
# for mixed precision training
|
||||
fp16_scaler = None
|
||||
if args.use_fp16:
|
||||
fp16_scaler = torch.cuda.amp.GradScaler()
|
||||
|
||||
# ============ init schedulers ... ============
|
||||
lr_schedule = utils.cosine_scheduler(
|
||||
args.lr * (args.batch_size_per_gpu * utils.get_world_size()) / 256., # linear scaling rule
|
||||
args.min_lr,
|
||||
args.epochs, len(data_loader),
|
||||
warmup_epochs=args.warmup_epochs,
|
||||
)
|
||||
wd_schedule = utils.cosine_scheduler(
|
||||
args.weight_decay,
|
||||
args.weight_decay_end,
|
||||
args.epochs, len(data_loader),
|
||||
)
|
||||
# momentum parameter is increased to 1. during training with a cosine schedule
|
||||
momentum_schedule = utils.cosine_scheduler(args.momentum_teacher, 1,
|
||||
args.epochs, len(data_loader))
|
||||
print(f"Loss, optimizer and schedulers ready.")
|
||||
|
||||
# ============ optionally resume training ... ============
|
||||
to_restore = {"epoch": 0}
|
||||
utils.restart_from_checkpoint(
|
||||
os.path.join(args.output_dir, "checkpoint.pth"),
|
||||
run_variables=to_restore,
|
||||
student=student,
|
||||
teacher=teacher,
|
||||
optimizer=optimizer,
|
||||
fp16_scaler=fp16_scaler,
|
||||
dino_loss=dino_loss,
|
||||
)
|
||||
start_epoch = to_restore["epoch"]
|
||||
|
||||
start_time = time.time()
|
||||
print("Starting DINO training !")
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
data_loader.sampler.set_epoch(epoch)
|
||||
|
||||
# ============ training one epoch of DINO ... ============
|
||||
train_stats = train_one_epoch(student, teacher, teacher_without_ddp, dino_loss,
|
||||
data_loader, optimizer, lr_schedule, wd_schedule, momentum_schedule,
|
||||
epoch, fp16_scaler, args)
|
||||
|
||||
# ============ writing logs ... ============
|
||||
save_dict = {
|
||||
'student': student.state_dict(),
|
||||
'teacher': teacher.state_dict(),
|
||||
'optimizer': optimizer.state_dict(),
|
||||
'epoch': epoch + 1,
|
||||
'args': args,
|
||||
'dino_loss': dino_loss.state_dict(),
|
||||
}
|
||||
if fp16_scaler is not None:
|
||||
save_dict['fp16_scaler'] = fp16_scaler.state_dict()
|
||||
utils.save_on_master(save_dict, os.path.join(args.output_dir, 'checkpoint.pth'))
|
||||
if args.saveckp_freq and epoch % args.saveckp_freq == 0:
|
||||
utils.save_on_master(save_dict, os.path.join(args.output_dir, f'checkpoint{epoch:04}.pth'))
|
||||
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
|
||||
'epoch': epoch}
|
||||
if utils.is_main_process():
|
||||
with (Path(args.output_dir) / "log.txt").open("a") as f:
|
||||
f.write(json.dumps(log_stats) + "\n")
|
||||
total_time = time.time() - start_time
|
||||
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
||||
print('Training time {}'.format(total_time_str))
|
||||
|
||||
|
||||
def train_one_epoch(student, teacher, teacher_without_ddp, dino_loss, data_loader,
|
||||
optimizer, lr_schedule, wd_schedule, momentum_schedule,epoch,
|
||||
fp16_scaler, args):
|
||||
metric_logger = utils.MetricLogger(delimiter=" ")
|
||||
header = 'Epoch: [{}/{}]'.format(epoch, args.epochs)
|
||||
for it, (images, _) in enumerate(metric_logger.log_every(data_loader, 10, header)):
|
||||
# update weight decay and learning rate according to their schedule
|
||||
it = len(data_loader) * epoch + it # global training iteration
|
||||
for i, param_group in enumerate(optimizer.param_groups):
|
||||
param_group["lr"] = lr_schedule[it]
|
||||
if i == 0: # only the first group is regularized
|
||||
param_group["weight_decay"] = wd_schedule[it]
|
||||
|
||||
# move images to gpu
|
||||
images = [im.cuda(non_blocking=True) for im in images]
|
||||
# teacher and student forward passes + compute dino loss
|
||||
with torch.cuda.amp.autocast(fp16_scaler is not None):
|
||||
teacher_output = teacher(images[:2]) # only the 2 global views pass through the teacher
|
||||
student_output = student(images)
|
||||
loss = dino_loss(student_output, teacher_output, epoch)
|
||||
|
||||
if not math.isfinite(loss.item()):
|
||||
print("Loss is {}, stopping training".format(loss.item()), force=True)
|
||||
sys.exit(1)
|
||||
|
||||
# student update
|
||||
optimizer.zero_grad()
|
||||
param_norms = None
|
||||
if fp16_scaler is None:
|
||||
loss.backward()
|
||||
if args.clip_grad:
|
||||
param_norms = utils.clip_gradients(model, args.clip_grad)
|
||||
utils.cancel_gradients_last_layer(epoch, student,
|
||||
args.freeze_last_layer)
|
||||
optimizer.step()
|
||||
else:
|
||||
fp16_scaler.scale(loss).backward()
|
||||
if args.clip_grad:
|
||||
fp16_scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
|
||||
param_norms = utils.clip_gradients(student, args.clip_grad)
|
||||
utils.cancel_gradients_last_layer(epoch, student,
|
||||
args.freeze_last_layer)
|
||||
fp16_scaler.step(optimizer)
|
||||
fp16_scaler.update()
|
||||
|
||||
# EMA update for the teacher
|
||||
with torch.no_grad():
|
||||
m = momentum_schedule[it] # momentum parameter
|
||||
for param_q, param_k in zip(student.module.parameters(), teacher_without_ddp.parameters()):
|
||||
param_k.data.mul_(m).add_((1 - m) * param_q.detach().data)
|
||||
|
||||
# logging
|
||||
torch.cuda.synchronize()
|
||||
metric_logger.update(loss=loss.item())
|
||||
metric_logger.update(lr=optimizer.param_groups[0]["lr"])
|
||||
metric_logger.update(wd=optimizer.param_groups[0]["weight_decay"])
|
||||
# gather the stats from all processes
|
||||
metric_logger.synchronize_between_processes()
|
||||
print("Averaged stats:", metric_logger)
|
||||
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
|
||||
|
||||
|
||||
class DINOLoss(nn.Module):
|
||||
def __init__(self, out_dim, ncrops, warmup_teacher_temp, teacher_temp,
|
||||
warmup_teacher_temp_epochs, nepochs, student_temp=0.1,
|
||||
center_momentum=0.9):
|
||||
super().__init__()
|
||||
self.student_temp = student_temp
|
||||
self.center_momentum = center_momentum
|
||||
self.ncrops = ncrops
|
||||
self.register_buffer("center", torch.zeros(1, out_dim))
|
||||
# we apply a warm up for the teacher temperature because
|
||||
# a too high temperature makes the training instable at the beginning
|
||||
self.teacher_temp_schedule = np.concatenate((
|
||||
np.linspace(warmup_teacher_temp,
|
||||
teacher_temp, warmup_teacher_temp_epochs),
|
||||
np.ones(nepochs - warmup_teacher_temp_epochs) * teacher_temp
|
||||
))
|
||||
|
||||
def forward(self, student_output, teacher_output, epoch):
|
||||
"""
|
||||
Cross-entropy between softmax outputs of the teacher and student networks.
|
||||
"""
|
||||
student_out = student_output / self.student_temp
|
||||
student_out = student_out.chunk(self.ncrops)
|
||||
|
||||
# teacher centering and sharpening
|
||||
temp = self.teacher_temp_schedule[epoch]
|
||||
teacher_out = F.softmax((teacher_output - self.center) / temp, dim=-1)
|
||||
teacher_out = teacher_out.detach().chunk(2)
|
||||
|
||||
total_loss = 0
|
||||
n_loss_terms = 0
|
||||
for iq, q in enumerate(teacher_out):
|
||||
for v in range(len(student_out)):
|
||||
if v == iq:
|
||||
# we skip cases where student and teacher operate on the same view
|
||||
continue
|
||||
loss = torch.sum(-q * F.log_softmax(student_out[v], dim=-1), dim=-1)
|
||||
total_loss += loss.mean()
|
||||
n_loss_terms += 1
|
||||
total_loss /= n_loss_terms
|
||||
self.update_center(teacher_output)
|
||||
return total_loss
|
||||
|
||||
@torch.no_grad()
|
||||
def update_center(self, teacher_output):
|
||||
"""
|
||||
Update center used for teacher output.
|
||||
"""
|
||||
batch_center = torch.sum(teacher_output, dim=0, keepdim=True)
|
||||
dist.all_reduce(batch_center)
|
||||
batch_center = batch_center / (len(teacher_output) * dist.get_world_size())
|
||||
|
||||
# ema update
|
||||
self.center = self.center * self.center_momentum + batch_center * (1 - self.center_momentum)
|
||||
|
||||
|
||||
class DataAugmentationDINO(object):
|
||||
def __init__(self, global_crops_scale, local_crops_scale, local_crops_number):
|
||||
flip_and_color_jitter = transforms.Compose([
|
||||
transforms.RandomHorizontalFlip(p=0.5),
|
||||
transforms.RandomApply(
|
||||
[transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)],
|
||||
p=0.8
|
||||
),
|
||||
transforms.RandomGrayscale(p=0.2),
|
||||
])
|
||||
normalize = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
|
||||
])
|
||||
|
||||
# first global crop
|
||||
self.global_transfo1 = transforms.Compose([
|
||||
transforms.RandomResizedCrop(224, scale=global_crops_scale, interpolation=Image.BICUBIC),
|
||||
flip_and_color_jitter,
|
||||
utils.GaussianBlur(1.0),
|
||||
normalize,
|
||||
])
|
||||
# second global crop
|
||||
self.global_transfo2 = transforms.Compose([
|
||||
transforms.RandomResizedCrop(224, scale=global_crops_scale, interpolation=Image.BICUBIC),
|
||||
flip_and_color_jitter,
|
||||
utils.GaussianBlur(0.1),
|
||||
utils.Solarization(0.2),
|
||||
normalize,
|
||||
])
|
||||
# transformation for the local small crops
|
||||
self.local_crops_number = local_crops_number
|
||||
self.local_transfo = transforms.Compose([
|
||||
transforms.RandomResizedCrop(96, scale=local_crops_scale, interpolation=Image.BICUBIC),
|
||||
flip_and_color_jitter,
|
||||
utils.GaussianBlur(p=0.5),
|
||||
normalize,
|
||||
])
|
||||
|
||||
def __call__(self, image):
|
||||
crops = []
|
||||
crops.append(self.global_transfo1(image))
|
||||
crops.append(self.global_transfo2(image))
|
||||
for _ in range(self.local_crops_number):
|
||||
crops.append(self.local_transfo(image))
|
||||
return crops
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser('DINO', parents=[get_args_parser()])
|
||||
args = parser.parse_args()
|
||||
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
||||
train_dino(args)
|
|
@ -0,0 +1,120 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
"""
|
||||
A script to run multinode training with submitit.
|
||||
Almost copy-paste from https://github.com/facebookresearch/deit/blob/main/run_with_submitit.py
|
||||
"""
|
||||
import argparse
|
||||
import os
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
import main_dino
|
||||
import submitit
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser("Submitit for DINO", parents=[main_dino.get_args_parser()])
|
||||
parser.add_argument("--ngpus", default=8, type=int, help="Number of gpus to request on each node")
|
||||
parser.add_argument("--nodes", default=2, type=int, help="Number of nodes to request")
|
||||
parser.add_argument("--timeout", default=2800, type=int, help="Duration of the job")
|
||||
|
||||
parser.add_argument("--partition", default="learnfair", type=str, help="Partition where to submit")
|
||||
parser.add_argument("--use_volta32", action='store_true', help="Big models? Use this")
|
||||
parser.add_argument('--comment', default="", type=str,
|
||||
help='Comment to pass to scheduler, e.g. priority message')
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def get_shared_folder() -> Path:
|
||||
user = os.getenv("USER")
|
||||
if Path("/checkpoint/").is_dir():
|
||||
p = Path(f"/checkpoint/{user}/experiments")
|
||||
p.mkdir(exist_ok=True)
|
||||
return p
|
||||
raise RuntimeError("No shared folder available")
|
||||
|
||||
|
||||
def get_init_file():
|
||||
# Init file must not exist, but it's parent dir must exist.
|
||||
os.makedirs(str(get_shared_folder()), exist_ok=True)
|
||||
init_file = get_shared_folder() / f"{uuid.uuid4().hex}_init"
|
||||
if init_file.exists():
|
||||
os.remove(str(init_file))
|
||||
return init_file
|
||||
|
||||
|
||||
class Trainer(object):
|
||||
def __init__(self, args):
|
||||
self.args = args
|
||||
|
||||
def __call__(self):
|
||||
import main_dino
|
||||
|
||||
self._setup_gpu_args()
|
||||
main_dino.train_dino(self.args)
|
||||
|
||||
def checkpoint(self):
|
||||
import os
|
||||
import submitit
|
||||
|
||||
self.args.dist_url = get_init_file().as_uri()
|
||||
print("Requeuing ", self.args)
|
||||
empty_trainer = type(self)(self.args)
|
||||
return submitit.helpers.DelayedSubmission(empty_trainer)
|
||||
|
||||
def _setup_gpu_args(self):
|
||||
import submitit
|
||||
from pathlib import Path
|
||||
|
||||
job_env = submitit.JobEnvironment()
|
||||
self.args.output_dir = Path(str(self.args.output_dir).replace("%j", str(job_env.job_id)))
|
||||
self.args.gpu = job_env.local_rank
|
||||
self.args.rank = job_env.global_rank
|
||||
self.args.world_size = job_env.num_tasks
|
||||
print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}")
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
if args.output_dir == "":
|
||||
args.output_dir = get_shared_folder() / "%j"
|
||||
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
||||
executor = submitit.AutoExecutor(folder=args.output_dir, slurm_max_num_timeout=30)
|
||||
|
||||
num_gpus_per_node = args.ngpus
|
||||
nodes = args.nodes
|
||||
timeout_min = args.timeout
|
||||
|
||||
partition = args.partition
|
||||
kwargs = {}
|
||||
if args.use_volta32:
|
||||
kwargs['slurm_constraint'] = 'volta32gb'
|
||||
if args.comment:
|
||||
kwargs['slurm_comment'] = args.comment
|
||||
|
||||
executor.update_parameters(
|
||||
mem_gb=40 * num_gpus_per_node,
|
||||
gpus_per_node=num_gpus_per_node,
|
||||
tasks_per_node=num_gpus_per_node, # one task per GPU
|
||||
cpus_per_task=10,
|
||||
nodes=nodes,
|
||||
timeout_min=timeout_min, # max is 60 * 72
|
||||
# Below are cluster dependent parameters
|
||||
slurm_partition=partition,
|
||||
slurm_signal_delay_s=120,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
executor.update_parameters(name="dino")
|
||||
|
||||
args.dist_url = get_init_file().as_uri()
|
||||
|
||||
trainer = Trainer(args)
|
||||
job = executor.submit(trainer)
|
||||
|
||||
print(f"Submitted job_id: {job.job_id}")
|
||||
print(f"Logs and checkpoints will be saved at: {args.output_dir}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,594 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
"""
|
||||
Misc functions.
|
||||
|
||||
Mostly copy-paste from torchvision references or other public repos like DETR:
|
||||
https://github.com/facebookresearch/detr/blob/master/util/misc.py
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import math
|
||||
import random
|
||||
import datetime
|
||||
import subprocess
|
||||
from collections import defaultdict, deque
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.distributed as dist
|
||||
from PIL import ImageFilter, ImageOps
|
||||
|
||||
|
||||
class GaussianBlur(object):
|
||||
"""
|
||||
Apply Gaussian Blur to the PIL image.
|
||||
"""
|
||||
def __init__(self, p=0.5, radius_min=0.1, radius_max=2.):
|
||||
self.prob = p
|
||||
self.radius_min = radius_min
|
||||
self.radius_max = radius_max
|
||||
|
||||
def __call__(self, img):
|
||||
do_it = random.random() <= self.prob
|
||||
if not do_it:
|
||||
return img
|
||||
|
||||
return img.filter(
|
||||
ImageFilter.GaussianBlur(
|
||||
radius=random.uniform(self.radius_min, self.radius_max)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class Solarization(object):
|
||||
"""
|
||||
Apply Solarization to the PIL image.
|
||||
"""
|
||||
def __init__(self, p):
|
||||
self.p = p
|
||||
|
||||
def __call__(self, img):
|
||||
if random.random() < self.p:
|
||||
return ImageOps.solarize(img)
|
||||
else:
|
||||
return img
|
||||
|
||||
|
||||
def load_pretrained_weights(model, pretrained_weights, checkpoint_key, model_name, patch_size):
|
||||
if os.path.isfile(pretrained_weights):
|
||||
state_dict = torch.load(pretrained_weights, map_location="cpu")
|
||||
if checkpoint_key is not None and checkpoint_key in state_dict:
|
||||
print(f"Take key {checkpoint_key} in provided checkpoint dict")
|
||||
state_dict = state_dict[checkpoint_key]
|
||||
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
|
||||
msg = model.load_state_dict(state_dict, strict=False)
|
||||
print('Pretrained weights found at {} and loaded with msg: {}'.format(pretrained_weights, msg))
|
||||
else:
|
||||
print("Please use the `--pretrained_weights` argument to indicate the path of the checkpoint to evaluate.")
|
||||
url = None
|
||||
if model_name == "deit_small" and patch_size == 16:
|
||||
url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth"
|
||||
elif model_name == "deit_small" and patch_size == 8:
|
||||
url = "dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth"
|
||||
elif model_name == "vit_base" and patch_size == 16:
|
||||
url = "dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth"
|
||||
elif model_name == "vit_base" and patch_size == 8:
|
||||
url = "dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth"
|
||||
if url is not None:
|
||||
print("Since no pretrained weights have been provided, we load the reference pretrained DINO weights.")
|
||||
state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/" + url)
|
||||
model.load_state_dict(state_dict, strict=True)
|
||||
else:
|
||||
print("There is no reference weights available for this model => We use random weights.")
|
||||
|
||||
|
||||
def clip_gradients(model, clip):
|
||||
norms = []
|
||||
for name, p in model.named_parameters():
|
||||
if p.grad is not None:
|
||||
param_norm = p.grad.data.norm(2)
|
||||
norms.append(param_norm.item())
|
||||
clip_coef = clip / (param_norm + 1e-6)
|
||||
if clip_coef < 1:
|
||||
p.grad.data.mul_(clip_coef)
|
||||
return norms
|
||||
|
||||
|
||||
def cancel_gradients_last_layer(epoch, model, freeze_last_layer):
|
||||
if epoch >= freeze_last_layer:
|
||||
return
|
||||
for n, p in model.named_parameters():
|
||||
if "last_layer" in n:
|
||||
p.grad = None
|
||||
|
||||
|
||||
def restart_from_checkpoint(ckp_path, run_variables=None, **kwargs):
|
||||
"""
|
||||
Re-start from checkpoint
|
||||
"""
|
||||
if not os.path.isfile(ckp_path):
|
||||
return
|
||||
print("Found checkpoint at {}".format(ckp_path))
|
||||
|
||||
# open checkpoint file
|
||||
checkpoint = torch.load(ckp_path, map_location="cpu")
|
||||
|
||||
# key is what to look for in the checkpoint file
|
||||
# value is the object to load
|
||||
# example: {'state_dict': model}
|
||||
for key, value in kwargs.items():
|
||||
if key in checkpoint and value is not None:
|
||||
try:
|
||||
msg = value.load_state_dict(checkpoint[key], strict=False)
|
||||
print("=> loaded {} from checkpoint '{}' with msg {}".format(key, ckp_path, msg))
|
||||
except TypeError:
|
||||
try:
|
||||
msg = value.load_state_dict(checkpoint[key])
|
||||
print("=> loaded {} from checkpoint '{}'".format(key, ckp_path))
|
||||
except ValueError:
|
||||
print("=> failed to load {} from checkpoint '{}'".format(key, ckp_path))
|
||||
else:
|
||||
print("=> failed to load {} from checkpoint '{}'".format(key, ckp_path))
|
||||
|
||||
# re load variable important for the run
|
||||
if run_variables is not None:
|
||||
for var_name in run_variables:
|
||||
if var_name in checkpoint:
|
||||
run_variables[var_name] = checkpoint[var_name]
|
||||
|
||||
|
||||
def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, start_warmup_value=0):
|
||||
warmup_schedule = np.array([])
|
||||
warmup_iters = warmup_epochs * niter_per_ep
|
||||
if warmup_epochs > 0:
|
||||
warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
|
||||
|
||||
iters = np.arange(epochs * niter_per_ep - warmup_iters)
|
||||
schedule = np.array([final_value + 0.5 * (base_value - final_value) * (1 + \
|
||||
math.cos(math.pi * i / (len(iters)))) for i in iters])
|
||||
|
||||
schedule = np.concatenate((warmup_schedule, schedule))
|
||||
assert len(schedule) == epochs * niter_per_ep
|
||||
return schedule
|
||||
|
||||
|
||||
def bool_flag(s):
|
||||
"""
|
||||
Parse boolean arguments from the command line.
|
||||
"""
|
||||
FALSY_STRINGS = {"off", "false", "0"}
|
||||
TRUTHY_STRINGS = {"on", "true", "1"}
|
||||
if s.lower() in FALSY_STRINGS:
|
||||
return False
|
||||
elif s.lower() in TRUTHY_STRINGS:
|
||||
return True
|
||||
else:
|
||||
raise argparse.ArgumentTypeError("invalid value for a boolean flag")
|
||||
|
||||
|
||||
def fix_random_seeds(seed=31):
|
||||
"""
|
||||
Fix random seeds.
|
||||
"""
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
np.random.seed(seed)
|
||||
|
||||
|
||||
class SmoothedValue(object):
|
||||
"""Track a series of values and provide access to smoothed values over a
|
||||
window or the global series average.
|
||||
"""
|
||||
|
||||
def __init__(self, window_size=20, fmt=None):
|
||||
if fmt is None:
|
||||
fmt = "{median:.6f} ({global_avg:.6f})"
|
||||
self.deque = deque(maxlen=window_size)
|
||||
self.total = 0.0
|
||||
self.count = 0
|
||||
self.fmt = fmt
|
||||
|
||||
def update(self, value, n=1):
|
||||
self.deque.append(value)
|
||||
self.count += n
|
||||
self.total += value * n
|
||||
|
||||
def synchronize_between_processes(self):
|
||||
"""
|
||||
Warning: does not synchronize the deque!
|
||||
"""
|
||||
if not is_dist_avail_and_initialized():
|
||||
return
|
||||
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
|
||||
dist.barrier()
|
||||
dist.all_reduce(t)
|
||||
t = t.tolist()
|
||||
self.count = int(t[0])
|
||||
self.total = t[1]
|
||||
|
||||
@property
|
||||
def median(self):
|
||||
d = torch.tensor(list(self.deque))
|
||||
return d.median().item()
|
||||
|
||||
@property
|
||||
def avg(self):
|
||||
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
||||
return d.mean().item()
|
||||
|
||||
@property
|
||||
def global_avg(self):
|
||||
return self.total / self.count
|
||||
|
||||
@property
|
||||
def max(self):
|
||||
return max(self.deque)
|
||||
|
||||
@property
|
||||
def value(self):
|
||||
return self.deque[-1]
|
||||
|
||||
def __str__(self):
|
||||
return self.fmt.format(
|
||||
median=self.median,
|
||||
avg=self.avg,
|
||||
global_avg=self.global_avg,
|
||||
max=self.max,
|
||||
value=self.value)
|
||||
|
||||
|
||||
def reduce_dict(input_dict, average=True):
|
||||
"""
|
||||
Args:
|
||||
input_dict (dict): all the values will be reduced
|
||||
average (bool): whether to do average or sum
|
||||
Reduce the values in the dictionary from all processes so that all processes
|
||||
have the averaged results. Returns a dict with the same fields as
|
||||
input_dict, after reduction.
|
||||
"""
|
||||
world_size = get_world_size()
|
||||
if world_size < 2:
|
||||
return input_dict
|
||||
with torch.no_grad():
|
||||
names = []
|
||||
values = []
|
||||
# sort the keys so that they are consistent across processes
|
||||
for k in sorted(input_dict.keys()):
|
||||
names.append(k)
|
||||
values.append(input_dict[k])
|
||||
values = torch.stack(values, dim=0)
|
||||
dist.all_reduce(values)
|
||||
if average:
|
||||
values /= world_size
|
||||
reduced_dict = {k: v for k, v in zip(names, values)}
|
||||
return reduced_dict
|
||||
|
||||
|
||||
class MetricLogger(object):
|
||||
def __init__(self, delimiter="\t"):
|
||||
self.meters = defaultdict(SmoothedValue)
|
||||
self.delimiter = delimiter
|
||||
|
||||
def update(self, **kwargs):
|
||||
for k, v in kwargs.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
v = v.item()
|
||||
assert isinstance(v, (float, int))
|
||||
self.meters[k].update(v)
|
||||
|
||||
def __getattr__(self, attr):
|
||||
if attr in self.meters:
|
||||
return self.meters[attr]
|
||||
if attr in self.__dict__:
|
||||
return self.__dict__[attr]
|
||||
raise AttributeError("'{}' object has no attribute '{}'".format(
|
||||
type(self).__name__, attr))
|
||||
|
||||
def __str__(self):
|
||||
loss_str = []
|
||||
for name, meter in self.meters.items():
|
||||
loss_str.append(
|
||||
"{}: {}".format(name, str(meter))
|
||||
)
|
||||
return self.delimiter.join(loss_str)
|
||||
|
||||
def synchronize_between_processes(self):
|
||||
for meter in self.meters.values():
|
||||
meter.synchronize_between_processes()
|
||||
|
||||
def add_meter(self, name, meter):
|
||||
self.meters[name] = meter
|
||||
|
||||
def log_every(self, iterable, print_freq, header=None):
|
||||
i = 0
|
||||
if not header:
|
||||
header = ''
|
||||
start_time = time.time()
|
||||
end = time.time()
|
||||
iter_time = SmoothedValue(fmt='{avg:.6f}')
|
||||
data_time = SmoothedValue(fmt='{avg:.6f}')
|
||||
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
|
||||
if torch.cuda.is_available():
|
||||
log_msg = self.delimiter.join([
|
||||
header,
|
||||
'[{0' + space_fmt + '}/{1}]',
|
||||
'eta: {eta}',
|
||||
'{meters}',
|
||||
'time: {time}',
|
||||
'data: {data}',
|
||||
'max mem: {memory:.0f}'
|
||||
])
|
||||
else:
|
||||
log_msg = self.delimiter.join([
|
||||
header,
|
||||
'[{0' + space_fmt + '}/{1}]',
|
||||
'eta: {eta}',
|
||||
'{meters}',
|
||||
'time: {time}',
|
||||
'data: {data}'
|
||||
])
|
||||
MB = 1024.0 * 1024.0
|
||||
for obj in iterable:
|
||||
data_time.update(time.time() - end)
|
||||
yield obj
|
||||
iter_time.update(time.time() - end)
|
||||
if i % print_freq == 0 or i == len(iterable) - 1:
|
||||
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
||||
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
||||
if torch.cuda.is_available():
|
||||
print(log_msg.format(
|
||||
i, len(iterable), eta=eta_string,
|
||||
meters=str(self),
|
||||
time=str(iter_time), data=str(data_time),
|
||||
memory=torch.cuda.max_memory_allocated() / MB))
|
||||
else:
|
||||
print(log_msg.format(
|
||||
i, len(iterable), eta=eta_string,
|
||||
meters=str(self),
|
||||
time=str(iter_time), data=str(data_time)))
|
||||
i += 1
|
||||
end = time.time()
|
||||
total_time = time.time() - start_time
|
||||
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
||||
print('{} Total time: {} ({:.6f} s / it)'.format(
|
||||
header, total_time_str, total_time / len(iterable)))
|
||||
|
||||
|
||||
def get_sha():
|
||||
cwd = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
def _run(command):
|
||||
return subprocess.check_output(command, cwd=cwd).decode('ascii').strip()
|
||||
sha = 'N/A'
|
||||
diff = "clean"
|
||||
branch = 'N/A'
|
||||
try:
|
||||
sha = _run(['git', 'rev-parse', 'HEAD'])
|
||||
subprocess.check_output(['git', 'diff'], cwd=cwd)
|
||||
diff = _run(['git', 'diff-index', 'HEAD'])
|
||||
diff = "has uncommited changes" if diff else "clean"
|
||||
branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD'])
|
||||
except Exception:
|
||||
pass
|
||||
message = f"sha: {sha}, status: {diff}, branch: {branch}"
|
||||
return message
|
||||
|
||||
|
||||
def is_dist_avail_and_initialized():
|
||||
if not dist.is_available():
|
||||
return False
|
||||
if not dist.is_initialized():
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def get_world_size():
|
||||
if not is_dist_avail_and_initialized():
|
||||
return 1
|
||||
return dist.get_world_size()
|
||||
|
||||
|
||||
def get_rank():
|
||||
if not is_dist_avail_and_initialized():
|
||||
return 0
|
||||
return dist.get_rank()
|
||||
|
||||
|
||||
def is_main_process():
|
||||
return get_rank() == 0
|
||||
|
||||
|
||||
def save_on_master(*args, **kwargs):
|
||||
if is_main_process():
|
||||
torch.save(*args, **kwargs)
|
||||
|
||||
|
||||
def setup_for_distributed(is_master):
|
||||
"""
|
||||
This function disables printing when not in master process
|
||||
"""
|
||||
import builtins as __builtin__
|
||||
builtin_print = __builtin__.print
|
||||
|
||||
def print(*args, **kwargs):
|
||||
force = kwargs.pop('force', False)
|
||||
if is_master or force:
|
||||
builtin_print(*args, **kwargs)
|
||||
|
||||
__builtin__.print = print
|
||||
|
||||
|
||||
def init_distributed_mode(args):
|
||||
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
|
||||
args.rank = int(os.environ["RANK"])
|
||||
args.world_size = int(os.environ['WORLD_SIZE'])
|
||||
args.gpu = int(os.environ['LOCAL_RANK'])
|
||||
elif 'SLURM_PROCID' in os.environ:
|
||||
args.rank = int(os.environ['SLURM_PROCID'])
|
||||
args.gpu = args.rank % torch.cuda.device_count()
|
||||
else:
|
||||
print('Code is not suited for non distributed mode. Exit.')
|
||||
sys.exit(1)
|
||||
|
||||
dist.init_process_group(
|
||||
backend="nccl",
|
||||
init_method=args.dist_url,
|
||||
world_size=args.world_size,
|
||||
rank=args.rank,
|
||||
)
|
||||
|
||||
torch.cuda.set_device(args.gpu)
|
||||
print('| distributed init (rank {}): {}'.format(
|
||||
args.rank, args.dist_url), flush=True)
|
||||
dist.barrier()
|
||||
setup_for_distributed(args.rank == 0)
|
||||
|
||||
|
||||
def accuracy(output, target, topk=(1,)):
|
||||
"""Computes the accuracy over the k top predictions for the specified values of k"""
|
||||
maxk = max(topk)
|
||||
batch_size = target.size(0)
|
||||
_, pred = output.topk(maxk, 1, True, True)
|
||||
pred = pred.t()
|
||||
correct = pred.eq(target.reshape(1, -1).expand_as(pred))
|
||||
return [correct[:k].reshape(-1).float().sum(0) * 100. / batch_size for k in topk]
|
||||
|
||||
|
||||
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
||||
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
||||
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
||||
def norm_cdf(x):
|
||||
# Computes standard normal cumulative distribution function
|
||||
return (1. + math.erf(x / math.sqrt(2.))) / 2.
|
||||
|
||||
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
||||
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
|
||||
"The distribution of values may be incorrect.",
|
||||
stacklevel=2)
|
||||
|
||||
with torch.no_grad():
|
||||
# Values are generated by using a truncated uniform distribution and
|
||||
# then using the inverse CDF for the normal distribution.
|
||||
# Get upper and lower cdf values
|
||||
l = norm_cdf((a - mean) / std)
|
||||
u = norm_cdf((b - mean) / std)
|
||||
|
||||
# Uniformly fill tensor with values from [l, u], then translate to
|
||||
# [2l-1, 2u-1].
|
||||
tensor.uniform_(2 * l - 1, 2 * u - 1)
|
||||
|
||||
# Use inverse cdf transform for normal distribution to get truncated
|
||||
# standard normal
|
||||
tensor.erfinv_()
|
||||
|
||||
# Transform to proper mean, std
|
||||
tensor.mul_(std * math.sqrt(2.))
|
||||
tensor.add_(mean)
|
||||
|
||||
# Clamp to ensure it's in the proper range
|
||||
tensor.clamp_(min=a, max=b)
|
||||
return tensor
|
||||
|
||||
|
||||
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
|
||||
# type: (Tensor, float, float, float, float) -> Tensor
|
||||
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
||||
|
||||
|
||||
class LARS(torch.optim.Optimizer):
|
||||
"""
|
||||
Almost copy-paste from https://github.com/facebookresearch/barlowtwins/blob/main/main.py
|
||||
"""
|
||||
def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, eta=0.001,
|
||||
weight_decay_filter=None, lars_adaptation_filter=None):
|
||||
defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum,
|
||||
eta=eta, weight_decay_filter=weight_decay_filter,
|
||||
lars_adaptation_filter=lars_adaptation_filter)
|
||||
super().__init__(params, defaults)
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self):
|
||||
for g in self.param_groups:
|
||||
for p in g['params']:
|
||||
dp = p.grad
|
||||
|
||||
if dp is None:
|
||||
continue
|
||||
|
||||
if p.ndim != 1:
|
||||
dp = dp.add(p, alpha=g['weight_decay'])
|
||||
|
||||
if p.ndim != 1:
|
||||
param_norm = torch.norm(p)
|
||||
update_norm = torch.norm(dp)
|
||||
one = torch.ones_like(param_norm)
|
||||
q = torch.where(param_norm > 0.,
|
||||
torch.where(update_norm > 0,
|
||||
(g['eta'] * param_norm / update_norm), one), one)
|
||||
dp = dp.mul(q)
|
||||
|
||||
param_state = self.state[p]
|
||||
if 'mu' not in param_state:
|
||||
param_state['mu'] = torch.zeros_like(p)
|
||||
mu = param_state['mu']
|
||||
mu.mul_(g['momentum']).add_(dp)
|
||||
|
||||
p.add_(mu, alpha=-g['lr'])
|
||||
|
||||
|
||||
class MultiCropWrapper(nn.Module):
|
||||
"""
|
||||
Perform forward pass separately on each resolution input.
|
||||
The inputs corresponding to a single resolution are clubbed and single
|
||||
forward is run on the same resolution inputs. Hence we do several
|
||||
forward passes = number of different resolutions used. We then
|
||||
concatenate all the output features.
|
||||
"""
|
||||
def __init__(self, backbone, head):
|
||||
super(MultiCropWrapper, self).__init__()
|
||||
backbone.fc = nn.Identity()
|
||||
self.backbone = backbone
|
||||
self.head = head
|
||||
|
||||
def forward(self, x):
|
||||
# convert to list
|
||||
if not isinstance(x, list):
|
||||
x = [x]
|
||||
idx_crops = torch.cumsum(torch.unique_consecutive(
|
||||
torch.tensor([inp.shape[-1] for inp in x]),
|
||||
return_counts=True,
|
||||
)[1], 0)
|
||||
start_idx = 0
|
||||
for end_idx in idx_crops:
|
||||
_out = self.backbone(torch.cat(x[start_idx: end_idx]))
|
||||
if start_idx == 0:
|
||||
output = _out
|
||||
else:
|
||||
output = torch.cat((output, _out))
|
||||
start_idx = end_idx
|
||||
# Run the head forward on the concatenated features.
|
||||
return self.head(output)
|
||||
|
||||
|
||||
def get_params_groups(model):
|
||||
regularized = []
|
||||
not_regularized = []
|
||||
for name, param in model.named_parameters():
|
||||
if not param.requires_grad:
|
||||
continue
|
||||
# we do not regularize biases nor Norm parameters
|
||||
if name.endswith(".bias") or len(param.shape) == 1:
|
||||
not_regularized.append(param)
|
||||
else:
|
||||
regularized.append(param)
|
||||
return [{'params': regularized}, {'params': not_regularized, 'weight_decay': 0.}]
|
||||
|
||||
|
||||
def has_batchnorms(model):
|
||||
bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm)
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, bn_types):
|
||||
return True
|
||||
return False
|
|
@ -0,0 +1,334 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
"""
|
||||
Mostly copy-paste from timm library.
|
||||
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
||||
"""
|
||||
import math
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from utils import trunc_normal_
|
||||
|
||||
|
||||
def drop_path(x, drop_prob: float = 0., training: bool = False):
|
||||
if drop_prob == 0. or not training:
|
||||
return x
|
||||
keep_prob = 1 - drop_prob
|
||||
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
||||
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
||||
random_tensor.floor_() # binarize
|
||||
output = x.div(keep_prob) * random_tensor
|
||||
return output
|
||||
|
||||
|
||||
class DropPath(nn.Module):
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||
"""
|
||||
def __init__(self, drop_prob=None):
|
||||
super(DropPath, self).__init__()
|
||||
self.drop_prob = drop_prob
|
||||
|
||||
def forward(self, x):
|
||||
return drop_path(x, self.drop_prob, self.training)
|
||||
|
||||
|
||||
class Mlp(nn.Module):
|
||||
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.fc1 = nn.Linear(in_features, hidden_features)
|
||||
self.act = act_layer()
|
||||
self.fc2 = nn.Linear(hidden_features, out_features)
|
||||
self.drop = nn.Dropout(drop)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.drop(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
self.scale = qk_scale or head_dim ** -0.5
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
def forward(self, x):
|
||||
B, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2]
|
||||
|
||||
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x, attn
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
||||
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
||||
super().__init__()
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = Attention(
|
||||
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
self.norm2 = norm_layer(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
||||
|
||||
def forward(self, x, return_attention=False):
|
||||
y, attn = self.attn(self.norm1(x))
|
||||
if return_attention:
|
||||
return attn
|
||||
x = x + self.drop_path(y)
|
||||
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
||||
return x
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
""" Image to Patch Embedding
|
||||
"""
|
||||
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
|
||||
super().__init__()
|
||||
num_patches = (img_size // patch_size) * (img_size // patch_size)
|
||||
self.img_size = img_size
|
||||
self.patch_size = patch_size
|
||||
self.num_patches = num_patches
|
||||
|
||||
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
||||
|
||||
def forward(self, x):
|
||||
B, C, H, W = x.shape
|
||||
x = self.proj(x).flatten(2).transpose(1, 2)
|
||||
return x
|
||||
|
||||
|
||||
class VisionTransformer(nn.Module):
|
||||
""" Vision Transformer """
|
||||
def __init__(self, img_size=[224], patch_size=16, in_chans=3, num_classes=0, embed_dim=768, depth=12,
|
||||
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
|
||||
drop_path_rate=0., norm_layer=nn.LayerNorm, **kwargs):
|
||||
super().__init__()
|
||||
self.num_features = self.embed_dim = embed_dim
|
||||
|
||||
self.patch_embed = PatchEmbed(
|
||||
img_size=img_size[0], patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
||||
num_patches = self.patch_embed.num_patches
|
||||
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
||||
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
|
||||
self.pos_drop = nn.Dropout(p=drop_rate)
|
||||
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
||||
self.blocks = nn.ModuleList([
|
||||
Block(
|
||||
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
||||
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
|
||||
for i in range(depth)])
|
||||
self.norm = norm_layer(embed_dim)
|
||||
|
||||
# Classifier head
|
||||
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
trunc_normal_(self.pos_embed, std=.02)
|
||||
trunc_normal_(self.cls_token, std=.02)
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
|
||||
def forward(self, x):
|
||||
# convert to list
|
||||
if not isinstance(x, list):
|
||||
x = [x]
|
||||
# Perform forward pass separately on each resolution input.
|
||||
# The inputs corresponding to a single resolution are clubbed and single
|
||||
# forward is run on the same resolution inputs. Hence we do several
|
||||
# forward passes = number of different resolutions used. We then
|
||||
# concatenate all the output features.
|
||||
idx_crops = torch.cumsum(torch.unique_consecutive(
|
||||
torch.tensor([inp.shape[-1] for inp in x]),
|
||||
return_counts=True,
|
||||
)[1], 0)
|
||||
start_idx = 0
|
||||
for end_idx in idx_crops:
|
||||
_out = self.forward_features(torch.cat(x[start_idx: end_idx]))
|
||||
if start_idx == 0:
|
||||
output = _out
|
||||
else:
|
||||
output = torch.cat((output, _out))
|
||||
start_idx = end_idx
|
||||
# Run the head forward on the concatenated features.
|
||||
return self.head(output)
|
||||
|
||||
def forward_features(self, x):
|
||||
B = x.shape[0]
|
||||
x = self.patch_embed(x)
|
||||
|
||||
cls_tokens = self.cls_token.expand(B, -1, -1)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
pos_embed = self.interpolate_pos_encoding(x, self.pos_embed)
|
||||
x = x + pos_embed
|
||||
x = self.pos_drop(x)
|
||||
|
||||
for blk in self.blocks:
|
||||
x = blk(x)
|
||||
if self.norm is not None:
|
||||
x = self.norm(x)
|
||||
|
||||
return x[:, 0]
|
||||
|
||||
def interpolate_pos_encoding(self, x, pos_embed):
|
||||
npatch = x.shape[1] - 1
|
||||
N = pos_embed.shape[1] - 1
|
||||
if npatch == N:
|
||||
return pos_embed
|
||||
class_emb = pos_embed[:, 0]
|
||||
pos_embed = pos_embed[:, 1:]
|
||||
dim = x.shape[-1]
|
||||
pos_embed = nn.functional.interpolate(
|
||||
pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
|
||||
scale_factor=math.sqrt(npatch / N),
|
||||
mode='bicubic',
|
||||
)
|
||||
pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
||||
return torch.cat((class_emb.unsqueeze(0), pos_embed), dim=1)
|
||||
|
||||
def forward_selfattention(self, x):
|
||||
B, nc, w, h = x.shape
|
||||
N = self.pos_embed.shape[1] - 1
|
||||
x = self.patch_embed(x)
|
||||
|
||||
# interpolate patch embeddings
|
||||
dim = x.shape[-1]
|
||||
w0 = w // self.patch_embed.patch_size
|
||||
h0 = h // self.patch_embed.patch_size
|
||||
class_pos_embed = self.pos_embed[:, 0]
|
||||
patch_pos_embed = self.pos_embed[:, 1:]
|
||||
patch_pos_embed = nn.functional.interpolate(
|
||||
patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
|
||||
scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
|
||||
mode='bicubic',
|
||||
)
|
||||
if w0 != patch_pos_embed.shape[-2]:
|
||||
helper = torch.zeros(h0)[None, None, None, :].repeat(1, dim, w0 - patch_pos_embed.shape[-2], 1).to(x.device)
|
||||
patch_pos_embed = torch.cat((patch_pos_embed, helper), dim=-2)
|
||||
if h0 != patch_pos_embed.shape[-1]:
|
||||
helper = torch.zeros(w0)[None, None, :, None].repeat(1, dim, 1, h0 - patch_pos_embed.shape[-1]).to(x.device)
|
||||
pos_embed = torch.cat((patch_pos_embed, helper), dim=-1)
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
||||
pos_embed = torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
|
||||
|
||||
cls_tokens = self.cls_token.expand(B, -1, -1)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
x = x + pos_embed
|
||||
x = self.pos_drop(x)
|
||||
|
||||
for i, blk in enumerate(self.blocks):
|
||||
if i < len(self.blocks) - 1:
|
||||
x = blk(x)
|
||||
else:
|
||||
return blk(x, return_attention=True)
|
||||
|
||||
def forward_return_n_last_blocks(self, x, n=1, return_patch_avgpool=False):
|
||||
B = x.shape[0]
|
||||
x = self.patch_embed(x)
|
||||
|
||||
cls_tokens = self.cls_token.expand(B, -1, -1)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
pos_embed = self.interpolate_pos_encoding(x, self.pos_embed)
|
||||
x = x + pos_embed
|
||||
x = self.pos_drop(x)
|
||||
|
||||
# we will return the [CLS] tokens from the `n` last blocks
|
||||
output = []
|
||||
for i, blk in enumerate(self.blocks):
|
||||
x = blk(x)
|
||||
if len(self.blocks) - i <= n:
|
||||
output.append(self.norm(x)[:, 0])
|
||||
if return_patch_avgpool:
|
||||
x = self.norm(x)
|
||||
# In addition to the [CLS] tokens from the `n` last blocks, we also return
|
||||
# the patch tokens from the last block. This is useful for linear eval.
|
||||
output.append(torch.mean(x[:, 1:], dim=1))
|
||||
return torch.cat(output, dim=-1)
|
||||
|
||||
|
||||
def deit_tiny(patch_size=16, **kwargs):
|
||||
model = VisionTransformer(
|
||||
patch_size=patch_size, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4,
|
||||
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def deit_small(patch_size=16, **kwargs):
|
||||
model = VisionTransformer(
|
||||
patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4,
|
||||
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def vit_base(patch_size=16, **kwargs):
|
||||
model = VisionTransformer(
|
||||
patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
|
||||
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
class DINOHead(nn.Module):
|
||||
def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True, nlayers=3, hidden_dim=2048, bottleneck_dim=256):
|
||||
super().__init__()
|
||||
nlayers = max(nlayers, 1)
|
||||
if nlayers == 1:
|
||||
self.mlp = nn.Linear(in_dim, bottleneck_dim)
|
||||
else:
|
||||
layers = [nn.Linear(in_dim, hidden_dim)]
|
||||
if use_bn:
|
||||
layers.append(nn.BatchNorm1d(hidden_dim))
|
||||
layers.append(nn.GELU())
|
||||
for _ in range(nlayers - 2):
|
||||
layers.append(nn.Linear(hidden_dim, hidden_dim))
|
||||
if use_bn:
|
||||
layers.append(nn.BatchNorm1d(hidden_dim))
|
||||
layers.append(nn.GELU())
|
||||
layers.append(nn.Linear(hidden_dim, bottleneck_dim))
|
||||
self.mlp = nn.Sequential(*layers)
|
||||
self.apply(self._init_weights)
|
||||
self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
|
||||
self.last_layer.weight_g.data.fill_(1)
|
||||
if norm_last_layer:
|
||||
self.last_layer.weight_g.requires_grad = False
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.mlp(x)
|
||||
x = nn.functional.normalize(x, dim=-1, p=2)
|
||||
x = self.last_layer(x)
|
||||
return x
|
|
@ -0,0 +1,193 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
import cv2
|
||||
import random
|
||||
import colorsys
|
||||
import requests
|
||||
from io import BytesIO
|
||||
|
||||
import skimage.io
|
||||
from skimage.measure import find_contours
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.patches import Polygon
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision
|
||||
from torchvision import transforms as pth_transforms
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
import utils
|
||||
import vision_transformer as vits
|
||||
|
||||
|
||||
def apply_mask(image, mask, color, alpha=0.5):
|
||||
for c in range(3):
|
||||
image[:, :, c] = image[:, :, c] * (1 - alpha * mask) + alpha * mask * color[c] * 255
|
||||
return image
|
||||
|
||||
|
||||
def random_colors(N, bright=True):
|
||||
"""
|
||||
Generate random colors.
|
||||
"""
|
||||
brightness = 1.0 if bright else 0.7
|
||||
hsv = [(i / N, 1, brightness) for i in range(N)]
|
||||
colors = list(map(lambda c: colorsys.hsv_to_rgb(*c), hsv))
|
||||
random.shuffle(colors)
|
||||
return colors
|
||||
|
||||
|
||||
def display_instances(image, mask, fname="test", figsize=(5, 5), blur=False, contour=True, alpha=0.5):
|
||||
fig = plt.figure(figsize=figsize, frameon=False)
|
||||
ax = plt.Axes(fig, [0., 0., 1., 1.])
|
||||
ax.set_axis_off()
|
||||
fig.add_axes(ax)
|
||||
ax = plt.gca()
|
||||
|
||||
N = 1
|
||||
mask = mask[None, :, :]
|
||||
# Generate random colors
|
||||
colors = random_colors(N)
|
||||
|
||||
# Show area outside image boundaries.
|
||||
height, width = image.shape[:2]
|
||||
margin = 0
|
||||
ax.set_ylim(height + margin, -margin)
|
||||
ax.set_xlim(-margin, width + margin)
|
||||
ax.axis('off')
|
||||
masked_image = image.astype(np.uint32).copy()
|
||||
for i in range(N):
|
||||
color = colors[i]
|
||||
_mask = mask[i]
|
||||
if blur:
|
||||
_mask = cv2.blur(_mask,(10,10))
|
||||
# Mask
|
||||
masked_image = apply_mask(masked_image, _mask, color, alpha)
|
||||
# Mask Polygon
|
||||
# Pad to ensure proper polygons for masks that touch image edges.
|
||||
if contour:
|
||||
padded_mask = np.zeros((_mask.shape[0] + 2, _mask.shape[1] + 2))
|
||||
padded_mask[1:-1, 1:-1] = _mask
|
||||
contours = find_contours(padded_mask, 0.5)
|
||||
for verts in contours:
|
||||
# Subtract the padding and flip (y, x) to (x, y)
|
||||
verts = np.fliplr(verts) - 1
|
||||
p = Polygon(verts, facecolor="none", edgecolor=color)
|
||||
ax.add_patch(p)
|
||||
ax.imshow(masked_image.astype(np.uint8), aspect='auto')
|
||||
fig.savefig(fname)
|
||||
print(f"{fname} saved.")
|
||||
return
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser('Visualize Self-Attention maps')
|
||||
parser.add_argument('--arch', default='deit_small', type=str,
|
||||
choices=['deit_tiny', 'deit_small', 'vit_base'], help='Architecture (support only ViT atm).')
|
||||
parser.add_argument('--patch_size', default=8, type=int, help='Patch resolution of the model.')
|
||||
parser.add_argument('--pretrained_weights', default='', type=str,
|
||||
help="Path to pretrained weights to load.")
|
||||
parser.add_argument("--checkpoint_key", default="teacher", type=str,
|
||||
help='Key to use in the checkpoint (example: "teacher")')
|
||||
parser.add_argument("--image_path", default=None, type=str, help="Path of the image to load.")
|
||||
parser.add_argument('--output_dir', default='.', help='Path where to save visualizations.')
|
||||
parser.add_argument("--threshold", type=float, default=0.6, help="""We visualize masks
|
||||
obtained by thresholding the self-attention maps to keep xx% of the mass.""")
|
||||
args = parser.parse_args()
|
||||
|
||||
# build model
|
||||
model = vits.__dict__[args.arch](patch_size=args.patch_size, num_classes=0)
|
||||
for p in model.parameters():
|
||||
p.requires_grad = False
|
||||
model.eval()
|
||||
model.cuda()
|
||||
if os.path.isfile(args.pretrained_weights):
|
||||
state_dict = torch.load(args.pretrained_weights, map_location="cpu")
|
||||
if args.checkpoint_key is not None and args.checkpoint_key in state_dict:
|
||||
print(f"Take key {args.checkpoint_key} in provided checkpoint dict")
|
||||
state_dict = state_dict[args.checkpoint_key]
|
||||
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
|
||||
msg = model.load_state_dict(state_dict, strict=False)
|
||||
print('Pretrained weights found at {} and loaded with msg: {}'.format(args.pretrained_weights, msg))
|
||||
else:
|
||||
print("Please use the `--pretrained_weights` argument to indicate the path of the checkpoint to evaluate.")
|
||||
url = None
|
||||
if args.arch == "deit_small" and args.patch_size == 16:
|
||||
url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth"
|
||||
elif args.arch == "deit_small" and args.patch_size == 8:
|
||||
url = "dino_deitsmall8_300ep_pretrain/dino_deitsmall8_300ep_pretrain.pth" # model used for visualizations in our paper
|
||||
elif args.arch == "vit_base" and args.patch_size == 16:
|
||||
url = "dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth"
|
||||
elif args.arch == "vit_base" and args.patch_size == 8:
|
||||
url = "dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth"
|
||||
if url is not None:
|
||||
print("Since no pretrained weights have been provided, we load the reference pretrained DINO weights.")
|
||||
state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/" + url)
|
||||
model.load_state_dict(state_dict, strict=True)
|
||||
else:
|
||||
print("There is no reference weights available for this model => We use random weights.")
|
||||
|
||||
# open image
|
||||
if args.image_path is None:
|
||||
# user has not specified any image - we use our own image
|
||||
print("Please use the `--image_path` argument to indicate the path of the image you wish to visualize.")
|
||||
print("Since no image path have been provided, we take the first image in our paper.")
|
||||
response = requests.get("https://dl.fbaipublicfiles.com/dino/img.png")
|
||||
img = Image.open(BytesIO(response.content))
|
||||
img = img.convert('RGB')
|
||||
elif os.path.isfile(args.image_path):
|
||||
with open(args.image_path, 'rb') as f:
|
||||
img = Image.open(f)
|
||||
img = img.convert('RGB')
|
||||
else:
|
||||
print(f"Provided image path {args.image_path} is non valid.")
|
||||
sys.exit(1)
|
||||
transform = pth_transforms.Compose([
|
||||
pth_transforms.ToTensor(),
|
||||
pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
|
||||
])
|
||||
img = transform(img)
|
||||
|
||||
# make the image divisible by the patch size
|
||||
w, h = img.shape[1] - img.shape[1] % args.patch_size, img.shape[2] - img.shape[2] % args.patch_size
|
||||
img = img[:, :w, :h].unsqueeze(0)
|
||||
|
||||
w_featmap = img.shape[-2] // args.patch_size
|
||||
h_featmap = img.shape[-1] // args.patch_size
|
||||
|
||||
attentions = model.forward_selfattention(img.cuda())
|
||||
|
||||
nh = attentions.shape[1] # number of head
|
||||
|
||||
# we keep only the output patch attention
|
||||
attentions = attentions[0, :, 0, 1:].reshape(nh, -1)
|
||||
|
||||
# we keep only a certain percentage of the mass
|
||||
val, idx = torch.sort(attentions)
|
||||
val /= torch.sum(val, dim=1, keepdim=True)
|
||||
cumval = torch.cumsum(val, dim=1)
|
||||
th_attn = cumval > (1 - args.threshold)
|
||||
idx2 = torch.argsort(idx)
|
||||
for head in range(nh):
|
||||
th_attn[head] = th_attn[head][idx2[head]]
|
||||
th_attn = th_attn.reshape(nh, w_featmap, h_featmap).float()
|
||||
# interpolate
|
||||
th_attn = nn.functional.interpolate(th_attn.unsqueeze(0), scale_factor=args.patch_size, mode="nearest")[0].cpu().numpy()
|
||||
|
||||
attentions = attentions.reshape(nh, w_featmap, h_featmap)
|
||||
attentions = nn.functional.interpolate(attentions.unsqueeze(0), scale_factor=args.patch_size, mode="nearest")[0].cpu().numpy()
|
||||
|
||||
# save attentions heatmaps
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
torchvision.utils.save_image(torchvision.utils.make_grid(img, normalize=True, scale_each=True), os.path.join(args.output_dir, "img.png"))
|
||||
for j in range(nh):
|
||||
fname = os.path.join(args.output_dir, "attn-head" + str(j) + ".png")
|
||||
plt.imsave(fname=fname, arr=attentions[j], format='png')
|
||||
print(f"{fname} saved.")
|
||||
|
||||
image = skimage.io.imread(os.path.join(args.output_dir, "img.png"))
|
||||
for j in range(nh):
|
||||
display_instances(image, th_attn[j], fname=os.path.join(args.output_dir, "mask_th" + str(args.threshold) + "_head" + str(j) +".png"), blur=False)
|
Loading…
Reference in New Issue