Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Python recipe render improvements #57

Merged
merged 14 commits into from
Jan 7, 2025
1 change: 0 additions & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ jobs:
- py311
- py310
- py39
- py38
steps:
- uses: actions/checkout@v4
- uses: prefix-dev/[email protected]
Expand Down
22,198 changes: 8,335 additions & 13,863 deletions pixi.lock

Large diffs are not rendered by default.

6 changes: 1 addition & 5 deletions pixi.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ build_sdist = "pixi run python -m build --sdist"

[dependencies]
python = ">=3.8"
build = ">=0.7.0,<0.8"
python-build = ">=1.2.2.post1,<2"
rattler-build = ">=0.18.1,<1"
conda-build = ">=24.3.0,<25.0"
conda = ">=4.2"
Expand Down Expand Up @@ -60,14 +60,10 @@ python = "3.10.*"
[feature.py39.dependencies]
python = "3.9.*"

[feature.py38.dependencies]
python = "3.8.*"

[environments]
py312 = { features = ["py312", "tests"] }
py311 = ["py311", "tests"]
py310 = ["py310", "tests"]
py39 = ["py39", "tests"]
py38 = ["py38", "tests"]
lint = { features = ["lint"], no-default-feature = true }
type-checking = { features = ["type-checking"] }
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ build-backend = "hatchling.build"
[project]
name = "rattler-build-conda-compat"
description = "A package for exposing rattler-build API for conda-smithy"
version = "1.2.2"
version = "1.3.0"
readme = "README.md"
authors = [{ name = "Nichita Morcotilo", email = "[email protected]" }]
license = { file = "LICENSE.txt" }
Expand Down
40 changes: 25 additions & 15 deletions src/rattler_build_conda_compat/jinja/jinja.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
from __future__ import annotations

from typing import Any, TypedDict
from typing import Any, Mapping, TypedDict

import jinja2
from jinja2.sandbox import SandboxedEnvironment

