fix bug in randaug, train_progressive and efficientnet_v2
parent
7e12c73e0a
commit
4fdcda7c60
|
@ -268,7 +268,8 @@ v2_xl_block = [ # only for 21k pretraining.
|
|||
]
|
||||
efficientnetv2_params = {
|
||||
# params: (block, width, depth, dropout)
|
||||
"efficientnetv2-s": (v2_s_block, 1.0, 1.0, np.linspace(0.1, 0.3, 4)),
|
||||
"efficientnetv2-s":
|
||||
(v2_s_block, 1.0, 1.0, np.linspace(0.1, 0.3, 4).tolist()),
|
||||
"efficientnetv2-m": (v2_m_block, 1.0, 1.0, 0.3),
|
||||
"efficientnetv2-l": (v2_l_block, 1.0, 1.0, 0.4),
|
||||
"efficientnetv2-xl": (v2_xl_block, 1.0, 1.0, 0.4),
|
||||
|
|
|
@ -109,6 +109,18 @@ class RandAugmentV2(RawRandAugmentV2):
|
|||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def __call__(self, img):
|
||||
if not isinstance(img, Image.Image):
|
||||
img = np.ascontiguousarray(img)
|
||||
img = Image.fromarray(img)
|
||||
|
||||
img = super().__call__(img)
|
||||
|
||||
if isinstance(img, Image.Image):
|
||||
img = np.asarray(img)
|
||||
|
||||
return img
|
||||
|
||||
|
||||
class TimmAutoAugment(RawTimmAutoAugment):
|
||||
""" TimmAutoAugment wrapper to auto fit different img tyeps. """
|
||||
|
|
|
@ -203,6 +203,8 @@ class RandAugmentV2(RandAugment):
|
|||
"cutout": int(40 * abso_level)
|
||||
}
|
||||
|
||||
# from https://stackoverflow.com/questions/5252170/
|
||||
# specify-image-filling-color-when-rotating-in-python-with-pil-and-setting-expand
|
||||
def rotate_with_fill(img, magnitude):
|
||||
rot = img.convert("RGBA").rotate(magnitude)
|
||||
return Image.composite(rot,
|
||||
|
|
|
@ -48,11 +48,12 @@ def train_epoch_progressive(engine, epoch_id, print_batch_step):
|
|||
cur_image_size = engine.config["DataLoader"]["Train"]["dataset"][
|
||||
"transform_ops"][1]["RandCropImage"]["progress_size"][stage_id]
|
||||
cur_magnitude = engine.config["DataLoader"]["Train"]["dataset"][
|
||||
"transform_ops"][3]["RandAugment"]["progress_magnitude"][stage_id]
|
||||
"transform_ops"][3]["RandAugmentV2"]["progress_magnitude"][
|
||||
stage_id]
|
||||
engine.config["DataLoader"]["Train"]["dataset"]["transform_ops"][1][
|
||||
"RandCropImage"]["size"] = cur_image_size
|
||||
engine.config["DataLoader"]["Train"]["dataset"]["transform_ops"][3][
|
||||
"RandAugment"]["magnitude"] = cur_magnitude
|
||||
"RandAugmentV2"]["magnitude"] = cur_magnitude
|
||||
engine.train_dataloader = build_dataloader(
|
||||
engine.config["DataLoader"],
|
||||
"Train",
|
||||
|
|
Loading…
Reference in New Issue