WACV 2022 release

pull/33/head
gudovskiy 2021-07-26 17:57:53 -07:00
commit 6a520d5eeb
17 changed files with 6881 additions and 0 deletions

61
README.md 100644
View File

@ -0,0 +1,61 @@
## CFLOW-AD: Real-Time Unsupervised Anomaly Detection with Localization via Conditional Normalizing Flows
WACV 2022 preprint:[https://arxiv.org/abs/????.?????](https://arxiv.org/abs/????.?????)
## Abstract
Unsupervised anomaly detection with localization has many practical applications when labeling is infeasible and, moreover, when anomaly examples are completely missing in the train data. While recently proposed models for such data setup achieve high accuracy metrics, their complexity is a limiting factor for real-time processing. In this paper, we propose a real-time model and analytically derive its relationship to prior methods. Our CFLOW-AD model is based on a conditional normalizing flow framework adopted for anomaly detection with localization. In particular, CFLOW-AD consists of a discriminatively pretrained encoder followed by a multi-scale generative decoders where the latter explicitly estimate likelihood of the encoded features. Our approach results in a computationally and memory-efficient model: CFLOW-AD is faster and smaller by a factor of 10x than prior state-of-the-art with the same input setting. Our experiments on the MVTec dataset show that CFLOW-AD outperforms previous methods by 0.36% AUROC in detection task, by 1.12% AUROC and 2.5% AUPRO in localization task, respectively. We open-source our code with fully reproducible experiments.
## BibTex Citation
If you like our [paper](https://arxiv.org/abs/????.?????) or code, please cite its WACV 2022 preprint using the following BibTex:
```
@article{cflow_ad,
title={CFLOW-AD: Real-Time Unsupervised Anomaly Detection with Localization via Conditional Normalizing Flows},
author={Gudovskiy, Denis and Ishizaka, Shun and Kozuka, Kazuki and Tsukizawa, Sotaro},
journal={arXiv:????.?????},
year={2021}
}
```
## Installation
- Clone this repository: tested on Python 3.8
- Install [PyTorch](http://pytorch.org/): tested on v1.8
- Install [FrEIA Flows](https://github.com/VLL-HD/FrEIA): tested on [the recent branch](https://github.com/VLL-HD/FrEIA/tree/4e0c6ab42b26ec6e41b1ee2abb1a8b6562752b00)
- Other dependencies in requirements.txt
Install all packages with this command:
```
$ python3 -m pip install -U -r requirements.txt
```
## Datasets
We support [MVTec AD dataset](https://www.mvtec.com/de/unternehmen/forschung/datasets/mvtec-ad/) for anomaly localization in factory setting and [Shanghai Tech Campus (STC)](https://svip-lab.github.io/dataset/campus_dataset.html) dataset with surveillance camera videos. Please, download dataset from URLs and extract to 'data' folder.
## Code Organization
- ./custom_datasets - contains dataloaders for MVTec and STC
- ./custom_models - contains pretrained feature extractors
## Running Experiments
- Run code by selecting class name, feature extractor, input size, flow model etc.
- The commands below should reproduce our reference MVTec results:
```
python3 main.py --gpu 0 --pro -inp 512 --dataset mvtec --class-name bottle
python3 main.py --gpu 0 --pro -inp 256 --dataset mvtec --class-name cable
python3 main.py --gpu 0 --pro -inp 256 --dataset mvtec --class-name capsule
python3 main.py --gpu 0 --pro -inp 512 --dataset mvtec --class-name carpet
python3 main.py --gpu 0 --pro -inp 512 --dataset mvtec --class-name grid
python3 main.py --gpu 0 --pro -inp 256 --dataset mvtec --class-name hazelnut
python3 main.py --gpu 0 --pro -inp 512 --dataset mvtec --class-name leather
python3 main.py --gpu 0 --pro -inp 256 --dataset mvtec --class-name metal_nut
python3 main.py --gpu 0 --pro -inp 256 --dataset mvtec --class-name pill
python3 main.py --gpu 0 --pro -inp 512 --dataset mvtec --class-name screw
python3 main.py --gpu 0 --pro -inp 512 --dataset mvtec --class-name tile
python3 main.py --gpu 0 --pro -inp 512 --dataset mvtec --class-name toothbrush
python3 main.py --gpu 0 --pro -inp 128 --dataset mvtec --class-name transistor
python3 main.py --gpu 0 --pro -inp 512 --dataset mvtec --class-name wood
python3 main.py --gpu 0 --pro -inp 512 --dataset mvtec --class-name zipper
```
## CFLOW-AD Architecture
![CFLOW-AD](./images/fig-cflow.svg)
## Reference CFLOW-AD Results for MVTec
![CFLOW-AD](./images/fig-table.svg)

52
config.py 100644
View File

@ -0,0 +1,52 @@
from __future__ import print_function
import argparse
__all__ = ['get_args']
def get_args():
parser = argparse.ArgumentParser(description='CFLOW-AD')
parser.add_argument('--dataset', default='mvtec', type=str, metavar='D',
help='dataset name: mvtec/stc (default: mvtec)')
parser.add_argument('--checkpoint', default='', type=str, metavar='D',
help='file with saved checkpoint')
parser.add_argument('-cl', '--class-name', default='none', type=str, metavar='C',
help='class name for MVTec/STC (default: none)')
parser.add_argument('-enc', '--enc-arch', default='wide_resnet50_2', type=str, metavar='A',
help='feature extractor architecture (default: wide_resnet50_2)')
parser.add_argument('-dec', '--dec-arch', default='freia-cflow', type=str, metavar='A',
help='normalizing flow model (default: freia-cflow)')
parser.add_argument('-pl', '--pool-layers', default=3, type=int, metavar='L',
help='number of layers used in NF model (default: 3)')
parser.add_argument('-cb', '--coupling-blocks', default=8, type=int, metavar='L',
help='number of layers used in NF model (default: 8)')
parser.add_argument('-run', '--run-name', default=0, type=int, metavar='C',
help='name of the run (default: 0)')
parser.add_argument('-inp', '--input-size', default=256, type=int, metavar='C',
help='image resize dimensions (default: 256)')
parser.add_argument("--action-type", default='norm-train', type=str, metavar='T',
help='norm-train (default: norm-train)')
parser.add_argument('-bs', '--batch-size', default=32, type=int, metavar='B',
help='train batch size (default: 32)')
parser.add_argument('--lr', type=float, default=2e-4, metavar='LR',
help='learning rate (default: 2e-4)')
parser.add_argument('--meta-epochs', type=int, default=25, metavar='N',
help='number of meta epochs to train (default: 25)')
parser.add_argument('--sub-epochs', type=int, default=8, metavar='N',
help='number of sub epochs to train (default: 8)')
parser.add_argument('--pro', action='store_true', default=False,
help='enables estimation of AUPRO metric')
parser.add_argument('--viz', action='store_true', default=False,
help='saves test data visualizations')
parser.add_argument('--workers', default=4, type=int, metavar='G',
help='number of data loading workers (default: 4)')
parser.add_argument("--gpu", default='0', type=str, metavar='G',
help='GPU device number')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--video-path', default='.', type=str, metavar='D',
help='video file path')
args = parser.parse_args()
return args

View File

@ -0,0 +1 @@
from .loader import MVTecDataset, StcDataset

View File

@ -0,0 +1,219 @@
import os
from PIL import Image
import numpy as np
import torch
from torchvision.io import read_video, write_jpeg
from torch.utils.data import Dataset
from torchvision import transforms as T
__all__ = ('MVTecDataset', 'StcDataset')
# URL = 'ftp://guest:GU.205dldo@ftp.softronics.ch/mvtec_anomaly_detection/mvtec_anomaly_detection.tar.xz'
MVTEC_CLASS_NAMES = ['bottle', 'cable', 'capsule', 'carpet', 'grid',
'hazelnut', 'leather', 'metal_nut', 'pill', 'screw',
'tile', 'toothbrush', 'transistor', 'wood', 'zipper']
STC_CLASS_NAMES = ['01', '02', '03', '04', '05', '06',
'07', '08', '09', '10', '11', '12'] #, '13' - no ground-truth]
class StcDataset(Dataset):
def __init__(self, c, is_train=True):
assert c.class_name in STC_CLASS_NAMES, 'class_name: {}, should be in {}'.format(c.class_name, STC_CLASS_NAMES)
self.class_name = c.class_name
self.is_train = is_train
self.cropsize = c.crp_size
#
if is_train:
self.dataset_path = os.path.join(c.data_path, 'training')
self.dataset_vid = os.path.join(self.dataset_path, 'videos')
self.dataset_dir = os.path.join(self.dataset_path, 'frames')
self.dataset_files = sorted([f for f in os.listdir(self.dataset_vid) if f.startswith(self.class_name)])
if not os.path.isdir(self.dataset_dir):
os.mkdir(self.dataset_dir)
done_file = os.path.join(self.dataset_path, 'frames_{}.pt'.format(self.class_name))
print(done_file)
H, W = 480, 856
if os.path.isfile(done_file):
assert torch.load(done_file) == len(self.dataset_files), 'train frames are not processed!'
else:
count = 0
for dataset_file in self.dataset_files:
print(dataset_file)
data = read_video(os.path.join(self.dataset_vid, dataset_file)) # read video file entirely -> mem issue!!!
vid = data[0] # weird read_video that returns byte tensor in format [T,H,W,C]
fps = data[2]['video_fps']
print('video mu/std: {}/{} {}'.format(torch.mean(vid/255.0, (0,1,2)), torch.std(vid/255.0, (0,1,2)), vid.shape))
assert [H, W] == [vid.size(1), vid.size(2)], 'same H/W'
dataset_file_dir = os.path.join(self.dataset_dir, os.path.splitext(dataset_file)[0])
os.mkdir(dataset_file_dir)
count = count + 1
for i, frame in enumerate(vid):
filename = '{0:08d}.jpg'.format(i)
write_jpeg(frame.permute((2, 0, 1)), os.path.join(dataset_file_dir, filename), 80)
torch.save(torch.tensor(count), done_file)
#
self.x, self.y, self.mask = self.load_dataset_folder()
else:
self.dataset_path = os.path.join(c.data_path, 'testing')
self.x, self.y, self.mask = self.load_dataset_folder()
# set transforms
if is_train:
self.transform_x = T.Compose([
T.Resize(c.img_size, Image.ANTIALIAS),
T.RandomRotation(5),
T.CenterCrop(c.crp_size),
T.ToTensor()])
# test:
else:
self.transform_x = T.Compose([
T.Resize(c.img_size, Image.ANTIALIAS),
T.CenterCrop(c.crp_size),
T.ToTensor()])
# mask
self.transform_mask = T.Compose([
T.ToPILImage(),
T.Resize(c.img_size, Image.NEAREST),
T.CenterCrop(c.crp_size),
T.ToTensor()])
self.normalize = T.Compose([T.Normalize(c.norm_mean, c.norm_std)])
def __getitem__(self, idx):
x, y, mask = self.x[idx], self.y[idx], self.mask[idx]
x = Image.open(x).convert('RGB')
x = self.normalize(self.transform_x(x))
if y == 0: #self.is_train:
mask = torch.zeros([1, self.cropsize[0], self.cropsize[1]])
else:
mask = self.transform_mask(mask)
#
return x, y, mask
def __len__(self):
return len(self.x)
def load_dataset_folder(self):
phase = 'train' if self.is_train else 'test'
x, y, mask = list(), list(), list()
img_dir = os.path.join(self.dataset_path, 'frames')
img_types = sorted([f for f in os.listdir(img_dir) if f.startswith(self.class_name)])
gt_frame_dir = os.path.join(self.dataset_path, 'test_frame_mask')
gt_pixel_dir = os.path.join(self.dataset_path, 'test_pixel_mask')
for i, img_type in enumerate(img_types):
print('Folder:', img_type)
# load images
img_type_dir = os.path.join(img_dir, img_type)
img_fpath_list = sorted([os.path.join(img_type_dir, f) for f in os.listdir(img_type_dir) if f.endswith('.jpg')])
x.extend(img_fpath_list)
# labels for every test image
if phase == 'test':
gt_pixel = np.load('{}.npy'.format(os.path.join(gt_pixel_dir, img_type)))
gt_frame = np.load('{}.npy'.format(os.path.join(gt_frame_dir, img_type)))
if i == 0:
m = gt_pixel
y = gt_frame
else:
m = np.concatenate((m, gt_pixel), axis=0)
y = np.concatenate((y, gt_frame), axis=0)
#
mask = [e for e in m] # np.expand_dims(e, axis=0)
assert len(x) == len(y), 'number of x and y should be same'
assert len(x) == len(mask), 'number of x and mask should be same'
else:
mask.extend([None] * len(img_fpath_list))
y.extend([0] * len(img_fpath_list))
#
return list(x), list(y), list(mask)
class MVTecDataset(Dataset):
def __init__(self, c, is_train=True):
assert c.class_name in MVTEC_CLASS_NAMES, 'class_name: {}, should be in {}'.format(c.class_name, MVTEC_CLASS_NAMES)
self.dataset_path = c.data_path
self.class_name = c.class_name
self.is_train = is_train
self.cropsize = c.crp_size
# load dataset
self.x, self.y, self.mask = self.load_dataset_folder()
# set transforms
if is_train:
self.transform_x = T.Compose([
T.Resize(c.img_size, Image.ANTIALIAS),
T.RandomRotation(5),
T.CenterCrop(c.crp_size),
T.ToTensor()])
# test:
else:
self.transform_x = T.Compose([
T.Resize(c.img_size, Image.ANTIALIAS),
T.CenterCrop(c.crp_size),
T.ToTensor()])
# mask
self.transform_mask = T.Compose([
T.Resize(c.img_size, Image.NEAREST),
T.CenterCrop(c.crp_size),
T.ToTensor()])
self.normalize = T.Compose([T.Normalize(c.norm_mean, c.norm_std)])
def __getitem__(self, idx):
x, y, mask = self.x[idx], self.y[idx], self.mask[idx]
#x = Image.open(x).convert('RGB')
x = Image.open(x)
if self.class_name in ['zipper', 'screw', 'grid']: # handle greyscale classes
x = np.expand_dims(np.array(x), axis=2)
x = np.concatenate([x, x, x], axis=2)
x = Image.fromarray(x.astype('uint8')).convert('RGB')
#
x = self.normalize(self.transform_x(x))
#
if y == 0:
mask = torch.zeros([1, self.cropsize[0], self.cropsize[1]])
else:
mask = Image.open(mask)
mask = self.transform_mask(mask)
return x, y, mask
def __len__(self):
return len(self.x)
def load_dataset_folder(self):
phase = 'train' if self.is_train else 'test'
x, y, mask = [], [], []
img_dir = os.path.join(self.dataset_path, self.class_name, phase)
gt_dir = os.path.join(self.dataset_path, self.class_name, 'ground_truth')
img_types = sorted(os.listdir(img_dir))
for img_type in img_types:
# load images
img_type_dir = os.path.join(img_dir, img_type)
if not os.path.isdir(img_type_dir):
continue
img_fpath_list = sorted([os.path.join(img_type_dir, f)
for f in os.listdir(img_type_dir)
if f.endswith('.png')])
x.extend(img_fpath_list)
# load gt labels
if img_type == 'good':
y.extend([0] * len(img_fpath_list))
mask.extend([None] * len(img_fpath_list))
else:
y.extend([1] * len(img_fpath_list))
gt_type_dir = os.path.join(gt_dir, img_type)
img_fname_list = [os.path.splitext(os.path.basename(f))[0] for f in img_fpath_list]
gt_fpath_list = [os.path.join(gt_type_dir, img_fname + '_mask.png')
for img_fname in img_fname_list]
mask.extend(gt_fpath_list)
assert len(x) == len(y), 'number of x and y should be same'
return list(x), list(y), list(mask)

View File

@ -0,0 +1,3 @@
from .resnet import *
from .mobilenetv3 import *
from .utils import *

View File

@ -0,0 +1,276 @@
import torch
from functools import partial
from torch import nn, Tensor
from torch.nn import functional as F
from typing import Any, Callable, Dict, List, Optional, Sequence
from torchvision.models.utils import load_state_dict_from_url
from torchvision.models.mobilenetv2 import _make_divisible, ConvBNActivation
__all__ = ["MobileNetV3", "mobilenet_v3_large", "mobilenet_v3_small"]
model_urls = {
"mobilenet_v3_large": "https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth",
"mobilenet_v3_small": "https://download.pytorch.org/models/mobilenet_v3_small-047dcff4.pth",
}
class SqueezeExcitation(nn.Module):
def __init__(self, input_channels: int, squeeze_factor: int = 4):
super().__init__()
squeeze_channels = _make_divisible(input_channels // squeeze_factor, 8)
self.fc1 = nn.Conv2d(input_channels, squeeze_channels, 1)
self.relu = nn.ReLU(inplace=True)
self.fc2 = nn.Conv2d(squeeze_channels, input_channels, 1)
def _scale(self, input: Tensor, inplace: bool) -> Tensor:
scale = F.adaptive_avg_pool2d(input, 1)
scale = self.fc1(scale)
scale = self.relu(scale)
scale = self.fc2(scale)
return F.hardsigmoid(scale, inplace=inplace)
def forward(self, input: Tensor) -> Tensor:
scale = self._scale(input, True)
return scale * input
class InvertedResidualConfig:
def __init__(self, input_channels: int, kernel: int, expanded_channels: int, out_channels: int, use_se: bool,
activation: str, stride: int, dilation: int, width_mult: float):
self.input_channels = self.adjust_channels(input_channels, width_mult)
self.kernel = kernel
self.expanded_channels = self.adjust_channels(expanded_channels, width_mult)
self.out_channels = self.adjust_channels(out_channels, width_mult)
self.use_se = use_se
self.use_hs = activation == "HS"
self.stride = stride
self.dilation = dilation
@staticmethod
def adjust_channels(channels: int, width_mult: float):
return _make_divisible(channels * width_mult, 8)
class InvertedResidual(nn.Module):
def __init__(self, cnf: InvertedResidualConfig, norm_layer: Callable[..., nn.Module],
se_layer: Callable[..., nn.Module] = SqueezeExcitation):
super().__init__()
if not (1 <= cnf.stride <= 2):
raise ValueError('illegal stride value')
self.use_res_connect = cnf.stride == 1 and cnf.input_channels == cnf.out_channels
layers: List[nn.Module] = []
activation_layer = nn.Hardswish if cnf.use_hs else nn.ReLU
# expand
if cnf.expanded_channels != cnf.input_channels:
layers.append(ConvBNActivation(cnf.input_channels, cnf.expanded_channels, kernel_size=1,
norm_layer=norm_layer, activation_layer=activation_layer))
# depthwise
stride = 1 if cnf.dilation > 1 else cnf.stride
layers.append(ConvBNActivation(cnf.expanded_channels, cnf.expanded_channels, kernel_size=cnf.kernel,
stride=stride, dilation=cnf.dilation, groups=cnf.expanded_channels,
norm_layer=norm_layer, activation_layer=activation_layer))
if cnf.use_se:
layers.append(se_layer(cnf.expanded_channels))
# project
layers.append(ConvBNActivation(cnf.expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer,
activation_layer=nn.Identity))
self.block = nn.Sequential(*layers)
self.out_channels = cnf.out_channels
self._is_cn = cnf.stride > 1
def forward(self, input: Tensor) -> Tensor:
result = self.block(input)
if self.use_res_connect:
result += input
return result
class MobileNetV3(nn.Module):
def __init__(
self,
inverted_residual_setting: List[InvertedResidualConfig],
last_channel: int,
num_classes: int = 1000,
block: Optional[Callable[..., nn.Module]] = None,
norm_layer: Optional[Callable[..., nn.Module]] = None
) -> None:
"""
MobileNet V3 main class
Args:
inverted_residual_setting (List[InvertedResidualConfig]): Network structure
last_channel (int): The number of channels on the penultimate layer
num_classes (int): Number of classes
block (Optional[Callable[..., nn.Module]]): Module specifying inverted residual building block for mobilenet
norm_layer (Optional[Callable[..., nn.Module]]): Module specifying the normalization layer to use
"""
super().__init__()
if not inverted_residual_setting:
raise ValueError("The inverted_residual_setting should not be empty")
elif not (isinstance(inverted_residual_setting, Sequence) and
all([isinstance(s, InvertedResidualConfig) for s in inverted_residual_setting])):
raise TypeError("The inverted_residual_setting should be List[InvertedResidualConfig]")
if block is None:
block = InvertedResidual
if norm_layer is None:
norm_layer = partial(nn.BatchNorm2d, eps=0.001, momentum=0.01)
layers: List[nn.Module] = []
# building first layer
firstconv_output_channels = inverted_residual_setting[0].input_channels
layers.append(ConvBNActivation(3, firstconv_output_channels, kernel_size=3, stride=2, norm_layer=norm_layer,
activation_layer=nn.Hardswish))
# building inverted residual blocks
for cnf in inverted_residual_setting:
layers.append(block(cnf, norm_layer))
# building last several layers
lastconv_input_channels = inverted_residual_setting[-1].out_channels
lastconv_output_channels = 6 * lastconv_input_channels
layers.append(ConvBNActivation(lastconv_input_channels, lastconv_output_channels, kernel_size=1,
norm_layer=norm_layer, activation_layer=nn.Hardswish))
self.features = nn.Sequential(*layers)
#self.avgpool = nn.AdaptiveAvgPool2d(1)
#self.classifier = nn.Sequential(
# nn.Linear(lastconv_output_channels, last_channel),
# nn.Hardswish(inplace=True),
# nn.Dropout(p=0.2, inplace=True),
# nn.Linear(last_channel, num_classes),
#)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.zeros_(m.bias)
def _forward_impl(self, x: Tensor) -> Tensor:
x = self.features(x)
# remove extra layers
#x = self.avgpool(x)
#x = torch.flatten(x, 1)
#x = self.classifier(x)
return x
def forward(self, x: Tensor) -> Tensor:
return self._forward_impl(x)
def _mobilenet_v3_conf(arch: str, params: Dict[str, Any]):
# non-public config parameters
reduce_divider = 2 if params.pop('_reduced_tail', False) else 1
dilation = 2 if params.pop('_dilated', False) else 1
width_mult = params.pop('_width_mult', 1.0)
bneck_conf = partial(InvertedResidualConfig, width_mult=width_mult)
adjust_channels = partial(InvertedResidualConfig.adjust_channels, width_mult=width_mult)
if arch == "mobilenet_v3_large":
inverted_residual_setting = [
bneck_conf(16, 3, 16, 16, False, "RE", 1, 1),
bneck_conf(16, 3, 64, 24, False, "RE", 2, 1), # C1
bneck_conf(24, 3, 72, 24, False, "RE", 1, 1),
bneck_conf(24, 5, 72, 40, True, "RE", 2, 1), # C2
bneck_conf(40, 5, 120, 40, True, "RE", 1, 1),
bneck_conf(40, 5, 120, 40, True, "RE", 1, 1),
bneck_conf(40, 3, 240, 80, False, "HS", 2, 1), # C3
bneck_conf(80, 3, 200, 80, False, "HS", 1, 1),
bneck_conf(80, 3, 184, 80, False, "HS", 1, 1),
bneck_conf(80, 3, 184, 80, False, "HS", 1, 1),
bneck_conf(80, 3, 480, 112, True, "HS", 1, 1),
bneck_conf(112, 3, 672, 112, True, "HS", 1, 1),
bneck_conf(112, 5, 672, 160 // reduce_divider, True, "HS", 2, dilation), # C4
bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, True, "HS", 1, dilation),
bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, True, "HS", 1, dilation),
]
last_channel = adjust_channels(1280 // reduce_divider) # C5
elif arch == "mobilenet_v3_small":
inverted_residual_setting = [
bneck_conf(16, 3, 16, 16, True, "RE", 2, 1), # C1
bneck_conf(16, 3, 72, 24, False, "RE", 2, 1), # C2
bneck_conf(24, 3, 88, 24, False, "RE", 1, 1),
bneck_conf(24, 5, 96, 40, True, "HS", 2, 1), # C3
bneck_conf(40, 5, 240, 40, True, "HS", 1, 1),
bneck_conf(40, 5, 240, 40, True, "HS", 1, 1),
bneck_conf(40, 5, 120, 48, True, "HS", 1, 1),
bneck_conf(48, 5, 144, 48, True, "HS", 1, 1),
bneck_conf(48, 5, 288, 96 // reduce_divider, True, "HS", 2, dilation), # C4
bneck_conf(96 // reduce_divider, 5, 576 // reduce_divider, 96 // reduce_divider, True, "HS", 1, dilation),
bneck_conf(96 // reduce_divider, 5, 576 // reduce_divider, 96 // reduce_divider, True, "HS", 1, dilation),
]
last_channel = adjust_channels(1024 // reduce_divider) # C5
else:
raise ValueError("Unsupported model type {}".format(arch))
return inverted_residual_setting, last_channel
def _mobilenet_v3_model(
arch: str,
inverted_residual_setting: List[InvertedResidualConfig],
last_channel: int,
pretrained: bool,
progress: bool,
**kwargs: Any
):
model = MobileNetV3(inverted_residual_setting, last_channel, **kwargs)
if pretrained:
if model_urls.get(arch, None) is None:
raise ValueError("No checkpoint is available for model type {}".format(arch))
state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
#model.load_state_dict(state_dict)
model.load_state_dict(state_dict, strict=False)
return model
def mobilenet_v3_large(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MobileNetV3:
"""
Constructs a large MobileNetV3 architecture from
`"Searching for MobileNetV3" <https://arxiv.org/abs/1905.02244>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
arch = "mobilenet_v3_large"
inverted_residual_setting, last_channel = _mobilenet_v3_conf(arch, kwargs)
return _mobilenet_v3_model(arch, inverted_residual_setting, last_channel, pretrained, progress, **kwargs)
def mobilenet_v3_small(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MobileNetV3:
"""
Constructs a small MobileNetV3 architecture from
`"Searching for MobileNetV3" <https://arxiv.org/abs/1905.02244>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
arch = "mobilenet_v3_small"
inverted_residual_setting, last_channel = _mobilenet_v3_conf(arch, kwargs)
return _mobilenet_v3_model(arch, inverted_residual_setting, last_channel, pretrained, progress, **kwargs)

View File

@ -0,0 +1,385 @@
import torch
from torch import Tensor
import torch.nn as nn
from .utils import load_state_dict_from_url
from typing import Type, Any, Callable, Union, List, Optional
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
'wide_resnet50_2', 'wide_resnet101_2']
model_urls = {
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
}
PADDING_MODE = 'reflect' # {'zeros', 'reflect', 'replicate', 'circular'}
def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=dilation, padding_mode = PADDING_MODE, groups=groups, bias=False, dilation=dilation)
def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
class BasicBlock(nn.Module):
expansion: int = 1
def __init__(
self,
inplanes: int,
planes: int,
stride: int = 1,
downsample: Optional[nn.Module] = None,
groups: int = 1,
base_width: int = 64,
dilation: int = 1,
norm_layer: Optional[Callable[..., nn.Module]] = None
) -> None:
super(BasicBlock, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
if groups != 1 or base_width != 64:
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
if dilation > 1:
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = norm_layer(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = norm_layer(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x: Tensor) -> Tensor:
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class Bottleneck(nn.Module):
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
# This variant is also known as ResNet V1.5 and improves accuracy according to
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
expansion: int = 4
def __init__(
self,
inplanes: int,
planes: int,
stride: int = 1,
downsample: Optional[nn.Module] = None,
groups: int = 1,
base_width: int = 64,
dilation: int = 1,
norm_layer: Optional[Callable[..., nn.Module]] = None
) -> None:
super(Bottleneck, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
width = int(planes * (base_width / 64.)) * groups
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv1x1(inplanes, width)
self.bn1 = norm_layer(width)
self.conv2 = conv3x3(width, width, stride, groups, dilation)
self.bn2 = norm_layer(width)
self.conv3 = conv1x1(width, planes * self.expansion)
self.bn3 = norm_layer(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x: Tensor) -> Tensor:
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class ResNet(nn.Module):
def __init__(
self,
block: Type[Union[BasicBlock, Bottleneck]],
layers: List[int],
num_classes: int = 1000,
zero_init_residual: bool = False,
groups: int = 1,
width_per_group: int = 64,
replace_stride_with_dilation: Optional[List[bool]] = None,
norm_layer: Optional[Callable[..., nn.Module]] = None
) -> None:
super(ResNet, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
self._norm_layer = norm_layer
self.inplanes = 64
self.dilation = 1
if replace_stride_with_dilation is None:
# each element in the tuple indicates if we should replace
# the 2x2 stride with a dilated convolution instead
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError("replace_stride_with_dilation should be None "
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
self.groups = groups
self.base_width = width_per_group
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, padding_mode = PADDING_MODE,
bias=False)
self.bn1 = norm_layer(self.inplanes)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0])
self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1])
self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2])
#self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
#self.fc = nn.Linear(512 * block.expansion, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
# Zero-initialize the last BN in each residual branch,
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
if zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type]
elif isinstance(m, BasicBlock):
nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type]
def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int,
stride: int = 1, dilate: bool = False) -> nn.Sequential:
norm_layer = self._norm_layer
downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
conv1x1(self.inplanes, planes * block.expansion, stride),
norm_layer(planes * block.expansion),)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
self.base_width, previous_dilation, norm_layer))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes, groups=self.groups,
base_width=self.base_width, dilation=self.dilation,
norm_layer=norm_layer))
return nn.Sequential(*layers)
def _forward_impl(self, x: Tensor) -> Tensor:
# See note [TorchScript super()]
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
# remove extra layers
#x = self.avgpool(x)
#x = torch.flatten(x, 1)
#x = self.fc(x)
return x
def forward(self, x: Tensor) -> Tensor:
return self._forward_impl(x)
def _resnet(
arch: str,
block: Type[Union[BasicBlock, Bottleneck]],
layers: List[int],
pretrained: bool,
progress: bool,
**kwargs: Any
) -> ResNet:
model = ResNet(block, layers, **kwargs)
if pretrained:
state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
#model.load_state_dict(state_dict)
model.load_state_dict(state_dict, strict=False)
return model
def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
r"""ResNet-18 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
**kwargs)
def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
r"""ResNet-34 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
**kwargs)
def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
r"""ResNet-50 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
**kwargs)
def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
r"""ResNet-101 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
**kwargs)
def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
r"""ResNet-152 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
**kwargs)
def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
r"""ResNeXt-50 32x4d model from
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs['groups'] = 32
kwargs['width_per_group'] = 4
return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
pretrained, progress, **kwargs)
def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
r"""ResNeXt-101 32x8d model from
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs['groups'] = 32
kwargs['width_per_group'] = 8
return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
pretrained, progress, **kwargs)
def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
r"""Wide ResNet-50-2 model from
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_.
The model is the same as ResNet except for the bottleneck number of channels
which is twice larger in every block. The number of channels in outer 1x1
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs['width_per_group'] = 64 * 2
return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
pretrained, progress, **kwargs)
def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
r"""Wide ResNet-101-2 model from
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_.
The model is the same as ResNet except for the bottleneck number of channels
which is twice larger in every block. The number of channels in outer 1x1
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs['width_per_group'] = 64 * 2
return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
pretrained, progress, **kwargs)

View File

@ -0,0 +1,73 @@
import os, math
import numpy as np
import torch
RESULT_DIR = './results'
WEIGHT_DIR = './weights'
MODEL_DIR = './models'
__all__ = ('save_results', 'save_weights', 'load_weights', 'adjust_learning_rate', 'warmup_learning_rate')
try:
from torch.hub import load_state_dict_from_url
except ImportError:
from torch.utils.model_zoo import load_url as load_state_dict_from_url
def save_results(det_roc_obs, seg_roc_obs, seg_pro_obs, model_name, class_name, run_date):
result = '{:.2f},{:.2f},{:.2f} \t\tfor {:s}/{:s}/{:s} at epoch {:d}/{:d}/{:d} for {:s}\n'.format(
det_roc_obs.max_score, seg_roc_obs.max_score, seg_pro_obs.max_score,
det_roc_obs.name, seg_roc_obs.name, seg_pro_obs.name,
det_roc_obs.max_epoch, seg_roc_obs.max_epoch, seg_pro_obs.max_epoch, class_name)
if not os.path.exists(RESULT_DIR):
os.makedirs(RESULT_DIR)
fp = open(os.path.join(RESULT_DIR, '{}_{}.txt'.format(model_name, run_date)), "w")
fp.write(result)
fp.close()
def save_weights(encoder, decoders, model_name, run_date):
if not os.path.exists(WEIGHT_DIR):
os.makedirs(WEIGHT_DIR)
state = {'encoder_state_dict': encoder.state_dict(),
'decoder_state_dict': [decoder.state_dict() for decoder in decoders]}
filename = '{}_{}.pt'.format(model_name, run_date)
path = os.path.join(WEIGHT_DIR, filename)
torch.save(state, path)
print('Saving weights to {}'.format(filename))
def load_weights(encoder, decoders, filename):
path = os.path.join(WEIGHT_DIR, filename)
state = torch.load(path)
encoder.load_state_dict(state['encoder_state_dict'], strict=False)
decoders = [decoder.load_state_dict(state, strict=False) for decoder, state in zip(decoders, state['decoder_state_dict'])]
print('Loading weights from {}'.format(filename))
def adjust_learning_rate(c, optimizer, epoch):
lr = c.lr
if c.lr_cosine:
eta_min = lr * (c.lr_decay_rate ** 3)
lr = eta_min + (lr - eta_min) * (
1 + math.cos(math.pi * epoch / c.meta_epochs)) / 2
else:
steps = np.sum(epoch >= np.asarray(c.lr_decay_epochs))
if steps > 0:
lr = lr * (c.lr_decay_rate ** steps)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
def warmup_learning_rate(c, epoch, batch_id, total_batches, optimizer):
if c.lr_warm and epoch < c.lr_warm_epochs:
p = (batch_id + epoch * total_batches) / \
(c.lr_warm_epochs * total_batches)
lr = c.lr_warmup_from + p * (c.lr_warmup_to - c.lr_warmup_from)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
#
for param_group in optimizer.param_groups:
lrate = param_group['lr']
return lrate

File diff suppressed because one or more lines are too long

After

Width:  |  Height:  |  Size: 1.1 MiB

4031
images/fig-table.svg 100644

File diff suppressed because it is too large Load Diff

After

Width:  |  Height:  |  Size: 319 KiB

93
main.py 100644
View File

@ -0,0 +1,93 @@
from __future__ import print_function
import os, random, time, math
import numpy as np
import torch
import timm
from timm.data import resolve_data_config
from config import get_args
from train import train
def init_seeds(seed=0):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def main(c):
# model
if c.action_type == 'video-train':
c.model = "{}_{}_{}".format(c.enc_arch, c.dec_arch, c.video_path)
elif c.action_type == 'norm-train' or c.action_type == 'norm-test':
c.model = "{}_{}_{}_pl{}_cb{}_inp{}_run{}_{}".format(
c.dataset, c.enc_arch, c.dec_arch, c.pool_layers, c.coupling_blocks, c.input_size, c.run_name, c.class_name)
else:
raise NotImplementedError('{} is not supported action-type!'.format(c.action_type))
# image
if ('vit' in c.enc_arch) or ('efficient' in c.enc_arch):
encoder = timm.create_model(c.enc_arch, pretrained=True)
arch_config = resolve_data_config({}, model=encoder)
c.norm_mean, c.norm_std = list(arch_config['mean']), list(arch_config['mean'])
c.img_size = arch_config['input_size'][1:] # HxW format
c.crp_size = arch_config['input_size'][1:] # HxW format
else:
c.img_size = (c.input_size, c.input_size) # HxW format
c.crp_size = (c.input_size, c.input_size) # HxW format
if c.dataset == 'stc':
c.norm_mean, c.norm_std = 3*[0.5], 3*[0.225]
else:
c.norm_mean, c.norm_std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
#
c.img_dims = [3] + list(c.img_size)
# network hyperparameters
c.clamp_alpha = 1.9 # see paper equation 2 for explanation
c.condition_vec = 128
c.dropout = 0.0 # dropout in s-t-networks
# dataloader parameters
if c.dataset == 'mvtec':
c.data_path = './data/MVTec-AD'
elif c.dataset == 'stc':
c.data_path = './data/STC/shanghaitech'
elif c.dataset == 'video':
c.data_path = c.video_path
else:
raise NotImplementedError('{} is not supported dataset!'.format(c.dataset))
# output settings
c.verbose = True
c.hide_tqdm_bar = True
c.save_results = True
# unsup-train
c.print_freq = 2
c.temp = 0.5
c.lr_decay_epochs = [i*c.meta_epochs//100 for i in [50,75,90]]
print('LR schedule: {}'.format(c.lr_decay_epochs))
c.lr_decay_rate = 0.1
c.lr_warm_epochs = 2
c.lr_warm = True
c.lr_cosine = True
if c.lr_warm:
c.lr_warmup_from = c.lr/10.0
if c.lr_cosine:
eta_min = c.lr * (c.lr_decay_rate ** 3)
c.lr_warmup_to = eta_min + (c.lr - eta_min) * (
1 + math.cos(math.pi * c.lr_warm_epochs / c.meta_epochs)) / 2
else:
c.lr_warmup_to = c.lr
########
os.environ['CUDA_VISIBLE_DEVICES'] = c.gpu
c.use_cuda = not c.no_cuda and torch.cuda.is_available()
init_seeds(seed=int(time.time()))
c.device = torch.device("cuda" if c.use_cuda else "cpu")
# selected function:
if c.action_type == 'norm-train':
train(c)
else:
raise NotImplementedError('{} is not supported action-type!'.format(c.action_type))
if __name__ == '__main__':
c = get_args()
main(c)

178
model.py 100644
View File

@ -0,0 +1,178 @@
import math
import torch
from torch import nn
from custom_models import *
# FrEIA (https://github.com/VLL-HD/FrEIA/)
import FrEIA.framework as Ff
import FrEIA.modules as Fm
import timm
def positionalencoding2d(D, H, W):
"""
:param D: dimension of the model
:param H: H of the positions
:param W: W of the positions
:return: DxHxW position matrix
"""
if D % 4 != 0:
raise ValueError("Cannot use sin/cos positional encoding with odd dimension (got dim={:d})".format(D))
P = torch.zeros(D, H, W)
# Each dimension use half of D
D = D // 2
div_term = torch.exp(torch.arange(0.0, D, 2) * -(math.log(1e4) / D))
pos_w = torch.arange(0.0, W).unsqueeze(1)
pos_h = torch.arange(0.0, H).unsqueeze(1)
P[0:D:2, :, :] = torch.sin(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, H, 1)
P[1:D:2, :, :] = torch.cos(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, H, 1)
P[D::2, :, :] = torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, W)
P[D+1::2,:, :] = torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, W)
return P
def subnet_fc(dims_in, dims_out):
return nn.Sequential(nn.Linear(dims_in, 2*dims_in), nn.ReLU(), nn.Linear(2*dims_in, dims_out))
def freia_flow_head(c, n_feat):
coder = Ff.SequenceINN(n_feat)
print('NF coder:', n_feat)
for k in range(c.coupling_blocks):
coder.append(Fm.AllInOneBlock, subnet_constructor=subnet_fc, affine_clamping=c.clamp_alpha,
global_affine_type='SOFTPLUS', permute_soft=True)
return coder
def freia_cflow_head(c, n_feat):
n_cond = c.condition_vec
coder = Ff.SequenceINN(n_feat)
print('CNF coder:', n_feat)
for k in range(c.coupling_blocks):
coder.append(Fm.AllInOneBlock, cond=0, cond_shape=(n_cond,), subnet_constructor=subnet_fc, affine_clamping=c.clamp_alpha,
global_affine_type='SOFTPLUS', permute_soft=True)
return coder
def load_decoder_arch(c, dim_in):
if c.dec_arch == 'freia-flow':
decoder = freia_flow_head(c, dim_in)
elif c.dec_arch == 'freia-cflow':
decoder = freia_cflow_head(c, dim_in)
else:
raise NotImplementedError('{} is not supported NF!'.format(c.dec_arch))
#print(decoder)
return decoder
activation = {}
def get_activation(name):
def hook(model, input, output):
activation[name] = output.detach()
return hook
def load_encoder_arch(c, L):
# encoder pretrained on natural images:
pool_cnt = 0
pool_dims = list()
pool_layers = ['layer'+str(i) for i in range(L)]
if 'resnet' in c.enc_arch:
if c.enc_arch == 'resnet18':
encoder = resnet18(pretrained=True, progress=True)
elif c.enc_arch == 'resnet34':
encoder = resnet34(pretrained=True, progress=True)
elif c.enc_arch == 'resnet50':
encoder = resnet50(pretrained=True, progress=True)
elif c.enc_arch == 'resnext50_32x4d':
encoder = resnext50_32x4d(pretrained=True, progress=True)
elif c.enc_arch == 'wide_resnet50_2':
encoder = wide_resnet50_2(pretrained=True, progress=True)
else:
raise NotImplementedError('{} is not supported architecture!'.format(c.enc_arch))
#
if L >= 3:
encoder.layer2.register_forward_hook(get_activation(pool_layers[pool_cnt]))
if 'wide' in c.enc_arch:
pool_dims.append(encoder.layer2[-1].conv3.out_channels)
else:
pool_dims.append(encoder.layer2[-1].conv2.out_channels)
pool_cnt = pool_cnt + 1
if L >= 2:
encoder.layer3.register_forward_hook(get_activation(pool_layers[pool_cnt]))
if 'wide' in c.enc_arch:
pool_dims.append(encoder.layer3[-1].conv3.out_channels)
else:
pool_dims.append(encoder.layer3[-1].conv2.out_channels)
pool_cnt = pool_cnt + 1
if L >= 1:
encoder.layer4.register_forward_hook(get_activation(pool_layers[pool_cnt]))
if 'wide' in c.enc_arch:
pool_dims.append(encoder.layer4[-1].conv3.out_channels)
else:
pool_dims.append(encoder.layer4[-1].conv2.out_channels)
pool_cnt = pool_cnt + 1
elif 'vit' in c.enc_arch:
if c.enc_arch == 'vit_base_patch16_224':
encoder = timm.create_model('vit_base_patch16_224', pretrained=True)
elif c.enc_arch == 'vit_base_patch16_384':
encoder = timm.create_model('vit_base_patch16_384', pretrained=True)
else:
raise NotImplementedError('{} is not supported architecture!'.format(c.enc_arch))
#
if L >= 3:
encoder.blocks[10].register_forward_hook(get_activation(pool_layers[pool_cnt]))
pool_dims.append(encoder.blocks[6].mlp.fc2.out_features)
pool_cnt = pool_cnt + 1
if L >= 2:
encoder.blocks[2].register_forward_hook(get_activation(pool_layers[pool_cnt]))
pool_dims.append(encoder.blocks[6].mlp.fc2.out_features)
pool_cnt = pool_cnt + 1
if L >= 1:
encoder.blocks[6].register_forward_hook(get_activation(pool_layers[pool_cnt]))
pool_dims.append(encoder.blocks[6].mlp.fc2.out_features)
pool_cnt = pool_cnt + 1
elif 'efficient' in c.enc_arch:
if 'b5' in c.enc_arch:
encoder = timm.create_model(c.enc_arch, pretrained=True)
blocks = [-2, -3, -5]
else:
raise NotImplementedError('{} is not supported architecture!'.format(c.enc_arch))
#
if L >= 3:
encoder.blocks[blocks[2]][-1].bn3.register_forward_hook(get_activation(pool_layers[pool_cnt]))
pool_dims.append(encoder.blocks[blocks[2]][-1].bn3.num_features)
pool_cnt = pool_cnt + 1
if L >= 2:
encoder.blocks[blocks[1]][-1].bn3.register_forward_hook(get_activation(pool_layers[pool_cnt]))
pool_dims.append(encoder.blocks[blocks[1]][-1].bn3.num_features)
pool_cnt = pool_cnt + 1
if L >= 1:
encoder.blocks[blocks[0]][-1].bn3.register_forward_hook(get_activation(pool_layers[pool_cnt]))
pool_dims.append(encoder.blocks[blocks[0]][-1].bn3.num_features)
pool_cnt = pool_cnt + 1
elif 'mobile' in c.enc_arch:
if c.enc_arch == 'mobilenet_v3_small':
encoder = mobilenet_v3_small(pretrained=True, progress=True).features
blocks = [-2, -5, -10]
elif c.enc_arch == 'mobilenet_v3_large':
encoder = mobilenet_v3_large(pretrained=True, progress=True).features
blocks = [-2, -5, -11]
else:
raise NotImplementedError('{} is not supported architecture!'.format(c.enc_arch))
#
if L >= 3:
encoder[blocks[2]].block[-1][-3].register_forward_hook(get_activation(pool_layers[pool_cnt]))
pool_dims.append(encoder[blocks[2]].block[-1][-3].out_channels)
pool_cnt = pool_cnt + 1
if L >= 2:
encoder[blocks[1]].block[-1][-3].register_forward_hook(get_activation(pool_layers[pool_cnt]))
pool_dims.append(encoder[blocks[1]].block[-1][-3].out_channels)
pool_cnt = pool_cnt + 1
if L >= 1:
encoder[blocks[0]].block[-1][-3].register_forward_hook(get_activation(pool_layers[pool_cnt]))
pool_dims.append(encoder[blocks[0]].block[-1][-3].out_channels)
pool_cnt = pool_cnt + 1
else:
raise NotImplementedError('{} is not supported architecture!'.format(c.enc_arch))
#
return encoder, pool_layers, pool_dims

89
parse_results.py 100644
View File

@ -0,0 +1,89 @@
from __future__ import print_function
import argparse, os
import numpy as np
RESULT_DIR = './results'
def get_args():
parser = argparse.ArgumentParser(description='CFLOW-AD')
parser.add_argument('--dataset', default='mvtec', type=str, metavar='D',
help='dataset name: mvtec/stc/video (default: mvtec)')
parser.add_argument('-cl', '--class-name', default='none', type=str, metavar='C',
help='class name for MVTec/STC (default: none)')
parser.add_argument('-enc', '--enc-arch', default='resnet18', type=str, metavar='A',
help='feature extractor architecture (default: resnet18)')
parser.add_argument('-dec', '--dec-arch', default='freia-cflow', type=str, metavar='A',
help='normalizing flow model (default: freia-cflow)')
parser.add_argument('-pl', '--pool-layers', default=2, type=int, metavar='L',
help='number of layers used in NF model (default: 2)')
parser.add_argument('-cb', '--coupling-blocks', default=8, type=int, metavar='L',
help='number of layers used in NF model (default: 8)')
parser.add_argument('-runs', '--run-count', default=4, type=int, metavar='C',
help='number of runs (default: 4)')
parser.add_argument('-inp', '--input-size', default=256, type=int, metavar='C',
help='image resize dimensions (default: 256)')
args = parser.parse_args()
return args
def main(c):
runs = c.run_count
if c.dataset == 'mvtec':
class_names = ['bottle', 'cable', 'capsule', 'carpet', 'grid',
'hazelnut', 'leather', 'metal_nut', 'pill', 'screw',
'tile', 'toothbrush', 'transistor', 'wood', 'zipper']
elif c.dataset == 'stc':
class_names = ['01', '02', '03', '04', '05', '06',
'07', '08', '09', '10', '11', '12'] #, '13']
else:
raise NotImplementedError('{} is not supported dataset!'.format(c.dataset))
#
metrics = ['DET_AUROC', 'SEG_AUROC', 'SEG_AUPRO']
results = np.zeros((len(metrics), len(class_names), runs))
# loop
result_files = os.listdir(RESULT_DIR)
for class_idx, class_name in enumerate(class_names):
for run in range(runs):
if class_name in ['cable', 'capsule', 'hazelnut', 'metal_nut', 'pill']: # AUROC
input_size = 256
elif class_name == 'transistor':
input_size = 128
else:
input_size = c.input_size
#input_size = c.input_size
c.model = "{}_{}_{}_pl{}_cb{}_inp{}_run{}_{}".format(
c.dataset, c.enc_arch, c.dec_arch, c.pool_layers, c.coupling_blocks, input_size, run, class_name)
#
result_file = list(filter(lambda x: x.startswith(c.model), result_files))
if len(result_file) == 0:
raise NotImplementedError('{} results are not found!'.format(c.model))
elif len(result_file) > 1:
raise NotImplementedError('{} duplicate results are found!'.format(result_file))
else:
result_file = result_file[0]
#
fp = open(os.path.join(RESULT_DIR, result_file), 'r')
lines = fp.readlines()
rline = lines[0].split(' ')[0].split(',')
result = np.array([float(r) for r in rline])
fp.close()
results[:, class_idx, run] = result
#
for i, m in enumerate(metrics):
print('\n{}:'.format(m))
for j, class_name in enumerate(class_names):
print(r"{:.2f}\tiny$\pm${:.2f} for class {}".format(np.mean(results[i, j]), np.std(results[i, j]), class_name))
# \tiny$\pm$
means = np.mean(results[i], 0)
#print(results[i].shape, means.shape)
print(r"{:.2f} for average".format(np.mean(means)))
if __name__ == '__main__':
c = get_args()
main(c)

11
requirements.txt 100644
View File

@ -0,0 +1,11 @@
numpy
scipy
scikit-learn
matplotlib
tqdm
scikit-image
av
timm
torch >= 1.6.0
torchvision >= 0.7.0
git+git://github.com/VLL-HD/FrEIA@4e0c6ab42b26ec6e41b1ee2abb1a8b6562752b00

428
train.py 100644
View File

@ -0,0 +1,428 @@
import os, time
import numpy as np
import torch
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score, auc, precision_recall_curve
from skimage.measure import label, regionprops
from tqdm import tqdm
from visualize import *
from model import load_decoder_arch, load_encoder_arch, positionalencoding2d, activation
from utils import *
from custom_datasets import *
from custom_models import *
OUT_DIR = './viz/'
gamma = 0.0
theta = torch.nn.Sigmoid()
log_theta = torch.nn.LogSigmoid()
def train_meta_epoch(c, epoch, loader, encoder, decoders, optimizer, pool_layers, N):
P = c.condition_vec
L = c.pool_layers
decoders = [decoder.train() for decoder in decoders]
adjust_learning_rate(c, optimizer, epoch)
I = len(loader)
iterator = iter(loader)
for sub_epoch in range(c.sub_epochs):
train_loss = 0.0
train_count = 0
for i in range(I):
# warm-up learning rate
lr = warmup_learning_rate(c, epoch, i+sub_epoch*I, I*c.sub_epochs, optimizer)
# sample batch
try:
image, _, _ = next(iterator)
except StopIteration:
iterator = iter(loader)
image, _, _ = next(iterator)
# encoder prediction
image = image.to(c.device) # single scale
with torch.no_grad():
_ = encoder(image)
# train decoder
e_list = list()
c_list = list()
for l, layer in enumerate(pool_layers):
if 'vit' in c.enc_arch:
e = activation[layer].transpose(1, 2)[...,1:]
e_hw = int(np.sqrt(e.size(2)))
e = e.reshape(-1, e.size(1), e_hw, e_hw) # BxCxHxW
else:
e = activation[layer].detach() # BxCxHxW
#
B, C, H, W = e.size()
S = H*W
E = B*S
#
p = positionalencoding2d(P, H, W).to(c.device).unsqueeze(0).repeat(B, 1, 1, 1)
c_r = p.reshape(B, P, S).transpose(1, 2).reshape(E, P) # BHWxP
e_r = e.reshape(B, C, S).transpose(1, 2).reshape(E, C) # BHWxC
perm = torch.randperm(E).to(c.device) # BHW
decoder = decoders[l]
#
FIB = E//N # number of fiber batches
assert FIB > 0, 'MAKE SURE WE HAVE ENOUGH FIBERS, otherwise decrease N or batch-size!'
for f in range(FIB): # per-fiber processing
idx = torch.arange(f*N, (f+1)*N)
c_p = c_r[perm[idx]] # NxP
e_p = e_r[perm[idx]] # NxC
if 'cflow' in c.dec_arch:
z, log_jac_det = decoder(e_p, [c_p,])
else:
z, log_jac_det = decoder(e_p)
#
decoder_log_prob = get_logp(C, z, log_jac_det)
log_prob = decoder_log_prob / C # likelihood per dim
loss = -log_theta(log_prob)
optimizer.zero_grad()
loss.mean().backward()
optimizer.step()
train_loss += t2np(loss.sum())
train_count += len(loss)
#
mean_train_loss = train_loss / train_count
if c.verbose:
print('Epoch: {:d}.{:d} \t train loss: {:.4f}, lr={:.6f}'.format(epoch, sub_epoch, mean_train_loss, lr))
#
def test_meta_epoch(c, epoch, loader, encoder, decoders, pool_layers, N):
# test
if c.verbose:
print('\nCompute loss and scores on test set:')
#
P = c.condition_vec
decoders = [decoder.eval() for decoder in decoders]
height = list()
width = list()
image_list = list()
gt_label_list = list()
gt_mask_list = list()
test_dist = [list() for layer in pool_layers]
test_loss = 0.0
test_count = 0
start = time.time()
with torch.no_grad():
for i, (image, label, mask) in enumerate(tqdm(loader, disable=c.hide_tqdm_bar)):
# save
if c.viz:
image_list.extend(t2np(image))
gt_label_list.extend(t2np(label))
gt_mask_list.extend(t2np(mask))
# data
image = image.to(c.device) # single scale
_ = encoder(image) # BxCxHxW
# test decoder
e_list = list()
for l, layer in enumerate(pool_layers):
if 'vit' in c.enc_arch:
e = activation[layer].transpose(1, 2)[...,1:]
e_hw = int(np.sqrt(e.size(2)))
e = e.reshape(-1, e.size(1), e_hw, e_hw) # BxCxHxW
else:
e = activation[layer] # BxCxHxW
#
B, C, H, W = e.size()
S = H*W
E = B*S
#
if i == 0: # get stats
height.append(H)
width.append(W)
#
p = positionalencoding2d(P, H, W).to(c.device).unsqueeze(0).repeat(B, 1, 1, 1)
c_r = p.reshape(B, P, S).transpose(1, 2).reshape(E, P) # BHWxP
e_r = e.reshape(B, C, S).transpose(1, 2).reshape(E, C) # BHWxC
#
m = F.interpolate(mask, size=(H, W), mode='nearest')
m_r = m.reshape(B, 1, S).transpose(1, 2).reshape(E, 1) # BHWx1
#
decoder = decoders[l]
FIB = E//N + int(E%N > 0) # number of fiber batches
for f in range(FIB):
if f < (FIB-1):
idx = torch.arange(f*N, (f+1)*N)
else:
idx = torch.arange(f*N, E)
#
c_p = c_r[idx] # NxP
e_p = e_r[idx] # NxC
m_p = m_r[idx] > 0.5 # Nx1
#
if 'cflow' in c.dec_arch:
z, log_jac_det = decoder(e_p, [c_p,])
else:
z, log_jac_det = decoder(e_p)
#
decoder_log_prob = get_logp(C, z, log_jac_det)
log_prob = decoder_log_prob / C # likelihood per dim
loss = -log_theta(log_prob)
test_loss += t2np(loss.sum())
test_count += len(loss)
test_dist[l] = test_dist[l] + log_prob.detach().cpu().tolist()
#
fps = len(loader.dataset) / (time.time() - start)
mean_test_loss = test_loss / test_count
if c.verbose:
print('Epoch: {:d} \t test_loss: {:.4f} and {:.2f} fps'.format(epoch, mean_test_loss, fps))
#
return height, width, image_list, test_dist, gt_label_list, gt_mask_list
def test_meta_fps(c, epoch, loader, encoder, decoders, pool_layers, N):
# test
if c.verbose:
print('\nCompute loss and scores on test set:')
#
P = c.condition_vec
decoders = [decoder.eval() for decoder in decoders]
height = list()
width = list()
image_list = list()
gt_label_list = list()
gt_mask_list = list()
test_dist = [list() for layer in pool_layers]
test_loss = 0.0
test_count = 0
A = len(loader.dataset)
with torch.no_grad():
# warm-up
for i, (image, _, _) in enumerate(tqdm(loader, disable=c.hide_tqdm_bar)):
# data
image = image.to(c.device) # single scale
_ = encoder(image) # BxCxHxW
# measure encoder only
torch.cuda.synchronize()
start = time.time()
for i, (image, _, _) in enumerate(tqdm(loader, disable=c.hide_tqdm_bar)):
# data
image = image.to(c.device) # single scale
_ = encoder(image) # BxCxHxW
# measure encoder + decoder
torch.cuda.synchronize()
time_enc = time.time() - start
start = time.time()
for i, (image, _, _) in enumerate(tqdm(loader, disable=c.hide_tqdm_bar)):
# data
image = image.to(c.device) # single scale
_ = encoder(image) # BxCxHxW
# test decoder
e_list = list()
for l, layer in enumerate(pool_layers):
if 'vit' in c.enc_arch:
e = activation[layer].transpose(1, 2)[...,1:]
e_hw = int(np.sqrt(e.size(2)))
e = e.reshape(-1, e.size(1), e_hw, e_hw) # BxCxHxW
else:
e = activation[layer] # BxCxHxW
#
B, C, H, W = e.size()
S = H*W
E = B*S
#
if i == 0: # get stats
height.append(H)
width.append(W)
#
p = positionalencoding2d(P, H, W).to(c.device).unsqueeze(0).repeat(B, 1, 1, 1)
c_r = p.reshape(B, P, S).transpose(1, 2).reshape(E, P) # BHWxP
e_r = e.reshape(B, C, S).transpose(1, 2).reshape(E, C) # BHWxC
#
decoder = decoders[l]
FIB = E//N + int(E%N > 0) # number of fiber batches
for f in range(FIB):
if f < (FIB-1):
idx = torch.arange(f*N, (f+1)*N)
else:
idx = torch.arange(f*N, E)
#
c_p = c_r[idx] # NxP
e_p = e_r[idx] # NxC
#
if 'cflow' in c.dec_arch:
z, log_jac_det = decoder(e_p, [c_p,])
else:
z, log_jac_det = decoder(e_p)
#
torch.cuda.synchronize()
time_all = time.time() - start
fps_enc = A / time_enc
fps_all = A / time_all
print('Encoder/All {:.2f}/{:.2f} fps'.format(fps_enc, fps_all))
#
return height, width, image_list, test_dist, gt_label_list, gt_mask_list
def train(c):
run_date = datetime.datetime.now().strftime("%Y-%m-%d-%H:%M:%S")
L = c.pool_layers # number of pooled layers
print('Number of pool layers =', L)
encoder, pool_layers, pool_dims = load_encoder_arch(c, L)
encoder = encoder.to(c.device).eval()
#print(encoder)
# NF decoder
decoders = [load_decoder_arch(c, pool_dim) for pool_dim in pool_dims]
decoders = [decoder.to(c.device) for decoder in decoders]
params = list(decoders[0].parameters())
for l in range(1, L):
params += list(decoders[l].parameters())
# optimizer
optimizer = torch.optim.Adam(params, lr=c.lr)
# data
kwargs = {'num_workers': c.workers, 'pin_memory': True} if c.use_cuda else {}
# task data
if c.dataset == 'mvtec':
train_dataset = MVTecDataset(c, is_train=True)
test_dataset = MVTecDataset(c, is_train=False)
elif c.dataset == 'stc':
train_dataset = StcDataset(c, is_train=True)
test_dataset = StcDataset(c, is_train=False)
#elif c.dataset == 'video':
# c.data_path = c.video_path
else:
raise NotImplementedError('{} is not supported dataset!'.format(c.dataset))
#
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=c.batch_size, shuffle=True, drop_last=True, **kwargs)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=c.batch_size, shuffle=False, drop_last=False, **kwargs)
N = 256 # hyperparameter that increases batch size for the decoder model by N
print('train/test loader length', len(train_loader.dataset), len(test_loader.dataset))
print('train/test loader batches', len(train_loader), len(test_loader))
# stats
det_roc_obs = Score_Observer('DET_AUROC')
seg_roc_obs = Score_Observer('SEG_AUROC')
seg_pro_obs = Score_Observer('SEG_AUPRO')
for epoch in range(c.meta_epochs):
if c.viz:
if c.checkpoint:
load_weights(encoder, decoders, c.checkpoint)
else:
print('Train meta epoch: {}'.format(epoch))
train_meta_epoch(c, epoch, train_loader, encoder, decoders, optimizer, pool_layers, N)
#height, width, test_image_list, test_dist, gt_label_list, gt_mask_list = test_meta_fps(
# c, epoch, test_loader, encoder, decoders, pool_layers, N)
height, width, test_image_list, test_dist, gt_label_list, gt_mask_list = test_meta_epoch(
c, epoch, test_loader, encoder, decoders, pool_layers, N)
# PxEHW
print('Heights/Widths', height, width)
test_map = [list() for p in pool_layers]
for l, p in enumerate(pool_layers):
test_norm = torch.tensor(test_dist[l], dtype=torch.double) # EHWx1
test_norm-= torch.max(test_norm) # normalize likelihoods to (-Inf:0] by subtracting a constant
test_prob = torch.exp(test_norm) # convert to probs in range [0:1]
test_mask = test_prob.reshape(-1, height[l], width[l])
#print('Prob shape:', test_prob.shape, test_prob.min(), test_prob.max())
test_mask = test_prob.reshape(-1, height[l], width[l])
# upsample
test_map[l] = F.interpolate(test_mask.unsqueeze(1),
size=c.crp_size, mode='bilinear', align_corners=True).squeeze().numpy()
# score aggregation
score_map = np.zeros_like(test_map[0])
for l, p in enumerate(pool_layers):
score_map += test_map[l]
score_mask = score_map
# superpixels
super_mask = score_mask.max() - score_mask # /score_mask.max() # normality score to anomaly score
# calculate detection AUROC
score_label = np.max(super_mask, axis=(1, 2))
gt_label = np.asarray(gt_label_list, dtype=np.bool)
det_roc_auc = roc_auc_score(gt_label, score_label)
det_roc_obs.update(100.0*det_roc_auc, epoch)
# calculate segmentation AUROC
gt_mask = np.squeeze(np.asarray(gt_mask_list, dtype=np.bool), axis=1)
seg_roc_auc = roc_auc_score(gt_mask.flatten(), super_mask.flatten())
seg_roc_obs.update(100.0*seg_roc_auc, epoch)
# calculate segmentation AUPRO
# from https://github.com/YoungGod/DFR:
if c.pro: # and (epoch % 4 == 0): # AUPRO is expensive to compute
max_step = 1000
expect_fpr = 0.3 # default 30%
max_th = super_mask.max()
min_th = super_mask.min()
delta = (max_th - min_th) / max_step
ious_mean = []
ious_std = []
pros_mean = []
pros_std = []
threds = []
fprs = []
binary_score_maps = np.zeros_like(super_mask, dtype=np.bool)
for step in range(max_step):
thred = max_th - step * delta
# segmentation
binary_score_maps[super_mask <= thred] = 0
binary_score_maps[super_mask > thred] = 1
pro = [] # per region overlap
iou = [] # per image iou
# pro: find each connected gt region, compute the overlapped pixels between the gt region and predicted region
# iou: for each image, compute the ratio, i.e. intersection/union between the gt and predicted binary map
for i in range(len(binary_score_maps)): # for i th image
# pro (per region level)
label_map = label(gt_mask[i], connectivity=2)
props = regionprops(label_map)
for prop in props:
x_min, y_min, x_max, y_max = prop.bbox # find the bounding box of an anomaly region
cropped_pred_label = binary_score_maps[i][x_min:x_max, y_min:y_max]
# cropped_mask = gt_mask[i][x_min:x_max, y_min:y_max] # bug!
cropped_mask = prop.filled_image # corrected!
intersection = np.logical_and(cropped_pred_label, cropped_mask).astype(np.float32).sum()
pro.append(intersection / prop.area)
# iou (per image level)
intersection = np.logical_and(binary_score_maps[i], gt_mask[i]).astype(np.float32).sum()
union = np.logical_or(binary_score_maps[i], gt_mask[i]).astype(np.float32).sum()
if gt_mask[i].any() > 0: # when the gt have no anomaly pixels, skip it
iou.append(intersection / union)
# against steps and average metrics on the testing data
ious_mean.append(np.array(iou).mean())
#print("per image mean iou:", np.array(iou).mean())
ious_std.append(np.array(iou).std())
pros_mean.append(np.array(pro).mean())
pros_std.append(np.array(pro).std())
# fpr for pro-auc
gt_masks_neg = ~gt_mask
fpr = np.logical_and(gt_masks_neg, binary_score_maps).sum() / gt_masks_neg.sum()
fprs.append(fpr)
threds.append(thred)
# as array
threds = np.array(threds)
pros_mean = np.array(pros_mean)
pros_std = np.array(pros_std)
fprs = np.array(fprs)
ious_mean = np.array(ious_mean)
ious_std = np.array(ious_std)
# best per image iou
best_miou = ious_mean.max()
#print(f"Best IOU: {best_miou:.4f}")
# default 30% fpr vs pro, pro_auc
idx = fprs <= expect_fpr # find the indexs of fprs that is less than expect_fpr (default 0.3)
fprs_selected = fprs[idx]
fprs_selected = rescale(fprs_selected) # rescale fpr [0,0.3] -> [0, 1]
pros_mean_selected = pros_mean[idx]
seg_pro_auc = auc(fprs_selected, pros_mean_selected)
seg_pro_obs.update(100.0*seg_pro_auc, epoch)
# export visualuzations
if c.viz:
precision, recall, thresholds = precision_recall_curve(gt_label, score_label)
a = 2 * precision * recall
b = precision + recall
f1 = np.divide(a, b, out=np.zeros_like(a), where=b != 0)
det_threshold = thresholds[np.argmax(f1)]
print('Optimal DET Threshold: {:.2f}'.format(det_threshold))
precision, recall, thresholds = precision_recall_curve(gt_mask.flatten(), super_mask.flatten())
a = 2 * precision * recall
b = precision + recall
f1 = np.divide(a, b, out=np.zeros_like(a), where=b != 0)
seg_threshold = thresholds[np.argmax(f1)]
print('Optimal SEG Threshold: {:.2f}'.format(seg_threshold))
export_groundtruth(c, test_image_list, gt_mask)
export_scores(c, test_image_list, super_mask, seg_threshold)
export_test_images(c, test_image_list, gt_mask, super_mask, seg_threshold)
export_hist(c, gt_mask, super_mask, seg_threshold)
#save_weights(encoder, decoders, c.model, run_date) # avoid unnecessary saves
elif c.save_results:
save_results(det_roc_obs, seg_roc_obs, seg_pro_obs, c.model, c.class_name, run_date)
save_weights(encoder, decoders, c.model, run_date) # avoid unnecessary saves

40
utils.py 100644
View File

@ -0,0 +1,40 @@
import numpy as np
import torch
import torch.nn.functional as F
np.random.seed(0)
_GCONST_ = -0.9189385332046727 # ln(sqrt(2*pi))
class Score_Observer:
def __init__(self, name):
self.name = name
self.max_epoch = 0
self.max_score = 0.0
self.last = 0.0
def update(self, score, epoch, print_score=True):
self.last = score
if epoch == 0 or score > self.max_score:
self.max_score = score
self.max_epoch = epoch
if print_score:
self.print_score()
def print_score(self):
print('{:s}: \t last: {:.2f} \t max: {:.2f} \t epoch_max: {:d}'.format(
self.name, self.last, self.max_score, self.max_epoch))
def t2np(tensor):
'''pytorch tensor -> numpy array'''
return tensor.cpu().data.numpy() if tensor is not None else None
def get_logp(C, z, logdet_J):
logp = C * _GCONST_ - 0.5*torch.sum(z**2, 1) + logdet_J
return logp
def rescale(x):
return (x - x.min()) / (x.max() - x.min())

147
visualize.py 100644
View File

@ -0,0 +1,147 @@
import os
import datetime
import numpy as np
from skimage import morphology
from skimage.segmentation import mark_boundaries
import matplotlib.pyplot as plt
import matplotlib
from utils import *
OUT_DIR = './viz/'
norm = matplotlib.colors.Normalize(vmin=0.0, vmax=255.0)
cm = 1/2.54
dpi = 300
def denormalization(x, norm_mean, norm_std):
mean = np.array(norm_mean)
std = np.array(norm_std)
x = (((x.transpose(1, 2, 0) * std) + mean) * 255.).astype(np.uint8)
return x
def export_hist(c, gts, scores, threshold):
print('Exporting histogram...')
plt.rcParams.update({'font.size': 4})
image_dirs = os.path.join(OUT_DIR, c.model)
os.makedirs(image_dirs, exist_ok=True)
Y = scores.flatten()
Y_label = gts.flatten()
fig = plt.figure(figsize=(4*cm, 4*cm), dpi=dpi)
ax = plt.Axes(fig, [0., 0., 1., 1.])
fig.add_axes(ax)
plt.hist([Y[Y_label==1], Y[Y_label==0]], 500, density=True, color=['r', 'g'], label=['ANO', 'TYP'], alpha=0.75, histtype='barstacked')
image_file = os.path.join(image_dirs, 'hist_images_' + datetime.datetime.now().strftime("%Y-%m-%d-%H:%M:%S"))
fig.savefig(image_file, dpi=dpi, format='svg', bbox_inches = 'tight', pad_inches = 0.0)
plt.close()
def export_groundtruth(c, test_img, gts):
image_dirs = os.path.join(OUT_DIR, c.model, 'gt_images_' + datetime.datetime.now().strftime("%Y-%m-%d-%H:%M:%S"))
# images
if not os.path.isdir(image_dirs):
print('Exporting grountruth...')
os.makedirs(image_dirs, exist_ok=True)
num = len(test_img)
kernel = morphology.disk(4)
for i in range(num):
img = test_img[i]
img = denormalization(img, c.norm_mean, c.norm_std)
# gts
gt_mask = gts[i].astype(np.float64)
gt_mask = morphology.opening(gt_mask, kernel)
gt_mask = (255.0*gt_mask).astype(np.uint8)
gt_img = mark_boundaries(img, gt_mask, color=(1, 0, 0), mode='thick')
#
fig = plt.figure(figsize=(2*cm, 2*cm), dpi=dpi)
ax = plt.Axes(fig, [0., 0., 1., 1.])
ax.set_axis_off()
fig.add_axes(ax)
ax.imshow(gt_img)
image_file = os.path.join(image_dirs, '{:08d}'.format(i))
fig.savefig(image_file, dpi=dpi, format='svg', bbox_inches = 'tight', pad_inches = 0.0)
plt.close()
def export_scores(c, test_img, scores, threshold):
image_dirs = os.path.join(OUT_DIR, c.model, 'sc_images_' + datetime.datetime.now().strftime("%Y-%m-%d-%H:%M:%S"))
# images
if not os.path.isdir(image_dirs):
print('Exporting scores...')
os.makedirs(image_dirs, exist_ok=True)
num = len(test_img)
kernel = morphology.disk(4)
scores_norm = 1.0/scores.max()
for i in range(num):
img = test_img[i]
img = denormalization(img, c.norm_mean, c.norm_std)
# scores
score_mask = np.zeros_like(scores[i])
score_mask[scores[i] > threshold] = 1.0
score_mask = morphology.opening(score_mask, kernel)
score_mask = (255.0*score_mask).astype(np.uint8)
score_img = mark_boundaries(img, score_mask, color=(1, 0, 0), mode='thick')
score_map = (255.0*scores[i]*scores_norm).astype(np.uint8)
#
fig_img, ax_img = plt.subplots(2, 1, figsize=(2*cm, 4*cm))
for ax_i in ax_img:
ax_i.axes.xaxis.set_visible(False)
ax_i.axes.yaxis.set_visible(False)
ax_i.spines['top'].set_visible(False)
ax_i.spines['right'].set_visible(False)
ax_i.spines['bottom'].set_visible(False)
ax_i.spines['left'].set_visible(False)
#
plt.subplots_adjust(hspace = 0.1, wspace = 0.1)
ax_img[0].imshow(img, cmap='gray', interpolation='none')
ax_img[0].imshow(score_map, cmap='jet', norm=norm, alpha=0.5, interpolation='none')
ax_img[1].imshow(score_img)
image_file = os.path.join(image_dirs, '{:08d}'.format(i))
fig_img.savefig(image_file, dpi=dpi, format='svg', bbox_inches = 'tight', pad_inches = 0.0)
plt.close()
def export_test_images(c, test_img, gts, scores, threshold):
image_dirs = os.path.join(OUT_DIR, c.model, 'images_' + datetime.datetime.now().strftime("%Y-%m-%d-%H:%M:%S"))
cm = 1/2.54
# images
if not os.path.isdir(image_dirs):
print('Exporting images...')
os.makedirs(image_dirs, exist_ok=True)
num = len(test_img)
font = {'family': 'serif', 'color': 'black', 'weight': 'normal', 'size': 8}
kernel = morphology.disk(4)
scores_norm = 1.0/scores.max()
for i in range(num):
img = test_img[i]
img = denormalization(img, c.norm_mean, c.norm_std)
# gts
gt_mask = gts[i].astype(np.float64)
print('GT:', i, gt_mask.sum())
gt_mask = morphology.opening(gt_mask, kernel)
gt_mask = (255.0*gt_mask).astype(np.uint8)
gt_img = mark_boundaries(img, gt_mask, color=(1, 0, 0), mode='thick')
# scores
score_mask = np.zeros_like(scores[i])
score_mask[scores[i] > threshold] = 1.0
print('SC:', i, score_mask.sum())
score_mask = morphology.opening(score_mask, kernel)
score_mask = (255.0*score_mask).astype(np.uint8)
score_img = mark_boundaries(img, score_mask, color=(1, 0, 0), mode='thick')
score_map = (255.0*scores[i]*scores_norm).astype(np.uint8)
#
fig_img, ax_img = plt.subplots(3, 1, figsize=(2*cm, 6*cm))
for ax_i in ax_img:
ax_i.axes.xaxis.set_visible(False)
ax_i.axes.yaxis.set_visible(False)
ax_i.spines['top'].set_visible(False)
ax_i.spines['right'].set_visible(False)
ax_i.spines['bottom'].set_visible(False)
ax_i.spines['left'].set_visible(False)
#
plt.subplots_adjust(hspace = 0.1, wspace = 0.1)
ax_img[0].imshow(gt_img)
ax_img[1].imshow(score_map, cmap='jet', norm=norm)
ax_img[2].imshow(score_img)
image_file = os.path.join(image_dirs, '{:08d}'.format(i))
fig_img.savefig(image_file, dpi=dpi, format='svg', bbox_inches = 'tight', pad_inches = 0.0)
plt.close()