mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Fix] Avoid creating a new logger in PretrainedInit (#791)
* use current logger * remove get_current_instance * remove logger parameter at weight_init * remove elif branch
This commit is contained in:
parent
d876d4e0f8
commit
504fdc371a
@ -8,7 +8,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
from mmengine.logging import MMLogger, print_log
|
from mmengine.logging import print_log
|
||||||
from mmengine.registry import WEIGHT_INITIALIZERS, build_from_cfg
|
from mmengine.registry import WEIGHT_INITIALIZERS, build_from_cfg
|
||||||
|
|
||||||
|
|
||||||
@ -481,22 +481,21 @@ class PretrainedInit:
|
|||||||
from mmengine.runner.checkpoint import (_load_checkpoint_with_prefix,
|
from mmengine.runner.checkpoint import (_load_checkpoint_with_prefix,
|
||||||
load_checkpoint,
|
load_checkpoint,
|
||||||
load_state_dict)
|
load_state_dict)
|
||||||
logger = MMLogger.get_instance('mmengine')
|
|
||||||
if self.prefix is None:
|
if self.prefix is None:
|
||||||
print_log(f'load model from: {self.checkpoint}', logger=logger)
|
print_log(f'load model from: {self.checkpoint}', logger='current')
|
||||||
load_checkpoint(
|
load_checkpoint(
|
||||||
module,
|
module,
|
||||||
self.checkpoint,
|
self.checkpoint,
|
||||||
map_location=self.map_location,
|
map_location=self.map_location,
|
||||||
strict=False,
|
strict=False,
|
||||||
logger=logger)
|
logger='current')
|
||||||
else:
|
else:
|
||||||
print_log(
|
print_log(
|
||||||
f'load {self.prefix} in model from: {self.checkpoint}',
|
f'load {self.prefix} in model from: {self.checkpoint}',
|
||||||
logger=logger)
|
logger='current')
|
||||||
state_dict = _load_checkpoint_with_prefix(
|
state_dict = _load_checkpoint_with_prefix(
|
||||||
self.prefix, self.checkpoint, map_location=self.map_location)
|
self.prefix, self.checkpoint, map_location=self.map_location)
|
||||||
load_state_dict(module, state_dict, strict=False, logger=logger)
|
load_state_dict(module, state_dict, strict=False, logger='current')
|
||||||
|
|
||||||
if hasattr(module, '_params_init_info'):
|
if hasattr(module, '_params_init_info'):
|
||||||
update_init_info(module, init_info=self._get_init_info())
|
update_init_info(module, init_info=self._get_init_info())
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
import io
|
import io
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
import pkgutil
|
import pkgutil
|
||||||
@ -106,10 +107,8 @@ def load_state_dict(module, state_dict, strict=False, logger=None):
|
|||||||
err_msg = '\n'.join(err_msg)
|
err_msg = '\n'.join(err_msg)
|
||||||
if strict:
|
if strict:
|
||||||
raise RuntimeError(err_msg)
|
raise RuntimeError(err_msg)
|
||||||
elif logger is not None:
|
|
||||||
logger.warning(err_msg)
|
|
||||||
else:
|
else:
|
||||||
print(err_msg)
|
print_log(err_msg, logger=logger, level=logging.WARNING)
|
||||||
|
|
||||||
|
|
||||||
def get_torchvision_models():
|
def get_torchvision_models():
|
||||||
|
Loading…
x
Reference in New Issue
Block a user