101 lines
3.1 KiB
Python
101 lines
3.1 KiB
Python
#!/usr/bin/env python
|
|
"""
|
|
Trains a model using one or more GPUs.
|
|
"""
|
|
from multiprocessing import Process
|
|
|
|
import caffe
|
|
|
|
|
|
def train(
|
|
solver, # solver proto definition
|
|
snapshot, # solver snapshot to restore
|
|
gpus, # list of device ids
|
|
timing=False, # show timing info for compute and communications
|
|
):
|
|
# NCCL uses a uid to identify a session
|
|
uid = caffe.NCCL.new_uid()
|
|
|
|
caffe.init_log()
|
|
caffe.log('Using devices %s' % str(gpus))
|
|
|
|
procs = []
|
|
for rank in range(len(gpus)):
|
|
p = Process(target=solve,
|
|
args=(solver, snapshot, gpus, timing, uid, rank))
|
|
p.daemon = True
|
|
p.start()
|
|
procs.append(p)
|
|
for p in procs:
|
|
p.join()
|
|
|
|
|
|
def time(solver, nccl):
|
|
fprop = []
|
|
bprop = []
|
|
total = caffe.Timer()
|
|
allrd = caffe.Timer()
|
|
for _ in range(len(solver.net.layers)):
|
|
fprop.append(caffe.Timer())
|
|
bprop.append(caffe.Timer())
|
|
display = solver.param.display
|
|
|
|
def show_time():
|
|
if solver.iter % display == 0:
|
|
s = '\n'
|
|
for i in range(len(solver.net.layers)):
|
|
s += 'forw %3d %8s ' % (i, solver.net._layer_names[i])
|
|
s += ': %.2f\n' % fprop[i].ms
|
|
for i in range(len(solver.net.layers) - 1, -1, -1):
|
|
s += 'back %3d %8s ' % (i, solver.net._layer_names[i])
|
|
s += ': %.2f\n' % bprop[i].ms
|
|
s += 'solver total: %.2f\n' % total.ms
|
|
s += 'allreduce: %.2f\n' % allrd.ms
|
|
caffe.log(s)
|
|
|
|
solver.net.before_forward(lambda layer: fprop[layer].start())
|
|
solver.net.after_forward(lambda layer: fprop[layer].stop())
|
|
solver.net.before_backward(lambda layer: bprop[layer].start())
|
|
solver.net.after_backward(lambda layer: bprop[layer].stop())
|
|
solver.add_callback(lambda: total.start(), lambda: (total.stop(), allrd.start()))
|
|
solver.add_callback(nccl)
|
|
solver.add_callback(lambda: '', lambda: (allrd.stop(), show_time()))
|
|
|
|
|
|
def solve(proto, snapshot, gpus, timing, uid, rank):
|
|
caffe.set_mode_gpu()
|
|
caffe.set_device(gpus[rank])
|
|
caffe.set_solver_count(len(gpus))
|
|
caffe.set_solver_rank(rank)
|
|
caffe.set_multiprocess(True)
|
|
|
|
solver = caffe.SGDSolver(proto)
|
|
if snapshot and len(snapshot) != 0:
|
|
solver.restore(snapshot)
|
|
|
|
nccl = caffe.NCCL(solver, uid)
|
|
nccl.bcast()
|
|
|
|
if timing and rank == 0:
|
|
time(solver, nccl)
|
|
else:
|
|
solver.add_callback(nccl)
|
|
|
|
if solver.param.layer_wise_reduce:
|
|
solver.net.after_backward(nccl)
|
|
solver.step(solver.param.max_iter)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
import argparse
|
|
parser = argparse.ArgumentParser()
|
|
|
|
parser.add_argument("--solver", required=True, help="Solver proto definition.")
|
|
parser.add_argument("--snapshot", help="Solver snapshot to restore.")
|
|
parser.add_argument("--gpus", type=int, nargs='+', default=[0],
|
|
help="List of device ids.")
|
|
parser.add_argument("--timing", action='store_true', help="Show timing info.")
|
|
args = parser.parse_args()
|
|
|
|
train(args.solver, args.snapshot, args.gpus, args.timing)
|