19 lines
373 B
Python
19 lines
373 B
Python
# encoding: utf-8
|
|
"""
|
|
@author: liaoxingyu
|
|
@contact: sherlockliao01@gmail.com
|
|
"""
|
|
|
|
import torch
|
|
|
|
|
|
def train_collate_fn(batch):
|
|
imgs, pids, _, _, = zip(*batch)
|
|
pids = torch.tensor(pids, dtype=torch.int64)
|
|
return torch.stack(imgs, dim=0), pids
|
|
|
|
|
|
def val_collate_fn(batch):
|
|
imgs, pids, camids, _ = zip(*batch)
|
|
return torch.stack(imgs, dim=0), pids, camids
|