增加cifar100参数yaml

This commit is contained in:
zh-hike 2022-12-15 11:44:38 +00:00 committed by Walter
parent 009f347d64
commit f073e97d22
8 changed files with 120 additions and 292 deletions

View File

@ -70,10 +70,9 @@ class RecModel(TheseusLayer):
super().__init__()
backbone_config = config["Backbone"]
backbone_name = backbone_config.pop("name")
self.decoup = False
if backbone_config.get('decoup', False):
self.decoup = backbone_config.pop('decoup')
self.backbone = eval(backbone_name)(**backbone_config)
self.head_feature_from = config.get('head_feature_from', 'neck')
if "BackboneStopLayer" in config:
backbone_stop_layer = config["BackboneStopLayer"]["name"]
self.backbone.stop_after(backbone_stop_layer)
@ -94,19 +93,16 @@ class RecModel(TheseusLayer):
x = self.backbone(x)
out["backbone"] = x
if self.decoup:
logits_index, features_index = self.decoup['logits_index'], self.decoup['features_index']
logits, feat = x[logits_index], x[features_index]
out['logits'] = logits
out['features'] =feat
return out
if self.neck is not None:
feat = self.neck(x)
out["neck"] = feat
out["features"] = out['neck'] if self.neck else x
if self.head is not None:
y = self.head(out['features'], label)
if self.head_feature_from == 'backbone':
y = self.head(out['backbone'], label)
elif self.head_feature_from == 'neck':
y = self.head(out['features'], label)
out["logits"] = y
return out

View File

