Integration of the WandbLogger with the latest changes in the PaddleOCR integrations
parent
e4ab0ebe86
commit
dde60d69be
|
@ -49,6 +49,7 @@ PaddleOCR aims to create multilingual, awesome, leading, and practical OCR tools
|
|||
- Support user-defined training, provides rich predictive inference deployment solutions
|
||||
- Support PIP installation, easy to use
|
||||
- Support Linux, Windows, MacOS and other systems
|
||||
- Supports metric logging to [VisualDL](https://www.paddlepaddle.org.cn/documentation/docs/en/guides/03_VisualDL/visualdl_usage_en.html) and [Weights & Biases](docs.wandb.ai)
|
||||
|
||||
## Visualization
|
||||
|
||||
|
|
|
@ -36,6 +36,7 @@ Take rec_chinese_lite_train_v2.0.yml as an example
|
|||
| pretrained_model | Set the path of the pre-trained model | ./pretrain_models/CRNN/best_accuracy | \ |
|
||||
| checkpoints | set model parameter path | None | Used to load parameters after interruption to continue training|
|
||||
| use_visualdl | Set whether to enable visualdl for visual log display | False | [Tutorial](https://www.paddlepaddle.org.cn/paddle/visualdl) |
|
||||
| use_wandb | Set whether to enable W&B for visual log display | False | [Documentation](https://docs.wandb.ai/)
|
||||
| infer_img | Set inference image path or folder path | ./infer_img | \||
|
||||
| character_dict_path | Set dictionary path | ./ppocr/utils/ppocr_keys_v1.txt | If the character_dict_path is None, model can only recognize number and lower letters |
|
||||
| max_text_length | Set the maximum length of text | 25 | \ |
|
||||
|
@ -66,7 +67,7 @@ In PaddleOCR, the network is divided into four stages: Transform, Backbone, Neck
|
|||
| :---------------------: | :---------------------: | :--------------: | :--------------------: |
|
||||
| model_type | Network Type | rec | Currently support`rec`,`det`,`cls` |
|
||||
| algorithm | Model name | CRNN | See [algorithm_overview](./algorithm_overview_en.md) for the support list |
|
||||
| **Transform** | Set the transformation method | - | Currently only recognition algorithms are supported, see [ppocr/modeling/transforms](../../ppocr/modeling/transforms) for details |
|
||||
| **Transform** | Set the transformation method | - | Currently only recognition algorithms are supported, see [ppocr/modeling/transform](../../ppocr/modeling/transform) for details |
|
||||
| name | Transformation class name | TPS | Currently supports `TPS` |
|
||||
| num_fiducial | Number of TPS control points | 20 | Ten on the top and bottom |
|
||||
| loc_lr | Localization network learning rate | 0.1 | \ |
|
||||
|
@ -130,6 +131,17 @@ In PaddleOCR, the network is divided into four stages: Transform, Backbone, Neck
|
|||
| drop_last | Whether to discard the last incomplete mini-batch because the number of samples in the data set cannot be divisible by batch_size | True | \ |
|
||||
| num_workers | The number of sub-processes used to load data, if it is 0, the sub-process is not started, and the data is loaded in the main process | 8 | \ |
|
||||
|
||||
### Weights & Biases ([W&B](../../ppocr/utils/loggers/wandb_logger.py))
|
||||
| Parameter | Use | Defaults | Note |
|
||||
| :---------------------: | :---------------------: | :--------------: | :--------------------: |
|
||||
| project | Project to which the run is to be logged | uncategorized | \
|
||||
| name | Alias/Name of the run | Randomly generated by wandb | \
|
||||
| id | ID of the run | Randomly generated by wandb | \
|
||||
| entity | User or team to which the run is being logged | The logged in user | \
|
||||
| save_dir | local directory in which all the models and other data is saved | wandb | \
|
||||
| config | model configuration | None | \
|
||||
|
||||
|
||||
<a name="3-multilingual-config-file-generation"></a>
|
||||
|
||||
## 3. Multilingual Config File Generation
|
||||
|
|
|
@ -0,0 +1,2 @@
|
|||
from .vdl_logger import VDLLogger
|
||||
from .wandb_logger import WandbLogger
|
|
@ -0,0 +1,15 @@
|
|||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
class BaseLogger(ABC):
|
||||
def __init__(self, save_dir):
|
||||
self.save_dir = save_dir
|
||||
os.makedirs(self.save_dir, exist_ok=True)
|
||||
|
||||
@abstractmethod
|
||||
def log_metrics(self, metrics, prefix=None):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def close(self):
|
||||
pass
|
|
@ -0,0 +1,18 @@
|
|||
from .base_logger import BaseLogger
|
||||
from visualdl import LogWriter
|
||||
|
||||
class VDLLogger(BaseLogger):
|
||||
def __init__(self, save_dir):
|
||||
super().__init__(save_dir)
|
||||
self.vdl_writer = LogWriter(logdir=save_dir)
|
||||
|
||||
def log_metrics(self, metrics, prefix=None, step=None):
|
||||
if not prefix:
|
||||
prefix = ""
|
||||
updated_metrics = {prefix + "/" + k: v for k, v in metrics.items()}
|
||||
|
||||
for k, v in updated_metrics.items():
|
||||
self.vdl_writer.add_scalar(k, v, step)
|
||||
|
||||
def close(self):
|
||||
self.vdl_writer.close()
|
|
@ -0,0 +1,78 @@
|
|||
import os
|
||||
from .base_logger import BaseLogger
|
||||
|
||||
class WandbLogger(BaseLogger):
|
||||
def __init__(self,
|
||||
project=None,
|
||||
name=None,
|
||||
id=None,
|
||||
entity=None,
|
||||
save_dir=None,
|
||||
config=None,
|
||||
**kwargs):
|
||||
try:
|
||||
import wandb
|
||||
self.wandb = wandb
|
||||
except ModuleNotFoundError:
|
||||
raise ModuleNotFoundError(
|
||||
"Please install wandb using `pip install wandb`"
|
||||
)
|
||||
|
||||
self.project = project
|
||||
self.name = name
|
||||
self.id = id
|
||||
self.save_dir = save_dir
|
||||
self.config = config
|
||||
self.kwargs = kwargs
|
||||
self.entity = entity
|
||||
self._run = None
|
||||
self._wandb_init = dict(
|
||||
project=self.project,
|
||||
name=self.name,
|
||||
id=self.id,
|
||||
entity=self.entity,
|
||||
dir=self.save_dir,
|
||||
resume="allow"
|
||||
)
|
||||
self._wandb_init.update(**kwargs)
|
||||
|
||||
_ = self.run
|
||||
|
||||
if self.config:
|
||||
self.run.config.update(self.config)
|
||||
|
||||
@property
|
||||
def run(self):
|
||||
if self._run is None:
|
||||
if self.wandb.run is not None:
|
||||
logger.info(
|
||||
"There is a wandb run already in progress "
|
||||
"and newly created instances of `WandbLogger` will reuse"
|
||||
" this run. If this is not desired, call `wandb.finish()`"
|
||||
"before instantiating `WandbLogger`."
|
||||
)
|
||||
self._run = self.wandb.run
|
||||
else:
|
||||
self._run = self.wandb.init(**self._wandb_init)
|
||||
return self._run
|
||||
|
||||
def log_metrics(self, metrics, prefix=None, step=None):
|
||||
if not prefix:
|
||||
prefix = ""
|
||||
updated_metrics = {prefix + "/" + k: v for k, v in metrics.items()}
|
||||
|
||||
self.run.log(updated_metrics, step=step)
|
||||
|
||||
def log_model(self, is_best, prefix, metadata=None):
|
||||
model_path = os.path.join(self.save_dir, prefix + '.pdparams')
|
||||
artifact = self.wandb.Artifact('model-{}'.format(self.run.id), type='model', metadata=metadata)
|
||||
artifact.add_file(model_path, name="model_ckpt.pdparams")
|
||||
|
||||
aliases = [prefix]
|
||||
if is_best:
|
||||
aliases.append("best")
|
||||
|
||||
self.run.log_artifact(artifact, aliases=aliases)
|
||||
|
||||
def close(self):
|
||||
self.run.finish()
|
|
@ -31,6 +31,7 @@ from ppocr.utils.stats import TrainingStats
|
|||
from ppocr.utils.save_load import save_model
|
||||
from ppocr.utils.utility import print_dict, AverageMeter
|
||||
from ppocr.utils.logging import get_logger
|
||||
from ppocr.utils.loggers import VDLLogger, WandbLogger
|
||||
from ppocr.utils import profiler
|
||||
from ppocr.data import build_dataloader
|
||||
|
||||
|
@ -161,7 +162,7 @@ def train(config,
|
|||
eval_class,
|
||||
pre_best_model_dict,
|
||||
logger,
|
||||
vdl_writer=None,
|
||||
log_writer=None,
|
||||
scaler=None):
|
||||
cal_metric_during_train = config['Global'].get('cal_metric_during_train',
|
||||
False)
|
||||
|
@ -288,10 +289,8 @@ def train(config,
|
|||
stats['lr'] = lr
|
||||
train_stats.update(stats)
|
||||
|
||||
if vdl_writer is not None and dist.get_rank() == 0:
|
||||
for k, v in train_stats.get().items():
|
||||
vdl_writer.add_scalar('TRAIN/{}'.format(k), v, global_step)
|
||||
vdl_writer.add_scalar('TRAIN/lr', lr, global_step)
|
||||
if log_writer is not None and dist.get_rank() == 0:
|
||||
log_writer.log_metrics(metrics=train_stats.get(), prefix="TRAIN", step=global_step)
|
||||
|
||||
if dist.get_rank() == 0 and (
|
||||
(global_step > 0 and global_step % print_batch_step == 0) or
|
||||
|
@ -337,11 +336,9 @@ def train(config,
|
|||
logger.info(cur_metric_str)
|
||||
|
||||
# logger metric
|
||||
if vdl_writer is not None:
|
||||
for k, v in cur_metric.items():
|
||||
if isinstance(v, (float, int)):
|
||||
vdl_writer.add_scalar('EVAL/{}'.format(k),
|
||||
cur_metric[k], global_step)
|
||||
if log_writer is not None:
|
||||
log_writer.log_metrics(metrics=cur_metric, prefix="EVAL", step=global_step)
|
||||
|
||||
if cur_metric[main_indicator] >= best_model_dict[
|
||||
main_indicator]:
|
||||
best_model_dict.update(cur_metric)
|
||||
|
@ -362,10 +359,13 @@ def train(config,
|
|||
]))
|
||||
logger.info(best_str)
|
||||
# logger best metric
|
||||
if vdl_writer is not None:
|
||||
vdl_writer.add_scalar('EVAL/best_{}'.format(main_indicator),
|
||||
best_model_dict[main_indicator],
|
||||
global_step)
|
||||
if log_writer is not None:
|
||||
log_writer.log_metrics(metrics={
|
||||
"best_{}".format(main_indicator): best_model_dict[main_indicator]
|
||||
}, prefix="EVAL", step=global_step)
|
||||
|
||||
if isinstance(log_writer, WandbLogger):
|
||||
log_writer.log_model(is_best=True, prefix="best_accuracy", metadata=best_model_dict)
|
||||
|
||||
reader_start = time.time()
|
||||
if dist.get_rank() == 0:
|
||||
|
@ -380,6 +380,10 @@ def train(config,
|
|||
best_model_dict=best_model_dict,
|
||||
epoch=epoch,
|
||||
global_step=global_step)
|
||||
|
||||
if isinstance(log_writer, WandbLogger):
|
||||
log_writer.log_model(is_best=False, prefix="latest")
|
||||
|
||||
if dist.get_rank() == 0 and epoch > 0 and epoch % save_epoch_step == 0:
|
||||
save_model(
|
||||
model,
|
||||
|
@ -392,11 +396,15 @@ def train(config,
|
|||
best_model_dict=best_model_dict,
|
||||
epoch=epoch,
|
||||
global_step=global_step)
|
||||
|
||||
if isinstance(log_writer, WandbLogger):
|
||||
log_writer.log_model(is_best=False, prefix='iter_epoch_{}'.format(epoch))
|
||||
|
||||
best_str = 'best metric, {}'.format(', '.join(
|
||||
['{}: {}'.format(k, v) for k, v in best_model_dict.items()]))
|
||||
logger.info(best_str)
|
||||
if dist.get_rank() == 0 and vdl_writer is not None:
|
||||
vdl_writer.close()
|
||||
if dist.get_rank() == 0 and log_writer is not None:
|
||||
log_writer.close()
|
||||
return
|
||||
|
||||
|
||||
|
@ -553,15 +561,22 @@ def preprocess(is_train=False):
|
|||
|
||||
config['Global']['distributed'] = dist.get_world_size() != 1
|
||||
|
||||
if config['Global']['use_visualdl'] and dist.get_rank() == 0:
|
||||
from visualdl import LogWriter
|
||||
if "use_visualdl" in config['Global'] and config['Global']['use_visualdl'] and dist.get_rank() == 0:
|
||||
save_model_dir = config['Global']['save_model_dir']
|
||||
vdl_writer_path = '{}/vdl/'.format(save_model_dir)
|
||||
os.makedirs(vdl_writer_path, exist_ok=True)
|
||||
vdl_writer = LogWriter(logdir=vdl_writer_path)
|
||||
log_writer = VDLLogger(save_model_dir)
|
||||
elif ("use_wandb" in config['Global'] and config['Global']['use_wandb']) or "wandb" in config:
|
||||
save_dir = config['Global']['save_model_dir']
|
||||
wandb_writer_path = "{}/wandb".format(save_dir)
|
||||
if "wandb" in config:
|
||||
wandb_params = config['wandb']
|
||||
else:
|
||||
wandb_params = dict()
|
||||
wandb_params.update({'save_dir': save_model_dir})
|
||||
log_writer = WandbLogger(**wandb_params, config=config)
|
||||
else:
|
||||
vdl_writer = None
|
||||
log_writer = None
|
||||
print_dict(config, logger)
|
||||
logger.info('train with paddle {} and device {}'.format(paddle.__version__,
|
||||
device))
|
||||
return config, device, logger, vdl_writer
|
||||
return config, device, logger, log_writer
|
||||
|
|
Loading…
Reference in New Issue