diff --git a/GFramework.SourceGenerators.Tests/Cqrs/CqrsHandlerRegistryGeneratorTests.cs b/GFramework.SourceGenerators.Tests/Cqrs/CqrsHandlerRegistryGeneratorTests.cs index 679fd0fa..ef70ec4a 100644 --- a/GFramework.SourceGenerators.Tests/Cqrs/CqrsHandlerRegistryGeneratorTests.cs +++ b/GFramework.SourceGenerators.Tests/Cqrs/CqrsHandlerRegistryGeneratorTests.cs @@ -29,115 +29,26 @@ public class CqrsHandlerRegistryGeneratorTests var registryAssembly = typeof(global::GFramework.Generated.Cqrs.__GFrameworkGeneratedCqrsHandlerRegistry).Assembly; - RegisterReflectedHandler(services, logger, registryAssembly, "TestApp.Container+HiddenHandler"); + var implementationType0 = registryAssembly.GetType("TestApp.Container+HiddenHandler", throwOnError: false, ignoreCase: false); + if (implementationType0 is not null) + { + var serviceType0_0Argument0 = registryAssembly.GetType("TestApp.Container+HiddenRequest", throwOnError: false, ignoreCase: false); + if (serviceType0_0Argument0 is not null) + { + var serviceType0_0 = typeof(global::GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler<,>).MakeGenericType(serviceType0_0Argument0, typeof(string)); + global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient( + services, + serviceType0_0, + implementationType0); + logger.Debug("Registered CQRS handler TestApp.Container.HiddenHandler as GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler."); + } + } global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient( services, typeof(global::GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler), typeof(global::TestApp.VisibleHandler)); logger.Debug("Registered CQRS handler TestApp.VisibleHandler as GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler."); } - - private static void RegisterReflectedHandler(global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, global::GFramework.Core.Abstractions.Logging.ILogger logger, global::System.Reflection.Assembly registryAssembly, string implementationTypeMetadataName) - { - var implementationType = registryAssembly.GetType(implementationTypeMetadataName, throwOnError: false, ignoreCase: false); - if (implementationType is null) - return; - - var handlerInterfaces = implementationType.GetInterfaces(); - global::System.Array.Sort(handlerInterfaces, CompareTypes); - - foreach (var handlerInterface in handlerInterfaces) - { - if (!IsSupportedHandlerInterface(handlerInterface)) - continue; - - global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient( - services, - handlerInterface, - implementationType); - logger.Debug($"Registered CQRS handler {GetRuntimeTypeDisplayName(implementationType)} as {GetRuntimeTypeDisplayName(handlerInterface)}."); - } - } - - private static int CompareTypes(global::System.Type left, global::System.Type right) - { - return global::System.StringComparer.Ordinal.Compare(GetRuntimeTypeDisplayName(left), GetRuntimeTypeDisplayName(right)); - } - - private static bool IsSupportedHandlerInterface(global::System.Type interfaceType) - { - if (!interfaceType.IsGenericType) - return false; - - var definitionFullName = interfaceType.GetGenericTypeDefinition().FullName; - return global::System.StringComparer.Ordinal.Equals(definitionFullName, "GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler`2") - || global::System.StringComparer.Ordinal.Equals(definitionFullName, "GFramework.Cqrs.Abstractions.Cqrs.INotificationHandler`1") - || global::System.StringComparer.Ordinal.Equals(definitionFullName, "GFramework.Cqrs.Abstractions.Cqrs.IStreamRequestHandler`2"); - } - - private static string GetRuntimeTypeDisplayName(global::System.Type type) - { - if (type == typeof(string)) - return "string"; - if (type == typeof(int)) - return "int"; - if (type == typeof(long)) - return "long"; - if (type == typeof(short)) - return "short"; - if (type == typeof(byte)) - return "byte"; - if (type == typeof(bool)) - return "bool"; - if (type == typeof(object)) - return "object"; - if (type == typeof(void)) - return "void"; - if (type == typeof(uint)) - return "uint"; - if (type == typeof(ulong)) - return "ulong"; - if (type == typeof(ushort)) - return "ushort"; - if (type == typeof(sbyte)) - return "sbyte"; - if (type == typeof(float)) - return "float"; - if (type == typeof(double)) - return "double"; - if (type == typeof(decimal)) - return "decimal"; - if (type == typeof(char)) - return "char"; - - if (type.IsArray) - return GetRuntimeTypeDisplayName(type.GetElementType()!) + "[]"; - - if (!type.IsGenericType) - return (type.FullName ?? type.Name).Replace('+', '.'); - - var genericTypeName = type.GetGenericTypeDefinition().FullName ?? type.Name; - var arityIndex = genericTypeName.IndexOf('`'); - if (arityIndex >= 0) - genericTypeName = genericTypeName[..arityIndex]; - - genericTypeName = genericTypeName.Replace('+', '.'); - var arguments = type.GetGenericArguments(); - var builder = new global::System.Text.StringBuilder(); - builder.Append(genericTypeName); - builder.Append('<'); - - for (var index = 0; index < arguments.Length; index++) - { - if (index > 0) - builder.Append(", "); - - builder.Append(GetRuntimeTypeDisplayName(arguments[index])); - } - - builder.Append('>'); - return builder.ToString(); - } } """; @@ -175,6 +86,84 @@ public class CqrsHandlerRegistryGeneratorTests """; + private const string HiddenArrayResponseFallbackExpected = """ + // + #nullable enable + + [assembly: global::GFramework.Cqrs.CqrsHandlerRegistryAttribute(typeof(global::GFramework.Generated.Cqrs.__GFrameworkGeneratedCqrsHandlerRegistry))] + + namespace GFramework.Generated.Cqrs; + + internal sealed class __GFrameworkGeneratedCqrsHandlerRegistry : global::GFramework.Cqrs.ICqrsHandlerRegistry + { + public void Register(global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, global::GFramework.Core.Abstractions.Logging.ILogger logger) + { + if (services is null) + throw new global::System.ArgumentNullException(nameof(services)); + if (logger is null) + throw new global::System.ArgumentNullException(nameof(logger)); + + var registryAssembly = typeof(global::GFramework.Generated.Cqrs.__GFrameworkGeneratedCqrsHandlerRegistry).Assembly; + + var implementationType0 = registryAssembly.GetType("TestApp.Container+HiddenHandler", throwOnError: false, ignoreCase: false); + if (implementationType0 is not null) + { + var serviceType0_0Argument0 = registryAssembly.GetType("TestApp.Container+HiddenRequest", throwOnError: false, ignoreCase: false); + var serviceType0_0Argument1Element = registryAssembly.GetType("TestApp.Container+HiddenResponse", throwOnError: false, ignoreCase: false); + if (serviceType0_0Argument0 is not null && serviceType0_0Argument1Element is not null) + { + var serviceType0_0 = typeof(global::GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler<,>).MakeGenericType(serviceType0_0Argument0, serviceType0_0Argument1Element.MakeArrayType()); + global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient( + services, + serviceType0_0, + implementationType0); + logger.Debug("Registered CQRS handler TestApp.Container.HiddenHandler as GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler."); + } + } + } + } + + """; + + private const string HiddenGenericEnvelopeResponseExpected = """ + // + #nullable enable + + [assembly: global::GFramework.Cqrs.CqrsHandlerRegistryAttribute(typeof(global::GFramework.Generated.Cqrs.__GFrameworkGeneratedCqrsHandlerRegistry))] + + namespace GFramework.Generated.Cqrs; + + internal sealed class __GFrameworkGeneratedCqrsHandlerRegistry : global::GFramework.Cqrs.ICqrsHandlerRegistry + { + public void Register(global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, global::GFramework.Core.Abstractions.Logging.ILogger logger) + { + if (services is null) + throw new global::System.ArgumentNullException(nameof(services)); + if (logger is null) + throw new global::System.ArgumentNullException(nameof(logger)); + + var registryAssembly = typeof(global::GFramework.Generated.Cqrs.__GFrameworkGeneratedCqrsHandlerRegistry).Assembly; + + var implementationType0 = registryAssembly.GetType("TestApp.Container+HiddenHandler", throwOnError: false, ignoreCase: false); + if (implementationType0 is not null) + { + var serviceType0_0Argument0 = registryAssembly.GetType("TestApp.Container+HiddenRequest", throwOnError: false, ignoreCase: false); + var serviceType0_0Argument1GenericDefinition = registryAssembly.GetType("TestApp.Container+HiddenEnvelope`1", throwOnError: false, ignoreCase: false); + if (serviceType0_0Argument0 is not null && serviceType0_0Argument1GenericDefinition is not null) + { + var serviceType0_0 = typeof(global::GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler<,>).MakeGenericType(serviceType0_0Argument0, serviceType0_0Argument1GenericDefinition.MakeGenericType(typeof(string))); + global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient( + services, + serviceType0_0, + implementationType0); + logger.Debug("Registered CQRS handler TestApp.Container.HiddenHandler as GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler>."); + } + } + } + } + + """; + /// /// 验证生成器会为当前程序集中的 request、notification 和 stream 处理器生成稳定顺序的注册器。 /// @@ -444,6 +433,152 @@ public class CqrsHandlerRegistryGeneratorTests ("CqrsHandlerRegistry.g.cs", HiddenImplementationDirectInterfaceRegistrationExpected)); } + /// + /// 验证精确重建路径会递归覆盖隐藏元素类型数组, + /// 使这类 handler interface 也能直接生成 closed service type,而不再退回 GetInterfaces()。 + /// + [Test] + public async Task Generates_Precise_Service_Type_For_Hidden_Array_Type_Arguments() + { + const string source = """ + using System; + + namespace Microsoft.Extensions.DependencyInjection + { + public interface IServiceCollection { } + + public static class ServiceCollectionServiceExtensions + { + public static void AddTransient(IServiceCollection services, Type serviceType, Type implementationType) { } + } + } + + namespace GFramework.Core.Abstractions.Logging + { + public interface ILogger + { + void Debug(string msg); + } + } + + namespace GFramework.Cqrs.Abstractions.Cqrs + { + public interface IRequest { } + public interface INotification { } + public interface IStreamRequest { } + + public interface IRequestHandler where TRequest : IRequest { } + public interface INotificationHandler where TNotification : INotification { } + public interface IStreamRequestHandler where TRequest : IStreamRequest { } + } + + namespace GFramework.Cqrs + { + public interface ICqrsHandlerRegistry + { + void Register(Microsoft.Extensions.DependencyInjection.IServiceCollection services, GFramework.Core.Abstractions.Logging.ILogger logger); + } + + [AttributeUsage(AttributeTargets.Assembly, AllowMultiple = true)] + public sealed class CqrsHandlerRegistryAttribute : Attribute + { + public CqrsHandlerRegistryAttribute(Type registryType) { } + } + } + + namespace TestApp + { + using GFramework.Cqrs.Abstractions.Cqrs; + + public sealed class Container + { + private sealed record HiddenResponse(); + + private sealed record HiddenRequest() : IRequest; + + private sealed class HiddenHandler : IRequestHandler { } + } + } + """; + + await GeneratorTest.RunAsync( + source, + ("CqrsHandlerRegistry.g.cs", HiddenArrayResponseFallbackExpected)); + } + + /// + /// 验证精确重建路径会递归覆盖隐藏泛型定义, + /// 使“隐藏泛型定义 + 可见/常量型实参”的闭包类型也能直接生成 closed service type。 + /// + [Test] + public async Task Generates_Precise_Service_Type_For_Hidden_Generic_Type_Definitions() + { + const string source = """ + using System; + + namespace Microsoft.Extensions.DependencyInjection + { + public interface IServiceCollection { } + + public static class ServiceCollectionServiceExtensions + { + public static void AddTransient(IServiceCollection services, Type serviceType, Type implementationType) { } + } + } + + namespace GFramework.Core.Abstractions.Logging + { + public interface ILogger + { + void Debug(string msg); + } + } + + namespace GFramework.Cqrs.Abstractions.Cqrs + { + public interface IRequest { } + public interface INotification { } + public interface IStreamRequest { } + + public interface IRequestHandler where TRequest : IRequest { } + public interface INotificationHandler where TNotification : INotification { } + public interface IStreamRequestHandler where TRequest : IStreamRequest { } + } + + namespace GFramework.Cqrs + { + public interface ICqrsHandlerRegistry + { + void Register(Microsoft.Extensions.DependencyInjection.IServiceCollection services, GFramework.Core.Abstractions.Logging.ILogger logger); + } + + [AttributeUsage(AttributeTargets.Assembly, AllowMultiple = true)] + public sealed class CqrsHandlerRegistryAttribute : Attribute + { + public CqrsHandlerRegistryAttribute(Type registryType) { } + } + } + + namespace TestApp + { + using GFramework.Cqrs.Abstractions.Cqrs; + + public sealed class Container + { + private sealed class HiddenEnvelope { } + + private sealed record HiddenRequest() : IRequest>; + + private sealed class HiddenHandler : IRequestHandler> { } + } + } + """; + + await GeneratorTest.RunAsync( + source, + ("CqrsHandlerRegistry.g.cs", HiddenGenericEnvelopeResponseExpected)); + } + /// /// 验证即使 runtime 仍暴露旧版无参 fallback marker,生成器也会优先在生成注册器内部处理隐藏 handler, /// 不再输出 fallback marker。 diff --git a/GFramework.SourceGenerators/Cqrs/CqrsHandlerRegistryGenerator.cs b/GFramework.SourceGenerators/Cqrs/CqrsHandlerRegistryGenerator.cs index 27b01f9f..1119a92c 100644 --- a/GFramework.SourceGenerators/Cqrs/CqrsHandlerRegistryGenerator.cs +++ b/GFramework.SourceGenerators/Cqrs/CqrsHandlerRegistryGenerator.cs @@ -90,36 +90,48 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator var registrations = ImmutableArray.CreateBuilder(handlerInterfaces.Length); var reflectedImplementationRegistrations = ImmutableArray.CreateBuilder(handlerInterfaces.Length); + var preciseReflectedRegistrations = + ImmutableArray.CreateBuilder(handlerInterfaces.Length); foreach (var handlerInterface in handlerInterfaces) { var canReferenceHandlerInterface = CanReferenceFromGeneratedRegistry(handlerInterface); - if (!canReferenceImplementation || !canReferenceHandlerInterface) + if (canReferenceImplementation && canReferenceHandlerInterface) { - if (!canReferenceImplementation && canReferenceHandlerInterface) - { - reflectedImplementationRegistrations.Add(new ReflectedImplementationRegistrationSpec( - handlerInterface.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat), - GetLogDisplayName(handlerInterface))); - continue; - } - - // Non-public handlers closed over non-public message types still cannot be expressed purely as - // typeof(...) registrations. Preserve generator hit rate by resolving only the affected - // implementation back from the current assembly instead of asking the runtime registrar to rescan - // the whole assembly. - return new HandlerCandidateAnalysis( + registrations.Add(new HandlerRegistrationSpec( + handlerInterface.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat), implementationTypeDisplayName, - implementationLogName, - ImmutableArray.Empty, - ImmutableArray.Empty, - GetReflectionTypeMetadataName(type)); + GetLogDisplayName(handlerInterface), + implementationLogName)); + continue; } - registrations.Add(new HandlerRegistrationSpec( - handlerInterface.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat), + if (!canReferenceImplementation && canReferenceHandlerInterface) + { + reflectedImplementationRegistrations.Add(new ReflectedImplementationRegistrationSpec( + handlerInterface.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat), + GetLogDisplayName(handlerInterface))); + continue; + } + + if (TryCreatePreciseReflectedRegistration( + context.SemanticModel.Compilation, + handlerInterface, + out var preciseReflectedRegistration)) + { + preciseReflectedRegistrations.Add(preciseReflectedRegistration); + continue; + } + + // Some closed handler interfaces still contain runtime-only type shapes such as arrays closed over + // non-public element types. For those rare cases keep the narrow implementation lookup, but let the + // generated registry discover the exact supported interfaces from the implementation type at runtime. + return new HandlerCandidateAnalysis( implementationTypeDisplayName, - GetLogDisplayName(handlerInterface), - implementationLogName)); + implementationLogName, + ImmutableArray.Empty, + ImmutableArray.Empty, + ImmutableArray.Empty, + GetReflectionTypeMetadataName(type)); } return new HandlerCandidateAnalysis( @@ -127,6 +139,7 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator implementationLogName, registrations.ToImmutable(), reflectedImplementationRegistrations.ToImmutable(), + preciseReflectedRegistrations.ToImmutable(), canReferenceImplementation ? null : GetReflectionTypeMetadataName(type)); } @@ -170,6 +183,7 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator candidate.ImplementationLogName, candidate.Registrations, candidate.ReflectedImplementationRegistrations, + candidate.PreciseReflectedRegistrations, candidate.ReflectionTypeMetadataName)); } @@ -214,6 +228,114 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator string.Equals(definitionMetadataName, IStreamRequestHandlerMetadataName, StringComparison.Ordinal); } + private static bool TryCreatePreciseReflectedRegistration( + Compilation compilation, + INamedTypeSymbol handlerInterface, + out PreciseReflectedRegistrationSpec registration) + { + var openHandlerTypeDisplayName = handlerInterface.OriginalDefinition + .ConstructUnboundGenericType() + .ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + var typeArguments = ImmutableArray.CreateBuilder(handlerInterface.TypeArguments.Length); + foreach (var typeArgument in handlerInterface.TypeArguments) + { + if (!TryCreateRuntimeTypeReference(compilation, typeArgument, out var runtimeTypeReference)) + { + registration = default; + return false; + } + + typeArguments.Add(runtimeTypeReference!); + } + + registration = new PreciseReflectedRegistrationSpec( + openHandlerTypeDisplayName, + GetLogDisplayName(handlerInterface), + typeArguments.ToImmutable()); + return true; + } + + private static bool TryCreateRuntimeTypeReference( + Compilation compilation, + ITypeSymbol type, + out RuntimeTypeReferenceSpec? runtimeTypeReference) + { + if (CanReferenceFromGeneratedRegistry(type)) + { + runtimeTypeReference = RuntimeTypeReferenceSpec.FromDirectReference( + type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)); + return true; + } + + if (type is IArrayTypeSymbol arrayType && + TryCreateRuntimeTypeReference(compilation, arrayType.ElementType, out var elementTypeReference)) + { + runtimeTypeReference = RuntimeTypeReferenceSpec.FromArray(elementTypeReference!, arrayType.Rank); + return true; + } + + if (type is INamedTypeSymbol genericNamedType && + genericNamedType.IsGenericType && + !genericNamedType.IsUnboundGenericType && + TryCreateGenericTypeDefinitionReference(compilation, genericNamedType, out var genericTypeDefinitionReference)) + { + var genericTypeArguments = + ImmutableArray.CreateBuilder(genericNamedType.TypeArguments.Length); + foreach (var typeArgument in genericNamedType.TypeArguments) + { + if (!TryCreateRuntimeTypeReference(compilation, typeArgument, out var genericTypeArgumentReference)) + { + runtimeTypeReference = null; + return false; + } + + genericTypeArguments.Add(genericTypeArgumentReference!); + } + + runtimeTypeReference = RuntimeTypeReferenceSpec.FromConstructedGeneric( + genericTypeDefinitionReference!, + genericTypeArguments.ToImmutable()); + return true; + } + + if (type is INamedTypeSymbol namedType && + SymbolEqualityComparer.Default.Equals(namedType.ContainingAssembly, compilation.Assembly)) + { + runtimeTypeReference = RuntimeTypeReferenceSpec.FromReflectionLookup( + GetReflectionTypeMetadataName(namedType)); + return true; + } + + runtimeTypeReference = null; + return false; + } + + private static bool TryCreateGenericTypeDefinitionReference( + Compilation compilation, + INamedTypeSymbol genericNamedType, + out RuntimeTypeReferenceSpec? genericTypeDefinitionReference) + { + var genericTypeDefinition = genericNamedType.OriginalDefinition; + if (CanReferenceFromGeneratedRegistry(genericTypeDefinition)) + { + genericTypeDefinitionReference = RuntimeTypeReferenceSpec.FromDirectReference( + genericTypeDefinition + .ConstructUnboundGenericType() + .ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)); + return true; + } + + if (SymbolEqualityComparer.Default.Equals(genericTypeDefinition.ContainingAssembly, compilation.Assembly)) + { + genericTypeDefinitionReference = RuntimeTypeReferenceSpec.FromReflectionLookup( + GetReflectionTypeMetadataName(genericTypeDefinition)); + return true; + } + + genericTypeDefinitionReference = null; + return false; + } + private static bool CanReferenceFromGeneratedRegistry(ITypeSymbol type) { switch (type) @@ -319,9 +441,12 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator { var hasReflectedImplementationRegistrations = registrations.Any(static registration => !registration.ReflectedImplementationRegistrations.IsDefaultOrEmpty); + var hasPreciseReflectedRegistrations = registrations.Any(static registration => + !registration.PreciseReflectedRegistrations.IsDefaultOrEmpty); var hasFullReflectionRegistrations = registrations.Any(static registration => !string.IsNullOrWhiteSpace(registration.ReflectionTypeMetadataName) && - registration.ReflectedImplementationRegistrations.IsDefaultOrEmpty); + registration.ReflectedImplementationRegistrations.IsDefaultOrEmpty && + registration.PreciseReflectedRegistrations.IsDefaultOrEmpty); var builder = new StringBuilder(); builder.AppendLine("// "); builder.AppendLine("#nullable enable"); @@ -354,7 +479,7 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator builder.AppendLine(" throw new global::System.ArgumentNullException(nameof(services));"); builder.AppendLine(" if (logger is null)"); builder.AppendLine(" throw new global::System.ArgumentNullException(nameof(logger));"); - if (hasReflectedImplementationRegistrations || hasFullReflectionRegistrations) + if (hasReflectedImplementationRegistrations || hasPreciseReflectedRegistrations || hasFullReflectionRegistrations) { builder.AppendLine(); builder.Append(" var registryAssembly = typeof(global::"); @@ -376,6 +501,12 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator continue; } + if (!registration.PreciseReflectedRegistrations.IsDefaultOrEmpty) + { + AppendPreciseReflectedRegistrations(builder, registration, registrationIndex); + continue; + } + if (!string.IsNullOrWhiteSpace(registration.ReflectionTypeMetadataName)) { AppendReflectionRegistration(builder, registration.ReflectionTypeMetadataName!); @@ -457,6 +588,197 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator builder.AppendLine(" }"); } + private static void AppendPreciseReflectedRegistrations( + StringBuilder builder, + ImplementationRegistrationSpec registration, + int registrationIndex) + { + var implementationVariableName = $"implementationType{registrationIndex}"; + if (string.IsNullOrWhiteSpace(registration.ReflectionTypeMetadataName)) + { + builder.Append(" var "); + builder.Append(implementationVariableName); + builder.Append(" = typeof("); + builder.Append(registration.ImplementationTypeDisplayName); + builder.AppendLine(");"); + } + else + { + var implementationReflectionTypeMetadataName = registration.ReflectionTypeMetadataName!; + builder.Append(" var "); + builder.Append(implementationVariableName); + builder.Append(" = registryAssembly.GetType(\""); + builder.Append(EscapeStringLiteral(implementationReflectionTypeMetadataName)); + builder.AppendLine("\", throwOnError: false, ignoreCase: false);"); + } + + builder.Append(" if ("); + builder.Append(implementationVariableName); + builder.AppendLine(" is not null)"); + builder.AppendLine(" {"); + + for (var registrationOffset = 0; + registrationOffset < registration.PreciseReflectedRegistrations.Length; + registrationOffset++) + { + var reflectedRegistration = registration.PreciseReflectedRegistrations[registrationOffset]; + var registrationVariablePrefix = $"serviceType{registrationIndex}_{registrationOffset}"; + AppendPreciseReflectedTypeResolution( + builder, + reflectedRegistration.ServiceTypeArguments, + registrationVariablePrefix, + implementationVariableName, + reflectedRegistration.OpenHandlerTypeDisplayName, + registration.ImplementationLogName, + reflectedRegistration.HandlerInterfaceLogName, + 3); + } + + builder.AppendLine(" }"); + } + + private static void AppendPreciseReflectedTypeResolution( + StringBuilder builder, + ImmutableArray serviceTypeArguments, + string registrationVariablePrefix, + string implementationVariableName, + string openHandlerTypeDisplayName, + string implementationLogName, + string handlerInterfaceLogName, + int indentLevel) + { + var indent = new string(' ', indentLevel * 4); + var nestedIndent = new string(' ', (indentLevel + 1) * 4); + var resolvedArgumentNames = new string[serviceTypeArguments.Length]; + var reflectedArgumentNames = new List(); + + for (var argumentIndex = 0; argumentIndex < serviceTypeArguments.Length; argumentIndex++) + { + resolvedArgumentNames[argumentIndex] = AppendRuntimeTypeReferenceResolution( + builder, + serviceTypeArguments[argumentIndex], + $"{registrationVariablePrefix}Argument{argumentIndex}", + reflectedArgumentNames, + indent); + } + + if (reflectedArgumentNames.Count > 0) + { + builder.Append(indent); + builder.Append("if ("); + for (var index = 0; index < reflectedArgumentNames.Count; index++) + { + if (index > 0) + builder.Append(" && "); + + builder.Append(reflectedArgumentNames[index]); + builder.Append(" is not null"); + } + + builder.AppendLine(")"); + builder.Append(indent); + builder.AppendLine("{"); + indent = nestedIndent; + } + + builder.Append(indent); + builder.Append("var "); + builder.Append(registrationVariablePrefix); + builder.Append(" = typeof("); + builder.Append(openHandlerTypeDisplayName); + builder.Append(").MakeGenericType("); + for (var index = 0; index < resolvedArgumentNames.Length; index++) + { + if (index > 0) + builder.Append(", "); + + builder.Append(resolvedArgumentNames[index]); + } + + builder.AppendLine(");"); + builder.Append(indent); + builder.AppendLine("global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient("); + builder.Append(indent); + builder.AppendLine(" services,"); + builder.Append(indent); + builder.Append(" "); + builder.Append(registrationVariablePrefix); + builder.AppendLine(","); + builder.Append(indent); + builder.Append(" "); + builder.Append(implementationVariableName); + builder.AppendLine(");"); + builder.Append(indent); + builder.Append("logger.Debug(\"Registered CQRS handler "); + builder.Append(EscapeStringLiteral(implementationLogName)); + builder.Append(" as "); + builder.Append(EscapeStringLiteral(handlerInterfaceLogName)); + builder.AppendLine(".\");"); + + if (reflectedArgumentNames.Count > 0) + { + builder.Append(new string(' ', indentLevel * 4)); + builder.AppendLine("}"); + } + } + + private static string AppendRuntimeTypeReferenceResolution( + StringBuilder builder, + RuntimeTypeReferenceSpec runtimeTypeReference, + string variableBaseName, + ICollection reflectedArgumentNames, + string indent) + { + if (!string.IsNullOrWhiteSpace(runtimeTypeReference.TypeDisplayName)) + return $"typeof({runtimeTypeReference.TypeDisplayName})"; + + if (runtimeTypeReference.ArrayElementTypeReference is not null) + { + var elementExpression = AppendRuntimeTypeReferenceResolution( + builder, + runtimeTypeReference.ArrayElementTypeReference, + $"{variableBaseName}Element", + reflectedArgumentNames, + indent); + + return runtimeTypeReference.ArrayRank == 1 + ? $"{elementExpression}.MakeArrayType()" + : $"{elementExpression}.MakeArrayType({runtimeTypeReference.ArrayRank})"; + } + + if (runtimeTypeReference.GenericTypeDefinitionReference is not null) + { + var genericTypeDefinitionExpression = AppendRuntimeTypeReferenceResolution( + builder, + runtimeTypeReference.GenericTypeDefinitionReference, + $"{variableBaseName}GenericDefinition", + reflectedArgumentNames, + indent); + var genericArgumentExpressions = new string[runtimeTypeReference.GenericTypeArguments.Length]; + for (var argumentIndex = 0; argumentIndex < runtimeTypeReference.GenericTypeArguments.Length; argumentIndex++) + { + genericArgumentExpressions[argumentIndex] = AppendRuntimeTypeReferenceResolution( + builder, + runtimeTypeReference.GenericTypeArguments[argumentIndex], + $"{variableBaseName}GenericArgument{argumentIndex}", + reflectedArgumentNames, + indent); + } + + return $"{genericTypeDefinitionExpression}.MakeGenericType({string.Join(", ", genericArgumentExpressions)})"; + } + + var reflectionTypeMetadataName = runtimeTypeReference.ReflectionTypeMetadataName!; + reflectedArgumentNames.Add(variableBaseName); + builder.Append(indent); + builder.Append("var "); + builder.Append(variableBaseName); + builder.Append(" = registryAssembly.GetType(\""); + builder.Append(EscapeStringLiteral(reflectionTypeMetadataName)); + builder.AppendLine("\", throwOnError: false, ignoreCase: false);"); + return variableBaseName; + } + private static void AppendReflectionHelpers(StringBuilder builder) { // Emit the runtime helper methods only when at least one handler requires metadata-name lookup. @@ -589,11 +911,52 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator string HandlerInterfaceDisplayName, string HandlerInterfaceLogName); + private sealed record RuntimeTypeReferenceSpec( + string? TypeDisplayName, + string? ReflectionTypeMetadataName, + RuntimeTypeReferenceSpec? ArrayElementTypeReference, + int ArrayRank, + RuntimeTypeReferenceSpec? GenericTypeDefinitionReference, + ImmutableArray GenericTypeArguments) + { + public static RuntimeTypeReferenceSpec FromDirectReference(string typeDisplayName) + { + return new RuntimeTypeReferenceSpec(typeDisplayName, null, null, 0, null, + ImmutableArray.Empty); + } + + public static RuntimeTypeReferenceSpec FromReflectionLookup(string reflectionTypeMetadataName) + { + return new RuntimeTypeReferenceSpec(null, reflectionTypeMetadataName, null, 0, null, + ImmutableArray.Empty); + } + + public static RuntimeTypeReferenceSpec FromArray(RuntimeTypeReferenceSpec elementTypeReference, int arrayRank) + { + return new RuntimeTypeReferenceSpec(null, null, elementTypeReference, arrayRank, null, + ImmutableArray.Empty); + } + + public static RuntimeTypeReferenceSpec FromConstructedGeneric( + RuntimeTypeReferenceSpec genericTypeDefinitionReference, + ImmutableArray genericTypeArguments) + { + return new RuntimeTypeReferenceSpec(null, null, null, 0, genericTypeDefinitionReference, + genericTypeArguments); + } + } + + private readonly record struct PreciseReflectedRegistrationSpec( + string OpenHandlerTypeDisplayName, + string HandlerInterfaceLogName, + ImmutableArray ServiceTypeArguments); + private readonly record struct ImplementationRegistrationSpec( string ImplementationTypeDisplayName, string ImplementationLogName, ImmutableArray DirectRegistrations, ImmutableArray ReflectedImplementationRegistrations, + ImmutableArray PreciseReflectedRegistrations, string? ReflectionTypeMetadataName); private readonly struct HandlerCandidateAnalysis : IEquatable @@ -603,12 +966,14 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator string implementationLogName, ImmutableArray registrations, ImmutableArray reflectedImplementationRegistrations, + ImmutableArray preciseReflectedRegistrations, string? reflectionTypeMetadataName) { ImplementationTypeDisplayName = implementationTypeDisplayName; ImplementationLogName = implementationLogName; Registrations = registrations; ReflectedImplementationRegistrations = reflectedImplementationRegistrations; + PreciseReflectedRegistrations = preciseReflectedRegistrations; ReflectionTypeMetadataName = reflectionTypeMetadataName; } @@ -620,6 +985,8 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator public ImmutableArray ReflectedImplementationRegistrations { get; } + public ImmutableArray PreciseReflectedRegistrations { get; } + public string? ReflectionTypeMetadataName { get; } public bool Equals(HandlerCandidateAnalysis other) @@ -630,7 +997,8 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator !string.Equals(ReflectionTypeMetadataName, other.ReflectionTypeMetadataName, StringComparison.Ordinal) || Registrations.Length != other.Registrations.Length || - ReflectedImplementationRegistrations.Length != other.ReflectedImplementationRegistrations.Length) + ReflectedImplementationRegistrations.Length != other.ReflectedImplementationRegistrations.Length || + PreciseReflectedRegistrations.Length != other.PreciseReflectedRegistrations.Length) { return false; } @@ -648,6 +1016,12 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator return false; } + for (var index = 0; index < PreciseReflectedRegistrations.Length; index++) + { + if (!PreciseReflectedRegistrations[index].Equals(other.PreciseReflectedRegistrations[index])) + return false; + } + return true; } @@ -676,6 +1050,11 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator hashCode = (hashCode * 397) ^ reflectedImplementationRegistration.GetHashCode(); } + foreach (var preciseReflectedRegistration in PreciseReflectedRegistrations) + { + hashCode = (hashCode * 397) ^ preciseReflectedRegistration.GetHashCode(); + } + return hashCode; } }