from rattler_build_conda_compat.jinja.filters import _bool, _split, _version_to_build_string
from rattler_build_conda_compat.jinja.objects import (
_stub_compatible_pin,
_stub_is_linux,
_stub_is_unix,
_stub_is_win,
_stub_match,
_stub_subpackage_pin,
_StubEnv,
Expand All @@ -24,7 +21,7 @@ class RecipeWithContext(TypedDict, total=False):
context: dict[str, str]


def jinja_env() -> SandboxedEnvironment:
def jinja_env(variant_config: Mapping[str, str] | None = None) -> SandboxedEnvironment:
"""
Create a `rattler-build` specific Jinja2 environment with modified syntax.
Target platform, build platform, and mpi are set to linux-64 by default.
Expand All @@ -42,6 +39,20 @@ def jinja_env() -> SandboxedEnvironment:
env_obj = _StubEnv()

# inject rattler-build recipe functions in jinja environment
if not variant_config:
variant_config = {"target_platform": "linux-64", "build_platform": "linux-64", "mpi": "mpi"}

extra_vars = {}
target_platform = variant_config["target_platform"]
if target_platform != "noarch":
# set `linux` / `win`
extra_vars[target_platform.split("-")[0]] = True

if target_platform.startswith("win"):
extra_vars["unix"] = False
else:
extra_vars["unix"] = True

env.globals.update(
{
"compiler": lambda x: x + "_compiler_stub",
Expand All @@ -51,14 +62,11 @@ def jinja_env() -> SandboxedEnvironment:
"cdt": lambda *args, **kwargs: "cdt_stub", # noqa: ARG005
"env": env_obj,
"match": _stub_match,
"is_unix": _stub_is_unix,
"is_win": _stub_is_win,
"is_linux": _stub_is_linux,
"unix": True,
"linux": True,
"target_platform": "linux-64",
"build_platform": "linux-64",
"mpi": "mpi",
"is_unix": lambda x: not x.startswith("win"),
"is_win": lambda x: x.startswith("win"),
"is_linux": lambda x: x.startswith("linux"),
**extra_vars,
**variant_config,
}
)

Expand Down Expand Up @@ -89,7 +97,9 @@ def load_recipe_context(context: dict[str, str], jinja_env: jinja2.Environment)
return context


def render_recipe_with_context(recipe_content: RecipeWithContext) -> dict[str, Any]:
def render_recipe_with_context(
recipe_content: RecipeWithContext, variant_config: Mapping[str, str] | None = None
) -> dict[str, Any]:
"""
Render the recipe using known values from context section.
Unknown values are not evaluated and are kept as it is.
Expand All @@ -106,7 +116,7 @@ def render_recipe_with_context(recipe_content: RecipeWithContext) -> dict[str, A
>>>
```
"""
env = jinja_env()
env = jinja_env(variant_config)
context = recipe_content.get("context", {})
# render out the context section and retrieve dictionary
context_variables = load_recipe_context(context, env)
Expand Down
6 changes: 3 additions & 3 deletions src/rattler_build_conda_compat/modify_recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
import hashlib
import logging
import re
from typing import TYPE_CHECKING, Any, Literal
from typing import TYPE_CHECKING, Any, Literal, MutableMapping

import requests

from rattler_build_conda_compat.jinja.jinja import jinja_env, load_recipe_context
from rattler_build_conda_compat.recipe_sources import Source, get_all_sources
from rattler_build_conda_compat.recipe_sources import get_all_sources
from rattler_build_conda_compat.yaml import _dump_yaml_to_string, _yaml_object

if TYPE_CHECKING:
Expand Down Expand Up @@ -90,7 +90,7 @@ def _has_jinja_version(url: str) -> bool:
return re.search(pattern, url) is not None


def update_hash(source: Source, url: str, hash_: Hash | None) -> None:
def update_hash(source: MutableMapping[str, Any], url: str, hash_: Hash | None) -> None:
"""
Update the sha256 hash in the source dictionary.

Expand Down
117 changes: 100 additions & 17 deletions src/rattler_build_conda_compat/recipe_sources.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,48 @@
from __future__ import annotations

import sys
import typing
from typing import Any, List, TypedDict, Union
from collections.abc import MutableMapping
from dataclasses import dataclass
from typing import Any, List, Union, cast

from rattler_build_conda_compat.jinja.jinja import (
RecipeWithContext,
jinja_env,
load_recipe_context,
)
from rattler_build_conda_compat.loader import _eval_selector
from rattler_build_conda_compat.variant_config import variant_combinations

from .conditional_list import ConditionalList, visit_conditional_list

if sys.version_info < (3, 11):
from typing_extensions import NotRequired
else:
from typing import NotRequired

if typing.TYPE_CHECKING:
from collections.abc import Iterator, Mapping
from collections.abc import Iterator


OptionalUrlList = Union[str, List[str], None]


class Source(TypedDict):
url: NotRequired[str | list[str]]
sha256: NotRequired[str]
md5: NotRequired[str]
@dataclass(frozen=True)
class Source:
url: str | list[str]
template: str | list[str]
context: dict[str, str] | None = None
sha256: str | None = None
md5: str | None = None

def __getitem__(self, key: str) -> str | list[str] | None:
return self.__dict__[key]

def __eq__(self, other: object) -> bool:
if not isinstance(other, Source):
return NotImplemented
return (self.url, self.sha256, self.md5) == (other.url, other.sha256, other.md5)

def get_all_sources(recipe: Mapping[Any, Any]) -> Iterator[Source]:
def __hash__(self) -> int:
return hash((tuple(self.url), self.sha256, self.md5))


def get_all_sources(recipe: MutableMapping[str, Any]) -> Iterator[MutableMapping[str, Any]]:
"""
Get all sources from the recipe. This can be from a list of sources,
a single source, or conditional and its branches.
Expand All @@ -37,30 +56,39 @@ def get_all_sources(recipe: Mapping[Any, Any]) -> Iterator[Source]:
A list of source objects.
"""
sources = recipe.get("source", None)
sources = typing.cast(ConditionalList[Source], sources)
sources = typing.cast(ConditionalList[MutableMapping[str, Any]], sources)

# Try getting all url top-level sources
if sources is not None:
source_list = visit_conditional_list(sources, None)
for source in source_list:
yield source

cache_output = recipe.get("cache", None)
if cache_output is not None:
sources = cache_output.get("source", None)
sources = typing.cast(ConditionalList[MutableMapping[str, Any]], sources)
if sources is not None:
source_list = visit_conditional_list(sources, None)
for source in source_list:
yield source

outputs = recipe.get("outputs", None)
if outputs is None:
return

outputs = visit_conditional_list(outputs, None)
for output in outputs:
sources = output.get("source", None)
sources = typing.cast(ConditionalList[Source], sources)
sources = typing.cast(ConditionalList[MutableMapping[str, Any]], sources)
if sources is None:
continue
source_list = visit_conditional_list(sources, None)
for source in source_list:
yield source


def get_all_url_sources(recipe: Mapping[Any, Any]) -> Iterator[str]:
def get_all_url_sources(recipe: MutableMapping[str, Any]) -> Iterator[str]:
"""
Get all url sources from the recipe. This can be from a list of sources,
a single source, or conditional and its branches.
Expand All @@ -74,9 +102,64 @@ def get_all_url_sources(recipe: Mapping[Any, Any]) -> Iterator[str]:
A list of URLs.
"""

def get_first_url(source: Mapping[str, Any]) -> str:
def get_first_url(source: MutableMapping[str, Any]) -> str:
if isinstance(source["url"], list):
return source["url"][0]
return source["url"]

return (get_first_url(source) for source in get_all_sources(recipe) if "url" in source)


def render_all_sources(
recipe: RecipeWithContext,
variants: list[dict[str, list[str]]],
override_version: str | None = None,
) -> set[Source]:
"""
This function should render _all_ URL sources from the given recipe and with the given variants.
Variants can be loaded with the `variant_config.variant_combinations` module.
Optionally, you can override the version in the recipe context to render URLs with a different version.
"""

def render(template: str | list[str], context: dict[str, str]) -> str | list[str]:
if isinstance(template, list):
return [cast(str, render(t, context)) for t in template]
template = env.from_string(template)
return template.render(context_variables)

if override_version is not None:
recipe["context"]["version"] = override_version

final_sources = set()
for v in variants:
combinations = variant_combinations(v)
for combination in combinations:
env = jinja_env(combination)

context = recipe.get("context", {})
# render out the context section and retrieve dictionary
context_variables = load_recipe_context(context, env)

# now evaluate the if / else statements
sources = recipe.get("source")
if sources:
if not isinstance(sources, list):
sources = [sources]

for elem in visit_conditional_list(
sources,
lambda x, combination=combination: _eval_selector(x, combination), # type: ignore[misc]
):
# we need to explicitly cast here
elem_dict = typing.cast(dict[str, Any], elem)
if "url" in elem_dict:
as_url = Source(
url=render(elem_dict["url"], context_variables),
template=elem_dict["url"],
sha256=elem_dict.get("sha256"),
md5=elem_dict.get("md5"),
context=context_variables,
)
final_sources.add(as_url)

return final_sources
47 changes: 47 additions & 0 deletions src/rattler_build_conda_compat/variant_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from __future__ import annotations

from itertools import product


def variant_combinations(data: dict[str, list[str]]) -> list[dict[str, str]]:
"""
This function takes a "variant" configuration dictionary that gets expanded into multiple build matrices.

Arguments:
----------
* `data` - A dictionary with keys as the variant names and values as the possible values.
* `zip_keys` - A list of lists of keys that should be zipped together.

Returns:
--------
A list of dictionaries that represent the different combinations of the variant configuration
"""
zip_keys = data.pop("zip_keys", [])
# Separate the keys that need to be zipped from the rest
zip_keys_flat = [item for sublist in zip_keys for item in sublist]
other_keys = [key for key in data if key not in zip_keys_flat]

# Create combinations for non-zipped keys
other_combinations = list(product(*[data[key] for key in other_keys]))

# Create zipped combinations
zipped_combinations = [list(zip(*[data[key] for key in zip_group])) for zip_group in zip_keys]

# Combine zipped combinations
zipped_product = list(product(*zipped_combinations))

# Combine all results into dictionaries
final_combinations = []
for other_combo in other_combinations:
for zipped_combo in zipped_product:
combined = {}
# Add non-zipped items
for key, value in zip(other_keys, other_combo):
combined[key] = str(value)
# Add zipped items
for zip_group, zip_values in zip(zip_keys, zipped_combo):
for key, value in zip(zip_group, zip_values):
combined[key] = str(value)
final_combinations.append(combined)

return final_combinations
Loading
Loading