fixed intermediate output indices

This commit is contained in:
Dillon Laird 2023-11-22 12:31:31 -08:00 committed by Ross Wightman
parent 4d0737d5fa
commit 63ee54853c

View File

@ -1164,8 +1164,10 @@ class FastVit(nn.Module):
# For segmentation and detection, extract intermediate output
if self.fork_feat:
# add a norm layer for each output
self.out_indices = [0, 2, 4, 6]
# Add a norm layer for each output. self.stages is slightly different than self.network
# in the original code, the PatchEmbed layer is part of self.stages in this code where
# it was part of self.network in the original code. So we do not need to skip out indices.
self.out_indices = [0, 1, 2, 3]
for i_emb, i_layer in enumerate(self.out_indices):
if i_emb == 0 and os.environ.get("FORK_LAST3", None):
"""For RetinaNet, `start_level=1`. The first norm layer will not used.