mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Fix] Avoid creating a new logger in PretrainedInit (#791)
* use current logger * remove get_current_instance * remove logger parameter at weight_init * remove elif branch
This commit is contained in:
parent
d876d4e0f8
commit
504fdc371a
@ -8,7 +8,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
|
||||
from mmengine.logging import MMLogger, print_log
|
||||
from mmengine.logging import print_log
|
||||
from mmengine.registry import WEIGHT_INITIALIZERS, build_from_cfg
|
||||
|
||||
|
||||
@ -481,22 +481,21 @@ class PretrainedInit:
|
||||
from mmengine.runner.checkpoint import (_load_checkpoint_with_prefix,
|
||||
load_checkpoint,
|
||||
load_state_dict)
|
||||
logger = MMLogger.get_instance('mmengine')
|
||||
if self.prefix is None:
|
||||
print_log(f'load model from: {self.checkpoint}', logger=logger)
|
||||
print_log(f'load model from: {self.checkpoint}', logger='current')
|
||||
load_checkpoint(
|
||||
module,
|
||||
self.checkpoint,
|
||||
map_location=self.map_location,
|
||||
strict=False,
|
||||
logger=logger)
|
||||
logger='current')
|
||||
else:
|
||||
print_log(
|
||||
f'load {self.prefix} in model from: {self.checkpoint}',
|
||||
logger=logger)
|
||||
logger='current')
|
||||
state_dict = _load_checkpoint_with_prefix(
|
||||
self.prefix, self.checkpoint, map_location=self.map_location)
|
||||
load_state_dict(module, state_dict, strict=False, logger=logger)
|
||||
load_state_dict(module, state_dict, strict=False, logger='current')
|
||||
|
||||
if hasattr(module, '_params_init_info'):
|
||||
update_init_info(module, init_info=self._get_init_info())
|
||||
|
@ -1,5 +1,6 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import os.path as osp
|
||||
import pkgutil
|
||||
@ -106,10 +107,8 @@ def load_state_dict(module, state_dict, strict=False, logger=None):
|
||||
err_msg = '\n'.join(err_msg)
|
||||
if strict:
|
||||
raise RuntimeError(err_msg)
|
||||
elif logger is not None:
|
||||
logger.warning(err_msg)
|
||||
else:
|
||||
print(err_msg)
|
||||
print_log(err_msg, logger=logger, level=logging.WARNING)
|
||||
|
||||
|
||||
def get_torchvision_models():
|
||||
|
Loading…
x
Reference in New Issue
Block a user