mirror of https://github.com/JDAI-CV/fast-reid.git
update training instruction
Summary: update dataset configuration and training instructionpull/150/head
parent
727a746831
commit
8879db3fba
|
@ -12,3 +12,25 @@ Then you should set the pretrain model path in `configs/Base-bagtricks.yml`.
|
|||
```bash
|
||||
cd fastreid/evaluation/rank_cylib; make all
|
||||
```
|
||||
|
||||
## Training & Evaluation in Command Line
|
||||
|
||||
We provide a script in "tools/train_net.py", that is made to train all the configs provided in fastreid.
|
||||
You may want to use it as a reference to write your own training script.
|
||||
|
||||
To train a model with "train_net.py", first setup up the corresponding datasets following [datasets/README.md](https://github.com/JDAI-CV/fast-reid/tree/master/datasets), then run:
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=$gpus tools/train_net.py --config-file ./configs/Market1501/bagtricks_R50.yml
|
||||
```
|
||||
|
||||
The configs are made for 1-GPU training.
|
||||
|
||||
To evaluate a model's performance, use
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=$gpus tools/train_net.py --config-file ./configs/Market1501/bagtricks_R50.yml \
|
||||
--eval-only MODEL.WEIGHTS /path/to/checkpoint_file
|
||||
```
|
||||
|
||||
For more options, see `./train_net.py -h`.
|
||||
|
|
|
@ -2,9 +2,9 @@
|
|||
|
||||
Fastreid has buildin support for a few datasets. The datasets are assumed to exist in a directory specified by the environment variable `FASTREID_DATASETS`. Under this directory, fastreid expects to find datasets in the structure described below.
|
||||
|
||||
You can set the location for builtin datasets by `export FASTREID_DATASETS=/path/to/datasets/`. If left unset, the default is `./datasets` relative to your current working directory.
|
||||
You can set the location for builtin datasets by `export FASTREID_DATASETS=/path/to/datasets/`. If left unset, the default is `datasets/` relative to your current working directory.
|
||||
|
||||
The model zoo contains configs and models that use these buildin datasets.
|
||||
The [model zoo](https://github.com/JDAI-CV/fast-reid/blob/master/MODEL_ZOO.md) contains configs and models that use these buildin datasets.
|
||||
|
||||
## Expected dataset structure for Market1501
|
||||
|
||||
|
@ -20,4 +20,24 @@ datasets/
|
|||
|
||||
## Expected dataset structure for DukeMTMC
|
||||
|
||||
1. Download datasets to `datasets/`
|
||||
2. Extract dataset. The dataset structure would like:
|
||||
|
||||
```bash
|
||||
datasets/
|
||||
DukeMTMC-reID/
|
||||
bounding_box_train/
|
||||
bounding_box_test/
|
||||
```
|
||||
|
||||
## Expected dataset structure for MSMT17
|
||||
|
||||
1. Download datasets to `datasets/`
|
||||
2. Extract dataset. The dataset structure would like:
|
||||
|
||||
```bash
|
||||
datasets/
|
||||
MSMT17_V2/
|
||||
mask_train_v2/
|
||||
mask_test_v2/
|
||||
```
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import os
|
||||
import torch
|
||||
from torch._six import container_abcs, string_classes, int_classes
|
||||
from torch.utils.data import DataLoader
|
||||
|
@ -13,13 +14,15 @@ from .common import CommDataset
|
|||
from .datasets import DATASET_REGISTRY
|
||||
from .transforms import build_transforms
|
||||
|
||||
_root = os.getenv("FASTREID_DATASETS", "datasets")
|
||||
|
||||
|
||||
def build_reid_train_loader(cfg):
|
||||
train_transforms = build_transforms(cfg, is_train=True)
|
||||
|
||||
train_items = list()
|
||||
for d in cfg.DATASETS.NAMES:
|
||||
dataset = DATASET_REGISTRY.get(d)(combineall=cfg.DATASETS.COMBINEALL)
|
||||
dataset = DATASET_REGISTRY.get(d)(root=_root, combineall=cfg.DATASETS.COMBINEALL)
|
||||
dataset.show_train()
|
||||
train_items.extend(dataset.train)
|
||||
|
||||
|
@ -50,7 +53,7 @@ def build_reid_train_loader(cfg):
|
|||
def build_reid_test_loader(cfg, dataset_name):
|
||||
test_transforms = build_transforms(cfg, is_train=False)
|
||||
|
||||
dataset = DATASET_REGISTRY.get(dataset_name)()
|
||||
dataset = DATASET_REGISTRY.get(dataset_name)(root=_root)
|
||||
dataset.show_test()
|
||||
test_items = dataset.query + dataset.gallery
|
||||
|
||||
|
|
|
@ -19,3 +19,5 @@ from .msmt17 import MSMT17
|
|||
from .veri import VeRi
|
||||
from .vehicleid import VehicleID, SmallVehicleID, MediumVehicleID, LargeVehicleID
|
||||
from .veriwild import VeRiWild, SmallVeRiWild, MediumVeRiWild, LargeVeRiWild
|
||||
|
||||
__all__ = [k for k in globals().keys() if "builtin" not in k and not k.startswith("_")]
|
||||
|
|
|
@ -5,11 +5,8 @@
|
|||
"""
|
||||
|
||||
import copy
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import logging
|
||||
import os
|
||||
|
||||
|
||||
class Dataset(object):
|
||||
|
|
|
@ -95,7 +95,7 @@ class Checkpointer(object):
|
|||
if not path:
|
||||
# no checkpoint provided
|
||||
self.logger.info(
|
||||
"No checkpoint found. Initializing model from scratch"
|
||||
"No checkpoint found. Training model from scratch"
|
||||
)
|
||||
return {}
|
||||
self.logger.info("Loading checkpoint from {}".format(path))
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
#!/usr/bin/env python
|
||||
# encoding: utf-8
|
||||
"""
|
||||
@author: sherlock
|
||||
|
|
Loading…
Reference in New Issue