mmselfsup/openselfsup/datasets/classification.py

44 lines
1.4 KiB
Python

import torch
from openselfsup.utils import print_log
from .registry import DATASETS
from .base import BaseDataset
@DATASETS.register_module
class ClassificationDataset(BaseDataset):
"""Dataset for classification
"""
def __init__(self, data_source, pipeline):
super(ClassificationDataset, self).__init__(data_source, pipeline)
def __getitem__(self, idx):
img, target = self.data_source.get_sample(idx)
img = self.pipeline(img)
return dict(img=img, gt_label=target)
def evaluate(self, scores, keyword, logger=None, topk=(1, 5)):
'''results: Tensor (NxC)
'''
eval_res = {}
target = torch.LongTensor(self.data_source.labels)
assert scores.size(0) == target.size(0), \
"Inconsistent length for results and labels, {} vs {}".format(
scores.size(0), target.size(0))
num = scores.size(0)
_, pred = scores.topk(max(topk), dim=1, largest=True, sorted=True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred)) # KxN
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0).item()
acc = correct_k * 100.0 / num
eval_res["{}_top{}".format(keyword, k)] = acc
if logger is not None and logger != 'silent':
print_log(
"{}_top{}: {:.03f}".format(keyword, k, acc),
logger=logger)
return eval_res