Merge pull request #4 from AICradle/feature/safex-reid-training

CLean up scripts
pull/462/head
Joshua Newman 2021-07-30 10:21:41 -06:00 committed by GitHub
commit c087521d24
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 18 additions and 54 deletions

1
.gitignore vendored
View File

@ -137,6 +137,7 @@ dmypy.json
# ReID
reid-data/
log/
saved-models/
model-zoo/

View File

View File

@ -3,11 +3,9 @@ import logging
import numpy as np
from argparse import ArgumentParser
from tqdm import tqdm
from torchreid.utils import FeatureExtractor
from torchreid.utils.tools import read_jsonl
logging.basicConfig(level=logging.INFO)
@ -28,12 +26,13 @@ class NumpyEncoder(json.JSONEncoder):
def main(in_fp: str, model_fp: str, out_fp: str, device: str, max_objects: int):
extractor = FeatureExtractor(
model_name="resnet50",
model_name="resnet18",
model_path=model_fp,
device=device
)
manifest_entries = read_jsonl(in_fp)
# Todo (Josh) speed this up my batching
with open(out_fp, "w") as f:
for ix, manifest_entry in tqdm(enumerate(manifest_entries), desc="objects"):

View File

@ -17,34 +17,22 @@ logging.basicConfig(level=logging.DEBUG)
LOGGER = logging.getLogger(__name__)
def main(data_dir, save_dir):
"""
datamanager = torchreid.data.ImageDataManager(
root='reid-data',
sources='market1501',
targets='market1501',
height=256,
width=128,
batch_size_train=32,
batch_size_test=100,
transforms=['random_flip', 'random_crop']
)
"""
def main(data_dir, save_dir, model_fp):
datamanager = torchreid.data.ImageDataManager(
root=data_dir,
sources=['safex_carla_simulation'],
sources="market1501",
targets="market1501",
height=256,
width=128,
batch_size_train=32,
batch_size_test=100,
transforms=['random_flip'],
split_id=1
transforms=["random_flip", "random_crop"]
)
model = torchreid.models.build_model(
name='resnet50',
name="resnet18",
num_classes=datamanager.num_train_pids,
loss='triplet',
loss="triplet",
pretrained=True
)
@ -52,13 +40,13 @@ def main(data_dir, save_dir):
optimizer = torchreid.optim.build_optimizer(
model,
optim='adam',
optim="adam",
lr=0.0003
)
scheduler = torchreid.optim.build_lr_scheduler(
optimizer,
lr_scheduler='single_step',
lr_scheduler="single_step",
stepsize=20
)
@ -71,7 +59,7 @@ def main(data_dir, save_dir):
)
start_epoch = torchreid.utils.resume_from_checkpoint(
'/home/ubuntu/deep-person-reid/reid_out/resnet/model/model.pth.tar-15',
model_fp,
model,
optimizer
)
@ -89,11 +77,13 @@ def main(data_dir, save_dir):
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--model_fp", required=True)
parser.add_argument("--data_dir", default="./reid_out/", required=False)
parser.add_argument("--save_dir", default="./reid_out/", required=False)
args = parser.parse_args()
main(
model_fp=args.model_path,
data_dir=args.data_dir,
save_dir=args.save_dir
)

View File

@ -18,9 +18,8 @@ LOGGER = logging.getLogger(__name__)
def main(data_dir, save_dir):
"""
datamanager = torchreid.data.ImageDataManager(
root='reid-data',
root=data_dir,
sources='market1501',
targets='market1501',
height=256,
@ -29,19 +28,6 @@ def main(data_dir, save_dir):
batch_size_test=100,
transforms=['random_flip', 'random_crop']
)
"""
datamanager = torchreid.data.ImageDataManager(
root=data_dir,
sources=['safex_carla_simulation'],
height=256,
width=128,
split_id=0,
batch_size_train=32,
batch_size_test=100,
transforms=['random_flip'],
)
model = torchreid.models.build_model(
name='osnet_ain_x1_0',
num_classes=datamanager.num_train_pids,

View File

@ -18,9 +18,8 @@ LOGGER = logging.getLogger(__name__)
def main(data_dir, save_dir):
"""
datamanager = torchreid.data.ImageDataManager(
root='reid-data',
root=data_dir,
sources='market1501',
targets='market1501',
height=256,
@ -29,20 +28,9 @@ def main(data_dir, save_dir):
batch_size_test=100,
transforms=['random_flip', 'random_crop']
)
"""
datamanager = torchreid.data.ImageDataManager(
root=data_dir,
sources=['safex_carla_simulation'],
height=256,
width=128,
batch_size_train=32,
batch_size_test=100,
transforms=['random_flip'],
)
model = torchreid.models.build_model(
name='resnet50',
name='resnet18',
num_classes=datamanager.num_train_pids,
loss='triplet',
pretrained=True

View File

@ -40,7 +40,7 @@ class FeatureExtractor(object):
extractor = FeatureExtractor(
model_name='osnet_x1_0',
model_path='a/b/c/model.pth.tar',
model_fp='a/b/c/model.pth.tar',
device='cuda'
)