2019-03-17 15:16:48 +08:00
|
|
|
# encoding: utf-8
|
|
|
|
"""
|
|
|
|
@author: sherlock
|
|
|
|
@contact: sherlockliao01@gmail.com
|
|
|
|
"""
|
|
|
|
|
|
|
|
from .baseline import Baseline
|
|
|
|
|
|
|
|
|
|
|
|
def build_model(cfg, num_classes):
|
2019-03-25 22:47:35 +08:00
|
|
|
# if cfg.MODEL.NAME == 'resnet50':
|
|
|
|
# model = Baseline(num_classes, cfg.MODEL.LAST_STRIDE, cfg.MODEL.PRETRAIN_PATH, cfg.MODEL.NECK, cfg.TEST.NECK_FEAT)
|
2019-03-26 16:22:05 +08:00
|
|
|
model = Baseline(num_classes, cfg.MODEL.LAST_STRIDE, cfg.MODEL.PRETRAIN_PATH, cfg.MODEL.NECK, cfg.TEST.NECK_FEAT, cfg.MODEL.NAME, cfg.MODEL.PRETRAIN_CHOICE)
|
2019-03-17 15:16:48 +08:00
|
|
|
return model
|