# Copyright (c) OpenMMLab. All rights reserved.
# Modified from https://github.com/bethgelab/model-vs-human
from typing import Any, Dict, List, Optional

import matplotlib as mpl
import pandas as pd
from matplotlib import _api
from matplotlib import transforms as mtransforms


class _DummyAxis:
    """Define the minimal interface for a dummy axis.

    Args:
        minpos (float): The minimum positive value for the axis. Defaults to 0.
    """
    __name__ = 'dummy'

    # Once the deprecation elapses, replace dataLim and viewLim by plain
    # _view_interval and _data_interval private tuples.
    dataLim = _api.deprecate_privatize_attribute(
        '3.6', alternative='get_data_interval() and set_data_interval()')
    viewLim = _api.deprecate_privatize_attribute(
        '3.6', alternative='get_view_interval() and set_view_interval()')

    def __init__(self, minpos: float = 0) -> None:
        self._dataLim = mtransforms.Bbox.unit()
        self._viewLim = mtransforms.Bbox.unit()
        self._minpos = minpos

    def get_view_interval(self) -> Dict:
        """Return the view interval as a tuple (*vmin*, *vmax*)."""
        return self._viewLim.intervalx

    def set_view_interval(self, vmin: float, vmax: float) -> None:
        """Set the view interval to (*vmin*, *vmax*)."""
        self._viewLim.intervalx = vmin, vmax

    def get_minpos(self) -> float:
        """Return the minimum positive value for the axis."""
        return self._minpos

    def get_data_interval(self) -> Dict:
        """Return the data interval as a tuple (*vmin*, *vmax*)."""
        return self._dataLim.intervalx

    def set_data_interval(self, vmin: float, vmax: float) -> None:
        """Set the data interval to (*vmin*, *vmax*)."""
        self._dataLim.intervalx = vmin, vmax

    def get_tick_space(self) -> int:
        """Return the number of ticks to use."""
        # Just use the long-standing default of nbins==9
        return 9


class TickHelper:
    """A helper class for ticks and tick labels."""
    axis = None

    def set_axis(self, axis: Any) -> None:
        """Set the axis instance."""
        self.axis = axis

    def create_dummy_axis(self, **kwargs) -> None:
        """Create a dummy axis if no axis is set."""
        if self.axis is None:
            self.axis = _DummyAxis(**kwargs)

    @_api.deprecated('3.5', alternative='`.Axis.set_view_interval`')
    def set_view_interval(self, vmin: float, vmax: float) -> None:
        """Set the view interval to (*vmin*, *vmax*)."""
        self.axis.set_view_interval(vmin, vmax)

    @_api.deprecated('3.5', alternative='`.Axis.set_data_interval`')
    def set_data_interval(self, vmin: float, vmax: float) -> None:
        """Set the data interval to (*vmin*, *vmax*)."""
        self.axis.set_data_interval(vmin, vmax)

    @_api.deprecated(
        '3.5',
        alternative='`.Axis.set_view_interval` and `.Axis.set_data_interval`')
    def set_bounds(self, vmin: float, vmax: float) -> None:
        """Set the view and data interval to (*vmin*, *vmax*)."""
        self.set_view_interval(vmin, vmax)
        self.set_data_interval(vmin, vmax)


