Merge pull request #33 from Tekno-H/master

Export model to ONNX format
pull/38/head
Denis 2022-12-20 10:51:11 -08:00 committed by GitHub
commit 0d8c20251e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 408 additions and 0 deletions

196
cflow-to-onnx.py 100644
View File

@ -0,0 +1,196 @@
import os, time
import random, math
import numpy as np
import torch
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score, auc, precision_recall_curve
from skimage.measure import label, regionprops
from tqdm import tqdm
from visualize import *
from model import load_decoder_arch, load_encoder_arch, positionalencoding2d, activation
from utils import *
from custom_datasets import *
from custom_models import *
from config import get_args
from config import get_args
from train import train
from test import test
from test_single import test_single
log_theta = torch.nn.LogSigmoid()
# First we create the model ex. model=CFlow(C)
# Then we load the parameters ex. model(checkpoint)
def load_weights2(model, filename):
path = os.path.join(filename)
state = torch.load(path)
model.encoder.load_state_dict(state['encoder_state_dict'], strict=False)
decoders = [decoder.load_state_dict(state, strict=False) for decoder, state in zip(model.decoders, state['decoder_state_dict'])]
print('Loading weights from {}'.format(filename))
return model
class CFlow(torch.nn.Module):
def __init__(self, c):
super(CFlow, self).__init__()
L = c.pool_layers
self.encoder, self.pool_layers, self.pool_dims = load_encoder_arch(c, L)
self.encoder = self.encoder.to(c.device).eval()
self.decoders = [load_decoder_arch(c, pool_dim) for pool_dim in self.pool_dims]
self.decoders = [decoder.to(c.device) for decoder in self.decoders]
params = list(self.decoders[0].parameters())
for l in range(1, L):
params += list(self.decoders[l].parameters())
# optimizer
self.optimizer = torch.optim.Adam(params, lr=c.lr)
self.N=256
def forward(self, x):
print(type(x))
P = c.condition_vec
#print(self.decoders)
self.decoders = [decoder.eval() for decoder in self.decoders]
height = list()
width = list()
i=0
test_dist = [list() for layer in self.pool_layers]
test_loss = 0.0
test_count = 0
start = time.time()
_ = self.encoder(x)
with torch.no_grad():
for l, layer in enumerate(self.pool_layers):
e = activation[layer] # BxCxHxW
#
B, C, H, W = e.size()
S = H * W
E = B * S
#
if i == 0: # get stats
height.append(H)
width.append(W)
#
p = positionalencoding2d(P, H, W).to(c.device).unsqueeze(0).repeat(B, 1, 1, 1)
c_r = p.reshape(B, P, S).transpose(1, 2).reshape(E, P) # BHWxP
e_r = e.reshape(B, C, S).transpose(1, 2).reshape(E, C) # BHWxC
decoder = self.decoders[l]
FIB = E // self.N + int(E % self.N > 0) # number of fiber batches
for f in range(FIB):
if f < (FIB - 1):
idx = torch.arange(f * self.N, (f + 1) * self.N)
else:
idx = torch.arange(f * self.N, E)
#
c_p = c_r[idx] # NxP
e_p = e_r[idx] # NxC
# m_p = m_r[idx] > 0.5 # Nx1
#
if 'cflow' in c.dec_arch:
z, log_jac_det = decoder(e_p, [c_p, ])
else:
z, log_jac_det = decoder(e_p)
#
decoder_log_prob = get_logp(C, z, log_jac_det)
log_prob = decoder_log_prob / C # likelihood per dim
loss = -log_theta(log_prob)
test_loss += t2np(loss.sum())
test_count += len(loss)
test_dist[l] = test_dist[l] + log_prob.detach().cpu().tolist()
return height, width, test_dist
def init_seeds(seed=0):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def main(c):
# model
c.gpu='0'
c.enc_arch='mobilenet_v3_large'
c.inp=256
c.dataset='mvtec'
c.action_type='norm-train'
# image
c.img_size = (c.input_size, c.input_size) # HxW format
c.crp_size = (c.input_size, c.input_size) # HxW format
if c.dataset == 'stc':
c.norm_mean, c.norm_std = 3 * [0.5], 3 * [0.225]
else:
c.norm_mean, c.norm_std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
#
c.img_dims = [3] + list(c.img_size)
# network hyperparameters
c.clamp_alpha = 1.9 # see paper equation 2 for explanation
c.condition_vec = 128
c.dropout = 0.0 # dropout in s-t-networks
# dataloader parameters
if c.dataset == 'mvtec':
c.data_path = './data/MVTec-AD'
elif c.dataset == 'stc':
c.data_path = './data/STC/shanghaitech'
elif c.dataset == 'video':
c.data_path = c.video_path
elif c.dataset == 'image':
c.data_path = c.image_path
else:
raise NotImplementedError('{} is not supported dataset!'.format(c.dataset))
# output settings
c.verbose = True
c.hide_tqdm_bar = True
c.save_results = True
# unsup-train
c.print_freq = 2
c.temp = 0.5
c.lr_decay_epochs = [i * c.meta_epochs // 100 for i in [50, 75, 90]]
print('LR schedule: {}'.format(c.lr_decay_epochs))
c.lr_decay_rate = 0.1
c.lr_warm_epochs = 2
c.lr_warm = True
c.lr_cosine = True
if c.lr_warm:
c.lr_warmup_from = c.lr / 10.0
if c.lr_cosine:
eta_min = c.lr * (c.lr_decay_rate ** 3)
c.lr_warmup_to = eta_min + (c.lr - eta_min) * (
1 + math.cos(math.pi * c.lr_warm_epochs / c.meta_epochs)) / 2
else:
c.lr_warmup_to = c.lr
########
os.environ['CUDA_VISIBLE_DEVICES'] = c.gpu
c.use_cuda = not c.no_cuda and torch.cuda.is_available()
init_seeds(seed=int(time.time()))
c.device = torch.device("cuda" if c.use_cuda else "cpu")
# selected function:
model = CFlow(c)
print("Created !!!")
PATH = 'weights/mvtec_mobilenet_v3_large_freia-cflow_pl3_cb8_inp256_run0_Model_2022-11-08-10:50:39.pt'
model=load_weights2(model,PATH)
print("Loaded !")
model.eval()
batch_size = 1
x = torch.randn(batch_size, 3, 256, 256).to(c.device)
out=model(x)
print("Starting Export !!!")
torch.onnx.export(
model,
x,
"custom.onnx",
export_params=True,
verbose=True,
opset_version=11,
input_names=["input"],
output_names=["output"],
)
if __name__ == '__main__':
c = get_args()
main(c)

212
cflow-to-onnx2.py 100644
View File

@ -0,0 +1,212 @@
import os, time
import random, math
import numpy as np
import torch
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score, auc, precision_recall_curve
from skimage.measure import label, regionprops
from tqdm import tqdm
from visualize import *
from model import load_decoder_arch, load_encoder_arch, positionalencoding2d, activation
from utils import *
from custom_datasets import *
from custom_models import *
from config import get_args
from config import get_args
from train import train
from test import test
from test_single import test_single
log_theta = torch.nn.LogSigmoid()
def load_weights2(model, filename):
path = os.path.join(filename)
state = torch.load(path)
model.Encoder_module.encoder.load_state_dict(state['encoder_state_dict'], strict=False)
decoders = [decoder.load_state_dict(state, strict=False) for decoder, state in zip(model.Decoder_module.decoders, state['decoder_state_dict'])]
print('Loading weights from {}'.format(filename))
return model
# Define the encoder
class Encoder(torch.nn.Module):
def __init__(self, encoder):
super(Encoder,self).__init__()
self.encoder = encoder
def forward(self, input):
return self.encoder(input)
#Define the decoder
class Decoder(torch.nn.Module):
def __init__(self, c, decoders):
super(Decoder, self).__init__()
self.c = c
self.decoders = decoders
L = c.pool_layers
params = list(self.decoders[0].parameters())
for l in range(1, L):
params += list(self.decoders[l].parameters())
# optimizer
self.optimizer = torch.optim.Adam(params, lr=self.c.lr)
self.N = 256
def forward(self, pool_layers):
P = self.c.condition_vec
# print(self.decoders)
self.decoders = [decoder.eval() for decoder in self.decoders]
height = list()
width = list()
i = 0
test_dist = [list() for layer in pool_layers]
test_loss = 0.0
test_count = 0
start = time.time()
with torch.no_grad():
for l, layer in enumerate(pool_layers):
e = activation[layer] # BxCxHxW
#
B, C, H, W = e.size()
S = H * W
E = B * S
#
if i == 0: # get stats
height.append(H)
width.append(W)
#
p = positionalencoding2d(P, H, W).to(self.c.device).unsqueeze(0).repeat(B, 1, 1, 1)
c_r = p.reshape(B, P, S).transpose(1, 2).reshape(E, P) # BHWxP
e_r = e.reshape(B, C, S).transpose(1, 2).reshape(E, C) # BHWxC
decoder = self.decoders[l]
FIB = E // self.N + int(E % self.N > 0) # number of fiber batches
for f in range(FIB):
if f < (FIB - 1):
idx = torch.arange(f * self.N, (f + 1) * self.N)
else:
idx = torch.arange(f * self.N, E)
#
c_p = c_r[idx] # NxP
e_p = e_r[idx] # NxC
# m_p = m_r[idx] > 0.5 # Nx1
#
z, log_jac_det = decoder(e_p, [c_p, ])
#
decoder_log_prob = get_logp(C, z, log_jac_det)
log_prob = decoder_log_prob / C # likelihood per dim
loss = -log_theta(log_prob)
test_loss += t2np(loss.sum())
test_count += len(loss)
test_dist[l] = test_dist[l] + log_prob.detach().cpu().tolist()
return height, width, test_dist
#Define the base calss for the model
class CFlow(torch.nn.Module):
def __init__(self, c,encoder,decoders,pool_layers):
super(CFlow, self).__init__()
self.pool_layers=pool_layers
self.Encoder_module = Encoder(encoder)
self.Decoder_module = Decoder(c, decoders)
def forward(self,enc_input):
_=self.Encoder_module(enc_input)
height, width, test_dist = self.Decoder_module(self.pool_layers)
return height, width , test_dist
def init_seeds(seed=0):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
#Main
def main(c):
# model
c.gpu = '0'
c.enc_arch = 'mobilenet_v3_large'
c.inp = 256
c.dataset = 'mvtec'
c.action_type = 'norm-train'
# image
c.img_size = (c.input_size, c.input_size) # HxW format
c.crp_size = (c.input_size, c.input_size) # HxW format
if c.dataset == 'stc':
c.norm_mean, c.norm_std = 3 * [0.5], 3 * [0.225]
else:
c.norm_mean, c.norm_std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
#
c.img_dims = [3] + list(c.img_size)
# network hyperparameters
c.clamp_alpha = 1.9 # see paper equation 2 for explanation
c.condition_vec = 128
c.dropout = 0.0 # dropout in s-t-networks
# dataloader parameters
if c.dataset == 'mvtec':
c.data_path = './data/MVTec-AD'
elif c.dataset == 'stc':
c.data_path = './data/STC/shanghaitech'
elif c.dataset == 'video':
c.data_path = c.video_path
elif c.dataset == 'image':
c.data_path = c.image_path
else:
raise NotImplementedError('{} is not supported dataset!'.format(c.dataset))
# output settings
c.verbose = True
c.hide_tqdm_bar = True
c.save_results = True
# unsup-train
c.print_freq = 2
c.temp = 0.5
c.lr_decay_epochs = [i * c.meta_epochs // 100 for i in [50, 75, 90]]
print('LR schedule: {}'.format(c.lr_decay_epochs))
c.lr_decay_rate = 0.1
c.lr_warm_epochs = 2
c.lr_warm = True
c.lr_cosine = True
if c.lr_warm:
c.lr_warmup_from = c.lr / 10.0
if c.lr_cosine:
eta_min = c.lr * (c.lr_decay_rate ** 3)
c.lr_warmup_to = eta_min + (c.lr - eta_min) * (
1 + math.cos(math.pi * c.lr_warm_epochs / c.meta_epochs)) / 2
else:
c.lr_warmup_to = c.lr
########
os.environ['CUDA_VISIBLE_DEVICES'] = c.gpu
c.use_cuda = not c.no_cuda and torch.cuda.is_available()
init_seeds(seed=int(time.time()))
c.device = torch.device("cuda" if c.use_cuda else "cpu")
#Create the encoder and decoder networks
L = c.pool_layers
encoder, pool_layers, pool_dims = load_encoder_arch(c, L)
encoder = encoder.to(c.device).eval()
decoders = [load_decoder_arch(c, pool_dim) for pool_dim in pool_dims]
decoders = [decoder.to(c.device) for decoder in decoders]
#Initialize the base calss
model=CFlow(c,encoder,decoders,pool_layers)
PATH = 'weights/mvtec_mobilenet_v3_large_freia-cflow_pl3_cb8_inp256_run0_Model_2022-11-08-10:50:39.pt'
model = load_weights2(model, PATH)
print("Loaded !")
model.eval()
batch_size = 1
x = torch.randn(batch_size, 3, 256, 256).to(c.device)
out = model(x)
torch.onnx.export(
model, #
x,
"custom-d.onnx",
export_params=True,
verbose=True,
opset_version=11,
input_names=["input"],
output_names=["output"],
)
if __name__ == '__main__':
c = get_args()
main(c)