Skip to content

Commit

Permalink
tests for py3.10 pass
Browse files Browse the repository at this point in the history
  • Loading branch information
nrbnlulu committed Dec 29, 2024
1 parent a195d68 commit 4554e0f
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 45 deletions.
36 changes: 20 additions & 16 deletions aioinject/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,20 +74,23 @@ def register(self, provider: Provider[Any]) -> None:

def is_generic_alias(type_: Any) -> TypeGuard[GenericAlias]:
# we currently don't support tuple, list, dict, set, type
return isinstance(type_, types.GenericAlias | t._GenericAlias) and t.get_origin(type_) not in (tuple, list, dict, set, type) # type: ignore[reportAttributeAccessIssue] # noqa: SLF001
return isinstance(
type_,
types.GenericAlias | t._GenericAlias, # type: ignore[attr-defined] # noqa: SLF001
) and t.get_origin(type_) not in (tuple, list, dict, set, type)


def get_orig_bases(type_: type) -> tuple[type, ...] | None:
return getattr(type_, "__orig_bases__", None)


def get_typevars(type_: Any) -> list[t.TypeVar] | None:
if is_generic_alias(type_):
args = t.get_args(type_)
return [
arg
for arg in args
if isinstance(arg, t.TypeVar)
]
return [arg for arg in args if isinstance(arg, t.TypeVar)]
return None


