diff --git a/CMakeLists.txt b/CMakeLists.txt index b196c7329be8c815a4a94af06bc6aea820a87e4b..feadb4b8fe8427f6b4b7a1f10d561fc38b225e82 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -23,6 +23,11 @@ endif() project(akg C CXX) +if (CMAKE_COMPILER_IS_GNUCC AND (CMAKE_CXX_COMPILER_VERSION VERSION_LESS 7.3 OR + CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 9.4)) + message(SEND_ERROR "gcc version should be in [7.3, 9.4], while ${CMAKE_CXX_COMPILER_VERSION} is detected") +endif() + set(AKG_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}") set(TVM_DIR "${AKG_SOURCE_DIR}/third_party/incubator-tvm") diff --git a/third_party/incubator-tvm/python/tvm/contrib/nvcc.py b/third_party/incubator-tvm/python/tvm/contrib/nvcc.py index ea9b34fa4ca5450e8473f858d3c935895fd281f4..06f0394fe4218726ff0ff8f0c3ec51bbea446792 100644 --- a/third_party/incubator-tvm/python/tvm/contrib/nvcc.py +++ b/third_party/incubator-tvm/python/tvm/contrib/nvcc.py @@ -17,18 +17,37 @@ # pylint: disable=invalid-name # 2020.9.15 - Add default GPU arch: sm_70. # 2020.9.19 - Modify default GPU arch function. -# 2020.10.26- Add the logic of finding AkgReduce library. +# 2020.10.26 - Add the logic of finding AkgReduce library. +# 2022.9.5 - Add the check of gcc version. """Utility to invoke nvcc compiler in the system""" from __future__ import absolute_import as _abs import subprocess import os +import re import warnings from . import util from .. import ndarray as nd from ..api import register_func from .._ffi.base import py_str + +def get_gcc_version(): + """Get the gcc version from environment. + Return + ------ + ver : tuple of int + The version of gcc likes: (7.3.0), (9.4.0) + """ + f = os.popen("gcc --version") + for line in f: + if line[:3] == "gcc": + m = re.findall(r"\d+\.\d+\.\d+", line)[0] + a, b, c = m.split(".") + ver = (int(a), int(b), int(c)) + return ver + return None + def compile_cuda(code, target="ptx", arch=None, @@ -58,6 +77,12 @@ def compile_cuda(code, cubin : bytearray The bytearray of the cubin """ + # supported gcc version for cuda code generation + gcc_version = get_gcc_version() + if gcc_version is None or (7,3,0) > gcc_version or gcc_version > (9,4,0): + raise ValueError("gcc version is {}, not in range [7.3.0, 9.4.0]".format( + gcc_version)) + arch_exception = ["sm_00", None] temp = util.tempdir() if target not in ["cubin", "ptx", "fatbin"]: