# Copyright (c) Alibaba, Inc. and its affiliates. import torch import torch.nn as nn class Scale(nn.Module): """ A learnable scale parameter """ def __init__(self, scale=1.0): super(Scale, self).__init__() self.scale = nn.Parameter(torch.tensor(scale, dtype=torch.float)) def forward(self, x): return x * self.scale