add assert for DistributedRandomIdentitySampler and refine docstring for DistributedRandomIdentitySampler&Pksampler

pull/2213/head
HydrogenSulfate 2022-08-18 20:16:55 +08:00
parent da70ef9c95
commit 84870b771d
2 changed files with 29 additions and 27 deletions

View File

@ -14,24 +14,27 @@
from __future__ import absolute_import
from __future__ import division
from collections import defaultdict
import numpy as np
import copy
import random
from collections import defaultdict
import numpy as np
from paddle.io import DistributedBatchSampler, Sampler
class DistributedRandomIdentitySampler(DistributedBatchSampler):
"""
Randomly sample N identities, then for each identity,
randomly sample K instances, therefore batch size is N*K.
"""Randomly sample N identities, then for each identity,
randomly sample K instances, therefore batch size equals to N * K.
Args:
- data_source (list): list of (img_path, pid, camid).
- num_instances (int): number of instances per identity in a batch.
- batch_size (int): number of examples in a batch.
dataset(Dataset): Dataset which contains list of (img_path, pid, camid))
batch_size (int): batch size
num_instances (int): number of instance(s) within an class
drop_last (bool): whether to discard the data at the end
"""
def __init__(self, dataset, batch_size, num_instances, drop_last, **args):
assert batch_size % num_instances == 0, \
f"batch_size({batch_size}) must be divisible by num_instances({num_instances}) when using DistributedRandomIdentitySampler"
self.dataset = dataset
self.batch_size = batch_size
self.num_instances = num_instances

View File

@ -14,27 +14,27 @@
from __future__ import absolute_import
from __future__ import division
from collections import defaultdict
import numpy as np
import random
from paddle.io import DistributedBatchSampler
from collections import defaultdict
import numpy as np
from paddle.io import DistributedBatchSampler
from ppcls.utils import logger
class PKSampler(DistributedBatchSampler):
"""
First, randomly sample P identities.
Then for each identity randomly sample K instances.
Therefore batch size is P*K, and the sampler called PKSampler.
Args:
dataset (paddle.io.Dataset): list of (img_path, pid, cam_id).
sample_per_id(int): number of instances per identity in a batch.
batch_size (int): number of examples in a batch.
shuffle(bool): whether to shuffle indices order before generating
batch indices. Default False.
"""
"""First, randomly sample P identities.
Then for each identity randomly sample K instances.
Therefore batch size equals to P * K, and the sampler called PKSampler.
Args:
dataset (Dataset): Dataset which contains list of (img_path, pid, camid))
batch_size (_type_): batch size
sample_per_id (_type_): number of instance(s) within an class
shuffle (bool, optional): _description_. Defaults to True.
drop_last (bool, optional): whether to discard the data at the end. Defaults to True.
sample_method (str, optional): sample method when generating prob_list. Defaults to "sample_avg_prob".
"""
def __init__(self,
dataset,
batch_size,
@ -42,10 +42,9 @@ class PKSampler(DistributedBatchSampler):
shuffle=True,
drop_last=True,
sample_method="sample_avg_prob"):
super().__init__(
dataset, batch_size, shuffle=shuffle, drop_last=drop_last)
super().__init__(dataset, batch_size, shuffle=shuffle, drop_last=drop_last)
assert batch_size % sample_per_id == 0, \
"PKSampler configs error, Sample_per_id must be a divisor of batch_size."
f"PKSampler configs error, sample_per_id({sample_per_id}) must be a divisor of batch_size({batch_size})."
assert hasattr(self.dataset,
"labels"), "Dataset must have labels attribute."
self.sample_per_label = sample_per_id