mirror of https://github.com/JosephKJ/OWOD.git
74 lines
3.0 KiB
Python
74 lines
3.0 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
|
import pickle
|
|
from fvcore.common.checkpoint import Checkpointer
|
|
from fvcore.common.file_io import PathManager
|
|
|
|
import detectron2.utils.comm as comm
|
|
|
|
from .c2_model_loading import align_and_update_state_dicts
|
|
|
|
|
|
class DetectionCheckpointer(Checkpointer):
|
|
"""
|
|
Same as :class:`Checkpointer`, but is able to handle models in detectron & detectron2
|
|
model zoo, and apply conversions for legacy models.
|
|
"""
|
|
|
|
def __init__(self, model, save_dir="", *, save_to_disk=None, **checkpointables):
|
|
is_main_process = comm.is_main_process()
|
|
super().__init__(
|
|
model,
|
|
save_dir,
|
|
save_to_disk=is_main_process if save_to_disk is None else save_to_disk,
|
|
**checkpointables,
|
|
)
|
|
|
|
def _load_file(self, filename):
|
|
if filename.endswith(".pkl"):
|
|
with PathManager.open(filename, "rb") as f:
|
|
data = pickle.load(f, encoding="latin1")
|
|
if "model" in data and "__author__" in data:
|
|
# file is in Detectron2 model zoo format
|
|
self.logger.info("Reading a file from '{}'".format(data["__author__"]))
|
|
return data
|
|
else:
|
|
# assume file is from Caffe2 / Detectron1 model zoo
|
|
if "blobs" in data:
|
|
# Detection models have "blobs", but ImageNet models don't
|
|
data = data["blobs"]
|
|
data = {k: v for k, v in data.items() if not k.endswith("_momentum")}
|
|
return {"model": data, "__author__": "Caffe2", "matching_heuristics": True}
|
|
|
|
loaded = super()._load_file(filename) # load native pth checkpoint
|
|
if "model" not in loaded:
|
|
loaded = {"model": loaded}
|
|
return loaded
|
|
|
|
def _load_model(self, checkpoint):
|
|
if checkpoint.get("matching_heuristics", False):
|
|
self._convert_ndarray_to_tensor(checkpoint["model"])
|
|
# convert weights by name-matching heuristics
|
|
model_state_dict = self.model.state_dict()
|
|
align_and_update_state_dicts(
|
|
model_state_dict,
|
|
checkpoint["model"],
|
|
c2_conversion=checkpoint.get("__author__", None) == "Caffe2",
|
|
)
|
|
checkpoint["model"] = model_state_dict
|
|
# for non-caffe2 models, use standard ways to load it
|
|
incompatible = super()._load_model(checkpoint)
|
|
if incompatible is None: # support older versions of fvcore
|
|
return None
|
|
|
|
model_buffers = dict(self.model.named_buffers(recurse=False))
|
|
for k in ["pixel_mean", "pixel_std"]:
|
|
# Ignore missing key message about pixel_mean/std.
|
|
# Though they may be missing in old checkpoints, they will be correctly
|
|
# initialized from config anyway.
|
|
if k in model_buffers:
|
|
try:
|
|
incompatible.missing_keys.remove(k)
|
|
except ValueError:
|
|
pass
|
|
return incompatible
|