Skip to content

Commit

Permalink
added versioneer and updated setuptools building
Browse files Browse the repository at this point in the history
  • Loading branch information
stefdoerr committed Feb 23, 2024
1 parent 166b7db commit 12d97aa
Show file tree
Hide file tree
Showing 4 changed files with 758 additions and 26 deletions.
40 changes: 40 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
[project]
name = "torchmd-net"
description = "TorchMD-Net package"
authors = [{ name = "Acellera", email = "[email protected]" }]
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.8"
dynamic = ["version"]
classifiers = [
"Programming Language :: Python :: 3",
"Operating System :: POSIX :: Linux",
]
dependencies = []

[project.urls]
"Homepage" = "https://github.com/torchmd/torchmd-net"
"Bug Tracker" = "https://github.com/torchmd/torchmd-net/issues"

[project.scripts]
torchmd-train = "torchmdnet.scripts.train:main"

[tool.setuptools.packages.find]
where = [""]
include = ["torchmdnet*"]
namespaces = false

[tool.setuptools.package-data]
torchmdnet = ["extensions/torchmdnet_extensions.so"]

[tool.versioneer]
VCS = "git"
style = "pep440"
versionfile_source = "torchmdnet/_version.py"
versionfile_build = "torchmdnet/_version.py"
tag_prefix = ""
parentdir_prefix = "torchmdnet-"

[build-system]
requires = ["setuptools", "toml", "versioneer[toml]==0.28", "torch<2.2"]
build-backend = "setuptools.build_meta"
48 changes: 22 additions & 26 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,24 @@
# Distributed under the MIT License.
# (See accompanying file README.md file or copy at http://opensource.org/licenses/MIT)

import subprocess
from setuptools import setup, find_packages
from setuptools import setup
import torch
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, include_paths, CppExtension
from torch.utils.cpp_extension import (
BuildExtension,
CUDAExtension,
include_paths,
CppExtension,
)
import versioneer
import os

try:
version = (
subprocess.check_output(["git", "describe", "--abbrev=0", "--tags"])
.strip()
.decode("utf-8")
)
except:
print("Failed to retrieve the current version, defaulting to 0")
version = "0"
# If CPU_ONLY is defined
force_cpu_only = os.environ.get("CPU_ONLY", None) is not None
use_cuda = torch.cuda._is_compiled() if not force_cpu_only else False


def set_torch_cuda_arch_list():
""" Set the CUDA arch list according to the architectures the current torch installation was compiled for.
"""Set the CUDA arch list according to the architectures the current torch installation was compiled for.
This function is a no-op if the environment variable TORCH_CUDA_ARCH_LIST is already set or if torch was not compiled with CUDA support.
"""
if not os.environ.get("TORCH_CUDA_ARCH_LIST"):
Expand All @@ -32,31 +30,29 @@ def set_torch_cuda_arch_list():
formatted_versions += "+PTX"
os.environ["TORCH_CUDA_ARCH_LIST"] = formatted_versions


set_torch_cuda_arch_list()

extension_root= os.path.join("torchmdnet", "extensions")
neighbor_sources=["neighbors_cpu.cpp"]
extension_root = os.path.join("torchmdnet", "extensions")
neighbor_sources = ["neighbors_cpu.cpp"]
if use_cuda:
neighbor_sources.append("neighbors_cuda.cu")
neighbor_sources = [os.path.join(extension_root, "neighbors", source) for source in neighbor_sources]
neighbor_sources = [
os.path.join(extension_root, "neighbors", source) for source in neighbor_sources
]

ExtensionType = CppExtension if not use_cuda else CUDAExtension
extensions = ExtensionType(
name='torchmdnet.extensions.torchmdnet_extensions',
name="torchmdnet.extensions.torchmdnet_extensions",
sources=[os.path.join(extension_root, "extensions.cpp")] + neighbor_sources,
include_dirs=include_paths(),
define_macros=[('WITH_CUDA', 1)] if use_cuda else [],
define_macros=[("WITH_CUDA", 1)] if use_cuda else [],
)

if __name__ == "__main__":
buildext = BuildExtension.with_options(no_python_abi_suffix=True, use_ninja=False)
setup(
name="torchmd-net",
version=version,
packages=find_packages(),
ext_modules=[extensions],
cmdclass={
'build_ext': BuildExtension.with_options(no_python_abi_suffix=True, use_ninja=False)},
include_package_data=True,
entry_points={"console_scripts": ["torchmd-train = torchmdnet.scripts.train:main"]},
package_data={"torchmdnet": ["extensions/torchmdnet_extensions.so"]},
version=versioneer.get_version(),
cmdclass=versioneer.get_cmdclass({"build_ext": buildext}),
)
3 changes: 3 additions & 0 deletions torchmdnet/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from torchmdnet import _version

__version__ = _version.get_versions()["version"]
Loading

0 comments on commit 12d97aa

Please sign in to comment.