[projects] add NAS code

pull/265/head
kaiyangzhou 2019-11-08 13:00:39 +00:00
parent 240f5f6e7b
commit 727de0728f
7 changed files with 1433 additions and 1 deletions

View File

@ -1 +1,36 @@
Architecture search code for OSNet-AIN. Coming soon ...
# Differentiable NAS for OSNet-AIN
## Introduction
This repository contains the neural architecture search (NAS) code (based on [Torchreid](https://arxiv.org/abs/1910.10093)) for [OSNet-AIN](https://arxiv.org/abs/1910.06827), an extension of [OSNet](https://arxiv.org/abs/1905.00953) that achieves strong performance on cross-domain person re-identification (re-ID) benchmarks *without using any target data*. OSNet-AIN builds on the idea of using using [instance normalisation](https://arxiv.org/abs/1607.08022) (IN) layers to eliminate instance-specific contrast for learning domain-generalisable representations. This is inspired by the [style transfer](https://arxiv.org/abs/1703.06868) works that use IN to remove image styles. However, it remains unclear that for a particular computer vision task (i.e. person re-ID in our case), where to insert IN to a CNN can maximise the performance gain. To overcome this problem, OSNet-AIN learns to search for the optimal OSNet+IN design from data using a differentiable NAS algorithm. For technical details, please refer to our paper at https://arxiv.org/abs/1910.06827.
<div align="center">
<img src="https://drive.google.com/uc?export=view&id=1EyO1cD8wikh86YLOR3CFdaGKmoM0-3iX" width="500px" />
</div>
## Training
Assume the reid data is stored at `$DATA`. Run
```
python main.py --config-file nas.yaml --root $DATA
```
The default config was designed for 8 Tesla V100 32GB GPUs. You can modify the batch size based on your device memory.
**Note** that the test result obtained at the end of training means nothing. Do not rely on the result to judge the model performance. Instead, you should construct the found architecture in `osnet_child.py` and evaluate the model on the reid datasets.
## Citation
If you find this code useful to your research, please consider citing the following papers.
```
@article{zhou2019learning,
title={Learning Generalisable Omni-Scale Representations for Person Re-Identification},
author={Zhou, Kaiyang and Yang, Yongxin and Cavallaro, Andrea and Xiang, Tao},
journal={arXiv preprint arXiv:1910.06827},
year={2019}
}
@inproceedings{zhou2019osnet,
title={Omni-Scale Feature Learning for Person Re-Identification},
author={Zhou, Kaiyang and Yang, Yongxin and Cavallaro, Andrea and Xiang, Tao},
booktitle={ICCV},
year={2019}
}
```

View File

@ -0,0 +1,209 @@
from yacs.config import CfgNode as CN
def get_default_config():
cfg = CN()
# model
cfg.model = CN()
cfg.model.name = 'resnet50'
cfg.model.pretrained = True # automatically load pretrained model weights if available
cfg.model.load_weights = '' # path to model weights
cfg.model.resume = '' # path to checkpoint for resume training
# NAS
cfg.nas = CN()
cfg.nas.mc_iter = 1 # Monte Carlo sampling
cfg.nas.init_lmda = 10. # initial lambda value
cfg.nas.min_lmda = 1. # minimum lambda value
cfg.nas.lmda_decay_step = 20 # decay step for lambda
cfg.nas.lmda_decay_rate = 0.5 # decay rate for lambda
cfg.nas.fixed_lmda = False # keep lambda unchanged
# data
cfg.data = CN()
cfg.data.type = 'image'
cfg.data.root = 'reid-data'
cfg.data.sources = ['market1501']
cfg.data.targets = ['market1501']
cfg.data.workers = 4 # number of data loading workers
cfg.data.split_id = 0 # split index
cfg.data.height = 256 # image height
cfg.data.width = 128 # image width
cfg.data.combineall = False # combine train, query and gallery for training
cfg.data.transforms = ['random_flip'] # data augmentation
cfg.data.norm_mean = [0.485, 0.456, 0.406] # default is imagenet mean
cfg.data.norm_std = [0.229, 0.224, 0.225] # default is imagenet std
cfg.data.save_dir = 'log' # path to save log
# specific datasets
cfg.market1501 = CN()
cfg.market1501.use_500k_distractors = False # add 500k distractors to the gallery set for market1501
cfg.cuhk03 = CN()
cfg.cuhk03.labeled_images = False # use labeled images, if False, use detected images
cfg.cuhk03.classic_split = False # use classic split by Li et al. CVPR14
cfg.cuhk03.use_metric_cuhk03 = False # use cuhk03's metric for evaluation
# sampler
cfg.sampler = CN()
cfg.sampler.train_sampler = 'RandomSampler'
cfg.sampler.num_instances = 4 # number of instances per identity for RandomIdentitySampler
# video reid setting
cfg.video = CN()
cfg.video.seq_len = 15 # number of images to sample in a tracklet
cfg.video.sample_method = 'evenly' # how to sample images from a tracklet
cfg.video.pooling_method = 'avg' # how to pool features over a tracklet
# train
cfg.train = CN()
cfg.train.optim = 'adam'
cfg.train.lr = 0.0003
cfg.train.weight_decay = 5e-4
cfg.train.max_epoch = 60
cfg.train.start_epoch = 0
cfg.train.batch_size = 32
cfg.train.fixbase_epoch = 0 # number of epochs to fix base layers
cfg.train.open_layers = ['classifier'] # layers for training while keeping others frozen
cfg.train.staged_lr = False # set different lr to different layers
cfg.train.new_layers = ['classifier'] # newly added layers with default lr
cfg.train.base_lr_mult = 0.1 # learning rate multiplier for base layers
cfg.train.lr_scheduler = 'single_step'
cfg.train.stepsize = [20] # stepsize to decay learning rate
cfg.train.gamma = 0.1 # learning rate decay multiplier
cfg.train.print_freq = 20 # print frequency
cfg.train.seed = 1 # random seed
# optimizer
cfg.sgd = CN()
cfg.sgd.momentum = 0.9 # momentum factor for sgd and rmsprop
cfg.sgd.dampening = 0. # dampening for momentum
cfg.sgd.nesterov = False # Nesterov momentum
cfg.rmsprop = CN()
cfg.rmsprop.alpha = 0.99 # smoothing constant
cfg.adam = CN()
cfg.adam.beta1 = 0.9 # exponential decay rate for first moment
cfg.adam.beta2 = 0.999 # exponential decay rate for second moment
# loss
cfg.loss = CN()
cfg.loss.name = 'softmax'
cfg.loss.softmax = CN()
cfg.loss.softmax.label_smooth = True # use label smoothing regularizer
cfg.loss.triplet = CN()
cfg.loss.triplet.margin = 0.3 # distance margin
cfg.loss.triplet.weight_t =1. # weight to balance hard triplet loss
cfg.loss.triplet.weight_x = 0. # weight to balance cross entropy loss
# test
cfg.test = CN()
cfg.test.batch_size = 100
cfg.test.dist_metric = 'euclidean' # distance metric, ['euclidean', 'cosine']
cfg.test.normalize_feature = False # normalize feature vectors before computing distance
cfg.test.ranks = [1, 5, 10, 20] # cmc ranks
cfg.test.evaluate = False # test only
cfg.test.eval_freq = -1 # evaluation frequency (-1 means to only test after training)
cfg.test.start_eval = 0 # start to evaluate after a specific epoch
cfg.test.rerank = False # use person re-ranking
cfg.test.visrank = False # visualize ranked results (only available when cfg.test.evaluate=True)
cfg.test.visrank_topk = 10 # top-k ranks to visualize
cfg.test.visactmap = False # visualize CNN activation maps
return cfg
def imagedata_kwargs(cfg):
return {
'root': cfg.data.root,
'sources': cfg.data.sources,
'targets': cfg.data.targets,
'height': cfg.data.height,
'width': cfg.data.width,
'transforms': cfg.data.transforms,
'norm_mean': cfg.data.norm_mean,
'norm_std': cfg.data.norm_std,
'use_gpu': cfg.use_gpu,
'split_id': cfg.data.split_id,
'combineall': cfg.data.combineall,
'batch_size_train': cfg.train.batch_size,
'batch_size_test': cfg.test.batch_size,
'workers': cfg.data.workers,
'num_instances': cfg.sampler.num_instances,
'train_sampler': cfg.sampler.train_sampler,
# image
'cuhk03_labeled': cfg.cuhk03.labeled_images,
'cuhk03_classic_split': cfg.cuhk03.classic_split,
'market1501_500k': cfg.market1501.use_500k_distractors,
}
def videodata_kwargs(cfg):
return {
'root': cfg.data.root,
'sources': cfg.data.sources,
'targets': cfg.data.targets,
'height': cfg.data.height,
'width': cfg.data.width,
'transforms': cfg.data.transforms,
'norm_mean': cfg.data.norm_mean,
'norm_std': cfg.data.norm_std,
'use_gpu': cfg.use_gpu,
'split_id': cfg.data.split_id,
'combineall': cfg.data.combineall,
'batch_size_train': cfg.train.batch_size,
'batch_size_test': cfg.test.batch_size,
'workers': cfg.data.workers,
'num_instances': cfg.sampler.num_instances,
'train_sampler': cfg.sampler.train_sampler,
# video
'seq_len': cfg.video.seq_len,
'sample_method': cfg.video.sample_method
}
def optimizer_kwargs(cfg):
return {
'optim': cfg.train.optim,
'lr': cfg.train.lr,
'weight_decay': cfg.train.weight_decay,
'momentum': cfg.sgd.momentum,
'sgd_dampening': cfg.sgd.dampening,
'sgd_nesterov': cfg.sgd.nesterov,
'rmsprop_alpha': cfg.rmsprop.alpha,
'adam_beta1': cfg.adam.beta1,
'adam_beta2': cfg.adam.beta2,
'staged_lr': cfg.train.staged_lr,
'new_layers': cfg.train.new_layers,
'base_lr_mult': cfg.train.base_lr_mult
}
def lr_scheduler_kwargs(cfg):
return {
'lr_scheduler': cfg.train.lr_scheduler,
'stepsize': cfg.train.stepsize,
'gamma': cfg.train.gamma,
'max_epoch': cfg.train.max_epoch
}
def engine_run_kwargs(cfg):
return {
'save_dir': cfg.data.save_dir,
'max_epoch': cfg.train.max_epoch,
'start_epoch': cfg.train.start_epoch,
'fixbase_epoch': cfg.train.fixbase_epoch,
'open_layers': cfg.train.open_layers,
'start_eval': cfg.test.start_eval,
'eval_freq': cfg.test.eval_freq,
'test_only': cfg.test.evaluate,
'print_freq': cfg.train.print_freq,
'dist_metric': cfg.test.dist_metric,
'normalize_feature': cfg.test.normalize_feature,
'visrank': cfg.test.visrank,
'visrank_topk': cfg.test.visrank_topk,
'use_metric_cuhk03': cfg.cuhk03.use_metric_cuhk03,
'ranks': cfg.test.ranks,
'rerank': cfg.test.rerank,
'visactmap': cfg.test.visactmap
}

View File

@ -0,0 +1,109 @@
import sys
import os
import os.path as osp
import time
import argparse
import torch
import torch.nn as nn
from default_config import (
get_default_config, imagedata_kwargs, videodata_kwargs,
optimizer_kwargs, lr_scheduler_kwargs, engine_run_kwargs
)
import torchreid
from torchreid.utils import (
Logger, set_random_seed, check_isfile, resume_from_checkpoint,
load_pretrained_weights, compute_model_complexity, collect_env_info
)
import osnet_search as osnet_models
from softmax_nas import ImageSoftmaxNASEngine
def reset_config(cfg, args):
if args.root:
cfg.data.root = args.root
if args.sources:
cfg.data.sources = args.sources
if args.targets:
cfg.data.targets = args.targets
if args.transforms:
cfg.data.transforms = args.transforms
def main():
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--config-file', type=str, default='', help='path to config file')
parser.add_argument('-s', '--sources', type=str, nargs='+', help='source datasets (delimited by space)')
parser.add_argument('-t', '--targets', type=str, nargs='+', help='target datasets (delimited by space)')
parser.add_argument('--transforms', type=str, nargs='+', help='data augmentation')
parser.add_argument('--root', type=str, default='', help='path to data root')
parser.add_argument('--gpu-devices', type=str, default='',)
parser.add_argument('opts', default=None, nargs=argparse.REMAINDER, help='Modify config options using the command-line')
args = parser.parse_args()
cfg = get_default_config()
cfg.use_gpu = torch.cuda.is_available()
if args.config_file:
cfg.merge_from_file(args.config_file)
reset_config(cfg, args)
cfg.merge_from_list(args.opts)
set_random_seed(cfg.train.seed)
if cfg.use_gpu and args.gpu_devices:
# if gpu_devices is not specified, all available gpus will be used
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_devices
log_name = 'test.log' if cfg.test.evaluate else 'train.log'
log_name += time.strftime('-%Y-%m-%d-%H-%M-%S')
sys.stdout = Logger(osp.join(cfg.data.save_dir, log_name))
print('Show configuration\n{}\n'.format(cfg))
print('Collecting env info ...')
print('** System info **\n{}\n'.format(collect_env_info()))
if cfg.use_gpu:
torch.backends.cudnn.benchmark = True
datamanager = torchreid.data.ImageDataManager(**imagedata_kwargs(cfg))
print('Building model: {}'.format(cfg.model.name))
model = osnet_models.build_model(cfg.model.name, num_classes=datamanager.num_train_pids)
num_params, flops = compute_model_complexity(model, (1, 3, cfg.data.height, cfg.data.width))
print('Model complexity: params={:,} flops={:,}'.format(num_params, flops))
if cfg.use_gpu:
model = nn.DataParallel(model).cuda()
optimizer = torchreid.optim.build_optimizer(model, **optimizer_kwargs(cfg))
scheduler = torchreid.optim.build_lr_scheduler(optimizer, **lr_scheduler_kwargs(cfg))
if cfg.model.resume and check_isfile(cfg.model.resume):
cfg.train.start_epoch = resume_from_checkpoint(cfg.model.resume, model, optimizer=optimizer)
print('Building NAS engine')
engine = ImageSoftmaxNASEngine(
datamanager,
model,
optimizer,
scheduler=scheduler,
use_gpu=cfg.use_gpu,
label_smooth=cfg.loss.softmax.label_smooth,
mc_iter=cfg.nas.mc_iter,
init_lmda=cfg.nas.init_lmda,
min_lmda=cfg.nas.min_lmda,
lmda_decay_step=cfg.nas.lmda_decay_step,
lmda_decay_rate=cfg.nas.lmda_decay_rate,
fixed_lmda=cfg.nas.fixed_lmda
)
engine.run(**engine_run_kwargs(cfg))
print('*** Display the found architecture ***')
if cfg.use_gpu:
model.module.build_child_graph()
else:
model.build_child_graph()
if __name__ == '__main__':
main()

View File

@ -0,0 +1,44 @@
model:
name: 'osnet_nas'
pretrained: False
nas:
mc_iter: 1
init_lmda: 10.
min_lmda: 1.
lmda_decay_step: 20
lmda_decay_rate: 0.5
fixed_lmda: False
data:
type: 'image'
sources: ['msmt17']
targets: ['market1501']
height: 256
width: 128
combineall: True
transforms: ['random_flip', 'color_jitter']
save_dir: 'log/osnet_nas'
loss:
name: 'softmax'
softmax:
label_smooth: True
train:
optim: 'sgd'
lr: 0.1
max_epoch: 120
batch_size: 512
fixbase_epoch: 0
open_layers: ['classifier']
lr_scheduler: 'cosine'
test:
batch_size: 300
dist_metric: 'cosine'
normalize_feature: False
evaluate: False
eval_freq: -1
rerank: False
visactmap: False

View File

@ -0,0 +1,447 @@
from __future__ import absolute_import
from __future__ import division
import torch
from torch import nn
from torch.nn import functional as F
##########
# Basic layers
##########
class ConvLayer(nn.Module):
"""Convolution layer (conv + bn + relu)."""
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, groups=1, IN=False):
super(ConvLayer, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride,
padding=padding, bias=False, groups=groups)
if IN:
self.bn = nn.InstanceNorm2d(out_channels, affine=True)
else:
self.bn = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return self.relu(x)
class Conv1x1(nn.Module):
"""1x1 convolution + bn + relu."""
def __init__(self, in_channels, out_channels, stride=1, groups=1):
super(Conv1x1, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, 1, stride=stride, padding=0,
bias=False, groups=groups)
self.bn = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return self.relu(x)
class Conv1x1Linear(nn.Module):
"""1x1 convolution + bn (w/o non-linearity)."""
def __init__(self, in_channels, out_channels, stride=1, bn=True):
super(Conv1x1Linear, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, 1, stride=stride, padding=0, bias=False)
self.bn = None
if bn:
self.bn = nn.BatchNorm2d(out_channels)
def forward(self, x):
x = self.conv(x)
if self.bn is not None:
x = self.bn(x)
return x
class Conv3x3(nn.Module):
"""3x3 convolution + bn + relu."""
def __init__(self, in_channels, out_channels, stride=1, groups=1):
super(Conv3x3, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, 3, stride=stride, padding=1,
bias=False, groups=groups)
self.bn = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return self.relu(x)
class LightConv3x3(nn.Module):
"""Lightweight 3x3 convolution.
1x1 (linear) + dw 3x3 (nonlinear).
"""
def __init__(self, in_channels, out_channels):
super(LightConv3x3, self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, 1, stride=1, padding=0, bias=False)
self.conv2 = nn.Conv2d(out_channels, out_channels, 3, stride=1, padding=1, bias=False, groups=out_channels)
self.bn = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.bn(x)
return self.relu(x)
class LightConvStream(nn.Module):
"""Lightweight convolution stream."""
def __init__(self, in_channels, out_channels, depth):
super(LightConvStream, self).__init__()
assert depth >= 1, 'depth must be equal to or larger than 1, but got {}'.format(depth)
layers = []
layers += [LightConv3x3(in_channels, out_channels)]
for i in range(depth-1):
layers += [LightConv3x3(out_channels, out_channels)]
self.layers = nn.Sequential(*layers)
def forward(self, x):
return self.layers(x)
##########
# Building blocks for omni-scale feature learning
##########
class ChannelGate(nn.Module):
"""A mini-network that generates channel-wise gates conditioned on input tensor."""
def __init__(self, in_channels, num_gates=None, return_gates=False,
gate_activation='sigmoid', reduction=16, layer_norm=False):
super(ChannelGate, self).__init__()
if num_gates is None:
num_gates = in_channels
self.return_gates = return_gates
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
self.fc1 = nn.Conv2d(in_channels, in_channels//reduction, kernel_size=1, bias=True, padding=0)
self.norm1 = None
if layer_norm:
self.norm1 = nn.LayerNorm((in_channels//reduction, 1, 1))
self.relu = nn.ReLU(inplace=True)
self.fc2 = nn.Conv2d(in_channels//reduction, num_gates, kernel_size=1, bias=True, padding=0)
if gate_activation == 'sigmoid':
self.gate_activation = nn.Sigmoid()
elif gate_activation == 'relu':
self.gate_activation = nn.ReLU(inplace=True)
elif gate_activation == 'linear':
self.gate_activation = None
else:
raise RuntimeError("Unknown gate activation: {}".format(gate_activation))
def forward(self, x):
input = x
x = self.global_avgpool(x)
x = self.fc1(x)
if self.norm1 is not None:
x = self.norm1(x)
x = self.relu(x)
x = self.fc2(x)
if self.gate_activation is not None:
x = self.gate_activation(x)
if self.return_gates:
return x
return input * x
class OSBlock(nn.Module):
"""Omni-scale feature learning block."""
def __init__(self, in_channels, out_channels, reduction=4, T=4, **kwargs):
super(OSBlock, self).__init__()
assert T >= 1
assert out_channels>=reduction and out_channels%reduction==0
mid_channels = out_channels // reduction
self.conv1 = Conv1x1(in_channels, mid_channels)
self.conv2 = nn.ModuleList()
for t in range(1, T+1):
self.conv2 += [LightConvStream(mid_channels, mid_channels, t)]
self.gate = ChannelGate(mid_channels)
self.conv3 = Conv1x1Linear(mid_channels, out_channels)
self.downsample = None
if in_channels != out_channels:
self.downsample = Conv1x1Linear(in_channels, out_channels)
def forward(self, x):
identity = x
x1 = self.conv1(x)
x2 = 0
for conv2_t in self.conv2:
x2_t = conv2_t(x1)
x2 = x2 + self.gate(x2_t)
x3 = self.conv3(x2)
if self.downsample is not None:
identity = self.downsample(identity)
out = x3 + identity
return F.relu(out)
class OSBlockINv1(nn.Module):
"""Omni-scale feature learning block with instance normalization."""
def __init__(self, in_channels, out_channels, reduction=4, T=4, **kwargs):
super(OSBlockINv1, self).__init__()
assert T >= 1
assert out_channels>=reduction and out_channels%reduction==0
mid_channels = out_channels // reduction
self.conv1 = Conv1x1(in_channels, mid_channels)
self.conv2 = nn.ModuleList()
for t in range(1, T+1):
self.conv2 += [LightConvStream(mid_channels, mid_channels, t)]
self.gate = ChannelGate(mid_channels)
self.conv3 = Conv1x1Linear(mid_channels, out_channels, bn=False)
self.downsample = None
if in_channels != out_channels:
self.downsample = Conv1x1Linear(in_channels, out_channels)
self.IN = nn.InstanceNorm2d(out_channels, affine=True)
def forward(self, x):
identity = x
x1 = self.conv1(x)
x2 = 0
for conv2_t in self.conv2:
x2_t = conv2_t(x1)
x2 = x2 + self.gate(x2_t)
x3 = self.conv3(x2)
x3 = self.IN(x3) # IN inside residual
if self.downsample is not None:
identity = self.downsample(identity)
out = x3 + identity
return F.relu(out)
class OSBlockINv2(nn.Module):
"""Omni-scale feature learning block with instance normalization."""
def __init__(self, in_channels, out_channels, reduction=4, T=4, **kwargs):
super(OSBlockINv2, self).__init__()
assert T >= 1
assert out_channels>=reduction and out_channels%reduction==0
mid_channels = out_channels // reduction
self.conv1 = Conv1x1(in_channels, mid_channels)
self.conv2 = nn.ModuleList()
for t in range(1, T+1):
self.conv2 += [LightConvStream(mid_channels, mid_channels, t)]
self.gate = ChannelGate(mid_channels)
self.conv3 = Conv1x1Linear(mid_channels, out_channels)
self.downsample = None
if in_channels != out_channels:
self.downsample = Conv1x1Linear(in_channels, out_channels)
self.IN = nn.InstanceNorm2d(out_channels, affine=True)
def forward(self, x):
identity = x
x1 = self.conv1(x)
x2 = 0
for conv2_t in self.conv2:
x2_t = conv2_t(x1)
x2 = x2 + self.gate(x2_t)
x3 = self.conv3(x2)
if self.downsample is not None:
identity = self.downsample(identity)
out = x3 + identity
out = self.IN(out) # IN outside residual
return F.relu(out)
class OSBlockINv3(nn.Module):
"""Omni-scale feature learning block with instance normalization."""
def __init__(self, in_channels, out_channels, reduction=4, T=4, **kwargs):
super(OSBlockINv3, self).__init__()
assert T >= 1
assert out_channels>=reduction and out_channels%reduction==0
mid_channels = out_channels // reduction
self.conv1 = Conv1x1(in_channels, mid_channels)
self.conv2 = nn.ModuleList()
for t in range(1, T+1):
self.conv2 += [LightConvStream(mid_channels, mid_channels, t)]
self.gate = ChannelGate(mid_channels)
self.conv3 = Conv1x1Linear(mid_channels, out_channels, bn=False)
self.downsample = None
if in_channels != out_channels:
self.downsample = Conv1x1Linear(in_channels, out_channels)
self.IN_in = nn.InstanceNorm2d(out_channels, affine=True)
self.IN_out = nn.InstanceNorm2d(out_channels, affine=True)
def forward(self, x):
identity = x
x1 = self.conv1(x)
x2 = 0
for conv2_t in self.conv2:
x2_t = conv2_t(x1)
x2 = x2 + self.gate(x2_t)
x3 = self.conv3(x2)
x3 = self.IN_in(x3) # IN inside residual
if self.downsample is not None:
identity = self.downsample(identity)
out = x3 + identity
out = self.IN_out(out) # IN outside residual
return F.relu(out)
##########
# Network architecture
##########
class OSNet(nn.Module):
"""Omni-Scale Network.
Reference:
- Zhou et al. Omni-Scale Feature Learning for Person Re-Identification. ICCV, 2019.
- Zhou et al. Learning Generalisable Omni-Scale Representations
for Person Re-Identification. arXiv preprint, 2019.
"""
def __init__(self, num_classes, blocks, layers, channels, feature_dim=512, loss='softmax',
conv1_IN=True, **kwargs):
super(OSNet, self).__init__()
num_blocks = len(blocks)
assert num_blocks == len(layers)
assert num_blocks == len(channels) - 1
self.loss = loss
self.feature_dim = feature_dim
# convolutional backbone
self.conv1 = ConvLayer(3, channels[0], 7, stride=2, padding=3, IN=conv1_IN)
self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)
self.conv2 = self._make_layer(blocks[0], layers[0], channels[0], channels[1])
self.pool2 = nn.Sequential(Conv1x1(channels[1], channels[1]), nn.AvgPool2d(2, stride=2))
self.conv3 = self._make_layer(blocks[1], layers[1], channels[1], channels[2])
self.pool3 = nn.Sequential(Conv1x1(channels[2], channels[2]), nn.AvgPool2d(2, stride=2))
self.conv4 = self._make_layer(blocks[2], layers[2], channels[2], channels[3])
self.conv5 = Conv1x1(channels[3], channels[3])
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
# fully connected layer
self.fc = self._construct_fc_layer(self.feature_dim, channels[3], dropout_p=None)
# identity classification layer
self.classifier = nn.Linear(self.feature_dim, num_classes)
self._init_params()
def _make_layer(self, blocks, layer, in_channels, out_channels):
layers = []
layers += [blocks[0](in_channels, out_channels)]
for i in range(1, len(blocks)):
layers += [blocks[i](out_channels, out_channels)]
return nn.Sequential(*layers)
def _construct_fc_layer(self, fc_dims, input_dim, dropout_p=None):
if fc_dims is None or fc_dims<0:
self.feature_dim = input_dim
return None
if isinstance(fc_dims, int):
fc_dims = [fc_dims]
layers = []
for dim in fc_dims:
layers.append(nn.Linear(input_dim, dim))
layers.append(nn.BatchNorm1d(dim))
layers.append(nn.ReLU(inplace=True))
if dropout_p is not None:
layers.append(nn.Dropout(p=dropout_p))
input_dim = dim
self.feature_dim = fc_dims[-1]
return nn.Sequential(*layers)
def _init_params(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm1d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.InstanceNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def featuremaps(self, x):
x = self.conv1(x)
x = self.maxpool(x)
x = self.conv2(x)
x = self.pool2(x)
x = self.conv3(x)
x = self.pool3(x)
x = self.conv4(x)
return self.conv5(x)
def forward(self, x, return_featuremaps=False, **kwargs):
x = self.featuremaps(x)
if return_featuremaps:
return x
v = self.global_avgpool(x)
v = v.view(v.size(0), -1)
if self.fc is not None:
v = self.fc(v)
if not self.training:
return v
y = self.classifier(v)
if self.loss == 'softmax':
return y
elif self.loss == 'triplet':
return y, v
else:
raise KeyError("Unsupported loss: {}".format(self.loss))
##########
# Instantiation
##########
def osnet_ain_x1_0(num_classes=1000, pretrained=True, loss='softmax', **kwargs):
model = OSNet(
num_classes,
blocks=[
[OSBlockINv1, OSBlockINv1],
[OSBlock, OSBlockINv1],
[OSBlockINv1, OSBlock]
],
layers=[2, 2, 2],
channels=[64, 256, 384, 512],
loss=loss,
conv1_IN=True,
**kwargs
)
return model
__models = {
'osnet_ain_x1_0': osnet_ain_x1_0
}
def build_model(name, num_classes=100):
avai_models = list(__models.keys())
if name not in avai_models:
raise KeyError('Unknown model: {}. Must be one of {}'.format(name, avai_models))
return __models[name](num_classes=num_classes)

View File

@ -0,0 +1,477 @@
from __future__ import absolute_import
from __future__ import division
import torch
from torch import nn
from torch.nn import functional as F
EPS = 1e-12
NORM_AFFINE = False # enable affine transformations for normalization layer
##########
# Basic layers
##########
class ConvLayer(nn.Module):
"""Convolution layer (conv + bn + relu)."""
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, groups=1, IN=False):
super(ConvLayer, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride,
padding=padding, bias=False, groups=groups)
if IN:
self.bn = nn.InstanceNorm2d(out_channels, affine=NORM_AFFINE)
else:
self.bn = nn.BatchNorm2d(out_channels, affine=NORM_AFFINE)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return self.relu(x)
class Conv1x1(nn.Module):
"""1x1 convolution + bn + relu."""
def __init__(self, in_channels, out_channels, stride=1, groups=1, ibn=False):
super(Conv1x1, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, 1, stride=stride, padding=0,
bias=False, groups=groups)
if ibn:
self.bn = IBN(out_channels)
else:
self.bn = nn.BatchNorm2d(out_channels, affine=NORM_AFFINE)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return self.relu(x)
class Conv1x1Linear(nn.Module):
"""1x1 convolution + bn (w/o non-linearity)."""
def __init__(self, in_channels, out_channels, stride=1, bn=True):
super(Conv1x1Linear, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, 1, stride=stride, padding=0, bias=False)
self.bn = None
if bn:
self.bn = nn.BatchNorm2d(out_channels, affine=NORM_AFFINE)
def forward(self, x):
x = self.conv(x)
if self.bn is not None:
x = self.bn(x)
return x
class Conv3x3(nn.Module):
"""3x3 convolution + bn + relu."""
def __init__(self, in_channels, out_channels, stride=1, groups=1):
super(Conv3x3, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, 3, stride=stride, padding=1,
bias=False, groups=groups)
self.bn = nn.BatchNorm2d(out_channels, affine=NORM_AFFINE)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return self.relu(x)
class LightConv3x3(nn.Module):
"""Lightweight 3x3 convolution.
1x1 (linear) + dw 3x3 (nonlinear).
"""
def __init__(self, in_channels, out_channels):
super(LightConv3x3, self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, 1, stride=1, padding=0, bias=False)
self.conv2 = nn.Conv2d(out_channels, out_channels, 3, stride=1, padding=1, bias=False, groups=out_channels)
self.bn = nn.BatchNorm2d(out_channels, affine=NORM_AFFINE)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.bn(x)
return self.relu(x)
class LightConvStream(nn.Module):
"""Lightweight convolution stream."""
def __init__(self, in_channels, out_channels, depth):
super(LightConvStream, self).__init__()
assert depth >= 1, 'depth must be equal to or larger than 1, but got {}'.format(depth)
layers = []
layers += [LightConv3x3(in_channels, out_channels)]
for i in range(depth-1):
layers += [LightConv3x3(out_channels, out_channels)]
self.layers = nn.Sequential(*layers)
def forward(self, x):
return self.layers(x)
##########
# Building blocks for omni-scale feature learning
##########
class ChannelGate(nn.Module):
"""A mini-network that generates channel-wise gates conditioned on input tensor."""
def __init__(self, in_channels, num_gates=None, return_gates=False,
gate_activation='sigmoid', reduction=16, layer_norm=False):
super(ChannelGate, self).__init__()
if num_gates is None:
num_gates = in_channels
self.return_gates = return_gates
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
self.fc1 = nn.Conv2d(in_channels, in_channels//reduction, kernel_size=1, bias=True, padding=0)
self.norm1 = None
if layer_norm:
self.norm1 = nn.LayerNorm((in_channels//reduction, 1, 1))
self.relu = nn.ReLU(inplace=True)
self.fc2 = nn.Conv2d(in_channels//reduction, num_gates, kernel_size=1, bias=True, padding=0)
if gate_activation == 'sigmoid':
self.gate_activation = nn.Sigmoid()
elif gate_activation == 'relu':
self.gate_activation = nn.ReLU(inplace=True)
elif gate_activation == 'linear':
self.gate_activation = None
else:
raise RuntimeError("Unknown gate activation: {}".format(gate_activation))
def forward(self, x):
input = x
x = self.global_avgpool(x)
x = self.fc1(x)
if self.norm1 is not None:
x = self.norm1(x)
x = self.relu(x)
x = self.fc2(x)
if self.gate_activation is not None:
x = self.gate_activation(x)
if self.return_gates:
return x
return input * x
class OSBlock(nn.Module):
"""Omni-scale feature learning block."""
def __init__(self, in_channels, out_channels, reduction=4, T=4, **kwargs):
super(OSBlock, self).__init__()
assert T >= 1
assert out_channels>=reduction and out_channels%reduction==0
mid_channels = out_channels // reduction
self.conv1 = Conv1x1(in_channels, mid_channels)
self.conv2 = nn.ModuleList()
for t in range(1, T+1):
self.conv2 += [LightConvStream(mid_channels, mid_channels, t)]
self.gate = ChannelGate(mid_channels)
self.conv3 = Conv1x1Linear(mid_channels, out_channels)
self.downsample = None
if in_channels != out_channels:
self.downsample = Conv1x1Linear(in_channels, out_channels)
def forward(self, x):
identity = x
x1 = self.conv1(x)
x2 = 0
for conv2_t in self.conv2:
x2_t = conv2_t(x1)
x2 = x2 + self.gate(x2_t)
x3 = self.conv3(x2)
if self.downsample is not None:
identity = self.downsample(identity)
out = x3 + identity
return F.relu(out)
class OSBlockINv1(nn.Module):
"""Omni-scale feature learning block with instance normalization."""
def __init__(self, in_channels, out_channels, reduction=4, T=4, **kwargs):
super(OSBlockINv1, self).__init__()
assert T >= 1
assert out_channels>=reduction and out_channels%reduction==0
mid_channels = out_channels // reduction
self.conv1 = Conv1x1(in_channels, mid_channels)
self.conv2 = nn.ModuleList()
for t in range(1, T+1):
self.conv2 += [LightConvStream(mid_channels, mid_channels, t)]
self.gate = ChannelGate(mid_channels)
self.conv3 = Conv1x1Linear(mid_channels, out_channels, bn=False)
self.downsample = None
if in_channels != out_channels:
self.downsample = Conv1x1Linear(in_channels, out_channels)
self.IN = nn.InstanceNorm2d(out_channels, affine=NORM_AFFINE)
def forward(self, x):
identity = x
x1 = self.conv1(x)
x2 = 0
for conv2_t in self.conv2:
x2_t = conv2_t(x1)
x2 = x2 + self.gate(x2_t)
x3 = self.conv3(x2)
x3 = self.IN(x3) # IN inside residual
if self.downsample is not None:
identity = self.downsample(identity)
out = x3 + identity
return F.relu(out)
class OSBlockINv2(nn.Module):
"""Omni-scale feature learning block with instance normalization."""
def __init__(self, in_channels, out_channels, reduction=4, T=4, **kwargs):
super(OSBlockINv2, self).__init__()
assert T >= 1
assert out_channels>=reduction and out_channels%reduction==0
mid_channels = out_channels // reduction
self.conv1 = Conv1x1(in_channels, mid_channels)
self.conv2 = nn.ModuleList()
for t in range(1, T+1):
self.conv2 += [LightConvStream(mid_channels, mid_channels, t)]
self.gate = ChannelGate(mid_channels)
self.conv3 = Conv1x1Linear(mid_channels, out_channels)
self.downsample = None
if in_channels != out_channels:
self.downsample = Conv1x1Linear(in_channels, out_channels)
self.IN = nn.InstanceNorm2d(out_channels, affine=NORM_AFFINE)
def forward(self, x):
identity = x
x1 = self.conv1(x)
x2 = 0
for conv2_t in self.conv2:
x2_t = conv2_t(x1)
x2 = x2 + self.gate(x2_t)
x3 = self.conv3(x2)
if self.downsample is not None:
identity = self.downsample(identity)
out = x3 + identity
out = self.IN(out) # IN outside residual
return F.relu(out)
class OSBlockINv3(nn.Module):
"""Omni-scale feature learning block with instance normalization."""
def __init__(self, in_channels, out_channels, reduction=4, T=4, **kwargs):
super(OSBlockINv3, self).__init__()
assert T >= 1
assert out_channels>=reduction and out_channels%reduction==0
mid_channels = out_channels // reduction
self.conv1 = Conv1x1(in_channels, mid_channels)
self.conv2 = nn.ModuleList()
for t in range(1, T+1):
self.conv2 += [LightConvStream(mid_channels, mid_channels, t)]
self.gate = ChannelGate(mid_channels)
self.conv3 = Conv1x1Linear(mid_channels, out_channels, bn=False)
self.downsample = None
if in_channels != out_channels:
self.downsample = Conv1x1Linear(in_channels, out_channels)
self.IN_in = nn.InstanceNorm2d(out_channels, affine=NORM_AFFINE)
self.IN_out = nn.InstanceNorm2d(out_channels, affine=NORM_AFFINE)
def forward(self, x):
identity = x
x1 = self.conv1(x)
x2 = 0
for conv2_t in self.conv2:
x2_t = conv2_t(x1)
x2 = x2 + self.gate(x2_t)
x3 = self.conv3(x2)
x3 = self.IN_in(x3) # inside residual
if self.downsample is not None:
identity = self.downsample(identity)
out = x3 + identity
out = self.IN_out(out) # IN outside residual
return F.relu(out)
class NASBlock(nn.Module):
"""Neural architecture search layer."""
def __init__(self, in_channels, out_channels, search_space=None):
super(NASBlock, self).__init__()
self._is_child_graph = False
self.search_space = search_space
if self.search_space is None:
raise ValueError('search_space is None')
self.os_block = nn.ModuleList()
for block in self.search_space:
self.os_block += [block(in_channels, out_channels)]
self.weights = nn.Parameter(torch.ones(len(self.search_space)))
def build_child_graph(self):
if self._is_child_graph:
raise RuntimeError('build_child_graph() can only be called once')
idx = self.weights.data.max(dim=0)[1].item()
self.os_block = self.os_block[idx]
self.weights = None
self._is_child_graph = True
return self.search_space[idx]
def forward(self, x, lmda=1.):
if self._is_child_graph:
return self.os_block(x)
uniform = torch.rand_like(self.weights)
gumbel = - torch.log(- torch.log(uniform + EPS))
nonneg_weights = F.relu(self.weights)
logits = torch.log(nonneg_weights + EPS) + gumbel
exp = torch.exp(logits / lmda)
weights_softmax = exp / (exp.sum() + EPS)
output = 0
for i, weight in enumerate(weights_softmax):
output = output + weight * self.os_block[i](x)
return output
##########
# Network architecture
##########
class OSNet(nn.Module):
"""Omni-Scale Network.
Reference:
- Zhou et al. Omni-Scale Feature Learning for Person Re-Identification. ICCV, 2019.
- Zhou et al. Learning Generalisable Omni-Scale Representations
for Person Re-Identification. arXiv preprint, 2019.
"""
def __init__(self, num_classes, blocks, layers, channels, feature_dim=512, loss='softmax',
search_space=None, **kwargs):
super(OSNet, self).__init__()
num_blocks = len(blocks)
assert num_blocks == len(layers)
assert num_blocks == len(channels) - 1
# no matter what loss is specified, the model only returns the ID predictions
self.loss = loss
self.feature_dim = feature_dim
# convolutional backbone
self.conv1 = ConvLayer(3, channels[0], 7, stride=2, padding=3, IN=True)
self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)
self.conv2 = self._make_layer(blocks[0], layers[0], channels[0], channels[1], search_space)
self.pool2 = nn.Sequential(Conv1x1(channels[1], channels[1]), nn.AvgPool2d(2, stride=2))
self.conv3 = self._make_layer(blocks[1], layers[1], channels[1], channels[2], search_space)
self.pool3 = nn.Sequential(Conv1x1(channels[2], channels[2]), nn.AvgPool2d(2, stride=2))
self.conv4 = self._make_layer(blocks[2], layers[2], channels[2], channels[3], search_space)
self.conv5 = Conv1x1(channels[3], channels[3])
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
# fully connected layer
self.fc = self._construct_fc_layer(self.feature_dim, channels[3], dropout_p=None)
# identity classification layer
self.classifier = nn.Linear(self.feature_dim, num_classes)
def _make_layer(self, block, layer, in_channels, out_channels, search_space):
layers = nn.ModuleList()
layers += [block(in_channels, out_channels, search_space=search_space)]
for i in range(1, layer):
layers += [block(out_channels, out_channels, search_space=search_space)]
return layers
def _construct_fc_layer(self, fc_dims, input_dim, dropout_p=None):
if fc_dims is None or fc_dims<0:
self.feature_dim = input_dim
return None
if isinstance(fc_dims, int):
fc_dims = [fc_dims]
layers = []
for dim in fc_dims:
layers.append(nn.Linear(input_dim, dim))
layers.append(nn.BatchNorm1d(dim, affine=NORM_AFFINE))
layers.append(nn.ReLU(inplace=True))
if dropout_p is not None:
layers.append(nn.Dropout(p=dropout_p))
input_dim = dim
self.feature_dim = fc_dims[-1]
return nn.Sequential(*layers)
def build_child_graph(self):
print('Building child graph')
for i, conv in enumerate(self.conv2):
block = conv.build_child_graph()
print('- conv2-{} Block={}'.format(i+1, block.__name__))
for i, conv in enumerate(self.conv3):
block = conv.build_child_graph()
print('- conv3-{} Block={}'.format(i+1, block.__name__))
for i, conv in enumerate(self.conv4):
block = conv.build_child_graph()
print('- conv4-{} Block={}'.format(i+1, block.__name__))
def featuremaps(self, x, lmda):
x = self.conv1(x)
x = self.maxpool(x)
for conv in self.conv2:
x = conv(x, lmda)
x = self.pool2(x)
for conv in self.conv3:
x = conv(x, lmda)
x = self.pool3(x)
for conv in self.conv4:
x = conv(x, lmda)
return self.conv5(x)
def forward(self, x, lmda=1., return_featuremaps=False):
# lmda (float): temperature parameter for concrete distribution
x = self.featuremaps(x, lmda)
if return_featuremaps:
return x
v = self.global_avgpool(x)
v = v.view(v.size(0), -1)
if self.fc is not None:
v = self.fc(v)
if not self.training:
return v
return self.classifier(v)
##########
# Instantiation
##########
def osnet_nas(num_classes=1000, loss='softmax', **kwargs):
# standard size (width x1.0)
return OSNet(
num_classes,
blocks=[NASBlock, NASBlock, NASBlock],
layers=[2, 2, 2],
channels=[64, 256, 384, 512],
loss=loss,
search_space=[OSBlock, OSBlockINv1, OSBlockINv2, OSBlockINv3],
**kwargs
)
__NAS_models = {
'osnet_nas': osnet_nas
}
def build_model(name, num_classes=100):
avai_models = list(__NAS_models.keys())
if name not in avai_models:
raise KeyError('Unknown model: {}. Must be one of {}'.format(name, avai_models))
return __NAS_models[name](num_classes=num_classes)

View File

@ -0,0 +1,111 @@
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import time
import datetime
import torch
import torchreid
from torchreid.engine import Engine
from torchreid.losses import CrossEntropyLoss
from torchreid.utils import AverageMeter, open_specified_layers, open_all_layers
from torchreid import metrics
class ImageSoftmaxNASEngine(Engine):
def __init__(self, datamanager, model, optimizer, scheduler=None, use_gpu=False,
label_smooth=True, mc_iter=1, init_lmda=1., min_lmda=1., lmda_decay_step=20,
lmda_decay_rate=0.5, fixed_lmda=False):
super(ImageSoftmaxNASEngine, self).__init__(datamanager, model, optimizer, scheduler, use_gpu)
self.mc_iter = mc_iter
self.init_lmda = init_lmda
self.min_lmda = min_lmda
self.lmda_decay_step = lmda_decay_step
self.lmda_decay_rate = lmda_decay_rate
self.fixed_lmda = fixed_lmda
self.criterion = CrossEntropyLoss(
num_classes=self.datamanager.num_train_pids,
use_gpu=self.use_gpu,
label_smooth=label_smooth
)
def train(self, epoch, max_epoch, trainloader, fixbase_epoch=0, open_layers=None, print_freq=10):
losses = AverageMeter()
accs = AverageMeter()
batch_time = AverageMeter()
data_time = AverageMeter()
self.model.train()
if (epoch+1)<=fixbase_epoch and open_layers is not None:
print('* Only train {} (epoch: {}/{})'.format(open_layers, epoch+1, fixbase_epoch))
open_specified_layers(self.model, open_layers)
else:
open_all_layers(self.model)
num_batches = len(trainloader)
end = time.time()
for batch_idx, data in enumerate(trainloader):
data_time.update(time.time() - end)
imgs, pids = self._parse_data_for_train(data)
if self.use_gpu:
imgs = imgs.cuda()
pids = pids.cuda()
# softmax temporature
if self.fixed_lmda or self.lmda_decay_step==-1:
lmda = self.init_lmda
else:
lmda = self.init_lmda * self.lmda_decay_rate ** (epoch // self.lmda_decay_step)
if lmda < self.min_lmda:
lmda = self.min_lmda
for k in range(self.mc_iter):
outputs = self.model(imgs, lmda=lmda)
loss = self._compute_loss(self.criterion, outputs, pids)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
batch_time.update(time.time() - end)
losses.update(loss.item(), pids.size(0))
accs.update(metrics.accuracy(outputs, pids)[0].item())
if (batch_idx+1) % print_freq == 0:
# estimate remaining time
eta_seconds = batch_time.avg * (num_batches-(batch_idx+1) + (max_epoch-(epoch+1))*num_batches)
eta_str = str(datetime.timedelta(seconds=int(eta_seconds)))
print('Epoch: [{0}/{1}][{2}/{3}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
'Acc {acc.val:.2f} ({acc.avg:.2f})\t'
'Lr {lr:.6f}\t'
'eta {eta}'.format(
epoch+1, max_epoch, batch_idx+1, len(trainloader),
batch_time=batch_time,
data_time=data_time,
loss=losses,
acc=accs,
lr=self.optimizer.param_groups[0]['lr'],
eta=eta_str
)
)
if self.writer is not None:
n_iter = epoch * num_batches + batch_idx
self.writer.add_scalar('Train/Time', batch_time.avg, n_iter)
self.writer.add_scalar('Train/Data', data_time.avg, n_iter)
self.writer.add_scalar('Train/Loss', losses.avg, n_iter)
self.writer.add_scalar('Train/Acc', accs.avg, n_iter)
self.writer.add_scalar('Train/Lr', self.optimizer.param_groups[0]['lr'], n_iter)
end = time.time()
if self.scheduler is not None:
self.scheduler.step()