style(configs): put all config files together

put all config files into one place for easily control,
and add tools for put train_net.py which almost the same in
different projects
This commit is contained in:
liaoxingyu 2020-04-29 16:18:54 +08:00
parent e38a799b63
commit ec19bcc1d3
17 changed files with 536 additions and 209 deletions

31
MODEL_ZOO.md Normal file
View File

@ -0,0 +1,31 @@
# FastReID Model Zoo and Baselines
## Introduction
## Market1501 Baselines
| Method | Pretrained | Rank@1 | mAP | mINP |
| :---: | :---: | :---: |:---: | :---: |
| BagTricks(R50) | ImageNet | 93.6% | 85.1% | 58.1% |
| BagTricks(R50-ibn) | ImageNet | 94.8% | 87.3% | 63.5% |
| AGW(R50) | ImageNet | 94.9% | 87.4% | 63.1% |
| stronger baseline(R50-ibn) | ImageNet | 95.5% | 88.4% | 65.8% |
## DukeMTMC Baseline
| Method | Pretrained | Rank@1 | mAP | mINP |
| :---: | :---: | :---: |:---: | :---: |
| BagTricks(R50) | ImageNet | 86.1% | 75.9% | 38.7% |
| BagTricks(R50-ibn) | ImageNet | 89.0% | 78.8% | 43.6% |
| AGW(R50) | ImageNet | 88.9% | 79.1% | 43.2% |
| stronger baseline(R50-ibn) | ImageNet | 91.3% | 81.6% | 47.6% |
### MSMT17 Baseline
| Method | Pretrained | Rank@1 | mAP | mINP |
| :---: | :---: | :---: |:---: | :---: |
| BagTricks(R50) | ImageNet | 70.4% | 47.5% | 9.6% |
| BagTricks(R50-ibn) | ImageNet | 76.9% | 55.0% | 13.5% |
| AGW(R50) | ImageNet | 75.6% | 52.6% | 11.9% |
| stronger baseline(R50-ibn) | ImageNet | 84.2% | 61.5% | 15.7% |

View File

