mirror of https://github.com/JDAI-CV/fast-reid.git
add pytorch to caffe converting
Summary: update deployment README and support pytorch to caffe converting for basemodelpull/78/head
parent
6ed6d25c6c
commit
2f6d999469
|
@ -27,6 +27,14 @@ Learn more at out [documentation](). And see [projects/](https://github.com/JDAI
|
|||
|
||||
We provide a large set of baseline results and trained models available for download in the [Fastreid Model Zoo](https://github.com/JDAI-CV/fast-reid/blob/master/MODEL_ZOO.md).
|
||||
|
||||
## Deployment
|
||||
|
||||
We provide some examples and scripts to convert fastreid model to Caffe, ONNX and TensorRT format in [Fastreid deploy](https://github.com/JDAI-CV/fast-reid/blob/master/tools/deploy).
|
||||
|
||||
## License
|
||||
|
||||
Fastreid is released under the [Apache 2.0 license](https://github.com/JDAI-CV/fast-reid/blob/master/LICENSE).
|
||||
|
||||
## Citing Fastreid
|
||||
|
||||
If you use Fastreid in your research or wish to refer to the baseline results published in the Model Zoo, please use the following BibTeX entry.
|
||||
|
|
17
demo/demo.py
17
demo/demo.py
|
@ -17,6 +17,7 @@ from torch.backends import cudnn
|
|||
sys.path.append('..')
|
||||
|
||||
from fastreid.config import get_cfg
|
||||
from fastreid.utils.file_io import PathManager
|
||||
from predictor import FeatureExtractionDemo
|
||||
|
||||
cudnn.benchmark = True
|
||||
|
@ -54,6 +55,11 @@ def get_parser():
|
|||
help="A list of space separated input images; "
|
||||
"or a single glob pattern such as 'directory/*.jpg'",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
default='demo_output',
|
||||
help='path to save features'
|
||||
)
|
||||
parser.add_argument(
|
||||
"--opts",
|
||||
help="Modify config options using the command-line 'KEY VALUE' pairs",
|
||||
|
@ -68,16 +74,13 @@ if __name__ == '__main__':
|
|||
cfg = setup_cfg(args)
|
||||
demo = FeatureExtractionDemo(cfg, device=args.device, parallel=args.parallel)
|
||||
|
||||
feats = []
|
||||
PathManager.mkdirs(args.output)
|
||||
if args.input:
|
||||
if len(args.input) == 1:
|
||||
if PathManager.isdir(args.input[0]):
|
||||
args.input = glob.glob(os.path.expanduser(args.input[0]))
|
||||
assert args.input, "The input path(s) was not found"
|
||||
for path in tqdm.tqdm(args.input):
|
||||
img = cv2.imread(path)
|
||||
feat = demo.run_on_image(img)
|
||||
feats.append(feat.numpy())
|
||||
|
||||
cos_sim = np.dot(feats[0], feats[1].T).item()
|
||||
|
||||
print('cosine similarity of the first two images is {:.4f}'.format(cos_sim))
|
||||
feat = feat.numpy()
|
||||
np.save(os.path.join(args.output, path.replace('.jpg', '.npy').split('/')[-1]), feat)
|
||||
|
|
|
@ -0,0 +1,21 @@
|
|||
# The Caffe in PytorchToCaffe Provides some convenient API
|
||||
If there are some problem in parse your prototxt or caffemodel, Please replace
|
||||
the caffe.proto with your own version and compile it with command
|
||||
`protoc --python_out ./ caffe.proto`
|
||||
|
||||
## caffe_net.py
|
||||
Using `from nn_tools.Caffe import caffe_net` to import this model
|
||||
### Prototxt
|
||||
+ `net=caffe_net.Prototxt(file_name)` to open a prototxt file
|
||||
+ `net.init_caffemodel(caffe_cmd_path='caffe')` to generate a caffemodel file in the current work directory \
|
||||
if your `caffe` cmd not in the $PATH, specify your caffe cmd path by the `caffe_cmd_path` kwargs.
|
||||
### Caffemodel
|
||||
+ `net=caffe_net.Caffemodel(file_name)` to open a caffemodel
|
||||
+ `net.save_prototxt(path)` to save the caffemodel to a prototxt file (not containing the weight data)
|
||||
+ `net.get_layer_data(layer_name)` return the numpy ndarray data of the layer
|
||||
+ `net.set_layer_date(layer_name, datas)` specify the data of one layer in the caffemodel .`datas` is normally a list of numpy ndarray `[weights,bias]`
|
||||
+ `net.save(path)` save the changed caffemodel
|
||||
### Functions for both Prototxt and Caffemodel
|
||||
+ `net.add_layer(layer_params,before='',after='')` add a new layer with `Layer_Param` object
|
||||
+ `net.remove_layer_by_name(layer_name)`
|
||||
+ `net.get_layer_by_name(layer_name)` or `net.layer(layer_name)` get the raw Layer object defined in caffe_pb2
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,35 @@
|
|||
import lmdb
|
||||
from Caffe import caffe_pb2 as pb2
|
||||
import numpy as np
|
||||
|
||||
class Read_Caffe_LMDB():
|
||||
def __init__(self,path,dtype=np.uint8):
|
||||
|
||||
self.env=lmdb.open(path, readonly=True)
|
||||
self.dtype=dtype
|
||||
self.txn=self.env.begin()
|
||||
self.cursor=self.txn.cursor()
|
||||
|
||||
@staticmethod
|
||||
def to_numpy(value,dtype=np.uint8):
|
||||
datum = pb2.Datum()
|
||||
datum.ParseFromString(value)
|
||||
flat_x = np.fromstring(datum.data, dtype=dtype)
|
||||
data = flat_x.reshape(datum.channels, datum.height, datum.width)
|
||||
label=flat_x = datum.label
|
||||
return data,label
|
||||
|
||||
def iterator(self):
|
||||
while True:
|
||||
key,value=self.cursor.key(),self.cursor.value()
|
||||
yield self.to_numpy(value,self.dtype)
|
||||
if not self.cursor.next():
|
||||
return
|
||||
|
||||
def __iter__(self):
|
||||
self.cursor.first()
|
||||
it = self.iterator()
|
||||
return it
|
||||
|
||||
def __len__(self):
|
||||
return int(self.env.stat()['entries'])
|
|
@ -0,0 +1,157 @@
|
|||
from __future__ import absolute_import
|
||||
from . import caffe_pb2 as pb
|
||||
import google.protobuf.text_format as text_format
|
||||
import numpy as np
|
||||
from .layer_param import Layer_param
|
||||
|
||||
class _Net(object):
|
||||
def __init__(self):
|
||||
self.net=pb.NetParameter()
|
||||
self.needChange = {}
|
||||
|
||||
def layer_index(self,layer_name):
|
||||
# find a layer's index by name. if the layer was found, return the layer position in the net, else return -1.
|
||||
for i, layer in enumerate(self.net.layer):
|
||||
if layer.name == layer_name:
|
||||
return i
|
||||
|
||||
def add_layer(self,layer_params,before='',after=''):
|
||||
# find the before of after layer's position
|
||||
index = -1
|
||||
if after != '':
|
||||
index = self.layer_index(after) + 1
|
||||
if before != '':
|
||||
index = self.layer_index(before)
|
||||
new_layer = pb.LayerParameter()
|
||||
new_layer.CopyFrom(layer_params.param)
|
||||
#insert the layer into the layer protolist
|
||||
if index != -1:
|
||||
self.net.layer.add()
|
||||
for i in range(len(self.net.layer) - 1, index, -1):
|
||||
self.net.layer[i].CopyFrom(self.net.layer[i - 1])
|
||||
self.net.layer[index].CopyFrom(new_layer)
|
||||
else:
|
||||
self.net.layer.extend([new_layer])
|
||||
|
||||
def remove_layer_by_name(self,layer_name):
|
||||
for i,layer in enumerate(self.net.layer):
|
||||
if layer.name == layer_name:
|
||||
del self.net.layer[i]
|
||||
return
|
||||
raise(AttributeError, "cannot found layer %s" % str(layer_name))
|
||||
|
||||
|
||||
|
||||
def remove_layer_by_type(self,type_name):
|
||||
for i,layer in enumerate(self.net.layer):
|
||||
if layer.type == type_name:
|
||||
# self.change_layer_bottom(layer.top,layer.bottom)
|
||||
s1 = "\"" + layer.top[0] + "\""
|
||||
s2 = "\"" + layer.bottom[0] + "\""
|
||||
self.needChange[s1]=s2
|
||||
del self.net.layer[i]
|
||||
return
|
||||
|
||||
|
||||
|
||||
def get_layer_by_name(self, layer_name):
|
||||
# get the layer by layer_name
|
||||
for layer in self.net.layer:
|
||||
if layer.name == layer_name:
|
||||
return layer
|
||||
raise(AttributeError, "cannot found layer %s" % str(layer_name))
|
||||
|
||||
def save_prototxt(self,path):
|
||||
prototxt=pb.NetParameter()
|
||||
prototxt.CopyFrom(self.net)
|
||||
for layer in prototxt.layer:
|
||||
del layer.blobs[:]
|
||||
with open(path,'w') as f:
|
||||
string = text_format.MessageToString(prototxt)
|
||||
for origin_name in self.needChange.keys():
|
||||
string = string.replace(origin_name,self.needChange[origin_name])
|
||||
f.write(string)
|
||||
|
||||
def layer(self,layer_name):
|
||||
return self.get_layer_by_name(layer_name)
|
||||
|
||||
def layers(self):
|
||||
return list(self.net.layer)
|
||||
|
||||
|
||||
|
||||
class Prototxt(_Net):
|
||||
def __init__(self,file_name=''):
|
||||
super(Prototxt,self).__init__()
|
||||
self.file_name=file_name
|
||||
if file_name!='':
|
||||
f = open(file_name,'r')
|
||||
text_format.Parse(f.read(), self.net)
|
||||
pass
|
||||
|
||||
def init_caffemodel(self,caffe_cmd_path='caffe'):
|
||||
"""
|
||||
:param caffe_cmd_path: The shell command of caffe, normally at <path-to-caffe>/build/tools/caffe
|
||||
"""
|
||||
s=pb.SolverParameter()
|
||||
s.train_net=self.file_name
|
||||
s.max_iter=0
|
||||
s.base_lr=1
|
||||
s.solver_mode = pb.SolverParameter.CPU
|
||||
s.snapshot_prefix='./nn'
|
||||
with open('/tmp/nn_tools_solver.prototxt','w') as f:
|
||||
f.write(str(s))
|
||||
import os
|
||||
os.system('%s train --solver /tmp/nn_tools_solver.prototxt'%caffe_cmd_path)
|
||||
|
||||
class Caffemodel(_Net):
|
||||
def __init__(self, file_name=''):
|
||||
super(Caffemodel,self).__init__()
|
||||
# caffe_model dir
|
||||
if file_name!='':
|
||||
f = open(file_name,'rb')
|
||||
self.net.ParseFromString(f.read())
|
||||
f.close()
|
||||
|
||||
def save(self, path):
|
||||
with open(path,'wb') as f:
|
||||
f.write(self.net.SerializeToString())
|
||||
|
||||
def add_layer_with_data(self,layer_params,datas, before='', after=''):
|
||||
"""
|
||||
Args:
|
||||
layer_params:A Layer_Param object
|
||||
datas:a fixed dimension numpy object list
|
||||
after: put the layer after a specified layer
|
||||
before: put the layer before a specified layer
|
||||
"""
|
||||
self.add_layer(layer_params,before,after)
|
||||
new_layer =self.layer(layer_params.name)
|
||||
|
||||
#process blobs
|
||||
del new_layer.blobs[:]
|
||||
for data in datas:
|
||||
new_blob=new_layer.blobs.add()
|
||||
for dim in data.shape:
|
||||
new_blob.shape.dim.append(dim)
|
||||
new_blob.data.extend(data.flatten().astype(float))
|
||||
|
||||
def get_layer_data(self,layer_name):
|
||||
layer=self.layer(layer_name)
|
||||
datas=[]
|
||||
for blob in layer.blobs:
|
||||
shape=list(blob.shape.dim)
|
||||
data=np.array(blob.data).reshape(shape)
|
||||
datas.append(data)
|
||||
return datas
|
||||
|
||||
def set_layer_data(self,layer_name,datas):
|
||||
# datas is normally a list of [weights,bias]
|
||||
layer=self.layer(layer_name)
|
||||
for blob,data in zip(layer.blobs,datas):
|
||||
blob.data[:]=data.flatten()
|
||||
pass
|
||||
|
||||
class Net():
|
||||
def __init__(self,*args,**kwargs):
|
||||
raise(TypeError,'the class Net is no longer used, please use Caffemodel or Prototxt instead')
|
File diff suppressed because one or more lines are too long
|
@ -0,0 +1,183 @@
|
|||
from __future__ import absolute_import
|
||||
from . import caffe_pb2 as pb
|
||||
import numpy as np
|
||||
|
||||
def pair_process(item,strict_one=True):
|
||||
if hasattr(item,'__iter__'):
|
||||
for i in item:
|
||||
if i!=item[0]:
|
||||
if strict_one:
|
||||
raise ValueError("number in item {} must be the same".format(item))
|
||||
else:
|
||||
print("IMPORTANT WARNING: number in item {} must be the same".format(item))
|
||||
return item[0]
|
||||
return item
|
||||
|
||||
def pair_reduce(item):
|
||||
if hasattr(item,'__iter__'):
|
||||
for i in item:
|
||||
if i!=item[0]:
|
||||
return item
|
||||
return [item[0]]
|
||||
return [item]
|
||||
|
||||
class Layer_param():
|
||||
def __init__(self,name='',type='',top=(),bottom=()):
|
||||
self.param=pb.LayerParameter()
|
||||
self.name=self.param.name=name
|
||||
self.type=self.param.type=type
|
||||
|
||||
self.top=self.param.top
|
||||
self.top.extend(top)
|
||||
self.bottom=self.param.bottom
|
||||
self.bottom.extend(bottom)
|
||||
|
||||
def fc_param(self, num_output, weight_filler='xavier', bias_filler='constant',has_bias=True):
|
||||
if self.type != 'InnerProduct':
|
||||
raise TypeError('the layer type must be InnerProduct if you want set fc param')
|
||||
fc_param = pb.InnerProductParameter()
|
||||
fc_param.num_output = num_output
|
||||
fc_param.weight_filler.type = weight_filler
|
||||
fc_param.bias_term = has_bias
|
||||
if has_bias:
|
||||
fc_param.bias_filler.type = bias_filler
|
||||
self.param.inner_product_param.CopyFrom(fc_param)
|
||||
|
||||
def conv_param(self, num_output, kernel_size, stride=(1), pad=(0,),
|
||||
weight_filler_type='xavier', bias_filler_type='constant',
|
||||
bias_term=True, dilation=None,groups=None):
|
||||
"""
|
||||
add a conv_param layer if you spec the layer type "Convolution"
|
||||
Args:
|
||||
num_output: a int
|
||||
kernel_size: int list
|
||||
stride: a int list
|
||||
weight_filler_type: the weight filer type
|
||||
bias_filler_type: the bias filler type
|
||||
Returns:
|
||||
"""
|
||||
if self.type not in ['Convolution','Deconvolution']:
|
||||
raise TypeError('the layer type must be Convolution or Deconvolution if you want set conv param')
|
||||
conv_param=pb.ConvolutionParameter()
|
||||
conv_param.num_output=num_output
|
||||
conv_param.kernel_size.extend(pair_reduce(kernel_size))
|
||||
conv_param.stride.extend(pair_reduce(stride))
|
||||
conv_param.pad.extend(pair_reduce(pad))
|
||||
conv_param.bias_term=bias_term
|
||||
conv_param.weight_filler.type=weight_filler_type
|
||||
if bias_term:
|
||||
conv_param.bias_filler.type = bias_filler_type
|
||||
if dilation:
|
||||
conv_param.dilation.extend(pair_reduce(dilation))
|
||||
if groups:
|
||||
conv_param.group=groups
|
||||
if groups != 1:
|
||||
conv_param.engine = 1
|
||||
self.param.convolution_param.CopyFrom(conv_param)
|
||||
|
||||
def norm_param(self, eps):
|
||||
"""
|
||||
add a conv_param layer if you spec the layer type "Convolution"
|
||||
Args:
|
||||
num_output: a int
|
||||
kernel_size: int list
|
||||
stride: a int list
|
||||
weight_filler_type: the weight filer type
|
||||
bias_filler_type: the bias filler type
|
||||
Returns:
|
||||
"""
|
||||
l2norm_param = pb.NormalizeParameter()
|
||||
l2norm_param.across_spatial = False
|
||||
l2norm_param.channel_shared = False
|
||||
l2norm_param.eps = eps
|
||||
self.param.norm_param.CopyFrom(l2norm_param)
|
||||
|
||||
|
||||
def permute_param(self, order1, order2, order3, order4):
|
||||
"""
|
||||
add a conv_param layer if you spec the layer type "Convolution"
|
||||
Args:
|
||||
num_output: a int
|
||||
kernel_size: int list
|
||||
stride: a int list
|
||||
weight_filler_type: the weight filer type
|
||||
bias_filler_type: the bias filler type
|
||||
Returns:
|
||||
"""
|
||||
permute_param = pb.PermuteParameter()
|
||||
permute_param.order.extend([order1, order2, order3, order4])
|
||||
|
||||
self.param.permute_param.CopyFrom(permute_param)
|
||||
|
||||
|
||||
def pool_param(self,type='MAX',kernel_size=2,stride=2,pad=None, ceil_mode = True):
|
||||
pool_param=pb.PoolingParameter()
|
||||
pool_param.pool=pool_param.PoolMethod.Value(type)
|
||||
pool_param.kernel_size=pair_process(kernel_size)
|
||||
pool_param.stride=pair_process(stride)
|
||||
pool_param.ceil_mode=ceil_mode
|
||||
if pad:
|
||||
if isinstance(pad,tuple):
|
||||
pool_param.pad_h = pad[0]
|
||||
pool_param.pad_w = pad[1]
|
||||
else:
|
||||
pool_param.pad=pad
|
||||
self.param.pooling_param.CopyFrom(pool_param)
|
||||
|
||||
def batch_norm_param(self,use_global_stats=0,moving_average_fraction=None,eps=None):
|
||||
bn_param=pb.BatchNormParameter()
|
||||
bn_param.use_global_stats=use_global_stats
|
||||
if moving_average_fraction:
|
||||
bn_param.moving_average_fraction=moving_average_fraction
|
||||
if eps:
|
||||
bn_param.eps = eps
|
||||
self.param.batch_norm_param.CopyFrom(bn_param)
|
||||
|
||||
# layer
|
||||
# {
|
||||
# name: "upsample_layer"
|
||||
# type: "Upsample"
|
||||
# bottom: "some_input_feature_map"
|
||||
# bottom: "some_input_pool_index"
|
||||
# top: "some_output"
|
||||
# upsample_param {
|
||||
# upsample_h: 224
|
||||
# upsample_w: 224
|
||||
# }
|
||||
# }
|
||||
def upsample_param(self,size=None, scale_factor=None):
|
||||
upsample_param=pb.UpsampleParameter()
|
||||
if scale_factor:
|
||||
if isinstance(scale_factor,int):
|
||||
upsample_param.scale = scale_factor
|
||||
else:
|
||||
upsample_param.scale_h = scale_factor[0]
|
||||
upsample_param.scale_w = scale_factor[1]
|
||||
|
||||
if size:
|
||||
if isinstance(size,int):
|
||||
upsample_param.upsample_h = size
|
||||
else:
|
||||
upsample_param.upsample_h = size[0] * scale_factor
|
||||
upsample_param.\
|
||||
upsample_w = size[1] * scale_factor
|
||||
self.param.upsample_param.CopyFrom(upsample_param)
|
||||
|
||||
def add_data(self,*args):
|
||||
"""Args are data numpy array
|
||||
"""
|
||||
del self.param.blobs[:]
|
||||
for data in args:
|
||||
new_blob = self.param.blobs.add()
|
||||
for dim in data.shape:
|
||||
new_blob.shape.dim.append(dim)
|
||||
new_blob.data.extend(data.flatten().astype(float))
|
||||
|
||||
def set_params_by_dict(self,dic):
|
||||
pass
|
||||
|
||||
def copy_from(self,layer_param):
|
||||
pass
|
||||
|
||||
def set_enum(param,key,value):
|
||||
setattr(param,key,param.Value(value))
|
|
@ -0,0 +1 @@
|
|||
raise ImportError,'the nn_tools.Caffe.net is no longer used, please use nn_tools.Caffe.caffe_net'
|
|
@ -0,0 +1,102 @@
|
|||
# Deployment
|
||||
|
||||
This directory contains:
|
||||
|
||||
1. A script that converts a fastreid model to Caffe format.
|
||||
|
||||
2. An exmpale that loads a R50 baseline model in Caffe and run inference.
|
||||
|
||||
## Tutorial
|
||||
|
||||
This is a tiny example steps for convert baseline `meta_arch` to Caffe model, if you want to convert more complext architecture, you need to customize more things.
|
||||
|
||||
1. Change `preprocess_image` in `fastreid/modeling/meta_arch/baseline.py` as below
|
||||
|
||||
```python
|
||||
def preprocess_image(self, batched_inputs):
|
||||
"""
|
||||
Normalize and batch the input images.
|
||||
"""
|
||||
# images = [x["images"] for x in batched_inputs]
|
||||
# images = batched_inputs["images"]
|
||||
images = batched_inputs
|
||||
images.sub_(self.pixel_mean).div_(self.pixel_std)
|
||||
return images
|
||||
```
|
||||
|
||||
2. Run `caffe_export.py` to get the converted Caffe model,
|
||||
|
||||
```bash
|
||||
python caffe_export.py --config-file "/export/home/lxy/fast-reid/logs/market1501/bagtricks_R50/config.yaml" --name "baseline_R50" --output "logs/caffe_model" --opts MODEL.WEIGHTS "/export/home/lxy/fast-reid/logs/market1501/bagtricks_R50/model_final.pth"
|
||||
```
|
||||
|
||||
then you can check the Caffe model and prototxt in `logs/caffe_model`.
|
||||
|
||||
3. Change `prototxt` following next three steps:
|
||||
|
||||
1) Edit `max_pooling` in `baseline_R50.prototxt` like this
|
||||
|
||||
```prototxt
|
||||
layer {
|
||||
name: "max_pool1"
|
||||
type: "Pooling"
|
||||
bottom: "relu_blob1"
|
||||
top: "max_pool_blob1"
|
||||
pooling_param {
|
||||
pool: MAX
|
||||
kernel_size: 3
|
||||
stride: 2
|
||||
pad: 0 # 1
|
||||
# ceil_mode: false
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
2) Add `avg_pooling` right place in `baseline_R50.prototxt`
|
||||
|
||||
```prototxt
|
||||
layer {
|
||||
name: "avgpool1"
|
||||
type: "Pooling"
|
||||
bottom: "relu_blob49"
|
||||
top: "avgpool_blob1"
|
||||
pooling_param {
|
||||
pool: AVE
|
||||
global_pooling: true
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
3) Change the last layer `top` name to `output`
|
||||
|
||||
```prototxt
|
||||
layer {
|
||||
name: "bn_scale54"
|
||||
type: "Scale"
|
||||
bottom: "batch_norm_blob54"
|
||||
top: "output" # bn_norm_blob54
|
||||
scale_param {
|
||||
bias_term: true
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
4. (optional) You can open [Netscope](https://ethereon.github.io/netscope/quickstart.html), then enter you network `prototxt` to visualize the network.
|
||||
|
||||
5. Run `caffe_inference.py` to save Caffe model features with input images
|
||||
|
||||
```bash
|
||||
python caffe_inference.py --model-def "logs/caffe_model/baseline_R50.prototxt" \
|
||||
--model-weights "logs/caffe_model/baseline_R50.caffemodel" \
|
||||
--input \
|
||||
'/export/home/DATA/Market-1501-v15.09.15/bounding_box_test/1182_c5s3_015240_04.jpg' \
|
||||
'/export/home/DATA/Market-1501-v15.09.15/bounding_box_test/1182_c6s3_038217_01.jpg' \
|
||||
'/export/home/DATA/Market-1501-v15.09.15/bounding_box_test/1183_c5s3_006943_05.jpg' \
|
||||
--output "caffe_R34_output"
|
||||
```
|
||||
|
||||
6. Run `demo/demo.py` to get fastreid model features with the same input images, then compute the cosine similarity of difference model features to verify if you convert Caffe model successfully.
|
||||
|
||||
## Acknowledgements
|
||||
|
||||
Thank to [CPFLAME](https://github.com/CPFLAME), [](), [YuxiangJohn](https://github.com/YuxiangJohn) and []() at JDAI Model Acceleration Group for help in PyTorch to Caffe model converting.
|
|
@ -0,0 +1,74 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: xingyu liao
|
||||
@contact: liaoxingyu5@jd.com
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
import sys
|
||||
sys.path.append('../../')
|
||||
|
||||
import pytorch_to_caffe
|
||||
from fastreid.config import get_cfg
|
||||
from fastreid.modeling.meta_arch import build_model
|
||||
from fastreid.utils.file_io import PathManager
|
||||
from fastreid.utils.checkpoint import Checkpointer
|
||||
|
||||
|
||||
def setup_cfg(args):
|
||||
cfg = get_cfg()
|
||||
cfg.merge_from_file(args.config_file)
|
||||
cfg.merge_from_list(args.opts)
|
||||
cfg.freeze()
|
||||
return cfg
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(description="Convert Pytorch to Caffe model")
|
||||
|
||||
parser.add_argument(
|
||||
"--config-file",
|
||||
metavar="FILE",
|
||||
help="path to config file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--name",
|
||||
default="baseline",
|
||||
help="name for converted model"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
default='caffe_model',
|
||||
help='path to save converted caffe model'
|
||||
)
|
||||
parser.add_argument(
|
||||
"--opts",
|
||||
help="Modify config options using the command-line 'KEY VALUE' pairs",
|
||||
default=[],
|
||||
nargs=argparse.REMAINDER,
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = get_parser().parse_args()
|
||||
cfg = setup_cfg(args)
|
||||
|
||||
cfg.defrost()
|
||||
cfg.MODEL.BACKBONE.PRETRAIN = False
|
||||
cfg.MODEL.HEADS.POOL_LAYER = "identity"
|
||||
cfg.MODEL.BACKBONE.WITH_NL = False
|
||||
|
||||
model = build_model(cfg)
|
||||
Checkpointer(model).load(cfg.MODEL.WEIGHTS)
|
||||
model.cuda()
|
||||
model.eval()
|
||||
print(model)
|
||||
|
||||
inputs = torch.randn(1, 3, cfg.INPUT.SIZE_TEST[0], cfg.INPUT.SIZE_TEST[1]).cuda()
|
||||
PathManager.mkdirs(args.output)
|
||||
pytorch_to_caffe.trans_net(model, inputs, args.name)
|
||||
pytorch_to_caffe.save_prototxt(f"{args.output}/{args.name}.prototxt")
|
||||
pytorch_to_caffe.save_caffemodel(f"{args.output}/{args.name}.caffemodel")
|
|
@ -0,0 +1,95 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: xingyu liao
|
||||
@contact: liaoxingyu5@jd.com
|
||||
"""
|
||||
|
||||
import caffe
|
||||
import tqdm
|
||||
import glob
|
||||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
caffe.set_mode_gpu()
|
||||
|
||||
import argparse
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(description="Caffe model inference")
|
||||
|
||||
parser.add_argument(
|
||||
"--model-def",
|
||||
default="logs/test_caffe/baseline_R50.prototxt",
|
||||
help="caffe model prototxt"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-weights",
|
||||
default="logs/test_caffe/baseline_R50.caffemodel",
|
||||
help="caffe model weights"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input",
|
||||
nargs="+",
|
||||
help="A list of space separated input images; "
|
||||
"or a single glob pattern such as 'directory/*.jpg'",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
default='caffe_output',
|
||||
help='path to save converted caffe model'
|
||||
)
|
||||
parser.add_argument(
|
||||
"--height",
|
||||
type=int,
|
||||
default=384,
|
||||
help="height of image"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--width",
|
||||
type=int,
|
||||
default=128,
|
||||
help="width of image"
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def preprocess(image_path, image_height, image_width):
|
||||
original_image = cv2.imread(image_path)
|
||||
# the model expects RGB inputs
|
||||
original_image = original_image[:, :, ::-1]
|
||||
|
||||
# Apply pre-processing to image.
|
||||
image = cv2.resize(original_image, (image_width, image_height), interpolation=cv2.INTER_CUBIC)
|
||||
image = image.astype("float32").transpose(2, 0, 1)[np.newaxis] # (1, 3, h, w)
|
||||
image = (image - np.array([0.485 * 255, 0.456 * 255, 0.406 * 255]).reshape((1, -1, 1, 1))) / np.array(
|
||||
[0.229 * 255, 0.224 * 255, 0.225 * 255]).reshape((1, -1, 1, 1))
|
||||
return image
|
||||
|
||||
|
||||
def normalize(nparray, order=2, axis=-1):
|
||||
"""Normalize a N-D numpy array along the specified axis."""
|
||||
norm = np.linalg.norm(nparray, ord=order, axis=axis, keepdims=True)
|
||||
return nparray / (norm + np.finfo(np.float32).eps)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_parser().parse_args()
|
||||
|
||||
net = caffe.Net(args.model_def, args.model_weights, caffe.TEST)
|
||||
net.blobs['blob1'].reshape(1, 3, args.height, args.width)
|
||||
|
||||
if not os.path.exists(args.output): os.makedirs(args.output)
|
||||
|
||||
if args.input:
|
||||
if os.path.isdir(args.input[0]):
|
||||
args.input = glob.glob(os.path.expanduser(args.input[0]))
|
||||
assert args.input, "The input path(s) was not found"
|
||||
for path in tqdm.tqdm(args.input):
|
||||
image = preprocess(path, args.height, args.width)
|
||||
net.blobs['blob1'].data[...] = image
|
||||
feat = net.forward()['output']
|
||||
feat = normalize(feat[..., 0, 0], axis=1)
|
||||
np.save(os.path.join(args.output, path.replace('.jpg', '.npy').split('/')[-1]), feat)
|
||||
|
|
@ -0,0 +1,48 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: sherlock
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import sys
|
||||
|
||||
import torch
|
||||
sys.path.append('../..')
|
||||
from fastreid.config import get_cfg
|
||||
from fastreid.engine import default_argument_parser, default_setup
|
||||
from fastreid.modeling.meta_arch import build_model
|
||||
from fastreid.export.tensorflow_export import export_tf_reid_model
|
||||
from fastreid.export.tf_modeling import TfMetaArch
|
||||
|
||||
|
||||
def setup(args):
|
||||
"""
|
||||
Create configs and perform basic setups.
|
||||
"""
|
||||
cfg = get_cfg()
|
||||
# cfg.merge_from_file(args.config_file)
|
||||
cfg.merge_from_list(args.opts)
|
||||
cfg.freeze()
|
||||
default_setup(cfg, args)
|
||||
return cfg
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = default_argument_parser().parse_args()
|
||||
print("Command Line Args:", args)
|
||||
cfg = setup(args)
|
||||
cfg.defrost()
|
||||
cfg.MODEL.BACKBONE.NAME = "build_resnet_backbone"
|
||||
cfg.MODEL.BACKBONE.DEPTH = 50
|
||||
cfg.MODEL.BACKBONE.LAST_STRIDE = 1
|
||||
# If use IBN block in backbone
|
||||
cfg.MODEL.BACKBONE.WITH_IBN = False
|
||||
cfg.MODEL.BACKBONE.PRETRAIN = False
|
||||
|
||||
from torchvision.models import resnet50
|
||||
# model = TfMetaArch(cfg)
|
||||
model = resnet50(pretrained=False)
|
||||
# model.load_params_wo_fc(torch.load('logs/bjstation/res50_baseline_v0.4/ckpts/model_epoch80.pth'))
|
||||
model.eval()
|
||||
dummy_inputs = torch.randn(1, 3, 256, 128)
|
||||
export_tf_reid_model(model, dummy_inputs, 'reid_tf.pb')
|
|
@ -0,0 +1,783 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import traceback
|
||||
from Caffe import caffe_net
|
||||
import torch.nn.functional as F
|
||||
from torch.autograd import Variable
|
||||
from Caffe import layer_param
|
||||
from torch.nn.modules.utils import _pair
|
||||
import numpy as np
|
||||
|
||||
"""
|
||||
How to support a new layer type:
|
||||
layer_name=log.add_layer(layer_type_name)
|
||||
top_blobs=log.add_blobs(<output of that layer>)
|
||||
layer=caffe_net.Layer_param(xxx)
|
||||
<set layer parameters>
|
||||
[<layer.add_data(*datas)>]
|
||||
log.cnet.add_layer(layer)
|
||||
|
||||
Please MUTE the inplace operations to avoid not find in graph
|
||||
|
||||
注意:只有torch.nn.functional中的函数才能转换为caffe中的层
|
||||
"""
|
||||
|
||||
# TODO: support the inplace output of the layers
|
||||
|
||||
class Blob_LOG():
|
||||
def __init__(self):
|
||||
self.data={}
|
||||
def __setitem__(self, key, value):
|
||||
self.data[key]=value
|
||||
def __getitem__(self, key):
|
||||
return self.data[key]
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
NET_INITTED=False
|
||||
|
||||
# 转换原理解析:通过记录
|
||||
class TransLog(object):
|
||||
def __init__(self):
|
||||
"""
|
||||
doing init() with inputs Variable before using it
|
||||
"""
|
||||
self.layers={}
|
||||
self.detail_layers={}
|
||||
self.detail_blobs={}
|
||||
self._blobs=Blob_LOG()
|
||||
self._blobs_data=[]
|
||||
self.cnet=caffe_net.Caffemodel('')
|
||||
self.debug=True
|
||||
|
||||
def init(self,inputs):
|
||||
"""
|
||||
:param inputs: is a list of input variables
|
||||
"""
|
||||
self.add_blobs(inputs)
|
||||
def add_layer(self,name='layer'):
|
||||
if name in self.layers:
|
||||
return self.layers[name]
|
||||
if name not in self.detail_layers.keys():
|
||||
self.detail_layers[name] =0
|
||||
self.detail_layers[name] +=1
|
||||
name='{}{}'.format(name,self.detail_layers[name])
|
||||
self.layers[name]=name
|
||||
if self.debug:
|
||||
print("{} was added to layers".format(self.layers[name]))
|
||||
return self.layers[name]
|
||||
|
||||
def add_blobs(self, blobs,name='blob',with_num=True):
|
||||
rst=[]
|
||||
for blob in blobs:
|
||||
self._blobs_data.append(blob) # to block the memory address be rewrited
|
||||
blob_id=int(id(blob))
|
||||
if name not in self.detail_blobs.keys():
|
||||
self.detail_blobs[name] =0
|
||||
self.detail_blobs[name] +=1
|
||||
if with_num:
|
||||
rst.append('{}{}'.format(name,self.detail_blobs[name]))
|
||||
else:
|
||||
rst.append('{}'.format(name))
|
||||
if self.debug:
|
||||
print("{}:{} was added to blobs".format(blob_id,rst[-1]))
|
||||
# print('Add blob {} : {}'.format(rst[-1].center(21),blob.size()))
|
||||
self._blobs[blob_id]=rst[-1]
|
||||
return rst
|
||||
def blobs(self, var):
|
||||
var=id(var)
|
||||
# if self.debug:
|
||||
# print("{}:{} getting".format(var, self._blobs[var]))
|
||||
try:
|
||||
return self._blobs[var]
|
||||
except:
|
||||
print("WARNING: CANNOT FOUND blob {}".format(var))
|
||||
return None
|
||||
|
||||
log=TransLog()
|
||||
|
||||
layer_names={}
|
||||
def _conv2d(raw,input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
|
||||
print('conv: ',log.blobs(input))
|
||||
x=raw(input,weight,bias,stride,padding,dilation,groups)
|
||||
name=log.add_layer(name='conv')
|
||||
log.add_blobs([x],name='conv_blob')
|
||||
layer=caffe_net.Layer_param(name=name, type='Convolution',
|
||||
bottom=[log.blobs(input)], top=[log.blobs(x)])
|
||||
layer.conv_param(x.size()[1],weight.size()[2:],stride=_pair(stride),
|
||||
pad=_pair(padding),dilation=_pair(dilation),bias_term=bias is not None,groups=groups)
|
||||
if bias is not None:
|
||||
layer.add_data(weight.cpu().data.numpy(),bias.cpu().data.numpy())
|
||||
else:
|
||||
layer.param.convolution_param.bias_term=False
|
||||
layer.add_data(weight.cpu().data.numpy())
|
||||
log.cnet.add_layer(layer)
|
||||
return x
|
||||
|
||||
def _conv_transpose2d(raw,input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
|
||||
x=raw(input, weight, bias, stride, padding, output_padding, groups, dilation)
|
||||
name=log.add_layer(name='conv_transpose')
|
||||
log.add_blobs([x],name='conv_transpose_blob')
|
||||
layer=caffe_net.Layer_param(name=name, type='Deconvolution',
|
||||
bottom=[log.blobs(input)], top=[log.blobs(x)])
|
||||
layer.conv_param(x.size()[1],weight.size()[2:],stride=_pair(stride),
|
||||
pad=_pair(padding),dilation=_pair(dilation),bias_term=bias is not None, groups = groups)
|
||||
if bias is not None:
|
||||
layer.add_data(weight.cpu().data.numpy(),bias.cpu().data.numpy())
|
||||
else:
|
||||
layer.param.convolution_param.bias_term=False
|
||||
layer.add_data(weight.cpu().data.numpy())
|
||||
log.cnet.add_layer(layer)
|
||||
return x
|
||||
|
||||
def _linear(raw,input, weight, bias=None):
|
||||
x=raw(input,weight,bias)
|
||||
layer_name=log.add_layer(name='fc')
|
||||
top_blobs=log.add_blobs([x],name='fc_blob')
|
||||
layer=caffe_net.Layer_param(name=layer_name,type='InnerProduct',
|
||||
bottom=[log.blobs(input)],top=top_blobs)
|
||||
layer.fc_param(x.size()[1],has_bias=bias is not None)
|
||||
if bias is not None:
|
||||
layer.add_data(weight.cpu().data.numpy(),bias.cpu().data.numpy())
|
||||
else:
|
||||
layer.add_data(weight.cpu().data.numpy())
|
||||
log.cnet.add_layer(layer)
|
||||
return x
|
||||
|
||||
def _split(raw,tensor, split_size, dim=0):
|
||||
# split in pytorch is slice in caffe
|
||||
x=raw(tensor, split_size, dim)
|
||||
layer_name=log.add_layer('split')
|
||||
top_blobs=log.add_blobs(x,name='split_blob')
|
||||
layer=caffe_net.Layer_param(name=layer_name, type='Slice',
|
||||
bottom=[log.blobs(tensor)], top=top_blobs)
|
||||
slice_num=int(np.floor(tensor.size()[dim]/split_size))
|
||||
slice_param=caffe_net.pb.SliceParameter(axis=dim,slice_point=[split_size*i for i in range(1,slice_num)])
|
||||
layer.param.slice_param.CopyFrom(slice_param)
|
||||
log.cnet.add_layer(layer)
|
||||
return x
|
||||
|
||||
|
||||
def _pool(type,raw,input,x,kernel_size,stride,padding,ceil_mode):
|
||||
# TODO dilation,ceil_mode,return indices
|
||||
layer_name = log.add_layer(name='{}_pool'.format(type))
|
||||
top_blobs = log.add_blobs([x], name='{}_pool_blob'.format(type))
|
||||
layer = caffe_net.Layer_param(name=layer_name, type='Pooling',
|
||||
bottom=[log.blobs(input)], top=top_blobs)
|
||||
# TODO w,h different kernel, stride and padding
|
||||
# processing ceil mode
|
||||
layer.pool_param(kernel_size=kernel_size, stride=kernel_size if stride is None else stride,
|
||||
pad=padding, type=type.upper() , ceil_mode = ceil_mode)
|
||||
log.cnet.add_layer(layer)
|
||||
if ceil_mode==False and stride is not None:
|
||||
oheight = (input.size()[2] - _pair(kernel_size)[0] + 2 * _pair(padding)[0]) % (_pair(stride)[0])
|
||||
owidth = (input.size()[3] - _pair(kernel_size)[1] + 2 * _pair(padding)[1]) % (_pair(stride)[1])
|
||||
if oheight!=0 or owidth!=0:
|
||||
caffe_out=raw(input, kernel_size, stride, padding, ceil_mode=True)
|
||||
print("WARNING: the output shape miss match at {}: "
|
||||
|
||||
"input {} output---Pytorch:{}---Caffe:{}\n"
|
||||
"This is caused by the different implementation that ceil mode in caffe and the floor mode in pytorch.\n"
|
||||
"You can add the clip layer in caffe prototxt manually if shape mismatch error is caused in caffe. ".format(layer_name,input.size(),x.size(),caffe_out.size()))
|
||||
|
||||
def _max_pool2d(raw,input, kernel_size, stride=None, padding=0, dilation=1,
|
||||
ceil_mode=False, return_indices=False):
|
||||
x = raw(input, kernel_size, stride, padding, dilation,ceil_mode, return_indices)
|
||||
_pool('max',raw,input, x, kernel_size, stride, padding,ceil_mode)
|
||||
return x
|
||||
|
||||
def _avg_pool2d(raw,input, kernel_size, stride = None, padding = 0, ceil_mode = False, count_include_pad = True):
|
||||
x = raw(input, kernel_size, stride, padding, ceil_mode, count_include_pad)
|
||||
_pool('ave',raw,input, x, kernel_size, stride, padding,ceil_mode)
|
||||
return x
|
||||
|
||||
def _adaptive_avg_pool2d(raw, input, output_size):
|
||||
_output_size = _list_with_default(output_size, input.size())
|
||||
x = raw(input, _output_size)
|
||||
if isinstance(_output_size, int):
|
||||
out_dim = _output_size
|
||||
else:
|
||||
out_dim = _output_size[0]
|
||||
tmp = max(input.shape[2], input.shape[3])
|
||||
stride = tmp //out_dim
|
||||
kernel_size = tmp - (out_dim - 1) * stride
|
||||
_pool('ave', raw, input, x, kernel_size, stride, 0, False)
|
||||
return x
|
||||
|
||||
def _max(raw,*args):
|
||||
x=raw(*args)
|
||||
if len(args)==1:
|
||||
# TODO max in one tensor
|
||||
assert NotImplementedError
|
||||
else:
|
||||
bottom_blobs=[]
|
||||
for arg in args:
|
||||
bottom_blobs.append(log.blobs(arg))
|
||||
layer_name=log.add_layer(name='max')
|
||||
top_blobs=log.add_blobs([x],name='max_blob')
|
||||
layer=caffe_net.Layer_param(name=layer_name,type='Eltwise',
|
||||
bottom=bottom_blobs,top=top_blobs)
|
||||
layer.param.eltwise_param.operation =2
|
||||
log.cnet.add_layer(layer)
|
||||
return x
|
||||
|
||||
def _cat(raw,inputs, dimension=0):
|
||||
x=raw(inputs, dimension)
|
||||
bottom_blobs=[]
|
||||
for input in inputs:
|
||||
bottom_blobs.append(log.blobs(input))
|
||||
layer_name=log.add_layer(name='cat')
|
||||
top_blobs=log.add_blobs([x],name='cat_blob')
|
||||
layer=caffe_net.Layer_param(name=layer_name,type='Concat',
|
||||
bottom=bottom_blobs,top=top_blobs)
|
||||
layer.param.concat_param.axis =dimension
|
||||
log.cnet.add_layer(layer)
|
||||
return x
|
||||
|
||||
def _dropout(raw,input,p=0.5, training=False, inplace=False):
|
||||
x=raw(input,p, training, inplace)
|
||||
bottom_blobs=[log.blobs(input)]
|
||||
layer_name=log.add_layer(name='dropout')
|
||||
top_blobs=log.add_blobs([x],name=bottom_blobs[0],with_num=False)
|
||||
layer=caffe_net.Layer_param(name=layer_name,type='Dropout',
|
||||
bottom=bottom_blobs,top=top_blobs)
|
||||
layer.param.dropout_param.dropout_ratio = p
|
||||
layer.param.include.extend([caffe_net.pb.NetStateRule(phase=0)]) # 1 for test, 0 for train
|
||||
log.cnet.add_layer(layer)
|
||||
return x
|
||||
|
||||
def _threshold(raw,input, threshold, value, inplace=False):
|
||||
# for threshold or relu
|
||||
if threshold==0 and value==0:
|
||||
x = raw(input,threshold, value, inplace)
|
||||
bottom_blobs=[log.blobs(input)]
|
||||
name = log.add_layer(name='relu')
|
||||
log.add_blobs([x], name='relu_blob')
|
||||
layer = caffe_net.Layer_param(name=name, type='ReLU',
|
||||
bottom=bottom_blobs, top=[log.blobs(x)])
|
||||
log.cnet.add_layer(layer)
|
||||
return x
|
||||
if value!=0:
|
||||
raise NotImplemented("value !=0 not implemented in caffe")
|
||||
x=raw(input,input, threshold, value, inplace)
|
||||
bottom_blobs=[log.blobs(input)]
|
||||
layer_name=log.add_layer(name='threshold')
|
||||
top_blobs=log.add_blobs([x],name='threshold_blob')
|
||||
layer=caffe_net.Layer_param(name=layer_name,type='Threshold',
|
||||
bottom=bottom_blobs,top=top_blobs)
|
||||
layer.param.threshold_param.threshold = threshold
|
||||
log.cnet.add_layer(layer)
|
||||
return x
|
||||
|
||||
def _relu(raw, input, inplace=False):
|
||||
# for threshold or prelu
|
||||
x = raw(input, False)
|
||||
name = log.add_layer(name='relu')
|
||||
log.add_blobs([x], name='relu_blob')
|
||||
layer = caffe_net.Layer_param(name=name, type='ReLU',
|
||||
bottom=[log.blobs(input)], top=[log.blobs(x)])
|
||||
log.cnet.add_layer(layer)
|
||||
return x
|
||||
def _prelu(raw, input, weight):
|
||||
# for threshold or prelu
|
||||
x = raw(input, weight)
|
||||
bottom_blobs=[log.blobs(input)]
|
||||
name = log.add_layer(name='prelu')
|
||||
log.add_blobs([x], name='prelu_blob')
|
||||
layer = caffe_net.Layer_param(name=name, type='PReLU',
|
||||
bottom=bottom_blobs, top=[log.blobs(x)])
|
||||
if weight.size()[0]==1:
|
||||
layer.param.prelu_param.channel_shared=True
|
||||
layer.add_data(weight.cpu().data.numpy()[0])
|
||||
else:
|
||||
layer.add_data(weight.cpu().data.numpy())
|
||||
log.cnet.add_layer(layer)
|
||||
return x
|
||||
|
||||
def _leaky_relu(raw, input, negative_slope=0.01, inplace=False):
|
||||
x = raw(input, negative_slope)
|
||||
name = log.add_layer(name='leaky_relu')
|
||||
log.add_blobs([x], name='leaky_relu_blob')
|
||||
layer = caffe_net.Layer_param(name=name, type='ReLU',
|
||||
bottom=[log.blobs(input)], top=[log.blobs(x)])
|
||||
layer.param.relu_param.negative_slope=negative_slope
|
||||
log.cnet.add_layer(layer)
|
||||
return x
|
||||
|
||||
def _tanh(raw, input):
|
||||
# for tanh activation
|
||||
x = raw(input)
|
||||
name = log.add_layer(name='tanh')
|
||||
log.add_blobs([x], name='tanh_blob')
|
||||
layer = caffe_net.Layer_param(name=name, type='TanH',
|
||||
bottom=[log.blobs(input)], top=[log.blobs(x)])
|
||||
log.cnet.add_layer(layer)
|
||||
return x
|
||||
|
||||
def _softmax(raw, input, dim=None, _stacklevel=3):
|
||||
# for F.softmax
|
||||
x=raw(input, dim=dim)
|
||||
if dim is None:
|
||||
dim=F._get_softmax_dim('softmax', input.dim(), _stacklevel)
|
||||
bottom_blobs=[log.blobs(input)]
|
||||
name = log.add_layer(name='softmax')
|
||||
log.add_blobs([x], name='softmax_blob')
|
||||
layer = caffe_net.Layer_param(name=name, type='Softmax',
|
||||
bottom=bottom_blobs, top=[log.blobs(x)])
|
||||
layer.param.softmax_param.axis=dim
|
||||
log.cnet.add_layer(layer)
|
||||
return x
|
||||
|
||||
def _batch_norm(raw,input, running_mean, running_var, weight=None, bias=None,
|
||||
training=False, momentum=0.1, eps=1e-5):
|
||||
# because the runing_mean and runing_var will be changed after the _batch_norm operation, we first save the parameters
|
||||
|
||||
x = raw(input, running_mean, running_var, weight, bias,
|
||||
training, momentum, eps)
|
||||
bottom_blobs = [log.blobs(input)]
|
||||
layer_name1 = log.add_layer(name='batch_norm')
|
||||
top_blobs = log.add_blobs([x], name='batch_norm_blob')
|
||||
layer1 = caffe_net.Layer_param(name=layer_name1, type='BatchNorm',
|
||||
bottom=bottom_blobs, top=top_blobs)
|
||||
if running_mean is None or running_var is None:
|
||||
# not use global_stats, normalization is performed over the current mini-batch
|
||||
layer1.batch_norm_param(use_global_stats=0,eps=eps)
|
||||
else:
|
||||
layer1.batch_norm_param(use_global_stats=1, eps=eps)
|
||||
running_mean_clone = running_mean.clone()
|
||||
running_var_clone = running_var.clone()
|
||||
layer1.add_data(running_mean_clone.cpu().numpy(), running_var_clone.cpu().numpy(), np.array([1.0]))
|
||||
log.cnet.add_layer(layer1)
|
||||
if weight is not None and bias is not None:
|
||||
layer_name2 = log.add_layer(name='bn_scale')
|
||||
layer2 = caffe_net.Layer_param(name=layer_name2, type='Scale',
|
||||
bottom=top_blobs, top=top_blobs)
|
||||
layer2.param.scale_param.bias_term = True
|
||||
layer2.add_data(weight.cpu().data.numpy(), bias.cpu().data.numpy())
|
||||
log.cnet.add_layer(layer2)
|
||||
return x
|
||||
|
||||
def _instance_norm(raw, input, running_mean=None, running_var=None, weight=None,
|
||||
bias=None, use_input_stats=True, momentum=0.1, eps=1e-5):
|
||||
# TODO: the batch size!=1 view operations
|
||||
print("WARNING: The Instance Normalization transfers to Caffe using BatchNorm, so the batch size should be 1")
|
||||
if running_var is not None or weight is not None:
|
||||
# TODO: the affine=True or track_running_stats=True case
|
||||
raise NotImplementedError("not implement the affine=True or track_running_stats=True case InstanceNorm")
|
||||
x= torch.batch_norm(
|
||||
input, weight, bias, running_mean, running_var,
|
||||
use_input_stats, momentum, eps,torch.backends.cudnn.enabled)
|
||||
bottom_blobs = [log.blobs(input)]
|
||||
layer_name1 = log.add_layer(name='instance_norm')
|
||||
top_blobs = log.add_blobs([x], name='instance_norm_blob')
|
||||
layer1 = caffe_net.Layer_param(name=layer_name1, type='BatchNorm',
|
||||
bottom=bottom_blobs, top=top_blobs)
|
||||
if running_mean is None or running_var is None:
|
||||
# not use global_stats, normalization is performed over the current mini-batch
|
||||
layer1.batch_norm_param(use_global_stats=0,eps=eps)
|
||||
running_mean=torch.zeros(input.size()[1])
|
||||
running_var=torch.ones(input.size()[1])
|
||||
else:
|
||||
layer1.batch_norm_param(use_global_stats=1, eps=eps)
|
||||
running_mean_clone = running_mean.clone()
|
||||
running_var_clone = running_var.clone()
|
||||
layer1.add_data(running_mean_clone.cpu().numpy(), running_var_clone.cpu().numpy(), np.array([1.0]))
|
||||
log.cnet.add_layer(layer1)
|
||||
if weight is not None and bias is not None:
|
||||
layer_name2 = log.add_layer(name='bn_scale')
|
||||
layer2 = caffe_net.Layer_param(name=layer_name2, type='Scale',
|
||||
bottom=top_blobs, top=top_blobs)
|
||||
layer2.param.scale_param.bias_term = True
|
||||
layer2.add_data(weight.cpu().data.numpy(), bias.cpu().data.numpy())
|
||||
log.cnet.add_layer(layer2)
|
||||
return x
|
||||
|
||||
|
||||
#upsample layer
|
||||
def _interpolate(raw, input,size=None, scale_factor=None, mode='nearest', align_corners=None):
|
||||
# 定义的参数包括 scale,即输出与输入的尺寸比例,如 2;scale_h、scale_w,
|
||||
# 同 scale,分别为 h、w 方向上的尺寸比例;pad_out_h、pad_out_w,仅在 scale 为 2 时
|
||||
# 有用,对输出进行额外 padding 在 h、w 方向上的数值;upsample_h、upsample_w,输
|
||||
# 出图像尺寸的数值。在 Upsample 的相关代码中,推荐仅仅使用 upsample_h、
|
||||
# upsample_w 准确定义 Upsample 层的输出尺寸,其他所有的参数都不推荐继续使用。
|
||||
# for nearest _interpolate
|
||||
if mode != "nearest" or align_corners != None:
|
||||
raise NotImplementedError("not implement F.interpolate totoaly")
|
||||
x = raw(input,size , scale_factor ,mode)
|
||||
|
||||
layer_name = log.add_layer(name='upsample')
|
||||
top_blobs = log.add_blobs([x], name='upsample_blob'.format(type))
|
||||
layer = caffe_net.Layer_param(name=layer_name, type='Upsample',
|
||||
bottom=[log.blobs(input)], top=top_blobs)
|
||||
|
||||
layer.upsample_param(size =(input.size(2),input.size(3)), scale_factor= scale_factor)
|
||||
log.cnet.add_layer(layer)
|
||||
return x
|
||||
|
||||
|
||||
#sigmid layer
|
||||
def _sigmoid(raw, input):
|
||||
# Applies the element-wise function:
|
||||
#
|
||||
# Sigmoid(x)= 1/(1+exp(−x))
|
||||
#
|
||||
#
|
||||
x = raw(input)
|
||||
name = log.add_layer(name='sigmoid')
|
||||
log.add_blobs([x], name='sigmoid_blob')
|
||||
layer = caffe_net.Layer_param(name=name, type='Sigmoid',
|
||||
bottom=[log.blobs(input)], top=[log.blobs(x)])
|
||||
log.cnet.add_layer(layer)
|
||||
|
||||
#tanh layer
|
||||
def _tanh(raw, input):
|
||||
# Applies the element-wise function:
|
||||
#
|
||||
# torch.nn.Tanh
|
||||
#
|
||||
#
|
||||
x = raw(input)
|
||||
name = log.add_layer(name='tanh')
|
||||
log.add_blobs([x], name='tanh_blob')
|
||||
layer = caffe_net.Layer_param(name=name, type='TanH',
|
||||
bottom=[log.blobs(input)], top=[log.blobs(x)])
|
||||
log.cnet.add_layer(layer)
|
||||
|
||||
def _hardtanh(raw, input, min_val, max_val, inplace):
|
||||
# Applies the element-wise function:
|
||||
#
|
||||
# torch.nn.ReLu6
|
||||
#
|
||||
#
|
||||
print('relu6: ', log.blobs(input))
|
||||
x = raw(input, min_val, max_val)
|
||||
name = log.add_layer(name='relu6')
|
||||
log.add_blobs([x], name='relu6_blob')
|
||||
layer = caffe_net.Layer_param(name=name, type='ReLU6',
|
||||
bottom=[log.blobs(input)], top=[log.blobs(x)])
|
||||
log.cnet.add_layer(layer)
|
||||
return x
|
||||
|
||||
#L2Norm layer
|
||||
def _l2Norm(raw, input, weight, eps):
|
||||
# Applies the element-wise function:
|
||||
#
|
||||
# L2Norm in vgg_ssd
|
||||
#
|
||||
#
|
||||
x = raw(input, weight, eps)
|
||||
name = log.add_layer(name='normalize')
|
||||
log.add_blobs([x], name='normalize_blob')
|
||||
layer = caffe_net.Layer_param(name=name, type='Normalize',
|
||||
bottom=[log.blobs(input)], top=[log.blobs(x)])
|
||||
layer.norm_param(eps)
|
||||
|
||||
layer.add_data(weight.cpu().data.numpy())
|
||||
log.cnet.add_layer(layer)
|
||||
return x
|
||||
|
||||
def _div(raw,inputs, inputs2):
|
||||
x=raw(inputs, inputs2)
|
||||
log.add_blobs([x],name='div_blob')
|
||||
return x
|
||||
|
||||
|
||||
# ----- for Variable operations --------
|
||||
|
||||
def _view(input, *args):
|
||||
x=raw_view(input, *args)
|
||||
if not NET_INITTED:
|
||||
return x
|
||||
layer_name=log.add_layer(name='view')
|
||||
top_blobs=log.add_blobs([x],name='view_blob')
|
||||
layer=caffe_net.Layer_param(name=layer_name,type='Reshape',
|
||||
bottom=[log.blobs(input)],top=top_blobs)
|
||||
# TODO: reshpae added to nn_tools layer
|
||||
dims=list(args)
|
||||
dims[0]=0 # the first dim should be batch_size
|
||||
layer.param.reshape_param.shape.CopyFrom(caffe_net.pb.BlobShape(dim=dims))
|
||||
log.cnet.add_layer(layer)
|
||||
return x
|
||||
|
||||
def _mean(input, *args,**kwargs):
|
||||
x=raw_mean(input, *args,**kwargs)
|
||||
if not NET_INITTED:
|
||||
return x
|
||||
layer_name=log.add_layer(name='mean')
|
||||
top_blobs=log.add_blobs([x],name='mean_blob')
|
||||
layer=caffe_net.Layer_param(name=layer_name,type='Reduction',
|
||||
bottom=[log.blobs(input)],top=top_blobs)
|
||||
if len(args)==1:
|
||||
dim=args[0]
|
||||
elif 'dim' in kwargs:
|
||||
dim=kwargs['dim']
|
||||
else:
|
||||
raise NotImplementedError('mean operation must specify a dim')
|
||||
layer.param.reduction_param.operation=4
|
||||
layer.param.reduction_param.axis=dim
|
||||
log.cnet.add_layer(layer)
|
||||
return x
|
||||
|
||||
def _add(input, *args):
|
||||
x = raw__add__(input, *args)
|
||||
if not NET_INITTED:
|
||||
return x
|
||||
layer_name = log.add_layer(name='add')
|
||||
top_blobs = log.add_blobs([x], name='add_blob')
|
||||
if log.blobs(args[0]) == None:
|
||||
log.add_blobs([args[0]], name='extra_blob')
|
||||
else:
|
||||
layer = caffe_net.Layer_param(name=layer_name, type='Eltwise',
|
||||
bottom=[log.blobs(input),log.blobs(args[0])], top=top_blobs)
|
||||
layer.param.eltwise_param.operation = 1 # sum is 1
|
||||
log.cnet.add_layer(layer)
|
||||
return x
|
||||
|
||||
def _iadd(input, *args):
|
||||
x = raw__iadd__(input, *args)
|
||||
if not NET_INITTED:
|
||||
return x
|
||||
x=x.clone()
|
||||
layer_name = log.add_layer(name='add')
|
||||
top_blobs = log.add_blobs([x], name='add_blob')
|
||||
layer = caffe_net.Layer_param(name=layer_name, type='Eltwise',
|
||||
bottom=[log.blobs(input),log.blobs(args[0])], top=top_blobs)
|
||||
layer.param.eltwise_param.operation = 1 # sum is 1
|
||||
log.cnet.add_layer(layer)
|
||||
return x
|
||||
|
||||
def _sub(input, *args):
|
||||
x = raw__sub__(input, *args)
|
||||
if not NET_INITTED:
|
||||
return x
|
||||
layer_name = log.add_layer(name='sub')
|
||||
top_blobs = log.add_blobs([x], name='sub_blob')
|
||||
layer = caffe_net.Layer_param(name=layer_name, type='Eltwise',
|
||||
bottom=[log.blobs(input),log.blobs(args[0])], top=top_blobs)
|
||||
layer.param.eltwise_param.operation = 1 # sum is 1
|
||||
layer.param.eltwise_param.coeff.extend([1.,-1.])
|
||||
log.cnet.add_layer(layer)
|
||||
return x
|
||||
|
||||
def _isub(input, *args):
|
||||
x = raw__isub__(input, *args)
|
||||
if not NET_INITTED:
|
||||
return x
|
||||
x=x.clone()
|
||||
layer_name = log.add_layer(name='sub')
|
||||
top_blobs = log.add_blobs([x], name='sub_blob')
|
||||
layer = caffe_net.Layer_param(name=layer_name, type='Eltwise',
|
||||
bottom=[log.blobs(input),log.blobs(args[0])], top=top_blobs)
|
||||
layer.param.eltwise_param.operation = 1 # sum is 1
|
||||
log.cnet.add_layer(layer)
|
||||
return x
|
||||
|
||||
def _mul(input, *args):
|
||||
x = raw__mul__(input, *args)
|
||||
if not NET_INITTED:
|
||||
return x
|
||||
layer_name = log.add_layer(name='mul')
|
||||
top_blobs = log.add_blobs([x], name='mul_blob')
|
||||
layer = caffe_net.Layer_param(name=layer_name, type='Eltwise',
|
||||
bottom=[log.blobs(input), log.blobs(args[0])], top=top_blobs)
|
||||
layer.param.eltwise_param.operation = 0 # product is 1
|
||||
log.cnet.add_layer(layer)
|
||||
return x
|
||||
|
||||
def _imul(input, *args):
|
||||
x = raw__imul__(input, *args)
|
||||
if not NET_INITTED:
|
||||
return x
|
||||
x = x.clone()
|
||||
layer_name = log.add_layer(name='mul')
|
||||
top_blobs = log.add_blobs([x], name='mul_blob')
|
||||
layer = caffe_net.Layer_param(name=layer_name, type='Eltwise',
|
||||
bottom=[log.blobs(input), log.blobs(args[0])], top=top_blobs)
|
||||
layer.param.eltwise_param.operation = 0 # product is 1
|
||||
layer.param.eltwise_param.coeff.extend([1., -1.])
|
||||
log.cnet.add_layer(layer)
|
||||
return x
|
||||
|
||||
|
||||
#Permute layer
|
||||
def _permute(input, *args):
|
||||
x = raw__permute__(input, *args)
|
||||
name = log.add_layer(name='permute')
|
||||
log.add_blobs([x], name='permute_blob')
|
||||
layer = caffe_net.Layer_param(name=name, type='Permute',
|
||||
bottom=[log.blobs(input)], top=[log.blobs(x)])
|
||||
order1 = args[0]
|
||||
order2 = args[1]
|
||||
order3 = args[2]
|
||||
order4 = args[3]
|
||||
|
||||
layer.permute_param(order1, order2, order3, order4)
|
||||
log.cnet.add_layer(layer)
|
||||
return x
|
||||
|
||||
#contiguous
|
||||
def _contiguous(input, *args):
|
||||
x = raw__contiguous__(input, *args)
|
||||
name = log.add_layer(name='contiguous')
|
||||
log.add_blobs([x], name='contiguous_blob')
|
||||
layer = caffe_net.Layer_param(name=name, type='NeedRemove',
|
||||
bottom=[log.blobs(input)], top=[log.blobs(x)])
|
||||
log.cnet.add_layer(layer)
|
||||
return x
|
||||
|
||||
#pow
|
||||
def _pow(input, *args):
|
||||
x = raw__pow__(input, *args)
|
||||
log.add_blobs([x], name='pow_blob')
|
||||
return x
|
||||
|
||||
#sum
|
||||
def _sum(input, *args):
|
||||
x = raw__sum__(input, *args)
|
||||
log.add_blobs([x], name='sum_blob')
|
||||
return x
|
||||
|
||||
# sqrt
|
||||
def _sqrt(input, *args):
|
||||
x = raw__sqrt__(input, *args)
|
||||
log.add_blobs([x], name='sqrt_blob')
|
||||
return x
|
||||
|
||||
# unsqueeze
|
||||
def _unsqueeze(input, *args):
|
||||
x = raw__unsqueeze__(input, *args)
|
||||
log.add_blobs([x], name='unsqueeze_blob')
|
||||
return x
|
||||
|
||||
# sqrt
|
||||
def _expand_as(input, *args):
|
||||
x = raw__expand_as__(input, *args)
|
||||
log.add_blobs([x], name='expand_as_blob')
|
||||
return x
|
||||
|
||||
|
||||
|
||||
|
||||
# 核心组件,通过该类,实现对torch的function中的operators的输入,输出以及参数的读取
|
||||
class Rp(object):
|
||||
def __init__(self,raw,replace,**kwargs):
|
||||
# replace the raw function to replace function
|
||||
self.obj=replace
|
||||
self.raw=raw
|
||||
|
||||
def __call__(self,*args,**kwargs):
|
||||
if not NET_INITTED:
|
||||
return self.raw(*args,**kwargs)
|
||||
for stack in traceback.walk_stack(None):
|
||||
if 'self' in stack[0].f_locals:
|
||||
layer=stack[0].f_locals['self']
|
||||
if layer in layer_names:
|
||||
log.pytorch_layer_name=layer_names[layer]
|
||||
print(layer_names[layer])
|
||||
break
|
||||
out=self.obj(self.raw,*args,**kwargs)
|
||||
# if isinstance(out,Variable):
|
||||
# out=[out]
|
||||
return out
|
||||
|
||||
|
||||
|
||||
|
||||
F.conv2d=Rp(F.conv2d,_conv2d)
|
||||
F.linear=Rp(F.linear,_linear)
|
||||
F.relu=Rp(F.relu,_relu)
|
||||
F.leaky_relu=Rp(F.leaky_relu,_leaky_relu)
|
||||
F.max_pool2d=Rp(F.max_pool2d,_max_pool2d)
|
||||
F.avg_pool2d=Rp(F.avg_pool2d,_avg_pool2d)
|
||||
F.dropout=Rp(F.dropout,_dropout)
|
||||
F.threshold=Rp(F.threshold,_threshold)
|
||||
F.prelu=Rp(F.prelu,_prelu)
|
||||
F.batch_norm=Rp(F.batch_norm,_batch_norm)
|
||||
F.instance_norm=Rp(F.instance_norm,_instance_norm)
|
||||
F.softmax=Rp(F.softmax,_softmax)
|
||||
F.conv_transpose2d=Rp(F.conv_transpose2d,_conv_transpose2d)
|
||||
F.interpolate = Rp(F.interpolate,_interpolate)
|
||||
F.sigmoid = Rp(F.sigmoid,_sigmoid)
|
||||
F.tanh = Rp(F.tanh,_tanh)
|
||||
F.tanh = Rp(F.tanh,_tanh)
|
||||
F.hardtanh = Rp(F.hardtanh,_hardtanh)
|
||||
# F.l2norm = Rp(F.l2norm,_l2Norm)
|
||||
|
||||
torch.split=Rp(torch.split,_split)
|
||||
torch.max=Rp(torch.max,_max)
|
||||
torch.cat=Rp(torch.cat,_cat)
|
||||
torch.div=Rp(torch.div,_div)
|
||||
|
||||
# TODO: other types of the view function
|
||||
try:
|
||||
raw_view=Variable.view
|
||||
Variable.view=_view
|
||||
raw_mean=Variable.mean
|
||||
Variable.mean=_mean
|
||||
raw__add__=Variable.__add__
|
||||
Variable.__add__=_add
|
||||
raw__iadd__=Variable.__iadd__
|
||||
Variable.__iadd__=_iadd
|
||||
raw__sub__=Variable.__sub__
|
||||
Variable.__sub__=_sub
|
||||
raw__isub__=Variable.__isub__
|
||||
Variable.__isub__=_isub
|
||||
raw__mul__ = Variable.__mul__
|
||||
Variable.__mul__ = _mul
|
||||
raw__imul__ = Variable.__imul__
|
||||
Variable.__imul__ = _imul
|
||||
except:
|
||||
# for new version 0.4.0 and later version
|
||||
for t in [torch.Tensor]:
|
||||
raw_view = t.view
|
||||
t.view = _view
|
||||
raw_mean = t.mean
|
||||
t.mean = _mean
|
||||
raw__add__ = t.__add__
|
||||
t.__add__ = _add
|
||||
raw__iadd__ = t.__iadd__
|
||||
t.__iadd__ = _iadd
|
||||
raw__sub__ = t.__sub__
|
||||
t.__sub__ = _sub
|
||||
raw__isub__ = t.__isub__
|
||||
t.__isub__ = _isub
|
||||
raw__mul__ = t.__mul__
|
||||
t.__mul__=_mul
|
||||
raw__imul__ = t.__imul__
|
||||
t.__imul__ = _imul
|
||||
raw__permute__ = t.permute
|
||||
t.permute = _permute
|
||||
raw__contiguous__ = t.contiguous
|
||||
t.contiguous = _contiguous
|
||||
raw__pow__ = t.pow
|
||||
t.pow = _pow
|
||||
raw__sum__ = t.sum
|
||||
t.sum = _sum
|
||||
raw__sqrt__ = t.sqrt
|
||||
t.sqrt = _sqrt
|
||||
raw__unsqueeze__ = t.unsqueeze
|
||||
t.unsqueeze = _unsqueeze
|
||||
raw__expand_as__ = t.expand_as
|
||||
t.expand_as = _expand_as
|
||||
|
||||
|
||||
def trans_net(net,input_var,name='TransferedPytorchModel'):
|
||||
print('Starting Transform, This will take a while')
|
||||
log.init([input_var])
|
||||
log.cnet.net.name=name
|
||||
log.cnet.net.input.extend([log.blobs(input_var)])
|
||||
log.cnet.net.input_dim.extend(input_var.size())
|
||||
global NET_INITTED
|
||||
NET_INITTED=True
|
||||
for name,layer in net.named_modules():
|
||||
layer_names[layer]=name
|
||||
print("torch ops name:", layer_names)
|
||||
out = net.forward(input_var)
|
||||
print('Transform Completed')
|
||||
|
||||
def save_prototxt(save_name):
|
||||
log.cnet.remove_layer_by_type("NeedRemove")
|
||||
log.cnet.save_prototxt(save_name)
|
||||
|
||||
def save_caffemodel(save_name):
|
||||
log.cnet.save(save_name)
|
|
@ -0,0 +1,5 @@
|
|||
|
||||
python caffe_export.py --config-file "/export/home/lxy/cvpalgo-fast-reid/logs/dukemtmc/R34/config.yaml" \
|
||||
--name "baseline_R34" \
|
||||
--output "logs/caffe_R34" \
|
||||
--opts MODEL.WEIGHTS "/export/home/lxy/cvpalgo-fast-reid/logs/dukemtmc/R34/model_final.pth"
|
|
@ -0,0 +1,8 @@
|
|||
|
||||
python caffe_inference.py --model-def "logs/caffe_R34/baseline_R34.prototxt" \
|
||||
--model-weights "logs/caffe_R34/baseline_R34.caffemodel" \
|
||||
--input \
|
||||
'/export/home/DATA/Market-1501-v15.09.15/bounding_box_test/1182_c5s3_015240_04.jpg' \
|
||||
'/export/home/DATA/Market-1501-v15.09.15/bounding_box_test/1182_c6s3_038217_01.jpg' \
|
||||
'/export/home/DATA/Market-1501-v15.09.15/bounding_box_test/1183_c5s3_006943_05.jpg' \
|
||||
--output "caffe_R34_output"
|
Loading…
Reference in New Issue