mirror of https://github.com/JDAI-CV/fast-reid.git
72 lines
2.2 KiB
Python
72 lines
2.2 KiB
Python
#!/usr/bin/env python3
|
|
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
|
|
|
import numpy as np
|
|
from typing import List, Tuple
|
|
|
|
|
|
class HistoryBuffer:
|
|
"""
|
|
Track a series of scalar values and provide access to smoothed values over a
|
|
window or the global average of the series.
|
|
"""
|
|
|
|
def __init__(self, max_length: int = 1000000):
|
|
"""
|
|
Args:
|
|
max_length: maximal number of values that can be stored in the
|
|
buffer. When the capacity of the buffer is exhausted, old
|
|
values will be removed.
|
|
"""
|
|
self._max_length: int = max_length
|
|
self._data: List[Tuple[float, float]] = [] # (value, iteration) pairs
|
|
self._count: int = 0
|
|
self._global_avg: float = 0
|
|
|
|
def update(self, value: float, iteration: float = None):
|
|
"""
|
|
Add a new scalar value produced at certain iteration. If the length
|
|
of the buffer exceeds self._max_length, the oldest element will be
|
|
removed from the buffer.
|
|
"""
|
|
if iteration is None:
|
|
iteration = self._count
|
|
if len(self._data) == self._max_length:
|
|
self._data.pop(0)
|
|
self._data.append((value, iteration))
|
|
|
|
self._count += 1
|
|
self._global_avg += (value - self._global_avg) / self._count
|
|
|
|
def latest(self):
|
|
"""
|
|
Return the latest scalar value added to the buffer.
|
|
"""
|
|
return self._data[-1][0]
|
|
|
|
def median(self, window_size: int):
|
|
"""
|
|
Return the median of the latest `window_size` values in the buffer.
|
|
"""
|
|
return np.median([x[0] for x in self._data[-window_size:]])
|
|
|
|
def avg(self, window_size: int):
|
|
"""
|
|
Return the mean of the latest `window_size` values in the buffer.
|
|
"""
|
|
return np.mean([x[0] for x in self._data[-window_size:]])
|
|
|
|
def global_avg(self):
|
|
"""
|
|
Return the mean of all the elements in the buffer. Note that this
|
|
includes those getting removed due to limited buffer storage.
|
|
"""
|
|
return self._global_avg
|
|
|
|
def values(self):
|
|
"""
|
|
Returns:
|
|
list[(number, iteration)]: content of the current buffer.
|
|
"""
|
|
return self._data
|