diff --git a/src/GraphQL.Conventions/Types/Resolution/Extensions/ReflectionExtensions.cs b/src/GraphQL.Conventions/Types/Resolution/Extensions/ReflectionExtensions.cs index 8be70f5..0207af0 100644 --- a/src/GraphQL.Conventions/Types/Resolution/Extensions/ReflectionExtensions.cs +++ b/src/GraphQL.Conventions/Types/Resolution/Extensions/ReflectionExtensions.cs @@ -30,6 +30,46 @@ public static bool IsNullableType(this TypeInfo type) return type.IsGenericType(typeof(Nullable<>)); } + public static Type GetImplementationInterface(this Type type, Type interfaceType, bool fuseGeneric = true) + { + if (!interfaceType.IsInterface) + return null; + + fuseGeneric &= interfaceType.IsGenericType; + if (type.IsGenericType && fuseGeneric + ? type.GetGenericTypeDefinition() == interfaceType.GetGenericTypeDefinition() + : type == interfaceType) + { + return interfaceType; + } + + while (type is not null) + { + var interfaces = type.GetInterfaces(); + var mayFusedGenericInterface = fuseGeneric + ? interfaces.Select(t => t.IsGenericType ? t.GetGenericTypeDefinition() : t).ToArray() + : interfaces; + + for (int i = 0; i < interfaces.Length; i++) + { + var @interface = interfaces[i]; + if (mayFusedGenericInterface[i] == interfaceType) + return interfaces[i]; + var ni = @interface.GetImplementationInterface(interfaceType, fuseGeneric); + if (ni is not null) + return ni; + } + + type = type.BaseType; + } + + return null; + } + + public static bool IsImplementingInterface(this Type type, Type interfaceType, bool fuseGeneric = true) => + type.GetImplementationInterface(interfaceType, fuseGeneric) is not null; + + public static TypeInfo BaseType(this TypeInfo type) { return type.IsNullableType() @@ -45,7 +85,7 @@ public static TypeInfo TypeParameter(this TypeInfo type) } return type.IsGenericType ? type.GenericTypeArguments.First().GetTypeInfo() - : null; + : type.TypeParameterCollection(); } public static TypeInfo TypeParameter(this GraphTypeInfo type) @@ -53,6 +93,11 @@ public static TypeInfo TypeParameter(this GraphTypeInfo type) return type.TypeRepresentation.TypeParameter(); } + public static TypeInfo TypeParameterCollection(this TypeInfo type) => ( + type.GetImplementationInterface(typeof(ICollection<>)) ?? + type.GetImplementationInterface(typeof(IReadOnlyList<>)) + )?.GetTypeInfo(); + public static bool IsPrimitiveGraphType(this TypeInfo type) { return type.IsPrimitive || @@ -63,8 +108,13 @@ public static bool IsPrimitiveGraphType(this TypeInfo type) public static bool IsEnumerableGraphType(this TypeInfo type) { - return type.IsGenericType(typeof(List<>)) || - type.IsGenericType(typeof(IList<>)) || + if (type.IsImplementingInterface(typeof(IDictionary)) || type.IsImplementingInterface(typeof(IDictionary<,>))) + { + return false; + } + + return type.IsImplementingInterface(typeof(ICollection<>)) || + type.IsImplementingInterface(typeof(IReadOnlyCollection<>)) || type.IsGenericType(typeof(IEnumerable<>)) || (type.IsGenericType && type.DeclaringType == typeof(Enumerable)) || // Handles internal Iterator implementations for LINQ; for reference https://referencesource.microsoft.com/#system.core/System/Linq/Enumerable.cs type.IsArray; @@ -93,16 +143,19 @@ public static TypeInfo GetTypeRepresentation(this TypeInfo typeInfo) { typeInfo = typeInfo.TypeParameter(); } + if (typeInfo.IsGenericType(typeof(IObservable<>))) { typeInfo = typeInfo.TypeParameter(); } + if (typeInfo.IsGenericType(typeof(Nullable<>)) || typeInfo.IsGenericType(typeof(NonNull<>)) || typeInfo.IsGenericType(typeof(Optional<>))) { typeInfo = typeInfo.TypeParameter(); } + return typeInfo; } diff --git a/test/GraphQL.Conventions.Tests/Types/Resolution/Extensions/ReflectionExtensionsTests.cs b/test/GraphQL.Conventions.Tests/Types/Resolution/Extensions/ReflectionExtensionsTests.cs new file mode 100644 index 0000000..dd9b32a --- /dev/null +++ b/test/GraphQL.Conventions.Tests/Types/Resolution/Extensions/ReflectionExtensionsTests.cs @@ -0,0 +1,110 @@ +using System; +using System.Collections; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Reflection; +using GraphQL.Conventions.Types.Resolution.Extensions; +using Xunit; + +namespace Tests.Types.Resolution.Extensions +{ + public class ReflectionExtensionsTests + { + [Theory] + [MemberData(nameof(IsEnumerableGraphType_Should_Return_True_For_Common_Collection_Types_Data))] + public void IsEnumerableGraphType_Should_Return_True_For_Common_Collection_Types(Type type) + { + Assert.IsTrue(type.GetTypeInfo().IsEnumerableGraphType()); + } + + public static TheoryData IsEnumerableGraphType_Should_Return_True_For_Common_Collection_Types_Data() => new() + { + typeof(IEnumerable<>), + typeof(ConcurrentQueue<>), + typeof(HashSet<>), + typeof(int[]), + typeof(List<>), + typeof(IList<>), + typeof(IReadOnlyList<>), + typeof(IReadOnlyCollection<>), + }; + + [Theory] + [MemberData(nameof(IsEnumerableGraphType_Should_Return_False_For_Common_Dictionary_Types_Data))] + public void IsEnumerableGraphType_Should_Return_False_For_Common_Dictionary_Types(Type type) + { + Assert.IsFalse(type.GetTypeInfo().IsEnumerableGraphType()); + } + + public static TheoryData IsEnumerableGraphType_Should_Return_False_For_Common_Dictionary_Types_Data() => new() + { + typeof(IDictionary<,>), + typeof(IDictionary), + }; + + [Theory] + [MemberData(nameof(GetImplementationInterface_WithoutFuse_AcquireSpecifiedInterfaceOnly_Data))] + public void GetImplementationInterface_WithoutFuse_AcquireSpecifiedInterfaceOnly(Type type, Type assignableInterface) + => Assert.IsTrue(type.IsImplementingInterface(assignableInterface, false)); + + public static TheoryData GetImplementationInterface_WithoutFuse_AcquireSpecifiedInterfaceOnly_Data() => + new() + { + { typeof(ITestInterface1), typeof(ITestInterface1) }, + { typeof(ITestInterface2), typeof(ITestInterface1) }, + { typeof(ITestInterface2), typeof(ITestInterface2) }, + { typeof(TestClass11), typeof(ITestInterface1) }, + { typeof(TestClass12), typeof(ITestInterface1) }, + { typeof(TestClass12), typeof(ITestInterface2) }, + { typeof(TestClass2), typeof(ITestInterface1) }, + { typeof(TestClass2), typeof(ITestInterface1) }, + { typeof(TestClass2), typeof(ITestInterface2) }, + }; + + [Theory] + [MemberData(nameof(GetImplementationInterface_WithFuse_AcquireSpecifiedInterfaceOnly_Data))] + public void GetImplementationInterface_WithFuse_AcquireSpecifiedInterfaceOnly(Type type, Type assignableInterface) => + Assert.IsTrue(type.IsImplementingInterface(assignableInterface)); + + public static TheoryData GetImplementationInterface_WithFuse_AcquireSpecifiedInterfaceOnly_Data() => + new() + { + { typeof(ITestInterface1), typeof(ITestInterface1<>) }, + { typeof(ITestInterface2), typeof(ITestInterface1<>) }, + { typeof(ITestInterface2), typeof(ITestInterface2<>) }, + { typeof(TestClass11), typeof(ITestInterface1<>) }, + // + { typeof(ITestInterface1), typeof(ITestInterface1<>) }, + { typeof(ITestInterface2), typeof(ITestInterface1<>) }, + { typeof(ITestInterface2), typeof(ITestInterface2<>) }, + { typeof(TestClass11), typeof(ITestInterface1<>) }, + // + { typeof(TestClass12), typeof(ITestInterface1<>) }, + { typeof(TestClass12), typeof(ITestInterface2<>) }, + { typeof(TestClass2), typeof(ITestInterface1<>) }, + { typeof(TestClass2), typeof(ITestInterface2<>) }, + }; + + private interface ITestInterface1 + { + // ReSharper disable once UnusedMember.Global + T Item => throw new NotImplementedException(); + } + + private interface ITestInterface2 : ITestInterface1 + { + } + + private class TestClass11 : ITestInterface1 + { + } + + private class TestClass12 : ITestInterface2 + { + } + + private class TestClass2 : ITestInterface2, ITestInterface1 + { + } + } +}