mirror of
https://github.com/ultralytics/yolov5.git
synced 2025-06-03 14:49:29 +08:00
add NMS to pretrained pytorch hub models
This commit is contained in:
parent
5a9c5c1d3b
commit
c4cb78570c
@ -10,6 +10,7 @@ import os
|
||||
|
||||
import torch
|
||||
|
||||
from models.common import NMS
|
||||
from models.yolo import Model
|
||||
from utils.google_utils import attempt_download
|
||||
|
||||
@ -35,6 +36,12 @@ def create(name, pretrained, channels, classes):
|
||||
state_dict = torch.load(ckpt, map_location=torch.device('cpu'))['model'].float().state_dict() # to FP32
|
||||
state_dict = {k: v for k, v in state_dict.items() if model.state_dict()[k].shape == v.shape} # filter
|
||||
model.load_state_dict(state_dict, strict=False) # load
|
||||
|
||||
m = NMS()
|
||||
m.f = -1 # from
|
||||
m.i = model.model[-1].i + 1 # index
|
||||
model.model.add_module(name='%s' % m.i, module=m) # add NMS
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
except Exception as e:
|
||||
|
@ -3,6 +3,7 @@ import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from utils.general import non_max_suppression
|
||||
|
||||
|
||||
def autopad(k, p=None): # kernel, padding
|
||||
@ -98,6 +99,19 @@ class Concat(nn.Module):
|
||||
return torch.cat(x, self.d)
|
||||
|
||||
|
||||
class NMS(nn.Module):
|
||||
# Non-Maximum Suppression (NMS) module
|
||||
conf = 0.3 # confidence threshold
|
||||
iou = 0.6 # IoU threshold
|
||||
classes = None # (optional list) filter by class
|
||||
|
||||
def __init__(self, dimension=1):
|
||||
super(NMS, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return non_max_suppression(x[0], conf_thres=self.conf, iou_thres=self.iou, classes=self.classes)
|
||||
|
||||
|
||||
class Flatten(nn.Module):
|
||||
# Use after nn.AdaptiveAvgPool2d(1) to remove last 2 dimensions
|
||||
@staticmethod
|
||||
|
Loading…
x
Reference in New Issue
Block a user