diff --git a/GFramework.Cqrs.Tests/Cqrs/CqrsHandlerRegistrarTests.cs b/GFramework.Cqrs.Tests/Cqrs/CqrsHandlerRegistrarTests.cs index 3100ce84..318f6d21 100644 --- a/GFramework.Cqrs.Tests/Cqrs/CqrsHandlerRegistrarTests.cs +++ b/GFramework.Cqrs.Tests/Cqrs/CqrsHandlerRegistrarTests.cs @@ -13,6 +13,9 @@ namespace GFramework.Cqrs.Tests.Cqrs; [TestFixture] internal sealed class CqrsHandlerRegistrarTests { + private MicrosoftDiContainer? _container; + private ArchitectureContext? _context; + /// /// 初始化测试容器并重置共享状态。 /// @@ -42,9 +45,6 @@ internal sealed class CqrsHandlerRegistrarTests DeterministicNotificationHandlerState.Reset(); } - private MicrosoftDiContainer? _container; - private ArchitectureContext? _context; - /// /// 验证自动扫描到的通知处理器会按稳定名称顺序执行,而不是依赖反射枚举顺序。 /// @@ -188,6 +188,50 @@ internal sealed class CqrsHandlerRegistrarTests LoggerFactoryResolver.Provider = originalProvider; } } + + /// + /// 验证当生成注册器显式要求 reflection fallback 时,运行时会补扫剩余 handlers, + /// 同时避免把已由生成注册器注册的映射重复写入服务集合。 + /// + [Test] + public void RegisterHandlers_Should_Combine_Generated_Registry_With_Reflection_Fallback_Without_Duplicates() + { + var generatedAssembly = new Mock(); + generatedAssembly + .SetupGet(static assembly => assembly.FullName) + .Returns("GFramework.Core.Tests.Cqrs.PartialGeneratedRegistryAssembly, Version=1.0.0.0"); + generatedAssembly + .Setup(static assembly => assembly.GetCustomAttributes(typeof(CqrsHandlerRegistryAttribute), false)) + .Returns([new CqrsHandlerRegistryAttribute(typeof(PartialGeneratedNotificationHandlerRegistry))]); + generatedAssembly + .Setup(static assembly => assembly.GetCustomAttributes(typeof(CqrsReflectionFallbackAttribute), false)) + .Returns([new CqrsReflectionFallbackAttribute()]); + generatedAssembly + .Setup(static assembly => assembly.GetTypes()) + .Returns( + [ + typeof(GeneratedRegistryNotificationHandler), + ReflectionFallbackNotificationContainer.ReflectionOnlyHandlerType + ]); + + var container = new MicrosoftDiContainer(); + CqrsTestRuntime.RegisterHandlers(container, generatedAssembly.Object); + + var registrations = container.GetServicesUnsafe + .Where(static descriptor => + descriptor.ServiceType == typeof(INotificationHandler) && + descriptor.ImplementationType is not null) + .Select(static descriptor => descriptor.ImplementationType!) + .ToList(); + + Assert.That( + registrations, + Is.EqualTo( + [ + typeof(GeneratedRegistryNotificationHandler), + ReflectionFallbackNotificationContainer.ReflectionOnlyHandlerType + ])); + } } /// @@ -337,3 +381,52 @@ internal sealed class GeneratedNotificationHandlerRegistry : ICqrsHandlerRegistr $"Registered CQRS handler {typeof(GeneratedRegistryNotificationHandler).FullName} as {typeof(INotificationHandler).FullName}."); } } + +/// +/// 用于验证“生成注册器 + reflection fallback”组合路径的私有嵌套处理器容器。 +/// +internal sealed class ReflectionFallbackNotificationContainer +{ + /// + /// 获取仅能通过反射补扫接入的私有嵌套处理器类型。 + /// + public static Type ReflectionOnlyHandlerType => typeof(ReflectionOnlyGeneratedRegistryNotificationHandler); + + private sealed class ReflectionOnlyGeneratedRegistryNotificationHandler + : INotificationHandler + { + /// + /// 处理测试通知。 + /// + /// 通知实例。 + /// 取消令牌。 + /// 已完成任务。 + public ValueTask Handle(GeneratedRegistryNotification notification, CancellationToken cancellationToken) + { + return ValueTask.CompletedTask; + } + } +} + +/// +/// 模拟局部生成注册器场景中,仅注册“可由生成代码直接引用”的那部分 handlers。 +/// +internal sealed class PartialGeneratedNotificationHandlerRegistry : ICqrsHandlerRegistry +{ + /// + /// 将生成路径可见的通知处理器注册到目标服务集合。 + /// + /// 承载处理器映射的服务集合。 + /// 用于记录注册诊断的日志器。 + public void Register(IServiceCollection services, ILogger logger) + { + ArgumentNullException.ThrowIfNull(services); + ArgumentNullException.ThrowIfNull(logger); + + services.AddTransient( + typeof(INotificationHandler), + typeof(GeneratedRegistryNotificationHandler)); + logger.Debug( + $"Registered CQRS handler {typeof(GeneratedRegistryNotificationHandler).FullName} as {typeof(INotificationHandler).FullName}."); + } +} diff --git a/GFramework.Cqrs/CqrsReflectionFallbackAttribute.cs b/GFramework.Cqrs/CqrsReflectionFallbackAttribute.cs new file mode 100644 index 00000000..f18a7344 --- /dev/null +++ b/GFramework.Cqrs/CqrsReflectionFallbackAttribute.cs @@ -0,0 +1,14 @@ +namespace GFramework.Cqrs; + +/// +/// 标记程序集中的 CQRS 生成注册器仍需要运行时补充反射扫描。 +/// +/// +/// 该特性通常由源码生成器自动添加到消费端程序集。 +/// 当生成器只能安全生成部分 handler 映射时,运行时会先执行生成注册器,再补一次带去重的反射扫描, +/// 以覆盖那些生成代码无法直接引用的 handler 类型。 +/// +[AttributeUsage(AttributeTargets.Assembly)] +public sealed class CqrsReflectionFallbackAttribute : Attribute +{ +} diff --git a/GFramework.Cqrs/Internal/CqrsHandlerRegistrar.cs b/GFramework.Cqrs/Internal/CqrsHandlerRegistrar.cs index 85a95509..435c9cd5 100644 --- a/GFramework.Cqrs/Internal/CqrsHandlerRegistrar.cs +++ b/GFramework.Cqrs/Internal/CqrsHandlerRegistrar.cs @@ -1,4 +1,3 @@ -using System.Reflection; using GFramework.Core.Abstractions.Ioc; using GFramework.Core.Abstractions.Logging; using GFramework.Cqrs.Abstractions.Cqrs; @@ -32,7 +31,9 @@ internal static class CqrsHandlerRegistrar .Distinct() .OrderBy(GetAssemblySortKey, StringComparer.Ordinal)) { - if (TryRegisterGeneratedHandlers(container.GetServicesUnsafe, assembly, logger)) + var generatedRegistrationResult = + TryRegisterGeneratedHandlers(container.GetServicesUnsafe, assembly, logger); + if (generatedRegistrationResult == GeneratedRegistrationResult.FullyHandled) continue; RegisterAssemblyHandlers(container.GetServicesUnsafe, assembly, logger); @@ -45,8 +46,11 @@ internal static class CqrsHandlerRegistrar /// 目标服务集合。 /// 当前要处理的程序集。 /// 日志记录器。 - /// 当成功使用生成注册器时返回 ;否则返回 - private static bool TryRegisterGeneratedHandlers(IServiceCollection services, Assembly assembly, ILogger logger) + /// 生成注册器的使用结果。 + private static GeneratedRegistrationResult TryRegisterGeneratedHandlers( + IServiceCollection services, + Assembly assembly, + ILogger logger) { var assemblyName = GetAssemblySortKey(assembly); @@ -62,7 +66,7 @@ internal static class CqrsHandlerRegistrar .ToList(); if (registryTypes.Count == 0) - return false; + return GeneratedRegistrationResult.NoGeneratedRegistry; var registries = new List(registryTypes.Count); foreach (var registryType in registryTypes) @@ -71,21 +75,21 @@ internal static class CqrsHandlerRegistrar { logger.Warn( $"Ignoring generated CQRS handler registry {registryType.FullName} in assembly {assemblyName} because it does not implement {typeof(ICqrsHandlerRegistry).FullName}."); - return false; + return GeneratedRegistrationResult.NoGeneratedRegistry; } if (registryType.IsAbstract) { logger.Warn( $"Ignoring generated CQRS handler registry {registryType.FullName} in assembly {assemblyName} because it is abstract."); - return false; + return GeneratedRegistrationResult.NoGeneratedRegistry; } if (Activator.CreateInstance(registryType, nonPublic: true) is not ICqrsHandlerRegistry registry) { logger.Warn( $"Ignoring generated CQRS handler registry {registryType.FullName} in assembly {assemblyName} because it could not be instantiated."); - return false; + return GeneratedRegistrationResult.NoGeneratedRegistry; } registries.Add(registry); @@ -98,7 +102,14 @@ internal static class CqrsHandlerRegistrar registry.Register(services, logger); } - return true; + if (RequiresReflectionFallback(assembly)) + { + logger.Debug( + $"Generated CQRS registry for assembly {assemblyName} requested reflection fallback for unsupported handlers."); + return GeneratedRegistrationResult.RequiresReflectionFallback; + } + + return GeneratedRegistrationResult.FullyHandled; } catch (Exception exception) { @@ -106,7 +117,7 @@ internal static class CqrsHandlerRegistrar $"Generated CQRS handler registry discovery failed for assembly {assemblyName}. Falling back to reflection scan."); logger.Warn( $"Failed to use generated CQRS handler registry for assembly {assemblyName}: {exception.Message}"); - return false; + return GeneratedRegistrationResult.NoGeneratedRegistry; } } @@ -128,6 +139,13 @@ internal static class CqrsHandlerRegistrar foreach (var handlerInterface in handlerInterfaces) { + if (IsHandlerMappingAlreadyRegistered(services, handlerInterface, implementationType)) + { + logger.Debug( + $"Skipping duplicate CQRS handler {implementationType.FullName} as {handlerInterface.FullName}."); + continue; + } + // Request/notification handlers receive context injection before every dispatch. // Transient registration avoids sharing mutable Context across concurrent requests. services.AddTransient(handlerInterface, implementationType); @@ -202,6 +220,27 @@ internal static class CqrsHandlerRegistrar definition == typeof(IStreamRequestHandler<,>); } + /// + /// 判断生成注册器是否要求运行时继续补充反射扫描。 + /// + private static bool RequiresReflectionFallback(Assembly assembly) + { + return assembly.GetCustomAttributes(typeof(CqrsReflectionFallbackAttribute), inherit: false)?.Length > 0; + } + + /// + /// 判断同一 handler 映射是否已经由生成注册器或先前扫描步骤写入服务集合。 + /// + private static bool IsHandlerMappingAlreadyRegistered( + IServiceCollection services, + Type handlerInterface, + Type implementationType) + { + return services.Any(descriptor => + descriptor.ServiceType == handlerInterface && + descriptor.ImplementationType == implementationType); + } + /// /// 生成程序集排序键,保证跨运行环境的处理器注册顺序稳定。 /// @@ -217,4 +256,11 @@ internal static class CqrsHandlerRegistrar { return type.FullName ?? type.Name; } + + private enum GeneratedRegistrationResult + { + NoGeneratedRegistry, + FullyHandled, + RequiresReflectionFallback + } } diff --git a/GFramework.SourceGenerators.Tests/Cqrs/CqrsHandlerRegistryGeneratorTests.cs b/GFramework.SourceGenerators.Tests/Cqrs/CqrsHandlerRegistryGeneratorTests.cs index 0392ac8a..0cd91844 100644 --- a/GFramework.SourceGenerators.Tests/Cqrs/CqrsHandlerRegistryGeneratorTests.cs +++ b/GFramework.SourceGenerators.Tests/Cqrs/CqrsHandlerRegistryGeneratorTests.cs @@ -61,6 +61,11 @@ public class CqrsHandlerRegistryGeneratorTests { public CqrsHandlerRegistryAttribute(Type registryType) { } } + + [AttributeUsage(AttributeTargets.Assembly)] + public sealed class CqrsReflectionFallbackAttribute : Attribute + { + } } namespace TestApp @@ -120,10 +125,120 @@ public class CqrsHandlerRegistryGeneratorTests } /// - /// 验证当程序集包含生成代码无法合法引用的私有嵌套处理器时,生成器会放弃产出并让运行时回退到反射扫描。 + /// 验证当程序集包含生成代码无法合法引用的私有嵌套处理器时,生成器仍会为可见 handlers 生成注册器, + /// 并额外标记运行时补充反射扫描。 /// [Test] - public async Task Skips_Generation_When_Assembly_Contains_Private_Nested_Handler() + public async Task + Generates_Visible_Handlers_And_Requests_Reflection_Fallback_When_Assembly_Contains_Private_Nested_Handler() + { + const string source = """ + using System; + + namespace Microsoft.Extensions.DependencyInjection + { + public interface IServiceCollection { } + + public static class ServiceCollectionServiceExtensions + { + public static void AddTransient(IServiceCollection services, Type serviceType, Type implementationType) { } + } + } + + namespace GFramework.Core.Abstractions.Logging + { + public interface ILogger + { + void Debug(string msg); + } + } + + namespace GFramework.Cqrs.Abstractions.Cqrs + { + public interface IRequest { } + public interface INotification { } + public interface IStreamRequest { } + + public interface IRequestHandler where TRequest : IRequest { } + public interface INotificationHandler where TNotification : INotification { } + public interface IStreamRequestHandler where TRequest : IStreamRequest { } + } + + namespace GFramework.Cqrs + { + public interface ICqrsHandlerRegistry + { + void Register(Microsoft.Extensions.DependencyInjection.IServiceCollection services, GFramework.Core.Abstractions.Logging.ILogger logger); + } + + [AttributeUsage(AttributeTargets.Assembly, AllowMultiple = true)] + public sealed class CqrsHandlerRegistryAttribute : Attribute + { + public CqrsHandlerRegistryAttribute(Type registryType) { } + } + + [AttributeUsage(AttributeTargets.Assembly)] + public sealed class CqrsReflectionFallbackAttribute : Attribute + { + } + } + + namespace TestApp + { + using GFramework.Cqrs.Abstractions.Cqrs; + + public sealed record VisibleRequest() : IRequest; + + public sealed class Container + { + private sealed record HiddenRequest() : IRequest; + + private sealed class HiddenHandler : IRequestHandler { } + } + + public sealed class VisibleHandler : IRequestHandler { } + } + """; + + const string expected = """ + // + #nullable enable + + [assembly: global::GFramework.Cqrs.CqrsHandlerRegistryAttribute(typeof(global::GFramework.Generated.Cqrs.__GFrameworkGeneratedCqrsHandlerRegistry))] + [assembly: global::GFramework.Cqrs.CqrsReflectionFallbackAttribute()] + + namespace GFramework.Generated.Cqrs; + + internal sealed class __GFrameworkGeneratedCqrsHandlerRegistry : global::GFramework.Cqrs.ICqrsHandlerRegistry + { + public void Register(global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, global::GFramework.Core.Abstractions.Logging.ILogger logger) + { + if (services is null) + throw new global::System.ArgumentNullException(nameof(services)); + if (logger is null) + throw new global::System.ArgumentNullException(nameof(logger)); + + global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient( + services, + typeof(global::GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler), + typeof(global::TestApp.VisibleHandler)); + logger.Debug("Registered CQRS handler TestApp.VisibleHandler as GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler."); + } + } + + """; + + await GeneratorTest.RunAsync( + source, + ("CqrsHandlerRegistry.g.cs", expected)); + } + + /// + /// 验证当旧版 runtime 合同中不存在 reflection fallback 标记特性时, + /// 生成器会保留此前的整程序集回退行为,避免丢失不可见 handlers。 + /// + [Test] + public async Task Skips_Generation_For_Unsupported_Handler_When_Fallback_Marker_Is_Unavailable() { const string source = """ using System; diff --git a/GFramework.SourceGenerators/Cqrs/CqrsHandlerRegistryGenerator.cs b/GFramework.SourceGenerators/Cqrs/CqrsHandlerRegistryGenerator.cs index 65aad69a..80561248 100644 --- a/GFramework.SourceGenerators/Cqrs/CqrsHandlerRegistryGenerator.cs +++ b/GFramework.SourceGenerators/Cqrs/CqrsHandlerRegistryGenerator.cs @@ -16,6 +16,9 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator private const string IStreamRequestHandlerMetadataName = $"{CqrsContractsNamespace}.IStreamRequestHandler`2"; private const string ICqrsHandlerRegistryMetadataName = $"{CqrsRuntimeNamespace}.ICqrsHandlerRegistry"; + private const string CqrsReflectionFallbackAttributeMetadataName = + $"{CqrsRuntimeNamespace}.CqrsReflectionFallbackAttribute"; + private const string CqrsHandlerRegistryAttributeMetadataName = $"{CqrsRuntimeNamespace}.CqrsHandlerRegistryAttribute"; @@ -28,8 +31,8 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator /// public void Initialize(IncrementalGeneratorInitializationContext context) { - var generationEnabled = context.CompilationProvider - .Select(static (compilation, _) => HasRequiredTypes(compilation)); + var generationEnvironment = context.CompilationProvider + .Select(static (compilation, _) => CreateGenerationEnvironment(compilation)); // Restrict semantic analysis to type declarations that can actually contribute implemented interfaces. var handlerCandidates = context.SyntaxProvider.CreateSyntaxProvider( @@ -39,19 +42,24 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator .Collect(); context.RegisterSourceOutput( - generationEnabled.Combine(handlerCandidates), + generationEnvironment.Combine(handlerCandidates), static (productionContext, pair) => Execute(productionContext, pair.Left, pair.Right)); } - private static bool HasRequiredTypes(Compilation compilation) + private static GenerationEnvironment CreateGenerationEnvironment(Compilation compilation) { - return compilation.GetTypeByMetadataName(IRequestHandlerMetadataName) is not null && - compilation.GetTypeByMetadataName(INotificationHandlerMetadataName) is not null && - compilation.GetTypeByMetadataName(IStreamRequestHandlerMetadataName) is not null && - compilation.GetTypeByMetadataName(ICqrsHandlerRegistryMetadataName) is not null && - compilation.GetTypeByMetadataName(CqrsHandlerRegistryAttributeMetadataName) is not null && - compilation.GetTypeByMetadataName(ILoggerMetadataName) is not null && - compilation.GetTypeByMetadataName(IServiceCollectionMetadataName) is not null; + var generationEnabled = compilation.GetTypeByMetadataName(IRequestHandlerMetadataName) is not null && + compilation.GetTypeByMetadataName(INotificationHandlerMetadataName) is not null && + compilation.GetTypeByMetadataName(IStreamRequestHandlerMetadataName) is not null && + compilation.GetTypeByMetadataName(ICqrsHandlerRegistryMetadataName) is not null && + compilation.GetTypeByMetadataName( + CqrsHandlerRegistryAttributeMetadataName) is not null && + compilation.GetTypeByMetadataName(ILoggerMetadataName) is not null && + compilation.GetTypeByMetadataName(IServiceCollectionMetadataName) is not null; + + return new GenerationEnvironment( + generationEnabled, + compilation.GetTypeByMetadataName(CqrsReflectionFallbackAttributeMetadataName) is not null); } private static bool IsHandlerCandidate(SyntaxNode node) @@ -108,21 +116,25 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator false); } - private static void Execute(SourceProductionContext context, bool generationEnabled, + private static void Execute(SourceProductionContext context, GenerationEnvironment generationEnvironment, ImmutableArray candidates) { - if (!generationEnabled) + if (!generationEnvironment.GenerationEnabled) return; var registrations = CollectRegistrations(candidates, out var hasUnsupportedConcreteHandler); - // If the assembly contains handlers that generated code cannot legally reference - // (for example private nested handlers), keep the runtime on the reflection path - // so registration behavior remains complete instead of silently dropping handlers. - if (hasUnsupportedConcreteHandler || registrations.Count == 0) + if (registrations.Count == 0) return; - context.AddSource(HintName, GenerateSource(registrations)); + // If the runtime contract does not yet expose the reflection fallback marker, + // keep the previous all-or-nothing behavior so unsupported handlers are not silently dropped. + if (hasUnsupportedConcreteHandler && !generationEnvironment.SupportsReflectionFallbackMarker) + return; + + context.AddSource( + HintName, + GenerateSource(registrations, hasUnsupportedConcreteHandler)); } private static List CollectRegistrations( @@ -144,7 +156,7 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator if (candidate.Value.HasUnsupportedConcreteHandler) { hasUnsupportedConcreteHandler = true; - return []; + continue; } uniqueCandidates[candidate.Value.ImplementationTypeDisplayName] = candidate.Value; @@ -270,7 +282,9 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator return GetTypeSortKey(type).Replace("global::", string.Empty); } - private static string GenerateSource(IReadOnlyList registrations) + private static string GenerateSource( + IReadOnlyList registrations, + bool emitReflectionFallbackAttribute) { var builder = new StringBuilder(); builder.AppendLine("// "); @@ -283,6 +297,13 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator builder.Append('.'); builder.Append(GeneratedTypeName); builder.AppendLine("))]"); + if (emitReflectionFallbackAttribute) + { + builder.Append("[assembly: global::"); + builder.Append(CqrsRuntimeNamespace); + builder.AppendLine(".CqrsReflectionFallbackAttribute()]"); + } + builder.AppendLine(); builder.Append("namespace "); builder.Append(GeneratedNamespace); @@ -399,4 +420,8 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator } } } + + private readonly record struct GenerationEnvironment( + bool GenerationEnabled, + bool SupportsReflectionFallbackMarker); }