Add BTD and LAL (#4)

pull/509/head
Han Chong 2025-02-06 09:39:35 +08:00 committed by GitHub
parent d00d2f5822
commit c9406038a2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 208 additions and 5 deletions

View File

@ -11,7 +11,7 @@ import torch
from torch.utils.data import Sampler
from .datasets import ImageNet, ImageNet22k, ImageShipID, ImageShipID_Extra, ImageShipID_20P, ImageShipID_40P, ImageShipID_60P, ImageShipID_80P, ImageShipOOD
from .samplers import EpochSampler, InfiniteSampler, ShardedInfiniteSampler
from .samplers import EpochSampler, InfiniteSampler, ShardedInfiniteSampler, ShardedInfiniteBalancedSampler
logger = logging.getLogger("dinov2")
@ -23,6 +23,7 @@ class SamplerType(Enum):
INFINITE = 2
SHARDED_INFINITE = 3
SHARDED_INFINITE_NEW = 4
SHARDED_INFINITE_BALANCED = 5
def _make_bool_str(b: bool) -> str:
@ -133,6 +134,7 @@ def _make_sampler(
seed: int = 0,
size: int = -1,
advance: int = 0,
**kwargs,
) -> Optional[Sampler]:
sample_count = len(dataset)
@ -183,6 +185,17 @@ def _make_sampler(
seed=seed,
drop_last=False,
)
elif type == SamplerType.SHARDED_INFINITE_BALANCED:
logger.info("sampler: sharded infinite balanced")
if size > 0:
raise ValueError("sampler size > 0 is invalid")
return ShardedInfiniteBalancedSampler(
labels=dataset.get_targets(),
mode=kwargs["balanced_sampler_mode"],
shuffle=shuffle,
seed=seed,
advance=advance,
)
logger.info("sampler: none")
return None
@ -204,6 +217,7 @@ def make_data_loader(
drop_last: bool = True,
persistent_workers: bool = False,
collate_fn: Optional[Callable[[List[T]], Any]] = None,
**kwargs,
):
"""
Creates a data loader with the specified parameters.
@ -229,6 +243,7 @@ def make_data_loader(
seed=seed,
size=sampler_size,
advance=sampler_advance,
**kwargs,
)
logger.info("using PyTorch data loader")

View File

@ -4,7 +4,7 @@
# found in the LICENSE file in the root directory of this source tree.
import itertools
from typing import Any, Optional
from typing import Any, Optional, List, Union, Iterator
import warnings
import numpy as np
@ -227,3 +227,147 @@ class ShardedInfiniteSampler(Sampler):
)
yield from iterable
self._iter_count += 1
class ShardedInfiniteBalancedSampler(Sampler):
def __init__(
self,
labels: List[int],
mode: Union[str, int] = "downsampling",
shuffle: bool = False,
seed: int = 0,
start: Optional[int] = None,
step: Optional[int] = None,
advance: int = 0,
use_new_shuffle_tensor_slice: bool = False,
):
"""
A sharded infinite sampler that can optionally perform balanced (stratified) class sampling.
Args:
labels: List of class labels for the dataset.
mode: "downsampling", "upsampling", or an integer specifying the number of samples per class per cycle.
- "downsampling": each class is sampled using the count equal to the minimum available samples.
- "upsampling": each class is sampled using the count equal to the maximum available samples.
shuffle: Whether to shuffle the balanced samples each cycle.
seed: Random seed for shuffling.
start: Shard start index (defaults to global rank).
step: Shard step (defaults to global size).
advance: Number of initial samples to skip.
use_new_shuffle_tensor_slice: Switch to the alternate shuffle slice function.
"""
super().__init__(labels)
self._labels = np.array(labels)
self._seed = seed
self._shuffle = shuffle
# Sharding info.
self._start = distributed.get_global_rank() if start is None else start
self._step = distributed.get_global_size() if step is None else step
self._advance = advance
# Choose the slice function.
self._shuffle_tensor_slice_fn = (
_new_shuffle_tensor_slice if use_new_shuffle_tensor_slice else _shuffle_tensor_slice
)
# Cycle count.
self._iter_count = 0
# Balanced sampling configuration.
self._unique_labels = np.unique(self._labels)
self._lbl2idx = {lbl: np.where(self._labels == lbl)[0] for lbl in self._unique_labels}
self._sorted_labels = sorted(self._unique_labels)
# Determine samples per class.
if isinstance(mode, str):
if mode == "downsampling":
self._samples_per_class = min(len(idxs) for idxs in self._lbl2idx.values())
elif mode == "upsampling":
self._samples_per_class = max(len(idxs) for idxs in self._lbl2idx.values())
else:
raise ValueError(f"mode='{mode}' must be 'downsampling', 'upsampling', or an integer.")
elif isinstance(mode, int):
self._samples_per_class = mode
else:
raise ValueError(f"mode must be str or int, got {type(mode)}.")
# The size of one balanced "cycle" (nominal epoch length).
self._sample_count = self._samples_per_class * len(self._sorted_labels)
def __iter__(self) -> Iterator[int]:
"""
Main entry point for iteration.
- Possibly skips entire cycles if 'advance' is large.
- Then yields from either _iterator() or _shuffled_iterator(),
slicing off the first 'advance' items.
"""
# Skip entire cycles if _advance is large.
iter_count = self._advance // self._sample_count
if iter_count > 0:
self._advance -= iter_count * self._sample_count
self._iter_count += iter_count
if self._shuffle:
iterator = self._shuffled_iterator()
else:
iterator = self._iterator()
yield from itertools.islice(iterator, self._advance, None)
def _iterator(self) -> Iterator[int]:
"""
Non-shuffling infinite iterator.
Builds a balanced set each cycle (down/up sample) without final shuffling, then shards it.
"""
rng = np.random.default_rng(self._seed)
while True:
indices = []
for lbl in self._sorted_labels:
idxs = self._lbl2idx[lbl]
replace = (self._samples_per_class > len(idxs))
chosen = rng.choice(idxs, self._samples_per_class, replace=replace)
indices.extend(chosen)
# Shard for this rank.
for idx in indices[self._start::self._step]:
yield int(idx)
self._iter_count += 1
def _shuffled_iterator(self) -> Iterator[int]:
"""
Shuffling infinite iterator.
Each cycle picks new balanced subsets from each class, shuffles them,
and shards them.
"""
# Torch generator for slicing.
generator = torch.Generator()
while True:
seed = _make_seed(self._seed, self._start, self._iter_count)
generator.manual_seed(seed)
# Create a single NumPy RNG for the entire cycle.
rng = np.random.RandomState(seed)
indices_np = []
for lbl in self._sorted_labels:
idxs = self._lbl2idx[lbl]
replace = (self._samples_per_class > len(idxs))
chosen = rng.choice(idxs, self._samples_per_class, replace=replace)
indices_np.extend(chosen)
# Convert to torch tensor.
cycle_size = len(indices_np)
dtype = _get_torch_dtype(cycle_size)
indices_tensor = torch.tensor(indices_np, dtype=dtype)
# Final shuffle and shard using the slice function.
iterable = self._shuffle_tensor_slice_fn(
tensor=indices_tensor,
start=self._start,
step=self._step,
generator=generator
)
yield from iterable
self._iter_count += 1
def __len__(self) -> int:
"""
Returns the nominal size of one 'cycle': samples_per_class * number of classes.
"""
return self._sample_count

View File

@ -4,6 +4,7 @@
# found in the LICENSE file in the root directory of this source tree.
import argparse
from collections import defaultdict
from functools import partial
import json
import logging
@ -14,6 +15,7 @@ from typing import List, Optional
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parallel import DistributedDataParallel
from fvcore.common.checkpoint import Checkpointer, PeriodicCheckpointer
@ -310,6 +312,29 @@ def evaluate_linear_classifiers(
return results_dict
def get_cls_num_list(labels):
counter = defaultdict(int)
for label in labels:
counter[label] += 1
labels = list(counter.keys())
labels.sort()
cls_num_list = [counter[label] for label in labels]
return cls_num_list
class LogitAdjustedLoss(nn.Module):
def __init__(self, cls_num_list, tau=1.0):
super().__init__()
cls_num_ratio = cls_num_list / torch.sum(cls_num_list)
log_cls_num = torch.log(cls_num_ratio)
self.log_cls_num = log_cls_num
self.tau = tau
def forward(self, logit, target):
logit_adjusted = logit + self.tau * self.log_cls_num.unsqueeze(0)
return F.cross_entropy(logit_adjusted, target)
def eval_linear(
*,
feature_model,
@ -329,6 +354,7 @@ def eval_linear(
resume=True,
classifier_fpath=None,
val_class_mapping=None,
**kwargs,
):
checkpointer = Checkpointer(linear_classifiers, output_dir, optimizer=optimizer, scheduler=scheduler)
start_iter = checkpointer.resume_or_load(classifier_fpath or "", resume=resume).get("iteration", -1) + 1
@ -339,6 +365,13 @@ def eval_linear(
metric_logger = MetricLogger(delimiter=" ")
header = "Training"
if kwargs.get('logit_adjusted_loss', False):
cls_num_list = get_cls_num_list(train_data_loader.dataset.get_targets())
cls_num_list = torch.Tensor(cls_num_list).to("cuda")
train_loss_fn = LogitAdjustedLoss(cls_num_list)
else:
train_loss_fn = nn.CrossEntropyLoss()
for data, labels in metric_logger.log_every(
train_data_loader,
10,
@ -352,7 +385,7 @@ def eval_linear(
features = feature_model(data)
outputs = linear_classifiers(features)
losses = {f"loss_{k}": nn.CrossEntropyLoss()(v, labels) for k, v in outputs.items()}
losses = {f"loss_{k}": train_loss_fn(v, labels) for k, v in outputs.items()}
loss = sum(losses.values())
# compute the gradients
@ -480,6 +513,7 @@ def run_eval_linear(
test_class_mapping_fpaths=[None],
val_metric_type=MetricType.MEAN_ACCURACY,
test_metric_types=None,
**kwargs,
):
seed = 0
@ -497,8 +531,13 @@ def run_eval_linear(
transform=train_transform,
)
training_num_classes = len(torch.unique(torch.Tensor(train_dataset.get_targets().astype(int))))
sampler_type = SamplerType.SHARDED_INFINITE
# sampler_type = SamplerType.INFINITE
if kwargs.get('balanced_sampler', False):
sampler_type = SamplerType.SHARDED_INFINITE_BALANCED
balanced_sampler_mode = kwargs['balanced_sampler_mode']
else:
balanced_sampler_mode = None
sampler_type = SamplerType.SHARDED_INFINITE
# sampler_type = SamplerType.INFINITE
n_last_blocks_list = [1, 4]
n_last_blocks = max(n_last_blocks_list)
@ -529,6 +568,7 @@ def run_eval_linear(
sampler_advance=start_iter,
drop_last=True,
persistent_workers=True,
balanced_sampler_mode=balanced_sampler_mode,
)
val_data_loader = make_eval_data_loader(val_dataset_str, batch_size, num_workers, val_metric_type)
@ -568,6 +608,7 @@ def run_eval_linear(
resume=resume,
val_class_mapping=val_class_mapping,
classifier_fpath=classifier_fpath,
**kwargs,
)
results_dict = {}
if len(test_dataset_strs) > 1 or test_dataset_strs[0] != val_dataset_str:
@ -614,6 +655,9 @@ def main(args):
test_metric_types=args.test_metric_types,
val_class_mapping_fpath=args.val_class_mapping_fpath,
test_class_mapping_fpaths=args.test_class_mapping_fpaths,
logit_adjusted_loss=args.logit_adjusted_loss,
balanced_sampler=args.balanced_sampler,
balanced_sampler_mode=args.balanced_sampler_mode,
)
return 0