PyRetri/search/reid_search_modules/extract_dict.py

71 lines
1.3 KiB
Python
Raw Normal View History

2020-04-02 14:00:49 +08:00
# -*- coding: utf-8 -*-
from utils.search_modules import SearchModules
2020-04-15 14:44:22 +08:00
from pyretri.config import get_defaults_cfg
2020-04-02 14:00:49 +08:00
models = SearchModules()
extracts = SearchModules()
models.add(
"market_res50",
{
"name": "ft_net",
"ft_net": {
"load_checkpoint": "/data/my_model_zoo/res50_market1501.pth"
}
}
)
extracts.add(
"market_res50",
{
2020-04-17 20:31:26 +08:00
"assemble": 1,
2020-04-02 14:00:49 +08:00
"extractor": {
"name": "ReIDSeries",
"ReIDSeries": {
"extract_features": ["output"],
}
},
"splitter": {
"name": "Identity",
},
"aggregators": {
"names": ["GAP"]
},
}
)
models.add(
"duke_res50",
{
"name": "ft_net",
"ft_net": {
2020-04-17 20:31:26 +08:00
"load_checkpoint": "/home/songrenjie/projects/reID_baseline/model/ft_ResNet50/res50_duke.pth"
2020-04-02 14:00:49 +08:00
}
}
)
extracts.add(
"duke_res50",
{
2020-04-17 20:31:26 +08:00
"assemble": 1,
2020-04-02 14:00:49 +08:00
"extractor": {
"name": "ReIDSeries",
"ReIDSeries": {
"extract_features": ["output"],
}
},
"splitter": {
"name": "Identity",
},
"aggregators": {
"names": ["GAP"]
},
}
)
cfg = get_defaults_cfg()
models.check_valid(cfg["model"])
extracts.check_valid(cfg["extract"])