From 9fd16d6da9f95c2ba114be64146cf8139115c7ae Mon Sep 17 00:00:00 2001 From: Soran Ghaderi Date: Sat, 5 Oct 2024 17:54:50 +0100 Subject: [PATCH] refactor: update pyproject.toml using actions 4 --- .github/workflows/tag-release.yml | 86 +++++++------------------------ pyproject.toml | 4 -- torchebm/core/energy_function.py | 6 --- torchebm/core/sampler.py | 2 - torchebm/models/base_model.py | 3 +- 5 files changed, 22 insertions(+), 79 deletions(-) diff --git a/.github/workflows/tag-release.yml b/.github/workflows/tag-release.yml index 5f0ef3e..d834a95 100644 --- a/.github/workflows/tag-release.yml +++ b/.github/workflows/tag-release.yml @@ -1,52 +1,9 @@ -#name: Tag Release -# -#on: -# push: -# branches: -# - master -# -#jobs: -# tag: -# runs-on: ubuntu-latest -# -# steps: -# - name: Checkout code -# uses: actions/checkout@v3 -# -# - name: Get latest tag -# id: get_latest_tag -# run: | -# git fetch --tags -# TAG=$(git tag --sort=-creatordate | head -n 1) -# echo "Latest tag is $TAG" -# echo "::set-output name=latest::$TAG" -# -# - name: Determine new version -# id: new_version -# run: | -# latest_tag="${{ steps.get_latest_tag.outputs.latest }}" -# if [ -z "$latest_tag" ]; then -# new_version="v0.1.0" -# else -# # Increment the patch version -# new_version=$(echo $latest_tag | awk -F. -v OFS=. '{$NF++;print}') -# fi -# echo "New version is $new_version" -# echo "::set-output name=new_version::$new_version" -# -# - name: Create new tag -# run: | -# git tag "${{ steps.new_version.outputs.new_version }}" -# git push origin "${{ steps.new_version.outputs.new_version }}" - #name: Release and Publish # #on: # push: # branches: # - master # Adjust as necessary -# release: -# types: [published] # Trigger on release published # #jobs: # release: @@ -69,10 +26,10 @@ # run: | # latest_tag="${{ steps.get_latest_tag.outputs.latest }}" # if [ -z "$latest_tag" ]; then -# new_version="v0.1.0" +# new_version="0.1.0" # No 'v' prefix for version number # else # # Increment the patch version -# new_version=$(echo $latest_tag | awk -F. -v OFS=. '{$NF++;print}') +# new_version=$(echo $latest_tag | awk -F. -v OFS=. '{$NF++;print}' | sed 's/^v//') # fi # echo "New version is $new_version" # echo "::set-output name=new_version::$new_version" @@ -83,8 +40,8 @@ # # - name: Create new tag # run: | -# git tag "${{ steps.new_version.outputs.new_version }}" -# git push origin "${{ steps.new_version.outputs.new_version }}" +# git tag "v${{ steps.new_version.outputs.new_version }}" +# git push origin "v${{ steps.new_version.outputs.new_version }}" # # - name: Set up Python # uses: actions/setup-python@v4 @@ -99,14 +56,6 @@ # - name: Set PYTHONPATH # run: echo "PYTHONPATH=$(pwd)/torchebm" >> $GITHUB_ENV # -# -# - name: Install package -# run: | -# pip install . -# -# - name: Set PYTHONPATH -# run: echo "PYTHONPATH=$(pwd)/torchebm" >> $GITHUB_ENV -# # - name: Install package # run: | # pip install . @@ -127,13 +76,14 @@ # password: ${{ secrets.PYPI_API_TOKEN_EBM }} - name: Release and Publish on: push: branches: - master # Adjust as necessary + paths: + - 'pyproject.toml' # Trigger on changes to pyproject.toml jobs: release: @@ -143,6 +93,13 @@ jobs: - name: Checkout code uses: actions/checkout@v3 + - name: Get current version from pyproject.toml + id: get_version + run: | + version=$(grep '^version =' pyproject.toml | sed 's/version = //; s/"//g') + echo "Current version is $version" + echo "::set-output name=current_version::$version" + - name: Get latest tag id: get_latest_tag run: | @@ -155,18 +112,15 @@ jobs: id: new_version run: | latest_tag="${{ steps.get_latest_tag.outputs.latest }}" - if [ -z "$latest_tag" ]; then - new_version="0.1.0" # No 'v' prefix for version number + current_version="${{ steps.get_version.outputs.current_version }}" + + if [ "$latest_tag" == "$current_version" ]; then + echo "No new version detected." + exit 1 # Exit if no new version is found else - # Increment the patch version - new_version=$(echo $latest_tag | awk -F. -v OFS=. '{$NF++;print}' | sed 's/^v//') + echo "New version detected: $current_version" + echo "::set-output name=new_version::$current_version" fi - echo "New version is $new_version" - echo "::set-output name=new_version::$new_version" - - - name: Update pyproject.toml version - run: | - sed -i "s/^version = .*/version = \"${{ steps.new_version.outputs.new_version }}\"/" pyproject.toml - name: Create new tag run: | diff --git a/pyproject.toml b/pyproject.toml index 9dffbe6..fce1dde 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,15 +5,11 @@ build-backend = "setuptools.build_meta" [project] name = "torchEBM" # Replace with your package name version = "0.1.0" -#dynamic = ["version"] description = "Components and algorithms for energy-based models" readme = "README.md" license = { file = "LICENSE" } authors = [{ name="Soran Ghaderi", email="soran.gdr.cs@gmail.com" }] dependencies = [ - # List your dependencies here, e.g., "numpy", "torch" ] [project.urls] homepage = "https://github.com/soran-ghaderi/torchebm" - - diff --git a/torchebm/core/energy_function.py b/torchebm/core/energy_function.py index e74ca42..0a03e40 100644 --- a/torchebm/core/energy_function.py +++ b/torchebm/core/energy_function.py @@ -9,9 +9,3 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: @abstractmethod def gradient(self, x: torch.Tensor) -> torch.Tensor: pass - - def cuda_forward(self, x: torch.Tensor) -> torch.Tensor: - return self.forward(x) # Default to PyTorch implementation - - def cuda_gradient(self, x: torch.Tensor) -> torch.Tensor: - return self.gradient(x) \ No newline at end of file diff --git a/torchebm/core/sampler.py b/torchebm/core/sampler.py index 005afaf..6de500b 100644 --- a/torchebm/core/sampler.py +++ b/torchebm/core/sampler.py @@ -9,5 +9,3 @@ class Sampler(ABC): def sample(self, energy_function: EnergyFunction, initial_state: torch.Tensor, num_steps: int) -> torch.Tensor: pass - def cuda_sample(self, energy_function: EnergyFunction, initial_state: torch.Tensor, num_steps: int) -> torch.Tensor: - return self.sample(energy_function, initial_state, num_steps) # Default to PyTorch implementation diff --git a/torchebm/models/base_model.py b/torchebm/models/base_model.py index c954e49..4fc86ec 100644 --- a/torchebm/models/base_model.py +++ b/torchebm/models/base_model.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod import torch -from ..core import EnergyFunction, Sampler +from torchebm.core.energy_function import EnergyFunction +from torchebm.core.sampler import Sampler class BaseModel(ABC): def __init__(self, energy_function: EnergyFunction, sampler: Sampler):