diff --git a/GFramework.Cqrs.Tests/Cqrs/CqrsDispatcherCacheTests.cs b/GFramework.Cqrs.Tests/Cqrs/CqrsDispatcherCacheTests.cs index 6c52a910..5ae794ad 100644 --- a/GFramework.Cqrs.Tests/Cqrs/CqrsDispatcherCacheTests.cs +++ b/GFramework.Cqrs.Tests/Cqrs/CqrsDispatcherCacheTests.cs @@ -12,9 +12,6 @@ namespace GFramework.Cqrs.Tests.Cqrs; [TestFixture] internal sealed class CqrsDispatcherCacheTests { - private MicrosoftDiContainer? _container; - private ArchitectureContext? _context; - /// /// 初始化测试上下文。 /// @@ -23,6 +20,7 @@ internal sealed class CqrsDispatcherCacheTests { LoggerFactoryResolver.Provider = new ConsoleLoggerFactoryProvider(); _container = new MicrosoftDiContainer(); + _container.RegisterCqrsPipelineBehavior(); CqrsTestRuntime.RegisterHandlers( _container, @@ -43,6 +41,9 @@ internal sealed class CqrsDispatcherCacheTests _container = null; } + private MicrosoftDiContainer? _container; + private ArchitectureContext? _context; + /// /// 验证相同消息类型重复分发时,不会重复扩张服务类型缓存。 /// @@ -52,32 +53,54 @@ internal sealed class CqrsDispatcherCacheTests var notificationServiceTypes = GetCacheField("NotificationHandlerServiceTypes"); var requestServiceTypes = GetCacheField("RequestServiceTypes"); var streamServiceTypes = GetCacheField("StreamHandlerServiceTypes"); + var requestInvokers = GetCacheField("RequestInvokers"); + var requestPipelineInvokers = GetCacheField("RequestPipelineInvokers"); + var notificationInvokers = GetCacheField("NotificationInvokers"); + var streamInvokers = GetCacheField("StreamInvokers"); var notificationBefore = notificationServiceTypes.Count; var requestBefore = requestServiceTypes.Count; var streamBefore = streamServiceTypes.Count; + var requestInvokersBefore = requestInvokers.Count; + var requestPipelineInvokersBefore = requestPipelineInvokers.Count; + var notificationInvokersBefore = notificationInvokers.Count; + var streamInvokersBefore = streamInvokers.Count; await _context!.SendRequestAsync(new DispatcherCacheRequest()); + await _context.SendRequestAsync(new DispatcherPipelineCacheRequest()); await _context.PublishAsync(new DispatcherCacheNotification()); await DrainAsync(_context.CreateStream(new DispatcherCacheStreamRequest())); var notificationAfterFirstDispatch = notificationServiceTypes.Count; var requestAfterFirstDispatch = requestServiceTypes.Count; var streamAfterFirstDispatch = streamServiceTypes.Count; + var requestInvokersAfterFirstDispatch = requestInvokers.Count; + var requestPipelineInvokersAfterFirstDispatch = requestPipelineInvokers.Count; + var notificationInvokersAfterFirstDispatch = notificationInvokers.Count; + var streamInvokersAfterFirstDispatch = streamInvokers.Count; await _context.SendRequestAsync(new DispatcherCacheRequest()); + await _context.SendRequestAsync(new DispatcherPipelineCacheRequest()); await _context.PublishAsync(new DispatcherCacheNotification()); await DrainAsync(_context.CreateStream(new DispatcherCacheStreamRequest())); Assert.Multiple(() => { Assert.That(notificationAfterFirstDispatch, Is.EqualTo(notificationBefore + 1)); - Assert.That(requestAfterFirstDispatch, Is.EqualTo(requestBefore + 1)); + Assert.That(requestAfterFirstDispatch, Is.EqualTo(requestBefore + 2)); Assert.That(streamAfterFirstDispatch, Is.EqualTo(streamBefore + 1)); + Assert.That(requestInvokersAfterFirstDispatch, Is.EqualTo(requestInvokersBefore + 1)); + Assert.That(requestPipelineInvokersAfterFirstDispatch, Is.EqualTo(requestPipelineInvokersBefore + 1)); + Assert.That(notificationInvokersAfterFirstDispatch, Is.EqualTo(notificationInvokersBefore + 1)); + Assert.That(streamInvokersAfterFirstDispatch, Is.EqualTo(streamInvokersBefore + 1)); Assert.That(notificationServiceTypes.Count, Is.EqualTo(notificationAfterFirstDispatch)); Assert.That(requestServiceTypes.Count, Is.EqualTo(requestAfterFirstDispatch)); Assert.That(streamServiceTypes.Count, Is.EqualTo(streamAfterFirstDispatch)); + Assert.That(requestInvokers.Count, Is.EqualTo(requestInvokersAfterFirstDispatch)); + Assert.That(requestPipelineInvokers.Count, Is.EqualTo(requestPipelineInvokersAfterFirstDispatch)); + Assert.That(notificationInvokers.Count, Is.EqualTo(notificationInvokersAfterFirstDispatch)); + Assert.That(streamInvokers.Count, Is.EqualTo(streamInvokersAfterFirstDispatch)); }); } @@ -126,6 +149,11 @@ internal sealed record DispatcherCacheNotification : INotification; /// internal sealed record DispatcherCacheStreamRequest : IStreamRequest; +/// +/// 用于验证 pipeline invoker 缓存的测试请求。 +/// +internal sealed record DispatcherPipelineCacheRequest : IRequest; + /// /// 处理 。 /// @@ -170,3 +198,35 @@ internal sealed class DispatcherCacheStreamHandler : IStreamRequestHandler +/// 处理 。 +/// +internal sealed class DispatcherPipelineCacheRequestHandler : IRequestHandler +{ + /// + /// 返回固定结果,供 pipeline 缓存测试使用。 + /// + public ValueTask Handle(DispatcherPipelineCacheRequest request, CancellationToken cancellationToken) + { + return ValueTask.FromResult(2); + } +} + +/// +/// 为 提供最小 pipeline 行为, +/// 用于命中 dispatcher 的 pipeline invoker 缓存分支。 +/// +internal sealed class DispatcherPipelineCacheBehavior : IPipelineBehavior +{ + /// + /// 直接转发到下一个处理器。 + /// + public ValueTask Handle( + DispatcherPipelineCacheRequest request, + MessageHandlerDelegate next, + CancellationToken cancellationToken) + { + return next(request, cancellationToken); + } +} diff --git a/GFramework.Cqrs/Internal/CqrsDispatcher.cs b/GFramework.Cqrs/Internal/CqrsDispatcher.cs index 91532e17..002b7edc 100644 --- a/GFramework.Cqrs/Internal/CqrsDispatcher.cs +++ b/GFramework.Cqrs/Internal/CqrsDispatcher.cs @@ -44,6 +44,19 @@ internal sealed class CqrsDispatcher( StreamHandlerServiceTypes = new(); + // 静态方法定义缓存:这些反射查找与消息类型无关,只需解析一次即可复用。 + private static readonly MethodInfo RequestHandlerInvokerMethodDefinition = typeof(CqrsDispatcher) + .GetMethod(nameof(InvokeRequestHandlerAsync), BindingFlags.NonPublic | BindingFlags.Static)!; + + private static readonly MethodInfo RequestPipelineInvokerMethodDefinition = typeof(CqrsDispatcher) + .GetMethod(nameof(InvokeRequestPipelineAsync), BindingFlags.NonPublic | BindingFlags.Static)!; + + private static readonly MethodInfo NotificationHandlerInvokerMethodDefinition = typeof(CqrsDispatcher) + .GetMethod(nameof(InvokeNotificationHandlerAsync), BindingFlags.NonPublic | BindingFlags.Static)!; + + private static readonly MethodInfo StreamHandlerInvokerMethodDefinition = typeof(CqrsDispatcher) + .GetMethod(nameof(InvokeStreamHandler), BindingFlags.NonPublic | BindingFlags.Static)!; + /// /// 发布通知到所有已注册处理器。 /// @@ -189,8 +202,7 @@ internal sealed class CqrsDispatcher( /// private static RequestInvoker CreateRequestInvoker(Type requestType, Type responseType) { - var method = typeof(CqrsDispatcher) - .GetMethod(nameof(InvokeRequestHandlerAsync), BindingFlags.NonPublic | BindingFlags.Static)! + var method = RequestHandlerInvokerMethodDefinition .MakeGenericMethod(requestType, responseType); return (RequestInvoker)Delegate.CreateDelegate(typeof(RequestInvoker), method); } @@ -200,8 +212,7 @@ internal sealed class CqrsDispatcher( /// private static RequestPipelineInvoker CreateRequestPipelineInvoker(Type requestType, Type responseType) { - var method = typeof(CqrsDispatcher) - .GetMethod(nameof(InvokeRequestPipelineAsync), BindingFlags.NonPublic | BindingFlags.Static)! + var method = RequestPipelineInvokerMethodDefinition .MakeGenericMethod(requestType, responseType); return (RequestPipelineInvoker)Delegate.CreateDelegate(typeof(RequestPipelineInvoker), method); } @@ -211,8 +222,7 @@ internal sealed class CqrsDispatcher( /// private static NotificationInvoker CreateNotificationInvoker(Type notificationType) { - var method = typeof(CqrsDispatcher) - .GetMethod(nameof(InvokeNotificationHandlerAsync), BindingFlags.NonPublic | BindingFlags.Static)! + var method = NotificationHandlerInvokerMethodDefinition .MakeGenericMethod(notificationType); return (NotificationInvoker)Delegate.CreateDelegate(typeof(NotificationInvoker), method); } @@ -222,8 +232,7 @@ internal sealed class CqrsDispatcher( /// private static StreamInvoker CreateStreamInvoker(Type requestType, Type responseType) { - var method = typeof(CqrsDispatcher) - .GetMethod(nameof(InvokeStreamHandler), BindingFlags.NonPublic | BindingFlags.Static)! + var method = StreamHandlerInvokerMethodDefinition .MakeGenericMethod(requestType, responseType); return (StreamInvoker)Delegate.CreateDelegate(typeof(StreamInvoker), method); } diff --git a/GFramework.SourceGenerators.Tests/Cqrs/CqrsHandlerRegistryGeneratorTests.cs b/GFramework.SourceGenerators.Tests/Cqrs/CqrsHandlerRegistryGeneratorTests.cs index e1ec1546..dcdb5e5f 100644 --- a/GFramework.SourceGenerators.Tests/Cqrs/CqrsHandlerRegistryGeneratorTests.cs +++ b/GFramework.SourceGenerators.Tests/Cqrs/CqrsHandlerRegistryGeneratorTests.cs @@ -10,6 +10,138 @@ namespace GFramework.SourceGenerators.Tests.Cqrs; [TestFixture] public class CqrsHandlerRegistryGeneratorTests { + private const string HiddenNestedHandlerSelfRegistrationExpected = """ + // + #nullable enable + + [assembly: global::GFramework.Cqrs.CqrsHandlerRegistryAttribute(typeof(global::GFramework.Generated.Cqrs.__GFrameworkGeneratedCqrsHandlerRegistry))] + + namespace GFramework.Generated.Cqrs; + + internal sealed class __GFrameworkGeneratedCqrsHandlerRegistry : global::GFramework.Cqrs.ICqrsHandlerRegistry + { + public void Register(global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, global::GFramework.Core.Abstractions.Logging.ILogger logger) + { + if (services is null) + throw new global::System.ArgumentNullException(nameof(services)); + if (logger is null) + throw new global::System.ArgumentNullException(nameof(logger)); + + var registryAssembly = typeof(global::GFramework.Generated.Cqrs.__GFrameworkGeneratedCqrsHandlerRegistry).Assembly; + + RegisterReflectedHandler(services, logger, registryAssembly, "TestApp.Container+HiddenHandler"); + global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient( + services, + typeof(global::GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler), + typeof(global::TestApp.VisibleHandler)); + logger.Debug("Registered CQRS handler TestApp.VisibleHandler as GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler."); + } + + private static void RegisterReflectedHandler(global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, global::GFramework.Core.Abstractions.Logging.ILogger logger, global::System.Reflection.Assembly registryAssembly, string implementationTypeMetadataName) + { + var implementationType = registryAssembly.GetType(implementationTypeMetadataName, throwOnError: false, ignoreCase: false); + if (implementationType is null) + return; + + var handlerInterfaces = implementationType.GetInterfaces(); + global::System.Array.Sort(handlerInterfaces, CompareTypes); + + foreach (var handlerInterface in handlerInterfaces) + { + if (!IsSupportedHandlerInterface(handlerInterface)) + continue; + + global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient( + services, + handlerInterface, + implementationType); + logger.Debug($"Registered CQRS handler {GetRuntimeTypeDisplayName(implementationType)} as {GetRuntimeTypeDisplayName(handlerInterface)}."); + } + } + + private static int CompareTypes(global::System.Type left, global::System.Type right) + { + return global::System.StringComparer.Ordinal.Compare(GetRuntimeTypeDisplayName(left), GetRuntimeTypeDisplayName(right)); + } + + private static bool IsSupportedHandlerInterface(global::System.Type interfaceType) + { + if (!interfaceType.IsGenericType) + return false; + + var definitionFullName = interfaceType.GetGenericTypeDefinition().FullName; + return global::System.StringComparer.Ordinal.Equals(definitionFullName, "GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler`2") + || global::System.StringComparer.Ordinal.Equals(definitionFullName, "GFramework.Cqrs.Abstractions.Cqrs.INotificationHandler`1") + || global::System.StringComparer.Ordinal.Equals(definitionFullName, "GFramework.Cqrs.Abstractions.Cqrs.IStreamRequestHandler`2"); + } + + private static string GetRuntimeTypeDisplayName(global::System.Type type) + { + if (type == typeof(string)) + return "string"; + if (type == typeof(int)) + return "int"; + if (type == typeof(long)) + return "long"; + if (type == typeof(short)) + return "short"; + if (type == typeof(byte)) + return "byte"; + if (type == typeof(bool)) + return "bool"; + if (type == typeof(object)) + return "object"; + if (type == typeof(void)) + return "void"; + if (type == typeof(uint)) + return "uint"; + if (type == typeof(ulong)) + return "ulong"; + if (type == typeof(ushort)) + return "ushort"; + if (type == typeof(sbyte)) + return "sbyte"; + if (type == typeof(float)) + return "float"; + if (type == typeof(double)) + return "double"; + if (type == typeof(decimal)) + return "decimal"; + if (type == typeof(char)) + return "char"; + + if (type.IsArray) + return GetRuntimeTypeDisplayName(type.GetElementType()!) + "[]"; + + if (!type.IsGenericType) + return (type.FullName ?? type.Name).Replace('+', '.'); + + var genericTypeName = type.GetGenericTypeDefinition().FullName ?? type.Name; + var arityIndex = genericTypeName.IndexOf('`'); + if (arityIndex >= 0) + genericTypeName = genericTypeName[..arityIndex]; + + genericTypeName = genericTypeName.Replace('+', '.'); + var arguments = type.GetGenericArguments(); + var builder = new global::System.Text.StringBuilder(); + builder.Append(genericTypeName); + builder.Append('<'); + + for (var index = 0; index < arguments.Length; index++) + { + if (index > 0) + builder.Append(", "); + + builder.Append(GetRuntimeTypeDisplayName(arguments[index])); + } + + builder.Append('>'); + return builder.ToString(); + } + } + + """; + /// /// 验证生成器会为当前程序集中的 request、notification 和 stream 处理器生成稳定顺序的注册器。 /// @@ -126,12 +258,12 @@ public class CqrsHandlerRegistryGeneratorTests } /// - /// 验证当程序集包含生成代码无法合法引用的私有嵌套处理器时,生成器仍会为可见 handlers 生成注册器, - /// 并额外标记运行时补充反射扫描。 + /// 验证当程序集包含生成代码无法合法引用的私有嵌套处理器时,生成器会在生成注册器内部执行定向反射注册, + /// 不再依赖程序集级 fallback marker。 /// [Test] public async Task - Generates_Visible_Handlers_And_Requests_Reflection_Fallback_When_Assembly_Contains_Private_Nested_Handler() + Generates_Visible_Handlers_And_Self_Registers_Private_Nested_Handler_When_Assembly_Contains_Hidden_Handler() { const string source = """ using System; @@ -202,45 +334,17 @@ public class CqrsHandlerRegistryGeneratorTests } """; - const string expected = """ - // - #nullable enable - - [assembly: global::GFramework.Cqrs.CqrsHandlerRegistryAttribute(typeof(global::GFramework.Generated.Cqrs.__GFrameworkGeneratedCqrsHandlerRegistry))] - [assembly: global::GFramework.Cqrs.CqrsReflectionFallbackAttribute("TestApp.Container+HiddenHandler")] - - namespace GFramework.Generated.Cqrs; - - internal sealed class __GFrameworkGeneratedCqrsHandlerRegistry : global::GFramework.Cqrs.ICqrsHandlerRegistry - { - public void Register(global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, global::GFramework.Core.Abstractions.Logging.ILogger logger) - { - if (services is null) - throw new global::System.ArgumentNullException(nameof(services)); - if (logger is null) - throw new global::System.ArgumentNullException(nameof(logger)); - - global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient( - services, - typeof(global::GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler), - typeof(global::TestApp.VisibleHandler)); - logger.Debug("Registered CQRS handler TestApp.VisibleHandler as GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler."); - } - } - - """; - await GeneratorTest.RunAsync( source, - ("CqrsHandlerRegistry.g.cs", expected)); + ("CqrsHandlerRegistry.g.cs", HiddenNestedHandlerSelfRegistrationExpected)); } /// - /// 验证当 runtime 仅支持旧版无参 fallback marker 时,生成器会退回旧语义, - /// 只输出 marker 而不输出精确类型名。 + /// 验证即使 runtime 仍暴露旧版无参 fallback marker,生成器也会优先在生成注册器内部处理隐藏 handler, + /// 不再输出 fallback marker。 /// [Test] - public async Task Generates_Legacy_Fallback_Marker_When_Runtime_Does_Not_Support_Type_Name_List() + public async Task Does_Not_Emit_Legacy_Fallback_Marker_When_Generated_Registry_Can_Self_Register_Hidden_Handler() { const string source = """ using System; @@ -311,45 +415,17 @@ public class CqrsHandlerRegistryGeneratorTests } """; - const string expected = """ - // - #nullable enable - - [assembly: global::GFramework.Cqrs.CqrsHandlerRegistryAttribute(typeof(global::GFramework.Generated.Cqrs.__GFrameworkGeneratedCqrsHandlerRegistry))] - [assembly: global::GFramework.Cqrs.CqrsReflectionFallbackAttribute()] - - namespace GFramework.Generated.Cqrs; - - internal sealed class __GFrameworkGeneratedCqrsHandlerRegistry : global::GFramework.Cqrs.ICqrsHandlerRegistry - { - public void Register(global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, global::GFramework.Core.Abstractions.Logging.ILogger logger) - { - if (services is null) - throw new global::System.ArgumentNullException(nameof(services)); - if (logger is null) - throw new global::System.ArgumentNullException(nameof(logger)); - - global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient( - services, - typeof(global::GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler), - typeof(global::TestApp.VisibleHandler)); - logger.Debug("Registered CQRS handler TestApp.VisibleHandler as GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler."); - } - } - - """; - await GeneratorTest.RunAsync( source, - ("CqrsHandlerRegistry.g.cs", expected)); + ("CqrsHandlerRegistry.g.cs", HiddenNestedHandlerSelfRegistrationExpected)); } /// - /// 验证当旧版 runtime 合同中不存在 reflection fallback 标记特性时, - /// 生成器会保留此前的整程序集回退行为,避免丢失不可见 handlers。 + /// 验证即使 runtime 合同中完全不存在 reflection fallback 标记特性, + /// 生成器仍能通过生成注册器内部的定向反射逻辑覆盖隐藏 handler。 /// [Test] - public async Task Skips_Generation_For_Unsupported_Handler_When_Fallback_Marker_Is_Unavailable() + public async Task Generates_Registry_For_Hidden_Handler_When_Fallback_Marker_Is_Unavailable() { const string source = """ using System; @@ -414,16 +490,9 @@ public class CqrsHandlerRegistryGeneratorTests } """; - var test = new CSharpSourceGeneratorTest - { - TestState = - { - Sources = { source } - }, - DisabledDiagnostics = { "GF_Common_Trace_001" } - }; - - await test.RunAsync(); + await GeneratorTest.RunAsync( + source, + ("CqrsHandlerRegistry.g.cs", HiddenNestedHandlerSelfRegistrationExpected)); } /// diff --git a/GFramework.SourceGenerators/Cqrs/CqrsHandlerRegistryGenerator.cs b/GFramework.SourceGenerators/Cqrs/CqrsHandlerRegistryGenerator.cs index 1e260e32..83559781 100644 --- a/GFramework.SourceGenerators/Cqrs/CqrsHandlerRegistryGenerator.cs +++ b/GFramework.SourceGenerators/Cqrs/CqrsHandlerRegistryGenerator.cs @@ -16,9 +16,6 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator private const string IStreamRequestHandlerMetadataName = $"{CqrsContractsNamespace}.IStreamRequestHandler`2"; private const string ICqrsHandlerRegistryMetadataName = $"{CqrsRuntimeNamespace}.ICqrsHandlerRegistry"; - private const string CqrsReflectionFallbackAttributeMetadataName = - $"{CqrsRuntimeNamespace}.CqrsReflectionFallbackAttribute"; - private const string CqrsHandlerRegistryAttributeMetadataName = $"{CqrsRuntimeNamespace}.CqrsHandlerRegistryAttribute"; @@ -57,10 +54,7 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator compilation.GetTypeByMetadataName(ILoggerMetadataName) is not null && compilation.GetTypeByMetadataName(IServiceCollectionMetadataName) is not null; - return new GenerationEnvironment( - generationEnabled, - GetReflectionFallbackEmissionMode( - compilation.GetTypeByMetadataName(CqrsReflectionFallbackAttributeMetadataName))); + return new GenerationEnvironment(generationEnabled); } private static bool IsHandlerCandidate(SyntaxNode node) @@ -91,17 +85,20 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator return null; var implementationTypeDisplayName = type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + var implementationLogName = GetLogDisplayName(type); if (!CanReferenceFromGeneratedRegistry(type) || handlerInterfaces.Any(interfaceType => !CanReferenceFromGeneratedRegistry(interfaceType))) { + // Non-public handlers and handlers closed over non-public message types cannot appear in typeof(...) + // expressions inside generated code. Preserve generator hit rate by resolving just that implementation + // type back from the current assembly instead of asking the runtime registrar to rescan the assembly. return new HandlerCandidateAnalysis( implementationTypeDisplayName, + implementationLogName, ImmutableArray.Empty, - true, - GetReflectionFallbackTypeName(type)); + GetReflectionTypeMetadataName(type)); } - var implementationLogName = GetLogDisplayName(type); var registrations = ImmutableArray.CreateBuilder(handlerInterfaces.Length); foreach (var handlerInterface in handlerInterfaces) { @@ -114,8 +111,8 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator return new HandlerCandidateAnalysis( implementationTypeDisplayName, + implementationLogName, registrations.MoveToImmutable(), - false, null); } @@ -125,40 +122,23 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator if (!generationEnvironment.GenerationEnabled) return; - var registrations = CollectRegistrations( - candidates, - out var hasUnsupportedConcreteHandler, - out var reflectionFallbackTypeNames); + var registrations = CollectRegistrations(candidates); if (registrations.Count == 0) return; - // If the runtime contract does not yet expose the reflection fallback marker, - // keep the previous all-or-nothing behavior so unsupported handlers are not silently dropped. - if (hasUnsupportedConcreteHandler && - generationEnvironment.ReflectionFallbackEmissionMode == ReflectionFallbackEmissionMode.Disabled) - return; - context.AddSource( HintName, - GenerateSource( - registrations, - hasUnsupportedConcreteHandler, - generationEnvironment.ReflectionFallbackEmissionMode, - reflectionFallbackTypeNames)); + GenerateSource(registrations)); } - private static List CollectRegistrations( - ImmutableArray candidates, - out bool hasUnsupportedConcreteHandler, - out IReadOnlyList reflectionFallbackTypeNames) + private static List CollectRegistrations( + ImmutableArray candidates) { - var registrations = new List(); - hasUnsupportedConcreteHandler = false; - var fallbackTypeNames = new SortedSet(StringComparer.Ordinal); + var registrations = new List(); // Partial declarations surface the same symbol through multiple syntax nodes. - // Collapse them by implementation type so generated registrations stay stable and duplicate-free. + // Collapse them by implementation type so direct and reflected registrations stay stable and duplicate-free. var uniqueCandidates = new Dictionary(StringComparer.Ordinal); foreach (var candidate in candidates) @@ -166,25 +146,16 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator if (candidate is null) continue; - if (candidate.Value.HasUnsupportedConcreteHandler) - { - hasUnsupportedConcreteHandler = true; - var reflectionFallbackTypeName = candidate.Value.ReflectionFallbackTypeName; - if (reflectionFallbackTypeName is not null && - !string.IsNullOrWhiteSpace(reflectionFallbackTypeName)) - { - fallbackTypeNames.Add(reflectionFallbackTypeName); - } - - continue; - } - uniqueCandidates[candidate.Value.ImplementationTypeDisplayName] = candidate.Value; } foreach (var candidate in uniqueCandidates.Values) { - registrations.AddRange(candidate.Registrations); + registrations.Add(new ImplementationRegistrationSpec( + candidate.ImplementationTypeDisplayName, + candidate.ImplementationLogName, + candidate.Registrations, + candidate.ReflectionTypeMetadataName)); } registrations.Sort(static (left, right) => @@ -193,35 +164,12 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator left.ImplementationLogName, right.ImplementationLogName); - return implementationComparison != 0 - ? implementationComparison - : StringComparer.Ordinal.Compare(left.HandlerInterfaceLogName, right.HandlerInterfaceLogName); + return implementationComparison; }); - reflectionFallbackTypeNames = fallbackTypeNames.ToArray(); return registrations; } - private static ReflectionFallbackEmissionMode GetReflectionFallbackEmissionMode(INamedTypeSymbol? attributeType) - { - if (attributeType is null) - return ReflectionFallbackEmissionMode.Disabled; - - foreach (var constructor in attributeType.InstanceConstructors) - { - if (constructor.Parameters.Length != 1) - continue; - - if (constructor.Parameters[0].Type is IArrayTypeSymbol arrayType && - arrayType.ElementType.SpecialType == SpecialType.System_String) - { - return ReflectionFallbackEmissionMode.PreciseTypeNames; - } - } - - return ReflectionFallbackEmissionMode.MarkerOnly; - } - private static bool IsConcreteHandlerType(INamedTypeSymbol type) { return type.TypeKind is TypeKind.Class or TypeKind.Struct && @@ -313,7 +261,7 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator return builder.ToString(); } - private static string GetReflectionFallbackTypeName(INamedTypeSymbol type) + private static string GetReflectionTypeMetadataName(INamedTypeSymbol type) { var nestedTypes = new Stack(); for (var current = type; current is not null; current = current.ContainingType) @@ -352,11 +300,10 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator } private static string GenerateSource( - IReadOnlyList registrations, - bool emitReflectionFallbackAttribute, - ReflectionFallbackEmissionMode reflectionFallbackEmissionMode, - IReadOnlyList reflectionFallbackTypeNames) + IReadOnlyList registrations) { + var hasReflectionRegistrations = registrations.Any(static registration => + !string.IsNullOrWhiteSpace(registration.ReflectionTypeMetadataName)); var builder = new StringBuilder(); builder.AppendLine("// "); builder.AppendLine("#nullable enable"); @@ -368,11 +315,6 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator builder.Append('.'); builder.Append(GeneratedTypeName); builder.AppendLine("))]"); - if (emitReflectionFallbackAttribute && - reflectionFallbackEmissionMode != ReflectionFallbackEmissionMode.Disabled) - { - AppendReflectionFallbackAttribute(builder, reflectionFallbackEmissionMode, reflectionFallbackTypeNames); - } builder.AppendLine(); builder.Append("namespace "); @@ -394,59 +336,177 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator builder.AppendLine(" throw new global::System.ArgumentNullException(nameof(services));"); builder.AppendLine(" if (logger is null)"); builder.AppendLine(" throw new global::System.ArgumentNullException(nameof(logger));"); - builder.AppendLine(); + if (hasReflectionRegistrations) + { + builder.AppendLine(); + builder.Append(" var registryAssembly = typeof(global::"); + builder.Append(GeneratedNamespace); + builder.Append('.'); + builder.Append(GeneratedTypeName); + builder.AppendLine(").Assembly;"); + } + + if (registrations.Count > 0) + builder.AppendLine(); foreach (var registration in registrations) { - builder.AppendLine( - " global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient("); - builder.AppendLine(" services,"); - builder.Append(" typeof("); - builder.Append(registration.HandlerInterfaceDisplayName); - builder.AppendLine("),"); - builder.Append(" typeof("); - builder.Append(registration.ImplementationTypeDisplayName); - builder.AppendLine("));"); - builder.Append(" logger.Debug(\"Registered CQRS handler "); - builder.Append(EscapeStringLiteral(registration.ImplementationLogName)); - builder.Append(" as "); - builder.Append(EscapeStringLiteral(registration.HandlerInterfaceLogName)); - builder.AppendLine(".\");"); + if (!string.IsNullOrWhiteSpace(registration.ReflectionTypeMetadataName)) + { + AppendReflectionRegistration(builder, registration.ReflectionTypeMetadataName!); + continue; + } + + foreach (var directRegistration in registration.DirectRegistrations) + { + builder.AppendLine( + " global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient("); + builder.AppendLine(" services,"); + builder.Append(" typeof("); + builder.Append(directRegistration.HandlerInterfaceDisplayName); + builder.AppendLine("),"); + builder.Append(" typeof("); + builder.Append(directRegistration.ImplementationTypeDisplayName); + builder.AppendLine("));"); + builder.Append(" logger.Debug(\"Registered CQRS handler "); + builder.Append(EscapeStringLiteral(directRegistration.ImplementationLogName)); + builder.Append(" as "); + builder.Append(EscapeStringLiteral(directRegistration.HandlerInterfaceLogName)); + builder.AppendLine(".\");"); + } } builder.AppendLine(" }"); + + if (hasReflectionRegistrations) + { + builder.AppendLine(); + AppendReflectionHelpers(builder); + } + builder.AppendLine("}"); return builder.ToString(); } - private static void AppendReflectionFallbackAttribute( - StringBuilder builder, - ReflectionFallbackEmissionMode reflectionFallbackEmissionMode, - IReadOnlyList reflectionFallbackTypeNames) + private static void AppendReflectionRegistration(StringBuilder builder, string reflectionTypeMetadataName) { - builder.Append("[assembly: global::"); - builder.Append(CqrsRuntimeNamespace); - builder.Append(".CqrsReflectionFallbackAttribute"); + builder.Append(" RegisterReflectedHandler(services, logger, registryAssembly, \""); + builder.Append(EscapeStringLiteral(reflectionTypeMetadataName)); + builder.AppendLine("\");"); + } - if (reflectionFallbackEmissionMode == ReflectionFallbackEmissionMode.PreciseTypeNames && - reflectionFallbackTypeNames.Count > 0) - { - builder.Append('('); - for (var index = 0; index < reflectionFallbackTypeNames.Count; index++) - { - if (index > 0) - builder.Append(", "); - - builder.Append('"'); - builder.Append(EscapeStringLiteral(reflectionFallbackTypeNames[index])); - builder.Append('"'); - } - - builder.AppendLine(")]"); - return; - } - - builder.AppendLine("()]"); + private static void AppendReflectionHelpers(StringBuilder builder) + { + // Emit the runtime helper methods only when at least one handler requires metadata-name lookup. + builder.AppendLine( + " private static void RegisterReflectedHandler(global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, global::GFramework.Core.Abstractions.Logging.ILogger logger, global::System.Reflection.Assembly registryAssembly, string implementationTypeMetadataName)"); + builder.AppendLine(" {"); + builder.AppendLine( + " var implementationType = registryAssembly.GetType(implementationTypeMetadataName, throwOnError: false, ignoreCase: false);"); + builder.AppendLine(" if (implementationType is null)"); + builder.AppendLine(" return;"); + builder.AppendLine(); + builder.AppendLine(" var handlerInterfaces = implementationType.GetInterfaces();"); + builder.AppendLine(" global::System.Array.Sort(handlerInterfaces, CompareTypes);"); + builder.AppendLine(); + builder.AppendLine(" foreach (var handlerInterface in handlerInterfaces)"); + builder.AppendLine(" {"); + builder.AppendLine(" if (!IsSupportedHandlerInterface(handlerInterface))"); + builder.AppendLine(" continue;"); + builder.AppendLine(); + builder.AppendLine( + " global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient("); + builder.AppendLine(" services,"); + builder.AppendLine(" handlerInterface,"); + builder.AppendLine(" implementationType);"); + builder.AppendLine( + " logger.Debug($\"Registered CQRS handler {GetRuntimeTypeDisplayName(implementationType)} as {GetRuntimeTypeDisplayName(handlerInterface)}.\");"); + builder.AppendLine(" }"); + builder.AppendLine(" }"); + builder.AppendLine(); + builder.AppendLine(" private static int CompareTypes(global::System.Type left, global::System.Type right)"); + builder.AppendLine(" {"); + builder.AppendLine( + " return global::System.StringComparer.Ordinal.Compare(GetRuntimeTypeDisplayName(left), GetRuntimeTypeDisplayName(right));"); + builder.AppendLine(" }"); + builder.AppendLine(); + builder.AppendLine(" private static bool IsSupportedHandlerInterface(global::System.Type interfaceType)"); + builder.AppendLine(" {"); + builder.AppendLine(" if (!interfaceType.IsGenericType)"); + builder.AppendLine(" return false;"); + builder.AppendLine(); + builder.AppendLine(" var definitionFullName = interfaceType.GetGenericTypeDefinition().FullName;"); + builder.AppendLine( + $" return global::System.StringComparer.Ordinal.Equals(definitionFullName, \"{IRequestHandlerMetadataName}\")"); + builder.AppendLine( + $" || global::System.StringComparer.Ordinal.Equals(definitionFullName, \"{INotificationHandlerMetadataName}\")"); + builder.AppendLine( + $" || global::System.StringComparer.Ordinal.Equals(definitionFullName, \"{IStreamRequestHandlerMetadataName}\");"); + builder.AppendLine(" }"); + builder.AppendLine(); + builder.AppendLine(" private static string GetRuntimeTypeDisplayName(global::System.Type type)"); + builder.AppendLine(" {"); + builder.AppendLine(" if (type == typeof(string))"); + builder.AppendLine(" return \"string\";"); + builder.AppendLine(" if (type == typeof(int))"); + builder.AppendLine(" return \"int\";"); + builder.AppendLine(" if (type == typeof(long))"); + builder.AppendLine(" return \"long\";"); + builder.AppendLine(" if (type == typeof(short))"); + builder.AppendLine(" return \"short\";"); + builder.AppendLine(" if (type == typeof(byte))"); + builder.AppendLine(" return \"byte\";"); + builder.AppendLine(" if (type == typeof(bool))"); + builder.AppendLine(" return \"bool\";"); + builder.AppendLine(" if (type == typeof(object))"); + builder.AppendLine(" return \"object\";"); + builder.AppendLine(" if (type == typeof(void))"); + builder.AppendLine(" return \"void\";"); + builder.AppendLine(" if (type == typeof(uint))"); + builder.AppendLine(" return \"uint\";"); + builder.AppendLine(" if (type == typeof(ulong))"); + builder.AppendLine(" return \"ulong\";"); + builder.AppendLine(" if (type == typeof(ushort))"); + builder.AppendLine(" return \"ushort\";"); + builder.AppendLine(" if (type == typeof(sbyte))"); + builder.AppendLine(" return \"sbyte\";"); + builder.AppendLine(" if (type == typeof(float))"); + builder.AppendLine(" return \"float\";"); + builder.AppendLine(" if (type == typeof(double))"); + builder.AppendLine(" return \"double\";"); + builder.AppendLine(" if (type == typeof(decimal))"); + builder.AppendLine(" return \"decimal\";"); + builder.AppendLine(" if (type == typeof(char))"); + builder.AppendLine(" return \"char\";"); + builder.AppendLine(); + builder.AppendLine(" if (type.IsArray)"); + builder.AppendLine(" return GetRuntimeTypeDisplayName(type.GetElementType()!) + \"[]\";"); + builder.AppendLine(); + builder.AppendLine(" if (!type.IsGenericType)"); + builder.AppendLine(" return (type.FullName ?? type.Name).Replace('+', '.');"); + builder.AppendLine(); + builder.AppendLine(" var genericTypeName = type.GetGenericTypeDefinition().FullName ?? type.Name;"); + builder.AppendLine(" var arityIndex = genericTypeName.IndexOf('`');"); + builder.AppendLine(" if (arityIndex >= 0)"); + builder.AppendLine(" genericTypeName = genericTypeName[..arityIndex];"); + builder.AppendLine(); + builder.AppendLine(" genericTypeName = genericTypeName.Replace('+', '.');"); + builder.AppendLine(" var arguments = type.GetGenericArguments();"); + builder.AppendLine(" var builder = new global::System.Text.StringBuilder();"); + builder.AppendLine(" builder.Append(genericTypeName);"); + builder.AppendLine(" builder.Append('<');"); + builder.AppendLine(); + builder.AppendLine(" for (var index = 0; index < arguments.Length; index++)"); + builder.AppendLine(" {"); + builder.AppendLine(" if (index > 0)"); + builder.AppendLine(" builder.Append(\", \");"); + builder.AppendLine(); + builder.AppendLine(" builder.Append(GetRuntimeTypeDisplayName(arguments[index]));"); + builder.AppendLine(" }"); + builder.AppendLine(); + builder.AppendLine(" builder.Append('>');"); + builder.AppendLine(" return builder.ToString();"); + builder.AppendLine(" }"); } private static string EscapeStringLiteral(string value) @@ -463,34 +523,40 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator string HandlerInterfaceLogName, string ImplementationLogName); + private readonly record struct ImplementationRegistrationSpec( + string ImplementationTypeDisplayName, + string ImplementationLogName, + ImmutableArray DirectRegistrations, + string? ReflectionTypeMetadataName); + private readonly struct HandlerCandidateAnalysis : IEquatable { public HandlerCandidateAnalysis( string implementationTypeDisplayName, + string implementationLogName, ImmutableArray registrations, - bool hasUnsupportedConcreteHandler, - string? reflectionFallbackTypeName) + string? reflectionTypeMetadataName) { ImplementationTypeDisplayName = implementationTypeDisplayName; + ImplementationLogName = implementationLogName; Registrations = registrations; - HasUnsupportedConcreteHandler = hasUnsupportedConcreteHandler; - ReflectionFallbackTypeName = reflectionFallbackTypeName; + ReflectionTypeMetadataName = reflectionTypeMetadataName; } public string ImplementationTypeDisplayName { get; } + public string ImplementationLogName { get; } + public ImmutableArray Registrations { get; } - public bool HasUnsupportedConcreteHandler { get; } - - public string? ReflectionFallbackTypeName { get; } + public string? ReflectionTypeMetadataName { get; } public bool Equals(HandlerCandidateAnalysis other) { if (!string.Equals(ImplementationTypeDisplayName, other.ImplementationTypeDisplayName, StringComparison.Ordinal) || - HasUnsupportedConcreteHandler != other.HasUnsupportedConcreteHandler || - !string.Equals(ReflectionFallbackTypeName, other.ReflectionFallbackTypeName, + !string.Equals(ImplementationLogName, other.ImplementationLogName, StringComparison.Ordinal) || + !string.Equals(ReflectionTypeMetadataName, other.ReflectionTypeMetadataName, StringComparison.Ordinal) || Registrations.Length != other.Registrations.Length) { @@ -516,11 +582,11 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator unchecked { var hashCode = StringComparer.Ordinal.GetHashCode(ImplementationTypeDisplayName); - hashCode = (hashCode * 397) ^ HasUnsupportedConcreteHandler.GetHashCode(); + hashCode = (hashCode * 397) ^ StringComparer.Ordinal.GetHashCode(ImplementationLogName); hashCode = (hashCode * 397) ^ - (ReflectionFallbackTypeName is null + (ReflectionTypeMetadataName is null ? 0 - : StringComparer.Ordinal.GetHashCode(ReflectionFallbackTypeName)); + : StringComparer.Ordinal.GetHashCode(ReflectionTypeMetadataName)); foreach (var registration in Registrations) { hashCode = (hashCode * 397) ^ registration.GetHashCode(); @@ -531,14 +597,5 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator } } - private readonly record struct GenerationEnvironment( - bool GenerationEnabled, - ReflectionFallbackEmissionMode ReflectionFallbackEmissionMode); - - private enum ReflectionFallbackEmissionMode - { - Disabled, - MarkerOnly, - PreciseTypeNames - } + private readonly record struct GenerationEnvironment(bool GenerationEnabled); }