mirror of
https://github.com/JDAI-CV/fast-reid.git
synced 2025-06-03 14:50:47 +08:00
finish v0.2 ddp training
This commit is contained in:
parent
5ae2cff47e
commit
fec7abc461
12
demo/demo.py
12
demo/demo.py
@ -20,6 +20,9 @@ from fastreid.config import get_cfg
|
||||
from fastreid.utils.file_io import PathManager
|
||||
from predictor import FeatureExtractionDemo
|
||||
|
||||
# import some modules added in project like this below
|
||||
# from projects.PartialReID.partialreid import *
|
||||
|
||||
cudnn.benchmark = True
|
||||
|
||||
|
||||
@ -40,12 +43,7 @@ def get_parser():
|
||||
help="path to config file",
|
||||
)
|
||||
parser.add_argument(
|
||||
'--device',
|
||||
default='cuda: 1',
|
||||
help='CUDA device to use'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--parallel',
|
||||
"--parallel",
|
||||
action='store_true',
|
||||
help='If use multiprocess for feature extraction.'
|
||||
)
|
||||
@ -72,7 +70,7 @@ def get_parser():
|
||||
if __name__ == '__main__':
|
||||
args = get_parser().parse_args()
|
||||
cfg = setup_cfg(args)
|
||||
demo = FeatureExtractionDemo(cfg, device=args.device, parallel=args.parallel)
|
||||
demo = FeatureExtractionDemo(cfg, parallel=args.parallel)
|
||||
|
||||
PathManager.mkdirs(args.output)
|
||||
if args.input:
|
||||
|
@ -21,7 +21,7 @@ except RuntimeError:
|
||||
|
||||
|
||||
class FeatureExtractionDemo(object):
|
||||
def __init__(self, cfg, device='cuda:0', parallel=False):
|
||||
def __init__(self, cfg, parallel=False):
|
||||
"""
|
||||
Args:
|
||||
cfg (CfgNode):
|
||||
@ -35,7 +35,7 @@ class FeatureExtractionDemo(object):
|
||||
self.num_gpus = torch.cuda.device_count()
|
||||
self.predictor = AsyncPredictor(cfg, self.num_gpus)
|
||||
else:
|
||||
self.predictor = DefaultPredictor(cfg, device)
|
||||
self.predictor = DefaultPredictor(cfg)
|
||||
|
||||
def run_on_image(self, original_image):
|
||||
"""
|
||||
@ -51,6 +51,8 @@ class FeatureExtractionDemo(object):
|
||||
original_image = original_image[:, :, ::-1]
|
||||
# Apply pre-processing to image.
|
||||
image = cv2.resize(original_image, tuple(self.cfg.INPUT.SIZE_TEST[::-1]), interpolation=cv2.INTER_CUBIC)
|
||||
# Make shape with a new batch dimension which is adapted for
|
||||
# network input
|
||||
image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))[None]
|
||||
predictions = self.predictor(image)
|
||||
return predictions
|
||||
@ -68,16 +70,16 @@ class FeatureExtractionDemo(object):
|
||||
if cnt >= buffer_size:
|
||||
batch = batch_data.popleft()
|
||||
predictions = self.predictor.get()
|
||||
yield predictions, batch['targets'].numpy(), batch['camid'].numpy()
|
||||
yield predictions, batch["targets"].numpy(), batch["camid"].numpy()
|
||||
|
||||
while len(batch_data):
|
||||
batch = batch_data.popleft()
|
||||
predictions = self.predictor.get()
|
||||
yield predictions, batch['targets'].numpy(), batch['camid'].numpy()
|
||||
yield predictions, batch["targets"].numpy(), batch["camid"].numpy()
|
||||
else:
|
||||
for batch in data_loader:
|
||||
predictions = self.predictor(batch["images"])
|
||||
yield predictions, batch['targets'].numpy(), batch['camid'].numpy()
|
||||
yield predictions, batch["targets"].numpy(), batch["camid"].numpy()
|
||||
|
||||
|
||||
class AsyncPredictor:
|
||||
@ -90,15 +92,14 @@ class AsyncPredictor:
|
||||
pass
|
||||
|
||||
class _PredictWorker(mp.Process):
|
||||
def __init__(self, cfg, device, task_queue, result_queue):
|
||||
def __init__(self, cfg, task_queue, result_queue):
|
||||
self.cfg = cfg
|
||||
self.device = device
|
||||
self.task_queue = task_queue
|
||||
self.result_queue = result_queue
|
||||
super().__init__()
|
||||
|
||||
def run(self):
|
||||
predictor = DefaultPredictor(self.cfg, self.device)
|
||||
predictor = DefaultPredictor(self.cfg)
|
||||
|
||||
while True:
|
||||
task = self.task_queue.get()
|
||||
@ -120,9 +121,11 @@ class AsyncPredictor:
|
||||
self.result_queue = mp.Queue(maxsize=num_workers * 3)
|
||||
self.procs = []
|
||||
for gpuid in range(max(num_gpus, 1)):
|
||||
device = "cuda:{}".format(gpuid) if num_gpus > 0 else "cpu"
|
||||
cfg = cfg.clone()
|
||||
cfg.defrost()
|
||||
cfg.MODEL.DEVICE = "cuda: {}".format(gpuid) if num_gpus > 0 else "cpu"
|
||||
self.procs.append(
|
||||
AsyncPredictor._PredictWorker(cfg, device, self.task_queue, self.result_queue)
|
||||
AsyncPredictor._PredictWorker(cfg, self.task_queue, self.result_queue)
|
||||
)
|
||||
|
||||
self.put_idx = 0
|
||||
|
@ -42,11 +42,6 @@ def get_parser():
|
||||
metavar="FILE",
|
||||
help="path to config file",
|
||||
)
|
||||
parser.add_argument(
|
||||
'--device',
|
||||
default='cuda: 1',
|
||||
help='CUDA device to use'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--parallel',
|
||||
action='store_true',
|
||||
@ -101,7 +96,7 @@ if __name__ == '__main__':
|
||||
logger = setup_logger()
|
||||
cfg = setup_cfg(args)
|
||||
test_loader, num_query = build_reid_test_loader(cfg, args.dataset_name)
|
||||
demo = FeatureExtractionDemo(cfg, device=args.device, parallel=args.parallel)
|
||||
demo = FeatureExtractionDemo(cfg, parallel=args.parallel)
|
||||
|
||||
logger.info("Start extracting image features")
|
||||
feats = []
|
||||
|
@ -20,8 +20,9 @@ _C = CN()
|
||||
# MODEL
|
||||
# -----------------------------------------------------------------------------
|
||||
_C.MODEL = CN()
|
||||
_C.MODEL.DEVICE = "cuda"
|
||||
_C.MODEL.META_ARCHITECTURE = 'Baseline'
|
||||
_C.MODEL.OPEN_LAYERS = ['']
|
||||
_C.MODEL.FREEZE_LAYERS = ['']
|
||||
|
||||
# ---------------------------------------------------------------------------- #
|
||||
# Backbone options
|
||||
@ -57,7 +58,7 @@ _C.MODEL.HEADS.NORM = "BN"
|
||||
# Mini-batch split of Ghost BN
|
||||
_C.MODEL.HEADS.NORM_SPLIT = 1
|
||||
# Number of identity
|
||||
_C.MODEL.HEADS.NUM_CLASSES = 751
|
||||
_C.MODEL.HEADS.NUM_CLASSES = 0
|
||||
# Input feature dimension
|
||||
_C.MODEL.HEADS.IN_FEAT = 2048
|
||||
# Reduction dimension in head
|
||||
@ -74,7 +75,6 @@ _C.MODEL.HEADS.CLS_LAYER = "linear" # "arcface" or "circle"
|
||||
_C.MODEL.HEADS.MARGIN = 0.15
|
||||
_C.MODEL.HEADS.SCALE = 128
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------- #
|
||||
# REID LOSSES options
|
||||
# ---------------------------------------------------------------------------- #
|
||||
@ -115,7 +115,7 @@ _C.MODEL.WEIGHTS = ""
|
||||
_C.MODEL.PIXEL_MEAN = [0.485*255, 0.456*255, 0.406*255]
|
||||
# Values to be used for image normalization
|
||||
_C.MODEL.PIXEL_STD = [0.229*255, 0.224*255, 0.225*255]
|
||||
#
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# INPUT
|
||||
@ -194,10 +194,10 @@ _C.SOLVER.WEIGHT_DECAY_BIAS = 0.
|
||||
# Multi-step learning rate options
|
||||
_C.SOLVER.SCHED = "WarmupMultiStepLR"
|
||||
_C.SOLVER.GAMMA = 0.1
|
||||
_C.SOLVER.STEPS = (30, 55)
|
||||
_C.SOLVER.STEPS = [30, 55]
|
||||
|
||||
# Cosine annealing learning rate options
|
||||
_C.SOLVER.DELAY_ITERS = 100
|
||||
_C.SOLVER.DELAY_ITERS = 0
|
||||
_C.SOLVER.ETA_MIN_LR = 3e-7
|
||||
|
||||
# Warmup options
|
||||
@ -210,15 +210,14 @@ _C.SOLVER.FREEZE_ITERS = 0
|
||||
# SWA options
|
||||
_C.SOLVER.SWA = CN()
|
||||
_C.SOLVER.SWA.ENABLED = False
|
||||
_C.SOLVER.SWA.ITER = 0
|
||||
_C.SOLVER.SWA.PERIOD = 10
|
||||
_C.SOLVER.SWA.ITER = 10
|
||||
_C.SOLVER.SWA.PERIOD = 2
|
||||
_C.SOLVER.SWA.LR_FACTOR = 10.
|
||||
_C.SOLVER.SWA.ETA_MIN_LR = 3.5e-6
|
||||
_C.SOLVER.SWA.LR_SCHED = False
|
||||
|
||||
_C.SOLVER.CHECKPOINT_PERIOD = 5000
|
||||
|
||||
_C.SOLVER.LOG_PERIOD = 30
|
||||
# Number of images per batch
|
||||
# This is global, so if we have 8 GPUs and IMS_PER_BATCH = 16, each GPU will
|
||||
# see 2 images per batch
|
||||
|
@ -12,3 +12,4 @@ __all__ = [k for k in globals().keys() if not k.startswith("_")]
|
||||
# but still make them available here
|
||||
from .hooks import *
|
||||
from .defaults import *
|
||||
from .launch import *
|
||||
|
@ -11,20 +11,22 @@ since they are meant to represent the "common default behavior" people need in t
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import DataParallel
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
|
||||
from fastreid.data import build_reid_test_loader, build_reid_train_loader
|
||||
from fastreid.evaluation import (DatasetEvaluator, ReidEvaluator,
|
||||
inference_on_dataset, print_csv_format)
|
||||
from fastreid.modeling.meta_arch import build_model
|
||||
from fastreid.layers.sync_bn import patch_replication_callback
|
||||
from fastreid.solver import build_lr_scheduler, build_optimizer
|
||||
from fastreid.utils import comm
|
||||
from fastreid.utils.env import seed_all_rng
|
||||
from fastreid.utils.checkpoint import Checkpointer
|
||||
from fastreid.utils.collect_env import collect_env_info
|
||||
from fastreid.utils.events import CommonMetricPrinter, JSONWriter, TensorboardXWriter
|
||||
from fastreid.utils.file_io import PathManager
|
||||
from fastreid.utils.logger import setup_logger
|
||||
@ -36,7 +38,7 @@ __all__ = ["default_argument_parser", "default_setup", "DefaultPredictor", "Defa
|
||||
|
||||
def default_argument_parser():
|
||||
"""
|
||||
Create a parser with some common arguments used by detectron2 users.
|
||||
Create a parser with some common arguments used by fastreid users.
|
||||
Returns:
|
||||
argparse.ArgumentParser:
|
||||
"""
|
||||
@ -48,17 +50,17 @@ def default_argument_parser():
|
||||
help="whether to attempt to resume from the checkpoint directory",
|
||||
)
|
||||
parser.add_argument("--eval-only", action="store_true", help="perform evaluation only")
|
||||
# parser.add_argument("--num-gpus", type=int, default=1, help="number of gpus *per machine*")
|
||||
# parser.add_argument("--num-machines", type=int, default=1)
|
||||
# parser.add_argument(
|
||||
# "--machine-rank", type=int, default=0, help="the rank of this machine (unique per machine)"
|
||||
# )
|
||||
parser.add_argument("--num-gpus", type=int, default=1, help="number of gpus *per machine*")
|
||||
parser.add_argument("--num-machines", type=int, default=1, help="total number of machines")
|
||||
parser.add_argument(
|
||||
"--machine-rank", type=int, default=0, help="the rank of this machine (unique per machine)"
|
||||
)
|
||||
|
||||
# PyTorch still may leave orphan processes in multi-gpu training.
|
||||
# Therefore we use a deterministic way to obtain port,
|
||||
# so that users are aware of orphan processes by seeing the port occupied.
|
||||
# port = 2 ** 15 + 2 ** 14 + hash(os.getuid()) % 2 ** 14
|
||||
# parser.add_argument("--dist-url", default="tcp://127.0.0.1:{}".format(port))
|
||||
port = 2 ** 15 + 2 ** 14 + hash(os.getuid() if sys.platform != "win32" else 1) % 2 ** 14
|
||||
parser.add_argument("--dist-url", default="tcp://127.0.0.1:{}".format(port))
|
||||
parser.add_argument(
|
||||
"opts",
|
||||
help="Modify config options using the command-line",
|
||||
@ -87,7 +89,7 @@ def default_setup(cfg, args):
|
||||
logger = setup_logger(output_dir, distributed_rank=rank)
|
||||
|
||||
logger.info("Rank of current process: {}. World size: {}".format(rank, comm.get_world_size()))
|
||||
# logger.info("Environment info:\n" + collect_env_info())
|
||||
logger.info("Environment info:\n" + collect_env_info())
|
||||
|
||||
logger.info("Command line arguments: " + str(args))
|
||||
if hasattr(args, "config_file") and args.config_file != "":
|
||||
@ -106,6 +108,9 @@ def default_setup(cfg, args):
|
||||
f.write(cfg.dump())
|
||||
logger.info("Full config saved to {}".format(os.path.abspath(path)))
|
||||
|
||||
# make sure each worker has a different, yet deterministic seed if specified
|
||||
seed_all_rng()
|
||||
|
||||
# cudnn benchmark has large overhead. It shouldn't be used considering the small size of
|
||||
# typical validation set.
|
||||
if not (hasattr(args, "eval_only") and args.eval_only):
|
||||
@ -128,17 +133,14 @@ class DefaultPredictor:
|
||||
outputs = pred(inputs)
|
||||
"""
|
||||
|
||||
def __init__(self, cfg, device='cpu'):
|
||||
def __init__(self, cfg):
|
||||
self.cfg = cfg.clone() # cfg can be modified by model
|
||||
self.cfg.defrost()
|
||||
self.cfg.MODEL.BACKBONE.PRETRAIN = False
|
||||
self.device = device
|
||||
self.model = build_model(self.cfg)
|
||||
self.model.to(device)
|
||||
self.model.eval()
|
||||
|
||||
checkpointer = Checkpointer(self.model)
|
||||
checkpointer.load(cfg.MODEL.WEIGHTS)
|
||||
Checkpointer(self.model).load(cfg.MODEL.WEIGHTS)
|
||||
|
||||
def __call__(self, image):
|
||||
"""
|
||||
@ -147,9 +149,8 @@ class DefaultPredictor:
|
||||
Returns:
|
||||
predictions (torch.tensor): the output features of the model
|
||||
"""
|
||||
inputs = {"images": image}
|
||||
with torch.no_grad(): # https://github.com/sphinx-doc/sphinx/issues/4258
|
||||
image = image.to(self.device)
|
||||
inputs = {"images": image}
|
||||
predictions = self.model(inputs)
|
||||
# Normalize feature to compute cosine distance
|
||||
pred_feat = F.normalize(predictions)
|
||||
@ -196,20 +197,24 @@ class DefaultTrainer(SimpleTrainer):
|
||||
cfg (CfgNode):
|
||||
"""
|
||||
self.cfg = cfg
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = logging.getLogger("fastreid")
|
||||
if not logger.isEnabledFor(logging.INFO): # setup_logger is not called for fastreid
|
||||
setup_logger()
|
||||
|
||||
# Assume these objects must be constructed in this order.
|
||||
data_loader = self.build_train_loader(cfg)
|
||||
cfg = self.auto_scale_hyperparams(cfg, data_loader)
|
||||
model = self.build_model(cfg)
|
||||
optimizer = self.build_optimizer(cfg, model)
|
||||
logger.info('Prepare training set')
|
||||
data_loader = self.build_train_loader(cfg)
|
||||
# For training, wrap with DP. But don't need this for inference.
|
||||
model = DataParallel(model)
|
||||
if cfg.MODEL.BACKBONE.NORM == "syncBN":
|
||||
# Monkey-patching with syncBN
|
||||
patch_replication_callback(model)
|
||||
model = model.cuda()
|
||||
|
||||
# For training, wrap with DDP. But don't need this for inference.
|
||||
if comm.get_world_size() > 1:
|
||||
# ref to https://github.com/pytorch/pytorch/issues/22049 to set `find_unused_parameters=True`
|
||||
# for part of the parameters is not updated.
|
||||
model = DistributedDataParallel(
|
||||
model, device_ids=[comm.get_local_rank()], broadcast_buffers=False
|
||||
)
|
||||
|
||||
super().__init__(model, data_loader, optimizer)
|
||||
|
||||
self.scheduler = self.build_lr_scheduler(cfg, optimizer)
|
||||
@ -218,8 +223,8 @@ class DefaultTrainer(SimpleTrainer):
|
||||
self.checkpointer = Checkpointer(
|
||||
# Assume you want to save checkpoints together with logs/statistics
|
||||
model,
|
||||
self.data_loader.dataset,
|
||||
cfg.OUTPUT_DIR,
|
||||
save_to_disk=comm.is_main_process(),
|
||||
optimizer=optimizer,
|
||||
scheduler=self.scheduler,
|
||||
)
|
||||
@ -246,12 +251,10 @@ class DefaultTrainer(SimpleTrainer):
|
||||
# Reinitialize dataloader iter because when we update dataset person identity dict
|
||||
# to resume training, DataLoader won't update this dictionary when using multiprocess
|
||||
# because of the function scope.
|
||||
self._data_loader_iter = iter(self.data_loader)
|
||||
|
||||
self.start_iter = checkpoint.get("iteration", -1) if resume else -1
|
||||
# The checkpoint stores the training iteration that just finished, thus we start
|
||||
# at the next iteration (or iter zero if there's no checkpoint).
|
||||
self.start_iter += 1
|
||||
if resume and self.checkpointer.has_checkpoint():
|
||||
self.start_iter = checkpoint.get("iteration", -1) + 1
|
||||
# The checkpoint stores the training iteration that just finished, thus we start
|
||||
# at the next iteration (or iter zero if there's no checkpoint).
|
||||
|
||||
def build_hooks(self):
|
||||
"""
|
||||
@ -292,31 +295,38 @@ class DefaultTrainer(SimpleTrainer):
|
||||
cfg.TEST.PRECISE_BN.NUM_ITER,
|
||||
))
|
||||
|
||||
if cfg.MODEL.OPEN_LAYERS != [''] and cfg.SOLVER.FREEZE_ITERS > 0:
|
||||
open_layers = ",".join(cfg.MODEL.OPEN_LAYERS)
|
||||
logger.info(f'Open "{open_layers}" training for {cfg.SOLVER.FREEZE_ITERS:d} iters')
|
||||
if cfg.MODEL.FREEZE_LAYERS != [''] and cfg.SOLVER.FREEZE_ITERS > 0:
|
||||
freeze_layers = ",".join(cfg.MODEL.FREEZE_LAYERS)
|
||||
logger.info(f'Freeze layer group "{freeze_layers}" training for {cfg.SOLVER.FREEZE_ITERS:d} iterations')
|
||||
ret.append(hooks.FreezeLayer(
|
||||
self.model,
|
||||
cfg.MODEL.OPEN_LAYERS,
|
||||
self.optimizer,
|
||||
cfg.MODEL.FREEZE_LAYERS,
|
||||
cfg.SOLVER.FREEZE_ITERS,
|
||||
))
|
||||
# Do PreciseBN before checkpointer, because it updates the model and need to
|
||||
# be saved by checkpointer.
|
||||
# This is not always the best: if checkpointing has a different frequency,
|
||||
# some checkpoints may have more precise statistics than others.
|
||||
# if comm.is_main_process():
|
||||
ret.append(hooks.PeriodicCheckpointer(self.checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD))
|
||||
if comm.is_main_process():
|
||||
ret.append(hooks.PeriodicCheckpointer(self.checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD))
|
||||
|
||||
def test_and_save_results():
|
||||
self._last_eval_results = self.test(self.cfg, self.model)
|
||||
return self._last_eval_results
|
||||
if comm.is_main_process():
|
||||
self._last_eval_results = self.test(self.cfg, self.model)
|
||||
torch.cuda.empty_cache()
|
||||
return self._last_eval_results
|
||||
else:
|
||||
return None
|
||||
|
||||
# Do evaluation after checkpointer, because then if it fails,
|
||||
# we can use the saved checkpoint to debug.
|
||||
ret.append(hooks.EvalHook(cfg.TEST.EVAL_PERIOD, test_and_save_results))
|
||||
|
||||
# run writers in the end, so that evaluation metrics are written
|
||||
ret.append(hooks.PeriodicWriter(self.build_writers(), cfg.SOLVER.LOG_PERIOD))
|
||||
if comm.is_main_process():
|
||||
# run writers in the end, so that evaluation metrics are written
|
||||
ret.append(hooks.PeriodicWriter(self.build_writers(), 200))
|
||||
|
||||
return ret
|
||||
|
||||
def build_writers(self):
|
||||
@ -351,9 +361,12 @@ class DefaultTrainer(SimpleTrainer):
|
||||
OrderedDict of results, if evaluation is enabled. Otherwise None.
|
||||
"""
|
||||
super().train(self.start_iter, self.max_iter)
|
||||
# if hasattr(self, "_last_eval_results") and comm.is_main_process():
|
||||
# verify_results(self.cfg, self._last_eval_results)
|
||||
# return self._last_eval_results
|
||||
if comm.is_main_process():
|
||||
assert hasattr(
|
||||
self, "_last_eval_results"
|
||||
), "No evaluation results obtained during training!"
|
||||
# verify_results(self.cfg, self._last_eval_results)
|
||||
return self._last_eval_results
|
||||
|
||||
@classmethod
|
||||
def build_model(cls, cfg):
|
||||
@ -365,7 +378,7 @@ class DefaultTrainer(SimpleTrainer):
|
||||
"""
|
||||
model = build_model(cfg)
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info("Model:\n{}".format(model))
|
||||
# logger.info("Model:\n{}".format(model))
|
||||
return model
|
||||
|
||||
@classmethod
|
||||
@ -394,6 +407,8 @@ class DefaultTrainer(SimpleTrainer):
|
||||
It now calls :func:`fastreid.data.build_detection_train_loader`.
|
||||
Overwrite it if you'd like a different data loader.
|
||||
"""
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info("Prepare training set")
|
||||
return build_reid_train_loader(cfg)
|
||||
|
||||
@classmethod
|
||||
@ -433,7 +448,7 @@ class DefaultTrainer(SimpleTrainer):
|
||||
|
||||
results = OrderedDict()
|
||||
for idx, dataset_name in enumerate(cfg.DATASETS.TESTS):
|
||||
logger.info(f'prepare test set')
|
||||
logger.info("Prepare testing set")
|
||||
data_loader, num_query = cls.build_test_loader(cfg, dataset_name)
|
||||
# When evaluators are passed in as arguments,
|
||||
# implicitly assume that evaluators can be created before data_loader.
|
||||
@ -451,15 +466,53 @@ class DefaultTrainer(SimpleTrainer):
|
||||
continue
|
||||
results_i = inference_on_dataset(model, data_loader, evaluator)
|
||||
results[dataset_name] = results_i
|
||||
if comm.is_main_process():
|
||||
assert isinstance(
|
||||
results_i, dict
|
||||
), "Evaluator must return a dict on the main process. Got {} instead.".format(
|
||||
results_i
|
||||
)
|
||||
logger.info("Evaluation results for {} in csv format:".format(dataset_name))
|
||||
print_csv_format(results_i)
|
||||
|
||||
if len(results) == 1:
|
||||
results = list(results.values())[0]
|
||||
if comm.is_main_process():
|
||||
assert isinstance(
|
||||
results, dict
|
||||
), "Evaluator must return a dict on the main process. Got {} instead.".format(
|
||||
results
|
||||
)
|
||||
print_csv_format(results)
|
||||
|
||||
if len(results) == 1: results = list(results.values())[0]
|
||||
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def auto_scale_hyperparams(cfg, data_loader):
|
||||
r"""
|
||||
This is used for auto-computation actual training iterations,
|
||||
because some hyper-param, such as MAX_ITER, means training epochs rather than iters,
|
||||
so we need to convert specific hyper-param to training iterations.
|
||||
"""
|
||||
|
||||
cfg = cfg.clone()
|
||||
frozen = cfg.is_frozen()
|
||||
cfg.defrost()
|
||||
|
||||
iters_per_epoch = len(data_loader.dataset) // (cfg.SOLVER.IMS_PER_BATCH * comm.get_world_size())
|
||||
cfg.MODEL.HEADS.NUM_CLASSES = data_loader.dataset.num_classes
|
||||
cfg.SOLVER.MAX_ITER *= iters_per_epoch
|
||||
cfg.SOLVER.WARMUP_ITERS *= iters_per_epoch
|
||||
cfg.SOLVER.FREEZE_ITERS *= iters_per_epoch
|
||||
cfg.SOLVER.DELAY_ITERS *= iters_per_epoch
|
||||
for i in range(len(cfg.SOLVER.STEPS)):
|
||||
cfg.SOLVER.STEPS[i] *= iters_per_epoch
|
||||
cfg.SOLVER.SWA.ITER *= iters_per_epoch
|
||||
cfg.SOLVER.SWA.PERIOD *= iters_per_epoch
|
||||
cfg.SOLVER.CHECKPOINT_PERIOD *= iters_per_epoch
|
||||
cfg.TEST.EVAL_PERIOD *= iters_per_epoch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info(
|
||||
f"Auto-scaling the config to num_classes={cfg.MODEL.HEADS.NUM_CLASSES}, "
|
||||
f"max_Iter={cfg.SOLVER.MAX_ITER}, wamrup_Iter={cfg.SOLVER.WARMUP_ITERS}, "
|
||||
f"freeze_Iter={cfg.SOLVER.FREEZE_ITERS}, delay_Iter={cfg.SOLVER.DELAY_ITERS}, "
|
||||
f"step_Iter={cfg.SOLVER.STEPS}, ckpt_Iter={cfg.SOLVER.CHECKPOINT_PERIOD}, "
|
||||
f"eval_Iter={cfg.TEST.EVAL_PERIOD}."
|
||||
)
|
||||
|
||||
if frozen: cfg.freeze()
|
||||
|
||||
return cfg
|
||||
|
@ -4,7 +4,6 @@
|
||||
import datetime
|
||||
import itertools
|
||||
import logging
|
||||
import warnings
|
||||
import os
|
||||
import tempfile
|
||||
import time
|
||||
@ -12,16 +11,17 @@ from collections import Counter
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
|
||||
from .train_loop import HookBase
|
||||
from fastreid.solver import optim
|
||||
from fastreid.evaluation.testing import flatten_results_dict
|
||||
from fastreid.solver import optim
|
||||
from fastreid.utils import comm
|
||||
from fastreid.utils.checkpoint import PeriodicCheckpointer as _PeriodicCheckpointer
|
||||
from fastreid.utils.events import EventStorage, EventWriter
|
||||
from fastreid.utils.file_io import PathManager
|
||||
from fastreid.utils.precision_bn import update_bn_stats, get_bn_modules
|
||||
from fastreid.utils.timer import Timer
|
||||
from .train_loop import HookBase
|
||||
|
||||
__all__ = [
|
||||
"CallbackHook",
|
||||
@ -308,7 +308,6 @@ class EvalHook(HookBase):
|
||||
"""
|
||||
self._period = eval_period
|
||||
self._func = eval_function
|
||||
self._done_eval_at_last = False
|
||||
|
||||
def _do_eval(self):
|
||||
results = self._func()
|
||||
@ -329,22 +328,16 @@ class EvalHook(HookBase):
|
||||
)
|
||||
self.trainer.storage.put_scalars(**flattened_results, smoothing_hint=False)
|
||||
|
||||
# Evaluation may take different time among workers.
|
||||
# A barrier make them start the next iteration together.
|
||||
comm.synchronize()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def after_step(self):
|
||||
next_iter = self.trainer.iter + 1
|
||||
is_final = next_iter == self.trainer.max_iter
|
||||
if is_final or (self._period > 0 and next_iter % self._period == 0):
|
||||
self._do_eval()
|
||||
if is_final:
|
||||
self._done_eval_at_last = True
|
||||
# Evaluation may take different time among workers.
|
||||
# A barrier make them start the next iteration together.
|
||||
comm.synchronize()
|
||||
|
||||
def after_train(self):
|
||||
if not self._done_eval_at_last:
|
||||
self._do_eval()
|
||||
# func is likely a closure that holds reference to the trainer
|
||||
# therefore we clean it to avoid circular reference in the end
|
||||
del self._func
|
||||
@ -418,27 +411,24 @@ class PreciseBN(HookBase):
|
||||
update_bn_stats(self._model, data_loader(), self._num_iter)
|
||||
|
||||
|
||||
class LRFinder(HookBase):
|
||||
pass
|
||||
|
||||
|
||||
class FreezeLayer(HookBase):
|
||||
def __init__(self, model, open_layer_names, freeze_iters):
|
||||
def __init__(self, model, optimizer, freeze_layers, freeze_iters):
|
||||
self._logger = logging.getLogger(__name__)
|
||||
|
||||
if isinstance(model, nn.DataParallel):
|
||||
if isinstance(model, DistributedDataParallel):
|
||||
model = model.module
|
||||
self.model = model
|
||||
self.optimizer = optimizer
|
||||
|
||||
self.freeze_layers = freeze_layers
|
||||
self.freeze_iters = freeze_iters
|
||||
|
||||
self.open_layer_names = open_layer_names
|
||||
|
||||
# previous requires grad status
|
||||
param_grad = {}
|
||||
for name, param in self.model.named_parameters():
|
||||
param_grad[name] = param.requires_grad
|
||||
self.param_grad = param_grad
|
||||
# Previous parameters freeze status
|
||||
param_freeze = {}
|
||||
for param_group in self.optimizer.param_groups:
|
||||
param_name = param_group['name']
|
||||
param_freeze[param_name] = param_group['freeze']
|
||||
self.param_freeze = param_freeze
|
||||
|
||||
def before_step(self):
|
||||
# Freeze specific layers
|
||||
@ -450,28 +440,28 @@ class FreezeLayer(HookBase):
|
||||
self.open_all_layer()
|
||||
|
||||
def freeze_specific_layer(self):
|
||||
for layer in self.open_layer_names:
|
||||
for layer in self.freeze_layers:
|
||||
if not hasattr(self.model, layer):
|
||||
self._logger.info(f'"{layer}" is not an attribute of the model, will skip this layer')
|
||||
self._logger.info(f'{layer} is not an attribute of the model, will skip this layer')
|
||||
|
||||
for param_group in self.optimizer.param_groups:
|
||||
param_name = param_group['name']
|
||||
if param_name.split('.')[0] in self.freeze_layers:
|
||||
param_group['freeze'] = True
|
||||
|
||||
# Change BN in freeze layers to eval mode
|
||||
for name, module in self.model.named_children():
|
||||
if name in self.open_layer_names:
|
||||
module.train()
|
||||
for p in module.parameters():
|
||||
p.requires_grad = True
|
||||
else:
|
||||
module.eval()
|
||||
for p in module.parameters():
|
||||
p.requires_grad = False
|
||||
if name in self.freeze_layers: module.eval()
|
||||
|
||||
def open_all_layer(self):
|
||||
self.model.train()
|
||||
for name, param in self.model.named_parameters():
|
||||
param.requires_grad = self.param_grad[name]
|
||||
for param_group in self.optimizer.param_groups:
|
||||
param_name = param_group['name']
|
||||
param_group['freeze'] = self.param_freeze[param_name]
|
||||
|
||||
|
||||
class SWA(HookBase):
|
||||
def __init__(self, swa_start: int, swa_freq: int, swa_lr_factor: float, eta_min: float, lr_sched=False,):
|
||||
def __init__(self, swa_start: int, swa_freq: int, swa_lr_factor: float, eta_min: float, lr_sched=False, ):
|
||||
self.swa_start = swa_start
|
||||
self.swa_freq = swa_freq
|
||||
self.swa_lr_factor = swa_lr_factor
|
||||
|
103
fastreid/engine/launch.py
Normal file
103
fastreid/engine/launch.py
Normal file
@ -0,0 +1,103 @@
|
||||
# encoding: utf-8
|
||||
"""
|
||||
@author: xingyu liao
|
||||
@contact: liaoxingyu5@jd.com
|
||||
"""
|
||||
|
||||
# based on:
|
||||
# https://github.com/facebookresearch/detectron2/blob/master/detectron2/engine/launch.py
|
||||
|
||||
|
||||
import logging
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from fastreid.utils import comm
|
||||
|
||||
__all__ = ["launch"]
|
||||
|
||||
|
||||
def _find_free_port():
|
||||
import socket
|
||||
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
# Binding to port 0 will cause the OS to find an available port for us
|
||||
sock.bind(("", 0))
|
||||
port = sock.getsockname()[1]
|
||||
sock.close()
|
||||
# NOTE: there is still a chance the port could be taken by other processes.
|
||||
return port
|
||||
|
||||
|
||||
def launch(main_func, num_gpus_per_machine, num_machines=1, machine_rank=0, dist_url=None, args=()):
|
||||
"""
|
||||
Launch multi-gpu or distributed training.
|
||||
This function must be called on all machines involved in the training.
|
||||
It will spawn child processes (defined by ``num_gpus_per_machine`) on each machine.
|
||||
Args:
|
||||
main_func: a function that will be called by `main_func(*args)`
|
||||
num_gpus_per_machine (int): number of GPUs per machine
|
||||
num_machines (int): the total number of machines
|
||||
machine_rank (int): the rank of this machine
|
||||
dist_url (str): url to connect to for distributed jobs, including protocol
|
||||
e.g. "tcp://127.0.0.1:8686".
|
||||
Can be set to "auto" to automatically select a free port on localhost
|
||||
args (tuple): arguments passed to main_func
|
||||
"""
|
||||
world_size = num_machines * num_gpus_per_machine
|
||||
if world_size > 1:
|
||||
# https://github.com/pytorch/pytorch/pull/14391
|
||||
# TODO prctl in spawned processes
|
||||
|
||||
if dist_url == "auto":
|
||||
assert num_machines == 1, "dist_url=auto not supported in multi-machine jobs."
|
||||
port = _find_free_port()
|
||||
dist_url = f"tcp://127.0.0.1:{port}"
|
||||
if num_machines > 1 and dist_url.startswith("file://"):
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.warning(
|
||||
"file:// is not a reliable init_method in multi-machine jobs. Prefer tcp://"
|
||||
)
|
||||
|
||||
mp.spawn(
|
||||
_distributed_worker,
|
||||
nprocs=num_gpus_per_machine,
|
||||
args=(main_func, world_size, num_gpus_per_machine, machine_rank, dist_url, args),
|
||||
daemon=False,
|
||||
)
|
||||
else:
|
||||
main_func(*args)
|
||||
|
||||
|
||||
def _distributed_worker(
|
||||
local_rank, main_func, world_size, num_gpus_per_machine, machine_rank, dist_url, args
|
||||
):
|
||||
assert torch.cuda.is_available(), "cuda is not available. Please check your installation."
|
||||
global_rank = machine_rank * num_gpus_per_machine + local_rank
|
||||
try:
|
||||
dist.init_process_group(
|
||||
backend="NCCL", init_method=dist_url, world_size=world_size, rank=global_rank
|
||||
)
|
||||
except Exception as e:
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.error("Process group URL: {}".format(dist_url))
|
||||
raise e
|
||||
# synchronize is needed here to prevent a possible timeout after calling init_process_group
|
||||
# See: https://github.com/facebookresearch/maskrcnn-benchmark/issues/172
|
||||
comm.synchronize()
|
||||
|
||||
assert num_gpus_per_machine <= torch.cuda.device_count()
|
||||
torch.cuda.set_device(local_rank)
|
||||
|
||||
# Setup the local process group (which contains ranks within the same machine)
|
||||
assert comm._LOCAL_PROCESS_GROUP is None
|
||||
num_machines = world_size // num_gpus_per_machine
|
||||
for i in range(num_machines):
|
||||
ranks_on_i = list(range(i * num_gpus_per_machine, (i + 1) * num_gpus_per_machine))
|
||||
pg = dist.new_group(ranks_on_i)
|
||||
if i == machine_rank:
|
||||
comm._LOCAL_PROCESS_GROUP = pg
|
||||
|
||||
main_func(*args)
|
@ -5,10 +5,12 @@ https://github.com/facebookresearch/detectron2/blob/master/detectron2/engine/tra
|
||||
"""
|
||||
|
||||
import logging
|
||||
import numpy as np
|
||||
import time
|
||||
import weakref
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import fastreid.utils.comm as comm
|
||||
from fastreid.utils.events import EventStorage
|
||||
|
||||
@ -194,12 +196,12 @@ class SimpleTrainer(TrainerBase):
|
||||
"""
|
||||
data = next(self._data_loader_iter)
|
||||
data_time = time.perf_counter() - start
|
||||
|
||||
"""
|
||||
If your want to do something with the heads, you can wrap the model.
|
||||
"""
|
||||
outputs = self.model(data)
|
||||
loss_dict = self.model.module.losses(outputs)
|
||||
losses = sum(loss for loss in loss_dict.values())
|
||||
loss_dict = self.model(data)
|
||||
losses = sum(loss_dict.values())
|
||||
self._detect_anomaly(losses, loss_dict)
|
||||
|
||||
metrics_dict = loss_dict
|
||||
@ -238,23 +240,22 @@ class SimpleTrainer(TrainerBase):
|
||||
}
|
||||
# gather metrics among all workers for logging
|
||||
# This assumes we do DDP-style training, which is currently the only
|
||||
# supported method in detectron2.
|
||||
# supported method in fastreid.
|
||||
all_metrics_dict = comm.gather(metrics_dict)
|
||||
|
||||
# if comm.is_main_process():
|
||||
if "data_time" in all_metrics_dict[0]:
|
||||
# data_time among workers can have high variance. The actual latency
|
||||
# caused by data_time is the maximum among workers.
|
||||
data_time = np.max([x.pop("data_time") for x in all_metrics_dict])
|
||||
self.storage.put_scalar("data_time", data_time)
|
||||
if comm.is_main_process():
|
||||
if "data_time" in all_metrics_dict[0]:
|
||||
# data_time among workers can have high variance. The actual latency
|
||||
# caused by data_time is the maximum among workers.
|
||||
data_time = np.max([x.pop("data_time") for x in all_metrics_dict])
|
||||
self.storage.put_scalar("data_time", data_time)
|
||||
|
||||
# average the rest metrics
|
||||
metrics_dict = {
|
||||
k: np.mean([x[k] for x in all_metrics_dict]) for k in all_metrics_dict[0].keys()
|
||||
}
|
||||
total_losses_reduced = sum(loss for loss in metrics_dict.values())
|
||||
|
||||
self.storage.put_scalar("total_loss", total_losses_reduced)
|
||||
if len(metrics_dict) > 1:
|
||||
self.storage.put_scalars(**metrics_dict)
|
||||
# average the rest metrics
|
||||
metrics_dict = {
|
||||
k: np.mean([x[k] for x in all_metrics_dict]) for k in all_metrics_dict[0].keys()
|
||||
}
|
||||
total_losses_reduced = sum(loss for loss in metrics_dict.values())
|
||||
|
||||
self.storage.put_scalar("total_loss", total_losses_reduced)
|
||||
if len(metrics_dict) > 1:
|
||||
self.storage.put_scalars(**metrics_dict)
|
||||
|
@ -6,7 +6,7 @@ from contextlib import contextmanager
|
||||
|
||||
import torch
|
||||
|
||||
from ..utils.logger import log_every_n_seconds
|
||||
from fastreid.utils.logger import log_every_n_seconds
|
||||
|
||||
|
||||
class DatasetEvaluator:
|
||||
@ -28,12 +28,12 @@ class DatasetEvaluator:
|
||||
def preprocess_inputs(self, inputs):
|
||||
pass
|
||||
|
||||
def process(self, output):
|
||||
def process(self, inputs, outputs):
|
||||
"""
|
||||
Process an input/output pair.
|
||||
Args:
|
||||
input: the input that's used to call the model.
|
||||
output: the return value of `model(input)`
|
||||
inputs: the inputs that's used to call the model.
|
||||
outputs: the return value of `model(input)`
|
||||
"""
|
||||
pass
|
||||
|
||||
@ -95,11 +95,10 @@ def inference_on_dataset(model, data_loader, evaluator):
|
||||
Returns:
|
||||
The return value of `evaluator.evaluate()`
|
||||
"""
|
||||
# num_devices = torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info("Start inference on {} images".format(len(data_loader.dataset)))
|
||||
|
||||
total = len(data_loader) # inference data loader must have a fixed length
|
||||
total = len(data_loader.dataset) # inference data loader must have a fixed length
|
||||
evaluator.reset()
|
||||
|
||||
num_warmup = min(5, total - 1)
|
||||
@ -116,7 +115,7 @@ def inference_on_dataset(model, data_loader, evaluator):
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
total_compute_time += time.perf_counter() - start_compute_time
|
||||
evaluator.process(outputs)
|
||||
evaluator.process(inputs, outputs)
|
||||
|
||||
idx += 1
|
||||
iters_after_start = idx + 1 - num_warmup * int(idx >= num_warmup)
|
||||
|
@ -154,10 +154,8 @@ def eval_market1501(distmat, q_pids, g_pids, q_camids, g_camids, max_rank):
|
||||
|
||||
all_cmc = np.asarray(all_cmc).astype(np.float32)
|
||||
all_cmc = all_cmc.sum(0) / num_valid_q
|
||||
mAP = np.mean(all_AP)
|
||||
mINP = np.mean(all_INP)
|
||||
|
||||
return all_cmc, mAP, mINP
|
||||
return all_cmc, all_AP, all_INP
|
||||
|
||||
|
||||
def evaluate_py(
|
||||
|
@ -14,8 +14,8 @@ import torch.nn.functional as F
|
||||
from .evaluator import DatasetEvaluator
|
||||
from .query_expansion import aqe
|
||||
from .rank import evaluate_rank
|
||||
from .roc import evaluate_roc
|
||||
from .rerank import re_ranking
|
||||
from .roc import evaluate_roc
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -35,10 +35,10 @@ class ReidEvaluator(DatasetEvaluator):
|
||||
self.pids = []
|
||||
self.camids = []
|
||||
|
||||
def process(self, outputs):
|
||||
self.features.append(outputs[0].cpu())
|
||||
self.pids.extend(outputs[1].cpu().numpy())
|
||||
self.camids.extend(outputs[2].cpu().numpy())
|
||||
def process(self, inputs, outputs):
|
||||
self.pids.extend(inputs["targets"].numpy())
|
||||
self.camids.extend(inputs["camid"].numpy())
|
||||
self.features.append(outputs.cpu())
|
||||
|
||||
@staticmethod
|
||||
def cal_dist(metric: str, query_feat: torch.tensor, gallery_feat: torch.tensor):
|
||||
@ -92,13 +92,12 @@ class ReidEvaluator(DatasetEvaluator):
|
||||
cmc, all_AP, all_INP = evaluate_rank(dist, query_pids, gallery_pids, query_camids, gallery_camids)
|
||||
mAP = np.mean(all_AP)
|
||||
mINP = np.mean(all_INP)
|
||||
for r in [1, 5, 10]:
|
||||
self._results['Rank-{}'.format(r)] = cmc[r - 1]
|
||||
self._results['R-1'] = cmc[0]
|
||||
self._results['mAP'] = mAP
|
||||
self._results['mINP'] = mINP
|
||||
|
||||
tprs = evaluate_roc(dist, query_pids, gallery_pids, query_camids, gallery_camids)
|
||||
fprs = [1e-4, 1e-3, 1e-2]
|
||||
for i in range(len(fprs)):
|
||||
self._results["TPR@FPR={}".format(fprs[i])] = tprs[i]
|
||||
# tprs = evaluate_roc(dist, query_pids, gallery_pids, query_camids, gallery_camids)
|
||||
# fprs = [1e-4, 1e-3, 1e-2]
|
||||
# for i in range(len(fprs)):
|
||||
# self._results["TPR@FPR={}".format(fprs[i])] = tprs[i]
|
||||
return copy.deepcopy(self._results)
|
||||
|
@ -15,10 +15,16 @@ def print_csv_format(results):
|
||||
results (OrderedDict[dict]): task_name -> {metric -> score}
|
||||
"""
|
||||
assert isinstance(results, OrderedDict), results # unordered results cannot be properly printed
|
||||
task = list(results.keys())[0]
|
||||
metrics = [k for k in results[task]]
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
logger.info('----------------------------------------')
|
||||
logger.info("Evaluation results in csv format:")
|
||||
logger.info("Metric: " + ", ".join([k for k in metrics]))
|
||||
for task, res in results.items():
|
||||
logger.info("Task: {}".format(task))
|
||||
logger.info("{:.1%}".format(res))
|
||||
logger.info(f"{task}: " + ", ".join(["{:.1%}".format(v) for v in res.values()]))
|
||||
logger.info('----------------------------------------')
|
||||
|
||||
|
||||
def verify_results(cfg, results):
|
||||
|
166
fastreid/export/demo.py
Normal file
166
fastreid/export/demo.py
Normal file
@ -0,0 +1,166 @@
|
||||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
import torch.nn.functional as F
|
||||
from collections import defaultdict
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.backends import cudnn
|
||||
from fastreid.modeling import build_model
|
||||
from fastreid.utils.checkpoint import Checkpointer
|
||||
from fastreid.config import get_cfg
|
||||
|
||||
cudnn.benchmark = True
|
||||
|
||||
|
||||
class Reid(object):
|
||||
|
||||
def __init__(self, config_file):
|
||||
cfg = get_cfg()
|
||||
cfg.merge_from_file(config_file)
|
||||
cfg.defrost()
|
||||
cfg.MODEL.WEIGHTS = 'projects/bjzProject/logs/bjz/arcface_adam/model_final.pth'
|
||||
model = build_model(cfg)
|
||||
Checkpointer(model).resume_or_load(cfg.MODEL.WEIGHTS)
|
||||
|
||||
model.cuda()
|
||||
model.eval()
|
||||
self.model = model
|
||||
# self.model = torch.jit.load("reid_model.pt")
|
||||
# self.model.eval()
|
||||
# self.model.cuda()
|
||||
|
||||
example = torch.rand(1, 3, 256, 128)
|
||||
example = example.cuda()
|
||||
traced_script_module = torch.jit.trace_module(model, {'inference': example})
|
||||
traced_script_module.save("reid_feat_extractor.pt")
|
||||
|
||||
@classmethod
|
||||
def preprocess(cls, img_path):
|
||||
img = cv2.imread(img_path)
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
img = cv2.resize(img, (128, 256))
|
||||
img = img / 255.0
|
||||
img = (img - [0.485, 0.456, 0.406]) / [0.229, 0.224, 0.225]
|
||||
img = img.transpose((2, 0, 1)).astype(np.float32)
|
||||
img = img[np.newaxis, :, :, :]
|
||||
data = torch.from_numpy(img).cuda().float()
|
||||
return data
|
||||
|
||||
@torch.no_grad()
|
||||
def demo(self, img_path):
|
||||
data = self.preprocess(img_path)
|
||||
output = self.model.inference(data)
|
||||
feat = output.cpu().data.numpy()
|
||||
return feat
|
||||
|
||||
# @torch.no_grad()
|
||||
# def extract_feat(self, dataloader):
|
||||
# prefetcher = test_data_prefetcher(dataloader)
|
||||
# feats = []
|
||||
# labels = []
|
||||
# batch = prefetcher.next()
|
||||
# num_count = 0
|
||||
# while batch[0] is not None:
|
||||
# img, pid, camid = batch
|
||||
# feat = self.model(img)
|
||||
# feats.append(feat.cpu())
|
||||
# labels.extend(np.asarray(pid))
|
||||
#
|
||||
# # if num_count > 2:
|
||||
# # break
|
||||
# batch = prefetcher.next()
|
||||
# # num_count += 1
|
||||
#
|
||||
# feats = torch.cat(feats, dim=0)
|
||||
# id_feats = defaultdict(list)
|
||||
# for f, i in zip(feats, labels):
|
||||
# id_feats[i].append(f)
|
||||
# all_feats = []
|
||||
# label_names = []
|
||||
# for i in id_feats:
|
||||
# all_feats.append(torch.stack(id_feats[i], dim=0).mean(dim=0))
|
||||
# label_names.append(i)
|
||||
#
|
||||
# label_names = np.asarray(label_names)
|
||||
# all_feats = torch.stack(all_feats, dim=0) # (n, 2048)
|
||||
# all_feats = F.normalize(all_feats, p=2, dim=1)
|
||||
# np.save('feats.npy', all_feats.cpu())
|
||||
# np.save('labels.npy', label_names)
|
||||
# cos = torch.mm(all_feats, all_feats.t()).numpy() # (n, n)
|
||||
# cos -= np.eye(all_feats.shape[0])
|
||||
# f = open('check_cross_folder_similarity.txt', 'w')
|
||||
# for i in range(len(label_names)):
|
||||
# sim_indx = np.argwhere(cos[i] > 0.5)[:, 0]
|
||||
# sim_name = label_names[sim_indx]
|
||||
# write_str = label_names[i] + ' '
|
||||
# # f.write(label_names[i]+'\t')
|
||||
# for n in sim_name:
|
||||
# write_str += (n + ' ')
|
||||
# # f.write(n+'\t')
|
||||
# f.write(write_str+'\n')
|
||||
#
|
||||
#
|
||||
# def prepare_gt(self, json_file):
|
||||
# feat = []
|
||||
# label = []
|
||||
# with open(json_file, 'r') as f:
|
||||
# total = json.load(f)
|
||||
# for index in total:
|
||||
# label.append(index)
|
||||
# feat.append(np.array(total[index]))
|
||||
# time_label = [int(i[0:10]) for i in label]
|
||||
#
|
||||
# return np.array(feat), np.array(label), np.array(time_label)
|
||||
|
||||
def compute_topk(self, k, feat, feats, label):
|
||||
|
||||
# num_gallery = feats.shape[0]
|
||||
# new_feat = np.tile(feat,[num_gallery,1])
|
||||
norm_feat = np.sqrt(np.sum(np.square(feat), axis=-1))
|
||||
norm_feats = np.sqrt(np.sum(np.square(feats), axis=-1))
|
||||
matrix = np.sum(np.multiply(feat, feats), axis=-1)
|
||||
dist = matrix / np.multiply(norm_feat, norm_feats)
|
||||
# print('feat:',feat.shape)
|
||||
# print('feats:',feats.shape)
|
||||
# print('label:',label.shape)
|
||||
# print('dist:',dist.shape)
|
||||
|
||||
index = np.argsort(-dist)
|
||||
|
||||
# print('index:',index.shape)
|
||||
result = []
|
||||
for i in range(min(feats.shape[0], k)):
|
||||
print(dist[index[i]])
|
||||
result.append(label[index[i]])
|
||||
return result
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
reid_sys = Reid(config_file='../../projects/bjzProject/configs/bjz.yml')
|
||||
img_path = '/export/home/lxy/beijingStationReID/reid_model/demo_imgs/003740_c5s2_1561733125170.000000.jpg'
|
||||
feat = reid_sys.demo(img_path)
|
||||
feat_extractor = torch.jit.load('reid_feat_extractor.pt')
|
||||
data = reid_sys.preprocess(img_path)
|
||||
feat2 = feat_extractor.inference(data)
|
||||
from ipdb import set_trace; set_trace()
|
||||
# imgs = os.listdir(img_path)
|
||||
# feats = {}
|
||||
# for i in range(len(imgs)):
|
||||
# feat = reid.demo(os.path.join(img_path, imgs[i]))
|
||||
# feats[imgs[i]] = feat
|
||||
# feat = reid.demo(os.path.join(img_path, 'crop_img0.jpg'))
|
||||
# out1 = feats['dog.jpg']
|
||||
# out2 = feats['kobe2.jpg']
|
||||
# innerProduct = np.dot(out1, out2.T)
|
||||
# cosineSimilarity = innerProduct / (np.linalg.norm(out1, ord=2) * np.linalg.norm(out2, ord=2))
|
||||
# print(f'cosine similarity is {cosineSimilarity[0][0]:.4f}')
|
@ -10,8 +10,6 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from .sync_bn import SynchronizedBatchNorm2d
|
||||
|
||||
__all__ = [
|
||||
"BatchNorm",
|
||||
"IBN",
|
||||
@ -32,12 +30,14 @@ class BatchNorm(nn.BatchNorm2d):
|
||||
self.bias.requires_grad_(not bias_freeze)
|
||||
|
||||
|
||||
class SyncBatchNorm(SynchronizedBatchNorm2d):
|
||||
class SyncBatchNorm(nn.SyncBatchNorm):
|
||||
def __init__(self, num_features, eps=1e-05, momentum=0.1, weight_freeze=False, bias_freeze=False, weight_init=1.0,
|
||||
bias_init=0.0):
|
||||
super().__init__(num_features, eps=eps, momentum=momentum, weight_freeze=weight_freeze, bias_freeze=bias_freeze)
|
||||
super().__init__(num_features, eps=eps, momentum=momentum)
|
||||
if weight_init is not None: self.weight.data.fill_(weight_init)
|
||||
if bias_init is not None: self.bias.data.fill_(bias_init)
|
||||
self.weight.requires_grad_(not weight_freeze)
|
||||
self.bias.requires_grad_(not bias_freeze)
|
||||
|
||||
|
||||
class IBN(nn.Module):
|
||||
|
@ -4,15 +4,11 @@
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import Parameter
|
||||
|
||||
from fastreid.utils.one_hot import one_hot
|
||||
|
||||
|
||||
class Circle(nn.Module):
|
||||
def __init__(self, cfg, in_feat, num_classes):
|
||||
@ -34,7 +30,7 @@ class Circle(nn.Module):
|
||||
s_p = self._s * alpha_p * (sim_mat - delta_p)
|
||||
s_n = self._s * alpha_n * (sim_mat - delta_n)
|
||||
|
||||
targets = one_hot(targets, self._num_classes)
|
||||
targets = F.one_hot(targets, num_classes=self._num_classes)
|
||||
|
||||
pred_class_logits = targets * s_p + (1.0 - targets) * s_n
|
||||
|
||||
|
@ -1,13 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# File : __init__.py
|
||||
# Author : Jiayuan Mao
|
||||
# Email : maojiayuan@gmail.com
|
||||
# Date : 27/01/2018
|
||||
#
|
||||
# This file is part of Synchronized-BatchNorm-PyTorch.
|
||||
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
||||
# Distributed under MIT License.
|
||||
|
||||
from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d
|
||||
from .batchnorm import patch_sync_batchnorm, convert_model
|
||||
from .replicate import DataParallelWithCallback, patch_replication_callback
|
@ -1,395 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# File : batchnorm.py
|
||||
# Author : Jiayuan Mao
|
||||
# Email : maojiayuan@gmail.com
|
||||
# Date : 27/01/2018
|
||||
#
|
||||
# This file is part of Synchronized-BatchNorm-PyTorch.
|
||||
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
||||
# Distributed under MIT License.
|
||||
|
||||
import collections
|
||||
import contextlib
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
||||
try:
|
||||
from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
|
||||
except ImportError:
|
||||
ReduceAddCoalesced = Broadcast = None
|
||||
|
||||
try:
|
||||
from jactorch.parallel.comm import SyncMaster
|
||||
from jactorch.parallel.data_parallel import JacDataParallel as DataParallelWithCallback
|
||||
except ImportError:
|
||||
from .comm import SyncMaster
|
||||
from .replicate import DataParallelWithCallback
|
||||
|
||||
__all__ = [
|
||||
'SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d',
|
||||
'patch_sync_batchnorm', 'convert_model'
|
||||
]
|
||||
|
||||
|
||||
def _sum_ft(tensor):
|
||||
"""sum over the first and last dimention"""
|
||||
return tensor.sum(dim=0).sum(dim=-1)
|
||||
|
||||
|
||||
def _unsqueeze_ft(tensor):
|
||||
"""add new dimensions at the front and the tail"""
|
||||
return tensor.unsqueeze(0).unsqueeze(-1)
|
||||
|
||||
|
||||
_ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size'])
|
||||
_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])
|
||||
|
||||
|
||||
class _SynchronizedBatchNorm(_BatchNorm):
|
||||
def __init__(self, num_features, eps=1e-5, momentum=0.1, weight_freeze=False, bias_freeze=False, affine=True):
|
||||
assert ReduceAddCoalesced is not None, 'Can not use Synchronized Batch Normalization without CUDA support.'
|
||||
|
||||
super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine)
|
||||
self.weight.requires_grad_(not weight_freeze)
|
||||
self.bias.requires_grad_(not bias_freeze)
|
||||
|
||||
self._sync_master = SyncMaster(self._data_parallel_master)
|
||||
|
||||
self._is_parallel = False
|
||||
self._parallel_id = None
|
||||
self._slave_pipe = None
|
||||
|
||||
def forward(self, input):
|
||||
# If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.
|
||||
if not (self._is_parallel and self.training):
|
||||
return F.batch_norm(
|
||||
input, self.running_mean, self.running_var, self.weight, self.bias,
|
||||
self.training, self.momentum, self.eps)
|
||||
|
||||
# Resize the input to (B, C, -1).
|
||||
input_shape = input.size()
|
||||
input = input.view(input.size(0), self.num_features, -1)
|
||||
|
||||
# Compute the sum and square-sum.
|
||||
sum_size = input.size(0) * input.size(2)
|
||||
input_sum = _sum_ft(input)
|
||||
input_ssum = _sum_ft(input ** 2)
|
||||
|
||||
# Reduce-and-broadcast the statistics.
|
||||
if self._parallel_id == 0:
|
||||
mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
|
||||
else:
|
||||
mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))
|
||||
|
||||
# Compute the output.
|
||||
if self.affine:
|
||||
# MJY:: Fuse the multiplication for speed.
|
||||
output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias)
|
||||
else:
|
||||
output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)
|
||||
|
||||
# Reshape it.
|
||||
return output.view(input_shape)
|
||||
|
||||
def __data_parallel_replicate__(self, ctx, copy_id):
|
||||
self._is_parallel = True
|
||||
self._parallel_id = copy_id
|
||||
|
||||
# parallel_id == 0 means master device.
|
||||
if self._parallel_id == 0:
|
||||
ctx.sync_master = self._sync_master
|
||||
else:
|
||||
self._slave_pipe = ctx.sync_master.register_slave(copy_id)
|
||||
|
||||
def _data_parallel_master(self, intermediates):
|
||||
"""Reduce the sum and square-sum, compute the statistics, and broadcast it."""
|
||||
|
||||
# Always using same "device order" makes the ReduceAdd operation faster.
|
||||
# Thanks to:: Tete Xiao (http://tetexiao.com/)
|
||||
intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())
|
||||
|
||||
to_reduce = [i[1][:2] for i in intermediates]
|
||||
to_reduce = [j for i in to_reduce for j in i] # flatten
|
||||
target_gpus = [i[1].sum.get_device() for i in intermediates]
|
||||
|
||||
sum_size = sum([i[1].sum_size for i in intermediates])
|
||||
sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
|
||||
mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)
|
||||
|
||||
broadcasted = Broadcast.apply(target_gpus, mean, inv_std)
|
||||
|
||||
outputs = []
|
||||
for i, rec in enumerate(intermediates):
|
||||
outputs.append((rec[0], _MasterMessage(*broadcasted[i * 2:i * 2 + 2])))
|
||||
|
||||
return outputs
|
||||
|
||||
def _compute_mean_std(self, sum_, ssum, size):
|
||||
"""Compute the mean and standard-deviation with sum and square-sum. This method
|
||||
also maintains the moving average on the master device."""
|
||||
assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
|
||||
mean = sum_ / size
|
||||
sumvar = ssum - sum_ * mean
|
||||
unbias_var = sumvar / (size - 1)
|
||||
bias_var = sumvar / size
|
||||
|
||||
if hasattr(torch, 'no_grad'):
|
||||
with torch.no_grad():
|
||||
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
|
||||
self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
|
||||
else:
|
||||
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
|
||||
self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
|
||||
|
||||
return mean, bias_var.clamp(self.eps) ** -0.5
|
||||
|
||||
|
||||
class SynchronizedBatchNorm1d(_SynchronizedBatchNorm):
|
||||
r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a
|
||||
mini-batch.
|
||||
|
||||
.. math::
|
||||
|
||||
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
|
||||
|
||||
This module differs from the built-in PyTorch BatchNorm1d as the mean and
|
||||
standard-deviation are reduced across all devices during training.
|
||||
|
||||
For example, when one uses `nn.DataParallel` to wrap the network during
|
||||
training, PyTorch's implementation normalize the tensor on each device using
|
||||
the statistics only on that device, which accelerated the computation and
|
||||
is also easy to implement, but the statistics might be inaccurate.
|
||||
Instead, in this synchronized version, the statistics will be computed
|
||||
over all training samples distributed on multiple devices.
|
||||
|
||||
Note that, for one-GPU or CPU-only case, this module behaves exactly same
|
||||
as the built-in PyTorch implementation.
|
||||
|
||||
The mean and standard-deviation are calculated per-dimension over
|
||||
the mini-batches and gamma and beta are learnable parameter vectors
|
||||
of size C (where C is the input size).
|
||||
|
||||
During training, this layer keeps a running estimate of its computed mean
|
||||
and variance. The running sum is kept with a default momentum of 0.1.
|
||||
|
||||
During evaluation, this running mean/variance is used for normalization.
|
||||
|
||||
Because the BatchNorm is done over the `C` dimension, computing statistics
|
||||
on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm
|
||||
|
||||
Args:
|
||||
num_features: num_features from an expected input of size
|
||||
`batch_size x num_features [x width]`
|
||||
eps: a value added to the denominator for numerical stability.
|
||||
Default: 1e-5
|
||||
momentum: the value used for the running_mean and running_var
|
||||
computation. Default: 0.1
|
||||
affine: a boolean value that when set to ``True``, gives the layer learnable
|
||||
affine parameters. Default: ``True``
|
||||
|
||||
Shape::
|
||||
- Input: :math:`(N, C)` or :math:`(N, C, L)`
|
||||
- Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
|
||||
|
||||
Examples:
|
||||
>>> # With Learnable Parameters
|
||||
>>> m = SynchronizedBatchNorm1d(100)
|
||||
>>> # Without Learnable Parameters
|
||||
>>> m = SynchronizedBatchNorm1d(100, affine=False)
|
||||
>>> input = torch.autograd.Variable(torch.randn(20, 100))
|
||||
>>> output = m(input)
|
||||
"""
|
||||
|
||||
def _check_input_dim(self, input):
|
||||
if input.dim() != 2 and input.dim() != 3:
|
||||
raise ValueError('expected 2D or 3D input (got {}D input)'
|
||||
.format(input.dim()))
|
||||
super(SynchronizedBatchNorm1d, self)._check_input_dim(input)
|
||||
|
||||
|
||||
class SynchronizedBatchNorm2d(_SynchronizedBatchNorm):
|
||||
r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch
|
||||
of 3d inputs
|
||||
|
||||
.. math::
|
||||
|
||||
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
|
||||
|
||||
This module differs from the built-in PyTorch BatchNorm2d as the mean and
|
||||
standard-deviation are reduced across all devices during training.
|
||||
|
||||
For example, when one uses `nn.DataParallel` to wrap the network during
|
||||
training, PyTorch's implementation normalize the tensor on each device using
|
||||
the statistics only on that device, which accelerated the computation and
|
||||
is also easy to implement, but the statistics might be inaccurate.
|
||||
Instead, in this synchronized version, the statistics will be computed
|
||||
over all training samples distributed on multiple devices.
|
||||
|
||||
Note that, for one-GPU or CPU-only case, this module behaves exactly same
|
||||
as the built-in PyTorch implementation.
|
||||
|
||||
The mean and standard-deviation are calculated per-dimension over
|
||||
the mini-batches and gamma and beta are learnable parameter vectors
|
||||
of size C (where C is the input size).
|
||||
|
||||
During training, this layer keeps a running estimate of its computed mean
|
||||
and variance. The running sum is kept with a default momentum of 0.1.
|
||||
|
||||
During evaluation, this running mean/variance is used for normalization.
|
||||
|
||||
Because the BatchNorm is done over the `C` dimension, computing statistics
|
||||
on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm
|
||||
|
||||
Args:
|
||||
num_features: num_features from an expected input of
|
||||
size batch_size x num_features x height x width
|
||||
eps: a value added to the denominator for numerical stability.
|
||||
Default: 1e-5
|
||||
momentum: the value used for the running_mean and running_var
|
||||
computation. Default: 0.1
|
||||
affine: a boolean value that when set to ``True``, gives the layer learnable
|
||||
affine parameters. Default: ``True``
|
||||
|
||||
Shape::
|
||||
- Input: :math:`(N, C, H, W)`
|
||||
- Output: :math:`(N, C, H, W)` (same shape as input)
|
||||
|
||||
Examples:
|
||||
>>> # With Learnable Parameters
|
||||
>>> m = SynchronizedBatchNorm2d(100)
|
||||
>>> # Without Learnable Parameters
|
||||
>>> m = SynchronizedBatchNorm2d(100, affine=False)
|
||||
>>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45))
|
||||
>>> output = m(input)
|
||||
"""
|
||||
|
||||
def _check_input_dim(self, input):
|
||||
if input.dim() != 4:
|
||||
raise ValueError('expected 4D input (got {}D input)'
|
||||
.format(input.dim()))
|
||||
super(SynchronizedBatchNorm2d, self)._check_input_dim(input)
|
||||
|
||||
|
||||
class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):
|
||||
r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch
|
||||
of 4d inputs
|
||||
|
||||
.. math::
|
||||
|
||||
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
|
||||
|
||||
This module differs from the built-in PyTorch BatchNorm3d as the mean and
|
||||
standard-deviation are reduced across all devices during training.
|
||||
|
||||
For example, when one uses `nn.DataParallel` to wrap the network during
|
||||
training, PyTorch's implementation normalize the tensor on each device using
|
||||
the statistics only on that device, which accelerated the computation and
|
||||
is also easy to implement, but the statistics might be inaccurate.
|
||||
Instead, in this synchronized version, the statistics will be computed
|
||||
over all training samples distributed on multiple devices.
|
||||
|
||||
Note that, for one-GPU or CPU-only case, this module behaves exactly same
|
||||
as the built-in PyTorch implementation.
|
||||
|
||||
The mean and standard-deviation are calculated per-dimension over
|
||||
the mini-batches and gamma and beta are learnable parameter vectors
|
||||
of size C (where C is the input size).
|
||||
|
||||
During training, this layer keeps a running estimate of its computed mean
|
||||
and variance. The running sum is kept with a default momentum of 0.1.
|
||||
|
||||
During evaluation, this running mean/variance is used for normalization.
|
||||
|
||||
Because the BatchNorm is done over the `C` dimension, computing statistics
|
||||
on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm
|
||||
or Spatio-temporal BatchNorm
|
||||
|
||||
Args:
|
||||
num_features: num_features from an expected input of
|
||||
size batch_size x num_features x depth x height x width
|
||||
eps: a value added to the denominator for numerical stability.
|
||||
Default: 1e-5
|
||||
momentum: the value used for the running_mean and running_var
|
||||
computation. Default: 0.1
|
||||
affine: a boolean value that when set to ``True``, gives the layer learnable
|
||||
affine parameters. Default: ``True``
|
||||
|
||||
Shape::
|
||||
- Input: :math:`(N, C, D, H, W)`
|
||||
- Output: :math:`(N, C, D, H, W)` (same shape as input)
|
||||
|
||||
Examples:
|
||||
>>> # With Learnable Parameters
|
||||
>>> m = SynchronizedBatchNorm3d(100)
|
||||
>>> # Without Learnable Parameters
|
||||
>>> m = SynchronizedBatchNorm3d(100, affine=False)
|
||||
>>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10))
|
||||
>>> output = m(input)
|
||||
"""
|
||||
|
||||
def _check_input_dim(self, input):
|
||||
if input.dim() != 5:
|
||||
raise ValueError('expected 5D input (got {}D input)'
|
||||
.format(input.dim()))
|
||||
super(SynchronizedBatchNorm3d, self)._check_input_dim(input)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def patch_sync_batchnorm():
|
||||
import torch.nn as nn
|
||||
|
||||
backup = nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d
|
||||
|
||||
nn.BatchNorm1d = SynchronizedBatchNorm1d
|
||||
nn.BatchNorm2d = SynchronizedBatchNorm2d
|
||||
nn.BatchNorm3d = SynchronizedBatchNorm3d
|
||||
|
||||
yield
|
||||
|
||||
nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d = backup
|
||||
|
||||
|
||||
def convert_model(module):
|
||||
"""Traverse the input module and its child recursively
|
||||
and replace all instance of torch.nn.modules.batchnorm.BatchNorm*N*d
|
||||
to SynchronizedBatchNorm*N*d
|
||||
|
||||
Args:
|
||||
module: the input module needs to be convert to SyncBN model
|
||||
|
||||
Examples:
|
||||
>>> import torch.nn as nn
|
||||
>>> import torchvision
|
||||
>>> # m is a standard pytorch model
|
||||
>>> m = torchvision.models.resnet18(True)
|
||||
>>> m = nn.DataParallel(m)
|
||||
>>> # after convert, m is using SyncBN
|
||||
>>> m = convert_model(m)
|
||||
"""
|
||||
if isinstance(module, torch.nn.DataParallel):
|
||||
mod = module.module
|
||||
mod = convert_model(mod)
|
||||
mod = DataParallelWithCallback(mod)
|
||||
return mod
|
||||
|
||||
mod = module
|
||||
for pth_module, sync_module in zip([torch.nn.modules.batchnorm.BatchNorm1d,
|
||||
torch.nn.modules.batchnorm.BatchNorm2d,
|
||||
torch.nn.modules.batchnorm.BatchNorm3d],
|
||||
[SynchronizedBatchNorm1d,
|
||||
SynchronizedBatchNorm2d,
|
||||
SynchronizedBatchNorm3d]):
|
||||
if isinstance(module, pth_module):
|
||||
mod = sync_module(module.num_features, module.eps, module.momentum, module.affine)
|
||||
mod.running_mean = module.running_mean
|
||||
mod.running_var = module.running_var
|
||||
if module.affine:
|
||||
mod.weight.data = module.weight.data.clone().detach()
|
||||
mod.bias.data = module.bias.data.clone().detach()
|
||||
|
||||
for name, child in module.named_children():
|
||||
mod.add_module(name, convert_model(child))
|
||||
|
||||
return mod
|
@ -1,74 +0,0 @@
|
||||
#! /usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
# File : batchnorm_reimpl.py
|
||||
# Author : acgtyrant
|
||||
# Date : 11/01/2018
|
||||
#
|
||||
# This file is part of Synchronized-BatchNorm-PyTorch.
|
||||
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
||||
# Distributed under MIT License.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.init as init
|
||||
|
||||
__all__ = ['BatchNorm2dReimpl']
|
||||
|
||||
|
||||
class BatchNorm2dReimpl(nn.Module):
|
||||
"""
|
||||
A re-implementation of batch normalization, used for testing the numerical
|
||||
stability.
|
||||
|
||||
Author: acgtyrant
|
||||
See also:
|
||||
https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14
|
||||
"""
|
||||
def __init__(self, num_features, eps=1e-5, momentum=0.1):
|
||||
super().__init__()
|
||||
|
||||
self.num_features = num_features
|
||||
self.eps = eps
|
||||
self.momentum = momentum
|
||||
self.weight = nn.Parameter(torch.empty(num_features))
|
||||
self.bias = nn.Parameter(torch.empty(num_features))
|
||||
self.register_buffer('running_mean', torch.zeros(num_features))
|
||||
self.register_buffer('running_var', torch.ones(num_features))
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_running_stats(self):
|
||||
self.running_mean.zero_()
|
||||
self.running_var.fill_(1)
|
||||
|
||||
def reset_parameters(self):
|
||||
self.reset_running_stats()
|
||||
init.uniform_(self.weight)
|
||||
init.zeros_(self.bias)
|
||||
|
||||
def forward(self, input_):
|
||||
batchsize, channels, height, width = input_.size()
|
||||
numel = batchsize * height * width
|
||||
input_ = input_.permute(1, 0, 2, 3).contiguous().view(channels, numel)
|
||||
sum_ = input_.sum(1)
|
||||
sum_of_square = input_.pow(2).sum(1)
|
||||
mean = sum_ / numel
|
||||
sumvar = sum_of_square - sum_ * mean
|
||||
|
||||
self.running_mean = (
|
||||
(1 - self.momentum) * self.running_mean
|
||||
+ self.momentum * mean.detach()
|
||||
)
|
||||
unbias_var = sumvar / (numel - 1)
|
||||
self.running_var = (
|
||||
(1 - self.momentum) * self.running_var
|
||||
+ self.momentum * unbias_var.detach()
|
||||
)
|
||||
|
||||
bias_var = sumvar / numel
|
||||
inv_std = 1 / (bias_var + self.eps).pow(0.5)
|
||||
output = (
|
||||
(input_ - mean.unsqueeze(1)) * inv_std.unsqueeze(1) *
|
||||
self.weight.unsqueeze(1) + self.bias.unsqueeze(1))
|
||||
|
||||
return output.view(channels, batchsize, height, width).permute(1, 0, 2, 3).contiguous()
|
||||
|
@ -1,137 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# File : comm.py
|
||||
# Author : Jiayuan Mao
|
||||
# Email : maojiayuan@gmail.com
|
||||
# Date : 27/01/2018
|
||||
#
|
||||
# This file is part of Synchronized-BatchNorm-PyTorch.
|
||||
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
||||
# Distributed under MIT License.
|
||||
|
||||
import queue
|
||||
import collections
|
||||
import threading
|
||||
|
||||
__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
|
||||
|
||||
|
||||
class FutureResult(object):
|
||||
"""A thread-safe future implementation. Used only as one-to-one pipe."""
|
||||
|
||||
def __init__(self):
|
||||
self._result = None
|
||||
self._lock = threading.Lock()
|
||||
self._cond = threading.Condition(self._lock)
|
||||
|
||||
def put(self, result):
|
||||
with self._lock:
|
||||
assert self._result is None, 'Previous result has\'t been fetched.'
|
||||
self._result = result
|
||||
self._cond.notify()
|
||||
|
||||
def get(self):
|
||||
with self._lock:
|
||||
if self._result is None:
|
||||
self._cond.wait()
|
||||
|
||||
res = self._result
|
||||
self._result = None
|
||||
return res
|
||||
|
||||
|
||||
_MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
|
||||
_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
|
||||
|
||||
|
||||
class SlavePipe(_SlavePipeBase):
|
||||
"""Pipe for master-slave communication."""
|
||||
|
||||
def run_slave(self, msg):
|
||||
self.queue.put((self.identifier, msg))
|
||||
ret = self.result.get()
|
||||
self.queue.put(True)
|
||||
return ret
|
||||
|
||||
|
||||
class SyncMaster(object):
|
||||
"""An abstract `SyncMaster` object.
|
||||
|
||||
- During the replication, as the data parallel will trigger an callback of each module, all slave devices should
|
||||
call `register(id)` and obtain an `SlavePipe` to communicate with the master.
|
||||
- During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
|
||||
and passed to a registered callback.
|
||||
- After receiving the messages, the master device should gather the information and determine to message passed
|
||||
back to each slave devices.
|
||||
"""
|
||||
|
||||
def __init__(self, master_callback):
|
||||
"""
|
||||
|
||||
Args:
|
||||
master_callback: a callback to be invoked after having collected messages from slave devices.
|
||||
"""
|
||||
self._master_callback = master_callback
|
||||
self._queue = queue.Queue()
|
||||
self._registry = collections.OrderedDict()
|
||||
self._activated = False
|
||||
|
||||
def __getstate__(self):
|
||||
return {'master_callback': self._master_callback}
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.__init__(state['master_callback'])
|
||||
|
||||
def register_slave(self, identifier):
|
||||
"""
|
||||
Register an slave device.
|
||||
|
||||
Args:
|
||||
identifier: an identifier, usually is the device id.
|
||||
|
||||
Returns: a `SlavePipe` object which can be used to communicate with the master device.
|
||||
|
||||
"""
|
||||
if self._activated:
|
||||
assert self._queue.empty(), 'Queue is not clean before next initialization.'
|
||||
self._activated = False
|
||||
self._registry.clear()
|
||||
future = FutureResult()
|
||||
self._registry[identifier] = _MasterRegistry(future)
|
||||
return SlavePipe(identifier, self._queue, future)
|
||||
|
||||
def run_master(self, master_msg):
|
||||
"""
|
||||
Main entry for the master device in each forward pass.
|
||||
The messages were first collected from each devices (including the master device), and then
|
||||
an callback will be invoked to compute the message to be sent back to each devices
|
||||
(including the master device).
|
||||
|
||||
Args:
|
||||
master_msg: the message that the master want to send to itself. This will be placed as the first
|
||||
message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
|
||||
|
||||
Returns: the message to be sent back to the master device.
|
||||
|
||||
"""
|
||||
self._activated = True
|
||||
|
||||
intermediates = [(0, master_msg)]
|
||||
for i in range(self.nr_slaves):
|
||||
intermediates.append(self._queue.get())
|
||||
|
||||
results = self._master_callback(intermediates)
|
||||
assert results[0][0] == 0, 'The first result should belongs to the master.'
|
||||
|
||||
for i, res in results:
|
||||
if i == 0:
|
||||
continue
|
||||
self._registry[i].result.put(res)
|
||||
|
||||
for i in range(self.nr_slaves):
|
||||
assert self._queue.get() is True
|
||||
|
||||
return results[0][1]
|
||||
|
||||
@property
|
||||
def nr_slaves(self):
|
||||
return len(self._registry)
|
@ -1,94 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# File : replicate.py
|
||||
# Author : Jiayuan Mao
|
||||
# Email : maojiayuan@gmail.com
|
||||
# Date : 27/01/2018
|
||||
#
|
||||
# This file is part of Synchronized-BatchNorm-PyTorch.
|
||||
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
||||
# Distributed under MIT License.
|
||||
|
||||
import functools
|
||||
|
||||
from torch.nn.parallel.data_parallel import DataParallel
|
||||
|
||||
__all__ = [
|
||||
'CallbackContext',
|
||||
'execute_replication_callbacks',
|
||||
'DataParallelWithCallback',
|
||||
'patch_replication_callback'
|
||||
]
|
||||
|
||||
|
||||
class CallbackContext(object):
|
||||
pass
|
||||
|
||||
|
||||
def execute_replication_callbacks(modules):
|
||||
"""
|
||||
Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.
|
||||
|
||||
The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
|
||||
|
||||
Note that, as all modules are isomorphism, we assign each sub-module with a context
|
||||
(shared among multiple copies of this module on different devices).
|
||||
Through this context, different copies can share some information.
|
||||
|
||||
We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback
|
||||
of any slave copies.
|
||||
"""
|
||||
master_copy = modules[0]
|
||||
nr_modules = len(list(master_copy.modules()))
|
||||
ctxs = [CallbackContext() for _ in range(nr_modules)]
|
||||
|
||||
for i, module in enumerate(modules):
|
||||
for j, m in enumerate(module.modules()):
|
||||
if hasattr(m, '__data_parallel_replicate__'):
|
||||
m.__data_parallel_replicate__(ctxs[j], i)
|
||||
|
||||
|
||||
class DataParallelWithCallback(DataParallel):
|
||||
"""
|
||||
Data Parallel with a replication callback.
|
||||
|
||||
An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by
|
||||
original `replicate` function.
|
||||
The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
|
||||
|
||||
Examples:
|
||||
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
|
||||
> sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
|
||||
# sync_bn.__data_parallel_replicate__ will be invoked.
|
||||
"""
|
||||
|
||||
def replicate(self, module, device_ids):
|
||||
modules = super(DataParallelWithCallback, self).replicate(module, device_ids)
|
||||
execute_replication_callbacks(modules)
|
||||
return modules
|
||||
|
||||
|
||||
def patch_replication_callback(data_parallel):
|
||||
"""
|
||||
Monkey-patch an existing `DataParallel` object. Add the replication callback.
|
||||
Useful when you have customized `DataParallel` implementation.
|
||||
|
||||
Examples:
|
||||
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
|
||||
> sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
|
||||
> patch_replication_callback(sync_bn)
|
||||
# this is equivalent to
|
||||
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
|
||||
> sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
|
||||
"""
|
||||
|
||||
assert isinstance(data_parallel, DataParallel)
|
||||
|
||||
old_replicate = data_parallel.replicate
|
||||
|
||||
@functools.wraps(old_replicate)
|
||||
def new_replicate(module, device_ids):
|
||||
modules = old_replicate(module, device_ids)
|
||||
execute_replication_callbacks(modules)
|
||||
return modules
|
||||
|
||||
data_parallel.replicate = new_replicate
|
@ -1,29 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# File : unittest.py
|
||||
# Author : Jiayuan Mao
|
||||
# Email : maojiayuan@gmail.com
|
||||
# Date : 27/01/2018
|
||||
#
|
||||
# This file is part of Synchronized-BatchNorm-PyTorch.
|
||||
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
||||
# Distributed under MIT License.
|
||||
|
||||
import unittest
|
||||
import torch
|
||||
|
||||
|
||||
class TorchTestCase(unittest.TestCase):
|
||||
def assertTensorClose(self, x, y):
|
||||
adiff = float((x - y).abs().max())
|
||||
if (y == 0).all():
|
||||
rdiff = 'NaN'
|
||||
else:
|
||||
rdiff = float((adiff / y).abs().max())
|
||||
|
||||
message = (
|
||||
'Tensor close check failed\n'
|
||||
'adiff={}\n'
|
||||
'rdiff={}\n'
|
||||
).format(adiff, rdiff)
|
||||
self.assertTrue(torch.allclose(x, y), message)
|
||||
|
@ -256,7 +256,7 @@ def build_resnet_backbone(cfg):
|
||||
if pretrain:
|
||||
if not with_ibn:
|
||||
try:
|
||||
state_dict = torch.load(pretrain_path)['model']
|
||||
state_dict = torch.load(pretrain_path, map_location=torch.device('cpu'))['model']
|
||||
# Remove module.encoder in name
|
||||
new_state_dict = {}
|
||||
for k in state_dict:
|
||||
@ -270,7 +270,7 @@ def build_resnet_backbone(cfg):
|
||||
state_dict = model_zoo.load_url(model_urls[depth])
|
||||
logger.info("Loading pretrained model from torchvision")
|
||||
else:
|
||||
state_dict = torch.load(pretrain_path)['state_dict'] # ibn-net
|
||||
state_dict = torch.load(pretrain_path, map_location=torch.device('cpu'))['state_dict'] # ibn-net
|
||||
# Remove module in name
|
||||
new_state_dict = {}
|
||||
for k in state_dict:
|
||||
|
@ -4,7 +4,6 @@
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
from .build_losses import reid_losses
|
||||
from .cross_entroy_loss import CrossEntropyLoss
|
||||
from .focal_loss import FocalLoss
|
||||
from .metric_loss import *
|
||||
|
@ -1,20 +0,0 @@
|
||||
# encoding: utf-8
|
||||
"""
|
||||
@author: xingyu liao
|
||||
@contact: liaoxingyu5@jd.com
|
||||
"""
|
||||
|
||||
from .. import losses as Loss
|
||||
|
||||
|
||||
def reid_losses(cfg, pred_class_logits, global_features, gt_classes, prefix='') -> dict:
|
||||
loss_dict = {}
|
||||
for loss_name in cfg.MODEL.LOSSES.NAME:
|
||||
loss = getattr(Loss, loss_name)(cfg)(pred_class_logits, global_features, gt_classes)
|
||||
loss_dict.update(loss)
|
||||
# rename
|
||||
named_loss_dict = {}
|
||||
for name in loss_dict.keys():
|
||||
named_loss_dict[prefix + name] = loss_dict[name]
|
||||
del loss_dict
|
||||
return named_loss_dict
|
@ -1,46 +0,0 @@
|
||||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class CenterLoss(nn.Module):
|
||||
"""Center loss.
|
||||
Reference:
|
||||
Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016.
|
||||
Args:
|
||||
num_classes (int): number of classes.
|
||||
feat_dim (int): feature dimension.
|
||||
"""
|
||||
|
||||
def __init__(self, num_classes=751, feat_dim=2048, use_gpu=True):
|
||||
super(CenterLoss, self).__init__()
|
||||
self.num_classes,self.feat_dim = num_classes, feat_dim
|
||||
|
||||
if use_gpu: self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim).cuda())
|
||||
else: self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim))
|
||||
|
||||
def forward(self, x, labels):
|
||||
"""
|
||||
Args:
|
||||
x: feature matrix with shape (batch_size, feat_dim).
|
||||
labels: ground truth labels with shape (num_classes).
|
||||
"""
|
||||
assert x.size(0) == labels.size(0), "features.size(0) is not equal to labels.size(0)"
|
||||
|
||||
batch_size = x.size(0)
|
||||
distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \
|
||||
torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t()
|
||||
distmat.addmm_(1, -2, x, self.centers.t())
|
||||
|
||||
classes = torch.arange(self.num_classes).long()
|
||||
classes = classes.to(x.device)
|
||||
labels = labels.unsqueeze(1).expand(batch_size, self.num_classes)
|
||||
mask = labels.eq(classes.expand(batch_size, self.num_classes))
|
||||
|
||||
dist = distmat * mask.float()
|
||||
loss = dist.clamp(min=1e-12, max=1e+12).sum() / batch_size
|
||||
return loss
|
@ -20,33 +20,31 @@ class CrossEntropyLoss(object):
|
||||
self._alpha = cfg.MODEL.LOSSES.CE.ALPHA
|
||||
self._scale = cfg.MODEL.LOSSES.CE.SCALE
|
||||
|
||||
self._topk = (1,)
|
||||
|
||||
def _log_accuracy(self, pred_class_logits, gt_classes):
|
||||
@staticmethod
|
||||
def log_accuracy(pred_class_logits, gt_classes, topk=(1,)):
|
||||
"""
|
||||
Log the accuracy metrics to EventStorage.
|
||||
"""
|
||||
bsz = pred_class_logits.size(0)
|
||||
maxk = max(self._topk)
|
||||
maxk = max(topk)
|
||||
_, pred_class = pred_class_logits.topk(maxk, 1, True, True)
|
||||
pred_class = pred_class.t()
|
||||
correct = pred_class.eq(gt_classes.view(1, -1).expand_as(pred_class))
|
||||
|
||||
ret = []
|
||||
for k in self._topk:
|
||||
for k in topk:
|
||||
correct_k = correct[:k].view(-1).float().sum(dim=0, keepdim=True)
|
||||
ret.append(correct_k.mul_(1. / bsz))
|
||||
|
||||
storage = get_event_storage()
|
||||
storage.put_scalar("cls_accuracy", ret[0])
|
||||
|
||||
def __call__(self, pred_class_logits, _, gt_classes):
|
||||
def __call__(self, pred_class_logits, gt_classes):
|
||||
"""
|
||||
Compute the softmax cross entropy loss for box classification.
|
||||
Returns:
|
||||
scalar Tensor
|
||||
"""
|
||||
self._log_accuracy(pred_class_logits, gt_classes)
|
||||
if self._eps >= 0:
|
||||
smooth_param = self._eps
|
||||
else:
|
||||
@ -61,6 +59,4 @@ class CrossEntropyLoss(object):
|
||||
targets.scatter_(1, gt_classes.data.unsqueeze(1), (1 - smooth_param))
|
||||
|
||||
loss = (-targets * log_probs).mean(0).sum()
|
||||
return {
|
||||
"loss_cls": loss * self._scale,
|
||||
}
|
||||
return loss * self._scale
|
||||
|
@ -7,8 +7,6 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from fastreid.utils.one_hot import one_hot
|
||||
|
||||
|
||||
# based on:
|
||||
# https://github.com/kornia/kornia/blob/master/kornia/losses/focal.py
|
||||
@ -49,9 +47,7 @@ def focal_loss(
|
||||
input_soft = F.softmax(input, dim=1)
|
||||
|
||||
# create the labels one hot tensor
|
||||
target_one_hot = one_hot(
|
||||
target, num_classes=input.shape[1],
|
||||
dtype=input.dtype)
|
||||
target_one_hot = F.one_hot(target, num_classes=input.shape[1])
|
||||
|
||||
# compute the actual focal loss
|
||||
weight = torch.pow(-input_soft + 1., gamma)
|
||||
|
@ -6,6 +6,7 @@
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from fastreid.utils import comm
|
||||
|
||||
__all__ = [
|
||||
"TripletLoss",
|
||||
@ -13,6 +14,21 @@ __all__ = [
|
||||
]
|
||||
|
||||
|
||||
# utils
|
||||
@torch.no_grad()
|
||||
def concat_all_gather(tensor):
|
||||
"""
|
||||
Performs all_gather operation on the provided tensors.
|
||||
*** Warning ***: torch.distributed.all_gather has no gradient.
|
||||
"""
|
||||
tensors_gather = [torch.ones_like(tensor)
|
||||
for _ in range(torch.distributed.get_world_size())]
|
||||
torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
|
||||
|
||||
output = torch.cat(tensors_gather, dim=0)
|
||||
return output
|
||||
|
||||
|
||||
def normalize(x, axis=-1):
|
||||
"""Normalizing to unit length along the specified dimension.
|
||||
Args:
|
||||
@ -54,9 +70,9 @@ def softmax_weights(dist, mask):
|
||||
def hard_example_mining(dist_mat, is_pos, is_neg):
|
||||
"""For each anchor, find the hardest positive and negative sample.
|
||||
Args:
|
||||
dist_mat: pytorch Variable, pair wise distance between samples, shape [N, N]
|
||||
labels: pytorch LongTensor, with shape [N]
|
||||
return_inds: whether to return the indices. Save time if `False`(?)
|
||||
dist_mat: pair wise distance between samples, shape [N, M]
|
||||
is_pos: positive index with shape [N, M]
|
||||
is_neg: negative index with shape [N, M]
|
||||
Returns:
|
||||
dist_ap: pytorch Variable, distance(anchor, positive); shape [N]
|
||||
dist_an: pytorch Variable, distance(anchor, negative); shape [N]
|
||||
@ -69,7 +85,6 @@ def hard_example_mining(dist_mat, is_pos, is_neg):
|
||||
"""
|
||||
|
||||
assert len(dist_mat.size()) == 2
|
||||
assert dist_mat.size(0) == dist_mat.size(1)
|
||||
N = dist_mat.size(0)
|
||||
|
||||
# `dist_ap` means distance(anchor, positive)
|
||||
@ -98,12 +113,13 @@ def weighted_example_mining(dist_mat, is_pos, is_neg):
|
||||
"""For each anchor, find the weighted positive and negative sample.
|
||||
Args:
|
||||
dist_mat: pytorch Variable, pair wise distance between samples, shape [N, N]
|
||||
is_pos:
|
||||
is_neg:
|
||||
Returns:
|
||||
dist_ap: pytorch Variable, distance(anchor, positive); shape [N]
|
||||
dist_an: pytorch Variable, distance(anchor, negative); shape [N]
|
||||
"""
|
||||
assert len(dist_mat.size()) == 2
|
||||
assert dist_mat.size(0) == dist_mat.size(1)
|
||||
|
||||
is_pos = is_pos.float()
|
||||
is_neg = is_neg.float()
|
||||
@ -130,15 +146,23 @@ class TripletLoss(object):
|
||||
self._scale = cfg.MODEL.LOSSES.TRI.SCALE
|
||||
self._hard_mining = cfg.MODEL.LOSSES.TRI.HARD_MINING
|
||||
|
||||
def __call__(self, _, global_features, targets):
|
||||
def __call__(self, embedding, targets):
|
||||
if self._normalize_feature:
|
||||
global_features = normalize(global_features, axis=-1)
|
||||
embedding = normalize(embedding, axis=-1)
|
||||
|
||||
dist_mat = euclidean_dist(global_features, global_features)
|
||||
# For distributed training, gather all features from different process.
|
||||
if comm.get_world_size() > 1:
|
||||
all_embedding = concat_all_gather(embedding)
|
||||
all_targets = concat_all_gather(targets)
|
||||
else:
|
||||
all_embedding = embedding
|
||||
all_targets = targets
|
||||
|
||||
N = dist_mat.size(0)
|
||||
is_pos = targets.expand(N, N).eq(targets.expand(N, N).t())
|
||||
is_neg = targets.expand(N, N).ne(targets.expand(N, N).t())
|
||||
dist_mat = euclidean_dist(embedding, all_embedding)
|
||||
|
||||
N, M = dist_mat.size()
|
||||
is_pos = targets.view(N, 1).expand(N, M).eq(all_targets.view(M, 1).expand(M, N).t())
|
||||
is_neg = targets.view(N, 1).expand(N, M).ne(all_targets.view(M, 1).expand(M, N).t())
|
||||
|
||||
if self._hard_mining:
|
||||
dist_ap, dist_an = hard_example_mining(dist_mat, is_pos, is_neg)
|
||||
@ -153,9 +177,7 @@ class TripletLoss(object):
|
||||
loss = F.soft_margin_loss(dist_an - dist_ap, y)
|
||||
if loss == float('Inf'): loss = F.margin_ranking_loss(dist_an, dist_ap, y, margin=0.3)
|
||||
|
||||
return {
|
||||
"loss_triplet": loss * self._scale,
|
||||
}
|
||||
return loss * self._scale
|
||||
|
||||
|
||||
class CircleLoss(object):
|
||||
@ -165,18 +187,24 @@ class CircleLoss(object):
|
||||
self.m = cfg.MODEL.LOSSES.CIRCLE.MARGIN
|
||||
self.s = cfg.MODEL.LOSSES.CIRCLE.ALPHA
|
||||
|
||||
def __call__(self, _, global_features, targets):
|
||||
global_features = F.normalize(global_features, dim=1)
|
||||
def __call__(self, embedding, targets):
|
||||
embedding = F.normalize(embedding, dim=1)
|
||||
|
||||
sim_mat = torch.matmul(global_features, global_features.t())
|
||||
if comm.get_world_size() > 1:
|
||||
all_embedding = concat_all_gather(embedding)
|
||||
all_targets = concat_all_gather(targets)
|
||||
else:
|
||||
all_embedding = embedding
|
||||
all_targets = targets
|
||||
|
||||
N = sim_mat.size(0)
|
||||
is_pos = targets.expand(N, N).eq(targets.expand(N, N).t()).float() - torch.eye(N).to(sim_mat.device)
|
||||
is_pos = is_pos.bool()
|
||||
is_neg = targets.expand(N, N).ne(targets.expand(N, N).t())
|
||||
dist_mat = torch.matmul(embedding, all_embedding.t())
|
||||
|
||||
s_p = sim_mat[is_pos].contiguous().view(N, -1)
|
||||
s_n = sim_mat[is_neg].contiguous().view(N, -1)
|
||||
N, M = dist_mat.size()
|
||||
is_pos = targets.view(N, 1).expand(N, M).eq(all_targets.view(M, 1).expand(M, N).t())
|
||||
is_neg = targets.view(N, 1).expand(N, M).ne(all_targets.view(M, 1).expand(M, N).t())
|
||||
|
||||
s_p = dist_mat[is_pos].contiguous().view(N, -1)
|
||||
s_n = dist_mat[is_neg].contiguous().view(N, -1)
|
||||
|
||||
alpha_p = F.relu(-s_p.detach() + 1 + self.m)
|
||||
alpha_n = F.relu(s_n.detach() + self.m)
|
||||
@ -188,6 +216,4 @@ class CircleLoss(object):
|
||||
|
||||
loss = F.softplus(torch.logsumexp(logit_p, dim=1) + torch.logsumexp(logit_n, dim=1)).mean()
|
||||
|
||||
return {
|
||||
"loss_circle": loss * self._scale,
|
||||
}
|
||||
return loss * self._scale
|
||||
|
@ -9,4 +9,4 @@ from .build import META_ARCH_REGISTRY, build_model
|
||||
|
||||
# import all the meta_arch, so they will be registered
|
||||
from .baseline import Baseline
|
||||
from .mgn import MGN
|
||||
from .mgn import MGN
|
||||
|
@ -10,7 +10,7 @@ from torch import nn
|
||||
from fastreid.layers import GeneralizedMeanPoolingP, AdaptiveAvgMaxPool2d, FastGlobalAvgPool2d
|
||||
from fastreid.modeling.backbones import build_backbone
|
||||
from fastreid.modeling.heads import build_reid_heads
|
||||
from fastreid.modeling.losses import reid_losses
|
||||
from fastreid.modeling.losses import *
|
||||
from .build import META_ARCH_REGISTRY
|
||||
|
||||
|
||||
@ -18,9 +18,11 @@ from .build import META_ARCH_REGISTRY
|
||||
class Baseline(nn.Module):
|
||||
def __init__(self, cfg):
|
||||
super().__init__()
|
||||
self.register_buffer("pixel_mean", torch.Tensor(cfg.MODEL.PIXEL_MEAN).view(1, -1, 1, 1))
|
||||
self.register_buffer("pixel_std", torch.Tensor(cfg.MODEL.PIXEL_STD).view(1, -1, 1, 1))
|
||||
self._cfg = cfg
|
||||
assert len(cfg.MODEL.PIXEL_MEAN) == len(cfg.MODEL.PIXEL_STD)
|
||||
self.register_buffer("pixel_mean", torch.tensor(cfg.MODEL.PIXEL_MEAN).view(1, -1, 1, 1))
|
||||
self.register_buffer("pixel_std", torch.tensor(cfg.MODEL.PIXEL_STD).view(1, -1, 1, 1))
|
||||
|
||||
# backbone
|
||||
self.backbone = build_backbone(cfg)
|
||||
|
||||
@ -44,35 +46,50 @@ class Baseline(nn.Module):
|
||||
return self.pixel_mean.device
|
||||
|
||||
def forward(self, batched_inputs):
|
||||
if not self.training:
|
||||
pred_feat = self.inference(batched_inputs)
|
||||
try: return pred_feat, batched_inputs["targets"], batched_inputs["camid"]
|
||||
except Exception: return pred_feat
|
||||
|
||||
images = self.preprocess_image(batched_inputs)
|
||||
targets = batched_inputs["targets"].long()
|
||||
features = self.backbone(images)
|
||||
|
||||
# training
|
||||
features = self.backbone(images) # (bs, 2048, 16, 8)
|
||||
return self.heads(features, targets)
|
||||
if self.training:
|
||||
assert "targets" in batched_inputs, "Person ID annotation are missing in training!"
|
||||
targets = batched_inputs["targets"].long().to(self.device)
|
||||
|
||||
def inference(self, batched_inputs):
|
||||
assert not self.training
|
||||
images = self.preprocess_image(batched_inputs)
|
||||
features = self.backbone(images) # (bs, 2048, 16, 8)
|
||||
pred_feat = self.heads(features)
|
||||
return pred_feat
|
||||
# PreciseBN flag
|
||||
if targets.sum() < 0:
|
||||
# If do preciseBN on different dataset, the number of classes in new dataset
|
||||
# may be larger than that in the original dataset, so the circle/arcface will
|
||||
# throw an error. We just set all the targets to 0 to avoid this situation.
|
||||
targets.zero_()
|
||||
self.heads(features, targets)
|
||||
# We just skip loss computation, because targets are all 0, and triplet loss
|
||||
# will go wrong when targets are not in PK sampler way.
|
||||
return
|
||||
cls_outputs, features = self.heads(features, targets)
|
||||
losses = self.losses(cls_outputs, features, targets)
|
||||
return losses
|
||||
else:
|
||||
pred_features = self.heads(features)
|
||||
return pred_features
|
||||
|
||||
def preprocess_image(self, batched_inputs):
|
||||
"""
|
||||
Normalize and batch the input images.
|
||||
"""
|
||||
# images = [x["images"] for x in batched_inputs]
|
||||
images = batched_inputs["images"]
|
||||
images = batched_inputs["images"].to(self.device)
|
||||
# images = batched_inputs
|
||||
images.sub_(self.pixel_mean).div_(self.pixel_std)
|
||||
return images
|
||||
|
||||
def losses(self, outputs):
|
||||
logits, feat, targets = outputs
|
||||
return reid_losses(self._cfg, logits, feat, targets)
|
||||
def losses(self, cls_outputs, pred_features, gt_labels):
|
||||
loss_dict = {}
|
||||
loss_names = self._cfg.MODEL.LOSSES.NAME
|
||||
|
||||
if "CrossEntropyLoss" in loss_names:
|
||||
loss_dict['loss_cls'] = CrossEntropyLoss(self._cfg)(cls_outputs, gt_labels)
|
||||
|
||||
if "TripletLoss" in loss_names:
|
||||
loss_dict['loss_triplet'] = TripletLoss(self._cfg)(pred_features, gt_labels)
|
||||
|
||||
if "CircleLoss" in loss_names:
|
||||
loss_dict['loss_circle'] = CircleLoss(self._cfg)(pred_features, gt_labels)
|
||||
|
||||
return loss_dict
|
||||
|
@ -3,8 +3,9 @@
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
import torch
|
||||
|
||||
from ...utils.registry import Registry
|
||||
from fastreid.utils.registry import Registry
|
||||
|
||||
META_ARCH_REGISTRY = Registry("META_ARCH") # noqa F401 isort:skip
|
||||
META_ARCH_REGISTRY.__doc__ = """
|
||||
@ -20,4 +21,6 @@ def build_model(cfg):
|
||||
Note that it does not load any weights from ``cfg``.
|
||||
"""
|
||||
meta_arch = cfg.MODEL.META_ARCHITECTURE
|
||||
return META_ARCH_REGISTRY.get(meta_arch)(cfg)
|
||||
model = META_ARCH_REGISTRY.get(meta_arch)(cfg)
|
||||
model.to(torch.device(cfg.MODEL.DEVICE))
|
||||
return model
|
||||
|
@ -12,7 +12,7 @@ from fastreid.layers import GeneralizedMeanPoolingP, get_norm, AdaptiveAvgMaxPoo
|
||||
from fastreid.modeling.backbones import build_backbone
|
||||
from fastreid.modeling.backbones.resnet import Bottleneck
|
||||
from fastreid.modeling.heads import build_reid_heads
|
||||
from fastreid.modeling.losses import reid_losses, CrossEntropyLoss
|
||||
from fastreid.modeling.losses import CrossEntropyLoss, TripletLoss
|
||||
from fastreid.utils.weight_init import weights_init_kaiming
|
||||
from .build import META_ARCH_REGISTRY
|
||||
|
||||
@ -21,9 +21,10 @@ from .build import META_ARCH_REGISTRY
|
||||
class MGN(nn.Module):
|
||||
def __init__(self, cfg):
|
||||
super().__init__()
|
||||
self._cfg = cfg
|
||||
assert len(cfg.MODEL.PIXEL_MEAN) == len(cfg.MODEL.PIXEL_STD)
|
||||
self.register_buffer("pixel_mean", torch.Tensor(cfg.MODEL.PIXEL_MEAN).view(1, -1, 1, 1))
|
||||
self.register_buffer("pixel_std", torch.Tensor(cfg.MODEL.PIXEL_STD).view(1, -1, 1, 1))
|
||||
self._cfg = cfg
|
||||
|
||||
# backbone
|
||||
bn_norm = cfg.MODEL.BACKBONE.NORM
|
||||
@ -116,129 +117,113 @@ class MGN(nn.Module):
|
||||
return self.pixel_mean.device
|
||||
|
||||
def forward(self, batched_inputs):
|
||||
if not self.training:
|
||||
pred_feat = self.inference(batched_inputs)
|
||||
try: return pred_feat, batched_inputs["targets"], batched_inputs["camid"]
|
||||
except Exception: return pred_feat
|
||||
|
||||
images = self.preprocess_image(batched_inputs)
|
||||
targets = batched_inputs["targets"].long()
|
||||
|
||||
# Training
|
||||
features = self.backbone(images) # (bs, 2048, 16, 8)
|
||||
|
||||
# branch1
|
||||
b1_feat = self.b1(features)
|
||||
b1_pool_feat = self.b1_pool(b1_feat)
|
||||
b1_logits, b1_pool_feat, _ = self.b1_head(b1_pool_feat, targets)
|
||||
|
||||
# branch2
|
||||
b2_feat = self.b2(features)
|
||||
# global
|
||||
b2_pool_feat = self.b2_pool(b2_feat)
|
||||
b2_logits, b2_pool_feat, _ = self.b2_head(b2_pool_feat, targets)
|
||||
|
||||
b21_feat, b22_feat = torch.chunk(b2_feat, 2, dim=2)
|
||||
# part1
|
||||
b21_pool_feat = self.b21_pool(b21_feat)
|
||||
b21_logits, b21_pool_feat, _ = self.b21_head(b21_pool_feat, targets)
|
||||
# part2
|
||||
b22_pool_feat = self.b22_pool(b22_feat)
|
||||
b22_logits, b22_pool_feat, _ = self.b22_head(b22_pool_feat, targets)
|
||||
|
||||
# branch3
|
||||
b3_feat = self.b3(features)
|
||||
# global
|
||||
b3_pool_feat = self.b3_pool(b3_feat)
|
||||
b3_logits, b3_pool_feat, _ = self.b3_head(b3_pool_feat, targets)
|
||||
|
||||
b31_feat, b32_feat, b33_feat = torch.chunk(b3_feat, 3, dim=2)
|
||||
# part1
|
||||
b31_pool_feat = self.b31_pool(b31_feat)
|
||||
b31_logits, b31_pool_feat, _ = self.b31_head(b31_pool_feat, targets)
|
||||
# part2
|
||||
b32_pool_feat = self.b32_pool(b32_feat)
|
||||
b32_logits, b32_pool_feat, _ = self.b32_head(b32_pool_feat, targets)
|
||||
# part3
|
||||
b33_pool_feat = self.b33_pool(b33_feat)
|
||||
b33_logits, b33_pool_feat, _ = self.b33_head(b33_pool_feat, targets)
|
||||
|
||||
return (b1_logits, b2_logits, b3_logits, b21_logits, b22_logits, b31_logits, b32_logits, b33_logits), \
|
||||
(b1_pool_feat, b2_pool_feat, b3_pool_feat,
|
||||
if self.training:
|
||||
assert "targets" in batched_inputs, "Person ID annotation are missing in training!"
|
||||
targets = batched_inputs["targets"].long().to(self.device)
|
||||
|
||||
if targets.sum() < 0:
|
||||
targets.zero_()
|
||||
self.b1_head(b1_pool_feat, targets)
|
||||
self.b2_head(b2_pool_feat, targets)
|
||||
self.b21_head(b21_pool_feat, targets)
|
||||
self.b22_head(b22_pool_feat, targets)
|
||||
self.b3_head(b3_pool_feat, targets)
|
||||
self.b31_head(b31_pool_feat, targets)
|
||||
self.b32_head(b32_pool_feat, targets)
|
||||
self.b33_head(b33_pool_feat, targets)
|
||||
return
|
||||
|
||||
b1_logits, b1_pool_feat = self.b1_head(b1_pool_feat, targets)
|
||||
b2_logits, b2_pool_feat = self.b2_head(b2_pool_feat, targets)
|
||||
b21_logits, b21_pool_feat = self.b21_head(b21_pool_feat, targets)
|
||||
b22_logits, b22_pool_feat = self.b22_head(b22_pool_feat, targets)
|
||||
b3_logits, b3_pool_feat = self.b3_head(b3_pool_feat, targets)
|
||||
b31_logits, b31_pool_feat = self.b31_head(b31_pool_feat, targets)
|
||||
b32_logits, b32_pool_feat = self.b32_head(b32_pool_feat, targets)
|
||||
b33_logits, b33_pool_feat = self.b33_head(b33_pool_feat, targets)
|
||||
losses = self.losses(
|
||||
b1_logits, b2_logits, b21_logits, b22_logits, b3_logits, b31_logits, b32_logits, b33_logits,
|
||||
b1_pool_feat, b2_pool_feat, b3_pool_feat,
|
||||
torch.cat((b21_pool_feat, b22_pool_feat), dim=1),
|
||||
torch.cat((b31_pool_feat, b32_pool_feat, b33_pool_feat), dim=1)), \
|
||||
targets
|
||||
torch.cat((b31_pool_feat, b32_pool_feat, b33_pool_feat), dim=1),
|
||||
targets,
|
||||
)
|
||||
return losses
|
||||
else:
|
||||
b1_pool_feat = self.b1_head(b1_pool_feat)
|
||||
b2_pool_feat = self.b2_head(b2_pool_feat)
|
||||
b21_pool_feat = self.b21_head(b21_pool_feat)
|
||||
b22_pool_feat = self.b22_head(b22_pool_feat)
|
||||
b3_pool_feat = self.b3_head(b3_pool_feat)
|
||||
b31_pool_feat = self.b31_head(b31_pool_feat)
|
||||
b32_pool_feat = self.b32_head(b32_pool_feat)
|
||||
b33_pool_feat = self.b33_head(b33_pool_feat)
|
||||
|
||||
def inference(self, batched_inputs):
|
||||
assert not self.training
|
||||
images = self.preprocess_image(batched_inputs)
|
||||
features = self.backbone(images) # (bs, 2048, 16, 8)
|
||||
|
||||
# branch1
|
||||
b1_feat = self.b1(features)
|
||||
b1_pool_feat = self.b1_pool(b1_feat)
|
||||
b1_pool_feat = self.b1_head(b1_pool_feat)
|
||||
|
||||
# branch2
|
||||
b2_feat = self.b2(features)
|
||||
# global
|
||||
b2_pool_feat = self.b2_pool(b2_feat)
|
||||
b2_pool_feat = self.b2_head(b2_pool_feat)
|
||||
|
||||
b21_feat, b22_feat = torch.chunk(b2_feat, 2, dim=2)
|
||||
# part1
|
||||
b21_pool_feat = self.b21_pool(b21_feat)
|
||||
b21_pool_feat = self.b21_head(b21_pool_feat)
|
||||
# part2
|
||||
b22_pool_feat = self.b22_pool(b22_feat)
|
||||
b22_pool_feat = self.b22_head(b22_pool_feat)
|
||||
|
||||
# branch3
|
||||
b3_feat = self.b3(features)
|
||||
# global
|
||||
b3_pool_feat = self.b3_pool(b3_feat)
|
||||
b3_pool_feat = self.b3_head(b3_pool_feat)
|
||||
|
||||
b31_feat, b32_feat, b33_feat = torch.chunk(b3_feat, 3, dim=2)
|
||||
# part1
|
||||
b31_pool_feat = self.b31_pool(b31_feat)
|
||||
b31_pool_feat = self.b31_head(b31_pool_feat)
|
||||
# part2
|
||||
b32_pool_feat = self.b32_pool(b32_feat)
|
||||
b32_pool_feat = self.b32_head(b32_pool_feat)
|
||||
# part3
|
||||
b33_pool_feat = self.b33_pool(b33_feat)
|
||||
b33_pool_feat = self.b33_head(b33_pool_feat)
|
||||
|
||||
pred_feat = torch.cat([b1_pool_feat, b2_pool_feat, b3_pool_feat, b21_pool_feat,
|
||||
b22_pool_feat, b31_pool_feat, b32_pool_feat, b33_pool_feat], dim=1)
|
||||
return pred_feat
|
||||
pred_feat = torch.cat([b1_pool_feat, b2_pool_feat, b3_pool_feat, b21_pool_feat,
|
||||
b22_pool_feat, b31_pool_feat, b32_pool_feat, b33_pool_feat], dim=1)
|
||||
return pred_feat
|
||||
|
||||
def preprocess_image(self, batched_inputs):
|
||||
"""
|
||||
Normalize and batch the input images.
|
||||
"""
|
||||
# images = [x["images"] for x in batched_inputs]
|
||||
images = batched_inputs["images"]
|
||||
images = batched_inputs["images"].to(self.device)
|
||||
# images = batched_inputs
|
||||
images.sub_(self.pixel_mean).div_(self.pixel_std)
|
||||
return images
|
||||
|
||||
def losses(self, outputs):
|
||||
logits, feats, targets = outputs
|
||||
def losses(self, b1_logits, b2_logits, b21_logits, b22_logits, b3_logits, b31_logits, b32_logits, b33_logits,
|
||||
b1_pool_feat, b2_pool_feat, b3_pool_feat, b22_pool_feat, b33_pool_feat, gt_labels):
|
||||
loss_dict = {}
|
||||
loss_dict.update(reid_losses(self._cfg, logits[0], feats[0], targets, 'b1_'))
|
||||
loss_dict.update(reid_losses(self._cfg, logits[1], feats[1], targets, 'b2_'))
|
||||
loss_dict.update(reid_losses(self._cfg, logits[2], feats[2], targets, 'b3_'))
|
||||
loss_dict.update(reid_losses(self._cfg, logits[3], feats[3], targets, 'b21_'))
|
||||
loss_dict.update(reid_losses(self._cfg, logits[5], feats[4], targets, 'b31_'))
|
||||
loss_names = self._cfg.MODEL.LOSSES.NAME
|
||||
|
||||
if "CrossEntropyLoss" in loss_names:
|
||||
loss_dict['loss_cls_b1'] = CrossEntropyLoss(self._cfg)(b1_logits, gt_labels)
|
||||
loss_dict['loss_cls_b2'] = CrossEntropyLoss(self._cfg)(b2_logits, gt_labels)
|
||||
loss_dict['loss_cls_b21'] = CrossEntropyLoss(self._cfg)(b21_logits, gt_labels)
|
||||
loss_dict['loss_cls_b22'] = CrossEntropyLoss(self._cfg)(b22_logits, gt_labels)
|
||||
loss_dict['loss_cls_b3'] = CrossEntropyLoss(self._cfg)(b3_logits, gt_labels)
|
||||
loss_dict['loss_cls_b31'] = CrossEntropyLoss(self._cfg)(b31_logits, gt_labels)
|
||||
loss_dict['loss_cls_b32'] = CrossEntropyLoss(self._cfg)(b32_logits, gt_labels)
|
||||
loss_dict['loss_cls_b33'] = CrossEntropyLoss(self._cfg)(b33_logits, gt_labels)
|
||||
|
||||
if "TripletLoss" in loss_names:
|
||||
loss_dict['loss_triplet_b1'] = TripletLoss(self._cfg)(b1_pool_feat, gt_labels)
|
||||
loss_dict['loss_triplet_b2'] = TripletLoss(self._cfg)(b2_pool_feat, gt_labels)
|
||||
loss_dict['loss_triplet_b3'] = TripletLoss(self._cfg)(b3_pool_feat, gt_labels)
|
||||
loss_dict['loss_triplet_b22'] = TripletLoss(self._cfg)(b22_pool_feat, gt_labels)
|
||||
loss_dict['loss_triplet_b33'] = TripletLoss(self._cfg)(b33_pool_feat, gt_labels)
|
||||
|
||||
part_ce_loss = [
|
||||
(CrossEntropyLoss(self._cfg)(logits[4], None, targets), 'b22_'),
|
||||
(CrossEntropyLoss(self._cfg)(logits[6], None, targets), 'b32_'),
|
||||
(CrossEntropyLoss(self._cfg)(logits[7], None, targets), 'b33_')
|
||||
]
|
||||
named_ce_loss = {}
|
||||
for item in part_ce_loss:
|
||||
named_ce_loss[item[1] + [*item[0]][0]] = [*item[0].values()][0]
|
||||
loss_dict.update(named_ce_loss)
|
||||
return loss_dict
|
||||
|
@ -20,7 +20,7 @@ def build_optimizer(cfg, model):
|
||||
if "bias" in key:
|
||||
lr *= cfg.SOLVER.BIAS_LR_FACTOR
|
||||
weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS
|
||||
params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}]
|
||||
params += [{"name": key, "params": [value], "lr": lr, "weight_decay": weight_decay, "freeze": False}]
|
||||
|
||||
solver_opt = cfg.SOLVER.OPT
|
||||
if hasattr(optim, solver_opt):
|
||||
|
@ -4,13 +4,14 @@
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import math
|
||||
from bisect import bisect_right
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from torch.optim.lr_scheduler import _LRScheduler, CosineAnnealingLR
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
|
||||
__all__ = ["WarmupMultiStepLR", "DelayedScheduler"]
|
||||
__all__ = ["WarmupMultiStepLR", "WarmupCosineAnnealingLR"]
|
||||
|
||||
|
||||
class WarmupMultiStepLR(_LRScheduler):
|
||||
@ -50,6 +51,71 @@ class WarmupMultiStepLR(_LRScheduler):
|
||||
return self.get_lr()
|
||||
|
||||
|
||||
class WarmupCosineAnnealingLR(_LRScheduler):
|
||||
r"""Set the learning rate of each parameter group using a cosine annealing
|
||||
schedule, where :math:`\eta_{max}` is set to the initial lr and
|
||||
:math:`T_{cur}` is the number of epochs since the last restart in SGDR:
|
||||
|
||||
.. math::
|
||||
\eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})(1 +
|
||||
\cos(\frac{T_{cur}}{T_{max}}\pi))
|
||||
|
||||
When last_epoch=-1, sets initial lr as lr.
|
||||
|
||||
It has been proposed in
|
||||
`SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this only
|
||||
implements the cosine annealing part of SGDR, and not the restarts.
|
||||
|
||||
Args:
|
||||
optimizer (Optimizer): Wrapped optimizer.
|
||||
T_max (int): Maximum number of iterations.
|
||||
eta_min (float): Minimum learning rate. Default: 0.
|
||||
last_epoch (int): The index of last epoch. Default: -1.
|
||||
|
||||
.. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
|
||||
https://arxiv.org/abs/1608.03983
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
max_iters: int,
|
||||
delay_iters: int = 0,
|
||||
eta_min_lr: int = 0,
|
||||
warmup_factor: float = 0.001,
|
||||
warmup_iters: int = 1000,
|
||||
warmup_method: str = "linear",
|
||||
last_epoch=-1,
|
||||
**kwargs
|
||||
):
|
||||
self.max_iters = max_iters
|
||||
self.delay_iters = delay_iters
|
||||
self.eta_min_lr = eta_min_lr
|
||||
self.warmup_factor = warmup_factor
|
||||
self.warmup_iters = warmup_iters
|
||||
self.warmup_method = warmup_method
|
||||
assert self.delay_iters >= self.warmup_iters, "Scheduler delay iters must be larger than warmup iters"
|
||||
super(WarmupCosineAnnealingLR, self).__init__(optimizer, last_epoch)
|
||||
|
||||
def get_lr(self) -> List[float]:
|
||||
if self.last_epoch <= self.warmup_iters:
|
||||
warmup_factor = _get_warmup_factor_at_iter(
|
||||
self.warmup_method, self.last_epoch, self.warmup_iters, self.warmup_factor,
|
||||
)
|
||||
return [
|
||||
base_lr * warmup_factor for base_lr in self.base_lrs
|
||||
]
|
||||
elif self.last_epoch <= self.delay_iters:
|
||||
return self.base_lrs
|
||||
|
||||
else:
|
||||
return [
|
||||
self.eta_min_lr + (base_lr - self.eta_min_lr) *
|
||||
(1 + math.cos(
|
||||
math.pi * (self.last_epoch - self.delay_iters) / (self.max_iters - self.delay_iters))) / 2
|
||||
for base_lr in self.base_lrs]
|
||||
|
||||
|
||||
def _get_warmup_factor_at_iter(
|
||||
method: str, iter: int, warmup_iters: int, warmup_factor: float
|
||||
) -> float:
|
||||
@ -75,49 +141,3 @@ def _get_warmup_factor_at_iter(
|
||||
return warmup_factor * (1 - alpha) + alpha
|
||||
else:
|
||||
raise ValueError("Unknown warmup method: {}".format(method))
|
||||
|
||||
|
||||
class DelayedScheduler(_LRScheduler):
|
||||
""" Starts with a flat lr schedule until it reaches N epochs the applies a scheduler
|
||||
Args:
|
||||
optimizer (Optimizer): Wrapped optimizer.
|
||||
delay_iters: number of epochs to keep the initial lr until starting applying the scheduler
|
||||
after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau)
|
||||
"""
|
||||
|
||||
def __init__(self, optimizer, delay_iters, after_scheduler, warmup_factor, warmup_iters, warmup_method):
|
||||
self.delay_epochs = delay_iters
|
||||
self.after_scheduler = after_scheduler
|
||||
self.finished = False
|
||||
self.warmup_factor = warmup_factor
|
||||
self.warmup_iters = warmup_iters
|
||||
self.warmup_method = warmup_method
|
||||
super().__init__(optimizer)
|
||||
|
||||
def get_lr(self):
|
||||
if self.last_epoch >= self.delay_epochs:
|
||||
if not self.finished:
|
||||
self.after_scheduler.base_lrs = self.base_lrs
|
||||
self.finished = True
|
||||
return self.after_scheduler.get_lr()
|
||||
|
||||
warmup_factor = _get_warmup_factor_at_iter(
|
||||
self.warmup_method, self.last_epoch, self.warmup_iters, self.warmup_factor
|
||||
)
|
||||
return [base_lr * warmup_factor for base_lr in self.base_lrs]
|
||||
|
||||
def step(self, epoch=None):
|
||||
if self.finished:
|
||||
if epoch is None:
|
||||
self.after_scheduler.step(None)
|
||||
else:
|
||||
self.after_scheduler.step(epoch - self.delay_epochs)
|
||||
else:
|
||||
return super(DelayedScheduler, self).step(epoch)
|
||||
|
||||
|
||||
def DelayedCosineAnnealingLR(optimizer, delay_iters, max_iters, eta_min_lr, warmup_factor,
|
||||
warmup_iters, warmup_method, **kwargs, ):
|
||||
cosine_annealing_iters = max_iters - delay_iters
|
||||
base_scheduler = CosineAnnealingLR(optimizer, cosine_annealing_iters, eta_min_lr)
|
||||
return DelayedScheduler(optimizer, delay_iters, base_scheduler, warmup_factor, warmup_iters, warmup_method)
|
||||
|
@ -1,11 +1,5 @@
|
||||
from .lamb import Lamb
|
||||
from .lars import LARS
|
||||
from .lookahead import Lookahead, LookaheadAdam
|
||||
from .novograd import Novograd
|
||||
from .over9000 import Over9000, RangerLars
|
||||
from .radam import RAdam, PlainRAdam, AdamW
|
||||
from .ralamb import Ralamb
|
||||
from .ranger import Ranger
|
||||
from .swa import SWA
|
||||
from .adam import Adam
|
||||
from .sgd import SGD
|
||||
|
||||
from torch.optim import *
|
||||
|
115
fastreid/solver/optim/adam.py
Normal file
115
fastreid/solver/optim/adam.py
Normal file
@ -0,0 +1,115 @@
|
||||
# encoding: utf-8
|
||||
"""
|
||||
@author: xingyu liao
|
||||
@contact: liaoxingyu5@jd.com
|
||||
"""
|
||||
|
||||
import torch
|
||||
import math
|
||||
from torch.optim.optimizer import Optimizer
|
||||
|
||||
|
||||
class Adam(Optimizer):
|
||||
r"""Implements Adam algorithm.
|
||||
|
||||
It has been proposed in `Adam: A Method for Stochastic Optimization`_.
|
||||
|
||||
Arguments:
|
||||
params (iterable): iterable of parameters to optimize or dicts defining
|
||||
parameter groups
|
||||
lr (float, optional): learning rate (default: 1e-3)
|
||||
betas (Tuple[float, float], optional): coefficients used for computing
|
||||
running averages of gradient and its square (default: (0.9, 0.999))
|
||||
eps (float, optional): term added to the denominator to improve
|
||||
numerical stability (default: 1e-8)
|
||||
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
|
||||
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
|
||||
algorithm from the paper `On the Convergence of Adam and Beyond`_
|
||||
(default: False)
|
||||
|
||||
.. _Adam\: A Method for Stochastic Optimization:
|
||||
https://arxiv.org/abs/1412.6980
|
||||
.. _On the Convergence of Adam and Beyond:
|
||||
https://openreview.net/forum?id=ryQu7f-RZ
|
||||
"""
|
||||
|
||||
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
|
||||
weight_decay=0, amsgrad=False):
|
||||
if not 0.0 <= lr:
|
||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||
if not 0.0 <= eps:
|
||||
raise ValueError("Invalid epsilon value: {}".format(eps))
|
||||
if not 0.0 <= betas[0] < 1.0:
|
||||
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
|
||||
if not 0.0 <= betas[1] < 1.0:
|
||||
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
|
||||
defaults = dict(lr=lr, betas=betas, eps=eps,
|
||||
weight_decay=weight_decay, amsgrad=amsgrad)
|
||||
super(Adam, self).__init__(params, defaults)
|
||||
|
||||
def __setstate__(self, state):
|
||||
super(Adam, self).__setstate__(state)
|
||||
for group in self.param_groups:
|
||||
group.setdefault('amsgrad', False)
|
||||
|
||||
def step(self, closure=None):
|
||||
"""Performs a single optimization step.
|
||||
|
||||
Arguments:
|
||||
closure (callable, optional): A closure that reevaluates the model
|
||||
and returns the loss.
|
||||
"""
|
||||
loss = None
|
||||
if closure is not None:
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
for p in group['params']:
|
||||
if p.grad is None or group['freeze']:
|
||||
continue
|
||||
grad = p.grad.data
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
|
||||
amsgrad = group['amsgrad']
|
||||
|
||||
state = self.state[p]
|
||||
|
||||
# State initialization
|
||||
if len(state) == 0:
|
||||
state['step'] = 0
|
||||
# Exponential moving average of gradient values
|
||||
state['exp_avg'] = torch.zeros_like(p.data)
|
||||
# Exponential moving average of squared gradient values
|
||||
state['exp_avg_sq'] = torch.zeros_like(p.data)
|
||||
if amsgrad:
|
||||
# Maintains max of all exp. moving avg. of sq. grad. values
|
||||
state['max_exp_avg_sq'] = torch.zeros_like(p.data)
|
||||
|
||||
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
||||
if amsgrad:
|
||||
max_exp_avg_sq = state['max_exp_avg_sq']
|
||||
beta1, beta2 = group['betas']
|
||||
|
||||
state['step'] += 1
|
||||
|
||||
if group['weight_decay'] != 0:
|
||||
grad.add_(group['weight_decay'], p.data)
|
||||
|
||||
# Decay the first and second moment running average coefficient
|
||||
exp_avg.mul_(beta1).add_(1 - beta1, grad)
|
||||
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
|
||||
if amsgrad:
|
||||
# Maintains the maximum of all 2nd moment running avg. till now
|
||||
torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
|
||||
# Use the max. for normalizing running avg. of gradient
|
||||
denom = max_exp_avg_sq.sqrt().add_(group['eps'])
|
||||
else:
|
||||
denom = exp_avg_sq.sqrt().add_(group['eps'])
|
||||
|
||||
bias_correction1 = 1 - beta1 ** state['step']
|
||||
bias_correction2 = 1 - beta2 ** state['step']
|
||||
step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
|
||||
|
||||
p.data.addcdiv_(-step_size, exp_avg, denom)
|
||||
|
||||
return loss
|
@ -68,7 +68,7 @@ class Lamb(Optimizer):
|
||||
|
||||
for group in self.param_groups:
|
||||
for p in group['params']:
|
||||
if p.grad is None:
|
||||
if p.grad is None or group['freeze']:
|
||||
continue
|
||||
grad = p.grad.data
|
||||
if grad.is_sparse:
|
||||
|
@ -1,101 +0,0 @@
|
||||
# encoding: utf-8
|
||||
"""
|
||||
@author: xingyu liao
|
||||
@contact: liaoxingyu5@jd.com
|
||||
"""
|
||||
|
||||
# based on:
|
||||
# https://github.com/NVIDIA/apex/blob/d74fda260c403f775817470d87f810f816f3d615/apex/parallel/LARC.py
|
||||
|
||||
import torch
|
||||
from torch.optim.optimizer import Optimizer
|
||||
|
||||
|
||||
class LARS(Optimizer):
|
||||
"""
|
||||
:class:`LARS` is a pytorch implementation of both the scaling and clipping variants of LARC,
|
||||
in which the ratio between gradient and parameter magnitudes is used to calculate an adaptive
|
||||
local learning rate for each individual parameter. The algorithm is designed to improve
|
||||
convergence of large batch training.
|
||||
|
||||
See https://arxiv.org/abs/1708.03888 for calculation of the local learning rate.
|
||||
In practice it modifies the gradients of parameters as a proxy for modifying the learning rate
|
||||
of the parameters. This design allows it to be used as a wrapper around any torch.optim Optimizer.
|
||||
```
|
||||
model = ...
|
||||
optim = torch.optim.Adam(model.parameters(), lr=...)
|
||||
optim = LARS(optim)
|
||||
```
|
||||
It can even be used in conjunction with apex.fp16_utils.FP16_optimizer.
|
||||
```
|
||||
model = ...
|
||||
optim = torch.optim.Adam(model.parameters(), lr=...)
|
||||
optim = LARS(optim)
|
||||
optim = apex.fp16_utils.FP16_Optimizer(optim)
|
||||
```
|
||||
Args:
|
||||
optimizer: Pytorch optimizer to wrap and modify learning rate for.
|
||||
trust_coefficient: Trust coefficient for calculating the lr. See https://arxiv.org/abs/1708.03888
|
||||
clip: Decides between clipping or scaling mode of LARC. If `clip=True` the learning rate is set to `min(optimizer_lr, local_lr)` for each parameter. If `clip=False` the learning rate is set to `local_lr*optimizer_lr`.
|
||||
eps: epsilon kludge to help with numerical stability while calculating adaptive_lr
|
||||
"""
|
||||
|
||||
def __init__(self, optimizer, trust_coefficient=0.02, clip=True, eps=1e-8):
|
||||
self.param_groups = optimizer.param_groups
|
||||
self.optim = optimizer
|
||||
self.trust_coefficient = trust_coefficient
|
||||
self.eps = eps
|
||||
self.clip = clip
|
||||
|
||||
def __getstate__(self):
|
||||
return self.optim.__getstate__()
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.optim.__setstate__(state)
|
||||
|
||||
def __repr__(self):
|
||||
return self.optim.__repr__()
|
||||
|
||||
def state_dict(self):
|
||||
return self.optim.state_dict()
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self.optim.load_state_dict(state_dict)
|
||||
|
||||
def zero_grad(self):
|
||||
self.optim.zero_grad()
|
||||
|
||||
def add_param_group(self, param_group):
|
||||
self.optim.add_param_group(param_group)
|
||||
|
||||
def step(self):
|
||||
with torch.no_grad():
|
||||
weight_decays = []
|
||||
for group in self.optim.param_groups:
|
||||
# absorb weight decay control from optimizer
|
||||
weight_decay = group['weight_decay'] if 'weight_decay' in group else 0
|
||||
weight_decays.append(weight_decay)
|
||||
group['weight_decay'] = 0
|
||||
for p in group['params']:
|
||||
if p.grad is None:
|
||||
continue
|
||||
param_norm = torch.norm(p.data)
|
||||
grad_norm = torch.norm(p.grad.data)
|
||||
|
||||
if param_norm != 0 and grad_norm != 0:
|
||||
# calculate adaptive lr + weight decay
|
||||
adaptive_lr = self.trust_coefficient * (param_norm) / (
|
||||
grad_norm + param_norm * weight_decay + self.eps)
|
||||
|
||||
# clip learning rate for LARC
|
||||
if self.clip:
|
||||
# calculation of adaptive_lr so that when multiplied by lr it equals `min(adaptive_lr, lr)`
|
||||
adaptive_lr = min(adaptive_lr / group['lr'], 1)
|
||||
|
||||
p.grad.data += weight_decay * p.data
|
||||
p.grad.data *= adaptive_lr
|
||||
|
||||
self.optim.step()
|
||||
# return weight decay control to optimizer
|
||||
for i, group in enumerate(self.optim.param_groups):
|
||||
group['weight_decay'] = weight_decays[i]
|
@ -1,104 +0,0 @@
|
||||
####
|
||||
# CODE TAKEN FROM https://github.com/lonePatient/lookahead_pytorch
|
||||
# Original paper: https://arxiv.org/abs/1907.08610
|
||||
####
|
||||
# Lookahead implementation from https://github.com/rwightman/pytorch-image-models/blob/master/timm/optim/lookahead.py
|
||||
|
||||
""" Lookahead Optimizer Wrapper.
|
||||
Implementation modified from: https://github.com/alphadl/lookahead.pytorch
|
||||
Paper: `Lookahead Optimizer: k steps forward, 1 step back` - https://arxiv.org/abs/1907.08610
|
||||
"""
|
||||
from collections import defaultdict
|
||||
|
||||
import torch
|
||||
from torch.optim import Adam
|
||||
from torch.optim.optimizer import Optimizer
|
||||
|
||||
|
||||
class Lookahead(Optimizer):
|
||||
def __init__(self, base_optimizer, alpha=0.5, k=6):
|
||||
if not 0.0 <= alpha <= 1.0:
|
||||
raise ValueError(f'Invalid slow update rate: {alpha}')
|
||||
if not 1 <= k:
|
||||
raise ValueError(f'Invalid lookahead steps: {k}')
|
||||
defaults = dict(lookahead_alpha=alpha, lookahead_k=k, lookahead_step=0)
|
||||
self.base_optimizer = base_optimizer
|
||||
self.param_groups = self.base_optimizer.param_groups
|
||||
self.defaults = base_optimizer.defaults
|
||||
self.defaults.update(defaults)
|
||||
self.state = defaultdict(dict)
|
||||
# manually add our defaults to the param groups
|
||||
for name, default in defaults.items():
|
||||
for group in self.param_groups:
|
||||
group.setdefault(name, default)
|
||||
|
||||
def update_slow(self, group):
|
||||
for fast_p in group["params"]:
|
||||
if fast_p.grad is None:
|
||||
continue
|
||||
param_state = self.state[fast_p]
|
||||
if 'slow_buffer' not in param_state:
|
||||
param_state['slow_buffer'] = torch.empty_like(fast_p.data)
|
||||
param_state['slow_buffer'].copy_(fast_p.data)
|
||||
slow = param_state['slow_buffer']
|
||||
slow.add_(group['lookahead_alpha'], fast_p.data - slow)
|
||||
fast_p.data.copy_(slow)
|
||||
|
||||
def sync_lookahead(self):
|
||||
for group in self.param_groups:
|
||||
self.update_slow(group)
|
||||
|
||||
def step(self, closure=None):
|
||||
# print(self.k)
|
||||
# assert id(self.param_groups) == id(self.base_optimizer.param_groups)
|
||||
loss = self.base_optimizer.step(closure)
|
||||
for group in self.param_groups:
|
||||
group['lookahead_step'] += 1
|
||||
if group['lookahead_step'] % group['lookahead_k'] == 0:
|
||||
self.update_slow(group)
|
||||
return loss
|
||||
|
||||
def state_dict(self):
|
||||
fast_state_dict = self.base_optimizer.state_dict()
|
||||
slow_state = {
|
||||
(id(k) if isinstance(k, torch.Tensor) else k): v
|
||||
for k, v in self.state.items()
|
||||
}
|
||||
fast_state = fast_state_dict['state']
|
||||
param_groups = fast_state_dict['param_groups']
|
||||
return {
|
||||
'state': fast_state,
|
||||
'slow_state': slow_state,
|
||||
'param_groups': param_groups,
|
||||
}
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
fast_state_dict = {
|
||||
'state': state_dict['state'],
|
||||
'param_groups': state_dict['param_groups'],
|
||||
}
|
||||
self.base_optimizer.load_state_dict(fast_state_dict)
|
||||
|
||||
# We want to restore the slow state, but share param_groups reference
|
||||
# with base_optimizer. This is a bit redundant but least code
|
||||
slow_state_new = False
|
||||
if 'slow_state' not in state_dict:
|
||||
print('Loading state_dict from optimizer without Lookahead applied.')
|
||||
state_dict['slow_state'] = defaultdict(dict)
|
||||
slow_state_new = True
|
||||
slow_state_dict = {
|
||||
'state': state_dict['slow_state'],
|
||||
'param_groups': state_dict['param_groups'], # this is pointless but saves code
|
||||
}
|
||||
super(Lookahead, self).load_state_dict(slow_state_dict)
|
||||
self.param_groups = self.base_optimizer.param_groups # make both ref same container
|
||||
if slow_state_new:
|
||||
# reapply defaults to catch missing lookahead specific ones
|
||||
for name, default in self.defaults.items():
|
||||
for group in self.param_groups:
|
||||
group.setdefault(name, default)
|
||||
|
||||
|
||||
def LookaheadAdam(params, alpha=0.5, k=6, *args, **kwargs):
|
||||
adam = Adam(params, *args, **kwargs)
|
||||
return Lookahead(adam, alpha, k)
|
@ -1,229 +0,0 @@
|
||||
####
|
||||
# CODE TAKEN FROM https://github.com/mgrankin/over9000
|
||||
####
|
||||
|
||||
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
from torch.optim.optimizer import Optimizer
|
||||
import math
|
||||
|
||||
|
||||
class AdamW(Optimizer):
|
||||
"""Implements AdamW algorithm.
|
||||
|
||||
It has been proposed in `Adam: A Method for Stochastic Optimization`_.
|
||||
|
||||
Arguments:
|
||||
params (iterable): iterable of parameters to optimize or dicts defining
|
||||
parameter groups
|
||||
lr (float, optional): learning rate (default: 1e-3)
|
||||
betas (Tuple[float, float], optional): coefficients used for computing
|
||||
running averages of gradient and its square (default: (0.9, 0.999))
|
||||
eps (float, optional): term added to the denominator to improve
|
||||
numerical stability (default: 1e-8)
|
||||
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
|
||||
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
|
||||
algorithm from the paper `On the Convergence of Adam and Beyond`_
|
||||
|
||||
Adam: A Method for Stochastic Optimization:
|
||||
https://arxiv.org/abs/1412.6980
|
||||
On the Convergence of Adam and Beyond:
|
||||
https://openreview.net/forum?id=ryQu7f-RZ
|
||||
"""
|
||||
|
||||
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
|
||||
weight_decay=0, amsgrad=False):
|
||||
if not 0.0 <= lr:
|
||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||
if not 0.0 <= eps:
|
||||
raise ValueError("Invalid epsilon value: {}".format(eps))
|
||||
if not 0.0 <= betas[0] < 1.0:
|
||||
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
|
||||
if not 0.0 <= betas[1] < 1.0:
|
||||
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
|
||||
defaults = dict(lr=lr, betas=betas, eps=eps,
|
||||
weight_decay=weight_decay, amsgrad=amsgrad)
|
||||
super(AdamW, self).__init__(params, defaults)
|
||||
|
||||
def __setstate__(self, state):
|
||||
super(AdamW, self).__setstate__(state)
|
||||
for group in self.param_groups:
|
||||
group.setdefault('amsgrad', False)
|
||||
|
||||
def step(self, closure=None):
|
||||
"""Performs a single optimization step.
|
||||
|
||||
Arguments:
|
||||
closure (callable, optional): A closure that reevaluates the model
|
||||
and returns the loss.
|
||||
"""
|
||||
loss = None
|
||||
if closure is not None:
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
for p in group['params']:
|
||||
if p.grad is None:
|
||||
continue
|
||||
grad = p.grad.data
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
|
||||
amsgrad = group['amsgrad']
|
||||
|
||||
state = self.state[p]
|
||||
|
||||
# State initialization
|
||||
if len(state) == 0:
|
||||
state['step'] = 0
|
||||
# Exponential moving average of gradient values
|
||||
state['exp_avg'] = torch.zeros_like(p.data)
|
||||
# Exponential moving average of squared gradient values
|
||||
state['exp_avg_sq'] = torch.zeros_like(p.data)
|
||||
if amsgrad:
|
||||
# Maintains max of all exp. moving avg. of sq. grad. values
|
||||
state['max_exp_avg_sq'] = torch.zeros_like(p.data)
|
||||
|
||||
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
||||
if amsgrad:
|
||||
max_exp_avg_sq = state['max_exp_avg_sq']
|
||||
beta1, beta2 = group['betas']
|
||||
|
||||
state['step'] += 1
|
||||
# Decay the first and second moment running average coefficient
|
||||
exp_avg.mul_(beta1).add_(1 - beta1, grad)
|
||||
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
|
||||
if amsgrad:
|
||||
# Maintains the maximum of all 2nd moment running avg. till now
|
||||
torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
|
||||
# Use the max. for normalizing running avg. of gradient
|
||||
denom = max_exp_avg_sq.sqrt().add_(group['eps'])
|
||||
else:
|
||||
denom = exp_avg_sq.sqrt().add_(group['eps'])
|
||||
|
||||
bias_correction1 = 1 - beta1 ** state['step']
|
||||
bias_correction2 = 1 - beta2 ** state['step']
|
||||
step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
|
||||
p.data.add_(-step_size, torch.mul(p.data, group['weight_decay']).addcdiv_(1, exp_avg, denom))
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
class Novograd(Optimizer):
|
||||
"""
|
||||
Implements Novograd algorithm.
|
||||
|
||||
Args:
|
||||
params (iterable): iterable of parameters to optimize or dicts defining
|
||||
parameter groups
|
||||
lr (float, optional): learning rate (default: 1e-3)
|
||||
betas (Tuple[float, float], optional): coefficients used for computing
|
||||
running averages of gradient and its square (default: (0.95, 0))
|
||||
eps (float, optional): term added to the denominator to improve
|
||||
numerical stability (default: 1e-8)
|
||||
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
|
||||
grad_averaging: gradient averaging
|
||||
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
|
||||
algorithm from the paper `On the Convergence of Adam and Beyond`_
|
||||
(default: False)
|
||||
"""
|
||||
|
||||
def __init__(self, params, lr=1e-3, betas=(0.95, 0), eps=1e-8,
|
||||
weight_decay=0, grad_averaging=False, amsgrad=False):
|
||||
if not 0.0 <= lr:
|
||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||
if not 0.0 <= eps:
|
||||
raise ValueError("Invalid epsilon value: {}".format(eps))
|
||||
if not 0.0 <= betas[0] < 1.0:
|
||||
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
|
||||
if not 0.0 <= betas[1] < 1.0:
|
||||
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
|
||||
defaults = dict(lr=lr, betas=betas, eps=eps,
|
||||
weight_decay=weight_decay,
|
||||
grad_averaging=grad_averaging,
|
||||
amsgrad=amsgrad)
|
||||
|
||||
super(Novograd, self).__init__(params, defaults)
|
||||
|
||||
def __setstate__(self, state):
|
||||
super(Novograd, self).__setstate__(state)
|
||||
for group in self.param_groups:
|
||||
group.setdefault('amsgrad', False)
|
||||
|
||||
def step(self, closure=None):
|
||||
"""Performs a single optimization step.
|
||||
|
||||
Arguments:
|
||||
closure (callable, optional): A closure that reevaluates the model
|
||||
and returns the loss.
|
||||
"""
|
||||
loss = None
|
||||
if closure is not None:
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
for p in group['params']:
|
||||
if p.grad is None:
|
||||
continue
|
||||
grad = p.grad.data
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError('Sparse gradients are not supported.')
|
||||
amsgrad = group['amsgrad']
|
||||
|
||||
state = self.state[p]
|
||||
|
||||
# State initialization
|
||||
if len(state) == 0:
|
||||
state['step'] = 0
|
||||
# Exponential moving average of gradient values
|
||||
state['exp_avg'] = torch.zeros_like(p.data)
|
||||
# Exponential moving average of squared gradient values
|
||||
state['exp_avg_sq'] = torch.zeros([]).to(state['exp_avg'].device)
|
||||
if amsgrad:
|
||||
# Maintains max of all exp. moving avg. of sq. grad. values
|
||||
state['max_exp_avg_sq'] = torch.zeros([]).to(state['exp_avg'].device)
|
||||
|
||||
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
||||
if amsgrad:
|
||||
max_exp_avg_sq = state['max_exp_avg_sq']
|
||||
beta1, beta2 = group['betas']
|
||||
|
||||
state['step'] += 1
|
||||
|
||||
norm = torch.sum(torch.pow(grad, 2))
|
||||
|
||||
if exp_avg_sq == 0:
|
||||
exp_avg_sq.copy_(norm)
|
||||
else:
|
||||
exp_avg_sq.mul_(beta2).add_(1 - beta2, norm)
|
||||
|
||||
if amsgrad:
|
||||
# Maintains the maximum of all 2nd moment running avg. till now
|
||||
torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
|
||||
# Use the max. for normalizing running avg. of gradient
|
||||
denom = max_exp_avg_sq.sqrt().add_(group['eps'])
|
||||
else:
|
||||
denom = exp_avg_sq.sqrt().add_(group['eps'])
|
||||
|
||||
grad.div_(denom)
|
||||
if group['weight_decay'] != 0:
|
||||
grad.add_(group['weight_decay'], p.data)
|
||||
if group['grad_averaging']:
|
||||
grad.mul_(1 - beta1)
|
||||
exp_avg.mul_(beta1).add_(grad)
|
||||
|
||||
p.data.add_(-group['lr'], exp_avg)
|
||||
|
||||
return loss
|
@ -1,19 +0,0 @@
|
||||
####
|
||||
# CODE TAKEN FROM https://github.com/mgrankin/over9000
|
||||
####
|
||||
|
||||
from .lookahead import Lookahead
|
||||
from .ralamb import Ralamb
|
||||
|
||||
|
||||
# RAdam + LARS + LookAHead
|
||||
|
||||
# Lookahead implementation from https://github.com/lonePatient/lookahead_pytorch/blob/master/optimizer.py
|
||||
# RAdam + LARS implementation from https://gist.github.com/redknightlois/c4023d393eb8f92bb44b2ab582d7ec20
|
||||
|
||||
def Over9000(params, alpha=0.5, k=6, *args, **kwargs):
|
||||
ralamb = Ralamb(params, *args, **kwargs)
|
||||
return Lookahead(ralamb, alpha, k)
|
||||
|
||||
|
||||
RangerLars = Over9000
|
@ -1,255 +0,0 @@
|
||||
####
|
||||
# CODE TAKEN FROM https://github.com/LiyuanLucasLiu/RAdam
|
||||
# Paper: https://arxiv.org/abs/1908.03265
|
||||
####
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch.optim.optimizer import Optimizer
|
||||
|
||||
|
||||
class RAdam(Optimizer):
|
||||
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=True):
|
||||
if not 0.0 <= lr:
|
||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||
if not 0.0 <= eps:
|
||||
raise ValueError("Invalid epsilon value: {}".format(eps))
|
||||
if not 0.0 <= betas[0] < 1.0:
|
||||
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
|
||||
if not 0.0 <= betas[1] < 1.0:
|
||||
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
|
||||
|
||||
self.degenerated_to_sgd = degenerated_to_sgd
|
||||
if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict):
|
||||
for param in params:
|
||||
if 'betas' in param and (param['betas'][0] != betas[0] or param['betas'][1] != betas[1]):
|
||||
param['buffer'] = [[None, None, None] for _ in range(10)]
|
||||
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay,
|
||||
buffer=[[None, None, None] for _ in range(10)])
|
||||
super(RAdam, self).__init__(params, defaults)
|
||||
|
||||
def __setstate__(self, state):
|
||||
super(RAdam, self).__setstate__(state)
|
||||
|
||||
def step(self, closure=None):
|
||||
|
||||
loss = None
|
||||
if closure is not None:
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
|
||||
for p in group['params']:
|
||||
if p.grad is None:
|
||||
continue
|
||||
grad = p.grad.data.float()
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError('RAdam does not support sparse gradients')
|
||||
|
||||
p_data_fp32 = p.data.float()
|
||||
|
||||
state = self.state[p]
|
||||
|
||||
if len(state) == 0:
|
||||
state['step'] = 0
|
||||
state['exp_avg'] = torch.zeros_like(p_data_fp32)
|
||||
state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
|
||||
else:
|
||||
state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
|
||||
state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
|
||||
|
||||
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
||||
beta1, beta2 = group['betas']
|
||||
|
||||
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
|
||||
exp_avg.mul_(beta1).add_(1 - beta1, grad)
|
||||
|
||||
state['step'] += 1
|
||||
buffered = group['buffer'][int(state['step'] % 10)]
|
||||
if state['step'] == buffered[0]:
|
||||
N_sma, step_size = buffered[1], buffered[2]
|
||||
else:
|
||||
buffered[0] = state['step']
|
||||
beta2_t = beta2 ** state['step']
|
||||
N_sma_max = 2 / (1 - beta2) - 1
|
||||
N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
|
||||
buffered[1] = N_sma
|
||||
|
||||
# more conservative since it's an approximated value
|
||||
if N_sma >= 5:
|
||||
step_size = math.sqrt(
|
||||
(1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (
|
||||
N_sma_max - 2)) / (1 - beta1 ** state['step'])
|
||||
elif self.degenerated_to_sgd:
|
||||
step_size = 1.0 / (1 - beta1 ** state['step'])
|
||||
else:
|
||||
step_size = -1
|
||||
buffered[2] = step_size
|
||||
|
||||
# more conservative since it's an approximated value
|
||||
if N_sma >= 5:
|
||||
if group['weight_decay'] != 0:
|
||||
p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
|
||||
denom = exp_avg_sq.sqrt().add_(group['eps'])
|
||||
p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom)
|
||||
p.data.copy_(p_data_fp32)
|
||||
elif step_size > 0:
|
||||
if group['weight_decay'] != 0:
|
||||
p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
|
||||
p_data_fp32.add_(-step_size * group['lr'], exp_avg)
|
||||
p.data.copy_(p_data_fp32)
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
class PlainRAdam(Optimizer):
|
||||
|
||||
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=True):
|
||||
if not 0.0 <= lr:
|
||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||
if not 0.0 <= eps:
|
||||
raise ValueError("Invalid epsilon value: {}".format(eps))
|
||||
if not 0.0 <= betas[0] < 1.0:
|
||||
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
|
||||
if not 0.0 <= betas[1] < 1.0:
|
||||
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
|
||||
|
||||
self.degenerated_to_sgd = degenerated_to_sgd
|
||||
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
|
||||
|
||||
super(PlainRAdam, self).__init__(params, defaults)
|
||||
|
||||
def __setstate__(self, state):
|
||||
super(PlainRAdam, self).__setstate__(state)
|
||||
|
||||
def step(self, closure=None):
|
||||
|
||||
loss = None
|
||||
if closure is not None:
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
|
||||
for p in group['params']:
|
||||
if p.grad is None:
|
||||
continue
|
||||
grad = p.grad.data.float()
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError('RAdam does not support sparse gradients')
|
||||
|
||||
p_data_fp32 = p.data.float()
|
||||
|
||||
state = self.state[p]
|
||||
|
||||
if len(state) == 0:
|
||||
state['step'] = 0
|
||||
state['exp_avg'] = torch.zeros_like(p_data_fp32)
|
||||
state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
|
||||
else:
|
||||
state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
|
||||
state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
|
||||
|
||||
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
||||
beta1, beta2 = group['betas']
|
||||
|
||||
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
|
||||
exp_avg.mul_(beta1).add_(1 - beta1, grad)
|
||||
|
||||
state['step'] += 1
|
||||
beta2_t = beta2 ** state['step']
|
||||
N_sma_max = 2 / (1 - beta2) - 1
|
||||
N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
|
||||
|
||||
# more conservative since it's an approximated value
|
||||
if N_sma >= 5:
|
||||
if group['weight_decay'] != 0:
|
||||
p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
|
||||
step_size = group['lr'] * math.sqrt(
|
||||
(1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (
|
||||
N_sma_max - 2)) / (1 - beta1 ** state['step'])
|
||||
denom = exp_avg_sq.sqrt().add_(group['eps'])
|
||||
p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
|
||||
p.data.copy_(p_data_fp32)
|
||||
elif self.degenerated_to_sgd:
|
||||
if group['weight_decay'] != 0:
|
||||
p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
|
||||
step_size = group['lr'] / (1 - beta1 ** state['step'])
|
||||
p_data_fp32.add_(-step_size, exp_avg)
|
||||
p.data.copy_(p_data_fp32)
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
class AdamW(Optimizer):
|
||||
|
||||
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, warmup=0):
|
||||
if not 0.0 <= lr:
|
||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||
if not 0.0 <= eps:
|
||||
raise ValueError("Invalid epsilon value: {}".format(eps))
|
||||
if not 0.0 <= betas[0] < 1.0:
|
||||
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
|
||||
if not 0.0 <= betas[1] < 1.0:
|
||||
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
|
||||
|
||||
defaults = dict(lr=lr, betas=betas, eps=eps,
|
||||
weight_decay=weight_decay, warmup=warmup)
|
||||
super(AdamW, self).__init__(params, defaults)
|
||||
|
||||
def __setstate__(self, state):
|
||||
super(AdamW, self).__setstate__(state)
|
||||
|
||||
def step(self, closure=None):
|
||||
loss = None
|
||||
if closure is not None:
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
|
||||
for p in group['params']:
|
||||
if p.grad is None:
|
||||
continue
|
||||
grad = p.grad.data.float()
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
|
||||
|
||||
p_data_fp32 = p.data.float()
|
||||
|
||||
state = self.state[p]
|
||||
|
||||
if len(state) == 0:
|
||||
state['step'] = 0
|
||||
state['exp_avg'] = torch.zeros_like(p_data_fp32)
|
||||
state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
|
||||
else:
|
||||
state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
|
||||
state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
|
||||
|
||||
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
||||
beta1, beta2 = group['betas']
|
||||
|
||||
state['step'] += 1
|
||||
|
||||
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
|
||||
exp_avg.mul_(beta1).add_(1 - beta1, grad)
|
||||
|
||||
denom = exp_avg_sq.sqrt().add_(group['eps'])
|
||||
bias_correction1 = 1 - beta1 ** state['step']
|
||||
bias_correction2 = 1 - beta2 ** state['step']
|
||||
|
||||
if group['warmup'] > state['step']:
|
||||
scheduled_lr = 1e-8 + state['step'] * group['lr'] / group['warmup']
|
||||
else:
|
||||
scheduled_lr = group['lr']
|
||||
|
||||
step_size = scheduled_lr * math.sqrt(bias_correction2) / bias_correction1
|
||||
|
||||
if group['weight_decay'] != 0:
|
||||
p_data_fp32.add_(-group['weight_decay'] * scheduled_lr, p_data_fp32)
|
||||
|
||||
p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
|
||||
|
||||
p.data.copy_(p_data_fp32)
|
||||
|
||||
return loss
|
@ -1,103 +0,0 @@
|
||||
####
|
||||
# CODE TAKEN FROM https://github.com/mgrankin/over9000
|
||||
####
|
||||
|
||||
import torch, math
|
||||
from torch.optim.optimizer import Optimizer
|
||||
|
||||
# RAdam + LARS
|
||||
class Ralamb(Optimizer):
|
||||
|
||||
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
|
||||
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
|
||||
self.buffer = [[None, None, None] for ind in range(10)]
|
||||
super(Ralamb, self).__init__(params, defaults)
|
||||
|
||||
def __setstate__(self, state):
|
||||
super(Ralamb, self).__setstate__(state)
|
||||
|
||||
def step(self, closure=None):
|
||||
|
||||
loss = None
|
||||
if closure is not None:
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
|
||||
for p in group['params']:
|
||||
if p.grad is None:
|
||||
continue
|
||||
grad = p.grad.data.float()
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError('Ralamb does not support sparse gradients')
|
||||
|
||||
p_data_fp32 = p.data.float()
|
||||
|
||||
state = self.state[p]
|
||||
|
||||
if len(state) == 0:
|
||||
state['step'] = 0
|
||||
state['exp_avg'] = torch.zeros_like(p_data_fp32)
|
||||
state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
|
||||
else:
|
||||
state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
|
||||
state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
|
||||
|
||||
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
||||
beta1, beta2 = group['betas']
|
||||
|
||||
# Decay the first and second moment running average coefficient
|
||||
# m_t
|
||||
exp_avg.mul_(beta1).add_(1 - beta1, grad)
|
||||
# v_t
|
||||
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
|
||||
|
||||
state['step'] += 1
|
||||
buffered = self.buffer[int(state['step'] % 10)]
|
||||
|
||||
if state['step'] == buffered[0]:
|
||||
N_sma, radam_step_size = buffered[1], buffered[2]
|
||||
else:
|
||||
buffered[0] = state['step']
|
||||
beta2_t = beta2 ** state['step']
|
||||
N_sma_max = 2 / (1 - beta2) - 1
|
||||
N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
|
||||
buffered[1] = N_sma
|
||||
|
||||
# more conservative since it's an approximated value
|
||||
if N_sma >= 5:
|
||||
radam_step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step'])
|
||||
else:
|
||||
radam_step_size = 1.0 / (1 - beta1 ** state['step'])
|
||||
buffered[2] = radam_step_size
|
||||
|
||||
if group['weight_decay'] != 0:
|
||||
p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
|
||||
|
||||
# more conservative since it's an approximated value
|
||||
radam_step = p_data_fp32.clone()
|
||||
if N_sma >= 5:
|
||||
denom = exp_avg_sq.sqrt().add_(group['eps'])
|
||||
radam_step.addcdiv_(-radam_step_size * group['lr'], exp_avg, denom)
|
||||
else:
|
||||
radam_step.add_(-radam_step_size * group['lr'], exp_avg)
|
||||
|
||||
radam_norm = radam_step.pow(2).sum().sqrt()
|
||||
weight_norm = p.data.pow(2).sum().sqrt().clamp(0, 10)
|
||||
if weight_norm == 0 or radam_norm == 0:
|
||||
trust_ratio = 1
|
||||
else:
|
||||
trust_ratio = weight_norm / radam_norm
|
||||
|
||||
state['weight_norm'] = weight_norm
|
||||
state['adam_norm'] = radam_norm
|
||||
state['trust_ratio'] = trust_ratio
|
||||
|
||||
if N_sma >= 5:
|
||||
p_data_fp32.addcdiv_(-radam_step_size * group['lr'] * trust_ratio, exp_avg, denom)
|
||||
else:
|
||||
p_data_fp32.add_(-radam_step_size * group['lr'] * trust_ratio, exp_avg)
|
||||
|
||||
p.data.copy_(p_data_fp32)
|
||||
|
||||
return loss
|
@ -1,177 +0,0 @@
|
||||
# Ranger deep learning optimizer - RAdam + Lookahead + Gradient Centralization, combined into one optimizer.
|
||||
|
||||
# https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer
|
||||
# and/or
|
||||
# https://github.com/lessw2020/Best-Deep-Learning-Optimizers
|
||||
|
||||
# Ranger has now been used to capture 12 records on the FastAI leaderboard.
|
||||
|
||||
# This version = 20.4.11
|
||||
|
||||
# Credits:
|
||||
# Gradient Centralization --> https://arxiv.org/abs/2004.01461v2 (a new optimization technique for DNNs), github: https://github.com/Yonghongwei/Gradient-Centralization
|
||||
# RAdam --> https://github.com/LiyuanLucasLiu/RAdam
|
||||
# Lookahead --> rewritten by lessw2020, but big thanks to Github @LonePatient and @RWightman for ideas from their code.
|
||||
# Lookahead paper --> MZhang,G Hinton https://arxiv.org/abs/1907.08610
|
||||
|
||||
# summary of changes:
|
||||
# 4/11/20 - add gradient centralization option. Set new testing benchmark for accuracy with it, toggle with use_gc flag at init.
|
||||
# full code integration with all updates at param level instead of group, moves slow weights into state dict (from generic weights),
|
||||
# supports group learning rates (thanks @SHolderbach), fixes sporadic load from saved model issues.
|
||||
# changes 8/31/19 - fix references to *self*.N_sma_threshold;
|
||||
# changed eps to 1e-5 as better default than 1e-8.
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch.optim.optimizer import Optimizer
|
||||
|
||||
|
||||
class Ranger(Optimizer):
|
||||
|
||||
def __init__(self, params, lr=1e-3, # lr
|
||||
alpha=0.5, k=6, N_sma_threshhold=5, # Ranger options
|
||||
betas=(.95, 0.999), eps=1e-5, weight_decay=0, # Adam options
|
||||
use_gc=True, gc_conv_only=False
|
||||
# Gradient centralization on or off, applied to conv layers only or conv + fc layers
|
||||
):
|
||||
|
||||
# parameter checks
|
||||
if not 0.0 <= alpha <= 1.0:
|
||||
raise ValueError(f'Invalid slow update rate: {alpha}')
|
||||
if not 1 <= k:
|
||||
raise ValueError(f'Invalid lookahead steps: {k}')
|
||||
if not lr > 0:
|
||||
raise ValueError(f'Invalid Learning Rate: {lr}')
|
||||
if not eps > 0:
|
||||
raise ValueError(f'Invalid eps: {eps}')
|
||||
|
||||
# parameter comments:
|
||||
# beta1 (momentum) of .95 seems to work better than .90...
|
||||
# N_sma_threshold of 5 seems better in testing than 4.
|
||||
# In both cases, worth testing on your dataset (.90 vs .95, 4 vs 5) to make sure which works best for you.
|
||||
|
||||
# prep defaults and init torch.optim base
|
||||
defaults = dict(lr=lr, alpha=alpha, k=k, step_counter=0, betas=betas, N_sma_threshhold=N_sma_threshhold,
|
||||
eps=eps, weight_decay=weight_decay)
|
||||
super().__init__(params, defaults)
|
||||
|
||||
# adjustable threshold
|
||||
self.N_sma_threshhold = N_sma_threshhold
|
||||
|
||||
# look ahead params
|
||||
|
||||
self.alpha = alpha
|
||||
self.k = k
|
||||
|
||||
# radam buffer for state
|
||||
self.radam_buffer = [[None, None, None] for ind in range(10)]
|
||||
|
||||
# gc on or off
|
||||
self.use_gc = use_gc
|
||||
|
||||
# level of gradient centralization
|
||||
self.gc_gradient_threshold = 3 if gc_conv_only else 1
|
||||
|
||||
print(f"Ranger optimizer loaded. \nGradient Centralization usage = {self.use_gc}")
|
||||
if (self.use_gc and self.gc_gradient_threshold == 1):
|
||||
print(f"GC applied to both conv and fc layers")
|
||||
elif (self.use_gc and self.gc_gradient_threshold == 3):
|
||||
print(f"GC applied to conv layers only")
|
||||
|
||||
def __setstate__(self, state):
|
||||
print("set state called")
|
||||
super(Ranger, self).__setstate__(state)
|
||||
|
||||
def step(self, closure=None):
|
||||
loss = None
|
||||
# note - below is commented out b/c I have other work that passes back the loss as a float, and thus not a callable closure.
|
||||
# Uncomment if you need to use the actual closure...
|
||||
|
||||
# if closure is not None:
|
||||
# loss = closure()
|
||||
|
||||
# Evaluate averages and grad, update param tensors
|
||||
for group in self.param_groups:
|
||||
|
||||
for p in group['params']:
|
||||
if p.grad is None:
|
||||
continue
|
||||
grad = p.grad.data.float()
|
||||
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError('Ranger optimizer does not support sparse gradients')
|
||||
|
||||
p_data_fp32 = p.data.float()
|
||||
|
||||
state = self.state[p] # get state dict for this param
|
||||
|
||||
if len(state) == 0: # if first time to run...init dictionary with our desired entries
|
||||
# if self.first_run_check==0:
|
||||
# self.first_run_check=1
|
||||
# print("Initializing slow buffer...should not see this at load from saved model!")
|
||||
state['step'] = 0
|
||||
state['exp_avg'] = torch.zeros_like(p_data_fp32)
|
||||
state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
|
||||
|
||||
# look ahead weight storage now in state dict
|
||||
state['slow_buffer'] = torch.empty_like(p.data)
|
||||
state['slow_buffer'].copy_(p.data)
|
||||
|
||||
else:
|
||||
state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
|
||||
state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
|
||||
|
||||
# begin computations
|
||||
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
||||
beta1, beta2 = group['betas']
|
||||
|
||||
# GC operation for Conv layers and FC layers
|
||||
if grad.dim() > self.gc_gradient_threshold:
|
||||
grad.add_(-grad.mean(dim=tuple(range(1, grad.dim())), keepdim=True))
|
||||
|
||||
state['step'] += 1
|
||||
|
||||
# compute variance mov avg
|
||||
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
|
||||
# compute mean moving avg
|
||||
exp_avg.mul_(beta1).add_(1 - beta1, grad)
|
||||
|
||||
buffered = self.radam_buffer[int(state['step'] % 10)]
|
||||
|
||||
if state['step'] == buffered[0]:
|
||||
N_sma, step_size = buffered[1], buffered[2]
|
||||
else:
|
||||
buffered[0] = state['step']
|
||||
beta2_t = beta2 ** state['step']
|
||||
N_sma_max = 2 / (1 - beta2) - 1
|
||||
N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
|
||||
buffered[1] = N_sma
|
||||
if N_sma > self.N_sma_threshhold:
|
||||
step_size = math.sqrt(
|
||||
(1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (
|
||||
N_sma_max - 2)) / (1 - beta1 ** state['step'])
|
||||
else:
|
||||
step_size = 1.0 / (1 - beta1 ** state['step'])
|
||||
buffered[2] = step_size
|
||||
|
||||
if group['weight_decay'] != 0:
|
||||
p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
|
||||
|
||||
# apply lr
|
||||
if N_sma > self.N_sma_threshhold:
|
||||
denom = exp_avg_sq.sqrt().add_(group['eps'])
|
||||
p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom)
|
||||
else:
|
||||
p_data_fp32.add_(-step_size * group['lr'], exp_avg)
|
||||
|
||||
p.data.copy_(p_data_fp32)
|
||||
|
||||
# integrated look ahead...
|
||||
# we do it at the param level instead of group level
|
||||
if state['step'] % group['k'] == 0:
|
||||
slow_p = state['slow_buffer'] # get access to slow param tensor
|
||||
slow_p.add_(self.alpha, p.data - slow_p) # (fast weights - slow weights) * alpha
|
||||
p.data.copy_(slow_p) # copy interpolated weights to RAdam param tensor
|
||||
|
||||
return loss
|
115
fastreid/solver/optim/sgd.py
Normal file
115
fastreid/solver/optim/sgd.py
Normal file
@ -0,0 +1,115 @@
|
||||
# encoding: utf-8
|
||||
"""
|
||||
@author: xingyu liao
|
||||
@contact: liaoxingyu5@jd.com
|
||||
"""
|
||||
|
||||
|
||||
import torch
|
||||
from torch.optim.optimizer import Optimizer, required
|
||||
|
||||
|
||||
class SGD(Optimizer):
|
||||
r"""Implements stochastic gradient descent (optionally with momentum).
|
||||
|
||||
Nesterov momentum is based on the formula from
|
||||
`On the importance of initialization and momentum in deep learning`__.
|
||||
|
||||
Args:
|
||||
params (iterable): iterable of parameters to optimize or dicts defining
|
||||
parameter groups
|
||||
lr (float): learning rate
|
||||
momentum (float, optional): momentum factor (default: 0)
|
||||
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
|
||||
dampening (float, optional): dampening for momentum (default: 0)
|
||||
nesterov (bool, optional): enables Nesterov momentum (default: False)
|
||||
|
||||
Example:
|
||||
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
|
||||
>>> optimizer.zero_grad()
|
||||
>>> loss_fn(model(input), target).backward()
|
||||
>>> optimizer.step()
|
||||
|
||||
__ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf
|
||||
|
||||
.. note::
|
||||
The implementation of SGD with Momentum/Nesterov subtly differs from
|
||||
Sutskever et. al. and implementations in some other frameworks.
|
||||
|
||||
Considering the specific case of Momentum, the update can be written as
|
||||
|
||||
.. math::
|
||||
v = \rho * v + g \\
|
||||
p = p - lr * v
|
||||
|
||||
where p, g, v and :math:`\rho` denote the parameters, gradient,
|
||||
velocity, and momentum respectively.
|
||||
|
||||
This is in contrast to Sutskever et. al. and
|
||||
other frameworks which employ an update of the form
|
||||
|
||||
.. math::
|
||||
v = \rho * v + lr * g \\
|
||||
p = p - v
|
||||
|
||||
The Nesterov version is analogously modified.
|
||||
"""
|
||||
|
||||
def __init__(self, params, lr=required, momentum=0, dampening=0,
|
||||
weight_decay=0, nesterov=False):
|
||||
if lr is not required and lr < 0.0:
|
||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||
if momentum < 0.0:
|
||||
raise ValueError("Invalid momentum value: {}".format(momentum))
|
||||
if weight_decay < 0.0:
|
||||
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
|
||||
|
||||
defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
|
||||
weight_decay=weight_decay, nesterov=nesterov)
|
||||
if nesterov and (momentum <= 0 or dampening != 0):
|
||||
raise ValueError("Nesterov momentum requires a momentum and zero dampening")
|
||||
super(SGD, self).__init__(params, defaults)
|
||||
|
||||
def __setstate__(self, state):
|
||||
super(SGD, self).__setstate__(state)
|
||||
for group in self.param_groups:
|
||||
group.setdefault('nesterov', False)
|
||||
|
||||
def step(self, closure=None):
|
||||
"""Performs a single optimization step.
|
||||
|
||||
Arguments:
|
||||
closure (callable, optional): A closure that reevaluates the model
|
||||
and returns the loss.
|
||||
"""
|
||||
loss = None
|
||||
if closure is not None:
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
weight_decay = group['weight_decay']
|
||||
momentum = group['momentum']
|
||||
dampening = group['dampening']
|
||||
nesterov = group['nesterov']
|
||||
|
||||
for p in group['params']:
|
||||
if p.grad is None or group['freeze']:
|
||||
continue
|
||||
d_p = p.grad.data
|
||||
if weight_decay != 0:
|
||||
d_p.add_(weight_decay, p.data)
|
||||
if momentum != 0:
|
||||
param_state = self.state[p]
|
||||
if 'momentum_buffer' not in param_state:
|
||||
buf = param_state['momentum_buffer'] = torch.clone(d_p).detach()
|
||||
else:
|
||||
buf = param_state['momentum_buffer']
|
||||
buf.mul_(momentum).add_(1 - dampening, d_p)
|
||||
if nesterov:
|
||||
d_p = d_p.add(momentum, buf)
|
||||
else:
|
||||
d_p = buf
|
||||
|
||||
p.data.add_(-group['lr'], d_p)
|
||||
|
||||
return loss
|
@ -11,7 +11,6 @@ from typing import Any
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import Dataset
|
||||
from termcolor import colored
|
||||
from torch.nn.parallel import DataParallel, DistributedDataParallel
|
||||
|
||||
@ -27,7 +26,6 @@ class Checkpointer(object):
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
dataset: Dataset = None,
|
||||
save_dir: str = "",
|
||||
*,
|
||||
save_to_disk: bool = True,
|
||||
@ -47,7 +45,6 @@ class Checkpointer(object):
|
||||
if isinstance(model, (DistributedDataParallel, DataParallel)):
|
||||
model = model.module
|
||||
self.model = model
|
||||
self.dataset = dataset
|
||||
self.checkpointables = copy.copy(checkpointables)
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.save_dir = save_dir
|
||||
@ -65,8 +62,6 @@ class Checkpointer(object):
|
||||
|
||||
data = {}
|
||||
data["model"] = self.model.state_dict()
|
||||
if self.dataset is not None:
|
||||
data["pid_dict"] = self.dataset.pid_dict
|
||||
for key, obj in self.checkpointables.items():
|
||||
data[key] = obj.state_dict()
|
||||
data.update(kwargs)
|
||||
@ -104,9 +99,6 @@ class Checkpointer(object):
|
||||
assert os.path.isfile(path), "Checkpoint {} not found!".format(path)
|
||||
|
||||
checkpoint = self._load_file(path)
|
||||
if self.dataset is not None:
|
||||
self.logger.info("Loading dataset pid dictionary from {}".format(path))
|
||||
self._load_dataset_pid_dict(checkpoint)
|
||||
self._load_model(checkpoint)
|
||||
for key, obj in self.checkpointables.items():
|
||||
if key in checkpoint:
|
||||
@ -191,10 +183,6 @@ class Checkpointer(object):
|
||||
"""
|
||||
return torch.load(f, map_location=torch.device("cpu"))
|
||||
|
||||
def _load_dataset_pid_dict(self, checkpoint: Any):
|
||||
checkpoint_pid_dict = checkpoint.pop("pid_dict")
|
||||
self.dataset.update_pid_dict(checkpoint_pid_dict)
|
||||
|
||||
def _load_model(self, checkpoint: Any):
|
||||
"""
|
||||
Load weights from a checkpoint.
|
||||
|
158
fastreid/utils/collect_env.py
Normal file
158
fastreid/utils/collect_env.py
Normal file
@ -0,0 +1,158 @@
|
||||
# encoding: utf-8
|
||||
"""
|
||||
@author: xingyu liao
|
||||
@contact: liaoxingyu5@jd.com
|
||||
"""
|
||||
|
||||
# based on
|
||||
# https://github.com/facebookresearch/detectron2/blob/master/detectron2/utils/collect_env.py
|
||||
import importlib
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
|
||||
import PIL
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision
|
||||
from tabulate import tabulate
|
||||
|
||||
__all__ = ["collect_env_info"]
|
||||
|
||||
|
||||
def collect_torch_env():
|
||||
try:
|
||||
import torch.__config__
|
||||
|
||||
return torch.__config__.show()
|
||||
except ImportError:
|
||||
# compatible with older versions of pytorch
|
||||
from torch.utils.collect_env import get_pretty_env_info
|
||||
|
||||
return get_pretty_env_info()
|
||||
|
||||
|
||||
def get_env_module():
|
||||
var_name = "DETECTRON2_ENV_MODULE"
|
||||
return var_name, os.environ.get(var_name, "<not set>")
|
||||
|
||||
|
||||
def detect_compute_compatibility(CUDA_HOME, so_file):
|
||||
try:
|
||||
cuobjdump = os.path.join(CUDA_HOME, "bin", "cuobjdump")
|
||||
if os.path.isfile(cuobjdump):
|
||||
output = subprocess.check_output(
|
||||
"'{}' --list-elf '{}'".format(cuobjdump, so_file), shell=True
|
||||
)
|
||||
output = output.decode("utf-8").strip().split("\n")
|
||||
sm = []
|
||||
for line in output:
|
||||
line = re.findall(r"\.sm_[0-9]*\.", line)[0]
|
||||
sm.append(line.strip("."))
|
||||
sm = sorted(set(sm))
|
||||
return ", ".join(sm)
|
||||
else:
|
||||
return so_file + "; cannot find cuobjdump"
|
||||
except Exception:
|
||||
# unhandled failure
|
||||
return so_file
|
||||
|
||||
|
||||
def collect_env_info():
|
||||
has_gpu = torch.cuda.is_available() # true for both CUDA & ROCM
|
||||
torch_version = torch.__version__
|
||||
|
||||
# NOTE: the use of CUDA_HOME and ROCM_HOME requires the CUDA/ROCM build deps, though in
|
||||
# theory detectron2 should be made runnable with only the corresponding runtimes
|
||||
from torch.utils.cpp_extension import CUDA_HOME
|
||||
|
||||
has_rocm = False
|
||||
if tuple(map(int, torch_version.split(".")[:2])) >= (1, 5):
|
||||
from torch.utils.cpp_extension import ROCM_HOME
|
||||
|
||||
if (getattr(torch.version, "hip", None) is not None) and (ROCM_HOME is not None):
|
||||
has_rocm = True
|
||||
has_cuda = has_gpu and (not has_rocm)
|
||||
|
||||
data = []
|
||||
data.append(("sys.platform", sys.platform))
|
||||
data.append(("Python", sys.version.replace("\n", "")))
|
||||
data.append(("numpy", np.__version__))
|
||||
|
||||
try:
|
||||
import fastreid # noqa
|
||||
|
||||
data.append(
|
||||
("fastreid", fastreid.__version__ + " @" + os.path.dirname(fastreid.__file__))
|
||||
)
|
||||
except ImportError:
|
||||
data.append(("fastreid", "failed to import"))
|
||||
|
||||
data.append(get_env_module())
|
||||
data.append(("PyTorch", torch_version + " @" + os.path.dirname(torch.__file__)))
|
||||
data.append(("PyTorch debug build", torch.version.debug))
|
||||
|
||||
data.append(("GPU available", has_gpu))
|
||||
if has_gpu:
|
||||
devices = defaultdict(list)
|
||||
for k in range(torch.cuda.device_count()):
|
||||
devices[torch.cuda.get_device_name(k)].append(str(k))
|
||||
for name, devids in devices.items():
|
||||
data.append(("GPU " + ",".join(devids), name))
|
||||
|
||||
if has_rocm:
|
||||
data.append(("ROCM_HOME", str(ROCM_HOME)))
|
||||
else:
|
||||
data.append(("CUDA_HOME", str(CUDA_HOME)))
|
||||
|
||||
cuda_arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST", None)
|
||||
if cuda_arch_list:
|
||||
data.append(("TORCH_CUDA_ARCH_LIST", cuda_arch_list))
|
||||
data.append(("Pillow", PIL.__version__))
|
||||
|
||||
try:
|
||||
data.append(
|
||||
(
|
||||
"torchvision",
|
||||
str(torchvision.__version__) + " @" + os.path.dirname(torchvision.__file__),
|
||||
)
|
||||
)
|
||||
if has_cuda:
|
||||
try:
|
||||
torchvision_C = importlib.util.find_spec("torchvision._C").origin
|
||||
msg = detect_compute_compatibility(CUDA_HOME, torchvision_C)
|
||||
data.append(("torchvision arch flags", msg))
|
||||
except ImportError:
|
||||
data.append(("torchvision._C", "failed to find"))
|
||||
except AttributeError:
|
||||
data.append(("torchvision", "unknown"))
|
||||
|
||||
try:
|
||||
import fvcore
|
||||
|
||||
data.append(("fvcore", fvcore.__version__))
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
import cv2
|
||||
|
||||
data.append(("cv2", cv2.__version__))
|
||||
except ImportError:
|
||||
pass
|
||||
env_str = tabulate(data) + "\n"
|
||||
env_str += collect_torch_env()
|
||||
return env_str
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
import detectron2 # noqa
|
||||
except ImportError:
|
||||
print(collect_env_info())
|
||||
else:
|
||||
from fastreid.utils.collect_env import collect_env_info
|
||||
|
||||
print(collect_env_info())
|
119
fastreid/utils/env.py
Normal file
119
fastreid/utils/env.py
Normal file
@ -0,0 +1,119 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
import importlib
|
||||
import importlib.util
|
||||
import logging
|
||||
import numpy as np
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
from datetime import datetime
|
||||
import torch
|
||||
|
||||
__all__ = ["seed_all_rng"]
|
||||
|
||||
|
||||
TORCH_VERSION = tuple(int(x) for x in torch.__version__.split(".")[:2])
|
||||
"""
|
||||
PyTorch version as a tuple of 2 ints. Useful for comparison.
|
||||
"""
|
||||
|
||||
|
||||
def seed_all_rng(seed=None):
|
||||
"""
|
||||
Set the random seed for the RNG in torch, numpy and python.
|
||||
Args:
|
||||
seed (int): if None, will use a strong random seed.
|
||||
"""
|
||||
if seed is None:
|
||||
seed = (
|
||||
os.getpid()
|
||||
+ int(datetime.now().strftime("%S%f"))
|
||||
+ int.from_bytes(os.urandom(2), "big")
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info("Using a generated random seed {}".format(seed))
|
||||
np.random.seed(seed)
|
||||
torch.set_rng_state(torch.manual_seed(seed).get_state())
|
||||
random.seed(seed)
|
||||
|
||||
|
||||
# from https://stackoverflow.com/questions/67631/how-to-import-a-module-given-the-full-path
|
||||
def _import_file(module_name, file_path, make_importable=False):
|
||||
spec = importlib.util.spec_from_file_location(module_name, file_path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
if make_importable:
|
||||
sys.modules[module_name] = module
|
||||
return module
|
||||
|
||||
|
||||
def _configure_libraries():
|
||||
"""
|
||||
Configurations for some libraries.
|
||||
"""
|
||||
# An environment option to disable `import cv2` globally,
|
||||
# in case it leads to negative performance impact
|
||||
disable_cv2 = int(os.environ.get("DETECTRON2_DISABLE_CV2", False))
|
||||
if disable_cv2:
|
||||
sys.modules["cv2"] = None
|
||||
else:
|
||||
# Disable opencl in opencv since its interaction with cuda often has negative effects
|
||||
# This envvar is supported after OpenCV 3.4.0
|
||||
os.environ["OPENCV_OPENCL_RUNTIME"] = "disabled"
|
||||
try:
|
||||
import cv2
|
||||
|
||||
if int(cv2.__version__.split(".")[0]) >= 3:
|
||||
cv2.ocl.setUseOpenCL(False)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
def get_version(module, digit=2):
|
||||
return tuple(map(int, module.__version__.split(".")[:digit]))
|
||||
|
||||
# fmt: off
|
||||
assert get_version(torch) >= (1, 4), "Requires torch>=1.4"
|
||||
import yaml
|
||||
assert get_version(yaml) >= (5, 1), "Requires pyyaml>=5.1"
|
||||
# fmt: on
|
||||
|
||||
|
||||
_ENV_SETUP_DONE = False
|
||||
|
||||
|
||||
def setup_environment():
|
||||
"""Perform environment setup work. The default setup is a no-op, but this
|
||||
function allows the user to specify a Python source file or a module in
|
||||
the $DETECTRON2_ENV_MODULE environment variable, that performs
|
||||
custom setup work that may be necessary to their computing environment.
|
||||
"""
|
||||
global _ENV_SETUP_DONE
|
||||
if _ENV_SETUP_DONE:
|
||||
return
|
||||
_ENV_SETUP_DONE = True
|
||||
|
||||
_configure_libraries()
|
||||
|
||||
custom_module_path = os.environ.get("FASTREID_ENV_MODULE")
|
||||
|
||||
if custom_module_path:
|
||||
setup_custom_environment(custom_module_path)
|
||||
else:
|
||||
# The default setup is a no-op
|
||||
pass
|
||||
|
||||
|
||||
def setup_custom_environment(custom_module):
|
||||
"""
|
||||
Load custom environment setup by importing a Python source file or a
|
||||
module, and run the setup function.
|
||||
"""
|
||||
if custom_module.endswith(".py"):
|
||||
module = _import_file("fastreid.utils.env.custom_module", custom_module)
|
||||
else:
|
||||
module = importlib.import_module(custom_module)
|
||||
assert hasattr(module, "setup_environment") and callable(module.setup_environment), (
|
||||
"Custom environment module defined in {} does not have the "
|
||||
"required callable attribute 'setup_environment'."
|
||||
).format(custom_module)
|
||||
module.setup_environment()
|
@ -1,58 +0,0 @@
|
||||
# encoding: utf-8
|
||||
"""
|
||||
@author: xingyu liao
|
||||
@contact: liaoxingyu5@jd.com
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
# based on:
|
||||
# https://github.com/kornia/kornia/blob/master/kornia/utils/one_hot.py
|
||||
|
||||
|
||||
def one_hot(labels: torch.Tensor,
|
||||
num_classes: int,
|
||||
dtype: Optional[torch.dtype] = None, ) -> torch.Tensor:
|
||||
# eps: Optional[float] = 1e-6) -> torch.Tensor:
|
||||
r"""Converts an integer label x-D tensor to a one-hot (x+1)-D tensor.
|
||||
Args:
|
||||
labels (torch.Tensor) : tensor with labels of shape :math:`(N, *)`,
|
||||
where N is batch size. Each value is an integer
|
||||
representing correct classification.
|
||||
num_classes (int): number of classes in labels.
|
||||
device (Optional[torch.device]): the desired device of returned tensor.
|
||||
Default: if None, uses the current device for the default tensor type
|
||||
(see torch.set_default_tensor_type()). device will be the CPU for CPU
|
||||
tensor types and the current CUDA device for CUDA tensor types.
|
||||
dtype (Optional[torch.dtype]): the desired data type of returned
|
||||
tensor. Default: if None, infers data type from values.
|
||||
Returns:
|
||||
torch.Tensor: the labels in one hot tensor of shape :math:`(N, C, *)`,
|
||||
Examples::
|
||||
>>> labels = torch.LongTensor([[[0, 1], [2, 0]]])
|
||||
>>> one_hot(labels, num_classes=3)
|
||||
tensor([[[[1., 0.],
|
||||
[0., 1.]],
|
||||
[[0., 1.],
|
||||
[0., 0.]],
|
||||
[[0., 0.],
|
||||
[1., 0.]]]]
|
||||
"""
|
||||
if not torch.is_tensor(labels):
|
||||
raise TypeError("Input labels type is not a torch.Tensor. Got {}"
|
||||
.format(type(labels)))
|
||||
if not labels.dtype == torch.int64:
|
||||
raise ValueError(
|
||||
"labels must be of the same dtype torch.int64. Got: {}".format(
|
||||
labels.dtype))
|
||||
if num_classes < 1:
|
||||
raise ValueError("The number of classes must be bigger than one."
|
||||
" Got: {}".format(num_classes))
|
||||
device = labels.device
|
||||
shape = labels.shape
|
||||
one_hot = torch.zeros(shape[0], num_classes, *shape[1:],
|
||||
device=device, dtype=dtype)
|
||||
return one_hot.scatter_(1, labels.unsqueeze(1), 1.0)
|
@ -57,9 +57,7 @@ def update_bn_stats(model, data_loader, num_iters: int = 200):
|
||||
running_var = [torch.zeros_like(bn.running_var) for bn in bn_layers]
|
||||
|
||||
for ind, inputs in enumerate(itertools.islice(data_loader, num_iters)):
|
||||
# Change targets to zero to avoid error in circle(arcface) loss
|
||||
# which will use targets in forward
|
||||
inputs['targets'].zero_()
|
||||
inputs['targets'].fill_(-1)
|
||||
with torch.no_grad(): # No need to backward
|
||||
model(inputs)
|
||||
for i, bn in enumerate(bn_layers):
|
||||
|
@ -63,7 +63,6 @@ if __name__ == '__main__':
|
||||
|
||||
model = build_model(cfg)
|
||||
Checkpointer(model).load(cfg.MODEL.WEIGHTS)
|
||||
model.cuda()
|
||||
model.eval()
|
||||
print(model)
|
||||
|
||||
|
@ -1,48 +0,0 @@
|
||||
# encoding: utf-8
|
||||
"""
|
||||
@author: sherlock
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import sys
|
||||
|
||||
import torch
|
||||
sys.path.append('../..')
|
||||
from fastreid.config import get_cfg
|
||||
from fastreid.engine import default_argument_parser, default_setup
|
||||
from fastreid.modeling.meta_arch import build_model
|
||||
from fastreid.export.tensorflow_export import export_tf_reid_model
|
||||
from fastreid.export.tf_modeling import TfMetaArch
|
||||
|
||||
|
||||
def setup(args):
|
||||
"""
|
||||
Create configs and perform basic setups.
|
||||
"""
|
||||
cfg = get_cfg()
|
||||
# cfg.merge_from_file(args.config_file)
|
||||
cfg.merge_from_list(args.opts)
|
||||
cfg.freeze()
|
||||
default_setup(cfg, args)
|
||||
return cfg
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = default_argument_parser().parse_args()
|
||||
print("Command Line Args:", args)
|
||||
cfg = setup(args)
|
||||
cfg.defrost()
|
||||
cfg.MODEL.BACKBONE.NAME = "build_resnet_backbone"
|
||||
cfg.MODEL.BACKBONE.DEPTH = 50
|
||||
cfg.MODEL.BACKBONE.LAST_STRIDE = 1
|
||||
# If use IBN block in backbone
|
||||
cfg.MODEL.BACKBONE.WITH_IBN = False
|
||||
cfg.MODEL.BACKBONE.PRETRAIN = False
|
||||
|
||||
from torchvision.models import resnet50
|
||||
# model = TfMetaArch(cfg)
|
||||
model = resnet50(pretrained=False)
|
||||
# model.load_params_wo_fc(torch.load('logs/bjstation/res50_baseline_v0.4/ckpts/model_epoch80.pth'))
|
||||
model.eval()
|
||||
dummy_inputs = torch.randn(1, 3, 256, 128)
|
||||
export_tf_reid_model(model, dummy_inputs, 'reid_tf.pb')
|
@ -5,16 +5,14 @@
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.append('.')
|
||||
|
||||
from torch import nn
|
||||
|
||||
from fastreid.config import get_cfg
|
||||
from fastreid.engine import DefaultTrainer, default_argument_parser, default_setup
|
||||
from fastreid.engine import DefaultTrainer, default_argument_parser, default_setup, launch
|
||||
from fastreid.utils.checkpoint import Checkpointer
|
||||
from fastreid.engine import hooks
|
||||
from fastreid.evaluation import ReidEvaluator
|
||||
@ -43,15 +41,14 @@ def setup(args):
|
||||
def main(args):
|
||||
cfg = setup(args)
|
||||
|
||||
logger = logging.getLogger('fastreid.' + __name__)
|
||||
if args.eval_only:
|
||||
logger = logging.getLogger("fastreid.trainer")
|
||||
cfg.defrost()
|
||||
cfg.MODEL.BACKBONE.PRETRAIN = False
|
||||
model = Trainer.build_model(cfg)
|
||||
model = nn.DataParallel(model)
|
||||
model = model.cuda()
|
||||
|
||||
Checkpointer(model, save_dir=cfg.OUTPUT_DIR).load(cfg.MODEL.WEIGHTS) # load trained model
|
||||
Checkpointer(model).load(cfg.MODEL.WEIGHTS) # load trained model
|
||||
|
||||
if cfg.TEST.PRECISE_BN.ENABLED and hooks.get_bn_modules(model):
|
||||
prebn_cfg = cfg.clone()
|
||||
prebn_cfg.DATALOADER.NUM_WORKERS = 0 # save some memory and time for PreciseBN
|
||||
@ -75,4 +72,11 @@ def main(args):
|
||||
if __name__ == "__main__":
|
||||
args = default_argument_parser().parse_args()
|
||||
print("Command Line Args:", args)
|
||||
main(args)
|
||||
launch(
|
||||
main,
|
||||
args.num_gpus,
|
||||
num_machines=args.num_machines,
|
||||
machine_rank=args.machine_rank,
|
||||
dist_url=args.dist_url,
|
||||
args=(args,),
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user