59 lines
1.9 KiB
Python
59 lines
1.9 KiB
Python
#!/usr/bin/env python
|
|
"""
|
|
Draw a graph of the net architecture.
|
|
"""
|
|
from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
|
|
from google.protobuf import text_format
|
|
|
|
import caffe
|
|
import caffe.draw
|
|
from caffe.proto import caffe_pb2
|
|
|
|
|
|
def parse_args():
|
|
"""Parse input arguments
|
|
"""
|
|
|
|
parser = ArgumentParser(description=__doc__,
|
|
formatter_class=ArgumentDefaultsHelpFormatter)
|
|
|
|
parser.add_argument('input_net_proto_file',
|
|
help='Input network prototxt file')
|
|
parser.add_argument('output_image_file',
|
|
help='Output image file')
|
|
parser.add_argument('--rankdir',
|
|
help=('One of TB (top-bottom, i.e., vertical), '
|
|
'RL (right-left, i.e., horizontal), or another '
|
|
'valid dot option; see '
|
|
'http://www.graphviz.org/doc/info/'
|
|
'attrs.html#k:rankdir'),
|
|
default='LR')
|
|
parser.add_argument('--phase',
|
|
help=('Which network phase to draw: can be TRAIN, '
|
|
'TEST, or ALL. If ALL, then all layers are drawn '
|
|
'regardless of phase.'),
|
|
default="ALL")
|
|
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
net = caffe_pb2.NetParameter()
|
|
text_format.Merge(open(args.input_net_proto_file).read(), net)
|
|
print('Drawing net to %s' % args.output_image_file)
|
|
phase=None;
|
|
if args.phase == "TRAIN":
|
|
phase = caffe.TRAIN
|
|
elif args.phase == "TEST":
|
|
phase = caffe.TEST
|
|
elif args.phase != "ALL":
|
|
raise ValueError("Unknown phase: " + args.phase)
|
|
caffe.draw.draw_net_to_file(net, args.output_image_file, args.rankdir,
|
|
phase)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|