add assert for DistributedRandomIdentitySampler and refine docstring for DistributedRandomIdentitySampler&Pksampler
parent
da70ef9c95
commit
84870b771d
|
@ -14,24 +14,27 @@
|
|||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from collections import defaultdict
|
||||
import numpy as np
|
||||
|
||||
import copy
|
||||
import random
|
||||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
from paddle.io import DistributedBatchSampler, Sampler
|
||||
|
||||
|
||||
class DistributedRandomIdentitySampler(DistributedBatchSampler):
|
||||
"""
|
||||
Randomly sample N identities, then for each identity,
|
||||
randomly sample K instances, therefore batch size is N*K.
|
||||
"""Randomly sample N identities, then for each identity,
|
||||
randomly sample K instances, therefore batch size equals to N * K.
|
||||
Args:
|
||||
- data_source (list): list of (img_path, pid, camid).
|
||||
- num_instances (int): number of instances per identity in a batch.
|
||||
- batch_size (int): number of examples in a batch.
|
||||
dataset(Dataset): Dataset which contains list of (img_path, pid, camid))
|
||||
batch_size (int): batch size
|
||||
num_instances (int): number of instance(s) within an class
|
||||
drop_last (bool): whether to discard the data at the end
|
||||
"""
|
||||
|
||||
def __init__(self, dataset, batch_size, num_instances, drop_last, **args):
|
||||
assert batch_size % num_instances == 0, \
|
||||
f"batch_size({batch_size}) must be divisible by num_instances({num_instances}) when using DistributedRandomIdentitySampler"
|
||||
self.dataset = dataset
|
||||
self.batch_size = batch_size
|
||||
self.num_instances = num_instances
|
||||
|
|
|
@ -14,27 +14,27 @@
|
|||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from collections import defaultdict
|
||||
import numpy as np
|
||||
import random
|
||||
from paddle.io import DistributedBatchSampler
|
||||
|
||||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
from paddle.io import DistributedBatchSampler
|
||||
from ppcls.utils import logger
|
||||
|
||||
|
||||
class PKSampler(DistributedBatchSampler):
|
||||
"""
|
||||
First, randomly sample P identities.
|
||||
Then for each identity randomly sample K instances.
|
||||
Therefore batch size is P*K, and the sampler called PKSampler.
|
||||
Args:
|
||||
dataset (paddle.io.Dataset): list of (img_path, pid, cam_id).
|
||||
sample_per_id(int): number of instances per identity in a batch.
|
||||
batch_size (int): number of examples in a batch.
|
||||
shuffle(bool): whether to shuffle indices order before generating
|
||||
batch indices. Default False.
|
||||
"""
|
||||
"""First, randomly sample P identities.
|
||||
Then for each identity randomly sample K instances.
|
||||
Therefore batch size equals to P * K, and the sampler called PKSampler.
|
||||
|
||||
Args:
|
||||
dataset (Dataset): Dataset which contains list of (img_path, pid, camid))
|
||||
batch_size (_type_): batch size
|
||||
sample_per_id (_type_): number of instance(s) within an class
|
||||
shuffle (bool, optional): _description_. Defaults to True.
|
||||
drop_last (bool, optional): whether to discard the data at the end. Defaults to True.
|
||||
sample_method (str, optional): sample method when generating prob_list. Defaults to "sample_avg_prob".
|
||||
"""
|
||||
def __init__(self,
|
||||
dataset,
|
||||
batch_size,
|
||||
|
@ -42,10 +42,9 @@ class PKSampler(DistributedBatchSampler):
|
|||
shuffle=True,
|
||||
drop_last=True,
|
||||
sample_method="sample_avg_prob"):
|
||||
super().__init__(
|
||||
dataset, batch_size, shuffle=shuffle, drop_last=drop_last)
|
||||
super().__init__(dataset, batch_size, shuffle=shuffle, drop_last=drop_last)
|
||||
assert batch_size % sample_per_id == 0, \
|
||||
"PKSampler configs error, Sample_per_id must be a divisor of batch_size."
|
||||
f"PKSampler configs error, sample_per_id({sample_per_id}) must be a divisor of batch_size({batch_size})."
|
||||
assert hasattr(self.dataset,
|
||||
"labels"), "Dataset must have labels attribute."
|
||||
self.sample_per_label = sample_per_id
|
||||
|
|
Loading…
Reference in New Issue