mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Fix] Fix offline_evaluate index error (#630)
* [Fix] Fix offline eval dataset index error. * update * update
This commit is contained in:
parent
f2b0540f58
commit
a9a575866f
@ -104,12 +104,6 @@ class Evaluator:
|
||||
"""
|
||||
|
||||
# support chunking iterable objects
|
||||
if data is not None:
|
||||
assert len(data_samples) == len(data), (
|
||||
'outputs and data should have the same length, but got '
|
||||
f'outputs length: {len(data_samples)} '
|
||||
f'data length: {len(data)}')
|
||||
|
||||
def get_chunks(seq: Iterator, chunk_size=1):
|
||||
stop = False
|
||||
while not stop:
|
||||
@ -123,10 +117,17 @@ class Evaluator:
|
||||
if chunk:
|
||||
yield chunk
|
||||
|
||||
if data is not None:
|
||||
assert len(data_samples) == len(data), (
|
||||
'data_samples and data should have the same length, but got '
|
||||
f'data_samples length: {len(data_samples)} '
|
||||
f'data length: {len(data)}')
|
||||
data = get_chunks(iter(data), chunk_size)
|
||||
|
||||
size = 0
|
||||
for output_chunk in get_chunks(iter(data_samples), chunk_size):
|
||||
if data is not None:
|
||||
data_chunk = pseudo_collate(data[size:size + chunk_size])
|
||||
data_chunk = pseudo_collate(next(data)) # type: ignore
|
||||
else:
|
||||
data_chunk = None
|
||||
size += len(output_chunk)
|
||||
|
@ -247,7 +247,7 @@ class TestEvaluator(TestCase):
|
||||
all_data = [dict() for _ in range(9)]
|
||||
with self.assertRaisesRegex(
|
||||
AssertionError,
|
||||
'outputs and data should have the same length'):
|
||||
'data_samples and data should have the same length'):
|
||||
evaluator.offline_evaluate(all_predictions, all_data)
|
||||
|
||||
@unittest.skipUnless(torch.cuda.is_available(), 'can only run with gpu')
|
||||
|
Loading…
x
Reference in New Issue
Block a user