59 lines
1.1 KiB
Python
59 lines
1.1 KiB
Python
from enum import Enum
|
|
from typing import Union
|
|
|
|
import torch
|
|
|
|
|
|
class Format(str, Enum):
|
|
NCHW = 'NCHW'
|
|
NHWC = 'NHWC'
|
|
NCL = 'NCL'
|
|
NLC = 'NLC'
|
|
|
|
|
|
FormatT = Union[str, Format]
|
|
|
|
|
|
def get_spatial_dim(fmt: FormatT):
|
|
fmt = Format(fmt)
|
|
if fmt is Format.NLC:
|
|
dim = (1,)
|
|
elif fmt is Format.NCL:
|
|
dim = (2,)
|
|
elif fmt is Format.NHWC:
|
|
dim = (1, 2)
|
|
else:
|
|
dim = (2, 3)
|
|
return dim
|
|
|
|
|
|
def get_channel_dim(fmt: FormatT):
|
|
fmt = Format(fmt)
|
|
if fmt is Format.NHWC:
|
|
dim = 3
|
|
elif fmt is Format.NLC:
|
|
dim = 2
|
|
else:
|
|
dim = 1
|
|
return dim
|
|
|
|
|
|
def nchw_to(x: torch.Tensor, fmt: Format):
|
|
if fmt == Format.NHWC:
|
|
x = x.permute(0, 2, 3, 1)
|
|
elif fmt == Format.NLC:
|
|
x = x.flatten(2).transpose(1, 2)
|
|
elif fmt == Format.NCL:
|
|
x = x.flatten(2)
|
|
return x
|
|
|
|
|
|
def nhwc_to(x: torch.Tensor, fmt: Format):
|
|
if fmt == Format.NCHW:
|
|
x = x.permute(0, 3, 1, 2)
|
|
elif fmt == Format.NLC:
|
|
x = x.flatten(1, 2)
|
|
elif fmt == Format.NCL:
|
|
x = x.flatten(1, 2).transpose(1, 2)
|
|
return x
|