fast-reid/fastreid/export/tf_modeling.py

21 lines
514 B
Python

# encoding: utf-8
"""
@author: l1aoxingyu
@contact: sherlockliao01@gmail.com
"""
from torch import nn
from ..modeling.backbones import build_backbone
from ..modeling.heads import build_reid_heads
class TfMetaArch(nn.Module):
def __init__(self, cfg):
super().__init__()
self.backbone = build_backbone(cfg)
self.heads = build_reid_heads(cfg)
def forward(self, x):
global_feat = self.backbone(x)
pred_features = self.heads(global_feat)
return pred_features