requires grad after reset params
parent
e08d568d39
commit
5c854fab5e
|
@ -114,13 +114,13 @@ def train(opt, device):
|
|||
LOGGER.warning("WARNING: pass YOLOv5 classifier model with '-cls' suffix, i.e. '--model yolov5s-cls.pt'")
|
||||
model = ClassificationModel(model=model, nc=nc, cutoff=opt.cutoff or 10) # convert to classification model
|
||||
reshape_classifier_output(model, nc) # update class count
|
||||
for p in model.parameters():
|
||||
p.requires_grad = True # for training
|
||||
for m in model.modules():
|
||||
if not pretrained and hasattr(m, 'reset_parameters'):
|
||||
m.reset_parameters()
|
||||
if isinstance(m, torch.nn.Dropout) and opt.dropout is not None:
|
||||
m.p = opt.dropout # set dropout
|
||||
for p in model.parameters():
|
||||
p.requires_grad = True # for training
|
||||
model = model.to(device)
|
||||
names = trainloader.dataset.classes # class names
|
||||
model.names = names # attach class names
|
||||
|
|
Loading…
Reference in New Issue