fix ut & add torch2.0 for ci

pull/538/head
humu789 2023-05-24 14:24:06 +08:00
parent c7e05f7cf2
commit 4c9938596c
5 changed files with 8 additions and 9 deletions

View File

@ -29,7 +29,7 @@ jobs:
strategy: strategy:
matrix: matrix:
python-version: [3.7] python-version: [3.7]
torch: [1.8.0, 1.9.0, 1.10.0, 1.11.0, 1.12.0, 1.13.0] torch: [1.8.0, 1.9.0, 1.10.0, 1.11.0, 1.12.0, 1.13.0, 2.0.0]
include: include:
- torch: 1.8.0 - torch: 1.8.0
torch_version: 1.8 torch_version: 1.8
@ -73,6 +73,10 @@ jobs:
torch_version: 1.13 torch_version: 1.13
torchvision: 0.14.0 torchvision: 0.14.0
python-version: 3.8 python-version: 3.8
- torch: 2.0.0
torch_version: 2.0.0
torchvision: 0.15.0
python-version: 3.8
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v2

View File

@ -1,16 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from .compressor import GPTQCompressor from .compressor import GPTQCompressor
from .custom_autotune import (Autotuner, autotune,
matmul248_kernel_config_pruner)
from .gptq import GPTQMixIn from .gptq import GPTQMixIn
from .ops import GPTQConv2d, GPTQLinear, TritonGPTQLinear from .ops import GPTQConv2d, GPTQLinear, TritonGPTQLinear
from .quantizer import Quantizer from .quantizer import Quantizer
__all__ = [ __all__ = [
'GPTQCompressor', 'GPTQCompressor',
'Autotuner',
'autotune',
'matmul248_kernel_config_pruner',
'GPTQMixIn', 'GPTQMixIn',
'GPTQConv2d', 'GPTQConv2d',
'GPTQLinear', 'GPTQLinear',

View File

@ -12,7 +12,7 @@ try:
import triton import triton
except ImportError: except ImportError:
from mmrazor.utils import get_package_placeholder from mmrazor.utils import get_package_placeholder
triton = get_package_placeholder('please install triton with pip') triton = get_package_placeholder('triton >= 2.0.0')
class Autotuner(triton.KernelInterface): class Autotuner(triton.KernelInterface):

View File

@ -350,7 +350,7 @@ try:
c_mask = (offs_am[:, None] < M) & (offs_bk[None, :] < K) c_mask = (offs_am[:, None] < M) & (offs_bk[None, :] < K)
tl.store(c_ptrs, accumulator, mask=c_mask) tl.store(c_ptrs, accumulator, mask=c_mask)
except: # noqa: E722 except: # noqa: E722
print('trioton not installed.') print('triton not installed.')
def matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq): def matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq):

View File

@ -7,6 +7,6 @@ nbformat
numpy < 1.24.0 # A temporary solution for tests with mmdet. numpy < 1.24.0 # A temporary solution for tests with mmdet.
onnx onnx
pytest pytest
triton triton==2.0.0
xdoctest >= 0.10.0 xdoctest >= 0.10.0
yapf yapf