DCL/utils/utils.py

125 lines
3.8 KiB
Python

import torch
import torch.nn as nn
import torch.nn.functional as F
import pdb
class LossRecord(object):
def __init__(self, batch_size):
self.rec_loss = 0
self.count = 0
self.batch_size = batch_size
def update(self, loss):
if isinstance(loss, list):
avg_loss = sum(loss)
avg_loss /= (len(loss)*self.batch_size)
self.rec_loss += avg_loss
self.count += 1
if isinstance(loss, float):
self.rec_loss += loss/self.batch_size
self.count += 1
def get_val(self, init=False):
pop_loss = self.rec_loss / self.count
if init:
self.rec_loss = 0
self.count = 0
return pop_loss
def weights_normal_init(model, dev=0.01):
if isinstance(model, list):
for m in model:
weights_normal_init(m, dev)
else:
for m in model.modules():
if isinstance(m, nn.Conv2d):
m.weight.data.normal_(0.0, dev)
elif isinstance(m, nn.Linear):
m.weight.data.normal_(0.0, dev)
def clip_gradient(model, clip_norm):
"""Computes a gradient clipping coefficient based on gradient norm."""
totalnorm = 0
for p in model.parameters():
if p.requires_grad:
modulenorm = p.grad.data.norm()
totalnorm += modulenorm ** 2
totalnorm = torch.sqrt(totalnorm).item()
norm = (clip_norm / max(totalnorm, clip_norm))
for p in model.parameters():
if p.requires_grad:
p.grad.mul_(norm)
def Linear(in_features, out_features, bias=True):
"""Weight-normalized Linear layer (input: N x T x C)"""
m = nn.Linear(in_features, out_features, bias=bias)
m.weight.data.uniform_(-0.1, 0.1)
if bias:
m.bias.data.uniform_(-0.1, 0.1)
return m
class convolution(nn.Module):
def __init__(self, k, inp_dim, out_dim, stride=1, with_bn=True):
super(convolution, self).__init__()
pad = (k - 1) // 2
self.conv = nn.Conv2d(inp_dim, out_dim, (k, k), padding=(pad, pad), stride=(stride, stride), bias=not with_bn)
self.bn = nn.BatchNorm2d(out_dim) if with_bn else nn.Sequential()
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
conv = self.conv(x)
bn = self.bn(conv)
relu = self.relu(bn)
return relu
class fully_connected(nn.Module):
def __init__(self, inp_dim, out_dim, with_bn=True):
super(fully_connected, self).__init__()
self.with_bn = with_bn
self.linear = nn.Linear(inp_dim, out_dim)
if self.with_bn:
self.bn = nn.BatchNorm1d(out_dim)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
linear = self.linear(x)
bn = self.bn(linear) if self.with_bn else linear
relu = self.relu(bn)
return relu
class residual(nn.Module):
def __init__(self, k, inp_dim, out_dim, stride=1, with_bn=True):
super(residual, self).__init__()
self.conv1 = nn.Conv2d(inp_dim, out_dim, (3, 3), padding=(1, 1), stride=(stride, stride), bias=False)
self.bn1 = nn.BatchNorm2d(out_dim)
self.relu1 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(out_dim, out_dim, (3, 3), padding=(1, 1), bias=False)
self.bn2 = nn.BatchNorm2d(out_dim)
self.skip = nn.Sequential(
nn.Conv2d(inp_dim, out_dim, (1, 1), stride=(stride, stride), bias=False),
nn.BatchNorm2d(out_dim)
) if stride != 1 or inp_dim != out_dim else nn.Sequential()
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
conv1 = self.conv1(x)
bn1 = self.bn1(conv1)
relu1 = self.relu1(bn1)
conv2 = self.conv2(relu1)
bn2 = self.bn2(conv2)
skip = self.skip(x)
return self.relu(bn2 + skip)