update docs

pull/389/head
liaoxingyu 2021-01-22 21:11:19 +08:00
parent 274cd81dab
commit b5c3c0a24d
26 changed files with 105 additions and 18 deletions

View File

@ -28,7 +28,7 @@ The designed architecture follows this guide [PyTorch-Project-Template](https://
See [GETTING_STARTED.md](https://github.com/JDAI-CV/fast-reid/blob/master/GETTING_STARTED.md).
Learn more at out [documentation](). And see [projects/](https://github.com/JDAI-CV/fast-reid/tree/master/projects) for some projects that are build on top of fastreid.
Learn more at out [documentation](https://fast-reid.readthedocs.io/). And see [projects/](https://github.com/JDAI-CV/fast-reid/tree/master/projects) for some projects that are build on top of fastreid.
## Model Zoo and Baselines

1
docs/.gitignore vendored 100644
View File

@ -0,0 +1 @@
_build

View File

@ -4,4 +4,14 @@
@contact: sherlockliao01@gmail.com
"""
from .build import build_reid_train_loader, build_reid_test_loader
from . import transforms # isort:skip
from .build import (
build_reid_train_loader,
build_reid_test_loader
)
from .common import CommDataset
# ensure the builtin datasets are registered
from . import datasets, samplers # isort:skip
__all__ = [k for k in globals().keys() if not k.startswith("_")]

View File

@ -16,10 +16,25 @@ from .common import CommDataset
from .datasets import DATASET_REGISTRY
from .transforms import build_transforms
__all__ = [
"build_reid_train_loader",
"build_reid_test_loader"
]
_root = os.getenv("FASTREID_DATASETS", "datasets")
def build_reid_train_loader(cfg, mapper=None, **kwargs):
"""
Build reid train loader
Args:
cfg : image file path
mapper : one of the supported image modes in PIL, or "BGR"
Returns:
torch.utils.data.DataLoader: a dataloader.
"""
cfg = cfg.clone()
train_items = list()
@ -60,6 +75,19 @@ def build_reid_train_loader(cfg, mapper=None, **kwargs):
def build_reid_test_loader(cfg, dataset_name, mapper=None, **kwargs):
"""
Build reid test loader
Args:
cfg:
dataset_name:
mapper:
**kwargs:
Returns:
"""
cfg = cfg.clone()
dataset = DATASET_REGISTRY.get(dataset_name)(root=_root, **kwargs)

View File

@ -13,6 +13,7 @@ def read_image(file_name, format=None):
"""
Read an image into the given format.
Will apply rotation and flipping if the image has such exif information.
Args:
file_name (str): image file path
format (str): one of the supported image modes in PIL, or "BGR"

View File

@ -14,6 +14,9 @@ __all__ = ['AirportALERT', ]
@DATASET_REGISTRY.register()
class AirportALERT(ImageDataset):
"""Airport
"""
dataset_dir = "AirportALERT"
dataset_name = "airport"

View File

@ -7,6 +7,7 @@
import copy
import logging
import os
from tabulate import tabulate
from termcolor import colored
@ -16,6 +17,7 @@ logger = logging.getLogger(__name__)
class Dataset(object):
"""An abstract class representing a Dataset.
This is the base class for ``ImageDataset`` and ``VideoDataset``.
Args:
train (list): contains tuples of (img_path(s), pid, camid).
query (list): contains tuples of (img_path(s), pid, camid).

View File

@ -19,6 +19,8 @@ __all__ = ['CAVIARa',]
@DATASET_REGISTRY.register()
class CAVIARa(ImageDataset):
"""CAVIARa
"""
dataset_dir = "CAVIARa"
dataset_name = "caviara"

View File

@ -15,7 +15,7 @@ from ..datasets import DATASET_REGISTRY
@DATASET_REGISTRY.register()
class cuhkSYSU(ImageDataset):
r"""CUHK SYSU datasets.
"""CUHK SYSU datasets.
The dataset is collected from two sources: street snap and movie.
In street snap, 12,490 images and 6,057 query persons were collected

View File

@ -15,6 +15,8 @@ __all__ = ['iLIDS', ]
@DATASET_REGISTRY.register()
class iLIDS(ImageDataset):
"""iLIDS
"""
dataset_dir = "iLIDS"
dataset_name = "ilids"

View File

@ -15,7 +15,9 @@ __all__ = ['LPW', ]
@DATASET_REGISTRY.register()
class LPW(ImageDataset):
dataset_dir = "pep_256x128"
"""LPW
"""
dataset_dir = "pep_256x128/data_slim"
dataset_name = "lpw"
def __init__(self, root='datasets', **kwargs):

View File

@ -15,6 +15,8 @@ __all__ = ['PeS3D',]
@DATASET_REGISTRY.register()
class PeS3D(ImageDataset):
"""3Dpes
"""
dataset_dir = "3DPeS"
dataset_name = "pes3d"

View File

@ -15,6 +15,8 @@ __all__ = ['PKU', ]
@DATASET_REGISTRY.register()
class PKU(ImageDataset):
"""PKU
"""
dataset_dir = "PKUv1a_128x48"
dataset_name = 'pku'

View File

@ -5,18 +5,18 @@
"""
import os
from scipy.io import loadmat
from glob import glob
from fastreid.data.datasets import DATASET_REGISTRY
from fastreid.data.datasets.bases import ImageDataset
import pdb
__all__ = ['PRAI', ]
@DATASET_REGISTRY.register()
class PRAI(ImageDataset):
"""PRAI
"""
dataset_dir = "PRAI-1581"
dataset_name = 'prai'
@ -41,4 +41,3 @@ class PRAI(ImageDataset):
camid = self.dataset_name + "_" + img_info[1]
data.append([img_path, pid, camid])
return data

View File

@ -15,6 +15,8 @@ __all__ = ['SAIVT', ]
@DATASET_REGISTRY.register()
class SAIVT(ImageDataset):
"""SAIVT
"""
dataset_dir = "SAIVT-SoftBio"
dataset_name = "saivt"

View File

@ -15,6 +15,8 @@ __all__ = ['SenseReID', ]
@DATASET_REGISTRY.register()
class SenseReID(ImageDataset):
"""Sense reid
"""
dataset_dir = "SenseReID"
dataset_name = "senseid"

View File

@ -14,6 +14,8 @@ __all__ = ['Shinpuhkan', ]
@DATASET_REGISTRY.register()
class Shinpuhkan(ImageDataset):
"""shinpuhkan
"""
dataset_dir = "shinpuhkan"
dataset_name = 'shinpuhkan'

View File

@ -17,6 +17,8 @@ __all__ = ['SYSU_mm', ]
@DATASET_REGISTRY.register()
class SYSU_mm(ImageDataset):
"""sysu mm
"""
dataset_dir = "SYSU-MM01"
dataset_name = "sysumm01"

View File

@ -19,6 +19,8 @@ __all__ = ['Thermalworld',]
@DATASET_REGISTRY.register()
class Thermalworld(ImageDataset):
"""thermal world
"""
dataset_dir = "thermalworld_rgb"
dataset_name = "thermalworld"

View File

@ -6,3 +6,10 @@
from .triplet_sampler import BalancedIdentitySampler, NaiveIdentitySampler
from .data_sampler import TrainingSampler, InferenceSampler
__all__ = [
"BalancedIdentitySampler",
"NaiveIdentitySampler",
"TrainingSampler",
"InferenceSampler"
]

View File

@ -4,6 +4,8 @@
@contact: sherlockliao01@gmail.com
"""
from .autoaugment import *
from .autoaugment import AutoAugment
from .build import build_transforms
from .transforms import *
__all__ = [k for k in globals().keys() if not k.startswith("_")]

View File

@ -149,13 +149,9 @@ class AugMix(object):
np.random.dirichlet([self.aug_prob_coeff] * self.mixture_width))
m = np.float32(np.random.beta(self.aug_prob_coeff, self.aug_prob_coeff))
# image = np.asarray(image, dtype=np.float32).copy()
# mix = np.zeros_like(image)
mix = np.zeros([image.size[1], image.size[0], 3])
# h, w = image.shape[0], image.shape[1]
for i in range(self.mixture_width):
image_aug = image.copy()
# image_aug = Image.fromarray(image.copy().astype(np.uint8))
depth = self.mixture_depth if self.mixture_depth > 0 else np.random.randint(1, 4)
for _ in range(depth):
op = np.random.choice(self.augmentations)

View File

@ -4,4 +4,20 @@
@contact: sherlockliao01@gmail.com
"""
from .meta_arch import build_model
from . import losses
from .backbones import (
BACKBONE_REGISTRY,
build_resnet_backbone,
build_backbone,
)
from .heads import (
REID_HEADS_REGISTRY,
build_heads,
EmbeddingHead,
)
from .meta_arch import (
build_model,
META_ARCH_REGISTRY,
)
__all__ = [k for k in globals().keys() if k not in k.startswith("_")]

View File

@ -10,8 +10,7 @@ BACKBONE_REGISTRY = Registry("BACKBONE")
BACKBONE_REGISTRY.__doc__ = """
Registry for backbones, which extract feature maps from images
The registered object must be a callable that accepts two arguments:
1. A :class:`detectron2.config.CfgNode`
2. A :class:`detectron2.layers.ShapeSpec`, which contains the input shape specification.
1. A :class:`fastreid.config.CfgNode`
It must returns an instance of :class:`Backbone`.
"""

View File

@ -8,7 +8,8 @@ from ...utils.registry import Registry
REID_HEADS_REGISTRY = Registry("HEADS")
REID_HEADS_REGISTRY.__doc__ = """
Registry for ROI heads in a generalized R-CNN model.
Registry for reid heads in a baseline model.
ROIHeads take feature maps and region proposals, and
perform per-region computation.
The registered object will be called with `obj(cfg, input_shape)`.

View File

@ -4,7 +4,9 @@
@contact: sherlockliao01@gmail.com
"""
from .circle_loss import *
from .cross_entroy_loss import cross_entropy_loss, log_accuracy
from .focal_loss import focal_loss
from .triplet_loss import triplet_loss
from .circle_loss import *
__all__ = [k for k in globals().keys() if k not in k.startswith("_")]