diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 88928727..7314ca7b 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -29,7 +29,7 @@ jobs: strategy: matrix: python-version: [3.7] - torch: [1.8.0, 1.9.0, 1.10.0, 1.11.0, 1.12.0, 1.13.0] + torch: [1.8.0, 1.9.0, 1.10.0, 1.11.0, 1.12.0, 1.13.0, 2.0.0] include: - torch: 1.8.0 torch_version: 1.8 @@ -73,6 +73,10 @@ jobs: torch_version: 1.13 torchvision: 0.14.0 python-version: 3.8 + - torch: 2.0.0 + torch_version: 2.0.0 + torchvision: 0.15.0 + python-version: 3.8 steps: - uses: actions/checkout@v2 diff --git a/mmrazor/implementations/quantization/gptq/__init__.py b/mmrazor/implementations/quantization/gptq/__init__.py index c1f4b226..4981c801 100644 --- a/mmrazor/implementations/quantization/gptq/__init__.py +++ b/mmrazor/implementations/quantization/gptq/__init__.py @@ -1,16 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. from .compressor import GPTQCompressor -from .custom_autotune import (Autotuner, autotune, - matmul248_kernel_config_pruner) from .gptq import GPTQMixIn from .ops import GPTQConv2d, GPTQLinear, TritonGPTQLinear from .quantizer import Quantizer __all__ = [ 'GPTQCompressor', - 'Autotuner', - 'autotune', - 'matmul248_kernel_config_pruner', 'GPTQMixIn', 'GPTQConv2d', 'GPTQLinear', diff --git a/mmrazor/implementations/quantization/gptq/custom_autotune.py b/mmrazor/implementations/quantization/gptq/custom_autotune.py index c41ba8bb..1bc0d7d5 100644 --- a/mmrazor/implementations/quantization/gptq/custom_autotune.py +++ b/mmrazor/implementations/quantization/gptq/custom_autotune.py @@ -12,7 +12,7 @@ try: import triton except ImportError: from mmrazor.utils import get_package_placeholder - triton = get_package_placeholder('please install triton with pip') + triton = get_package_placeholder('triton >= 2.0.0') class Autotuner(triton.KernelInterface): diff --git a/mmrazor/implementations/quantization/gptq/ops.py b/mmrazor/implementations/quantization/gptq/ops.py index 590febc0..b8c13941 100644 --- a/mmrazor/implementations/quantization/gptq/ops.py +++ b/mmrazor/implementations/quantization/gptq/ops.py @@ -350,7 +350,7 @@ try: c_mask = (offs_am[:, None] < M) & (offs_bk[None, :] < K) tl.store(c_ptrs, accumulator, mask=c_mask) except: # noqa: E722 - print('trioton not installed.') + print('triton not installed.') def matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq): diff --git a/requirements/tests.txt b/requirements/tests.txt index dacfe1c2..b025f5a6 100644 --- a/requirements/tests.txt +++ b/requirements/tests.txt @@ -7,6 +7,6 @@ nbformat numpy < 1.24.0 # A temporary solution for tests with mmdet. onnx pytest -triton +triton==2.0.0 xdoctest >= 0.10.0 yapf