From b8a158937756c82c94488d244a76ff7559b96e9f Mon Sep 17 00:00:00 2001 From: Yang Nie Date: Fri, 17 Mar 2023 00:35:51 +0800 Subject: [PATCH] update data augment and init method for MobileViTv3-v2 --- ppcls/arch/backbone/model_zoo/mobilevit_v3.py | 18 ++++++++++++++---- .../ImageNet/MobileViTv3/MobileViTv3_x0_5.yaml | 8 ++++---- .../MobileViTv3/MobileViTv3_x0_75.yaml | 8 ++++---- .../ImageNet/MobileViTv3/MobileViTv3_x1_0.yaml | 8 ++++---- 4 files changed, 26 insertions(+), 16 deletions(-) diff --git a/ppcls/arch/backbone/model_zoo/mobilevit_v3.py b/ppcls/arch/backbone/model_zoo/mobilevit_v3.py index 61b2d8456..d2e569511 100644 --- a/ppcls/arch/backbone/model_zoo/mobilevit_v3.py +++ b/ppcls/arch/backbone/model_zoo/mobilevit_v3.py @@ -687,15 +687,25 @@ class MobileViTv3(nn.Layer): def _init_weights(self, m): if isinstance(m, nn.Conv2D): + fan_in = m.weight.shape[1] * m.weight.shape[2] * m.weight.shape[3] fan_out = m.weight.shape[0] * m.weight.shape[2] * m.weight.shape[3] - nn.initializer.KaimingNormal(fan_in=fan_out)(m.weight) - if m.bias is not None: - nn.initializer.Constant(0)(m.bias) + if self.mobilevit_v2_based: + bound = 1.0 / fan_in**0.5 + nn.initializer.Uniform(-bound, bound)(m.weight) + if m.bias is not None: + nn.initializer.Uniform(-bound, bound)(m.bias) + else: + nn.initializer.KaimingNormal(fan_in=fan_out)(m.weight) + if m.bias is not None: + nn.initializer.Constant(0)(m.bias) elif isinstance(m, nn.BatchNorm2D): nn.initializer.Constant(1)(m.weight) nn.initializer.Constant(0)(m.bias) elif isinstance(m, nn.Linear): - nn.initializer.TruncatedNormal(std=.02)(m.weight) + if self.mobilevit_v2_based: + nn.initializer.XavierUniform()(m.weight) + else: + nn.initializer.TruncatedNormal(std=.02)(m.weight) if m.bias is not None: nn.initializer.Constant(0)(m.bias) diff --git a/ppcls/configs/ImageNet/MobileViTv3/MobileViTv3_x0_5.yaml b/ppcls/configs/ImageNet/MobileViTv3/MobileViTv3_x0_5.yaml index aa9591cf3..2465210c3 100644 --- a/ppcls/configs/ImageNet/MobileViTv3/MobileViTv3_x0_5.yaml +++ b/ppcls/configs/ImageNet/MobileViTv3/MobileViTv3_x0_5.yaml @@ -93,15 +93,15 @@ DataLoader: r1: 0.3 attempt: 10 use_log_aspect: True - mode: pixel + mode: const batch_transform_ops: - OpSampler: MixupOperator: alpha: 0.2 - prob: 0.5 + prob: 0.25 CutmixOperator: alpha: 1.0 - prob: 0.5 + prob: 0.25 sampler: name: DistributedBatchSampler batch_size: 128 @@ -111,7 +111,7 @@ DataLoader: num_workers: 4 use_shared_memory: True Eval: - dataset: + dataset: name: ImageNetDataset image_root: ./dataset/ILSVRC2012/ cls_label_path: ./dataset/ILSVRC2012/val_list.txt diff --git a/ppcls/configs/ImageNet/MobileViTv3/MobileViTv3_x0_75.yaml b/ppcls/configs/ImageNet/MobileViTv3/MobileViTv3_x0_75.yaml index 837c1ace9..1a97e3535 100644 --- a/ppcls/configs/ImageNet/MobileViTv3/MobileViTv3_x0_75.yaml +++ b/ppcls/configs/ImageNet/MobileViTv3/MobileViTv3_x0_75.yaml @@ -93,15 +93,15 @@ DataLoader: r1: 0.3 attempt: 10 use_log_aspect: True - mode: pixel + mode: const batch_transform_ops: - OpSampler: MixupOperator: alpha: 0.2 - prob: 0.5 + prob: 0.25 CutmixOperator: alpha: 1.0 - prob: 0.5 + prob: 0.25 sampler: name: DistributedBatchSampler batch_size: 128 @@ -111,7 +111,7 @@ DataLoader: num_workers: 4 use_shared_memory: True Eval: - dataset: + dataset: name: ImageNetDataset image_root: ./dataset/ILSVRC2012/ cls_label_path: ./dataset/ILSVRC2012/val_list.txt diff --git a/ppcls/configs/ImageNet/MobileViTv3/MobileViTv3_x1_0.yaml b/ppcls/configs/ImageNet/MobileViTv3/MobileViTv3_x1_0.yaml index f9de6f559..c973d8796 100644 --- a/ppcls/configs/ImageNet/MobileViTv3/MobileViTv3_x1_0.yaml +++ b/ppcls/configs/ImageNet/MobileViTv3/MobileViTv3_x1_0.yaml @@ -93,15 +93,15 @@ DataLoader: r1: 0.3 attempt: 10 use_log_aspect: True - mode: pixel + mode: const batch_transform_ops: - OpSampler: MixupOperator: alpha: 0.2 - prob: 0.5 + prob: 0.25 CutmixOperator: alpha: 1.0 - prob: 0.5 + prob: 0.25 sampler: name: DistributedBatchSampler batch_size: 128 @@ -111,7 +111,7 @@ DataLoader: num_workers: 4 use_shared_memory: True Eval: - dataset: + dataset: name: ImageNetDataset image_root: ./dataset/ILSVRC2012/ cls_label_path: ./dataset/ILSVRC2012/val_list.txt