mirror of
https://github.com/JDAI-CV/fast-reid.git
synced 2025-06-03 14:50:47 +08:00
Upload demo.py and example
This commit is contained in:
parent
3f3975e372
commit
c6e0176c53
10
demo/README.md
Normal file
10
demo/README.md
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
# FastReID Demo
|
||||||
|
|
||||||
|
We provide a command line tool to run a simple demo of builtin models.
|
||||||
|
|
||||||
|
You can run this command to get cosine similarites between different images
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd demo/
|
||||||
|
sh run_demo.sh
|
||||||
|
```
|
133
demo/demo.py
Normal file
133
demo/demo.py
Normal file
@ -0,0 +1,133 @@
|
|||||||
|
# encoding: utf-8
|
||||||
|
"""
|
||||||
|
@author: liaoxingyu
|
||||||
|
@contact: sherlockliao01@gmail.com
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch.backends import cudnn
|
||||||
|
import sys
|
||||||
|
sys.path.append('..')
|
||||||
|
|
||||||
|
from fastreid.config import get_cfg
|
||||||
|
from fastreid.data.transforms import ToTensor
|
||||||
|
from fastreid.modeling import build_model
|
||||||
|
from fastreid.utils.checkpoint import Checkpointer
|
||||||
|
|
||||||
|
cudnn.benchmark = True
|
||||||
|
|
||||||
|
|
||||||
|
def setup_cfg(args):
|
||||||
|
# load config from file and command-line arguments
|
||||||
|
cfg = get_cfg()
|
||||||
|
cfg.merge_from_file(args.config_file)
|
||||||
|
cfg.merge_from_list(args.opts)
|
||||||
|
cfg.freeze()
|
||||||
|
return cfg
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(description="FastReID demo for builtin models")
|
||||||
|
parser.add_argument(
|
||||||
|
"--config-file",
|
||||||
|
default="configs/quick_schedules/mask_rcnn_R_50_FPN_inference_acc_test.yaml",
|
||||||
|
metavar="FILE",
|
||||||
|
help="path to config file",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--input",
|
||||||
|
nargs="+",
|
||||||
|
help="A list of space separated input images; "
|
||||||
|
"or a single glob pattern such as 'directory/*.jpg'",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output",
|
||||||
|
default="traced_module/",
|
||||||
|
help="A file or directory to save export jit module.",
|
||||||
|
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--export-jitmodule",
|
||||||
|
action='store_true',
|
||||||
|
help="If export reid model to traced jit module"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--opts",
|
||||||
|
help="Modify config options using the command-line 'KEY VALUE' pairs",
|
||||||
|
default=[],
|
||||||
|
nargs=argparse.REMAINDER,
|
||||||
|
)
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
class ReidDemo(object):
|
||||||
|
"""
|
||||||
|
ReID demo example
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, cfg):
|
||||||
|
self.cfg = cfg.clone()
|
||||||
|
if cfg.MODEL.WEIGHTS.endswith('.pt'):
|
||||||
|
self.model = torch.jit.load(cfg.MODEL.WEIGHTS)
|
||||||
|
else:
|
||||||
|
self.model = build_model(cfg)
|
||||||
|
# load pre-trained model
|
||||||
|
Checkpointer(self.model).load(cfg.MODEL.WEIGHTS)
|
||||||
|
|
||||||
|
self.model.eval()
|
||||||
|
# self.model = nn.DataParallel(self.model)
|
||||||
|
self.model.cuda()
|
||||||
|
|
||||||
|
num_channels = len(cfg.MODEL.PIXEL_MEAN)
|
||||||
|
self.mean = torch.tensor(cfg.MODEL.PIXEL_MEAN).view(1, num_channels, 1, 1)
|
||||||
|
self.std = torch.tensor(cfg.MODEL.PIXEL_STD).view(1, num_channels, 1, 1)
|
||||||
|
|
||||||
|
def preprocess(self, img):
|
||||||
|
img = cv2.resize(img, tuple(self.cfg.INPUT.SIZE_TEST[::-1]))
|
||||||
|
img = ToTensor()(img)[None, :, :, :]
|
||||||
|
return img.sub_(self.mean).div_(self.std)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def predict(self, img_path):
|
||||||
|
img = cv2.imread(img_path)
|
||||||
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||||
|
data = self.preprocess(img)
|
||||||
|
output = self.model.inference(data.cuda())
|
||||||
|
feat = output.cpu().data.numpy()
|
||||||
|
return feat
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@torch.no_grad()
|
||||||
|
def export_jit_model(cls, cfg, model, output_dir):
|
||||||
|
example = torch.rand(1, len(cfg.MODEL.PIXEL_MEAN), *cfg.INPUT.SIZE_TEST)
|
||||||
|
example = example.cuda()
|
||||||
|
# if isinstance(model, (nn.DistributedDataParallel, nn.DataParallel)):
|
||||||
|
# model = model.module
|
||||||
|
# else:
|
||||||
|
# model = model
|
||||||
|
traced_script_module = torch.jit.trace_module(model, {"inference": example})
|
||||||
|
traced_script_module.save(os.path.join(output_dir, "traced_reid_module.pt"))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
args = get_parser().parse_args()
|
||||||
|
cfg = setup_cfg(args)
|
||||||
|
reidSystem = ReidDemo(cfg)
|
||||||
|
if args.export_jitmodule and not isinstance(reidSystem.model, torch.jit.ScriptModule):
|
||||||
|
reidSystem.export_jit_model(cfg, reidSystem.model, args.output)
|
||||||
|
|
||||||
|
feats = [reidSystem.predict(data) for data in args.input]
|
||||||
|
|
||||||
|
cos_12 = np.dot(feats[0], feats[1].T).item()
|
||||||
|
cos_13 = np.dot(feats[0], feats[2].T).item()
|
||||||
|
cos_23 = np.dot(feats[1], feats[2].T).item()
|
||||||
|
|
||||||
|
print('cosine similarity is {:.4f}, {:.4f}, {:.4f}'.format(cos_12, cos_13, cos_23))
|
9
demo/run_demo.sh
Normal file
9
demo/run_demo.sh
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
gpus='0'
|
||||||
|
CUDA_VISIBLDE_DEVICES=$gpus python demo.py --config-file 'logs/market1501/baseline/config.yaml' \
|
||||||
|
--input \
|
||||||
|
'/export/home/DATA/Market-1501-v15.09.15/bounding_box_test/1182_c5s3_015240_04.jpg' \
|
||||||
|
'/export/home/DATA/Market-1501-v15.09.15/bounding_box_test/1182_c6s3_038217_01.jpg' \
|
||||||
|
'/export/home/DATA/Market-1501-v15.09.15/bounding_box_test/1183_c5s3_006943_05.jpg' \
|
||||||
|
--output 'logs/market1501/baseline/' \
|
||||||
|
--opts MODEL.WEIGHTS 'logs/market1510/baseline/model_final.pth'
|
||||||
|
|
@ -14,7 +14,6 @@ from ...layers import bn_no_bias, Flatten
|
|||||||
|
|
||||||
@REID_HEADS_REGISTRY.register()
|
@REID_HEADS_REGISTRY.register()
|
||||||
class BNneckHead(nn.Module):
|
class BNneckHead(nn.Module):
|
||||||
|
|
||||||
def __init__(self, cfg, in_feat, pool_layer=nn.AdaptiveAvgPool2d(1)):
|
def __init__(self, cfg, in_feat, pool_layer=nn.AdaptiveAvgPool2d(1)):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._num_classes = cfg.MODEL.HEADS.NUM_CLASSES
|
self._num_classes = cfg.MODEL.HEADS.NUM_CLASSES
|
||||||
|
Loading…
x
Reference in New Issue
Block a user