mmcv/setup.py

229 lines
8.0 KiB
Python
Raw Normal View History

import glob
import os
2019-07-31 02:22:20 +08:00
import platform
import re
import setuptools
from pkg_resources import DistributionNotFound, get_distribution
from setuptools import dist, find_packages, setup
dist.Distribution().fetch_build_eggs(['Cython', 'numpy>=1.11.1'])
import numpy # NOQA: E402 # isort:skip
from Cython.Build import cythonize # NOQA: E402 # isort:skip
from Cython.Distutils import build_ext as build_cmd # NOQA: E402 # isort:skip
def choose_requirement(primary, secondary):
"""If some version of primary requirement installed, return primary, else
return secondary."""
try:
name = re.split(r'[!<>=]', primary)[0]
get_distribution(name)
except DistributionNotFound:
return secondary
return str(primary)
def readme():
with open('README.rst', encoding='utf-8') as f:
content = f.read()
return content
def get_version():
version_file = 'mmcv/version.py'
with open(version_file, 'r', encoding='utf-8') as f:
exec(compile(f.read(), version_file, 'exec'))
return locals()['__version__']
2020-05-15 23:08:12 +08:00
def parse_requirements(fname='requirements.txt', with_version=True):
"""Parse the package dependencies listed in a requirements file but strips
2020-05-15 23:08:12 +08:00
specific versioning information.
Args:
fname (str): path to requirements file
with_version (bool, default=False): if True include version specs
Returns:
List[str]: list of requirements items
CommandLine:
python -c "import setup; print(setup.parse_requirements())"
"""
import sys
from os.path import exists
import re
require_fpath = fname
def parse_line(line):
"""Parse information from a line in a requirements text file."""
2020-05-15 23:08:12 +08:00
if line.startswith('-r '):
# Allow specifying requirements in other files
target = line.split(' ')[1]
for info in parse_require_file(target):
yield info
else:
info = {'line': line}
if line.startswith('-e '):
info['package'] = line.split('#egg=')[1]
else:
# Remove versioning from the package
pat = '(' + '|'.join(['>=', '==', '>']) + ')'
parts = re.split(pat, line, maxsplit=1)
parts = [p.strip() for p in parts]
info['package'] = parts[0]
if len(parts) > 1:
op, rest = parts[1:]
if ';' in rest:
# Handle platform specific dependencies
# http://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-platform-specific-dependencies
version, platform_deps = map(str.strip,
rest.split(';'))
info['platform_deps'] = platform_deps
else:
version = rest # NOQA
info['version'] = (op, version)
yield info
def parse_require_file(fpath):
with open(fpath, 'r') as f:
for line in f.readlines():
line = line.strip()
if line and not line.startswith('#'):
for info in parse_line(line):
yield info
def gen_packages_items():
if exists(require_fpath):
for info in parse_require_file(require_fpath):
parts = [info['package']]
if with_version and 'version' in info:
parts.extend(info['version'])
if not sys.version.startswith('3.4'):
# apparently package_deps are broken in 3.4
platform_deps = info.get('platform_deps')
if platform_deps is not None:
parts.append(';' + platform_deps)
item = ''.join(parts)
yield item
packages = list(gen_packages_items())
return packages
# If first not installed install second package
CHOOSE_INSTALL_REQUIRES = [('opencv-python-headless>=3', 'opencv-python>=3')]
install_requires = parse_requirements()
for main, secondary in CHOOSE_INSTALL_REQUIRES:
install_requires.append(choose_requirement(main, secondary))
2019-07-31 02:22:20 +08:00
def get_extensions():
extensions = []
if platform.system() == 'Darwin':
extra_compile_args = ['-stdlib=libc++']
extra_link_args = ['-stdlib=libc++']
else:
extra_compile_args = []
extra_link_args = []
ext_flow = setuptools.Extension(
name='mmcv._flow_warp_ext',
sources=[
'./mmcv/video/optflow_warp/flow_warp.cpp',
'./mmcv/video/optflow_warp/flow_warp_module.pyx'
],
include_dirs=[numpy.get_include()],
2019-07-31 02:22:20 +08:00
language='c++',
extra_compile_args=extra_compile_args,
extra_link_args=extra_link_args)
extensions.extend(cythonize(ext_flow))
try:
import torch
ext_name = 'mmcv._ext'
if torch.__version__ == 'parrots':
from parrots.utils.build_extension import BuildExtension, Extension
define_macros = [('MMCV_USE_PARROTS', None)]
op_files = glob.glob('./mmcv/ops/csrc/parrots/*')
include_path = os.path.abspath('./mmcv/ops/csrc')
2020-06-29 21:14:09 +08:00
cuda_args = os.getenv('MMCV_CUDA_ARGS')
ext_ops = Extension(
name=ext_name,
sources=op_files,
include_dirs=[include_path],
define_macros=define_macros,
extra_compile_args={
2020-06-29 21:14:09 +08:00
'nvcc': [cuda_args] if cuda_args else [],
'cxx': [],
},
cuda=True)
extensions.append(ext_ops)
else:
from torch.utils.cpp_extension import (BuildExtension,
CUDAExtension, CppExtension)
# prevent ninja from using too many resources
os.environ.setdefault('MAX_JOBS', '4')
define_macros = []
extra_compile_args = {'cxx': []}
if (torch.cuda.is_available()
or os.getenv('FORCE_CUDA', '0') == '1'):
define_macros += [('MMCV_WITH_CUDA', None)]
2020-06-29 21:14:09 +08:00
cuda_args = os.getenv('MMCV_CUDA_ARGS')
extra_compile_args['nvcc'] = [cuda_args] if cuda_args else []
op_files = glob.glob('./mmcv/ops/csrc/pytorch/*')
extension = CUDAExtension
else:
print(f'Compiling {ext_name} without CUDA')
op_files = glob.glob('./mmcv/ops/csrc/pytorch/*.cpp')
extension = CppExtension
include_path = os.path.abspath('./mmcv/ops/csrc')
ext_ops = extension(
name=ext_name,
sources=op_files,
include_dirs=[include_path],
define_macros=define_macros,
extra_compile_args=extra_compile_args)
extensions.append(ext_ops)
global build_cmd
build_cmd = BuildExtension
except ModuleNotFoundError:
print('Skip building ext ops due to the absence of torch.')
return extensions
setup(
name='mmcv',
version=get_version(),
description='OpenMMLab Computer Vision Foundation',
long_description=readme(),
keywords='computer vision',
packages=find_packages(),
2020-05-30 13:04:37 +08:00
include_package_data=True,
classifiers=[
'Development Status :: 4 - Beta',
2018-12-09 17:28:52 +08:00
'License :: OSI Approved :: Apache Software License',
'Operating System :: OS Independent',
'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3.6',
'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8',
'Topic :: Utilities',
],
url='https://github.com/open-mmlab/mmcv',
author='MMCV Authors',
author_email='chenkaidev@gmail.com',
setup_requires=['pytest-runner'],
tests_require=['pytest'],
install_requires=install_requires,
ext_modules=get_extensions(),
cmdclass={'build_ext': build_cmd},
zip_safe=False)