finish v0.2 ddp training

This commit is contained in:
liaoxingyu 2020-07-06 16:57:43 +08:00
parent 5ae2cff47e
commit fec7abc461
54 changed files with 1257 additions and 2320 deletions

View File

@ -20,6 +20,9 @@ from fastreid.config import get_cfg
from fastreid.utils.file_io import PathManager
from predictor import FeatureExtractionDemo
# import some modules added in project like this below
# from projects.PartialReID.partialreid import *
cudnn.benchmark = True
@ -40,12 +43,7 @@ def get_parser():
help="path to config file",
)
parser.add_argument(
'--device',
default='cuda: 1',
help='CUDA device to use'
)
parser.add_argument(
'--parallel',
"--parallel",
action='store_true',
help='If use multiprocess for feature extraction.'
)
@ -72,7 +70,7 @@ def get_parser():
if __name__ == '__main__':
args = get_parser().parse_args()
cfg = setup_cfg(args)
demo = FeatureExtractionDemo(cfg, device=args.device, parallel=args.parallel)
demo = FeatureExtractionDemo(cfg, parallel=args.parallel)
PathManager.mkdirs(args.output)
if args.input:

View File

@ -21,7 +21,7 @@ except RuntimeError:
class FeatureExtractionDemo(object):
def __init__(self, cfg, device='cuda:0', parallel=False):
def __init__(self, cfg, parallel=False):
"""
Args:
cfg (CfgNode):
@ -35,7 +35,7 @@ class FeatureExtractionDemo(object):
self.num_gpus = torch.cuda.device_count()
self.predictor = AsyncPredictor(cfg, self.num_gpus)
else:
self.predictor = DefaultPredictor(cfg, device)
self.predictor = DefaultPredictor(cfg)
def run_on_image(self, original_image):
"""
@ -51,6 +51,8 @@ class FeatureExtractionDemo(object):
original_image = original_image[:, :, ::-1]
# Apply pre-processing to image.
image = cv2.resize(original_image, tuple(self.cfg.INPUT.SIZE_TEST[::-1]), interpolation=cv2.INTER_CUBIC)
# Make shape with a new batch dimension which is adapted for
# network input
image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))[None]
predictions = self.predictor(image)
return predictions
@ -68,16 +70,16 @@ class FeatureExtractionDemo(object):
if cnt >= buffer_size:
batch = batch_data.popleft()
predictions = self.predictor.get()
yield predictions, batch['targets'].numpy(), batch['camid'].numpy()
yield predictions, batch["targets"].numpy(), batch["camid"].numpy()
while len(batch_data):
batch = batch_data.popleft()
predictions = self.predictor.get()
yield predictions, batch['targets'].numpy(), batch['camid'].numpy()
yield predictions, batch["targets"].numpy(), batch["camid"].numpy()
else:
for batch in data_loader:
predictions = self.predictor(batch["images"])
yield predictions, batch['targets'].numpy(), batch['camid'].numpy()
yield predictions, batch["targets"].numpy(), batch["camid"].numpy()
class AsyncPredictor:
@ -90,15 +92,14 @@ class AsyncPredictor:
pass
class _PredictWorker(mp.Process):
def __init__(self, cfg, device, task_queue, result_queue):
def __init__(self, cfg, task_queue, result_queue):
self.cfg = cfg
self.device = device
self.task_queue = task_queue
self.result_queue = result_queue
super().__init__()
def run(self):
predictor = DefaultPredictor(self.cfg, self.device)
predictor = DefaultPredictor(self.cfg)
while True:
task = self.task_queue.get()
@ -120,9 +121,11 @@ class AsyncPredictor:
self.result_queue = mp.Queue(maxsize=num_workers * 3)
self.procs = []
for gpuid in range(max(num_gpus, 1)):
device = "cuda:{}".format(gpuid) if num_gpus > 0 else "cpu"
cfg = cfg.clone()
cfg.defrost()
cfg.MODEL.DEVICE = "cuda: {}".format(gpuid) if num_gpus > 0 else "cpu"
self.procs.append(
AsyncPredictor._PredictWorker(cfg, device, self.task_queue, self.result_queue)
AsyncPredictor._PredictWorker(cfg, self.task_queue, self.result_queue)
)
self.put_idx = 0

View File

