mirror of
https://github.com/open-mmlab/mmselfsup.git
synced 2025-06-03 14:59:38 +08:00
14 lines
305 B
Python
14 lines
305 B
Python
import torch
|
|
import torch.nn as nn
|
|
|
|
|
|
class Scale(nn.Module):
|
|
"""A learnable scale parameter."""
|
|
|
|
def __init__(self, scale=1.0):
|
|
super(Scale, self).__init__()
|
|
self.scale = nn.Parameter(torch.tensor(scale, dtype=torch.float))
|
|
|
|
def forward(self, x):
|
|
return x * self.scale
|