mirror of https://github.com/JDAI-CV/fast-reid.git
update training instruction
Summary: update dataset configuration and training instructionpull/150/head
parent
727a746831
commit
8879db3fba
|
@ -12,3 +12,25 @@ Then you should set the pretrain model path in `configs/Base-bagtricks.yml`.
|
||||||
```bash
|
```bash
|
||||||
cd fastreid/evaluation/rank_cylib; make all
|
cd fastreid/evaluation/rank_cylib; make all
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Training & Evaluation in Command Line
|
||||||
|
|
||||||
|
We provide a script in "tools/train_net.py", that is made to train all the configs provided in fastreid.
|
||||||
|
You may want to use it as a reference to write your own training script.
|
||||||
|
|
||||||
|
To train a model with "train_net.py", first setup up the corresponding datasets following [datasets/README.md](https://github.com/JDAI-CV/fast-reid/tree/master/datasets), then run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
CUDA_VISIBLE_DEVICES=$gpus tools/train_net.py --config-file ./configs/Market1501/bagtricks_R50.yml
|
||||||
|
```
|
||||||
|
|
||||||
|
The configs are made for 1-GPU training.
|
||||||
|
|
||||||
|
To evaluate a model's performance, use
|
||||||
|
|
||||||
|
```bash
|
||||||
|
CUDA_VISIBLE_DEVICES=$gpus tools/train_net.py --config-file ./configs/Market1501/bagtricks_R50.yml \
|
||||||
|
--eval-only MODEL.WEIGHTS /path/to/checkpoint_file
|
||||||
|
```
|
||||||
|
|
||||||
|
For more options, see `./train_net.py -h`.
|
||||||
|
|
|
@ -2,9 +2,9 @@
|
||||||
|
|
||||||
Fastreid has buildin support for a few datasets. The datasets are assumed to exist in a directory specified by the environment variable `FASTREID_DATASETS`. Under this directory, fastreid expects to find datasets in the structure described below.
|
Fastreid has buildin support for a few datasets. The datasets are assumed to exist in a directory specified by the environment variable `FASTREID_DATASETS`. Under this directory, fastreid expects to find datasets in the structure described below.
|
||||||
|
|
||||||
You can set the location for builtin datasets by `export FASTREID_DATASETS=/path/to/datasets/`. If left unset, the default is `./datasets` relative to your current working directory.
|
You can set the location for builtin datasets by `export FASTREID_DATASETS=/path/to/datasets/`. If left unset, the default is `datasets/` relative to your current working directory.
|
||||||
|
|
||||||
The model zoo contains configs and models that use these buildin datasets.
|
The [model zoo](https://github.com/JDAI-CV/fast-reid/blob/master/MODEL_ZOO.md) contains configs and models that use these buildin datasets.
|
||||||
|
|
||||||
## Expected dataset structure for Market1501
|
## Expected dataset structure for Market1501
|
||||||
|
|
||||||
|
@ -20,4 +20,24 @@ datasets/
|
||||||
|
|
||||||
## Expected dataset structure for DukeMTMC
|
## Expected dataset structure for DukeMTMC
|
||||||
|
|
||||||
|
1. Download datasets to `datasets/`
|
||||||
|
2. Extract dataset. The dataset structure would like:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
datasets/
|
||||||
|
DukeMTMC-reID/
|
||||||
|
bounding_box_train/
|
||||||
|
bounding_box_test/
|
||||||
|
```
|
||||||
|
|
||||||
## Expected dataset structure for MSMT17
|
## Expected dataset structure for MSMT17
|
||||||
|
|
||||||
|
1. Download datasets to `datasets/`
|
||||||
|
2. Extract dataset. The dataset structure would like:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
datasets/
|
||||||
|
MSMT17_V2/
|
||||||
|
mask_train_v2/
|
||||||
|
mask_test_v2/
|
||||||
|
```
|
||||||
|
|
|
@ -4,6 +4,7 @@
|
||||||
@contact: sherlockliao01@gmail.com
|
@contact: sherlockliao01@gmail.com
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
import torch
|
import torch
|
||||||
from torch._six import container_abcs, string_classes, int_classes
|
from torch._six import container_abcs, string_classes, int_classes
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
@ -13,13 +14,15 @@ from .common import CommDataset
|
||||||
from .datasets import DATASET_REGISTRY
|
from .datasets import DATASET_REGISTRY
|
||||||
from .transforms import build_transforms
|
from .transforms import build_transforms
|
||||||
|
|
||||||
|
_root = os.getenv("FASTREID_DATASETS", "datasets")
|
||||||
|
|
||||||
|
|
||||||
def build_reid_train_loader(cfg):
|
def build_reid_train_loader(cfg):
|
||||||
train_transforms = build_transforms(cfg, is_train=True)
|
train_transforms = build_transforms(cfg, is_train=True)
|
||||||
|
|
||||||
train_items = list()
|
train_items = list()
|
||||||
for d in cfg.DATASETS.NAMES:
|
for d in cfg.DATASETS.NAMES:
|
||||||
dataset = DATASET_REGISTRY.get(d)(combineall=cfg.DATASETS.COMBINEALL)
|
dataset = DATASET_REGISTRY.get(d)(root=_root, combineall=cfg.DATASETS.COMBINEALL)
|
||||||
dataset.show_train()
|
dataset.show_train()
|
||||||
train_items.extend(dataset.train)
|
train_items.extend(dataset.train)
|
||||||
|
|
||||||
|
@ -50,7 +53,7 @@ def build_reid_train_loader(cfg):
|
||||||
def build_reid_test_loader(cfg, dataset_name):
|
def build_reid_test_loader(cfg, dataset_name):
|
||||||
test_transforms = build_transforms(cfg, is_train=False)
|
test_transforms = build_transforms(cfg, is_train=False)
|
||||||
|
|
||||||
dataset = DATASET_REGISTRY.get(dataset_name)()
|
dataset = DATASET_REGISTRY.get(dataset_name)(root=_root)
|
||||||
dataset.show_test()
|
dataset.show_test()
|
||||||
test_items = dataset.query + dataset.gallery
|
test_items = dataset.query + dataset.gallery
|
||||||
|
|
||||||
|
|
|
@ -19,3 +19,5 @@ from .msmt17 import MSMT17
|
||||||
from .veri import VeRi
|
from .veri import VeRi
|
||||||
from .vehicleid import VehicleID, SmallVehicleID, MediumVehicleID, LargeVehicleID
|
from .vehicleid import VehicleID, SmallVehicleID, MediumVehicleID, LargeVehicleID
|
||||||
from .veriwild import VeRiWild, SmallVeRiWild, MediumVeRiWild, LargeVeRiWild
|
from .veriwild import VeRiWild, SmallVeRiWild, MediumVeRiWild, LargeVeRiWild
|
||||||
|
|
||||||
|
__all__ = [k for k in globals().keys() if "builtin" not in k and not k.startswith("_")]
|
||||||
|
|
|
@ -5,11 +5,8 @@
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
import os
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
class Dataset(object):
|
class Dataset(object):
|
||||||
|
|
|
@ -95,7 +95,7 @@ class Checkpointer(object):
|
||||||
if not path:
|
if not path:
|
||||||
# no checkpoint provided
|
# no checkpoint provided
|
||||||
self.logger.info(
|
self.logger.info(
|
||||||
"No checkpoint found. Initializing model from scratch"
|
"No checkpoint found. Training model from scratch"
|
||||||
)
|
)
|
||||||
return {}
|
return {}
|
||||||
self.logger.info("Loading checkpoint from {}".format(path))
|
self.logger.info("Loading checkpoint from {}".format(path))
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
# encoding: utf-8
|
# encoding: utf-8
|
||||||
"""
|
"""
|
||||||
@author: sherlock
|
@author: sherlock
|
||||||
|
|
Loading…
Reference in New Issue