78 lines
2.3 KiB
Python
78 lines
2.3 KiB
Python
import os
|
|
from .base_logger import BaseLogger
|
|
|
|
class WandbLogger(BaseLogger):
|
|
def __init__(self,
|
|
project=None,
|
|
name=None,
|
|
id=None,
|
|
entity=None,
|
|
save_dir=None,
|
|
config=None,
|
|
**kwargs):
|
|
try:
|
|
import wandb
|
|
self.wandb = wandb
|
|
except ModuleNotFoundError:
|
|
raise ModuleNotFoundError(
|
|
"Please install wandb using `pip install wandb`"
|
|
)
|
|
|
|
self.project = project
|
|
self.name = name
|
|
self.id = id
|
|
self.save_dir = save_dir
|
|
self.config = config
|
|
self.kwargs = kwargs
|
|
self.entity = entity
|
|
self._run = None
|
|
self._wandb_init = dict(
|
|
project=self.project,
|
|
name=self.name,
|
|
id=self.id,
|
|
entity=self.entity,
|
|
dir=self.save_dir,
|
|
resume="allow"
|
|
)
|
|
self._wandb_init.update(**kwargs)
|
|
|
|
_ = self.run
|
|
|
|
if self.config:
|
|
self.run.config.update(self.config)
|
|
|
|
@property
|
|
def run(self):
|
|
if self._run is None:
|
|
if self.wandb.run is not None:
|
|
logger.info(
|
|
"There is a wandb run already in progress "
|
|
"and newly created instances of `WandbLogger` will reuse"
|
|
" this run. If this is not desired, call `wandb.finish()`"
|
|
"before instantiating `WandbLogger`."
|
|
)
|
|
self._run = self.wandb.run
|
|
else:
|
|
self._run = self.wandb.init(**self._wandb_init)
|
|
return self._run
|
|
|
|
def log_metrics(self, metrics, prefix=None, step=None):
|
|
if not prefix:
|
|
prefix = ""
|
|
updated_metrics = {prefix.lower() + "/" + k: v for k, v in metrics.items()}
|
|
|
|
self.run.log(updated_metrics, step=step)
|
|
|
|
def log_model(self, is_best, prefix, metadata=None):
|
|
model_path = os.path.join(self.save_dir, prefix + '.pdparams')
|
|
artifact = self.wandb.Artifact('model-{}'.format(self.run.id), type='model', metadata=metadata)
|
|
artifact.add_file(model_path, name="model_ckpt.pdparams")
|
|
|
|
aliases = [prefix]
|
|
if is_best:
|
|
aliases.append("best")
|
|
|
|
self.run.log_artifact(artifact, aliases=aliases)
|
|
|
|
def close(self):
|
|
self.run.finish() |