linting code
parent
2a9f44af9b
commit
5ce43185b1
18
.flake8
18
.flake8
|
@ -1,4 +1,18 @@
|
||||||
[flake8]
|
[flake8]
|
||||||
ignore = E261, E501, W293
|
ignore =
|
||||||
|
# At least two spaces before inline comment
|
||||||
|
E261,
|
||||||
|
# Line lengths are recommended to be no greater than 79 characters
|
||||||
|
E501,
|
||||||
|
# Missing whitespace around arithmetic operator
|
||||||
|
E226,
|
||||||
|
# Blank line contains whitespace
|
||||||
|
W293,
|
||||||
|
# Do not use bare 'except'
|
||||||
|
E722,
|
||||||
|
# Line break after binary operator
|
||||||
|
W504,
|
||||||
|
# isort found an import in the wrong position
|
||||||
|
I001
|
||||||
max-line-length = 79
|
max-line-length = 79
|
||||||
exclude = __init__.py, build
|
exclude = __init__.py, build, torchreid/metrics/rank_cylib/
|
|
@ -57,7 +57,7 @@ templates_path = ['_templates']
|
||||||
# You can specify multiple suffix as a list of string:
|
# You can specify multiple suffix as a list of string:
|
||||||
#
|
#
|
||||||
source_suffix = ['.rst', '.md']
|
source_suffix = ['.rst', '.md']
|
||||||
#source_suffix = '.rst'
|
# source_suffix = '.rst'
|
||||||
source_parsers = {'.md': 'recommonmark.parser.CommonMarkParser'}
|
source_parsers = {'.md': 'recommonmark.parser.CommonMarkParser'}
|
||||||
|
|
||||||
# The master toctree document.
|
# The master toctree document.
|
||||||
|
|
10
linter.sh
10
linter.sh
|
@ -1,3 +1,11 @@
|
||||||
|
echo "Running isort"
|
||||||
isort -y -sp .
|
isort -y -sp .
|
||||||
|
echo "Done"
|
||||||
|
|
||||||
yapf -i -r -vv . -e build
|
echo "Running yapf"
|
||||||
|
yapf -i -r -vv -e build .
|
||||||
|
echo "Done"
|
||||||
|
|
||||||
|
echo "Running flake8"
|
||||||
|
flake8 .
|
||||||
|
echo "Done"
|
|
@ -1,10 +1,7 @@
|
||||||
from __future__ import division, print_function, absolute_import
|
from __future__ import division, print_function, absolute_import
|
||||||
import time
|
|
||||||
import datetime
|
|
||||||
import torch
|
import torch
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
from torchreid import metrics
|
|
||||||
from torchreid.utils import open_all_layers, open_specified_layers
|
from torchreid.utils import open_all_layers, open_specified_layers
|
||||||
from torchreid.engine import Engine
|
from torchreid.engine import Engine
|
||||||
from torchreid.losses import TripletLoss, CrossEntropyLoss
|
from torchreid.losses import TripletLoss, CrossEntropyLoss
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
import os
|
|
||||||
import sys
|
import sys
|
||||||
import copy
|
import copy
|
||||||
import time
|
import time
|
||||||
|
@ -15,8 +14,8 @@ from torchreid.utils import (
|
||||||
|
|
||||||
from dml import ImageDMLEngine
|
from dml import ImageDMLEngine
|
||||||
from default_config import (
|
from default_config import (
|
||||||
imagedata_kwargs, optimizer_kwargs, videodata_kwargs, engine_run_kwargs,
|
imagedata_kwargs, optimizer_kwargs, engine_run_kwargs, get_default_config,
|
||||||
get_default_config, lr_scheduler_kwargs
|
lr_scheduler_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -10,6 +10,24 @@ NORM_AFFINE = False # enable affine transformations for normalization layer
|
||||||
##########
|
##########
|
||||||
# Basic layers
|
# Basic layers
|
||||||
##########
|
##########
|
||||||
|
class IBN(nn.Module):
|
||||||
|
"""Instance + Batch Normalization."""
|
||||||
|
|
||||||
|
def __init__(self, num_channels):
|
||||||
|
super(IBN, self).__init__()
|
||||||
|
half1 = int(num_channels / 2)
|
||||||
|
self.half = half1
|
||||||
|
half2 = num_channels - half1
|
||||||
|
self.IN = nn.InstanceNorm2d(half1, affine=NORM_AFFINE)
|
||||||
|
self.BN = nn.BatchNorm2d(half2, affine=NORM_AFFINE)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
split = torch.split(x, self.half, 1)
|
||||||
|
out1 = self.IN(split[0].contiguous())
|
||||||
|
out2 = self.BN(split[1].contiguous())
|
||||||
|
return torch.cat((out1, out2), 1)
|
||||||
|
|
||||||
|
|
||||||
class ConvLayer(nn.Module):
|
class ConvLayer(nn.Module):
|
||||||
"""Convolution layer (conv + bn + relu)."""
|
"""Convolution layer (conv + bn + relu)."""
|
||||||
|
|
||||||
|
|
|
@ -1,11 +1,6 @@
|
||||||
from __future__ import division, print_function, absolute_import
|
from __future__ import division, print_function, absolute_import
|
||||||
import time
|
|
||||||
import datetime
|
|
||||||
|
|
||||||
from torchreid import metrics
|
from torchreid import metrics
|
||||||
from torchreid.utils import (
|
|
||||||
AverageMeter, open_all_layers, open_specified_layers
|
|
||||||
)
|
|
||||||
from torchreid.engine import Engine
|
from torchreid.engine import Engine
|
||||||
from torchreid.losses import CrossEntropyLoss
|
from torchreid.losses import CrossEntropyLoss
|
||||||
|
|
||||||
|
@ -58,7 +53,7 @@ class ImageSoftmaxNASEngine(Engine):
|
||||||
lmda = self.init_lmda
|
lmda = self.init_lmda
|
||||||
else:
|
else:
|
||||||
lmda = self.init_lmda * self.lmda_decay_rate**(
|
lmda = self.init_lmda * self.lmda_decay_rate**(
|
||||||
epoch // self.lmda_decay_step
|
self.epoch // self.lmda_decay_step
|
||||||
)
|
)
|
||||||
if lmda < self.min_lmda:
|
if lmda < self.min_lmda:
|
||||||
lmda = self.min_lmda
|
lmda = self.min_lmda
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
import os
|
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
|
|
3
setup.py
3
setup.py
|
@ -45,8 +45,7 @@ def get_requirements(filename='requirements.txt'):
|
||||||
setup(
|
setup(
|
||||||
name='torchreid',
|
name='torchreid',
|
||||||
version=find_version(),
|
version=find_version(),
|
||||||
description=
|
description='A library for deep learning person re-ID in PyTorch',
|
||||||
'A library for deep learning person re-identification in PyTorch',
|
|
||||||
author='Kaiyang Zhou',
|
author='Kaiyang Zhou',
|
||||||
license='MIT',
|
license='MIT',
|
||||||
long_description=readme(),
|
long_description=readme(),
|
||||||
|
|
|
@ -70,7 +70,7 @@ class DataManager(object):
|
||||||
"""Returns the number of training cameras."""
|
"""Returns the number of training cameras."""
|
||||||
return self._num_train_cams
|
return self._num_train_cams
|
||||||
|
|
||||||
def return_query_and_gallery_by_name(self, name):
|
def fetch_qg(self, name):
|
||||||
"""Returns query and gallery of a test dataset, each containing
|
"""Returns query and gallery of a test dataset, each containing
|
||||||
tuples of (img_path(s), pid, camid).
|
tuples of (img_path(s), pid, camid).
|
||||||
|
|
||||||
|
|
|
@ -238,9 +238,9 @@ class Dataset(object):
|
||||||
' gallery | {:5d} | {:7d} | {:9d}\n' \
|
' gallery | {:5d} | {:7d} | {:9d}\n' \
|
||||||
' ----------------------------------------\n' \
|
' ----------------------------------------\n' \
|
||||||
' items: images/tracklets for image/video dataset\n'.format(
|
' items: images/tracklets for image/video dataset\n'.format(
|
||||||
num_train_pids, len(self.train), num_train_cams,
|
num_train_pids, len(self.train), num_train_cams,
|
||||||
num_query_pids, len(self.query), num_query_cams,
|
num_query_pids, len(self.query), num_query_cams,
|
||||||
num_gallery_pids, len(self.gallery), num_gallery_cams
|
num_gallery_pids, len(self.gallery), num_gallery_cams
|
||||||
)
|
)
|
||||||
|
|
||||||
return msg
|
return msg
|
||||||
|
|
|
@ -151,10 +151,8 @@ class CUHK03(ImageDataset):
|
||||||
img_paths = _process_images(
|
img_paths = _process_images(
|
||||||
camp[pid, :], campid, pid, imgs_dir
|
camp[pid, :], campid, pid, imgs_dir
|
||||||
)
|
)
|
||||||
assert len(img_paths
|
assert len(img_paths) > 0, \
|
||||||
) > 0, 'campid{}-pid{} has no images'.format(
|
'campid{}-pid{} has no images'.format(campid, pid)
|
||||||
campid, pid
|
|
||||||
)
|
|
||||||
meta_data.append((campid + 1, pid + 1, img_paths))
|
meta_data.append((campid + 1, pid + 1, img_paths))
|
||||||
print(
|
print(
|
||||||
'- done camera pair {} with {} identities'.format(
|
'- done camera pair {} with {} identities'.format(
|
||||||
|
|
|
@ -61,7 +61,8 @@ class DukeMTMCreID(ImageDataset):
|
||||||
pid, camid = map(int, pattern.search(img_path).groups())
|
pid, camid = map(int, pattern.search(img_path).groups())
|
||||||
assert 1 <= camid <= 8
|
assert 1 <= camid <= 8
|
||||||
camid -= 1 # index starts from 0
|
camid -= 1 # index starts from 0
|
||||||
if relabel: pid = pid2label[pid]
|
if relabel:
|
||||||
|
pid = pid2label[pid]
|
||||||
data.append((img_path, pid, camid))
|
data.append((img_path, pid, camid))
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
|
@ -3,7 +3,7 @@ import os.path as osp
|
||||||
|
|
||||||
from ..dataset import ImageDataset
|
from ..dataset import ImageDataset
|
||||||
|
|
||||||
##### Log #####
|
# Log
|
||||||
# 22.01.2019
|
# 22.01.2019
|
||||||
# - add v2
|
# - add v2
|
||||||
# - v1 and v2 differ in dir names
|
# - v1 and v2 differ in dir names
|
||||||
|
|
|
@ -83,10 +83,8 @@ class VIPeR(ImageDataset):
|
||||||
np.random.shuffle(order)
|
np.random.shuffle(order)
|
||||||
train_idxs = order[:num_train_pids]
|
train_idxs = order[:num_train_pids]
|
||||||
test_idxs = order[num_train_pids:]
|
test_idxs = order[num_train_pids:]
|
||||||
assert not bool(
|
assert not bool(set(train_idxs) & set(test_idxs)), \
|
||||||
set(train_idxs)
|
'Error: train and test overlap'
|
||||||
& set(test_idxs)
|
|
||||||
), 'Error: train and test overlap'
|
|
||||||
|
|
||||||
train = []
|
train = []
|
||||||
for pid, idx in enumerate(train_idxs):
|
for pid, idx in enumerate(train_idxs):
|
||||||
|
|
|
@ -98,7 +98,8 @@ class Mars(VideoDataset):
|
||||||
if pid == -1:
|
if pid == -1:
|
||||||
continue # junk images are just ignored
|
continue # junk images are just ignored
|
||||||
assert 1 <= camid <= 6
|
assert 1 <= camid <= 6
|
||||||
if relabel: pid = pid2label[pid]
|
if relabel:
|
||||||
|
pid = pid2label[pid]
|
||||||
camid -= 1 # index starts from 0
|
camid -= 1 # index starts from 0
|
||||||
img_names = names[start_index - 1:end_index]
|
img_names = names[start_index - 1:end_index]
|
||||||
|
|
||||||
|
|
|
@ -4,7 +4,9 @@ import random
|
||||||
from collections import deque
|
from collections import deque
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from torchvision.transforms import *
|
from torchvision.transforms import (
|
||||||
|
Resize, Compose, ToTensor, Normalize, ColorJitter, RandomHorizontalFlip
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Random2DTranslation(object):
|
class Random2DTranslation(object):
|
||||||
|
@ -279,8 +281,13 @@ def build_transforms(
|
||||||
transform_tr += [RandomHorizontalFlip()]
|
transform_tr += [RandomHorizontalFlip()]
|
||||||
|
|
||||||
if 'random_crop' in transforms:
|
if 'random_crop' in transforms:
|
||||||
print('+ random crop (enlarge to {}x{} and ' \
|
print(
|
||||||
'crop {}x{})'.format(int(round(height*1.125)), int(round(width*1.125)), height, width))
|
'+ random crop (enlarge to {}x{} and '
|
||||||
|
'crop {}x{})'.format(
|
||||||
|
int(round(height * 1.125)), int(round(width * 1.125)), height,
|
||||||
|
width
|
||||||
|
)
|
||||||
|
)
|
||||||
transform_tr += [Random2DTranslation(height, width)]
|
transform_tr += [Random2DTranslation(height, width)]
|
||||||
|
|
||||||
if 'random_patch' in transforms:
|
if 'random_patch' in transforms:
|
||||||
|
|
|
@ -428,8 +428,7 @@ class Engine(object):
|
||||||
if visrank:
|
if visrank:
|
||||||
visualize_ranked_results(
|
visualize_ranked_results(
|
||||||
distmat,
|
distmat,
|
||||||
self.datamanager.
|
self.datamanager.fetch_qg(dataset_name),
|
||||||
return_query_and_gallery_by_name(dataset_name),
|
|
||||||
self.datamanager.data_type,
|
self.datamanager.data_type,
|
||||||
width=self.datamanager.width,
|
width=self.datamanager.width,
|
||||||
height=self.datamanager.height,
|
height=self.datamanager.height,
|
||||||
|
|
|
@ -1,6 +1,4 @@
|
||||||
from __future__ import division, print_function, absolute_import
|
from __future__ import division, print_function, absolute_import
|
||||||
import time
|
|
||||||
import datetime
|
|
||||||
|
|
||||||
from torchreid import metrics
|
from torchreid import metrics
|
||||||
from torchreid.losses import CrossEntropyLoss
|
from torchreid.losses import CrossEntropyLoss
|
||||||
|
|
|
@ -1,6 +1,4 @@
|
||||||
from __future__ import division, print_function, absolute_import
|
from __future__ import division, print_function, absolute_import
|
||||||
import time
|
|
||||||
import datetime
|
|
||||||
|
|
||||||
from torchreid import metrics
|
from torchreid import metrics
|
||||||
from torchreid.losses import TripletLoss, CrossEntropyLoss
|
from torchreid.losses import TripletLoss, CrossEntropyLoss
|
||||||
|
|
|
@ -57,8 +57,9 @@ def euclidean_squared_distance(input1, input2):
|
||||||
torch.Tensor: distance matrix.
|
torch.Tensor: distance matrix.
|
||||||
"""
|
"""
|
||||||
m, n = input1.size(0), input2.size(0)
|
m, n = input1.size(0), input2.size(0)
|
||||||
distmat = torch.pow(input1, 2).sum(dim=1, keepdim=True).expand(m, n) + \
|
mat1 = torch.pow(input1, 2).sum(dim=1, keepdim=True).expand(m, n)
|
||||||
torch.pow(input2, 2).sum(dim=1, keepdim=True).expand(n, m).t()
|
mat2 = torch.pow(input2, 2).sum(dim=1, keepdim=True).expand(n, m).t()
|
||||||
|
distmat = mat1 + mat2
|
||||||
distmat.addmm_(1, -2, input1, input2.t())
|
distmat.addmm_(1, -2, input1, input2.t())
|
||||||
return distmat
|
return distmat
|
||||||
|
|
||||||
|
|
|
@ -300,7 +300,8 @@ class HACNN(nn.Module):
|
||||||
theta = torch.zeros(theta_i.size(0), 2, 3)
|
theta = torch.zeros(theta_i.size(0), 2, 3)
|
||||||
theta[:, :, :2] = scale_factors
|
theta[:, :, :2] = scale_factors
|
||||||
theta[:, :, -1] = theta_i
|
theta[:, :, -1] = theta_i
|
||||||
if self.use_gpu: theta = theta.cuda()
|
if self.use_gpu:
|
||||||
|
theta = theta.cuda()
|
||||||
return theta
|
return theta
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
|
|
@ -249,9 +249,9 @@ class Block8(nn.Module):
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
##################### Model Definition #########################
|
# ----------------
|
||||||
|
# Model Definition
|
||||||
|
# ----------------
|
||||||
class InceptionResNetV2(nn.Module):
|
class InceptionResNetV2(nn.Module):
|
||||||
"""Inception-ResNet-V2.
|
"""Inception-ResNet-V2.
|
||||||
|
|
||||||
|
|
|
@ -260,7 +260,7 @@ def init_pretrained_weights(model, model_url):
|
||||||
def mlfn(num_classes, loss='softmax', pretrained=True, **kwargs):
|
def mlfn(num_classes, loss='softmax', pretrained=True, **kwargs):
|
||||||
model = MLFN(num_classes, loss, **kwargs)
|
model = MLFN(num_classes, loss, **kwargs)
|
||||||
if pretrained:
|
if pretrained:
|
||||||
#init_pretrained_weights(model, model_urls['imagenet'])
|
# init_pretrained_weights(model, model_urls['imagenet'])
|
||||||
import warnings
|
import warnings
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
'The imagenet pretrained weights need to be manually downloaded from {}'
|
'The imagenet pretrained weights need to be manually downloaded from {}'
|
||||||
|
|
|
@ -246,7 +246,7 @@ def mobilenetv2_x1_0(num_classes, loss, pretrained=True, **kwargs):
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
if pretrained:
|
if pretrained:
|
||||||
#init_pretrained_weights(model, model_urls['mobilenetv2_x1_0'])
|
# init_pretrained_weights(model, model_urls['mobilenetv2_x1_0'])
|
||||||
import warnings
|
import warnings
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
'The imagenet pretrained weights need to be manually downloaded from {}'
|
'The imagenet pretrained weights need to be manually downloaded from {}'
|
||||||
|
@ -265,7 +265,7 @@ def mobilenetv2_x1_4(num_classes, loss, pretrained=True, **kwargs):
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
if pretrained:
|
if pretrained:
|
||||||
#init_pretrained_weights(model, model_urls['mobilenetv2_x1_4'])
|
# init_pretrained_weights(model, model_urls['mobilenetv2_x1_4'])
|
||||||
import warnings
|
import warnings
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
'The imagenet pretrained weights need to be manually downloaded from {}'
|
'The imagenet pretrained weights need to be manually downloaded from {}'
|
||||||
|
|
|
@ -26,7 +26,7 @@ Code imported from https://github.com/Cadene/pretrained-models.pytorch
|
||||||
pretrained_settings = {
|
pretrained_settings = {
|
||||||
'nasnetamobile': {
|
'nasnetamobile': {
|
||||||
'imagenet': {
|
'imagenet': {
|
||||||
#'url': 'https://github.com/veronikayurchuk/pretrained-models.pytorch/releases/download/v1.0/nasnetmobile-7e03cead.pth.tar',
|
# 'url': 'https://github.com/veronikayurchuk/pretrained-models.pytorch/releases/download/v1.0/nasnetmobile-7e03cead.pth.tar',
|
||||||
'url':
|
'url':
|
||||||
'http://data.lip6.fr/cadene/pretrainedmodels/nasnetamobile-7e03cead.pth',
|
'http://data.lip6.fr/cadene/pretrainedmodels/nasnetamobile-7e03cead.pth',
|
||||||
'input_space': 'RGB',
|
'input_space': 'RGB',
|
||||||
|
|
|
@ -45,7 +45,8 @@ class Bottleneck(nn.Module):
|
||||||
assert stride in [1, 2], 'Warning: stride must be either 1 or 2'
|
assert stride in [1, 2], 'Warning: stride must be either 1 or 2'
|
||||||
self.stride = stride
|
self.stride = stride
|
||||||
mid_channels = out_channels // 4
|
mid_channels = out_channels // 4
|
||||||
if stride == 2: out_channels -= in_channels
|
if stride == 2:
|
||||||
|
out_channels -= in_channels
|
||||||
# group conv is not applied to first conv1x1 at stage 2
|
# group conv is not applied to first conv1x1 at stage 2
|
||||||
num_groups_conv1x1 = num_groups if group_conv1x1 else 1
|
num_groups_conv1x1 = num_groups if group_conv1x1 else 1
|
||||||
self.conv1 = nn.Conv2d(
|
self.conv1 = nn.Conv2d(
|
||||||
|
@ -71,7 +72,8 @@ class Bottleneck(nn.Module):
|
||||||
mid_channels, out_channels, 1, groups=num_groups, bias=False
|
mid_channels, out_channels, 1, groups=num_groups, bias=False
|
||||||
)
|
)
|
||||||
self.bn3 = nn.BatchNorm2d(out_channels)
|
self.bn3 = nn.BatchNorm2d(out_channels)
|
||||||
if stride == 2: self.shortcut = nn.AvgPool2d(3, stride=2, padding=1)
|
if stride == 2:
|
||||||
|
self.shortcut = nn.AvgPool2d(3, stride=2, padding=1)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
out = F.relu(self.bn1(self.conv1(x)))
|
out = F.relu(self.bn1(self.conv1(x)))
|
||||||
|
@ -187,7 +189,7 @@ def init_pretrained_weights(model, model_url):
|
||||||
def shufflenet(num_classes, loss='softmax', pretrained=True, **kwargs):
|
def shufflenet(num_classes, loss='softmax', pretrained=True, **kwargs):
|
||||||
model = ShuffleNet(num_classes, loss, **kwargs)
|
model = ShuffleNet(num_classes, loss, **kwargs)
|
||||||
if pretrained:
|
if pretrained:
|
||||||
#init_pretrained_weights(model, model_urls['imagenet'])
|
# init_pretrained_weights(model, model_urls['imagenet'])
|
||||||
import warnings
|
import warnings
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
'The imagenet pretrained weights need to be manually downloaded from {}'
|
'The imagenet pretrained weights need to be manually downloaded from {}'
|
||||||
|
|
|
@ -5,7 +5,6 @@ from __future__ import division, absolute_import
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.utils.model_zoo as model_zoo
|
import torch.utils.model_zoo as model_zoo
|
||||||
from torch.utils import model_zoo as model_zoo
|
|
||||||
|
|
||||||
__all__ = ['squeezenet1_0', 'squeezenet1_1', 'squeezenet1_0_fc512']
|
__all__ = ['squeezenet1_0', 'squeezenet1_1', 'squeezenet1_0_fc512']
|
||||||
|
|
||||||
|
|
|
@ -32,7 +32,7 @@ Convolution
|
||||||
def hook_convNd(m, x, y):
|
def hook_convNd(m, x, y):
|
||||||
k = torch.prod(torch.Tensor(m.kernel_size)).item()
|
k = torch.prod(torch.Tensor(m.kernel_size)).item()
|
||||||
cin = m.in_channels
|
cin = m.in_channels
|
||||||
flops_per_ele = k * cin #+ (k*cin-1)
|
flops_per_ele = k * cin # + (k*cin-1)
|
||||||
if m.bias is not None:
|
if m.bias is not None:
|
||||||
flops_per_ele += 1
|
flops_per_ele += 1
|
||||||
flops = flops_per_ele * y.numel() / m.groups
|
flops = flops_per_ele * y.numel() / m.groups
|
||||||
|
@ -200,7 +200,7 @@ Linear
|
||||||
|
|
||||||
|
|
||||||
def hook_linear(m, x, y):
|
def hook_linear(m, x, y):
|
||||||
flops_per_ele = m.in_features #+ (m.in_features-1)
|
flops_per_ele = m.in_features # + (m.in_features-1)
|
||||||
if m.bias is not None:
|
if m.bias is not None:
|
||||||
flops_per_ele += 1
|
flops_per_ele += 1
|
||||||
flops = flops_per_ele * y.numel()
|
flops = flops_per_ele * y.numel()
|
||||||
|
|
Loading…
Reference in New Issue