diff --git a/GFramework.Core/Ioc/MicrosoftDiContainer.cs b/GFramework.Core/Ioc/MicrosoftDiContainer.cs index 6152366f..d1a0576d 100644 --- a/GFramework.Core/Ioc/MicrosoftDiContainer.cs +++ b/GFramework.Core/Ioc/MicrosoftDiContainer.cs @@ -400,7 +400,7 @@ public class MicrosoftDiContainer(IServiceCollection? serviceCollection = null) /// 要接入的程序集集合。 /// /// 中存在 元素。 - /// 容器已冻结,无法继续注册 CQRS 处理器。 + /// 容器已冻结,无法继续注册 CQRS 处理器。 public void RegisterCqrsHandlersFromAssemblies(IEnumerable assemblies) { ArgumentNullException.ThrowIfNull(assemblies); diff --git a/GFramework.Cqrs.Tests/Cqrs/CqrsDispatcherCacheTests.cs b/GFramework.Cqrs.Tests/Cqrs/CqrsDispatcherCacheTests.cs new file mode 100644 index 00000000..5ae794ad --- /dev/null +++ b/GFramework.Cqrs.Tests/Cqrs/CqrsDispatcherCacheTests.cs @@ -0,0 +1,232 @@ +using GFramework.Core.Abstractions.Logging; +using GFramework.Core.Architectures; +using GFramework.Core.Ioc; +using GFramework.Core.Logging; +using GFramework.Cqrs.Abstractions.Cqrs; + +namespace GFramework.Cqrs.Tests.Cqrs; + +/// +/// 验证 CQRS dispatcher 会缓存热路径中的服务类型构造结果。 +/// +[TestFixture] +internal sealed class CqrsDispatcherCacheTests +{ + /// + /// 初始化测试上下文。 + /// + [SetUp] + public void SetUp() + { + LoggerFactoryResolver.Provider = new ConsoleLoggerFactoryProvider(); + _container = new MicrosoftDiContainer(); + _container.RegisterCqrsPipelineBehavior(); + + CqrsTestRuntime.RegisterHandlers( + _container, + typeof(CqrsDispatcherCacheTests).Assembly, + typeof(ArchitectureContext).Assembly); + + _container.Freeze(); + _context = new ArchitectureContext(_container); + } + + /// + /// 清理测试上下文引用。 + /// + [TearDown] + public void TearDown() + { + _context = null; + _container = null; + } + + private MicrosoftDiContainer? _container; + private ArchitectureContext? _context; + + /// + /// 验证相同消息类型重复分发时,不会重复扩张服务类型缓存。 + /// + [Test] + public async Task Dispatcher_Should_Cache_Service_Types_After_First_Dispatch() + { + var notificationServiceTypes = GetCacheField("NotificationHandlerServiceTypes"); + var requestServiceTypes = GetCacheField("RequestServiceTypes"); + var streamServiceTypes = GetCacheField("StreamHandlerServiceTypes"); + var requestInvokers = GetCacheField("RequestInvokers"); + var requestPipelineInvokers = GetCacheField("RequestPipelineInvokers"); + var notificationInvokers = GetCacheField("NotificationInvokers"); + var streamInvokers = GetCacheField("StreamInvokers"); + + var notificationBefore = notificationServiceTypes.Count; + var requestBefore = requestServiceTypes.Count; + var streamBefore = streamServiceTypes.Count; + var requestInvokersBefore = requestInvokers.Count; + var requestPipelineInvokersBefore = requestPipelineInvokers.Count; + var notificationInvokersBefore = notificationInvokers.Count; + var streamInvokersBefore = streamInvokers.Count; + + await _context!.SendRequestAsync(new DispatcherCacheRequest()); + await _context.SendRequestAsync(new DispatcherPipelineCacheRequest()); + await _context.PublishAsync(new DispatcherCacheNotification()); + await DrainAsync(_context.CreateStream(new DispatcherCacheStreamRequest())); + + var notificationAfterFirstDispatch = notificationServiceTypes.Count; + var requestAfterFirstDispatch = requestServiceTypes.Count; + var streamAfterFirstDispatch = streamServiceTypes.Count; + var requestInvokersAfterFirstDispatch = requestInvokers.Count; + var requestPipelineInvokersAfterFirstDispatch = requestPipelineInvokers.Count; + var notificationInvokersAfterFirstDispatch = notificationInvokers.Count; + var streamInvokersAfterFirstDispatch = streamInvokers.Count; + + await _context.SendRequestAsync(new DispatcherCacheRequest()); + await _context.SendRequestAsync(new DispatcherPipelineCacheRequest()); + await _context.PublishAsync(new DispatcherCacheNotification()); + await DrainAsync(_context.CreateStream(new DispatcherCacheStreamRequest())); + + Assert.Multiple(() => + { + Assert.That(notificationAfterFirstDispatch, Is.EqualTo(notificationBefore + 1)); + Assert.That(requestAfterFirstDispatch, Is.EqualTo(requestBefore + 2)); + Assert.That(streamAfterFirstDispatch, Is.EqualTo(streamBefore + 1)); + Assert.That(requestInvokersAfterFirstDispatch, Is.EqualTo(requestInvokersBefore + 1)); + Assert.That(requestPipelineInvokersAfterFirstDispatch, Is.EqualTo(requestPipelineInvokersBefore + 1)); + Assert.That(notificationInvokersAfterFirstDispatch, Is.EqualTo(notificationInvokersBefore + 1)); + Assert.That(streamInvokersAfterFirstDispatch, Is.EqualTo(streamInvokersBefore + 1)); + + Assert.That(notificationServiceTypes.Count, Is.EqualTo(notificationAfterFirstDispatch)); + Assert.That(requestServiceTypes.Count, Is.EqualTo(requestAfterFirstDispatch)); + Assert.That(streamServiceTypes.Count, Is.EqualTo(streamAfterFirstDispatch)); + Assert.That(requestInvokers.Count, Is.EqualTo(requestInvokersAfterFirstDispatch)); + Assert.That(requestPipelineInvokers.Count, Is.EqualTo(requestPipelineInvokersAfterFirstDispatch)); + Assert.That(notificationInvokers.Count, Is.EqualTo(notificationInvokersAfterFirstDispatch)); + Assert.That(streamInvokers.Count, Is.EqualTo(streamInvokersAfterFirstDispatch)); + }); + } + + /// + /// 通过反射读取 dispatcher 的静态缓存字典。 + /// + private static IDictionary GetCacheField(string fieldName) + { + var dispatcherType = typeof(CqrsReflectionFallbackAttribute).Assembly + .GetType("GFramework.Cqrs.Internal.CqrsDispatcher", throwOnError: true)!; + + var field = dispatcherType.GetField( + fieldName, + BindingFlags.NonPublic | BindingFlags.Static); + + Assert.That(field, Is.Not.Null, $"Missing dispatcher cache field {fieldName}."); + + return field!.GetValue(null) as IDictionary + ?? throw new InvalidOperationException( + $"Dispatcher cache field {fieldName} does not implement IDictionary."); + } + + /// + /// 消费整个异步流,确保建流路径被真实执行。 + /// + private static async Task DrainAsync(IAsyncEnumerable stream) + { + await foreach (var _ in stream) + { + } + } +} + +/// +/// 用于验证 request 服务类型缓存的测试请求。 +/// +internal sealed record DispatcherCacheRequest : IRequest; + +/// +/// 用于验证 notification 服务类型缓存的测试通知。 +/// +internal sealed record DispatcherCacheNotification : INotification; + +/// +/// 用于验证 stream 服务类型缓存的测试请求。 +/// +internal sealed record DispatcherCacheStreamRequest : IStreamRequest; + +/// +/// 用于验证 pipeline invoker 缓存的测试请求。 +/// +internal sealed record DispatcherPipelineCacheRequest : IRequest; + +/// +/// 处理 。 +/// +internal sealed class DispatcherCacheRequestHandler : IRequestHandler +{ + /// + /// 返回固定结果,供缓存测试验证 dispatcher 请求路径。 + /// + public ValueTask Handle(DispatcherCacheRequest request, CancellationToken cancellationToken) + { + return ValueTask.FromResult(1); + } +} + +/// +/// 处理 。 +/// +internal sealed class DispatcherCacheNotificationHandler : INotificationHandler +{ + /// + /// 消费通知,不执行额外副作用。 + /// + public ValueTask Handle(DispatcherCacheNotification notification, CancellationToken cancellationToken) + { + return ValueTask.CompletedTask; + } +} + +/// +/// 处理 。 +/// +internal sealed class DispatcherCacheStreamHandler : IStreamRequestHandler +{ + /// + /// 返回一个最小流,供缓存测试命中 stream 分发路径。 + /// + public async IAsyncEnumerable Handle( + DispatcherCacheStreamRequest request, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + yield return 1; + await Task.CompletedTask; + } +} + +/// +/// 处理 。 +/// +internal sealed class DispatcherPipelineCacheRequestHandler : IRequestHandler +{ + /// + /// 返回固定结果,供 pipeline 缓存测试使用。 + /// + public ValueTask Handle(DispatcherPipelineCacheRequest request, CancellationToken cancellationToken) + { + return ValueTask.FromResult(2); + } +} + +/// +/// 为 提供最小 pipeline 行为, +/// 用于命中 dispatcher 的 pipeline invoker 缓存分支。 +/// +internal sealed class DispatcherPipelineCacheBehavior : IPipelineBehavior +{ + /// + /// 直接转发到下一个处理器。 + /// + public ValueTask Handle( + DispatcherPipelineCacheRequest request, + MessageHandlerDelegate next, + CancellationToken cancellationToken) + { + return next(request, cancellationToken); + } +} diff --git a/GFramework.Cqrs.Tests/Cqrs/CqrsHandlerRegistrarTests.cs b/GFramework.Cqrs.Tests/Cqrs/CqrsHandlerRegistrarTests.cs index 318f6d21..b44b0bb1 100644 --- a/GFramework.Cqrs.Tests/Cqrs/CqrsHandlerRegistrarTests.cs +++ b/GFramework.Cqrs.Tests/Cqrs/CqrsHandlerRegistrarTests.cs @@ -190,11 +190,11 @@ internal sealed class CqrsHandlerRegistrarTests } /// - /// 验证当生成注册器显式要求 reflection fallback 时,运行时会补扫剩余 handlers, - /// 同时避免把已由生成注册器注册的映射重复写入服务集合。 + /// 验证当生成注册器提供精确 fallback 类型名时,运行时会定向补扫剩余 handlers, + /// 而不是重新枚举整个程序集的类型列表。 /// [Test] - public void RegisterHandlers_Should_Combine_Generated_Registry_With_Reflection_Fallback_Without_Duplicates() + public void RegisterHandlers_Should_Use_Targeted_Type_Lookups_For_Reflection_Fallback_Without_Duplicates() { var generatedAssembly = new Mock(); generatedAssembly @@ -205,13 +205,65 @@ internal sealed class CqrsHandlerRegistrarTests .Returns([new CqrsHandlerRegistryAttribute(typeof(PartialGeneratedNotificationHandlerRegistry))]); generatedAssembly .Setup(static assembly => assembly.GetCustomAttributes(typeof(CqrsReflectionFallbackAttribute), false)) - .Returns([new CqrsReflectionFallbackAttribute()]); - generatedAssembly - .Setup(static assembly => assembly.GetTypes()) .Returns( + [ + new CqrsReflectionFallbackAttribute( + ReflectionFallbackNotificationContainer.ReflectionOnlyHandlerType.FullName!) + ]); + generatedAssembly + .Setup(static assembly => assembly.GetType( + ReflectionFallbackNotificationContainer.ReflectionOnlyHandlerType.FullName!, + false, + false)) + .Returns(ReflectionFallbackNotificationContainer.ReflectionOnlyHandlerType); + + var container = new MicrosoftDiContainer(); + CqrsTestRuntime.RegisterHandlers(container, generatedAssembly.Object); + + var registrations = container.GetServicesUnsafe + .Where(static descriptor => + descriptor.ServiceType == typeof(INotificationHandler) && + descriptor.ImplementationType is not null) + .Select(static descriptor => descriptor.ImplementationType!) + .ToList(); + + Assert.That( + registrations, + Is.EqualTo( [ typeof(GeneratedRegistryNotificationHandler), ReflectionFallbackNotificationContainer.ReflectionOnlyHandlerType + ])); + + generatedAssembly.Verify( + static assembly => assembly.GetType( + ReflectionFallbackNotificationContainer.ReflectionOnlyHandlerType.FullName!, + false, + false), + Times.Once); + generatedAssembly.Verify(static assembly => assembly.GetTypes(), Times.Never); + } + + /// + /// 验证手写 fallback metadata 直接提供 handler 类型时,运行时会复用这些类型, + /// 而不会再通过程序集名称查找或整程序集扫描补齐映射。 + /// + [Test] + public void RegisterHandlers_Should_Use_Direct_Fallback_Types_Without_GetType_Or_GetTypes() + { + var generatedAssembly = new Mock(); + generatedAssembly + .SetupGet(static assembly => assembly.FullName) + .Returns(ReflectionFallbackNotificationContainer.ReflectionOnlyHandlerType.Assembly.FullName); + generatedAssembly + .Setup(static assembly => assembly.GetCustomAttributes(typeof(CqrsHandlerRegistryAttribute), false)) + .Returns([new CqrsHandlerRegistryAttribute(typeof(PartialGeneratedNotificationHandlerRegistry))]); + generatedAssembly + .Setup(static assembly => assembly.GetCustomAttributes(typeof(CqrsReflectionFallbackAttribute), false)) + .Returns( + [ + new CqrsReflectionFallbackAttribute( + ReflectionFallbackNotificationContainer.ReflectionOnlyHandlerType) ]); var container = new MicrosoftDiContainer(); @@ -231,6 +283,14 @@ internal sealed class CqrsHandlerRegistrarTests typeof(GeneratedRegistryNotificationHandler), ReflectionFallbackNotificationContainer.ReflectionOnlyHandlerType ])); + + generatedAssembly.Verify( + static assembly => assembly.GetType( + ReflectionFallbackNotificationContainer.ReflectionOnlyHandlerType.FullName!, + false, + false), + Times.Never); + generatedAssembly.Verify(static assembly => assembly.GetTypes(), Times.Never); } } diff --git a/GFramework.Cqrs/CqrsReflectionFallbackAttribute.cs b/GFramework.Cqrs/CqrsReflectionFallbackAttribute.cs index f18a7344..da557d84 100644 --- a/GFramework.Cqrs/CqrsReflectionFallbackAttribute.cs +++ b/GFramework.Cqrs/CqrsReflectionFallbackAttribute.cs @@ -11,4 +11,61 @@ namespace GFramework.Cqrs; [AttributeUsage(AttributeTargets.Assembly)] public sealed class CqrsReflectionFallbackAttribute : Attribute { + /// + /// 初始化 ,保留旧版“仅标记需要补扫”的语义。 + /// + public CqrsReflectionFallbackAttribute() + { + FallbackHandlerTypeNames = []; + FallbackHandlerTypes = []; + } + + /// + /// 初始化 。 + /// + /// + /// 需要运行时补充反射注册的处理器类型全名。 + /// 当该清单为空时,运行时会回退到整程序集扫描,以兼容旧版 marker 语义。 + /// + public CqrsReflectionFallbackAttribute(params string[] fallbackHandlerTypeNames) + { + ArgumentNullException.ThrowIfNull(fallbackHandlerTypeNames); + + FallbackHandlerTypeNames = fallbackHandlerTypeNames + .Where(static typeName => !string.IsNullOrWhiteSpace(typeName)) + .Distinct(StringComparer.Ordinal) + .OrderBy(static typeName => typeName, StringComparer.Ordinal) + .ToArray(); + FallbackHandlerTypes = []; + } + + /// + /// 初始化 。 + /// + /// + /// 需要运行时补充反射注册的处理器类型。 + /// 该重载适合手写或第三方程序集显式声明可直接引用的 fallback handlers, + /// 避免再通过字符串名称回查程序集元数据。 + /// + public CqrsReflectionFallbackAttribute(params Type[] fallbackHandlerTypes) + { + ArgumentNullException.ThrowIfNull(fallbackHandlerTypes); + + FallbackHandlerTypeNames = []; + FallbackHandlerTypes = fallbackHandlerTypes + .Where(static type => type is not null) + .Distinct() + .OrderBy(static type => type.FullName ?? type.Name, StringComparer.Ordinal) + .ToArray(); + } + + /// + /// 获取需要运行时补充反射注册的处理器类型全名集合。 + /// + public IReadOnlyList FallbackHandlerTypeNames { get; } + + /// + /// 获取可直接供运行时补充反射注册的处理器类型集合。 + /// + public IReadOnlyList FallbackHandlerTypes { get; } } diff --git a/GFramework.Cqrs/GlobalUsings.cs b/GFramework.Cqrs/GlobalUsings.cs index b60938a5..3085d1e1 100644 --- a/GFramework.Cqrs/GlobalUsings.cs +++ b/GFramework.Cqrs/GlobalUsings.cs @@ -6,3 +6,4 @@ global using System.Threading.Tasks; global using System.Reflection; global using Microsoft.Extensions.DependencyInjection; global using System.Diagnostics; +global using System.Collections.Concurrent; diff --git a/GFramework.Cqrs/Internal/CqrsDispatcher.cs b/GFramework.Cqrs/Internal/CqrsDispatcher.cs index 9a125789..002b7edc 100644 --- a/GFramework.Cqrs/Internal/CqrsDispatcher.cs +++ b/GFramework.Cqrs/Internal/CqrsDispatcher.cs @@ -1,5 +1,3 @@ -using System.Collections.Concurrent; -using System.Reflection; using GFramework.Core.Abstractions.Architectures; using GFramework.Core.Abstractions.Ioc; using GFramework.Core.Abstractions.Logging; @@ -30,10 +28,35 @@ internal sealed class CqrsDispatcher( // 进程级缓存:缓存通知调用委托,复用并发安全字典以支撑多线程发布路径。 private static readonly ConcurrentDictionary NotificationInvokers = new(); + // 进程级缓存:缓存通知处理器服务类型,避免每次发布都重复 MakeGenericType。 + private static readonly ConcurrentDictionary NotificationHandlerServiceTypes = new(); + // 进程级缓存:缓存流式请求调用委托,避免每次创建流时重复解析反射签名。 private static readonly ConcurrentDictionary<(Type RequestType, Type ResponseType), StreamInvoker> StreamInvokers = new(); + // 进程级缓存:缓存请求处理器与 pipeline 行为的服务类型,减少热路径中的泛型类型构造。 + private static readonly ConcurrentDictionary<(Type RequestType, Type ResponseType), RequestServiceTypeSet> + RequestServiceTypes = new(); + + // 进程级缓存:缓存流式请求处理器服务类型,避免每次建流时重复 MakeGenericType。 + private static readonly ConcurrentDictionary<(Type RequestType, Type ResponseType), Type> + StreamHandlerServiceTypes = + new(); + + // 静态方法定义缓存:这些反射查找与消息类型无关,只需解析一次即可复用。 + private static readonly MethodInfo RequestHandlerInvokerMethodDefinition = typeof(CqrsDispatcher) + .GetMethod(nameof(InvokeRequestHandlerAsync), BindingFlags.NonPublic | BindingFlags.Static)!; + + private static readonly MethodInfo RequestPipelineInvokerMethodDefinition = typeof(CqrsDispatcher) + .GetMethod(nameof(InvokeRequestPipelineAsync), BindingFlags.NonPublic | BindingFlags.Static)!; + + private static readonly MethodInfo NotificationHandlerInvokerMethodDefinition = typeof(CqrsDispatcher) + .GetMethod(nameof(InvokeNotificationHandlerAsync), BindingFlags.NonPublic | BindingFlags.Static)!; + + private static readonly MethodInfo StreamHandlerInvokerMethodDefinition = typeof(CqrsDispatcher) + .GetMethod(nameof(InvokeStreamHandler), BindingFlags.NonPublic | BindingFlags.Static)!; + /// /// 发布通知到所有已注册处理器。 /// @@ -51,7 +74,9 @@ internal sealed class CqrsDispatcher( ArgumentNullException.ThrowIfNull(notification); var notificationType = notification.GetType(); - var handlerType = typeof(INotificationHandler<>).MakeGenericType(notificationType); + var handlerType = NotificationHandlerServiceTypes.GetOrAdd( + notificationType, + static type => typeof(INotificationHandler<>).MakeGenericType(type)); var handlers = container.GetAll(handlerType); if (handlers.Count == 0) @@ -88,14 +113,18 @@ internal sealed class CqrsDispatcher( ArgumentNullException.ThrowIfNull(request); var requestType = request.GetType(); - var handlerType = typeof(IRequestHandler<,>).MakeGenericType(requestType, typeof(TResponse)); + var serviceTypes = RequestServiceTypes.GetOrAdd( + (requestType, typeof(TResponse)), + static key => new RequestServiceTypeSet( + typeof(IRequestHandler<,>).MakeGenericType(key.RequestType, key.ResponseType), + typeof(IPipelineBehavior<,>).MakeGenericType(key.RequestType, key.ResponseType))); + var handlerType = serviceTypes.HandlerType; var handler = container.Get(handlerType) ?? throw new InvalidOperationException( $"No CQRS request handler registered for {requestType.FullName}."); PrepareHandler(handler, context); - var behaviorType = typeof(IPipelineBehavior<,>).MakeGenericType(requestType, typeof(TResponse)); - var behaviors = container.GetAll(behaviorType); + var behaviors = container.GetAll(serviceTypes.BehaviorType); foreach (var behavior in behaviors) PrepareHandler(behavior, context); @@ -135,7 +164,9 @@ internal sealed class CqrsDispatcher( ArgumentNullException.ThrowIfNull(request); var requestType = request.GetType(); - var handlerType = typeof(IStreamRequestHandler<,>).MakeGenericType(requestType, typeof(TResponse)); + var handlerType = StreamHandlerServiceTypes.GetOrAdd( + (requestType, typeof(TResponse)), + static key => typeof(IStreamRequestHandler<,>).MakeGenericType(key.RequestType, key.ResponseType)); var handler = container.Get(handlerType) ?? throw new InvalidOperationException( $"No CQRS stream handler registered for {requestType.FullName}."); @@ -171,8 +202,7 @@ internal sealed class CqrsDispatcher( /// private static RequestInvoker CreateRequestInvoker(Type requestType, Type responseType) { - var method = typeof(CqrsDispatcher) - .GetMethod(nameof(InvokeRequestHandlerAsync), BindingFlags.NonPublic | BindingFlags.Static)! + var method = RequestHandlerInvokerMethodDefinition .MakeGenericMethod(requestType, responseType); return (RequestInvoker)Delegate.CreateDelegate(typeof(RequestInvoker), method); } @@ -182,8 +212,7 @@ internal sealed class CqrsDispatcher( /// private static RequestPipelineInvoker CreateRequestPipelineInvoker(Type requestType, Type responseType) { - var method = typeof(CqrsDispatcher) - .GetMethod(nameof(InvokeRequestPipelineAsync), BindingFlags.NonPublic | BindingFlags.Static)! + var method = RequestPipelineInvokerMethodDefinition .MakeGenericMethod(requestType, responseType); return (RequestPipelineInvoker)Delegate.CreateDelegate(typeof(RequestPipelineInvoker), method); } @@ -193,8 +222,7 @@ internal sealed class CqrsDispatcher( /// private static NotificationInvoker CreateNotificationInvoker(Type notificationType) { - var method = typeof(CqrsDispatcher) - .GetMethod(nameof(InvokeNotificationHandlerAsync), BindingFlags.NonPublic | BindingFlags.Static)! + var method = NotificationHandlerInvokerMethodDefinition .MakeGenericMethod(notificationType); return (NotificationInvoker)Delegate.CreateDelegate(typeof(NotificationInvoker), method); } @@ -204,8 +232,7 @@ internal sealed class CqrsDispatcher( /// private static StreamInvoker CreateStreamInvoker(Type requestType, Type responseType) { - var method = typeof(CqrsDispatcher) - .GetMethod(nameof(InvokeStreamHandler), BindingFlags.NonPublic | BindingFlags.Static)! + var method = StreamHandlerInvokerMethodDefinition .MakeGenericMethod(requestType, responseType); return (StreamInvoker)Delegate.CreateDelegate(typeof(StreamInvoker), method); } @@ -293,4 +320,6 @@ internal sealed class CqrsDispatcher( CancellationToken cancellationToken); private delegate object StreamInvoker(object handler, object request, CancellationToken cancellationToken); + + private readonly record struct RequestServiceTypeSet(Type HandlerType, Type BehaviorType); } diff --git a/GFramework.Cqrs/Internal/CqrsHandlerRegistrar.cs b/GFramework.Cqrs/Internal/CqrsHandlerRegistrar.cs index 867c6887..3604de83 100644 --- a/GFramework.Cqrs/Internal/CqrsHandlerRegistrar.cs +++ b/GFramework.Cqrs/Internal/CqrsHandlerRegistrar.cs @@ -33,10 +33,14 @@ internal static class CqrsHandlerRegistrar { var generatedRegistrationResult = TryRegisterGeneratedHandlers(container.GetServicesUnsafe, assembly, logger); - if (generatedRegistrationResult == GeneratedRegistrationResult.FullyHandled) + if (generatedRegistrationResult is { UsedGeneratedRegistry: true, RequiresReflectionFallback: false }) continue; - RegisterAssemblyHandlers(container.GetServicesUnsafe, assembly, logger); + RegisterAssemblyHandlers( + container.GetServicesUnsafe, + assembly, + logger, + generatedRegistrationResult.ReflectionFallbackMetadata); } } @@ -66,7 +70,7 @@ internal static class CqrsHandlerRegistrar .ToList(); if (registryTypes.Count == 0) - return GeneratedRegistrationResult.NoGeneratedRegistry; + return GeneratedRegistrationResult.NoGeneratedRegistry(); var registries = new List(registryTypes.Count); foreach (var registryType in registryTypes) @@ -75,21 +79,21 @@ internal static class CqrsHandlerRegistrar { logger.Warn( $"Ignoring generated CQRS handler registry {registryType.FullName} in assembly {assemblyName} because it does not implement {typeof(ICqrsHandlerRegistry).FullName}."); - return GeneratedRegistrationResult.NoGeneratedRegistry; + return GeneratedRegistrationResult.NoGeneratedRegistry(); } if (registryType.IsAbstract) { logger.Warn( $"Ignoring generated CQRS handler registry {registryType.FullName} in assembly {assemblyName} because it is abstract."); - return GeneratedRegistrationResult.NoGeneratedRegistry; + return GeneratedRegistrationResult.NoGeneratedRegistry(); } if (Activator.CreateInstance(registryType, nonPublic: true) is not ICqrsHandlerRegistry registry) { logger.Warn( $"Ignoring generated CQRS handler registry {registryType.FullName} in assembly {assemblyName} because it could not be instantiated."); - return GeneratedRegistrationResult.NoGeneratedRegistry; + return GeneratedRegistrationResult.NoGeneratedRegistry(); } registries.Add(registry); @@ -102,14 +106,24 @@ internal static class CqrsHandlerRegistrar registry.Register(services, logger); } - if (RequiresReflectionFallback(assembly)) + var reflectionFallbackMetadata = GetReflectionFallbackMetadata(assembly, logger); + if (reflectionFallbackMetadata is not null) { - logger.Debug( - $"Generated CQRS registry for assembly {assemblyName} requested reflection fallback for unsupported handlers."); - return GeneratedRegistrationResult.RequiresReflectionFallback; + if (reflectionFallbackMetadata.HasExplicitTypes) + { + logger.Debug( + $"Generated CQRS registry for assembly {assemblyName} requested targeted reflection fallback for {reflectionFallbackMetadata.Types.Count} unsupported handler type(s)."); + } + else + { + logger.Debug( + $"Generated CQRS registry for assembly {assemblyName} requested full reflection fallback for unsupported handlers."); + } + + return GeneratedRegistrationResult.WithReflectionFallback(reflectionFallbackMetadata); } - return GeneratedRegistrationResult.FullyHandled; + return GeneratedRegistrationResult.FullyHandled(); } catch (Exception exception) { @@ -117,16 +131,21 @@ internal static class CqrsHandlerRegistrar $"Generated CQRS handler registry discovery failed for assembly {assemblyName}. Falling back to reflection scan."); logger.Warn( $"Failed to use generated CQRS handler registry for assembly {assemblyName}: {exception.Message}"); - return GeneratedRegistrationResult.NoGeneratedRegistry; + return GeneratedRegistrationResult.NoGeneratedRegistry(); } } /// /// 注册单个程序集里的所有 CQRS 处理器映射。 /// - private static void RegisterAssemblyHandlers(IServiceCollection services, Assembly assembly, ILogger logger) + private static void RegisterAssemblyHandlers( + IServiceCollection services, + Assembly assembly, + ILogger logger, + ReflectionFallbackMetadata? reflectionFallbackMetadata) { - foreach (var implementationType in GetLoadableTypes(assembly, logger).Where(IsConcreteHandlerType)) + foreach (var implementationType in GetCandidateHandlerTypes(assembly, logger, reflectionFallbackMetadata) + .Where(IsConcreteHandlerType)) { var handlerInterfaces = implementationType .GetInterfaces() @@ -155,6 +174,87 @@ internal static class CqrsHandlerRegistrar } } + /// + /// 根据生成器提供的 fallback 清单或整程序集扫描结果,获取本轮要注册的候选处理器类型。 + /// + private static IReadOnlyList GetCandidateHandlerTypes( + Assembly assembly, + ILogger logger, + ReflectionFallbackMetadata? reflectionFallbackMetadata) + { + return reflectionFallbackMetadata is { HasExplicitTypes: true } + ? reflectionFallbackMetadata.Types + : GetLoadableTypes(assembly, logger); + } + + /// + /// 获取生成注册器要求运行时继续补充反射扫描的 handler 元数据。 + /// + private static ReflectionFallbackMetadata? GetReflectionFallbackMetadata( + Assembly assembly, + ILogger logger) + { + var assemblyName = GetAssemblySortKey(assembly); + var fallbackAttributes = assembly + .GetCustomAttributes(typeof(CqrsReflectionFallbackAttribute), inherit: false) + .OfType() + .ToList(); + + if (fallbackAttributes.Count == 0) + return null; + + var resolvedTypes = new List(); + foreach (var fallbackType in fallbackAttributes + .SelectMany(static attribute => attribute.FallbackHandlerTypes) + .Where(static type => type is not null) + .Distinct() + .OrderBy(GetTypeSortKey, StringComparer.Ordinal)) + { + if (!string.Equals( + GetAssemblySortKey(fallbackType.Assembly), + assemblyName, + StringComparison.Ordinal)) + { + logger.Warn( + $"Generated CQRS reflection fallback type {fallbackType.FullName} was declared on assembly {assemblyName} but belongs to assembly {GetAssemblySortKey(fallbackType.Assembly)}. Skipping mismatched fallback entry."); + continue; + } + + resolvedTypes.Add(fallbackType); + } + + foreach (var typeName in fallbackAttributes + .SelectMany(static attribute => attribute.FallbackHandlerTypeNames) + .Where(static name => !string.IsNullOrWhiteSpace(name)) + .Distinct(StringComparer.Ordinal) + .OrderBy(static name => name, StringComparer.Ordinal)) + { + try + { + var type = assembly.GetType(typeName, throwOnError: false, ignoreCase: false); + if (type is null) + { + logger.Warn( + $"Generated CQRS reflection fallback type {typeName} could not be resolved in assembly {assemblyName}. Skipping targeted fallback entry."); + continue; + } + + resolvedTypes.Add(type); + } + catch (Exception exception) + { + logger.Warn( + $"Generated CQRS reflection fallback type {typeName} failed to load in assembly {assemblyName}: {exception.Message}"); + } + } + + return new ReflectionFallbackMetadata( + resolvedTypes + .Distinct() + .OrderBy(GetTypeSortKey, StringComparer.Ordinal) + .ToArray()); + } + /// /// 安全获取程序集中的可加载类型,并在部分类型加载失败时保留其余处理器注册能力。 /// @@ -220,14 +320,6 @@ internal static class CqrsHandlerRegistrar definition == typeof(IStreamRequestHandler<,>); } - /// - /// 判断生成注册器是否要求运行时继续补充反射扫描。 - /// - private static bool RequiresReflectionFallback(Assembly assembly) - { - return assembly.GetCustomAttributes(typeof(CqrsReflectionFallbackAttribute), inherit: false)?.Length > 0; - } - /// /// 判断同一 handler 映射是否已经由生成注册器或先前扫描步骤写入服务集合。 /// @@ -259,10 +351,43 @@ internal static class CqrsHandlerRegistrar return type.FullName ?? type.Name; } - private enum GeneratedRegistrationResult + private readonly record struct GeneratedRegistrationResult( + bool UsedGeneratedRegistry, + bool RequiresReflectionFallback, + ReflectionFallbackMetadata? ReflectionFallbackMetadata) { - NoGeneratedRegistry, - FullyHandled, - RequiresReflectionFallback + public static GeneratedRegistrationResult NoGeneratedRegistry() + { + return new GeneratedRegistrationResult( + UsedGeneratedRegistry: false, + RequiresReflectionFallback: false, + ReflectionFallbackMetadata: null); + } + + public static GeneratedRegistrationResult FullyHandled() + { + return new GeneratedRegistrationResult( + UsedGeneratedRegistry: true, + RequiresReflectionFallback: false, + ReflectionFallbackMetadata: null); + } + + public static GeneratedRegistrationResult WithReflectionFallback( + ReflectionFallbackMetadata reflectionFallbackMetadata) + { + ArgumentNullException.ThrowIfNull(reflectionFallbackMetadata); + + return new GeneratedRegistrationResult( + UsedGeneratedRegistry: true, + RequiresReflectionFallback: true, + ReflectionFallbackMetadata: reflectionFallbackMetadata); + } + } + + private sealed class ReflectionFallbackMetadata(IReadOnlyList types) + { + public IReadOnlyList Types { get; } = types ?? throw new ArgumentNullException(nameof(types)); + + public bool HasExplicitTypes => Types.Count > 0; } } diff --git a/GFramework.SourceGenerators.Tests/Cqrs/CqrsHandlerRegistryGeneratorTests.cs b/GFramework.SourceGenerators.Tests/Cqrs/CqrsHandlerRegistryGeneratorTests.cs index 0cd91844..dcdb5e5f 100644 --- a/GFramework.SourceGenerators.Tests/Cqrs/CqrsHandlerRegistryGeneratorTests.cs +++ b/GFramework.SourceGenerators.Tests/Cqrs/CqrsHandlerRegistryGeneratorTests.cs @@ -10,6 +10,138 @@ namespace GFramework.SourceGenerators.Tests.Cqrs; [TestFixture] public class CqrsHandlerRegistryGeneratorTests { + private const string HiddenNestedHandlerSelfRegistrationExpected = """ + // + #nullable enable + + [assembly: global::GFramework.Cqrs.CqrsHandlerRegistryAttribute(typeof(global::GFramework.Generated.Cqrs.__GFrameworkGeneratedCqrsHandlerRegistry))] + + namespace GFramework.Generated.Cqrs; + + internal sealed class __GFrameworkGeneratedCqrsHandlerRegistry : global::GFramework.Cqrs.ICqrsHandlerRegistry + { + public void Register(global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, global::GFramework.Core.Abstractions.Logging.ILogger logger) + { + if (services is null) + throw new global::System.ArgumentNullException(nameof(services)); + if (logger is null) + throw new global::System.ArgumentNullException(nameof(logger)); + + var registryAssembly = typeof(global::GFramework.Generated.Cqrs.__GFrameworkGeneratedCqrsHandlerRegistry).Assembly; + + RegisterReflectedHandler(services, logger, registryAssembly, "TestApp.Container+HiddenHandler"); + global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient( + services, + typeof(global::GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler), + typeof(global::TestApp.VisibleHandler)); + logger.Debug("Registered CQRS handler TestApp.VisibleHandler as GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler."); + } + + private static void RegisterReflectedHandler(global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, global::GFramework.Core.Abstractions.Logging.ILogger logger, global::System.Reflection.Assembly registryAssembly, string implementationTypeMetadataName) + { + var implementationType = registryAssembly.GetType(implementationTypeMetadataName, throwOnError: false, ignoreCase: false); + if (implementationType is null) + return; + + var handlerInterfaces = implementationType.GetInterfaces(); + global::System.Array.Sort(handlerInterfaces, CompareTypes); + + foreach (var handlerInterface in handlerInterfaces) + { + if (!IsSupportedHandlerInterface(handlerInterface)) + continue; + + global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient( + services, + handlerInterface, + implementationType); + logger.Debug($"Registered CQRS handler {GetRuntimeTypeDisplayName(implementationType)} as {GetRuntimeTypeDisplayName(handlerInterface)}."); + } + } + + private static int CompareTypes(global::System.Type left, global::System.Type right) + { + return global::System.StringComparer.Ordinal.Compare(GetRuntimeTypeDisplayName(left), GetRuntimeTypeDisplayName(right)); + } + + private static bool IsSupportedHandlerInterface(global::System.Type interfaceType) + { + if (!interfaceType.IsGenericType) + return false; + + var definitionFullName = interfaceType.GetGenericTypeDefinition().FullName; + return global::System.StringComparer.Ordinal.Equals(definitionFullName, "GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler`2") + || global::System.StringComparer.Ordinal.Equals(definitionFullName, "GFramework.Cqrs.Abstractions.Cqrs.INotificationHandler`1") + || global::System.StringComparer.Ordinal.Equals(definitionFullName, "GFramework.Cqrs.Abstractions.Cqrs.IStreamRequestHandler`2"); + } + + private static string GetRuntimeTypeDisplayName(global::System.Type type) + { + if (type == typeof(string)) + return "string"; + if (type == typeof(int)) + return "int"; + if (type == typeof(long)) + return "long"; + if (type == typeof(short)) + return "short"; + if (type == typeof(byte)) + return "byte"; + if (type == typeof(bool)) + return "bool"; + if (type == typeof(object)) + return "object"; + if (type == typeof(void)) + return "void"; + if (type == typeof(uint)) + return "uint"; + if (type == typeof(ulong)) + return "ulong"; + if (type == typeof(ushort)) + return "ushort"; + if (type == typeof(sbyte)) + return "sbyte"; + if (type == typeof(float)) + return "float"; + if (type == typeof(double)) + return "double"; + if (type == typeof(decimal)) + return "decimal"; + if (type == typeof(char)) + return "char"; + + if (type.IsArray) + return GetRuntimeTypeDisplayName(type.GetElementType()!) + "[]"; + + if (!type.IsGenericType) + return (type.FullName ?? type.Name).Replace('+', '.'); + + var genericTypeName = type.GetGenericTypeDefinition().FullName ?? type.Name; + var arityIndex = genericTypeName.IndexOf('`'); + if (arityIndex >= 0) + genericTypeName = genericTypeName[..arityIndex]; + + genericTypeName = genericTypeName.Replace('+', '.'); + var arguments = type.GetGenericArguments(); + var builder = new global::System.Text.StringBuilder(); + builder.Append(genericTypeName); + builder.Append('<'); + + for (var index = 0; index < arguments.Length; index++) + { + if (index > 0) + builder.Append(", "); + + builder.Append(GetRuntimeTypeDisplayName(arguments[index])); + } + + builder.Append('>'); + return builder.ToString(); + } + } + + """; + /// /// 验证生成器会为当前程序集中的 request、notification 和 stream 处理器生成稳定顺序的注册器。 /// @@ -65,6 +197,7 @@ public class CqrsHandlerRegistryGeneratorTests [AttributeUsage(AttributeTargets.Assembly)] public sealed class CqrsReflectionFallbackAttribute : Attribute { + public CqrsReflectionFallbackAttribute(params string[] fallbackHandlerTypeNames) { } } } @@ -125,12 +258,12 @@ public class CqrsHandlerRegistryGeneratorTests } /// - /// 验证当程序集包含生成代码无法合法引用的私有嵌套处理器时,生成器仍会为可见 handlers 生成注册器, - /// 并额外标记运行时补充反射扫描。 + /// 验证当程序集包含生成代码无法合法引用的私有嵌套处理器时,生成器会在生成注册器内部执行定向反射注册, + /// 不再依赖程序集级 fallback marker。 /// [Test] public async Task - Generates_Visible_Handlers_And_Requests_Reflection_Fallback_When_Assembly_Contains_Private_Nested_Handler() + Generates_Visible_Handlers_And_Self_Registers_Private_Nested_Handler_When_Assembly_Contains_Hidden_Handler() { const string source = """ using System; @@ -180,6 +313,7 @@ public class CqrsHandlerRegistryGeneratorTests [AttributeUsage(AttributeTargets.Assembly)] public sealed class CqrsReflectionFallbackAttribute : Attribute { + public CqrsReflectionFallbackAttribute(params string[] fallbackHandlerTypeNames) { } } } @@ -200,45 +334,98 @@ public class CqrsHandlerRegistryGeneratorTests } """; - const string expected = """ - // - #nullable enable - - [assembly: global::GFramework.Cqrs.CqrsHandlerRegistryAttribute(typeof(global::GFramework.Generated.Cqrs.__GFrameworkGeneratedCqrsHandlerRegistry))] - [assembly: global::GFramework.Cqrs.CqrsReflectionFallbackAttribute()] - - namespace GFramework.Generated.Cqrs; - - internal sealed class __GFrameworkGeneratedCqrsHandlerRegistry : global::GFramework.Cqrs.ICqrsHandlerRegistry - { - public void Register(global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, global::GFramework.Core.Abstractions.Logging.ILogger logger) - { - if (services is null) - throw new global::System.ArgumentNullException(nameof(services)); - if (logger is null) - throw new global::System.ArgumentNullException(nameof(logger)); - - global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient( - services, - typeof(global::GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler), - typeof(global::TestApp.VisibleHandler)); - logger.Debug("Registered CQRS handler TestApp.VisibleHandler as GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler."); - } - } - - """; - await GeneratorTest.RunAsync( source, - ("CqrsHandlerRegistry.g.cs", expected)); + ("CqrsHandlerRegistry.g.cs", HiddenNestedHandlerSelfRegistrationExpected)); } /// - /// 验证当旧版 runtime 合同中不存在 reflection fallback 标记特性时, - /// 生成器会保留此前的整程序集回退行为,避免丢失不可见 handlers。 + /// 验证即使 runtime 仍暴露旧版无参 fallback marker,生成器也会优先在生成注册器内部处理隐藏 handler, + /// 不再输出 fallback marker。 /// [Test] - public async Task Skips_Generation_For_Unsupported_Handler_When_Fallback_Marker_Is_Unavailable() + public async Task Does_Not_Emit_Legacy_Fallback_Marker_When_Generated_Registry_Can_Self_Register_Hidden_Handler() + { + const string source = """ + using System; + + namespace Microsoft.Extensions.DependencyInjection + { + public interface IServiceCollection { } + + public static class ServiceCollectionServiceExtensions + { + public static void AddTransient(IServiceCollection services, Type serviceType, Type implementationType) { } + } + } + + namespace GFramework.Core.Abstractions.Logging + { + public interface ILogger + { + void Debug(string msg); + } + } + + namespace GFramework.Cqrs.Abstractions.Cqrs + { + public interface IRequest { } + public interface INotification { } + public interface IStreamRequest { } + + public interface IRequestHandler where TRequest : IRequest { } + public interface INotificationHandler where TNotification : INotification { } + public interface IStreamRequestHandler where TRequest : IStreamRequest { } + } + + namespace GFramework.Cqrs + { + public interface ICqrsHandlerRegistry + { + void Register(Microsoft.Extensions.DependencyInjection.IServiceCollection services, GFramework.Core.Abstractions.Logging.ILogger logger); + } + + [AttributeUsage(AttributeTargets.Assembly, AllowMultiple = true)] + public sealed class CqrsHandlerRegistryAttribute : Attribute + { + public CqrsHandlerRegistryAttribute(Type registryType) { } + } + + [AttributeUsage(AttributeTargets.Assembly)] + public sealed class CqrsReflectionFallbackAttribute : Attribute + { + public CqrsReflectionFallbackAttribute() { } + } + } + + namespace TestApp + { + using GFramework.Cqrs.Abstractions.Cqrs; + + public sealed record VisibleRequest() : IRequest; + + public sealed class Container + { + private sealed record HiddenRequest() : IRequest; + + private sealed class HiddenHandler : IRequestHandler { } + } + + public sealed class VisibleHandler : IRequestHandler { } + } + """; + + await GeneratorTest.RunAsync( + source, + ("CqrsHandlerRegistry.g.cs", HiddenNestedHandlerSelfRegistrationExpected)); + } + + /// + /// 验证即使 runtime 合同中完全不存在 reflection fallback 标记特性, + /// 生成器仍能通过生成注册器内部的定向反射逻辑覆盖隐藏 handler。 + /// + [Test] + public async Task Generates_Registry_For_Hidden_Handler_When_Fallback_Marker_Is_Unavailable() { const string source = """ using System; @@ -303,16 +490,9 @@ public class CqrsHandlerRegistryGeneratorTests } """; - var test = new CSharpSourceGeneratorTest - { - TestState = - { - Sources = { source } - }, - DisabledDiagnostics = { "GF_Common_Trace_001" } - }; - - await test.RunAsync(); + await GeneratorTest.RunAsync( + source, + ("CqrsHandlerRegistry.g.cs", HiddenNestedHandlerSelfRegistrationExpected)); } /// diff --git a/GFramework.SourceGenerators/Cqrs/CqrsHandlerRegistryGenerator.cs b/GFramework.SourceGenerators/Cqrs/CqrsHandlerRegistryGenerator.cs index 80561248..83559781 100644 --- a/GFramework.SourceGenerators/Cqrs/CqrsHandlerRegistryGenerator.cs +++ b/GFramework.SourceGenerators/Cqrs/CqrsHandlerRegistryGenerator.cs @@ -16,9 +16,6 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator private const string IStreamRequestHandlerMetadataName = $"{CqrsContractsNamespace}.IStreamRequestHandler`2"; private const string ICqrsHandlerRegistryMetadataName = $"{CqrsRuntimeNamespace}.ICqrsHandlerRegistry"; - private const string CqrsReflectionFallbackAttributeMetadataName = - $"{CqrsRuntimeNamespace}.CqrsReflectionFallbackAttribute"; - private const string CqrsHandlerRegistryAttributeMetadataName = $"{CqrsRuntimeNamespace}.CqrsHandlerRegistryAttribute"; @@ -57,9 +54,7 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator compilation.GetTypeByMetadataName(ILoggerMetadataName) is not null && compilation.GetTypeByMetadataName(IServiceCollectionMetadataName) is not null; - return new GenerationEnvironment( - generationEnabled, - compilation.GetTypeByMetadataName(CqrsReflectionFallbackAttributeMetadataName) is not null); + return new GenerationEnvironment(generationEnabled); } private static bool IsHandlerCandidate(SyntaxNode node) @@ -90,16 +85,20 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator return null; var implementationTypeDisplayName = type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + var implementationLogName = GetLogDisplayName(type); if (!CanReferenceFromGeneratedRegistry(type) || handlerInterfaces.Any(interfaceType => !CanReferenceFromGeneratedRegistry(interfaceType))) { + // Non-public handlers and handlers closed over non-public message types cannot appear in typeof(...) + // expressions inside generated code. Preserve generator hit rate by resolving just that implementation + // type back from the current assembly instead of asking the runtime registrar to rescan the assembly. return new HandlerCandidateAnalysis( implementationTypeDisplayName, + implementationLogName, ImmutableArray.Empty, - true); + GetReflectionTypeMetadataName(type)); } - var implementationLogName = GetLogDisplayName(type); var registrations = ImmutableArray.CreateBuilder(handlerInterfaces.Length); foreach (var handlerInterface in handlerInterfaces) { @@ -112,8 +111,9 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator return new HandlerCandidateAnalysis( implementationTypeDisplayName, + implementationLogName, registrations.MoveToImmutable(), - false); + null); } private static void Execute(SourceProductionContext context, GenerationEnvironment generationEnvironment, @@ -122,30 +122,23 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator if (!generationEnvironment.GenerationEnabled) return; - var registrations = CollectRegistrations(candidates, out var hasUnsupportedConcreteHandler); + var registrations = CollectRegistrations(candidates); if (registrations.Count == 0) return; - // If the runtime contract does not yet expose the reflection fallback marker, - // keep the previous all-or-nothing behavior so unsupported handlers are not silently dropped. - if (hasUnsupportedConcreteHandler && !generationEnvironment.SupportsReflectionFallbackMarker) - return; - context.AddSource( HintName, - GenerateSource(registrations, hasUnsupportedConcreteHandler)); + GenerateSource(registrations)); } - private static List CollectRegistrations( - ImmutableArray candidates, - out bool hasUnsupportedConcreteHandler) + private static List CollectRegistrations( + ImmutableArray candidates) { - var registrations = new List(); - hasUnsupportedConcreteHandler = false; + var registrations = new List(); // Partial declarations surface the same symbol through multiple syntax nodes. - // Collapse them by implementation type so generated registrations stay stable and duplicate-free. + // Collapse them by implementation type so direct and reflected registrations stay stable and duplicate-free. var uniqueCandidates = new Dictionary(StringComparer.Ordinal); foreach (var candidate in candidates) @@ -153,18 +146,16 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator if (candidate is null) continue; - if (candidate.Value.HasUnsupportedConcreteHandler) - { - hasUnsupportedConcreteHandler = true; - continue; - } - uniqueCandidates[candidate.Value.ImplementationTypeDisplayName] = candidate.Value; } foreach (var candidate in uniqueCandidates.Values) { - registrations.AddRange(candidate.Registrations); + registrations.Add(new ImplementationRegistrationSpec( + candidate.ImplementationTypeDisplayName, + candidate.ImplementationLogName, + candidate.Registrations, + candidate.ReflectionTypeMetadataName)); } registrations.Sort(static (left, right) => @@ -173,9 +164,7 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator left.ImplementationLogName, right.ImplementationLogName); - return implementationComparison != 0 - ? implementationComparison - : StringComparer.Ordinal.Compare(left.HandlerInterfaceLogName, right.HandlerInterfaceLogName); + return implementationComparison; }); return registrations; @@ -272,6 +261,34 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator return builder.ToString(); } + private static string GetReflectionTypeMetadataName(INamedTypeSymbol type) + { + var nestedTypes = new Stack(); + for (var current = type; current is not null; current = current.ContainingType) + { + nestedTypes.Push(current.MetadataName); + } + + var builder = new StringBuilder(); + if (!type.ContainingNamespace.IsGlobalNamespace) + { + builder.Append(type.ContainingNamespace.ToDisplayString()); + builder.Append('.'); + } + + var isFirstType = true; + while (nestedTypes.Count > 0) + { + if (!isFirstType) + builder.Append('+'); + + builder.Append(nestedTypes.Pop()); + isFirstType = false; + } + + return builder.ToString(); + } + private static string GetTypeSortKey(ITypeSymbol type) { return type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); @@ -283,9 +300,10 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator } private static string GenerateSource( - IReadOnlyList registrations, - bool emitReflectionFallbackAttribute) + IReadOnlyList registrations) { + var hasReflectionRegistrations = registrations.Any(static registration => + !string.IsNullOrWhiteSpace(registration.ReflectionTypeMetadataName)); var builder = new StringBuilder(); builder.AppendLine("// "); builder.AppendLine("#nullable enable"); @@ -297,12 +315,6 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator builder.Append('.'); builder.Append(GeneratedTypeName); builder.AppendLine("))]"); - if (emitReflectionFallbackAttribute) - { - builder.Append("[assembly: global::"); - builder.Append(CqrsRuntimeNamespace); - builder.AppendLine(".CqrsReflectionFallbackAttribute()]"); - } builder.AppendLine(); builder.Append("namespace "); @@ -324,31 +336,179 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator builder.AppendLine(" throw new global::System.ArgumentNullException(nameof(services));"); builder.AppendLine(" if (logger is null)"); builder.AppendLine(" throw new global::System.ArgumentNullException(nameof(logger));"); - builder.AppendLine(); + if (hasReflectionRegistrations) + { + builder.AppendLine(); + builder.Append(" var registryAssembly = typeof(global::"); + builder.Append(GeneratedNamespace); + builder.Append('.'); + builder.Append(GeneratedTypeName); + builder.AppendLine(").Assembly;"); + } + + if (registrations.Count > 0) + builder.AppendLine(); foreach (var registration in registrations) { - builder.AppendLine( - " global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient("); - builder.AppendLine(" services,"); - builder.Append(" typeof("); - builder.Append(registration.HandlerInterfaceDisplayName); - builder.AppendLine("),"); - builder.Append(" typeof("); - builder.Append(registration.ImplementationTypeDisplayName); - builder.AppendLine("));"); - builder.Append(" logger.Debug(\"Registered CQRS handler "); - builder.Append(EscapeStringLiteral(registration.ImplementationLogName)); - builder.Append(" as "); - builder.Append(EscapeStringLiteral(registration.HandlerInterfaceLogName)); - builder.AppendLine(".\");"); + if (!string.IsNullOrWhiteSpace(registration.ReflectionTypeMetadataName)) + { + AppendReflectionRegistration(builder, registration.ReflectionTypeMetadataName!); + continue; + } + + foreach (var directRegistration in registration.DirectRegistrations) + { + builder.AppendLine( + " global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient("); + builder.AppendLine(" services,"); + builder.Append(" typeof("); + builder.Append(directRegistration.HandlerInterfaceDisplayName); + builder.AppendLine("),"); + builder.Append(" typeof("); + builder.Append(directRegistration.ImplementationTypeDisplayName); + builder.AppendLine("));"); + builder.Append(" logger.Debug(\"Registered CQRS handler "); + builder.Append(EscapeStringLiteral(directRegistration.ImplementationLogName)); + builder.Append(" as "); + builder.Append(EscapeStringLiteral(directRegistration.HandlerInterfaceLogName)); + builder.AppendLine(".\");"); + } } builder.AppendLine(" }"); + + if (hasReflectionRegistrations) + { + builder.AppendLine(); + AppendReflectionHelpers(builder); + } + builder.AppendLine("}"); return builder.ToString(); } + private static void AppendReflectionRegistration(StringBuilder builder, string reflectionTypeMetadataName) + { + builder.Append(" RegisterReflectedHandler(services, logger, registryAssembly, \""); + builder.Append(EscapeStringLiteral(reflectionTypeMetadataName)); + builder.AppendLine("\");"); + } + + private static void AppendReflectionHelpers(StringBuilder builder) + { + // Emit the runtime helper methods only when at least one handler requires metadata-name lookup. + builder.AppendLine( + " private static void RegisterReflectedHandler(global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, global::GFramework.Core.Abstractions.Logging.ILogger logger, global::System.Reflection.Assembly registryAssembly, string implementationTypeMetadataName)"); + builder.AppendLine(" {"); + builder.AppendLine( + " var implementationType = registryAssembly.GetType(implementationTypeMetadataName, throwOnError: false, ignoreCase: false);"); + builder.AppendLine(" if (implementationType is null)"); + builder.AppendLine(" return;"); + builder.AppendLine(); + builder.AppendLine(" var handlerInterfaces = implementationType.GetInterfaces();"); + builder.AppendLine(" global::System.Array.Sort(handlerInterfaces, CompareTypes);"); + builder.AppendLine(); + builder.AppendLine(" foreach (var handlerInterface in handlerInterfaces)"); + builder.AppendLine(" {"); + builder.AppendLine(" if (!IsSupportedHandlerInterface(handlerInterface))"); + builder.AppendLine(" continue;"); + builder.AppendLine(); + builder.AppendLine( + " global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient("); + builder.AppendLine(" services,"); + builder.AppendLine(" handlerInterface,"); + builder.AppendLine(" implementationType);"); + builder.AppendLine( + " logger.Debug($\"Registered CQRS handler {GetRuntimeTypeDisplayName(implementationType)} as {GetRuntimeTypeDisplayName(handlerInterface)}.\");"); + builder.AppendLine(" }"); + builder.AppendLine(" }"); + builder.AppendLine(); + builder.AppendLine(" private static int CompareTypes(global::System.Type left, global::System.Type right)"); + builder.AppendLine(" {"); + builder.AppendLine( + " return global::System.StringComparer.Ordinal.Compare(GetRuntimeTypeDisplayName(left), GetRuntimeTypeDisplayName(right));"); + builder.AppendLine(" }"); + builder.AppendLine(); + builder.AppendLine(" private static bool IsSupportedHandlerInterface(global::System.Type interfaceType)"); + builder.AppendLine(" {"); + builder.AppendLine(" if (!interfaceType.IsGenericType)"); + builder.AppendLine(" return false;"); + builder.AppendLine(); + builder.AppendLine(" var definitionFullName = interfaceType.GetGenericTypeDefinition().FullName;"); + builder.AppendLine( + $" return global::System.StringComparer.Ordinal.Equals(definitionFullName, \"{IRequestHandlerMetadataName}\")"); + builder.AppendLine( + $" || global::System.StringComparer.Ordinal.Equals(definitionFullName, \"{INotificationHandlerMetadataName}\")"); + builder.AppendLine( + $" || global::System.StringComparer.Ordinal.Equals(definitionFullName, \"{IStreamRequestHandlerMetadataName}\");"); + builder.AppendLine(" }"); + builder.AppendLine(); + builder.AppendLine(" private static string GetRuntimeTypeDisplayName(global::System.Type type)"); + builder.AppendLine(" {"); + builder.AppendLine(" if (type == typeof(string))"); + builder.AppendLine(" return \"string\";"); + builder.AppendLine(" if (type == typeof(int))"); + builder.AppendLine(" return \"int\";"); + builder.AppendLine(" if (type == typeof(long))"); + builder.AppendLine(" return \"long\";"); + builder.AppendLine(" if (type == typeof(short))"); + builder.AppendLine(" return \"short\";"); + builder.AppendLine(" if (type == typeof(byte))"); + builder.AppendLine(" return \"byte\";"); + builder.AppendLine(" if (type == typeof(bool))"); + builder.AppendLine(" return \"bool\";"); + builder.AppendLine(" if (type == typeof(object))"); + builder.AppendLine(" return \"object\";"); + builder.AppendLine(" if (type == typeof(void))"); + builder.AppendLine(" return \"void\";"); + builder.AppendLine(" if (type == typeof(uint))"); + builder.AppendLine(" return \"uint\";"); + builder.AppendLine(" if (type == typeof(ulong))"); + builder.AppendLine(" return \"ulong\";"); + builder.AppendLine(" if (type == typeof(ushort))"); + builder.AppendLine(" return \"ushort\";"); + builder.AppendLine(" if (type == typeof(sbyte))"); + builder.AppendLine(" return \"sbyte\";"); + builder.AppendLine(" if (type == typeof(float))"); + builder.AppendLine(" return \"float\";"); + builder.AppendLine(" if (type == typeof(double))"); + builder.AppendLine(" return \"double\";"); + builder.AppendLine(" if (type == typeof(decimal))"); + builder.AppendLine(" return \"decimal\";"); + builder.AppendLine(" if (type == typeof(char))"); + builder.AppendLine(" return \"char\";"); + builder.AppendLine(); + builder.AppendLine(" if (type.IsArray)"); + builder.AppendLine(" return GetRuntimeTypeDisplayName(type.GetElementType()!) + \"[]\";"); + builder.AppendLine(); + builder.AppendLine(" if (!type.IsGenericType)"); + builder.AppendLine(" return (type.FullName ?? type.Name).Replace('+', '.');"); + builder.AppendLine(); + builder.AppendLine(" var genericTypeName = type.GetGenericTypeDefinition().FullName ?? type.Name;"); + builder.AppendLine(" var arityIndex = genericTypeName.IndexOf('`');"); + builder.AppendLine(" if (arityIndex >= 0)"); + builder.AppendLine(" genericTypeName = genericTypeName[..arityIndex];"); + builder.AppendLine(); + builder.AppendLine(" genericTypeName = genericTypeName.Replace('+', '.');"); + builder.AppendLine(" var arguments = type.GetGenericArguments();"); + builder.AppendLine(" var builder = new global::System.Text.StringBuilder();"); + builder.AppendLine(" builder.Append(genericTypeName);"); + builder.AppendLine(" builder.Append('<');"); + builder.AppendLine(); + builder.AppendLine(" for (var index = 0; index < arguments.Length; index++)"); + builder.AppendLine(" {"); + builder.AppendLine(" if (index > 0)"); + builder.AppendLine(" builder.Append(\", \");"); + builder.AppendLine(); + builder.AppendLine(" builder.Append(GetRuntimeTypeDisplayName(arguments[index]));"); + builder.AppendLine(" }"); + builder.AppendLine(); + builder.AppendLine(" builder.Append('>');"); + builder.AppendLine(" return builder.ToString();"); + builder.AppendLine(" }"); + } + private static string EscapeStringLiteral(string value) { return value.Replace("\\", "\\\\") @@ -363,29 +523,41 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator string HandlerInterfaceLogName, string ImplementationLogName); + private readonly record struct ImplementationRegistrationSpec( + string ImplementationTypeDisplayName, + string ImplementationLogName, + ImmutableArray DirectRegistrations, + string? ReflectionTypeMetadataName); + private readonly struct HandlerCandidateAnalysis : IEquatable { public HandlerCandidateAnalysis( string implementationTypeDisplayName, + string implementationLogName, ImmutableArray registrations, - bool hasUnsupportedConcreteHandler) + string? reflectionTypeMetadataName) { ImplementationTypeDisplayName = implementationTypeDisplayName; + ImplementationLogName = implementationLogName; Registrations = registrations; - HasUnsupportedConcreteHandler = hasUnsupportedConcreteHandler; + ReflectionTypeMetadataName = reflectionTypeMetadataName; } public string ImplementationTypeDisplayName { get; } + public string ImplementationLogName { get; } + public ImmutableArray Registrations { get; } - public bool HasUnsupportedConcreteHandler { get; } + public string? ReflectionTypeMetadataName { get; } public bool Equals(HandlerCandidateAnalysis other) { if (!string.Equals(ImplementationTypeDisplayName, other.ImplementationTypeDisplayName, StringComparison.Ordinal) || - HasUnsupportedConcreteHandler != other.HasUnsupportedConcreteHandler || + !string.Equals(ImplementationLogName, other.ImplementationLogName, StringComparison.Ordinal) || + !string.Equals(ReflectionTypeMetadataName, other.ReflectionTypeMetadataName, + StringComparison.Ordinal) || Registrations.Length != other.Registrations.Length) { return false; @@ -410,7 +582,11 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator unchecked { var hashCode = StringComparer.Ordinal.GetHashCode(ImplementationTypeDisplayName); - hashCode = (hashCode * 397) ^ HasUnsupportedConcreteHandler.GetHashCode(); + hashCode = (hashCode * 397) ^ StringComparer.Ordinal.GetHashCode(ImplementationLogName); + hashCode = (hashCode * 397) ^ + (ReflectionTypeMetadataName is null + ? 0 + : StringComparer.Ordinal.GetHashCode(ReflectionTypeMetadataName)); foreach (var registration in Registrations) { hashCode = (hashCode * 397) ^ registration.GetHashCode(); @@ -421,7 +597,5 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator } } - private readonly record struct GenerationEnvironment( - bool GenerationEnabled, - bool SupportsReflectionFallbackMarker); + private readonly record struct GenerationEnvironment(bool GenerationEnabled); }