Add BTD and LAL (#4)
parent
d00d2f5822
commit
c9406038a2
|
@ -11,7 +11,7 @@ import torch
|
||||||
from torch.utils.data import Sampler
|
from torch.utils.data import Sampler
|
||||||
|
|
||||||
from .datasets import ImageNet, ImageNet22k, ImageShipID, ImageShipID_Extra, ImageShipID_20P, ImageShipID_40P, ImageShipID_60P, ImageShipID_80P, ImageShipOOD
|
from .datasets import ImageNet, ImageNet22k, ImageShipID, ImageShipID_Extra, ImageShipID_20P, ImageShipID_40P, ImageShipID_60P, ImageShipID_80P, ImageShipOOD
|
||||||
from .samplers import EpochSampler, InfiniteSampler, ShardedInfiniteSampler
|
from .samplers import EpochSampler, InfiniteSampler, ShardedInfiniteSampler, ShardedInfiniteBalancedSampler
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger("dinov2")
|
logger = logging.getLogger("dinov2")
|
||||||
|
@ -23,6 +23,7 @@ class SamplerType(Enum):
|
||||||
INFINITE = 2
|
INFINITE = 2
|
||||||
SHARDED_INFINITE = 3
|
SHARDED_INFINITE = 3
|
||||||
SHARDED_INFINITE_NEW = 4
|
SHARDED_INFINITE_NEW = 4
|
||||||
|
SHARDED_INFINITE_BALANCED = 5
|
||||||
|
|
||||||
|
|
||||||
def _make_bool_str(b: bool) -> str:
|
def _make_bool_str(b: bool) -> str:
|
||||||
|
@ -133,6 +134,7 @@ def _make_sampler(
|
||||||
seed: int = 0,
|
seed: int = 0,
|
||||||
size: int = -1,
|
size: int = -1,
|
||||||
advance: int = 0,
|
advance: int = 0,
|
||||||
|
**kwargs,
|
||||||
) -> Optional[Sampler]:
|
) -> Optional[Sampler]:
|
||||||
sample_count = len(dataset)
|
sample_count = len(dataset)
|
||||||
|
|
||||||
|
@ -183,6 +185,17 @@ def _make_sampler(
|
||||||
seed=seed,
|
seed=seed,
|
||||||
drop_last=False,
|
drop_last=False,
|
||||||
)
|
)
|
||||||
|
elif type == SamplerType.SHARDED_INFINITE_BALANCED:
|
||||||
|
logger.info("sampler: sharded infinite balanced")
|
||||||
|
if size > 0:
|
||||||
|
raise ValueError("sampler size > 0 is invalid")
|
||||||
|
return ShardedInfiniteBalancedSampler(
|
||||||
|
labels=dataset.get_targets(),
|
||||||
|
mode=kwargs["balanced_sampler_mode"],
|
||||||
|
shuffle=shuffle,
|
||||||
|
seed=seed,
|
||||||
|
advance=advance,
|
||||||
|
)
|
||||||
|
|
||||||
logger.info("sampler: none")
|
logger.info("sampler: none")
|
||||||
return None
|
return None
|
||||||
|
@ -204,6 +217,7 @@ def make_data_loader(
|
||||||
drop_last: bool = True,
|
drop_last: bool = True,
|
||||||
persistent_workers: bool = False,
|
persistent_workers: bool = False,
|
||||||
collate_fn: Optional[Callable[[List[T]], Any]] = None,
|
collate_fn: Optional[Callable[[List[T]], Any]] = None,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Creates a data loader with the specified parameters.
|
Creates a data loader with the specified parameters.
|
||||||
|
@ -229,6 +243,7 @@ def make_data_loader(
|
||||||
seed=seed,
|
seed=seed,
|
||||||
size=sampler_size,
|
size=sampler_size,
|
||||||
advance=sampler_advance,
|
advance=sampler_advance,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info("using PyTorch data loader")
|
logger.info("using PyTorch data loader")
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
# found in the LICENSE file in the root directory of this source tree.
|
# found in the LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
import itertools
|
import itertools
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional, List, Union, Iterator
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -227,3 +227,147 @@ class ShardedInfiniteSampler(Sampler):
|
||||||
)
|
)
|
||||||
yield from iterable
|
yield from iterable
|
||||||
self._iter_count += 1
|
self._iter_count += 1
|
||||||
|
|
||||||
|
|
||||||
|
class ShardedInfiniteBalancedSampler(Sampler):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
labels: List[int],
|
||||||
|
mode: Union[str, int] = "downsampling",
|
||||||
|
shuffle: bool = False,
|
||||||
|
seed: int = 0,
|
||||||
|
start: Optional[int] = None,
|
||||||
|
step: Optional[int] = None,
|
||||||
|
advance: int = 0,
|
||||||
|
use_new_shuffle_tensor_slice: bool = False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
A sharded infinite sampler that can optionally perform balanced (stratified) class sampling.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
labels: List of class labels for the dataset.
|
||||||
|
mode: "downsampling", "upsampling", or an integer specifying the number of samples per class per cycle.
|
||||||
|
- "downsampling": each class is sampled using the count equal to the minimum available samples.
|
||||||
|
- "upsampling": each class is sampled using the count equal to the maximum available samples.
|
||||||
|
shuffle: Whether to shuffle the balanced samples each cycle.
|
||||||
|
seed: Random seed for shuffling.
|
||||||
|
start: Shard start index (defaults to global rank).
|
||||||
|
step: Shard step (defaults to global size).
|
||||||
|
advance: Number of initial samples to skip.
|
||||||
|
use_new_shuffle_tensor_slice: Switch to the alternate shuffle slice function.
|
||||||
|
"""
|
||||||
|
super().__init__(labels)
|
||||||
|
self._labels = np.array(labels)
|
||||||
|
self._seed = seed
|
||||||
|
self._shuffle = shuffle
|
||||||
|
|
||||||
|
# Sharding info.
|
||||||
|
self._start = distributed.get_global_rank() if start is None else start
|
||||||
|
self._step = distributed.get_global_size() if step is None else step
|
||||||
|
self._advance = advance
|
||||||
|
|
||||||
|
# Choose the slice function.
|
||||||
|
self._shuffle_tensor_slice_fn = (
|
||||||
|
_new_shuffle_tensor_slice if use_new_shuffle_tensor_slice else _shuffle_tensor_slice
|
||||||
|
)
|
||||||
|
|
||||||
|
# Cycle count.
|
||||||
|
self._iter_count = 0
|
||||||
|
|
||||||
|
# Balanced sampling configuration.
|
||||||
|
self._unique_labels = np.unique(self._labels)
|
||||||
|
self._lbl2idx = {lbl: np.where(self._labels == lbl)[0] for lbl in self._unique_labels}
|
||||||
|
self._sorted_labels = sorted(self._unique_labels)
|
||||||
|
|
||||||
|
# Determine samples per class.
|
||||||
|
if isinstance(mode, str):
|
||||||
|
if mode == "downsampling":
|
||||||
|
self._samples_per_class = min(len(idxs) for idxs in self._lbl2idx.values())
|
||||||
|
elif mode == "upsampling":
|
||||||
|
self._samples_per_class = max(len(idxs) for idxs in self._lbl2idx.values())
|
||||||
|
else:
|
||||||
|
raise ValueError(f"mode='{mode}' must be 'downsampling', 'upsampling', or an integer.")
|
||||||
|
elif isinstance(mode, int):
|
||||||
|
self._samples_per_class = mode
|
||||||
|
else:
|
||||||
|
raise ValueError(f"mode must be str or int, got {type(mode)}.")
|
||||||
|
|
||||||
|
# The size of one balanced "cycle" (nominal epoch length).
|
||||||
|
self._sample_count = self._samples_per_class * len(self._sorted_labels)
|
||||||
|
|
||||||
|
def __iter__(self) -> Iterator[int]:
|
||||||
|
"""
|
||||||
|
Main entry point for iteration.
|
||||||
|
- Possibly skips entire cycles if 'advance' is large.
|
||||||
|
- Then yields from either _iterator() or _shuffled_iterator(),
|
||||||
|
slicing off the first 'advance' items.
|
||||||
|
"""
|
||||||
|
# Skip entire cycles if _advance is large.
|
||||||
|
iter_count = self._advance // self._sample_count
|
||||||
|
if iter_count > 0:
|
||||||
|
self._advance -= iter_count * self._sample_count
|
||||||
|
self._iter_count += iter_count
|
||||||
|
|
||||||
|
if self._shuffle:
|
||||||
|
iterator = self._shuffled_iterator()
|
||||||
|
else:
|
||||||
|
iterator = self._iterator()
|
||||||
|
|
||||||
|
yield from itertools.islice(iterator, self._advance, None)
|
||||||
|
|
||||||
|
def _iterator(self) -> Iterator[int]:
|
||||||
|
"""
|
||||||
|
Non-shuffling infinite iterator.
|
||||||
|
Builds a balanced set each cycle (down/up sample) without final shuffling, then shards it.
|
||||||
|
"""
|
||||||
|
rng = np.random.default_rng(self._seed)
|
||||||
|
while True:
|
||||||
|
indices = []
|
||||||
|
for lbl in self._sorted_labels:
|
||||||
|
idxs = self._lbl2idx[lbl]
|
||||||
|
replace = (self._samples_per_class > len(idxs))
|
||||||
|
chosen = rng.choice(idxs, self._samples_per_class, replace=replace)
|
||||||
|
indices.extend(chosen)
|
||||||
|
# Shard for this rank.
|
||||||
|
for idx in indices[self._start::self._step]:
|
||||||
|
yield int(idx)
|
||||||
|
self._iter_count += 1
|
||||||
|
|
||||||
|
def _shuffled_iterator(self) -> Iterator[int]:
|
||||||
|
"""
|
||||||
|
Shuffling infinite iterator.
|
||||||
|
Each cycle picks new balanced subsets from each class, shuffles them,
|
||||||
|
and shards them.
|
||||||
|
"""
|
||||||
|
# Torch generator for slicing.
|
||||||
|
generator = torch.Generator()
|
||||||
|
while True:
|
||||||
|
seed = _make_seed(self._seed, self._start, self._iter_count)
|
||||||
|
generator.manual_seed(seed)
|
||||||
|
# Create a single NumPy RNG for the entire cycle.
|
||||||
|
rng = np.random.RandomState(seed)
|
||||||
|
indices_np = []
|
||||||
|
for lbl in self._sorted_labels:
|
||||||
|
idxs = self._lbl2idx[lbl]
|
||||||
|
replace = (self._samples_per_class > len(idxs))
|
||||||
|
chosen = rng.choice(idxs, self._samples_per_class, replace=replace)
|
||||||
|
indices_np.extend(chosen)
|
||||||
|
# Convert to torch tensor.
|
||||||
|
cycle_size = len(indices_np)
|
||||||
|
dtype = _get_torch_dtype(cycle_size)
|
||||||
|
indices_tensor = torch.tensor(indices_np, dtype=dtype)
|
||||||
|
# Final shuffle and shard using the slice function.
|
||||||
|
iterable = self._shuffle_tensor_slice_fn(
|
||||||
|
tensor=indices_tensor,
|
||||||
|
start=self._start,
|
||||||
|
step=self._step,
|
||||||
|
generator=generator
|
||||||
|
)
|
||||||
|
yield from iterable
|
||||||
|
self._iter_count += 1
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
"""
|
||||||
|
Returns the nominal size of one 'cycle': samples_per_class * number of classes.
|
||||||
|
"""
|
||||||
|
return self._sample_count
|
||||||
|
|
|
@ -4,6 +4,7 @@
|
||||||
# found in the LICENSE file in the root directory of this source tree.
|
# found in the LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
from collections import defaultdict
|
||||||
from functools import partial
|
from functools import partial
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
@ -14,6 +15,7 @@ from typing import List, Optional
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
from torch.nn.parallel import DistributedDataParallel
|
from torch.nn.parallel import DistributedDataParallel
|
||||||
from fvcore.common.checkpoint import Checkpointer, PeriodicCheckpointer
|
from fvcore.common.checkpoint import Checkpointer, PeriodicCheckpointer
|
||||||
|
|
||||||
|
@ -310,6 +312,29 @@ def evaluate_linear_classifiers(
|
||||||
return results_dict
|
return results_dict
|
||||||
|
|
||||||
|
|
||||||
|
def get_cls_num_list(labels):
|
||||||
|
counter = defaultdict(int)
|
||||||
|
for label in labels:
|
||||||
|
counter[label] += 1
|
||||||
|
labels = list(counter.keys())
|
||||||
|
labels.sort()
|
||||||
|
cls_num_list = [counter[label] for label in labels]
|
||||||
|
return cls_num_list
|
||||||
|
|
||||||
|
|
||||||
|
class LogitAdjustedLoss(nn.Module):
|
||||||
|
def __init__(self, cls_num_list, tau=1.0):
|
||||||
|
super().__init__()
|
||||||
|
cls_num_ratio = cls_num_list / torch.sum(cls_num_list)
|
||||||
|
log_cls_num = torch.log(cls_num_ratio)
|
||||||
|
self.log_cls_num = log_cls_num
|
||||||
|
self.tau = tau
|
||||||
|
|
||||||
|
def forward(self, logit, target):
|
||||||
|
logit_adjusted = logit + self.tau * self.log_cls_num.unsqueeze(0)
|
||||||
|
return F.cross_entropy(logit_adjusted, target)
|
||||||
|
|
||||||
|
|
||||||
def eval_linear(
|
def eval_linear(
|
||||||
*,
|
*,
|
||||||
feature_model,
|
feature_model,
|
||||||
|
@ -329,6 +354,7 @@ def eval_linear(
|
||||||
resume=True,
|
resume=True,
|
||||||
classifier_fpath=None,
|
classifier_fpath=None,
|
||||||
val_class_mapping=None,
|
val_class_mapping=None,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
checkpointer = Checkpointer(linear_classifiers, output_dir, optimizer=optimizer, scheduler=scheduler)
|
checkpointer = Checkpointer(linear_classifiers, output_dir, optimizer=optimizer, scheduler=scheduler)
|
||||||
start_iter = checkpointer.resume_or_load(classifier_fpath or "", resume=resume).get("iteration", -1) + 1
|
start_iter = checkpointer.resume_or_load(classifier_fpath or "", resume=resume).get("iteration", -1) + 1
|
||||||
|
@ -339,6 +365,13 @@ def eval_linear(
|
||||||
metric_logger = MetricLogger(delimiter=" ")
|
metric_logger = MetricLogger(delimiter=" ")
|
||||||
header = "Training"
|
header = "Training"
|
||||||
|
|
||||||
|
if kwargs.get('logit_adjusted_loss', False):
|
||||||
|
cls_num_list = get_cls_num_list(train_data_loader.dataset.get_targets())
|
||||||
|
cls_num_list = torch.Tensor(cls_num_list).to("cuda")
|
||||||
|
train_loss_fn = LogitAdjustedLoss(cls_num_list)
|
||||||
|
else:
|
||||||
|
train_loss_fn = nn.CrossEntropyLoss()
|
||||||
|
|
||||||
for data, labels in metric_logger.log_every(
|
for data, labels in metric_logger.log_every(
|
||||||
train_data_loader,
|
train_data_loader,
|
||||||
10,
|
10,
|
||||||
|
@ -352,7 +385,7 @@ def eval_linear(
|
||||||
features = feature_model(data)
|
features = feature_model(data)
|
||||||
outputs = linear_classifiers(features)
|
outputs = linear_classifiers(features)
|
||||||
|
|
||||||
losses = {f"loss_{k}": nn.CrossEntropyLoss()(v, labels) for k, v in outputs.items()}
|
losses = {f"loss_{k}": train_loss_fn(v, labels) for k, v in outputs.items()}
|
||||||
loss = sum(losses.values())
|
loss = sum(losses.values())
|
||||||
|
|
||||||
# compute the gradients
|
# compute the gradients
|
||||||
|
@ -480,6 +513,7 @@ def run_eval_linear(
|
||||||
test_class_mapping_fpaths=[None],
|
test_class_mapping_fpaths=[None],
|
||||||
val_metric_type=MetricType.MEAN_ACCURACY,
|
val_metric_type=MetricType.MEAN_ACCURACY,
|
||||||
test_metric_types=None,
|
test_metric_types=None,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
seed = 0
|
seed = 0
|
||||||
|
|
||||||
|
@ -497,6 +531,11 @@ def run_eval_linear(
|
||||||
transform=train_transform,
|
transform=train_transform,
|
||||||
)
|
)
|
||||||
training_num_classes = len(torch.unique(torch.Tensor(train_dataset.get_targets().astype(int))))
|
training_num_classes = len(torch.unique(torch.Tensor(train_dataset.get_targets().astype(int))))
|
||||||
|
if kwargs.get('balanced_sampler', False):
|
||||||
|
sampler_type = SamplerType.SHARDED_INFINITE_BALANCED
|
||||||
|
balanced_sampler_mode = kwargs['balanced_sampler_mode']
|
||||||
|
else:
|
||||||
|
balanced_sampler_mode = None
|
||||||
sampler_type = SamplerType.SHARDED_INFINITE
|
sampler_type = SamplerType.SHARDED_INFINITE
|
||||||
# sampler_type = SamplerType.INFINITE
|
# sampler_type = SamplerType.INFINITE
|
||||||
|
|
||||||
|
@ -529,6 +568,7 @@ def run_eval_linear(
|
||||||
sampler_advance=start_iter,
|
sampler_advance=start_iter,
|
||||||
drop_last=True,
|
drop_last=True,
|
||||||
persistent_workers=True,
|
persistent_workers=True,
|
||||||
|
balanced_sampler_mode=balanced_sampler_mode,
|
||||||
)
|
)
|
||||||
val_data_loader = make_eval_data_loader(val_dataset_str, batch_size, num_workers, val_metric_type)
|
val_data_loader = make_eval_data_loader(val_dataset_str, batch_size, num_workers, val_metric_type)
|
||||||
|
|
||||||
|
@ -568,6 +608,7 @@ def run_eval_linear(
|
||||||
resume=resume,
|
resume=resume,
|
||||||
val_class_mapping=val_class_mapping,
|
val_class_mapping=val_class_mapping,
|
||||||
classifier_fpath=classifier_fpath,
|
classifier_fpath=classifier_fpath,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
results_dict = {}
|
results_dict = {}
|
||||||
if len(test_dataset_strs) > 1 or test_dataset_strs[0] != val_dataset_str:
|
if len(test_dataset_strs) > 1 or test_dataset_strs[0] != val_dataset_str:
|
||||||
|
@ -614,6 +655,9 @@ def main(args):
|
||||||
test_metric_types=args.test_metric_types,
|
test_metric_types=args.test_metric_types,
|
||||||
val_class_mapping_fpath=args.val_class_mapping_fpath,
|
val_class_mapping_fpath=args.val_class_mapping_fpath,
|
||||||
test_class_mapping_fpaths=args.test_class_mapping_fpaths,
|
test_class_mapping_fpaths=args.test_class_mapping_fpaths,
|
||||||
|
logit_adjusted_loss=args.logit_adjusted_loss,
|
||||||
|
balanced_sampler=args.balanced_sampler,
|
||||||
|
balanced_sampler_mode=args.balanced_sampler_mode,
|
||||||
)
|
)
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue