new nc=len(names) check
parent
cb527d3af9
commit
e16e9e43e1
4
train.py
4
train.py
|
@ -76,7 +76,7 @@ def train(hyp):
|
|||
os.remove(f)
|
||||
|
||||
# Create model
|
||||
model = Model(opt.cfg, nc=data_dict['nc']).to(device)
|
||||
model = Model(opt.cfg, nc=nc).to(device)
|
||||
|
||||
# Image sizes
|
||||
gs = int(max(model.stride)) # grid size (max stride)
|
||||
|
@ -177,7 +177,7 @@ def train(hyp):
|
|||
model.hyp = hyp # attach hyperparameters to model
|
||||
model.gr = 1.0 # giou loss ratio (obj_loss = 1.0 or giou)
|
||||
model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights
|
||||
model.names = data_dict['names']
|
||||
model.names = names
|
||||
|
||||
# Class frequency
|
||||
labels = np.concatenate(dataset.labels, 0)
|
||||
|
|
Loading…
Reference in New Issue