mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Check input resolution
This commit is contained in:
parent
81bf0b4033
commit
ff5f6bcd6c
@ -899,6 +899,10 @@ class SwinTransformerV2CR(nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
features (List[torch.Tensor]): List of feature maps from each stage
|
features (List[torch.Tensor]): List of feature maps from each stage
|
||||||
"""
|
"""
|
||||||
|
# Check input resolution
|
||||||
|
assert input.shape[2:] == self.input_resolution, \
|
||||||
|
"Input resolution and utilized resolution does not match. Please update the models resolution by calling " \
|
||||||
|
"update_resolution the provided method."
|
||||||
# Perform patch embedding
|
# Perform patch embedding
|
||||||
output: torch.Tensor = self.patch_embedding(input)
|
output: torch.Tensor = self.patch_embedding(input)
|
||||||
# Init list to store feature
|
# Init list to store feature
|
||||||
@ -919,6 +923,10 @@ class SwinTransformerV2CR(nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
classification (torch.Tensor): Classification of the shape (B, num_classes)
|
classification (torch.Tensor): Classification of the shape (B, num_classes)
|
||||||
"""
|
"""
|
||||||
|
# Check input resolution
|
||||||
|
assert input.shape[2:] == self.input_resolution, \
|
||||||
|
"Input resolution and utilized resolution does not match. Please update the models resolution by calling " \
|
||||||
|
"update_resolution the provided method."
|
||||||
# Perform patch embedding
|
# Perform patch embedding
|
||||||
output: torch.Tensor = self.patch_embedding(input)
|
output: torch.Tensor = self.patch_embedding(input)
|
||||||
# Forward pass of each stage
|
# Forward pass of each stage
|
||||||
|
Loading…
x
Reference in New Issue
Block a user