Wang Xin 045e5f6ac7
add pre-commit workflow (#11973)
* add pre-commit workflow

* run 'pre-commit run --all-files'

* setup python version
2024-04-21 21:46:20 +08:00

219 lines
7.0 KiB
Python

# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import yaml
import os
from argparse import ArgumentParser, RawDescriptionHelpFormatter
def override(dl, ks, v):
"""
Recursively replace dict of list
Args:
dl(dict or list): dict or list to be replaced
ks(list): list of keys
v(str): value to be replaced
"""
def str2num(v):
try:
return eval(v)
except Exception:
return v
assert isinstance(dl, (list, dict)), "{} should be a list or a dict"
assert len(ks) > 0, "lenght of keys should larger than 0"
if isinstance(dl, list):
k = str2num(ks[0])
if len(ks) == 1:
assert k < len(dl), "index({}) out of range({})".format(k, dl)
dl[k] = str2num(v)
else:
override(dl[k], ks[1:], v)
else:
if len(ks) == 1:
# assert ks[0] in dl, ('{} is not exist in {}'.format(ks[0], dl))
if not ks[0] in dl:
logger.warning("A new filed ({}) detected!".format(ks[0], dl))
dl[ks[0]] = str2num(v)
else:
assert (
ks[0] in dl
), "({}) doesn't exist in {}, a new dict field is invalid".format(ks[0], dl)
override(dl[ks[0]], ks[1:], v)
def override_config(config, options=None):
"""
Recursively override the config
Args:
config(dict): dict to be replaced
options(list): list of pairs(key0.key1.idx.key2=value)
such as: [
'topk=2',
'VALID.transforms.1.ResizeImage.resize_short=300'
]
Returns:
config(dict): replaced config
"""
if options is not None:
for opt in options:
assert isinstance(opt, str), "option({}) should be a str".format(opt)
assert "=" in opt, (
"option({}) should contain a ="
"to distinguish between key and value".format(opt)
)
pair = opt.split("=")
assert len(pair) == 2, "there can be only a = in the option"
key, value = pair
keys = key.split(".")
override(config, keys, value)
return config
class ArgsParser(ArgumentParser):
def __init__(self):
super(ArgsParser, self).__init__(formatter_class=RawDescriptionHelpFormatter)
self.add_argument("-c", "--config", help="configuration file to use")
self.add_argument("-t", "--tag", default="0", help="tag for marking worker")
self.add_argument(
"-o",
"--override",
action="append",
default=[],
help="config options to be overridden",
)
self.add_argument(
"--style_image",
default="examples/style_images/1.jpg",
help="tag for marking worker",
)
self.add_argument(
"--text_corpus", default="PaddleOCR", help="tag for marking worker"
)
self.add_argument("--language", default="en", help="tag for marking worker")
def parse_args(self, argv=None):
args = super(ArgsParser, self).parse_args(argv)
assert args.config is not None, "Please specify --config=configure_file_path."
return args
def load_config(file_path):
"""
Load config from yml/yaml file.
Args:
file_path (str): Path of the config file to be loaded.
Returns: config
"""
ext = os.path.splitext(file_path)[1]
assert ext in [".yml", ".yaml"], "only support yaml files for now"
with open(file_path, "rb") as f:
config = yaml.load(f, Loader=yaml.Loader)
return config
def gen_config():
base_config = {
"Global": {
"algorithm": "SRNet",
"use_gpu": True,
"start_epoch": 1,
"stage1_epoch_num": 100,
"stage2_epoch_num": 100,
"log_smooth_window": 20,
"print_batch_step": 2,
"save_model_dir": "./output/SRNet",
"use_visualdl": False,
"save_epoch_step": 10,
"vgg_pretrain": "./pretrained/VGG19_pretrained",
"vgg_load_static_pretrain": True,
},
"Architecture": {
"model_type": "data_aug",
"algorithm": "SRNet",
"net_g": {
"name": "srnet_net_g",
"encode_dim": 64,
"norm": "batch",
"use_dropout": False,
"init_type": "xavier",
"init_gain": 0.02,
"use_dilation": 1,
},
# input_nc, ndf, netD,
# n_layers_D=3, norm='instance', use_sigmoid=False, init_type='normal', init_gain=0.02, gpu_id='cuda:0'
"bg_discriminator": {
"name": "srnet_bg_discriminator",
"input_nc": 6,
"ndf": 64,
"netD": "basic",
"norm": "none",
"init_type": "xavier",
},
"fusion_discriminator": {
"name": "srnet_fusion_discriminator",
"input_nc": 6,
"ndf": 64,
"netD": "basic",
"norm": "none",
"init_type": "xavier",
},
},
"Loss": {"lamb": 10, "perceptual_lamb": 1, "muvar_lamb": 50, "style_lamb": 500},
"Optimizer": {
"name": "Adam",
"learning_rate": {"name": "lambda", "lr": 0.0002, "lr_decay_iters": 50},
"beta1": 0.5,
"beta2": 0.999,
},
"Train": {
"batch_size_per_card": 8,
"num_workers_per_card": 4,
"dataset": {
"delimiter": "\t",
"data_dir": "/",
"label_file": "tmp/label.txt",
"transforms": [
{
"DecodeImage": {
"to_rgb": True,
"to_np": False,
"channel_first": False,
}
},
{
"NormalizeImage": {
"scale": 1.0 / 255.0,
"mean": [0.485, 0.456, 0.406],
"std": [0.229, 0.224, 0.225],
"order": None,
}
},
{"ToCHWImage": None},
],
},
},
}
with open("config.yml", "w") as f:
yaml.dump(base_config, f)
if __name__ == "__main__":
gen_config()