Feature visualization update (#3920)
* Feature visualization update * Save to jpg (faster) * Save to pngpull/3923/head
parent
61047a2b4f
commit
87b094bcbc
|
@ -40,6 +40,7 @@ def run(weights='yolov5s.pt', # model.pt path(s)
|
|||
classes=None, # filter by class: --class 0, or --class 0 2 3
|
||||
agnostic_nms=False, # class-agnostic NMS
|
||||
augment=False, # augmented inference
|
||||
visualize=False, # visualize features
|
||||
update=False, # update all models
|
||||
project='runs/detect', # save results to project/name
|
||||
name='exp', # save results to project/name
|
||||
|
@ -100,7 +101,9 @@ def run(weights='yolov5s.pt', # model.pt path(s)
|
|||
|
||||
# Inference
|
||||
t1 = time_synchronized()
|
||||
pred = model(img, augment=augment)[0]
|
||||
pred = model(img,
|
||||
augment=augment,
|
||||
visualize=increment_path(save_dir / 'features', mkdir=True) if visualize else False)[0]
|
||||
|
||||
# Apply NMS
|
||||
pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)
|
||||
|
@ -201,6 +204,7 @@ def parse_opt():
|
|||
parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --class 0, or --class 0 2 3')
|
||||
parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS')
|
||||
parser.add_argument('--augment', action='store_true', help='augmented inference')
|
||||
parser.add_argument('--visualize', action='store_true', help='visualize features')
|
||||
parser.add_argument('--update', action='store_true', help='update all models')
|
||||
parser.add_argument('--project', default='runs/detect', help='save results to project/name')
|
||||
parser.add_argument('--name', default='exp', help='save results to project/name')
|
||||
|
|
|
@ -117,11 +117,10 @@ class Model(nn.Module):
|
|||
self.info()
|
||||
logger.info('')
|
||||
|
||||
def forward(self, x, augment=False, profile=False):
|
||||
def forward(self, x, augment=False, profile=False, visualize=False):
|
||||
if augment:
|
||||
return self.forward_augment(x) # augmented inference, None
|
||||
else:
|
||||
return self.forward_once(x, profile) # single-scale inference, train
|
||||
return self.forward_once(x, profile, visualize) # single-scale inference, train
|
||||
|
||||
def forward_augment(self, x):
|
||||
img_size = x.shape[-2:] # height, width
|
||||
|
@ -136,7 +135,7 @@ class Model(nn.Module):
|
|||
y.append(yi)
|
||||
return torch.cat(y, 1), None # augmented inference, train
|
||||
|
||||
def forward_once(self, x, profile=False, feature_vis=False):
|
||||
def forward_once(self, x, profile=False, visualize=False):
|
||||
y, dt = [], [] # outputs
|
||||
for m in self.model:
|
||||
if m.f != -1: # if not from previous layer
|
||||
|
@ -155,8 +154,8 @@ class Model(nn.Module):
|
|||
x = m(x) # run
|
||||
y.append(x if m.i in self.save else None) # save output
|
||||
|
||||
if feature_vis and m.type == 'models.common.SPP':
|
||||
feature_visualization(x, m.type, m.i)
|
||||
if visualize:
|
||||
feature_visualization(x, m.type, m.i, save_dir=visualize)
|
||||
|
||||
if profile:
|
||||
logger.info('%.1fms total' % sum(dt))
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
# Plotting utils
|
||||
|
||||
import glob
|
||||
import math
|
||||
import os
|
||||
from copy import copy
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import math
|
||||
import matplotlib
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
@ -15,7 +15,6 @@ import seaborn as sn
|
|||
import torch
|
||||
import yaml
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
from torchvision import transforms
|
||||
|
||||
from utils.general import increment_path, xywh2xyxy, xyxy2xywh
|
||||
from utils.metrics import fitness
|
||||
|
@ -448,28 +447,26 @@ def plot_results(start=0, stop=0, bucket='', id=(), labels=(), save_dir=''):
|
|||
fig.savefig(Path(save_dir) / 'results.png', dpi=200)
|
||||
|
||||
|
||||
def feature_visualization(x, module_type, stage, n=64):
|
||||
def feature_visualization(x, module_type, stage, n=64, save_dir=Path('runs/detect/exp')):
|
||||
"""
|
||||
x: Features to be visualized
|
||||
module_type: Module type
|
||||
stage: Module stage within model
|
||||
n: Maximum number of feature maps to plot
|
||||
save_dir: Directory to save results
|
||||
"""
|
||||
batch, channels, height, width = x.shape # batch, channels, height, width
|
||||
if height > 1 and width > 1:
|
||||
project, name = 'runs/features', 'exp'
|
||||
save_dir = increment_path(Path(project) / name) # increment run
|
||||
save_dir.mkdir(parents=True, exist_ok=True) # make dir
|
||||
if 'Detect' not in module_type:
|
||||
batch, channels, height, width = x.shape # batch, channels, height, width
|
||||
if height > 1 and width > 1:
|
||||
f = f"stage{stage}_{module_type.split('.')[-1]}_features.png" # filename
|
||||
|
||||
plt.figure(tight_layout=True)
|
||||
blocks = torch.chunk(x, channels, dim=1) # block by channel dimension
|
||||
n = min(n, len(blocks))
|
||||
for i in range(n):
|
||||
feature = transforms.ToPILImage()(blocks[i].squeeze())
|
||||
ax = plt.subplot(int(math.sqrt(n)), int(math.sqrt(n)), i + 1)
|
||||
ax.axis('off')
|
||||
plt.imshow(feature) # cmap='gray'
|
||||
plt.figure(tight_layout=True)
|
||||
blocks = torch.chunk(x[0], channels, dim=0) # select batch index 0, block by channels
|
||||
n = min(n, channels) # number of plots
|
||||
ax = plt.subplots(math.ceil(n / 8), 8, tight_layout=True)[1].ravel() # 8 rows x n/8 cols
|
||||
for i in range(n):
|
||||
ax[i].imshow(blocks[i].squeeze()) # cmap='gray'
|
||||
ax[i].axis('off')
|
||||
|
||||
f = f"stage_{stage}_{module_type.split('.')[-1]}_features.png"
|
||||
print(f'Saving {save_dir / f}...')
|
||||
plt.savefig(save_dir / f, dpi=300)
|
||||
print(f'Saving {save_dir / f}... ({n}/{channels})')
|
||||
plt.savefig(save_dir / f, dpi=300)
|
||||
|
|
Loading…
Reference in New Issue