rename class autoShape -> AutoShape (#3173)
* rename class autoShape -> AutoShape follow other class naming convention * rename class autoShape -> AutoShape follow other classes' naming convention * rename class autoShape -> AutoShapepull/3104/head
parent
17b0f71538
commit
be86c21c73
|
@ -223,18 +223,18 @@ class NMS(nn.Module):
|
|||
return non_max_suppression(x[0], conf_thres=self.conf, iou_thres=self.iou, classes=self.classes)
|
||||
|
||||
|
||||
class autoShape(nn.Module):
|
||||
class AutoShape(nn.Module):
|
||||
# input-robust model wrapper for passing cv2/np/PIL/torch inputs. Includes preprocessing, inference and NMS
|
||||
conf = 0.25 # NMS confidence threshold
|
||||
iou = 0.45 # NMS IoU threshold
|
||||
classes = None # (optional list) filter by class
|
||||
|
||||
def __init__(self, model):
|
||||
super(autoShape, self).__init__()
|
||||
super(AutoShape, self).__init__()
|
||||
self.model = model.eval()
|
||||
|
||||
def autoshape(self):
|
||||
print('autoShape already enabled, skipping... ') # model already converted to model.autoshape()
|
||||
print('AutoShape already enabled, skipping... ') # model already converted to model.autoshape()
|
||||
return self
|
||||
|
||||
@torch.no_grad()
|
||||
|
|
|
@ -215,9 +215,9 @@ class Model(nn.Module):
|
|||
self.model = self.model[:-1] # remove
|
||||
return self
|
||||
|
||||
def autoshape(self): # add autoShape module
|
||||
logger.info('Adding autoShape... ')
|
||||
m = autoShape(self) # wrap model
|
||||
def autoshape(self): # add AutoShape module
|
||||
logger.info('Adding AutoShape... ')
|
||||
m = AutoShape(self) # wrap model
|
||||
copy_attr(m, self, include=('yaml', 'nc', 'hyp', 'names', 'stride'), exclude=()) # copy attributes
|
||||
return m
|
||||
|
||||
|
|
Loading…
Reference in New Issue