model.yaml nc inherited from dataset.yaml
parent
1fdaa4987b
commit
bb3c346916
|
@ -52,7 +52,8 @@ class Model(nn.Module):
|
|||
self.md = yaml.load(f, Loader=yaml.FullLoader) # model dict
|
||||
|
||||
# Define model
|
||||
if nc:
|
||||
if nc and nc != self.md['nc']:
|
||||
print('Overriding %s nc=%g with nc=%g' % (model_cfg, self.md['nc'], nc))
|
||||
self.md['nc'] = nc # override yaml value
|
||||
self.model, self.save = parse_model(self.md, ch=[ch]) # model, savelist, ch_out
|
||||
# print([x.shape for x in self.forward(torch.zeros(1, ch, 64, 64))])
|
||||
|
|
3
train.py
3
train.py
|
@ -77,8 +77,7 @@ def train(hyp):
|
|||
os.remove(f)
|
||||
|
||||
# Create model
|
||||
model = Model(opt.cfg).to(device)
|
||||
assert model.md['nc'] == nc, '%s nc=%g classes but %s nc=%g classes' % (opt.data, nc, opt.cfg, model.md['nc'])
|
||||
model = Model(opt.cfg, nc=data_dict['nc']).to(device)
|
||||
|
||||
# Image sizes
|
||||
gs = int(max(model.stride)) # grid size (max stride)
|
||||
|
|
Loading…
Reference in New Issue