Always some torchscript issues
parent
528faa0e04
commit
437d344e03
|
@ -3,6 +3,8 @@ import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
class SpaceToDepth(nn.Module):
|
class SpaceToDepth(nn.Module):
|
||||||
|
bs: torch.jit.Final[int]
|
||||||
|
|
||||||
def __init__(self, block_size=4):
|
def __init__(self, block_size=4):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert block_size == 4
|
assert block_size == 4
|
||||||
|
@ -12,7 +14,7 @@ class SpaceToDepth(nn.Module):
|
||||||
N, C, H, W = x.size()
|
N, C, H, W = x.size()
|
||||||
x = x.view(N, C, H // self.bs, self.bs, W // self.bs, self.bs) # (N, C, H//bs, bs, W//bs, bs)
|
x = x.view(N, C, H // self.bs, self.bs, W // self.bs, self.bs) # (N, C, H//bs, bs, W//bs, bs)
|
||||||
x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs)
|
x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs)
|
||||||
x = x.view(N, C * (self.bs ** 2), H // self.bs, W // self.bs) # (N, C*bs^2, H//bs, W//bs)
|
x = x.view(N, C * self.bs * self.bs, H // self.bs, W // self.bs) # (N, C*bs^2, H//bs, W//bs)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -66,6 +66,8 @@ class MlpWithDepthwiseConv(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class Attention(nn.Module):
|
class Attention(nn.Module):
|
||||||
|
fused_attn: torch.jit.Final[bool]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
dim,
|
dim,
|
||||||
|
|
Loading…
Reference in New Issue