diff --git a/GFramework.SourceGenerators.Tests/Cqrs/CqrsHandlerRegistryGeneratorTests.cs b/GFramework.SourceGenerators.Tests/Cqrs/CqrsHandlerRegistryGeneratorTests.cs index dcdb5e5f..ef70ec4a 100644 --- a/GFramework.SourceGenerators.Tests/Cqrs/CqrsHandlerRegistryGeneratorTests.cs +++ b/GFramework.SourceGenerators.Tests/Cqrs/CqrsHandlerRegistryGeneratorTests.cs @@ -29,119 +29,141 @@ public class CqrsHandlerRegistryGeneratorTests var registryAssembly = typeof(global::GFramework.Generated.Cqrs.__GFrameworkGeneratedCqrsHandlerRegistry).Assembly; - RegisterReflectedHandler(services, logger, registryAssembly, "TestApp.Container+HiddenHandler"); + var implementationType0 = registryAssembly.GetType("TestApp.Container+HiddenHandler", throwOnError: false, ignoreCase: false); + if (implementationType0 is not null) + { + var serviceType0_0Argument0 = registryAssembly.GetType("TestApp.Container+HiddenRequest", throwOnError: false, ignoreCase: false); + if (serviceType0_0Argument0 is not null) + { + var serviceType0_0 = typeof(global::GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler<,>).MakeGenericType(serviceType0_0Argument0, typeof(string)); + global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient( + services, + serviceType0_0, + implementationType0); + logger.Debug("Registered CQRS handler TestApp.Container.HiddenHandler as GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler."); + } + } global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient( services, typeof(global::GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler), typeof(global::TestApp.VisibleHandler)); logger.Debug("Registered CQRS handler TestApp.VisibleHandler as GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler."); } - - private static void RegisterReflectedHandler(global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, global::GFramework.Core.Abstractions.Logging.ILogger logger, global::System.Reflection.Assembly registryAssembly, string implementationTypeMetadataName) - { - var implementationType = registryAssembly.GetType(implementationTypeMetadataName, throwOnError: false, ignoreCase: false); - if (implementationType is null) - return; - - var handlerInterfaces = implementationType.GetInterfaces(); - global::System.Array.Sort(handlerInterfaces, CompareTypes); - - foreach (var handlerInterface in handlerInterfaces) - { - if (!IsSupportedHandlerInterface(handlerInterface)) - continue; - - global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient( - services, - handlerInterface, - implementationType); - logger.Debug($"Registered CQRS handler {GetRuntimeTypeDisplayName(implementationType)} as {GetRuntimeTypeDisplayName(handlerInterface)}."); - } - } - - private static int CompareTypes(global::System.Type left, global::System.Type right) - { - return global::System.StringComparer.Ordinal.Compare(GetRuntimeTypeDisplayName(left), GetRuntimeTypeDisplayName(right)); - } - - private static bool IsSupportedHandlerInterface(global::System.Type interfaceType) - { - if (!interfaceType.IsGenericType) - return false; - - var definitionFullName = interfaceType.GetGenericTypeDefinition().FullName; - return global::System.StringComparer.Ordinal.Equals(definitionFullName, "GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler`2") - || global::System.StringComparer.Ordinal.Equals(definitionFullName, "GFramework.Cqrs.Abstractions.Cqrs.INotificationHandler`1") - || global::System.StringComparer.Ordinal.Equals(definitionFullName, "GFramework.Cqrs.Abstractions.Cqrs.IStreamRequestHandler`2"); - } - - private static string GetRuntimeTypeDisplayName(global::System.Type type) - { - if (type == typeof(string)) - return "string"; - if (type == typeof(int)) - return "int"; - if (type == typeof(long)) - return "long"; - if (type == typeof(short)) - return "short"; - if (type == typeof(byte)) - return "byte"; - if (type == typeof(bool)) - return "bool"; - if (type == typeof(object)) - return "object"; - if (type == typeof(void)) - return "void"; - if (type == typeof(uint)) - return "uint"; - if (type == typeof(ulong)) - return "ulong"; - if (type == typeof(ushort)) - return "ushort"; - if (type == typeof(sbyte)) - return "sbyte"; - if (type == typeof(float)) - return "float"; - if (type == typeof(double)) - return "double"; - if (type == typeof(decimal)) - return "decimal"; - if (type == typeof(char)) - return "char"; - - if (type.IsArray) - return GetRuntimeTypeDisplayName(type.GetElementType()!) + "[]"; - - if (!type.IsGenericType) - return (type.FullName ?? type.Name).Replace('+', '.'); - - var genericTypeName = type.GetGenericTypeDefinition().FullName ?? type.Name; - var arityIndex = genericTypeName.IndexOf('`'); - if (arityIndex >= 0) - genericTypeName = genericTypeName[..arityIndex]; - - genericTypeName = genericTypeName.Replace('+', '.'); - var arguments = type.GetGenericArguments(); - var builder = new global::System.Text.StringBuilder(); - builder.Append(genericTypeName); - builder.Append('<'); - - for (var index = 0; index < arguments.Length; index++) - { - if (index > 0) - builder.Append(", "); - - builder.Append(GetRuntimeTypeDisplayName(arguments[index])); - } - - builder.Append('>'); - return builder.ToString(); - } } """; + private const string HiddenImplementationDirectInterfaceRegistrationExpected = """ + // + #nullable enable + + [assembly: global::GFramework.Cqrs.CqrsHandlerRegistryAttribute(typeof(global::GFramework.Generated.Cqrs.__GFrameworkGeneratedCqrsHandlerRegistry))] + + namespace GFramework.Generated.Cqrs; + + internal sealed class __GFrameworkGeneratedCqrsHandlerRegistry : global::GFramework.Cqrs.ICqrsHandlerRegistry + { + public void Register(global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, global::GFramework.Core.Abstractions.Logging.ILogger logger) + { + if (services is null) + throw new global::System.ArgumentNullException(nameof(services)); + if (logger is null) + throw new global::System.ArgumentNullException(nameof(logger)); + + var registryAssembly = typeof(global::GFramework.Generated.Cqrs.__GFrameworkGeneratedCqrsHandlerRegistry).Assembly; + + var implementationType0 = registryAssembly.GetType("TestApp.Container+HiddenHandler", throwOnError: false, ignoreCase: false); + if (implementationType0 is not null) + { + global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient( + services, + typeof(global::GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler), + implementationType0); + logger.Debug("Registered CQRS handler TestApp.Container.HiddenHandler as GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler."); + } + } + } + + """; + + private const string HiddenArrayResponseFallbackExpected = """ + // + #nullable enable + + [assembly: global::GFramework.Cqrs.CqrsHandlerRegistryAttribute(typeof(global::GFramework.Generated.Cqrs.__GFrameworkGeneratedCqrsHandlerRegistry))] + + namespace GFramework.Generated.Cqrs; + + internal sealed class __GFrameworkGeneratedCqrsHandlerRegistry : global::GFramework.Cqrs.ICqrsHandlerRegistry + { + public void Register(global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, global::GFramework.Core.Abstractions.Logging.ILogger logger) + { + if (services is null) + throw new global::System.ArgumentNullException(nameof(services)); + if (logger is null) + throw new global::System.ArgumentNullException(nameof(logger)); + + var registryAssembly = typeof(global::GFramework.Generated.Cqrs.__GFrameworkGeneratedCqrsHandlerRegistry).Assembly; + + var implementationType0 = registryAssembly.GetType("TestApp.Container+HiddenHandler", throwOnError: false, ignoreCase: false); + if (implementationType0 is not null) + { + var serviceType0_0Argument0 = registryAssembly.GetType("TestApp.Container+HiddenRequest", throwOnError: false, ignoreCase: false); + var serviceType0_0Argument1Element = registryAssembly.GetType("TestApp.Container+HiddenResponse", throwOnError: false, ignoreCase: false); + if (serviceType0_0Argument0 is not null && serviceType0_0Argument1Element is not null) + { + var serviceType0_0 = typeof(global::GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler<,>).MakeGenericType(serviceType0_0Argument0, serviceType0_0Argument1Element.MakeArrayType()); + global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient( + services, + serviceType0_0, + implementationType0); + logger.Debug("Registered CQRS handler TestApp.Container.HiddenHandler as GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler."); + } + } + } + } + + """; + + private const string HiddenGenericEnvelopeResponseExpected = """ + // + #nullable enable + + [assembly: global::GFramework.Cqrs.CqrsHandlerRegistryAttribute(typeof(global::GFramework.Generated.Cqrs.__GFrameworkGeneratedCqrsHandlerRegistry))] + + namespace GFramework.Generated.Cqrs; + + internal sealed class __GFrameworkGeneratedCqrsHandlerRegistry : global::GFramework.Cqrs.ICqrsHandlerRegistry + { + public void Register(global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, global::GFramework.Core.Abstractions.Logging.ILogger logger) + { + if (services is null) + throw new global::System.ArgumentNullException(nameof(services)); + if (logger is null) + throw new global::System.ArgumentNullException(nameof(logger)); + + var registryAssembly = typeof(global::GFramework.Generated.Cqrs.__GFrameworkGeneratedCqrsHandlerRegistry).Assembly; + + var implementationType0 = registryAssembly.GetType("TestApp.Container+HiddenHandler", throwOnError: false, ignoreCase: false); + if (implementationType0 is not null) + { + var serviceType0_0Argument0 = registryAssembly.GetType("TestApp.Container+HiddenRequest", throwOnError: false, ignoreCase: false); + var serviceType0_0Argument1GenericDefinition = registryAssembly.GetType("TestApp.Container+HiddenEnvelope`1", throwOnError: false, ignoreCase: false); + if (serviceType0_0Argument0 is not null && serviceType0_0Argument1GenericDefinition is not null) + { + var serviceType0_0 = typeof(global::GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler<,>).MakeGenericType(serviceType0_0Argument0, serviceType0_0Argument1GenericDefinition.MakeGenericType(typeof(string))); + global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient( + services, + serviceType0_0, + implementationType0); + logger.Debug("Registered CQRS handler TestApp.Container.HiddenHandler as GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler>."); + } + } + } + } + + """; + /// /// 验证生成器会为当前程序集中的 request、notification 和 stream 处理器生成稳定顺序的注册器。 /// @@ -339,6 +361,224 @@ public class CqrsHandlerRegistryGeneratorTests ("CqrsHandlerRegistry.g.cs", HiddenNestedHandlerSelfRegistrationExpected)); } + /// + /// 验证当隐藏实现类型的 handler 接口仍可被生成代码直接引用时, + /// 生成器只会定向反射实现类型,而不会再生成基于 GetInterfaces() 的接口发现辅助逻辑。 + /// + [Test] + public async Task + Generates_Direct_Interface_Registrations_For_Hidden_Implementation_When_Handler_Interface_Is_Public() + { + const string source = """ + using System; + + namespace Microsoft.Extensions.DependencyInjection + { + public interface IServiceCollection { } + + public static class ServiceCollectionServiceExtensions + { + public static void AddTransient(IServiceCollection services, Type serviceType, Type implementationType) { } + } + } + + namespace GFramework.Core.Abstractions.Logging + { + public interface ILogger + { + void Debug(string msg); + } + } + + namespace GFramework.Cqrs.Abstractions.Cqrs + { + public interface IRequest { } + public interface INotification { } + public interface IStreamRequest { } + + public interface IRequestHandler where TRequest : IRequest { } + public interface INotificationHandler where TNotification : INotification { } + public interface IStreamRequestHandler where TRequest : IStreamRequest { } + } + + namespace GFramework.Cqrs + { + public interface ICqrsHandlerRegistry + { + void Register(Microsoft.Extensions.DependencyInjection.IServiceCollection services, GFramework.Core.Abstractions.Logging.ILogger logger); + } + + [AttributeUsage(AttributeTargets.Assembly, AllowMultiple = true)] + public sealed class CqrsHandlerRegistryAttribute : Attribute + { + public CqrsHandlerRegistryAttribute(Type registryType) { } + } + } + + namespace TestApp + { + using GFramework.Cqrs.Abstractions.Cqrs; + + public sealed record VisibleRequest() : IRequest; + + public sealed class Container + { + private sealed class HiddenHandler : IRequestHandler { } + } + } + """; + + await GeneratorTest.RunAsync( + source, + ("CqrsHandlerRegistry.g.cs", HiddenImplementationDirectInterfaceRegistrationExpected)); + } + + /// + /// 验证精确重建路径会递归覆盖隐藏元素类型数组, + /// 使这类 handler interface 也能直接生成 closed service type,而不再退回 GetInterfaces()。 + /// + [Test] + public async Task Generates_Precise_Service_Type_For_Hidden_Array_Type_Arguments() + { + const string source = """ + using System; + + namespace Microsoft.Extensions.DependencyInjection + { + public interface IServiceCollection { } + + public static class ServiceCollectionServiceExtensions + { + public static void AddTransient(IServiceCollection services, Type serviceType, Type implementationType) { } + } + } + + namespace GFramework.Core.Abstractions.Logging + { + public interface ILogger + { + void Debug(string msg); + } + } + + namespace GFramework.Cqrs.Abstractions.Cqrs + { + public interface IRequest { } + public interface INotification { } + public interface IStreamRequest { } + + public interface IRequestHandler where TRequest : IRequest { } + public interface INotificationHandler where TNotification : INotification { } + public interface IStreamRequestHandler where TRequest : IStreamRequest { } + } + + namespace GFramework.Cqrs + { + public interface ICqrsHandlerRegistry + { + void Register(Microsoft.Extensions.DependencyInjection.IServiceCollection services, GFramework.Core.Abstractions.Logging.ILogger logger); + } + + [AttributeUsage(AttributeTargets.Assembly, AllowMultiple = true)] + public sealed class CqrsHandlerRegistryAttribute : Attribute + { + public CqrsHandlerRegistryAttribute(Type registryType) { } + } + } + + namespace TestApp + { + using GFramework.Cqrs.Abstractions.Cqrs; + + public sealed class Container + { + private sealed record HiddenResponse(); + + private sealed record HiddenRequest() : IRequest; + + private sealed class HiddenHandler : IRequestHandler { } + } + } + """; + + await GeneratorTest.RunAsync( + source, + ("CqrsHandlerRegistry.g.cs", HiddenArrayResponseFallbackExpected)); + } + + /// + /// 验证精确重建路径会递归覆盖隐藏泛型定义, + /// 使“隐藏泛型定义 + 可见/常量型实参”的闭包类型也能直接生成 closed service type。 + /// + [Test] + public async Task Generates_Precise_Service_Type_For_Hidden_Generic_Type_Definitions() + { + const string source = """ + using System; + + namespace Microsoft.Extensions.DependencyInjection + { + public interface IServiceCollection { } + + public static class ServiceCollectionServiceExtensions + { + public static void AddTransient(IServiceCollection services, Type serviceType, Type implementationType) { } + } + } + + namespace GFramework.Core.Abstractions.Logging + { + public interface ILogger + { + void Debug(string msg); + } + } + + namespace GFramework.Cqrs.Abstractions.Cqrs + { + public interface IRequest { } + public interface INotification { } + public interface IStreamRequest { } + + public interface IRequestHandler where TRequest : IRequest { } + public interface INotificationHandler where TNotification : INotification { } + public interface IStreamRequestHandler where TRequest : IStreamRequest { } + } + + namespace GFramework.Cqrs + { + public interface ICqrsHandlerRegistry + { + void Register(Microsoft.Extensions.DependencyInjection.IServiceCollection services, GFramework.Core.Abstractions.Logging.ILogger logger); + } + + [AttributeUsage(AttributeTargets.Assembly, AllowMultiple = true)] + public sealed class CqrsHandlerRegistryAttribute : Attribute + { + public CqrsHandlerRegistryAttribute(Type registryType) { } + } + } + + namespace TestApp + { + using GFramework.Cqrs.Abstractions.Cqrs; + + public sealed class Container + { + private sealed class HiddenEnvelope { } + + private sealed record HiddenRequest() : IRequest>; + + private sealed class HiddenHandler : IRequestHandler> { } + } + } + """; + + await GeneratorTest.RunAsync( + source, + ("CqrsHandlerRegistry.g.cs", HiddenGenericEnvelopeResponseExpected)); + } + /// /// 验证即使 runtime 仍暴露旧版无参 fallback marker,生成器也会优先在生成注册器内部处理隐藏 handler, /// 不再输出 fallback marker。 diff --git a/GFramework.SourceGenerators/Cqrs/CqrsHandlerRegistryGenerator.cs b/GFramework.SourceGenerators/Cqrs/CqrsHandlerRegistryGenerator.cs index 83559781..36a72be5 100644 --- a/GFramework.SourceGenerators/Cqrs/CqrsHandlerRegistryGenerator.cs +++ b/GFramework.SourceGenerators/Cqrs/CqrsHandlerRegistryGenerator.cs @@ -86,34 +86,61 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator var implementationTypeDisplayName = type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); var implementationLogName = GetLogDisplayName(type); - if (!CanReferenceFromGeneratedRegistry(type) || - handlerInterfaces.Any(interfaceType => !CanReferenceFromGeneratedRegistry(interfaceType))) + var canReferenceImplementation = CanReferenceFromGeneratedRegistry(type); + var registrations = ImmutableArray.CreateBuilder(handlerInterfaces.Length); + var reflectedImplementationRegistrations = + ImmutableArray.CreateBuilder(handlerInterfaces.Length); + var preciseReflectedRegistrations = + ImmutableArray.CreateBuilder(handlerInterfaces.Length); + foreach (var handlerInterface in handlerInterfaces) { - // Non-public handlers and handlers closed over non-public message types cannot appear in typeof(...) - // expressions inside generated code. Preserve generator hit rate by resolving just that implementation - // type back from the current assembly instead of asking the runtime registrar to rescan the assembly. + var canReferenceHandlerInterface = CanReferenceFromGeneratedRegistry(handlerInterface); + if (canReferenceImplementation && canReferenceHandlerInterface) + { + registrations.Add(new HandlerRegistrationSpec( + handlerInterface.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat), + implementationTypeDisplayName, + GetLogDisplayName(handlerInterface), + implementationLogName)); + continue; + } + + if (!canReferenceImplementation && canReferenceHandlerInterface) + { + reflectedImplementationRegistrations.Add(new ReflectedImplementationRegistrationSpec( + handlerInterface.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat), + GetLogDisplayName(handlerInterface))); + continue; + } + + if (TryCreatePreciseReflectedRegistration( + context.SemanticModel.Compilation, + handlerInterface, + out var preciseReflectedRegistration)) + { + preciseReflectedRegistrations.Add(preciseReflectedRegistration); + continue; + } + + // Some closed handler interfaces still contain runtime-only type shapes such as arrays closed over + // non-public element types. For those rare cases keep the narrow implementation lookup, but let the + // generated registry discover the exact supported interfaces from the implementation type at runtime. return new HandlerCandidateAnalysis( implementationTypeDisplayName, implementationLogName, ImmutableArray.Empty, + ImmutableArray.Empty, + ImmutableArray.Empty, GetReflectionTypeMetadataName(type)); } - var registrations = ImmutableArray.CreateBuilder(handlerInterfaces.Length); - foreach (var handlerInterface in handlerInterfaces) - { - registrations.Add(new HandlerRegistrationSpec( - handlerInterface.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat), - implementationTypeDisplayName, - GetLogDisplayName(handlerInterface), - implementationLogName)); - } - return new HandlerCandidateAnalysis( implementationTypeDisplayName, implementationLogName, - registrations.MoveToImmutable(), - null); + registrations.ToImmutable(), + reflectedImplementationRegistrations.ToImmutable(), + preciseReflectedRegistrations.ToImmutable(), + canReferenceImplementation ? null : GetReflectionTypeMetadataName(type)); } private static void Execute(SourceProductionContext context, GenerationEnvironment generationEnvironment, @@ -155,6 +182,8 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator candidate.ImplementationTypeDisplayName, candidate.ImplementationLogName, candidate.Registrations, + candidate.ReflectedImplementationRegistrations, + candidate.PreciseReflectedRegistrations, candidate.ReflectionTypeMetadataName)); } @@ -199,6 +228,167 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator string.Equals(definitionMetadataName, IStreamRequestHandlerMetadataName, StringComparison.Ordinal); } + /// + /// 为无法直接在生成代码中书写的关闭处理器接口构造精确的运行时注册描述。 + /// + /// + /// 当前生成轮次对应的编译上下文,用于判断类型是否属于当前程序集,从而决定是生成直接类型引用还是延迟到运行时反射解析。 + /// + /// + /// 需要注册的关闭处理器接口。调用方应保证它来自受支持的 CQRS 处理器接口定义,并且其泛型参数顺序与运行时注册约定一致。 + /// + /// + /// 当方法返回 时,包含开放泛型处理器类型和每个运行时类型实参的精确描述; + /// 当方法返回 时,为默认值,调用方应回退到基于实现类型的宽松反射发现路径。 + /// + /// + /// 当接口上的所有运行时类型引用都能在生成阶段被稳定描述时返回 ; + /// 只要任一泛型实参无法安全编码到生成输出中,就返回 。 + /// + private static bool TryCreatePreciseReflectedRegistration( + Compilation compilation, + INamedTypeSymbol handlerInterface, + out PreciseReflectedRegistrationSpec registration) + { + var openHandlerTypeDisplayName = handlerInterface.OriginalDefinition + .ConstructUnboundGenericType() + .ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + var typeArguments = + ImmutableArray.CreateBuilder(handlerInterface.TypeArguments.Length); + foreach (var typeArgument in handlerInterface.TypeArguments) + { + if (!TryCreateRuntimeTypeReference(compilation, typeArgument, out var runtimeTypeReference)) + { + registration = default; + return false; + } + + typeArguments.Add(runtimeTypeReference!); + } + + registration = new PreciseReflectedRegistrationSpec( + openHandlerTypeDisplayName, + GetLogDisplayName(handlerInterface), + typeArguments.ToImmutable()); + return true; + } + + /// + /// 将 Roslyn 类型符号转换为生成注册器可消费的运行时类型引用描述。 + /// + /// + /// 当前编译上下文,用于区分可直接引用的外部可访问类型与必须通过当前程序集运行时反射查找的内部类型。 + /// + /// + /// 需要转换的类型符号。该方法会递归处理数组元素类型和已构造泛型的类型实参,但不会为未绑定泛型或类型参数生成引用。 + /// + /// + /// 当方法返回 时,包含可直接引用、数组、已构造泛型或反射查找中的一种运行时表示; + /// 当方法返回 时为 ,调用方应回退到更宽泛的实现类型反射扫描策略。 + /// + /// + /// 当 及其递归子结构都能映射为稳定的运行时引用时返回 ; + /// 若遇到类型参数、无法访问的运行时结构,或任一递归分支无法表示,则返回 。 + /// + private static bool TryCreateRuntimeTypeReference( + Compilation compilation, + ITypeSymbol type, + out RuntimeTypeReferenceSpec? runtimeTypeReference) + { + if (CanReferenceFromGeneratedRegistry(type)) + { + runtimeTypeReference = RuntimeTypeReferenceSpec.FromDirectReference( + type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)); + return true; + } + + if (type is IArrayTypeSymbol arrayType && + TryCreateRuntimeTypeReference(compilation, arrayType.ElementType, out var elementTypeReference)) + { + runtimeTypeReference = RuntimeTypeReferenceSpec.FromArray(elementTypeReference!, arrayType.Rank); + return true; + } + + if (type is INamedTypeSymbol genericNamedType && + genericNamedType.IsGenericType && + !genericNamedType.IsUnboundGenericType && + TryCreateGenericTypeDefinitionReference(compilation, genericNamedType, + out var genericTypeDefinitionReference)) + { + var genericTypeArguments = + ImmutableArray.CreateBuilder(genericNamedType.TypeArguments.Length); + foreach (var typeArgument in genericNamedType.TypeArguments) + { + if (!TryCreateRuntimeTypeReference(compilation, typeArgument, out var genericTypeArgumentReference)) + { + runtimeTypeReference = null; + return false; + } + + genericTypeArguments.Add(genericTypeArgumentReference!); + } + + runtimeTypeReference = RuntimeTypeReferenceSpec.FromConstructedGeneric( + genericTypeDefinitionReference!, + genericTypeArguments.ToImmutable()); + return true; + } + + if (type is INamedTypeSymbol namedType && + SymbolEqualityComparer.Default.Equals(namedType.ContainingAssembly, compilation.Assembly)) + { + runtimeTypeReference = RuntimeTypeReferenceSpec.FromReflectionLookup( + GetReflectionTypeMetadataName(namedType)); + return true; + } + + runtimeTypeReference = null; + return false; + } + + /// + /// 为已构造泛型类型解析其泛型定义的运行时引用描述。 + /// + /// + /// 当前编译上下文,用于判断泛型定义是否应以内联类型引用形式生成,或在运行时通过当前程序集反射解析。 + /// + /// + /// 已构造的泛型类型。该方法只处理其原始泛型定义,不负责递归解析类型实参。 + /// + /// + /// 当方法返回 时,包含泛型定义的直接引用或反射查找描述; + /// 当方法返回 时为 ,调用方应停止精确构造并回退到更保守的注册路径。 + /// + /// + /// 当泛型定义能够以稳定方式编码到生成输出中时返回 ; + /// 若泛型定义既不能直接引用,也不属于当前程序集可供反射查找,则返回 。 + /// + private static bool TryCreateGenericTypeDefinitionReference( + Compilation compilation, + INamedTypeSymbol genericNamedType, + out RuntimeTypeReferenceSpec? genericTypeDefinitionReference) + { + var genericTypeDefinition = genericNamedType.OriginalDefinition; + if (CanReferenceFromGeneratedRegistry(genericTypeDefinition)) + { + genericTypeDefinitionReference = RuntimeTypeReferenceSpec.FromDirectReference( + genericTypeDefinition + .ConstructUnboundGenericType() + .ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)); + return true; + } + + if (SymbolEqualityComparer.Default.Equals(genericTypeDefinition.ContainingAssembly, compilation.Assembly)) + { + genericTypeDefinitionReference = RuntimeTypeReferenceSpec.FromReflectionLookup( + GetReflectionTypeMetadataName(genericTypeDefinition)); + return true; + } + + genericTypeDefinitionReference = null; + return false; + } + private static bool CanReferenceFromGeneratedRegistry(ITypeSymbol type) { switch (type) @@ -302,8 +492,14 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator private static string GenerateSource( IReadOnlyList registrations) { - var hasReflectionRegistrations = registrations.Any(static registration => - !string.IsNullOrWhiteSpace(registration.ReflectionTypeMetadataName)); + var hasReflectedImplementationRegistrations = registrations.Any(static registration => + !registration.ReflectedImplementationRegistrations.IsDefaultOrEmpty); + var hasPreciseReflectedRegistrations = registrations.Any(static registration => + !registration.PreciseReflectedRegistrations.IsDefaultOrEmpty); + var hasFullReflectionRegistrations = registrations.Any(static registration => + !string.IsNullOrWhiteSpace(registration.ReflectionTypeMetadataName) && + registration.ReflectedImplementationRegistrations.IsDefaultOrEmpty && + registration.PreciseReflectedRegistrations.IsDefaultOrEmpty); var builder = new StringBuilder(); builder.AppendLine("// "); builder.AppendLine("#nullable enable"); @@ -336,7 +532,8 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator builder.AppendLine(" throw new global::System.ArgumentNullException(nameof(services));"); builder.AppendLine(" if (logger is null)"); builder.AppendLine(" throw new global::System.ArgumentNullException(nameof(logger));"); - if (hasReflectionRegistrations) + if (hasReflectedImplementationRegistrations || hasPreciseReflectedRegistrations || + hasFullReflectionRegistrations) { builder.AppendLine(); builder.Append(" var registryAssembly = typeof(global::"); @@ -349,8 +546,21 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator if (registrations.Count > 0) builder.AppendLine(); - foreach (var registration in registrations) + for (var registrationIndex = 0; registrationIndex < registrations.Count; registrationIndex++) { + var registration = registrations[registrationIndex]; + if (!registration.ReflectedImplementationRegistrations.IsDefaultOrEmpty) + { + AppendReflectedImplementationRegistrations(builder, registration, registrationIndex); + continue; + } + + if (!registration.PreciseReflectedRegistrations.IsDefaultOrEmpty) + { + AppendPreciseReflectedRegistrations(builder, registration, registrationIndex); + continue; + } + if (!string.IsNullOrWhiteSpace(registration.ReflectionTypeMetadataName)) { AppendReflectionRegistration(builder, registration.ReflectionTypeMetadataName!); @@ -378,7 +588,7 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator builder.AppendLine(" }"); - if (hasReflectionRegistrations) + if (hasFullReflectionRegistrations) { builder.AppendLine(); AppendReflectionHelpers(builder); @@ -395,6 +605,238 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator builder.AppendLine("\");"); } + private static void AppendReflectedImplementationRegistrations( + StringBuilder builder, + ImplementationRegistrationSpec registration, + int registrationIndex) + { + var implementationVariableName = $"implementationType{registrationIndex}"; + builder.Append(" var "); + builder.Append(implementationVariableName); + builder.Append(" = registryAssembly.GetType(\""); + builder.Append(EscapeStringLiteral(registration.ReflectionTypeMetadataName!)); + builder.AppendLine("\", throwOnError: false, ignoreCase: false);"); + builder.Append(" if ("); + builder.Append(implementationVariableName); + builder.AppendLine(" is not null)"); + builder.AppendLine(" {"); + + foreach (var reflectedRegistration in registration.ReflectedImplementationRegistrations) + { + builder.AppendLine( + " global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient("); + builder.AppendLine(" services,"); + builder.Append(" typeof("); + builder.Append(reflectedRegistration.HandlerInterfaceDisplayName); + builder.AppendLine("),"); + builder.Append(" "); + builder.Append(implementationVariableName); + builder.AppendLine(");"); + builder.Append(" logger.Debug(\"Registered CQRS handler "); + builder.Append(EscapeStringLiteral(registration.ImplementationLogName)); + builder.Append(" as "); + builder.Append(EscapeStringLiteral(reflectedRegistration.HandlerInterfaceLogName)); + builder.AppendLine(".\");"); + } + + builder.AppendLine(" }"); + } + + private static void AppendPreciseReflectedRegistrations( + StringBuilder builder, + ImplementationRegistrationSpec registration, + int registrationIndex) + { + var implementationVariableName = $"implementationType{registrationIndex}"; + if (string.IsNullOrWhiteSpace(registration.ReflectionTypeMetadataName)) + { + builder.Append(" var "); + builder.Append(implementationVariableName); + builder.Append(" = typeof("); + builder.Append(registration.ImplementationTypeDisplayName); + builder.AppendLine(");"); + } + else + { + var implementationReflectionTypeMetadataName = registration.ReflectionTypeMetadataName!; + builder.Append(" var "); + builder.Append(implementationVariableName); + builder.Append(" = registryAssembly.GetType(\""); + builder.Append(EscapeStringLiteral(implementationReflectionTypeMetadataName)); + builder.AppendLine("\", throwOnError: false, ignoreCase: false);"); + } + + builder.Append(" if ("); + builder.Append(implementationVariableName); + builder.AppendLine(" is not null)"); + builder.AppendLine(" {"); + + for (var registrationOffset = 0; + registrationOffset < registration.PreciseReflectedRegistrations.Length; + registrationOffset++) + { + var reflectedRegistration = registration.PreciseReflectedRegistrations[registrationOffset]; + var registrationVariablePrefix = $"serviceType{registrationIndex}_{registrationOffset}"; + AppendPreciseReflectedTypeResolution( + builder, + reflectedRegistration.ServiceTypeArguments, + registrationVariablePrefix, + implementationVariableName, + reflectedRegistration.OpenHandlerTypeDisplayName, + registration.ImplementationLogName, + reflectedRegistration.HandlerInterfaceLogName, + 3); + } + + builder.AppendLine(" }"); + } + + private static void AppendPreciseReflectedTypeResolution( + StringBuilder builder, + ImmutableArray serviceTypeArguments, + string registrationVariablePrefix, + string implementationVariableName, + string openHandlerTypeDisplayName, + string implementationLogName, + string handlerInterfaceLogName, + int indentLevel) + { + var indent = new string(' ', indentLevel * 4); + var nestedIndent = new string(' ', (indentLevel + 1) * 4); + var resolvedArgumentNames = new string[serviceTypeArguments.Length]; + var reflectedArgumentNames = new List(); + + for (var argumentIndex = 0; argumentIndex < serviceTypeArguments.Length; argumentIndex++) + { + resolvedArgumentNames[argumentIndex] = AppendRuntimeTypeReferenceResolution( + builder, + serviceTypeArguments[argumentIndex], + $"{registrationVariablePrefix}Argument{argumentIndex}", + reflectedArgumentNames, + indent); + } + + if (reflectedArgumentNames.Count > 0) + { + builder.Append(indent); + builder.Append("if ("); + for (var index = 0; index < reflectedArgumentNames.Count; index++) + { + if (index > 0) + builder.Append(" && "); + + builder.Append(reflectedArgumentNames[index]); + builder.Append(" is not null"); + } + + builder.AppendLine(")"); + builder.Append(indent); + builder.AppendLine("{"); + indent = nestedIndent; + } + + builder.Append(indent); + builder.Append("var "); + builder.Append(registrationVariablePrefix); + builder.Append(" = typeof("); + builder.Append(openHandlerTypeDisplayName); + builder.Append(").MakeGenericType("); + for (var index = 0; index < resolvedArgumentNames.Length; index++) + { + if (index > 0) + builder.Append(", "); + + builder.Append(resolvedArgumentNames[index]); + } + + builder.AppendLine(");"); + builder.Append(indent); + builder.AppendLine( + "global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient("); + builder.Append(indent); + builder.AppendLine(" services,"); + builder.Append(indent); + builder.Append(" "); + builder.Append(registrationVariablePrefix); + builder.AppendLine(","); + builder.Append(indent); + builder.Append(" "); + builder.Append(implementationVariableName); + builder.AppendLine(");"); + builder.Append(indent); + builder.Append("logger.Debug(\"Registered CQRS handler "); + builder.Append(EscapeStringLiteral(implementationLogName)); + builder.Append(" as "); + builder.Append(EscapeStringLiteral(handlerInterfaceLogName)); + builder.AppendLine(".\");"); + + if (reflectedArgumentNames.Count > 0) + { + builder.Append(new string(' ', indentLevel * 4)); + builder.AppendLine("}"); + } + } + + private static string AppendRuntimeTypeReferenceResolution( + StringBuilder builder, + RuntimeTypeReferenceSpec runtimeTypeReference, + string variableBaseName, + ICollection reflectedArgumentNames, + string indent) + { + if (!string.IsNullOrWhiteSpace(runtimeTypeReference.TypeDisplayName)) + return $"typeof({runtimeTypeReference.TypeDisplayName})"; + + if (runtimeTypeReference.ArrayElementTypeReference is not null) + { + var elementExpression = AppendRuntimeTypeReferenceResolution( + builder, + runtimeTypeReference.ArrayElementTypeReference, + $"{variableBaseName}Element", + reflectedArgumentNames, + indent); + + return runtimeTypeReference.ArrayRank == 1 + ? $"{elementExpression}.MakeArrayType()" + : $"{elementExpression}.MakeArrayType({runtimeTypeReference.ArrayRank})"; + } + + if (runtimeTypeReference.GenericTypeDefinitionReference is not null) + { + var genericTypeDefinitionExpression = AppendRuntimeTypeReferenceResolution( + builder, + runtimeTypeReference.GenericTypeDefinitionReference, + $"{variableBaseName}GenericDefinition", + reflectedArgumentNames, + indent); + var genericArgumentExpressions = new string[runtimeTypeReference.GenericTypeArguments.Length]; + for (var argumentIndex = 0; + argumentIndex < runtimeTypeReference.GenericTypeArguments.Length; + argumentIndex++) + { + genericArgumentExpressions[argumentIndex] = AppendRuntimeTypeReferenceResolution( + builder, + runtimeTypeReference.GenericTypeArguments[argumentIndex], + $"{variableBaseName}GenericArgument{argumentIndex}", + reflectedArgumentNames, + indent); + } + + return + $"{genericTypeDefinitionExpression}.MakeGenericType({string.Join(", ", genericArgumentExpressions)})"; + } + + var reflectionTypeMetadataName = runtimeTypeReference.ReflectionTypeMetadataName!; + reflectedArgumentNames.Add(variableBaseName); + builder.Append(indent); + builder.Append("var "); + builder.Append(variableBaseName); + builder.Append(" = registryAssembly.GetType(\""); + builder.Append(EscapeStringLiteral(reflectionTypeMetadataName)); + builder.AppendLine("\", throwOnError: false, ignoreCase: false);"); + return variableBaseName; + } + private static void AppendReflectionHelpers(StringBuilder builder) { // Emit the runtime helper methods only when at least one handler requires metadata-name lookup. @@ -523,10 +965,56 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator string HandlerInterfaceLogName, string ImplementationLogName); + private readonly record struct ReflectedImplementationRegistrationSpec( + string HandlerInterfaceDisplayName, + string HandlerInterfaceLogName); + + private sealed record RuntimeTypeReferenceSpec( + string? TypeDisplayName, + string? ReflectionTypeMetadataName, + RuntimeTypeReferenceSpec? ArrayElementTypeReference, + int ArrayRank, + RuntimeTypeReferenceSpec? GenericTypeDefinitionReference, + ImmutableArray GenericTypeArguments) + { + public static RuntimeTypeReferenceSpec FromDirectReference(string typeDisplayName) + { + return new RuntimeTypeReferenceSpec(typeDisplayName, null, null, 0, null, + ImmutableArray.Empty); + } + + public static RuntimeTypeReferenceSpec FromReflectionLookup(string reflectionTypeMetadataName) + { + return new RuntimeTypeReferenceSpec(null, reflectionTypeMetadataName, null, 0, null, + ImmutableArray.Empty); + } + + public static RuntimeTypeReferenceSpec FromArray(RuntimeTypeReferenceSpec elementTypeReference, int arrayRank) + { + return new RuntimeTypeReferenceSpec(null, null, elementTypeReference, arrayRank, null, + ImmutableArray.Empty); + } + + public static RuntimeTypeReferenceSpec FromConstructedGeneric( + RuntimeTypeReferenceSpec genericTypeDefinitionReference, + ImmutableArray genericTypeArguments) + { + return new RuntimeTypeReferenceSpec(null, null, null, 0, genericTypeDefinitionReference, + genericTypeArguments); + } + } + + private readonly record struct PreciseReflectedRegistrationSpec( + string OpenHandlerTypeDisplayName, + string HandlerInterfaceLogName, + ImmutableArray ServiceTypeArguments); + private readonly record struct ImplementationRegistrationSpec( string ImplementationTypeDisplayName, string ImplementationLogName, ImmutableArray DirectRegistrations, + ImmutableArray ReflectedImplementationRegistrations, + ImmutableArray PreciseReflectedRegistrations, string? ReflectionTypeMetadataName); private readonly struct HandlerCandidateAnalysis : IEquatable @@ -535,11 +1023,15 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator string implementationTypeDisplayName, string implementationLogName, ImmutableArray registrations, + ImmutableArray reflectedImplementationRegistrations, + ImmutableArray preciseReflectedRegistrations, string? reflectionTypeMetadataName) { ImplementationTypeDisplayName = implementationTypeDisplayName; ImplementationLogName = implementationLogName; Registrations = registrations; + ReflectedImplementationRegistrations = reflectedImplementationRegistrations; + PreciseReflectedRegistrations = preciseReflectedRegistrations; ReflectionTypeMetadataName = reflectionTypeMetadataName; } @@ -549,6 +1041,10 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator public ImmutableArray Registrations { get; } + public ImmutableArray ReflectedImplementationRegistrations { get; } + + public ImmutableArray PreciseReflectedRegistrations { get; } + public string? ReflectionTypeMetadataName { get; } public bool Equals(HandlerCandidateAnalysis other) @@ -558,7 +1054,9 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator !string.Equals(ImplementationLogName, other.ImplementationLogName, StringComparison.Ordinal) || !string.Equals(ReflectionTypeMetadataName, other.ReflectionTypeMetadataName, StringComparison.Ordinal) || - Registrations.Length != other.Registrations.Length) + Registrations.Length != other.Registrations.Length || + ReflectedImplementationRegistrations.Length != other.ReflectedImplementationRegistrations.Length || + PreciseReflectedRegistrations.Length != other.PreciseReflectedRegistrations.Length) { return false; } @@ -569,6 +1067,19 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator return false; } + for (var index = 0; index < ReflectedImplementationRegistrations.Length; index++) + { + if (!ReflectedImplementationRegistrations[index].Equals( + other.ReflectedImplementationRegistrations[index])) + return false; + } + + for (var index = 0; index < PreciseReflectedRegistrations.Length; index++) + { + if (!PreciseReflectedRegistrations[index].Equals(other.PreciseReflectedRegistrations[index])) + return false; + } + return true; } @@ -592,6 +1103,16 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator hashCode = (hashCode * 397) ^ registration.GetHashCode(); } + foreach (var reflectedImplementationRegistration in ReflectedImplementationRegistrations) + { + hashCode = (hashCode * 397) ^ reflectedImplementationRegistration.GetHashCode(); + } + + foreach (var preciseReflectedRegistration in PreciseReflectedRegistrations) + { + hashCode = (hashCode * 397) ^ preciseReflectedRegistration.GetHashCode(); + } + return hashCode; } }