# -*- coding: utf-8 -*- # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved from torch import nn from .batch_norm import FrozenBatchNorm2d class CNNBlockBase(nn.Module): """ A CNN block is assumed to have input channels, output channels and a stride. The input and output of `forward()` method must be NCHW tensors. The method can perform arbitrary computation but must match the given channels and stride specification. Attribute: in_channels (int): out_channels (int): stride (int): """ def __init__(self, in_channels, out_channels, stride): """ The `__init__` method of any subclass should also contain these arguments. Args: in_channels (int): out_channels (int): stride (int): """ super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.stride = stride def freeze(self): """ Make this block not trainable. This method sets all parameters to `requires_grad=False`, and convert all BatchNorm layers to FrozenBatchNorm Returns: the block itself """ for p in self.parameters(): p.requires_grad = False FrozenBatchNorm2d.convert_frozen_batchnorm(self) return self