mirror of https://github.com/alibaba/EasyCV.git
17 lines
364 B
Python
17 lines
364 B
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
|
|
class Scale(nn.Module):
|
|
"""
|
|
A learnable scale parameter
|
|
"""
|
|
|
|
def __init__(self, scale=1.0):
|
|
super(Scale, self).__init__()
|
|
self.scale = nn.Parameter(torch.tensor(scale, dtype=torch.float))
|
|
|
|
def forward(self, x):
|
|
return x * self.scale
|