Implement default class names (#1609)
parent
efa7a915d8
commit
d929bb656c
|
@ -1,16 +1,16 @@
|
|||
import argparse
|
||||
import logging
|
||||
import math
|
||||
import sys
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.append('./') # to run '$ python *.py' files in subdirectories
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
sys.path.append('./') # to run '$ python *.py' files in subdirectories
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from models.common import Conv, Bottleneck, SPP, DWConv, Focus, BottleneckCSP, Concat, NMS, autoShape
|
||||
from models.experimental import MixConv2d, CrossConv, C3
|
||||
from utils.autoanchor import check_anchor_order
|
||||
|
@ -82,6 +82,7 @@ class Model(nn.Module):
|
|||
logger.info('Overriding model.yaml nc=%g with nc=%g' % (self.yaml['nc'], nc))
|
||||
self.yaml['nc'] = nc # override yaml value
|
||||
self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist, ch_out
|
||||
self.names = [str(i) for i in range(self.yaml['nc'])] # default names
|
||||
# print([x.shape for x in self.forward(torch.zeros(1, ch, 64, 64))])
|
||||
|
||||
# Build strides, anchors
|
||||
|
|
Loading…
Reference in New Issue