parent
597b8a61c7
commit
b8f42c70fa
|
@ -0,0 +1,36 @@
|
|||
# model settings
|
||||
norm_cfg = dict(type='SyncBN', requires_grad=True)
|
||||
model = dict(
|
||||
type='EncoderDecoder',
|
||||
pretrained='open-mmlab://resnet50_v1c',
|
||||
backbone=dict(
|
||||
type='ResNetV1c',
|
||||
depth=50,
|
||||
num_stages=4,
|
||||
out_indices=(0, 1, 2, 3),
|
||||
dilations=(1, 1, 1, 1),
|
||||
strides=(1, 2, 2, 2),
|
||||
norm_cfg=norm_cfg,
|
||||
norm_eval=False,
|
||||
style='pytorch',
|
||||
contract_dilation=True),
|
||||
neck=dict(
|
||||
type='FPN',
|
||||
in_channels=[256, 512, 1024, 2048],
|
||||
out_channels=256,
|
||||
num_outs=4),
|
||||
decode_head=dict(
|
||||
type='FPNHead',
|
||||
in_channels=[256, 256, 256, 256],
|
||||
in_index=[0, 1, 2, 3],
|
||||
feature_strides=[4, 8, 16, 32],
|
||||
channels=128,
|
||||
dropout_ratio=0.1,
|
||||
num_classes=19,
|
||||
norm_cfg=norm_cfg,
|
||||
align_corners=False,
|
||||
loss_decode=dict(
|
||||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)))
|
||||
# model training and testing settings
|
||||
train_cfg = dict()
|
||||
test_cfg = dict(mode='whole')
|
|
@ -0,0 +1,30 @@
|
|||
# Panoptic Feature Pyramid Networks
|
||||
|
||||
## Introduction
|
||||
```
|
||||
@article{Kirillov_2019,
|
||||
title={Panoptic Feature Pyramid Networks},
|
||||
ISBN={9781728132938},
|
||||
url={http://dx.doi.org/10.1109/CVPR.2019.00656},
|
||||
DOI={10.1109/cvpr.2019.00656},
|
||||
journal={2019 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
|
||||
publisher={IEEE},
|
||||
author={Kirillov, Alexander and Girshick, Ross and He, Kaiming and Dollar, Piotr},
|
||||
year={2019},
|
||||
month={Jun}
|
||||
}
|
||||
```
|
||||
|
||||
## Results and models
|
||||
|
||||
### Cityscapes
|
||||
| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | download |
|
||||
|--------|----------|-----------|--------:|---------:|----------------|------:|---------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| FPN | R-50 | 512x1024 | 80000 | 2.8 | 13.54 | 74.52 | 76.08 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/sem_fpn/fpn_r50_512x1024_80k_cityscapes/fpn_r50_512x1024_80k_cityscapes_20200717_021437-94018a0d.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/sem_fpn/fpn_r50_512x1024_80k_cityscapes/fpn_r50_512x1024_80k_cityscapes-20200717_021437.log.json) |
|
||||
| FPN | R-101 | 512x1024 | 80000 | 3.9 | 10.29 | 75.80 | 77.40 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/sem_fpn/fpn_r101_512x1024_80k_cityscapes/fpn_r101_512x1024_80k_cityscapes_20200717_012416-c5800d4c.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/sem_fpn/fpn_r101_512x1024_80k_cityscapes/fpn_r101_512x1024_80k_cityscapes-20200717_012416.log.json) |
|
||||
|
||||
### ADE20K
|
||||
| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | download |
|
||||
|--------|----------|-----------|--------:|---------:|----------------|------:|---------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| FPN | R-50 | 512x512 | 160000 | 4.9 | 55.77 | 37.49 | 39.09 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/sem_fpn/fpn_r50_512x512_160k_ade20k/fpn_r50_512x512_160k_ade20k_20200718_131734-5b5a6ab9.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/sem_fpn/fpn_r50_512x512_160k_ade20k/fpn_r50_512x512_160k_ade20k-20200718_131734.log.json) |
|
||||
| FPN | R-101 | 512x512 | 160000 | 5.9 | 40.58 | 39.35 | 40.72 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/sem_fpn/fpn_r101_512x512_160k_ade20k/fpn_r101_512x512_160k_ade20k_20200718_131734-306b5004.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/sem_fpn/fpn_r101_512x512_160k_ade20k/fpn_r101_512x512_160k_ade20k-20200718_131734.log.json) |
|
|
@ -0,0 +1,2 @@
|
|||
_base_ = './fpn_r50_512x1024_80k_cityscapes.py'
|
||||
model = dict(pretrained='open-mmlab://resnet101_v1c', backbone=dict(depth=101))
|
|
@ -0,0 +1,2 @@
|
|||
_base_ = './fpn_r50_512x512_160k_ade20k.py'
|
||||
model = dict(pretrained='open-mmlab://resnet101_v1c', backbone=dict(depth=101))
|
|
@ -0,0 +1,4 @@
|
|||
_base_ = [
|
||||
'../_base_/models/fpn_r50.py', '../_base_/datasets/cityscapes.py',
|
||||
'../_base_/default_runtime.py', '../_base_/schedules/schedule_80k.py'
|
||||
]
|
|
@ -0,0 +1,5 @@
|
|||
_base_ = [
|
||||
'../_base_/models/fpn_r50.py', '../_base_/datasets/ade20k.py',
|
||||
'../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py'
|
||||
]
|
||||
model = dict(decode_head=dict(num_classes=150))
|
|
@ -3,6 +3,7 @@ from .builder import (BACKBONES, HEADS, LOSSES, SEGMENTORS, build_backbone,
|
|||
build_head, build_loss, build_segmentor)
|
||||
from .decode_heads import * # noqa: F401,F403
|
||||
from .losses import * # noqa: F401,F403
|
||||
from .necks import * # noqa: F401,F403
|
||||
from .segmentors import * # noqa: F401,F403
|
||||
|
||||
__all__ = [
|
||||
|
|
|
@ -4,6 +4,7 @@ from .cc_head import CCHead
|
|||
from .da_head import DAHead
|
||||
from .enc_head import EncHead
|
||||
from .fcn_head import FCNHead
|
||||
from .fpn_head import FPNHead
|
||||
from .gc_head import GCHead
|
||||
from .nl_head import NLHead
|
||||
from .ocr_head import OCRHead
|
||||
|
@ -16,5 +17,5 @@ from .uper_head import UPerHead
|
|||
__all__ = [
|
||||
'FCNHead', 'PSPHead', 'ASPPHead', 'PSAHead', 'NLHead', 'GCHead', 'CCHead',
|
||||
'UPerHead', 'DepthwiseSeparableASPPHead', 'ANNHead', 'DAHead', 'OCRHead',
|
||||
'EncHead', 'DepthwiseSeparableFCNHead'
|
||||
'EncHead', 'DepthwiseSeparableFCNHead', 'FPNHead'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,68 @@
|
|||
import numpy as np
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
from mmseg.ops import resize
|
||||
from ..builder import HEADS
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
@HEADS.register_module()
|
||||
class FPNHead(BaseDecodeHead):
|
||||
"""Panoptic Feature Pyramid Networks.
|
||||
|
||||
This head is the implementation of `Semantic FPN
|
||||
<https://arxiv.org/abs/1901.02446>`_.
|
||||
|
||||
Args:
|
||||
feature_strides (tuple[int]): The strides for input feature maps.
|
||||
stack_lateral. All strides suppose to be power of 2. The first
|
||||
one is of largest resolution.
|
||||
"""
|
||||
|
||||
def __init__(self, feature_strides, **kwargs):
|
||||
super(FPNHead, self).__init__(
|
||||
input_transform='multiple_select', **kwargs)
|
||||
assert len(feature_strides) == len(self.in_channels)
|
||||
assert min(feature_strides) == feature_strides[0]
|
||||
self.feature_strides = feature_strides
|
||||
|
||||
self.scale_heads = nn.ModuleList()
|
||||
for i in range(len(feature_strides)):
|
||||
head_length = max(
|
||||
1,
|
||||
int(np.log2(feature_strides[i]) - np.log2(feature_strides[0])))
|
||||
scale_head = []
|
||||
for k in range(head_length):
|
||||
scale_head.append(
|
||||
ConvModule(
|
||||
self.in_channels[i] if k == 0 else self.channels,
|
||||
self.channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg))
|
||||
if feature_strides[i] != feature_strides[0]:
|
||||
scale_head.append(
|
||||
nn.Upsample(
|
||||
scale_factor=2,
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners))
|
||||
self.scale_heads.append(nn.Sequential(*scale_head))
|
||||
|
||||
def forward(self, inputs):
|
||||
|
||||
x = self._transform_inputs(inputs)
|
||||
|
||||
output = self.scale_heads[0](x[0])
|
||||
for i in range(1, len(self.feature_strides)):
|
||||
# non inplace
|
||||
output = output + resize(
|
||||
self.scale_heads[i](x[i]),
|
||||
size=output.shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
|
||||
output = self.cls_seg(output)
|
||||
return output
|
|
@ -0,0 +1,3 @@
|
|||
from .fpn import FPN
|
||||
|
||||
__all__ = ['FPN']
|
|
@ -0,0 +1,212 @@
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import ConvModule, xavier_init
|
||||
|
||||
from ..builder import NECKS
|
||||
|
||||
|
||||
@NECKS.register_module()
|
||||
class FPN(nn.Module):
|
||||
"""Feature Pyramid Network.
|
||||
|
||||
This is an implementation of - Feature Pyramid Networks for Object
|
||||
Detection (https://arxiv.org/abs/1612.03144)
|
||||
|
||||
Args:
|
||||
in_channels (List[int]): Number of input channels per scale.
|
||||
out_channels (int): Number of output channels (used at each scale)
|
||||
num_outs (int): Number of output scales.
|
||||
start_level (int): Index of the start input backbone level used to
|
||||
build the feature pyramid. Default: 0.
|
||||
end_level (int): Index of the end input backbone level (exclusive) to
|
||||
build the feature pyramid. Default: -1, which means the last level.
|
||||
add_extra_convs (bool | str): If bool, it decides whether to add conv
|
||||
layers on top of the original feature maps. Default to False.
|
||||
If True, its actual mode is specified by `extra_convs_on_inputs`.
|
||||
If str, it specifies the source feature map of the extra convs.
|
||||
Only the following options are allowed
|
||||
|
||||
- 'on_input': Last feat map of neck inputs (i.e. backbone feature).
|
||||
- 'on_lateral': Last feature map after lateral convs.
|
||||
- 'on_output': The last output feature map after fpn convs.
|
||||
extra_convs_on_inputs (bool, deprecated): Whether to apply extra convs
|
||||
on the original feature from the backbone. If True,
|
||||
it is equivalent to `add_extra_convs='on_input'`. If False, it is
|
||||
equivalent to set `add_extra_convs='on_output'`. Default to True.
|
||||
relu_before_extra_convs (bool): Whether to apply relu before the extra
|
||||
conv. Default: False.
|
||||
no_norm_on_lateral (bool): Whether to apply norm on lateral.
|
||||
Default: False.
|
||||
conv_cfg (dict): Config dict for convolution layer. Default: None.
|
||||
norm_cfg (dict): Config dict for normalization layer. Default: None.
|
||||
act_cfg (str): Config dict for activation layer in ConvModule.
|
||||
Default: None.
|
||||
upsample_cfg (dict): Config dict for interpolate layer.
|
||||
Default: `dict(mode='nearest')`
|
||||
|
||||
Example:
|
||||
>>> import torch
|
||||
>>> in_channels = [2, 3, 5, 7]
|
||||
>>> scales = [340, 170, 84, 43]
|
||||
>>> inputs = [torch.rand(1, c, s, s)
|
||||
... for c, s in zip(in_channels, scales)]
|
||||
>>> self = FPN(in_channels, 11, len(in_channels)).eval()
|
||||
>>> outputs = self.forward(inputs)
|
||||
>>> for i in range(len(outputs)):
|
||||
... print(f'outputs[{i}].shape = {outputs[i].shape}')
|
||||
outputs[0].shape = torch.Size([1, 11, 340, 340])
|
||||
outputs[1].shape = torch.Size([1, 11, 170, 170])
|
||||
outputs[2].shape = torch.Size([1, 11, 84, 84])
|
||||
outputs[3].shape = torch.Size([1, 11, 43, 43])
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
num_outs,
|
||||
start_level=0,
|
||||
end_level=-1,
|
||||
add_extra_convs=False,
|
||||
extra_convs_on_inputs=False,
|
||||
relu_before_extra_convs=False,
|
||||
no_norm_on_lateral=False,
|
||||
conv_cfg=None,
|
||||
norm_cfg=None,
|
||||
act_cfg=None,
|
||||
upsample_cfg=dict(mode='nearest')):
|
||||
super(FPN, self).__init__()
|
||||
assert isinstance(in_channels, list)
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.num_ins = len(in_channels)
|
||||
self.num_outs = num_outs
|
||||
self.relu_before_extra_convs = relu_before_extra_convs
|
||||
self.no_norm_on_lateral = no_norm_on_lateral
|
||||
self.fp16_enabled = False
|
||||
self.upsample_cfg = upsample_cfg.copy()
|
||||
|
||||
if end_level == -1:
|
||||
self.backbone_end_level = self.num_ins
|
||||
assert num_outs >= self.num_ins - start_level
|
||||
else:
|
||||
# if end_level < inputs, no extra level is allowed
|
||||
self.backbone_end_level = end_level
|
||||
assert end_level <= len(in_channels)
|
||||
assert num_outs == end_level - start_level
|
||||
self.start_level = start_level
|
||||
self.end_level = end_level
|
||||
self.add_extra_convs = add_extra_convs
|
||||
assert isinstance(add_extra_convs, (str, bool))
|
||||
if isinstance(add_extra_convs, str):
|
||||
# Extra_convs_source choices: 'on_input', 'on_lateral', 'on_output'
|
||||
assert add_extra_convs in ('on_input', 'on_lateral', 'on_output')
|
||||
elif add_extra_convs: # True
|
||||
if extra_convs_on_inputs:
|
||||
# For compatibility with previous release
|
||||
# TODO: deprecate `extra_convs_on_inputs`
|
||||
self.add_extra_convs = 'on_input'
|
||||
else:
|
||||
self.add_extra_convs = 'on_output'
|
||||
|
||||
self.lateral_convs = nn.ModuleList()
|
||||
self.fpn_convs = nn.ModuleList()
|
||||
|
||||
for i in range(self.start_level, self.backbone_end_level):
|
||||
l_conv = ConvModule(
|
||||
in_channels[i],
|
||||
out_channels,
|
||||
1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg if not self.no_norm_on_lateral else None,
|
||||
act_cfg=act_cfg,
|
||||
inplace=False)
|
||||
fpn_conv = ConvModule(
|
||||
out_channels,
|
||||
out_channels,
|
||||
3,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
inplace=False)
|
||||
|
||||
self.lateral_convs.append(l_conv)
|
||||
self.fpn_convs.append(fpn_conv)
|
||||
|
||||
# add extra conv layers (e.g., RetinaNet)
|
||||
extra_levels = num_outs - self.backbone_end_level + self.start_level
|
||||
if self.add_extra_convs and extra_levels >= 1:
|
||||
for i in range(extra_levels):
|
||||
if i == 0 and self.add_extra_convs == 'on_input':
|
||||
in_channels = self.in_channels[self.backbone_end_level - 1]
|
||||
else:
|
||||
in_channels = out_channels
|
||||
extra_fpn_conv = ConvModule(
|
||||
in_channels,
|
||||
out_channels,
|
||||
3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
inplace=False)
|
||||
self.fpn_convs.append(extra_fpn_conv)
|
||||
|
||||
# default init_weights for conv(msra) and norm in ConvModule
|
||||
def init_weights(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
xavier_init(m, distribution='uniform')
|
||||
|
||||
def forward(self, inputs):
|
||||
assert len(inputs) == len(self.in_channels)
|
||||
|
||||
# build laterals
|
||||
laterals = [
|
||||
lateral_conv(inputs[i + self.start_level])
|
||||
for i, lateral_conv in enumerate(self.lateral_convs)
|
||||
]
|
||||
|
||||
# build top-down path
|
||||
used_backbone_levels = len(laterals)
|
||||
for i in range(used_backbone_levels - 1, 0, -1):
|
||||
# In some cases, fixing `scale factor` (e.g. 2) is preferred, but
|
||||
# it cannot co-exist with `size` in `F.interpolate`.
|
||||
if 'scale_factor' in self.upsample_cfg:
|
||||
laterals[i - 1] += F.interpolate(laterals[i],
|
||||
**self.upsample_cfg)
|
||||
else:
|
||||
prev_shape = laterals[i - 1].shape[2:]
|
||||
laterals[i - 1] += F.interpolate(
|
||||
laterals[i], size=prev_shape, **self.upsample_cfg)
|
||||
|
||||
# build outputs
|
||||
# part 1: from original levels
|
||||
outs = [
|
||||
self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels)
|
||||
]
|
||||
# part 2: add extra levels
|
||||
if self.num_outs > len(outs):
|
||||
# use max pool to get more levels on top of outputs
|
||||
# (e.g., Faster R-CNN, Mask R-CNN)
|
||||
if not self.add_extra_convs:
|
||||
for i in range(self.num_outs - used_backbone_levels):
|
||||
outs.append(F.max_pool2d(outs[-1], 1, stride=2))
|
||||
# add conv layers on top of original feature maps (RetinaNet)
|
||||
else:
|
||||
if self.add_extra_convs == 'on_input':
|
||||
extra_source = inputs[self.backbone_end_level - 1]
|
||||
elif self.add_extra_convs == 'on_lateral':
|
||||
extra_source = laterals[-1]
|
||||
elif self.add_extra_convs == 'on_output':
|
||||
extra_source = outs[-1]
|
||||
else:
|
||||
raise NotImplementedError
|
||||
outs.append(self.fpn_convs[used_backbone_levels](extra_source))
|
||||
for i in range(used_backbone_levels + 1, self.num_outs):
|
||||
if self.relu_before_extra_convs:
|
||||
outs.append(self.fpn_convs[i](F.relu(outs[-1])))
|
||||
else:
|
||||
outs.append(self.fpn_convs[i](outs[-1]))
|
||||
return tuple(outs)
|
|
@ -153,6 +153,10 @@ def test_encnet_forward():
|
|||
'encnet/encnet_r50-d8_512x1024_40k_cityscapes.py')
|
||||
|
||||
|
||||
def test_sem_fpn_forward():
|
||||
_test_encoder_decoder_forward('sem_fpn/fpn_r50_512x1024_80k_cityscapes.py')
|
||||
|
||||
|
||||
def get_world_size(process_group):
|
||||
|
||||
return 1
|
||||
|
|
|
@ -6,8 +6,7 @@ from mmcv.cnn import ConvModule
|
|||
from mmcv.utils.parrots_wrapper import SyncBatchNorm
|
||||
|
||||
from mmseg.models.decode_heads import (ANNHead, ASPPHead, CCHead, DAHead,
|
||||
DepthwiseSeparableASPPHead,
|
||||
DepthwiseSeparableFCNHead, EncHead,
|
||||
DepthwiseSeparableASPPHead, EncHead,
|
||||
FCNHead, GCHead, NLHead, OCRHead,
|
||||
PSAHead, PSPHead, UPerHead)
|
||||
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
|
||||
|
@ -540,37 +539,3 @@ def test_dw_aspp_head():
|
|||
assert head.aspp_modules[2].depthwise_conv.dilation == (24, 24)
|
||||
outputs = head(inputs)
|
||||
assert outputs.shape == (1, head.num_classes, 45, 45)
|
||||
|
||||
|
||||
def test_sep_fcn_head():
|
||||
# test sep_fcn_head with concat_input=False
|
||||
head = DepthwiseSeparableFCNHead(
|
||||
in_channels=128,
|
||||
channels=128,
|
||||
concat_input=False,
|
||||
num_classes=19,
|
||||
in_index=-1,
|
||||
norm_cfg=dict(type='BN', requires_grad=True, momentum=0.01))
|
||||
x = [torch.rand(2, 128, 32, 32)]
|
||||
output = head(x)
|
||||
assert output.shape == (2, head.num_classes, 32, 32)
|
||||
assert not head.concat_input
|
||||
from mmseg.ops.separable_conv_module import DepthwiseSeparableConvModule
|
||||
assert isinstance(head.convs[0], DepthwiseSeparableConvModule)
|
||||
assert isinstance(head.convs[1], DepthwiseSeparableConvModule)
|
||||
assert head.conv_seg.kernel_size == (1, 1)
|
||||
|
||||
head = DepthwiseSeparableFCNHead(
|
||||
in_channels=64,
|
||||
channels=64,
|
||||
concat_input=True,
|
||||
num_classes=19,
|
||||
in_index=-1,
|
||||
norm_cfg=dict(type='BN', requires_grad=True, momentum=0.01))
|
||||
x = [torch.rand(3, 64, 32, 32)]
|
||||
output = head(x)
|
||||
assert output.shape == (3, head.num_classes, 32, 32)
|
||||
assert head.concat_input
|
||||
from mmseg.ops.separable_conv_module import DepthwiseSeparableConvModule
|
||||
assert isinstance(head.convs[0], DepthwiseSeparableConvModule)
|
||||
assert isinstance(head.convs[1], DepthwiseSeparableConvModule)
|
||||
|
|
|
@ -0,0 +1,18 @@
|
|||
import torch
|
||||
|
||||
from mmseg.models import FPN
|
||||
|
||||
|
||||
def test_fpn():
|
||||
in_channels = [256, 512, 1024, 2048]
|
||||
inputs = [
|
||||
torch.randn(1, c, 56 // 2**i, 56 // 2**i)
|
||||
for i, c in enumerate(in_channels)
|
||||
]
|
||||
|
||||
fpn = FPN(in_channels, 256, len(in_channels))
|
||||
outputs = fpn(inputs)
|
||||
assert outputs[0].shape == torch.Size([1, 256, 56, 56])
|
||||
assert outputs[1].shape == torch.Size([1, 256, 28, 28])
|
||||
assert outputs[2].shape == torch.Size([1, 256, 14, 14])
|
||||
assert outputs[3].shape == torch.Size([1, 256, 7, 7])
|
Loading…
Reference in New Issue