mirror of
https://github.com/JDAI-CV/fast-reid.git
synced 2025-06-03 14:50:47 +08:00
69 lines
1.7 KiB
Python
69 lines
1.7 KiB
Python
# encoding: utf-8
|
|
"""
|
|
@author: sherlock
|
|
@contact: sherlockliao01@gmail.com
|
|
"""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
from __future__ import unicode_literals
|
|
|
|
import argparse
|
|
import logging
|
|
import os
|
|
import sys
|
|
from pprint import pprint
|
|
|
|
import torch
|
|
from torch import nn
|
|
from torch.backends import cudnn
|
|
|
|
import network
|
|
from core.config import opt, update_config
|
|
from core.loader import get_data_provider
|
|
from core.solver import Solver
|
|
|
|
FORMAT = '[%(levelname)s]: %(message)s'
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format=FORMAT,
|
|
stream=sys.stdout
|
|
)
|
|
|
|
|
|
def test(args):
|
|
logging.info('======= user config ======')
|
|
logging.info(pprint(opt))
|
|
logging.info(pprint(args))
|
|
logging.info('======= end ======')
|
|
|
|
train_data, test_data, num_query = get_data_provider(opt)
|
|
|
|
net = getattr(network, opt.network.name)(opt.dataset.num_classes, opt.network.last_stride)
|
|
net.load_state_dict(torch.load(args.load_model)['state_dict'])
|
|
net = nn.DataParallel(net).cuda()
|
|
|
|
mod = Solver(opt, net)
|
|
mod.test_func(test_data, num_query)
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description='reid model testing')
|
|
parser.add_argument('--config_file', type=str, default=None,
|
|
help='Optional config file for params')
|
|
parser.add_argument('--load_model', type=str, required=True,
|
|
help='load trained model for testing')
|
|
|
|
args = parser.parse_args()
|
|
if args.config_file is not None:
|
|
update_config(args.config_file)
|
|
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = opt.network.gpus
|
|
cudnn.benchmark = True
|
|
test(args)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|