import torch import torch.nn as nn class Scale(nn.Module): """A learnable scale parameter.""" def __init__(self, scale=1.0): super(Scale, self).__init__() self.scale = nn.Parameter(torch.tensor(scale, dtype=torch.float)) def forward(self, x): return x * self.scale