mirror of
https://github.com/KaiyangZhou/deep-person-reid.git
synced 2025-06-03 14:53:23 +08:00
91 lines
2.0 KiB
Python
91 lines
2.0 KiB
Python
|
from argparse import ArgumentParser
|
||
|
|
||
|
import torchreid
|
||
|
|
||
|
import json
|
||
|
import logging
|
||
|
|
||
|
|
||
|
def read_jsonl(in_file_path):
|
||
|
with open(in_file_path) as f:
|
||
|
data = map(json.loads, f)
|
||
|
for datum in data:
|
||
|
yield datum
|
||
|
|
||
|
|
||
|
logging.basicConfig(level=logging.DEBUG)
|
||
|
LOGGER = logging.getLogger(__name__)
|
||
|
|
||
|
|
||
|
def main(data_dir, save_dir):
|
||
|
"""
|
||
|
datamanager = torchreid.data.ImageDataManager(
|
||
|
root='reid-data',
|
||
|
sources='market1501',
|
||
|
targets='market1501',
|
||
|
height=256,
|
||
|
width=128,
|
||
|
batch_size_train=32,
|
||
|
batch_size_test=100,
|
||
|
transforms=['random_flip', 'random_crop']
|
||
|
)
|
||
|
"""
|
||
|
datamanager = torchreid.data.ImageDataManager(
|
||
|
root=data_dir,
|
||
|
sources=['safex_carla_simulation'],
|
||
|
height=50,
|
||
|
width=50,
|
||
|
batch_size_train=32,
|
||
|
batch_size_test=100,
|
||
|
transforms=['random_flip', 'color_jitter']
|
||
|
)
|
||
|
|
||
|
model = torchreid.models.build_model(
|
||
|
name='osnet_ain_x1_0',
|
||
|
num_classes=datamanager.num_train_pids,
|
||
|
loss='triplet',
|
||
|
pretrained=False
|
||
|
)
|
||
|
|
||
|
model = model.cpu()
|
||
|
|
||
|
optimizer = torchreid.optim.build_optimizer(
|
||
|
model,
|
||
|
optim='adam',
|
||
|
lr=0.0003
|
||
|
)
|
||
|
|
||
|
scheduler = torchreid.optim.build_lr_scheduler(
|
||
|
optimizer,
|
||
|
lr_scheduler='single_step',
|
||
|
stepsize=20
|
||
|
)
|
||
|
|
||
|
engine = torchreid.engine.ImageTripletEngine(
|
||
|
datamanager,
|
||
|
model,
|
||
|
optimizer=optimizer,
|
||
|
scheduler=scheduler,
|
||
|
label_smooth=True
|
||
|
)
|
||
|
|
||
|
engine.run(
|
||
|
save_dir=save_dir,
|
||
|
max_epoch=60,
|
||
|
eval_freq=10,
|
||
|
print_freq=10,
|
||
|
test_only=False
|
||
|
)
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
parser = ArgumentParser()
|
||
|
parser.add_argument("--data_dir", default="./reid_out/", required=False)
|
||
|
parser.add_argument("--save_dir", default="./reid_out/", required=False)
|
||
|
|
||
|
args = parser.parse_args()
|
||
|
main(
|
||
|
data_dir=args.data_dir,
|
||
|
save_dir=args.save_dir
|
||
|
)
|