fast-reid/fastreid/utils/history_buffer.py

72 lines
2.2 KiB
Python

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import numpy as np
from typing import List, Tuple
class HistoryBuffer:
"""
Track a series of scalar values and provide access to smoothed values over a
window or the global average of the series.
"""
def __init__(self, max_length: int = 1000000):
"""
Args:
max_length: maximal number of values that can be stored in the
buffer. When the capacity of the buffer is exhausted, old
values will be removed.
"""
self._max_length: int = max_length
self._data: List[Tuple[float, float]] = [] # (value, iteration) pairs
self._count: int = 0
self._global_avg: float = 0
def update(self, value: float, iteration: float = None):
"""
Add a new scalar value produced at certain iteration. If the length
of the buffer exceeds self._max_length, the oldest element will be
removed from the buffer.
"""
if iteration is None:
iteration = self._count
if len(self._data) == self._max_length:
self._data.pop(0)
self._data.append((value, iteration))
self._count += 1
self._global_avg += (value - self._global_avg) / self._count
def latest(self):
"""
Return the latest scalar value added to the buffer.
"""
return self._data[-1][0]
def median(self, window_size: int):
"""
Return the median of the latest `window_size` values in the buffer.
"""
return np.median([x[0] for x in self._data[-window_size:]])
def avg(self, window_size: int):
"""
Return the mean of the latest `window_size` values in the buffer.
"""
return np.mean([x[0] for x in self._data[-window_size:]])
def global_avg(self):
"""
Return the mean of all the elements in the buffer. Note that this
includes those getting removed due to limited buffer storage.
"""
return self._global_avg
def values(self):
"""
Returns:
list[(number, iteration)]: content of the current buffer.
"""
return self._data