mirror of https://github.com/YifanXu74/MQ-Det.git
22 lines
795 B
Python
22 lines
795 B
Python
|
import argparse
|
||
|
import os
|
||
|
import sys
|
||
|
|
||
|
import numpy as np
|
||
|
import torch
|
||
|
from PIL import Image, ImageDraw, ImageFont
|
||
|
|
||
|
import groundingdino_new.datasets.transforms as T
|
||
|
from groundingdino_new.models import build_model
|
||
|
from groundingdino_new.util import box_ops
|
||
|
from groundingdino_new.util.slconfig import SLConfig
|
||
|
from groundingdino_new.util.utils import clean_state_dict, get_phrases_from_posmap
|
||
|
|
||
|
def load_model(model_config_path, model_checkpoint_path, cpu_only=False):
|
||
|
args = SLConfig.fromfile(model_config_path)
|
||
|
args.device = "cuda" if not cpu_only else "cpu"
|
||
|
model = build_model(args)
|
||
|
checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
|
||
|
load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
|
||
|
print(load_res)
|
||
|
return model
|