mirror of
https://github.com/PaddlePaddle/PaddleClas.git
synced 2025-06-03 21:55:06 +08:00
增加cifar100参数yaml
This commit is contained in:
parent
009f347d64
commit
f073e97d22
@ -70,10 +70,9 @@ class RecModel(TheseusLayer):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
backbone_config = config["Backbone"]
|
backbone_config = config["Backbone"]
|
||||||
backbone_name = backbone_config.pop("name")
|
backbone_name = backbone_config.pop("name")
|
||||||
self.decoup = False
|
|
||||||
if backbone_config.get('decoup', False):
|
|
||||||
self.decoup = backbone_config.pop('decoup')
|
|
||||||
self.backbone = eval(backbone_name)(**backbone_config)
|
self.backbone = eval(backbone_name)(**backbone_config)
|
||||||
|
self.head_feature_from = config.get('head_feature_from', 'neck')
|
||||||
|
|
||||||
if "BackboneStopLayer" in config:
|
if "BackboneStopLayer" in config:
|
||||||
backbone_stop_layer = config["BackboneStopLayer"]["name"]
|
backbone_stop_layer = config["BackboneStopLayer"]["name"]
|
||||||
self.backbone.stop_after(backbone_stop_layer)
|
self.backbone.stop_after(backbone_stop_layer)
|
||||||
@ -94,19 +93,16 @@ class RecModel(TheseusLayer):
|
|||||||
x = self.backbone(x)
|
x = self.backbone(x)
|
||||||
|
|
||||||
out["backbone"] = x
|
out["backbone"] = x
|
||||||
if self.decoup:
|
|
||||||
logits_index, features_index = self.decoup['logits_index'], self.decoup['features_index']
|
|
||||||
logits, feat = x[logits_index], x[features_index]
|
|
||||||
out['logits'] = logits
|
|
||||||
out['features'] =feat
|
|
||||||
return out
|
|
||||||
|
|
||||||
if self.neck is not None:
|
if self.neck is not None:
|
||||||
feat = self.neck(x)
|
feat = self.neck(x)
|
||||||
out["neck"] = feat
|
out["neck"] = feat
|
||||||
out["features"] = out['neck'] if self.neck else x
|
out["features"] = out['neck'] if self.neck else x
|
||||||
if self.head is not None:
|
if self.head is not None:
|
||||||
y = self.head(out['features'], label)
|
if self.head_feature_from == 'backbone':
|
||||||
|
y = self.head(out['backbone'], label)
|
||||||
|
elif self.head_feature_from == 'neck':
|
||||||
|
y = self.head(out['features'], label)
|
||||||
out["logits"] = y
|
out["logits"] = y
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
@ -2,6 +2,7 @@ import paddle
|
|||||||
import paddle.nn as nn
|
import paddle.nn as nn
|
||||||
import paddle.nn.functional as F
|
import paddle.nn.functional as F
|
||||||
from paddle import ParamAttr
|
from paddle import ParamAttr
|
||||||
|
from ..base.theseus_layer import TheseusLayer
|
||||||
"""
|
"""
|
||||||
backbone option "WideResNet"
|
backbone option "WideResNet"
|
||||||
code in this file is adpated from
|
code in this file is adpated from
|
||||||
@ -123,7 +124,7 @@ class Normalize(nn.Layer):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
class Wide_ResNet(nn.Layer):
|
class Wide_ResNet(TheseusLayer):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
num_classes,
|
num_classes,
|
||||||
depth=28,
|
depth=28,
|
||||||
|
@ -20,6 +20,7 @@ from .vehicle_neck import VehicleNeck
|
|||||||
from paddle.nn import Tanh
|
from paddle.nn import Tanh
|
||||||
from .bnneck import BNNeck
|
from .bnneck import BNNeck
|
||||||
from .adamargin import AdaMargin
|
from .adamargin import AdaMargin
|
||||||
|
from .frfn_neck import FRFNNeck
|
||||||
|
|
||||||
__all__ = ['build_gear']
|
__all__ = ['build_gear']
|
||||||
|
|
||||||
@ -27,7 +28,7 @@ __all__ = ['build_gear']
|
|||||||
def build_gear(config):
|
def build_gear(config):
|
||||||
support_dict = [
|
support_dict = [
|
||||||
'ArcMargin', 'CosMargin', 'CircleMargin', 'FC', 'VehicleNeck', 'Tanh',
|
'ArcMargin', 'CosMargin', 'CircleMargin', 'FC', 'VehicleNeck', 'Tanh',
|
||||||
'BNNeck', 'AdaMargin',
|
'BNNeck', 'AdaMargin', 'FRFNNeck'
|
||||||
]
|
]
|
||||||
module_name = config.pop('name')
|
module_name = config.pop('name')
|
||||||
assert module_name in support_dict, Exception(
|
assert module_name in support_dict, Exception(
|
||||||
|
32
ppcls/arch/gears/frfn_neck.py
Normal file
32
ppcls/arch/gears/frfn_neck.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
import paddle.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
class Normalize(nn.Layer):
|
||||||
|
""" Ln normalization copied from
|
||||||
|
https://github.com/salesforce/CoMatch
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, power=2):
|
||||||
|
super(Normalize, self).__init__()
|
||||||
|
self.power = power
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power)
|
||||||
|
out = x.divide(norm)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class FRFNNeck(nn.Layer):
|
||||||
|
def __init__(self, num_features, low_dim, **kwargs):
|
||||||
|
super(FRFNNeck, self).__init__()
|
||||||
|
self.l2norm = Normalize(2)
|
||||||
|
self.fc1 = nn.Linear(num_features, num_features)
|
||||||
|
self.relu_mlp = nn.LeakyReLU(negative_slope=0.1)
|
||||||
|
self.fc2 = nn.Linear(num_features, low_dim)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.fc1(x)
|
||||||
|
x = self.relu_mlp(x)
|
||||||
|
x = self.fc2(x)
|
||||||
|
x = self.l2norm(x)
|
||||||
|
return x
|
@ -26,6 +26,7 @@ Arch:
|
|||||||
name: RecModel
|
name: RecModel
|
||||||
infer_output_key: logits
|
infer_output_key: logits
|
||||||
infer_add_softmax: false
|
infer_add_softmax: false
|
||||||
|
head_feature_from: backbone
|
||||||
Backbone:
|
Backbone:
|
||||||
name: WideResNet
|
name: WideResNet
|
||||||
widen_factor: 8
|
widen_factor: 8
|
||||||
@ -33,13 +34,18 @@ Arch:
|
|||||||
dropout: 0
|
dropout: 0
|
||||||
num_classes: 100
|
num_classes: 100
|
||||||
low_dim: 64
|
low_dim: 64
|
||||||
proj: true
|
proj: false
|
||||||
proj_after: false
|
proj_after: false
|
||||||
|
BackboneStopLayer:
|
||||||
Decoup:
|
name: bn1
|
||||||
name: Decoup
|
Neck:
|
||||||
logits_index: 0
|
name: FRFNNeck
|
||||||
features_index: 1
|
num_features: 512
|
||||||
|
low_dim: 64
|
||||||
|
Head:
|
||||||
|
name: FC
|
||||||
|
embedding_size: 512
|
||||||
|
class_num: 100
|
||||||
|
|
||||||
use_sync_bn: true
|
use_sync_bn: true
|
||||||
|
|
||||||
@ -66,33 +72,36 @@ Optimizer:
|
|||||||
use_nesterov: true
|
use_nesterov: true
|
||||||
weight_decay: 0.001
|
weight_decay: 0.001
|
||||||
lr:
|
lr:
|
||||||
name: 'cosine_schedule_with_warmup'
|
name: 'CosineFixmatch'
|
||||||
learning_rate: 0.03
|
learning_rate: 0.03
|
||||||
num_warmup_steps: 0
|
num_warmup_steps: 0
|
||||||
num_training_steps: 524800
|
|
||||||
|
|
||||||
DataLoader:
|
DataLoader:
|
||||||
mean: [0.5071, 0.4867, 0.4408]
|
mean: [0.5071, 0.4867, 0.4408]
|
||||||
std: [0.2675, 0.2565, 0.2761]
|
std: [0.2675, 0.2565, 0.2761]
|
||||||
Train:
|
Train:
|
||||||
dataset:
|
dataset:
|
||||||
name: CIFAR100SSL
|
name: Cifar100
|
||||||
data_file: null
|
data_file: null
|
||||||
mode: 'train'
|
mode: 'train'
|
||||||
download: true
|
download: true
|
||||||
|
backend: 'pil'
|
||||||
sample_per_label: 100
|
sample_per_label: 100
|
||||||
expand_labels: 1
|
expand_labels: 1
|
||||||
transform_ops:
|
transform_ops:
|
||||||
- RandomHorizontalFlip:
|
- RandFlipImage:
|
||||||
prob: 0.5
|
flip_code: 1
|
||||||
- RandomCrop:
|
- Pad_paddle_vision:
|
||||||
size: 32
|
|
||||||
padding: 4
|
padding: 4
|
||||||
padding_mode: "reflect"
|
padding_mode: reflect
|
||||||
- ToTensor:
|
- RandCropImageV2:
|
||||||
|
size: [32, 32]
|
||||||
|
- NormalizeImage:
|
||||||
- Normalize:
|
- Normalize:
|
||||||
|
scale: 1.0/255.0
|
||||||
mean: [0.5071, 0.4867, 0.4408]
|
mean: [0.5071, 0.4867, 0.4408]
|
||||||
std: [0.2675, 0.2565, 0.2761]
|
std: [0.2675, 0.2565, 0.2761]
|
||||||
|
order: hwc
|
||||||
|
|
||||||
sampler:
|
sampler:
|
||||||
name: DistributedBatchSampler
|
name: DistributedBatchSampler
|
||||||
@ -105,46 +114,51 @@ DataLoader:
|
|||||||
|
|
||||||
UnLabelTrain:
|
UnLabelTrain:
|
||||||
dataset:
|
dataset:
|
||||||
name: CIFAR100SSL
|
name: Cifar100
|
||||||
data_file: null
|
data_file: null
|
||||||
mode: 'train'
|
mode: 'train'
|
||||||
|
backend: 'pil'
|
||||||
download: true
|
download: true
|
||||||
|
|
||||||
transform_w:
|
transform_ops_weak:
|
||||||
- RandomHorizontalFlip:
|
- RandFlipImage:
|
||||||
prob: 0.5
|
flip_code: 1
|
||||||
- RandomCrop:
|
- Pad_paddle_vision:
|
||||||
size: 32
|
|
||||||
padding: 4
|
padding: 4
|
||||||
padding_mode: 'reflect'
|
padding_mode: reflect
|
||||||
- ToTensor:
|
- RandCropImageV2:
|
||||||
- Normalize:
|
size: [32, 32]
|
||||||
|
- NormalizeImage:
|
||||||
|
scale: 1.0/255.0
|
||||||
mean: [0.5071, 0.4867, 0.4408]
|
mean: [0.5071, 0.4867, 0.4408]
|
||||||
std: [0.2675, 0.2565, 0.2761]
|
std: [0.2675, 0.2565, 0.2761]
|
||||||
|
order: hwc
|
||||||
|
|
||||||
transform_s1:
|
transform_ops_strong:
|
||||||
- RandomHorizontalFlip:
|
- RandFlipImage:
|
||||||
prob: 0.5
|
flip_code: 1
|
||||||
- RandomCrop:
|
- Pad_paddle_vision:
|
||||||
size: 32
|
|
||||||
padding: 4
|
padding: 4
|
||||||
padding_mode: 'reflect'
|
padding_mode: reflect
|
||||||
- RandAugmentMC:
|
- RandCropImageV2:
|
||||||
n: 2
|
size: [32, 32]
|
||||||
m: 10
|
- RandAugment:
|
||||||
- ToTensor:
|
num_layers: 2
|
||||||
- Normalize:
|
magnitude: 10
|
||||||
|
- NormalizeImage:
|
||||||
|
scale: 1.0/255.0
|
||||||
mean: [0.5071, 0.4867, 0.4408]
|
mean: [0.5071, 0.4867, 0.4408]
|
||||||
std: [0.2675, 0.2565, 0.2761]
|
std: [0.2675, 0.2565, 0.2761]
|
||||||
|
order: hwc
|
||||||
|
|
||||||
transform_s2:
|
transform_ops_strong2:
|
||||||
- RandomResizedCrop:
|
- RandCropImageV2:
|
||||||
size: 32
|
size: [32, 32]
|
||||||
- RandomHorizontalFlip:
|
- RandFlipImage:
|
||||||
prob: 0.5
|
flip_code: 1
|
||||||
- RandomApply:
|
- RandomApply:
|
||||||
transforms:
|
transforms:
|
||||||
- ColorJitter:
|
- RawColorJitter:
|
||||||
brightness: 0.4
|
brightness: 0.4
|
||||||
contrast: 0.4
|
contrast: 0.4
|
||||||
saturation: 0.4
|
saturation: 0.4
|
||||||
@ -152,7 +166,12 @@ DataLoader:
|
|||||||
p: 0.8
|
p: 0.8
|
||||||
- RandomGrayscale:
|
- RandomGrayscale:
|
||||||
p: 0.2
|
p: 0.2
|
||||||
- ToTensor:
|
- NormalizeImage:
|
||||||
|
scale: 1.0/255.0
|
||||||
|
mean: [0.5071, 0.4867, 0.4408]
|
||||||
|
std: [0.2675, 0.2565, 0.2761]
|
||||||
|
order: hwc
|
||||||
|
|
||||||
|
|
||||||
sampler:
|
sampler:
|
||||||
name: DistributedBatchSampler
|
name: DistributedBatchSampler
|
||||||
@ -165,15 +184,17 @@ DataLoader:
|
|||||||
|
|
||||||
Eval:
|
Eval:
|
||||||
dataset:
|
dataset:
|
||||||
name: CIFAR100SSL
|
name: Cifar100
|
||||||
mode: 'test'
|
mode: 'test'
|
||||||
|
backend: 'pil'
|
||||||
download: true
|
download: true
|
||||||
data_file: null
|
data_file: null
|
||||||
transform_ops:
|
transform_ops:
|
||||||
- ToTensor:
|
- NormalizeImage:
|
||||||
- Normalize:
|
scale: 1.0/255.0
|
||||||
mean: [0.5071, 0.4867, 0.4408]
|
mean: [0.5071, 0.4867, 0.4408]
|
||||||
std: [0.2675, 0.2565, 0.2761]
|
std: [0.2675, 0.2565, 0.2761]
|
||||||
|
order: hwc
|
||||||
sampler:
|
sampler:
|
||||||
name: DistributedBatchSampler
|
name: DistributedBatchSampler
|
||||||
batch_size: 16
|
batch_size: 16
|
||||||
|
@ -26,18 +26,26 @@ Arch:
|
|||||||
name: RecModel
|
name: RecModel
|
||||||
infer_output_key: logits
|
infer_output_key: logits
|
||||||
infer_add_softmax: false
|
infer_add_softmax: false
|
||||||
|
head_feature_from: backbone
|
||||||
Backbone:
|
Backbone:
|
||||||
name: WideResNet
|
name: WideResNet
|
||||||
decoup:
|
|
||||||
logits_index: 0
|
|
||||||
features_index: 1
|
|
||||||
widen_factor: 2
|
widen_factor: 2
|
||||||
depth: 28
|
depth: 28
|
||||||
dropout: 0
|
dropout: 0
|
||||||
num_classes: 10
|
num_classes: 10
|
||||||
low_dim: 64
|
low_dim: 64
|
||||||
proj: true
|
proj: false
|
||||||
proj_after: false
|
proj_after: false
|
||||||
|
BackboneStopLayer:
|
||||||
|
name: bn1
|
||||||
|
Neck:
|
||||||
|
name: FRFNNeck
|
||||||
|
num_features: 128
|
||||||
|
low_dim: 64
|
||||||
|
Head:
|
||||||
|
name: FC
|
||||||
|
embedding_size: 128
|
||||||
|
class_num: 10
|
||||||
|
|
||||||
use_sync_bn: true
|
use_sync_bn: true
|
||||||
|
|
||||||
@ -103,8 +111,6 @@ DataLoader:
|
|||||||
num_workers: 4
|
num_workers: 4
|
||||||
use_shared_memory: true
|
use_shared_memory: true
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
UnLabelTrain:
|
UnLabelTrain:
|
||||||
dataset:
|
dataset:
|
||||||
name: Cifar10
|
name: Cifar10
|
||||||
|
@ -1,227 +0,0 @@
|
|||||||
import PIL
|
|
||||||
import PIL.ImageDraw
|
|
||||||
import random
|
|
||||||
from paddle.vision.transforms import transforms as T
|
|
||||||
from paddle.vision.transforms.transforms import ColorJitter
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
|
|
||||||
PARAMETER_MAX = 10
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def AutoContrast(img, **kwarg):
|
|
||||||
return PIL.ImageOps.autocontrast(img)
|
|
||||||
|
|
||||||
|
|
||||||
def Brightness(img, v, max_v, bias=0):
|
|
||||||
v = _float_parameter(v, max_v) + bias
|
|
||||||
return PIL.ImageEnhance.Brightness(img).enhance(v)
|
|
||||||
|
|
||||||
|
|
||||||
def Color(img, v, max_v, bias=0):
|
|
||||||
v = _float_parameter(v, max_v) + bias
|
|
||||||
return PIL.ImageEnhance.Color(img).enhance(v)
|
|
||||||
|
|
||||||
def Contrast(img, v, max_v, bias=0):
|
|
||||||
v = _float_parameter(v, max_v) + bias
|
|
||||||
return PIL.ImageEnhance.Contrast(img).enhance(v)
|
|
||||||
|
|
||||||
|
|
||||||
def Cutout(img, v, max_v, bias=0):
|
|
||||||
if v == 0:
|
|
||||||
return img
|
|
||||||
v = _float_parameter(v, max_v) + bias
|
|
||||||
v = int(v * min(img.size))
|
|
||||||
return CutoutAbs(img, v)
|
|
||||||
|
|
||||||
|
|
||||||
def CutoutAbs(img, v, **kwarg):
|
|
||||||
w, h = img.size
|
|
||||||
x0 = np.random.uniform(0, w)
|
|
||||||
y0 = np.random.uniform(0, h)
|
|
||||||
x0 = int(max(0, x0 - v / 2.))
|
|
||||||
y0 = int(max(0, y0 - v / 2.))
|
|
||||||
x1 = int(min(w, x0 + v))
|
|
||||||
y1 = int(min(h, y0 + v))
|
|
||||||
xy = (x0, y0, x1, y1)
|
|
||||||
# gray
|
|
||||||
color = (127, 127, 127)
|
|
||||||
img = img.copy()
|
|
||||||
PIL.ImageDraw.Draw(img).rectangle(xy, color)
|
|
||||||
return img
|
|
||||||
|
|
||||||
|
|
||||||
def Equalize(img, **kwarg):
|
|
||||||
return PIL.ImageOps.equalize(img)
|
|
||||||
|
|
||||||
|
|
||||||
def Identity(img, **kwarg):
|
|
||||||
return img
|
|
||||||
|
|
||||||
|
|
||||||
def Invert(img, **kwarg):
|
|
||||||
return PIL.ImageOps.invert(img)
|
|
||||||
|
|
||||||
|
|
||||||
def Posterize(img, v, max_v, bias=0):
|
|
||||||
v = _int_parameter(v, max_v) + bias
|
|
||||||
return PIL.ImageOps.posterize(img, v)
|
|
||||||
|
|
||||||
|
|
||||||
def Rotate(img, v, max_v, bias=0):
|
|
||||||
v = _int_parameter(v, max_v) + bias
|
|
||||||
if random.random() < 0.5:
|
|
||||||
v = -v
|
|
||||||
return img.rotate(v)
|
|
||||||
|
|
||||||
|
|
||||||
def Sharpness(img, v, max_v, bias=0):
|
|
||||||
v = _float_parameter(v, max_v) + bias
|
|
||||||
return PIL.ImageEnhance.Sharpness(img).enhance(v)
|
|
||||||
|
|
||||||
|
|
||||||
def ShearX(img, v, max_v, bias=0):
|
|
||||||
v = _float_parameter(v, max_v) + bias
|
|
||||||
if random.random() < 0.5:
|
|
||||||
v = -v
|
|
||||||
return img.transform(img.size, PIL.Image.Transform.AFFINE, (1, v, 0, 0, 1, 0))
|
|
||||||
|
|
||||||
|
|
||||||
def ShearY(img, v, max_v, bias=0):
|
|
||||||
v = _float_parameter(v, max_v) + bias
|
|
||||||
if random.random() < 0.5:
|
|
||||||
v = -v
|
|
||||||
return img.transform(img.size, PIL.Image.Transform.AFFINE, (1, 0, 0, v, 1, 0))
|
|
||||||
|
|
||||||
|
|
||||||
def Solarize(img, v, max_v, bias=0):
|
|
||||||
v = _int_parameter(v, max_v) + bias
|
|
||||||
return PIL.ImageOps.solarize(img, 256 - v)
|
|
||||||
|
|
||||||
|
|
||||||
def SolarizeAdd(img, v, max_v, bias=0, threshold=128):
|
|
||||||
v = _int_parameter(v, max_v) + bias
|
|
||||||
if random.random() < 0.5:
|
|
||||||
v = -v
|
|
||||||
img_np = np.array(img).astype(np.int)
|
|
||||||
img_np = img_np + v
|
|
||||||
img_np = np.clip(img_np, 0, 255)
|
|
||||||
img_np = img_np.astype(np.uint8)
|
|
||||||
img = Image.fromarray(img_np)
|
|
||||||
return PIL.ImageOps.solarize(img, threshold)
|
|
||||||
|
|
||||||
|
|
||||||
def TranslateX(img, v, max_v, bias=0):
|
|
||||||
v = _float_parameter(v, max_v) + bias
|
|
||||||
if random.random() < 0.5:
|
|
||||||
v = -v
|
|
||||||
v = int(v * img.size[0])
|
|
||||||
return img.transform(img.size, PIL.Image.Transform.AFFINE, (1, 0, v, 0, 1, 0))
|
|
||||||
|
|
||||||
|
|
||||||
def TranslateY(img, v, max_v, bias=0):
|
|
||||||
v = _float_parameter(v, max_v) + bias
|
|
||||||
if random.random() < 0.5:
|
|
||||||
v = -v
|
|
||||||
v = int(v * img.size[1])
|
|
||||||
return img.transform(img.size, PIL.Image.Transform.AFFINE, (1, 0, 0, 0, 1, v))
|
|
||||||
|
|
||||||
|
|
||||||
def _float_parameter(v, max_v):
|
|
||||||
return float(v) * max_v / PARAMETER_MAX
|
|
||||||
|
|
||||||
|
|
||||||
def _int_parameter(v, max_v):
|
|
||||||
return int(v * max_v / PARAMETER_MAX)
|
|
||||||
|
|
||||||
|
|
||||||
def fixmatch_augment_pool():
|
|
||||||
# FixMatch paper
|
|
||||||
augs = [(AutoContrast, None, None),
|
|
||||||
(Brightness, 0.9, 0.05),
|
|
||||||
(Color, 0.9, 0.05),
|
|
||||||
(Contrast, 0.9, 0.05),
|
|
||||||
(Equalize, None, None),
|
|
||||||
(Identity, None, None),
|
|
||||||
(Posterize, 4, 4),
|
|
||||||
(Rotate, 30, 0),
|
|
||||||
(Sharpness, 0.9, 0.05),
|
|
||||||
(ShearX, 0.3, 0),
|
|
||||||
(ShearY, 0.3, 0),
|
|
||||||
(Solarize, 256, 0),
|
|
||||||
(TranslateX, 0.3, 0),
|
|
||||||
(TranslateY, 0.3, 0)]
|
|
||||||
return augs
|
|
||||||
|
|
||||||
|
|
||||||
def my_augment_pool():
|
|
||||||
# Test
|
|
||||||
augs = [(AutoContrast, None, None),
|
|
||||||
(Brightness, 1.8, 0.1),
|
|
||||||
(Color, 1.8, 0.1),
|
|
||||||
(Contrast, 1.8, 0.1),
|
|
||||||
(Cutout, 0.2, 0),
|
|
||||||
(Equalize, None, None),
|
|
||||||
(Invert, None, None),
|
|
||||||
(Posterize, 4, 4),
|
|
||||||
(Rotate, 30, 0),
|
|
||||||
(Sharpness, 1.8, 0.1),
|
|
||||||
(ShearX, 0.3, 0),
|
|
||||||
(ShearY, 0.3, 0),
|
|
||||||
(Solarize, 256, 0),
|
|
||||||
(SolarizeAdd, 110, 0),
|
|
||||||
(TranslateX, 0.45, 0),
|
|
||||||
(TranslateY, 0.45, 0)]
|
|
||||||
return augs
|
|
||||||
|
|
||||||
|
|
||||||
class RandAugmentPC(object):
|
|
||||||
def __init__(self, n, m):
|
|
||||||
assert n >= 1
|
|
||||||
assert 1 <= m <= 10
|
|
||||||
self.n = n
|
|
||||||
self.m = m
|
|
||||||
self.augment_pool = my_augment_pool()
|
|
||||||
|
|
||||||
def __call__(self, img):
|
|
||||||
ops = random.choices(self.augment_pool, k=self.n)
|
|
||||||
for op, max_v, bias in ops:
|
|
||||||
prob = np.random.uniform(0.2, 0.8)
|
|
||||||
if random.random() + prob >= 1:
|
|
||||||
img = op(img, v=self.m, max_v=max_v, bias=bias)
|
|
||||||
img = CutoutAbs(img, int(32*0.5))
|
|
||||||
return img
|
|
||||||
|
|
||||||
|
|
||||||
class RandAugmentMC(object):
|
|
||||||
def __init__(self, n, m):
|
|
||||||
assert n >= 1
|
|
||||||
assert 1 <= m <= 10
|
|
||||||
self.n = n
|
|
||||||
self.m = m
|
|
||||||
self.augment_pool = fixmatch_augment_pool()
|
|
||||||
|
|
||||||
def __call__(self, img):
|
|
||||||
ops = random.choices(self.augment_pool, k=self.n)
|
|
||||||
for op, max_v, bias in ops:
|
|
||||||
v = np.random.randint(1, self.m)
|
|
||||||
if random.random() < 0.5:
|
|
||||||
img = op(img, v=v, max_v=max_v, bias=bias)
|
|
||||||
img = CutoutAbs(img, int(32 * 0.5))
|
|
||||||
return img
|
|
||||||
|
|
||||||
|
|
||||||
class RandomApply:
|
|
||||||
def __init__(self, p, transforms):
|
|
||||||
self.p = p
|
|
||||||
ts = []
|
|
||||||
for t in transforms:
|
|
||||||
for key in t.keys():
|
|
||||||
ts.append(eval(key)(**t[key]))
|
|
||||||
|
|
||||||
self.trans = T.Compose(ts)
|
|
||||||
|
|
||||||
def __call__(self, img):
|
|
||||||
timg = self.trans(img)
|
|
||||||
return timg
|
|
@ -11,8 +11,6 @@ import paddle
|
|||||||
|
|
||||||
|
|
||||||
def train_epoch_fixmatch_ccssl(engine, epoch_id, print_batch_step):
|
def train_epoch_fixmatch_ccssl(engine, epoch_id, print_batch_step):
|
||||||
print(engine.model.state_dict().keys())
|
|
||||||
assert 1==0
|
|
||||||
tic = time.time()
|
tic = time.time()
|
||||||
if not hasattr(engine, 'train_dataloader_iter'):
|
if not hasattr(engine, 'train_dataloader_iter'):
|
||||||
engine.train_dataloader_iter = iter(engine.train_dataloader)
|
engine.train_dataloader_iter = iter(engine.train_dataloader)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user