class InjectionContext(_BaseInjectionContext[ContextExtension]):
async def resolve( # noqa: C901, PLR0912
self,
Expand All @@ -104,7 +107,7 @@ async def resolve( # noqa: C901, PLR0912
if is_generic_alias(type_):
type_is_generic = True
args = type_.__args__
params = type_.__origin__.__parameters__
params = type_.__origin__.__parameters__ # type: ignore[attr-defined]
for param, arg in zip(params, args, strict=False):
args_map[param.__name__] = arg
elif orig_bases := get_orig_bases(type_):
Expand All @@ -113,26 +116,27 @@ async def resolve( # noqa: C901, PLR0912
for base in orig_bases:
if is_generic_alias(base):
args = base.__args__
if params := getattr(base.__origin__, "__parameters__", None):
if params := getattr(
base.__origin__, "__parameters__", None
):
for param, arg in zip(params, args, strict=False):
args_map[param.__name__] = arg
if not args_map:
# type may be generic though user didn't provide any type parameters
type_is_generic = False


for dependency in provider.resolve_dependencies(
self._container.type_context,
):
if type_is_generic and (args:= get_typevars(dependency.type_)):
if type_is_generic and (
dep_args := get_typevars(dependency.type_)
):
# This is a generic type, we need to resolve the type arguments
# and pass them to the provider.
resolved_args = [
args_map[arg.__name__]
for arg in
args
]
resolved_type = dependency.type_[*resolved_args]
resolved_args = [args_map[arg.__name__] for arg in dep_args]
origin = t.get_origin(dependency.type_)
assert origin is not None # noqa: S101
resolved_type = origin.__class_getitem__(*resolved_args)
dependencies[dependency.name] = await self.resolve(
type_=resolved_type,
)
Expand Down
76 changes: 47 additions & 29 deletions tests/features/test_generics.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,43 +72,58 @@ class NestedGenericService(Generic[T]):
def __init__(self, service: T) -> None:
self.service = service


MEANING_OF_LIFE_INT = 42
MEANING_OF_LIFE_STR = "42"


class Something:
def __init__(self) -> None:
self.a = MEANING_OF_LIFE_INT


async def test_nested_generics() -> None:
container = Container()
container.register(
Scoped(NestedGenericService[WithGenericDependency[Something]]),
Scoped(WithGenericDependency[Something]),
Scoped(Something),
Object(MEANING_OF_LIFE_INT),
Object("42"))
Scoped(WithGenericDependency[Something]),
Scoped(Something),
Object(MEANING_OF_LIFE_INT),
Object("42"),
)

async with container.context() as ctx:
instance = await ctx.resolve(NestedGenericService[WithGenericDependency[Something]])
instance = await ctx.resolve(
NestedGenericService[WithGenericDependency[Something]]
)
assert isinstance(instance, NestedGenericService)
assert isinstance(instance.service, WithGenericDependency)
assert isinstance(instance.service.dependency, Something)
assert instance.service.dependency.a == MEANING_OF_LIFE_INT


IS_PY_312 = sys.version_info >= (3, 12)
skip_ifnot_312 = pytest.mark.skipif(not IS_PY_312, reason="Python 3.12+ required")


def skip_ifnot_312(reasone: str) -> pytest.MarkDecorator:
return pytest.mark.skipif(
not IS_PY_312, reason=f"Python 3.12+ required: {reasone}"
)


class TestNestedUnresolvedGeneric(Generic[T]):
def __init__(self, service: WithGenericDependency[T]) -> None:
def __init__(self, service: WithGenericDependency[T]) -> None:
self.service = service


async def test_nested_unresolved_generic() -> None:
container = Container()
container.register(Scoped(TestNestedUnresolvedGeneric[int]),
Scoped(WithGenericDependency[int]),
Object(42),
Object("42"))
container.register(
Scoped(TestNestedUnresolvedGeneric[int]),
Scoped(WithGenericDependency[int]),
Object(42),
Object("42"),
)

async with container.context() as ctx:
instance = await ctx.resolve(TestNestedUnresolvedGeneric[int])
Expand All @@ -117,19 +132,17 @@ async def test_nested_unresolved_generic() -> None:
assert instance.service.dependency == MEANING_OF_LIFE_INT





async def test_nested_unresolved_concrete_generic() -> None:
class GenericImpl(TestNestedUnresolvedGeneric[str]):
pass


container = Container()
container.register(Scoped(GenericImpl),
Scoped(WithGenericDependency[str]),
Object(42),
Object("42"))
container.register(
Scoped(GenericImpl),
Scoped(WithGenericDependency[str]),
Object(42),
Object("42"),
)

async with container.context() as ctx:
instance = await ctx.resolve(GenericImpl)
Expand All @@ -138,26 +151,32 @@ class GenericImpl(TestNestedUnresolvedGeneric[str]):
assert instance.service.dependency == "42"



@skip_ifnot_312(
"Partially concrete generics will raise TypeError in prior versions"
)
async def test_partially_resolved_generic() -> None:
K = TypeVar("K")

class TwoGeneric(Generic[T, K]):
def __init__(self, a: WithGenericDependency[T], b: WithGenericDependency[K]) -> None:
def __init__(
self, a: WithGenericDependency[T], b: WithGenericDependency[K]
) -> None:
self.a = a
self.b = b


class UsesTwoGeneric(Generic[T]):
def __init__(self, service: TwoGeneric[T, str]) -> None:
self.service = service

container = Container()
container.register(Scoped(UsesTwoGeneric[int]),
Scoped(TwoGeneric[int, str]),
Scoped(WithGenericDependency[int]),
Scoped(WithGenericDependency[str]),
Object(MEANING_OF_LIFE_INT),
Object("42"))
container.register(
Scoped(UsesTwoGeneric[int]),
Scoped(TwoGeneric[int, str]),
Scoped(WithGenericDependency[int]),
Scoped(WithGenericDependency[str]),
Object(MEANING_OF_LIFE_INT),
Object("42"),
)

async with container.context() as ctx:
instance = await ctx.resolve(UsesTwoGeneric[int])
Expand All @@ -169,7 +188,6 @@ def __init__(self, service: TwoGeneric[T, str]) -> None:
assert instance.service.b.dependency == MEANING_OF_LIFE_STR



async def test_can_resolve_generic_class_without_parameters() -> None:
class GenericClass(Generic[T]):
def __init__(self, a: int) -> None:
Expand Down

0 comments on commit 4554e0f

Please sign in to comment.