from torch import nn as nn
from torch.nn import functional as F

from .adaptive_avgmax_pool import SelectAdaptivePool2d


class ClassifierHead(nn.Module):
    """Classifier Head w/ configurable global pooling and dropout."""

    def __init__(self, in_chs, num_classes, pool_type='avg', drop_rate=0.):
        super(ClassifierHead, self).__init__()
        self.drop_rate = drop_rate
        self.global_pool = SelectAdaptivePool2d(pool_type=pool_type)
        if num_classes > 0:
            self.fc = nn.Linear(in_chs * self.global_pool.feat_mult(), num_classes, bias=True)
        else:
            self.fc = nn.Identity()

    def forward(self, x):
        x = self.global_pool(x).flatten(1)
        if self.drop_rate:
            x = F.dropout(x, p=float(self.drop_rate), training=self.training)
        x = self.fc(x)
        return x