[Fix] Avoid creating a new logger in PretrainedInit (#791)

* use current logger

* remove get_current_instance

* remove logger parameter at weight_init

* remove elif branch
This commit is contained in:
谢昕辰 2022-12-12 14:16:15 +08:00 committed by GitHub
parent d876d4e0f8
commit 504fdc371a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 7 additions and 9 deletions

View File

@ -8,7 +8,7 @@ import torch
import torch.nn as nn
from torch import Tensor
from mmengine.logging import MMLogger, print_log
from mmengine.logging import print_log
from mmengine.registry import WEIGHT_INITIALIZERS, build_from_cfg
@ -481,22 +481,21 @@ class PretrainedInit:
from mmengine.runner.checkpoint import (_load_checkpoint_with_prefix,
load_checkpoint,
load_state_dict)
logger = MMLogger.get_instance('mmengine')
if self.prefix is None:
print_log(f'load model from: {self.checkpoint}', logger=logger)
print_log(f'load model from: {self.checkpoint}', logger='current')
load_checkpoint(
module,
self.checkpoint,
map_location=self.map_location,
strict=False,
logger=logger)
logger='current')
else:
print_log(
f'load {self.prefix} in model from: {self.checkpoint}',
logger=logger)
logger='current')
state_dict = _load_checkpoint_with_prefix(
self.prefix, self.checkpoint, map_location=self.map_location)
load_state_dict(module, state_dict, strict=False, logger=logger)
load_state_dict(module, state_dict, strict=False, logger='current')
if hasattr(module, '_params_init_info'):
update_init_info(module, init_info=self._get_init_info())

View File

@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import io
import logging
import os
import os.path as osp
import pkgutil
@ -106,10 +107,8 @@ def load_state_dict(module, state_dict, strict=False, logger=None):
err_msg = '\n'.join(err_msg)
if strict:
raise RuntimeError(err_msg)
elif logger is not None:
logger.warning(err_msg)
else:
print(err_msg)
print_log(err_msg, logger=logger, level=logging.WARNING)
def get_torchvision_models():