@ -42,11 +42,6 @@ def get_parser():
metavar="FILE",
help="path to config file",
)
parser.add_argument(
'--device',
default='cuda: 1',
help='CUDA device to use'
)
parser.add_argument(
'--parallel',
action='store_true',
@ -101,7 +96,7 @@ if __name__ == '__main__':
logger = setup_logger()
cfg = setup_cfg(args)
test_loader, num_query = build_reid_test_loader(cfg, args.dataset_name)
demo = FeatureExtractionDemo(cfg, device=args.device, parallel=args.parallel)
demo = FeatureExtractionDemo(cfg, parallel=args.parallel)
logger.info("Start extracting image features")
feats = []

View File

@ -20,8 +20,9 @@ _C = CN()
# MODEL
# -----------------------------------------------------------------------------
_C.MODEL = CN()
_C.MODEL.DEVICE = "cuda"
_C.MODEL.META_ARCHITECTURE = 'Baseline'
_C.MODEL.OPEN_LAYERS = ['']
_C.MODEL.FREEZE_LAYERS = ['']
# ---------------------------------------------------------------------------- #
# Backbone options
@ -57,7 +58,7 @@ _C.MODEL.HEADS.NORM = "BN"
# Mini-batch split of Ghost BN
_C.MODEL.HEADS.NORM_SPLIT = 1
# Number of identity
_C.MODEL.HEADS.NUM_CLASSES = 751
_C.MODEL.HEADS.NUM_CLASSES = 0
# Input feature dimension
_C.MODEL.HEADS.IN_FEAT = 2048
# Reduction dimension in head
@ -74,7 +75,6 @@ _C.MODEL.HEADS.CLS_LAYER = "linear" # "arcface" or "circle"
_C.MODEL.HEADS.MARGIN = 0.15
_C.MODEL.HEADS.SCALE = 128
# ---------------------------------------------------------------------------- #
# REID LOSSES options
# ---------------------------------------------------------------------------- #
@ -115,7 +115,7 @@ _C.MODEL.WEIGHTS = ""
_C.MODEL.PIXEL_MEAN = [0.485*255, 0.456*255, 0.406*255]
# Values to be used for image normalization
_C.MODEL.PIXEL_STD = [0.229*255, 0.224*255, 0.225*255]
#
# -----------------------------------------------------------------------------
# INPUT
@ -194,10 +194,10 @@ _C.SOLVER.WEIGHT_DECAY_BIAS = 0.
# Multi-step learning rate options
_C.SOLVER.SCHED = "WarmupMultiStepLR"
_C.SOLVER.GAMMA = 0.1
_C.SOLVER.STEPS = (30, 55)
_C.SOLVER.STEPS = [30, 55]
# Cosine annealing learning rate options
_C.SOLVER.DELAY_ITERS = 100
_C.SOLVER.DELAY_ITERS = 0
_C.SOLVER.ETA_MIN_LR = 3e-7
# Warmup options
@ -210,15 +210,14 @@ _C.SOLVER.FREEZE_ITERS = 0
# SWA options
_C.SOLVER.SWA = CN()
_C.SOLVER.SWA.ENABLED = False
_C.SOLVER.SWA.ITER = 0
_C.SOLVER.SWA.PERIOD = 10
_C.SOLVER.SWA.ITER = 10
_C.SOLVER.SWA.PERIOD = 2
_C.SOLVER.SWA.LR_FACTOR = 10.
_C.SOLVER.SWA.ETA_MIN_LR = 3.5e-6
_C.SOLVER.SWA.LR_SCHED = False
_C.SOLVER.CHECKPOINT_PERIOD = 5000
_C.SOLVER.LOG_PERIOD = 30
# Number of images per batch
# This is global, so if we have 8 GPUs and IMS_PER_BATCH = 16, each GPU will
# see 2 images per batch

View File

@ -12,3 +12,4 @@ __all__ = [k for k in globals().keys() if not k.startswith("_")]
# but still make them available here
from .hooks import *
from .defaults import *
from .launch import *

View File

@ -11,20 +11,22 @@ since they are meant to represent the "common default behavior" people need in t
import argparse
import logging
import os
import sys
from collections import OrderedDict
import torch
import torch.nn.functional as F
from torch.nn import DataParallel
from torch.nn.parallel import DistributedDataParallel
from fastreid.data import build_reid_test_loader, build_reid_train_loader
from fastreid.evaluation import (DatasetEvaluator, ReidEvaluator,
inference_on_dataset, print_csv_format)
from fastreid.modeling.meta_arch import build_model
from fastreid.layers.sync_bn import patch_replication_callback
from fastreid.solver import build_lr_scheduler, build_optimizer
from fastreid.utils import comm
from fastreid.utils.env import seed_all_rng
from fastreid.utils.checkpoint import Checkpointer
from fastreid.utils.collect_env import collect_env_info
from fastreid.utils.events import CommonMetricPrinter, JSONWriter, TensorboardXWriter
from fastreid.utils.file_io import PathManager
from fastreid.utils.logger import setup_logger
@ -36,7 +38,7 @@ __all__ = ["default_argument_parser", "default_setup", "DefaultPredictor", "Defa
def default_argument_parser():
"""
Create a parser with some common arguments used by detectron2 users.
Create a parser with some common arguments used by fastreid users.
Returns:
argparse.ArgumentParser:
"""
@ -48,17 +50,17 @@ def default_argument_parser():
help="whether to attempt to resume from the checkpoint directory",
)
parser.add_argument("--eval-only", action="store_true", help="perform evaluation only")
# parser.add_argument("--num-gpus", type=int, default=1, help="number of gpus *per machine*")
# parser.add_argument("--num-machines", type=int, default=1)
# parser.add_argument(
# "--machine-rank", type=int, default=0, help="the rank of this machine (unique per machine)"
# )
parser.add_argument("--num-gpus", type=int, default=1, help="number of gpus *per machine*")
parser.add_argument("--num-machines", type=int, default=1, help="total number of machines")
parser.add_argument(
"--machine-rank", type=int, default=0, help="the rank of this machine (unique per machine)"
)
# PyTorch still may leave orphan processes in multi-gpu training.
# Therefore we use a deterministic way to obtain port,
# so that users are aware of orphan processes by seeing the port occupied.
# port = 2 ** 15 + 2 ** 14 + hash(os.getuid()) % 2 ** 14
# parser.add_argument("--dist-url", default="tcp://127.0.0.1:{}".format(port))
port = 2 ** 15 + 2 ** 14 + hash(os.getuid() if sys.platform != "win32" else 1) % 2 ** 14
parser.add_argument("--dist-url", default="tcp://127.0.0.1:{}".format(port))
parser.add_argument(
"opts",
help="Modify config options using the command-line",
@ -87,7 +89,7 @@ def default_setup(cfg, args):
logger = setup_logger(output_dir, distributed_rank=rank)
logger.info("Rank of current process: {}. World size: {}".format(rank, comm.get_world_size()))
# logger.info("Environment info:\n" + collect_env_info())
logger.info("Environment info:\n" + collect_env_info())
logger.info("Command line arguments: " + str(args))
if hasattr(args, "config_file") and args.config_file != "":
@ -106,6 +108,9 @@ def default_setup(cfg, args):
f.write(cfg.dump())
logger.info("Full config saved to {}".format(os.path.abspath(path)))
# make sure each worker has a different, yet deterministic seed if specified
seed_all_rng()
# cudnn benchmark has large overhead. It shouldn't be used considering the small size of
# typical validation set.
if not (hasattr(args, "eval_only") and args.eval_only):
@ -128,17 +133,14 @@ class DefaultPredictor:
outputs = pred(inputs)
"""
def __init__(self, cfg, device='cpu'):
def __init__(self, cfg):
self.cfg = cfg.clone() # cfg can be modified by model
self.cfg.defrost()
self.cfg.MODEL.BACKBONE.PRETRAIN = False
self.device = device
self.model = build_model(self.cfg)
self.model.to(device)
self.model.eval()
checkpointer = Checkpointer(self.model)
checkpointer.load(cfg.MODEL.WEIGHTS)
Checkpointer(self.model).load(cfg.MODEL.WEIGHTS)
def __call__(self, image):
"""
@ -147,9 +149,8 @@ class DefaultPredictor:
Returns:
predictions (torch.tensor): the output features of the model
"""
inputs = {"images": image}
with torch.no_grad(): # https://github.com/sphinx-doc/sphinx/issues/4258
image = image.to(self.device)
inputs = {"images": image}
predictions = self.model(inputs)
# Normalize feature to compute cosine distance
pred_feat = F.normalize(predictions)
@ -196,20 +197,24 @@ class DefaultTrainer(SimpleTrainer):
cfg (CfgNode):
"""
self.cfg = cfg
logger = logging.getLogger(__name__)
logger = logging.getLogger("fastreid")
if not logger.isEnabledFor(logging.INFO): # setup_logger is not called for fastreid
setup_logger()
# Assume these objects must be constructed in this order.
data_loader = self.build_train_loader(cfg)
cfg = self.auto_scale_hyperparams(cfg, data_loader)
model = self.build_model(cfg)
optimizer = self.build_optimizer(cfg, model)
logger.info('Prepare training set')
data_loader = self.build_train_loader(cfg)
# For training, wrap with DP. But don't need this for inference.
model = DataParallel(model)
if cfg.MODEL.BACKBONE.NORM == "syncBN":
# Monkey-patching with syncBN
patch_replication_callback(model)
model = model.cuda()
# For training, wrap with DDP. But don't need this for inference.
if comm.get_world_size() > 1:
# ref to https://github.com/pytorch/pytorch/issues/22049 to set `find_unused_parameters=True`
# for part of the parameters is not updated.
model = DistributedDataParallel(
model, device_ids=[comm.get_local_rank()], broadcast_buffers=False
)
super().__init__(model, data_loader, optimizer)
self.scheduler = self.build_lr_scheduler(cfg, optimizer)
@ -218,8 +223,8 @@ class DefaultTrainer(SimpleTrainer):
self.checkpointer = Checkpointer(
# Assume you want to save checkpoints together with logs/statistics
model,
self.data_loader.dataset,
cfg.OUTPUT_DIR,
save_to_disk=comm.is_main_process(),
optimizer=optimizer,
scheduler=self.scheduler,
)
@ -246,12 +251,10 @@ class DefaultTrainer(SimpleTrainer):
# Reinitialize dataloader iter because when we update dataset person identity dict
# to resume training, DataLoader won't update this dictionary when using multiprocess
# because of the function scope.
self._data_loader_iter = iter(self.data_loader)
self.start_iter = checkpoint.get("iteration", -1) if resume else -1
# The checkpoint stores the training iteration that just finished, thus we start
# at the next iteration (or iter zero if there's no checkpoint).
self.start_iter += 1
if resume and self.checkpointer.has_checkpoint():
self.start_iter = checkpoint.get("iteration", -1) + 1
# The checkpoint stores the training iteration that just finished, thus we start
# at the next iteration (or iter zero if there's no checkpoint).
def build_hooks(self):
"""
@ -292,31 +295,38 @@ class DefaultTrainer(SimpleTrainer):
cfg.TEST.PRECISE_BN.NUM_ITER,
))
if cfg.MODEL.OPEN_LAYERS != [''] and cfg.SOLVER.FREEZE_ITERS > 0:
open_layers = ",".join(cfg.MODEL.OPEN_LAYERS)
logger.info(f'Open "{open_layers}" training for {cfg.SOLVER.FREEZE_ITERS:d} iters')
if cfg.MODEL.FREEZE_LAYERS != [''] and cfg.SOLVER.FREEZE_ITERS > 0:
freeze_layers = ",".join(cfg.MODEL.FREEZE_LAYERS)
logger.info(f'Freeze layer group "{freeze_layers}" training for {cfg.SOLVER.FREEZE_ITERS:d} iterations')
ret.append(hooks.FreezeLayer(
self.model,
cfg.MODEL.OPEN_LAYERS,
self.optimizer,
cfg.MODEL.FREEZE_LAYERS,
cfg.SOLVER.FREEZE_ITERS,
))
# Do PreciseBN before checkpointer, because it updates the model and need to
# be saved by checkpointer.
# This is not always the best: if checkpointing has a different frequency,
# some checkpoints may have more precise statistics than others.
# if comm.is_main_process():
ret.append(hooks.PeriodicCheckpointer(self.checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD))
if comm.is_main_process():
ret.append(hooks.PeriodicCheckpointer(self.checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD))
def test_and_save_results():
self._last_eval_results = self.test(self.cfg, self.model)
return self._last_eval_results
if comm.is_main_process():
self._last_eval_results = self.test(self.cfg, self.model)
torch.cuda.empty_cache()
return self._last_eval_results
else:
return None
# Do evaluation after checkpointer, because then if it fails,
# we can use the saved checkpoint to debug.
ret.append(hooks.EvalHook(cfg.TEST.EVAL_PERIOD, test_and_save_results))
# run writers in the end, so that evaluation metrics are written
ret.append(hooks.PeriodicWriter(self.build_writers(), cfg.SOLVER.LOG_PERIOD))
if comm.is_main_process():
# run writers in the end, so that evaluation metrics are written
ret.append(hooks.PeriodicWriter(self.build_writers(), 200))
return ret
def build_writers(self):
@ -351,9 +361,12 @@ class DefaultTrainer(SimpleTrainer):
OrderedDict of results, if evaluation is enabled. Otherwise None.
"""
super().train(self.start_iter, self.max_iter)
# if hasattr(self, "_last_eval_results") and comm.is_main_process():
# verify_results(self.cfg, self._last_eval_results)
# return self._last_eval_results
if comm.is_main_process():
assert hasattr(
self, "_last_eval_results"
), "No evaluation results obtained during training!"
# verify_results(self.cfg, self._last_eval_results)
return self._last_eval_results
@classmethod
def build_model(cls, cfg):
@ -365,7 +378,7 @@ class DefaultTrainer(SimpleTrainer):
"""
model = build_model(cfg)
logger = logging.getLogger(__name__)
logger.info("Model:\n{}".format(model))
# logger.info("Model:\n{}".format(model))
return model
@classmethod
@ -394,6 +407,8 @@ class DefaultTrainer(SimpleTrainer):
It now calls :func:`fastreid.data.build_detection_train_loader`.
Overwrite it if you'd like a different data loader.
"""
logger = logging.getLogger(__name__)
logger.info("Prepare training set")
return build_reid_train_loader(cfg)
@classmethod
@ -433,7 +448,7 @@ class DefaultTrainer(SimpleTrainer):
results = OrderedDict()
for idx, dataset_name in enumerate(cfg.DATASETS.TESTS):
logger.info(f'prepare test set')
logger.info("Prepare testing set")
data_loader, num_query = cls.build_test_loader(cfg, dataset_name)
# When evaluators are passed in as arguments,
# implicitly assume that evaluators can be created before data_loader.
@ -451,15 +466,53 @@ class DefaultTrainer(SimpleTrainer):
continue
results_i = inference_on_dataset(model, data_loader, evaluator)
results[dataset_name] = results_i
if comm.is_main_process():
assert isinstance(
results_i, dict
), "Evaluator must return a dict on the main process. Got {} instead.".format(
results_i
)
logger.info("Evaluation results for {} in csv format:".format(dataset_name))
print_csv_format(results_i)
if len(results) == 1:
results = list(results.values())[0]
if comm.is_main_process():
assert isinstance(
results, dict
), "Evaluator must return a dict on the main process. Got {} instead.".format(
results
)
print_csv_format(results)
if len(results) == 1: results = list(results.values())[0]
return results
@staticmethod
def auto_scale_hyperparams(cfg, data_loader):
r"""
This is used for auto-computation actual training iterations,
because some hyper-param, such as MAX_ITER, means training epochs rather than iters,
so we need to convert specific hyper-param to training iterations.
"""
cfg = cfg.clone()
frozen = cfg.is_frozen()
cfg.defrost()
iters_per_epoch = len(data_loader.dataset) // (cfg.SOLVER.IMS_PER_BATCH * comm.get_world_size())
cfg.MODEL.HEADS.NUM_CLASSES = data_loader.dataset.num_classes
cfg.SOLVER.MAX_ITER *= iters_per_epoch
cfg.SOLVER.WARMUP_ITERS *= iters_per_epoch
cfg.SOLVER.FREEZE_ITERS *= iters_per_epoch
cfg.SOLVER.DELAY_ITERS *= iters_per_epoch
for i in range(len(cfg.SOLVER.STEPS)):
cfg.SOLVER.STEPS[i] *= iters_per_epoch
cfg.SOLVER.SWA.ITER *= iters_per_epoch
cfg.SOLVER.SWA.PERIOD *= iters_per_epoch
cfg.SOLVER.CHECKPOINT_PERIOD *= iters_per_epoch
cfg.TEST.EVAL_PERIOD *= iters_per_epoch
logger = logging.getLogger(__name__)
logger.info(
f"Auto-scaling the config to num_classes={cfg.MODEL.HEADS.NUM_CLASSES}, "
f"max_Iter={cfg.SOLVER.MAX_ITER}, wamrup_Iter={cfg.SOLVER.WARMUP_ITERS}, "
f"freeze_Iter={cfg.SOLVER.FREEZE_ITERS}, delay_Iter={cfg.SOLVER.DELAY_ITERS}, "
f"step_Iter={cfg.SOLVER.STEPS}, ckpt_Iter={cfg.SOLVER.CHECKPOINT_PERIOD}, "
f"eval_Iter={cfg.TEST.EVAL_PERIOD}."
)
if frozen: cfg.freeze()
return cfg

View File

@ -4,7 +4,6 @@
import datetime
import itertools
import logging
import warnings
import os
import tempfile
import time
@ -12,16 +11,17 @@ from collections import Counter
import torch
from torch import nn
from torch.nn.parallel import DistributedDataParallel
from .train_loop import HookBase
from fastreid.solver import optim
from fastreid.evaluation.testing import flatten_results_dict
from fastreid.solver import optim
from fastreid.utils import comm
from fastreid.utils.checkpoint import PeriodicCheckpointer as _PeriodicCheckpointer
from fastreid.utils.events import EventStorage, EventWriter
from fastreid.utils.file_io import PathManager
from fastreid.utils.precision_bn import update_bn_stats, get_bn_modules
from fastreid.utils.timer import Timer
from .train_loop import HookBase
__all__ = [
"CallbackHook",
@ -308,7 +308,6 @@ class EvalHook(HookBase):
"""
self._period = eval_period
self._func = eval_function
self._done_eval_at_last = False
def _do_eval(self):
results = self._func()
@ -329,22 +328,16 @@ class EvalHook(HookBase):
)
self.trainer.storage.put_scalars(**flattened_results, smoothing_hint=False)
# Evaluation may take different time among workers.
# A barrier make them start the next iteration together.
comm.synchronize()
torch.cuda.empty_cache()
def after_step(self):
next_iter = self.trainer.iter + 1
is_final = next_iter == self.trainer.max_iter
if is_final or (self._period > 0 and next_iter % self._period == 0):
self._do_eval()
if is_final:
self._done_eval_at_last = True
# Evaluation may take different time among workers.
# A barrier make them start the next iteration together.
comm.synchronize()
def after_train(self):
if not self._done_eval_at_last:
self._do_eval()
# func is likely a closure that holds reference to the trainer
# therefore we clean it to avoid circular reference in the end
del self._func
@ -418,27 +411,24 @@ class PreciseBN(HookBase):
update_bn_stats(self._model, data_loader(), self._num_iter)
class LRFinder(HookBase):
pass
class FreezeLayer(HookBase):
def __init__(self, model, open_layer_names, freeze_iters):
def __init__(self, model, optimizer, freeze_layers, freeze_iters):
self._logger = logging.getLogger(__name__)
if isinstance(model, nn.DataParallel):
if isinstance(model, DistributedDataParallel):
model = model.module
self.model = model
self.optimizer = optimizer
self.freeze_layers = freeze_layers
self.freeze_iters = freeze_iters
self.open_layer_names = open_layer_names
# previous requires grad status
param_grad = {}
for name, param in self.model.named_parameters():
param_grad[name] = param.requires_grad
self.param_grad = param_grad
# Previous parameters freeze status
param_freeze = {}
for param_group in self.optimizer.param_groups:
param_name = param_group['name']
param_freeze[param_name] = param_group['freeze']
self.param_freeze = param_freeze
def before_step(self):
# Freeze specific layers
@ -450,28 +440,28 @@ class FreezeLayer(HookBase):
self.open_all_layer()
def freeze_specific_layer(self):
for layer in self.open_layer_names:
for layer in self.freeze_layers:
if not hasattr(self.model, layer):
self._logger.info(f'"{layer}" is not an attribute of the model, will skip this layer')
self._logger.info(f'{layer} is not an attribute of the model, will skip this layer')
for param_group in self.optimizer.param_groups:
param_name = param_group['name']
if param_name.split('.')[0] in self.freeze_layers:
param_group['freeze'] = True
# Change BN in freeze layers to eval mode
for name, module in self.model.named_children():
if name in self.open_layer_names:
module.train()
for p in module.parameters():
p.requires_grad = True
else:
module.eval()
for p in module.parameters():
p.requires_grad = False
if name in self.freeze_layers: module.eval()
def open_all_layer(self):
self.model.train()
for name, param in self.model.named_parameters():
param.requires_grad = self.param_grad[name]
for param_group in self.optimizer.param_groups:
param_name = param_group['name']
param_group['freeze'] = self.param_freeze[param_name]
class SWA(HookBase):
def __init__(self, swa_start: int, swa_freq: int, swa_lr_factor: float, eta_min: float, lr_sched=False,):
def __init__(self, swa_start: int, swa_freq: int, swa_lr_factor: float, eta_min: float, lr_sched=False, ):
self.swa_start = swa_start
self.swa_freq = swa_freq
self.swa_lr_factor = swa_lr_factor

103
fastreid/engine/launch.py Normal file
View File

@ -0,0 +1,103 @@
# encoding: utf-8
"""
@author: xingyu liao
@contact: liaoxingyu5@jd.com
"""
# based on:
# https://github.com/facebookresearch/detectron2/blob/master/detectron2/engine/launch.py
import logging
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from fastreid.utils import comm
__all__ = ["launch"]
def _find_free_port():
import socket
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
# Binding to port 0 will cause the OS to find an available port for us
sock.bind(("", 0))
port = sock.getsockname()[1]
sock.close()
# NOTE: there is still a chance the port could be taken by other processes.
return port
def launch(main_func, num_gpus_per_machine, num_machines=1, machine_rank=0, dist_url=None, args=()):
"""
Launch multi-gpu or distributed training.
This function must be called on all machines involved in the training.
It will spawn child processes (defined by ``num_gpus_per_machine`) on each machine.
Args:
main_func: a function that will be called by `main_func(*args)`
num_gpus_per_machine (int): number of GPUs per machine
num_machines (int): the total number of machines
machine_rank (int): the rank of this machine
dist_url (str): url to connect to for distributed jobs, including protocol
e.g. "tcp://127.0.0.1:8686".
Can be set to "auto" to automatically select a free port on localhost
args (tuple): arguments passed to main_func
"""
world_size = num_machines * num_gpus_per_machine
if world_size > 1:
# https://github.com/pytorch/pytorch/pull/14391
# TODO prctl in spawned processes
if dist_url == "auto":
assert num_machines == 1, "dist_url=auto not supported in multi-machine jobs."
port = _find_free_port()
dist_url = f"tcp://127.0.0.1:{port}"
if num_machines > 1 and dist_url.startswith("file://"):
logger = logging.getLogger(__name__)
logger.warning(
"file:// is not a reliable init_method in multi-machine jobs. Prefer tcp://"
)
mp.spawn(
_distributed_worker,
nprocs=num_gpus_per_machine,
args=(main_func, world_size, num_gpus_per_machine, machine_rank, dist_url, args),
daemon=False,
)
else:
main_func(*args)
def _distributed_worker(
local_rank, main_func, world_size, num_gpus_per_machine, machine_rank, dist_url, args
):
assert torch.cuda.is_available(), "cuda is not available. Please check your installation."
global_rank = machine_rank * num_gpus_per_machine + local_rank
try:
dist.init_process_group(
backend="NCCL", init_method=dist_url, world_size=world_size, rank=global_rank
)
except Exception as e:
logger = logging.getLogger(__name__)
logger.error("Process group URL: {}".format(dist_url))
raise e
# synchronize is needed here to prevent a possible timeout after calling init_process_group
# See: https://github.com/facebookresearch/maskrcnn-benchmark/issues/172
comm.synchronize()
assert num_gpus_per_machine <= torch.cuda.device_count()
torch.cuda.set_device(local_rank)
# Setup the local process group (which contains ranks within the same machine)
assert comm._LOCAL_PROCESS_GROUP is None
num_machines = world_size // num_gpus_per_machine
for i in range(num_machines):
ranks_on_i = list(range(i * num_gpus_per_machine, (i + 1) * num_gpus_per_machine))
pg = dist.new_group(ranks_on_i)
if i == machine_rank:
comm._LOCAL_PROCESS_GROUP = pg
main_func(*args)

View File

@ -5,10 +5,12 @@ https://github.com/facebookresearch/detectron2/blob/master/detectron2/engine/tra
"""
import logging
import numpy as np
import time
import weakref
import numpy as np
import torch
import fastreid.utils.comm as comm
from fastreid.utils.events import EventStorage
@ -194,12 +196,12 @@ class SimpleTrainer(TrainerBase):
"""
data = next(self._data_loader_iter)
data_time = time.perf_counter() - start
"""
If your want to do something with the heads, you can wrap the model.
"""
outputs = self.model(data)
loss_dict = self.model.module.losses(outputs)
losses = sum(loss for loss in loss_dict.values())
loss_dict = self.model(data)
losses = sum(loss_dict.values())
self._detect_anomaly(losses, loss_dict)
metrics_dict = loss_dict
@ -238,23 +240,22 @@ class SimpleTrainer(TrainerBase):
}
# gather metrics among all workers for logging
# This assumes we do DDP-style training, which is currently the only
# supported method in detectron2.
# supported method in fastreid.
all_metrics_dict = comm.gather(metrics_dict)
# if comm.is_main_process():
if "data_time" in all_metrics_dict[0]:
# data_time among workers can have high variance. The actual latency
# caused by data_time is the maximum among workers.
data_time = np.max([x.pop("data_time") for x in all_metrics_dict])
self.storage.put_scalar("data_time", data_time)
if comm.is_main_process():
if "data_time" in all_metrics_dict[0]:
# data_time among workers can have high variance. The actual latency
# caused by data_time is the maximum among workers.
data_time = np.max([x.pop("data_time") for x in all_metrics_dict])
self.storage.put_scalar("data_time", data_time)
# average the rest metrics
metrics_dict = {
k: np.mean([x[k] for x in all_metrics_dict]) for k in all_metrics_dict[0].keys()
}
total_losses_reduced = sum(loss for loss in metrics_dict.values())
self.storage.put_scalar("total_loss", total_losses_reduced)
if len(metrics_dict) > 1:
self.storage.put_scalars(**metrics_dict)
# average the rest metrics
metrics_dict = {
k: np.mean([x[k] for x in all_metrics_dict]) for k in all_metrics_dict[0].keys()
}
total_losses_reduced = sum(loss for loss in metrics_dict.values())
self.storage.put_scalar("total_loss", total_losses_reduced)
if len(metrics_dict) > 1:
self.storage.put_scalars(**metrics_dict)

View File

@ -6,7 +6,7 @@ from contextlib import contextmanager
import torch
from ..utils.logger import log_every_n_seconds
from fastreid.utils.logger import log_every_n_seconds
class DatasetEvaluator:
@ -28,12 +28,12 @@ class DatasetEvaluator:
def preprocess_inputs(self, inputs):
pass
def process(self, output):
def process(self, inputs, outputs):
"""
Process an input/output pair.
Args:
input: the input that's used to call the model.
output: the return value of `model(input)`
inputs: the inputs that's used to call the model.
outputs: the return value of `model(input)`
"""
pass
@ -95,11 +95,10 @@ def inference_on_dataset(model, data_loader, evaluator):
Returns:
The return value of `evaluator.evaluate()`
"""
# num_devices = torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1
logger = logging.getLogger(__name__)
logger.info("Start inference on {} images".format(len(data_loader.dataset)))
total = len(data_loader) # inference data loader must have a fixed length
total = len(data_loader.dataset) # inference data loader must have a fixed length
evaluator.reset()
num_warmup = min(5, total - 1)
@ -116,7 +115,7 @@ def inference_on_dataset(model, data_loader, evaluator):
if torch.cuda.is_available():
torch.cuda.synchronize()
total_compute_time += time.perf_counter() - start_compute_time
evaluator.process(outputs)
evaluator.process(inputs, outputs)
idx += 1
iters_after_start = idx + 1 - num_warmup * int(idx >= num_warmup)

View File

@ -154,10 +154,8 @@ def eval_market1501(distmat, q_pids, g_pids, q_camids, g_camids, max_rank):
all_cmc = np.asarray(all_cmc).astype(np.float32)
all_cmc = all_cmc.sum(0) / num_valid_q
mAP = np.mean(all_AP)
mINP = np.mean(all_INP)
return all_cmc, mAP, mINP
return all_cmc, all_AP, all_INP
def evaluate_py(

View File

@ -14,8 +14,8 @@ import torch.nn.functional as F
from .evaluator import DatasetEvaluator
from .query_expansion import aqe
from .rank import evaluate_rank
from .roc import evaluate_roc
from .rerank import re_ranking
from .roc import evaluate_roc
logger = logging.getLogger(__name__)
@ -35,10 +35,10 @@ class ReidEvaluator(DatasetEvaluator):
self.pids = []
self.camids = []
def process(self, outputs):
self.features.append(outputs[0].cpu())
self.pids.extend(outputs[1].cpu().numpy())
self.camids.extend(outputs[2].cpu().numpy())
def process(self, inputs, outputs):
self.pids.extend(inputs["targets"].numpy())
self.camids.extend(inputs["camid"].numpy())
self.features.append(outputs.cpu())
@staticmethod
def cal_dist(metric: str, query_feat: torch.tensor, gallery_feat: torch.tensor):
@ -92,13 +92,12 @@ class ReidEvaluator(DatasetEvaluator):
cmc, all_AP, all_INP = evaluate_rank(dist, query_pids, gallery_pids, query_camids, gallery_camids)
mAP = np.mean(all_AP)
mINP = np.mean(all_INP)
for r in [1, 5, 10]:
self._results['Rank-{}'.format(r)] = cmc[r - 1]
self._results['R-1'] = cmc[0]
self._results['mAP'] = mAP
self._results['mINP'] = mINP
tprs = evaluate_roc(dist, query_pids, gallery_pids, query_camids, gallery_camids)
fprs = [1e-4, 1e-3, 1e-2]
for i in range(len(fprs)):
self._results["TPR@FPR={}".format(fprs[i])] = tprs[i]
# tprs = evaluate_roc(dist, query_pids, gallery_pids, query_camids, gallery_camids)
# fprs = [1e-4, 1e-3, 1e-2]
# for i in range(len(fprs)):
# self._results["TPR@FPR={}".format(fprs[i])] = tprs[i]
return copy.deepcopy(self._results)

View File

@ -15,10 +15,16 @@ def print_csv_format(results):
results (OrderedDict[dict]): task_name -> {metric -> score}
"""
assert isinstance(results, OrderedDict), results # unordered results cannot be properly printed
task = list(results.keys())[0]
metrics = [k for k in results[task]]
logger = logging.getLogger(__name__)
logger.info('----------------------------------------')
logger.info("Evaluation results in csv format:")
logger.info("Metric: " + ", ".join([k for k in metrics]))
for task, res in results.items():
logger.info("Task: {}".format(task))
logger.info("{:.1%}".format(res))
logger.info(f"{task}: " + ", ".join(["{:.1%}".format(v) for v in res.values()]))
logger.info('----------------------------------------')
def verify_results(cfg, results):

166
fastreid/export/demo.py Normal file
View File

@ -0,0 +1,166 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import torch.nn.functional as F
from collections import defaultdict
import argparse
import json
import os
import sys
import time
import cv2
import numpy as np
import torch
from torch.backends import cudnn
from fastreid.modeling import build_model
from fastreid.utils.checkpoint import Checkpointer
from fastreid.config import get_cfg
cudnn.benchmark = True
class Reid(object):
def __init__(self, config_file):
cfg = get_cfg()
cfg.merge_from_file(config_file)
cfg.defrost()
cfg.MODEL.WEIGHTS = 'projects/bjzProject/logs/bjz/arcface_adam/model_final.pth'
model = build_model(cfg)
Checkpointer(model).resume_or_load(cfg.MODEL.WEIGHTS)
model.cuda()
model.eval()
self.model = model
# self.model = torch.jit.load("reid_model.pt")
# self.model.eval()
# self.model.cuda()
example = torch.rand(1, 3, 256, 128)
example = example.cuda()
traced_script_module = torch.jit.trace_module(model, {'inference': example})
traced_script_module.save("reid_feat_extractor.pt")
@classmethod
def preprocess(cls, img_path):
img = cv2.imread(img_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (128, 256))
img = img / 255.0
img = (img - [0.485, 0.456, 0.406]) / [0.229, 0.224, 0.225]
img = img.transpose((2, 0, 1)).astype(np.float32)
img = img[np.newaxis, :, :, :]
data = torch.from_numpy(img).cuda().float()
return data
@torch.no_grad()
def demo(self, img_path):
data = self.preprocess(img_path)
output = self.model.inference(data)
feat = output.cpu().data.numpy()
return feat
# @torch.no_grad()
# def extract_feat(self, dataloader):
# prefetcher = test_data_prefetcher(dataloader)
# feats = []
# labels = []
# batch = prefetcher.next()
# num_count = 0
# while batch[0] is not None:
# img, pid, camid = batch
# feat = self.model(img)
# feats.append(feat.cpu())
# labels.extend(np.asarray(pid))
#
# # if num_count > 2:
# # break
# batch = prefetcher.next()
# # num_count += 1
#
# feats = torch.cat(feats, dim=0)
# id_feats = defaultdict(list)
# for f, i in zip(feats, labels):
# id_feats[i].append(f)
# all_feats = []
# label_names = []
# for i in id_feats:
# all_feats.append(torch.stack(id_feats[i], dim=0).mean(dim=0))
# label_names.append(i)
#
# label_names = np.asarray(label_names)
# all_feats = torch.stack(all_feats, dim=0) # (n, 2048)
# all_feats = F.normalize(all_feats, p=2, dim=1)
# np.save('feats.npy', all_feats.cpu())
# np.save('labels.npy', label_names)
# cos = torch.mm(all_feats, all_feats.t()).numpy() # (n, n)
# cos -= np.eye(all_feats.shape[0])
# f = open('check_cross_folder_similarity.txt', 'w')
# for i in range(len(label_names)):
# sim_indx = np.argwhere(cos[i] > 0.5)[:, 0]
# sim_name = label_names[sim_indx]
# write_str = label_names[i] + ' '
# # f.write(label_names[i]+'\t')
# for n in sim_name:
# write_str += (n + ' ')
# # f.write(n+'\t')
# f.write(write_str+'\n')
#
#
# def prepare_gt(self, json_file):
# feat = []
# label = []
# with open(json_file, 'r') as f:
# total = json.load(f)
# for index in total:
# label.append(index)
# feat.append(np.array(total[index]))
# time_label = [int(i[0:10]) for i in label]
#
# return np.array(feat), np.array(label), np.array(time_label)
def compute_topk(self, k, feat, feats, label):
# num_gallery = feats.shape[0]
# new_feat = np.tile(feat,[num_gallery,1])
norm_feat = np.sqrt(np.sum(np.square(feat), axis=-1))
norm_feats = np.sqrt(np.sum(np.square(feats), axis=-1))
matrix = np.sum(np.multiply(feat, feats), axis=-1)
dist = matrix / np.multiply(norm_feat, norm_feats)
# print('feat:',feat.shape)
# print('feats:',feats.shape)
# print('label:',label.shape)
# print('dist:',dist.shape)
index = np.argsort(-dist)
# print('index:',index.shape)
result = []
for i in range(min(feats.shape[0], k)):
print(dist[index[i]])
result.append(label[index[i]])
return result
if __name__ == '__main__':
reid_sys = Reid(config_file='../../projects/bjzProject/configs/bjz.yml')
img_path = '/export/home/lxy/beijingStationReID/reid_model/demo_imgs/003740_c5s2_1561733125170.000000.jpg'
feat = reid_sys.demo(img_path)
feat_extractor = torch.jit.load('reid_feat_extractor.pt')
data = reid_sys.preprocess(img_path)
feat2 = feat_extractor.inference(data)
from ipdb import set_trace; set_trace()
# imgs = os.listdir(img_path)
# feats = {}
# for i in range(len(imgs)):
# feat = reid.demo(os.path.join(img_path, imgs[i]))
# feats[imgs[i]] = feat
# feat = reid.demo(os.path.join(img_path, 'crop_img0.jpg'))
# out1 = feats['dog.jpg']
# out2 = feats['kobe2.jpg']
# innerProduct = np.dot(out1, out2.T)
# cosineSimilarity = innerProduct / (np.linalg.norm(out1, ord=2) * np.linalg.norm(out2, ord=2))
# print(f'cosine similarity is {cosineSimilarity[0][0]:.4f}')

View File

@ -10,8 +10,6 @@ import torch
import torch.nn.functional as F
from torch import nn
from .sync_bn import SynchronizedBatchNorm2d
__all__ = [
"BatchNorm",
"IBN",
@ -32,12 +30,14 @@ class BatchNorm(nn.BatchNorm2d):
self.bias.requires_grad_(not bias_freeze)
class SyncBatchNorm(SynchronizedBatchNorm2d):
class SyncBatchNorm(nn.SyncBatchNorm):
def __init__(self, num_features, eps=1e-05, momentum=0.1, weight_freeze=False, bias_freeze=False, weight_init=1.0,
bias_init=0.0):
super().__init__(num_features, eps=eps, momentum=momentum, weight_freeze=weight_freeze, bias_freeze=bias_freeze)
super().__init__(num_features, eps=eps, momentum=momentum)
if weight_init is not None: self.weight.data.fill_(weight_init)
if bias_init is not None: self.bias.data.fill_(bias_init)
self.weight.requires_grad_(not weight_freeze)
self.bias.requires_grad_(not bias_freeze)
class IBN(nn.Module):

View File

@ -4,15 +4,11 @@
@contact: sherlockliao01@gmail.com
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter
from fastreid.utils.one_hot import one_hot
class Circle(nn.Module):
def __init__(self, cfg, in_feat, num_classes):
@ -34,7 +30,7 @@ class Circle(nn.Module):
s_p = self._s * alpha_p * (sim_mat - delta_p)
s_n = self._s * alpha_n * (sim_mat - delta_n)
targets = one_hot(targets, self._num_classes)
targets = F.one_hot(targets, num_classes=self._num_classes)
pred_class_logits = targets * s_p + (1.0 - targets) * s_n

View File

@ -1,13 +0,0 @@
# -*- coding: utf-8 -*-
# File : __init__.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 27/01/2018
#
# This file is part of Synchronized-BatchNorm-PyTorch.
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
# Distributed under MIT License.
from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d
from .batchnorm import patch_sync_batchnorm, convert_model
from .replicate import DataParallelWithCallback, patch_replication_callback

View File

@ -1,395 +0,0 @@
# -*- coding: utf-8 -*-
# File : batchnorm.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 27/01/2018
#
# This file is part of Synchronized-BatchNorm-PyTorch.
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
# Distributed under MIT License.
import collections
import contextlib
import torch
import torch.nn.functional as F
from torch.nn.modules.batchnorm import _BatchNorm
try:
from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
except ImportError:
ReduceAddCoalesced = Broadcast = None
try:
from jactorch.parallel.comm import SyncMaster
from jactorch.parallel.data_parallel import JacDataParallel as DataParallelWithCallback
except ImportError:
from .comm import SyncMaster
from .replicate import DataParallelWithCallback
__all__ = [
'SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d',
'patch_sync_batchnorm', 'convert_model'
]
def _sum_ft(tensor):
"""sum over the first and last dimention"""
return tensor.sum(dim=0).sum(dim=-1)
def _unsqueeze_ft(tensor):
"""add new dimensions at the front and the tail"""
return tensor.unsqueeze(0).unsqueeze(-1)
_ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size'])
_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])
class _SynchronizedBatchNorm(_BatchNorm):
def __init__(self, num_features, eps=1e-5, momentum=0.1, weight_freeze=False, bias_freeze=False, affine=True):
assert ReduceAddCoalesced is not None, 'Can not use Synchronized Batch Normalization without CUDA support.'
super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine)
self.weight.requires_grad_(not weight_freeze)
self.bias.requires_grad_(not bias_freeze)
self._sync_master = SyncMaster(self._data_parallel_master)
self._is_parallel = False
self._parallel_id = None
self._slave_pipe = None
def forward(self, input):
# If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.
if not (self._is_parallel and self.training):
return F.batch_norm(
input, self.running_mean, self.running_var, self.weight, self.bias,
self.training, self.momentum, self.eps)
# Resize the input to (B, C, -1).
input_shape = input.size()
input = input.view(input.size(0), self.num_features, -1)
# Compute the sum and square-sum.
sum_size = input.size(0) * input.size(2)
input_sum = _sum_ft(input)
input_ssum = _sum_ft(input ** 2)
# Reduce-and-broadcast the statistics.
if self._parallel_id == 0:
mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
else:
mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))
# Compute the output.
if self.affine:
# MJY:: Fuse the multiplication for speed.
output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias)
else:
output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)
# Reshape it.
return output.view(input_shape)
def __data_parallel_replicate__(self, ctx, copy_id):
self._is_parallel = True
self._parallel_id = copy_id
# parallel_id == 0 means master device.
if self._parallel_id == 0:
ctx.sync_master = self._sync_master
else:
self._slave_pipe = ctx.sync_master.register_slave(copy_id)
def _data_parallel_master(self, intermediates):
"""Reduce the sum and square-sum, compute the statistics, and broadcast it."""
# Always using same "device order" makes the ReduceAdd operation faster.
# Thanks to:: Tete Xiao (http://tetexiao.com/)
intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())
to_reduce = [i[1][:2] for i in intermediates]
to_reduce = [j for i in to_reduce for j in i] # flatten
target_gpus = [i[1].sum.get_device() for i in intermediates]
sum_size = sum([i[1].sum_size for i in intermediates])
sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)
broadcasted = Broadcast.apply(target_gpus, mean, inv_std)
outputs = []
for i, rec in enumerate(intermediates):
outputs.append((rec[0], _MasterMessage(*broadcasted[i * 2:i * 2 + 2])))
return outputs
def _compute_mean_std(self, sum_, ssum, size):
"""Compute the mean and standard-deviation with sum and square-sum. This method
also maintains the moving average on the master device."""
assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
mean = sum_ / size
sumvar = ssum - sum_ * mean
unbias_var = sumvar / (size - 1)
bias_var = sumvar / size
if hasattr(torch, 'no_grad'):
with torch.no_grad():
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
else:
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
return mean, bias_var.clamp(self.eps) ** -0.5
class SynchronizedBatchNorm1d(_SynchronizedBatchNorm):
r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a
mini-batch.
.. math::
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
This module differs from the built-in PyTorch BatchNorm1d as the mean and
standard-deviation are reduced across all devices during training.
For example, when one uses `nn.DataParallel` to wrap the network during
training, PyTorch's implementation normalize the tensor on each device using
the statistics only on that device, which accelerated the computation and
is also easy to implement, but the statistics might be inaccurate.
Instead, in this synchronized version, the statistics will be computed
over all training samples distributed on multiple devices.
Note that, for one-GPU or CPU-only case, this module behaves exactly same
as the built-in PyTorch implementation.
The mean and standard-deviation are calculated per-dimension over
the mini-batches and gamma and beta are learnable parameter vectors
of size C (where C is the input size).
During training, this layer keeps a running estimate of its computed mean
and variance. The running sum is kept with a default momentum of 0.1.
During evaluation, this running mean/variance is used for normalization.
Because the BatchNorm is done over the `C` dimension, computing statistics
on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm
Args:
num_features: num_features from an expected input of size
`batch_size x num_features [x width]`
eps: a value added to the denominator for numerical stability.
Default: 1e-5
momentum: the value used for the running_mean and running_var
computation. Default: 0.1
affine: a boolean value that when set to ``True``, gives the layer learnable
affine parameters. Default: ``True``
Shape::
- Input: :math:`(N, C)` or :math:`(N, C, L)`
- Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
Examples:
>>> # With Learnable Parameters
>>> m = SynchronizedBatchNorm1d(100)
>>> # Without Learnable Parameters
>>> m = SynchronizedBatchNorm1d(100, affine=False)
>>> input = torch.autograd.Variable(torch.randn(20, 100))
>>> output = m(input)
"""
def _check_input_dim(self, input):
if input.dim() != 2 and input.dim() != 3:
raise ValueError('expected 2D or 3D input (got {}D input)'
.format(input.dim()))
super(SynchronizedBatchNorm1d, self)._check_input_dim(input)
class SynchronizedBatchNorm2d(_SynchronizedBatchNorm):
r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch
of 3d inputs
.. math::
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
This module differs from the built-in PyTorch BatchNorm2d as the mean and
standard-deviation are reduced across all devices during training.
For example, when one uses `nn.DataParallel` to wrap the network during
training, PyTorch's implementation normalize the tensor on each device using
the statistics only on that device, which accelerated the computation and
is also easy to implement, but the statistics might be inaccurate.
Instead, in this synchronized version, the statistics will be computed
over all training samples distributed on multiple devices.
Note that, for one-GPU or CPU-only case, this module behaves exactly same
as the built-in PyTorch implementation.
The mean and standard-deviation are calculated per-dimension over
the mini-batches and gamma and beta are learnable parameter vectors
of size C (where C is the input size).
During training, this layer keeps a running estimate of its computed mean
and variance. The running sum is kept with a default momentum of 0.1.
During evaluation, this running mean/variance is used for normalization.
Because the BatchNorm is done over the `C` dimension, computing statistics
on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm
Args:
num_features: num_features from an expected input of
size batch_size x num_features x height x width
eps: a value added to the denominator for numerical stability.
Default: 1e-5
momentum: the value used for the running_mean and running_var
computation. Default: 0.1
affine: a boolean value that when set to ``True``, gives the layer learnable
affine parameters. Default: ``True``
Shape::
- Input: :math:`(N, C, H, W)`
- Output: :math:`(N, C, H, W)` (same shape as input)
Examples:
>>> # With Learnable Parameters
>>> m = SynchronizedBatchNorm2d(100)
>>> # Without Learnable Parameters
>>> m = SynchronizedBatchNorm2d(100, affine=False)
>>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45))
>>> output = m(input)
"""
def _check_input_dim(self, input):
if input.dim() != 4:
raise ValueError('expected 4D input (got {}D input)'
.format(input.dim()))
super(SynchronizedBatchNorm2d, self)._check_input_dim(input)
class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):
r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch
of 4d inputs
.. math::
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
This module differs from the built-in PyTorch BatchNorm3d as the mean and
standard-deviation are reduced across all devices during training.
For example, when one uses `nn.DataParallel` to wrap the network during
training, PyTorch's implementation normalize the tensor on each device using
the statistics only on that device, which accelerated the computation and
is also easy to implement, but the statistics might be inaccurate.
Instead, in this synchronized version, the statistics will be computed
over all training samples distributed on multiple devices.
Note that, for one-GPU or CPU-only case, this module behaves exactly same
as the built-in PyTorch implementation.
The mean and standard-deviation are calculated per-dimension over
the mini-batches and gamma and beta are learnable parameter vectors
of size C (where C is the input size).
During training, this layer keeps a running estimate of its computed mean
and variance. The running sum is kept with a default momentum of 0.1.
During evaluation, this running mean/variance is used for normalization.
Because the BatchNorm is done over the `C` dimension, computing statistics
on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm
or Spatio-temporal BatchNorm
Args:
num_features: num_features from an expected input of
size batch_size x num_features x depth x height x width
eps: a value added to the denominator for numerical stability.
Default: 1e-5
momentum: the value used for the running_mean and running_var
computation. Default: 0.1
affine: a boolean value that when set to ``True``, gives the layer learnable
affine parameters. Default: ``True``
Shape::
- Input: :math:`(N, C, D, H, W)`
- Output: :math:`(N, C, D, H, W)` (same shape as input)
Examples:
>>> # With Learnable Parameters
>>> m = SynchronizedBatchNorm3d(100)
>>> # Without Learnable Parameters
>>> m = SynchronizedBatchNorm3d(100, affine=False)
>>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10))
>>> output = m(input)
"""
def _check_input_dim(self, input):
if input.dim() != 5:
raise ValueError('expected 5D input (got {}D input)'
.format(input.dim()))
super(SynchronizedBatchNorm3d, self)._check_input_dim(input)
@contextlib.contextmanager
def patch_sync_batchnorm():
import torch.nn as nn
backup = nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d
nn.BatchNorm1d = SynchronizedBatchNorm1d
nn.BatchNorm2d = SynchronizedBatchNorm2d
nn.BatchNorm3d = SynchronizedBatchNorm3d
yield
nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d = backup
def convert_model(module):
"""Traverse the input module and its child recursively
and replace all instance of torch.nn.modules.batchnorm.BatchNorm*N*d
to SynchronizedBatchNorm*N*d
Args:
module: the input module needs to be convert to SyncBN model
Examples:
>>> import torch.nn as nn
>>> import torchvision
>>> # m is a standard pytorch model
>>> m = torchvision.models.resnet18(True)
>>> m = nn.DataParallel(m)
>>> # after convert, m is using SyncBN
>>> m = convert_model(m)
"""
if isinstance(module, torch.nn.DataParallel):
mod = module.module
mod = convert_model(mod)
mod = DataParallelWithCallback(mod)
return mod
mod = module
for pth_module, sync_module in zip([torch.nn.modules.batchnorm.BatchNorm1d,
torch.nn.modules.batchnorm.BatchNorm2d,
torch.nn.modules.batchnorm.BatchNorm3d],
[SynchronizedBatchNorm1d,
SynchronizedBatchNorm2d,
SynchronizedBatchNorm3d]):
if isinstance(module, pth_module):
mod = sync_module(module.num_features, module.eps, module.momentum, module.affine)
mod.running_mean = module.running_mean
mod.running_var = module.running_var
if module.affine:
mod.weight.data = module.weight.data.clone().detach()
mod.bias.data = module.bias.data.clone().detach()
for name, child in module.named_children():
mod.add_module(name, convert_model(child))
return mod

View File

@ -1,74 +0,0 @@
#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File : batchnorm_reimpl.py
# Author : acgtyrant
# Date : 11/01/2018
#
# This file is part of Synchronized-BatchNorm-PyTorch.
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
# Distributed under MIT License.
import torch
import torch.nn as nn
import torch.nn.init as init
__all__ = ['BatchNorm2dReimpl']
class BatchNorm2dReimpl(nn.Module):
"""
A re-implementation of batch normalization, used for testing the numerical
stability.
Author: acgtyrant
See also:
https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14
"""
def __init__(self, num_features, eps=1e-5, momentum=0.1):
super().__init__()
self.num_features = num_features
self.eps = eps
self.momentum = momentum
self.weight = nn.Parameter(torch.empty(num_features))
self.bias = nn.Parameter(torch.empty(num_features))
self.register_buffer('running_mean', torch.zeros(num_features))
self.register_buffer('running_var', torch.ones(num_features))
self.reset_parameters()
def reset_running_stats(self):
self.running_mean.zero_()
self.running_var.fill_(1)
def reset_parameters(self):
self.reset_running_stats()
init.uniform_(self.weight)
init.zeros_(self.bias)
def forward(self, input_):
batchsize, channels, height, width = input_.size()
numel = batchsize * height * width
input_ = input_.permute(1, 0, 2, 3).contiguous().view(channels, numel)
sum_ = input_.sum(1)
sum_of_square = input_.pow(2).sum(1)
mean = sum_ / numel
sumvar = sum_of_square - sum_ * mean
self.running_mean = (
(1 - self.momentum) * self.running_mean
+ self.momentum * mean.detach()
)
unbias_var = sumvar / (numel - 1)
self.running_var = (
(1 - self.momentum) * self.running_var
+ self.momentum * unbias_var.detach()
)
bias_var = sumvar / numel
inv_std = 1 / (bias_var + self.eps).pow(0.5)
output = (
(input_ - mean.unsqueeze(1)) * inv_std.unsqueeze(1) *
self.weight.unsqueeze(1) + self.bias.unsqueeze(1))
return output.view(channels, batchsize, height, width).permute(1, 0, 2, 3).contiguous()

View File

@ -1,137 +0,0 @@
# -*- coding: utf-8 -*-
# File : comm.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 27/01/2018
#
# This file is part of Synchronized-BatchNorm-PyTorch.
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
# Distributed under MIT License.
import queue
import collections
import threading
__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
class FutureResult(object):
"""A thread-safe future implementation. Used only as one-to-one pipe."""
def __init__(self):
self._result = None
self._lock = threading.Lock()
self._cond = threading.Condition(self._lock)
def put(self, result):
with self._lock:
assert self._result is None, 'Previous result has\'t been fetched.'
self._result = result
self._cond.notify()
def get(self):
with self._lock:
if self._result is None:
self._cond.wait()
res = self._result
self._result = None
return res
_MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
class SlavePipe(_SlavePipeBase):
"""Pipe for master-slave communication."""
def run_slave(self, msg):
self.queue.put((self.identifier, msg))
ret = self.result.get()
self.queue.put(True)
return ret
class SyncMaster(object):
"""An abstract `SyncMaster` object.
- During the replication, as the data parallel will trigger an callback of each module, all slave devices should
call `register(id)` and obtain an `SlavePipe` to communicate with the master.
- During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
and passed to a registered callback.
- After receiving the messages, the master device should gather the information and determine to message passed
back to each slave devices.
"""
def __init__(self, master_callback):
"""
Args:
master_callback: a callback to be invoked after having collected messages from slave devices.
"""
self._master_callback = master_callback
self._queue = queue.Queue()
self._registry = collections.OrderedDict()
self._activated = False
def __getstate__(self):
return {'master_callback': self._master_callback}
def __setstate__(self, state):
self.__init__(state['master_callback'])
def register_slave(self, identifier):
"""
Register an slave device.
Args:
identifier: an identifier, usually is the device id.
Returns: a `SlavePipe` object which can be used to communicate with the master device.
"""
if self._activated:
assert self._queue.empty(), 'Queue is not clean before next initialization.'
self._activated = False
self._registry.clear()
future = FutureResult()
self._registry[identifier] = _MasterRegistry(future)
return SlavePipe(identifier, self._queue, future)
def run_master(self, master_msg):
"""
Main entry for the master device in each forward pass.
The messages were first collected from each devices (including the master device), and then
an callback will be invoked to compute the message to be sent back to each devices
(including the master device).
Args:
master_msg: the message that the master want to send to itself. This will be placed as the first
message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
Returns: the message to be sent back to the master device.
"""
self._activated = True
intermediates = [(0, master_msg)]
for i in range(self.nr_slaves):
intermediates.append(self._queue.get())
results = self._master_callback(intermediates)
assert results[0][0] == 0, 'The first result should belongs to the master.'
for i, res in results:
if i == 0:
continue
self._registry[i].result.put(res)
for i in range(self.nr_slaves):
assert self._queue.get() is True
return results[0][1]
@property
def nr_slaves(self):
return len(self._registry)

View File

@ -1,94 +0,0 @@
# -*- coding: utf-8 -*-
# File : replicate.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 27/01/2018
#
# This file is part of Synchronized-BatchNorm-PyTorch.
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
# Distributed under MIT License.
import functools
from torch.nn.parallel.data_parallel import DataParallel
__all__ = [
'CallbackContext',
'execute_replication_callbacks',
'DataParallelWithCallback',
'patch_replication_callback'
]
class CallbackContext(object):
pass
def execute_replication_callbacks(modules):
"""
Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.
The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
Note that, as all modules are isomorphism, we assign each sub-module with a context
(shared among multiple copies of this module on different devices).
Through this context, different copies can share some information.
We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback
of any slave copies.
"""
master_copy = modules[0]
nr_modules = len(list(master_copy.modules()))
ctxs = [CallbackContext() for _ in range(nr_modules)]
for i, module in enumerate(modules):
for j, m in enumerate(module.modules()):
if hasattr(m, '__data_parallel_replicate__'):
m.__data_parallel_replicate__(ctxs[j], i)
class DataParallelWithCallback(DataParallel):
"""
Data Parallel with a replication callback.
An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by
original `replicate` function.
The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
Examples:
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
> sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
# sync_bn.__data_parallel_replicate__ will be invoked.
"""
def replicate(self, module, device_ids):
modules = super(DataParallelWithCallback, self).replicate(module, device_ids)
execute_replication_callbacks(modules)
return modules
def patch_replication_callback(data_parallel):
"""
Monkey-patch an existing `DataParallel` object. Add the replication callback.
Useful when you have customized `DataParallel` implementation.
Examples:
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
> sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
> patch_replication_callback(sync_bn)
# this is equivalent to
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
> sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
"""
assert isinstance(data_parallel, DataParallel)
old_replicate = data_parallel.replicate
@functools.wraps(old_replicate)
def new_replicate(module, device_ids):
modules = old_replicate(module, device_ids)
execute_replication_callbacks(modules)
return modules
data_parallel.replicate = new_replicate

View File

@ -1,29 +0,0 @@
# -*- coding: utf-8 -*-
# File : unittest.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 27/01/2018
#
# This file is part of Synchronized-BatchNorm-PyTorch.
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
# Distributed under MIT License.
import unittest
import torch
class TorchTestCase(unittest.TestCase):
def assertTensorClose(self, x, y):
adiff = float((x - y).abs().max())
if (y == 0).all():
rdiff = 'NaN'
else:
rdiff = float((adiff / y).abs().max())
message = (
'Tensor close check failed\n'
'adiff={}\n'
'rdiff={}\n'
).format(adiff, rdiff)
self.assertTrue(torch.allclose(x, y), message)

View File

@ -256,7 +256,7 @@ def build_resnet_backbone(cfg):
if pretrain:
if not with_ibn:
try:
state_dict = torch.load(pretrain_path)['model']
state_dict = torch.load(pretrain_path, map_location=torch.device('cpu'))['model']
# Remove module.encoder in name
new_state_dict = {}
for k in state_dict:
@ -270,7 +270,7 @@ def build_resnet_backbone(cfg):
state_dict = model_zoo.load_url(model_urls[depth])
logger.info("Loading pretrained model from torchvision")
else:
state_dict = torch.load(pretrain_path)['state_dict'] # ibn-net
state_dict = torch.load(pretrain_path, map_location=torch.device('cpu'))['state_dict'] # ibn-net
# Remove module in name
new_state_dict = {}
for k in state_dict:

View File

@ -4,7 +4,6 @@
@contact: sherlockliao01@gmail.com
"""
from .build_losses import reid_losses
from .cross_entroy_loss import CrossEntropyLoss
from .focal_loss import FocalLoss
from .metric_loss import *

View File

@ -1,20 +0,0 @@
# encoding: utf-8
"""
@author: xingyu liao
@contact: liaoxingyu5@jd.com
"""
from .. import losses as Loss
def reid_losses(cfg, pred_class_logits, global_features, gt_classes, prefix='') -> dict:
loss_dict = {}
for loss_name in cfg.MODEL.LOSSES.NAME:
loss = getattr(Loss, loss_name)(cfg)(pred_class_logits, global_features, gt_classes)
loss_dict.update(loss)
# rename
named_loss_dict = {}
for name in loss_dict.keys():
named_loss_dict[prefix + name] = loss_dict[name]
del loss_dict
return named_loss_dict

View File

@ -1,46 +0,0 @@
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import torch
from torch import nn
class CenterLoss(nn.Module):
"""Center loss.
Reference:
Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016.
Args:
num_classes (int): number of classes.
feat_dim (int): feature dimension.
"""
def __init__(self, num_classes=751, feat_dim=2048, use_gpu=True):
super(CenterLoss, self).__init__()
self.num_classes,self.feat_dim = num_classes, feat_dim
if use_gpu: self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim).cuda())
else: self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim))
def forward(self, x, labels):
"""
Args:
x: feature matrix with shape (batch_size, feat_dim).
labels: ground truth labels with shape (num_classes).
"""
assert x.size(0) == labels.size(0), "features.size(0) is not equal to labels.size(0)"
batch_size = x.size(0)
distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \
torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t()
distmat.addmm_(1, -2, x, self.centers.t())
classes = torch.arange(self.num_classes).long()
classes = classes.to(x.device)
labels = labels.unsqueeze(1).expand(batch_size, self.num_classes)
mask = labels.eq(classes.expand(batch_size, self.num_classes))
dist = distmat * mask.float()
loss = dist.clamp(min=1e-12, max=1e+12).sum() / batch_size
return loss

View File

@ -20,33 +20,31 @@ class CrossEntropyLoss(object):
self._alpha = cfg.MODEL.LOSSES.CE.ALPHA
self._scale = cfg.MODEL.LOSSES.CE.SCALE
self._topk = (1,)
def _log_accuracy(self, pred_class_logits, gt_classes):
@staticmethod
def log_accuracy(pred_class_logits, gt_classes, topk=(1,)):
"""
Log the accuracy metrics to EventStorage.
"""
bsz = pred_class_logits.size(0)
maxk = max(self._topk)
maxk = max(topk)
_, pred_class = pred_class_logits.topk(maxk, 1, True, True)
pred_class = pred_class.t()
correct = pred_class.eq(gt_classes.view(1, -1).expand_as(pred_class))
ret = []
for k in self._topk:
for k in topk:
correct_k = correct[:k].view(-1).float().sum(dim=0, keepdim=True)
ret.append(correct_k.mul_(1. / bsz))
storage = get_event_storage()
storage.put_scalar("cls_accuracy", ret[0])
def __call__(self, pred_class_logits, _, gt_classes):
def __call__(self, pred_class_logits, gt_classes):
"""
Compute the softmax cross entropy loss for box classification.
Returns:
scalar Tensor
"""
self._log_accuracy(pred_class_logits, gt_classes)
if self._eps >= 0:
smooth_param = self._eps
else:
@ -61,6 +59,4 @@ class CrossEntropyLoss(object):
targets.scatter_(1, gt_classes.data.unsqueeze(1), (1 - smooth_param))
loss = (-targets * log_probs).mean(0).sum()
return {
"loss_cls": loss * self._scale,
}
return loss * self._scale

View File

@ -7,8 +7,6 @@
import torch
import torch.nn.functional as F
from fastreid.utils.one_hot import one_hot
# based on:
# https://github.com/kornia/kornia/blob/master/kornia/losses/focal.py
@ -49,9 +47,7 @@ def focal_loss(
input_soft = F.softmax(input, dim=1)
# create the labels one hot tensor
target_one_hot = one_hot(
target, num_classes=input.shape[1],
dtype=input.dtype)
target_one_hot = F.one_hot(target, num_classes=input.shape[1])
# compute the actual focal loss
weight = torch.pow(-input_soft + 1., gamma)

View File

@ -6,6 +6,7 @@
import torch
import torch.nn.functional as F
from fastreid.utils import comm
__all__ = [
"TripletLoss",
@ -13,6 +14,21 @@ __all__ = [
]
# utils
@torch.no_grad()
def concat_all_gather(tensor):
"""
Performs all_gather operation on the provided tensors.
*** Warning ***: torch.distributed.all_gather has no gradient.
"""
tensors_gather = [torch.ones_like(tensor)
for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
output = torch.cat(tensors_gather, dim=0)
return output
def normalize(x, axis=-1):
"""Normalizing to unit length along the specified dimension.
Args:
@ -54,9 +70,9 @@ def softmax_weights(dist, mask):
def hard_example_mining(dist_mat, is_pos, is_neg):
"""For each anchor, find the hardest positive and negative sample.
Args:
dist_mat: pytorch Variable, pair wise distance between samples, shape [N, N]
labels: pytorch LongTensor, with shape [N]
return_inds: whether to return the indices. Save time if `False`(?)
dist_mat: pair wise distance between samples, shape [N, M]
is_pos: positive index with shape [N, M]
is_neg: negative index with shape [N, M]
Returns:
dist_ap: pytorch Variable, distance(anchor, positive); shape [N]
dist_an: pytorch Variable, distance(anchor, negative); shape [N]
@ -69,7 +85,6 @@ def hard_example_mining(dist_mat, is_pos, is_neg):
"""
assert len(dist_mat.size()) == 2
assert dist_mat.size(0) == dist_mat.size(1)
N = dist_mat.size(0)
# `dist_ap` means distance(anchor, positive)
@ -98,12 +113,13 @@ def weighted_example_mining(dist_mat, is_pos, is_neg):
"""For each anchor, find the weighted positive and negative sample.
Args:
dist_mat: pytorch Variable, pair wise distance between samples, shape [N, N]
is_pos:
is_neg:
Returns:
dist_ap: pytorch Variable, distance(anchor, positive); shape [N]
dist_an: pytorch Variable, distance(anchor, negative); shape [N]
"""
assert len(dist_mat.size()) == 2
assert dist_mat.size(0) == dist_mat.size(1)
is_pos = is_pos.float()
is_neg = is_neg.float()
@ -130,15 +146,23 @@ class TripletLoss(object):
self._scale = cfg.MODEL.LOSSES.TRI.SCALE
self._hard_mining = cfg.MODEL.LOSSES.TRI.HARD_MINING
def __call__(self, _, global_features, targets):
def __call__(self, embedding, targets):
if self._normalize_feature:
global_features = normalize(global_features, axis=-1)
embedding = normalize(embedding, axis=-1)
dist_mat = euclidean_dist(global_features, global_features)
# For distributed training, gather all features from different process.
if comm.get_world_size() > 1:
all_embedding = concat_all_gather(embedding)
all_targets = concat_all_gather(targets)
else:
all_embedding = embedding
all_targets = targets
N = dist_mat.size(0)
is_pos = targets.expand(N, N).eq(targets.expand(N, N).t())
is_neg = targets.expand(N, N).ne(targets.expand(N, N).t())
dist_mat = euclidean_dist(embedding, all_embedding)
N, M = dist_mat.size()
is_pos = targets.view(N, 1).expand(N, M).eq(all_targets.view(M, 1).expand(M, N).t())
is_neg = targets.view(N, 1).expand(N, M).ne(all_targets.view(M, 1).expand(M, N).t())
if self._hard_mining:
dist_ap, dist_an = hard_example_mining(dist_mat, is_pos, is_neg)
@ -153,9 +177,7 @@ class TripletLoss(object):
loss = F.soft_margin_loss(dist_an - dist_ap, y)
if loss == float('Inf'): loss = F.margin_ranking_loss(dist_an, dist_ap, y, margin=0.3)
return {
"loss_triplet": loss * self._scale,
}
return loss * self._scale
class CircleLoss(object):
@ -165,18 +187,24 @@ class CircleLoss(object):
self.m = cfg.MODEL.LOSSES.CIRCLE.MARGIN
self.s = cfg.MODEL.LOSSES.CIRCLE.ALPHA
def __call__(self, _, global_features, targets):
global_features = F.normalize(global_features, dim=1)
def __call__(self, embedding, targets):
embedding = F.normalize(embedding, dim=1)
sim_mat = torch.matmul(global_features, global_features.t())
if comm.get_world_size() > 1:
all_embedding = concat_all_gather(embedding)
all_targets = concat_all_gather(targets)
else:
all_embedding = embedding
all_targets = targets
N = sim_mat.size(0)
is_pos = targets.expand(N, N).eq(targets.expand(N, N).t()).float() - torch.eye(N).to(sim_mat.device)
is_pos = is_pos.bool()
is_neg = targets.expand(N, N).ne(targets.expand(N, N).t())
dist_mat = torch.matmul(embedding, all_embedding.t())
s_p = sim_mat[is_pos].contiguous().view(N, -1)
s_n = sim_mat[is_neg].contiguous().view(N, -1)
N, M = dist_mat.size()
is_pos = targets.view(N, 1).expand(N, M).eq(all_targets.view(M, 1).expand(M, N).t())
is_neg = targets.view(N, 1).expand(N, M).ne(all_targets.view(M, 1).expand(M, N).t())
s_p = dist_mat[is_pos].contiguous().view(N, -1)
s_n = dist_mat[is_neg].contiguous().view(N, -1)
alpha_p = F.relu(-s_p.detach() + 1 + self.m)
alpha_n = F.relu(s_n.detach() + self.m)
@ -188,6 +216,4 @@ class CircleLoss(object):
loss = F.softplus(torch.logsumexp(logit_p, dim=1) + torch.logsumexp(logit_n, dim=1)).mean()
return {
"loss_circle": loss * self._scale,
}
return loss * self._scale

View File

@ -9,4 +9,4 @@ from .build import META_ARCH_REGISTRY, build_model
# import all the meta_arch, so they will be registered
from .baseline import Baseline
from .mgn import MGN
from .mgn import MGN

View File

@ -10,7 +10,7 @@ from torch import nn
from fastreid.layers import GeneralizedMeanPoolingP, AdaptiveAvgMaxPool2d, FastGlobalAvgPool2d
from fastreid.modeling.backbones import build_backbone
from fastreid.modeling.heads import build_reid_heads
from fastreid.modeling.losses import reid_losses
from fastreid.modeling.losses import *
from .build import META_ARCH_REGISTRY
@ -18,9 +18,11 @@ from .build import META_ARCH_REGISTRY
class Baseline(nn.Module):
def __init__(self, cfg):
super().__init__()
self.register_buffer("pixel_mean", torch.Tensor(cfg.MODEL.PIXEL_MEAN).view(1, -1, 1, 1))
self.register_buffer("pixel_std", torch.Tensor(cfg.MODEL.PIXEL_STD).view(1, -1, 1, 1))
self._cfg = cfg
assert len(cfg.MODEL.PIXEL_MEAN) == len(cfg.MODEL.PIXEL_STD)
self.register_buffer("pixel_mean", torch.tensor(cfg.MODEL.PIXEL_MEAN).view(1, -1, 1, 1))
self.register_buffer("pixel_std", torch.tensor(cfg.MODEL.PIXEL_STD).view(1, -1, 1, 1))
# backbone
self.backbone = build_backbone(cfg)
@ -44,35 +46,50 @@ class Baseline(nn.Module):
return self.pixel_mean.device
def forward(self, batched_inputs):
if not self.training:
pred_feat = self.inference(batched_inputs)
try: return pred_feat, batched_inputs["targets"], batched_inputs["camid"]
except Exception: return pred_feat
images = self.preprocess_image(batched_inputs)
targets = batched_inputs["targets"].long()
features = self.backbone(images)
# training
features = self.backbone(images) # (bs, 2048, 16, 8)
return self.heads(features, targets)
if self.training:
assert "targets" in batched_inputs, "Person ID annotation are missing in training!"
targets = batched_inputs["targets"].long().to(self.device)
def inference(self, batched_inputs):
assert not self.training
images = self.preprocess_image(batched_inputs)
features = self.backbone(images) # (bs, 2048, 16, 8)
pred_feat = self.heads(features)
return pred_feat
# PreciseBN flag
if targets.sum() < 0:
# If do preciseBN on different dataset, the number of classes in new dataset
# may be larger than that in the original dataset, so the circle/arcface will
# throw an error. We just set all the targets to 0 to avoid this situation.
targets.zero_()
self.heads(features, targets)
# We just skip loss computation, because targets are all 0, and triplet loss
# will go wrong when targets are not in PK sampler way.
return
cls_outputs, features = self.heads(features, targets)
losses = self.losses(cls_outputs, features, targets)
return losses
else:
pred_features = self.heads(features)
return pred_features
def preprocess_image(self, batched_inputs):
"""
Normalize and batch the input images.
"""
# images = [x["images"] for x in batched_inputs]
images = batched_inputs["images"]
images = batched_inputs["images"].to(self.device)
# images = batched_inputs
images.sub_(self.pixel_mean).div_(self.pixel_std)
return images
def losses(self, outputs):
logits, feat, targets = outputs
return reid_losses(self._cfg, logits, feat, targets)
def losses(self, cls_outputs, pred_features, gt_labels):
loss_dict = {}
loss_names = self._cfg.MODEL.LOSSES.NAME
if "CrossEntropyLoss" in loss_names:
loss_dict['loss_cls'] = CrossEntropyLoss(self._cfg)(cls_outputs, gt_labels)
if "TripletLoss" in loss_names:
loss_dict['loss_triplet'] = TripletLoss(self._cfg)(pred_features, gt_labels)
if "CircleLoss" in loss_names:
loss_dict['loss_circle'] = CircleLoss(self._cfg)(pred_features, gt_labels)
return loss_dict

View File

@ -3,8 +3,9 @@
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import torch
from ...utils.registry import Registry
from fastreid.utils.registry import Registry
META_ARCH_REGISTRY = Registry("META_ARCH") # noqa F401 isort:skip
META_ARCH_REGISTRY.__doc__ = """
@ -20,4 +21,6 @@ def build_model(cfg):
Note that it does not load any weights from ``cfg``.
"""
meta_arch = cfg.MODEL.META_ARCHITECTURE
return META_ARCH_REGISTRY.get(meta_arch)(cfg)
model = META_ARCH_REGISTRY.get(meta_arch)(cfg)
model.to(torch.device(cfg.MODEL.DEVICE))
return model

View File

@ -12,7 +12,7 @@ from fastreid.layers import GeneralizedMeanPoolingP, get_norm, AdaptiveAvgMaxPoo
from fastreid.modeling.backbones import build_backbone
from fastreid.modeling.backbones.resnet import Bottleneck
from fastreid.modeling.heads import build_reid_heads
from fastreid.modeling.losses import reid_losses, CrossEntropyLoss
from fastreid.modeling.losses import CrossEntropyLoss, TripletLoss
from fastreid.utils.weight_init import weights_init_kaiming
from .build import META_ARCH_REGISTRY
@ -21,9 +21,10 @@ from .build import META_ARCH_REGISTRY
class MGN(nn.Module):
def __init__(self, cfg):
super().__init__()
self._cfg = cfg
assert len(cfg.MODEL.PIXEL_MEAN) == len(cfg.MODEL.PIXEL_STD)
self.register_buffer("pixel_mean", torch.Tensor(cfg.MODEL.PIXEL_MEAN).view(1, -1, 1, 1))
self.register_buffer("pixel_std", torch.Tensor(cfg.MODEL.PIXEL_STD).view(1, -1, 1, 1))
self._cfg = cfg
# backbone
bn_norm = cfg.MODEL.BACKBONE.NORM
@ -116,129 +117,113 @@ class MGN(nn.Module):
return self.pixel_mean.device
def forward(self, batched_inputs):
if not self.training:
pred_feat = self.inference(batched_inputs)
try: return pred_feat, batched_inputs["targets"], batched_inputs["camid"]
except Exception: return pred_feat
images = self.preprocess_image(batched_inputs)
targets = batched_inputs["targets"].long()
# Training
features = self.backbone(images) # (bs, 2048, 16, 8)
# branch1
b1_feat = self.b1(features)
b1_pool_feat = self.b1_pool(b1_feat)
b1_logits, b1_pool_feat, _ = self.b1_head(b1_pool_feat, targets)
# branch2
b2_feat = self.b2(features)
# global
b2_pool_feat = self.b2_pool(b2_feat)
b2_logits, b2_pool_feat, _ = self.b2_head(b2_pool_feat, targets)
b21_feat, b22_feat = torch.chunk(b2_feat, 2, dim=2)
# part1
b21_pool_feat = self.b21_pool(b21_feat)
b21_logits, b21_pool_feat, _ = self.b21_head(b21_pool_feat, targets)
# part2
b22_pool_feat = self.b22_pool(b22_feat)
b22_logits, b22_pool_feat, _ = self.b22_head(b22_pool_feat, targets)
# branch3
b3_feat = self.b3(features)
# global
b3_pool_feat = self.b3_pool(b3_feat)
b3_logits, b3_pool_feat, _ = self.b3_head(b3_pool_feat, targets)
b31_feat, b32_feat, b33_feat = torch.chunk(b3_feat, 3, dim=2)
# part1
b31_pool_feat = self.b31_pool(b31_feat)
b31_logits, b31_pool_feat, _ = self.b31_head(b31_pool_feat, targets)
# part2
b32_pool_feat = self.b32_pool(b32_feat)
b32_logits, b32_pool_feat, _ = self.b32_head(b32_pool_feat, targets)
# part3
b33_pool_feat = self.b33_pool(b33_feat)
b33_logits, b33_pool_feat, _ = self.b33_head(b33_pool_feat, targets)
return (b1_logits, b2_logits, b3_logits, b21_logits, b22_logits, b31_logits, b32_logits, b33_logits), \
(b1_pool_feat, b2_pool_feat, b3_pool_feat,
if self.training:
assert "targets" in batched_inputs, "Person ID annotation are missing in training!"
targets = batched_inputs["targets"].long().to(self.device)
if targets.sum() < 0:
targets.zero_()
self.b1_head(b1_pool_feat, targets)
self.b2_head(b2_pool_feat, targets)
self.b21_head(b21_pool_feat, targets)
self.b22_head(b22_pool_feat, targets)
self.b3_head(b3_pool_feat, targets)
self.b31_head(b31_pool_feat, targets)
self.b32_head(b32_pool_feat, targets)
self.b33_head(b33_pool_feat, targets)
return
b1_logits, b1_pool_feat = self.b1_head(b1_pool_feat, targets)
b2_logits, b2_pool_feat = self.b2_head(b2_pool_feat, targets)
b21_logits, b21_pool_feat = self.b21_head(b21_pool_feat, targets)
b22_logits, b22_pool_feat = self.b22_head(b22_pool_feat, targets)
b3_logits, b3_pool_feat = self.b3_head(b3_pool_feat, targets)
b31_logits, b31_pool_feat = self.b31_head(b31_pool_feat, targets)
b32_logits, b32_pool_feat = self.b32_head(b32_pool_feat, targets)
b33_logits, b33_pool_feat = self.b33_head(b33_pool_feat, targets)
losses = self.losses(
b1_logits, b2_logits, b21_logits, b22_logits, b3_logits, b31_logits, b32_logits, b33_logits,
b1_pool_feat, b2_pool_feat, b3_pool_feat,
torch.cat((b21_pool_feat, b22_pool_feat), dim=1),
torch.cat((b31_pool_feat, b32_pool_feat, b33_pool_feat), dim=1)), \
targets
torch.cat((b31_pool_feat, b32_pool_feat, b33_pool_feat), dim=1),
targets,
)
return losses
else:
b1_pool_feat = self.b1_head(b1_pool_feat)
b2_pool_feat = self.b2_head(b2_pool_feat)
b21_pool_feat = self.b21_head(b21_pool_feat)
b22_pool_feat = self.b22_head(b22_pool_feat)
b3_pool_feat = self.b3_head(b3_pool_feat)
b31_pool_feat = self.b31_head(b31_pool_feat)
b32_pool_feat = self.b32_head(b32_pool_feat)
b33_pool_feat = self.b33_head(b33_pool_feat)
def inference(self, batched_inputs):
assert not self.training
images = self.preprocess_image(batched_inputs)
features = self.backbone(images) # (bs, 2048, 16, 8)
# branch1
b1_feat = self.b1(features)
b1_pool_feat = self.b1_pool(b1_feat)
b1_pool_feat = self.b1_head(b1_pool_feat)
# branch2
b2_feat = self.b2(features)
# global
b2_pool_feat = self.b2_pool(b2_feat)
b2_pool_feat = self.b2_head(b2_pool_feat)
b21_feat, b22_feat = torch.chunk(b2_feat, 2, dim=2)
# part1
b21_pool_feat = self.b21_pool(b21_feat)
b21_pool_feat = self.b21_head(b21_pool_feat)
# part2
b22_pool_feat = self.b22_pool(b22_feat)
b22_pool_feat = self.b22_head(b22_pool_feat)
# branch3
b3_feat = self.b3(features)
# global
b3_pool_feat = self.b3_pool(b3_feat)
b3_pool_feat = self.b3_head(b3_pool_feat)
b31_feat, b32_feat, b33_feat = torch.chunk(b3_feat, 3, dim=2)
# part1
b31_pool_feat = self.b31_pool(b31_feat)
b31_pool_feat = self.b31_head(b31_pool_feat)
# part2
b32_pool_feat = self.b32_pool(b32_feat)
b32_pool_feat = self.b32_head(b32_pool_feat)
# part3
b33_pool_feat = self.b33_pool(b33_feat)
b33_pool_feat = self.b33_head(b33_pool_feat)
pred_feat = torch.cat([b1_pool_feat, b2_pool_feat, b3_pool_feat, b21_pool_feat,
b22_pool_feat, b31_pool_feat, b32_pool_feat, b33_pool_feat], dim=1)
return pred_feat
pred_feat = torch.cat([b1_pool_feat, b2_pool_feat, b3_pool_feat, b21_pool_feat,
b22_pool_feat, b31_pool_feat, b32_pool_feat, b33_pool_feat], dim=1)
return pred_feat
def preprocess_image(self, batched_inputs):
"""
Normalize and batch the input images.
"""
# images = [x["images"] for x in batched_inputs]
images = batched_inputs["images"]
images = batched_inputs["images"].to(self.device)
# images = batched_inputs
images.sub_(self.pixel_mean).div_(self.pixel_std)
return images
def losses(self, outputs):
logits, feats, targets = outputs
def losses(self, b1_logits, b2_logits, b21_logits, b22_logits, b3_logits, b31_logits, b32_logits, b33_logits,
b1_pool_feat, b2_pool_feat, b3_pool_feat, b22_pool_feat, b33_pool_feat, gt_labels):
loss_dict = {}
loss_dict.update(reid_losses(self._cfg, logits[0], feats[0], targets, 'b1_'))
loss_dict.update(reid_losses(self._cfg, logits[1], feats[1], targets, 'b2_'))
loss_dict.update(reid_losses(self._cfg, logits[2], feats[2], targets, 'b3_'))
loss_dict.update(reid_losses(self._cfg, logits[3], feats[3], targets, 'b21_'))
loss_dict.update(reid_losses(self._cfg, logits[5], feats[4], targets, 'b31_'))
loss_names = self._cfg.MODEL.LOSSES.NAME
if "CrossEntropyLoss" in loss_names:
loss_dict['loss_cls_b1'] = CrossEntropyLoss(self._cfg)(b1_logits, gt_labels)
loss_dict['loss_cls_b2'] = CrossEntropyLoss(self._cfg)(b2_logits, gt_labels)
loss_dict['loss_cls_b21'] = CrossEntropyLoss(self._cfg)(b21_logits, gt_labels)
loss_dict['loss_cls_b22'] = CrossEntropyLoss(self._cfg)(b22_logits, gt_labels)
loss_dict['loss_cls_b3'] = CrossEntropyLoss(self._cfg)(b3_logits, gt_labels)
loss_dict['loss_cls_b31'] = CrossEntropyLoss(self._cfg)(b31_logits, gt_labels)
loss_dict['loss_cls_b32'] = CrossEntropyLoss(self._cfg)(b32_logits, gt_labels)
loss_dict['loss_cls_b33'] = CrossEntropyLoss(self._cfg)(b33_logits, gt_labels)
if "TripletLoss" in loss_names:
loss_dict['loss_triplet_b1'] = TripletLoss(self._cfg)(b1_pool_feat, gt_labels)
loss_dict['loss_triplet_b2'] = TripletLoss(self._cfg)(b2_pool_feat, gt_labels)
loss_dict['loss_triplet_b3'] = TripletLoss(self._cfg)(b3_pool_feat, gt_labels)
loss_dict['loss_triplet_b22'] = TripletLoss(self._cfg)(b22_pool_feat, gt_labels)
loss_dict['loss_triplet_b33'] = TripletLoss(self._cfg)(b33_pool_feat, gt_labels)
part_ce_loss = [
(CrossEntropyLoss(self._cfg)(logits[4], None, targets), 'b22_'),
(CrossEntropyLoss(self._cfg)(logits[6], None, targets), 'b32_'),
(CrossEntropyLoss(self._cfg)(logits[7], None, targets), 'b33_')
]
named_ce_loss = {}
for item in part_ce_loss:
named_ce_loss[item[1] + [*item[0]][0]] = [*item[0].values()][0]
loss_dict.update(named_ce_loss)
return loss_dict

View File

@ -20,7 +20,7 @@ def build_optimizer(cfg, model):
if "bias" in key:
lr *= cfg.SOLVER.BIAS_LR_FACTOR
weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS
params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}]
params += [{"name": key, "params": [value], "lr": lr, "weight_decay": weight_decay, "freeze": False}]
solver_opt = cfg.SOLVER.OPT
if hasattr(optim, solver_opt):

View File

@ -4,13 +4,14 @@
@contact: sherlockliao01@gmail.com
"""
import math
from bisect import bisect_right
from typing import List
import torch
from torch.optim.lr_scheduler import _LRScheduler, CosineAnnealingLR
from torch.optim.lr_scheduler import _LRScheduler
__all__ = ["WarmupMultiStepLR", "DelayedScheduler"]
__all__ = ["WarmupMultiStepLR", "WarmupCosineAnnealingLR"]
class WarmupMultiStepLR(_LRScheduler):
@ -50,6 +51,71 @@ class WarmupMultiStepLR(_LRScheduler):
return self.get_lr()
class WarmupCosineAnnealingLR(_LRScheduler):
r"""Set the learning rate of each parameter group using a cosine annealing
schedule, where :math:`\eta_{max}` is set to the initial lr and
:math:`T_{cur}` is the number of epochs since the last restart in SGDR:
.. math::
\eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})(1 +
\cos(\frac{T_{cur}}{T_{max}}\pi))
When last_epoch=-1, sets initial lr as lr.
It has been proposed in
`SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this only
implements the cosine annealing part of SGDR, and not the restarts.
Args:
optimizer (Optimizer): Wrapped optimizer.
T_max (int): Maximum number of iterations.
eta_min (float): Minimum learning rate. Default: 0.
last_epoch (int): The index of last epoch. Default: -1.
.. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
https://arxiv.org/abs/1608.03983
"""
def __init__(
self,
optimizer: torch.optim.Optimizer,
max_iters: int,
delay_iters: int = 0,
eta_min_lr: int = 0,
warmup_factor: float = 0.001,
warmup_iters: int = 1000,
warmup_method: str = "linear",
last_epoch=-1,
**kwargs
):
self.max_iters = max_iters
self.delay_iters = delay_iters
self.eta_min_lr = eta_min_lr
self.warmup_factor = warmup_factor
self.warmup_iters = warmup_iters
self.warmup_method = warmup_method
assert self.delay_iters >= self.warmup_iters, "Scheduler delay iters must be larger than warmup iters"
super(WarmupCosineAnnealingLR, self).__init__(optimizer, last_epoch)
def get_lr(self) -> List[float]:
if self.last_epoch <= self.warmup_iters:
warmup_factor = _get_warmup_factor_at_iter(
self.warmup_method, self.last_epoch, self.warmup_iters, self.warmup_factor,
)
return [
base_lr * warmup_factor for base_lr in self.base_lrs
]
elif self.last_epoch <= self.delay_iters:
return self.base_lrs
else:
return [
self.eta_min_lr + (base_lr - self.eta_min_lr) *
(1 + math.cos(
math.pi * (self.last_epoch - self.delay_iters) / (self.max_iters - self.delay_iters))) / 2
for base_lr in self.base_lrs]
def _get_warmup_factor_at_iter(
method: str, iter: int, warmup_iters: int, warmup_factor: float
) -> float:
@ -75,49 +141,3 @@ def _get_warmup_factor_at_iter(
return warmup_factor * (1 - alpha) + alpha
else:
raise ValueError("Unknown warmup method: {}".format(method))
class DelayedScheduler(_LRScheduler):
""" Starts with a flat lr schedule until it reaches N epochs the applies a scheduler
Args:
optimizer (Optimizer): Wrapped optimizer.
delay_iters: number of epochs to keep the initial lr until starting applying the scheduler
after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau)
"""
def __init__(self, optimizer, delay_iters, after_scheduler, warmup_factor, warmup_iters, warmup_method):
self.delay_epochs = delay_iters
self.after_scheduler = after_scheduler
self.finished = False
self.warmup_factor = warmup_factor
self.warmup_iters = warmup_iters
self.warmup_method = warmup_method
super().__init__(optimizer)
def get_lr(self):
if self.last_epoch >= self.delay_epochs:
if not self.finished:
self.after_scheduler.base_lrs = self.base_lrs
self.finished = True
return self.after_scheduler.get_lr()
warmup_factor = _get_warmup_factor_at_iter(
self.warmup_method, self.last_epoch, self.warmup_iters, self.warmup_factor
)
return [base_lr * warmup_factor for base_lr in self.base_lrs]
def step(self, epoch=None):
if self.finished:
if epoch is None:
self.after_scheduler.step(None)
else:
self.after_scheduler.step(epoch - self.delay_epochs)
else:
return super(DelayedScheduler, self).step(epoch)
def DelayedCosineAnnealingLR(optimizer, delay_iters, max_iters, eta_min_lr, warmup_factor,
warmup_iters, warmup_method, **kwargs, ):
cosine_annealing_iters = max_iters - delay_iters
base_scheduler = CosineAnnealingLR(optimizer, cosine_annealing_iters, eta_min_lr)
return DelayedScheduler(optimizer, delay_iters, base_scheduler, warmup_factor, warmup_iters, warmup_method)

View File

@ -1,11 +1,5 @@
from .lamb import Lamb
from .lars import LARS
from .lookahead import Lookahead, LookaheadAdam
from .novograd import Novograd
from .over9000 import Over9000, RangerLars
from .radam import RAdam, PlainRAdam, AdamW
from .ralamb import Ralamb
from .ranger import Ranger
from .swa import SWA
from .adam import Adam
from .sgd import SGD
from torch.optim import *

View File

@ -0,0 +1,115 @@
# encoding: utf-8
"""
@author: xingyu liao
@contact: liaoxingyu5@jd.com
"""
import torch
import math
from torch.optim.optimizer import Optimizer
class Adam(Optimizer):
r"""Implements Adam algorithm.
It has been proposed in `Adam: A Method for Stochastic Optimization`_.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): learning rate (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False)
.. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
"""
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
weight_decay=0, amsgrad=False):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
defaults = dict(lr=lr, betas=betas, eps=eps,
weight_decay=weight_decay, amsgrad=amsgrad)
super(Adam, self).__init__(params, defaults)
def __setstate__(self, state):
super(Adam, self).__setstate__(state)
for group in self.param_groups:
group.setdefault('amsgrad', False)
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None or group['freeze']:
continue
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
amsgrad = group['amsgrad']
state = self.state[p]
# State initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p.data)
if amsgrad:
# Maintains max of all exp. moving avg. of sq. grad. values
state['max_exp_avg_sq'] = torch.zeros_like(p.data)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
if amsgrad:
max_exp_avg_sq = state['max_exp_avg_sq']
beta1, beta2 = group['betas']
state['step'] += 1
if group['weight_decay'] != 0:
grad.add_(group['weight_decay'], p.data)
# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(1 - beta1, grad)
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
if amsgrad:
# Maintains the maximum of all 2nd moment running avg. till now
torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
# Use the max. for normalizing running avg. of gradient
denom = max_exp_avg_sq.sqrt().add_(group['eps'])
else:
denom = exp_avg_sq.sqrt().add_(group['eps'])
bias_correction1 = 1 - beta1 ** state['step']
bias_correction2 = 1 - beta2 ** state['step']
step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
p.data.addcdiv_(-step_size, exp_avg, denom)
return loss

View File

@ -68,7 +68,7 @@ class Lamb(Optimizer):
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
if p.grad is None or group['freeze']:
continue
grad = p.grad.data
if grad.is_sparse:

View File

@ -1,101 +0,0 @@
# encoding: utf-8
"""
@author: xingyu liao
@contact: liaoxingyu5@jd.com
"""
# based on:
# https://github.com/NVIDIA/apex/blob/d74fda260c403f775817470d87f810f816f3d615/apex/parallel/LARC.py
import torch
from torch.optim.optimizer import Optimizer
class LARS(Optimizer):
"""
:class:`LARS` is a pytorch implementation of both the scaling and clipping variants of LARC,
in which the ratio between gradient and parameter magnitudes is used to calculate an adaptive
local learning rate for each individual parameter. The algorithm is designed to improve
convergence of large batch training.
See https://arxiv.org/abs/1708.03888 for calculation of the local learning rate.
In practice it modifies the gradients of parameters as a proxy for modifying the learning rate
of the parameters. This design allows it to be used as a wrapper around any torch.optim Optimizer.
```
model = ...
optim = torch.optim.Adam(model.parameters(), lr=...)
optim = LARS(optim)
```
It can even be used in conjunction with apex.fp16_utils.FP16_optimizer.
```
model = ...
optim = torch.optim.Adam(model.parameters(), lr=...)
optim = LARS(optim)
optim = apex.fp16_utils.FP16_Optimizer(optim)
```
Args:
optimizer: Pytorch optimizer to wrap and modify learning rate for.
trust_coefficient: Trust coefficient for calculating the lr. See https://arxiv.org/abs/1708.03888
clip: Decides between clipping or scaling mode of LARC. If `clip=True` the learning rate is set to `min(optimizer_lr, local_lr)` for each parameter. If `clip=False` the learning rate is set to `local_lr*optimizer_lr`.
eps: epsilon kludge to help with numerical stability while calculating adaptive_lr
"""
def __init__(self, optimizer, trust_coefficient=0.02, clip=True, eps=1e-8):
self.param_groups = optimizer.param_groups
self.optim = optimizer
self.trust_coefficient = trust_coefficient
self.eps = eps
self.clip = clip
def __getstate__(self):
return self.optim.__getstate__()
def __setstate__(self, state):
self.optim.__setstate__(state)
def __repr__(self):
return self.optim.__repr__()
def state_dict(self):
return self.optim.state_dict()
def load_state_dict(self, state_dict):
self.optim.load_state_dict(state_dict)
def zero_grad(self):
self.optim.zero_grad()
def add_param_group(self, param_group):
self.optim.add_param_group(param_group)
def step(self):
with torch.no_grad():
weight_decays = []
for group in self.optim.param_groups:
# absorb weight decay control from optimizer
weight_decay = group['weight_decay'] if 'weight_decay' in group else 0
weight_decays.append(weight_decay)
group['weight_decay'] = 0
for p in group['params']:
if p.grad is None:
continue
param_norm = torch.norm(p.data)
grad_norm = torch.norm(p.grad.data)
if param_norm != 0 and grad_norm != 0:
# calculate adaptive lr + weight decay
adaptive_lr = self.trust_coefficient * (param_norm) / (
grad_norm + param_norm * weight_decay + self.eps)
# clip learning rate for LARC
if self.clip:
# calculation of adaptive_lr so that when multiplied by lr it equals `min(adaptive_lr, lr)`
adaptive_lr = min(adaptive_lr / group['lr'], 1)
p.grad.data += weight_decay * p.data
p.grad.data *= adaptive_lr
self.optim.step()
# return weight decay control to optimizer
for i, group in enumerate(self.optim.param_groups):
group['weight_decay'] = weight_decays[i]

View File

@ -1,104 +0,0 @@
####
# CODE TAKEN FROM https://github.com/lonePatient/lookahead_pytorch
# Original paper: https://arxiv.org/abs/1907.08610
####
# Lookahead implementation from https://github.com/rwightman/pytorch-image-models/blob/master/timm/optim/lookahead.py
""" Lookahead Optimizer Wrapper.
Implementation modified from: https://github.com/alphadl/lookahead.pytorch
Paper: `Lookahead Optimizer: k steps forward, 1 step back` - https://arxiv.org/abs/1907.08610
"""
from collections import defaultdict
import torch
from torch.optim import Adam
from torch.optim.optimizer import Optimizer
class Lookahead(Optimizer):
def __init__(self, base_optimizer, alpha=0.5, k=6):
if not 0.0 <= alpha <= 1.0:
raise ValueError(f'Invalid slow update rate: {alpha}')
if not 1 <= k:
raise ValueError(f'Invalid lookahead steps: {k}')
defaults = dict(lookahead_alpha=alpha, lookahead_k=k, lookahead_step=0)
self.base_optimizer = base_optimizer
self.param_groups = self.base_optimizer.param_groups
self.defaults = base_optimizer.defaults
self.defaults.update(defaults)
self.state = defaultdict(dict)
# manually add our defaults to the param groups
for name, default in defaults.items():
for group in self.param_groups:
group.setdefault(name, default)
def update_slow(self, group):
for fast_p in group["params"]:
if fast_p.grad is None:
continue
param_state = self.state[fast_p]
if 'slow_buffer' not in param_state:
param_state['slow_buffer'] = torch.empty_like(fast_p.data)
param_state['slow_buffer'].copy_(fast_p.data)
slow = param_state['slow_buffer']
slow.add_(group['lookahead_alpha'], fast_p.data - slow)
fast_p.data.copy_(slow)
def sync_lookahead(self):
for group in self.param_groups:
self.update_slow(group)
def step(self, closure=None):
# print(self.k)
# assert id(self.param_groups) == id(self.base_optimizer.param_groups)
loss = self.base_optimizer.step(closure)
for group in self.param_groups:
group['lookahead_step'] += 1
if group['lookahead_step'] % group['lookahead_k'] == 0:
self.update_slow(group)
return loss
def state_dict(self):
fast_state_dict = self.base_optimizer.state_dict()
slow_state = {
(id(k) if isinstance(k, torch.Tensor) else k): v
for k, v in self.state.items()
}
fast_state = fast_state_dict['state']
param_groups = fast_state_dict['param_groups']
return {
'state': fast_state,
'slow_state': slow_state,
'param_groups': param_groups,
}
def load_state_dict(self, state_dict):
fast_state_dict = {
'state': state_dict['state'],
'param_groups': state_dict['param_groups'],
}
self.base_optimizer.load_state_dict(fast_state_dict)
# We want to restore the slow state, but share param_groups reference
# with base_optimizer. This is a bit redundant but least code
slow_state_new = False
if 'slow_state' not in state_dict:
print('Loading state_dict from optimizer without Lookahead applied.')
state_dict['slow_state'] = defaultdict(dict)
slow_state_new = True
slow_state_dict = {
'state': state_dict['slow_state'],
'param_groups': state_dict['param_groups'], # this is pointless but saves code
}
super(Lookahead, self).load_state_dict(slow_state_dict)
self.param_groups = self.base_optimizer.param_groups # make both ref same container
if slow_state_new:
# reapply defaults to catch missing lookahead specific ones
for name, default in self.defaults.items():
for group in self.param_groups:
group.setdefault(name, default)
def LookaheadAdam(params, alpha=0.5, k=6, *args, **kwargs):
adam = Adam(params, *args, **kwargs)
return Lookahead(adam, alpha, k)

View File

@ -1,229 +0,0 @@
####
# CODE TAKEN FROM https://github.com/mgrankin/over9000
####
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch.optim.optimizer import Optimizer
import math
class AdamW(Optimizer):
"""Implements AdamW algorithm.
It has been proposed in `Adam: A Method for Stochastic Optimization`_.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): learning rate (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
Adam: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
"""
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
weight_decay=0, amsgrad=False):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
defaults = dict(lr=lr, betas=betas, eps=eps,
weight_decay=weight_decay, amsgrad=amsgrad)
super(AdamW, self).__init__(params, defaults)
def __setstate__(self, state):
super(AdamW, self).__setstate__(state)
for group in self.param_groups:
group.setdefault('amsgrad', False)
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
amsgrad = group['amsgrad']
state = self.state[p]
# State initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p.data)
if amsgrad:
# Maintains max of all exp. moving avg. of sq. grad. values
state['max_exp_avg_sq'] = torch.zeros_like(p.data)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
if amsgrad:
max_exp_avg_sq = state['max_exp_avg_sq']
beta1, beta2 = group['betas']
state['step'] += 1
# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(1 - beta1, grad)
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
if amsgrad:
# Maintains the maximum of all 2nd moment running avg. till now
torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
# Use the max. for normalizing running avg. of gradient
denom = max_exp_avg_sq.sqrt().add_(group['eps'])
else:
denom = exp_avg_sq.sqrt().add_(group['eps'])
bias_correction1 = 1 - beta1 ** state['step']
bias_correction2 = 1 - beta2 ** state['step']
step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
p.data.add_(-step_size, torch.mul(p.data, group['weight_decay']).addcdiv_(1, exp_avg, denom))
return loss
class Novograd(Optimizer):
"""
Implements Novograd algorithm.
Args:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): learning rate (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.95, 0))
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
grad_averaging: gradient averaging
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False)
"""
def __init__(self, params, lr=1e-3, betas=(0.95, 0), eps=1e-8,
weight_decay=0, grad_averaging=False, amsgrad=False):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
defaults = dict(lr=lr, betas=betas, eps=eps,
weight_decay=weight_decay,
grad_averaging=grad_averaging,
amsgrad=amsgrad)
super(Novograd, self).__init__(params, defaults)
def __setstate__(self, state):
super(Novograd, self).__setstate__(state)
for group in self.param_groups:
group.setdefault('amsgrad', False)
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError('Sparse gradients are not supported.')
amsgrad = group['amsgrad']
state = self.state[p]
# State initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros([]).to(state['exp_avg'].device)
if amsgrad:
# Maintains max of all exp. moving avg. of sq. grad. values
state['max_exp_avg_sq'] = torch.zeros([]).to(state['exp_avg'].device)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
if amsgrad:
max_exp_avg_sq = state['max_exp_avg_sq']
beta1, beta2 = group['betas']
state['step'] += 1
norm = torch.sum(torch.pow(grad, 2))
if exp_avg_sq == 0:
exp_avg_sq.copy_(norm)
else:
exp_avg_sq.mul_(beta2).add_(1 - beta2, norm)
if amsgrad:
# Maintains the maximum of all 2nd moment running avg. till now
torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
# Use the max. for normalizing running avg. of gradient
denom = max_exp_avg_sq.sqrt().add_(group['eps'])
else:
denom = exp_avg_sq.sqrt().add_(group['eps'])
grad.div_(denom)
if group['weight_decay'] != 0:
grad.add_(group['weight_decay'], p.data)
if group['grad_averaging']:
grad.mul_(1 - beta1)
exp_avg.mul_(beta1).add_(grad)
p.data.add_(-group['lr'], exp_avg)
return loss

View File

@ -1,19 +0,0 @@
####
# CODE TAKEN FROM https://github.com/mgrankin/over9000
####
from .lookahead import Lookahead
from .ralamb import Ralamb
# RAdam + LARS + LookAHead
# Lookahead implementation from https://github.com/lonePatient/lookahead_pytorch/blob/master/optimizer.py
# RAdam + LARS implementation from https://gist.github.com/redknightlois/c4023d393eb8f92bb44b2ab582d7ec20
def Over9000(params, alpha=0.5, k=6, *args, **kwargs):
ralamb = Ralamb(params, *args, **kwargs)
return Lookahead(ralamb, alpha, k)
RangerLars = Over9000

View File

@ -1,255 +0,0 @@
####
# CODE TAKEN FROM https://github.com/LiyuanLucasLiu/RAdam
# Paper: https://arxiv.org/abs/1908.03265
####
import math
import torch
from torch.optim.optimizer import Optimizer
class RAdam(Optimizer):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=True):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
self.degenerated_to_sgd = degenerated_to_sgd
if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict):
for param in params:
if 'betas' in param and (param['betas'][0] != betas[0] or param['betas'][1] != betas[1]):
param['buffer'] = [[None, None, None] for _ in range(10)]
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay,
buffer=[[None, None, None] for _ in range(10)])
super(RAdam, self).__init__(params, defaults)
def __setstate__(self, state):
super(RAdam, self).__setstate__(state)
def step(self, closure=None):
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data.float()
if grad.is_sparse:
raise RuntimeError('RAdam does not support sparse gradients')
p_data_fp32 = p.data.float()
state = self.state[p]
if len(state) == 0:
state['step'] = 0
state['exp_avg'] = torch.zeros_like(p_data_fp32)
state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
else:
state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
beta1, beta2 = group['betas']
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
exp_avg.mul_(beta1).add_(1 - beta1, grad)
state['step'] += 1
buffered = group['buffer'][int(state['step'] % 10)]
if state['step'] == buffered[0]:
N_sma, step_size = buffered[1], buffered[2]
else:
buffered[0] = state['step']
beta2_t = beta2 ** state['step']
N_sma_max = 2 / (1 - beta2) - 1
N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
buffered[1] = N_sma
# more conservative since it's an approximated value
if N_sma >= 5:
step_size = math.sqrt(
(1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (
N_sma_max - 2)) / (1 - beta1 ** state['step'])
elif self.degenerated_to_sgd:
step_size = 1.0 / (1 - beta1 ** state['step'])
else:
step_size = -1
buffered[2] = step_size
# more conservative since it's an approximated value
if N_sma >= 5:
if group['weight_decay'] != 0:
p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
denom = exp_avg_sq.sqrt().add_(group['eps'])
p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom)
p.data.copy_(p_data_fp32)
elif step_size > 0:
if group['weight_decay'] != 0:
p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
p_data_fp32.add_(-step_size * group['lr'], exp_avg)
p.data.copy_(p_data_fp32)
return loss
class PlainRAdam(Optimizer):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=True):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
self.degenerated_to_sgd = degenerated_to_sgd
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
super(PlainRAdam, self).__init__(params, defaults)
def __setstate__(self, state):
super(PlainRAdam, self).__setstate__(state)
def step(self, closure=None):
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data.float()
if grad.is_sparse:
raise RuntimeError('RAdam does not support sparse gradients')
p_data_fp32 = p.data.float()
state = self.state[p]
if len(state) == 0:
state['step'] = 0
state['exp_avg'] = torch.zeros_like(p_data_fp32)
state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
else:
state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
beta1, beta2 = group['betas']
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
exp_avg.mul_(beta1).add_(1 - beta1, grad)
state['step'] += 1
beta2_t = beta2 ** state['step']
N_sma_max = 2 / (1 - beta2) - 1
N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
# more conservative since it's an approximated value
if N_sma >= 5:
if group['weight_decay'] != 0:
p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
step_size = group['lr'] * math.sqrt(
(1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (
N_sma_max - 2)) / (1 - beta1 ** state['step'])
denom = exp_avg_sq.sqrt().add_(group['eps'])
p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
p.data.copy_(p_data_fp32)
elif self.degenerated_to_sgd:
if group['weight_decay'] != 0:
p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
step_size = group['lr'] / (1 - beta1 ** state['step'])
p_data_fp32.add_(-step_size, exp_avg)
p.data.copy_(p_data_fp32)
return loss
class AdamW(Optimizer):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, warmup=0):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
defaults = dict(lr=lr, betas=betas, eps=eps,
weight_decay=weight_decay, warmup=warmup)
super(AdamW, self).__init__(params, defaults)
def __setstate__(self, state):
super(AdamW, self).__setstate__(state)
def step(self, closure=None):
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data.float()
if grad.is_sparse:
raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
p_data_fp32 = p.data.float()
state = self.state[p]
if len(state) == 0:
state['step'] = 0
state['exp_avg'] = torch.zeros_like(p_data_fp32)
state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
else:
state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
beta1, beta2 = group['betas']
state['step'] += 1
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
exp_avg.mul_(beta1).add_(1 - beta1, grad)
denom = exp_avg_sq.sqrt().add_(group['eps'])
bias_correction1 = 1 - beta1 ** state['step']
bias_correction2 = 1 - beta2 ** state['step']
if group['warmup'] > state['step']:
scheduled_lr = 1e-8 + state['step'] * group['lr'] / group['warmup']
else:
scheduled_lr = group['lr']
step_size = scheduled_lr * math.sqrt(bias_correction2) / bias_correction1
if group['weight_decay'] != 0:
p_data_fp32.add_(-group['weight_decay'] * scheduled_lr, p_data_fp32)
p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
p.data.copy_(p_data_fp32)
return loss

View File

@ -1,103 +0,0 @@
####
# CODE TAKEN FROM https://github.com/mgrankin/over9000
####
import torch, math
from torch.optim.optimizer import Optimizer
# RAdam + LARS
class Ralamb(Optimizer):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
self.buffer = [[None, None, None] for ind in range(10)]
super(Ralamb, self).__init__(params, defaults)
def __setstate__(self, state):
super(Ralamb, self).__setstate__(state)
def step(self, closure=None):
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data.float()
if grad.is_sparse:
raise RuntimeError('Ralamb does not support sparse gradients')
p_data_fp32 = p.data.float()
state = self.state[p]
if len(state) == 0:
state['step'] = 0
state['exp_avg'] = torch.zeros_like(p_data_fp32)
state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
else:
state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
beta1, beta2 = group['betas']
# Decay the first and second moment running average coefficient
# m_t
exp_avg.mul_(beta1).add_(1 - beta1, grad)
# v_t
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
state['step'] += 1
buffered = self.buffer[int(state['step'] % 10)]
if state['step'] == buffered[0]:
N_sma, radam_step_size = buffered[1], buffered[2]
else:
buffered[0] = state['step']
beta2_t = beta2 ** state['step']
N_sma_max = 2 / (1 - beta2) - 1
N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
buffered[1] = N_sma
# more conservative since it's an approximated value
if N_sma >= 5:
radam_step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step'])
else:
radam_step_size = 1.0 / (1 - beta1 ** state['step'])
buffered[2] = radam_step_size
if group['weight_decay'] != 0:
p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
# more conservative since it's an approximated value
radam_step = p_data_fp32.clone()
if N_sma >= 5:
denom = exp_avg_sq.sqrt().add_(group['eps'])
radam_step.addcdiv_(-radam_step_size * group['lr'], exp_avg, denom)
else:
radam_step.add_(-radam_step_size * group['lr'], exp_avg)
radam_norm = radam_step.pow(2).sum().sqrt()
weight_norm = p.data.pow(2).sum().sqrt().clamp(0, 10)
if weight_norm == 0 or radam_norm == 0:
trust_ratio = 1
else:
trust_ratio = weight_norm / radam_norm
state['weight_norm'] = weight_norm
state['adam_norm'] = radam_norm
state['trust_ratio'] = trust_ratio
if N_sma >= 5:
p_data_fp32.addcdiv_(-radam_step_size * group['lr'] * trust_ratio, exp_avg, denom)
else:
p_data_fp32.add_(-radam_step_size * group['lr'] * trust_ratio, exp_avg)
p.data.copy_(p_data_fp32)
return loss

View File

@ -1,177 +0,0 @@
# Ranger deep learning optimizer - RAdam + Lookahead + Gradient Centralization, combined into one optimizer.
# https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer
# and/or
# https://github.com/lessw2020/Best-Deep-Learning-Optimizers
# Ranger has now been used to capture 12 records on the FastAI leaderboard.
# This version = 20.4.11
# Credits:
# Gradient Centralization --> https://arxiv.org/abs/2004.01461v2 (a new optimization technique for DNNs), github: https://github.com/Yonghongwei/Gradient-Centralization
# RAdam --> https://github.com/LiyuanLucasLiu/RAdam
# Lookahead --> rewritten by lessw2020, but big thanks to Github @LonePatient and @RWightman for ideas from their code.
# Lookahead paper --> MZhang,G Hinton https://arxiv.org/abs/1907.08610
# summary of changes:
# 4/11/20 - add gradient centralization option. Set new testing benchmark for accuracy with it, toggle with use_gc flag at init.
# full code integration with all updates at param level instead of group, moves slow weights into state dict (from generic weights),
# supports group learning rates (thanks @SHolderbach), fixes sporadic load from saved model issues.
# changes 8/31/19 - fix references to *self*.N_sma_threshold;
# changed eps to 1e-5 as better default than 1e-8.
import math
import torch
from torch.optim.optimizer import Optimizer
class Ranger(Optimizer):
def __init__(self, params, lr=1e-3, # lr
alpha=0.5, k=6, N_sma_threshhold=5, # Ranger options
betas=(.95, 0.999), eps=1e-5, weight_decay=0, # Adam options
use_gc=True, gc_conv_only=False
# Gradient centralization on or off, applied to conv layers only or conv + fc layers
):
# parameter checks
if not 0.0 <= alpha <= 1.0:
raise ValueError(f'Invalid slow update rate: {alpha}')
if not 1 <= k:
raise ValueError(f'Invalid lookahead steps: {k}')
if not lr > 0:
raise ValueError(f'Invalid Learning Rate: {lr}')
if not eps > 0:
raise ValueError(f'Invalid eps: {eps}')
# parameter comments:
# beta1 (momentum) of .95 seems to work better than .90...
# N_sma_threshold of 5 seems better in testing than 4.
# In both cases, worth testing on your dataset (.90 vs .95, 4 vs 5) to make sure which works best for you.
# prep defaults and init torch.optim base
defaults = dict(lr=lr, alpha=alpha, k=k, step_counter=0, betas=betas, N_sma_threshhold=N_sma_threshhold,
eps=eps, weight_decay=weight_decay)
super().__init__(params, defaults)
# adjustable threshold
self.N_sma_threshhold = N_sma_threshhold
# look ahead params
self.alpha = alpha
self.k = k
# radam buffer for state
self.radam_buffer = [[None, None, None] for ind in range(10)]
# gc on or off
self.use_gc = use_gc
# level of gradient centralization
self.gc_gradient_threshold = 3 if gc_conv_only else 1
print(f"Ranger optimizer loaded. \nGradient Centralization usage = {self.use_gc}")
if (self.use_gc and self.gc_gradient_threshold == 1):
print(f"GC applied to both conv and fc layers")
elif (self.use_gc and self.gc_gradient_threshold == 3):
print(f"GC applied to conv layers only")
def __setstate__(self, state):
print("set state called")
super(Ranger, self).__setstate__(state)
def step(self, closure=None):
loss = None
# note - below is commented out b/c I have other work that passes back the loss as a float, and thus not a callable closure.
# Uncomment if you need to use the actual closure...
# if closure is not None:
# loss = closure()
# Evaluate averages and grad, update param tensors
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data.float()
if grad.is_sparse:
raise RuntimeError('Ranger optimizer does not support sparse gradients')
p_data_fp32 = p.data.float()
state = self.state[p] # get state dict for this param
if len(state) == 0: # if first time to run...init dictionary with our desired entries
# if self.first_run_check==0:
# self.first_run_check=1
# print("Initializing slow buffer...should not see this at load from saved model!")
state['step'] = 0
state['exp_avg'] = torch.zeros_like(p_data_fp32)
state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
# look ahead weight storage now in state dict
state['slow_buffer'] = torch.empty_like(p.data)
state['slow_buffer'].copy_(p.data)
else:
state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
# begin computations
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
beta1, beta2 = group['betas']
# GC operation for Conv layers and FC layers
if grad.dim() > self.gc_gradient_threshold:
grad.add_(-grad.mean(dim=tuple(range(1, grad.dim())), keepdim=True))
state['step'] += 1
# compute variance mov avg
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
# compute mean moving avg
exp_avg.mul_(beta1).add_(1 - beta1, grad)
buffered = self.radam_buffer[int(state['step'] % 10)]
if state['step'] == buffered[0]:
N_sma, step_size = buffered[1], buffered[2]
else:
buffered[0] = state['step']
beta2_t = beta2 ** state['step']
N_sma_max = 2 / (1 - beta2) - 1
N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
buffered[1] = N_sma
if N_sma > self.N_sma_threshhold:
step_size = math.sqrt(
(1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (
N_sma_max - 2)) / (1 - beta1 ** state['step'])
else:
step_size = 1.0 / (1 - beta1 ** state['step'])
buffered[2] = step_size
if group['weight_decay'] != 0:
p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
# apply lr
if N_sma > self.N_sma_threshhold:
denom = exp_avg_sq.sqrt().add_(group['eps'])
p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom)
else:
p_data_fp32.add_(-step_size * group['lr'], exp_avg)
p.data.copy_(p_data_fp32)
# integrated look ahead...
# we do it at the param level instead of group level
if state['step'] % group['k'] == 0:
slow_p = state['slow_buffer'] # get access to slow param tensor
slow_p.add_(self.alpha, p.data - slow_p) # (fast weights - slow weights) * alpha
p.data.copy_(slow_p) # copy interpolated weights to RAdam param tensor
return loss

View File

@ -0,0 +1,115 @@
# encoding: utf-8
"""
@author: xingyu liao
@contact: liaoxingyu5@jd.com
"""
import torch
from torch.optim.optimizer import Optimizer, required
class SGD(Optimizer):
r"""Implements stochastic gradient descent (optionally with momentum).
Nesterov momentum is based on the formula from
`On the importance of initialization and momentum in deep learning`__.
Args:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float): learning rate
momentum (float, optional): momentum factor (default: 0)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
dampening (float, optional): dampening for momentum (default: 0)
nesterov (bool, optional): enables Nesterov momentum (default: False)
Example:
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
>>> optimizer.zero_grad()
>>> loss_fn(model(input), target).backward()
>>> optimizer.step()
__ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf
.. note::
The implementation of SGD with Momentum/Nesterov subtly differs from
Sutskever et. al. and implementations in some other frameworks.
Considering the specific case of Momentum, the update can be written as
.. math::
v = \rho * v + g \\
p = p - lr * v
where p, g, v and :math:`\rho` denote the parameters, gradient,
velocity, and momentum respectively.
This is in contrast to Sutskever et. al. and
other frameworks which employ an update of the form
.. math::
v = \rho * v + lr * g \\
p = p - v
The Nesterov version is analogously modified.
"""
def __init__(self, params, lr=required, momentum=0, dampening=0,
weight_decay=0, nesterov=False):
if lr is not required and lr < 0.0:
raise ValueError("Invalid learning rate: {}".format(lr))
if momentum < 0.0:
raise ValueError("Invalid momentum value: {}".format(momentum))
if weight_decay < 0.0:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
weight_decay=weight_decay, nesterov=nesterov)
if nesterov and (momentum <= 0 or dampening != 0):
raise ValueError("Nesterov momentum requires a momentum and zero dampening")
super(SGD, self).__init__(params, defaults)
def __setstate__(self, state):
super(SGD, self).__setstate__(state)
for group in self.param_groups:
group.setdefault('nesterov', False)
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
weight_decay = group['weight_decay']
momentum = group['momentum']
dampening = group['dampening']
nesterov = group['nesterov']
for p in group['params']:
if p.grad is None or group['freeze']:
continue
d_p = p.grad.data
if weight_decay != 0:
d_p.add_(weight_decay, p.data)
if momentum != 0:
param_state = self.state[p]
if 'momentum_buffer' not in param_state:
buf = param_state['momentum_buffer'] = torch.clone(d_p).detach()
else:
buf = param_state['momentum_buffer']
buf.mul_(momentum).add_(1 - dampening, d_p)
if nesterov:
d_p = d_p.add(momentum, buf)
else:
d_p = buf
p.data.add_(-group['lr'], d_p)
return loss

View File

@ -11,7 +11,6 @@ from typing import Any
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from termcolor import colored
from torch.nn.parallel import DataParallel, DistributedDataParallel
@ -27,7 +26,6 @@ class Checkpointer(object):
def __init__(
self,
model: nn.Module,
dataset: Dataset = None,
save_dir: str = "",
*,
save_to_disk: bool = True,
@ -47,7 +45,6 @@ class Checkpointer(object):
if isinstance(model, (DistributedDataParallel, DataParallel)):
model = model.module
self.model = model
self.dataset = dataset
self.checkpointables = copy.copy(checkpointables)
self.logger = logging.getLogger(__name__)
self.save_dir = save_dir
@ -65,8 +62,6 @@ class Checkpointer(object):
data = {}
data["model"] = self.model.state_dict()
if self.dataset is not None:
data["pid_dict"] = self.dataset.pid_dict
for key, obj in self.checkpointables.items():
data[key] = obj.state_dict()
data.update(kwargs)
@ -104,9 +99,6 @@ class Checkpointer(object):
assert os.path.isfile(path), "Checkpoint {} not found!".format(path)
checkpoint = self._load_file(path)
if self.dataset is not None:
self.logger.info("Loading dataset pid dictionary from {}".format(path))
self._load_dataset_pid_dict(checkpoint)
self._load_model(checkpoint)
for key, obj in self.checkpointables.items():
if key in checkpoint:
@ -191,10 +183,6 @@ class Checkpointer(object):
"""
return torch.load(f, map_location=torch.device("cpu"))
def _load_dataset_pid_dict(self, checkpoint: Any):
checkpoint_pid_dict = checkpoint.pop("pid_dict")
self.dataset.update_pid_dict(checkpoint_pid_dict)
def _load_model(self, checkpoint: Any):
"""
Load weights from a checkpoint.

View File

@ -0,0 +1,158 @@
# encoding: utf-8
"""
@author: xingyu liao
@contact: liaoxingyu5@jd.com
"""
# based on
# https://github.com/facebookresearch/detectron2/blob/master/detectron2/utils/collect_env.py
import importlib
import os
import re
import subprocess
import sys
from collections import defaultdict
import PIL
import numpy as np
import torch
import torchvision
from tabulate import tabulate
__all__ = ["collect_env_info"]
def collect_torch_env():
try:
import torch.__config__
return torch.__config__.show()
except ImportError:
# compatible with older versions of pytorch
from torch.utils.collect_env import get_pretty_env_info
return get_pretty_env_info()
def get_env_module():
var_name = "DETECTRON2_ENV_MODULE"
return var_name, os.environ.get(var_name, "<not set>")
def detect_compute_compatibility(CUDA_HOME, so_file):
try:
cuobjdump = os.path.join(CUDA_HOME, "bin", "cuobjdump")
if os.path.isfile(cuobjdump):
output = subprocess.check_output(
"'{}' --list-elf '{}'".format(cuobjdump, so_file), shell=True
)
output = output.decode("utf-8").strip().split("\n")
sm = []
for line in output:
line = re.findall(r"\.sm_[0-9]*\.", line)[0]
sm.append(line.strip("."))
sm = sorted(set(sm))
return ", ".join(sm)
else:
return so_file + "; cannot find cuobjdump"
except Exception:
# unhandled failure
return so_file
def collect_env_info():
has_gpu = torch.cuda.is_available() # true for both CUDA & ROCM
torch_version = torch.__version__
# NOTE: the use of CUDA_HOME and ROCM_HOME requires the CUDA/ROCM build deps, though in
# theory detectron2 should be made runnable with only the corresponding runtimes
from torch.utils.cpp_extension import CUDA_HOME
has_rocm = False
if tuple(map(int, torch_version.split(".")[:2])) >= (1, 5):
from torch.utils.cpp_extension import ROCM_HOME
if (getattr(torch.version, "hip", None) is not None) and (ROCM_HOME is not None):
has_rocm = True
has_cuda = has_gpu and (not has_rocm)
data = []
data.append(("sys.platform", sys.platform))
data.append(("Python", sys.version.replace("\n", "")))
data.append(("numpy", np.__version__))
try:
import fastreid # noqa
data.append(
("fastreid", fastreid.__version__ + " @" + os.path.dirname(fastreid.__file__))
)
except ImportError:
data.append(("fastreid", "failed to import"))
data.append(get_env_module())
data.append(("PyTorch", torch_version + " @" + os.path.dirname(torch.__file__)))
data.append(("PyTorch debug build", torch.version.debug))
data.append(("GPU available", has_gpu))
if has_gpu:
devices = defaultdict(list)
for k in range(torch.cuda.device_count()):
devices[torch.cuda.get_device_name(k)].append(str(k))
for name, devids in devices.items():
data.append(("GPU " + ",".join(devids), name))
if has_rocm:
data.append(("ROCM_HOME", str(ROCM_HOME)))
else:
data.append(("CUDA_HOME", str(CUDA_HOME)))
cuda_arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST", None)
if cuda_arch_list:
data.append(("TORCH_CUDA_ARCH_LIST", cuda_arch_list))
data.append(("Pillow", PIL.__version__))
try:
data.append(
(
"torchvision",
str(torchvision.__version__) + " @" + os.path.dirname(torchvision.__file__),
)
)
if has_cuda:
try:
torchvision_C = importlib.util.find_spec("torchvision._C").origin
msg = detect_compute_compatibility(CUDA_HOME, torchvision_C)
data.append(("torchvision arch flags", msg))
except ImportError:
data.append(("torchvision._C", "failed to find"))
except AttributeError:
data.append(("torchvision", "unknown"))
try:
import fvcore
data.append(("fvcore", fvcore.__version__))
except ImportError:
pass
try:
import cv2
data.append(("cv2", cv2.__version__))
except ImportError:
pass
env_str = tabulate(data) + "\n"
env_str += collect_torch_env()
return env_str
if __name__ == "__main__":
try:
import detectron2 # noqa
except ImportError:
print(collect_env_info())
else:
from fastreid.utils.collect_env import collect_env_info
print(collect_env_info())

119
fastreid/utils/env.py Normal file
View File

@ -0,0 +1,119 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import importlib
import importlib.util
import logging
import numpy as np
import os
import random
import sys
from datetime import datetime
import torch
__all__ = ["seed_all_rng"]
TORCH_VERSION = tuple(int(x) for x in torch.__version__.split(".")[:2])
"""
PyTorch version as a tuple of 2 ints. Useful for comparison.
"""
def seed_all_rng(seed=None):
"""
Set the random seed for the RNG in torch, numpy and python.
Args:
seed (int): if None, will use a strong random seed.
"""
if seed is None:
seed = (
os.getpid()
+ int(datetime.now().strftime("%S%f"))
+ int.from_bytes(os.urandom(2), "big")
)
logger = logging.getLogger(__name__)
logger.info("Using a generated random seed {}".format(seed))
np.random.seed(seed)
torch.set_rng_state(torch.manual_seed(seed).get_state())
random.seed(seed)
# from https://stackoverflow.com/questions/67631/how-to-import-a-module-given-the-full-path
def _import_file(module_name, file_path, make_importable=False):
spec = importlib.util.spec_from_file_location(module_name, file_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
if make_importable:
sys.modules[module_name] = module
return module
def _configure_libraries():
"""
Configurations for some libraries.
"""
# An environment option to disable `import cv2` globally,
# in case it leads to negative performance impact
disable_cv2 = int(os.environ.get("DETECTRON2_DISABLE_CV2", False))
if disable_cv2:
sys.modules["cv2"] = None
else:
# Disable opencl in opencv since its interaction with cuda often has negative effects
# This envvar is supported after OpenCV 3.4.0
os.environ["OPENCV_OPENCL_RUNTIME"] = "disabled"
try:
import cv2
if int(cv2.__version__.split(".")[0]) >= 3:
cv2.ocl.setUseOpenCL(False)
except ImportError:
pass
def get_version(module, digit=2):
return tuple(map(int, module.__version__.split(".")[:digit]))
# fmt: off
assert get_version(torch) >= (1, 4), "Requires torch>=1.4"
import yaml
assert get_version(yaml) >= (5, 1), "Requires pyyaml>=5.1"
# fmt: on
_ENV_SETUP_DONE = False
def setup_environment():
"""Perform environment setup work. The default setup is a no-op, but this
function allows the user to specify a Python source file or a module in
the $DETECTRON2_ENV_MODULE environment variable, that performs
custom setup work that may be necessary to their computing environment.
"""
global _ENV_SETUP_DONE
if _ENV_SETUP_DONE:
return
_ENV_SETUP_DONE = True
_configure_libraries()
custom_module_path = os.environ.get("FASTREID_ENV_MODULE")
if custom_module_path:
setup_custom_environment(custom_module_path)
else:
# The default setup is a no-op
pass
def setup_custom_environment(custom_module):
"""
Load custom environment setup by importing a Python source file or a
module, and run the setup function.
"""
if custom_module.endswith(".py"):
module = _import_file("fastreid.utils.env.custom_module", custom_module)
else:
module = importlib.import_module(custom_module)
assert hasattr(module, "setup_environment") and callable(module.setup_environment), (
"Custom environment module defined in {} does not have the "
"required callable attribute 'setup_environment'."
).format(custom_module)
module.setup_environment()

View File

@ -1,58 +0,0 @@
# encoding: utf-8
"""
@author: xingyu liao
@contact: liaoxingyu5@jd.com
"""
from typing import Optional
import torch
# based on:
# https://github.com/kornia/kornia/blob/master/kornia/utils/one_hot.py
def one_hot(labels: torch.Tensor,
num_classes: int,
dtype: Optional[torch.dtype] = None, ) -> torch.Tensor:
# eps: Optional[float] = 1e-6) -> torch.Tensor:
r"""Converts an integer label x-D tensor to a one-hot (x+1)-D tensor.
Args:
labels (torch.Tensor) : tensor with labels of shape :math:`(N, *)`,
where N is batch size. Each value is an integer
representing correct classification.
num_classes (int): number of classes in labels.
device (Optional[torch.device]): the desired device of returned tensor.
Default: if None, uses the current device for the default tensor type
(see torch.set_default_tensor_type()). device will be the CPU for CPU
tensor types and the current CUDA device for CUDA tensor types.
dtype (Optional[torch.dtype]): the desired data type of returned
tensor. Default: if None, infers data type from values.
Returns:
torch.Tensor: the labels in one hot tensor of shape :math:`(N, C, *)`,
Examples::
>>> labels = torch.LongTensor([[[0, 1], [2, 0]]])
>>> one_hot(labels, num_classes=3)
tensor([[[[1., 0.],
[0., 1.]],
[[0., 1.],
[0., 0.]],
[[0., 0.],
[1., 0.]]]]
"""
if not torch.is_tensor(labels):
raise TypeError("Input labels type is not a torch.Tensor. Got {}"
.format(type(labels)))
if not labels.dtype == torch.int64:
raise ValueError(
"labels must be of the same dtype torch.int64. Got: {}".format(
labels.dtype))
if num_classes < 1:
raise ValueError("The number of classes must be bigger than one."
" Got: {}".format(num_classes))
device = labels.device
shape = labels.shape
one_hot = torch.zeros(shape[0], num_classes, *shape[1:],
device=device, dtype=dtype)
return one_hot.scatter_(1, labels.unsqueeze(1), 1.0)

View File

@ -57,9 +57,7 @@ def update_bn_stats(model, data_loader, num_iters: int = 200):
running_var = [torch.zeros_like(bn.running_var) for bn in bn_layers]
for ind, inputs in enumerate(itertools.islice(data_loader, num_iters)):
# Change targets to zero to avoid error in circle(arcface) loss
# which will use targets in forward
inputs['targets'].zero_()
inputs['targets'].fill_(-1)
with torch.no_grad(): # No need to backward
model(inputs)
for i, bn in enumerate(bn_layers):

View File

@ -63,7 +63,6 @@ if __name__ == '__main__':
model = build_model(cfg)
Checkpointer(model).load(cfg.MODEL.WEIGHTS)
model.cuda()
model.eval()
print(model)

View File

@ -1,48 +0,0 @@
# encoding: utf-8
"""
@author: sherlock
@contact: sherlockliao01@gmail.com
"""
import sys
import torch
sys.path.append('../..')
from fastreid.config import get_cfg
from fastreid.engine import default_argument_parser, default_setup
from fastreid.modeling.meta_arch import build_model
from fastreid.export.tensorflow_export import export_tf_reid_model
from fastreid.export.tf_modeling import TfMetaArch
def setup(args):
"""
Create configs and perform basic setups.
"""
cfg = get_cfg()
# cfg.merge_from_file(args.config_file)
cfg.merge_from_list(args.opts)
cfg.freeze()
default_setup(cfg, args)
return cfg
if __name__ == "__main__":
args = default_argument_parser().parse_args()
print("Command Line Args:", args)
cfg = setup(args)
cfg.defrost()
cfg.MODEL.BACKBONE.NAME = "build_resnet_backbone"
cfg.MODEL.BACKBONE.DEPTH = 50
cfg.MODEL.BACKBONE.LAST_STRIDE = 1
# If use IBN block in backbone
cfg.MODEL.BACKBONE.WITH_IBN = False
cfg.MODEL.BACKBONE.PRETRAIN = False
from torchvision.models import resnet50
# model = TfMetaArch(cfg)
model = resnet50(pretrained=False)
# model.load_params_wo_fc(torch.load('logs/bjstation/res50_baseline_v0.4/ckpts/model_epoch80.pth'))
model.eval()
dummy_inputs = torch.randn(1, 3, 256, 128)
export_tf_reid_model(model, dummy_inputs, 'reid_tf.pb')

View File

@ -5,16 +5,14 @@
@contact: sherlockliao01@gmail.com
"""
import os
import logging
import os
import sys
sys.path.append('.')
from torch import nn
from fastreid.config import get_cfg
from fastreid.engine import DefaultTrainer, default_argument_parser, default_setup
from fastreid.engine import DefaultTrainer, default_argument_parser, default_setup, launch
from fastreid.utils.checkpoint import Checkpointer
from fastreid.engine import hooks
from fastreid.evaluation import ReidEvaluator
@ -43,15 +41,14 @@ def setup(args):
def main(args):
cfg = setup(args)
logger = logging.getLogger('fastreid.' + __name__)
if args.eval_only:
logger = logging.getLogger("fastreid.trainer")
cfg.defrost()
cfg.MODEL.BACKBONE.PRETRAIN = False
model = Trainer.build_model(cfg)
model = nn.DataParallel(model)
model = model.cuda()
Checkpointer(model, save_dir=cfg.OUTPUT_DIR).load(cfg.MODEL.WEIGHTS) # load trained model
Checkpointer(model).load(cfg.MODEL.WEIGHTS) # load trained model
if cfg.TEST.PRECISE_BN.ENABLED and hooks.get_bn_modules(model):
prebn_cfg = cfg.clone()
prebn_cfg.DATALOADER.NUM_WORKERS = 0 # save some memory and time for PreciseBN
@ -75,4 +72,11 @@ def main(args):
if __name__ == "__main__":
args = default_argument_parser().parse_args()
print("Command Line Args:", args)
main(args)
launch(
main,
args.num_gpus,
num_machines=args.num_machines,
machine_rank=args.machine_rank,
dist_url=args.dist_url,
args=(args,),
)