add MoCo V2 (#2757)
parent
5d06a88a36
commit
122f7f9782
|
@ -0,0 +1,183 @@
|
|||
# rfc_task-137_MoCo-v2模型PaddleClas实现设计文档)
|
||||
|
||||
|模型名称 | MoCov2模型 |
|
||||
|---|---|
|
||||
|相关paper| https://arxiv.org/pdf/2003.04297.pdf |
|
||||
|参考项目| https://github.com/PaddlePaddle/PASSL https://github.com/facebookresearch/MoCo|
|
||||
|提交作者 | 张乐 |
|
||||
|提交时间 | 2022-03-11 |
|
||||
|依赖飞桨版本 | PaddlePaddle2.4.1 |
|
||||
|文件名 | rfc_task_137_model_MoCo-v2.md |
|
||||
|
||||
# MoCo-v2 模型PaddleClas实现设计文档
|
||||
## 一、概述
|
||||
|
||||
MoCo-v2[<sup>2</sup>](#moco-v2)模型是在MoCo模型的基础上增加了数据增强、将单层fc替换为多层mlp、学习率衰减策略修改为consine衰减。因此,我们在此重点介绍MoCo模型。
|
||||
|
||||
MoCo[<sup>1</sup>](#moco-v1)模型本身是一个自监督对比学习框架,可以从大规模图像数据集中学习到良好的图像表示特征,其预训练模型可以无缝地嵌入许多视觉任务中,比如:图像分类、目标检测、分割等。
|
||||
|
||||
**MoCo框架简述**
|
||||
|
||||
**前向传播**
|
||||
|
||||
下面我们从输入$minibatchImgs=\{I_1,I_2,..I_N\}$ 数据的前向传播过程来简单讲解MoCo框架,首先对$I_n$分别进行变换$view_1$和$view_2$:
|
||||
$$I^{view1}_n=view_1(I_n)$$
|
||||
$$I^{view2}_n=view_2(I_n)$$
|
||||
其中,$view_1$和$view_2$表示一系列图像预处理变换(随机裁切、灰度化、均值化等,具体详见paper Source Code),minibatch大小为$N$。这样每幅输入图像$I_n$就会得到两个变换图像$I^{view1}_n$和$I^{view2}_n$。
|
||||
|
||||
接着将$I^{view1}_n$和$I^{view2}_n$分别送入两个编码器,则:
|
||||
$$q_n=L2_{normalization}(Encoder_1(I^{view1}_n))$$
|
||||
$$k_n=L2_{normalization}(Encoder_2(I^{view2}_n))$$
|
||||
|
||||
其中$q_n$和$k_n$的特征维度均为k, $Encoder_1$和$Encoder_2$分别是ResNet50的backbone网络串联一个MLP网络组成。
|
||||
|
||||
为了满足对比学习任务的条件,需要正负样本来进行学习。作者自然而然将输入的样本都看作正样本,至于负样本,则通过构建一个**动态**$Dict_{K\times C}$维度的超大字典,通过将正样本集合$q_+=\{q_1,q_2...q_N\}$和$k_+=\{k_1,k_2...k_N\}$一一做向量点乘求和相加来计算$Loss_+$:
|
||||
|
||||
$$Loss_+=\{l^{1}_+;l^{2}_+; ...;l^{N}_+\}=\{ q_1\cdot k_1; q_2\cdot k_2;...; q_n\cdot k_n \}; Loss_+\in N \times 1$$
|
||||
|
||||
|
||||
$Loss_-$的计算过程为:
|
||||
$$l^{n,k}_-=q_n \cdot Dict_{:,n};Loss_-\in N \times C$$
|
||||
|
||||
|
||||
最后的loss为:
|
||||
$$Loss=concat(Loss_+, Loss_-)\in N \times (1+C)$$
|
||||
可以看到字典$Dict$在整个图像表示的学习过程中可以看作一个隐特征空间,作者发现,该字典设置的越大,视觉表示学习的效果就越好。其中,每次在做完前向传播后,需要将当前的minibatch以**队列入队**的形式将$k_n$加入到字典$Dict$中,并同时将最旧时刻的minibatch**出队**。
|
||||
|
||||
学习的目标函数采用交叉熵损失函数如下所示:
|
||||
|
||||
$$Loss_{crossentropy}=-log \cdot \frac{exp(l_+/ \tau)}{ \sum exp(l_n / \tau)}$$
|
||||
|
||||
其中超参数$\tau$取0.07
|
||||
|
||||
**反向梯度传播**
|
||||
|
||||
在梯度反向传播过程中,梯度传播只用来更新$Encoder_1$的参数$Param_{Encoder_1}$,为了不影响动态词典$Dict$的视觉表示特征一致性,$Encoder_2$的参数$Param_{Encoder_1}$更新过程为:
|
||||
|
||||
$$Param_{Encoder_2}=m \cdot Param_{Encoder_2} + ( 1- m ) \cdot Param_{Encoder_1} $$
|
||||
其中,超参数$m$取0.999
|
||||
|
||||
## 二、设计思路与实现方案
|
||||
|
||||
### 模型backbone(PaddleClas已有实现)
|
||||
|
||||
- ResNet50的backbone(去除最后的全连接层)
|
||||
- MLP由 两个全连接层FC1 $ 2048 \times 2048 $ 和FC2 $ 2048 \times 128 $ 构成
|
||||
- 动态字典大小为$65536$
|
||||
### optimizer
|
||||
- SGD:随机梯度下降优化器
|
||||
- 初始学习率 $0.03$
|
||||
- 权重衰减:$1e-4$
|
||||
- momentum of SGD: $0.9$
|
||||
|
||||
### 训练策略(PaddleClas已有实现)
|
||||
- batch-size:256
|
||||
- 单机8块V100
|
||||
- 在每个GPU上做shuffle_BN
|
||||
- 共迭代$epochs:200$
|
||||
|
||||
- lr schedule 在$epch=[120, 160]$, $lr=lr*.0.1$
|
||||
- 学习率衰减策略$cosine $
|
||||
|
||||
### metric(PaddleClas已有实现)
|
||||
- top1
|
||||
- top5
|
||||
|
||||
### dataset
|
||||
- 数据集:ImageNet
|
||||
- 数据增强(PaddleClas已有基本变换实现)
|
||||
```Python
|
||||
#pytorch 代码
|
||||
augmentation = [
|
||||
transforms.RandomResizedCrop(224, scale=(0.2, 1.0)),
|
||||
transforms.RandomApply(
|
||||
[transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8 # not strengthened
|
||||
),
|
||||
transforms.RandomGrayscale(p=0.2),
|
||||
transforms.RandomApply([moco.loader.GaussianBlur([0.1, 2.0])], p=0.5),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ToTensor(),
|
||||
normalize,
|
||||
]
|
||||
|
||||
```
|
||||
- 图像对随机变换和高斯模糊(**PSSL已有基本变换实现,需要转为PaddleClas项目实现**)
|
||||
|
||||
```python
|
||||
# pytorch 代码
|
||||
class TwoCropsTransform:
|
||||
"""Take two random crops of one image as the query and key."""
|
||||
|
||||
def __init__(self, base_transform):
|
||||
self.base_transform = base_transform
|
||||
|
||||
def __call__(self, x):
|
||||
q = self.base_transform(x)
|
||||
k = self.base_transform(x)
|
||||
return [q, k]
|
||||
|
||||
|
||||
class GaussianBlur(object):
|
||||
"""Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709"""
|
||||
|
||||
def __init__(self, sigma=[0.1, 2.0]):
|
||||
self.sigma = sigma
|
||||
|
||||
def __call__(self, x):
|
||||
sigma = random.uniform(self.sigma[0], self.sigma[1])
|
||||
x = x.filter(ImageFilter.GaussianBlur(radius=sigma))
|
||||
return x
|
||||
```
|
||||
|
||||
### PSSL项目和PaddleClas项目框架对比
|
||||
|
||||
- 两个项目基础模型ResNet50的每层参数名称不同,需要将PASSL项目的训练权重转化为PaddleClas项目使用
|
||||
- PSSL项目采用Register类方式将模型的architecture、backbone、neck、head、数据集、优化器、钩子函数链接在一起,使得整个模型的训练过程都可以通过命令行提供一份yaml文件搞定,这一点与PaddleClas项目类似
|
||||
|
||||
### 详细设计方案
|
||||
|
||||
1. model_zone添加模型MoCo的backbone、neck、head参数配置,backbone采用theseuslayer定义的Resnet50网络;
|
||||
2. MoCo.yaml格式参考paddleclas项目;
|
||||
3. 在ppcls.data.preprocess.ops.operators.py 文件下新增GaussianBlur类
|
||||
4. 重构ImageNetDataset类中的__init__ 和__getitem__方法, 原来的ImageNetDataset只能返回(img, label) 现增加返回(sample_1, sample_2, label)可选功能,其中,sample_1和sample_2均是img分别经过view_trans1,view_trans2得到的;
|
||||
5. 在train.py
|
||||
|
||||
|
||||
## 三、功能模块测试方法
|
||||
|功能模块|测试方法|
|
||||
|---|---|
|
||||
|前向完全对齐|给定相同的输入,分别对比PaddleClas实现的模型输出是否和官方的Pytorch版本相同|
|
||||
|反向完全对齐|给定相同的输入检查反向参数更新,分别对比PaddleClas实现和官方的Pytorch版本参数更新是否一致|
|
||||
|图像预处理|对照官方实现,编写paddle版本|
|
||||
|超参数配置|保持和官方实现一致|
|
||||
|训练环境|最好也是8块V100显卡环境,采用单机多卡分布式训练方式,和官方保持一致|
|
||||
|精度对齐|在提供的小数据集上预训练并finetune后,实现精度和原PSSL项目模型相同|
|
||||
|
||||
## 四、可行性分析和排期规划
|
||||
|时间|开发排期规划|时长|
|
||||
|---|---|---|
|
||||
|03.11-03.19|熟悉相关工具、前向对齐|9days|
|
||||
|03.20-04.02|反向对齐|14days|
|
||||
|04.03-04.16|训练对齐|14days|
|
||||
|04.16-04.29|代码合入|14days|
|
||||
|
||||
## 五、风险点与影响面
|
||||
|
||||
风险点:
|
||||
- MoCo模型训练后一般作为图像特征提取器使用,并不存在所谓的推理过程
|
||||
- **PaddleClas中所有模型和算法需要通过飞桨训推一体认证,当前只需要通过新增模型只需要通过训练和推理的基础认证即可**。但是这个与MoCo模型的训练推理原则相违背,是否可以对MoCo-v2模型的认证给出明确的指定
|
||||
- 合入代码题目是MoCo-v2,代码合入的时候是否需要同时考虑MoCo-v1代码模块(原PSSL项目有该项实现)
|
||||
- 原PSSL有MoCo-Clas分类模型,代码合入的时候是否需要同时加入此模块(原PSSL项目有该项实现)
|
||||
- 可能涉及到修改train.py部分代码
|
||||
|
||||
影响面:
|
||||
数据的Dataloader、数据增强和model均为新增脚本,不对其它模块构成影响
|
||||
|
||||
# 名词解释
|
||||
MoCo(Momentum Contrast,动量对比)
|
||||
# 附件及参考资料
|
||||
<div id="moco-v1"></div>
|
||||
[1] He K, Fan H, Wu Y, et al. Momentum contrast for unsupervised visual representation learning[C]//Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2020: 9729-9738.
|
||||
|
||||
<div id="moco-v2"></div>
|
||||
[2] Chen X, Fan H, Girshick R, et al. Improved baselines with momentum contrastive learning[J]. arXiv preprint arXiv:2003.04297, 2020.
|
|
@ -75,7 +75,8 @@ from .model_zoo.foundation_vit import CLIP_vit_base_patch32_224, CLIP_vit_base_p
|
|||
from .model_zoo.convnext import ConvNeXt_tiny, ConvNeXt_small, ConvNeXt_base_224, ConvNeXt_base_384, ConvNeXt_large_224, ConvNeXt_large_384
|
||||
from .model_zoo.nextvit import NextViT_small_224, NextViT_base_224, NextViT_large_224, NextViT_small_384, NextViT_base_384, NextViT_large_384
|
||||
from .model_zoo.cae import cae_base_patch16_224, cae_large_patch16_224
|
||||
|
||||
from .model_zoo.moco import MoCo_V1, MoCo_V2
|
||||
from .model_zoo.moco_finetune import MoCo_finetune
|
||||
from .variant_models.resnet_variant import ResNet50_last_stage_stride1
|
||||
from .variant_models.resnet_variant import ResNet50_adaptive_max_pool2d
|
||||
from .variant_models.resnet_variant import ResNet50_metabin
|
||||
|
|
|
@ -346,7 +346,7 @@ class ResNet(TheseusLayer):
|
|||
[32, 32, 3, 1], [32, 64, 3, 1]]
|
||||
}
|
||||
|
||||
self.stem = nn.Sequential(* [
|
||||
self.stem = nn.Sequential(*[
|
||||
ConvBNLayer(
|
||||
num_channels=in_c,
|
||||
num_filters=out_c,
|
||||
|
@ -396,7 +396,7 @@ class ResNet(TheseusLayer):
|
|||
|
||||
self.data_format = data_format
|
||||
|
||||
super().init_res(
|
||||
super().init_net(
|
||||
stages_pattern,
|
||||
return_patterns=return_patterns,
|
||||
return_stages=return_stages)
|
||||
|
|
|
@ -0,0 +1,354 @@
|
|||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# reference: https://arxiv.org/abs/1611.05431
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
from ppcls.utils.initializer import kaiming_normal_, constant_, normal_
|
||||
from ..legendary_models import *
|
||||
from ....utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url
|
||||
|
||||
MODEL_URLS = {"MoCo_V1": "UNKNOWN", "MoCo_V2": "UNKNOWN"}
|
||||
|
||||
__all__ = list(MODEL_URLS.keys())
|
||||
|
||||
|
||||
class LinearNeck(nn.Layer):
|
||||
"""Linear neck: fc only.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, with_avg_pool=False):
|
||||
super(LinearNeck, self).__init__()
|
||||
self.with_avg_pool = with_avg_pool
|
||||
if with_avg_pool:
|
||||
self.avgpool = nn.AdaptiveAvgPool2D((1, 1))
|
||||
self.fc = nn.Linear(in_channels, out_channels)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
if self.with_avg_pool:
|
||||
x = self.avgpool(x)
|
||||
return self.fc(x.reshape([x.shape[0], -1]))
|
||||
|
||||
|
||||
class NonLinearNeck(nn.Layer):
|
||||
"""The non-linear neck in MoCo v2: fc-relu-fc.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
hid_channels,
|
||||
out_channels,
|
||||
with_avg_pool=False):
|
||||
super(NonLinearNeck, self).__init__()
|
||||
self.with_avg_pool = with_avg_pool
|
||||
if with_avg_pool:
|
||||
self.avgpool = nn.AdaptiveAvgPool2D((1, 1))
|
||||
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(in_channels, hid_channels),
|
||||
nn.ReLU(), nn.Linear(hid_channels, out_channels))
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
if self.with_avg_pool:
|
||||
x = self.avgpool(x)
|
||||
return self.mlp(x.reshape([x.shape[0], -1]))
|
||||
|
||||
|
||||
class ContrastiveHead(nn.Layer):
|
||||
"""Head for contrastive learning.
|
||||
|
||||
Args:
|
||||
temperature (float): The temperature hyper-parameter that
|
||||
controls the concentration level of the distribution.
|
||||
Default: 0.1.
|
||||
"""
|
||||
|
||||
def __init__(self, temperature=0.1):
|
||||
super(ContrastiveHead, self).__init__()
|
||||
self.criterion = nn.CrossEntropyLoss()
|
||||
self.temperature = temperature
|
||||
|
||||
def forward(self, pos, neg):
|
||||
"""Forward head.
|
||||
|
||||
Args:
|
||||
pos (Tensor): Nx1 positive similarity.
|
||||
neg (Tensor): Nxk negative similarity.
|
||||
|
||||
Returns:
|
||||
dict[str, Tensor]: A dictionary of loss components.
|
||||
"""
|
||||
N = pos.shape[0]
|
||||
logits = paddle.concat((pos, neg), axis=1)
|
||||
logits /= self.temperature
|
||||
labels = paddle.zeros((N, 1), dtype='int64')
|
||||
|
||||
return logits, labels
|
||||
|
||||
|
||||
def _load_pretrained(pretrained, model, model_url, use_ssld=False):
|
||||
if pretrained is False:
|
||||
pass
|
||||
elif pretrained is True:
|
||||
load_dygraph_pretrain_from_url(model, model_url, use_ssld=use_ssld)
|
||||
elif isinstance(pretrained, str):
|
||||
load_dygraph_pretrain(model, pretrained)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"pretrained type is not available. Please use `string` or `boolean` type."
|
||||
)
|
||||
|
||||
|
||||
class MoCo(nn.Layer):
|
||||
"""
|
||||
Build a MoCo model with: a query encoder, a key encoder, and a queue
|
||||
https://arxiv.org/abs/1911.05722
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
backbone_config,
|
||||
neck_config,
|
||||
head_config,
|
||||
dim=128,
|
||||
K=65536,
|
||||
m=0.999,
|
||||
T=0.07):
|
||||
"""
|
||||
initialize `MoCoV1` or `MoCoV2` model depends on args
|
||||
Args:
|
||||
backbone_config (dict): config of backbone(eg: ResNet50).
|
||||
neck_config (dict): config of neck(eg: MLP or FC)
|
||||
head_config (dict): config of head
|
||||
dim (int): feature dimension. Default: 128.
|
||||
K (int): queue size; number of negative keys. Default: 65536.
|
||||
m (float): moco momentum of updating key encoder. Default: 0.999.
|
||||
T (float): softmax temperature. Default: 0.07.
|
||||
"""
|
||||
super(MoCo, self).__init__()
|
||||
self.K = K
|
||||
self.m = m
|
||||
self.T = T
|
||||
|
||||
backbone_type = backbone_config.pop('name')
|
||||
backbone = eval(backbone_type)
|
||||
|
||||
neck_type = neck_config.pop('name')
|
||||
neck = eval(neck_type)
|
||||
|
||||
head_type = head_config.pop('name')
|
||||
head = eval(head_type)
|
||||
|
||||
backbone_1 = backbone()
|
||||
backbone_1.stop_after(stop_layer_name='avg_pool')
|
||||
backbone_2 = backbone()
|
||||
backbone_2.stop_after(stop_layer_name='avg_pool')
|
||||
|
||||
self.encoder_q = nn.Sequential(backbone_1, neck(**neck_config))
|
||||
self.encoder_k = nn.Sequential(backbone_2, neck(**neck_config))
|
||||
|
||||
self.backbone = self.encoder_q[0]
|
||||
|
||||
self.head = head(**head_config)
|
||||
|
||||
# initialize function by kaiming
|
||||
self.init_parameters()
|
||||
|
||||
for param_q, param_k in zip(self.encoder_q.parameters(),
|
||||
self.encoder_k.parameters()):
|
||||
param_k.set_value(param_q) # moco initialize
|
||||
param_k.stop_gradient = True # not update by gradient
|
||||
|
||||
# frozen bn normal
|
||||
freeze_batchnorm_statictis(self.encoder_k)
|
||||
|
||||
# create the queue
|
||||
self.register_buffer("queue", paddle.randn([dim, K]))
|
||||
self.queue = nn.functional.normalize(self.queue, axis=0)
|
||||
|
||||
self.register_buffer("queue_ptr", paddle.zeros([1], 'int64'))
|
||||
|
||||
def init_parameters(self, init_linear='kaiming', std=0.01, bias=0.):
|
||||
assert init_linear in ['normal', 'kaiming'], \
|
||||
"Undefined init_linear: {}".format(init_linear)
|
||||
for m in self.sublayers():
|
||||
if isinstance(m, nn.Conv2D):
|
||||
kaiming_normal_(m, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, (nn.layer.norm._BatchNormBase, nn.GroupNorm)):
|
||||
constant_(m, 1)
|
||||
elif isinstance(m, nn.Linear):
|
||||
if init_linear == 'normal':
|
||||
normal_(m, std=std, bias=bias)
|
||||
else:
|
||||
kaiming_normal_(m, mode='fan_in', nonlinearity='relu')
|
||||
|
||||
@paddle.no_grad()
|
||||
def _momentum_update_key_encoder(self):
|
||||
"""
|
||||
Momentum update of the key encoder
|
||||
"""
|
||||
for param_q, param_k in zip(self.encoder_q.parameters(),
|
||||
self.encoder_k.parameters()):
|
||||
paddle.assign((param_k * self.m + param_q * (1. - self.m)),
|
||||
param_k)
|
||||
param_k.stop_gradient = True
|
||||
|
||||
@paddle.no_grad()
|
||||
def _dequeue_and_enqueue(self, keys):
|
||||
keys = concat_all_gather(keys)
|
||||
|
||||
batch_size = keys.shape[0]
|
||||
|
||||
ptr = int(self.queue_ptr[0])
|
||||
assert self.K % batch_size == 0 # for simplicity
|
||||
|
||||
# replace the keys at ptr (dequeue and enqueue)
|
||||
self.queue[:, ptr:ptr + batch_size] = keys.transpose([1, 0])
|
||||
ptr = (ptr + batch_size) % self.K # move pointer
|
||||
|
||||
self.queue_ptr[0] = ptr
|
||||
|
||||
@paddle.no_grad()
|
||||
def _batch_shuffle_ddp(self, x):
|
||||
"""
|
||||
Batch shuffle, for making use of BatchNorm.
|
||||
*** Only support DistributedDataParallel (DDP) model. ***
|
||||
"""
|
||||
# gather from all gpus
|
||||
batch_size_this = x.shape[0]
|
||||
x_gather = concat_all_gather(x)
|
||||
batch_size_all = x_gather.shape[0]
|
||||
|
||||
num_gpus = batch_size_all // batch_size_this
|
||||
|
||||
# random shuffle index
|
||||
idx_shuffle = paddle.randperm(batch_size_all).cuda()
|
||||
|
||||
# broadcast to all gpus
|
||||
if paddle.distributed.get_world_size() > 1:
|
||||
paddle.distributed.broadcast(idx_shuffle, src=0)
|
||||
|
||||
# index for restoring
|
||||
idx_unshuffle = paddle.argsort(idx_shuffle)
|
||||
|
||||
# shuffled index for this gpu
|
||||
gpu_idx = paddle.distributed.get_rank()
|
||||
idx_this = idx_shuffle.reshape([num_gpus, -1])[gpu_idx]
|
||||
return paddle.index_select(x_gather, idx_this), idx_unshuffle
|
||||
|
||||
@paddle.no_grad()
|
||||
def _batch_unshuffle_ddp(self, x, idx_unshuffle):
|
||||
"""
|
||||
Undo batch shuffle.
|
||||
*** Only support DistributedDataParallel (DDP) model. ***
|
||||
"""
|
||||
# gather from all gpus
|
||||
batch_size_this = x.shape[0]
|
||||
x_gather = concat_all_gather(x)
|
||||
batch_size_all = x_gather.shape[0]
|
||||
|
||||
num_gpus = batch_size_all // batch_size_this
|
||||
|
||||
# restored index for this gpu
|
||||
gpu_idx = paddle.distributed.get_rank()
|
||||
idx_this = idx_unshuffle.reshape([num_gpus, -1])[gpu_idx]
|
||||
|
||||
return paddle.index_select(x_gather, idx_this)
|
||||
|
||||
def train_iter(self, inputs, **kwargs):
|
||||
img_q, img_k = inputs
|
||||
|
||||
# compute query features
|
||||
q = self.encoder_q(img_q) # queries: NxC
|
||||
q = nn.functional.normalize(q, axis=1)
|
||||
|
||||
# compute key features
|
||||
with paddle.no_grad(): # no gradient to keys
|
||||
self._momentum_update_key_encoder() # update the key encoder
|
||||
|
||||
# shuffle for making use of BN
|
||||
img_k = paddle.to_tensor(img_k)
|
||||
im_k, idx_unshuffle = self._batch_shuffle_ddp(img_k)
|
||||
|
||||
k = self.encoder_k(im_k) # keys: NxC
|
||||
k = nn.functional.normalize(k, axis=1)
|
||||
|
||||
# undo shuffle
|
||||
k = self._batch_unshuffle_ddp(k, idx_unshuffle)
|
||||
|
||||
# compute logits
|
||||
# FIXME: Einstein sum is more intuitive
|
||||
# positive logits: Nx1
|
||||
l_pos = paddle.sum(q * k, axis=1).unsqueeze(-1)
|
||||
# negative logits: NxK
|
||||
l_neg = paddle.matmul(q, self.queue.clone().detach())
|
||||
|
||||
outputs = self.head(l_pos, l_neg)
|
||||
self._dequeue_and_enqueue(k)
|
||||
# add return label
|
||||
|
||||
return outputs
|
||||
|
||||
def forward(self, inputs, mode='train', **kwargs):
|
||||
if mode == 'train':
|
||||
return self.train_iter(inputs, **kwargs)
|
||||
elif mode == 'test':
|
||||
return self.test_iter(inputs, **kwargs)
|
||||
elif mode == 'extract':
|
||||
return self.backbone(inputs)
|
||||
else:
|
||||
raise Exception("No such mode: {}".format(mode))
|
||||
|
||||
|
||||
@paddle.no_grad()
|
||||
def concat_all_gather(tensor):
|
||||
"""
|
||||
Performs all_gather operation on the provided tensors.
|
||||
"""
|
||||
if paddle.distributed.get_world_size() < 2:
|
||||
return tensor
|
||||
|
||||
tensors_gather = []
|
||||
paddle.distributed.all_gather(tensors_gather, tensor)
|
||||
|
||||
output = paddle.concat(tensors_gather, axis=0)
|
||||
return output
|
||||
|
||||
|
||||
def freeze_batchnorm_statictis(layer):
|
||||
def freeze_bn(layer):
|
||||
if isinstance(layer, (nn.layer.norm._BatchNormBase)):
|
||||
layer._use_global_stats = True
|
||||
|
||||
|
||||
def MoCo_V1(backbone, neck, head, pretrained=False, use_ssld=False):
|
||||
model = MoCo(
|
||||
backbone_config=backbone, neck_config=neck, head_config=head, T=0.07)
|
||||
_load_pretrained(
|
||||
pretrained, model, MODEL_URLS["MoCo_V1"], use_ssld=use_ssld)
|
||||
return model
|
||||
|
||||
|
||||
def MoCo_V2(backbone, neck, head, pretrained=False, use_ssld=False):
|
||||
model = MoCo(
|
||||
backbone_config=backbone, neck_config=neck, head_config=head, T=0.2)
|
||||
_load_pretrained(
|
||||
pretrained, model, MODEL_URLS["MoCo_V2"], use_ssld=use_ssld)
|
||||
return model
|
|
@ -0,0 +1,139 @@
|
|||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# reference: https://arxiv.org/abs/1611.05431
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
from ....utils import logger
|
||||
from ppcls.utils.initializer import normal_
|
||||
from ..legendary_models import *
|
||||
from ....utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url
|
||||
|
||||
MODEL_URLS = {"MoCo_finetune": "UNKNOWN"}
|
||||
|
||||
__all__ = list(MODEL_URLS.keys())
|
||||
|
||||
|
||||
class ClasHead(nn.Layer):
|
||||
"""Simple classifier head.
|
||||
"""
|
||||
|
||||
def __init__(self, with_avg_pool=False, in_channels=2048, class_num=1000):
|
||||
super(ClasHead, self).__init__()
|
||||
self.with_avg_pool = with_avg_pool
|
||||
self.in_channels = in_channels
|
||||
self.num_classes = class_num
|
||||
|
||||
if self.with_avg_pool:
|
||||
self.avg_pool = nn.AdaptiveAvgPool2D((1, 1))
|
||||
self.fc = nn.Linear(in_channels, class_num)
|
||||
# reset_parameters(self.fc_cls)
|
||||
normal_(self.fc, mean=0.0, std=0.01, bias=0.0)
|
||||
|
||||
def forward(self, x):
|
||||
if self.with_avg_pool:
|
||||
x = self.avg_pool(x)
|
||||
x = paddle.reshape(x, [-1, self.in_channels])
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
||||
|
||||
def _load_pretrained(pretrained_config, model, use_ssld=False):
|
||||
if pretrained_config is not None:
|
||||
if pretrained_config.startswith("http"):
|
||||
load_dygraph_pretrain_from_url(model, pretrained_config)
|
||||
else:
|
||||
load_dygraph_pretrain(model, pretrained_config)
|
||||
|
||||
|
||||
class Classification(nn.Layer):
|
||||
"""
|
||||
Simple image classification.
|
||||
"""
|
||||
|
||||
def __init__(self, backbone, head, with_sobel=False):
|
||||
super(Classification, self).__init__()
|
||||
self.backbone = backbone
|
||||
self.head = head
|
||||
|
||||
def forward(self, inputs):
|
||||
x = self.backbone(inputs)
|
||||
x = self.head(x)
|
||||
return x
|
||||
|
||||
|
||||
def freeze_batchnorm_statictis(layer):
|
||||
def freeze_bn(layer):
|
||||
if isinstance(layer, nn.BatchNorm):
|
||||
layer._use_global_stats = True
|
||||
|
||||
|
||||
def freeze_params(model):
|
||||
from ppcls.arch.backbone.legendary_models.resnet import ConvBNLayer, BottleneckBlock
|
||||
for item in ['stem', 'max_pool', 'blocks', 'avg_pool']:
|
||||
m = getattr(model, item)
|
||||
if isinstance(m, nn.Sequential):
|
||||
for item in m:
|
||||
if isinstance(item, ConvBNLayer):
|
||||
print(item.bn)
|
||||
freeze_batchnorm_statictis(item.bn)
|
||||
|
||||
if isinstance(item, BottleneckBlock):
|
||||
freeze_batchnorm_statictis(item.conv0.bn)
|
||||
freeze_batchnorm_statictis(item.conv1.bn)
|
||||
freeze_batchnorm_statictis(item.conv2.bn)
|
||||
if hasattr(item, 'short'):
|
||||
freeze_batchnorm_statictis(item.short.bn)
|
||||
|
||||
for param in m.parameters():
|
||||
param.trainable = False
|
||||
|
||||
|
||||
def MoCo_finetune(backbone, head, pretrained=False, use_ssld=False):
|
||||
backbone_config = backbone
|
||||
head_config = head
|
||||
backbone_name = backbone_config.pop('name')
|
||||
backbone = eval(backbone_name)(**backbone_config)
|
||||
|
||||
# stop layer for backbone
|
||||
stop_layer_name = backbone_config.pop('stop_layer_name', None)
|
||||
if stop_layer_name:
|
||||
backbone.stop_after(stop_layer_name=stop_layer_name)
|
||||
# freeze specified layer before
|
||||
freeze_layer_name = backbone_config.pop('freeze_befor', None)
|
||||
if freeze_layer_name:
|
||||
ret = backbone.freeze_befor(freeze_layer_name)
|
||||
if ret:
|
||||
logger.info(
|
||||
"moco_clas backbone successfully freeze param update befor the layer: {}".
|
||||
format(freeze_layer_name))
|
||||
else:
|
||||
logger.error(
|
||||
"moco_clas backbone failurely freeze param update befor the layer: {}".
|
||||
format(freeze_layer_name))
|
||||
|
||||
freeze_params(backbone)
|
||||
head_name = head_config.pop('name')
|
||||
head = eval(head_name)(**head_config)
|
||||
model = Classification(backbone=backbone, head=head)
|
||||
|
||||
# load pretrain_moco_model weight
|
||||
pretrained_config = backbone_config.pop('pretrained_model')
|
||||
_load_pretrained(pretrained_config, model, use_ssld=use_ssld)
|
||||
return model
|
|
@ -0,0 +1,130 @@
|
|||
# global configs
|
||||
Global:
|
||||
checkpoints: null
|
||||
pretrained_model: null
|
||||
output_dir: ./output/
|
||||
device: gpu
|
||||
save_interval: 50
|
||||
# train_epoch_iter_two_samples
|
||||
train_mode: iter_two_samples
|
||||
eval_during_train: False
|
||||
eval_interval: 1
|
||||
epochs: 200
|
||||
print_batch_step: 10
|
||||
use_visualdl: False
|
||||
# used for static mode and model export
|
||||
image_shape: [3, 224, 224]
|
||||
save_inference_dir: ./inference
|
||||
# training model under @to_static
|
||||
to_static: False
|
||||
|
||||
|
||||
# model architecture
|
||||
Arch:
|
||||
name: MoCo_V2
|
||||
backbone:
|
||||
name: ResNet50
|
||||
stop_layer_name: AvgPool2D
|
||||
neck:
|
||||
name: NonLinearNeck
|
||||
in_channels: 2048
|
||||
hid_channels: 2048
|
||||
out_channels: 128
|
||||
head:
|
||||
name: ContrastiveHead
|
||||
temperature: 0.2
|
||||
|
||||
# loss function config
|
||||
Loss:
|
||||
Train:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
|
||||
|
||||
Optimizer:
|
||||
name: Momentum
|
||||
momentum: 0.9
|
||||
weight_decay: 0.0001
|
||||
lr:
|
||||
name: Cosine
|
||||
learning_rate: 0.03
|
||||
T_max: 200
|
||||
|
||||
|
||||
|
||||
# data loader for train
|
||||
DataLoader:
|
||||
Train:
|
||||
dataset:
|
||||
name: MoCoImageNetDataset
|
||||
image_root: ./dataset/ILSVRC2012/
|
||||
cls_label_path: ./dataset/ILSVRC2012/train_list.txt
|
||||
return_label: False
|
||||
return_two_sample: True
|
||||
transform_ops:
|
||||
- DecodeImage:
|
||||
to_rgb: True,
|
||||
channel_first: False
|
||||
- RandomResizedCrop:
|
||||
size: 224
|
||||
scale: [0.2, 1.]
|
||||
view_trans1:
|
||||
- RandomApply:
|
||||
transforms:
|
||||
- RawColorJitter:
|
||||
brightness: 0.4
|
||||
contrast: 0.4
|
||||
saturation: 0.4
|
||||
hue: 0.1
|
||||
p: 0.8
|
||||
- RandomGrayscale:
|
||||
p: 0.2
|
||||
- RandomApply:
|
||||
transforms:
|
||||
- GaussianBlur:
|
||||
sigma: [0.1, 2.0]
|
||||
p: 0.5
|
||||
- RandomHorizontalFlip:
|
||||
- NormalizeImage:
|
||||
scale: 1.0/255.0
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
- ToCHWImage:
|
||||
view_trans2:
|
||||
- RandomApply:
|
||||
transforms:
|
||||
- RawColorJitter:
|
||||
brightness: 0.4
|
||||
contrast: 0.4
|
||||
saturation: 0.4
|
||||
hue: 0.1
|
||||
p: 0.8
|
||||
- RandomGrayscale:
|
||||
p: 0.2
|
||||
- RandomApply:
|
||||
transforms:
|
||||
- GaussianBlur:
|
||||
sigma: [0.1, 2.0]
|
||||
p: 0.5
|
||||
- RandomHorizontalFlip:
|
||||
- NormalizeImage:
|
||||
scale: 1.0/255.0
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
- ToCHWImage:
|
||||
sampler:
|
||||
name: DistributedBatchSampler
|
||||
batch_size: 64
|
||||
drop_last: True
|
||||
shuffle: True
|
||||
loader:
|
||||
num_workers: 4
|
||||
use_shared_memory: True
|
||||
|
||||
Metric:
|
||||
Train:
|
||||
- TopkAcc:
|
||||
topk: [1, 5]
|
||||
|
|
@ -0,0 +1,114 @@
|
|||
# global configs
|
||||
Global:
|
||||
checkpoints: null
|
||||
output_dir: ./output/
|
||||
device: gpu
|
||||
save_interval: 20
|
||||
eval_during_train: True
|
||||
eval_interval: 1
|
||||
epochs: 100
|
||||
print_batch_step: 10
|
||||
use_visualdl: False
|
||||
# used for static mode and model export
|
||||
image_shape: [3, 224, 224]
|
||||
save_inference_dir: ./inference
|
||||
# training model under @to_static
|
||||
to_static: False
|
||||
|
||||
|
||||
Arch:
|
||||
name: MoCo_finetune
|
||||
pretrained_model: ./pretrain/moco_v2_bs_256_epoch_200
|
||||
backbone:
|
||||
name: ResNet50
|
||||
stop_layer_name: avg_pool
|
||||
freeze_befor: avg_pool
|
||||
head:
|
||||
name: ClasHead
|
||||
class_num: 1000
|
||||
|
||||
|
||||
# loss function config for traing/eval process
|
||||
Loss:
|
||||
Train:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
Eval:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
|
||||
|
||||
Optimizer:
|
||||
name: Momentum
|
||||
momentum: 0.9
|
||||
lr:
|
||||
name: MultiStepDecay
|
||||
learning_rate: 30.0
|
||||
milestones: [60, 80]
|
||||
|
||||
|
||||
DataLoader:
|
||||
Train:
|
||||
dataset:
|
||||
name: MoCoImageNetDataset
|
||||
image_root: ./dataset/ILSVRC2012/
|
||||
cls_label_path: ./dataset/ILSVRC2012/train_list.txt
|
||||
return_label: True
|
||||
return_two_sample: False
|
||||
transform_ops:
|
||||
- DecodeImage:
|
||||
to_rgb: True
|
||||
channel_first: False
|
||||
- RandomResizedCrop:
|
||||
size: 224
|
||||
- RandomHorizontalFlip:
|
||||
- NormalizeImage:
|
||||
scale: 1.0/255.0
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
sampler:
|
||||
name: DistributedBatchSampler
|
||||
batch_size: 64
|
||||
shuffle: True
|
||||
drop_last: False
|
||||
loader:
|
||||
num_workers: 4
|
||||
use_shared_memory: True
|
||||
|
||||
Eval:
|
||||
dataset:
|
||||
name: MoCoImageNetDataset
|
||||
image_root: ./dataset/ILSVRC2012/
|
||||
cls_label_path: ./dataset/ILSVRC2012/val_list.txt
|
||||
return_label: True
|
||||
return_two_sample: False
|
||||
transform_ops:
|
||||
- DecodeImage:
|
||||
to_rgb: True
|
||||
channel_first: False
|
||||
- Resize:
|
||||
size: 256
|
||||
- CenterCrop:
|
||||
size: 224
|
||||
- NormalizeImage:
|
||||
scale: 1.0/255.0
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
sampler:
|
||||
name: DistributedBatchSampler
|
||||
batch_size: 64
|
||||
shuffle: True
|
||||
drop_last: True
|
||||
loader:
|
||||
num_workers: 4
|
||||
use_shared_memory: True
|
||||
|
||||
Metric:
|
||||
Train:
|
||||
- TopkAcc:
|
||||
topk: [1, 5]
|
||||
Eval:
|
||||
- TopkAcc:
|
||||
topk: [1, 5]
|
|
@ -14,3 +14,4 @@ from ppcls.data.dataloader.face_dataset import AdaFaceDataset, FiveValidationDat
|
|||
from ppcls.data.dataloader.custom_label_dataset import CustomLabelDataset
|
||||
from ppcls.data.dataloader.cifar import Cifar10, Cifar100
|
||||
from ppcls.data.dataloader.metabin_sampler import DomainShuffleBatchSampler, NaiveIdentityBatchSampler
|
||||
from ppcls.data.dataloader.moco_imagenet_dataset import MoCoImageNetDataset
|
||||
|
|
|
@ -72,4 +72,4 @@ class ImageNetDataset(CommonDataset):
|
|||
else:
|
||||
self.labels.append(np.int64(line[1]))
|
||||
assert os.path.exists(self.images[
|
||||
-1]), f"path {self.images[-1]} does not exist."
|
||||
-1]), f"path {self.images[-1]} does not exist."
|
|
@ -0,0 +1,119 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
import numpy as np
|
||||
import os
|
||||
from ppcls.utils import logger
|
||||
from .common_dataset import CommonDataset, create_operators
|
||||
from ppcls.data.preprocess import transform
|
||||
|
||||
|
||||
class MoCoImageNetDataset(CommonDataset):
|
||||
"""MoCoImageNetDataset
|
||||
|
||||
Args:
|
||||
image_root (str): image root, path to `ILSVRC2012`
|
||||
cls_label_path (str): path to annotation file `train_list.txt` or `val_list.txt`
|
||||
return_label (bool, optional): whether return original label.
|
||||
return_two_sample (bool, optional): whether return two views about original image.
|
||||
transform_ops (list, optional): list of transform op(s). Defaults to None.
|
||||
delimiter (str, optional): delimiter. Defaults to None.
|
||||
relabel (bool, optional): whether do relabel when original label do not starts from 0 or are discontinuous. Defaults to False.
|
||||
view_trans1 (list): some transform op(s) for view1.
|
||||
view_trans2 (list): some transform op(s) for view2.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_root,
|
||||
cls_label_path,
|
||||
return_label=True,
|
||||
return_two_sample=False,
|
||||
transform_ops=None,
|
||||
delimiter=None,
|
||||
relabel=False,
|
||||
view_trans1=None,
|
||||
view_trans2=None, ):
|
||||
self.delimiter = delimiter if delimiter is not None else " "
|
||||
self.relabel = relabel
|
||||
super(MoCoImageNetDataset, self).__init__(image_root, cls_label_path,
|
||||
transform_ops)
|
||||
|
||||
self.return_label = return_label
|
||||
self.return_two_sample = return_two_sample
|
||||
|
||||
if self.return_two_sample:
|
||||
self.view_transform1 = create_operators(view_trans1)
|
||||
self.view_transform2 = create_operators(view_trans2)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
try:
|
||||
with open(self.images[idx], 'rb') as f:
|
||||
img = f.read()
|
||||
|
||||
if self.return_two_sample:
|
||||
sample1 = transform(img, self._transform_ops)
|
||||
sample2 = transform(img, self._transform_ops)
|
||||
sample1 = transform(sample1, self.view_transform1)
|
||||
sample2 = transform(sample2, self.view_transform2)
|
||||
|
||||
if self.return_label:
|
||||
return (sample1, sample2, self.labels[idx])
|
||||
else:
|
||||
return (sample1, sample2)
|
||||
|
||||
if self._transform_ops:
|
||||
img = transform(img, self._transform_ops)
|
||||
img = img.transpose((2, 0, 1))
|
||||
|
||||
return (img, self.labels[idx])
|
||||
|
||||
except Exception as ex:
|
||||
logger.error("Exception occured when parse line: {} with msg: {}".
|
||||
format(self.images[idx], ex))
|
||||
rnd_idx = np.random.randint(self.__len__())
|
||||
return self.__getitem__(rnd_idx)
|
||||
|
||||
def _load_anno(self, seed=None):
|
||||
assert os.path.exists(
|
||||
self._cls_path), f"path {self._cls_path} does not exist."
|
||||
assert os.path.exists(
|
||||
self._img_root), f"path {self._img_root} does not exist."
|
||||
self.images = []
|
||||
self.labels = []
|
||||
|
||||
with open(self._cls_path) as fd:
|
||||
lines = fd.readlines()
|
||||
if self.relabel:
|
||||
label_set = set()
|
||||
for line in lines:
|
||||
line = line.strip().split(self.delimiter)
|
||||
label_set.add(np.int64(line[1]))
|
||||
label_map = {
|
||||
oldlabel: newlabel
|
||||
for newlabel, oldlabel in enumerate(label_set)
|
||||
}
|
||||
|
||||
if seed is not None:
|
||||
np.random.RandomState(seed).shuffle(lines)
|
||||
for line in lines:
|
||||
line = line.strip().split(self.delimiter)
|
||||
self.images.append(os.path.join(self._img_root, line[0]))
|
||||
if self.relabel:
|
||||
self.labels.append(label_map[np.int64(line[1])])
|
||||
else:
|
||||
self.labels.append(np.int64(line[1]))
|
||||
assert os.path.exists(self.images[
|
||||
-1]), f"path {self.images[-1]} does not exist."
|
|
@ -12,6 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from paddle.vision.transforms import ToTensor, Normalize, RandomHorizontalFlip, RandomResizedCrop, Transpose
|
||||
from ppcls.data.preprocess.ops.autoaugment import ImageNetPolicy as RawImageNetPolicy
|
||||
from ppcls.data.preprocess.ops.randaugment import RandAugment as RawRandAugment
|
||||
from ppcls.data.preprocess.ops.randaugment import RandomApply
|
||||
|
@ -48,6 +49,8 @@ from ppcls.data.preprocess.ops.operators import RandomRotation
|
|||
from ppcls.data.preprocess.ops.operators import Padv2
|
||||
from ppcls.data.preprocess.ops.operators import RandomRot90
|
||||
from ppcls.data.preprocess.ops.operators import PCALighting
|
||||
from ppcls.data.preprocess.ops.operators import GaussianBlur
|
||||
|
||||
from .ops.operators import format_data
|
||||
from paddle.vision.transforms import Pad as Pad_paddle_vision
|
||||
|
||||
|
@ -58,6 +61,7 @@ import numpy as np
|
|||
from PIL import Image
|
||||
import random
|
||||
|
||||
|
||||
def transform(data, ops=[]):
|
||||
""" transform """
|
||||
for op in ops:
|
||||
|
@ -139,4 +143,4 @@ class TimmAutoAugment(RawTimmAutoAugment):
|
|||
if isinstance(img, Image.Image):
|
||||
img = np.asarray(img)
|
||||
|
||||
return img
|
||||
return img
|
||||
|
|
|
@ -24,12 +24,13 @@ import math
|
|||
import random
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image, ImageOps, __version__ as PILLOW_VERSION
|
||||
from PIL import ImageFilter, Image, ImageOps, __version__ as PILLOW_VERSION
|
||||
from paddle.vision.transforms import ColorJitter as RawColorJitter
|
||||
from paddle.vision.transforms import CenterCrop, Resize
|
||||
from paddle.vision.transforms import RandomRotation as RawRandomRotation
|
||||
from paddle.vision.transforms import ToTensor, Normalize, RandomHorizontalFlip, RandomResizedCrop
|
||||
from paddle.vision.transforms import functional as F
|
||||
from paddle.vision.transforms import transforms as T
|
||||
from .autoaugment import ImageNetPolicy
|
||||
from .functional import augmentations
|
||||
from ppcls.utils import logger
|
||||
|
@ -742,8 +743,8 @@ class Pad(object):
|
|||
# Process fill color for affine transforms
|
||||
major_found, minor_found = (int(v)
|
||||
for v in PILLOW_VERSION.split('.')[:2])
|
||||
major_required, minor_required = (int(v) for v in
|
||||
min_pil_version.split('.')[:2])
|
||||
major_required, minor_required = (
|
||||
int(v) for v in min_pil_version.split('.')[:2])
|
||||
if major_found < major_required or (major_found == major_required and
|
||||
minor_found < minor_required):
|
||||
if fill is None:
|
||||
|
@ -858,6 +859,25 @@ class BlurImage(object):
|
|||
return {"img": img, "blur_image": label}
|
||||
|
||||
|
||||
class GaussianBlur(object):
|
||||
"""Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709"""
|
||||
|
||||
def __init__(self, sigma=[.1, 2.], backend="cv2"):
|
||||
self.sigma = sigma
|
||||
self.kernel_size = 23
|
||||
self.backbend = backend
|
||||
|
||||
def __call__(self, x):
|
||||
sigma = np.random.uniform(self.sigma[0], self.sigma[1])
|
||||
if self.backbend == "PIL":
|
||||
x = x.filter(ImageFilter.GaussianBlur(radius=sigma))
|
||||
return x
|
||||
else:
|
||||
x = cv2.GaussianBlur(
|
||||
np.array(x), (self.kernel_size, self.kernel_size), sigma)
|
||||
return Image.fromarray(x.astype(np.uint8))
|
||||
|
||||
|
||||
class RandomGrayscale(object):
|
||||
"""Randomly convert image to grayscale with a probability of p (default 0.1).
|
||||
|
||||
|
@ -878,14 +898,20 @@ class RandomGrayscale(object):
|
|||
def __call__(self, img):
|
||||
"""
|
||||
Args:
|
||||
img (PIL Image): Image to be converted to grayscale.
|
||||
img (PIL.Image|np.array): Image to be converted to grayscale.
|
||||
|
||||
Returns:
|
||||
PIL Image: Randomly grayscaled image.
|
||||
"""
|
||||
num_output_channels = 1 if img.mode == 'L' else 3
|
||||
if random.random() < self.p:
|
||||
return F.to_grayscale(img, num_output_channels=num_output_channels)
|
||||
if isinstance(img, Image.Image):
|
||||
if img.mode == 'L':
|
||||
num_output_channels = 1
|
||||
|
||||
if isinstance(img, np.ndarray) or isinstance(img, Image.Image):
|
||||
num_output_channels = 3
|
||||
if random.random() < self.p:
|
||||
return F.to_grayscale(
|
||||
img, num_output_channels=num_output_channels)
|
||||
return img
|
||||
|
||||
def __repr__(self):
|
||||
|
|
|
@ -259,4 +259,4 @@ class RandAugmentV2(RandAugment):
|
|||
"equalize": lambda img, _: ImageOps.equalize(img),
|
||||
"invert": lambda img, _: ImageOps.invert(img),
|
||||
"cutout": lambda img, magnitude: cutout(img, magnitude, replace=fillcolor[0])
|
||||
}
|
||||
}
|
|
@ -330,6 +330,7 @@ class Engine(object):
|
|||
if self.config["Global"]["distributed"]:
|
||||
dist.init_parallel_env()
|
||||
self.model = paddle.DataParallel(self.model)
|
||||
|
||||
if self.mode == 'train' and len(self.train_loss_func.parameters(
|
||||
)) > 0:
|
||||
self.train_loss_func = paddle.DataParallel(
|
||||
|
|
|
@ -16,3 +16,4 @@ from ppcls.engine.train.train_fixmatch import train_epoch_fixmatch
|
|||
from ppcls.engine.train.train_fixmatch_ccssl import train_epoch_fixmatch_ccssl
|
||||
from ppcls.engine.train.train_progressive import train_epoch_progressive
|
||||
from ppcls.engine.train.train_metabin import train_epoch_metabin
|
||||
from ppcls.engine.train.train_iter_two_samples import train_epoch_iter_two_samples
|
|
@ -0,0 +1,106 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import time
|
||||
import paddle
|
||||
from ppcls.engine.train.utils import update_loss, update_metric, log_info, type_name
|
||||
from ppcls.utils import profiler
|
||||
|
||||
|
||||
def train_epoch_iter_two_samples(engine, epoch_id, print_batch_step):
|
||||
tic = time.time()
|
||||
|
||||
if not hasattr(engine, "train_dataloader_iter"):
|
||||
engine.train_dataloader_iter = iter(engine.train_dataloader)
|
||||
|
||||
for iter_id in range(engine.iter_per_epoch):
|
||||
# fetch data batch from dataloader
|
||||
try:
|
||||
batch = next(engine.train_dataloader_iter)
|
||||
except Exception:
|
||||
engine.train_dataloader_iter = iter(engine.train_dataloader)
|
||||
batch = next(engine.train_dataloader_iter)
|
||||
|
||||
profiler.add_profiler_step(engine.config["profiler_options"])
|
||||
if iter_id == 5:
|
||||
for key in engine.time_info:
|
||||
engine.time_info[key].reset()
|
||||
engine.time_info["reader_cost"].update(time.time() - tic)
|
||||
# view_1_samples: batch[0]
|
||||
# view_2_samples: batch[1]
|
||||
batch_size = batch[0].shape[0]
|
||||
engine.global_step += 1
|
||||
|
||||
# image input
|
||||
if engine.amp:
|
||||
amp_level = engine.config["AMP"].get("level", "O1").upper()
|
||||
with paddle.amp.auto_cast(
|
||||
custom_black_list={
|
||||
"flatten_contiguous_range", "greater_than"
|
||||
},
|
||||
level=amp_level):
|
||||
logits, labels = forward(engine, batch)
|
||||
loss_dict = engine.train_loss_func(logits, labels)
|
||||
else:
|
||||
logits, labels = forward(engine, batch)
|
||||
loss_dict = engine.train_loss_func(logits, labels)
|
||||
|
||||
# loss
|
||||
loss = loss_dict["loss"] / engine.update_freq
|
||||
|
||||
# backward & step opt
|
||||
if engine.amp:
|
||||
scaled = engine.scaler.scale(loss)
|
||||
scaled.backward()
|
||||
if (iter_id + 1) % engine.update_freq == 0:
|
||||
for i in range(len(engine.optimizer)):
|
||||
engine.scaler.minimize(engine.optimizer[i], scaled)
|
||||
else:
|
||||
loss.backward()
|
||||
if (iter_id + 1) % engine.update_freq == 0:
|
||||
for i in range(len(engine.optimizer)):
|
||||
engine.optimizer[i].step()
|
||||
|
||||
if (iter_id + 1) % engine.update_freq == 0:
|
||||
# clear grad
|
||||
for i in range(len(engine.optimizer)):
|
||||
engine.optimizer[i].clear_grad()
|
||||
# step lr(by step)
|
||||
for i in range(len(engine.lr_sch)):
|
||||
if not getattr(engine.lr_sch[i], "by_epoch", False):
|
||||
engine.lr_sch[i].step()
|
||||
# update ema
|
||||
if engine.ema:
|
||||
engine.model_ema.update(engine.model)
|
||||
|
||||
# below code just for logging
|
||||
# update metric_for_logger
|
||||
update_metric(engine, logits, [labels], batch_size)
|
||||
# update_loss_for_logger
|
||||
update_loss(engine, loss_dict, batch_size)
|
||||
engine.time_info["batch_cost"].update(time.time() - tic)
|
||||
if iter_id % print_batch_step == 0:
|
||||
log_info(engine, batch_size, epoch_id, iter_id)
|
||||
tic = time.time()
|
||||
|
||||
# step lr(by epoch)
|
||||
for i in range(len(engine.lr_sch)):
|
||||
if getattr(engine.lr_sch[i], "by_epoch", False) and \
|
||||
type_name(engine.lr_sch[i]) != "ReduceOnPlateau":
|
||||
engine.lr_sch[i].step()
|
||||
|
||||
|
||||
def forward(engine, batch):
|
||||
return engine.model(batch)
|
Loading…
Reference in New Issue