update data augment and init method for MobileViTv3-v2
parent
c32e2b098a
commit
b8a1589377
|
@ -687,7 +687,14 @@ class MobileViTv3(nn.Layer):
|
||||||
|
|
||||||
def _init_weights(self, m):
|
def _init_weights(self, m):
|
||||||
if isinstance(m, nn.Conv2D):
|
if isinstance(m, nn.Conv2D):
|
||||||
|
fan_in = m.weight.shape[1] * m.weight.shape[2] * m.weight.shape[3]
|
||||||
fan_out = m.weight.shape[0] * m.weight.shape[2] * m.weight.shape[3]
|
fan_out = m.weight.shape[0] * m.weight.shape[2] * m.weight.shape[3]
|
||||||
|
if self.mobilevit_v2_based:
|
||||||
|
bound = 1.0 / fan_in**0.5
|
||||||
|
nn.initializer.Uniform(-bound, bound)(m.weight)
|
||||||
|
if m.bias is not None:
|
||||||
|
nn.initializer.Uniform(-bound, bound)(m.bias)
|
||||||
|
else:
|
||||||
nn.initializer.KaimingNormal(fan_in=fan_out)(m.weight)
|
nn.initializer.KaimingNormal(fan_in=fan_out)(m.weight)
|
||||||
if m.bias is not None:
|
if m.bias is not None:
|
||||||
nn.initializer.Constant(0)(m.bias)
|
nn.initializer.Constant(0)(m.bias)
|
||||||
|
@ -695,6 +702,9 @@ class MobileViTv3(nn.Layer):
|
||||||
nn.initializer.Constant(1)(m.weight)
|
nn.initializer.Constant(1)(m.weight)
|
||||||
nn.initializer.Constant(0)(m.bias)
|
nn.initializer.Constant(0)(m.bias)
|
||||||
elif isinstance(m, nn.Linear):
|
elif isinstance(m, nn.Linear):
|
||||||
|
if self.mobilevit_v2_based:
|
||||||
|
nn.initializer.XavierUniform()(m.weight)
|
||||||
|
else:
|
||||||
nn.initializer.TruncatedNormal(std=.02)(m.weight)
|
nn.initializer.TruncatedNormal(std=.02)(m.weight)
|
||||||
if m.bias is not None:
|
if m.bias is not None:
|
||||||
nn.initializer.Constant(0)(m.bias)
|
nn.initializer.Constant(0)(m.bias)
|
||||||
|
|
|
@ -93,15 +93,15 @@ DataLoader:
|
||||||
r1: 0.3
|
r1: 0.3
|
||||||
attempt: 10
|
attempt: 10
|
||||||
use_log_aspect: True
|
use_log_aspect: True
|
||||||
mode: pixel
|
mode: const
|
||||||
batch_transform_ops:
|
batch_transform_ops:
|
||||||
- OpSampler:
|
- OpSampler:
|
||||||
MixupOperator:
|
MixupOperator:
|
||||||
alpha: 0.2
|
alpha: 0.2
|
||||||
prob: 0.5
|
prob: 0.25
|
||||||
CutmixOperator:
|
CutmixOperator:
|
||||||
alpha: 1.0
|
alpha: 1.0
|
||||||
prob: 0.5
|
prob: 0.25
|
||||||
sampler:
|
sampler:
|
||||||
name: DistributedBatchSampler
|
name: DistributedBatchSampler
|
||||||
batch_size: 128
|
batch_size: 128
|
||||||
|
|
|
@ -93,15 +93,15 @@ DataLoader:
|
||||||
r1: 0.3
|
r1: 0.3
|
||||||
attempt: 10
|
attempt: 10
|
||||||
use_log_aspect: True
|
use_log_aspect: True
|
||||||
mode: pixel
|
mode: const
|
||||||
batch_transform_ops:
|
batch_transform_ops:
|
||||||
- OpSampler:
|
- OpSampler:
|
||||||
MixupOperator:
|
MixupOperator:
|
||||||
alpha: 0.2
|
alpha: 0.2
|
||||||
prob: 0.5
|
prob: 0.25
|
||||||
CutmixOperator:
|
CutmixOperator:
|
||||||
alpha: 1.0
|
alpha: 1.0
|
||||||
prob: 0.5
|
prob: 0.25
|
||||||
sampler:
|
sampler:
|
||||||
name: DistributedBatchSampler
|
name: DistributedBatchSampler
|
||||||
batch_size: 128
|
batch_size: 128
|
||||||
|
|
|
@ -93,15 +93,15 @@ DataLoader:
|
||||||
r1: 0.3
|
r1: 0.3
|
||||||
attempt: 10
|
attempt: 10
|
||||||
use_log_aspect: True
|
use_log_aspect: True
|
||||||
mode: pixel
|
mode: const
|
||||||
batch_transform_ops:
|
batch_transform_ops:
|
||||||
- OpSampler:
|
- OpSampler:
|
||||||
MixupOperator:
|
MixupOperator:
|
||||||
alpha: 0.2
|
alpha: 0.2
|
||||||
prob: 0.5
|
prob: 0.25
|
||||||
CutmixOperator:
|
CutmixOperator:
|
||||||
alpha: 1.0
|
alpha: 1.0
|
||||||
prob: 0.5
|
prob: 0.25
|
||||||
sampler:
|
sampler:
|
||||||
name: DistributedBatchSampler
|
name: DistributedBatchSampler
|
||||||
batch_size: 128
|
batch_size: 128
|
||||||
|
|
Loading…
Reference in New Issue