14 lines
305 B
Python
Raw Normal View History

2020-06-16 00:05:18 +08:00
import torch
import torch.nn as nn
class Scale(nn.Module):
2020-09-02 18:49:39 +08:00
"""A learnable scale parameter."""
2020-06-16 00:05:18 +08:00
def __init__(self, scale=1.0):
super(Scale, self).__init__()
self.scale = nn.Parameter(torch.tensor(scale, dtype=torch.float))
def forward(self, x):
return x * self.scale