Assert messages added

pull/2164/head
user-miner1 2024-04-30 10:10:02 +03:00
parent 6de529bb3c
commit 740f4983b3
1 changed files with 4 additions and 4 deletions

View File

@ -632,7 +632,7 @@ class ConvNeXtBlock(nn.Module):
def window_partition(x, window_size: List[int]): def window_partition(x, window_size: List[int]):
B, H, W, C = x.shape B, H, W, C = x.shape
_assert(H % window_size[0] == 0, f'height ({H}) must be divisible by window ({window_size[0]})') _assert(H % window_size[0] == 0, f'height ({H}) must be divisible by window ({window_size[0]})')
_assert(W % window_size[1] == 0, '') _assert(W % window_size[1] == 0, f'width ({W}) must be divisible by window ({window_size[1]})')
x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C) x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0], window_size[1], C) windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0], window_size[1], C)
return windows return windows
@ -650,7 +650,7 @@ def window_reverse(windows, window_size: List[int], img_size: List[int]):
def grid_partition(x, grid_size: List[int]): def grid_partition(x, grid_size: List[int]):
B, H, W, C = x.shape B, H, W, C = x.shape
_assert(H % grid_size[0] == 0, f'height {H} must be divisible by grid {grid_size[0]}') _assert(H % grid_size[0] == 0, f'height {H} must be divisible by grid {grid_size[0]}')
_assert(W % grid_size[1] == 0, '') _assert(W % grid_size[1] == 0, f'width {W} must be divisible by grid {grid_size[1]}')
x = x.view(B, grid_size[0], H // grid_size[0], grid_size[1], W // grid_size[1], C) x = x.view(B, grid_size[0], H // grid_size[0], grid_size[1], W // grid_size[1], C)
windows = x.permute(0, 2, 4, 1, 3, 5).contiguous().view(-1, grid_size[0], grid_size[1], C) windows = x.permute(0, 2, 4, 1, 3, 5).contiguous().view(-1, grid_size[0], grid_size[1], C)
return windows return windows
@ -816,7 +816,7 @@ class ParallelPartitionAttention(nn.Module):
def window_partition_nchw(x, window_size: List[int]): def window_partition_nchw(x, window_size: List[int]):
B, C, H, W = x.shape B, C, H, W = x.shape
_assert(H % window_size[0] == 0, f'height ({H}) must be divisible by window ({window_size[0]})') _assert(H % window_size[0] == 0, f'height ({H}) must be divisible by window ({window_size[0]})')
_assert(W % window_size[1] == 0, '') _assert(W % window_size[1] == 0, f'width ({W}) must be divisible by window ({window_size[1]})')
x = x.view(B, C, H // window_size[0], window_size[0], W // window_size[1], window_size[1]) x = x.view(B, C, H // window_size[0], window_size[0], W // window_size[1], window_size[1])
windows = x.permute(0, 2, 4, 1, 3, 5).contiguous().view(-1, C, window_size[0], window_size[1]) windows = x.permute(0, 2, 4, 1, 3, 5).contiguous().view(-1, C, window_size[0], window_size[1])
return windows return windows
@ -834,7 +834,7 @@ def window_reverse_nchw(windows, window_size: List[int], img_size: List[int]):
def grid_partition_nchw(x, grid_size: List[int]): def grid_partition_nchw(x, grid_size: List[int]):
B, C, H, W = x.shape B, C, H, W = x.shape
_assert(H % grid_size[0] == 0, f'height {H} must be divisible by grid {grid_size[0]}') _assert(H % grid_size[0] == 0, f'height {H} must be divisible by grid {grid_size[0]}')
_assert(W % grid_size[1] == 0, '') _assert(W % grid_size[1] == 0, f'width {W} must be divisible by grid {grid_size[1]}')
x = x.view(B, C, grid_size[0], H // grid_size[0], grid_size[1], W // grid_size[1]) x = x.view(B, C, grid_size[0], H // grid_size[0], grid_size[1], W // grid_size[1])
windows = x.permute(0, 3, 5, 1, 2, 4).contiguous().view(-1, C, grid_size[0], grid_size[1]) windows = x.permute(0, 3, 5, 1, 2, 4).contiguous().view(-1, C, grid_size[0], grid_size[1])
return windows return windows