diff --git a/fastreid/evaluation/reid_evaluation.py b/fastreid/evaluation/reid_evaluation.py
index fa2b4d9..0ca6ec0 100644
--- a/fastreid/evaluation/reid_evaluation.py
+++ b/fastreid/evaluation/reid_evaluation.py
@@ -5,6 +5,7 @@
 """
 import copy
 import logging
+import itertools
 from collections import OrderedDict
 
 import numpy as np
@@ -28,50 +29,54 @@ class ReidEvaluator(DatasetEvaluator):
         self._num_query = num_query
         self._output_dir = output_dir
 
-        self.features = []
-        self.pids = []
-        self.camids = []
+        self._cpu_device = torch.device('cpu')
+
+        self._predictions = []
 
     def reset(self):
-        self.features = []
-        self.pids = []
-        self.camids = []
+        self._predictions = []
 
     def process(self, inputs, outputs):
-        self.pids.extend(inputs["targets"])
-        self.camids.extend(inputs["camids"])
-        self.features.append(outputs.cpu())
+        prediction = {
+            'feats': outputs.to(self._cpu_device, torch.float32),
+            'pids': inputs['targets'].to(self._cpu_device),
+            'camids': inputs['camids'].to(self._cpu_device)
+
+        }
+        self._predictions.append(prediction)
 
     def evaluate(self):
         if comm.get_world_size() > 1:
             comm.synchronize()
-            features = comm.gather(self.features)
-            features = sum(features, [])
+            predictions = comm.gather(self._predictions, dst=0)
+            predictions = list(itertools.chain(*predictions))
 
-            pids = comm.gather(self.pids)
-            pids = sum(pids, [])
+            if not comm.is_main_process():
+                return {}
 
-            camids = comm.gather(self.camids)
-            camids = sum(camids, [])
-
-            # fmt: off
-            if not comm.is_main_process(): return {}
-            # fmt: on
         else:
-            features = self.features
-            pids = self.pids
-            camids = self.camids
+            predictions = self._predictions
+
+        features = []
+        pids = []
+        camids = []
+        for prediction in predictions:
+            features.append(prediction['feats'])
+            pids.append(prediction['pids'])
+            camids.append(prediction['camids'])
 
         features = torch.cat(features, dim=0)
+        pids = torch.cat(pids, dim=0).numpy()
+        camids = torch.cat(camids, dim=0).numpy()
         # query feature, person ids and camera ids
         query_features = features[:self._num_query]
-        query_pids = np.asarray(pids[:self._num_query])
-        query_camids = np.asarray(camids[:self._num_query])
+        query_pids = pids[:self._num_query]
+        query_camids = camids[:self._num_query]
 
         # gallery features, person ids and camera ids
         gallery_features = features[self._num_query:]
-        gallery_pids = np.asarray(pids[self._num_query:])
-        gallery_camids = np.asarray(camids[self._num_query:])
+        gallery_pids = pids[self._num_query:]
+        gallery_camids = camids[self._num_query:]
 
         self._results = OrderedDict()