Implement default class names (#1609)

pull/1619/head
Glenn Jocher 2020-12-06 12:41:37 +01:00 committed by GitHub
parent efa7a915d8
commit d929bb656c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 5 additions and 4 deletions

View File

@ -1,16 +1,16 @@
import argparse import argparse
import logging import logging
import math
import sys import sys
from copy import deepcopy from copy import deepcopy
from pathlib import Path from pathlib import Path
sys.path.append('./') # to run '$ python *.py' files in subdirectories import math
logger = logging.getLogger(__name__)
import torch import torch
import torch.nn as nn import torch.nn as nn
sys.path.append('./') # to run '$ python *.py' files in subdirectories
logger = logging.getLogger(__name__)
from models.common import Conv, Bottleneck, SPP, DWConv, Focus, BottleneckCSP, Concat, NMS, autoShape from models.common import Conv, Bottleneck, SPP, DWConv, Focus, BottleneckCSP, Concat, NMS, autoShape
from models.experimental import MixConv2d, CrossConv, C3 from models.experimental import MixConv2d, CrossConv, C3
from utils.autoanchor import check_anchor_order from utils.autoanchor import check_anchor_order
@ -82,6 +82,7 @@ class Model(nn.Module):
logger.info('Overriding model.yaml nc=%g with nc=%g' % (self.yaml['nc'], nc)) logger.info('Overriding model.yaml nc=%g with nc=%g' % (self.yaml['nc'], nc))
self.yaml['nc'] = nc # override yaml value self.yaml['nc'] = nc # override yaml value
self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist, ch_out self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist, ch_out
self.names = [str(i) for i in range(self.yaml['nc'])] # default names
# print([x.shape for x in self.forward(torch.zeros(1, ch, 64, 64))]) # print([x.shape for x in self.forward(torch.zeros(1, ch, 64, 64))])
# Build strides, anchors # Build strides, anchors