mirror of
https://github.com/ultralytics/yolov5.git
synced 2025-06-03 14:49:29 +08:00
model definition update
This commit is contained in:
parent
715cb08b10
commit
5bee686649
@ -45,13 +45,15 @@ class Detect(nn.Module):
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, model_yaml='yolov5s.yaml'): # cfg, number of classes, depth-width gains
|
||||
def __init__(self, model_yaml='yolov5s.yaml', ch=3, nc=None): # model, input channels, number of classes
|
||||
super(Model, self).__init__()
|
||||
with open(model_yaml) as f:
|
||||
self.md = yaml.load(f, Loader=yaml.FullLoader) # model dict
|
||||
if nc:
|
||||
self.md['nc'] = nc # override yaml value
|
||||
|
||||
# Define model
|
||||
self.model, self.save, ch = parse_model(self.md, ch=[3]) # model, savelist, ch_out
|
||||
self.model, self.save, ch = parse_model(self.md, ch=[ch]) # model, savelist, ch_out
|
||||
# print([x.shape for x in self.forward(torch.zeros(1, 3, 64, 64))])
|
||||
|
||||
# Build strides, anchors
|
||||
|
Loading…
x
Reference in New Issue
Block a user