-
Notifications
You must be signed in to change notification settings - Fork 175
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* ignore more things * update setup.py * init pyproject.toml * update README * remove section seperations * remove unused variable * add docstring * read version from megablocks/__init__.py * fix reading repo version * add type hints * add classifiers, better long-description * update url * exclude more packages * add python_requires * update * use Composer's .gitignore * use Composer's pyproject.toml + my changes * remove my stk fork * remove composer specific * better error msg * fix typo * add correct versions of stanford-stk, grouped_gemm; add packaging * test in my GA * use my fork for testing * __init__.py only has __version * change GA to defaults * restore __init__.py * update readme * update readme
- Loading branch information
1 parent
28cd821
commit 52186a6
Showing
3 changed files
with
123 additions
and
41 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
# Copyright 2022 MegaBlocks Composer authors | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
"""The MegaBlocks Version.""" | ||
|
||
__version__ = '0.5.1' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,40 +1,77 @@ | ||
# Copyright 2024 MosaicML MegaBlocks authors | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
"""MegaBlocks package setup.""" | ||
|
||
import os | ||
import warnings | ||
from typing import Any, Dict, Mapping | ||
|
||
import torch | ||
from setuptools import find_packages, setup | ||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension | ||
|
||
if os.environ.get("TORCH_CUDA_ARCH_LIST"): | ||
# Let PyTorch builder to choose device to target for. | ||
device_capability = "" | ||
else: | ||
device_capability = torch.cuda.get_device_capability() | ||
device_capability = f"{device_capability[0]}{device_capability[1]}" | ||
|
||
nvcc_flags = [ | ||
"--ptxas-options=-v", | ||
"--optimize=2", | ||
] | ||
if device_capability: | ||
nvcc_flags.append( | ||
f"--generate-code=arch=compute_{device_capability},code=sm_{device_capability}" | ||
) | ||
|
||
ext_modules = [ | ||
CUDAExtension( | ||
"megablocks_ops", | ||
["csrc/ops.cu"], | ||
include_dirs=["csrc"], | ||
extra_compile_args={"cxx": ["-fopenmp"], "nvcc": nvcc_flags}, | ||
) | ||
# We require torch in setup.py to build cpp extensions "ahead of time" | ||
# More info here: # https://pytorch.org/tutorials/advanced/cpp_extension.html | ||
try: | ||
import torch | ||
from torch.utils.cpp_extension import (CUDA_HOME, BuildExtension, | ||
CUDAExtension,) | ||
except ModuleNotFoundError as e: | ||
raise ModuleNotFoundError( | ||
"No module named 'torch'. `torch` is required to install `MegaBlocks`." | ||
) from e | ||
|
||
|
||
_PACKAGE_NAME = 'megablocks' | ||
_PACKAGE_DIR = 'megablocks' | ||
_REPO_REAL_PATH = os.path.dirname(os.path.realpath(__file__)) | ||
_PACKAGE_REAL_PATH = os.path.join(_REPO_REAL_PATH, _PACKAGE_DIR) | ||
|
||
# Read the package version | ||
# We can't use `.__version__` from the library since it's not installed yet | ||
version_path = os.path.join(_PACKAGE_REAL_PATH, '_version.py') | ||
with open(version_path, encoding='utf-8') as f: | ||
version_globals: Dict[str, Any] = {} | ||
version_locals: Mapping[str, object] = {} | ||
content = f.read() | ||
exec(content, version_globals, version_locals) | ||
repo_version = version_locals['__version__'] | ||
|
||
|
||
with open('README.md', 'r', encoding='utf-8') as fh: | ||
long_description = fh.read() | ||
|
||
# Hide the content between <!-- SETUPTOOLS_LONG_DESCRIPTION_HIDE_BEGIN --> and | ||
# <!-- SETUPTOOLS_LONG_DESCRIPTION_HIDE_END --> tags in the README | ||
while True: | ||
start_tag = '<!-- SETUPTOOLS_LONG_DESCRIPTION_HIDE_BEGIN -->' | ||
end_tag = '<!-- SETUPTOOLS_LONG_DESCRIPTION_HIDE_END -->' | ||
start = long_description.find(start_tag) | ||
end = long_description.find(end_tag) | ||
if start == -1: | ||
assert end == -1, 'there should be a balanced number of start and ends' | ||
break | ||
else: | ||
assert end != -1, 'there should be a balanced number of start and ends' | ||
long_description = long_description[:start] + \ | ||
long_description[end + len(end_tag):] | ||
|
||
|
||
classifiers = [ | ||
'Programming Language :: Python :: 3', | ||
'Programming Language :: Python :: 3.9', | ||
'Programming Language :: Python :: 3.10', | ||
'Programming Language :: Python :: 3.11', | ||
'License :: OSI Approved :: BSD License', | ||
'Operating System :: Unix', | ||
] | ||
|
||
install_requires = [ | ||
'numpy>=1.21.5,<2.1.0', | ||
'packaging>=21.3.0,<24.2', | ||
'torch>=2.3.0,<2.4', | ||
'triton>=2.1.0', | ||
'stanford-stk @ git+https://[email protected]/stanford-futuredata/stk.git@a1ddf98466730b88a2988860a9d8000fd1833301', | ||
'packaging>=21.3.0,<24.2', | ||
] | ||
|
||
extra_deps = {} | ||
|
@@ -61,23 +98,62 @@ | |
if key not in {'testing'} | ||
}) | ||
|
||
|
||
cmdclass = {} | ||
ext_modules = [] | ||
|
||
# Only install CUDA extensions if available | ||
if 'cu' in torch.__version__ and CUDA_HOME is not None: | ||
|
||
cmdclass = {'build_ext': BuildExtension} | ||
nvcc_flags = ['--ptxas-options=-v', '--optimize=2'] | ||
|
||
if os.environ.get('TORCH_CUDA_ARCH_LIST'): | ||
# Let PyTorch builder to choose device to target for. | ||
device_capability = '' | ||
else: | ||
device_capability_tuple = torch.cuda.get_device_capability() | ||
device_capability = f'{device_capability_tuple[0]}{device_capability_tuple[1]}' | ||
|
||
if device_capability: | ||
nvcc_flags.append( | ||
f'--generate-code=arch=compute_{device_capability},code=sm_{device_capability}' | ||
) | ||
|
||
ext_modules = [ | ||
CUDAExtension( | ||
'megablocks_ops', | ||
['csrc/ops.cu'], | ||
include_dirs=['csrc'], | ||
extra_compile_args={ | ||
'cxx': ['-fopenmp'], | ||
'nvcc': nvcc_flags | ||
}, | ||
) | ||
] | ||
elif CUDA_HOME is None: | ||
warnings.warn( | ||
'Attempted to install CUDA extensions, but CUDA_HOME was None. ' + | ||
'Please install CUDA and ensure that the CUDA_HOME environment ' + | ||
'variable points to the installation location.') | ||
else: | ||
warnings.warn('Warning: No CUDA devices; cuda code will not be compiled.') | ||
|
||
|
||
setup( | ||
name="megablocks", | ||
version="0.5.1", | ||
author="Trevor Gale", | ||
author_email="[email protected]", | ||
description="MegaBlocks", | ||
long_description=open('README.md').read(), | ||
name=_PACKAGE_NAME, | ||
version=repo_version, | ||
author='Trevor Gale', | ||
author_email='[email protected]', | ||
description='MegaBlocks', | ||
long_description=long_description, | ||
long_description_content_type='text/markdown', | ||
url="https://github.com/stanford-futuredata/megablocks", | ||
classifiers=[ | ||
"Programming Language :: Python :: 3", | ||
"License :: OSI Approved :: BSD License", | ||
"Operating System :: Unix", | ||
], | ||
packages=find_packages(), | ||
url='https://github.com/databricks/megablocks', | ||
classifiers=classifiers, | ||
packages=find_packages(exclude=['tests*', 'third_party*', 'yamls*', 'exp*', '.github*']), | ||
ext_modules=ext_modules, | ||
cmdclass={"build_ext": BuildExtension}, | ||
cmdclass=cmdclass, | ||
install_requires=install_requires, | ||
extras_require=extra_deps, | ||
python_requires='>=3.9', | ||
) |