Skip to content

Commit

Permalink
Merge pull request #21 from nrbnlulu/fix-generic-dependencies
Browse files Browse the repository at this point in the history
fix Support nested generics #20
  • Loading branch information
ThirVondukr authored Jan 6, 2025
2 parents 21ced5b + 1aecea7 commit a4a7b50
Show file tree
Hide file tree
Showing 6 changed files with 224 additions and 7 deletions.
Empty file added aioinject/_features/__init__.py
Empty file.
83 changes: 83 additions & 0 deletions aioinject/_features/generics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from __future__ import annotations

import functools
import types
import typing as t
from types import GenericAlias
from typing import TYPE_CHECKING, Any, TypeGuard


if TYPE_CHECKING:
from aioinject.providers import Dependency


def _is_generic_alias(type_: Any) -> TypeGuard[GenericAlias]:
# we currently don't support tuple, list, dict, set, type
return isinstance(
type_,
types.GenericAlias | t._GenericAlias, # type: ignore[attr-defined] # noqa: SLF001
) and t.get_origin(type_) not in (tuple, list, dict, set, type)


def _get_orig_bases(type_: type) -> tuple[type, ...] | None:
return getattr(type_, "__orig_bases__", None)


def _get_generic_arguments(type_: Any) -> list[t.TypeVar] | None:
"""
Returns generic arguments of given class, e.g. Class[T] would return [~T]
"""
if _is_generic_alias(type_):
args = t.get_args(type_)
return [arg for arg in args if isinstance(arg, t.TypeVar)]
return None


@functools.lru_cache
def _get_generic_args_map(type_: type[object]) -> dict[str, type[object]]:
if _is_generic_alias(type_):
args = type_.__args__
params: dict[str, Any] = {
param.__name__: param
for param in type_.__origin__.__parameters__ # type: ignore[attr-defined]
}
# TODO(Doctor, nrbnlulu): Tests pass with strct=True, is this needed?
return dict(zip(params, args, strict=False))

args_map = {}
if orig_bases := _get_orig_bases(type_):
# find the generic parent
for base in orig_bases:
if _is_generic_alias(base):
args = base.__args__
if params := {
param.__name__: param
for param in getattr(base.__origin__, "__parameters__", ())
}:
args_map.update(
dict(zip(params, args, strict=True)),
)
return args_map


@functools.lru_cache
def get_generic_parameter_map(
provided_type: type[object],
dependencies: tuple[Dependency[Any], ...],
) -> dict[str, type[object]]:
args_map = _get_generic_args_map(provided_type) # type: ignore[arg-type]
result = {}
for dependency in dependencies:
if args_map and (
generic_arguments := _get_generic_arguments(dependency.type_)
):
# This is a generic type, we need to resolve the type arguments
# and pass them to the provider.
resolved_args = [
args_map[arg.__name__] for arg in generic_arguments
]
# We can use `[]` when we drop support for 3.10
result[dependency.name] = dependency.type_.__getitem__(
*resolved_args
)
return result
19 changes: 13 additions & 6 deletions aioinject/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from typing_extensions import Self

