RE-OWOD/detectron2/utils/store_non_list.py

62 lines
1.8 KiB
Python
Raw Normal View History

2022-01-04 13:35:52 +08:00
import random
from collections import deque
import numpy as np
class Store:
def __init__(self, total_num_classes, items_per_class, shuffle=False):
self.shuffle = shuffle
self.items_per_class = items_per_class
self.total_num_classes = total_num_classes
self.store = [deque(maxlen=self.items_per_class) for _ in range(self.total_num_classes)]
def add(self, items, class_ids):
for idx, class_id in enumerate(class_ids):
self.store[class_id].append(items[idx])
def retrieve(self, class_id):
if class_id != -1:
items = []
for item in self.store[class_id]:
items.extend(item)
if self.shuffle:
random.shuffle(items)
return items
else:
all_items = []
for i in range(self.total_num_classes):
items = []
for item in self.store[i]:
items.append(item)
all_items.append(items)
return all_items
def reset(self):
self.store = [deque(maxlen=self.items_per_class) for _ in range(self.total_num_classes)]
def __str__(self):
s = self.__class__.__name__ + '('
for idx, item in enumerate(self.store):
s += '\n Class ' + str(idx) + ' --> ' + str(len(list(item))) + ' items'
s = s + ' )'
return s
def __repr__(self):
return self.__str__()
def __len__(self):
return sum([len(s) for s in self.store])
if __name__ == "__main__":
store = Store(10, 3)
store.add(('a', 'b', 'c', 'd', 'e', 'f'), (1, 1, 9, 1, 0, 1))
store.add(('h',), (4,))
# print(store.retrieve(1))
# print(store.retrieve(3))
# print(store.retrieve(9))
print(store.retrieve(-1))
# print(len(store))
# store.reset()
# print(len(store))
print(store)