mirror of
https://github.com/ultralytics/yolov5.git
synced 2025-06-03 14:49:29 +08:00
hyp evolution force-autoanchor fix
This commit is contained in:
parent
c687d5c129
commit
c8e51812a5
4
train.py
4
train.py
@ -68,10 +68,10 @@ def train(hyp, opt, device, tb_writer=None):
|
|||||||
with torch_distributed_zero_first(rank):
|
with torch_distributed_zero_first(rank):
|
||||||
attempt_download(weights) # download if not found locally
|
attempt_download(weights) # download if not found locally
|
||||||
ckpt = torch.load(weights, map_location=device) # load checkpoint
|
ckpt = torch.load(weights, map_location=device) # load checkpoint
|
||||||
if 'anchors' in hyp and hyp['anchors']:
|
if hyp.get('anchors'):
|
||||||
ckpt['model'].yaml['anchors'] = round(hyp['anchors']) # force autoanchor
|
ckpt['model'].yaml['anchors'] = round(hyp['anchors']) # force autoanchor
|
||||||
model = Model(opt.cfg or ckpt['model'].yaml, ch=3, nc=nc).to(device) # create
|
model = Model(opt.cfg or ckpt['model'].yaml, ch=3, nc=nc).to(device) # create
|
||||||
exclude = ['anchor'] if opt.cfg else [] # exclude keys
|
exclude = ['anchor'] if opt.cfg or hyp.get('anchors') else [] # exclude keys
|
||||||
state_dict = ckpt['model'].float().state_dict() # to FP32
|
state_dict = ckpt['model'].float().state_dict() # to FP32
|
||||||
state_dict = intersect_dicts(state_dict, model.state_dict(), exclude=exclude) # intersect
|
state_dict = intersect_dicts(state_dict, model.state_dict(), exclude=exclude) # intersect
|
||||||
model.load_state_dict(state_dict, strict=False) # load
|
model.load_state_dict(state_dict, strict=False) # load
|
||||||
|
Loading…
x
Reference in New Issue
Block a user