fix ut & add torch2.0 for ci
parent
c7e05f7cf2
commit
4c9938596c
|
@ -29,7 +29,7 @@ jobs:
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
python-version: [3.7]
|
python-version: [3.7]
|
||||||
torch: [1.8.0, 1.9.0, 1.10.0, 1.11.0, 1.12.0, 1.13.0]
|
torch: [1.8.0, 1.9.0, 1.10.0, 1.11.0, 1.12.0, 1.13.0, 2.0.0]
|
||||||
include:
|
include:
|
||||||
- torch: 1.8.0
|
- torch: 1.8.0
|
||||||
torch_version: 1.8
|
torch_version: 1.8
|
||||||
|
@ -73,6 +73,10 @@ jobs:
|
||||||
torch_version: 1.13
|
torch_version: 1.13
|
||||||
torchvision: 0.14.0
|
torchvision: 0.14.0
|
||||||
python-version: 3.8
|
python-version: 3.8
|
||||||
|
- torch: 2.0.0
|
||||||
|
torch_version: 2.0.0
|
||||||
|
torchvision: 0.15.0
|
||||||
|
python-version: 3.8
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v2
|
- uses: actions/checkout@v2
|
||||||
|
|
|
@ -1,16 +1,11 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from .compressor import GPTQCompressor
|
from .compressor import GPTQCompressor
|
||||||
from .custom_autotune import (Autotuner, autotune,
|
|
||||||
matmul248_kernel_config_pruner)
|
|
||||||
from .gptq import GPTQMixIn
|
from .gptq import GPTQMixIn
|
||||||
from .ops import GPTQConv2d, GPTQLinear, TritonGPTQLinear
|
from .ops import GPTQConv2d, GPTQLinear, TritonGPTQLinear
|
||||||
from .quantizer import Quantizer
|
from .quantizer import Quantizer
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'GPTQCompressor',
|
'GPTQCompressor',
|
||||||
'Autotuner',
|
|
||||||
'autotune',
|
|
||||||
'matmul248_kernel_config_pruner',
|
|
||||||
'GPTQMixIn',
|
'GPTQMixIn',
|
||||||
'GPTQConv2d',
|
'GPTQConv2d',
|
||||||
'GPTQLinear',
|
'GPTQLinear',
|
||||||
|
|
|
@ -12,7 +12,7 @@ try:
|
||||||
import triton
|
import triton
|
||||||
except ImportError:
|
except ImportError:
|
||||||
from mmrazor.utils import get_package_placeholder
|
from mmrazor.utils import get_package_placeholder
|
||||||
triton = get_package_placeholder('please install triton with pip')
|
triton = get_package_placeholder('triton >= 2.0.0')
|
||||||
|
|
||||||
|
|
||||||
class Autotuner(triton.KernelInterface):
|
class Autotuner(triton.KernelInterface):
|
||||||
|
|
|
@ -350,7 +350,7 @@ try:
|
||||||
c_mask = (offs_am[:, None] < M) & (offs_bk[None, :] < K)
|
c_mask = (offs_am[:, None] < M) & (offs_bk[None, :] < K)
|
||||||
tl.store(c_ptrs, accumulator, mask=c_mask)
|
tl.store(c_ptrs, accumulator, mask=c_mask)
|
||||||
except: # noqa: E722
|
except: # noqa: E722
|
||||||
print('trioton not installed.')
|
print('triton not installed.')
|
||||||
|
|
||||||
|
|
||||||
def matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq):
|
def matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq):
|
||||||
|
|
|
@ -7,6 +7,6 @@ nbformat
|
||||||
numpy < 1.24.0 # A temporary solution for tests with mmdet.
|
numpy < 1.24.0 # A temporary solution for tests with mmdet.
|
||||||
onnx
|
onnx
|
||||||
pytest
|
pytest
|
||||||
triton
|
triton==2.0.0
|
||||||
xdoctest >= 0.10.0
|
xdoctest >= 0.10.0
|
||||||
yapf
|
yapf
|
||||||
|
|
Loading…
Reference in New Issue