Assert messages added
parent
6de529bb3c
commit
740f4983b3
|
@ -632,7 +632,7 @@ class ConvNeXtBlock(nn.Module):
|
|||
def window_partition(x, window_size: List[int]):
|
||||
B, H, W, C = x.shape
|
||||
_assert(H % window_size[0] == 0, f'height ({H}) must be divisible by window ({window_size[0]})')
|
||||
_assert(W % window_size[1] == 0, '')
|
||||
_assert(W % window_size[1] == 0, f'width ({W}) must be divisible by window ({window_size[1]})')
|
||||
x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C)
|
||||
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0], window_size[1], C)
|
||||
return windows
|
||||
|
@ -650,7 +650,7 @@ def window_reverse(windows, window_size: List[int], img_size: List[int]):
|
|||
def grid_partition(x, grid_size: List[int]):
|
||||
B, H, W, C = x.shape
|
||||
_assert(H % grid_size[0] == 0, f'height {H} must be divisible by grid {grid_size[0]}')
|
||||
_assert(W % grid_size[1] == 0, '')
|
||||
_assert(W % grid_size[1] == 0, f'width {W} must be divisible by grid {grid_size[1]}')
|
||||
x = x.view(B, grid_size[0], H // grid_size[0], grid_size[1], W // grid_size[1], C)
|
||||
windows = x.permute(0, 2, 4, 1, 3, 5).contiguous().view(-1, grid_size[0], grid_size[1], C)
|
||||
return windows
|
||||
|
@ -816,7 +816,7 @@ class ParallelPartitionAttention(nn.Module):
|
|||
def window_partition_nchw(x, window_size: List[int]):
|
||||
B, C, H, W = x.shape
|
||||
_assert(H % window_size[0] == 0, f'height ({H}) must be divisible by window ({window_size[0]})')
|
||||
_assert(W % window_size[1] == 0, '')
|
||||
_assert(W % window_size[1] == 0, f'width ({W}) must be divisible by window ({window_size[1]})')
|
||||
x = x.view(B, C, H // window_size[0], window_size[0], W // window_size[1], window_size[1])
|
||||
windows = x.permute(0, 2, 4, 1, 3, 5).contiguous().view(-1, C, window_size[0], window_size[1])
|
||||
return windows
|
||||
|
@ -834,7 +834,7 @@ def window_reverse_nchw(windows, window_size: List[int], img_size: List[int]):
|
|||
def grid_partition_nchw(x, grid_size: List[int]):
|
||||
B, C, H, W = x.shape
|
||||
_assert(H % grid_size[0] == 0, f'height {H} must be divisible by grid {grid_size[0]}')
|
||||
_assert(W % grid_size[1] == 0, '')
|
||||
_assert(W % grid_size[1] == 0, f'width {W} must be divisible by grid {grid_size[1]}')
|
||||
x = x.view(B, C, grid_size[0], H // grid_size[0], grid_size[1], W // grid_size[1])
|
||||
windows = x.permute(0, 3, 5, 1, 2, 4).contiguous().view(-1, C, grid_size[0], grid_size[1])
|
||||
return windows
|
||||
|
|
Loading…
Reference in New Issue