from aioinject._features.generics import get_generic_parameter_map
from aioinject._store import InstanceStore, NotInCache
from aioinject._types import AnyCtx, T
from aioinject.extensions import (
Expand Down Expand Up @@ -79,13 +80,19 @@ async def resolve(
if (cached := store.get(provider)) is not NotInCache.sentinel:
return cached

dependencies = {}
for dependency in provider.resolve_dependencies(
self._container.type_context,
):
dependencies[dependency.name] = await self.resolve(
type_=dependency.type_,
provider_dependencies = provider.resolve_dependencies(
context=self._container.type_context
)
dependencies_map = get_generic_parameter_map(
type_, # type: ignore[arg-type]
provider_dependencies,
)
dependencies = {
dependency.name: await self.resolve(
dependencies_map.get(dependency.name, dependency.type_)
)
for dependency in provider_dependencies
}

if provider.lifetime is DependencyLifetime.singleton:
async with store.lock(provider) as should_provide:
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ ignore = [
"EXE",
"ISC001", # ruff format conflict
"COM812", # ruff format conflict
"TD003", # Missing TODO link
"FIX002", # TODO in code
]


Expand Down
125 changes: 125 additions & 0 deletions tests/features/test_generics.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,128 @@ async def test_resolve_generics(
async with container.context() as ctx:
instance = await ctx.resolve(type_)
assert isinstance(instance, instanceof)


class NestedGenericService(Generic[T]):
def __init__(self, service: T) -> None:
self.service = service


MEANING_OF_LIFE_INT = 42
MEANING_OF_LIFE_STR = "42"


class Something:
def __init__(self) -> None:
self.a = MEANING_OF_LIFE_INT


async def test_nested_generics() -> None:
container = Container()
container.register(
Scoped(NestedGenericService[WithGenericDependency[Something]]),
Scoped(WithGenericDependency[Something]),
Scoped(Something),
Object(MEANING_OF_LIFE_INT),
Object("42"),
)

async with container.context() as ctx:
instance = await ctx.resolve(
NestedGenericService[WithGenericDependency[Something]]
)
assert isinstance(instance, NestedGenericService)
assert isinstance(instance.service, WithGenericDependency)
assert isinstance(instance.service.dependency, Something)
assert instance.service.dependency.a == MEANING_OF_LIFE_INT


class NestedUnresolvedGeneric(Generic[T]):
def __init__(self, service: WithGenericDependency[T]) -> None:
self.service = service


async def test_nested_unresolved_generic() -> None:
container = Container()
container.register(
Scoped(NestedUnresolvedGeneric[int]),
Scoped(WithGenericDependency[int]),
Object(42),
Object("42"),
)

async with container.context() as ctx:
instance = await ctx.resolve(NestedUnresolvedGeneric[int])
assert isinstance(instance, NestedUnresolvedGeneric)
assert isinstance(instance.service, WithGenericDependency)
assert instance.service.dependency == MEANING_OF_LIFE_INT


async def test_nested_unresolved_concrete_generic() -> None:
class GenericImpl(NestedUnresolvedGeneric[str]):
pass

container = Container()
container.register(
Scoped(GenericImpl),
Scoped(WithGenericDependency[str]),
Object(42),
Object("42"),
)

async with container.context() as ctx:
instance = await ctx.resolve(GenericImpl)
assert isinstance(instance, GenericImpl)
assert isinstance(instance.service, WithGenericDependency)
assert instance.service.dependency == "42"


async def test_partially_resolved_generic() -> None:
K = TypeVar("K")

class TwoGeneric(Generic[T, K]):
def __init__(
self, a: WithGenericDependency[T], b: WithGenericDependency[K]
) -> None:
self.a = a
self.b = b

class UsesTwoGeneric(Generic[T]):
def __init__(self, service: TwoGeneric[T, str]) -> None:
self.service = service

container = Container()
container.register(
Scoped(UsesTwoGeneric[int]),
Scoped(TwoGeneric[int, str]),
Scoped(WithGenericDependency[int]),
Scoped(WithGenericDependency[str]),
Object(MEANING_OF_LIFE_INT),
Object("42"),
)

async with container.context() as ctx:
instance = await ctx.resolve(UsesTwoGeneric[int])
assert isinstance(instance, UsesTwoGeneric)
assert isinstance(instance.service, TwoGeneric)
assert isinstance(instance.service.a, WithGenericDependency)
assert isinstance(instance.service.b, WithGenericDependency)
assert instance.service.a.dependency == MEANING_OF_LIFE_INT
assert instance.service.b.dependency == MEANING_OF_LIFE_STR


async def test_can_resolve_generic_class_without_parameters() -> None:
class GenericClass(Generic[T]):
def __init__(self, a: int) -> None:
self.a = a

def so_generic(self) -> T: # pragma: no cover
raise NotImplementedError

container = Container()
container.register(Scoped(GenericClass), Object(MEANING_OF_LIFE_INT))

async with container.context() as ctx:
instance = await ctx.resolve(GenericClass)
assert isinstance(instance, GenericClass)
assert instance.a == MEANING_OF_LIFE_INT
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit a4a7b50

Please sign in to comment.