mirror of
https://github.com/open-mmlab/mmselfsup.git
synced 2025-06-03 14:59:38 +08:00
[Fix]: Fix distributed setting for shape bias (#689)
* [Fix]: Fix dist bug * [Fix]: Fix lint * [Feature]: Add download log info * [Fix]: Fix lint
This commit is contained in:
parent
0f634de85b
commit
d3f57edf79
@ -6,6 +6,7 @@ from typing import List, Sequence
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmengine.dist.utils import get_rank
|
||||
from mmengine.evaluator import BaseMetric
|
||||
|
||||
from mmselfsup.registry import METRICS
|
||||
@ -75,7 +76,8 @@ class ShapeBiasMetric(BaseMetric):
|
||||
self.csv_dir = csv_dir
|
||||
self.model_name = model_name
|
||||
self.dataset_name = dataset_name
|
||||
self.csv_path = self.create_csv()
|
||||
if get_rank() == 0:
|
||||
self.csv_path = self.create_csv()
|
||||
|
||||
def process(self, data_batch, data_samples: Sequence[dict]) -> None:
|
||||
"""Process one batch of data samples.
|
||||
@ -126,7 +128,7 @@ class ShapeBiasMetric(BaseMetric):
|
||||
os.remove(csv_path)
|
||||
directory = osp.dirname(csv_path)
|
||||
if not osp.exists(directory):
|
||||
os.makedirs(directory)
|
||||
os.makedirs(directory, exist_ok=True)
|
||||
with open(csv_path, 'w') as f:
|
||||
writer = csv.writer(f)
|
||||
writer.writerow([
|
||||
@ -161,7 +163,8 @@ class ShapeBiasMetric(BaseMetric):
|
||||
Returns:
|
||||
dict: A dict of metrics.
|
||||
"""
|
||||
self.dump_results_to_csv(results)
|
||||
if get_rank() == 0:
|
||||
self.dump_results_to_csv(results)
|
||||
metrics = dict()
|
||||
metrics['accuracy/top1'] = np.mean([
|
||||
result['pred_category'][0] == result['gt_category']
|
||||
|
@ -8,6 +8,7 @@ import matplotlib as mpl
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from mmengine.logging import MMLogger
|
||||
from utils import FormatStrFormatter, ShapeBias
|
||||
|
||||
# global default boundary settings for thin gray transparent
|
||||
@ -46,12 +47,24 @@ parser.add_argument(
|
||||
help= # noqa
|
||||
'the plotting names for the plots of each model, and they should be in the same order as model_names' # noqa: E501
|
||||
)
|
||||
parser.add_argument(
|
||||
'--delete-icons',
|
||||
action='store_true',
|
||||
help='whether to delete the icons after plotting')
|
||||
|
||||
humans = [
|
||||
'subject-01', 'subject-02', 'subject-03', 'subject-04', 'subject-05',
|
||||
'subject-06', 'subject-07', 'subject-08', 'subject-09', 'subject-10'
|
||||
]
|
||||
|
||||
icon_names = [
|
||||
'airplane.png', 'response_icons_vertical_reverse.png', 'bottle.png',
|
||||
'car.png', 'oven.png', 'elephant.png', 'dog.png', 'boat.png', 'clock.png',
|
||||
'chair.png', 'keyboard.png', 'bird.png', 'bicycle.png',
|
||||
'response_icons_horizontal.png', 'cat.png', 'bear.png', 'colorbar.pdf',
|
||||
'knife.png', 'response_icons_vertical.png', 'truck.png'
|
||||
]
|
||||
|
||||
|
||||
def read_csvs(csv_dir: str) -> pd.DataFrame:
|
||||
"""Reads all csv files in a directory and returns a single dataframe.
|
||||
@ -83,7 +96,6 @@ def plot_shape_bias_matrixplot(args, analysis=ShapeBias()) -> None:
|
||||
mpl.rcParams['font.serif'] = ['Times New Roman']
|
||||
|
||||
plt.figure(figsize=(9, 7))
|
||||
|
||||
df = read_csvs(args.csv_dir)
|
||||
|
||||
fontsize = 15
|
||||
@ -218,19 +230,27 @@ def plot_shape_bias_matrixplot(args, analysis=ShapeBias()) -> None:
|
||||
plt.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
icon_names = [
|
||||
'airplane.png', 'response_icons_vertical_reverse.png', 'bottle.png',
|
||||
'car.png', 'oven.png', 'elephant.png', 'dog.png', 'boat.png',
|
||||
'clock.png', 'chair.png', 'keyboard.png', 'bird.png', 'bicycle.png',
|
||||
'response_icons_horizontal.png', 'cat.png', 'bear.png', 'colorbar.pdf',
|
||||
'knife.png', 'response_icons_vertical.png', 'truck.png'
|
||||
]
|
||||
root_url = 'https://github.com/bethgelab/model-vs-human/raw/master/assets/icons' # noqa: E501
|
||||
os.makedirs(ICONS_DIR, exist_ok=True)
|
||||
def check_icons() -> bool:
|
||||
"""Check if icons are present, if not download them."""
|
||||
if not osp.exists(ICONS_DIR):
|
||||
return False
|
||||
for icon_name in icon_names:
|
||||
url = osp.join(root_url, icon_name)
|
||||
os.system('wget -O {} {}'.format(osp.join(ICONS_DIR, icon_name), url))
|
||||
if not osp.exists(osp.join(ICONS_DIR, icon_name)):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
if not check_icons():
|
||||
root_url = 'https://github.com/bethgelab/model-vs-human/raw/master/assets/icons' # noqa: E501
|
||||
os.makedirs(ICONS_DIR, exist_ok=True)
|
||||
MMLogger.get_current_instance().info(
|
||||
f'Downloading icons to {ICONS_DIR}')
|
||||
for icon_name in icon_names:
|
||||
url = osp.join(root_url, icon_name)
|
||||
os.system('wget -O {} {}'.format(
|
||||
osp.join(ICONS_DIR, icon_name), url))
|
||||
|
||||
args = parser.parse_args()
|
||||
assert len(args.model_names) * 3 == len(args.colors), 'Number of colors \
|
||||
@ -260,5 +280,5 @@ if __name__ == '__main__':
|
||||
args.plotting_names.append('Humans')
|
||||
|
||||
plot_shape_bias_matrixplot(args)
|
||||
|
||||
os.system('rm -rf {}'.format(ICONS_DIR))
|
||||
if args.delete_icons:
|
||||
os.system('rm -rf {}'.format(ICONS_DIR))
|
||||
|
Loading…
x
Reference in New Issue
Block a user