From 91ff63118434fc427dbf4fd693b16b9e8a4b436d Mon Sep 17 00:00:00 2001 From: liaoxingyu Date: Mon, 31 May 2021 17:27:14 +0800 Subject: [PATCH] Minor changes Some minor changes, such as class name changing, remove extra blank line, etc. --- demo/README.md | 17 +--- demo/demo.py | 12 ++- fastreid/data/transforms/build.py | 4 +- fastreid/layers/__init__.py | 8 +- fastreid/layers/any_softmax.py | 2 +- fastreid/layers/batch_drop.py | 32 -------- fastreid/layers/pooling.py | 4 +- fastreid/modeling/losses/triplet_loss.py | 2 +- .../FastAttr/fastattr/modeling/attr_head.py | 2 +- projects/FastClas/fastclas/__init__.py | 2 + projects/FastClas/fastclas/bee_ant.py | 1 + projects/FastClas/fastclas/dataset.py | 20 +++-- projects/FastClas/fastclas/trainer.py | 82 +++++++++++++++++++ projects/FastClas/train_net.py | 74 +++-------------- projects/FastTune/configs/search_trial.yml | 3 - .../PartialReID/configs/partial_market.yml | 5 +- projects/PartialReID/partialreid/dsr_head.py | 2 +- tools/deploy/onnx_export.py | 1 - 18 files changed, 137 insertions(+), 136 deletions(-) delete mode 100644 fastreid/layers/batch_drop.py create mode 100644 projects/FastClas/fastclas/trainer.py diff --git a/demo/README.md b/demo/README.md index dc4daee..572e51f 100644 --- a/demo/README.md +++ b/demo/README.md @@ -2,18 +2,9 @@ We provide a command line tool to run a simple demo of builtin models. -You can run this command to get rank visualization results by cosine similarites between different images. +You can run this command to get cosine similarites between different images -```shell script -python3 demo/visualize_result.py --config-file logs/dukemtmc/mgn_R50-ibn/config.yaml \ ---parallel --vis-label --dataset-name 'DukeMTMC' --output logs/mgn_duke_vis \ ---opts MODEL.WEIGHTS logs/dukemtmc/mgn_R50-ibn/model_final.pth -``` - -You can also run this command to extract image features. - -```shell script -python3 demo/demo.py --config-file logs/dukemtmc/sbs_R50/config.yaml \ ---parallel --input tools/deploy/test_data/*.jpg --output sbs_R50_feat \ ---opts MODEL.WEIGHTS logs/dukemtmc/sbs_R50/model_final.pth +```bash +cd demo/ +sh run_demo.sh ``` \ No newline at end of file diff --git a/demo/demo.py b/demo/demo.py index 23405b6..26f330f 100644 --- a/demo/demo.py +++ b/demo/demo.py @@ -9,6 +9,7 @@ import glob import os import sys +import torch.nn.functional as F import cv2 import numpy as np import tqdm @@ -23,7 +24,7 @@ from fastreid.utils.file_io import PathManager from predictor import FeatureExtractionDemo # import some modules added in project like this below -# sys.path.append('../projects/PartialReID') +# sys.path.append("projects/PartialReID") # from partialreid import * cudnn.benchmark = True @@ -72,6 +73,13 @@ def get_parser(): return parser +def postprocess(features): + # Normalize feature to compute cosine distance + features = F.normalize(features) + features = features.cpu().data.numpy() + return features + + if __name__ == '__main__': args = get_parser().parse_args() cfg = setup_cfg(args) @@ -85,5 +93,5 @@ if __name__ == '__main__': for path in tqdm.tqdm(args.input): img = cv2.imread(path) feat = demo.run_on_image(img) - feat = feat.numpy() + feat = postprocess(feat) np.save(os.path.join(args.output, os.path.basename(path).split('.')[0] + '.npy'), feat) diff --git a/fastreid/data/transforms/build.py b/fastreid/data/transforms/build.py index 1404143..5ca06a7 100644 --- a/fastreid/data/transforms/build.py +++ b/fastreid/data/transforms/build.py @@ -78,8 +78,8 @@ def build_transforms(cfg, is_train=True): if do_cj: res.append(T.RandomApply([T.ColorJitter(cj_brightness, cj_contrast, cj_saturation, cj_hue)], p=cj_prob)) if do_affine: - res.append(T.RandomAffine(degrees=0, translate=None, scale=[0.9, 1.1], shear=None, resample=False, - fillcolor=128)) + res.append(T.RandomAffine(degrees=10, translate=None, scale=[0.9, 1.1], shear=0.1, resample=False, + fillcolor=0)) if do_augmix: res.append(AugMix(prob=augmix_prob)) res.append(ToTensor()) diff --git a/fastreid/layers/__init__.py b/fastreid/layers/__init__.py index 1f575b4..e68064c 100644 --- a/fastreid/layers/__init__.py +++ b/fastreid/layers/__init__.py @@ -5,11 +5,15 @@ """ from .activation import * -from .batch_drop import BatchDrop from .batch_norm import * from .context_block import ContextBlock +from .drop import DropPath, DropBlock2d, drop_block_2d, drop_path from .frn import FRN, TLU +from .gather_layer import GatherLayer +from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible from .non_local import Non_local from .se_layer import SELayer from .splat import SplAtConv2d, DropBlock2D -from .gather_layer import GatherLayer +from .weight_init import ( + trunc_normal_, variance_scaling_, lecun_normal_, weights_init_kaiming, weights_init_classifier +) diff --git a/fastreid/layers/any_softmax.py b/fastreid/layers/any_softmax.py index 9d643bd..3d8392d 100644 --- a/fastreid/layers/any_softmax.py +++ b/fastreid/layers/any_softmax.py @@ -23,7 +23,7 @@ class Linear(nn.Module): self.m = margin def forward(self, logits, targets): - return logits + return logits.mul_(self.s) def extra_repr(self): return f"num_classes={self.num_classes}, scale={self.s}, margin={self.m}" diff --git a/fastreid/layers/batch_drop.py b/fastreid/layers/batch_drop.py deleted file mode 100644 index 5c25697..0000000 --- a/fastreid/layers/batch_drop.py +++ /dev/null @@ -1,32 +0,0 @@ -# encoding: utf-8 -""" -@author: liaoxingyu -@contact: sherlockliao01@gmail.com -""" - -import random - -from torch import nn - - -class BatchDrop(nn.Module): - """ref: https://github.com/daizuozhuo/batch-dropblock-network/blob/master/models/networks.py - batch drop mask - """ - - def __init__(self, h_ratio, w_ratio): - super(BatchDrop, self).__init__() - self.h_ratio = h_ratio - self.w_ratio = w_ratio - - def forward(self, x): - if self.training: - h, w = x.size()[-2:] - rh = round(self.h_ratio * h) - rw = round(self.w_ratio * w) - sx = random.randint(0, h - rh) - sy = random.randint(0, w - rw) - mask = x.new_ones(x.size()) - mask[:, :, sx:sx + rh, sy:sy + rw] = 0 - x = x * mask - return x diff --git a/fastreid/layers/pooling.py b/fastreid/layers/pooling.py index ff8bdb2..2f6c603 100644 --- a/fastreid/layers/pooling.py +++ b/fastreid/layers/pooling.py @@ -61,7 +61,7 @@ class GeneralizedMeanPooling(nn.Module): be the same as that of the input. """ - def __init__(self, norm=3, output_size=1, eps=1e-6, *args, **kwargs): + def __init__(self, norm=3, output_size=(1, 1), eps=1e-6, *args, **kwargs): super(GeneralizedMeanPooling, self).__init__() assert norm > 0 self.p = float(norm) @@ -82,7 +82,7 @@ class GeneralizedMeanPoolingP(GeneralizedMeanPooling): """ Same, but norm is trainable """ - def __init__(self, norm=3, output_size=1, eps=1e-6, *args, **kwargs): + def __init__(self, norm=3, output_size=(1, 1), eps=1e-6, *args, **kwargs): super(GeneralizedMeanPoolingP, self).__init__(norm, output_size, eps) self.p = nn.Parameter(torch.ones(1) * norm) diff --git a/fastreid/modeling/losses/triplet_loss.py b/fastreid/modeling/losses/triplet_loss.py index b1cd11d..5275008 100644 --- a/fastreid/modeling/losses/triplet_loss.py +++ b/fastreid/modeling/losses/triplet_loss.py @@ -42,7 +42,7 @@ def hard_example_mining(dist_mat, is_pos, is_neg): dist_ap, _ = torch.max(dist_mat * is_pos, dim=1) # `dist_an` means distance(anchor, negative) # both `dist_an` and `relative_n_inds` with shape [N] - dist_an, _ = torch.min(dist_mat * is_neg + is_pos * 99999999., dim=1) + dist_an, _ = torch.min(dist_mat * is_neg + is_pos * 1e9, dim=1) return dist_ap, dist_an diff --git a/projects/FastAttr/fastattr/modeling/attr_head.py b/projects/FastAttr/fastattr/modeling/attr_head.py index 621de8a..dc7972e 100644 --- a/projects/FastAttr/fastattr/modeling/attr_head.py +++ b/projects/FastAttr/fastattr/modeling/attr_head.py @@ -10,7 +10,7 @@ from torch import nn from fastreid.modeling.heads import EmbeddingHead from fastreid.modeling.heads.build import REID_HEADS_REGISTRY -from fastreid.utils.weight_init import weights_init_kaiming +from fastreid.layers.weight_init import weights_init_kaiming @REID_HEADS_REGISTRY.register() diff --git a/projects/FastClas/fastclas/__init__.py b/projects/FastClas/fastclas/__init__.py index c1a8bc1..cfabb8c 100644 --- a/projects/FastClas/fastclas/__init__.py +++ b/projects/FastClas/fastclas/__init__.py @@ -5,4 +5,6 @@ """ from .bee_ant import * +from .distracted_driver import * from .dataset import ClasDataset +from .trainer import ClasTrainer diff --git a/projects/FastClas/fastclas/bee_ant.py b/projects/FastClas/fastclas/bee_ant.py index d4a9e7d..dbcd4de 100644 --- a/projects/FastClas/fastclas/bee_ant.py +++ b/projects/FastClas/fastclas/bee_ant.py @@ -10,6 +10,7 @@ import os from fastreid.data.datasets import DATASET_REGISTRY from fastreid.data.datasets.bases import ImageDataset + __all__ = ["Hymenoptera"] diff --git a/projects/FastClas/fastclas/dataset.py b/projects/FastClas/fastclas/dataset.py index e80f5bb..4681a41 100644 --- a/projects/FastClas/fastclas/dataset.py +++ b/projects/FastClas/fastclas/dataset.py @@ -12,18 +12,22 @@ from fastreid.data.data_utils import read_image class ClasDataset(Dataset): """Image Person ReID Dataset""" - def __init__(self, img_items, transform=None): + def __init__(self, img_items, transform=None, idx_to_class=None): self.img_items = img_items self.transform = transform - classes = set() - for i in img_items: - classes.add(i[1]) + if idx_to_class is not None: + self.idx_to_class = idx_to_class + self.class_to_idx = {clas_name: int(i) for i, clas_name in self.idx_to_class.items()} + self.classes = sorted(list(self.idx_to_class.values())) + else: + classes = set() + for i in img_items: + classes.add(i[1]) - self.classes = list(classes) - self.classes.sort() - self.class_to_idx = {cls_name: i for i, cls_name in enumerate(self.classes)} - self.idx_to_class = {idx: clas for clas, idx in self.class_to_idx.items()} + self.classes = sorted(list(classes)) + self.class_to_idx = {cls_name: i for i, cls_name in enumerate(self.classes)} + self.idx_to_class = {idx: clas for clas, idx in self.class_to_idx.items()} def __len__(self): return len(self.img_items) diff --git a/projects/FastClas/fastclas/trainer.py b/projects/FastClas/fastclas/trainer.py new file mode 100644 index 0000000..bf0b76c --- /dev/null +++ b/projects/FastClas/fastclas/trainer.py @@ -0,0 +1,82 @@ +# encoding: utf-8 +""" +@author: xingyu liao +@contact: sherlockliao01@gmail.com +""" + +import json +import logging +import os + +from fastreid.data.build import _root +from fastreid.data.build import build_reid_train_loader, build_reid_test_loader +from fastreid.data.datasets import DATASET_REGISTRY +from fastreid.data.transforms import build_transforms +from fastreid.engine import DefaultTrainer +from fastreid.evaluation.clas_evaluator import ClasEvaluator +from fastreid.utils import comm +from fastreid.utils.checkpoint import PathManager +from .dataset import ClasDataset + + +class ClasTrainer(DefaultTrainer): + idx2class = None + + @classmethod + def build_train_loader(cls, cfg): + """ + Returns: + iterable + It now calls :func:`fastreid.data.build_reid_train_loader`. + Overwrite it if you'd like a different data loader. + """ + logger = logging.getLogger("fastreid.clas_dataset") + logger.info("Prepare training set") + + train_items = list() + for d in cfg.DATASETS.NAMES: + data = DATASET_REGISTRY.get(d)(root=_root) + if comm.is_main_process(): + data.show_train() + train_items.extend(data.train) + transforms = build_transforms(cfg, is_train=True) + train_set = ClasDataset(train_items, transforms) + cls.idx2class = train_set.idx_to_class + + data_loader = build_reid_train_loader(cfg, train_set=train_set) + return data_loader + + @classmethod + def build_test_loader(cls, cfg, dataset_name): + """ + Returns: + iterable + It now calls :func:`fastreid.data.build_reid_test_loader`. + Overwrite it if you'd like a different data loader. + """ + data = DATASET_REGISTRY.get(dataset_name)(root=_root) + if comm.is_main_process(): + data.show_test() + transforms = build_transforms(cfg, is_train=False) + + test_set = ClasDataset(data.query, transforms, cls.idx2class) + data_loader, _ = build_reid_test_loader(cfg, test_set=test_set) + return data_loader + + @classmethod + def build_evaluator(cls, cfg, dataset_name, output_dir=None): + data_loader = cls.build_test_loader(cfg, dataset_name) + return data_loader, ClasEvaluator(cfg, output_dir) + + @staticmethod + def auto_scale_hyperparams(cfg, num_classes): + cfg = DefaultTrainer.auto_scale_hyperparams(cfg, num_classes) + + # Save index to class dictionary + output_dir = cfg.OUTPUT_DIR + if comm.is_main_process() and output_dir: + path = os.path.join(output_dir, "idx2class.json") + with PathManager.open(path, "w") as f: + json.dump(ClasTrainer.idx2class, f) + + return cfg diff --git a/projects/FastClas/train_net.py b/projects/FastClas/train_net.py index 4d779da..82c1f47 100644 --- a/projects/FastClas/train_net.py +++ b/projects/FastClas/train_net.py @@ -14,75 +14,11 @@ sys.path.append('.') from fastreid.config import get_cfg from fastreid.engine import default_argument_parser, default_setup, launch -from fastreid.data.build import build_reid_train_loader, build_reid_test_loader -from fastreid.evaluation.clas_evaluator import ClasEvaluator from fastreid.utils.checkpoint import Checkpointer, PathManager -from fastreid.utils import comm -from fastreid.engine import DefaultTrainer -from fastreid.data.datasets import DATASET_REGISTRY -from fastreid.data.transforms import build_transforms -from fastreid.data.build import _root from fastclas import * -class ClasTrainer(DefaultTrainer): - - @classmethod - def build_train_loader(cls, cfg): - """ - Returns: - iterable - It now calls :func:`fastreid.data.build_reid_train_loader`. - Overwrite it if you'd like a different data loader. - """ - logger = logging.getLogger("fastreid.clas_dataset") - logger.info("Prepare training set") - - train_items = list() - for d in cfg.DATASETS.NAMES: - data = DATASET_REGISTRY.get(d)(root=_root) - if comm.is_main_process(): - data.show_train() - train_items.extend(data.train) - - transforms = build_transforms(cfg, is_train=True) - train_set = ClasDataset(train_items, transforms) - - data_loader = build_reid_train_loader(cfg, train_set=train_set) - - # Save index to class dictionary - output_dir = cfg.OUTPUT_DIR - if comm.is_main_process() and output_dir: - path = os.path.join(output_dir, "idx2class.json") - with PathManager.open(path, "w") as f: - json.dump(train_set.idx_to_class, f) - - return data_loader - - @classmethod - def build_test_loader(cls, cfg, dataset_name): - """ - Returns: - iterable - It now calls :func:`fastreid.data.build_reid_test_loader`. - Overwrite it if you'd like a different data loader. - """ - - data = DATASET_REGISTRY.get(dataset_name)(root=_root) - if comm.is_main_process(): - data.show_test() - transforms = build_transforms(cfg, is_train=False) - test_set = ClasDataset(data.query, transforms) - data_loader, _ = build_reid_test_loader(cfg, test_set=test_set) - return data_loader - - @classmethod - def build_evaluator(cls, cfg, dataset_name, output_dir=None): - data_loader = cls.build_test_loader(cfg, dataset_name) - return data_loader, ClasEvaluator(cfg, output_dir) - - def setup(args): """ Create configs and perform basic setups. @@ -105,6 +41,16 @@ def main(args): Checkpointer(model).load(cfg.MODEL.WEIGHTS) # load trained model + try: + output_dir = os.path.dirname(cfg.MODEL.WEIGHTS) + path = os.path.join(output_dir, "idx2class.json") + with PathManager.open(path, 'r') as f: + idx2class = json.load(f) + ClasTrainer.idx2class = idx2class + except: + logger = logging.getLogger("fastreid.fastclas") + logger.info(f"Cannot find idx2class dict in {os.path.dirname(cfg.MODEL.WEIGHTS)}") + res = ClasTrainer.test(cfg, model) return res diff --git a/projects/FastTune/configs/search_trial.yml b/projects/FastTune/configs/search_trial.yml index fe67d5c..b960e0a 100644 --- a/projects/FastTune/configs/search_trial.yml +++ b/projects/FastTune/configs/search_trial.yml @@ -55,9 +55,6 @@ INPUT: PADDING: ENABLED: True - FLIP: - ENABLED: True - DATALOADER: SAMPLER_TRAIN: NaiveIdentitySampler NUM_INSTANCE: 16 diff --git a/projects/PartialReID/configs/partial_market.yml b/projects/PartialReID/configs/partial_market.yml index a561311..51f8c74 100644 --- a/projects/PartialReID/configs/partial_market.yml +++ b/projects/PartialReID/configs/partial_market.yml @@ -26,7 +26,7 @@ MODEL: TRI: MARGIN: 0.3 SCALE: 1.0 - HARD_MINING: True + HARD_MINING: False DATASETS: NAMES: ("Market1501",) @@ -44,7 +44,6 @@ DATALOADER: NUM_INSTANCE: 4 NUM_WORKERS: 8 - SOLVER: AMP: ENABLED: False @@ -71,4 +70,4 @@ TEST: CUDNN_BENCHMARK: True -OUTPUT_DIR: "projects/PartialReID/logs/test_partial" \ No newline at end of file +OUTPUT_DIR: projects/PartialReID/logs/test_partial \ No newline at end of file diff --git a/projects/PartialReID/partialreid/dsr_head.py b/projects/PartialReID/partialreid/dsr_head.py index 57b03ce..784bb30 100644 --- a/projects/PartialReID/partialreid/dsr_head.py +++ b/projects/PartialReID/partialreid/dsr_head.py @@ -11,7 +11,7 @@ from torch import nn from fastreid.layers import * from fastreid.modeling.heads import EmbeddingHead from fastreid.modeling.heads.build import REID_HEADS_REGISTRY -from fastreid.utils.weight_init import weights_init_kaiming +from fastreid.layers.weight_init import weights_init_kaiming class OcclusionUnit(nn.Module): diff --git a/tools/deploy/onnx_export.py b/tools/deploy/onnx_export.py index 108c2de..de5d0f7 100644 --- a/tools/deploy/onnx_export.py +++ b/tools/deploy/onnx_export.py @@ -28,7 +28,6 @@ from fastreid.utils.logger import setup_logger # sys.path.append("projects/FastDistill") # from fastdistill import * - setup_logger(name="fastreid") logger = logging.getLogger("fastreid.onnx_export")