mirror of https://github.com/WongKinYiu/yolov7.git
Update general.py
parent
a3c802d72e
commit
be1109c92a
142
utils/general.py
142
utils/general.py
|
@ -22,6 +22,14 @@ from utils.google_utils import gsutil_getsize
|
||||||
from utils.metrics import fitness
|
from utils.metrics import fitness
|
||||||
from utils.torch_utils import init_torch_seeds
|
from utils.torch_utils import init_torch_seeds
|
||||||
|
|
||||||
|
from utils.torch_utils import is_parallel
|
||||||
|
from torch.nn import functional as F
|
||||||
|
from detectron2.structures.masks import BitMasks
|
||||||
|
from detectron2.structures import Boxes
|
||||||
|
from detectron2.layers.roi_align import ROIAlign
|
||||||
|
from detectron2.utils.memory import retry_if_cuda_oom
|
||||||
|
from detectron2.layers import paste_masks_in_image
|
||||||
|
|
||||||
# Settings
|
# Settings
|
||||||
torch.set_printoptions(linewidth=320, precision=5, profile='long')
|
torch.set_printoptions(linewidth=320, precision=5, profile='long')
|
||||||
np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
|
np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
|
||||||
|
@ -43,6 +51,24 @@ def init_seeds(seed=0):
|
||||||
init_torch_seeds(seed)
|
init_torch_seeds(seed)
|
||||||
|
|
||||||
|
|
||||||
|
def merge_bases(rois, coeffs, attn_r, num_b, location_to_inds=None):
|
||||||
|
# merge predictions
|
||||||
|
# N = coeffs.size(0)
|
||||||
|
if location_to_inds is not None:
|
||||||
|
rois = rois[location_to_inds]
|
||||||
|
N, B, H, W = rois.size()
|
||||||
|
if coeffs.dim() != 4:
|
||||||
|
coeffs = coeffs.view(N, num_b, attn_r, attn_r)
|
||||||
|
# NA = coeffs.shape[1] // B
|
||||||
|
coeffs = F.interpolate(coeffs, (H, W),
|
||||||
|
mode="bilinear").softmax(dim=1)
|
||||||
|
# coeffs = coeffs.view(N, -1, B, H, W)
|
||||||
|
# rois = rois[:, None, ...].repeat(1, NA, 1, 1, 1)
|
||||||
|
# masks_preds, _ = (rois * coeffs).sum(dim=2) # c.max(dim=1)
|
||||||
|
masks_preds = (rois * coeffs).sum(dim=1)
|
||||||
|
return masks_preds
|
||||||
|
|
||||||
|
|
||||||
def get_latest_run(search_dir='.'):
|
def get_latest_run(search_dir='.'):
|
||||||
# Return path to most recent 'last.pt' in /runs (i.e. to --resume from)
|
# Return path to most recent 'last.pt' in /runs (i.e. to --resume from)
|
||||||
last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
|
last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
|
||||||
|
@ -795,6 +821,122 @@ def non_max_suppression_kpt(prediction, conf_thres=0.25, iou_thres=0.45, classes
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
def non_max_suppression_mask_conf(prediction, attn, bases, pooler, hyp, conf_thres=0.1, iou_thres=0.6, merge=False, classes=None, agnostic=False, mask_iou=None, vote=False):
|
||||||
|
|
||||||
|
if prediction.dtype is torch.float16:
|
||||||
|
prediction = prediction.float() # to FP32
|
||||||
|
nc = prediction[0].shape[1] - 5 # number of classes
|
||||||
|
xc = prediction[..., 4] > conf_thres # candidates
|
||||||
|
# Settings
|
||||||
|
min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height
|
||||||
|
max_det = 300 # maximum number of detections per image
|
||||||
|
time_limit = 10.0 # seconds to quit after
|
||||||
|
redundant = True # require redundant detections
|
||||||
|
multi_label = nc > 1 # multiple labels per box (adds 0.5ms/img)
|
||||||
|
|
||||||
|
t = time.time()
|
||||||
|
output = [None] * prediction.shape[0]
|
||||||
|
output_mask = [None] * prediction.shape[0]
|
||||||
|
output_mask_score = [None] * prediction.shape[0]
|
||||||
|
output_ac = [None] * prediction.shape[0]
|
||||||
|
output_ab = [None] * prediction.shape[0]
|
||||||
|
|
||||||
|
def RMS_contrast(masks):
|
||||||
|
mu = torch.mean(masks, dim=-1, keepdim=True)
|
||||||
|
return torch.sqrt(torch.mean((masks - mu)**2, dim=-1, keepdim=True))
|
||||||
|
|
||||||
|
|
||||||
|
for xi, x in enumerate(prediction): # image index, image inference
|
||||||
|
# Apply constraints
|
||||||
|
# x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
|
||||||
|
x = x[xc[xi]] # confidence
|
||||||
|
# Box (center x, center y, width, height) to (x1, y1, x2, y2)
|
||||||
|
box = xywh2xyxy(x[:, :4])
|
||||||
|
|
||||||
|
# If none remain process next image
|
||||||
|
if not x.shape[0]:
|
||||||
|
continue
|
||||||
|
|
||||||
|
a = attn[xi][xc[xi]]
|
||||||
|
base = bases[xi]
|
||||||
|
|
||||||
|
bboxes = Boxes(box)
|
||||||
|
pooled_bases = pooler([base[None]], [bboxes])
|
||||||
|
|
||||||
|
pred_masks = merge_bases(pooled_bases, a, hyp["attn_resolution"], hyp["num_base"]).view(a.shape[0], -1).sigmoid()
|
||||||
|
|
||||||
|
if mask_iou is not None:
|
||||||
|
mask_score = mask_iou[xi][xc[xi]][..., None]
|
||||||
|
else:
|
||||||
|
temp = pred_masks.clone()
|
||||||
|
temp[temp < 0.5] = 1 - temp[temp < 0.5]
|
||||||
|
mask_score = torch.exp(torch.log(temp).mean(dim=-1, keepdims=True))#torch.mean(temp, dim=-1, keepdims=True)
|
||||||
|
|
||||||
|
x[:, 5:] *= x[:, 4:5] * mask_score # x[:, 4:5] * * mask_conf * non_mask_conf # conf = obj_conf * cls_conf
|
||||||
|
|
||||||
|
if multi_label:
|
||||||
|
i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
|
||||||
|
x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
|
||||||
|
mask_score = mask_score[i]
|
||||||
|
if attn is not None:
|
||||||
|
pred_masks = pred_masks[i]
|
||||||
|
else: # best class only
|
||||||
|
conf, j = x[:, 5:].max(1, keepdim=True)
|
||||||
|
x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
|
||||||
|
|
||||||
|
# Filter by class
|
||||||
|
if classes:
|
||||||
|
x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
|
||||||
|
|
||||||
|
|
||||||
|
# If none remain process next image
|
||||||
|
n = x.shape[0] # number of boxes
|
||||||
|
if not n:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Batched NMS
|
||||||
|
c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
|
||||||
|
boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
|
||||||
|
# scores *= mask_score
|
||||||
|
i = torchvision.ops.boxes.nms(boxes, scores, iou_thres)
|
||||||
|
if i.shape[0] > max_det: # limit detections
|
||||||
|
i = i[:max_det]
|
||||||
|
|
||||||
|
|
||||||
|
all_candidates = []
|
||||||
|
all_boxes = []
|
||||||
|
if vote:
|
||||||
|
ious = box_iou(boxes[i], boxes) > iou_thres
|
||||||
|
for iou in ious:
|
||||||
|
selected_masks = pred_masks[iou]
|
||||||
|
k = min(10, selected_masks.shape[0])
|
||||||
|
_, tfive = torch.topk(scores[iou], k)
|
||||||
|
all_candidates.append(pred_masks[iou][tfive])
|
||||||
|
all_boxes.append(x[iou, :4][tfive])
|
||||||
|
#exit()
|
||||||
|
|
||||||
|
if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
|
||||||
|
try: # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
|
||||||
|
iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
|
||||||
|
weights = iou * scores[None] # box weights
|
||||||
|
x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
|
||||||
|
if redundant:
|
||||||
|
i = i[iou.sum(1) > 1] # require redundancy
|
||||||
|
except: # possible CUDA error https://github.com/ultralytics/yolov3/issues/1139
|
||||||
|
print(x, i, x.shape, i.shape)
|
||||||
|
pass
|
||||||
|
|
||||||
|
output[xi] = x[i]
|
||||||
|
output_mask_score[xi] = mask_score[i]
|
||||||
|
output_ac[xi] = all_candidates
|
||||||
|
output_ab[xi] = all_boxes
|
||||||
|
if attn is not None:
|
||||||
|
output_mask[xi] = pred_masks[i]
|
||||||
|
if (time.time() - t) > time_limit:
|
||||||
|
break # time limit exceeded
|
||||||
|
|
||||||
|
return output, output_mask, output_mask_score, output_ac, output_ab
|
||||||
|
|
||||||
|
|
||||||
def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_optimizer()
|
def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_optimizer()
|
||||||
# Strip optimizer from 'f' to finalize training, optionally save as 's'
|
# Strip optimizer from 'f' to finalize training, optionally save as 's'
|
||||||
|
|
Loading…
Reference in New Issue