requires grad after reset params

pull/9015/head
glennjocher 2022-08-18 02:44:50 +02:00
parent e08d568d39
commit 5c854fab5e
1 changed files with 2 additions and 2 deletions

View File

@ -114,13 +114,13 @@ def train(opt, device):
LOGGER.warning("WARNING: pass YOLOv5 classifier model with '-cls' suffix, i.e. '--model yolov5s-cls.pt'")
model = ClassificationModel(model=model, nc=nc, cutoff=opt.cutoff or 10) # convert to classification model
reshape_classifier_output(model, nc) # update class count
for p in model.parameters():
p.requires_grad = True # for training
for m in model.modules():
if not pretrained and hasattr(m, 'reset_parameters'):
m.reset_parameters()
if isinstance(m, torch.nn.Dropout) and opt.dropout is not None:
m.p = opt.dropout # set dropout
for p in model.parameters():
p.requires_grad = True # for training
model = model.to(device)
names = trainloader.dataset.classes # class names
model.names = names # attach class names