mirror of
https://github.com/ultralytics/yolov5.git
synced 2025-06-03 14:49:29 +08:00
Merge pull request #245 from yxNONG/patch-2
Unify the check point of single and multi GPU
This commit is contained in:
commit
e02a189a3a
4
train.py
4
train.py
@ -79,7 +79,6 @@ def train(hyp):
|
|||||||
# Create model
|
# Create model
|
||||||
model = Model(opt.cfg).to(device)
|
model = Model(opt.cfg).to(device)
|
||||||
assert model.md['nc'] == nc, '%s nc=%g classes but %s nc=%g classes' % (opt.data, nc, opt.cfg, model.md['nc'])
|
assert model.md['nc'] == nc, '%s nc=%g classes but %s nc=%g classes' % (opt.data, nc, opt.cfg, model.md['nc'])
|
||||||
model.names = data_dict['names']
|
|
||||||
|
|
||||||
# Image sizes
|
# Image sizes
|
||||||
gs = int(max(model.stride)) # grid size (max stride)
|
gs = int(max(model.stride)) # grid size (max stride)
|
||||||
@ -178,6 +177,7 @@ def train(hyp):
|
|||||||
model.hyp = hyp # attach hyperparameters to model
|
model.hyp = hyp # attach hyperparameters to model
|
||||||
model.gr = 1.0 # giou loss ratio (obj_loss = 1.0 or giou)
|
model.gr = 1.0 # giou loss ratio (obj_loss = 1.0 or giou)
|
||||||
model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights
|
model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights
|
||||||
|
model.names = data_dict['names']
|
||||||
|
|
||||||
# Class frequency
|
# Class frequency
|
||||||
labels = np.concatenate(dataset.labels, 0)
|
labels = np.concatenate(dataset.labels, 0)
|
||||||
@ -294,7 +294,7 @@ def train(hyp):
|
|||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
imgsz=imgsz_test,
|
imgsz=imgsz_test,
|
||||||
save_json=final_epoch and opt.data.endswith(os.sep + 'coco.yaml'),
|
save_json=final_epoch and opt.data.endswith(os.sep + 'coco.yaml'),
|
||||||
model=ema.ema,
|
model=ema.ema.module if hasattr(model, 'module') else ema.ema,
|
||||||
single_cls=opt.single_cls,
|
single_cls=opt.single_cls,
|
||||||
dataloader=testloader)
|
dataloader=testloader)
|
||||||
|
|
||||||
|
@ -54,6 +54,11 @@ def time_synchronized():
|
|||||||
return time.time()
|
return time.time()
|
||||||
|
|
||||||
|
|
||||||
|
def is_parallel(model):
|
||||||
|
# is model is parallel with DP or DDP
|
||||||
|
return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
|
||||||
|
|
||||||
|
|
||||||
def initialize_weights(model):
|
def initialize_weights(model):
|
||||||
for m in model.modules():
|
for m in model.modules():
|
||||||
t = type(m)
|
t = type(m)
|
||||||
@ -111,8 +116,8 @@ def model_info(model, verbose=False):
|
|||||||
|
|
||||||
try: # FLOPS
|
try: # FLOPS
|
||||||
from thop import profile
|
from thop import profile
|
||||||
macs, _ = profile(model, inputs=(torch.zeros(1, 3, 480, 640),), verbose=False)
|
flops = profile(deepcopy(model), inputs=(torch.zeros(1, 3, 64, 64),), verbose=False)[0] / 1E9 * 2
|
||||||
fs = ', %.1f GFLOPS' % (macs / 1E9 * 2)
|
fs = ', %.1f GFLOPS' % (flops * 100) # 640x640 FLOPS
|
||||||
except:
|
except:
|
||||||
fs = ''
|
fs = ''
|
||||||
|
|
||||||
@ -185,7 +190,7 @@ class ModelEMA:
|
|||||||
self.updates += 1
|
self.updates += 1
|
||||||
d = self.decay(self.updates)
|
d = self.decay(self.updates)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
if type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel):
|
if is_parallel(model):
|
||||||
msd, esd = model.module.state_dict(), self.ema.module.state_dict()
|
msd, esd = model.module.state_dict(), self.ema.module.state_dict()
|
||||||
else:
|
else:
|
||||||
msd, esd = model.state_dict(), self.ema.state_dict()
|
msd, esd = model.state_dict(), self.ema.state_dict()
|
||||||
@ -196,7 +201,8 @@ class ModelEMA:
|
|||||||
v += (1. - d) * msd[k].detach()
|
v += (1. - d) * msd[k].detach()
|
||||||
|
|
||||||
def update_attr(self, model):
|
def update_attr(self, model):
|
||||||
# Assign attributes (which may change during training)
|
# Update class attributes
|
||||||
for k in model.__dict__.keys():
|
ema = self.ema.module if is_parallel(model) else self.ema
|
||||||
if not k.startswith('_'):
|
for k, v in model.__dict__.items():
|
||||||
setattr(self.ema, k, getattr(model, k))
|
if not k.startswith('_') and k != 'module':
|
||||||
|
setattr(ema, k, v)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user