2020-09-02 18:49:39 +08:00

14 lines
305 B
Python

import torch
import torch.nn as nn
class Scale(nn.Module):
"""A learnable scale parameter."""
def __init__(self, scale=1.0):
super(Scale, self).__init__()
self.scale = nn.Parameter(torch.tensor(scale, dtype=torch.float))
def forward(self, x):
return x * self.scale