fix ut & add torch2.0 for ci

pull/538/head
humu789 2023-05-24 14:24:06 +08:00
parent c7e05f7cf2
commit 4c9938596c
5 changed files with 8 additions and 9 deletions

View File

@ -29,7 +29,7 @@ jobs:
strategy:
matrix:
python-version: [3.7]
torch: [1.8.0, 1.9.0, 1.10.0, 1.11.0, 1.12.0, 1.13.0]
torch: [1.8.0, 1.9.0, 1.10.0, 1.11.0, 1.12.0, 1.13.0, 2.0.0]
include:
- torch: 1.8.0
torch_version: 1.8
@ -73,6 +73,10 @@ jobs:
torch_version: 1.13
torchvision: 0.14.0
python-version: 3.8
- torch: 2.0.0
torch_version: 2.0.0
torchvision: 0.15.0
python-version: 3.8
steps:
- uses: actions/checkout@v2

View File

@ -1,16 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .compressor import GPTQCompressor
from .custom_autotune import (Autotuner, autotune,
matmul248_kernel_config_pruner)
from .gptq import GPTQMixIn
from .ops import GPTQConv2d, GPTQLinear, TritonGPTQLinear
from .quantizer import Quantizer
__all__ = [
'GPTQCompressor',
'Autotuner',
'autotune',
'matmul248_kernel_config_pruner',
'GPTQMixIn',
'GPTQConv2d',
'GPTQLinear',

View File

@ -12,7 +12,7 @@ try:
import triton
except ImportError:
from mmrazor.utils import get_package_placeholder
triton = get_package_placeholder('please install triton with pip')
triton = get_package_placeholder('triton >= 2.0.0')
class Autotuner(triton.KernelInterface):

View File

@ -350,7 +350,7 @@ try:
c_mask = (offs_am[:, None] < M) & (offs_bk[None, :] < K)
tl.store(c_ptrs, accumulator, mask=c_mask)
except: # noqa: E722
print('trioton not installed.')
print('triton not installed.')
def matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq):

View File

@ -7,6 +7,6 @@ nbformat
numpy < 1.24.0 # A temporary solution for tests with mmdet.
onnx
pytest
triton
triton==2.0.0
xdoctest >= 0.10.0
yapf