commit
c087521d24
|
@ -137,6 +137,7 @@ dmypy.json
|
|||
|
||||
# ReID
|
||||
reid-data/
|
||||
|
||||
log/
|
||||
saved-models/
|
||||
model-zoo/
|
||||
|
|
|
@ -3,11 +3,9 @@ import logging
|
|||
import numpy as np
|
||||
|
||||
from argparse import ArgumentParser
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from torchreid.utils import FeatureExtractor
|
||||
|
||||
from torchreid.utils.tools import read_jsonl
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
@ -28,12 +26,13 @@ class NumpyEncoder(json.JSONEncoder):
|
|||
|
||||
def main(in_fp: str, model_fp: str, out_fp: str, device: str, max_objects: int):
|
||||
extractor = FeatureExtractor(
|
||||
model_name="resnet50",
|
||||
model_name="resnet18",
|
||||
model_path=model_fp,
|
||||
device=device
|
||||
)
|
||||
|
||||
manifest_entries = read_jsonl(in_fp)
|
||||
# Todo (Josh) speed this up my batching
|
||||
with open(out_fp, "w") as f:
|
||||
for ix, manifest_entry in tqdm(enumerate(manifest_entries), desc="objects"):
|
||||
|
||||
|
|
|
@ -17,34 +17,22 @@ logging.basicConfig(level=logging.DEBUG)
|
|||
LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def main(data_dir, save_dir):
|
||||
"""
|
||||
datamanager = torchreid.data.ImageDataManager(
|
||||
root='reid-data',
|
||||
sources='market1501',
|
||||
targets='market1501',
|
||||
height=256,
|
||||
width=128,
|
||||
batch_size_train=32,
|
||||
batch_size_test=100,
|
||||
transforms=['random_flip', 'random_crop']
|
||||
)
|
||||
"""
|
||||
def main(data_dir, save_dir, model_fp):
|
||||
datamanager = torchreid.data.ImageDataManager(
|
||||
root=data_dir,
|
||||
sources=['safex_carla_simulation'],
|
||||
sources="market1501",
|
||||
targets="market1501",
|
||||
height=256,
|
||||
width=128,
|
||||
batch_size_train=32,
|
||||
batch_size_test=100,
|
||||
transforms=['random_flip'],
|
||||
split_id=1
|
||||
transforms=["random_flip", "random_crop"]
|
||||
)
|
||||
|
||||
model = torchreid.models.build_model(
|
||||
name='resnet50',
|
||||
name="resnet18",
|
||||
num_classes=datamanager.num_train_pids,
|
||||
loss='triplet',
|
||||
loss="triplet",
|
||||
pretrained=True
|
||||
)
|
||||
|
||||
|
@ -52,13 +40,13 @@ def main(data_dir, save_dir):
|
|||
|
||||
optimizer = torchreid.optim.build_optimizer(
|
||||
model,
|
||||
optim='adam',
|
||||
optim="adam",
|
||||
lr=0.0003
|
||||
)
|
||||
|
||||
scheduler = torchreid.optim.build_lr_scheduler(
|
||||
optimizer,
|
||||
lr_scheduler='single_step',
|
||||
lr_scheduler="single_step",
|
||||
stepsize=20
|
||||
)
|
||||
|
||||
|
@ -71,7 +59,7 @@ def main(data_dir, save_dir):
|
|||
)
|
||||
|
||||
start_epoch = torchreid.utils.resume_from_checkpoint(
|
||||
'/home/ubuntu/deep-person-reid/reid_out/resnet/model/model.pth.tar-15',
|
||||
model_fp,
|
||||
model,
|
||||
optimizer
|
||||
)
|
||||
|
@ -89,11 +77,13 @@ def main(data_dir, save_dir):
|
|||
|
||||
if __name__ == "__main__":
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("--model_fp", required=True)
|
||||
parser.add_argument("--data_dir", default="./reid_out/", required=False)
|
||||
parser.add_argument("--save_dir", default="./reid_out/", required=False)
|
||||
|
||||
args = parser.parse_args()
|
||||
main(
|
||||
model_fp=args.model_path,
|
||||
data_dir=args.data_dir,
|
||||
save_dir=args.save_dir
|
||||
)
|
||||
|
|
|
@ -18,9 +18,8 @@ LOGGER = logging.getLogger(__name__)
|
|||
|
||||
|
||||
def main(data_dir, save_dir):
|
||||
"""
|
||||
datamanager = torchreid.data.ImageDataManager(
|
||||
root='reid-data',
|
||||
root=data_dir,
|
||||
sources='market1501',
|
||||
targets='market1501',
|
||||
height=256,
|
||||
|
@ -29,19 +28,6 @@ def main(data_dir, save_dir):
|
|||
batch_size_test=100,
|
||||
transforms=['random_flip', 'random_crop']
|
||||
)
|
||||
"""
|
||||
datamanager = torchreid.data.ImageDataManager(
|
||||
root=data_dir,
|
||||
sources=['safex_carla_simulation'],
|
||||
height=256,
|
||||
width=128,
|
||||
split_id=0,
|
||||
batch_size_train=32,
|
||||
batch_size_test=100,
|
||||
transforms=['random_flip'],
|
||||
|
||||
)
|
||||
|
||||
model = torchreid.models.build_model(
|
||||
name='osnet_ain_x1_0',
|
||||
num_classes=datamanager.num_train_pids,
|
||||
|
|
|
@ -18,9 +18,8 @@ LOGGER = logging.getLogger(__name__)
|
|||
|
||||
|
||||
def main(data_dir, save_dir):
|
||||
"""
|
||||
datamanager = torchreid.data.ImageDataManager(
|
||||
root='reid-data',
|
||||
root=data_dir,
|
||||
sources='market1501',
|
||||
targets='market1501',
|
||||
height=256,
|
||||
|
@ -29,20 +28,9 @@ def main(data_dir, save_dir):
|
|||
batch_size_test=100,
|
||||
transforms=['random_flip', 'random_crop']
|
||||
)
|
||||
"""
|
||||
datamanager = torchreid.data.ImageDataManager(
|
||||
root=data_dir,
|
||||
sources=['safex_carla_simulation'],
|
||||
height=256,
|
||||
width=128,
|
||||
batch_size_train=32,
|
||||
batch_size_test=100,
|
||||
transforms=['random_flip'],
|
||||
|
||||
)
|
||||
|
||||
model = torchreid.models.build_model(
|
||||
name='resnet50',
|
||||
name='resnet18',
|
||||
num_classes=datamanager.num_train_pids,
|
||||
loss='triplet',
|
||||
pretrained=True
|
||||
|
|
|
@ -40,7 +40,7 @@ class FeatureExtractor(object):
|
|||
|
||||
extractor = FeatureExtractor(
|
||||
model_name='osnet_x1_0',
|
||||
model_path='a/b/c/model.pth.tar',
|
||||
model_fp='a/b/c/model.pth.tar',
|
||||
device='cuda'
|
||||
)
|
||||
|
||||
|
|
Loading…
Reference in New Issue