mirror of https://github.com/JDAI-CV/fast-reid.git
update docs
parent
274cd81dab
commit
b5c3c0a24d
|
@ -28,7 +28,7 @@ The designed architecture follows this guide [PyTorch-Project-Template](https://
|
|||
|
||||
See [GETTING_STARTED.md](https://github.com/JDAI-CV/fast-reid/blob/master/GETTING_STARTED.md).
|
||||
|
||||
Learn more at out [documentation](). And see [projects/](https://github.com/JDAI-CV/fast-reid/tree/master/projects) for some projects that are build on top of fastreid.
|
||||
Learn more at out [documentation](https://fast-reid.readthedocs.io/). And see [projects/](https://github.com/JDAI-CV/fast-reid/tree/master/projects) for some projects that are build on top of fastreid.
|
||||
|
||||
## Model Zoo and Baselines
|
||||
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
_build
|
|
@ -4,4 +4,14 @@
|
|||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
from .build import build_reid_train_loader, build_reid_test_loader
|
||||
from . import transforms # isort:skip
|
||||
from .build import (
|
||||
build_reid_train_loader,
|
||||
build_reid_test_loader
|
||||
)
|
||||
from .common import CommDataset
|
||||
|
||||
# ensure the builtin datasets are registered
|
||||
from . import datasets, samplers # isort:skip
|
||||
|
||||
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
||||
|
|
|
@ -16,10 +16,25 @@ from .common import CommDataset
|
|||
from .datasets import DATASET_REGISTRY
|
||||
from .transforms import build_transforms
|
||||
|
||||
__all__ = [
|
||||
"build_reid_train_loader",
|
||||
"build_reid_test_loader"
|
||||
]
|
||||
|
||||
_root = os.getenv("FASTREID_DATASETS", "datasets")
|
||||
|
||||
|
||||
def build_reid_train_loader(cfg, mapper=None, **kwargs):
|
||||
"""
|
||||
Build reid train loader
|
||||
|
||||
Args:
|
||||
cfg : image file path
|
||||
mapper : one of the supported image modes in PIL, or "BGR"
|
||||
|
||||
Returns:
|
||||
torch.utils.data.DataLoader: a dataloader.
|
||||
"""
|
||||
cfg = cfg.clone()
|
||||
|
||||
train_items = list()
|
||||
|
@ -60,6 +75,19 @@ def build_reid_train_loader(cfg, mapper=None, **kwargs):
|
|||
|
||||
|
||||
def build_reid_test_loader(cfg, dataset_name, mapper=None, **kwargs):
|
||||
"""
|
||||
Build reid test loader
|
||||
|
||||
Args:
|
||||
cfg:
|
||||
dataset_name:
|
||||
mapper:
|
||||
**kwargs:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
|
||||
cfg = cfg.clone()
|
||||
|
||||
dataset = DATASET_REGISTRY.get(dataset_name)(root=_root, **kwargs)
|
||||
|
|
|
@ -13,6 +13,7 @@ def read_image(file_name, format=None):
|
|||
"""
|
||||
Read an image into the given format.
|
||||
Will apply rotation and flipping if the image has such exif information.
|
||||
|
||||
Args:
|
||||
file_name (str): image file path
|
||||
format (str): one of the supported image modes in PIL, or "BGR"
|
||||
|
|
|
@ -14,6 +14,9 @@ __all__ = ['AirportALERT', ]
|
|||
|
||||
@DATASET_REGISTRY.register()
|
||||
class AirportALERT(ImageDataset):
|
||||
"""Airport
|
||||
|
||||
"""
|
||||
dataset_dir = "AirportALERT"
|
||||
dataset_name = "airport"
|
||||
|
||||
|
|
|
@ -7,6 +7,7 @@
|
|||
import copy
|
||||
import logging
|
||||
import os
|
||||
|
||||
from tabulate import tabulate
|
||||
from termcolor import colored
|
||||
|
||||
|
@ -16,6 +17,7 @@ logger = logging.getLogger(__name__)
|
|||
class Dataset(object):
|
||||
"""An abstract class representing a Dataset.
|
||||
This is the base class for ``ImageDataset`` and ``VideoDataset``.
|
||||
|
||||
Args:
|
||||
train (list): contains tuples of (img_path(s), pid, camid).
|
||||
query (list): contains tuples of (img_path(s), pid, camid).
|
||||
|
|
|
@ -19,6 +19,8 @@ __all__ = ['CAVIARa',]
|
|||
|
||||
@DATASET_REGISTRY.register()
|
||||
class CAVIARa(ImageDataset):
|
||||
"""CAVIARa
|
||||
"""
|
||||
dataset_dir = "CAVIARa"
|
||||
dataset_name = "caviara"
|
||||
|
||||
|
|
|
@ -15,7 +15,7 @@ from ..datasets import DATASET_REGISTRY
|
|||
|
||||
@DATASET_REGISTRY.register()
|
||||
class cuhkSYSU(ImageDataset):
|
||||
r"""CUHK SYSU datasets.
|
||||
"""CUHK SYSU datasets.
|
||||
|
||||
The dataset is collected from two sources: street snap and movie.
|
||||
In street snap, 12,490 images and 6,057 query persons were collected
|
||||
|
|
|
@ -15,6 +15,8 @@ __all__ = ['iLIDS', ]
|
|||
|
||||
@DATASET_REGISTRY.register()
|
||||
class iLIDS(ImageDataset):
|
||||
"""iLIDS
|
||||
"""
|
||||
dataset_dir = "iLIDS"
|
||||
dataset_name = "ilids"
|
||||
|
||||
|
|
|
@ -15,7 +15,9 @@ __all__ = ['LPW', ]
|
|||
|
||||
@DATASET_REGISTRY.register()
|
||||
class LPW(ImageDataset):
|
||||
dataset_dir = "pep_256x128"
|
||||
"""LPW
|
||||
"""
|
||||
dataset_dir = "pep_256x128/data_slim"
|
||||
dataset_name = "lpw"
|
||||
|
||||
def __init__(self, root='datasets', **kwargs):
|
||||
|
|
|
@ -15,6 +15,8 @@ __all__ = ['PeS3D',]
|
|||
|
||||
@DATASET_REGISTRY.register()
|
||||
class PeS3D(ImageDataset):
|
||||
"""3Dpes
|
||||
"""
|
||||
dataset_dir = "3DPeS"
|
||||
dataset_name = "pes3d"
|
||||
|
||||
|
|
|
@ -15,6 +15,8 @@ __all__ = ['PKU', ]
|
|||
|
||||
@DATASET_REGISTRY.register()
|
||||
class PKU(ImageDataset):
|
||||
"""PKU
|
||||
"""
|
||||
dataset_dir = "PKUv1a_128x48"
|
||||
dataset_name = 'pku'
|
||||
|
||||
|
|
|
@ -5,18 +5,18 @@
|
|||
"""
|
||||
|
||||
import os
|
||||
from scipy.io import loadmat
|
||||
from glob import glob
|
||||
|
||||
from fastreid.data.datasets import DATASET_REGISTRY
|
||||
from fastreid.data.datasets.bases import ImageDataset
|
||||
import pdb
|
||||
|
||||
__all__ = ['PRAI', ]
|
||||
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class PRAI(ImageDataset):
|
||||
"""PRAI
|
||||
"""
|
||||
dataset_dir = "PRAI-1581"
|
||||
dataset_name = 'prai'
|
||||
|
||||
|
@ -41,4 +41,3 @@ class PRAI(ImageDataset):
|
|||
camid = self.dataset_name + "_" + img_info[1]
|
||||
data.append([img_path, pid, camid])
|
||||
return data
|
||||
|
||||
|
|
|
@ -15,6 +15,8 @@ __all__ = ['SAIVT', ]
|
|||
|
||||
@DATASET_REGISTRY.register()
|
||||
class SAIVT(ImageDataset):
|
||||
"""SAIVT
|
||||
"""
|
||||
dataset_dir = "SAIVT-SoftBio"
|
||||
dataset_name = "saivt"
|
||||
|
||||
|
|
|
@ -15,6 +15,8 @@ __all__ = ['SenseReID', ]
|
|||
|
||||
@DATASET_REGISTRY.register()
|
||||
class SenseReID(ImageDataset):
|
||||
"""Sense reid
|
||||
"""
|
||||
dataset_dir = "SenseReID"
|
||||
dataset_name = "senseid"
|
||||
|
||||
|
|
|
@ -14,6 +14,8 @@ __all__ = ['Shinpuhkan', ]
|
|||
|
||||
@DATASET_REGISTRY.register()
|
||||
class Shinpuhkan(ImageDataset):
|
||||
"""shinpuhkan
|
||||
"""
|
||||
dataset_dir = "shinpuhkan"
|
||||
dataset_name = 'shinpuhkan'
|
||||
|
||||
|
|
|
@ -17,6 +17,8 @@ __all__ = ['SYSU_mm', ]
|
|||
|
||||
@DATASET_REGISTRY.register()
|
||||
class SYSU_mm(ImageDataset):
|
||||
"""sysu mm
|
||||
"""
|
||||
dataset_dir = "SYSU-MM01"
|
||||
dataset_name = "sysumm01"
|
||||
|
||||
|
|
|
@ -19,6 +19,8 @@ __all__ = ['Thermalworld',]
|
|||
|
||||
@DATASET_REGISTRY.register()
|
||||
class Thermalworld(ImageDataset):
|
||||
"""thermal world
|
||||
"""
|
||||
dataset_dir = "thermalworld_rgb"
|
||||
dataset_name = "thermalworld"
|
||||
|
||||
|
|
|
@ -6,3 +6,10 @@
|
|||
|
||||
from .triplet_sampler import BalancedIdentitySampler, NaiveIdentitySampler
|
||||
from .data_sampler import TrainingSampler, InferenceSampler
|
||||
|
||||
__all__ = [
|
||||
"BalancedIdentitySampler",
|
||||
"NaiveIdentitySampler",
|
||||
"TrainingSampler",
|
||||
"InferenceSampler"
|
||||
]
|
||||
|
|
|
@ -4,6 +4,8 @@
|
|||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
from .autoaugment import *
|
||||
from .autoaugment import AutoAugment
|
||||
from .build import build_transforms
|
||||
from .transforms import *
|
||||
|
||||
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
||||
|
|
|
@ -149,13 +149,9 @@ class AugMix(object):
|
|||
np.random.dirichlet([self.aug_prob_coeff] * self.mixture_width))
|
||||
m = np.float32(np.random.beta(self.aug_prob_coeff, self.aug_prob_coeff))
|
||||
|
||||
# image = np.asarray(image, dtype=np.float32).copy()
|
||||
# mix = np.zeros_like(image)
|
||||
mix = np.zeros([image.size[1], image.size[0], 3])
|
||||
# h, w = image.shape[0], image.shape[1]
|
||||
for i in range(self.mixture_width):
|
||||
image_aug = image.copy()
|
||||
# image_aug = Image.fromarray(image.copy().astype(np.uint8))
|
||||
depth = self.mixture_depth if self.mixture_depth > 0 else np.random.randint(1, 4)
|
||||
for _ in range(depth):
|
||||
op = np.random.choice(self.augmentations)
|
||||
|
|
|
@ -4,4 +4,20 @@
|
|||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
from .meta_arch import build_model
|
||||
from . import losses
|
||||
from .backbones import (
|
||||
BACKBONE_REGISTRY,
|
||||
build_resnet_backbone,
|
||||
build_backbone,
|
||||
)
|
||||
from .heads import (
|
||||
REID_HEADS_REGISTRY,
|
||||
build_heads,
|
||||
EmbeddingHead,
|
||||
)
|
||||
from .meta_arch import (
|
||||
build_model,
|
||||
META_ARCH_REGISTRY,
|
||||
)
|
||||
|
||||
__all__ = [k for k in globals().keys() if k not in k.startswith("_")]
|
||||
|
|
|
@ -10,8 +10,7 @@ BACKBONE_REGISTRY = Registry("BACKBONE")
|
|||
BACKBONE_REGISTRY.__doc__ = """
|
||||
Registry for backbones, which extract feature maps from images
|
||||
The registered object must be a callable that accepts two arguments:
|
||||
1. A :class:`detectron2.config.CfgNode`
|
||||
2. A :class:`detectron2.layers.ShapeSpec`, which contains the input shape specification.
|
||||
1. A :class:`fastreid.config.CfgNode`
|
||||
It must returns an instance of :class:`Backbone`.
|
||||
"""
|
||||
|
||||
|
|
|
@ -8,7 +8,8 @@ from ...utils.registry import Registry
|
|||
|
||||
REID_HEADS_REGISTRY = Registry("HEADS")
|
||||
REID_HEADS_REGISTRY.__doc__ = """
|
||||
Registry for ROI heads in a generalized R-CNN model.
|
||||
Registry for reid heads in a baseline model.
|
||||
|
||||
ROIHeads take feature maps and region proposals, and
|
||||
perform per-region computation.
|
||||
The registered object will be called with `obj(cfg, input_shape)`.
|
||||
|
|
|
@ -4,7 +4,9 @@
|
|||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
from .circle_loss import *
|
||||
from .cross_entroy_loss import cross_entropy_loss, log_accuracy
|
||||
from .focal_loss import focal_loss
|
||||
from .triplet_loss import triplet_loss
|
||||
from .circle_loss import *
|
||||
|
||||
__all__ = [k for k in globals().keys() if k not in k.startswith("_")]
|
||||
|
|
Loading…
Reference in New Issue