[Fix]: Fix distributed setting for shape bias (#689)

* [Fix]: Fix dist bug

* [Fix]: Fix lint

* [Feature]: Add download log info

* [Fix]: Fix lint
This commit is contained in:
Yuan Liu 2023-02-08 14:23:59 +08:00 committed by GitHub
parent 0f634de85b
commit d3f57edf79
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 41 additions and 18 deletions

View File

@ -6,6 +6,7 @@ from typing import List, Sequence
import numpy as np
import torch
from mmengine.dist.utils import get_rank
from mmengine.evaluator import BaseMetric
from mmselfsup.registry import METRICS
@ -75,7 +76,8 @@ class ShapeBiasMetric(BaseMetric):
self.csv_dir = csv_dir
self.model_name = model_name
self.dataset_name = dataset_name
self.csv_path = self.create_csv()
if get_rank() == 0:
self.csv_path = self.create_csv()
def process(self, data_batch, data_samples: Sequence[dict]) -> None:
"""Process one batch of data samples.
@ -126,7 +128,7 @@ class ShapeBiasMetric(BaseMetric):
os.remove(csv_path)
directory = osp.dirname(csv_path)
if not osp.exists(directory):
os.makedirs(directory)
os.makedirs(directory, exist_ok=True)
with open(csv_path, 'w') as f:
writer = csv.writer(f)
writer.writerow([
@ -161,7 +163,8 @@ class ShapeBiasMetric(BaseMetric):
Returns:
dict: A dict of metrics.
"""
self.dump_results_to_csv(results)
if get_rank() == 0:
self.dump_results_to_csv(results)
metrics = dict()
metrics['accuracy/top1'] = np.mean([
result['pred_category'][0] == result['gt_category']

View File

@ -8,6 +8,7 @@ import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from mmengine.logging import MMLogger
from utils import FormatStrFormatter, ShapeBias
# global default boundary settings for thin gray transparent
@ -46,12 +47,24 @@ parser.add_argument(
help= # noqa
'the plotting names for the plots of each model, and they should be in the same order as model_names' # noqa: E501
)
parser.add_argument(
'--delete-icons',
action='store_true',
help='whether to delete the icons after plotting')
humans = [
'subject-01', 'subject-02', 'subject-03', 'subject-04', 'subject-05',
'subject-06', 'subject-07', 'subject-08', 'subject-09', 'subject-10'
]
icon_names = [
'airplane.png', 'response_icons_vertical_reverse.png', 'bottle.png',
'car.png', 'oven.png', 'elephant.png', 'dog.png', 'boat.png', 'clock.png',
'chair.png', 'keyboard.png', 'bird.png', 'bicycle.png',
'response_icons_horizontal.png', 'cat.png', 'bear.png', 'colorbar.pdf',
'knife.png', 'response_icons_vertical.png', 'truck.png'
]
def read_csvs(csv_dir: str) -> pd.DataFrame:
"""Reads all csv files in a directory and returns a single dataframe.
@ -83,7 +96,6 @@ def plot_shape_bias_matrixplot(args, analysis=ShapeBias()) -> None:
mpl.rcParams['font.serif'] = ['Times New Roman']
plt.figure(figsize=(9, 7))
df = read_csvs(args.csv_dir)
fontsize = 15
@ -218,19 +230,27 @@ def plot_shape_bias_matrixplot(args, analysis=ShapeBias()) -> None:
plt.close()
if __name__ == '__main__':
icon_names = [
'airplane.png', 'response_icons_vertical_reverse.png', 'bottle.png',
'car.png', 'oven.png', 'elephant.png', 'dog.png', 'boat.png',
'clock.png', 'chair.png', 'keyboard.png', 'bird.png', 'bicycle.png',
'response_icons_horizontal.png', 'cat.png', 'bear.png', 'colorbar.pdf',
'knife.png', 'response_icons_vertical.png', 'truck.png'
]
root_url = 'https://github.com/bethgelab/model-vs-human/raw/master/assets/icons' # noqa: E501
os.makedirs(ICONS_DIR, exist_ok=True)
def check_icons() -> bool:
"""Check if icons are present, if not download them."""
if not osp.exists(ICONS_DIR):
return False
for icon_name in icon_names:
url = osp.join(root_url, icon_name)
os.system('wget -O {} {}'.format(osp.join(ICONS_DIR, icon_name), url))
if not osp.exists(osp.join(ICONS_DIR, icon_name)):
return False
return True
if __name__ == '__main__':
if not check_icons():
root_url = 'https://github.com/bethgelab/model-vs-human/raw/master/assets/icons' # noqa: E501
os.makedirs(ICONS_DIR, exist_ok=True)
MMLogger.get_current_instance().info(
f'Downloading icons to {ICONS_DIR}')
for icon_name in icon_names:
url = osp.join(root_url, icon_name)
os.system('wget -O {} {}'.format(
osp.join(ICONS_DIR, icon_name), url))
args = parser.parse_args()
assert len(args.model_names) * 3 == len(args.colors), 'Number of colors \
@ -260,5 +280,5 @@ if __name__ == '__main__':
args.plotting_names.append('Humans')
plot_shape_bias_matrixplot(args)
os.system('rm -rf {}'.format(ICONS_DIR))
if args.delete_icons:
os.system('rm -rf {}'.format(ICONS_DIR))