WACV 2022 release
commit
6a520d5eeb
|
@ -0,0 +1,61 @@
|
|||
## CFLOW-AD: Real-Time Unsupervised Anomaly Detection with Localization via Conditional Normalizing Flows
|
||||
WACV 2022 preprint:[https://arxiv.org/abs/????.?????](https://arxiv.org/abs/????.?????)
|
||||
|
||||
## Abstract
|
||||
Unsupervised anomaly detection with localization has many practical applications when labeling is infeasible and, moreover, when anomaly examples are completely missing in the train data. While recently proposed models for such data setup achieve high accuracy metrics, their complexity is a limiting factor for real-time processing. In this paper, we propose a real-time model and analytically derive its relationship to prior methods. Our CFLOW-AD model is based on a conditional normalizing flow framework adopted for anomaly detection with localization. In particular, CFLOW-AD consists of a discriminatively pretrained encoder followed by a multi-scale generative decoders where the latter explicitly estimate likelihood of the encoded features. Our approach results in a computationally and memory-efficient model: CFLOW-AD is faster and smaller by a factor of 10x than prior state-of-the-art with the same input setting. Our experiments on the MVTec dataset show that CFLOW-AD outperforms previous methods by 0.36% AUROC in detection task, by 1.12% AUROC and 2.5% AUPRO in localization task, respectively. We open-source our code with fully reproducible experiments.
|
||||
|
||||
## BibTex Citation
|
||||
If you like our [paper](https://arxiv.org/abs/????.?????) or code, please cite its WACV 2022 preprint using the following BibTex:
|
||||
```
|
||||
@article{cflow_ad,
|
||||
title={CFLOW-AD: Real-Time Unsupervised Anomaly Detection with Localization via Conditional Normalizing Flows},
|
||||
author={Gudovskiy, Denis and Ishizaka, Shun and Kozuka, Kazuki and Tsukizawa, Sotaro},
|
||||
journal={arXiv:????.?????},
|
||||
year={2021}
|
||||
}
|
||||
```
|
||||
|
||||
## Installation
|
||||
- Clone this repository: tested on Python 3.8
|
||||
- Install [PyTorch](http://pytorch.org/): tested on v1.8
|
||||
- Install [FrEIA Flows](https://github.com/VLL-HD/FrEIA): tested on [the recent branch](https://github.com/VLL-HD/FrEIA/tree/4e0c6ab42b26ec6e41b1ee2abb1a8b6562752b00)
|
||||
- Other dependencies in requirements.txt
|
||||
|
||||
Install all packages with this command:
|
||||
```
|
||||
$ python3 -m pip install -U -r requirements.txt
|
||||
```
|
||||
|
||||
## Datasets
|
||||
We support [MVTec AD dataset](https://www.mvtec.com/de/unternehmen/forschung/datasets/mvtec-ad/) for anomaly localization in factory setting and [Shanghai Tech Campus (STC)](https://svip-lab.github.io/dataset/campus_dataset.html) dataset with surveillance camera videos. Please, download dataset from URLs and extract to 'data' folder.
|
||||
|
||||
## Code Organization
|
||||
- ./custom_datasets - contains dataloaders for MVTec and STC
|
||||
- ./custom_models - contains pretrained feature extractors
|
||||
|
||||
## Running Experiments
|
||||
- Run code by selecting class name, feature extractor, input size, flow model etc.
|
||||
- The commands below should reproduce our reference MVTec results:
|
||||
```
|
||||
python3 main.py --gpu 0 --pro -inp 512 --dataset mvtec --class-name bottle
|
||||
python3 main.py --gpu 0 --pro -inp 256 --dataset mvtec --class-name cable
|
||||
python3 main.py --gpu 0 --pro -inp 256 --dataset mvtec --class-name capsule
|
||||
python3 main.py --gpu 0 --pro -inp 512 --dataset mvtec --class-name carpet
|
||||
python3 main.py --gpu 0 --pro -inp 512 --dataset mvtec --class-name grid
|
||||
python3 main.py --gpu 0 --pro -inp 256 --dataset mvtec --class-name hazelnut
|
||||
python3 main.py --gpu 0 --pro -inp 512 --dataset mvtec --class-name leather
|
||||
python3 main.py --gpu 0 --pro -inp 256 --dataset mvtec --class-name metal_nut
|
||||
python3 main.py --gpu 0 --pro -inp 256 --dataset mvtec --class-name pill
|
||||
python3 main.py --gpu 0 --pro -inp 512 --dataset mvtec --class-name screw
|
||||
python3 main.py --gpu 0 --pro -inp 512 --dataset mvtec --class-name tile
|
||||
python3 main.py --gpu 0 --pro -inp 512 --dataset mvtec --class-name toothbrush
|
||||
python3 main.py --gpu 0 --pro -inp 128 --dataset mvtec --class-name transistor
|
||||
python3 main.py --gpu 0 --pro -inp 512 --dataset mvtec --class-name wood
|
||||
python3 main.py --gpu 0 --pro -inp 512 --dataset mvtec --class-name zipper
|
||||
```
|
||||
|
||||
## CFLOW-AD Architecture
|
||||

|
||||
|
||||
## Reference CFLOW-AD Results for MVTec
|
||||

|
|
@ -0,0 +1,52 @@
|
|||
from __future__ import print_function
|
||||
import argparse
|
||||
|
||||
__all__ = ['get_args']
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(description='CFLOW-AD')
|
||||
parser.add_argument('--dataset', default='mvtec', type=str, metavar='D',
|
||||
help='dataset name: mvtec/stc (default: mvtec)')
|
||||
parser.add_argument('--checkpoint', default='', type=str, metavar='D',
|
||||
help='file with saved checkpoint')
|
||||
parser.add_argument('-cl', '--class-name', default='none', type=str, metavar='C',
|
||||
help='class name for MVTec/STC (default: none)')
|
||||
parser.add_argument('-enc', '--enc-arch', default='wide_resnet50_2', type=str, metavar='A',
|
||||
help='feature extractor architecture (default: wide_resnet50_2)')
|
||||
parser.add_argument('-dec', '--dec-arch', default='freia-cflow', type=str, metavar='A',
|
||||
help='normalizing flow model (default: freia-cflow)')
|
||||
parser.add_argument('-pl', '--pool-layers', default=3, type=int, metavar='L',
|
||||
help='number of layers used in NF model (default: 3)')
|
||||
parser.add_argument('-cb', '--coupling-blocks', default=8, type=int, metavar='L',
|
||||
help='number of layers used in NF model (default: 8)')
|
||||
parser.add_argument('-run', '--run-name', default=0, type=int, metavar='C',
|
||||
help='name of the run (default: 0)')
|
||||
parser.add_argument('-inp', '--input-size', default=256, type=int, metavar='C',
|
||||
help='image resize dimensions (default: 256)')
|
||||
parser.add_argument("--action-type", default='norm-train', type=str, metavar='T',
|
||||
help='norm-train (default: norm-train)')
|
||||
parser.add_argument('-bs', '--batch-size', default=32, type=int, metavar='B',
|
||||
help='train batch size (default: 32)')
|
||||
parser.add_argument('--lr', type=float, default=2e-4, metavar='LR',
|
||||
help='learning rate (default: 2e-4)')
|
||||
parser.add_argument('--meta-epochs', type=int, default=25, metavar='N',
|
||||
help='number of meta epochs to train (default: 25)')
|
||||
parser.add_argument('--sub-epochs', type=int, default=8, metavar='N',
|
||||
help='number of sub epochs to train (default: 8)')
|
||||
parser.add_argument('--pro', action='store_true', default=False,
|
||||
help='enables estimation of AUPRO metric')
|
||||
parser.add_argument('--viz', action='store_true', default=False,
|
||||
help='saves test data visualizations')
|
||||
parser.add_argument('--workers', default=4, type=int, metavar='G',
|
||||
help='number of data loading workers (default: 4)')
|
||||
parser.add_argument("--gpu", default='0', type=str, metavar='G',
|
||||
help='GPU device number')
|
||||
parser.add_argument('--no-cuda', action='store_true', default=False,
|
||||
help='disables CUDA training')
|
||||
parser.add_argument('--video-path', default='.', type=str, metavar='D',
|
||||
help='video file path')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
return args
|
|
@ -0,0 +1 @@
|
|||
from .loader import MVTecDataset, StcDataset
|
|
@ -0,0 +1,219 @@
|
|||
import os
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
import torch
|
||||
from torchvision.io import read_video, write_jpeg
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision import transforms as T
|
||||
|
||||
|
||||
__all__ = ('MVTecDataset', 'StcDataset')
|
||||
|
||||
|
||||
# URL = 'ftp://guest:GU.205dldo@ftp.softronics.ch/mvtec_anomaly_detection/mvtec_anomaly_detection.tar.xz'
|
||||
MVTEC_CLASS_NAMES = ['bottle', 'cable', 'capsule', 'carpet', 'grid',
|
||||
'hazelnut', 'leather', 'metal_nut', 'pill', 'screw',
|
||||
'tile', 'toothbrush', 'transistor', 'wood', 'zipper']
|
||||
|
||||
STC_CLASS_NAMES = ['01', '02', '03', '04', '05', '06',
|
||||
'07', '08', '09', '10', '11', '12'] #, '13' - no ground-truth]
|
||||
|
||||
|
||||
class StcDataset(Dataset):
|
||||
def __init__(self, c, is_train=True):
|
||||
assert c.class_name in STC_CLASS_NAMES, 'class_name: {}, should be in {}'.format(c.class_name, STC_CLASS_NAMES)
|
||||
self.class_name = c.class_name
|
||||
self.is_train = is_train
|
||||
self.cropsize = c.crp_size
|
||||
#
|
||||
if is_train:
|
||||
self.dataset_path = os.path.join(c.data_path, 'training')
|
||||
self.dataset_vid = os.path.join(self.dataset_path, 'videos')
|
||||
self.dataset_dir = os.path.join(self.dataset_path, 'frames')
|
||||
self.dataset_files = sorted([f for f in os.listdir(self.dataset_vid) if f.startswith(self.class_name)])
|
||||
if not os.path.isdir(self.dataset_dir):
|
||||
os.mkdir(self.dataset_dir)
|
||||
done_file = os.path.join(self.dataset_path, 'frames_{}.pt'.format(self.class_name))
|
||||
print(done_file)
|
||||
H, W = 480, 856
|
||||
if os.path.isfile(done_file):
|
||||
assert torch.load(done_file) == len(self.dataset_files), 'train frames are not processed!'
|
||||
else:
|
||||
count = 0
|
||||
for dataset_file in self.dataset_files:
|
||||
print(dataset_file)
|
||||
data = read_video(os.path.join(self.dataset_vid, dataset_file)) # read video file entirely -> mem issue!!!
|
||||
vid = data[0] # weird read_video that returns byte tensor in format [T,H,W,C]
|
||||
fps = data[2]['video_fps']
|
||||
print('video mu/std: {}/{} {}'.format(torch.mean(vid/255.0, (0,1,2)), torch.std(vid/255.0, (0,1,2)), vid.shape))
|
||||
assert [H, W] == [vid.size(1), vid.size(2)], 'same H/W'
|
||||
dataset_file_dir = os.path.join(self.dataset_dir, os.path.splitext(dataset_file)[0])
|
||||
os.mkdir(dataset_file_dir)
|
||||
count = count + 1
|
||||
for i, frame in enumerate(vid):
|
||||
filename = '{0:08d}.jpg'.format(i)
|
||||
write_jpeg(frame.permute((2, 0, 1)), os.path.join(dataset_file_dir, filename), 80)
|
||||
torch.save(torch.tensor(count), done_file)
|
||||
#
|
||||
self.x, self.y, self.mask = self.load_dataset_folder()
|
||||
else:
|
||||
self.dataset_path = os.path.join(c.data_path, 'testing')
|
||||
self.x, self.y, self.mask = self.load_dataset_folder()
|
||||
|
||||
# set transforms
|
||||
if is_train:
|
||||
self.transform_x = T.Compose([
|
||||
T.Resize(c.img_size, Image.ANTIALIAS),
|
||||
T.RandomRotation(5),
|
||||
T.CenterCrop(c.crp_size),
|
||||
T.ToTensor()])
|
||||
# test:
|
||||
else:
|
||||
self.transform_x = T.Compose([
|
||||
T.Resize(c.img_size, Image.ANTIALIAS),
|
||||
T.CenterCrop(c.crp_size),
|
||||
T.ToTensor()])
|
||||
# mask
|
||||
self.transform_mask = T.Compose([
|
||||
T.ToPILImage(),
|
||||
T.Resize(c.img_size, Image.NEAREST),
|
||||
T.CenterCrop(c.crp_size),
|
||||
T.ToTensor()])
|
||||
|
||||
self.normalize = T.Compose([T.Normalize(c.norm_mean, c.norm_std)])
|
||||
|
||||
def __getitem__(self, idx):
|
||||
x, y, mask = self.x[idx], self.y[idx], self.mask[idx]
|
||||
x = Image.open(x).convert('RGB')
|
||||
x = self.normalize(self.transform_x(x))
|
||||
if y == 0: #self.is_train:
|
||||
mask = torch.zeros([1, self.cropsize[0], self.cropsize[1]])
|
||||
else:
|
||||
mask = self.transform_mask(mask)
|
||||
#
|
||||
return x, y, mask
|
||||
|
||||
def __len__(self):
|
||||
return len(self.x)
|
||||
|
||||
def load_dataset_folder(self):
|
||||
phase = 'train' if self.is_train else 'test'
|
||||
x, y, mask = list(), list(), list()
|
||||
img_dir = os.path.join(self.dataset_path, 'frames')
|
||||
img_types = sorted([f for f in os.listdir(img_dir) if f.startswith(self.class_name)])
|
||||
gt_frame_dir = os.path.join(self.dataset_path, 'test_frame_mask')
|
||||
gt_pixel_dir = os.path.join(self.dataset_path, 'test_pixel_mask')
|
||||
for i, img_type in enumerate(img_types):
|
||||
print('Folder:', img_type)
|
||||
# load images
|
||||
img_type_dir = os.path.join(img_dir, img_type)
|
||||
img_fpath_list = sorted([os.path.join(img_type_dir, f) for f in os.listdir(img_type_dir) if f.endswith('.jpg')])
|
||||
x.extend(img_fpath_list)
|
||||
# labels for every test image
|
||||
if phase == 'test':
|
||||
gt_pixel = np.load('{}.npy'.format(os.path.join(gt_pixel_dir, img_type)))
|
||||
gt_frame = np.load('{}.npy'.format(os.path.join(gt_frame_dir, img_type)))
|
||||
if i == 0:
|
||||
m = gt_pixel
|
||||
y = gt_frame
|
||||
else:
|
||||
m = np.concatenate((m, gt_pixel), axis=0)
|
||||
y = np.concatenate((y, gt_frame), axis=0)
|
||||
#
|
||||
mask = [e for e in m] # np.expand_dims(e, axis=0)
|
||||
assert len(x) == len(y), 'number of x and y should be same'
|
||||
assert len(x) == len(mask), 'number of x and mask should be same'
|
||||
else:
|
||||
mask.extend([None] * len(img_fpath_list))
|
||||
y.extend([0] * len(img_fpath_list))
|
||||
#
|
||||
return list(x), list(y), list(mask)
|
||||
|
||||
|
||||
class MVTecDataset(Dataset):
|
||||
def __init__(self, c, is_train=True):
|
||||
assert c.class_name in MVTEC_CLASS_NAMES, 'class_name: {}, should be in {}'.format(c.class_name, MVTEC_CLASS_NAMES)
|
||||
self.dataset_path = c.data_path
|
||||
self.class_name = c.class_name
|
||||
self.is_train = is_train
|
||||
self.cropsize = c.crp_size
|
||||
# load dataset
|
||||
self.x, self.y, self.mask = self.load_dataset_folder()
|
||||
# set transforms
|
||||
if is_train:
|
||||
self.transform_x = T.Compose([
|
||||
T.Resize(c.img_size, Image.ANTIALIAS),
|
||||
T.RandomRotation(5),
|
||||
T.CenterCrop(c.crp_size),
|
||||
T.ToTensor()])
|
||||
# test:
|
||||
else:
|
||||
self.transform_x = T.Compose([
|
||||
T.Resize(c.img_size, Image.ANTIALIAS),
|
||||
T.CenterCrop(c.crp_size),
|
||||
T.ToTensor()])
|
||||
# mask
|
||||
self.transform_mask = T.Compose([
|
||||
T.Resize(c.img_size, Image.NEAREST),
|
||||
T.CenterCrop(c.crp_size),
|
||||
T.ToTensor()])
|
||||
|
||||
self.normalize = T.Compose([T.Normalize(c.norm_mean, c.norm_std)])
|
||||
|
||||
def __getitem__(self, idx):
|
||||
x, y, mask = self.x[idx], self.y[idx], self.mask[idx]
|
||||
#x = Image.open(x).convert('RGB')
|
||||
x = Image.open(x)
|
||||
if self.class_name in ['zipper', 'screw', 'grid']: # handle greyscale classes
|
||||
x = np.expand_dims(np.array(x), axis=2)
|
||||
x = np.concatenate([x, x, x], axis=2)
|
||||
|
||||
x = Image.fromarray(x.astype('uint8')).convert('RGB')
|
||||
#
|
||||
x = self.normalize(self.transform_x(x))
|
||||
#
|
||||
if y == 0:
|
||||
mask = torch.zeros([1, self.cropsize[0], self.cropsize[1]])
|
||||
else:
|
||||
mask = Image.open(mask)
|
||||
mask = self.transform_mask(mask)
|
||||
|
||||
return x, y, mask
|
||||
|
||||
def __len__(self):
|
||||
return len(self.x)
|
||||
|
||||
def load_dataset_folder(self):
|
||||
phase = 'train' if self.is_train else 'test'
|
||||
x, y, mask = [], [], []
|
||||
|
||||
img_dir = os.path.join(self.dataset_path, self.class_name, phase)
|
||||
gt_dir = os.path.join(self.dataset_path, self.class_name, 'ground_truth')
|
||||
|
||||
img_types = sorted(os.listdir(img_dir))
|
||||
for img_type in img_types:
|
||||
|
||||
# load images
|
||||
img_type_dir = os.path.join(img_dir, img_type)
|
||||
if not os.path.isdir(img_type_dir):
|
||||
continue
|
||||
img_fpath_list = sorted([os.path.join(img_type_dir, f)
|
||||
for f in os.listdir(img_type_dir)
|
||||
if f.endswith('.png')])
|
||||
x.extend(img_fpath_list)
|
||||
|
||||
# load gt labels
|
||||
if img_type == 'good':
|
||||
y.extend([0] * len(img_fpath_list))
|
||||
mask.extend([None] * len(img_fpath_list))
|
||||
else:
|
||||
y.extend([1] * len(img_fpath_list))
|
||||
gt_type_dir = os.path.join(gt_dir, img_type)
|
||||
img_fname_list = [os.path.splitext(os.path.basename(f))[0] for f in img_fpath_list]
|
||||
gt_fpath_list = [os.path.join(gt_type_dir, img_fname + '_mask.png')
|
||||
for img_fname in img_fname_list]
|
||||
mask.extend(gt_fpath_list)
|
||||
|
||||
assert len(x) == len(y), 'number of x and y should be same'
|
||||
|
||||
return list(x), list(y), list(mask)
|
|
@ -0,0 +1,3 @@
|
|||
from .resnet import *
|
||||
from .mobilenetv3 import *
|
||||
from .utils import *
|
|
@ -0,0 +1,276 @@
|
|||
import torch
|
||||
|
||||
from functools import partial
|
||||
from torch import nn, Tensor
|
||||
from torch.nn import functional as F
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence
|
||||
|
||||
from torchvision.models.utils import load_state_dict_from_url
|
||||
from torchvision.models.mobilenetv2 import _make_divisible, ConvBNActivation
|
||||
|
||||
|
||||
__all__ = ["MobileNetV3", "mobilenet_v3_large", "mobilenet_v3_small"]
|
||||
|
||||
|
||||
model_urls = {
|
||||
"mobilenet_v3_large": "https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth",
|
||||
"mobilenet_v3_small": "https://download.pytorch.org/models/mobilenet_v3_small-047dcff4.pth",
|
||||
}
|
||||
|
||||
class SqueezeExcitation(nn.Module):
|
||||
|
||||
def __init__(self, input_channels: int, squeeze_factor: int = 4):
|
||||
super().__init__()
|
||||
squeeze_channels = _make_divisible(input_channels // squeeze_factor, 8)
|
||||
self.fc1 = nn.Conv2d(input_channels, squeeze_channels, 1)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.fc2 = nn.Conv2d(squeeze_channels, input_channels, 1)
|
||||
|
||||
def _scale(self, input: Tensor, inplace: bool) -> Tensor:
|
||||
scale = F.adaptive_avg_pool2d(input, 1)
|
||||
scale = self.fc1(scale)
|
||||
scale = self.relu(scale)
|
||||
scale = self.fc2(scale)
|
||||
return F.hardsigmoid(scale, inplace=inplace)
|
||||
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
scale = self._scale(input, True)
|
||||
return scale * input
|
||||
|
||||
|
||||
class InvertedResidualConfig:
|
||||
|
||||
def __init__(self, input_channels: int, kernel: int, expanded_channels: int, out_channels: int, use_se: bool,
|
||||
activation: str, stride: int, dilation: int, width_mult: float):
|
||||
self.input_channels = self.adjust_channels(input_channels, width_mult)
|
||||
self.kernel = kernel
|
||||
self.expanded_channels = self.adjust_channels(expanded_channels, width_mult)
|
||||
self.out_channels = self.adjust_channels(out_channels, width_mult)
|
||||
self.use_se = use_se
|
||||
self.use_hs = activation == "HS"
|
||||
self.stride = stride
|
||||
self.dilation = dilation
|
||||
|
||||
@staticmethod
|
||||
def adjust_channels(channels: int, width_mult: float):
|
||||
return _make_divisible(channels * width_mult, 8)
|
||||
|
||||
|
||||
class InvertedResidual(nn.Module):
|
||||
|
||||
def __init__(self, cnf: InvertedResidualConfig, norm_layer: Callable[..., nn.Module],
|
||||
se_layer: Callable[..., nn.Module] = SqueezeExcitation):
|
||||
super().__init__()
|
||||
if not (1 <= cnf.stride <= 2):
|
||||
raise ValueError('illegal stride value')
|
||||
|
||||
self.use_res_connect = cnf.stride == 1 and cnf.input_channels == cnf.out_channels
|
||||
|
||||
layers: List[nn.Module] = []
|
||||
activation_layer = nn.Hardswish if cnf.use_hs else nn.ReLU
|
||||
|
||||
# expand
|
||||
if cnf.expanded_channels != cnf.input_channels:
|
||||
layers.append(ConvBNActivation(cnf.input_channels, cnf.expanded_channels, kernel_size=1,
|
||||
norm_layer=norm_layer, activation_layer=activation_layer))
|
||||
|
||||
# depthwise
|
||||
stride = 1 if cnf.dilation > 1 else cnf.stride
|
||||
layers.append(ConvBNActivation(cnf.expanded_channels, cnf.expanded_channels, kernel_size=cnf.kernel,
|
||||
stride=stride, dilation=cnf.dilation, groups=cnf.expanded_channels,
|
||||
norm_layer=norm_layer, activation_layer=activation_layer))
|
||||
if cnf.use_se:
|
||||
layers.append(se_layer(cnf.expanded_channels))
|
||||
|
||||
# project
|
||||
layers.append(ConvBNActivation(cnf.expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer,
|
||||
activation_layer=nn.Identity))
|
||||
|
||||
self.block = nn.Sequential(*layers)
|
||||
self.out_channels = cnf.out_channels
|
||||
self._is_cn = cnf.stride > 1
|
||||
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
result = self.block(input)
|
||||
if self.use_res_connect:
|
||||
result += input
|
||||
return result
|
||||
|
||||
|
||||
class MobileNetV3(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
inverted_residual_setting: List[InvertedResidualConfig],
|
||||
last_channel: int,
|
||||
num_classes: int = 1000,
|
||||
block: Optional[Callable[..., nn.Module]] = None,
|
||||
norm_layer: Optional[Callable[..., nn.Module]] = None
|
||||
) -> None:
|
||||
"""
|
||||
MobileNet V3 main class
|
||||
|
||||
Args:
|
||||
inverted_residual_setting (List[InvertedResidualConfig]): Network structure
|
||||
last_channel (int): The number of channels on the penultimate layer
|
||||
num_classes (int): Number of classes
|
||||
block (Optional[Callable[..., nn.Module]]): Module specifying inverted residual building block for mobilenet
|
||||
norm_layer (Optional[Callable[..., nn.Module]]): Module specifying the normalization layer to use
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
if not inverted_residual_setting:
|
||||
raise ValueError("The inverted_residual_setting should not be empty")
|
||||
elif not (isinstance(inverted_residual_setting, Sequence) and
|
||||
all([isinstance(s, InvertedResidualConfig) for s in inverted_residual_setting])):
|
||||
raise TypeError("The inverted_residual_setting should be List[InvertedResidualConfig]")
|
||||
|
||||
if block is None:
|
||||
block = InvertedResidual
|
||||
|
||||
if norm_layer is None:
|
||||
norm_layer = partial(nn.BatchNorm2d, eps=0.001, momentum=0.01)
|
||||
|
||||
layers: List[nn.Module] = []
|
||||
|
||||
# building first layer
|
||||
firstconv_output_channels = inverted_residual_setting[0].input_channels
|
||||
layers.append(ConvBNActivation(3, firstconv_output_channels, kernel_size=3, stride=2, norm_layer=norm_layer,
|
||||
activation_layer=nn.Hardswish))
|
||||
|
||||
# building inverted residual blocks
|
||||
for cnf in inverted_residual_setting:
|
||||
layers.append(block(cnf, norm_layer))
|
||||
|
||||
# building last several layers
|
||||
lastconv_input_channels = inverted_residual_setting[-1].out_channels
|
||||
lastconv_output_channels = 6 * lastconv_input_channels
|
||||
layers.append(ConvBNActivation(lastconv_input_channels, lastconv_output_channels, kernel_size=1,
|
||||
norm_layer=norm_layer, activation_layer=nn.Hardswish))
|
||||
|
||||
self.features = nn.Sequential(*layers)
|
||||
#self.avgpool = nn.AdaptiveAvgPool2d(1)
|
||||
#self.classifier = nn.Sequential(
|
||||
# nn.Linear(lastconv_output_channels, last_channel),
|
||||
# nn.Hardswish(inplace=True),
|
||||
# nn.Dropout(p=0.2, inplace=True),
|
||||
# nn.Linear(last_channel, num_classes),
|
||||
#)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out')
|
||||
if m.bias is not None:
|
||||
nn.init.zeros_(m.bias)
|
||||
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||
nn.init.ones_(m.weight)
|
||||
nn.init.zeros_(m.bias)
|
||||
elif isinstance(m, nn.Linear):
|
||||
nn.init.normal_(m.weight, 0, 0.01)
|
||||
nn.init.zeros_(m.bias)
|
||||
|
||||
def _forward_impl(self, x: Tensor) -> Tensor:
|
||||
x = self.features(x)
|
||||
# remove extra layers
|
||||
#x = self.avgpool(x)
|
||||
#x = torch.flatten(x, 1)
|
||||
#x = self.classifier(x)
|
||||
return x
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return self._forward_impl(x)
|
||||
|
||||
|
||||
def _mobilenet_v3_conf(arch: str, params: Dict[str, Any]):
|
||||
# non-public config parameters
|
||||
reduce_divider = 2 if params.pop('_reduced_tail', False) else 1
|
||||
dilation = 2 if params.pop('_dilated', False) else 1
|
||||
width_mult = params.pop('_width_mult', 1.0)
|
||||
|
||||
bneck_conf = partial(InvertedResidualConfig, width_mult=width_mult)
|
||||
adjust_channels = partial(InvertedResidualConfig.adjust_channels, width_mult=width_mult)
|
||||
|
||||
if arch == "mobilenet_v3_large":
|
||||
inverted_residual_setting = [
|
||||
bneck_conf(16, 3, 16, 16, False, "RE", 1, 1),
|
||||
bneck_conf(16, 3, 64, 24, False, "RE", 2, 1), # C1
|
||||
bneck_conf(24, 3, 72, 24, False, "RE", 1, 1),
|
||||
bneck_conf(24, 5, 72, 40, True, "RE", 2, 1), # C2
|
||||
bneck_conf(40, 5, 120, 40, True, "RE", 1, 1),
|
||||
bneck_conf(40, 5, 120, 40, True, "RE", 1, 1),
|
||||
bneck_conf(40, 3, 240, 80, False, "HS", 2, 1), # C3
|
||||
bneck_conf(80, 3, 200, 80, False, "HS", 1, 1),
|
||||
bneck_conf(80, 3, 184, 80, False, "HS", 1, 1),
|
||||
bneck_conf(80, 3, 184, 80, False, "HS", 1, 1),
|
||||
bneck_conf(80, 3, 480, 112, True, "HS", 1, 1),
|
||||
bneck_conf(112, 3, 672, 112, True, "HS", 1, 1),
|
||||
bneck_conf(112, 5, 672, 160 // reduce_divider, True, "HS", 2, dilation), # C4
|
||||
bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, True, "HS", 1, dilation),
|
||||
bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, True, "HS", 1, dilation),
|
||||
]
|
||||
last_channel = adjust_channels(1280 // reduce_divider) # C5
|
||||
elif arch == "mobilenet_v3_small":
|
||||
inverted_residual_setting = [
|
||||
bneck_conf(16, 3, 16, 16, True, "RE", 2, 1), # C1
|
||||
bneck_conf(16, 3, 72, 24, False, "RE", 2, 1), # C2
|
||||
bneck_conf(24, 3, 88, 24, False, "RE", 1, 1),
|
||||
bneck_conf(24, 5, 96, 40, True, "HS", 2, 1), # C3
|
||||
bneck_conf(40, 5, 240, 40, True, "HS", 1, 1),
|
||||
bneck_conf(40, 5, 240, 40, True, "HS", 1, 1),
|
||||
bneck_conf(40, 5, 120, 48, True, "HS", 1, 1),
|
||||
bneck_conf(48, 5, 144, 48, True, "HS", 1, 1),
|
||||
bneck_conf(48, 5, 288, 96 // reduce_divider, True, "HS", 2, dilation), # C4
|
||||
bneck_conf(96 // reduce_divider, 5, 576 // reduce_divider, 96 // reduce_divider, True, "HS", 1, dilation),
|
||||
bneck_conf(96 // reduce_divider, 5, 576 // reduce_divider, 96 // reduce_divider, True, "HS", 1, dilation),
|
||||
]
|
||||
last_channel = adjust_channels(1024 // reduce_divider) # C5
|
||||
else:
|
||||
raise ValueError("Unsupported model type {}".format(arch))
|
||||
|
||||
return inverted_residual_setting, last_channel
|
||||
|
||||
|
||||
def _mobilenet_v3_model(
|
||||
arch: str,
|
||||
inverted_residual_setting: List[InvertedResidualConfig],
|
||||
last_channel: int,
|
||||
pretrained: bool,
|
||||
progress: bool,
|
||||
**kwargs: Any
|
||||
):
|
||||
model = MobileNetV3(inverted_residual_setting, last_channel, **kwargs)
|
||||
if pretrained:
|
||||
if model_urls.get(arch, None) is None:
|
||||
raise ValueError("No checkpoint is available for model type {}".format(arch))
|
||||
state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
|
||||
#model.load_state_dict(state_dict)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
return model
|
||||
|
||||
|
||||
def mobilenet_v3_large(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MobileNetV3:
|
||||
"""
|
||||
Constructs a large MobileNetV3 architecture from
|
||||
`"Searching for MobileNetV3" <https://arxiv.org/abs/1905.02244>`_.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
arch = "mobilenet_v3_large"
|
||||
inverted_residual_setting, last_channel = _mobilenet_v3_conf(arch, kwargs)
|
||||
return _mobilenet_v3_model(arch, inverted_residual_setting, last_channel, pretrained, progress, **kwargs)
|
||||
|
||||
|
||||
def mobilenet_v3_small(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MobileNetV3:
|
||||
"""
|
||||
Constructs a small MobileNetV3 architecture from
|
||||
`"Searching for MobileNetV3" <https://arxiv.org/abs/1905.02244>`_.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
arch = "mobilenet_v3_small"
|
||||
inverted_residual_setting, last_channel = _mobilenet_v3_conf(arch, kwargs)
|
||||
return _mobilenet_v3_model(arch, inverted_residual_setting, last_channel, pretrained, progress, **kwargs)
|
|
@ -0,0 +1,385 @@
|
|||
import torch
|
||||
from torch import Tensor
|
||||
import torch.nn as nn
|
||||
from .utils import load_state_dict_from_url
|
||||
from typing import Type, Any, Callable, Union, List, Optional
|
||||
|
||||
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
|
||||
'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
|
||||
'wide_resnet50_2', 'wide_resnet101_2']
|
||||
|
||||
|
||||
model_urls = {
|
||||
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
|
||||
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
|
||||
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
|
||||
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
|
||||
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
|
||||
'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
|
||||
'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
|
||||
'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
|
||||
'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
|
||||
}
|
||||
|
||||
PADDING_MODE = 'reflect' # {'zeros', 'reflect', 'replicate', 'circular'}
|
||||
|
||||
def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
|
||||
"""3x3 convolution with padding"""
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
||||
padding=dilation, padding_mode = PADDING_MODE, groups=groups, bias=False, dilation=dilation)
|
||||
|
||||
|
||||
def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
|
||||
"""1x1 convolution"""
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
||||
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
expansion: int = 1
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
inplanes: int,
|
||||
planes: int,
|
||||
stride: int = 1,
|
||||
downsample: Optional[nn.Module] = None,
|
||||
groups: int = 1,
|
||||
base_width: int = 64,
|
||||
dilation: int = 1,
|
||||
norm_layer: Optional[Callable[..., nn.Module]] = None
|
||||
) -> None:
|
||||
super(BasicBlock, self).__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
if groups != 1 or base_width != 64:
|
||||
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
|
||||
if dilation > 1:
|
||||
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
|
||||
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
|
||||
self.conv1 = conv3x3(inplanes, planes, stride)
|
||||
self.bn1 = norm_layer(planes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.conv2 = conv3x3(planes, planes)
|
||||
self.bn2 = norm_layer(planes)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out += identity
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
|
||||
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
|
||||
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
|
||||
# This variant is also known as ResNet V1.5 and improves accuracy according to
|
||||
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
|
||||
|
||||
expansion: int = 4
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
inplanes: int,
|
||||
planes: int,
|
||||
stride: int = 1,
|
||||
downsample: Optional[nn.Module] = None,
|
||||
groups: int = 1,
|
||||
base_width: int = 64,
|
||||
dilation: int = 1,
|
||||
norm_layer: Optional[Callable[..., nn.Module]] = None
|
||||
) -> None:
|
||||
super(Bottleneck, self).__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
width = int(planes * (base_width / 64.)) * groups
|
||||
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
|
||||
self.conv1 = conv1x1(inplanes, width)
|
||||
self.bn1 = norm_layer(width)
|
||||
self.conv2 = conv3x3(width, width, stride, groups, dilation)
|
||||
self.bn2 = norm_layer(width)
|
||||
self.conv3 = conv1x1(width, planes * self.expansion)
|
||||
self.bn3 = norm_layer(planes * self.expansion)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out += identity
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ResNet(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
block: Type[Union[BasicBlock, Bottleneck]],
|
||||
layers: List[int],
|
||||
num_classes: int = 1000,
|
||||
zero_init_residual: bool = False,
|
||||
groups: int = 1,
|
||||
width_per_group: int = 64,
|
||||
replace_stride_with_dilation: Optional[List[bool]] = None,
|
||||
norm_layer: Optional[Callable[..., nn.Module]] = None
|
||||
) -> None:
|
||||
super(ResNet, self).__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
self._norm_layer = norm_layer
|
||||
|
||||
self.inplanes = 64
|
||||
self.dilation = 1
|
||||
if replace_stride_with_dilation is None:
|
||||
# each element in the tuple indicates if we should replace
|
||||
# the 2x2 stride with a dilated convolution instead
|
||||
replace_stride_with_dilation = [False, False, False]
|
||||
if len(replace_stride_with_dilation) != 3:
|
||||
raise ValueError("replace_stride_with_dilation should be None "
|
||||
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
|
||||
self.groups = groups
|
||||
self.base_width = width_per_group
|
||||
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, padding_mode = PADDING_MODE,
|
||||
bias=False)
|
||||
self.bn1 = norm_layer(self.inplanes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
|
||||
self.layer1 = self._make_layer(block, 64, layers[0])
|
||||
self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0])
|
||||
self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1])
|
||||
self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2])
|
||||
#self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
#self.fc = nn.Linear(512 * block.expansion, num_classes)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
# Zero-initialize the last BN in each residual branch,
|
||||
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
|
||||
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
|
||||
if zero_init_residual:
|
||||
for m in self.modules():
|
||||
if isinstance(m, Bottleneck):
|
||||
nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type]
|
||||
elif isinstance(m, BasicBlock):
|
||||
nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type]
|
||||
|
||||
def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int,
|
||||
stride: int = 1, dilate: bool = False) -> nn.Sequential:
|
||||
norm_layer = self._norm_layer
|
||||
downsample = None
|
||||
previous_dilation = self.dilation
|
||||
if dilate:
|
||||
self.dilation *= stride
|
||||
stride = 1
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
conv1x1(self.inplanes, planes * block.expansion, stride),
|
||||
norm_layer(planes * block.expansion),)
|
||||
|
||||
layers = []
|
||||
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
|
||||
self.base_width, previous_dilation, norm_layer))
|
||||
self.inplanes = planes * block.expansion
|
||||
for _ in range(1, blocks):
|
||||
layers.append(block(self.inplanes, planes, groups=self.groups,
|
||||
base_width=self.base_width, dilation=self.dilation,
|
||||
norm_layer=norm_layer))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def _forward_impl(self, x: Tensor) -> Tensor:
|
||||
# See note [TorchScript super()]
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
x = self.maxpool(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
# remove extra layers
|
||||
#x = self.avgpool(x)
|
||||
#x = torch.flatten(x, 1)
|
||||
#x = self.fc(x)
|
||||
return x
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return self._forward_impl(x)
|
||||
|
||||
|
||||
def _resnet(
|
||||
arch: str,
|
||||
block: Type[Union[BasicBlock, Bottleneck]],
|
||||
layers: List[int],
|
||||
pretrained: bool,
|
||||
progress: bool,
|
||||
**kwargs: Any
|
||||
) -> ResNet:
|
||||
model = ResNet(block, layers, **kwargs)
|
||||
if pretrained:
|
||||
state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
|
||||
#model.load_state_dict(state_dict)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
return model
|
||||
|
||||
|
||||
def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
|
||||
r"""ResNet-18 model from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
|
||||
**kwargs)
|
||||
|
||||
|
||||
def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
|
||||
r"""ResNet-34 model from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
|
||||
**kwargs)
|
||||
|
||||
|
||||
def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
|
||||
r"""ResNet-50 model from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
|
||||
**kwargs)
|
||||
|
||||
|
||||
def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
|
||||
r"""ResNet-101 model from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
|
||||
**kwargs)
|
||||
|
||||
|
||||
def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
|
||||
r"""ResNet-152 model from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
|
||||
**kwargs)
|
||||
|
||||
|
||||
def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
|
||||
r"""ResNeXt-50 32x4d model from
|
||||
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
kwargs['groups'] = 32
|
||||
kwargs['width_per_group'] = 4
|
||||
return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
|
||||
pretrained, progress, **kwargs)
|
||||
|
||||
|
||||
def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
|
||||
r"""ResNeXt-101 32x8d model from
|
||||
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
kwargs['groups'] = 32
|
||||
kwargs['width_per_group'] = 8
|
||||
return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
|
||||
pretrained, progress, **kwargs)
|
||||
|
||||
|
||||
def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
|
||||
r"""Wide ResNet-50-2 model from
|
||||
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_.
|
||||
|
||||
The model is the same as ResNet except for the bottleneck number of channels
|
||||
which is twice larger in every block. The number of channels in outer 1x1
|
||||
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
|
||||
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
kwargs['width_per_group'] = 64 * 2
|
||||
return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
|
||||
pretrained, progress, **kwargs)
|
||||
|
||||
|
||||
def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
|
||||
r"""Wide ResNet-101-2 model from
|
||||
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_.
|
||||
|
||||
The model is the same as ResNet except for the bottleneck number of channels
|
||||
which is twice larger in every block. The number of channels in outer 1x1
|
||||
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
|
||||
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
kwargs['width_per_group'] = 64 * 2
|
||||
return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
|
||||
pretrained, progress, **kwargs)
|
|
@ -0,0 +1,73 @@
|
|||
import os, math
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
RESULT_DIR = './results'
|
||||
WEIGHT_DIR = './weights'
|
||||
MODEL_DIR = './models'
|
||||
|
||||
__all__ = ('save_results', 'save_weights', 'load_weights', 'adjust_learning_rate', 'warmup_learning_rate')
|
||||
|
||||
try:
|
||||
from torch.hub import load_state_dict_from_url
|
||||
except ImportError:
|
||||
from torch.utils.model_zoo import load_url as load_state_dict_from_url
|
||||
|
||||
|
||||
def save_results(det_roc_obs, seg_roc_obs, seg_pro_obs, model_name, class_name, run_date):
|
||||
result = '{:.2f},{:.2f},{:.2f} \t\tfor {:s}/{:s}/{:s} at epoch {:d}/{:d}/{:d} for {:s}\n'.format(
|
||||
det_roc_obs.max_score, seg_roc_obs.max_score, seg_pro_obs.max_score,
|
||||
det_roc_obs.name, seg_roc_obs.name, seg_pro_obs.name,
|
||||
det_roc_obs.max_epoch, seg_roc_obs.max_epoch, seg_pro_obs.max_epoch, class_name)
|
||||
if not os.path.exists(RESULT_DIR):
|
||||
os.makedirs(RESULT_DIR)
|
||||
fp = open(os.path.join(RESULT_DIR, '{}_{}.txt'.format(model_name, run_date)), "w")
|
||||
fp.write(result)
|
||||
fp.close()
|
||||
|
||||
|
||||
def save_weights(encoder, decoders, model_name, run_date):
|
||||
if not os.path.exists(WEIGHT_DIR):
|
||||
os.makedirs(WEIGHT_DIR)
|
||||
state = {'encoder_state_dict': encoder.state_dict(),
|
||||
'decoder_state_dict': [decoder.state_dict() for decoder in decoders]}
|
||||
filename = '{}_{}.pt'.format(model_name, run_date)
|
||||
path = os.path.join(WEIGHT_DIR, filename)
|
||||
torch.save(state, path)
|
||||
print('Saving weights to {}'.format(filename))
|
||||
|
||||
|
||||
def load_weights(encoder, decoders, filename):
|
||||
path = os.path.join(WEIGHT_DIR, filename)
|
||||
state = torch.load(path)
|
||||
encoder.load_state_dict(state['encoder_state_dict'], strict=False)
|
||||
decoders = [decoder.load_state_dict(state, strict=False) for decoder, state in zip(decoders, state['decoder_state_dict'])]
|
||||
print('Loading weights from {}'.format(filename))
|
||||
|
||||
|
||||
def adjust_learning_rate(c, optimizer, epoch):
|
||||
lr = c.lr
|
||||
if c.lr_cosine:
|
||||
eta_min = lr * (c.lr_decay_rate ** 3)
|
||||
lr = eta_min + (lr - eta_min) * (
|
||||
1 + math.cos(math.pi * epoch / c.meta_epochs)) / 2
|
||||
else:
|
||||
steps = np.sum(epoch >= np.asarray(c.lr_decay_epochs))
|
||||
if steps > 0:
|
||||
lr = lr * (c.lr_decay_rate ** steps)
|
||||
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = lr
|
||||
|
||||
|
||||
def warmup_learning_rate(c, epoch, batch_id, total_batches, optimizer):
|
||||
if c.lr_warm and epoch < c.lr_warm_epochs:
|
||||
p = (batch_id + epoch * total_batches) / \
|
||||
(c.lr_warm_epochs * total_batches)
|
||||
lr = c.lr_warmup_from + p * (c.lr_warmup_to - c.lr_warmup_from)
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = lr
|
||||
#
|
||||
for param_group in optimizer.param_groups:
|
||||
lrate = param_group['lr']
|
||||
return lrate
|
File diff suppressed because one or more lines are too long
After Width: | Height: | Size: 1.1 MiB |
File diff suppressed because it is too large
Load Diff
After Width: | Height: | Size: 319 KiB |
|
@ -0,0 +1,93 @@
|
|||
from __future__ import print_function
|
||||
import os, random, time, math
|
||||
import numpy as np
|
||||
import torch
|
||||
import timm
|
||||
from timm.data import resolve_data_config
|
||||
from config import get_args
|
||||
from train import train
|
||||
|
||||
|
||||
def init_seeds(seed=0):
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
|
||||
def main(c):
|
||||
# model
|
||||
if c.action_type == 'video-train':
|
||||
c.model = "{}_{}_{}".format(c.enc_arch, c.dec_arch, c.video_path)
|
||||
elif c.action_type == 'norm-train' or c.action_type == 'norm-test':
|
||||
c.model = "{}_{}_{}_pl{}_cb{}_inp{}_run{}_{}".format(
|
||||
c.dataset, c.enc_arch, c.dec_arch, c.pool_layers, c.coupling_blocks, c.input_size, c.run_name, c.class_name)
|
||||
else:
|
||||
raise NotImplementedError('{} is not supported action-type!'.format(c.action_type))
|
||||
# image
|
||||
if ('vit' in c.enc_arch) or ('efficient' in c.enc_arch):
|
||||
encoder = timm.create_model(c.enc_arch, pretrained=True)
|
||||
arch_config = resolve_data_config({}, model=encoder)
|
||||
c.norm_mean, c.norm_std = list(arch_config['mean']), list(arch_config['mean'])
|
||||
c.img_size = arch_config['input_size'][1:] # HxW format
|
||||
c.crp_size = arch_config['input_size'][1:] # HxW format
|
||||
else:
|
||||
c.img_size = (c.input_size, c.input_size) # HxW format
|
||||
c.crp_size = (c.input_size, c.input_size) # HxW format
|
||||
if c.dataset == 'stc':
|
||||
c.norm_mean, c.norm_std = 3*[0.5], 3*[0.225]
|
||||
else:
|
||||
c.norm_mean, c.norm_std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
|
||||
#
|
||||
c.img_dims = [3] + list(c.img_size)
|
||||
# network hyperparameters
|
||||
c.clamp_alpha = 1.9 # see paper equation 2 for explanation
|
||||
c.condition_vec = 128
|
||||
c.dropout = 0.0 # dropout in s-t-networks
|
||||
# dataloader parameters
|
||||
if c.dataset == 'mvtec':
|
||||
c.data_path = './data/MVTec-AD'
|
||||
elif c.dataset == 'stc':
|
||||
c.data_path = './data/STC/shanghaitech'
|
||||
elif c.dataset == 'video':
|
||||
c.data_path = c.video_path
|
||||
else:
|
||||
raise NotImplementedError('{} is not supported dataset!'.format(c.dataset))
|
||||
# output settings
|
||||
c.verbose = True
|
||||
c.hide_tqdm_bar = True
|
||||
c.save_results = True
|
||||
# unsup-train
|
||||
c.print_freq = 2
|
||||
c.temp = 0.5
|
||||
c.lr_decay_epochs = [i*c.meta_epochs//100 for i in [50,75,90]]
|
||||
print('LR schedule: {}'.format(c.lr_decay_epochs))
|
||||
c.lr_decay_rate = 0.1
|
||||
c.lr_warm_epochs = 2
|
||||
c.lr_warm = True
|
||||
c.lr_cosine = True
|
||||
if c.lr_warm:
|
||||
c.lr_warmup_from = c.lr/10.0
|
||||
if c.lr_cosine:
|
||||
eta_min = c.lr * (c.lr_decay_rate ** 3)
|
||||
c.lr_warmup_to = eta_min + (c.lr - eta_min) * (
|
||||
1 + math.cos(math.pi * c.lr_warm_epochs / c.meta_epochs)) / 2
|
||||
else:
|
||||
c.lr_warmup_to = c.lr
|
||||
########
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = c.gpu
|
||||
c.use_cuda = not c.no_cuda and torch.cuda.is_available()
|
||||
init_seeds(seed=int(time.time()))
|
||||
c.device = torch.device("cuda" if c.use_cuda else "cpu")
|
||||
# selected function:
|
||||
if c.action_type == 'norm-train':
|
||||
train(c)
|
||||
else:
|
||||
raise NotImplementedError('{} is not supported action-type!'.format(c.action_type))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
c = get_args()
|
||||
main(c)
|
||||
|
|
@ -0,0 +1,178 @@
|
|||
import math
|
||||
import torch
|
||||
from torch import nn
|
||||
from custom_models import *
|
||||
# FrEIA (https://github.com/VLL-HD/FrEIA/)
|
||||
import FrEIA.framework as Ff
|
||||
import FrEIA.modules as Fm
|
||||
import timm
|
||||
|
||||
|
||||
def positionalencoding2d(D, H, W):
|
||||
"""
|
||||
:param D: dimension of the model
|
||||
:param H: H of the positions
|
||||
:param W: W of the positions
|
||||
:return: DxHxW position matrix
|
||||
"""
|
||||
if D % 4 != 0:
|
||||
raise ValueError("Cannot use sin/cos positional encoding with odd dimension (got dim={:d})".format(D))
|
||||
P = torch.zeros(D, H, W)
|
||||
# Each dimension use half of D
|
||||
D = D // 2
|
||||
div_term = torch.exp(torch.arange(0.0, D, 2) * -(math.log(1e4) / D))
|
||||
pos_w = torch.arange(0.0, W).unsqueeze(1)
|
||||
pos_h = torch.arange(0.0, H).unsqueeze(1)
|
||||
P[0:D:2, :, :] = torch.sin(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, H, 1)
|
||||
P[1:D:2, :, :] = torch.cos(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, H, 1)
|
||||
P[D::2, :, :] = torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, W)
|
||||
P[D+1::2,:, :] = torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, W)
|
||||
return P
|
||||
|
||||
|
||||
def subnet_fc(dims_in, dims_out):
|
||||
return nn.Sequential(nn.Linear(dims_in, 2*dims_in), nn.ReLU(), nn.Linear(2*dims_in, dims_out))
|
||||
|
||||
|
||||
def freia_flow_head(c, n_feat):
|
||||
coder = Ff.SequenceINN(n_feat)
|
||||
print('NF coder:', n_feat)
|
||||
for k in range(c.coupling_blocks):
|
||||
coder.append(Fm.AllInOneBlock, subnet_constructor=subnet_fc, affine_clamping=c.clamp_alpha,
|
||||
global_affine_type='SOFTPLUS', permute_soft=True)
|
||||
return coder
|
||||
|
||||
|
||||
def freia_cflow_head(c, n_feat):
|
||||
n_cond = c.condition_vec
|
||||
coder = Ff.SequenceINN(n_feat)
|
||||
print('CNF coder:', n_feat)
|
||||
for k in range(c.coupling_blocks):
|
||||
coder.append(Fm.AllInOneBlock, cond=0, cond_shape=(n_cond,), subnet_constructor=subnet_fc, affine_clamping=c.clamp_alpha,
|
||||
global_affine_type='SOFTPLUS', permute_soft=True)
|
||||
return coder
|
||||
|
||||
|
||||
def load_decoder_arch(c, dim_in):
|
||||
if c.dec_arch == 'freia-flow':
|
||||
decoder = freia_flow_head(c, dim_in)
|
||||
elif c.dec_arch == 'freia-cflow':
|
||||
decoder = freia_cflow_head(c, dim_in)
|
||||
else:
|
||||
raise NotImplementedError('{} is not supported NF!'.format(c.dec_arch))
|
||||
#print(decoder)
|
||||
return decoder
|
||||
|
||||
|
||||
activation = {}
|
||||
def get_activation(name):
|
||||
def hook(model, input, output):
|
||||
activation[name] = output.detach()
|
||||
return hook
|
||||
|
||||
|
||||
def load_encoder_arch(c, L):
|
||||
# encoder pretrained on natural images:
|
||||
pool_cnt = 0
|
||||
pool_dims = list()
|
||||
pool_layers = ['layer'+str(i) for i in range(L)]
|
||||
if 'resnet' in c.enc_arch:
|
||||
if c.enc_arch == 'resnet18':
|
||||
encoder = resnet18(pretrained=True, progress=True)
|
||||
elif c.enc_arch == 'resnet34':
|
||||
encoder = resnet34(pretrained=True, progress=True)
|
||||
elif c.enc_arch == 'resnet50':
|
||||
encoder = resnet50(pretrained=True, progress=True)
|
||||
elif c.enc_arch == 'resnext50_32x4d':
|
||||
encoder = resnext50_32x4d(pretrained=True, progress=True)
|
||||
elif c.enc_arch == 'wide_resnet50_2':
|
||||
encoder = wide_resnet50_2(pretrained=True, progress=True)
|
||||
else:
|
||||
raise NotImplementedError('{} is not supported architecture!'.format(c.enc_arch))
|
||||
#
|
||||
if L >= 3:
|
||||
encoder.layer2.register_forward_hook(get_activation(pool_layers[pool_cnt]))
|
||||
if 'wide' in c.enc_arch:
|
||||
pool_dims.append(encoder.layer2[-1].conv3.out_channels)
|
||||
else:
|
||||
pool_dims.append(encoder.layer2[-1].conv2.out_channels)
|
||||
pool_cnt = pool_cnt + 1
|
||||
if L >= 2:
|
||||
encoder.layer3.register_forward_hook(get_activation(pool_layers[pool_cnt]))
|
||||
if 'wide' in c.enc_arch:
|
||||
pool_dims.append(encoder.layer3[-1].conv3.out_channels)
|
||||
else:
|
||||
pool_dims.append(encoder.layer3[-1].conv2.out_channels)
|
||||
pool_cnt = pool_cnt + 1
|
||||
if L >= 1:
|
||||
encoder.layer4.register_forward_hook(get_activation(pool_layers[pool_cnt]))
|
||||
if 'wide' in c.enc_arch:
|
||||
pool_dims.append(encoder.layer4[-1].conv3.out_channels)
|
||||
else:
|
||||
pool_dims.append(encoder.layer4[-1].conv2.out_channels)
|
||||
pool_cnt = pool_cnt + 1
|
||||
elif 'vit' in c.enc_arch:
|
||||
if c.enc_arch == 'vit_base_patch16_224':
|
||||
encoder = timm.create_model('vit_base_patch16_224', pretrained=True)
|
||||
elif c.enc_arch == 'vit_base_patch16_384':
|
||||
encoder = timm.create_model('vit_base_patch16_384', pretrained=True)
|
||||
else:
|
||||
raise NotImplementedError('{} is not supported architecture!'.format(c.enc_arch))
|
||||
#
|
||||
if L >= 3:
|
||||
encoder.blocks[10].register_forward_hook(get_activation(pool_layers[pool_cnt]))
|
||||
pool_dims.append(encoder.blocks[6].mlp.fc2.out_features)
|
||||
pool_cnt = pool_cnt + 1
|
||||
if L >= 2:
|
||||
encoder.blocks[2].register_forward_hook(get_activation(pool_layers[pool_cnt]))
|
||||
pool_dims.append(encoder.blocks[6].mlp.fc2.out_features)
|
||||
pool_cnt = pool_cnt + 1
|
||||
if L >= 1:
|
||||
encoder.blocks[6].register_forward_hook(get_activation(pool_layers[pool_cnt]))
|
||||
pool_dims.append(encoder.blocks[6].mlp.fc2.out_features)
|
||||
pool_cnt = pool_cnt + 1
|
||||
elif 'efficient' in c.enc_arch:
|
||||
if 'b5' in c.enc_arch:
|
||||
encoder = timm.create_model(c.enc_arch, pretrained=True)
|
||||
blocks = [-2, -3, -5]
|
||||
else:
|
||||
raise NotImplementedError('{} is not supported architecture!'.format(c.enc_arch))
|
||||
#
|
||||
if L >= 3:
|
||||
encoder.blocks[blocks[2]][-1].bn3.register_forward_hook(get_activation(pool_layers[pool_cnt]))
|
||||
pool_dims.append(encoder.blocks[blocks[2]][-1].bn3.num_features)
|
||||
pool_cnt = pool_cnt + 1
|
||||
if L >= 2:
|
||||
encoder.blocks[blocks[1]][-1].bn3.register_forward_hook(get_activation(pool_layers[pool_cnt]))
|
||||
pool_dims.append(encoder.blocks[blocks[1]][-1].bn3.num_features)
|
||||
pool_cnt = pool_cnt + 1
|
||||
if L >= 1:
|
||||
encoder.blocks[blocks[0]][-1].bn3.register_forward_hook(get_activation(pool_layers[pool_cnt]))
|
||||
pool_dims.append(encoder.blocks[blocks[0]][-1].bn3.num_features)
|
||||
pool_cnt = pool_cnt + 1
|
||||
elif 'mobile' in c.enc_arch:
|
||||
if c.enc_arch == 'mobilenet_v3_small':
|
||||
encoder = mobilenet_v3_small(pretrained=True, progress=True).features
|
||||
blocks = [-2, -5, -10]
|
||||
elif c.enc_arch == 'mobilenet_v3_large':
|
||||
encoder = mobilenet_v3_large(pretrained=True, progress=True).features
|
||||
blocks = [-2, -5, -11]
|
||||
else:
|
||||
raise NotImplementedError('{} is not supported architecture!'.format(c.enc_arch))
|
||||
#
|
||||
if L >= 3:
|
||||
encoder[blocks[2]].block[-1][-3].register_forward_hook(get_activation(pool_layers[pool_cnt]))
|
||||
pool_dims.append(encoder[blocks[2]].block[-1][-3].out_channels)
|
||||
pool_cnt = pool_cnt + 1
|
||||
if L >= 2:
|
||||
encoder[blocks[1]].block[-1][-3].register_forward_hook(get_activation(pool_layers[pool_cnt]))
|
||||
pool_dims.append(encoder[blocks[1]].block[-1][-3].out_channels)
|
||||
pool_cnt = pool_cnt + 1
|
||||
if L >= 1:
|
||||
encoder[blocks[0]].block[-1][-3].register_forward_hook(get_activation(pool_layers[pool_cnt]))
|
||||
pool_dims.append(encoder[blocks[0]].block[-1][-3].out_channels)
|
||||
pool_cnt = pool_cnt + 1
|
||||
else:
|
||||
raise NotImplementedError('{} is not supported architecture!'.format(c.enc_arch))
|
||||
#
|
||||
return encoder, pool_layers, pool_dims
|
|
@ -0,0 +1,89 @@
|
|||
from __future__ import print_function
|
||||
import argparse, os
|
||||
import numpy as np
|
||||
|
||||
RESULT_DIR = './results'
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(description='CFLOW-AD')
|
||||
parser.add_argument('--dataset', default='mvtec', type=str, metavar='D',
|
||||
help='dataset name: mvtec/stc/video (default: mvtec)')
|
||||
parser.add_argument('-cl', '--class-name', default='none', type=str, metavar='C',
|
||||
help='class name for MVTec/STC (default: none)')
|
||||
parser.add_argument('-enc', '--enc-arch', default='resnet18', type=str, metavar='A',
|
||||
help='feature extractor architecture (default: resnet18)')
|
||||
parser.add_argument('-dec', '--dec-arch', default='freia-cflow', type=str, metavar='A',
|
||||
help='normalizing flow model (default: freia-cflow)')
|
||||
parser.add_argument('-pl', '--pool-layers', default=2, type=int, metavar='L',
|
||||
help='number of layers used in NF model (default: 2)')
|
||||
parser.add_argument('-cb', '--coupling-blocks', default=8, type=int, metavar='L',
|
||||
help='number of layers used in NF model (default: 8)')
|
||||
parser.add_argument('-runs', '--run-count', default=4, type=int, metavar='C',
|
||||
help='number of runs (default: 4)')
|
||||
parser.add_argument('-inp', '--input-size', default=256, type=int, metavar='C',
|
||||
help='image resize dimensions (default: 256)')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
return args
|
||||
|
||||
def main(c):
|
||||
runs = c.run_count
|
||||
if c.dataset == 'mvtec':
|
||||
class_names = ['bottle', 'cable', 'capsule', 'carpet', 'grid',
|
||||
'hazelnut', 'leather', 'metal_nut', 'pill', 'screw',
|
||||
'tile', 'toothbrush', 'transistor', 'wood', 'zipper']
|
||||
elif c.dataset == 'stc':
|
||||
class_names = ['01', '02', '03', '04', '05', '06',
|
||||
'07', '08', '09', '10', '11', '12'] #, '13']
|
||||
else:
|
||||
raise NotImplementedError('{} is not supported dataset!'.format(c.dataset))
|
||||
#
|
||||
metrics = ['DET_AUROC', 'SEG_AUROC', 'SEG_AUPRO']
|
||||
results = np.zeros((len(metrics), len(class_names), runs))
|
||||
# loop
|
||||
result_files = os.listdir(RESULT_DIR)
|
||||
for class_idx, class_name in enumerate(class_names):
|
||||
for run in range(runs):
|
||||
if class_name in ['cable', 'capsule', 'hazelnut', 'metal_nut', 'pill']: # AUROC
|
||||
input_size = 256
|
||||
elif class_name == 'transistor':
|
||||
input_size = 128
|
||||
else:
|
||||
input_size = c.input_size
|
||||
|
||||
#input_size = c.input_size
|
||||
|
||||
c.model = "{}_{}_{}_pl{}_cb{}_inp{}_run{}_{}".format(
|
||||
c.dataset, c.enc_arch, c.dec_arch, c.pool_layers, c.coupling_blocks, input_size, run, class_name)
|
||||
#
|
||||
result_file = list(filter(lambda x: x.startswith(c.model), result_files))
|
||||
if len(result_file) == 0:
|
||||
raise NotImplementedError('{} results are not found!'.format(c.model))
|
||||
elif len(result_file) > 1:
|
||||
raise NotImplementedError('{} duplicate results are found!'.format(result_file))
|
||||
else:
|
||||
result_file = result_file[0]
|
||||
#
|
||||
fp = open(os.path.join(RESULT_DIR, result_file), 'r')
|
||||
lines = fp.readlines()
|
||||
rline = lines[0].split(' ')[0].split(',')
|
||||
result = np.array([float(r) for r in rline])
|
||||
fp.close()
|
||||
results[:, class_idx, run] = result
|
||||
#
|
||||
for i, m in enumerate(metrics):
|
||||
print('\n{}:'.format(m))
|
||||
for j, class_name in enumerate(class_names):
|
||||
print(r"{:.2f}\tiny$\pm${:.2f} for class {}".format(np.mean(results[i, j]), np.std(results[i, j]), class_name))
|
||||
# \tiny$\pm$
|
||||
means = np.mean(results[i], 0)
|
||||
#print(results[i].shape, means.shape)
|
||||
print(r"{:.2f} for average".format(np.mean(means)))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
c = get_args()
|
||||
main(c)
|
||||
|
|
@ -0,0 +1,11 @@
|
|||
numpy
|
||||
scipy
|
||||
scikit-learn
|
||||
matplotlib
|
||||
tqdm
|
||||
scikit-image
|
||||
av
|
||||
timm
|
||||
torch >= 1.6.0
|
||||
torchvision >= 0.7.0
|
||||
git+git://github.com/VLL-HD/FrEIA@4e0c6ab42b26ec6e41b1ee2abb1a8b6562752b00
|
|
@ -0,0 +1,428 @@
|
|||
import os, time
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from sklearn.metrics import roc_auc_score, auc, precision_recall_curve
|
||||
from skimage.measure import label, regionprops
|
||||
from tqdm import tqdm
|
||||
from visualize import *
|
||||
from model import load_decoder_arch, load_encoder_arch, positionalencoding2d, activation
|
||||
from utils import *
|
||||
from custom_datasets import *
|
||||
from custom_models import *
|
||||
|
||||
OUT_DIR = './viz/'
|
||||
|
||||
gamma = 0.0
|
||||
theta = torch.nn.Sigmoid()
|
||||
log_theta = torch.nn.LogSigmoid()
|
||||
|
||||
|
||||
def train_meta_epoch(c, epoch, loader, encoder, decoders, optimizer, pool_layers, N):
|
||||
P = c.condition_vec
|
||||
L = c.pool_layers
|
||||
decoders = [decoder.train() for decoder in decoders]
|
||||
adjust_learning_rate(c, optimizer, epoch)
|
||||
I = len(loader)
|
||||
iterator = iter(loader)
|
||||
for sub_epoch in range(c.sub_epochs):
|
||||
train_loss = 0.0
|
||||
train_count = 0
|
||||
for i in range(I):
|
||||
# warm-up learning rate
|
||||
lr = warmup_learning_rate(c, epoch, i+sub_epoch*I, I*c.sub_epochs, optimizer)
|
||||
# sample batch
|
||||
try:
|
||||
image, _, _ = next(iterator)
|
||||
except StopIteration:
|
||||
iterator = iter(loader)
|
||||
image, _, _ = next(iterator)
|
||||
# encoder prediction
|
||||
image = image.to(c.device) # single scale
|
||||
with torch.no_grad():
|
||||
_ = encoder(image)
|
||||
# train decoder
|
||||
e_list = list()
|
||||
c_list = list()
|
||||
for l, layer in enumerate(pool_layers):
|
||||
if 'vit' in c.enc_arch:
|
||||
e = activation[layer].transpose(1, 2)[...,1:]
|
||||
e_hw = int(np.sqrt(e.size(2)))
|
||||
e = e.reshape(-1, e.size(1), e_hw, e_hw) # BxCxHxW
|
||||
else:
|
||||
e = activation[layer].detach() # BxCxHxW
|
||||
#
|
||||
B, C, H, W = e.size()
|
||||
S = H*W
|
||||
E = B*S
|
||||
#
|
||||
p = positionalencoding2d(P, H, W).to(c.device).unsqueeze(0).repeat(B, 1, 1, 1)
|
||||
c_r = p.reshape(B, P, S).transpose(1, 2).reshape(E, P) # BHWxP
|
||||
e_r = e.reshape(B, C, S).transpose(1, 2).reshape(E, C) # BHWxC
|
||||
perm = torch.randperm(E).to(c.device) # BHW
|
||||
decoder = decoders[l]
|
||||
#
|
||||
FIB = E//N # number of fiber batches
|
||||
assert FIB > 0, 'MAKE SURE WE HAVE ENOUGH FIBERS, otherwise decrease N or batch-size!'
|
||||
for f in range(FIB): # per-fiber processing
|
||||
idx = torch.arange(f*N, (f+1)*N)
|
||||
c_p = c_r[perm[idx]] # NxP
|
||||
e_p = e_r[perm[idx]] # NxC
|
||||
if 'cflow' in c.dec_arch:
|
||||
z, log_jac_det = decoder(e_p, [c_p,])
|
||||
else:
|
||||
z, log_jac_det = decoder(e_p)
|
||||
#
|
||||
decoder_log_prob = get_logp(C, z, log_jac_det)
|
||||
log_prob = decoder_log_prob / C # likelihood per dim
|
||||
loss = -log_theta(log_prob)
|
||||
optimizer.zero_grad()
|
||||
loss.mean().backward()
|
||||
optimizer.step()
|
||||
train_loss += t2np(loss.sum())
|
||||
train_count += len(loss)
|
||||
#
|
||||
mean_train_loss = train_loss / train_count
|
||||
if c.verbose:
|
||||
print('Epoch: {:d}.{:d} \t train loss: {:.4f}, lr={:.6f}'.format(epoch, sub_epoch, mean_train_loss, lr))
|
||||
#
|
||||
|
||||
|
||||
def test_meta_epoch(c, epoch, loader, encoder, decoders, pool_layers, N):
|
||||
# test
|
||||
if c.verbose:
|
||||
print('\nCompute loss and scores on test set:')
|
||||
#
|
||||
P = c.condition_vec
|
||||
decoders = [decoder.eval() for decoder in decoders]
|
||||
height = list()
|
||||
width = list()
|
||||
image_list = list()
|
||||
gt_label_list = list()
|
||||
gt_mask_list = list()
|
||||
test_dist = [list() for layer in pool_layers]
|
||||
test_loss = 0.0
|
||||
test_count = 0
|
||||
start = time.time()
|
||||
with torch.no_grad():
|
||||
for i, (image, label, mask) in enumerate(tqdm(loader, disable=c.hide_tqdm_bar)):
|
||||
# save
|
||||
if c.viz:
|
||||
image_list.extend(t2np(image))
|
||||
gt_label_list.extend(t2np(label))
|
||||
gt_mask_list.extend(t2np(mask))
|
||||
# data
|
||||
image = image.to(c.device) # single scale
|
||||
_ = encoder(image) # BxCxHxW
|
||||
# test decoder
|
||||
e_list = list()
|
||||
for l, layer in enumerate(pool_layers):
|
||||
if 'vit' in c.enc_arch:
|
||||
e = activation[layer].transpose(1, 2)[...,1:]
|
||||
e_hw = int(np.sqrt(e.size(2)))
|
||||
e = e.reshape(-1, e.size(1), e_hw, e_hw) # BxCxHxW
|
||||
else:
|
||||
e = activation[layer] # BxCxHxW
|
||||
#
|
||||
B, C, H, W = e.size()
|
||||
S = H*W
|
||||
E = B*S
|
||||
#
|
||||
if i == 0: # get stats
|
||||
height.append(H)
|
||||
width.append(W)
|
||||
#
|
||||
p = positionalencoding2d(P, H, W).to(c.device).unsqueeze(0).repeat(B, 1, 1, 1)
|
||||
c_r = p.reshape(B, P, S).transpose(1, 2).reshape(E, P) # BHWxP
|
||||
e_r = e.reshape(B, C, S).transpose(1, 2).reshape(E, C) # BHWxC
|
||||
#
|
||||
m = F.interpolate(mask, size=(H, W), mode='nearest')
|
||||
m_r = m.reshape(B, 1, S).transpose(1, 2).reshape(E, 1) # BHWx1
|
||||
#
|
||||
decoder = decoders[l]
|
||||
FIB = E//N + int(E%N > 0) # number of fiber batches
|
||||
for f in range(FIB):
|
||||
if f < (FIB-1):
|
||||
idx = torch.arange(f*N, (f+1)*N)
|
||||
else:
|
||||
idx = torch.arange(f*N, E)
|
||||
#
|
||||
c_p = c_r[idx] # NxP
|
||||
e_p = e_r[idx] # NxC
|
||||
m_p = m_r[idx] > 0.5 # Nx1
|
||||
#
|
||||
if 'cflow' in c.dec_arch:
|
||||
z, log_jac_det = decoder(e_p, [c_p,])
|
||||
else:
|
||||
z, log_jac_det = decoder(e_p)
|
||||
#
|
||||
decoder_log_prob = get_logp(C, z, log_jac_det)
|
||||
log_prob = decoder_log_prob / C # likelihood per dim
|
||||
loss = -log_theta(log_prob)
|
||||
test_loss += t2np(loss.sum())
|
||||
test_count += len(loss)
|
||||
test_dist[l] = test_dist[l] + log_prob.detach().cpu().tolist()
|
||||
#
|
||||
fps = len(loader.dataset) / (time.time() - start)
|
||||
mean_test_loss = test_loss / test_count
|
||||
if c.verbose:
|
||||
print('Epoch: {:d} \t test_loss: {:.4f} and {:.2f} fps'.format(epoch, mean_test_loss, fps))
|
||||
#
|
||||
return height, width, image_list, test_dist, gt_label_list, gt_mask_list
|
||||
|
||||
|
||||
def test_meta_fps(c, epoch, loader, encoder, decoders, pool_layers, N):
|
||||
# test
|
||||
if c.verbose:
|
||||
print('\nCompute loss and scores on test set:')
|
||||
#
|
||||
P = c.condition_vec
|
||||
decoders = [decoder.eval() for decoder in decoders]
|
||||
height = list()
|
||||
width = list()
|
||||
image_list = list()
|
||||
gt_label_list = list()
|
||||
gt_mask_list = list()
|
||||
test_dist = [list() for layer in pool_layers]
|
||||
test_loss = 0.0
|
||||
test_count = 0
|
||||
A = len(loader.dataset)
|
||||
with torch.no_grad():
|
||||
# warm-up
|
||||
for i, (image, _, _) in enumerate(tqdm(loader, disable=c.hide_tqdm_bar)):
|
||||
# data
|
||||
image = image.to(c.device) # single scale
|
||||
_ = encoder(image) # BxCxHxW
|
||||
# measure encoder only
|
||||
torch.cuda.synchronize()
|
||||
start = time.time()
|
||||
for i, (image, _, _) in enumerate(tqdm(loader, disable=c.hide_tqdm_bar)):
|
||||
# data
|
||||
image = image.to(c.device) # single scale
|
||||
_ = encoder(image) # BxCxHxW
|
||||
# measure encoder + decoder
|
||||
torch.cuda.synchronize()
|
||||
time_enc = time.time() - start
|
||||
start = time.time()
|
||||
for i, (image, _, _) in enumerate(tqdm(loader, disable=c.hide_tqdm_bar)):
|
||||
# data
|
||||
image = image.to(c.device) # single scale
|
||||
_ = encoder(image) # BxCxHxW
|
||||
# test decoder
|
||||
e_list = list()
|
||||
for l, layer in enumerate(pool_layers):
|
||||
if 'vit' in c.enc_arch:
|
||||
e = activation[layer].transpose(1, 2)[...,1:]
|
||||
e_hw = int(np.sqrt(e.size(2)))
|
||||
e = e.reshape(-1, e.size(1), e_hw, e_hw) # BxCxHxW
|
||||
else:
|
||||
e = activation[layer] # BxCxHxW
|
||||
#
|
||||
B, C, H, W = e.size()
|
||||
S = H*W
|
||||
E = B*S
|
||||
#
|
||||
if i == 0: # get stats
|
||||
height.append(H)
|
||||
width.append(W)
|
||||
#
|
||||
p = positionalencoding2d(P, H, W).to(c.device).unsqueeze(0).repeat(B, 1, 1, 1)
|
||||
c_r = p.reshape(B, P, S).transpose(1, 2).reshape(E, P) # BHWxP
|
||||
e_r = e.reshape(B, C, S).transpose(1, 2).reshape(E, C) # BHWxC
|
||||
#
|
||||
decoder = decoders[l]
|
||||
FIB = E//N + int(E%N > 0) # number of fiber batches
|
||||
for f in range(FIB):
|
||||
if f < (FIB-1):
|
||||
idx = torch.arange(f*N, (f+1)*N)
|
||||
else:
|
||||
idx = torch.arange(f*N, E)
|
||||
#
|
||||
c_p = c_r[idx] # NxP
|
||||
e_p = e_r[idx] # NxC
|
||||
#
|
||||
if 'cflow' in c.dec_arch:
|
||||
z, log_jac_det = decoder(e_p, [c_p,])
|
||||
else:
|
||||
z, log_jac_det = decoder(e_p)
|
||||
#
|
||||
torch.cuda.synchronize()
|
||||
time_all = time.time() - start
|
||||
fps_enc = A / time_enc
|
||||
fps_all = A / time_all
|
||||
print('Encoder/All {:.2f}/{:.2f} fps'.format(fps_enc, fps_all))
|
||||
#
|
||||
return height, width, image_list, test_dist, gt_label_list, gt_mask_list
|
||||
|
||||
|
||||
def train(c):
|
||||
run_date = datetime.datetime.now().strftime("%Y-%m-%d-%H:%M:%S")
|
||||
L = c.pool_layers # number of pooled layers
|
||||
print('Number of pool layers =', L)
|
||||
encoder, pool_layers, pool_dims = load_encoder_arch(c, L)
|
||||
encoder = encoder.to(c.device).eval()
|
||||
#print(encoder)
|
||||
# NF decoder
|
||||
decoders = [load_decoder_arch(c, pool_dim) for pool_dim in pool_dims]
|
||||
decoders = [decoder.to(c.device) for decoder in decoders]
|
||||
params = list(decoders[0].parameters())
|
||||
for l in range(1, L):
|
||||
params += list(decoders[l].parameters())
|
||||
# optimizer
|
||||
optimizer = torch.optim.Adam(params, lr=c.lr)
|
||||
# data
|
||||
kwargs = {'num_workers': c.workers, 'pin_memory': True} if c.use_cuda else {}
|
||||
# task data
|
||||
if c.dataset == 'mvtec':
|
||||
train_dataset = MVTecDataset(c, is_train=True)
|
||||
test_dataset = MVTecDataset(c, is_train=False)
|
||||
elif c.dataset == 'stc':
|
||||
train_dataset = StcDataset(c, is_train=True)
|
||||
test_dataset = StcDataset(c, is_train=False)
|
||||
#elif c.dataset == 'video':
|
||||
# c.data_path = c.video_path
|
||||
else:
|
||||
raise NotImplementedError('{} is not supported dataset!'.format(c.dataset))
|
||||
#
|
||||
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=c.batch_size, shuffle=True, drop_last=True, **kwargs)
|
||||
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=c.batch_size, shuffle=False, drop_last=False, **kwargs)
|
||||
N = 256 # hyperparameter that increases batch size for the decoder model by N
|
||||
print('train/test loader length', len(train_loader.dataset), len(test_loader.dataset))
|
||||
print('train/test loader batches', len(train_loader), len(test_loader))
|
||||
# stats
|
||||
det_roc_obs = Score_Observer('DET_AUROC')
|
||||
seg_roc_obs = Score_Observer('SEG_AUROC')
|
||||
seg_pro_obs = Score_Observer('SEG_AUPRO')
|
||||
for epoch in range(c.meta_epochs):
|
||||
if c.viz:
|
||||
if c.checkpoint:
|
||||
load_weights(encoder, decoders, c.checkpoint)
|
||||
else:
|
||||
print('Train meta epoch: {}'.format(epoch))
|
||||
train_meta_epoch(c, epoch, train_loader, encoder, decoders, optimizer, pool_layers, N)
|
||||
|
||||
#height, width, test_image_list, test_dist, gt_label_list, gt_mask_list = test_meta_fps(
|
||||
# c, epoch, test_loader, encoder, decoders, pool_layers, N)
|
||||
|
||||
height, width, test_image_list, test_dist, gt_label_list, gt_mask_list = test_meta_epoch(
|
||||
c, epoch, test_loader, encoder, decoders, pool_layers, N)
|
||||
|
||||
# PxEHW
|
||||
print('Heights/Widths', height, width)
|
||||
test_map = [list() for p in pool_layers]
|
||||
for l, p in enumerate(pool_layers):
|
||||
test_norm = torch.tensor(test_dist[l], dtype=torch.double) # EHWx1
|
||||
test_norm-= torch.max(test_norm) # normalize likelihoods to (-Inf:0] by subtracting a constant
|
||||
test_prob = torch.exp(test_norm) # convert to probs in range [0:1]
|
||||
test_mask = test_prob.reshape(-1, height[l], width[l])
|
||||
#print('Prob shape:', test_prob.shape, test_prob.min(), test_prob.max())
|
||||
test_mask = test_prob.reshape(-1, height[l], width[l])
|
||||
# upsample
|
||||
test_map[l] = F.interpolate(test_mask.unsqueeze(1),
|
||||
size=c.crp_size, mode='bilinear', align_corners=True).squeeze().numpy()
|
||||
# score aggregation
|
||||
score_map = np.zeros_like(test_map[0])
|
||||
for l, p in enumerate(pool_layers):
|
||||
score_map += test_map[l]
|
||||
score_mask = score_map
|
||||
# superpixels
|
||||
super_mask = score_mask.max() - score_mask # /score_mask.max() # normality score to anomaly score
|
||||
# calculate detection AUROC
|
||||
score_label = np.max(super_mask, axis=(1, 2))
|
||||
gt_label = np.asarray(gt_label_list, dtype=np.bool)
|
||||
det_roc_auc = roc_auc_score(gt_label, score_label)
|
||||
det_roc_obs.update(100.0*det_roc_auc, epoch)
|
||||
# calculate segmentation AUROC
|
||||
gt_mask = np.squeeze(np.asarray(gt_mask_list, dtype=np.bool), axis=1)
|
||||
seg_roc_auc = roc_auc_score(gt_mask.flatten(), super_mask.flatten())
|
||||
seg_roc_obs.update(100.0*seg_roc_auc, epoch)
|
||||
# calculate segmentation AUPRO
|
||||
# from https://github.com/YoungGod/DFR:
|
||||
if c.pro: # and (epoch % 4 == 0): # AUPRO is expensive to compute
|
||||
max_step = 1000
|
||||
expect_fpr = 0.3 # default 30%
|
||||
max_th = super_mask.max()
|
||||
min_th = super_mask.min()
|
||||
delta = (max_th - min_th) / max_step
|
||||
ious_mean = []
|
||||
ious_std = []
|
||||
pros_mean = []
|
||||
pros_std = []
|
||||
threds = []
|
||||
fprs = []
|
||||
binary_score_maps = np.zeros_like(super_mask, dtype=np.bool)
|
||||
for step in range(max_step):
|
||||
thred = max_th - step * delta
|
||||
# segmentation
|
||||
binary_score_maps[super_mask <= thred] = 0
|
||||
binary_score_maps[super_mask > thred] = 1
|
||||
pro = [] # per region overlap
|
||||
iou = [] # per image iou
|
||||
# pro: find each connected gt region, compute the overlapped pixels between the gt region and predicted region
|
||||
# iou: for each image, compute the ratio, i.e. intersection/union between the gt and predicted binary map
|
||||
for i in range(len(binary_score_maps)): # for i th image
|
||||
# pro (per region level)
|
||||
label_map = label(gt_mask[i], connectivity=2)
|
||||
props = regionprops(label_map)
|
||||
for prop in props:
|
||||
x_min, y_min, x_max, y_max = prop.bbox # find the bounding box of an anomaly region
|
||||
cropped_pred_label = binary_score_maps[i][x_min:x_max, y_min:y_max]
|
||||
# cropped_mask = gt_mask[i][x_min:x_max, y_min:y_max] # bug!
|
||||
cropped_mask = prop.filled_image # corrected!
|
||||
intersection = np.logical_and(cropped_pred_label, cropped_mask).astype(np.float32).sum()
|
||||
pro.append(intersection / prop.area)
|
||||
# iou (per image level)
|
||||
intersection = np.logical_and(binary_score_maps[i], gt_mask[i]).astype(np.float32).sum()
|
||||
union = np.logical_or(binary_score_maps[i], gt_mask[i]).astype(np.float32).sum()
|
||||
if gt_mask[i].any() > 0: # when the gt have no anomaly pixels, skip it
|
||||
iou.append(intersection / union)
|
||||
# against steps and average metrics on the testing data
|
||||
ious_mean.append(np.array(iou).mean())
|
||||
#print("per image mean iou:", np.array(iou).mean())
|
||||
ious_std.append(np.array(iou).std())
|
||||
pros_mean.append(np.array(pro).mean())
|
||||
pros_std.append(np.array(pro).std())
|
||||
# fpr for pro-auc
|
||||
gt_masks_neg = ~gt_mask
|
||||
fpr = np.logical_and(gt_masks_neg, binary_score_maps).sum() / gt_masks_neg.sum()
|
||||
fprs.append(fpr)
|
||||
threds.append(thred)
|
||||
# as array
|
||||
threds = np.array(threds)
|
||||
pros_mean = np.array(pros_mean)
|
||||
pros_std = np.array(pros_std)
|
||||
fprs = np.array(fprs)
|
||||
ious_mean = np.array(ious_mean)
|
||||
ious_std = np.array(ious_std)
|
||||
# best per image iou
|
||||
best_miou = ious_mean.max()
|
||||
#print(f"Best IOU: {best_miou:.4f}")
|
||||
# default 30% fpr vs pro, pro_auc
|
||||
idx = fprs <= expect_fpr # find the indexs of fprs that is less than expect_fpr (default 0.3)
|
||||
fprs_selected = fprs[idx]
|
||||
fprs_selected = rescale(fprs_selected) # rescale fpr [0,0.3] -> [0, 1]
|
||||
pros_mean_selected = pros_mean[idx]
|
||||
seg_pro_auc = auc(fprs_selected, pros_mean_selected)
|
||||
seg_pro_obs.update(100.0*seg_pro_auc, epoch)
|
||||
# export visualuzations
|
||||
if c.viz:
|
||||
precision, recall, thresholds = precision_recall_curve(gt_label, score_label)
|
||||
a = 2 * precision * recall
|
||||
b = precision + recall
|
||||
f1 = np.divide(a, b, out=np.zeros_like(a), where=b != 0)
|
||||
det_threshold = thresholds[np.argmax(f1)]
|
||||
print('Optimal DET Threshold: {:.2f}'.format(det_threshold))
|
||||
precision, recall, thresholds = precision_recall_curve(gt_mask.flatten(), super_mask.flatten())
|
||||
a = 2 * precision * recall
|
||||
b = precision + recall
|
||||
f1 = np.divide(a, b, out=np.zeros_like(a), where=b != 0)
|
||||
seg_threshold = thresholds[np.argmax(f1)]
|
||||
print('Optimal SEG Threshold: {:.2f}'.format(seg_threshold))
|
||||
export_groundtruth(c, test_image_list, gt_mask)
|
||||
export_scores(c, test_image_list, super_mask, seg_threshold)
|
||||
export_test_images(c, test_image_list, gt_mask, super_mask, seg_threshold)
|
||||
export_hist(c, gt_mask, super_mask, seg_threshold)
|
||||
#save_weights(encoder, decoders, c.model, run_date) # avoid unnecessary saves
|
||||
elif c.save_results:
|
||||
save_results(det_roc_obs, seg_roc_obs, seg_pro_obs, c.model, c.class_name, run_date)
|
||||
save_weights(encoder, decoders, c.model, run_date) # avoid unnecessary saves
|
|
@ -0,0 +1,40 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
np.random.seed(0)
|
||||
_GCONST_ = -0.9189385332046727 # ln(sqrt(2*pi))
|
||||
|
||||
|
||||
class Score_Observer:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
self.max_epoch = 0
|
||||
self.max_score = 0.0
|
||||
self.last = 0.0
|
||||
|
||||
def update(self, score, epoch, print_score=True):
|
||||
self.last = score
|
||||
if epoch == 0 or score > self.max_score:
|
||||
self.max_score = score
|
||||
self.max_epoch = epoch
|
||||
if print_score:
|
||||
self.print_score()
|
||||
|
||||
def print_score(self):
|
||||
print('{:s}: \t last: {:.2f} \t max: {:.2f} \t epoch_max: {:d}'.format(
|
||||
self.name, self.last, self.max_score, self.max_epoch))
|
||||
|
||||
|
||||
def t2np(tensor):
|
||||
'''pytorch tensor -> numpy array'''
|
||||
return tensor.cpu().data.numpy() if tensor is not None else None
|
||||
|
||||
|
||||
def get_logp(C, z, logdet_J):
|
||||
logp = C * _GCONST_ - 0.5*torch.sum(z**2, 1) + logdet_J
|
||||
return logp
|
||||
|
||||
|
||||
def rescale(x):
|
||||
return (x - x.min()) / (x.max() - x.min())
|
|
@ -0,0 +1,147 @@
|
|||
import os
|
||||
import datetime
|
||||
import numpy as np
|
||||
from skimage import morphology
|
||||
from skimage.segmentation import mark_boundaries
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib
|
||||
from utils import *
|
||||
|
||||
OUT_DIR = './viz/'
|
||||
|
||||
norm = matplotlib.colors.Normalize(vmin=0.0, vmax=255.0)
|
||||
cm = 1/2.54
|
||||
dpi = 300
|
||||
|
||||
def denormalization(x, norm_mean, norm_std):
|
||||
mean = np.array(norm_mean)
|
||||
std = np.array(norm_std)
|
||||
x = (((x.transpose(1, 2, 0) * std) + mean) * 255.).astype(np.uint8)
|
||||
return x
|
||||
|
||||
|
||||
def export_hist(c, gts, scores, threshold):
|
||||
print('Exporting histogram...')
|
||||
plt.rcParams.update({'font.size': 4})
|
||||
image_dirs = os.path.join(OUT_DIR, c.model)
|
||||
os.makedirs(image_dirs, exist_ok=True)
|
||||
Y = scores.flatten()
|
||||
Y_label = gts.flatten()
|
||||
fig = plt.figure(figsize=(4*cm, 4*cm), dpi=dpi)
|
||||
ax = plt.Axes(fig, [0., 0., 1., 1.])
|
||||
fig.add_axes(ax)
|
||||
plt.hist([Y[Y_label==1], Y[Y_label==0]], 500, density=True, color=['r', 'g'], label=['ANO', 'TYP'], alpha=0.75, histtype='barstacked')
|
||||
image_file = os.path.join(image_dirs, 'hist_images_' + datetime.datetime.now().strftime("%Y-%m-%d-%H:%M:%S"))
|
||||
fig.savefig(image_file, dpi=dpi, format='svg', bbox_inches = 'tight', pad_inches = 0.0)
|
||||
plt.close()
|
||||
|
||||
def export_groundtruth(c, test_img, gts):
|
||||
image_dirs = os.path.join(OUT_DIR, c.model, 'gt_images_' + datetime.datetime.now().strftime("%Y-%m-%d-%H:%M:%S"))
|
||||
# images
|
||||
if not os.path.isdir(image_dirs):
|
||||
print('Exporting grountruth...')
|
||||
os.makedirs(image_dirs, exist_ok=True)
|
||||
num = len(test_img)
|
||||
kernel = morphology.disk(4)
|
||||
for i in range(num):
|
||||
img = test_img[i]
|
||||
img = denormalization(img, c.norm_mean, c.norm_std)
|
||||
# gts
|
||||
gt_mask = gts[i].astype(np.float64)
|
||||
gt_mask = morphology.opening(gt_mask, kernel)
|
||||
gt_mask = (255.0*gt_mask).astype(np.uint8)
|
||||
gt_img = mark_boundaries(img, gt_mask, color=(1, 0, 0), mode='thick')
|
||||
#
|
||||
fig = plt.figure(figsize=(2*cm, 2*cm), dpi=dpi)
|
||||
ax = plt.Axes(fig, [0., 0., 1., 1.])
|
||||
ax.set_axis_off()
|
||||
fig.add_axes(ax)
|
||||
ax.imshow(gt_img)
|
||||
image_file = os.path.join(image_dirs, '{:08d}'.format(i))
|
||||
fig.savefig(image_file, dpi=dpi, format='svg', bbox_inches = 'tight', pad_inches = 0.0)
|
||||
plt.close()
|
||||
|
||||
|
||||
def export_scores(c, test_img, scores, threshold):
|
||||
image_dirs = os.path.join(OUT_DIR, c.model, 'sc_images_' + datetime.datetime.now().strftime("%Y-%m-%d-%H:%M:%S"))
|
||||
# images
|
||||
if not os.path.isdir(image_dirs):
|
||||
print('Exporting scores...')
|
||||
os.makedirs(image_dirs, exist_ok=True)
|
||||
num = len(test_img)
|
||||
kernel = morphology.disk(4)
|
||||
scores_norm = 1.0/scores.max()
|
||||
for i in range(num):
|
||||
img = test_img[i]
|
||||
img = denormalization(img, c.norm_mean, c.norm_std)
|
||||
# scores
|
||||
score_mask = np.zeros_like(scores[i])
|
||||
score_mask[scores[i] > threshold] = 1.0
|
||||
score_mask = morphology.opening(score_mask, kernel)
|
||||
score_mask = (255.0*score_mask).astype(np.uint8)
|
||||
score_img = mark_boundaries(img, score_mask, color=(1, 0, 0), mode='thick')
|
||||
score_map = (255.0*scores[i]*scores_norm).astype(np.uint8)
|
||||
#
|
||||
fig_img, ax_img = plt.subplots(2, 1, figsize=(2*cm, 4*cm))
|
||||
for ax_i in ax_img:
|
||||
ax_i.axes.xaxis.set_visible(False)
|
||||
ax_i.axes.yaxis.set_visible(False)
|
||||
ax_i.spines['top'].set_visible(False)
|
||||
ax_i.spines['right'].set_visible(False)
|
||||
ax_i.spines['bottom'].set_visible(False)
|
||||
ax_i.spines['left'].set_visible(False)
|
||||
#
|
||||
plt.subplots_adjust(hspace = 0.1, wspace = 0.1)
|
||||
ax_img[0].imshow(img, cmap='gray', interpolation='none')
|
||||
ax_img[0].imshow(score_map, cmap='jet', norm=norm, alpha=0.5, interpolation='none')
|
||||
ax_img[1].imshow(score_img)
|
||||
image_file = os.path.join(image_dirs, '{:08d}'.format(i))
|
||||
fig_img.savefig(image_file, dpi=dpi, format='svg', bbox_inches = 'tight', pad_inches = 0.0)
|
||||
plt.close()
|
||||
|
||||
|
||||
def export_test_images(c, test_img, gts, scores, threshold):
|
||||
image_dirs = os.path.join(OUT_DIR, c.model, 'images_' + datetime.datetime.now().strftime("%Y-%m-%d-%H:%M:%S"))
|
||||
cm = 1/2.54
|
||||
# images
|
||||
if not os.path.isdir(image_dirs):
|
||||
print('Exporting images...')
|
||||
os.makedirs(image_dirs, exist_ok=True)
|
||||
num = len(test_img)
|
||||
font = {'family': 'serif', 'color': 'black', 'weight': 'normal', 'size': 8}
|
||||
kernel = morphology.disk(4)
|
||||
scores_norm = 1.0/scores.max()
|
||||
for i in range(num):
|
||||
img = test_img[i]
|
||||
img = denormalization(img, c.norm_mean, c.norm_std)
|
||||
# gts
|
||||
gt_mask = gts[i].astype(np.float64)
|
||||
print('GT:', i, gt_mask.sum())
|
||||
gt_mask = morphology.opening(gt_mask, kernel)
|
||||
gt_mask = (255.0*gt_mask).astype(np.uint8)
|
||||
gt_img = mark_boundaries(img, gt_mask, color=(1, 0, 0), mode='thick')
|
||||
# scores
|
||||
score_mask = np.zeros_like(scores[i])
|
||||
score_mask[scores[i] > threshold] = 1.0
|
||||
print('SC:', i, score_mask.sum())
|
||||
score_mask = morphology.opening(score_mask, kernel)
|
||||
score_mask = (255.0*score_mask).astype(np.uint8)
|
||||
score_img = mark_boundaries(img, score_mask, color=(1, 0, 0), mode='thick')
|
||||
score_map = (255.0*scores[i]*scores_norm).astype(np.uint8)
|
||||
#
|
||||
fig_img, ax_img = plt.subplots(3, 1, figsize=(2*cm, 6*cm))
|
||||
for ax_i in ax_img:
|
||||
ax_i.axes.xaxis.set_visible(False)
|
||||
ax_i.axes.yaxis.set_visible(False)
|
||||
ax_i.spines['top'].set_visible(False)
|
||||
ax_i.spines['right'].set_visible(False)
|
||||
ax_i.spines['bottom'].set_visible(False)
|
||||
ax_i.spines['left'].set_visible(False)
|
||||
#
|
||||
plt.subplots_adjust(hspace = 0.1, wspace = 0.1)
|
||||
ax_img[0].imshow(gt_img)
|
||||
ax_img[1].imshow(score_map, cmap='jet', norm=norm)
|
||||
ax_img[2].imshow(score_img)
|
||||
image_file = os.path.join(image_dirs, '{:08d}'.format(i))
|
||||
fig_img.savefig(image_file, dpi=dpi, format='svg', bbox_inches = 'tight', pad_inches = 0.0)
|
||||
plt.close()
|
Loading…
Reference in New Issue