PyRetri/search/search_modules/extract_dict.py

125 lines
2.4 KiB
Python
Raw Permalink Normal View History

2020-04-02 14:00:49 +08:00
# -*- coding: utf-8 -*-
from utils.search_modules import SearchModules
2020-04-15 14:44:22 +08:00
from pyretri.config import get_defaults_cfg
2020-04-02 14:00:49 +08:00
models = SearchModules()
extracts = SearchModules()
models.add(
"imagenet_vgg16",
{
"name": "vgg16",
"vgg16": {
"load_checkpoint": "torchvision://vgg16"
}
}
)
extracts.add(
"imagenet_vgg16",
{
"extractor": {
"name": "VggSeries",
"VggSeries": {
"extract_features": ["all"],
}
},
"splitter": {
"name": "Identity",
},
"aggregators": {
"names": ["Crow", "GAP", "GMP", "GeM", "SPoC"]
},
}
)
models.add(
"imagenet_res50",
{
"name": "resnet50",
"resnet50": {
"load_checkpoint": "torchvision://resnet50"
}
}
)
extracts.add(
"imagenet_res50",
{
"extractor": {
"name": "ResSeries",
"ResSeries": {
"extract_features": ["all"],
}
},
"splitter": {
"name": "Identity",
},
"aggregators": {
"names": ["Crow", "GAP", "GMP", "GeM", "SPoC"]
},
}
)
models.add(
"places365_res50",
{
"name": "resnet50",
"resnet50": {
"load_checkpoint": "/data/places365_model/res50_places365.pt"
}
}
)
extracts.add(
"places365_res50",
{
"extractor": {
"name": "ResSeries",
"ResSeries": {
"extract_features": ["all"],
}
},
"splitter": {
"name": "Identity",
},
"aggregators": {
"names": ["Crow", "GAP", "GMP", "GeM", "SPoC"]
},
}
)
models.add(
"hybrid1365_res50",
{
"name": "resnet50",
"resnet50": {
"load_checkpoint": "/data/places365_model/res50_hybrid1365.pt"
}
}
)
extracts.add(
"hybrid1365_res50",
{
"extractor": {
"name": "ResSeries",
"ResSeries": {
"extract_features": ["all"],
}
},
"splitter": {
"name": "Identity",
},
"aggregators": {
"names": ["Crow", "GAP", "GMP", "GeM", "SPoC"]
},
}
)
cfg = get_defaults_cfg()
models.check_valid(cfg["model"])
extracts.check_valid(cfg["extract"])