EasyCV/easycv/models/utils/scale.py

17 lines
364 B
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
import torch
import torch.nn as nn
class Scale(nn.Module):
"""
A learnable scale parameter
"""
def __init__(self, scale=1.0):
super(Scale, self).__init__()
self.scale = nn.Parameter(torch.tensor(scale, dtype=torch.float))
def forward(self, x):
return x * self.scale