mirror of
https://github.com/alibaba/EasyCV.git
synced 2025-06-03 14:49:00 +08:00
22 lines
550 B
Python
22 lines
550 B
Python
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
|
||
|
|
||
|
class FReLU(nn.Module):
|
||
|
|
||
|
def __init__(self, in_channel):
|
||
|
super().__init__()
|
||
|
self.depthwise_conv_bn = nn.Sequential(
|
||
|
nn.Conv2d(
|
||
|
in_channel,
|
||
|
in_channel,
|
||
|
3,
|
||
|
padding=1,
|
||
|
groups=in_channel,
|
||
|
bias=False), nn.BatchNorm2d(in_channel))
|
||
|
|
||
|
def forward(self, x):
|
||
|
funnel_x = self.depthwise_conv_bn(x)
|
||
|
return torch.max(x, funnel_x)
|