@ -2,6 +2,7 @@ import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle import ParamAttr
from ..base.theseus_layer import TheseusLayer
"""
backbone option "WideResNet"
code in this file is adpated from
@ -123,7 +124,7 @@ class Normalize(nn.Layer):
return out
class Wide_ResNet(nn.Layer):
class Wide_ResNet(TheseusLayer):
def __init__(self,
num_classes,
depth=28,

View File

@ -20,6 +20,7 @@ from .vehicle_neck import VehicleNeck
from paddle.nn import Tanh
from .bnneck import BNNeck
from .adamargin import AdaMargin
from .frfn_neck import FRFNNeck
__all__ = ['build_gear']
@ -27,7 +28,7 @@ __all__ = ['build_gear']
def build_gear(config):
support_dict = [
'ArcMargin', 'CosMargin', 'CircleMargin', 'FC', 'VehicleNeck', 'Tanh',
'BNNeck', 'AdaMargin',
'BNNeck', 'AdaMargin', 'FRFNNeck'
]
module_name = config.pop('name')
assert module_name in support_dict, Exception(

View File

@ -0,0 +1,32 @@
import paddle.nn as nn
class Normalize(nn.Layer):
""" Ln normalization copied from
https://github.com/salesforce/CoMatch
"""
def __init__(self, power=2):
super(Normalize, self).__init__()
self.power = power
def forward(self, x):
norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power)
out = x.divide(norm)
return out
class FRFNNeck(nn.Layer):
def __init__(self, num_features, low_dim, **kwargs):
super(FRFNNeck, self).__init__()
self.l2norm = Normalize(2)
self.fc1 = nn.Linear(num_features, num_features)
self.relu_mlp = nn.LeakyReLU(negative_slope=0.1)
self.fc2 = nn.Linear(num_features, low_dim)
def forward(self, x):
x = self.fc1(x)
x = self.relu_mlp(x)
x = self.fc2(x)
x = self.l2norm(x)
return x

View File

@ -26,6 +26,7 @@ Arch:
name: RecModel
infer_output_key: logits
infer_add_softmax: false
head_feature_from: backbone
Backbone:
name: WideResNet
widen_factor: 8
@ -33,13 +34,18 @@ Arch:
dropout: 0
num_classes: 100
low_dim: 64
proj: true
proj: false
proj_after: false
Decoup:
name: Decoup
logits_index: 0
features_index: 1
BackboneStopLayer:
name: bn1
Neck:
name: FRFNNeck
num_features: 512
low_dim: 64
Head:
name: FC
embedding_size: 512
class_num: 100
use_sync_bn: true
@ -66,33 +72,36 @@ Optimizer:
use_nesterov: true
weight_decay: 0.001
lr:
name: 'cosine_schedule_with_warmup'
name: 'CosineFixmatch'
learning_rate: 0.03
num_warmup_steps: 0
num_training_steps: 524800
DataLoader:
mean: [0.5071, 0.4867, 0.4408]
std: [0.2675, 0.2565, 0.2761]
Train:
dataset:
name: CIFAR100SSL
name: Cifar100
data_file: null
mode: 'train'
download: true
backend: 'pil'
sample_per_label: 100
expand_labels: 1
transform_ops:
- RandomHorizontalFlip:
prob: 0.5
- RandomCrop:
size: 32
- RandFlipImage:
flip_code: 1
- Pad_paddle_vision:
padding: 4
padding_mode: "reflect"
- ToTensor:
padding_mode: reflect
- RandCropImageV2:
size: [32, 32]
- NormalizeImage:
- Normalize:
scale: 1.0/255.0
mean: [0.5071, 0.4867, 0.4408]
std: [0.2675, 0.2565, 0.2761]
order: hwc
sampler:
name: DistributedBatchSampler
@ -105,46 +114,51 @@ DataLoader:
UnLabelTrain:
dataset:
name: CIFAR100SSL
name: Cifar100
data_file: null
mode: 'train'
backend: 'pil'
download: true
transform_w:
- RandomHorizontalFlip:
prob: 0.5
- RandomCrop:
size: 32
transform_ops_weak:
- RandFlipImage:
flip_code: 1
- Pad_paddle_vision:
padding: 4
padding_mode: 'reflect'
- ToTensor:
- Normalize:
padding_mode: reflect
- RandCropImageV2:
size: [32, 32]
- NormalizeImage:
scale: 1.0/255.0
mean: [0.5071, 0.4867, 0.4408]
std: [0.2675, 0.2565, 0.2761]
order: hwc
transform_s1:
- RandomHorizontalFlip:
prob: 0.5
- RandomCrop:
size: 32
transform_ops_strong:
- RandFlipImage:
flip_code: 1
- Pad_paddle_vision:
padding: 4
padding_mode: 'reflect'
- RandAugmentMC:
n: 2
m: 10
- ToTensor:
- Normalize:
padding_mode: reflect
- RandCropImageV2:
size: [32, 32]
- RandAugment:
num_layers: 2
magnitude: 10
- NormalizeImage:
scale: 1.0/255.0
mean: [0.5071, 0.4867, 0.4408]
std: [0.2675, 0.2565, 0.2761]
order: hwc
transform_s2:
- RandomResizedCrop:
size: 32
- RandomHorizontalFlip:
prob: 0.5
transform_ops_strong2:
- RandCropImageV2:
size: [32, 32]
- RandFlipImage:
flip_code: 1
- RandomApply:
transforms:
- ColorJitter:
- RawColorJitter:
brightness: 0.4
contrast: 0.4
saturation: 0.4
@ -152,7 +166,12 @@ DataLoader:
p: 0.8
- RandomGrayscale:
p: 0.2
- ToTensor:
- NormalizeImage:
scale: 1.0/255.0
mean: [0.5071, 0.4867, 0.4408]
std: [0.2675, 0.2565, 0.2761]
order: hwc
sampler:
name: DistributedBatchSampler
@ -165,15 +184,17 @@ DataLoader:
Eval:
dataset:
name: CIFAR100SSL
name: Cifar100
mode: 'test'
backend: 'pil'
download: true
data_file: null
transform_ops:
- ToTensor:
- Normalize:
- NormalizeImage:
scale: 1.0/255.0
mean: [0.5071, 0.4867, 0.4408]
std: [0.2675, 0.2565, 0.2761]
order: hwc
sampler:
name: DistributedBatchSampler
batch_size: 16

View File

@ -26,18 +26,26 @@ Arch:
name: RecModel
infer_output_key: logits
infer_add_softmax: false
head_feature_from: backbone
Backbone:
name: WideResNet
decoup:
logits_index: 0
features_index: 1
widen_factor: 2
depth: 28
dropout: 0
num_classes: 10
low_dim: 64
proj: true
proj: false
proj_after: false
BackboneStopLayer:
name: bn1
Neck:
name: FRFNNeck
num_features: 128
low_dim: 64
Head:
name: FC
embedding_size: 128
class_num: 10
use_sync_bn: true
@ -103,8 +111,6 @@ DataLoader:
num_workers: 4
use_shared_memory: true
UnLabelTrain:
dataset:
name: Cifar10

View File

@ -1,227 +0,0 @@
import PIL
import PIL.ImageDraw
import random
from paddle.vision.transforms import transforms as T
from paddle.vision.transforms.transforms import ColorJitter
import numpy as np
PARAMETER_MAX = 10
def AutoContrast(img, **kwarg):
return PIL.ImageOps.autocontrast(img)
def Brightness(img, v, max_v, bias=0):
v = _float_parameter(v, max_v) + bias
return PIL.ImageEnhance.Brightness(img).enhance(v)
def Color(img, v, max_v, bias=0):
v = _float_parameter(v, max_v) + bias
return PIL.ImageEnhance.Color(img).enhance(v)
def Contrast(img, v, max_v, bias=0):
v = _float_parameter(v, max_v) + bias
return PIL.ImageEnhance.Contrast(img).enhance(v)
def Cutout(img, v, max_v, bias=0):
if v == 0:
return img
v = _float_parameter(v, max_v) + bias
v = int(v * min(img.size))
return CutoutAbs(img, v)
def CutoutAbs(img, v, **kwarg):
w, h = img.size
x0 = np.random.uniform(0, w)
y0 = np.random.uniform(0, h)
x0 = int(max(0, x0 - v / 2.))
y0 = int(max(0, y0 - v / 2.))
x1 = int(min(w, x0 + v))
y1 = int(min(h, y0 + v))
xy = (x0, y0, x1, y1)
# gray
color = (127, 127, 127)
img = img.copy()
PIL.ImageDraw.Draw(img).rectangle(xy, color)
return img
def Equalize(img, **kwarg):
return PIL.ImageOps.equalize(img)
def Identity(img, **kwarg):
return img
def Invert(img, **kwarg):
return PIL.ImageOps.invert(img)
def Posterize(img, v, max_v, bias=0):
v = _int_parameter(v, max_v) + bias
return PIL.ImageOps.posterize(img, v)
def Rotate(img, v, max_v, bias=0):
v = _int_parameter(v, max_v) + bias
if random.random() < 0.5:
v = -v
return img.rotate(v)
def Sharpness(img, v, max_v, bias=0):
v = _float_parameter(v, max_v) + bias
return PIL.ImageEnhance.Sharpness(img).enhance(v)
def ShearX(img, v, max_v, bias=0):
v = _float_parameter(v, max_v) + bias
if random.random() < 0.5:
v = -v
return img.transform(img.size, PIL.Image.Transform.AFFINE, (1, v, 0, 0, 1, 0))
def ShearY(img, v, max_v, bias=0):
v = _float_parameter(v, max_v) + bias
if random.random() < 0.5:
v = -v
return img.transform(img.size, PIL.Image.Transform.AFFINE, (1, 0, 0, v, 1, 0))
def Solarize(img, v, max_v, bias=0):
v = _int_parameter(v, max_v) + bias
return PIL.ImageOps.solarize(img, 256 - v)
def SolarizeAdd(img, v, max_v, bias=0, threshold=128):
v = _int_parameter(v, max_v) + bias
if random.random() < 0.5:
v = -v
img_np = np.array(img).astype(np.int)
img_np = img_np + v
img_np = np.clip(img_np, 0, 255)
img_np = img_np.astype(np.uint8)
img = Image.fromarray(img_np)
return PIL.ImageOps.solarize(img, threshold)
def TranslateX(img, v, max_v, bias=0):
v = _float_parameter(v, max_v) + bias
if random.random() < 0.5:
v = -v
v = int(v * img.size[0])
return img.transform(img.size, PIL.Image.Transform.AFFINE, (1, 0, v, 0, 1, 0))
def TranslateY(img, v, max_v, bias=0):
v = _float_parameter(v, max_v) + bias
if random.random() < 0.5:
v = -v
v = int(v * img.size[1])
return img.transform(img.size, PIL.Image.Transform.AFFINE, (1, 0, 0, 0, 1, v))
def _float_parameter(v, max_v):
return float(v) * max_v / PARAMETER_MAX
def _int_parameter(v, max_v):
return int(v * max_v / PARAMETER_MAX)
def fixmatch_augment_pool():
# FixMatch paper
augs = [(AutoContrast, None, None),
(Brightness, 0.9, 0.05),
(Color, 0.9, 0.05),
(Contrast, 0.9, 0.05),
(Equalize, None, None),
(Identity, None, None),
(Posterize, 4, 4),
(Rotate, 30, 0),
(Sharpness, 0.9, 0.05),
(ShearX, 0.3, 0),
(ShearY, 0.3, 0),
(Solarize, 256, 0),
(TranslateX, 0.3, 0),
(TranslateY, 0.3, 0)]
return augs
def my_augment_pool():
# Test
augs = [(AutoContrast, None, None),
(Brightness, 1.8, 0.1),
(Color, 1.8, 0.1),
(Contrast, 1.8, 0.1),
(Cutout, 0.2, 0),
(Equalize, None, None),
(Invert, None, None),
(Posterize, 4, 4),
(Rotate, 30, 0),
(Sharpness, 1.8, 0.1),
(ShearX, 0.3, 0),
(ShearY, 0.3, 0),
(Solarize, 256, 0),
(SolarizeAdd, 110, 0),
(TranslateX, 0.45, 0),
(TranslateY, 0.45, 0)]
return augs
class RandAugmentPC(object):
def __init__(self, n, m):
assert n >= 1
assert 1 <= m <= 10
self.n = n
self.m = m
self.augment_pool = my_augment_pool()
def __call__(self, img):
ops = random.choices(self.augment_pool, k=self.n)
for op, max_v, bias in ops:
prob = np.random.uniform(0.2, 0.8)
if random.random() + prob >= 1:
img = op(img, v=self.m, max_v=max_v, bias=bias)
img = CutoutAbs(img, int(32*0.5))
return img
class RandAugmentMC(object):
def __init__(self, n, m):
assert n >= 1
assert 1 <= m <= 10
self.n = n
self.m = m
self.augment_pool = fixmatch_augment_pool()
def __call__(self, img):
ops = random.choices(self.augment_pool, k=self.n)
for op, max_v, bias in ops:
v = np.random.randint(1, self.m)
if random.random() < 0.5:
img = op(img, v=v, max_v=max_v, bias=bias)
img = CutoutAbs(img, int(32 * 0.5))
return img
class RandomApply:
def __init__(self, p, transforms):
self.p = p
ts = []
for t in transforms:
for key in t.keys():
ts.append(eval(key)(**t[key]))
self.trans = T.Compose(ts)
def __call__(self, img):
timg = self.trans(img)
return timg

View File

@ -11,8 +11,6 @@ import paddle
def train_epoch_fixmatch_ccssl(engine, epoch_id, print_batch_step):
print(engine.model.state_dict().keys())
assert 1==0
tic = time.time()
if not hasattr(engine, 'train_dataloader_iter'):
engine.train_dataloader_iter = iter(engine.train_dataloader)