EasyCV/easycv/models/utils/scale.py

17 lines
364 B
Python
Raw Normal View History

2022-04-02 20:01:06 +08:00
# Copyright (c) Alibaba, Inc. and its affiliates.
import torch
import torch.nn as nn
class Scale(nn.Module):
"""
A learnable scale parameter
"""
def __init__(self, scale=1.0):
super(Scale, self).__init__()
self.scale = nn.Parameter(torch.tensor(scale, dtype=torch.float))
def forward(self, x):
return x * self.scale