@ -11,13 +11,12 @@ The designed architecture follows this guide [PyTorch-Project-Template](https://
3. Install dependencies:
- [pytorch 1.0.0+](https://pytorch.org/)
- torchvision
- tensorboard
- [yacs](https://github.com/rbgirshick/yacs)
4. Prepare dataset
Create a directory to store reid datasets under projects, for example
```bash
cd fast-reid/projects/StrongBaseline
cd fast-reid
mkdir datasets
```
@ -34,7 +33,7 @@ The designed architecture follows this guide [PyTorch-Project-Template](https://
5. Prepare pretrained model.
If you use origin ResNet, you do not need to do anything. But if you want to use ResNet_ibn, you need to download pretrain model in [here](https://drive.google.com/open?id=1thS2B8UOSBi_cJX6zRy6YYRwz_nVFI_S). And then you can put it in `~/.cache/torch/checkpoints` or anywhere you like.
Then you should set the pretrain model path in `configs/baseline_market1501.yml`.
Then you should set the pretrain model path in `configs/Base-bagtricks.yml`.
6. compile with cython to accelerate evalution
@ -44,28 +43,4 @@ The designed architecture follows this guide [PyTorch-Project-Template](https://
## Model Zoo and Baselines
### Market1501 dataset
| Method | Pretrained | Rank@1 | mAP | mINP |
| :---: | :---: | :---: |:---: | :---: |
| BagTricks | ImageNet | 93.6% | 85.1% | 58.1% |
| BagTricks + Ibn-a | ImageNet | 94.8% | 87.3% | 63.5% |
| AGW | ImageNet | 94.9% | 87.4% | 63.1% |
### DukeMTMC dataset
| Method | Pretrained | Rank@1 | mAP | mINP |
| :---: | :---: | :---: |:---: | :---: |
| BagTricks | ImageNet | 86.1% | 75.9% | 38.7% |
| BagTricks + Ibn-a | ImageNet | 89.0% | 78.8% | 43.6% |
| AGW | ImageNet | 88.9% | 79.1% | 43.2% |
### MSMT17 dataset
| Method | Pretrained | Rank@1 | mAP | mINP |
| :---: | :---: | :---: |:---: | :---: |
| BagTricks | ImageNet | 70.4% | 47.5% | 9.6% |
| BagTricks + Ibn-a | ImageNet | 76.9% | 55.0% | 13.5% |
| AGW | ImageNet | 75.6% | 52.6% | 11.9% |
We provide a large set of baseline results and trained models available for download in the [Fastreid Model Zoo](https://github.com/JDAI-CV/fast-reid/blob/master/MODEL_ZOO.md).

71
configs/Base-AGW.yml Normal file
View File

@ -0,0 +1,71 @@
MODEL:
META_ARCHITECTURE: 'Baseline'
BACKBONE:
NAME: "build_resnet_backbone"
DEPTH: 50
LAST_STRIDE: 1
WITH_NL: True
PRETRAIN: True
HEADS:
NAME: "BNneckHead"
POOL_LAYER: "gempool"
CLS_LAYER: "linear"
NUM_CLASSES: 702
LOSSES:
NAME: ("CrossEntropyLoss", "TripletLoss")
CE:
EPSILON: 0.1
SCALE: 1.0
TRI:
MARGIN: 0.0
HARD_MINING: False
USE_COSINE_DIST: False
SCALE: 1.0
DATASETS:
NAMES: ("DukeMTMC",)
TESTS: ("DukeMTMC",)
INPUT:
SIZE_TRAIN: [256, 128]
SIZE_TEST: [256, 128]
REA:
ENABLED: True
PROB: 0.5
MEAN: [123.675, 116.28, 103.53]
DO_PAD: True
DATALOADER:
PK_SAMPLER: True
NUM_INSTANCE: 4
NUM_WORKERS: 16
SOLVER:
OPT: "Adam"
MAX_ITER: 18000
BASE_LR: 0.00035
BIAS_LR_FACTOR: 2.
WEIGHT_DECAY: 0.0005
WEIGHT_DECAY_BIAS: 0.0005
IMS_PER_BATCH: 64
STEPS: [8000, 14000]
GAMMA: 0.1
WARMUP_FACTOR: 0.01
WARMUP_ITERS: 2000
LOG_PERIOD: 200
CHECKPOINT_PERIOD: 6000
TEST:
EVAL_PERIOD: 2000
IMS_PER_BATCH: 512
CUDNN_BENCHMARK: True
OUTPUT_DIR: "logs"

View File

@ -0,0 +1,84 @@
MODEL:
META_ARCHITECTURE: 'Baseline'
OPEN_LAYERS: "heads"
BACKBONE:
NAME: "build_resnet_backbone"
DEPTH: 50
LAST_STRIDE: 1
WITH_IBN: False
WITH_NL: False
PRETRAIN: True
HEADS:
NAME: "BNneckHead"
CLS_LAYER: "circle"
POOL_LAYER: "gempool"
SCALE: 64
MARGIN: 0.35
LOSSES:
NAME: ("CrossEntropyLoss", "TripletLoss",)
CE:
EPSILON: 0.1
SCALE: 1.0
TRI:
MARGIN: 0.0
HARD_MINING: True
NORM_FEAT: False
USE_COSINE_DIST: False
SCALE: 1.0
DATASETS:
NAMES: ("DukeMTMC",)
TESTS: ("DukeMTMC",)
INPUT:
SIZE_TRAIN: [384, 128]
SIZE_TEST: [384, 128]
DO_AUTOAUG: True
REA:
ENABLED: True
PROB: 0.5
MEAN: [123.675, 116.28, 103.53]
DO_PAD: True
DATALOADER:
PK_SAMPLER: True
NUM_INSTANCE: 16
NUM_WORKERS: 16
SOLVER:
OPT: "Adam"
MAX_ITER: 18000
BASE_LR: 0.00035
BIAS_LR_FACTOR: 2.
WEIGHT_DECAY: 0.0005
WEIGHT_DECAY_BIAS: 0.0
IMS_PER_BATCH: 64
SCHED: "DelayedCosineAnnealingLR"
DELAY_ITERS: 2000
ETA_MIN_LR: 0.00000077
STEPS: [8000, 14000]
GAMMA: 0.1
WARMUP_FACTOR: 0.01
WARMUP_ITERS: 2000
FREEZE_ITERS: 2000
LOG_PERIOD: 200
CHECKPOINT_PERIOD: 6000
TEST:
EVAL_PERIOD: 2000
IMS_PER_BATCH: 512
CUDNN_BENCHMARK: True
OUTPUT_DIR: "logs/dukemtmc/softmax"

View File

@ -0,0 +1,73 @@
MODEL:
META_ARCHITECTURE: 'Baseline'
OPEN_LAYERS: ""
BACKBONE:
NAME: "build_resnet_backbone"
DEPTH: 50
LAST_STRIDE: 1
WITH_IBN: False
PRETRAIN: True
HEADS:
NAME: "BNneckHead"
CLS_LAYER: "linear"
LOSSES:
NAME: ("CrossEntropyLoss", "TripletLoss")
CE:
EPSILON: 0.1
SCALE: 1.0
TRI:
MARGIN: 0.0
HARD_MINING: False
USE_COSINE_DIST: False
SCALE: 1.0
DATASETS:
NAMES: ("DukeMTMC",)
TESTS: ("DukeMTMC",)
INPUT:
SIZE_TRAIN: [384, 128]
SIZE_TEST: [384, 128]
REA:
ENABLED: True
PROB: 0.5
MEAN: [123.675, 116.28, 103.53]
DO_PAD: True
DATALOADER:
PK_SAMPLER: True
NUM_INSTANCE: 4
NUM_WORKERS: 16
SOLVER:
OPT: "Adam"
MAX_ITER: 18000
BASE_LR: 0.00035
BIAS_LR_FACTOR: 2.
WEIGHT_DECAY: 0.0005
WEIGHT_DECAY_BIAS: 0.
IMS_PER_BATCH: 128
STEPS: [8000, 14000]
GAMMA: 0.1
WARMUP_FACTOR: 0.01
WARMUP_ITERS: 2000
LOG_PERIOD: 200
CHECKPOINT_PERIOD: 2000
TEST:
EVAL_PERIOD: 2000
IMS_PER_BATCH: 512
CUDNN_BENCHMARK: True
OUTPUT_DIR: "logs/dukemtmc/softmax"

18
configs/DukeMTMC/AGW.yml Normal file
View File

@ -0,0 +1,18 @@
_BASE_: "../Base-AGW.yml"
MODEL:
HEADS:
NUM_CLASSES: 702
SOLVER:
MAX_ITER: 23000
STEPS: [10000, 18000]
WARMUP_ITERS: 2500
DATASETS:
NAMES: ("DukeMTMC",)
TESTS: ("DukeMTMC",)
OUTPUT_DIR: "logs/dukemtmc/agw"

View File

@ -0,0 +1,18 @@
_BASE_: "../Base-bagtricks.yml"
MODEL:
HEADS:
NUM_CLASSES: 702
SOLVER:
MAX_ITER: 12500
STEPS: [5000, 9000]
WARMUP_ITERS: 1250
DATASETS:
NAMES: ("DukeMTMC",)
TESTS: ("DukeMTMC",)
OUTPUT_DIR: "logs/dukemtmc/resnet50_baseline"

23
configs/DukeMTMC/sbs.yml Normal file
View File

@ -0,0 +1,23 @@
_BASE_: "../Base-Strongerbaseline.yml"
MODEL:
BACKBONE:
NAME: "build_resnest_backbone"
WITH_IBN: False
WITH_NL: True
PRETRAIN_PATH: "/home/liaoxingyu2/lxy/.cache/torch/checkpoints/resnet50_ibn_a.pth.tar"
HEADS:
NUM_CLASSES: 702
SOLVER:
MAX_ITER: 5
DELAY_ITERS: 8000
WARMUP_ITERS: 2600
FREEZE_ITERS: 2600
DATASETS:
NAMES: ("DukeMTMC",)
TESTS: ("DukeMTMC",)
OUTPUT_DIR: "logs/dukemtmc/test_prebn"

22
configs/MSMT17/AGW.yml Normal file
View File

@ -0,0 +1,22 @@
_BASE_: "../Base-AGW.yml"
MODEL:
HEADS:
NUM_CLASSES: 1041
DATASETS:
NAMES: ("MSMT17",)
TESTS: ("MSMT17",)
SOLVER:
MAX_ITER: 42000
STEPS: [19000, 33000]
WARMUP_ITERS: 4700
CHECKPOINT_PERIOD: 5000
TEST:
EVAL_PERIOD: 5000
OUTPUT_DIR: "logs/msmt17/agw"

View File

@ -0,0 +1,21 @@
_BASE_: "../Base-bagtricks.yml"
MODEL:
HEADS:
NUM_CLASSES: 1041
DATASETS:
NAMES: ("MSMT17",)
TESTS: ("MSMT17",)
SOLVER:
MAX_ITER: 42000
STEPS: [19000, 33000]
WARMUP_ITERS: 4700
CHECKPOINT_PERIOD: 5000
TEST:
EVAL_PERIOD: 5000
OUTPUT_DIR: "logs/msmt17/bagtricks"

29
configs/MSMT17/sbs.yml Normal file
View File

@ -0,0 +1,29 @@
_BASE_: "../Base-Strongerbaseline.yml"
MODEL:
BACKBONE:
NAME: "build_resnest_backbone"
WITH_IBN: False
WITH_NL: True
PRETRAIN_PATH: "/home/liaoxingyu2/lxy/.cache/torch/checkpoints/resnet50_ibn_a.pth.tar"
HEADS:
NUM_CLASSES: 1041
DATASETS:
NAMES: ("MSMT17",)
TESTS: ("MSMT17",)
SOLVER:
MAX_ITER: 29000
DELAY_ITERS: 14000
WARMUP_ITERS: 4700
FREEZE_ITERS: 4700
CHECKPOINT_PERIOD: 4000
TEST:
EVAL_PERIOD: 4000
OUTPUT_DIR: "logs/msmt17/resnest-nl-gem-circle_s64m0.35_loss-cos_delay-autoaug"

View File

@ -0,0 +1,18 @@
_BASE_: "../Base-AGW.yml"
MODEL:
HEADS:
NUM_CLASSES: 751
SOLVER:
MAX_ITER: 18000
STEPS: [8000, 14000]
WARMUP_ITERS: 2000
DATASETS:
NAMES: ("Market1501",)
TESTS: ("Market1501",)
OUTPUT_DIR: "logs/market1501/agw"

View File

@ -0,0 +1,19 @@
_BASE_: "../Base-bagtricks.yml"
MODEL:
HEADS:
NUM_CLASSES: 751
SOLVER:
MAX_ITER: 18000
STEPS: [8000, 14000]
WARMUP_ITERS: 2000
DATASETS:
NAMES: ("Market1501",)
TESTS: ("Market1501",)
OUTPUT_DIR: "logs/market1501/bagtricks"

View File

@ -0,0 +1,26 @@
_BASE_: "../Base-Strongerbaseline.yml"
MODEL:
BACKBONE:
NAME: "build_resnest_backbone"
WITH_IBN: False
WITH_NL: True
PRETRAIN_PATH: "/home/liaoxingyu2/lxy/.cache/torch/checkpoints/resnet50_ibn_a.pth.tar"
HEADS:
NUM_CLASSES: 751
SOLVER:
MAX_ITER: 16000
DELAY_ITERS: 8000
WARMUP_ITERS: 2600
FREEZE_ITERS: 2600
DATASETS:
NAMES: ("Market1501",)
TESTS: ("Market1501",)
TEST:
EVAL_PERIOD: 2000
OUTPUT_DIR: "logs/market1501/resnest-nl-gem-circle_s64m0.35_loss-cos_delay-autoaug"

View File

@ -1,166 +0,0 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import torch.nn.functional as F
from collections import defaultdict
import argparse
import json
import os
import sys
import time
import cv2
import numpy as np
import torch
from torch.backends import cudnn
from fastreid.modeling import build_model
from fastreid.utils.checkpoint import Checkpointer
from fastreid.config import get_cfg
cudnn.benchmark = True
class Reid(object):
def __init__(self, config_file):
cfg = get_cfg()
cfg.merge_from_file(config_file)
cfg.defrost()
cfg.MODEL.WEIGHTS = 'projects/bjzProject/logs/bjz/arcface_adam/model_final.pth'
model = build_model(cfg)
Checkpointer(model).resume_or_load(cfg.MODEL.WEIGHTS)
model.cuda()
model.eval()
self.model = model
# self.model = torch.jit.load("reid_model.pt")
# self.model.eval()
# self.model.cuda()
example = torch.rand(1, 3, 256, 128)
example = example.cuda()
traced_script_module = torch.jit.trace_module(model, {'inference': example})
traced_script_module.save("reid_feat_extractor.pt")
@classmethod
def preprocess(cls, img_path):
img = cv2.imread(img_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (128, 256))
img = img / 255.0
img = (img - [0.485, 0.456, 0.406]) / [0.229, 0.224, 0.225]
img = img.transpose((2, 0, 1)).astype(np.float32)
img = img[np.newaxis, :, :, :]
data = torch.from_numpy(img).cuda().float()
return data
@torch.no_grad()
def demo(self, img_path):
data = self.preprocess(img_path)
output = self.model.inference(data)
feat = output.cpu().data.numpy()
return feat
# @torch.no_grad()
# def extract_feat(self, dataloader):
# prefetcher = test_data_prefetcher(dataloader)
# feats = []
# labels = []
# batch = prefetcher.next()
# num_count = 0
# while batch[0] is not None:
# img, pid, camid = batch
# feat = self.model(img)
# feats.append(feat.cpu())
# labels.extend(np.asarray(pid))
#
# # if num_count > 2:
# # break
# batch = prefetcher.next()
# # num_count += 1
#
# feats = torch.cat(feats, dim=0)
# id_feats = defaultdict(list)
# for f, i in zip(feats, labels):
# id_feats[i].append(f)
# all_feats = []
# label_names = []
# for i in id_feats:
# all_feats.append(torch.stack(id_feats[i], dim=0).mean(dim=0))
# label_names.append(i)
#
# label_names = np.asarray(label_names)
# all_feats = torch.stack(all_feats, dim=0) # (n, 2048)
# all_feats = F.normalize(all_feats, p=2, dim=1)
# np.save('feats.npy', all_feats.cpu())
# np.save('labels.npy', label_names)
# cos = torch.mm(all_feats, all_feats.t()).numpy() # (n, n)
# cos -= np.eye(all_feats.shape[0])
# f = open('check_cross_folder_similarity.txt', 'w')
# for i in range(len(label_names)):
# sim_indx = np.argwhere(cos[i] > 0.5)[:, 0]
# sim_name = label_names[sim_indx]
# write_str = label_names[i] + ' '
# # f.write(label_names[i]+'\t')
# for n in sim_name:
# write_str += (n + ' ')
# # f.write(n+'\t')
# f.write(write_str+'\n')
#
#
# def prepare_gt(self, json_file):
# feat = []
# label = []
# with open(json_file, 'r') as f:
# total = json.load(f)
# for index in total:
# label.append(index)
# feat.append(np.array(total[index]))
# time_label = [int(i[0:10]) for i in label]
#
# return np.array(feat), np.array(label), np.array(time_label)
def compute_topk(self, k, feat, feats, label):
# num_gallery = feats.shape[0]
# new_feat = np.tile(feat,[num_gallery,1])
norm_feat = np.sqrt(np.sum(np.square(feat), axis=-1))
norm_feats = np.sqrt(np.sum(np.square(feats), axis=-1))
matrix = np.sum(np.multiply(feat, feats), axis=-1)
dist = matrix / np.multiply(norm_feat, norm_feats)
# print('feat:',feat.shape)
# print('feats:',feats.shape)
# print('label:',label.shape)
# print('dist:',dist.shape)
index = np.argsort(-dist)
# print('index:',index.shape)
result = []
for i in range(min(feats.shape[0], k)):
print(dist[index[i]])
result.append(label[index[i]])
return result
if __name__ == '__main__':
reid_sys = Reid(config_file='../../projects/bjzProject/configs/bjz.yml')
img_path = '/export/home/lxy/beijingStationReID/reid_model/demo_imgs/003740_c5s2_1561733125170.000000.jpg'
feat = reid_sys.demo(img_path)
feat_extractor = torch.jit.load('reid_feat_extractor.pt')
data = reid_sys.preprocess(img_path)
feat2 = feat_extractor.inference(data)
from ipdb import set_trace; set_trace()
# imgs = os.listdir(img_path)
# feats = {}
# for i in range(len(imgs)):
# feat = reid.demo(os.path.join(img_path, imgs[i]))
# feats[imgs[i]] = feat
# feat = reid.demo(os.path.join(img_path, 'crop_img0.jpg'))
# out1 = feats['dog.jpg']
# out2 = feats['kobe2.jpg']
# innerProduct = np.dot(out1, out2.T)
# cosineSimilarity = innerProduct / (np.linalg.norm(out1, ord=2) * np.linalg.norm(out2, ord=2))
# print(f'cosine similarity is {cosineSimilarity[0][0]:.4f}')

View File

@ -6,11 +6,12 @@
import torch
import torch.nn.functional as F
from torch.nn import Conv2d, Module, ReLU
from torch import nn
from torch.nn import Conv2d, ReLU
from torch.nn.modules.utils import _pair
class SplAtConv2d(Module):
class SplAtConv2d(nn.Module):
"""Split-Attention Conv2d
"""
@ -36,13 +37,15 @@ class SplAtConv2d(Module):
self.conv = Conv2d(in_channels, channels * radix, kernel_size, stride, padding, dilation,
groups=groups * radix, bias=bias, **kwargs)
self.use_bn = norm_layer is not None
self.bn0 = norm_layer(channels * radix)
if self.use_bn:
self.bn0 = norm_layer(channels * radix)
self.relu = ReLU(inplace=True)
self.fc1 = Conv2d(channels, inter_channels, 1, groups=self.cardinality)
self.bn1 = norm_layer(inter_channels)
if self.use_bn:
self.bn1 = norm_layer(inter_channels)
self.fc2 = Conv2d(inter_channels, channels * radix, 1, groups=self.cardinality)
if dropblock_prob > 0.0:
self.dropblock = DropBlock2D(dropblock_prob, 3)
self.rsoftmax = rSoftMax(radix, groups)
def forward(self, x):
x = self.conv(x)
@ -52,9 +55,9 @@ class SplAtConv2d(Module):
x = self.dropblock(x)
x = self.relu(x)
batch, channel = x.shape[:2]
batch, rchannel = x.shape[:2]
if self.radix > 1:
splited = torch.split(x, channel // self.radix, dim=1)
splited = torch.split(x, rchannel // self.radix, dim=1)
gap = sum(splited)
else:
gap = x
@ -65,15 +68,29 @@ class SplAtConv2d(Module):
gap = self.bn1(gap)
gap = self.relu(gap)
atten = self.fc2(gap).view((batch, self.radix, self.channels))
if self.radix > 1:
atten = F.softmax(atten, dim=1).view(batch, -1, 1, 1)
else:
atten = F.sigmoid(atten, dim=1).view(batch, -1, 1, 1)
atten = self.fc2(gap)
atten = self.rsoftmax(atten).view(batch, -1, 1, 1)
if self.radix > 1:
atten = torch.split(atten, channel // self.radix, dim=1)
out = sum([att * split for (att, split) in zip(atten, splited)])
attens = torch.split(atten, rchannel // self.radix, dim=1)
out = sum([att * split for (att, split) in zip(attens, splited)])
else:
out = atten * x
return out.contiguous()
class rSoftMax(nn.Module):
def __init__(self, radix, cardinality):
super().__init__()
self.radix = radix
self.cardinality = cardinality
def forward(self, x):
batch = x.size(0)
if self.radix > 1:
x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2)
x = F.softmax(x, dim=1)
x = x.reshape(batch, -1)
else:
x = torch.sigmoid(x)
return x

48
tools/export2tf.py Normal file
View File

@ -0,0 +1,48 @@
# encoding: utf-8
"""
@author: sherlock
@contact: sherlockliao01@gmail.com
"""
import sys
import torch
sys.path.append('../..')
from fastreid.config import get_cfg
from fastreid.engine import default_argument_parser, default_setup
from fastreid.modeling.meta_arch import build_model
from fastreid.export.tensorflow_export import export_tf_reid_model
from fastreid.export.tf_modeling import TfMetaArch
def setup(args):
"""
Create configs and perform basic setups.
"""
cfg = get_cfg()
# cfg.merge_from_file(args.config_file)
cfg.merge_from_list(args.opts)
cfg.freeze()
default_setup(cfg, args)
return cfg
if __name__ == "__main__":
args = default_argument_parser().parse_args()
print("Command Line Args:", args)
cfg = setup(args)
cfg.defrost()
cfg.MODEL.BACKBONE.NAME = "build_resnet_backbone"
cfg.MODEL.BACKBONE.DEPTH = 50
cfg.MODEL.BACKBONE.LAST_STRIDE = 1
# If use IBN block in backbone
cfg.MODEL.BACKBONE.WITH_IBN = False
cfg.MODEL.BACKBONE.PRETRAIN = False
from torchvision.models import resnet50
# model = TfMetaArch(cfg)
model = resnet50(pretrained=False)
# model.load_params_wo_fc(torch.load('logs/bjstation/res50_baseline_v0.4/ckpts/model_epoch80.pth'))
model.eval()
dummy_inputs = torch.randn(1, 3, 256, 128)
export_tf_reid_model(model, dummy_inputs, 'reid_tf.pb')