class Formatter(TickHelper):
    """Create a string based on a tick value and location."""
    # some classes want to see all the locs to help format
    # individual ones
    locs = []

    def __call__(self, x: str, pos: Optional[Any] = None) -> str:
        """Return the format for tick value *x* at position pos.

        ``pos=None`` indicates an unspecified location.

        This method must be overridden in the derived class.

        Args:
            x (str): The tick value.
            pos (Optional[Any]): The tick position. Defaults to None.
        """
        raise NotImplementedError('Derived must override')

    def format_ticks(self, values: pd.Series) -> List[str]:
        """Return the tick labels for all the ticks at once.

        Args:
            values (pd.Series): The tick values.

        Returns:
            List[str]: The tick labels.
        """
        self.set_locs(values)
        return [self(value, i) for i, value in enumerate(values)]

    def format_data(self, value: Any) -> str:
        """Return the full string representation of the value with the position
        unspecified.

        Args:
            value (Any): The tick value.

        Returns:
            str: The full string representation of the value.
        """
        return self.__call__(value)

    def format_data_short(self, value: Any) -> str:
        """Return a short string version of the tick value.

        Defaults to the position-independent long value.

        Args:
            value (Any): The tick value.

        Returns:
            str: The short string representation of the value.
        """
        return self.format_data(value)

    def get_offset(self) -> str:
        """Return the offset string."""
        return ''

    def set_locs(self, locs: List[Any]) -> None:
        """Set the locations of the ticks.

        This method is called before computing the tick labels because some
        formatters need to know all tick locations to do so.
        """
        self.locs = locs

    @staticmethod
    def fix_minus(s: str) -> str:
        """Some classes may want to replace a hyphen for minus with the proper
        Unicode symbol (U+2212) for typographical correctness.

        This is a
        helper method to perform such a replacement when it is enabled via
        :rc:`axes.unicode_minus`.

        Args:
            s (str): The string to replace the hyphen with the Unicode symbol.
        """
        return (s.replace('-', '\N{MINUS SIGN}')
                if mpl.rcParams['axes.unicode_minus'] else s)

    def _set_locator(self, locator: Any) -> None:
        """Subclasses may want to override this to set a locator."""
        pass


class FormatStrFormatter(Formatter):
    """Use an old-style ('%' operator) format string to format the tick.

    The format string should have a single variable format (%) in it.
    It will be applied to the value (not the position) of the tick.

    Negative numeric values will use a dash, not a Unicode minus; use mathtext
    to get a Unicode minus by wrapping the format specifier with $ (e.g.
    "$%g$").

    Args:
        fmt (str): Format string.
    """

    def __init__(self, fmt: str) -> None:
        self.fmt = fmt

    def __call__(self, x: str, pos: Optional[Any]) -> str:
        """Return the formatted label string.

        Only the value *x* is formatted. The position is ignored.

        Args:
            x (str): The value to format.
            pos (Any): The position of the tick. Ignored.
        """
        return self.fmt % x


class ShapeBias:
    """Compute the shape bias of a model.

    Reference: `ImageNet-trained CNNs are biased towards texture;
    increasing shape bias improves accuracy and robustness
    <https://arxiv.org/abs/1811.12231>`_.
    """
    num_input_models = 1

    def __init__(self) -> None:
        super().__init__()
        self.plotting_name = 'shape-bias'

    @staticmethod
    def _check_dataframe(df: pd.DataFrame) -> None:
        """Check that the dataframe is valid."""
        assert len(df) > 0, 'empty dataframe'

    def analysis(self, df: pd.DataFrame) -> Dict[str, float]:
        """Compute the shape bias of a model.

        Args:
            df (pd.DataFrame): The dataframe containing the data.

        Returns:
            Dict[str, float]: The shape bias.
        """
        self._check_dataframe(df)

        df = df.copy()
        df['correct_texture'] = df['imagename'].apply(
            self.get_texture_category)
        df['correct_shape'] = df['category']

        # remove those rows where shape = texture, i.e. no cue conflict present
        df2 = df.loc[df.correct_shape != df.correct_texture]
        fraction_correct_shape = len(
            df2.loc[df2.object_response == df2.correct_shape]) / len(df)
        fraction_correct_texture = len(
            df2.loc[df2.object_response == df2.correct_texture]) / len(df)
        shape_bias = fraction_correct_shape / (
            fraction_correct_shape + fraction_correct_texture)

        result_dict = {
            'fraction-correct-shape': fraction_correct_shape,
            'fraction-correct-texture': fraction_correct_texture,
            'shape-bias': shape_bias
        }
        return result_dict

    def get_texture_category(self, imagename: str) -> str:
        """Return texture category from imagename.

        e.g. 'XXX_dog10-bird2.png' -> 'bird '

        Args:
            imagename (str): Name of the image.

        Returns:
            str: Texture category.
        """
        assert type(imagename) is str

        # remove unnecessary words
        a = imagename.split('_')[-1]
        # remove .png etc.
        b = a.split('.')[0]
        # get texture category (last word)
        c = b.split('-')[-1]
        # remove number, e.g. 'bird2' -> 'bird'
        d = ''.join([i for i in c if not i.isdigit()])
        return d