mirror of https://github.com/JDAI-CV/fast-reid.git
fixup finetune problem
Summary: support finetune from the other model with different number of classes, and simplify calling way (#325) close #325 close #325pull/365/head
parent
f496193f17
commit
7e9a4775da
|
@ -44,11 +44,6 @@ def default_argument_parser():
|
|||
"""
|
||||
parser = argparse.ArgumentParser(description="fastreid Training")
|
||||
parser.add_argument("--config-file", default="", metavar="FILE", help="path to config file")
|
||||
parser.add_argument(
|
||||
"--finetune",
|
||||
action="store_true",
|
||||
help="whether to attempt to finetune from the trained model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume",
|
||||
action="store_true",
|
||||
|
@ -244,8 +239,13 @@ class DefaultTrainer(SimpleTrainer):
|
|||
|
||||
def resume_or_load(self, resume=True):
|
||||
"""
|
||||
If `resume==True`, and last checkpoint exists, resume from it.
|
||||
Otherwise, load a model specified by the config.
|
||||
If `resume==True` and `cfg.OUTPUT_DIR` contains the last checkpoint (defined by
|
||||
a `last_checkpoint` file), resume from the file. Resuming means loading all
|
||||
available states (eg. optimizer and scheduler) and update iteration counter
|
||||
from the checkpoint. ``cfg.MODEL.WEIGHTS`` will not be used.
|
||||
Otherwise, this is considered as an independent training. The method will load model
|
||||
weights from the file `cfg.MODEL.WEIGHTS` (but will not load other states) and start
|
||||
from iteration 0.
|
||||
Args:
|
||||
resume (bool): whether to do resume or not
|
||||
"""
|
||||
|
@ -468,7 +468,6 @@ class DefaultTrainer(SimpleTrainer):
|
|||
because some hyper-param, such as MAX_ITER, means training epochs rather than iters,
|
||||
so we need to convert specific hyper-param to training iterations.
|
||||
"""
|
||||
|
||||
cfg = cfg.clone()
|
||||
frozen = cfg.is_frozen()
|
||||
cfg.defrost()
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
||||
|
||||
import collections
|
||||
import copy
|
||||
import logging
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from typing import Any
|
||||
from typing import Optional, List, Dict, NamedTuple, Tuple, Iterable
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -17,6 +17,23 @@ from torch.nn.parallel import DataParallel, DistributedDataParallel
|
|||
from fastreid.utils.file_io import PathManager
|
||||
|
||||
|
||||
class _IncompatibleKeys(
|
||||
NamedTuple(
|
||||
# pyre-fixme[10]: Name `IncompatibleKeys` is used but not defined.
|
||||
"IncompatibleKeys",
|
||||
[
|
||||
("missing_keys", List[str]),
|
||||
("unexpected_keys", List[str]),
|
||||
# pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
|
||||
# pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
|
||||
# pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
|
||||
("incorrect_shapes", List[Tuple]),
|
||||
],
|
||||
)
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
class Checkpointer(object):
|
||||
"""
|
||||
A checkpointer that can save/load model as well as extra checkpointable
|
||||
|
@ -50,7 +67,9 @@ class Checkpointer(object):
|
|||
self.save_dir = save_dir
|
||||
self.save_to_disk = save_to_disk
|
||||
|
||||
def save(self, name: str, **kwargs: dict):
|
||||
self.path_manager = PathManager
|
||||
|
||||
def save(self, name: str, **kwargs: Dict[str, str]):
|
||||
"""
|
||||
Dump model and checkpointables to a file.
|
||||
Args:
|
||||
|
@ -74,13 +93,15 @@ class Checkpointer(object):
|
|||
torch.save(data, f)
|
||||
self.tag_last_checkpoint(basename)
|
||||
|
||||
def load(self, path: str):
|
||||
def load(self, path: str, checkpointables: Optional[List[str]] = None) -> object:
|
||||
"""
|
||||
Load from the given checkpoint. When path points to network file, this
|
||||
function has to be called on all ranks.
|
||||
Args:
|
||||
path (str): path or url to the checkpoint. If empty, will not load
|
||||
anything.
|
||||
checkpointables (list): List of checkpointable names to load. If not
|
||||
specified (None), will load all the possible checkpointables.
|
||||
Returns:
|
||||
dict:
|
||||
extra data loaded from the checkpoint that has not been
|
||||
|
@ -89,21 +110,25 @@ class Checkpointer(object):
|
|||
"""
|
||||
if not path:
|
||||
# no checkpoint provided
|
||||
self.logger.info(
|
||||
"No checkpoint found. Training model from scratch"
|
||||
)
|
||||
self.logger.info("No checkpoint found. Training model from scratch")
|
||||
return {}
|
||||
self.logger.info("Loading checkpoint from {}".format(path))
|
||||
if not os.path.isfile(path):
|
||||
path = PathManager.get_local_path(path)
|
||||
path = self.path_manager.get_local_path(path)
|
||||
assert os.path.isfile(path), "Checkpoint {} not found!".format(path)
|
||||
|
||||
checkpoint = self._load_file(path)
|
||||
self._load_model(checkpoint)
|
||||
for key, obj in self.checkpointables.items():
|
||||
if key in checkpoint:
|
||||
incompatible = self._load_model(checkpoint)
|
||||
if (
|
||||
incompatible is not None
|
||||
): # handle some existing subclasses that returns None
|
||||
self._log_incompatible_keys(incompatible)
|
||||
|
||||
for key in self.checkpointables if checkpointables is None else checkpointables:
|
||||
if key in checkpoint: # pyre-ignore
|
||||
self.logger.info("Loading {} from {}".format(key, path))
|
||||
obj.load_state_dict(checkpoint.pop(key))
|
||||
obj = self.checkpointables[key]
|
||||
obj.load_state_dict(checkpoint.pop(key)) # pyre-ignore
|
||||
|
||||
# return any further checkpoint data
|
||||
return checkpoint
|
||||
|
@ -158,7 +183,9 @@ class Checkpointer(object):
|
|||
"""
|
||||
if resume and self.has_checkpoint():
|
||||
path = self.get_checkpoint_file()
|
||||
return self.load(path)
|
||||
return self.load(path)
|
||||
else:
|
||||
return self.load(path, checkpointables=[])
|
||||
|
||||
def tag_last_checkpoint(self, last_filename_basename: str):
|
||||
"""
|
||||
|
@ -199,26 +226,40 @@ class Checkpointer(object):
|
|||
|
||||
# work around https://github.com/pytorch/pytorch/issues/24139
|
||||
model_state_dict = self.model.state_dict()
|
||||
incorrect_shapes = []
|
||||
for k in list(checkpoint_state_dict.keys()):
|
||||
if k in model_state_dict:
|
||||
shape_model = tuple(model_state_dict[k].shape)
|
||||
shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
|
||||
if shape_model != shape_checkpoint:
|
||||
self.logger.warning(
|
||||
"'{}' has shape {} in the checkpoint but {} in the "
|
||||
"model! Skipped.".format(
|
||||
k, shape_checkpoint, shape_model
|
||||
)
|
||||
)
|
||||
incorrect_shapes.append((k, shape_checkpoint, shape_model))
|
||||
checkpoint_state_dict.pop(k)
|
||||
|
||||
incompatible = self.model.load_state_dict(
|
||||
checkpoint_state_dict, strict=False
|
||||
incompatible = self.model.load_state_dict(checkpoint_state_dict, strict=False)
|
||||
return _IncompatibleKeys(
|
||||
missing_keys=incompatible.missing_keys,
|
||||
unexpected_keys=incompatible.unexpected_keys,
|
||||
incorrect_shapes=incorrect_shapes,
|
||||
)
|
||||
if incompatible.missing_keys:
|
||||
self.logger.info(
|
||||
get_missing_parameters_message(incompatible.missing_keys)
|
||||
|
||||
def _log_incompatible_keys(self, incompatible: _IncompatibleKeys) -> None:
|
||||
"""
|
||||
Log information about the incompatible keys returned by ``_load_model``.
|
||||
"""
|
||||
for k, shape_checkpoint, shape_model in incompatible.incorrect_shapes:
|
||||
self.logger.warning(
|
||||
"Skip loading parameter '{}' to the model due to incompatible "
|
||||
"shapes: {} in the checkpoint but {} in the "
|
||||
"model! You might want to double check if this is expected.".format(
|
||||
k, shape_checkpoint, shape_model
|
||||
)
|
||||
)
|
||||
if incompatible.missing_keys:
|
||||
missing_keys = _filter_reused_missing_keys(
|
||||
self.model, incompatible.missing_keys
|
||||
)
|
||||
if missing_keys:
|
||||
self.logger.info(get_missing_parameters_message(missing_keys))
|
||||
if incompatible.unexpected_keys:
|
||||
self.logger.info(
|
||||
get_unexpected_parameters_message(incompatible.unexpected_keys)
|
||||
|
@ -297,7 +338,27 @@ class PeriodicCheckpointer:
|
|||
self.checkpointer.save(name, **kwargs)
|
||||
|
||||
|
||||
def get_missing_parameters_message(keys: list):
|
||||
def _filter_reused_missing_keys(model: nn.Module, keys: List[str]) -> List[str]:
|
||||
"""
|
||||
Filter "missing keys" to not include keys that have been loaded with another name.
|
||||
"""
|
||||
keyset = set(keys)
|
||||
param_to_names = defaultdict(set) # param -> names that points to it
|
||||
for module_prefix, module in _named_modules_with_dup(model):
|
||||
for name, param in list(module.named_parameters(recurse=False)) + list(
|
||||
module.named_buffers(recurse=False) # pyre-ignore
|
||||
):
|
||||
full_name = (module_prefix + "." if module_prefix else "") + name
|
||||
param_to_names[param].add(full_name)
|
||||
for names in param_to_names.values():
|
||||
# if one name appears missing but its alias exists, then this
|
||||
# name is not considered missing
|
||||
if any(n in keyset for n in names) and not all(n in keyset for n in names):
|
||||
[keyset.remove(n) for n in names if n in keyset]
|
||||
return list(keyset)
|
||||
|
||||
|
||||
def get_missing_parameters_message(keys: List[str]) -> str:
|
||||
"""
|
||||
Get a logging-friendly message to report parameter names (keys) that are in
|
||||
the model but not found in a checkpoint.
|
||||
|
@ -307,14 +368,14 @@ def get_missing_parameters_message(keys: list):
|
|||
str: message.
|
||||
"""
|
||||
groups = _group_checkpoint_keys(keys)
|
||||
msg = "Some model parameters are not in the checkpoint:\n"
|
||||
msg = "Some model parameters or buffers are not found in the checkpoint:\n"
|
||||
msg += "\n".join(
|
||||
" " + colored(k + _group_to_str(v), "blue") for k, v in groups.items()
|
||||
)
|
||||
return msg
|
||||
|
||||
|
||||
def get_unexpected_parameters_message(keys: list):
|
||||
def get_unexpected_parameters_message(keys: List[str]) -> str:
|
||||
"""
|
||||
Get a logging-friendly message to report parameter names (keys) that are in
|
||||
the checkpoint but not found in the model.
|
||||
|
@ -324,15 +385,14 @@ def get_unexpected_parameters_message(keys: list):
|
|||
str: message.
|
||||
"""
|
||||
groups = _group_checkpoint_keys(keys)
|
||||
msg = "The checkpoint contains parameters not used by the model:\n"
|
||||
msg = "The checkpoint state_dict contains keys that are not used by the model:\n"
|
||||
msg += "\n".join(
|
||||
" " + colored(k + _group_to_str(v), "magenta")
|
||||
for k, v in groups.items()
|
||||
" " + colored(k + _group_to_str(v), "magenta") for k, v in groups.items()
|
||||
)
|
||||
return msg
|
||||
|
||||
|
||||
def _strip_prefix_if_present(state_dict: collections.OrderedDict, prefix: str):
|
||||
def _strip_prefix_if_present(state_dict: Dict[str, Any], prefix: str) -> None:
|
||||
"""
|
||||
Strip the prefix in metadata, if any.
|
||||
Args:
|
||||
|
@ -349,7 +409,7 @@ def _strip_prefix_if_present(state_dict: collections.OrderedDict, prefix: str):
|
|||
|
||||
# also strip the prefix in metadata, if any..
|
||||
try:
|
||||
metadata = state_dict._metadata
|
||||
metadata = state_dict._metadata # pyre-ignore
|
||||
except AttributeError:
|
||||
pass
|
||||
else:
|
||||
|
@ -365,7 +425,7 @@ def _strip_prefix_if_present(state_dict: collections.OrderedDict, prefix: str):
|
|||
metadata[newkey] = metadata.pop(key)
|
||||
|
||||
|
||||
def _group_checkpoint_keys(keys: list):
|
||||
def _group_checkpoint_keys(keys: List[str]) -> Dict[str, List[str]]:
|
||||
"""
|
||||
Group keys based on common prefixes. A prefix is the string up to the final
|
||||
"." in each key.
|
||||
|
@ -386,7 +446,7 @@ def _group_checkpoint_keys(keys: list):
|
|||
return groups
|
||||
|
||||
|
||||
def _group_to_str(group: list):
|
||||
def _group_to_str(group: List[str]) -> str:
|
||||
"""
|
||||
Format a group of parameter name suffixes into a loggable string.
|
||||
Args:
|
||||
|
@ -401,3 +461,18 @@ def _group_to_str(group: list):
|
|||
return "." + group[0]
|
||||
|
||||
return ".{" + ", ".join(group) + "}"
|
||||
|
||||
|
||||
def _named_modules_with_dup(
|
||||
model: nn.Module, prefix: str = ""
|
||||
) -> Iterable[Tuple[str, nn.Module]]:
|
||||
"""
|
||||
The same as `model.named_modules()`, except that it includes
|
||||
duplicated modules that have more than one name.
|
||||
"""
|
||||
yield prefix, model
|
||||
for name, module in model._modules.items(): # pyre-ignore
|
||||
if module is None:
|
||||
continue
|
||||
submodule_prefix = prefix + ("." if prefix else "") + name
|
||||
yield from _named_modules_with_dup(module, submodule_prefix)
|
||||
|
|
|
@ -40,7 +40,6 @@ def main(args):
|
|||
return res
|
||||
|
||||
trainer = DefaultTrainer(cfg)
|
||||
if args.finetune: Checkpointer(trainer.model).load(cfg.MODEL.WEIGHTS) # load trained model to funetune
|
||||
|
||||
trainer.resume_or_load(resume=args.resume)
|
||||
return trainer.train()
|
||||
|
|
Loading…
Reference in New Issue