mirror of https://github.com/JDAI-CV/fast-reid.git
v0.3 update
Summary: 1. change DPP training in apex way; 2. make warmup scheduler by iter and lr scheduler by epoch; 3. replace random erasing with torchvision implementation; 4. naming modification in config filepull/365/head
parent
1b9799f601
commit
a327a70f0d
|
@ -1,4 +1,4 @@
|
|||
_BASE_: "Base-Strongerbaseline.yml"
|
||||
_BASE_: "Base-SBS.yml"
|
||||
|
||||
MODEL:
|
||||
META_ARCHITECTURE: 'MGN'
|
||||
|
|
|
@ -18,42 +18,50 @@ MODEL:
|
|||
CE:
|
||||
EPSILON: 0.1
|
||||
SCALE: 1.0
|
||||
|
||||
TRI:
|
||||
MARGIN: 0.0
|
||||
HARD_MINING: True
|
||||
NORM_FEAT: False
|
||||
SCALE: 1.0
|
||||
|
||||
COSFACE:
|
||||
MARGIN: 0.35
|
||||
GAMMA: 64
|
||||
SCALE: 1.
|
||||
|
||||
INPUT:
|
||||
SIZE_TRAIN: [384, 128]
|
||||
SIZE_TEST: [384, 128]
|
||||
|
||||
DO_AUTOAUG: True
|
||||
AUTOAUG_PROB: 0.1
|
||||
|
||||
DATALOADER:
|
||||
NUM_INSTANCE: 16
|
||||
|
||||
SOLVER:
|
||||
OPT: "Adam"
|
||||
MAX_ITER: 60
|
||||
MAX_EPOCH: 60
|
||||
BASE_LR: 0.00035
|
||||
BIAS_LR_FACTOR: 1.
|
||||
WEIGHT_DECAY: 0.0005
|
||||
WEIGHT_DECAY_BIAS: 0.0005
|
||||
IMS_PER_BATCH: 64
|
||||
|
||||
SCHED: "WarmupCosineAnnealingLR"
|
||||
DELAY_ITERS: 30
|
||||
SCHED: "CosineAnnealingLR"
|
||||
DELAY_EPOCHS: 30
|
||||
ETA_MIN_LR: 0.00000077
|
||||
|
||||
WARMUP_FACTOR: 0.01
|
||||
WARMUP_ITERS: 10
|
||||
FREEZE_ITERS: 10
|
||||
WARMUP_FACTOR: 0.1
|
||||
WARMUP_ITERS: 2000
|
||||
|
||||
CHECKPOINT_PERIOD: 30
|
||||
FREEZE_ITERS: 2000
|
||||
|
||||
CHECKPOINT_PERIOD: 20
|
||||
|
||||
TEST:
|
||||
EVAL_PERIOD: 30
|
||||
EVAL_PERIOD: 20
|
||||
IMS_PER_BATCH: 128
|
||||
|
||||
CUDNN_BENCHMARK: True
|
||||
|
|
@ -38,7 +38,6 @@ INPUT:
|
|||
REA:
|
||||
ENABLED: True
|
||||
PROB: 0.5
|
||||
MEAN: [123.675, 116.28, 103.53]
|
||||
DO_PAD: True
|
||||
|
||||
DATALOADER:
|
||||
|
@ -48,26 +47,26 @@ DATALOADER:
|
|||
NUM_WORKERS: 8
|
||||
|
||||
SOLVER:
|
||||
FP16_ENABLED: True
|
||||
OPT: "Adam"
|
||||
MAX_ITER: 120
|
||||
MAX_EPOCH: 120
|
||||
BASE_LR: 0.00035
|
||||
BIAS_LR_FACTOR: 2.
|
||||
WEIGHT_DECAY: 0.0005
|
||||
WEIGHT_DECAY_BIAS: 0.0005
|
||||
IMS_PER_BATCH: 64
|
||||
|
||||
SCHED: "WarmupMultiStepLR"
|
||||
SCHED: "MultiStepLR"
|
||||
STEPS: [40, 90]
|
||||
GAMMA: 0.1
|
||||
|
||||
WARMUP_FACTOR: 0.01
|
||||
WARMUP_ITERS: 10
|
||||
WARMUP_FACTOR: 0.1
|
||||
WARMUP_ITERS: 2000
|
||||
|
||||
CHECKPOINT_PERIOD: 60
|
||||
CHECKPOINT_PERIOD: 30
|
||||
|
||||
TEST:
|
||||
EVAL_PERIOD: 30
|
||||
IMS_PER_BATCH: 128
|
||||
|
||||
CUDNN_BENCHMARK: True
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
_BASE_: "../Base-Strongerbaseline.yml"
|
||||
_BASE_: "../Base-SBS.yml"
|
||||
|
||||
MODEL:
|
||||
BACKBONE:
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
_BASE_: "../Base-Strongerbaseline.yml"
|
||||
_BASE_: "../Base-SBS.yml"
|
||||
|
||||
MODEL:
|
||||
BACKBONE:
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
_BASE_: "../Base-Strongerbaseline.yml"
|
||||
_BASE_: "../Base-SBS.yml"
|
||||
|
||||
DATASETS:
|
||||
NAMES: ("DukeMTMC",)
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
_BASE_: "../Base-Strongerbaseline.yml"
|
||||
_BASE_: "../Base-SBS.yml"
|
||||
|
||||
MODEL:
|
||||
BACKBONE:
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
_BASE_: "../Base-Strongerbaseline.yml"
|
||||
_BASE_: "../Base-SBS.yml"
|
||||
|
||||
MODEL:
|
||||
BACKBONE:
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
_BASE_: "../Base-Strongerbaseline.yml"
|
||||
_BASE_: "../Base-SBS.yml"
|
||||
|
||||
MODEL:
|
||||
BACKBONE:
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
_BASE_: "../Base-Strongerbaseline.yml"
|
||||
_BASE_: "../Base-SBS.yml"
|
||||
|
||||
DATASETS:
|
||||
NAMES: ("MSMT17",)
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
_BASE_: "../Base-Strongerbaseline.yml"
|
||||
_BASE_: "../Base-SBS.yml"
|
||||
|
||||
MODEL:
|
||||
BACKBONE:
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
_BASE_: "../Base-Strongerbaseline.yml"
|
||||
_BASE_: "../Base-SBS.yml"
|
||||
|
||||
MODEL:
|
||||
BACKBONE:
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
_BASE_: "../Base-Strongerbaseline.yml"
|
||||
_BASE_: "../Base-SBS.yml"
|
||||
|
||||
MODEL:
|
||||
BACKBONE:
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
_BASE_: "../Base-Strongerbaseline.yml"
|
||||
_BASE_: "../Base-SBS.yml"
|
||||
|
||||
DATASETS:
|
||||
NAMES: ("Market1501",)
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
_BASE_: "../Base-Strongerbaseline.yml"
|
||||
_BASE_: "../Base-SBS.yml"
|
||||
|
||||
MODEL:
|
||||
BACKBONE:
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
_BASE_: "../Base-Strongerbaseline.yml"
|
||||
_BASE_: "../Base-SBS.yml"
|
||||
|
||||
INPUT:
|
||||
SIZE_TRAIN: [256, 256]
|
||||
|
|
|
@ -2,8 +2,8 @@
|
|||
|
||||
## Introduction
|
||||
|
||||
This file documents collection of baselines trained with fastreid. All numbers were obtained with 1 NVIDIA P40 GPU.
|
||||
The software in use were PyTorch 1.4, CUDA 10.1.
|
||||
This file documents collection of baselines trained with fastreid. All numbers were obtained with 1 NVIDIA V100 GPU.
|
||||
The software in use were PyTorch 1.6, CUDA 10.1.
|
||||
|
||||
In addition to these official baseline models, you can find more models in [projects/](https://github.com/JDAI-CV/fast-reid/tree/master/projects).
|
||||
|
||||
|
|
|
@ -86,6 +86,12 @@ _C.MODEL.LOSSES.CE.EPSILON = 0.0
|
|||
_C.MODEL.LOSSES.CE.ALPHA = 0.2
|
||||
_C.MODEL.LOSSES.CE.SCALE = 1.0
|
||||
|
||||
# Focal Loss options
|
||||
_C.MODEL.LOSSES.FL = CN()
|
||||
_C.MODEL.LOSSES.FL.ALPHA = 0.25
|
||||
_C.MODEL.LOSSES.FL.GAMMA = 2
|
||||
_C.MODEL.LOSSES.FL.SCALE = 1.0
|
||||
|
||||
# Triplet Loss options
|
||||
_C.MODEL.LOSSES.TRI = CN()
|
||||
_C.MODEL.LOSSES.TRI.MARGIN = 0.3
|
||||
|
@ -96,14 +102,14 @@ _C.MODEL.LOSSES.TRI.SCALE = 1.0
|
|||
# Circle Loss options
|
||||
_C.MODEL.LOSSES.CIRCLE = CN()
|
||||
_C.MODEL.LOSSES.CIRCLE.MARGIN = 0.25
|
||||
_C.MODEL.LOSSES.CIRCLE.ALPHA = 128
|
||||
_C.MODEL.LOSSES.CIRCLE.GAMMA = 128
|
||||
_C.MODEL.LOSSES.CIRCLE.SCALE = 1.0
|
||||
|
||||
# Focal Loss options
|
||||
_C.MODEL.LOSSES.FL = CN()
|
||||
_C.MODEL.LOSSES.FL.ALPHA = 0.25
|
||||
_C.MODEL.LOSSES.FL.GAMMA = 2
|
||||
_C.MODEL.LOSSES.FL.SCALE = 1.0
|
||||
# Cosface Loss options
|
||||
_C.MODEL.LOSSES.COSFACE = CN()
|
||||
_C.MODEL.LOSSES.COSFACE.MARGIN = 0.25
|
||||
_C.MODEL.LOSSES.COSFACE.GAMMA = 128
|
||||
_C.MODEL.LOSSES.COSFACE.SCALE = 1.0
|
||||
|
||||
# Path to a checkpoint file to be loaded to the model. You can find available models in the model zoo.
|
||||
_C.MODEL.WEIGHTS = ""
|
||||
|
@ -131,6 +137,7 @@ _C.INPUT.FLIP_PROB = 0.5
|
|||
_C.INPUT.DO_PAD = True
|
||||
_C.INPUT.PADDING_MODE = 'constant'
|
||||
_C.INPUT.PADDING = 10
|
||||
|
||||
# Random color jitter
|
||||
_C.INPUT.CJ = CN()
|
||||
_C.INPUT.CJ.ENABLED = False
|
||||
|
@ -139,15 +146,20 @@ _C.INPUT.CJ.BRIGHTNESS = 0.15
|
|||
_C.INPUT.CJ.CONTRAST = 0.15
|
||||
_C.INPUT.CJ.SATURATION = 0.1
|
||||
_C.INPUT.CJ.HUE = 0.1
|
||||
|
||||
# Auto augmentation
|
||||
_C.INPUT.DO_AUTOAUG = False
|
||||
_C.INPUT.AUTOAUG_PROB = 0.0
|
||||
|
||||
# Augmix augmentation
|
||||
_C.INPUT.DO_AUGMIX = False
|
||||
_C.INPUT.AUGMIX_PROB = 0.0
|
||||
|
||||
# Random Erasing
|
||||
_C.INPUT.REA = CN()
|
||||
_C.INPUT.REA.ENABLED = False
|
||||
_C.INPUT.REA.PROB = 0.5
|
||||
_C.INPUT.REA.MEAN = [0.596*255, 0.558*255, 0.497*255] # [0.485*255, 0.456*255, 0.406*255]
|
||||
_C.INPUT.REA.VALUE = [0.485*255, 0.456*255, 0.406*255]
|
||||
# Random Patch
|
||||
_C.INPUT.RPT = CN()
|
||||
_C.INPUT.RPT.ENABLED = False
|
||||
|
@ -182,12 +194,12 @@ _C.DATALOADER.NUM_WORKERS = 8
|
|||
_C.SOLVER = CN()
|
||||
|
||||
# AUTOMATIC MIXED PRECISION
|
||||
_C.SOLVER.AMP_ENABLED = False
|
||||
_C.SOLVER.FP16_ENABLED = False
|
||||
|
||||
# Optimizer
|
||||
_C.SOLVER.OPT = "Adam"
|
||||
|
||||
_C.SOLVER.MAX_ITER = 120
|
||||
_C.SOLVER.MAX_EPOCH = 120
|
||||
|
||||
_C.SOLVER.BASE_LR = 3e-4
|
||||
_C.SOLVER.BIAS_LR_FACTOR = 1.
|
||||
|
@ -199,13 +211,15 @@ _C.SOLVER.WEIGHT_DECAY = 0.0005
|
|||
_C.SOLVER.WEIGHT_DECAY_BIAS = 0.
|
||||
|
||||
# Multi-step learning rate options
|
||||
_C.SOLVER.SCHED = "WarmupMultiStepLR"
|
||||
_C.SOLVER.SCHED = "MultiStepLR"
|
||||
|
||||
_C.SOLVER.DELAY_EPOCHS = 0
|
||||
|
||||
_C.SOLVER.GAMMA = 0.1
|
||||
_C.SOLVER.STEPS = [30, 55]
|
||||
|
||||
# Cosine annealing learning rate options
|
||||
_C.SOLVER.DELAY_ITERS = 0
|
||||
_C.SOLVER.ETA_MIN_LR = 3e-7
|
||||
_C.SOLVER.ETA_MIN_LR = 1e-7
|
||||
|
||||
# Warmup options
|
||||
_C.SOLVER.WARMUP_FACTOR = 0.1
|
||||
|
@ -215,13 +229,13 @@ _C.SOLVER.WARMUP_METHOD = "linear"
|
|||
_C.SOLVER.FREEZE_ITERS = 0
|
||||
|
||||
# SWA options
|
||||
_C.SOLVER.SWA = CN()
|
||||
_C.SOLVER.SWA.ENABLED = False
|
||||
_C.SOLVER.SWA.ITER = 10
|
||||
_C.SOLVER.SWA.PERIOD = 2
|
||||
_C.SOLVER.SWA.LR_FACTOR = 10.
|
||||
_C.SOLVER.SWA.ETA_MIN_LR = 3.5e-6
|
||||
_C.SOLVER.SWA.LR_SCHED = False
|
||||
# _C.SOLVER.SWA = CN()
|
||||
# _C.SOLVER.SWA.ENABLED = False
|
||||
# _C.SOLVER.SWA.ITER = 10
|
||||
# _C.SOLVER.SWA.PERIOD = 2
|
||||
# _C.SOLVER.SWA.LR_FACTOR = 10.
|
||||
# _C.SOLVER.SWA.ETA_MIN_LR = 3.5e-6
|
||||
# _C.SOLVER.SWA.LR_SCHED = False
|
||||
|
||||
_C.SOLVER.CHECKPOINT_PERIOD = 20
|
||||
|
||||
|
|
|
@ -5,11 +5,12 @@
|
|||
"""
|
||||
|
||||
import os
|
||||
|
||||
import torch
|
||||
from torch._six import container_abcs, string_classes, int_classes
|
||||
from torch.utils.data import DataLoader
|
||||
from fastreid.utils import comm
|
||||
|
||||
from fastreid.utils import comm
|
||||
from . import samplers
|
||||
from .common import CommDataset
|
||||
from .datasets import DATASET_REGISTRY
|
||||
|
@ -18,9 +19,8 @@ from .transforms import build_transforms
|
|||
_root = os.getenv("FASTREID_DATASETS", "datasets")
|
||||
|
||||
|
||||
def build_reid_train_loader(cfg):
|
||||
def build_reid_train_loader(cfg, mapper=None):
|
||||
cfg = cfg.clone()
|
||||
cfg.defrost()
|
||||
|
||||
train_items = list()
|
||||
for d in cfg.DATASETS.NAMES:
|
||||
|
@ -29,10 +29,12 @@ def build_reid_train_loader(cfg):
|
|||
dataset.show_train()
|
||||
train_items.extend(dataset.train)
|
||||
|
||||
iters_per_epoch = len(train_items) // cfg.SOLVER.IMS_PER_BATCH
|
||||
cfg.SOLVER.MAX_ITER *= iters_per_epoch
|
||||
train_transforms = build_transforms(cfg, is_train=True)
|
||||
train_set = CommDataset(train_items, train_transforms, relabel=True)
|
||||
if mapper is not None:
|
||||
transforms = mapper
|
||||
else:
|
||||
transforms = build_transforms(cfg, is_train=True)
|
||||
|
||||
train_set = CommDataset(train_items, transforms, relabel=True)
|
||||
|
||||
num_workers = cfg.DATALOADER.NUM_WORKERS
|
||||
num_instance = cfg.DATALOADER.NUM_INSTANCE
|
||||
|
@ -40,11 +42,9 @@ def build_reid_train_loader(cfg):
|
|||
|
||||
if cfg.DATALOADER.PK_SAMPLER:
|
||||
if cfg.DATALOADER.NAIVE_WAY:
|
||||
data_sampler = samplers.NaiveIdentitySampler(train_set.img_items,
|
||||
cfg.SOLVER.IMS_PER_BATCH, num_instance)
|
||||
data_sampler = samplers.NaiveIdentitySampler(train_set.img_items, mini_batch_size, num_instance)
|
||||
else:
|
||||
data_sampler = samplers.BalancedIdentitySampler(train_set.img_items,
|
||||
cfg.SOLVER.IMS_PER_BATCH, num_instance)
|
||||
data_sampler = samplers.BalancedIdentitySampler(train_set.img_items, mini_batch_size, num_instance)
|
||||
else:
|
||||
data_sampler = samplers.TrainingSampler(len(train_set))
|
||||
batch_sampler = torch.utils.data.sampler.BatchSampler(data_sampler, mini_batch_size, True)
|
||||
|
@ -61,7 +61,6 @@ def build_reid_train_loader(cfg):
|
|||
|
||||
def build_reid_test_loader(cfg, dataset_name):
|
||||
cfg = cfg.clone()
|
||||
cfg.defrost()
|
||||
|
||||
dataset = DATASET_REGISTRY.get(dataset_name)(root=_root)
|
||||
if comm.is_main_process():
|
||||
|
|
|
@ -5,14 +5,10 @@
|
|||
"""
|
||||
|
||||
import os
|
||||
from scipy.io import loadmat
|
||||
from glob import glob
|
||||
|
||||
from fastreid.data.datasets import DATASET_REGISTRY
|
||||
from fastreid.data.datasets.bases import ImageDataset
|
||||
import pdb
|
||||
import random
|
||||
import numpy as np
|
||||
|
||||
__all__ = ['PeS3D',]
|
||||
|
||||
|
|
|
@ -20,14 +20,35 @@ def no_index(a, b):
|
|||
return [i for i, j in enumerate(a) if j != b]
|
||||
|
||||
|
||||
class BalancedIdentitySampler(Sampler):
|
||||
def __init__(self, data_source: str, batch_size: int, num_instances: int, seed: Optional[int] = None):
|
||||
self.data_source = data_source
|
||||
self.batch_size = batch_size
|
||||
self.num_instances = num_instances
|
||||
self.num_pids_per_batch = batch_size // self.num_instances
|
||||
def reorder_index(batch_indices, world_size):
|
||||
r"""Reorder indices of samples to align with DataParallel training.
|
||||
In this order, each process will contain all images for one ID, triplet loss
|
||||
can be computed within each process, and BatchNorm will get a stable result.
|
||||
Args:
|
||||
batch_indices: A batched indices generated by sampler
|
||||
world_size: number of process
|
||||
Returns:
|
||||
|
||||
self.index_pid = defaultdict(list)
|
||||
"""
|
||||
mini_batchsize = len(batch_indices) // world_size
|
||||
reorder_indices = []
|
||||
for i in range(0, mini_batchsize):
|
||||
for j in range(0, world_size):
|
||||
reorder_indices.append(batch_indices[i + j * mini_batchsize])
|
||||
return reorder_indices
|
||||
|
||||
|
||||
class BalancedIdentitySampler(Sampler):
|
||||
def __init__(self, data_source: str, mini_batch_size: int, num_instances: int, seed: Optional[int] = None):
|
||||
self.data_source = data_source
|
||||
self.num_instances = num_instances
|
||||
self.num_pids_per_batch = mini_batch_size // self.num_instances
|
||||
|
||||
self._rank = comm.get_rank()
|
||||
self._world_size = comm.get_world_size()
|
||||
self.batch_size = mini_batch_size * self._world_size
|
||||
|
||||
self.index_pid = dict()
|
||||
self.pid_cam = defaultdict(list)
|
||||
self.pid_index = defaultdict(list)
|
||||
|
||||
|
@ -60,14 +81,14 @@ class BalancedIdentitySampler(Sampler):
|
|||
|
||||
# If remaining identities cannot be enough for a batch,
|
||||
# just drop the remaining parts
|
||||
drop_indices = self.num_identities % self.num_pids_per_batch
|
||||
drop_indices = self.num_identities % (self.num_pids_per_batch * self._world_size)
|
||||
if drop_indices: identities = identities[:-drop_indices]
|
||||
|
||||
ret = []
|
||||
batch_indices = []
|
||||
for kid in identities:
|
||||
i = np.random.choice(self.pid_index[self.pids[kid]])
|
||||
_, i_pid, i_cam = self.data_source[i]
|
||||
ret.append(i)
|
||||
batch_indices.append(i)
|
||||
pid_i = self.index_pid[i]
|
||||
cams = self.pid_cam[pid_i]
|
||||
index = self.pid_index[pid_i]
|
||||
|
@ -79,7 +100,7 @@ class BalancedIdentitySampler(Sampler):
|
|||
else:
|
||||
cam_indexes = np.random.choice(select_cams, size=self.num_instances - 1, replace=True)
|
||||
for kk in cam_indexes:
|
||||
ret.append(index[kk])
|
||||
batch_indices.append(index[kk])
|
||||
else:
|
||||
select_indexes = no_index(index, i)
|
||||
if not select_indexes:
|
||||
|
@ -91,11 +112,11 @@ class BalancedIdentitySampler(Sampler):
|
|||
ind_indexes = np.random.choice(select_indexes, size=self.num_instances - 1, replace=True)
|
||||
|
||||
for kk in ind_indexes:
|
||||
ret.append(index[kk])
|
||||
batch_indices.append(index[kk])
|
||||
|
||||
if len(ret) == self.batch_size:
|
||||
yield from ret
|
||||
ret = []
|
||||
if len(batch_indices) == self.batch_size:
|
||||
yield from reorder_index(batch_indices, self._world_size)
|
||||
batch_indices = []
|
||||
|
||||
|
||||
class NaiveIdentitySampler(Sampler):
|
||||
|
@ -108,21 +129,19 @@ class NaiveIdentitySampler(Sampler):
|
|||
- batch_size (int): number of examples in a batch.
|
||||
"""
|
||||
|
||||
def __init__(self, data_source: str, batch_size: int, num_instances: int, seed: Optional[int] = None):
|
||||
def __init__(self, data_source: str, mini_batch_size: int, num_instances: int, seed: Optional[int] = None):
|
||||
self.data_source = data_source
|
||||
self.batch_size = batch_size
|
||||
self.num_instances = num_instances
|
||||
self.num_pids_per_batch = batch_size // self.num_instances
|
||||
self.num_pids_per_batch = mini_batch_size // self.num_instances
|
||||
|
||||
self._rank = comm.get_rank()
|
||||
self._world_size = comm.get_world_size()
|
||||
self.batch_size = mini_batch_size * self._world_size
|
||||
|
||||
self.index_pid = defaultdict(list)
|
||||
self.pid_cam = defaultdict(list)
|
||||
self.pid_index = defaultdict(list)
|
||||
|
||||
for index, info in enumerate(data_source):
|
||||
pid = info[1]
|
||||
camid = info[2]
|
||||
self.index_pid[index] = pid
|
||||
self.pid_cam[pid].append(camid)
|
||||
self.pid_index[pid].append(index)
|
||||
|
||||
self.pids = sorted(list(self.pid_index.keys()))
|
||||
|
@ -132,9 +151,6 @@ class NaiveIdentitySampler(Sampler):
|
|||
seed = comm.shared_random_seed()
|
||||
self._seed = int(seed)
|
||||
|
||||
self._rank = comm.get_rank()
|
||||
self._world_size = comm.get_world_size()
|
||||
|
||||
def __iter__(self):
|
||||
start = self._rank
|
||||
yield from itertools.islice(self._infinite_indices(), start, None, self._world_size)
|
||||
|
@ -142,12 +158,12 @@ class NaiveIdentitySampler(Sampler):
|
|||
def _infinite_indices(self):
|
||||
np.random.seed(self._seed)
|
||||
while True:
|
||||
avai_pids = copy.deepcopy(self.pids)
|
||||
avl_pids = copy.deepcopy(self.pids)
|
||||
batch_idxs_dict = {}
|
||||
|
||||
batch_indices = []
|
||||
while len(avai_pids) >= self.num_pids_per_batch:
|
||||
selected_pids = np.random.choice(avai_pids, self.num_pids_per_batch, replace=False).tolist()
|
||||
while len(avl_pids) >= self.num_pids_per_batch:
|
||||
selected_pids = np.random.choice(avl_pids, self.num_pids_per_batch, replace=False).tolist()
|
||||
for pid in selected_pids:
|
||||
# Register pid in batch_idxs_dict if not
|
||||
if pid not in batch_idxs_dict:
|
||||
|
@ -157,13 +173,12 @@ class NaiveIdentitySampler(Sampler):
|
|||
np.random.shuffle(idxs)
|
||||
batch_idxs_dict[pid] = idxs
|
||||
|
||||
avai_idxs = batch_idxs_dict[pid]
|
||||
avl_idxs = batch_idxs_dict[pid]
|
||||
for _ in range(self.num_instances):
|
||||
batch_indices.append(avai_idxs.pop(0))
|
||||
batch_indices.append(avl_idxs.pop(0))
|
||||
|
||||
if len(avai_idxs) < self.num_instances: avai_pids.remove(pid)
|
||||
if len(avl_idxs) < self.num_instances: avl_pids.remove(pid)
|
||||
|
||||
assert len(batch_indices) == self.batch_size, f"batch indices have wrong " \
|
||||
f"length with {len(batch_indices)}!"
|
||||
yield from batch_indices
|
||||
batch_indices = []
|
||||
if len(batch_indices) == self.batch_size:
|
||||
yield from reorder_index(batch_indices, self._world_size)
|
||||
batch_indices = []
|
||||
|
|
|
@ -494,20 +494,14 @@ def auto_augment_policy(name="original"):
|
|||
|
||||
class AutoAugment:
|
||||
|
||||
def __init__(self, total_iter):
|
||||
self.total_iter = total_iter
|
||||
self.gamma = 0
|
||||
def __init__(self):
|
||||
self.policy = auto_augment_policy()
|
||||
|
||||
def __call__(self, img):
|
||||
if random.uniform(0, 1) > self.gamma:
|
||||
sub_policy = random.choice(self.policy)
|
||||
self.gamma = min(1.0, self.gamma + 1.0 / self.total_iter)
|
||||
for op in sub_policy:
|
||||
img = op(img)
|
||||
return img
|
||||
else:
|
||||
return img
|
||||
sub_policy = random.choice(self.policy)
|
||||
for op in sub_policy:
|
||||
img = op(img)
|
||||
return img
|
||||
|
||||
|
||||
def auto_augment_transform(config_str, hparams):
|
||||
|
|
|
@ -18,10 +18,11 @@ def build_transforms(cfg, is_train=True):
|
|||
|
||||
# augmix augmentation
|
||||
do_augmix = cfg.INPUT.DO_AUGMIX
|
||||
augmix_prob = cfg.INPUT.AUGMIX_PROB
|
||||
|
||||
# auto augmentation
|
||||
do_autoaug = cfg.INPUT.DO_AUTOAUG
|
||||
total_iter = cfg.SOLVER.MAX_ITER
|
||||
autoaug_prob = cfg.INPUT.AUTOAUG_PROB
|
||||
|
||||
# horizontal filp
|
||||
do_flip = cfg.INPUT.DO_FLIP
|
||||
|
@ -43,29 +44,31 @@ def build_transforms(cfg, is_train=True):
|
|||
# random erasing
|
||||
do_rea = cfg.INPUT.REA.ENABLED
|
||||
rea_prob = cfg.INPUT.REA.PROB
|
||||
rea_mean = cfg.INPUT.REA.MEAN
|
||||
rea_value = cfg.INPUT.REA.VALUE
|
||||
|
||||
# random patch
|
||||
do_rpt = cfg.INPUT.RPT.ENABLED
|
||||
rpt_prob = cfg.INPUT.RPT.PROB
|
||||
|
||||
if do_autoaug:
|
||||
res.append(AutoAugment(total_iter))
|
||||
res.append(T.RandomApply([AutoAugment()], p=autoaug_prob))
|
||||
res.append(T.Resize(size_train, interpolation=3))
|
||||
if do_flip:
|
||||
res.append(T.RandomHorizontalFlip(p=flip_prob))
|
||||
if do_pad:
|
||||
res.extend([T.Pad(padding, padding_mode=padding_mode),
|
||||
T.RandomCrop(size_train)])
|
||||
res.extend([T.Pad(padding, padding_mode=padding_mode), T.RandomCrop(size_train)])
|
||||
if do_cj:
|
||||
res.append(T.RandomApply([T.ColorJitter(cj_brightness, cj_contrast, cj_saturation, cj_hue)], p=cj_prob))
|
||||
if do_augmix:
|
||||
res.append(AugMix())
|
||||
res.append(T.RandomApply([AugMix()], p=augmix_prob))
|
||||
|
||||
res.append(ToTensor())
|
||||
if do_rea:
|
||||
res.append(RandomErasing(probability=rea_prob, mean=rea_mean))
|
||||
res.append(T.RandomErasing(p=rea_prob, value=rea_value))
|
||||
if do_rpt:
|
||||
res.append(RandomPatch(prob_happen=rpt_prob))
|
||||
else:
|
||||
size_test = cfg.INPUT.SIZE_TEST
|
||||
res.append(T.Resize(size_test, interpolation=3))
|
||||
res.append(ToTensor())
|
||||
res.append(ToTensor())
|
||||
return T.Compose(res)
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
__all__ = ['ToTensor', 'RandomErasing', 'RandomPatch', 'AugMix',]
|
||||
__all__ = ['ToTensor', 'RandomPatch', 'AugMix', ]
|
||||
|
||||
import math
|
||||
import random
|
||||
|
@ -41,51 +41,6 @@ class ToTensor(object):
|
|||
return self.__class__.__name__ + '()'
|
||||
|
||||
|
||||
class RandomErasing(object):
|
||||
""" Randomly selects a rectangle region in an image and erases its pixels.
|
||||
'Random Erasing Data Augmentation' by Zhong et al.
|
||||
See https://arxiv.org/pdf/1708.04896.pdf
|
||||
Args:
|
||||
probability: The probability that the Random Erasing operation will be performed.
|
||||
sl: Minimum proportion of erased area against input image.
|
||||
sh: Maximum proportion of erased area against input image.
|
||||
r1: Minimum aspect ratio of erased area.
|
||||
mean: Erasing value.
|
||||
"""
|
||||
|
||||
def __init__(self, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=255 * (0.49735, 0.4822, 0.4465)):
|
||||
self.probability = probability
|
||||
self.mean = mean
|
||||
self.sl = sl
|
||||
self.sh = sh
|
||||
self.r1 = r1
|
||||
|
||||
def __call__(self, img):
|
||||
img = np.asarray(img, dtype=np.float32).copy()
|
||||
if random.uniform(0, 1) > self.probability:
|
||||
return img
|
||||
|
||||
for attempt in range(100):
|
||||
area = img.shape[0] * img.shape[1]
|
||||
target_area = random.uniform(self.sl, self.sh) * area
|
||||
aspect_ratio = random.uniform(self.r1, 1 / self.r1)
|
||||
|
||||
h = int(round(math.sqrt(target_area * aspect_ratio)))
|
||||
w = int(round(math.sqrt(target_area / aspect_ratio)))
|
||||
|
||||
if w < img.shape[1] and h < img.shape[0]:
|
||||
x1 = random.randint(0, img.shape[0] - h)
|
||||
y1 = random.randint(0, img.shape[1] - w)
|
||||
if img.shape[2] == 3:
|
||||
img[x1:x1 + h, y1:y1 + w, 0] = self.mean[0]
|
||||
img[x1:x1 + h, y1:y1 + w, 1] = self.mean[1]
|
||||
img[x1:x1 + h, y1:y1 + w, 2] = self.mean[2]
|
||||
else:
|
||||
img[x1:x1 + h, y1:y1 + w, 0] = self.mean[0]
|
||||
return img
|
||||
return img
|
||||
|
||||
|
||||
class RandomPatch(object):
|
||||
"""Random patch data augmentation.
|
||||
There is a patch pool that stores randomly extracted pathces from person images.
|
||||
|
|
|
@ -11,15 +11,15 @@ since they are meant to represent the "common default behavior" people need in t
|
|||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import math
|
||||
import sys
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
|
||||
from fastreid.data import build_reid_test_loader, build_reid_train_loader
|
||||
from fastreid.evaluation import (DatasetEvaluator, ReidEvaluator,
|
||||
from fastreid.evaluation import (ReidEvaluator,
|
||||
inference_on_dataset, print_csv_format)
|
||||
from fastreid.modeling.meta_arch import build_model
|
||||
from fastreid.solver import build_lr_scheduler, build_optimizer
|
||||
|
@ -31,7 +31,15 @@ from fastreid.utils.events import CommonMetricPrinter, JSONWriter, TensorboardXW
|
|||
from fastreid.utils.file_io import PathManager
|
||||
from fastreid.utils.logger import setup_logger
|
||||
from . import hooks
|
||||
from .train_loop import SimpleTrainer
|
||||
from .train_loop import TrainerBase, AMPTrainer, SimpleTrainer
|
||||
|
||||
try:
|
||||
import apex
|
||||
from apex import amp
|
||||
from apex.parallel import DistributedDataParallel
|
||||
except ImportError:
|
||||
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example if you want to"
|
||||
"train with DDP")
|
||||
|
||||
__all__ = ["default_argument_parser", "default_setup", "DefaultPredictor", "DefaultTrainer"]
|
||||
|
||||
|
@ -158,7 +166,7 @@ class DefaultPredictor:
|
|||
return features
|
||||
|
||||
|
||||
class DefaultTrainer(SimpleTrainer):
|
||||
class DefaultTrainer(TrainerBase):
|
||||
"""
|
||||
A trainer with default training logic. Compared to `SimpleTrainer`, it
|
||||
contains the following logic in addition:
|
||||
|
@ -196,27 +204,38 @@ class DefaultTrainer(SimpleTrainer):
|
|||
Args:
|
||||
cfg (CfgNode):
|
||||
"""
|
||||
super().__init__()
|
||||
logger = logging.getLogger("fastreid")
|
||||
if not logger.isEnabledFor(logging.INFO): # setup_logger is not called for fastreid
|
||||
setup_logger()
|
||||
|
||||
# Assume these objects must be constructed in this order.
|
||||
data_loader = self.build_train_loader(cfg)
|
||||
cfg = self.auto_scale_hyperparams(cfg, data_loader)
|
||||
cfg = self.auto_scale_hyperparams(cfg, data_loader.dataset.num_classes)
|
||||
model = self.build_model(cfg)
|
||||
optimizer = self.build_optimizer(cfg, model)
|
||||
|
||||
optimizer_ckpt = dict(optimizer=optimizer)
|
||||
if cfg.SOLVER.FP16_ENABLED:
|
||||
model, optimizer = amp.initialize(model, optimizer, opt_level="O1")
|
||||
optimizer_ckpt.update(dict(amp=amp))
|
||||
|
||||
# For training, wrap with DDP. But don't need this for inference.
|
||||
if comm.get_world_size() > 1:
|
||||
# ref to https://github.com/pytorch/pytorch/issues/22049 to set `find_unused_parameters=True`
|
||||
# for part of the parameters is not updated.
|
||||
model = DistributedDataParallel(
|
||||
model, device_ids=[comm.get_local_rank()], broadcast_buffers=False
|
||||
)
|
||||
# model = DistributedDataParallel(
|
||||
# model, device_ids=[comm.get_local_rank()], broadcast_buffers=False
|
||||
# )
|
||||
model = DistributedDataParallel(model, delay_allreduce=True)
|
||||
|
||||
super().__init__(model, data_loader, optimizer, cfg.SOLVER.AMP_ENABLED)
|
||||
self._trainer = (AMPTrainer if cfg.SOLVER.FP16_ENABLED else SimpleTrainer)(
|
||||
model, data_loader, optimizer
|
||||
)
|
||||
|
||||
self.iters_per_epoch = len(data_loader.dataset) // cfg.SOLVER.IMS_PER_BATCH
|
||||
self.scheduler = self.build_lr_scheduler(cfg, optimizer, self.iters_per_epoch)
|
||||
|
||||
self.scheduler = self.build_lr_scheduler(cfg, optimizer)
|
||||
# Assume no other objects need to be checkpointed.
|
||||
# We can later make it checkpoint the stateful hooks
|
||||
self.checkpointer = Checkpointer(
|
||||
|
@ -224,15 +243,20 @@ class DefaultTrainer(SimpleTrainer):
|
|||
model,
|
||||
cfg.OUTPUT_DIR,
|
||||
save_to_disk=comm.is_main_process(),
|
||||
optimizer=optimizer,
|
||||
scheduler=self.scheduler,
|
||||
**optimizer_ckpt,
|
||||
**self.scheduler,
|
||||
)
|
||||
self.start_iter = 0
|
||||
if cfg.SOLVER.SWA.ENABLED:
|
||||
self.max_iter = cfg.SOLVER.MAX_ITER + cfg.SOLVER.SWA.ITER
|
||||
else:
|
||||
self.max_iter = cfg.SOLVER.MAX_ITER
|
||||
self.start_epoch = 0
|
||||
|
||||
# if cfg.SOLVER.SWA.ENABLED:
|
||||
# self.max_iter = cfg.SOLVER.MAX_ITER + cfg.SOLVER.SWA.ITER
|
||||
# else:
|
||||
# self.max_iter = cfg.SOLVER.MAX_ITER
|
||||
|
||||
self.max_epoch = cfg.SOLVER.MAX_EPOCH
|
||||
self.max_iter = self.max_epoch * self.iters_per_epoch
|
||||
self.warmup_iters = cfg.SOLVER.WARMUP_ITERS
|
||||
self.delay_epochs = cfg.SOLVER.DELAY_EPOCHS
|
||||
self.cfg = cfg
|
||||
|
||||
self.register_hooks(self.build_hooks())
|
||||
|
@ -254,7 +278,7 @@ class DefaultTrainer(SimpleTrainer):
|
|||
checkpoint = self.checkpointer.resume_or_load(self.cfg.MODEL.WEIGHTS, resume=resume)
|
||||
|
||||
if resume and self.checkpointer.has_checkpoint():
|
||||
self.start_iter = checkpoint.get("iteration", -1) + 1
|
||||
self.start_epoch = checkpoint.get("epoch", -1) + 1
|
||||
# The checkpoint stores the training iteration that just finished, thus we start
|
||||
# at the next iteration (or iter zero if there's no checkpoint).
|
||||
|
||||
|
@ -276,16 +300,16 @@ class DefaultTrainer(SimpleTrainer):
|
|||
hooks.LRScheduler(self.optimizer, self.scheduler),
|
||||
]
|
||||
|
||||
if cfg.SOLVER.SWA.ENABLED:
|
||||
ret.append(
|
||||
hooks.SWA(
|
||||
cfg.SOLVER.MAX_ITER,
|
||||
cfg.SOLVER.SWA.PERIOD,
|
||||
cfg.SOLVER.SWA.LR_FACTOR,
|
||||
cfg.SOLVER.SWA.ETA_MIN_LR,
|
||||
cfg.SOLVER.SWA.LR_SCHED,
|
||||
)
|
||||
)
|
||||
# if cfg.SOLVER.SWA.ENABLED:
|
||||
# ret.append(
|
||||
# hooks.SWA(
|
||||
# cfg.SOLVER.MAX_ITER,
|
||||
# cfg.SOLVER.SWA.PERIOD,
|
||||
# cfg.SOLVER.SWA.LR_FACTOR,
|
||||
# cfg.SOLVER.SWA.ETA_MIN_LR,
|
||||
# cfg.SOLVER.SWA.LR_SCHED,
|
||||
# )
|
||||
# )
|
||||
|
||||
if cfg.TEST.PRECISE_BN.ENABLED and hooks.get_bn_modules(self.model):
|
||||
logger.info("Prepare precise BN dataset")
|
||||
|
@ -298,11 +322,8 @@ class DefaultTrainer(SimpleTrainer):
|
|||
))
|
||||
|
||||
if cfg.MODEL.FREEZE_LAYERS != [''] and cfg.SOLVER.FREEZE_ITERS > 0:
|
||||
freeze_layers = ",".join(cfg.MODEL.FREEZE_LAYERS)
|
||||
logger.info(f'Freeze layer group "{freeze_layers}" training for {cfg.SOLVER.FREEZE_ITERS:d} iterations')
|
||||
ret.append(hooks.FreezeLayer(
|
||||
ret.append(hooks.LayerFreeze(
|
||||
self.model,
|
||||
self.optimizer,
|
||||
cfg.MODEL.FREEZE_LAYERS,
|
||||
cfg.SOLVER.FREEZE_ITERS,
|
||||
))
|
||||
|
@ -358,14 +379,17 @@ class DefaultTrainer(SimpleTrainer):
|
|||
Returns:
|
||||
OrderedDict of results, if evaluation is enabled. Otherwise None.
|
||||
"""
|
||||
super().train(self.start_iter, self.max_iter)
|
||||
super().train(self.start_epoch, self.max_epoch, self.iters_per_epoch)
|
||||
if comm.is_main_process():
|
||||
assert hasattr(
|
||||
self, "_last_eval_results"
|
||||
), "No evaluation results obtained during training!"
|
||||
# verify_results(self.cfg, self._last_eval_results)
|
||||
return self._last_eval_results
|
||||
|
||||
def run_step(self):
|
||||
self._trainer.iter = self.iter
|
||||
self._trainer.run_step()
|
||||
|
||||
@classmethod
|
||||
def build_model(cls, cfg):
|
||||
"""
|
||||
|
@ -390,11 +414,15 @@ class DefaultTrainer(SimpleTrainer):
|
|||
return build_optimizer(cfg, model)
|
||||
|
||||
@classmethod
|
||||
def build_lr_scheduler(cls, cfg, optimizer):
|
||||
def build_lr_scheduler(cls, cfg, optimizer, iters_per_epoch):
|
||||
"""
|
||||
It now calls :func:`fastreid.solver.build_lr_scheduler`.
|
||||
Overwrite it if you'd like a different scheduler.
|
||||
"""
|
||||
cfg = cfg.clone()
|
||||
cfg.defrost()
|
||||
cfg.SOLVER.MAX_EPOCH = cfg.SOLVER.MAX_EPOCH - max(
|
||||
math.ceil(cfg.SOLVER.WARMUP_ITERS / iters_per_epoch), cfg.SOLVER.DELAY_EPOCHS)
|
||||
return build_lr_scheduler(cfg, optimizer)
|
||||
|
||||
@classmethod
|
||||
|
@ -462,7 +490,7 @@ class DefaultTrainer(SimpleTrainer):
|
|||
return results
|
||||
|
||||
@staticmethod
|
||||
def auto_scale_hyperparams(cfg, data_loader):
|
||||
def auto_scale_hyperparams(cfg, num_classes):
|
||||
r"""
|
||||
This is used for auto-computation actual training iterations,
|
||||
because some hyper-param, such as MAX_ITER, means training epochs rather than iters,
|
||||
|
@ -475,7 +503,10 @@ class DefaultTrainer(SimpleTrainer):
|
|||
# If you don't hard-code the number of classes, it will compute the number automatically
|
||||
if cfg.MODEL.HEADS.NUM_CLASSES == 0:
|
||||
output_dir = cfg.OUTPUT_DIR
|
||||
cfg.MODEL.HEADS.NUM_CLASSES = data_loader.dataset.num_classes
|
||||
cfg.MODEL.HEADS.NUM_CLASSES = num_classes
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info(f"Auto-scaling the num_classes={cfg.MODEL.HEADS.NUM_CLASSES}")
|
||||
|
||||
# Update the saved config file to make the number of classes valid
|
||||
if comm.is_main_process() and output_dir:
|
||||
# Note: some of our scripts may expect the existence of
|
||||
|
@ -484,32 +515,11 @@ class DefaultTrainer(SimpleTrainer):
|
|||
with PathManager.open(path, "w") as f:
|
||||
f.write(cfg.dump())
|
||||
|
||||
iters_per_epoch = len(data_loader.dataset) // cfg.SOLVER.IMS_PER_BATCH
|
||||
cfg.SOLVER.MAX_ITER *= iters_per_epoch
|
||||
cfg.SOLVER.WARMUP_ITERS *= iters_per_epoch
|
||||
cfg.SOLVER.FREEZE_ITERS *= iters_per_epoch
|
||||
cfg.SOLVER.DELAY_ITERS *= iters_per_epoch
|
||||
for i in range(len(cfg.SOLVER.STEPS)):
|
||||
cfg.SOLVER.STEPS[i] *= iters_per_epoch
|
||||
cfg.SOLVER.SWA.ITER *= iters_per_epoch
|
||||
cfg.SOLVER.SWA.PERIOD *= iters_per_epoch
|
||||
|
||||
ckpt_multiple = cfg.SOLVER.CHECKPOINT_PERIOD / cfg.TEST.EVAL_PERIOD
|
||||
# Evaluation period must be divided by 200 for writing into tensorboard.
|
||||
eval_num_mod = (200 - cfg.TEST.EVAL_PERIOD * iters_per_epoch) % 200
|
||||
cfg.TEST.EVAL_PERIOD = cfg.TEST.EVAL_PERIOD * iters_per_epoch + eval_num_mod
|
||||
# Change checkpoint saving period consistent with evaluation period.
|
||||
cfg.SOLVER.CHECKPOINT_PERIOD = int(cfg.TEST.EVAL_PERIOD * ckpt_multiple)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info(
|
||||
f"Auto-scaling the config to num_classes={cfg.MODEL.HEADS.NUM_CLASSES}, "
|
||||
f"max_Iter={cfg.SOLVER.MAX_ITER}, wamrup_Iter={cfg.SOLVER.WARMUP_ITERS}, "
|
||||
f"freeze_Iter={cfg.SOLVER.FREEZE_ITERS}, delay_Iter={cfg.SOLVER.DELAY_ITERS}, "
|
||||
f"step_Iter={cfg.SOLVER.STEPS}, ckpt_Iter={cfg.SOLVER.CHECKPOINT_PERIOD}, "
|
||||
f"eval_Iter={cfg.TEST.EVAL_PERIOD}."
|
||||
)
|
||||
|
||||
if frozen: cfg.freeze()
|
||||
|
||||
return cfg
|
||||
|
||||
|
||||
# Access basic attributes from the underlying trainer
|
||||
for _attr in ["model", "data_loader", "optimizer"]:
|
||||
setattr(DefaultTrainer, _attr, property(lambda self, x=_attr: getattr(self._trainer, x)))
|
||||
|
|
|
@ -11,7 +11,7 @@ from collections import Counter
|
|||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
from apex.parallel import DistributedDataParallel
|
||||
|
||||
from fastreid.evaluation.testing import flatten_results_dict
|
||||
from fastreid.solver import optim
|
||||
|
@ -32,7 +32,7 @@ __all__ = [
|
|||
"AutogradProfiler",
|
||||
"EvalHook",
|
||||
"PreciseBN",
|
||||
"FreezeLayer",
|
||||
"LayerFreeze",
|
||||
]
|
||||
|
||||
"""
|
||||
|
@ -45,13 +45,16 @@ class CallbackHook(HookBase):
|
|||
Create a hook using callback functions provided by the user.
|
||||
"""
|
||||
|
||||
def __init__(self, *, before_train=None, after_train=None, before_step=None, after_step=None):
|
||||
def __init__(self, *, before_train=None, after_train=None, before_epoch=None, after_epoch=None,
|
||||
before_step=None, after_step=None):
|
||||
"""
|
||||
Each argument is a function that takes one argument: the trainer.
|
||||
"""
|
||||
self._before_train = before_train
|
||||
self._before_epoch = before_epoch
|
||||
self._before_step = before_step
|
||||
self._after_step = after_step
|
||||
self._after_epoch = after_epoch
|
||||
self._after_train = after_train
|
||||
|
||||
def before_train(self):
|
||||
|
@ -66,6 +69,14 @@ class CallbackHook(HookBase):
|
|||
del self._before_train, self._after_train
|
||||
del self._before_step, self._after_step
|
||||
|
||||
def before_epoch(self):
|
||||
if self._before_epoch:
|
||||
self._before_epoch(self.trainer)
|
||||
|
||||
def after_epoch(self):
|
||||
if self._after_epoch:
|
||||
self._after_epoch(self.trainer)
|
||||
|
||||
def before_step(self):
|
||||
if self._before_step:
|
||||
self._before_step(self.trainer)
|
||||
|
@ -167,6 +178,10 @@ class PeriodicWriter(HookBase):
|
|||
for writer in self._writers:
|
||||
writer.write()
|
||||
|
||||
def after_epoch(self):
|
||||
for writer in self._writers:
|
||||
writer.write()
|
||||
|
||||
def after_train(self):
|
||||
for writer in self._writers:
|
||||
writer.close()
|
||||
|
@ -182,11 +197,11 @@ class PeriodicCheckpointer(_PeriodicCheckpointer, HookBase):
|
|||
"""
|
||||
|
||||
def before_train(self):
|
||||
self.max_iter = self.trainer.max_iter
|
||||
self.max_epoch = self.trainer.max_epoch
|
||||
|
||||
def after_step(self):
|
||||
def after_epoch(self):
|
||||
# No way to use **kwargs
|
||||
self.step(self.trainer.iter)
|
||||
self.step(self.trainer.epoch)
|
||||
|
||||
|
||||
class LRScheduler(HookBase):
|
||||
|
@ -226,7 +241,16 @@ class LRScheduler(HookBase):
|
|||
def after_step(self):
|
||||
lr = self._optimizer.param_groups[self._best_param_group_id]["lr"]
|
||||
self.trainer.storage.put_scalar("lr", lr, smoothing_hint=False)
|
||||
self._scheduler.step()
|
||||
|
||||
next_iter = self.trainer.iter + 1
|
||||
if next_iter < self.trainer.warmup_iters:
|
||||
self._scheduler["warmup_sched"].step()
|
||||
|
||||
def after_epoch(self):
|
||||
next_iter = self.trainer.iter
|
||||
next_epoch = self.trainer.epoch + 1
|
||||
if next_iter >= self.trainer.warmup_iters and next_epoch >= self.trainer.delay_epochs:
|
||||
self._scheduler["lr_sched"].step()
|
||||
|
||||
|
||||
class AutogradProfiler(HookBase):
|
||||
|
@ -331,10 +355,10 @@ class EvalHook(HookBase):
|
|||
# Remove extra memory cache of main process due to evaluation
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def after_step(self):
|
||||
next_iter = self.trainer.iter + 1
|
||||
is_final = next_iter == self.trainer.max_iter
|
||||
if is_final or (self._period > 0 and next_iter % self._period == 0):
|
||||
def after_epoch(self):
|
||||
next_epoch = self.trainer.epoch + 1
|
||||
is_final = next_epoch == self.trainer.max_epoch
|
||||
if is_final or (self._period > 0 and next_epoch % self._period == 0):
|
||||
self._do_eval()
|
||||
# Evaluation may take different time among workers.
|
||||
# A barrier make them start the next iteration together.
|
||||
|
@ -381,9 +405,9 @@ class PreciseBN(HookBase):
|
|||
|
||||
self._data_iter = None
|
||||
|
||||
def after_step(self):
|
||||
next_iter = self.trainer.iter + 1
|
||||
is_final = next_iter == self.trainer.max_iter
|
||||
def after_epoch(self):
|
||||
next_epoch = self.trainer.epoch + 1
|
||||
is_final = next_epoch == self.trainer.max_epoch
|
||||
if is_final:
|
||||
self.update_stats()
|
||||
|
||||
|
@ -414,34 +438,26 @@ class PreciseBN(HookBase):
|
|||
update_bn_stats(self._model, data_loader(), self._num_iter)
|
||||
|
||||
|
||||
class FreezeLayer(HookBase):
|
||||
def __init__(self, model, optimizer, freeze_layers, freeze_iters):
|
||||
class LayerFreeze(HookBase):
|
||||
def __init__(self, model, freeze_layers, freeze_iters):
|
||||
self._logger = logging.getLogger(__name__)
|
||||
|
||||
if isinstance(model, DistributedDataParallel):
|
||||
model = model.module
|
||||
self.model = model
|
||||
self.optimizer = optimizer
|
||||
|
||||
self.freeze_layers = freeze_layers
|
||||
self.freeze_iters = freeze_iters
|
||||
|
||||
# Previous parameters freeze status
|
||||
param_freeze = {}
|
||||
for param_group in self.optimizer.param_groups:
|
||||
param_name = param_group['name']
|
||||
param_freeze[param_name] = param_group['freeze']
|
||||
self.param_freeze = param_freeze
|
||||
|
||||
self.is_frozen = False
|
||||
|
||||
def before_step(self):
|
||||
# Freeze specific layers
|
||||
if self.trainer.iter <= self.freeze_iters and not self.is_frozen:
|
||||
if self.trainer.iter < self.freeze_iters and not self.is_frozen:
|
||||
self.freeze_specific_layer()
|
||||
|
||||
# Recover original layers status
|
||||
if self.trainer.iter > self.freeze_iters and self.is_frozen:
|
||||
if self.trainer.iter >= self.freeze_iters and self.is_frozen:
|
||||
self.open_all_layer()
|
||||
|
||||
def freeze_specific_layer(self):
|
||||
|
@ -449,25 +465,29 @@ class FreezeLayer(HookBase):
|
|||
if not hasattr(self.model, layer):
|
||||
self._logger.info(f'{layer} is not an attribute of the model, will skip this layer')
|
||||
|
||||
for param_group in self.optimizer.param_groups:
|
||||
param_name = param_group['name']
|
||||
if param_name.split('.')[0] in self.freeze_layers:
|
||||
param_group['freeze'] = True
|
||||
|
||||
# Change BN in freeze layers to eval mode
|
||||
for name, module in self.model.named_children():
|
||||
if name in self.freeze_layers: module.eval()
|
||||
if name in self.freeze_layers:
|
||||
# Change BN in freeze layers to eval mode
|
||||
module.eval()
|
||||
for p in module.parameters():
|
||||
p.requires_grad_(False)
|
||||
|
||||
self.is_frozen = True
|
||||
freeze_layers = ",".join(self.freeze_layers)
|
||||
self._logger.info(f'Freeze layer group "{freeze_layers}" training for {self.freeze_iters:d} iterations')
|
||||
|
||||
def open_all_layer(self):
|
||||
self.model.train()
|
||||
for param_group in self.optimizer.param_groups:
|
||||
param_name = param_group['name']
|
||||
param_group['freeze'] = self.param_freeze[param_name]
|
||||
for name, module in self.model.named_children():
|
||||
if name in self.freeze_layers:
|
||||
module.train()
|
||||
for p in module.parameters():
|
||||
p.requires_grad_(True)
|
||||
|
||||
self.is_frozen = False
|
||||
|
||||
freeze_layers = ",".join(self.freeze_layers)
|
||||
self._logger.info(f'Open layer group "{freeze_layers}" training')
|
||||
|
||||
|
||||
class SWA(HookBase):
|
||||
def __init__(self, swa_start: int, swa_freq: int, swa_lr_factor: float, eta_min: float, lr_sched=False, ):
|
||||
|
|
|
@ -7,29 +7,35 @@ https://github.com/facebookresearch/detectron2/blob/master/detectron2/engine/tra
|
|||
import logging
|
||||
import time
|
||||
import weakref
|
||||
from typing import Dict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.cuda import amp
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
from apex import amp
|
||||
from apex.parallel import DistributedDataParallel
|
||||
|
||||
import fastreid.utils.comm as comm
|
||||
from fastreid.utils.events import EventStorage
|
||||
from fastreid.utils.events import EventStorage, get_event_storage
|
||||
|
||||
__all__ = ["HookBase", "TrainerBase", "SimpleTrainer"]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HookBase:
|
||||
"""
|
||||
Base class for hooks that can be registered with :class:`TrainerBase`.
|
||||
Each hook can implement 4 methods. The way they are called is demonstrated
|
||||
Each hook can implement 6 methods. The way they are called is demonstrated
|
||||
in the following snippet:
|
||||
.. code-block:: python
|
||||
hook.before_train()
|
||||
for iter in range(start_iter, max_iter):
|
||||
hook.before_step()
|
||||
trainer.run_step()
|
||||
hook.after_step()
|
||||
for _ in range(start_epoch, max_epoch):
|
||||
hook.before_epoch()
|
||||
for iter in range(start_iter, max_iter):
|
||||
hook.before_step()
|
||||
trainer.run_step()
|
||||
hook.after_step()
|
||||
hook.after_epoch()
|
||||
hook.after_train()
|
||||
Notes:
|
||||
1. In the hook method, users can access `self.trainer` to access more
|
||||
|
@ -59,6 +65,18 @@ class HookBase:
|
|||
"""
|
||||
pass
|
||||
|
||||
def before_epoch(self):
|
||||
"""
|
||||
Called before each epoch.
|
||||
"""
|
||||
pass
|
||||
|
||||
def after_epoch(self):
|
||||
"""
|
||||
Called after each epoch.
|
||||
"""
|
||||
pass
|
||||
|
||||
def before_step(self):
|
||||
"""
|
||||
Called before each iteration.
|
||||
|
@ -106,26 +124,30 @@ class TrainerBase:
|
|||
h.trainer = weakref.proxy(self)
|
||||
self._hooks.extend(hooks)
|
||||
|
||||
def train(self, start_iter: int, max_iter: int):
|
||||
def train(self, start_epoch: int, max_epoch: int, iters_per_epoch: int):
|
||||
"""
|
||||
Args:
|
||||
start_iter, max_iter (int): See docs above
|
||||
"""
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info("Starting training from iteration {}".format(start_iter))
|
||||
logger.info("Starting training from epoch {}".format(start_epoch))
|
||||
|
||||
self.iter = self.start_iter = start_iter
|
||||
self.max_iter = max_iter
|
||||
self.iter = self.start_iter = start_epoch * iters_per_epoch
|
||||
|
||||
with EventStorage(start_iter) as self.storage:
|
||||
with EventStorage(self.start_iter) as self.storage:
|
||||
try:
|
||||
self.before_train()
|
||||
for self.iter in range(start_iter, max_iter):
|
||||
self.before_step()
|
||||
self.run_step()
|
||||
self.after_step()
|
||||
for self.epoch in range(start_epoch, max_epoch):
|
||||
self.before_epoch()
|
||||
for _ in range(iters_per_epoch):
|
||||
self.before_step()
|
||||
self.run_step()
|
||||
self.after_step()
|
||||
self.iter += 1
|
||||
self.after_epoch()
|
||||
except Exception:
|
||||
logger.exception("Exception during training:")
|
||||
raise
|
||||
finally:
|
||||
self.after_train()
|
||||
|
||||
|
@ -134,18 +156,29 @@ class TrainerBase:
|
|||
h.before_train()
|
||||
|
||||
def after_train(self):
|
||||
self.storage.iter = self.iter
|
||||
for h in self._hooks:
|
||||
h.after_train()
|
||||
|
||||
def before_epoch(self):
|
||||
self.storage.epoch = self.epoch
|
||||
|
||||
for h in self._hooks:
|
||||
h.before_epoch()
|
||||
|
||||
def before_step(self):
|
||||
self.storage.iter = self.iter
|
||||
|
||||
for h in self._hooks:
|
||||
h.before_step()
|
||||
|
||||
def after_step(self):
|
||||
for h in self._hooks:
|
||||
h.after_step()
|
||||
# this guarantees, that in each hook's after_step, storage.iter == trainer.iter
|
||||
self.storage.step()
|
||||
|
||||
def after_epoch(self):
|
||||
for h in self._hooks:
|
||||
h.after_epoch()
|
||||
|
||||
def run_step(self):
|
||||
raise NotImplementedError
|
||||
|
@ -164,7 +197,7 @@ class SimpleTrainer(TrainerBase):
|
|||
or write your own training loop.
|
||||
"""
|
||||
|
||||
def __init__(self, model, data_loader, optimizer, amp_enabled):
|
||||
def __init__(self, model, data_loader, optimizer):
|
||||
"""
|
||||
Args:
|
||||
model: a torch Module. Takes a data from data_loader and returns a
|
||||
|
@ -186,11 +219,6 @@ class SimpleTrainer(TrainerBase):
|
|||
self.data_loader = data_loader
|
||||
self._data_loader_iter = iter(data_loader)
|
||||
self.optimizer = optimizer
|
||||
self.amp_enabled = amp_enabled
|
||||
|
||||
if amp_enabled:
|
||||
# Creates a GradScaler once at the beginning of training.
|
||||
self.scaler = amp.GradScaler()
|
||||
|
||||
def run_step(self):
|
||||
"""
|
||||
|
@ -208,76 +236,104 @@ class SimpleTrainer(TrainerBase):
|
|||
If your want to do something with the heads, you can wrap the model.
|
||||
"""
|
||||
|
||||
with amp.autocast(enabled=self.amp_enabled):
|
||||
outs = self.model(data)
|
||||
outs = self.model(data)
|
||||
|
||||
# Compute loss
|
||||
if isinstance(self.model, DistributedDataParallel):
|
||||
loss_dict = self.model.module.losses(outs)
|
||||
else:
|
||||
loss_dict = self.model.losses(outs)
|
||||
# Compute loss
|
||||
if isinstance(self.model, DistributedDataParallel):
|
||||
loss_dict = self.model.module.losses(outs)
|
||||
else:
|
||||
loss_dict = self.model.losses(outs)
|
||||
|
||||
losses = sum(loss_dict.values())
|
||||
|
||||
with torch.cuda.stream(torch.cuda.Stream()):
|
||||
metrics_dict = loss_dict
|
||||
metrics_dict["data_time"] = data_time
|
||||
self._write_metrics(metrics_dict)
|
||||
self._detect_anomaly(losses, loss_dict)
|
||||
losses = sum(loss_dict.values())
|
||||
|
||||
"""
|
||||
If you need accumulate gradients or something similar, you can
|
||||
wrap the optimizer with your custom `zero_grad()` method.
|
||||
"""
|
||||
self.optimizer.zero_grad()
|
||||
losses.backward()
|
||||
|
||||
if self.amp_enabled:
|
||||
self.scaler.scale(losses).backward()
|
||||
self.scaler.step(self.optimizer)
|
||||
self.scaler.update()
|
||||
else:
|
||||
losses.backward()
|
||||
"""
|
||||
If you need gradient clipping/scaling or other processing, you can
|
||||
wrap the optimizer with your custom `step()` method.
|
||||
"""
|
||||
self.optimizer.step()
|
||||
self._write_metrics(loss_dict, data_time)
|
||||
|
||||
def _detect_anomaly(self, losses, loss_dict):
|
||||
if not torch.isfinite(losses).all():
|
||||
raise FloatingPointError(
|
||||
"Loss became infinite or NaN at iteration={}!\nloss_dict = {}".format(
|
||||
self.iter, loss_dict
|
||||
)
|
||||
)
|
||||
"""
|
||||
If you need gradient clipping/scaling or other processing, you can
|
||||
wrap the optimizer with your custom `step()` method.
|
||||
"""
|
||||
self.optimizer.step()
|
||||
|
||||
def _write_metrics(self, metrics_dict: dict):
|
||||
def _write_metrics(self, loss_dict: Dict[str, torch.Tensor], data_time: float):
|
||||
"""
|
||||
Args:
|
||||
metrics_dict (dict): dict of scalar metrics
|
||||
loss_dict (dict): dict of scalar losses
|
||||
data_time (float): time taken by the dataloader iteration
|
||||
"""
|
||||
metrics_dict = {
|
||||
k: v.detach().cpu().item() if isinstance(v, torch.Tensor) else float(v)
|
||||
for k, v in metrics_dict.items()
|
||||
}
|
||||
# gather metrics among all workers for logging
|
||||
# This assumes we do DDP-style training, which is currently the only
|
||||
# supported method in fastreid.
|
||||
all_metrics_dict = comm.gather(metrics_dict)
|
||||
device = next(iter(loss_dict.values())).device
|
||||
|
||||
# Use a new stream so these ops don't wait for DDP or backward
|
||||
with torch.cuda.stream(torch.cuda.Stream() if device.type == "cuda" else None):
|
||||
metrics_dict = {k: v.detach().cpu().item() for k, v in loss_dict.items()}
|
||||
metrics_dict["data_time"] = data_time
|
||||
|
||||
# Gather metrics among all workers for logging
|
||||
# This assumes we do DDP-style training, which is currently the only
|
||||
# supported method in detectron2.
|
||||
all_metrics_dict = comm.gather(metrics_dict)
|
||||
|
||||
if comm.is_main_process():
|
||||
if "data_time" in all_metrics_dict[0]:
|
||||
# data_time among workers can have high variance. The actual latency
|
||||
# caused by data_time is the maximum among workers.
|
||||
data_time = np.max([x.pop("data_time") for x in all_metrics_dict])
|
||||
self.storage.put_scalar("data_time", data_time)
|
||||
storage = get_event_storage()
|
||||
|
||||
# data_time among workers can have high variance. The actual latency
|
||||
# caused by data_time is the maximum among workers.
|
||||
data_time = np.max([x.pop("data_time") for x in all_metrics_dict])
|
||||
storage.put_scalar("data_time", data_time)
|
||||
|
||||
# average the rest metrics
|
||||
metrics_dict = {
|
||||
k: np.mean([x[k] for x in all_metrics_dict]) for k in all_metrics_dict[0].keys()
|
||||
}
|
||||
total_losses_reduced = sum(loss for loss in metrics_dict.values())
|
||||
total_losses_reduced = sum(metrics_dict.values())
|
||||
if not np.isfinite(total_losses_reduced):
|
||||
raise FloatingPointError(
|
||||
f"Loss became infinite or NaN at iteration={self.iter}!\n"
|
||||
f"loss_dict = {metrics_dict}"
|
||||
)
|
||||
|
||||
self.storage.put_scalar("total_loss", total_losses_reduced)
|
||||
storage.put_scalar("total_loss", total_losses_reduced)
|
||||
if len(metrics_dict) > 1:
|
||||
self.storage.put_scalars(**metrics_dict)
|
||||
storage.put_scalars(**metrics_dict)
|
||||
|
||||
|
||||
class AMPTrainer(SimpleTrainer):
|
||||
"""
|
||||
Like :class:`SimpleTrainer`, but uses apex automatic mixed precision
|
||||
in the training loop.
|
||||
"""
|
||||
def run_step(self):
|
||||
"""
|
||||
Implement the AMP training logic.
|
||||
"""
|
||||
assert self.model.training, "[AMPTrainer] model was changed to eval mode!"
|
||||
assert torch.cuda.is_available(), "[AMPTrainer] CUDA is required for AMP training!"
|
||||
|
||||
start = time.perf_counter()
|
||||
data = next(self._data_loader_iter)
|
||||
data_time = time.perf_counter() - start
|
||||
|
||||
outs = self.model(data)
|
||||
|
||||
# Compute loss
|
||||
if isinstance(self.model, DistributedDataParallel):
|
||||
loss_dict = self.model.module.losses(outs)
|
||||
else:
|
||||
loss_dict = self.model.losses(outs)
|
||||
|
||||
losses = sum(loss_dict.values())
|
||||
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
with amp.scale_loss(losses, self.optimizer) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
|
||||
self._write_metrics(loss_dict, data_time)
|
||||
|
||||
self.optimizer.step()
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
from .activation import *
|
||||
from .arc_softmax import ArcSoftmax
|
||||
from .circle_softmax import CircleSoftmax
|
||||
from .am_softmax import AMSoftmax
|
||||
from .cos_softmax import CosSoftmax
|
||||
from .batch_drop import BatchDrop
|
||||
from .batch_norm import *
|
||||
from .context_block import ContextBlock
|
||||
|
@ -15,5 +15,5 @@ from .frn import FRN, TLU
|
|||
from .non_local import Non_local
|
||||
from .pooling import *
|
||||
from .se_layer import SELayer
|
||||
from .splat import SplAtConv2d
|
||||
from .splat import SplAtConv2d, DropBlock2D
|
||||
from .gather_layer import GatherLayer
|
||||
|
|
|
@ -10,14 +10,12 @@ import torch
|
|||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
__all__ = [
|
||||
"BatchNorm",
|
||||
"IBN",
|
||||
"GhostBatchNorm",
|
||||
"FrozenBatchNorm",
|
||||
"SyncBatchNorm",
|
||||
"get_norm",
|
||||
]
|
||||
try:
|
||||
from apex import parallel
|
||||
except ImportError:
|
||||
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run model with syncBN")
|
||||
|
||||
__all__ = ["IBN", "get_norm"]
|
||||
|
||||
|
||||
class BatchNorm(nn.BatchNorm2d):
|
||||
|
@ -30,7 +28,7 @@ class BatchNorm(nn.BatchNorm2d):
|
|||
self.bias.requires_grad_(not bias_freeze)
|
||||
|
||||
|
||||
class SyncBatchNorm(nn.SyncBatchNorm):
|
||||
class SyncBatchNorm(parallel.SyncBatchNorm):
|
||||
def __init__(self, num_features, eps=1e-05, momentum=0.1, weight_freeze=False, bias_freeze=False, weight_init=1.0,
|
||||
bias_init=0.0):
|
||||
super().__init__(num_features, eps=eps, momentum=momentum)
|
||||
|
|
|
@ -10,7 +10,7 @@ import torch.nn.functional as F
|
|||
from torch.nn import Parameter
|
||||
|
||||
|
||||
class AMSoftmax(nn.Module):
|
||||
class CosSoftmax(nn.Module):
|
||||
r"""Implement of large margin cosine distance:
|
||||
Args:
|
||||
in_feat: size of each input sample
|
|
@ -11,7 +11,7 @@ class Non_local(nn.Module):
|
|||
super(Non_local, self).__init__()
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.inter_channels = in_channels // reduc_ratio
|
||||
self.inter_channels = reduc_ratio // reduc_ratio
|
||||
|
||||
self.g = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels,
|
||||
kernel_size=1, stride=1, padding=0)
|
||||
|
|
|
@ -19,7 +19,7 @@ class SplAtConv2d(nn.Module):
|
|||
def __init__(self, in_channels, channels, kernel_size, stride=(1, 1), padding=(0, 0),
|
||||
dilation=(1, 1), groups=1, bias=True,
|
||||
radix=2, reduction_factor=4,
|
||||
rectify=False, rectify_avg=False, norm_layer=None, num_splits=1,
|
||||
rectify=False, rectify_avg=False, norm_layer=None,
|
||||
dropblock_prob=0.0, **kwargs):
|
||||
super(SplAtConv2d, self).__init__()
|
||||
padding = _pair(padding)
|
||||
|
@ -45,7 +45,8 @@ class SplAtConv2d(nn.Module):
|
|||
if self.use_bn:
|
||||
self.bn1 = get_norm(norm_layer, inter_channels)
|
||||
self.fc2 = Conv2d(inter_channels, channels * radix, 1, groups=self.cardinality)
|
||||
|
||||
if dropblock_prob > 0.0:
|
||||
self.dropblock = DropBlock2D(dropblock_prob, 3)
|
||||
self.rsoftmax = rSoftMax(radix, groups)
|
||||
|
||||
def forward(self, x):
|
||||
|
@ -58,7 +59,10 @@ class SplAtConv2d(nn.Module):
|
|||
|
||||
batch, rchannel = x.shape[:2]
|
||||
if self.radix > 1:
|
||||
splited = torch.split(x, rchannel // self.radix, dim=1)
|
||||
if torch.__version__ < '1.5':
|
||||
splited = torch.split(x, int(rchannel // self.radix), dim=1)
|
||||
else:
|
||||
splited = torch.split(x, rchannel // self.radix, dim=1)
|
||||
gap = sum(splited)
|
||||
else:
|
||||
gap = x
|
||||
|
@ -73,7 +77,10 @@ class SplAtConv2d(nn.Module):
|
|||
atten = self.rsoftmax(atten).view(batch, -1, 1, 1)
|
||||
|
||||
if self.radix > 1:
|
||||
attens = torch.split(atten, rchannel // self.radix, dim=1)
|
||||
if torch.__version__ < '1.5':
|
||||
attens = torch.split(atten, int(rchannel // self.radix), dim=1)
|
||||
else:
|
||||
attens = torch.split(atten, rchannel // self.radix, dim=1)
|
||||
out = sum([att * split for (att, split) in zip(attens, splited)])
|
||||
else:
|
||||
out = atten * x
|
||||
|
@ -95,3 +102,8 @@ class rSoftMax(nn.Module):
|
|||
else:
|
||||
x = torch.sigmoid(x)
|
||||
return x
|
||||
|
||||
|
||||
class DropBlock2D(object):
|
||||
def __init__(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -14,6 +14,7 @@ from torch import nn
|
|||
|
||||
from fastreid.layers import get_norm
|
||||
from fastreid.utils import comm
|
||||
from fastreid.utils.checkpoint import get_missing_parameters_message, get_unexpected_parameters_message
|
||||
from .build import BACKBONE_REGISTRY
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -463,23 +464,7 @@ def init_pretrained_weights(model, key=''):
|
|||
discarded_layers.append(k)
|
||||
|
||||
model_dict.update(new_state_dict)
|
||||
model.load_state_dict(model_dict)
|
||||
|
||||
if len(matched_layers) == 0:
|
||||
warnings.warn(
|
||||
'The pretrained weights from "{}" cannot be loaded, '
|
||||
'please check the key names manually '
|
||||
'(** ignored and continue **)'.format(cached_file)
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
'Successfully loaded imagenet pretrained weights from "{}"'.format(cached_file)
|
||||
)
|
||||
if len(discarded_layers) > 0:
|
||||
logger.info(
|
||||
'** The following layers are discarded '
|
||||
'due to unmatched keys or layer size: {}'.format(discarded_layers)
|
||||
)
|
||||
return model_dict
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
|
@ -513,7 +498,6 @@ def build_osnet_backbone(cfg):
|
|||
try:
|
||||
state_dict = torch.load(pretrain_path, map_location=torch.device('cpu'))
|
||||
logger.info(f"Loading pretrained model from {pretrain_path}")
|
||||
model.load_state_dict(state_dict)
|
||||
except FileNotFoundError as e:
|
||||
logger.info(f'{pretrain_path} is not found! Please check this path.')
|
||||
raise e
|
||||
|
@ -526,5 +510,15 @@ def build_osnet_backbone(cfg):
|
|||
else:
|
||||
pretrain_key = "osnet_" + depth
|
||||
|
||||
init_pretrained_weights(model, pretrain_key)
|
||||
state_dict = init_pretrained_weights(model, pretrain_key)
|
||||
|
||||
incompatible = model.load_state_dict(state_dict, strict=False)
|
||||
if incompatible.missing_keys:
|
||||
logger.info(
|
||||
get_missing_parameters_message(incompatible.missing_keys)
|
||||
)
|
||||
if incompatible.unexpected_keys:
|
||||
logger.info(
|
||||
get_unexpected_parameters_message(incompatible.unexpected_keys)
|
||||
)
|
||||
return model
|
||||
|
|
|
@ -9,12 +9,7 @@ import math
|
|||
import torch
|
||||
from torch import nn
|
||||
|
||||
from fastreid.layers import (
|
||||
IBN,
|
||||
Non_local,
|
||||
SplAtConv2d,
|
||||
get_norm,
|
||||
)
|
||||
from fastreid.layers import SplAtConv2d, get_norm, DropBlock2D
|
||||
from fastreid.utils.checkpoint import get_unexpected_parameters_message, get_missing_parameters_message
|
||||
from .build import BACKBONE_REGISTRY
|
||||
|
||||
|
@ -46,18 +41,15 @@ class Bottleneck(nn.Module):
|
|||
# pylint: disable=unused-argument
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, inplanes, planes, bn_norm, with_ibn=False, stride=1, downsample=None,
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None,
|
||||
radix=1, cardinality=1, bottleneck_width=64,
|
||||
avd=False, avd_first=False, dilation=1, is_first=False,
|
||||
rectified_conv=False, rectify_avg=False,
|
||||
dropblock_prob=0.0, last_gamma=False):
|
||||
norm_layer=None, dropblock_prob=0.0, last_gamma=False):
|
||||
super(Bottleneck, self).__init__()
|
||||
group_width = int(planes * (bottleneck_width / 64.)) * cardinality
|
||||
self.conv1 = nn.Conv2d(inplanes, group_width, kernel_size=1, bias=False)
|
||||
if with_ibn:
|
||||
self.bn1 = IBN(group_width, bn_norm)
|
||||
else:
|
||||
self.bn1 = get_norm(bn_norm, group_width)
|
||||
self.bn1 = get_norm(norm_layer, group_width)
|
||||
self.dropblock_prob = dropblock_prob
|
||||
self.radix = radix
|
||||
self.avd = avd and (stride > 1 or is_first)
|
||||
|
@ -67,14 +59,20 @@ class Bottleneck(nn.Module):
|
|||
self.avd_layer = nn.AvgPool2d(3, stride, padding=1)
|
||||
stride = 1
|
||||
|
||||
if radix > 1:
|
||||
if dropblock_prob > 0.0:
|
||||
self.dropblock1 = DropBlock2D(dropblock_prob, 3)
|
||||
if radix == 1:
|
||||
self.dropblock2 = DropBlock2D(dropblock_prob, 3)
|
||||
self.dropblock3 = DropBlock2D(dropblock_prob, 3)
|
||||
|
||||
if radix >= 1:
|
||||
self.conv2 = SplAtConv2d(
|
||||
group_width, group_width, kernel_size=3,
|
||||
stride=stride, padding=dilation,
|
||||
dilation=dilation, groups=cardinality, bias=False,
|
||||
radix=radix, rectify=rectified_conv,
|
||||
rectify_avg=rectify_avg,
|
||||
norm_layer=bn_norm,
|
||||
norm_layer=norm_layer,
|
||||
dropblock_prob=dropblock_prob)
|
||||
elif rectified_conv:
|
||||
from rfconv import RFConv2d
|
||||
|
@ -83,17 +81,17 @@ class Bottleneck(nn.Module):
|
|||
padding=dilation, dilation=dilation,
|
||||
groups=cardinality, bias=False,
|
||||
average_mode=rectify_avg)
|
||||
self.bn2 = get_norm(bn_norm, group_width)
|
||||
self.bn2 = get_norm(norm_layer, group_width)
|
||||
else:
|
||||
self.conv2 = nn.Conv2d(
|
||||
group_width, group_width, kernel_size=3, stride=stride,
|
||||
padding=dilation, dilation=dilation,
|
||||
groups=cardinality, bias=False)
|
||||
self.bn2 = get_norm(bn_norm, group_width)
|
||||
self.bn2 = get_norm(norm_layer, group_width)
|
||||
|
||||
self.conv3 = nn.Conv2d(
|
||||
group_width, planes * 4, kernel_size=1, bias=False)
|
||||
self.bn3 = get_norm(bn_norm, planes * 4)
|
||||
self.bn3 = get_norm(norm_layer, planes * 4)
|
||||
|
||||
if last_gamma:
|
||||
from torch.nn.init import zeros_
|
||||
|
@ -116,7 +114,7 @@ class Bottleneck(nn.Module):
|
|||
out = self.avd_layer(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
if self.radix == 1:
|
||||
if self.radix == 0:
|
||||
out = self.bn2(out)
|
||||
if self.dropblock_prob > 0.0:
|
||||
out = self.dropblock2(out)
|
||||
|
@ -139,8 +137,8 @@ class Bottleneck(nn.Module):
|
|||
return out
|
||||
|
||||
|
||||
class ResNest(nn.Module):
|
||||
"""ResNet Variants ResNest
|
||||
class ResNeSt(nn.Module):
|
||||
"""ResNet Variants
|
||||
Parameters
|
||||
----------
|
||||
block : Block
|
||||
|
@ -161,15 +159,15 @@ class ResNest(nn.Module):
|
|||
"""
|
||||
|
||||
# pylint: disable=unused-variable
|
||||
def __init__(self, last_stride, bn_norm, with_ibn, with_nl, block, layers, non_layers, radix=1,
|
||||
groups=1,
|
||||
bottleneck_width=64,
|
||||
def __init__(self, last_stride, block, layers, radix=1, groups=1, bottleneck_width=64,
|
||||
dilated=False, dilation=1,
|
||||
deep_stem=False, stem_width=64, avg_down=False,
|
||||
rectified_conv=False, rectify_avg=False,
|
||||
avd=False, avd_first=False,
|
||||
final_drop=0.0, dropblock_prob=0,
|
||||
last_gamma=False):
|
||||
last_gamma=False, norm_layer="BN"):
|
||||
if last_stride == 1: dilation = 2
|
||||
|
||||
self.cardinality = groups
|
||||
self.bottleneck_width = bottleneck_width
|
||||
# ResNet-D params
|
||||
|
@ -193,52 +191,51 @@ class ResNest(nn.Module):
|
|||
if deep_stem:
|
||||
self.conv1 = nn.Sequential(
|
||||
conv_layer(3, stem_width, kernel_size=3, stride=2, padding=1, bias=False, **conv_kwargs),
|
||||
get_norm(bn_norm, stem_width),
|
||||
get_norm(norm_layer, stem_width),
|
||||
nn.ReLU(inplace=True),
|
||||
conv_layer(stem_width, stem_width, kernel_size=3, stride=1, padding=1, bias=False, **conv_kwargs),
|
||||
get_norm(bn_norm, stem_width),
|
||||
get_norm(norm_layer, stem_width),
|
||||
nn.ReLU(inplace=True),
|
||||
conv_layer(stem_width, stem_width * 2, kernel_size=3, stride=1, padding=1, bias=False, **conv_kwargs),
|
||||
)
|
||||
else:
|
||||
self.conv1 = conv_layer(3, 64, kernel_size=7, stride=2, padding=3,
|
||||
bias=False, **conv_kwargs)
|
||||
self.bn1 = get_norm(bn_norm, self.inplanes)
|
||||
self.bn1 = get_norm(norm_layer, self.inplanes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
self.layer1 = self._make_layer(block, 64, layers[0], 1, bn_norm, with_ibn=with_ibn, is_first=False)
|
||||
self.layer2 = self._make_layer(block, 128, layers[1], 2, bn_norm, with_ibn=with_ibn)
|
||||
self.layer1 = self._make_layer(block, 64, layers[0], norm_layer=norm_layer, is_first=False)
|
||||
self.layer2 = self._make_layer(block, 128, layers[1], stride=2, norm_layer=norm_layer)
|
||||
if dilated or dilation == 4:
|
||||
self.layer3 = self._make_layer(block, 256, layers[2], 1, bn_norm, with_ibn=with_ibn,
|
||||
dilation=2, dropblock_prob=dropblock_prob)
|
||||
self.layer4 = self._make_layer(block, 512, layers[3], 1, bn_norm, with_ibn=with_ibn,
|
||||
dilation=4, dropblock_prob=dropblock_prob)
|
||||
self.layer3 = self._make_layer(block, 256, layers[2], stride=1,
|
||||
dilation=2, norm_layer=norm_layer,
|
||||
dropblock_prob=dropblock_prob)
|
||||
self.layer4 = self._make_layer(block, 512, layers[3], stride=1,
|
||||
dilation=4, norm_layer=norm_layer,
|
||||
dropblock_prob=dropblock_prob)
|
||||
elif dilation == 2:
|
||||
self.layer3 = self._make_layer(block, 256, layers[2], 2, bn_norm, with_ibn=with_ibn,
|
||||
dilation=1, dropblock_prob=dropblock_prob)
|
||||
self.layer4 = self._make_layer(block, 512, layers[3], 1, bn_norm, with_ibn=with_ibn,
|
||||
dilation=2, dropblock_prob=dropblock_prob)
|
||||
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
|
||||
dilation=1, norm_layer=norm_layer,
|
||||
dropblock_prob=dropblock_prob)
|
||||
self.layer4 = self._make_layer(block, 512, layers[3], stride=1,
|
||||
dilation=2, norm_layer=norm_layer,
|
||||
dropblock_prob=dropblock_prob)
|
||||
else:
|
||||
self.layer3 = self._make_layer(block, 256, layers[2], 2, bn_norm, with_ibn=with_ibn,
|
||||
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
|
||||
norm_layer=norm_layer,
|
||||
dropblock_prob=dropblock_prob)
|
||||
self.layer4 = self._make_layer(block, 512, layers[3], last_stride, bn_norm, with_ibn=with_ibn,
|
||||
self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
|
||||
norm_layer=norm_layer,
|
||||
dropblock_prob=dropblock_prob)
|
||||
self.drop = nn.Dropout(final_drop) if final_drop > 0.0 else None
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
m.weight.data.normal_(0, math.sqrt(2. / n))
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
m.weight.data.fill_(1)
|
||||
m.bias.data.zero_()
|
||||
|
||||
# fmt: off
|
||||
if with_nl: self._build_nonlocal(layers, non_layers, bn_norm)
|
||||
else: self.NL_1_idx = self.NL_2_idx = self.NL_3_idx = self.NL_4_idx = []
|
||||
# fmt: on
|
||||
|
||||
def _make_layer(self, block, planes, blocks, stride=1, bn_norm="BN", with_ibn=False,
|
||||
dilation=1, dropblock_prob=0.0, is_first=True):
|
||||
def _make_layer(self, block, planes, blocks, stride=1, dilation=1, norm_layer=None,
|
||||
dropblock_prob=0.0, is_first=True):
|
||||
downsample = None
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
down_layers = []
|
||||
|
@ -254,103 +251,54 @@ class ResNest(nn.Module):
|
|||
else:
|
||||
down_layers.append(nn.Conv2d(self.inplanes, planes * block.expansion,
|
||||
kernel_size=1, stride=stride, bias=False))
|
||||
down_layers.append(get_norm(bn_norm, planes * block.expansion))
|
||||
down_layers.append(get_norm(norm_layer, planes * block.expansion))
|
||||
downsample = nn.Sequential(*down_layers)
|
||||
|
||||
layers = []
|
||||
if dilation == 1 or dilation == 2:
|
||||
layers.append(block(self.inplanes, planes, bn_norm, with_ibn, stride, downsample=downsample,
|
||||
layers.append(block(self.inplanes, planes, stride, downsample=downsample,
|
||||
radix=self.radix, cardinality=self.cardinality,
|
||||
bottleneck_width=self.bottleneck_width,
|
||||
avd=self.avd, avd_first=self.avd_first,
|
||||
dilation=1, is_first=is_first, rectified_conv=self.rectified_conv,
|
||||
rectify_avg=self.rectify_avg,
|
||||
dropblock_prob=dropblock_prob,
|
||||
norm_layer=norm_layer, dropblock_prob=dropblock_prob,
|
||||
last_gamma=self.last_gamma))
|
||||
elif dilation == 4:
|
||||
layers.append(block(self.inplanes, planes, bn_norm, with_ibn, stride, downsample=downsample,
|
||||
layers.append(block(self.inplanes, planes, stride, downsample=downsample,
|
||||
radix=self.radix, cardinality=self.cardinality,
|
||||
bottleneck_width=self.bottleneck_width,
|
||||
avd=self.avd, avd_first=self.avd_first,
|
||||
dilation=2, is_first=is_first, rectified_conv=self.rectified_conv,
|
||||
rectify_avg=self.rectify_avg,
|
||||
dropblock_prob=dropblock_prob,
|
||||
norm_layer=norm_layer, dropblock_prob=dropblock_prob,
|
||||
last_gamma=self.last_gamma))
|
||||
else:
|
||||
raise RuntimeError("=> unknown dilation size: {}".format(dilation))
|
||||
|
||||
self.inplanes = planes * block.expansion
|
||||
for i in range(1, blocks):
|
||||
layers.append(block(self.inplanes, planes, bn_norm, with_ibn,
|
||||
layers.append(block(self.inplanes, planes,
|
||||
radix=self.radix, cardinality=self.cardinality,
|
||||
bottleneck_width=self.bottleneck_width,
|
||||
avd=self.avd, avd_first=self.avd_first,
|
||||
dilation=dilation, rectified_conv=self.rectified_conv,
|
||||
rectify_avg=self.rectify_avg,
|
||||
dropblock_prob=dropblock_prob,
|
||||
norm_layer=norm_layer, dropblock_prob=dropblock_prob,
|
||||
last_gamma=self.last_gamma))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def _build_nonlocal(self, layers, non_layers, bn_norm):
|
||||
self.NL_1 = nn.ModuleList(
|
||||
[Non_local(256, bn_norm) for _ in range(non_layers[0])])
|
||||
self.NL_1_idx = sorted([layers[0] - (i + 1) for i in range(non_layers[0])])
|
||||
self.NL_2 = nn.ModuleList(
|
||||
[Non_local(512, bn_norm) for _ in range(non_layers[1])])
|
||||
self.NL_2_idx = sorted([layers[1] - (i + 1) for i in range(non_layers[1])])
|
||||
self.NL_3 = nn.ModuleList(
|
||||
[Non_local(1024, bn_norm) for _ in range(non_layers[2])])
|
||||
self.NL_3_idx = sorted([layers[2] - (i + 1) for i in range(non_layers[2])])
|
||||
self.NL_4 = nn.ModuleList(
|
||||
[Non_local(2048, bn_norm) for _ in range(non_layers[3])])
|
||||
self.NL_4_idx = sorted([layers[3] - (i + 1) for i in range(non_layers[3])])
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
x = self.maxpool(x)
|
||||
|
||||
NL1_counter = 0
|
||||
if len(self.NL_1_idx) == 0:
|
||||
self.NL_1_idx = [-1]
|
||||
for i in range(len(self.layer1)):
|
||||
x = self.layer1[i](x)
|
||||
if i == self.NL_1_idx[NL1_counter]:
|
||||
_, C, H, W = x.shape
|
||||
x = self.NL_1[NL1_counter](x)
|
||||
NL1_counter += 1
|
||||
# Layer 2
|
||||
NL2_counter = 0
|
||||
if len(self.NL_2_idx) == 0:
|
||||
self.NL_2_idx = [-1]
|
||||
for i in range(len(self.layer2)):
|
||||
x = self.layer2[i](x)
|
||||
if i == self.NL_2_idx[NL2_counter]:
|
||||
_, C, H, W = x.shape
|
||||
x = self.NL_2[NL2_counter](x)
|
||||
NL2_counter += 1
|
||||
# Layer 3
|
||||
NL3_counter = 0
|
||||
if len(self.NL_3_idx) == 0:
|
||||
self.NL_3_idx = [-1]
|
||||
for i in range(len(self.layer3)):
|
||||
x = self.layer3[i](x)
|
||||
if i == self.NL_3_idx[NL3_counter]:
|
||||
_, C, H, W = x.shape
|
||||
x = self.NL_3[NL3_counter](x)
|
||||
NL3_counter += 1
|
||||
# Layer 4
|
||||
NL4_counter = 0
|
||||
if len(self.NL_4_idx) == 0:
|
||||
self.NL_4_idx = [-1]
|
||||
for i in range(len(self.layer4)):
|
||||
x = self.layer4[i](x)
|
||||
if i == self.NL_4_idx[NL4_counter]:
|
||||
_, C, H, W = x.shape
|
||||
x = self.NL_4[NL4_counter](x)
|
||||
NL4_counter += 1
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
|
||||
return x
|
||||
|
||||
|
@ -368,9 +316,6 @@ def build_resnest_backbone(cfg):
|
|||
pretrain_path = cfg.MODEL.BACKBONE.PRETRAIN_PATH
|
||||
last_stride = cfg.MODEL.BACKBONE.LAST_STRIDE
|
||||
bn_norm = cfg.MODEL.BACKBONE.NORM
|
||||
with_ibn = cfg.MODEL.BACKBONE.WITH_IBN
|
||||
with_se = cfg.MODEL.BACKBONE.WITH_SE
|
||||
with_nl = cfg.MODEL.BACKBONE.WITH_NL
|
||||
depth = cfg.MODEL.BACKBONE.DEPTH
|
||||
# fmt: on
|
||||
|
||||
|
@ -381,13 +326,6 @@ def build_resnest_backbone(cfg):
|
|||
"269x": [3, 30, 48, 8],
|
||||
}[depth]
|
||||
|
||||
nl_layers_per_stage = {
|
||||
"50x": [0, 2, 3, 0],
|
||||
"101x": [0, 2, 3, 0],
|
||||
"200x": [0, 2, 3, 0],
|
||||
"269x": [0, 2, 3, 0],
|
||||
}[depth]
|
||||
|
||||
stem_width = {
|
||||
"50x": 32,
|
||||
"101x": 64,
|
||||
|
@ -395,10 +333,10 @@ def build_resnest_backbone(cfg):
|
|||
"269x": 64,
|
||||
}[depth]
|
||||
|
||||
model = ResNest(last_stride, bn_norm, with_ibn, with_nl, Bottleneck, num_blocks_per_stage,
|
||||
nl_layers_per_stage, radix=2, groups=1, bottleneck_width=64,
|
||||
model = ResNeSt(last_stride, Bottleneck, num_blocks_per_stage,
|
||||
radix=2, groups=1, bottleneck_width=64,
|
||||
deep_stem=True, stem_width=stem_width, avg_down=True,
|
||||
avd=True, avd_first=False)
|
||||
avd=True, avd_first=False, norm_layer=bn_norm)
|
||||
if pretrain:
|
||||
# Load pretrain path if specifically
|
||||
if pretrain_path:
|
||||
|
|
|
@ -39,7 +39,7 @@ class AttrHead(nn.Module):
|
|||
if cls_type == 'linear': self.classifier = nn.Linear(feat_dim, num_classes, bias=False)
|
||||
elif cls_type == 'arcSoftmax': self.classifier = ArcSoftmax(cfg, feat_dim, num_classes)
|
||||
elif cls_type == 'circleSoftmax': self.classifier = CircleSoftmax(cfg, feat_dim, num_classes)
|
||||
elif cls_type == 'amSoftmax': self.classifier = AMSoftmax(cfg, feat_dim, num_classes)
|
||||
elif cls_type == 'amSoftmax': self.classifier = CosSoftmax(cfg, feat_dim, num_classes)
|
||||
else: raise KeyError(f"{cls_type} is not supported!")
|
||||
# fmt: on
|
||||
|
||||
|
|
|
@ -55,7 +55,7 @@ class EmbeddingHead(nn.Module):
|
|||
if cls_type == 'linear': self.classifier = nn.Linear(feat_dim, num_classes, bias=False)
|
||||
elif cls_type == 'arcSoftmax': self.classifier = ArcSoftmax(cfg, feat_dim, num_classes)
|
||||
elif cls_type == 'circleSoftmax': self.classifier = CircleSoftmax(cfg, feat_dim, num_classes)
|
||||
elif cls_type == 'amSoftmax': self.classifier = AMSoftmax(cfg, feat_dim, num_classes)
|
||||
elif cls_type == 'cosSoftmax': self.classifier = CosSoftmax(cfg, feat_dim, num_classes)
|
||||
else: raise KeyError(f"{cls_type} is not supported!")
|
||||
# fmt: on
|
||||
|
||||
|
|
|
@ -7,4 +7,4 @@
|
|||
from .cross_entroy_loss import cross_entropy_loss, log_accuracy
|
||||
from .focal_loss import focal_loss
|
||||
from .triplet_loss import triplet_loss
|
||||
from .circle_loss import circle_loss
|
||||
from .circle_loss import *
|
||||
|
|
|
@ -46,14 +46,6 @@ def cross_entropy_loss(pred_class_outputs, gt_classes, eps, alpha=0.2):
|
|||
|
||||
loss = (-targets * log_probs).sum(dim=1)
|
||||
|
||||
"""
|
||||
# confidence penalty
|
||||
conf_penalty = 0.3
|
||||
probs = F.softmax(pred_class_logits, dim=1)
|
||||
entropy = torch.sum(-probs * log_probs, dim=1)
|
||||
loss = torch.clamp_min(loss - conf_penalty * entropy, min=0.)
|
||||
"""
|
||||
|
||||
with torch.no_grad():
|
||||
non_zero_cnt = max(loss.nonzero(as_tuple=False).size(0), 1)
|
||||
|
||||
|
|
|
@ -9,7 +9,7 @@ import torch.nn.functional as F
|
|||
|
||||
from fastreid.utils import comm
|
||||
from fastreid.layers import GatherLayer
|
||||
from .utils import concat_all_gather, euclidean_dist, normalize
|
||||
from .utils import concat_all_gather, euclidean_dist, normalize, cosine_dist
|
||||
|
||||
|
||||
def softmax_weights(dist, mask):
|
||||
|
@ -87,21 +87,21 @@ def triplet_loss(embedding, targets, margin, norm_feat, hard_mining):
|
|||
Related Triplet Loss theory can be found in paper 'In Defense of the Triplet
|
||||
Loss for Person Re-Identification'."""
|
||||
|
||||
if norm_feat: embedding = normalize(embedding, axis=-1)
|
||||
|
||||
# For distributed training, gather all features from different process.
|
||||
if comm.get_world_size() > 1:
|
||||
all_embedding = torch.cat(GatherLayer.apply(embedding), dim=0)
|
||||
all_targets = concat_all_gather(targets)
|
||||
if norm_feat:
|
||||
dist_mat = cosine_dist(embedding, embedding)
|
||||
else:
|
||||
all_embedding = embedding
|
||||
all_targets = targets
|
||||
dist_mat = euclidean_dist(embedding, embedding)
|
||||
# For distributed training, gather all features from different process.
|
||||
# if comm.get_world_size() > 1:
|
||||
# all_embedding = torch.cat(GatherLayer.apply(embedding), dim=0)
|
||||
# all_targets = concat_all_gather(targets)
|
||||
# else:
|
||||
# all_embedding = embedding
|
||||
# all_targets = targets
|
||||
|
||||
dist_mat = euclidean_dist(all_embedding, all_embedding)
|
||||
|
||||
N, N = dist_mat.size()
|
||||
is_pos = all_targets.view(N, 1).expand(N, N).eq(all_targets.view(N, 1).expand(N, N).t())
|
||||
is_neg = all_targets.view(N, 1).expand(N, N).ne(all_targets.view(N, 1).expand(N, N).t())
|
||||
N = dist_mat.size(0)
|
||||
is_pos = targets.view(N, 1).expand(N, N).eq(targets.view(N, 1).expand(N, N).t())
|
||||
is_neg = targets.view(N, 1).expand(N, N).ne(targets.view(N, 1).expand(N, N).t())
|
||||
|
||||
if hard_mining:
|
||||
dist_ap, dist_an = hard_example_mining(dist_mat, is_pos, is_neg)
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def concat_all_gather(tensor):
|
||||
|
@ -41,9 +42,7 @@ def euclidean_dist(x, y):
|
|||
|
||||
|
||||
def cosine_dist(x, y):
|
||||
bs1, bs2 = x.size(0), y.size(0)
|
||||
frac_up = torch.matmul(x, y.transpose(0, 1))
|
||||
frac_down = (torch.sqrt(torch.sum(torch.pow(x, 2), 1))).view(bs1, 1).repeat(1, bs2) * \
|
||||
(torch.sqrt(torch.sum(torch.pow(y, 2), 1))).view(1, bs2).repeat(bs1, 1)
|
||||
cosine = frac_up / frac_down
|
||||
return 1 - cosine
|
||||
x = F.normalize(x, dim=1)
|
||||
y = F.normalize(y, dim=1)
|
||||
dist = 2 - 2 * torch.mm(x, y.t())
|
||||
return dist
|
||||
|
|
|
@ -89,7 +89,7 @@ class Baseline(nn.Module):
|
|||
loss_names = self._cfg.MODEL.LOSSES.NAME
|
||||
|
||||
if "CrossEntropyLoss" in loss_names:
|
||||
loss_dict['loss_cls'] = cross_entropy_loss(
|
||||
loss_dict["loss_cls"] = cross_entropy_loss(
|
||||
cls_outputs,
|
||||
gt_labels,
|
||||
self._cfg.MODEL.LOSSES.CE.EPSILON,
|
||||
|
@ -97,7 +97,7 @@ class Baseline(nn.Module):
|
|||
) * self._cfg.MODEL.LOSSES.CE.SCALE
|
||||
|
||||
if "TripletLoss" in loss_names:
|
||||
loss_dict['loss_triplet'] = triplet_loss(
|
||||
loss_dict["loss_triplet"] = triplet_loss(
|
||||
pred_features,
|
||||
gt_labels,
|
||||
self._cfg.MODEL.LOSSES.TRI.MARGIN,
|
||||
|
@ -106,11 +106,11 @@ class Baseline(nn.Module):
|
|||
) * self._cfg.MODEL.LOSSES.TRI.SCALE
|
||||
|
||||
if "CircleLoss" in loss_names:
|
||||
loss_dict['loss_circle'] = circle_loss(
|
||||
loss_dict["loss_circle"] = pairwise_circleloss(
|
||||
pred_features,
|
||||
gt_labels,
|
||||
self._cfg.MODEL.LOSSES.CIRCLE.MARGIN,
|
||||
self._cfg.MODEL.LOSSES.CIRCLE.ALPHA,
|
||||
self._cfg.MODEL.LOSSES.CIRCLE.GAMMA,
|
||||
) * self._cfg.MODEL.LOSSES.CIRCLE.SCALE
|
||||
|
||||
return loss_dict
|
||||
|
|
|
@ -20,7 +20,7 @@ def build_optimizer(cfg, model):
|
|||
if "bias" in key:
|
||||
lr *= cfg.SOLVER.BIAS_LR_FACTOR
|
||||
weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS
|
||||
params += [{"name": key, "params": [value], "lr": lr, "weight_decay": weight_decay, "freeze": False}]
|
||||
params += [{"name": key, "params": [value], "lr": lr, "weight_decay": weight_decay}]
|
||||
|
||||
solver_opt = cfg.SOLVER.OPT
|
||||
# fmt: off
|
||||
|
@ -31,22 +31,36 @@ def build_optimizer(cfg, model):
|
|||
|
||||
|
||||
def build_lr_scheduler(cfg, optimizer):
|
||||
scheduler_dict = {}
|
||||
|
||||
if cfg.SOLVER.WARMUP_ITERS > 0:
|
||||
warmup_args = {
|
||||
"optimizer": optimizer,
|
||||
|
||||
# warmup options
|
||||
"warmup_factor": cfg.SOLVER.WARMUP_FACTOR,
|
||||
"warmup_iters": cfg.SOLVER.WARMUP_ITERS,
|
||||
"warmup_method": cfg.SOLVER.WARMUP_METHOD,
|
||||
}
|
||||
scheduler_dict["warmup_sched"] = lr_scheduler.WarmupLR(**warmup_args)
|
||||
|
||||
scheduler_args = {
|
||||
"optimizer": optimizer,
|
||||
|
||||
# warmup options
|
||||
"warmup_factor": cfg.SOLVER.WARMUP_FACTOR,
|
||||
"warmup_iters": cfg.SOLVER.WARMUP_ITERS,
|
||||
"warmup_method": cfg.SOLVER.WARMUP_METHOD,
|
||||
|
||||
# multi-step lr scheduler options
|
||||
"milestones": cfg.SOLVER.STEPS,
|
||||
"gamma": cfg.SOLVER.GAMMA,
|
||||
|
||||
# cosine annealing lr scheduler options
|
||||
"max_iters": cfg.SOLVER.MAX_ITER,
|
||||
"delay_iters": cfg.SOLVER.DELAY_ITERS,
|
||||
"eta_min_lr": cfg.SOLVER.ETA_MIN_LR,
|
||||
"MultiStepLR": {
|
||||
"optimizer": optimizer,
|
||||
# multi-step lr scheduler options
|
||||
"milestones": cfg.SOLVER.STEPS,
|
||||
"gamma": cfg.SOLVER.GAMMA,
|
||||
},
|
||||
"CosineAnnealingLR": {
|
||||
"optimizer": optimizer,
|
||||
# cosine annealing lr scheduler options
|
||||
"T_max": cfg.SOLVER.MAX_EPOCH,
|
||||
"eta_min": cfg.SOLVER.ETA_MIN_LR,
|
||||
},
|
||||
|
||||
}
|
||||
return getattr(lr_scheduler, cfg.SOLVER.SCHED)(**scheduler_args)
|
||||
|
||||
scheduler_dict["lr_sched"] = getattr(lr_scheduler, cfg.SOLVER.SCHED)(
|
||||
**scheduler_args[cfg.SOLVER.SCHED])
|
||||
|
||||
return scheduler_dict
|
||||
|
|
|
@ -4,34 +4,22 @@
|
|||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import math
|
||||
from bisect import bisect_right
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from torch.optim.lr_scheduler import *
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
|
||||
__all__ = ["WarmupMultiStepLR", "WarmupCosineAnnealingLR"]
|
||||
|
||||
|
||||
class WarmupMultiStepLR(_LRScheduler):
|
||||
class WarmupLR(_LRScheduler):
|
||||
def __init__(
|
||||
self,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
milestones: List[int],
|
||||
gamma: float = 0.1,
|
||||
warmup_factor: float = 0.001,
|
||||
warmup_iters: int = 1000,
|
||||
warmup_factor: float = 0.1,
|
||||
warmup_iters: int = 10,
|
||||
warmup_method: str = "linear",
|
||||
last_epoch: int = -1,
|
||||
**kwargs,
|
||||
):
|
||||
if not list(milestones) == sorted(milestones):
|
||||
raise ValueError(
|
||||
"Milestones should be a list of" " increasing integers. Got {}", milestones
|
||||
)
|
||||
self.milestones = milestones
|
||||
self.gamma = gamma
|
||||
self.warmup_factor = warmup_factor
|
||||
self.warmup_iters = warmup_iters
|
||||
self.warmup_method = warmup_method
|
||||
|
@ -42,8 +30,7 @@ class WarmupMultiStepLR(_LRScheduler):
|
|||
self.warmup_method, self.last_epoch, self.warmup_iters, self.warmup_factor
|
||||
)
|
||||
return [
|
||||
base_lr * warmup_factor * self.gamma ** bisect_right(self.milestones, self.last_epoch)
|
||||
for base_lr in self.base_lrs
|
||||
base_lr * warmup_factor for base_lr in self.base_lrs
|
||||
]
|
||||
|
||||
def _compute_values(self) -> List[float]:
|
||||
|
@ -51,71 +38,6 @@ class WarmupMultiStepLR(_LRScheduler):
|
|||
return self.get_lr()
|
||||
|
||||
|
||||
class WarmupCosineAnnealingLR(_LRScheduler):
|
||||
r"""Set the learning rate of each parameter group using a cosine annealing
|
||||
schedule, where :math:`\eta_{max}` is set to the initial lr and
|
||||
:math:`T_{cur}` is the number of epochs since the last restart in SGDR:
|
||||
|
||||
.. math::
|
||||
\eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})(1 +
|
||||
\cos(\frac{T_{cur}}{T_{max}}\pi))
|
||||
|
||||
When last_epoch=-1, sets initial lr as lr.
|
||||
|
||||
It has been proposed in
|
||||
`SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this only
|
||||
implements the cosine annealing part of SGDR, and not the restarts.
|
||||
|
||||
Args:
|
||||
optimizer (Optimizer): Wrapped optimizer.
|
||||
T_max (int): Maximum number of iterations.
|
||||
eta_min (float): Minimum learning rate. Default: 0.
|
||||
last_epoch (int): The index of last epoch. Default: -1.
|
||||
|
||||
.. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
|
||||
https://arxiv.org/abs/1608.03983
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
max_iters: int,
|
||||
delay_iters: int = 0,
|
||||
eta_min_lr: int = 0,
|
||||
warmup_factor: float = 0.001,
|
||||
warmup_iters: int = 1000,
|
||||
warmup_method: str = "linear",
|
||||
last_epoch=-1,
|
||||
**kwargs
|
||||
):
|
||||
self.max_iters = max_iters
|
||||
self.delay_iters = delay_iters
|
||||
self.eta_min_lr = eta_min_lr
|
||||
self.warmup_factor = warmup_factor
|
||||
self.warmup_iters = warmup_iters
|
||||
self.warmup_method = warmup_method
|
||||
assert self.delay_iters >= self.warmup_iters, "Scheduler delay iters must be larger than warmup iters"
|
||||
super(WarmupCosineAnnealingLR, self).__init__(optimizer, last_epoch)
|
||||
|
||||
def get_lr(self) -> List[float]:
|
||||
if self.last_epoch <= self.warmup_iters:
|
||||
warmup_factor = _get_warmup_factor_at_iter(
|
||||
self.warmup_method, self.last_epoch, self.warmup_iters, self.warmup_factor,
|
||||
)
|
||||
return [
|
||||
base_lr * warmup_factor for base_lr in self.base_lrs
|
||||
]
|
||||
elif self.last_epoch <= self.delay_iters:
|
||||
return self.base_lrs
|
||||
|
||||
else:
|
||||
return [
|
||||
self.eta_min_lr + (base_lr - self.eta_min_lr) *
|
||||
(1 + math.cos(
|
||||
math.pi * (self.last_epoch - self.delay_iters) / (self.max_iters - self.delay_iters))) / 2
|
||||
for base_lr in self.base_lrs]
|
||||
|
||||
|
||||
def _get_warmup_factor_at_iter(
|
||||
method: str, iter: int, warmup_iters: int, warmup_factor: float
|
||||
) -> float:
|
||||
|
@ -137,7 +59,9 @@ def _get_warmup_factor_at_iter(
|
|||
if method == "constant":
|
||||
return warmup_factor
|
||||
elif method == "linear":
|
||||
alpha = iter / warmup_iters
|
||||
return warmup_factor * (1 - alpha) + alpha
|
||||
alpha = (1 - iter / warmup_iters) * (1 - warmup_factor)
|
||||
return 1 - alpha
|
||||
elif method == "exp":
|
||||
return warmup_factor ** (1 - iter / warmup_iters)
|
||||
else:
|
||||
raise ValueError("Unknown warmup method: {}".format(method))
|
||||
|
|
|
@ -1,5 +1,3 @@
|
|||
from .lamb import Lamb
|
||||
from .swa import SWA
|
||||
from .adam import Adam
|
||||
from .sgd import SGD
|
||||
|
||||
from torch.optim import *
|
||||
|
|
|
@ -295,35 +295,35 @@ class PeriodicCheckpointer:
|
|||
multiple of period or if `max_iter` is reached.
|
||||
"""
|
||||
|
||||
def __init__(self, checkpointer: Any, period: int, max_iter: int = None):
|
||||
def __init__(self, checkpointer: Any, period: int, max_epoch: int = None):
|
||||
"""
|
||||
Args:
|
||||
checkpointer (Any): the checkpointer object used to save
|
||||
checkpoints.
|
||||
period (int): the period to save checkpoint.
|
||||
max_iter (int): maximum number of iterations. When it is reached,
|
||||
max_epoch (int): maximum number of epochs. When it is reached,
|
||||
a checkpoint named "model_final" will be saved.
|
||||
"""
|
||||
self.checkpointer = checkpointer
|
||||
self.period = int(period)
|
||||
self.max_iter = max_iter
|
||||
self.max_epoch = max_epoch
|
||||
|
||||
def step(self, iteration: int, **kwargs: Any):
|
||||
def step(self, epoch: int, **kwargs: Any):
|
||||
"""
|
||||
Perform the appropriate action at the given iteration.
|
||||
Args:
|
||||
iteration (int): the current iteration, ranged in [0, max_iter-1].
|
||||
epoch (int): the current epoch, ranged in [0, max_epoch-1].
|
||||
kwargs (Any): extra data to save, same as in
|
||||
:meth:`Checkpointer.save`.
|
||||
"""
|
||||
iteration = int(iteration)
|
||||
additional_state = {"iteration": iteration}
|
||||
epoch = int(epoch)
|
||||
additional_state = {"epoch": epoch}
|
||||
additional_state.update(kwargs)
|
||||
if (iteration + 1) % self.period == 0:
|
||||
if (epoch + 1) % self.period == 0 and epoch < self.max_epoch - 1:
|
||||
self.checkpointer.save(
|
||||
"model_{:07d}".format(iteration), **additional_state
|
||||
"model_{:04d}".format(epoch), **additional_state
|
||||
)
|
||||
if iteration >= self.max_iter - 1:
|
||||
if epoch >= self.max_epoch - 1:
|
||||
self.checkpointer.save("model_final", **additional_state)
|
||||
|
||||
def save(self, name: str, **kwargs: Any):
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
|
@ -55,7 +55,7 @@ class JSONWriter(EventWriter):
|
|||
[
|
||||
{
|
||||
"data_time": 0.008433341979980469,
|
||||
"iteration": 20,
|
||||
"iteration": 19,
|
||||
"loss": 1.9228371381759644,
|
||||
"loss_box_reg": 0.050025828182697296,
|
||||
"loss_classifier": 0.5316952466964722,
|
||||
|
@ -67,7 +67,7 @@ class JSONWriter(EventWriter):
|
|||
},
|
||||
{
|
||||
"data_time": 0.007216215133666992,
|
||||
"iteration": 40,
|
||||
"iteration": 39,
|
||||
"loss": 1.282649278640747,
|
||||
"loss_box_reg": 0.06222952902317047,
|
||||
"loss_classifier": 0.30682939291000366,
|
||||
|
@ -105,8 +105,9 @@ class JSONWriter(EventWriter):
|
|||
if iter <= self._last_write:
|
||||
continue
|
||||
to_save[iter][k] = v
|
||||
all_iters = sorted(to_save.keys())
|
||||
self._last_write = max(all_iters)
|
||||
if len(to_save):
|
||||
all_iters = sorted(to_save.keys())
|
||||
self._last_write = max(all_iters)
|
||||
|
||||
for itr, scalars_per_iter in to_save.items():
|
||||
scalars_per_iter["iteration"] = itr
|
||||
|
@ -192,6 +193,12 @@ class CommonMetricPrinter(EventWriter):
|
|||
def write(self):
|
||||
storage = get_event_storage()
|
||||
iteration = storage.iter
|
||||
epoch = storage.epoch
|
||||
if iteration == self._max_iter:
|
||||
# This hook only reports training progress (loss, ETA, etc) but not other data,
|
||||
# therefore do not write anything after training succeeds, even if this method
|
||||
# is called.
|
||||
return
|
||||
|
||||
try:
|
||||
data_time = storage.history("data_time").avg(20)
|
||||
|
@ -203,7 +210,7 @@ class CommonMetricPrinter(EventWriter):
|
|||
eta_string = None
|
||||
try:
|
||||
iter_time = storage.history("time").global_avg()
|
||||
eta_seconds = storage.history("time").median(1000) * (self._max_iter - iteration)
|
||||
eta_seconds = storage.history("time").median(1000) * (self._max_iter - iteration - 1)
|
||||
storage.put_scalar("eta_seconds", eta_seconds, smoothing_hint=False)
|
||||
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
||||
except KeyError:
|
||||
|
@ -213,7 +220,7 @@ class CommonMetricPrinter(EventWriter):
|
|||
estimate_iter_time = (time.perf_counter() - self._last_write[1]) / (
|
||||
iteration - self._last_write[0]
|
||||
)
|
||||
eta_seconds = estimate_iter_time * (self._max_iter - iteration)
|
||||
eta_seconds = estimate_iter_time * (self._max_iter - iteration - 1)
|
||||
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
||||
self._last_write = (iteration, time.perf_counter())
|
||||
|
||||
|
@ -229,12 +236,13 @@ class CommonMetricPrinter(EventWriter):
|
|||
|
||||
# NOTE: max_mem is parsed by grep in "dev/parse_results.sh"
|
||||
self.logger.info(
|
||||
" {eta}iter: {iter} {losses} {time}{data_time}lr: {lr} {memory}".format(
|
||||
" {eta}epoch/iter: {epoch}/{iter} {losses} {time}{data_time}lr: {lr} {memory}".format(
|
||||
eta=f"eta: {eta_string} " if eta_string else "",
|
||||
epoch=epoch,
|
||||
iter=iteration,
|
||||
losses=" ".join(
|
||||
[
|
||||
"{}: {:.4g}".format(k, v.median(20))
|
||||
"{}: {:.4g}".format(k, v.median(200))
|
||||
for k, v in storage.histories().items()
|
||||
if "loss" in k
|
||||
]
|
||||
|
@ -394,17 +402,25 @@ class EventStorage:
|
|||
|
||||
def step(self):
|
||||
"""
|
||||
User should call this function at the beginning of each iteration, to
|
||||
notify the storage of the start of a new iteration.
|
||||
The storage will then be able to associate the new data with the
|
||||
correct iteration number.
|
||||
User should either: (1) Call this function to increment storage.iter when needed. Or
|
||||
(2) Set `storage.iter` to the correct iteration number before each iteration.
|
||||
The storage will then be able to associate the new data with an iteration number.
|
||||
"""
|
||||
self._iter += 1
|
||||
|
||||
@property
|
||||
def iter(self):
|
||||
"""
|
||||
Returns:
|
||||
int: The current iteration number. When used together with a trainer,
|
||||
this is ensured to be the same as trainer.iter.
|
||||
"""
|
||||
return self._iter
|
||||
|
||||
@iter.setter
|
||||
def iter(self, val):
|
||||
self._iter = int(val)
|
||||
|
||||
@property
|
||||
def iteration(self):
|
||||
# for backward compatibility
|
||||
|
|
|
@ -20,12 +20,13 @@ def weights_init_kaiming(m):
|
|||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0.0)
|
||||
elif classname.find('Conv') != -1:
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out')
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0.0)
|
||||
elif classname.find('BatchNorm') != -1:
|
||||
if m.affine:
|
||||
nn.init.normal_(m.weight, 1.0, 0.02)
|
||||
# nn.init.normal_(m.weight, 1.0, 0.02)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
nn.init.constant_(m.bias, 0.0)
|
||||
|
||||
|
||||
|
|
|
@ -15,3 +15,6 @@ Note that these are research projects, and therefore may not have the same level
|
|||
|
||||
External projects in the community that use fastreid:
|
||||
|
||||
# Competitions
|
||||
|
||||
- [NAIC20]() coming soon, stay tuned.
|
Loading…
Reference in New Issue