add db++
parent
1315cdfc86
commit
04e7104194
|
@ -18,7 +18,7 @@ Global:
|
|||
save_res_path: ./checkpoints/det_db/predicts_db.txt
|
||||
Architecture:
|
||||
model_type: det
|
||||
algorithm: DB
|
||||
algorithm: DB++
|
||||
Transform: null
|
||||
Backbone:
|
||||
name: ResNet
|
||||
|
|
|
@ -18,7 +18,7 @@ Global:
|
|||
save_res_path: ./checkpoints/det_db/predicts_db.txt
|
||||
Architecture:
|
||||
model_type: det
|
||||
algorithm: DB
|
||||
algorithm: DB++
|
||||
Transform: null
|
||||
Backbone:
|
||||
name: ResNet
|
||||
|
|
|
@ -67,6 +67,23 @@ class TextDetector(object):
|
|||
postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio
|
||||
postprocess_params["use_dilation"] = args.use_dilation
|
||||
postprocess_params["score_mode"] = args.det_db_score_mode
|
||||
elif self.det_algorithm == "DB++":
|
||||
postprocess_params['name'] = 'DBPostProcess'
|
||||
postprocess_params["thresh"] = args.det_db_thresh
|
||||
postprocess_params["box_thresh"] = args.det_db_box_thresh
|
||||
postprocess_params["max_candidates"] = 1000
|
||||
postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio
|
||||
postprocess_params["use_dilation"] = args.use_dilation
|
||||
postprocess_params["score_mode"] = args.det_db_score_mode
|
||||
pre_process_list[1] = {
|
||||
'NormalizeImage': {
|
||||
'std': [1.0, 1.0, 1.0],
|
||||
'mean':
|
||||
[0.48109378172549, 0.45752457890196, 0.40787054090196],
|
||||
'scale': '1./255.',
|
||||
'order': 'hwc'
|
||||
}
|
||||
}
|
||||
elif self.det_algorithm == "EAST":
|
||||
postprocess_params['name'] = 'EASTPostProcess'
|
||||
postprocess_params["score_thresh"] = args.det_east_score_thresh
|
||||
|
@ -231,7 +248,7 @@ class TextDetector(object):
|
|||
preds['f_score'] = outputs[1]
|
||||
preds['f_tco'] = outputs[2]
|
||||
preds['f_tvo'] = outputs[3]
|
||||
elif self.det_algorithm in ['DB', 'PSE']:
|
||||
elif self.det_algorithm in ['DB', 'PSE', 'DB++']:
|
||||
preds['maps'] = outputs[0]
|
||||
elif self.det_algorithm == 'FCE':
|
||||
for i, output in enumerate(outputs):
|
||||
|
|
|
@ -307,7 +307,8 @@ def train(config,
|
|||
train_stats.update(stats)
|
||||
|
||||
if log_writer is not None and dist.get_rank() == 0:
|
||||
log_writer.log_metrics(metrics=train_stats.get(), prefix="TRAIN", step=global_step)
|
||||
log_writer.log_metrics(
|
||||
metrics=train_stats.get(), prefix="TRAIN", step=global_step)
|
||||
|
||||
if dist.get_rank() == 0 and (
|
||||
(global_step > 0 and global_step % print_batch_step == 0) or
|
||||
|
@ -354,7 +355,8 @@ def train(config,
|
|||
|
||||
# logger metric
|
||||
if log_writer is not None:
|
||||
log_writer.log_metrics(metrics=cur_metric, prefix="EVAL", step=global_step)
|
||||
log_writer.log_metrics(
|
||||
metrics=cur_metric, prefix="EVAL", step=global_step)
|
||||
|
||||
if cur_metric[main_indicator] >= best_model_dict[
|
||||
main_indicator]:
|
||||
|
@ -377,11 +379,18 @@ def train(config,
|
|||
logger.info(best_str)
|
||||
# logger best metric
|
||||
if log_writer is not None:
|
||||
log_writer.log_metrics(metrics={
|
||||
"best_{}".format(main_indicator): best_model_dict[main_indicator]
|
||||
}, prefix="EVAL", step=global_step)
|
||||
|
||||
log_writer.log_model(is_best=True, prefix="best_accuracy", metadata=best_model_dict)
|
||||
log_writer.log_metrics(
|
||||
metrics={
|
||||
"best_{}".format(main_indicator):
|
||||
best_model_dict[main_indicator]
|
||||
},
|
||||
prefix="EVAL",
|
||||
step=global_step)
|
||||
|
||||
log_writer.log_model(
|
||||
is_best=True,
|
||||
prefix="best_accuracy",
|
||||
metadata=best_model_dict)
|
||||
|
||||
reader_start = time.time()
|
||||
if dist.get_rank() == 0:
|
||||
|
@ -413,7 +422,8 @@ def train(config,
|
|||
epoch=epoch,
|
||||
global_step=global_step)
|
||||
if log_writer is not None:
|
||||
log_writer.log_model(is_best=False, prefix='iter_epoch_{}'.format(epoch))
|
||||
log_writer.log_model(
|
||||
is_best=False, prefix='iter_epoch_{}'.format(epoch))
|
||||
|
||||
best_str = 'best metric, {}'.format(', '.join(
|
||||
['{}: {}'.format(k, v) for k, v in best_model_dict.items()]))
|
||||
|
@ -564,7 +574,7 @@ def preprocess(is_train=False):
|
|||
assert alg in [
|
||||
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
|
||||
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
|
||||
'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'PREN', 'FCE', 'SVTR'
|
||||
'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'PREN', 'FCE', 'SVTR', 'DB++'
|
||||
]
|
||||
|
||||
if use_xpu:
|
||||
|
@ -585,7 +595,8 @@ def preprocess(is_train=False):
|
|||
vdl_writer_path = '{}/vdl/'.format(save_model_dir)
|
||||
log_writer = VDLLogger(save_model_dir)
|
||||
loggers.append(log_writer)
|
||||
if ('use_wandb' in config['Global'] and config['Global']['use_wandb']) or 'wandb' in config:
|
||||
if ('use_wandb' in config['Global'] and
|
||||
config['Global']['use_wandb']) or 'wandb' in config:
|
||||
save_dir = config['Global']['save_model_dir']
|
||||
wandb_writer_path = "{}/wandb".format(save_dir)
|
||||
if "wandb" in config:
|
||||
|
|
Loading…
Reference in New Issue