Check input resolution
parent
81bf0b4033
commit
ff5f6bcd6c
|
@ -899,6 +899,10 @@ class SwinTransformerV2CR(nn.Module):
|
|||
Returns:
|
||||
features (List[torch.Tensor]): List of feature maps from each stage
|
||||
"""
|
||||
# Check input resolution
|
||||
assert input.shape[2:] == self.input_resolution, \
|
||||
"Input resolution and utilized resolution does not match. Please update the models resolution by calling " \
|
||||
"update_resolution the provided method."
|
||||
# Perform patch embedding
|
||||
output: torch.Tensor = self.patch_embedding(input)
|
||||
# Init list to store feature
|
||||
|
@ -919,6 +923,10 @@ class SwinTransformerV2CR(nn.Module):
|
|||
Returns:
|
||||
classification (torch.Tensor): Classification of the shape (B, num_classes)
|
||||
"""
|
||||
# Check input resolution
|
||||
assert input.shape[2:] == self.input_resolution, \
|
||||
"Input resolution and utilized resolution does not match. Please update the models resolution by calling " \
|
||||
"update_resolution the provided method."
|
||||
# Perform patch embedding
|
||||
output: torch.Tensor = self.patch_embedding(input)
|
||||
# Forward pass of each stage
|
||||
|
|
Loading…
Reference in New Issue