Add BTD and LAL (#4)
parent
d00d2f5822
commit
c9406038a2
|
@ -11,7 +11,7 @@ import torch
|
|||
from torch.utils.data import Sampler
|
||||
|
||||
from .datasets import ImageNet, ImageNet22k, ImageShipID, ImageShipID_Extra, ImageShipID_20P, ImageShipID_40P, ImageShipID_60P, ImageShipID_80P, ImageShipOOD
|
||||
from .samplers import EpochSampler, InfiniteSampler, ShardedInfiniteSampler
|
||||
from .samplers import EpochSampler, InfiniteSampler, ShardedInfiniteSampler, ShardedInfiniteBalancedSampler
|
||||
|
||||
|
||||
logger = logging.getLogger("dinov2")
|
||||
|
@ -23,6 +23,7 @@ class SamplerType(Enum):
|
|||
INFINITE = 2
|
||||
SHARDED_INFINITE = 3
|
||||
SHARDED_INFINITE_NEW = 4
|
||||
SHARDED_INFINITE_BALANCED = 5
|
||||
|
||||
|
||||
def _make_bool_str(b: bool) -> str:
|
||||
|
@ -133,6 +134,7 @@ def _make_sampler(
|
|||
seed: int = 0,
|
||||
size: int = -1,
|
||||
advance: int = 0,
|
||||
**kwargs,
|
||||
) -> Optional[Sampler]:
|
||||
sample_count = len(dataset)
|
||||
|
||||
|
@ -183,6 +185,17 @@ def _make_sampler(
|
|||
seed=seed,
|
||||
drop_last=False,
|
||||
)
|
||||
elif type == SamplerType.SHARDED_INFINITE_BALANCED:
|
||||
logger.info("sampler: sharded infinite balanced")
|
||||
if size > 0:
|
||||
raise ValueError("sampler size > 0 is invalid")
|
||||
return ShardedInfiniteBalancedSampler(
|
||||
labels=dataset.get_targets(),
|
||||
mode=kwargs["balanced_sampler_mode"],
|
||||
shuffle=shuffle,
|
||||
seed=seed,
|
||||
advance=advance,
|
||||
)
|
||||
|
||||
logger.info("sampler: none")
|
||||
return None
|
||||
|
@ -204,6 +217,7 @@ def make_data_loader(
|
|||
drop_last: bool = True,
|
||||
persistent_workers: bool = False,
|
||||
collate_fn: Optional[Callable[[List[T]], Any]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Creates a data loader with the specified parameters.
|
||||
|
@ -229,6 +243,7 @@ def make_data_loader(
|
|||
seed=seed,
|
||||
size=sampler_size,
|
||||
advance=sampler_advance,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
logger.info("using PyTorch data loader")
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
import itertools
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Optional, List, Union, Iterator
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
|
@ -227,3 +227,147 @@ class ShardedInfiniteSampler(Sampler):
|
|||
)
|
||||
yield from iterable
|
||||
self._iter_count += 1
|
||||
|
||||
|
||||
class ShardedInfiniteBalancedSampler(Sampler):
|
||||
def __init__(
|
||||
self,
|
||||
labels: List[int],
|
||||
mode: Union[str, int] = "downsampling",
|
||||
shuffle: bool = False,
|
||||
seed: int = 0,
|
||||
start: Optional[int] = None,
|
||||
step: Optional[int] = None,
|
||||
advance: int = 0,
|
||||
use_new_shuffle_tensor_slice: bool = False,
|
||||
):
|
||||
"""
|
||||
A sharded infinite sampler that can optionally perform balanced (stratified) class sampling.
|
||||
|
||||
Args:
|
||||
labels: List of class labels for the dataset.
|
||||
mode: "downsampling", "upsampling", or an integer specifying the number of samples per class per cycle.
|
||||
- "downsampling": each class is sampled using the count equal to the minimum available samples.
|
||||
- "upsampling": each class is sampled using the count equal to the maximum available samples.
|
||||
shuffle: Whether to shuffle the balanced samples each cycle.
|
||||
seed: Random seed for shuffling.
|
||||
start: Shard start index (defaults to global rank).
|
||||
step: Shard step (defaults to global size).
|
||||
advance: Number of initial samples to skip.
|
||||
use_new_shuffle_tensor_slice: Switch to the alternate shuffle slice function.
|
||||
"""
|
||||
super().__init__(labels)
|
||||
self._labels = np.array(labels)
|
||||
self._seed = seed
|
||||
self._shuffle = shuffle
|
||||
|
||||
# Sharding info.
|
||||
self._start = distributed.get_global_rank() if start is None else start
|
||||
self._step = distributed.get_global_size() if step is None else step
|
||||
self._advance = advance
|
||||
|
||||
# Choose the slice function.
|
||||
self._shuffle_tensor_slice_fn = (
|
||||
_new_shuffle_tensor_slice if use_new_shuffle_tensor_slice else _shuffle_tensor_slice
|
||||
)
|
||||
|
||||
# Cycle count.
|
||||
self._iter_count = 0
|
||||
|
||||
# Balanced sampling configuration.
|
||||
self._unique_labels = np.unique(self._labels)
|
||||
self._lbl2idx = {lbl: np.where(self._labels == lbl)[0] for lbl in self._unique_labels}
|
||||
self._sorted_labels = sorted(self._unique_labels)
|
||||
|
||||
# Determine samples per class.
|
||||
if isinstance(mode, str):
|
||||
if mode == "downsampling":
|
||||
self._samples_per_class = min(len(idxs) for idxs in self._lbl2idx.values())
|
||||
elif mode == "upsampling":
|
||||
self._samples_per_class = max(len(idxs) for idxs in self._lbl2idx.values())
|
||||
else:
|
||||
raise ValueError(f"mode='{mode}' must be 'downsampling', 'upsampling', or an integer.")
|
||||
elif isinstance(mode, int):
|
||||
self._samples_per_class = mode
|
||||
else:
|
||||
raise ValueError(f"mode must be str or int, got {type(mode)}.")
|
||||
|
||||
# The size of one balanced "cycle" (nominal epoch length).
|
||||
self._sample_count = self._samples_per_class * len(self._sorted_labels)
|
||||
|
||||
def __iter__(self) -> Iterator[int]:
|
||||
"""
|
||||
Main entry point for iteration.
|
||||
- Possibly skips entire cycles if 'advance' is large.
|
||||
- Then yields from either _iterator() or _shuffled_iterator(),
|
||||
slicing off the first 'advance' items.
|
||||
"""
|
||||
# Skip entire cycles if _advance is large.
|
||||
iter_count = self._advance // self._sample_count
|
||||
if iter_count > 0:
|
||||
self._advance -= iter_count * self._sample_count
|
||||
self._iter_count += iter_count
|
||||
|
||||
if self._shuffle:
|
||||
iterator = self._shuffled_iterator()
|
||||
else:
|
||||
iterator = self._iterator()
|
||||
|
||||
yield from itertools.islice(iterator, self._advance, None)
|
||||
|
||||
def _iterator(self) -> Iterator[int]:
|
||||
"""
|
||||
Non-shuffling infinite iterator.
|
||||
Builds a balanced set each cycle (down/up sample) without final shuffling, then shards it.
|
||||
"""
|
||||
rng = np.random.default_rng(self._seed)
|
||||
while True:
|
||||
indices = []
|
||||
for lbl in self._sorted_labels:
|
||||
idxs = self._lbl2idx[lbl]
|
||||
replace = (self._samples_per_class > len(idxs))
|
||||
chosen = rng.choice(idxs, self._samples_per_class, replace=replace)
|
||||
indices.extend(chosen)
|
||||
# Shard for this rank.
|
||||
for idx in indices[self._start::self._step]:
|
||||
yield int(idx)
|
||||
self._iter_count += 1
|
||||
|
||||
def _shuffled_iterator(self) -> Iterator[int]:
|
||||
"""
|
||||
Shuffling infinite iterator.
|
||||
Each cycle picks new balanced subsets from each class, shuffles them,
|
||||
and shards them.
|
||||
"""
|
||||
# Torch generator for slicing.
|
||||
generator = torch.Generator()
|
||||
while True:
|
||||
seed = _make_seed(self._seed, self._start, self._iter_count)
|
||||
generator.manual_seed(seed)
|
||||
# Create a single NumPy RNG for the entire cycle.
|
||||
rng = np.random.RandomState(seed)
|
||||
indices_np = []
|
||||
for lbl in self._sorted_labels:
|
||||
idxs = self._lbl2idx[lbl]
|
||||
replace = (self._samples_per_class > len(idxs))
|
||||
chosen = rng.choice(idxs, self._samples_per_class, replace=replace)
|
||||
indices_np.extend(chosen)
|
||||
# Convert to torch tensor.
|
||||
cycle_size = len(indices_np)
|
||||
dtype = _get_torch_dtype(cycle_size)
|
||||
indices_tensor = torch.tensor(indices_np, dtype=dtype)
|
||||
# Final shuffle and shard using the slice function.
|
||||
iterable = self._shuffle_tensor_slice_fn(
|
||||
tensor=indices_tensor,
|
||||
start=self._start,
|
||||
step=self._step,
|
||||
generator=generator
|
||||
)
|
||||
yield from iterable
|
||||
self._iter_count += 1
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""
|
||||
Returns the nominal size of one 'cycle': samples_per_class * number of classes.
|
||||
"""
|
||||
return self._sample_count
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
from collections import defaultdict
|
||||
from functools import partial
|
||||
import json
|
||||
import logging
|
||||
|
@ -14,6 +15,7 @@ from typing import List, Optional
|
|||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
from fvcore.common.checkpoint import Checkpointer, PeriodicCheckpointer
|
||||
|
||||
|
@ -310,6 +312,29 @@ def evaluate_linear_classifiers(
|
|||
return results_dict
|
||||
|
||||
|
||||
def get_cls_num_list(labels):
|
||||
counter = defaultdict(int)
|
||||
for label in labels:
|
||||
counter[label] += 1
|
||||
labels = list(counter.keys())
|
||||
labels.sort()
|
||||
cls_num_list = [counter[label] for label in labels]
|
||||
return cls_num_list
|
||||
|
||||
|
||||
class LogitAdjustedLoss(nn.Module):
|
||||
def __init__(self, cls_num_list, tau=1.0):
|
||||
super().__init__()
|
||||
cls_num_ratio = cls_num_list / torch.sum(cls_num_list)
|
||||
log_cls_num = torch.log(cls_num_ratio)
|
||||
self.log_cls_num = log_cls_num
|
||||
self.tau = tau
|
||||
|
||||
def forward(self, logit, target):
|
||||
logit_adjusted = logit + self.tau * self.log_cls_num.unsqueeze(0)
|
||||
return F.cross_entropy(logit_adjusted, target)
|
||||
|
||||
|
||||
def eval_linear(
|
||||
*,
|
||||
feature_model,
|
||||
|
@ -329,6 +354,7 @@ def eval_linear(
|
|||
resume=True,
|
||||
classifier_fpath=None,
|
||||
val_class_mapping=None,
|
||||
**kwargs,
|
||||
):
|
||||
checkpointer = Checkpointer(linear_classifiers, output_dir, optimizer=optimizer, scheduler=scheduler)
|
||||
start_iter = checkpointer.resume_or_load(classifier_fpath or "", resume=resume).get("iteration", -1) + 1
|
||||
|
@ -339,6 +365,13 @@ def eval_linear(
|
|||
metric_logger = MetricLogger(delimiter=" ")
|
||||
header = "Training"
|
||||
|
||||
if kwargs.get('logit_adjusted_loss', False):
|
||||
cls_num_list = get_cls_num_list(train_data_loader.dataset.get_targets())
|
||||
cls_num_list = torch.Tensor(cls_num_list).to("cuda")
|
||||
train_loss_fn = LogitAdjustedLoss(cls_num_list)
|
||||
else:
|
||||
train_loss_fn = nn.CrossEntropyLoss()
|
||||
|
||||
for data, labels in metric_logger.log_every(
|
||||
train_data_loader,
|
||||
10,
|
||||
|
@ -352,7 +385,7 @@ def eval_linear(
|
|||
features = feature_model(data)
|
||||
outputs = linear_classifiers(features)
|
||||
|
||||
losses = {f"loss_{k}": nn.CrossEntropyLoss()(v, labels) for k, v in outputs.items()}
|
||||
losses = {f"loss_{k}": train_loss_fn(v, labels) for k, v in outputs.items()}
|
||||
loss = sum(losses.values())
|
||||
|
||||
# compute the gradients
|
||||
|
@ -480,6 +513,7 @@ def run_eval_linear(
|
|||
test_class_mapping_fpaths=[None],
|
||||
val_metric_type=MetricType.MEAN_ACCURACY,
|
||||
test_metric_types=None,
|
||||
**kwargs,
|
||||
):
|
||||
seed = 0
|
||||
|
||||
|
@ -497,8 +531,13 @@ def run_eval_linear(
|
|||
transform=train_transform,
|
||||
)
|
||||
training_num_classes = len(torch.unique(torch.Tensor(train_dataset.get_targets().astype(int))))
|
||||
sampler_type = SamplerType.SHARDED_INFINITE
|
||||
# sampler_type = SamplerType.INFINITE
|
||||
if kwargs.get('balanced_sampler', False):
|
||||
sampler_type = SamplerType.SHARDED_INFINITE_BALANCED
|
||||
balanced_sampler_mode = kwargs['balanced_sampler_mode']
|
||||
else:
|
||||
balanced_sampler_mode = None
|
||||
sampler_type = SamplerType.SHARDED_INFINITE
|
||||
# sampler_type = SamplerType.INFINITE
|
||||
|
||||
n_last_blocks_list = [1, 4]
|
||||
n_last_blocks = max(n_last_blocks_list)
|
||||
|
@ -529,6 +568,7 @@ def run_eval_linear(
|
|||
sampler_advance=start_iter,
|
||||
drop_last=True,
|
||||
persistent_workers=True,
|
||||
balanced_sampler_mode=balanced_sampler_mode,
|
||||
)
|
||||
val_data_loader = make_eval_data_loader(val_dataset_str, batch_size, num_workers, val_metric_type)
|
||||
|
||||
|
@ -568,6 +608,7 @@ def run_eval_linear(
|
|||
resume=resume,
|
||||
val_class_mapping=val_class_mapping,
|
||||
classifier_fpath=classifier_fpath,
|
||||
**kwargs,
|
||||
)
|
||||
results_dict = {}
|
||||
if len(test_dataset_strs) > 1 or test_dataset_strs[0] != val_dataset_str:
|
||||
|
@ -614,6 +655,9 @@ def main(args):
|
|||
test_metric_types=args.test_metric_types,
|
||||
val_class_mapping_fpath=args.val_class_mapping_fpath,
|
||||
test_class_mapping_fpaths=args.test_class_mapping_fpaths,
|
||||
logit_adjusted_loss=args.logit_adjusted_loss,
|
||||
balanced_sampler=args.balanced_sampler,
|
||||
balanced_sampler_mode=args.balanced_sampler_mode,
|
||||
)
|
||||
return 0
|
||||
|
||||
|
|
Loading…
Reference in New Issue