213 lines
7.2 KiB
Python
213 lines
7.2 KiB
Python
import os, time
|
|
import random, math
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from sklearn.metrics import roc_auc_score, auc, precision_recall_curve
|
|
from skimage.measure import label, regionprops
|
|
from tqdm import tqdm
|
|
from visualize import *
|
|
from model import load_decoder_arch, load_encoder_arch, positionalencoding2d, activation
|
|
from utils import *
|
|
from custom_datasets import *
|
|
from custom_models import *
|
|
from config import get_args
|
|
from config import get_args
|
|
from train import train
|
|
from test import test
|
|
from test_single import test_single
|
|
|
|
log_theta = torch.nn.LogSigmoid()
|
|
|
|
def load_weights2(model, filename):
|
|
path = os.path.join(filename)
|
|
state = torch.load(path)
|
|
model.Encoder_module.encoder.load_state_dict(state['encoder_state_dict'], strict=False)
|
|
decoders = [decoder.load_state_dict(state, strict=False) for decoder, state in zip(model.Decoder_module.decoders, state['decoder_state_dict'])]
|
|
print('Loading weights from {}'.format(filename))
|
|
return model
|
|
|
|
# Define the encoder
|
|
class Encoder(torch.nn.Module):
|
|
def __init__(self, encoder):
|
|
super(Encoder,self).__init__()
|
|
self.encoder = encoder
|
|
def forward(self, input):
|
|
return self.encoder(input)
|
|
|
|
#Define the decoder
|
|
class Decoder(torch.nn.Module):
|
|
def __init__(self, c, decoders):
|
|
super(Decoder, self).__init__()
|
|
self.c = c
|
|
self.decoders = decoders
|
|
L = c.pool_layers
|
|
params = list(self.decoders[0].parameters())
|
|
for l in range(1, L):
|
|
params += list(self.decoders[l].parameters())
|
|
# optimizer
|
|
self.optimizer = torch.optim.Adam(params, lr=self.c.lr)
|
|
self.N = 256
|
|
|
|
def forward(self, pool_layers):
|
|
P = self.c.condition_vec
|
|
# print(self.decoders)
|
|
self.decoders = [decoder.eval() for decoder in self.decoders]
|
|
height = list()
|
|
width = list()
|
|
i = 0
|
|
test_dist = [list() for layer in pool_layers]
|
|
test_loss = 0.0
|
|
test_count = 0
|
|
start = time.time()
|
|
with torch.no_grad():
|
|
for l, layer in enumerate(pool_layers):
|
|
e = activation[layer] # BxCxHxW
|
|
#
|
|
B, C, H, W = e.size()
|
|
S = H * W
|
|
E = B * S
|
|
#
|
|
if i == 0: # get stats
|
|
height.append(H)
|
|
width.append(W)
|
|
#
|
|
p = positionalencoding2d(P, H, W).to(self.c.device).unsqueeze(0).repeat(B, 1, 1, 1)
|
|
c_r = p.reshape(B, P, S).transpose(1, 2).reshape(E, P) # BHWxP
|
|
e_r = e.reshape(B, C, S).transpose(1, 2).reshape(E, C) # BHWxC
|
|
|
|
decoder = self.decoders[l]
|
|
FIB = E // self.N + int(E % self.N > 0) # number of fiber batches
|
|
for f in range(FIB):
|
|
if f < (FIB - 1):
|
|
idx = torch.arange(f * self.N, (f + 1) * self.N)
|
|
else:
|
|
idx = torch.arange(f * self.N, E)
|
|
#
|
|
c_p = c_r[idx] # NxP
|
|
e_p = e_r[idx] # NxC
|
|
# m_p = m_r[idx] > 0.5 # Nx1
|
|
#
|
|
z, log_jac_det = decoder(e_p, [c_p, ])
|
|
#
|
|
decoder_log_prob = get_logp(C, z, log_jac_det)
|
|
log_prob = decoder_log_prob / C # likelihood per dim
|
|
loss = -log_theta(log_prob)
|
|
test_loss += t2np(loss.sum())
|
|
test_count += len(loss)
|
|
test_dist[l] = test_dist[l] + log_prob.detach().cpu().tolist()
|
|
return height, width, test_dist
|
|
|
|
#Define the base calss for the model
|
|
class CFlow(torch.nn.Module):
|
|
def __init__(self, c,encoder,decoders,pool_layers):
|
|
super(CFlow, self).__init__()
|
|
self.pool_layers=pool_layers
|
|
self.Encoder_module = Encoder(encoder)
|
|
self.Decoder_module = Decoder(c, decoders)
|
|
def forward(self,enc_input):
|
|
_=self.Encoder_module(enc_input)
|
|
height, width, test_dist = self.Decoder_module(self.pool_layers)
|
|
return height, width , test_dist
|
|
|
|
|
|
def init_seeds(seed=0):
|
|
random.seed(seed)
|
|
np.random.seed(seed)
|
|
torch.manual_seed(seed)
|
|
torch.cuda.manual_seed(seed)
|
|
torch.cuda.manual_seed_all(seed)
|
|
#Main
|
|
def main(c):
|
|
# model
|
|
c.gpu = '0'
|
|
c.enc_arch = 'mobilenet_v3_large'
|
|
c.inp = 256
|
|
c.dataset = 'mvtec'
|
|
c.action_type = 'norm-train'
|
|
|
|
# image
|
|
c.img_size = (c.input_size, c.input_size) # HxW format
|
|
c.crp_size = (c.input_size, c.input_size) # HxW format
|
|
if c.dataset == 'stc':
|
|
c.norm_mean, c.norm_std = 3 * [0.5], 3 * [0.225]
|
|
else:
|
|
c.norm_mean, c.norm_std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
|
|
#
|
|
c.img_dims = [3] + list(c.img_size)
|
|
# network hyperparameters
|
|
c.clamp_alpha = 1.9 # see paper equation 2 for explanation
|
|
c.condition_vec = 128
|
|
c.dropout = 0.0 # dropout in s-t-networks
|
|
# dataloader parameters
|
|
if c.dataset == 'mvtec':
|
|
c.data_path = './data/MVTec-AD'
|
|
elif c.dataset == 'stc':
|
|
c.data_path = './data/STC/shanghaitech'
|
|
elif c.dataset == 'video':
|
|
c.data_path = c.video_path
|
|
elif c.dataset == 'image':
|
|
c.data_path = c.image_path
|
|
else:
|
|
raise NotImplementedError('{} is not supported dataset!'.format(c.dataset))
|
|
# output settings
|
|
c.verbose = True
|
|
c.hide_tqdm_bar = True
|
|
c.save_results = True
|
|
# unsup-train
|
|
c.print_freq = 2
|
|
c.temp = 0.5
|
|
c.lr_decay_epochs = [i * c.meta_epochs // 100 for i in [50, 75, 90]]
|
|
print('LR schedule: {}'.format(c.lr_decay_epochs))
|
|
c.lr_decay_rate = 0.1
|
|
c.lr_warm_epochs = 2
|
|
c.lr_warm = True
|
|
c.lr_cosine = True
|
|
if c.lr_warm:
|
|
c.lr_warmup_from = c.lr / 10.0
|
|
if c.lr_cosine:
|
|
eta_min = c.lr * (c.lr_decay_rate ** 3)
|
|
c.lr_warmup_to = eta_min + (c.lr - eta_min) * (
|
|
1 + math.cos(math.pi * c.lr_warm_epochs / c.meta_epochs)) / 2
|
|
else:
|
|
c.lr_warmup_to = c.lr
|
|
########
|
|
os.environ['CUDA_VISIBLE_DEVICES'] = c.gpu
|
|
c.use_cuda = not c.no_cuda and torch.cuda.is_available()
|
|
init_seeds(seed=int(time.time()))
|
|
c.device = torch.device("cuda" if c.use_cuda else "cpu")
|
|
#Create the encoder and decoder networks
|
|
L = c.pool_layers
|
|
encoder, pool_layers, pool_dims = load_encoder_arch(c, L)
|
|
encoder = encoder.to(c.device).eval()
|
|
decoders = [load_decoder_arch(c, pool_dim) for pool_dim in pool_dims]
|
|
decoders = [decoder.to(c.device) for decoder in decoders]
|
|
|
|
#Initialize the base calss
|
|
model=CFlow(c,encoder,decoders,pool_layers)
|
|
|
|
PATH = 'weights/mvtec_mobilenet_v3_large_freia-cflow_pl3_cb8_inp256_run0_Model_2022-11-08-10:50:39.pt'
|
|
model = load_weights2(model, PATH)
|
|
print("Loaded !")
|
|
model.eval()
|
|
batch_size = 1
|
|
x = torch.randn(batch_size, 3, 256, 256).to(c.device)
|
|
out = model(x)
|
|
torch.onnx.export(
|
|
model, #
|
|
x,
|
|
"custom-d.onnx",
|
|
export_params=True,
|
|
verbose=True,
|
|
opset_version=11,
|
|
input_names=["input"],
|
|
output_names=["output"],
|
|
)
|
|
if __name__ == '__main__':
|
|
c = get_args()
|
|
main(c)
|
|
|
|
|
|
|
|
|