diff --git a/GFramework.Cqrs.Tests/Cqrs/CqrsDispatcherCacheTests.cs b/GFramework.Cqrs.Tests/Cqrs/CqrsDispatcherCacheTests.cs index 5ae794ad..badd7490 100644 --- a/GFramework.Cqrs.Tests/Cqrs/CqrsDispatcherCacheTests.cs +++ b/GFramework.Cqrs.Tests/Cqrs/CqrsDispatcherCacheTests.cs @@ -7,11 +7,14 @@ using GFramework.Cqrs.Abstractions.Cqrs; namespace GFramework.Cqrs.Tests.Cqrs; /// -/// 验证 CQRS dispatcher 会缓存热路径中的服务类型构造结果。 +/// 验证 CQRS dispatcher 会缓存热路径中的服务类型与调用委托。 /// [TestFixture] internal sealed class CqrsDispatcherCacheTests { + private MicrosoftDiContainer? _container; + private ArchitectureContext? _context; + /// /// 初始化测试上下文。 /// @@ -29,6 +32,7 @@ internal sealed class CqrsDispatcherCacheTests _container.Freeze(); _context = new ArchitectureContext(_container); + ClearDispatcherCaches(); } /// @@ -41,11 +45,8 @@ internal sealed class CqrsDispatcherCacheTests _container = null; } - private MicrosoftDiContainer? _container; - private ArchitectureContext? _context; - /// - /// 验证相同消息类型重复分发时,不会重复扩张服务类型缓存。 + /// 验证相同消息类型重复分发时,不会重复扩张服务类型与调用委托缓存。 /// [Test] public async Task Dispatcher_Should_Cache_Service_Types_After_First_Dispatch() @@ -53,8 +54,8 @@ internal sealed class CqrsDispatcherCacheTests var notificationServiceTypes = GetCacheField("NotificationHandlerServiceTypes"); var requestServiceTypes = GetCacheField("RequestServiceTypes"); var streamServiceTypes = GetCacheField("StreamHandlerServiceTypes"); - var requestInvokers = GetCacheField("RequestInvokers"); - var requestPipelineInvokers = GetCacheField("RequestPipelineInvokers"); + var requestInvokers = GetGenericCacheField("RequestInvokerCache`1", typeof(int), "Invokers"); + var requestPipelineInvokers = GetGenericCacheField("RequestPipelineInvokerCache`1", typeof(int), "Invokers"); var notificationInvokers = GetCacheField("NotificationInvokers"); var streamInvokers = GetCacheField("StreamInvokers"); @@ -104,14 +105,42 @@ internal sealed class CqrsDispatcherCacheTests }); } + /// + /// 验证 request 调用委托会按响应类型分别缓存,避免不同响应类型共用 object 结果桥接。 + /// + [Test] + public async Task Dispatcher_Should_Cache_Request_Invokers_Per_Response_Type() + { + var intRequestInvokers = GetGenericCacheField("RequestInvokerCache`1", typeof(int), "Invokers"); + var stringRequestInvokers = GetGenericCacheField("RequestInvokerCache`1", typeof(string), "Invokers"); + + var intBefore = intRequestInvokers.Count; + var stringBefore = stringRequestInvokers.Count; + + await _context!.SendRequestAsync(new DispatcherCacheRequest()); + await _context.SendRequestAsync(new DispatcherStringCacheRequest()); + + var intAfterFirstDispatch = intRequestInvokers.Count; + var stringAfterFirstDispatch = stringRequestInvokers.Count; + + await _context.SendRequestAsync(new DispatcherCacheRequest()); + await _context.SendRequestAsync(new DispatcherStringCacheRequest()); + + Assert.Multiple(() => + { + Assert.That(intAfterFirstDispatch, Is.EqualTo(intBefore + 1)); + Assert.That(stringAfterFirstDispatch, Is.EqualTo(stringBefore + 1)); + Assert.That(intRequestInvokers.Count, Is.EqualTo(intAfterFirstDispatch)); + Assert.That(stringRequestInvokers.Count, Is.EqualTo(stringAfterFirstDispatch)); + }); + } + /// /// 通过反射读取 dispatcher 的静态缓存字典。 /// private static IDictionary GetCacheField(string fieldName) { - var dispatcherType = typeof(CqrsReflectionFallbackAttribute).Assembly - .GetType("GFramework.Cqrs.Internal.CqrsDispatcher", throwOnError: true)!; - + var dispatcherType = GetDispatcherType(); var field = dispatcherType.GetField( fieldName, BindingFlags.NonPublic | BindingFlags.Static); @@ -123,6 +152,57 @@ internal sealed class CqrsDispatcherCacheTests $"Dispatcher cache field {fieldName} does not implement IDictionary."); } + /// + /// 清空本测试依赖的 dispatcher 静态缓存,避免跨用例共享进程级状态导致断言漂移。 + /// + private static void ClearDispatcherCaches() + { + GetCacheField("NotificationHandlerServiceTypes").Clear(); + GetCacheField("RequestServiceTypes").Clear(); + GetCacheField("StreamHandlerServiceTypes").Clear(); + GetCacheField("NotificationInvokers").Clear(); + GetCacheField("StreamInvokers").Clear(); + GetGenericCacheField("RequestInvokerCache`1", typeof(int), "Invokers").Clear(); + GetGenericCacheField("RequestInvokerCache`1", typeof(string), "Invokers").Clear(); + GetGenericCacheField("RequestPipelineInvokerCache`1", typeof(int), "Invokers").Clear(); + GetGenericCacheField("RequestPipelineInvokerCache`1", typeof(string), "Invokers").Clear(); + } + + /// + /// 通过反射读取 dispatcher 嵌套泛型缓存类型上的静态缓存字典。 + /// + private static IDictionary GetGenericCacheField(string nestedTypeName, Type genericTypeArgument, string fieldName) + { + var nestedGenericType = GetDispatcherType().GetNestedType( + nestedTypeName, + BindingFlags.NonPublic); + + Assert.That(nestedGenericType, Is.Not.Null, $"Missing dispatcher nested cache type {nestedTypeName}."); + + var closedNestedType = nestedGenericType!.MakeGenericType(genericTypeArgument); + var field = closedNestedType.GetField( + fieldName, + BindingFlags.NonPublic | BindingFlags.Static); + + Assert.That( + field, + Is.Not.Null, + $"Missing dispatcher nested cache field {nestedTypeName}.{fieldName} for {genericTypeArgument.FullName}."); + + return field!.GetValue(null) as IDictionary + ?? throw new InvalidOperationException( + $"Dispatcher nested cache field {nestedTypeName}.{fieldName} does not implement IDictionary."); + } + + /// + /// 获取 CQRS dispatcher 运行时类型。 + /// + private static Type GetDispatcherType() + { + return typeof(CqrsReflectionFallbackAttribute).Assembly + .GetType("GFramework.Cqrs.Internal.CqrsDispatcher", throwOnError: true)!; + } + /// /// 消费整个异步流,确保建流路径被真实执行。 /// @@ -154,6 +234,11 @@ internal sealed record DispatcherCacheStreamRequest : IStreamRequest; /// internal sealed record DispatcherPipelineCacheRequest : IRequest; +/// +/// 用于验证按响应类型分层 request invoker 缓存的测试请求。 +/// +internal sealed record DispatcherStringCacheRequest : IRequest; + /// /// 处理 。 /// @@ -213,6 +298,20 @@ internal sealed class DispatcherPipelineCacheRequestHandler : IRequestHandler +/// 处理 。 +/// +internal sealed class DispatcherStringCacheRequestHandler : IRequestHandler +{ + /// + /// 返回固定字符串,供按响应类型缓存测试验证 string 路径。 + /// + public ValueTask Handle(DispatcherStringCacheRequest request, CancellationToken cancellationToken) + { + return ValueTask.FromResult("dispatcher-cache"); + } +} + /// /// 为 提供最小 pipeline 行为, /// 用于命中 dispatcher 的 pipeline invoker 缓存分支。 diff --git a/GFramework.Cqrs/Internal/CqrsDispatcher.cs b/GFramework.Cqrs/Internal/CqrsDispatcher.cs index 002b7edc..a6d62f96 100644 --- a/GFramework.Cqrs/Internal/CqrsDispatcher.cs +++ b/GFramework.Cqrs/Internal/CqrsDispatcher.cs @@ -15,16 +15,6 @@ internal sealed class CqrsDispatcher( IIocContainer container, ILogger logger) : ICqrsRuntime { - // 进程级缓存:按请求/响应类型缓存直接处理器调用委托,避免热路径重复反射。 - // 线程安全依赖 ConcurrentDictionary;缓存与进程同寿命,默认假设请求类型集合有限且稳定。 - private static readonly ConcurrentDictionary<(Type RequestType, Type ResponseType), RequestInvoker> - RequestInvokers = new(); - - // 进程级缓存:缓存带 pipeline 的请求调用委托,减少每次分发时的反射与表达式重建开销。 - // 若后续引入动态生成请求类型,需要重新评估该缓存的增长边界。 - private static readonly ConcurrentDictionary<(Type RequestType, Type ResponseType), RequestPipelineInvoker> - RequestPipelineInvokers = new(); - // 进程级缓存:缓存通知调用委托,复用并发安全字典以支撑多线程发布路径。 private static readonly ConcurrentDictionary NotificationInvokers = new(); @@ -131,20 +121,18 @@ internal sealed class CqrsDispatcher( if (behaviors.Count == 0) { - var invoker = RequestInvokers.GetOrAdd( - (requestType, typeof(TResponse)), - static key => CreateRequestInvoker(key.RequestType, key.ResponseType)); + var invoker = RequestInvokerCache.Invokers.GetOrAdd( + requestType, + CreateRequestInvoker); - var result = await invoker(handler, request, cancellationToken); - return result is null ? default! : (TResponse)result; + return await invoker(handler, request, cancellationToken); } - var pipelineInvoker = RequestPipelineInvokers.GetOrAdd( - (requestType, typeof(TResponse)), - static key => CreateRequestPipelineInvoker(key.RequestType, key.ResponseType)); + var pipelineInvoker = RequestPipelineInvokerCache.Invokers.GetOrAdd( + requestType, + CreateRequestPipelineInvoker); - var pipelineResult = await pipelineInvoker(handler, behaviors, request, cancellationToken); - return pipelineResult is null ? default! : (TResponse)pipelineResult; + return await pipelineInvoker(handler, behaviors, request, cancellationToken); } /// @@ -200,21 +188,23 @@ internal sealed class CqrsDispatcher( /// /// 生成请求处理器调用委托,避免每次发送都重复反射。 /// - private static RequestInvoker CreateRequestInvoker(Type requestType, Type responseType) + private static RequestInvoker CreateRequestInvoker(Type requestType) { var method = RequestHandlerInvokerMethodDefinition - .MakeGenericMethod(requestType, responseType); - return (RequestInvoker)Delegate.CreateDelegate(typeof(RequestInvoker), method); + .MakeGenericMethod(requestType, typeof(TResponse)); + return (RequestInvoker)Delegate.CreateDelegate(typeof(RequestInvoker), method); } /// /// 生成带管道行为的请求处理委托,避免每次发送都重复反射。 /// - private static RequestPipelineInvoker CreateRequestPipelineInvoker(Type requestType, Type responseType) + private static RequestPipelineInvoker CreateRequestPipelineInvoker(Type requestType) { var method = RequestPipelineInvokerMethodDefinition - .MakeGenericMethod(requestType, responseType); - return (RequestPipelineInvoker)Delegate.CreateDelegate(typeof(RequestPipelineInvoker), method); + .MakeGenericMethod(requestType, typeof(TResponse)); + return (RequestPipelineInvoker)Delegate.CreateDelegate( + typeof(RequestPipelineInvoker), + method); } /// @@ -240,7 +230,7 @@ internal sealed class CqrsDispatcher( /// /// 执行已强类型化的请求处理器调用。 /// - private static async ValueTask InvokeRequestHandlerAsync( + private static ValueTask InvokeRequestHandlerAsync( object handler, object request, CancellationToken cancellationToken) @@ -248,14 +238,13 @@ internal sealed class CqrsDispatcher( { var typedHandler = (IRequestHandler)handler; var typedRequest = (TRequest)request; - var result = await typedHandler.Handle(typedRequest, cancellationToken); - return result; + return typedHandler.Handle(typedRequest, cancellationToken); } /// /// 执行包含管道行为链的请求处理。 /// - private static async ValueTask InvokeRequestPipelineAsync( + private static ValueTask InvokeRequestPipelineAsync( object handler, IReadOnlyList behaviors, object request, @@ -275,8 +264,7 @@ internal sealed class CqrsDispatcher( next = (message, token) => behavior.Handle(message, currentNext, token); } - var result = await next(typedRequest, cancellationToken); - return result; + return next(typedRequest, cancellationToken); } /// @@ -307,10 +295,12 @@ internal sealed class CqrsDispatcher( return typedHandler.Handle(typedRequest, cancellationToken); } - private delegate ValueTask RequestInvoker(object handler, object request, + private delegate ValueTask RequestInvoker( + object handler, + object request, CancellationToken cancellationToken); - private delegate ValueTask RequestPipelineInvoker( + private delegate ValueTask RequestPipelineInvoker( object handler, IReadOnlyList behaviors, object request, @@ -321,5 +311,23 @@ internal sealed class CqrsDispatcher( private delegate object StreamInvoker(object handler, object request, CancellationToken cancellationToken); + /// + /// 按响应类型分层缓存 request 处理器调用委托,避免 value-type 响应在 object 桥接中产生装箱。 + /// + /// 请求响应类型。 + private static class RequestInvokerCache + { + internal static readonly ConcurrentDictionary> Invokers = new(); + } + + /// + /// 按响应类型分层缓存带 pipeline 的 request 调用委托,避免 pipeline 热路径上的额外装箱。 + /// + /// 请求响应类型。 + private static class RequestPipelineInvokerCache + { + internal static readonly ConcurrentDictionary> Invokers = new(); + } + private readonly record struct RequestServiceTypeSet(Type HandlerType, Type BehaviorType); } diff --git a/GFramework.SourceGenerators.Tests/Core/GeneratorTest.cs b/GFramework.SourceGenerators.Tests/Core/GeneratorTest.cs index a622d387..bed493e7 100644 --- a/GFramework.SourceGenerators.Tests/Core/GeneratorTest.cs +++ b/GFramework.SourceGenerators.Tests/Core/GeneratorTest.cs @@ -16,6 +16,24 @@ public static class GeneratorTest public static async Task RunAsync( string source, params (string filename, string content)[] generatedSources) + { + await RunAsync( + source, + additionalReferences: [], + generatedSources); + } + + /// + /// 运行源代码生成器测试,并为测试编译显式追加元数据引用。 + /// + /// 输入的源代码。 + /// 附加元数据引用,用于构造多程序集场景。 + /// 期望生成的源文件集合,包含文件名和内容的元组。 + /// 异步操作任务。 + public static async Task RunAsync( + string source, + IEnumerable additionalReferences, + params (string filename, string content)[] generatedSources) { var test = new CSharpSourceGeneratorTest { @@ -31,6 +49,9 @@ public static class GeneratorTest test.TestState.GeneratedSources.Add( (typeof(TGenerator), filename, NormalizeLineEndings(content))); + foreach (var additionalReference in additionalReferences) + test.TestState.AdditionalReferences.Add(additionalReference); + await test.RunAsync(); } @@ -46,4 +67,4 @@ public static class GeneratorTest .Replace("\r", "\n", StringComparison.Ordinal) .Replace("\n", Environment.NewLine, StringComparison.Ordinal); } -} \ No newline at end of file +} diff --git a/GFramework.SourceGenerators.Tests/Core/MetadataReferenceTestBuilder.cs b/GFramework.SourceGenerators.Tests/Core/MetadataReferenceTestBuilder.cs new file mode 100644 index 00000000..45fd506a --- /dev/null +++ b/GFramework.SourceGenerators.Tests/Core/MetadataReferenceTestBuilder.cs @@ -0,0 +1,74 @@ +using System.Collections.Immutable; +using System.IO; + +namespace GFramework.SourceGenerators.Tests.Core; + +/// +/// 为多程序集源生成器测试构建内存元数据引用。 +/// +public static class MetadataReferenceTestBuilder +{ + // Reuse the runtime reference set across generator tests to avoid reparsing TRUSTED_PLATFORM_ASSEMBLIES + // for every in-memory compilation. + private static readonly Lazy> CachedRuntimeReferences = + new(CreateRuntimeMetadataReferences); + + /// + /// 将给定源码编译为内存程序集,并返回可供测试编译消费的元数据引用。 + /// + /// 目标程序集名称。 + /// 待编译源码。 + /// 附加元数据引用,用于构造依赖链。 + /// 编译成功后的内存元数据引用。 + public static MetadataReference CreateFromSource( + string assemblyName, + string source, + params MetadataReference[] additionalReferences) + { + var syntaxTree = CSharpSyntaxTree.ParseText(source); + var references = CachedRuntimeReferences.Value + .Concat(additionalReferences) + .ToImmutableArray(); + var compilation = CSharpCompilation.Create( + assemblyName, + [syntaxTree], + references, + new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary)); + + using var stream = new MemoryStream(); + var emitResult = compilation.Emit(stream); + if (!emitResult.Success) + { + var diagnostics = string.Join( + Environment.NewLine, + emitResult.Diagnostics + .Where(static diagnostic => diagnostic.Severity == DiagnosticSeverity.Error) + .Select(static diagnostic => diagnostic.ToString())); + throw new InvalidOperationException( + $"Failed to build metadata reference '{assemblyName}'.{Environment.NewLine}{diagnostics}"); + } + + stream.Position = 0; + return MetadataReference.CreateFromImage(stream.ToArray()); + } + + /// + /// 获取当前测试运行时可直接复用的基础元数据引用集合。 + /// + /// 当前运行时可信平台程序集对应的元数据引用。 + public static ImmutableArray GetRuntimeMetadataReferences() + { + return CachedRuntimeReferences.Value; + } + + private static ImmutableArray CreateRuntimeMetadataReferences() + { + var trustedPlatformAssemblies = ((string?)AppContext.GetData("TRUSTED_PLATFORM_ASSEMBLIES"))? + .Split(Path.PathSeparator, StringSplitOptions.RemoveEmptyEntries) + ?? Array.Empty(); + + return trustedPlatformAssemblies + .Select(static path => (MetadataReference)MetadataReference.CreateFromFile(path)) + .ToImmutableArray(); + } +} diff --git a/GFramework.SourceGenerators.Tests/Cqrs/CqrsHandlerRegistryGeneratorTests.cs b/GFramework.SourceGenerators.Tests/Cqrs/CqrsHandlerRegistryGeneratorTests.cs index ef70ec4a..f51dde95 100644 --- a/GFramework.SourceGenerators.Tests/Cqrs/CqrsHandlerRegistryGeneratorTests.cs +++ b/GFramework.SourceGenerators.Tests/Cqrs/CqrsHandlerRegistryGeneratorTests.cs @@ -164,6 +164,94 @@ public class CqrsHandlerRegistryGeneratorTests """; + private const string MixedDirectAndPreciseRegistrationsExpected = """ + // + #nullable enable + + [assembly: global::GFramework.Cqrs.CqrsHandlerRegistryAttribute(typeof(global::GFramework.Generated.Cqrs.__GFrameworkGeneratedCqrsHandlerRegistry))] + + namespace GFramework.Generated.Cqrs; + + internal sealed class __GFrameworkGeneratedCqrsHandlerRegistry : global::GFramework.Cqrs.ICqrsHandlerRegistry + { + public void Register(global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, global::GFramework.Core.Abstractions.Logging.ILogger logger) + { + if (services is null) + throw new global::System.ArgumentNullException(nameof(services)); + if (logger is null) + throw new global::System.ArgumentNullException(nameof(logger)); + + var registryAssembly = typeof(global::GFramework.Generated.Cqrs.__GFrameworkGeneratedCqrsHandlerRegistry).Assembly; + + var implementationType0 = typeof(global::TestApp.Container.MixedHandler); + if (implementationType0 is not null) + { + var serviceType0_0Argument0 = registryAssembly.GetType("TestApp.Container+HiddenRequest", throwOnError: false, ignoreCase: false); + var serviceType0_0Argument1Element = registryAssembly.GetType("TestApp.Container+HiddenResponse", throwOnError: false, ignoreCase: false); + if (serviceType0_0Argument0 is not null && serviceType0_0Argument1Element is not null) + { + var serviceType0_0 = typeof(global::GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler<,>).MakeGenericType(serviceType0_0Argument0, serviceType0_0Argument1Element.MakeArrayType()); + global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient( + services, + serviceType0_0, + implementationType0); + logger.Debug("Registered CQRS handler TestApp.Container.MixedHandler as GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler."); + } + global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient( + services, + typeof(global::GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler), + implementationType0); + logger.Debug("Registered CQRS handler TestApp.Container.MixedHandler as GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler."); + } + } + } + + """; + + private const string MixedReflectedImplementationAndPreciseRegistrationsExpected = """ + // + #nullable enable + + [assembly: global::GFramework.Cqrs.CqrsHandlerRegistryAttribute(typeof(global::GFramework.Generated.Cqrs.__GFrameworkGeneratedCqrsHandlerRegistry))] + + namespace GFramework.Generated.Cqrs; + + internal sealed class __GFrameworkGeneratedCqrsHandlerRegistry : global::GFramework.Cqrs.ICqrsHandlerRegistry + { + public void Register(global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, global::GFramework.Core.Abstractions.Logging.ILogger logger) + { + if (services is null) + throw new global::System.ArgumentNullException(nameof(services)); + if (logger is null) + throw new global::System.ArgumentNullException(nameof(logger)); + + var registryAssembly = typeof(global::GFramework.Generated.Cqrs.__GFrameworkGeneratedCqrsHandlerRegistry).Assembly; + + var implementationType0 = registryAssembly.GetType("TestApp.Container+HiddenMixedHandler", throwOnError: false, ignoreCase: false); + if (implementationType0 is not null) + { + var serviceType0_0Argument0 = registryAssembly.GetType("TestApp.Container+HiddenRequest", throwOnError: false, ignoreCase: false); + var serviceType0_0Argument1Element = registryAssembly.GetType("TestApp.Container+HiddenResponse", throwOnError: false, ignoreCase: false); + if (serviceType0_0Argument0 is not null && serviceType0_0Argument1Element is not null) + { + var serviceType0_0 = typeof(global::GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler<,>).MakeGenericType(serviceType0_0Argument0, serviceType0_0Argument1Element.MakeArrayType()); + global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient( + services, + serviceType0_0, + implementationType0); + logger.Debug("Registered CQRS handler TestApp.Container.HiddenMixedHandler as GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler."); + } + global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient( + services, + typeof(global::GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler), + implementationType0); + logger.Debug("Registered CQRS handler TestApp.Container.HiddenMixedHandler as GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler."); + } + } + } + + """; + /// /// 验证生成器会为当前程序集中的 request、notification 和 stream 处理器生成稳定顺序的注册器。 /// @@ -579,6 +667,293 @@ public class CqrsHandlerRegistryGeneratorTests ("CqrsHandlerRegistry.g.cs", HiddenGenericEnvelopeResponseExpected)); } + /// + /// 验证同一个 implementation 同时包含可直接注册接口与需精确重建接口时, + /// 生成器会保留两类注册,并继续按 handler interface 名称稳定排序。 + /// + [Test] + public async Task Generates_Mixed_Direct_And_Precise_Registrations_For_Same_Implementation() + { + const string source = """ + using System; + + namespace Microsoft.Extensions.DependencyInjection + { + public interface IServiceCollection { } + + public static class ServiceCollectionServiceExtensions + { + public static void AddTransient(IServiceCollection services, Type serviceType, Type implementationType) { } + } + } + + namespace GFramework.Core.Abstractions.Logging + { + public interface ILogger + { + void Debug(string msg); + } + } + + namespace GFramework.Cqrs.Abstractions.Cqrs + { + public interface IRequest { } + public interface INotification { } + public interface IStreamRequest { } + + public interface IRequestHandler where TRequest : IRequest { } + public interface INotificationHandler where TNotification : INotification { } + public interface IStreamRequestHandler where TRequest : IStreamRequest { } + } + + namespace GFramework.Cqrs + { + public interface ICqrsHandlerRegistry + { + void Register(Microsoft.Extensions.DependencyInjection.IServiceCollection services, GFramework.Core.Abstractions.Logging.ILogger logger); + } + + [AttributeUsage(AttributeTargets.Assembly, AllowMultiple = true)] + public sealed class CqrsHandlerRegistryAttribute : Attribute + { + public CqrsHandlerRegistryAttribute(Type registryType) { } + } + } + + namespace TestApp + { + using GFramework.Cqrs.Abstractions.Cqrs; + + public sealed record VisibleRequest() : IRequest; + + public sealed class Container + { + private sealed record HiddenResponse(); + + private sealed record HiddenRequest() : IRequest; + + public sealed class MixedHandler : + IRequestHandler, + IRequestHandler + { + } + } + } + """; + + await GeneratorTest.RunAsync( + source, + ("CqrsHandlerRegistry.g.cs", MixedDirectAndPreciseRegistrationsExpected)); + } + + /// + /// 验证隐藏 implementation 同时包含可见 handler interface 与需精确重建接口时, + /// 生成器会保留两类注册,而不会让可见接口被整实现回退吞掉。 + /// + [Test] + public async Task Generates_Mixed_Reflected_Implementation_And_Precise_Registrations_For_Same_Implementation() + { + const string source = """ + using System; + + namespace Microsoft.Extensions.DependencyInjection + { + public interface IServiceCollection { } + + public static class ServiceCollectionServiceExtensions + { + public static void AddTransient(IServiceCollection services, Type serviceType, Type implementationType) { } + } + } + + namespace GFramework.Core.Abstractions.Logging + { + public interface ILogger + { + void Debug(string msg); + } + } + + namespace GFramework.Cqrs.Abstractions.Cqrs + { + public interface IRequest { } + public interface INotification { } + public interface IStreamRequest { } + + public interface IRequestHandler where TRequest : IRequest { } + public interface INotificationHandler where TNotification : INotification { } + public interface IStreamRequestHandler where TRequest : IStreamRequest { } + } + + namespace GFramework.Cqrs + { + public interface ICqrsHandlerRegistry + { + void Register(Microsoft.Extensions.DependencyInjection.IServiceCollection services, GFramework.Core.Abstractions.Logging.ILogger logger); + } + + [AttributeUsage(AttributeTargets.Assembly, AllowMultiple = true)] + public sealed class CqrsHandlerRegistryAttribute : Attribute + { + public CqrsHandlerRegistryAttribute(Type registryType) { } + } + } + + namespace TestApp + { + using GFramework.Cqrs.Abstractions.Cqrs; + + public sealed record VisibleRequest() : IRequest; + + public sealed class Container + { + private sealed record HiddenResponse(); + + private sealed record HiddenRequest() : IRequest; + + private sealed class HiddenMixedHandler : + IRequestHandler, + IRequestHandler + { + } + } + } + """; + + await GeneratorTest.RunAsync( + source, + ("CqrsHandlerRegistry.g.cs", MixedReflectedImplementationAndPreciseRegistrationsExpected)); + } + + /// + /// 验证当外部基类暴露的 handler interface 含有生成注册器顶层上下文不可直接引用的 protected 类型时, + /// 生成器会保留已知直注册,并只对剩余未知接口做本地 interface discovery。 + /// + [Test] + public void Generates_Partial_Runtime_Interface_Discovery_For_Inaccessible_External_Protected_Types() + { + const string contractsSource = """ + namespace GFramework.Cqrs.Abstractions.Cqrs + { + public interface IRequest { } + public interface INotification { } + public interface IStreamRequest { } + + public interface IRequestHandler where TRequest : IRequest { } + public interface INotificationHandler where TNotification : INotification { } + public interface IStreamRequestHandler where TRequest : IStreamRequest { } + } + """; + + const string dependencySource = """ + using GFramework.Cqrs.Abstractions.Cqrs; + + namespace Dep; + + public sealed record VisibleRequest() : IRequest; + + public abstract class VisibilityScope + { + protected internal sealed record ProtectedResponse(); + + protected internal sealed record ProtectedRequest() : IRequest; + } + + public abstract class HandlerBase : + VisibilityScope, + IRequestHandler, + IRequestHandler + { + } + """; + + const string source = """ + using System; + using Dep; + + namespace Microsoft.Extensions.DependencyInjection + { + public interface IServiceCollection { } + + public static class ServiceCollectionServiceExtensions + { + public static void AddTransient(IServiceCollection services, Type serviceType, Type implementationType) { } + } + } + + namespace GFramework.Core.Abstractions.Logging + { + public interface ILogger + { + void Debug(string msg); + } + } + + namespace GFramework.Cqrs + { + public interface ICqrsHandlerRegistry + { + void Register(Microsoft.Extensions.DependencyInjection.IServiceCollection services, GFramework.Core.Abstractions.Logging.ILogger logger); + } + + [AttributeUsage(AttributeTargets.Assembly, AllowMultiple = true)] + public sealed class CqrsHandlerRegistryAttribute : Attribute + { + public CqrsHandlerRegistryAttribute(Type registryType) { } + } + } + + namespace TestApp + { + public sealed class DerivedHandler : HandlerBase + { + } + } + """; + + var contractsReference = MetadataReferenceTestBuilder.CreateFromSource( + "Contracts", + contractsSource); + var dependencyReference = MetadataReferenceTestBuilder.CreateFromSource( + "Dependency", + dependencySource, + contractsReference); + var generatedSource = RunGenerator( + source, + contractsReference, + dependencyReference); + + Assert.Multiple(() => + { + Assert.That( + generatedSource, + Does.Contain("var implementationType0 = typeof(global::TestApp.DerivedHandler);")); + Assert.That( + generatedSource, + Does.Contain( + "var knownServiceTypes0 = new global::System.Collections.Generic.HashSet();")); + Assert.That( + generatedSource, + Does.Contain( + "knownServiceTypes0.Add(typeof(global::GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler));")); + Assert.That( + generatedSource, + Does.Contain( + "RegisterRemainingReflectedHandlerInterfaces(services, logger, implementationType0, knownServiceTypes0);")); + Assert.That( + generatedSource, + Does.Contain("if (knownServiceTypes.Contains(handlerInterface))")); + Assert.That( + generatedSource, + Does.Contain( + "Registered CQRS handler TestApp.DerivedHandler as GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler.")); + Assert.That( + generatedSource, + Does.Not.Contain( + "typeof(global::GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler /// 验证即使 runtime 仍暴露旧版无参 fallback marker,生成器也会优先在生成注册器内部处理隐藏 handler, /// 不再输出 fallback marker。 @@ -753,4 +1128,42 @@ public class CqrsHandlerRegistryGeneratorTests Assert.That(escaped, Is.EqualTo(expected)); } + + /// + /// 运行 CQRS handler registry generator,并返回单个生成文件的源码文本。 + /// + private static string RunGenerator( + string source, + params MetadataReference[] additionalReferences) + { + var syntaxTree = CSharpSyntaxTree.ParseText(source); + var compilation = CSharpCompilation.Create( + "TestProject", + [syntaxTree], + MetadataReferenceTestBuilder.GetRuntimeMetadataReferences().AddRange(additionalReferences), + new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary)); + + GeneratorDriver driver = CSharpGeneratorDriver.Create( + generators: [new CqrsHandlerRegistryGenerator().AsSourceGenerator()], + parseOptions: (CSharpParseOptions)syntaxTree.Options); + driver = driver.RunGeneratorsAndUpdateCompilation( + compilation, + out var updatedCompilation, + out _); + + var compilationErrors = updatedCompilation.GetDiagnostics() + .Where(static diagnostic => diagnostic.Severity == DiagnosticSeverity.Error) + .ToArray(); + Assert.That( + compilationErrors, + Is.Empty, + () => + $"编译生成的代码时出现错误:{Environment.NewLine}{string.Join(Environment.NewLine, compilationErrors.Select(static diagnostic => diagnostic.ToString()))}"); + + var runResult = driver.GetRunResult(); + Assert.That(runResult.Results, Has.Length.EqualTo(1)); + Assert.That(runResult.Results[0].GeneratedSources, Has.Length.EqualTo(1)); + + return runResult.Results[0].GeneratedSources[0].SourceText.ToString(); + } } diff --git a/GFramework.SourceGenerators.Tests/GlobalUsings.cs b/GFramework.SourceGenerators.Tests/GlobalUsings.cs index 78c09fee..03266a66 100644 --- a/GFramework.SourceGenerators.Tests/GlobalUsings.cs +++ b/GFramework.SourceGenerators.Tests/GlobalUsings.cs @@ -20,4 +20,5 @@ global using Microsoft.CodeAnalysis; global using Microsoft.CodeAnalysis.Text; global using Microsoft.CodeAnalysis.CSharp.Testing; global using Microsoft.CodeAnalysis.Testing; -global using NUnit.Framework; \ No newline at end of file +global using Microsoft.CodeAnalysis.CSharp; +global using NUnit.Framework; diff --git a/GFramework.SourceGenerators/Cqrs/CqrsHandlerRegistryGenerator.cs b/GFramework.SourceGenerators/Cqrs/CqrsHandlerRegistryGenerator.cs index 36a72be5..3bd3571d 100644 --- a/GFramework.SourceGenerators/Cqrs/CqrsHandlerRegistryGenerator.cs +++ b/GFramework.SourceGenerators/Cqrs/CqrsHandlerRegistryGenerator.cs @@ -86,15 +86,17 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator var implementationTypeDisplayName = type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); var implementationLogName = GetLogDisplayName(type); - var canReferenceImplementation = CanReferenceFromGeneratedRegistry(type); + var canReferenceImplementation = CanReferenceFromGeneratedRegistry(context.SemanticModel.Compilation, type); var registrations = ImmutableArray.CreateBuilder(handlerInterfaces.Length); var reflectedImplementationRegistrations = ImmutableArray.CreateBuilder(handlerInterfaces.Length); var preciseReflectedRegistrations = ImmutableArray.CreateBuilder(handlerInterfaces.Length); + var requiresRuntimeInterfaceDiscovery = false; foreach (var handlerInterface in handlerInterfaces) { - var canReferenceHandlerInterface = CanReferenceFromGeneratedRegistry(handlerInterface); + var canReferenceHandlerInterface = + CanReferenceFromGeneratedRegistry(context.SemanticModel.Compilation, handlerInterface); if (canReferenceImplementation && canReferenceHandlerInterface) { registrations.Add(new HandlerRegistrationSpec( @@ -122,16 +124,10 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator continue; } - // Some closed handler interfaces still contain runtime-only type shapes such as arrays closed over - // non-public element types. For those rare cases keep the narrow implementation lookup, but let the - // generated registry discover the exact supported interfaces from the implementation type at runtime. - return new HandlerCandidateAnalysis( - implementationTypeDisplayName, - implementationLogName, - ImmutableArray.Empty, - ImmutableArray.Empty, - ImmutableArray.Empty, - GetReflectionTypeMetadataName(type)); + // 某些关闭 handler interface 仍包含只能在实现类型运行时语义里解析的类型形态。 + // 对这些边角场景保留“已知接口静态注册 + 剩余接口运行时补洞”的组合路径, + // 避免单个未知接口把同实现上的其它已知注册全部拖回整实现反射发现。 + requiresRuntimeInterfaceDiscovery = true; } return new HandlerCandidateAnalysis( @@ -140,7 +136,8 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator registrations.ToImmutable(), reflectedImplementationRegistrations.ToImmutable(), preciseReflectedRegistrations.ToImmutable(), - canReferenceImplementation ? null : GetReflectionTypeMetadataName(type)); + canReferenceImplementation ? null : GetReflectionTypeMetadataName(type), + requiresRuntimeInterfaceDiscovery); } private static void Execute(SourceProductionContext context, GenerationEnvironment generationEnvironment, @@ -184,7 +181,8 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator candidate.Registrations, candidate.ReflectedImplementationRegistrations, candidate.PreciseReflectedRegistrations, - candidate.ReflectionTypeMetadataName)); + candidate.ReflectionTypeMetadataName, + candidate.RequiresRuntimeInterfaceDiscovery)); } registrations.Sort(static (left, right) => @@ -295,7 +293,7 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator ITypeSymbol type, out RuntimeTypeReferenceSpec? runtimeTypeReference) { - if (CanReferenceFromGeneratedRegistry(type)) + if (CanReferenceFromGeneratedRegistry(compilation, type)) { runtimeTypeReference = RuntimeTypeReferenceSpec.FromDirectReference( type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)); @@ -369,7 +367,7 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator out RuntimeTypeReferenceSpec? genericTypeDefinitionReference) { var genericTypeDefinition = genericNamedType.OriginalDefinition; - if (CanReferenceFromGeneratedRegistry(genericTypeDefinition)) + if (CanReferenceFromGeneratedRegistry(compilation, genericTypeDefinition)) { genericTypeDefinitionReference = RuntimeTypeReferenceSpec.FromDirectReference( genericTypeDefinition @@ -389,43 +387,34 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator return false; } - private static bool CanReferenceFromGeneratedRegistry(ITypeSymbol type) + private static bool CanReferenceFromGeneratedRegistry(Compilation compilation, ITypeSymbol type) { switch (type) { case IArrayTypeSymbol arrayType: - return CanReferenceFromGeneratedRegistry(arrayType.ElementType); + return CanReferenceFromGeneratedRegistry(compilation, arrayType.ElementType); case INamedTypeSymbol namedType: - if (!IsTypeChainAccessible(namedType)) + if (!compilation.IsSymbolAccessibleWithin(namedType, compilation.Assembly, throughType: null)) return false; - return namedType.TypeArguments.All(CanReferenceFromGeneratedRegistry); + foreach (var typeArgument in namedType.TypeArguments) + { + if (!CanReferenceFromGeneratedRegistry(compilation, typeArgument)) + return false; + } + + return true; case IPointerTypeSymbol pointerType: - return CanReferenceFromGeneratedRegistry(pointerType.PointedAtType); + return CanReferenceFromGeneratedRegistry(compilation, pointerType.PointedAtType); case ITypeParameterSymbol: return false; default: + // Treat other Roslyn type kinds, such as dynamic or unresolved error types, as referenceable for now. + // If a real-world case proves unsafe, tighten this branch instead of broadening the named-type path above. return true; } } - private static bool IsTypeChainAccessible(INamedTypeSymbol type) - { - for (var current = type; current is not null; current = current.ContainingType) - { - if (!IsSymbolAccessible(current)) - return false; - } - - return true; - } - - private static bool IsSymbolAccessible(ISymbol symbol) - { - return symbol.DeclaredAccessibility is Accessibility.Public or Accessibility.Internal - or Accessibility.ProtectedOrInternal; - } - private static string GetFullyQualifiedMetadataName(INamedTypeSymbol type) { var nestedTypes = new Stack(); @@ -496,10 +485,8 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator !registration.ReflectedImplementationRegistrations.IsDefaultOrEmpty); var hasPreciseReflectedRegistrations = registrations.Any(static registration => !registration.PreciseReflectedRegistrations.IsDefaultOrEmpty); - var hasFullReflectionRegistrations = registrations.Any(static registration => - !string.IsNullOrWhiteSpace(registration.ReflectionTypeMetadataName) && - registration.ReflectedImplementationRegistrations.IsDefaultOrEmpty && - registration.PreciseReflectedRegistrations.IsDefaultOrEmpty); + var hasRuntimeInterfaceDiscovery = registrations.Any(static registration => + registration.RequiresRuntimeInterfaceDiscovery); var builder = new StringBuilder(); builder.AppendLine("// "); builder.AppendLine("#nullable enable"); @@ -533,7 +520,8 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator builder.AppendLine(" if (logger is null)"); builder.AppendLine(" throw new global::System.ArgumentNullException(nameof(logger));"); if (hasReflectedImplementationRegistrations || hasPreciseReflectedRegistrations || - hasFullReflectionRegistrations) + registrations.Any(static registration => + !string.IsNullOrWhiteSpace(registration.ReflectionTypeMetadataName))) { builder.AppendLine(); builder.Append(" var registryAssembly = typeof(global::"); @@ -549,46 +537,21 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator for (var registrationIndex = 0; registrationIndex < registrations.Count; registrationIndex++) { var registration = registrations[registrationIndex]; - if (!registration.ReflectedImplementationRegistrations.IsDefaultOrEmpty) + if (!registration.ReflectedImplementationRegistrations.IsDefaultOrEmpty || + !registration.PreciseReflectedRegistrations.IsDefaultOrEmpty || + registration.RequiresRuntimeInterfaceDiscovery) { - AppendReflectedImplementationRegistrations(builder, registration, registrationIndex); - continue; + AppendOrderedImplementationRegistrations(builder, registration, registrationIndex); } - - if (!registration.PreciseReflectedRegistrations.IsDefaultOrEmpty) + else if (!registration.DirectRegistrations.IsDefaultOrEmpty) { - AppendPreciseReflectedRegistrations(builder, registration, registrationIndex); - continue; - } - - if (!string.IsNullOrWhiteSpace(registration.ReflectionTypeMetadataName)) - { - AppendReflectionRegistration(builder, registration.ReflectionTypeMetadataName!); - continue; - } - - foreach (var directRegistration in registration.DirectRegistrations) - { - builder.AppendLine( - " global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient("); - builder.AppendLine(" services,"); - builder.Append(" typeof("); - builder.Append(directRegistration.HandlerInterfaceDisplayName); - builder.AppendLine("),"); - builder.Append(" typeof("); - builder.Append(directRegistration.ImplementationTypeDisplayName); - builder.AppendLine("));"); - builder.Append(" logger.Debug(\"Registered CQRS handler "); - builder.Append(EscapeStringLiteral(directRegistration.ImplementationLogName)); - builder.Append(" as "); - builder.Append(EscapeStringLiteral(directRegistration.HandlerInterfaceLogName)); - builder.AppendLine(".\");"); + AppendDirectRegistrations(builder, registration); } } builder.AppendLine(" }"); - if (hasFullReflectionRegistrations) + if (hasRuntimeInterfaceDiscovery) { builder.AppendLine(); AppendReflectionHelpers(builder); @@ -598,56 +561,73 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator return builder.ToString(); } - private static void AppendReflectionRegistration(StringBuilder builder, string reflectionTypeMetadataName) - { - builder.Append(" RegisterReflectedHandler(services, logger, registryAssembly, \""); - builder.Append(EscapeStringLiteral(reflectionTypeMetadataName)); - builder.AppendLine("\");"); - } - - private static void AppendReflectedImplementationRegistrations( + private static void AppendDirectRegistrations( StringBuilder builder, - ImplementationRegistrationSpec registration, - int registrationIndex) + ImplementationRegistrationSpec registration) { - var implementationVariableName = $"implementationType{registrationIndex}"; - builder.Append(" var "); - builder.Append(implementationVariableName); - builder.Append(" = registryAssembly.GetType(\""); - builder.Append(EscapeStringLiteral(registration.ReflectionTypeMetadataName!)); - builder.AppendLine("\", throwOnError: false, ignoreCase: false);"); - builder.Append(" if ("); - builder.Append(implementationVariableName); - builder.AppendLine(" is not null)"); - builder.AppendLine(" {"); - - foreach (var reflectedRegistration in registration.ReflectedImplementationRegistrations) + foreach (var directRegistration in registration.DirectRegistrations) { builder.AppendLine( - " global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient("); - builder.AppendLine(" services,"); - builder.Append(" typeof("); - builder.Append(reflectedRegistration.HandlerInterfaceDisplayName); + " global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient("); + builder.AppendLine(" services,"); + builder.Append(" typeof("); + builder.Append(directRegistration.HandlerInterfaceDisplayName); builder.AppendLine("),"); - builder.Append(" "); - builder.Append(implementationVariableName); - builder.AppendLine(");"); - builder.Append(" logger.Debug(\"Registered CQRS handler "); - builder.Append(EscapeStringLiteral(registration.ImplementationLogName)); + builder.Append(" typeof("); + builder.Append(directRegistration.ImplementationTypeDisplayName); + builder.AppendLine("));"); + builder.Append(" logger.Debug(\"Registered CQRS handler "); + builder.Append(EscapeStringLiteral(directRegistration.ImplementationLogName)); builder.Append(" as "); - builder.Append(EscapeStringLiteral(reflectedRegistration.HandlerInterfaceLogName)); + builder.Append(EscapeStringLiteral(directRegistration.HandlerInterfaceLogName)); builder.AppendLine(".\");"); } - - builder.AppendLine(" }"); } - private static void AppendPreciseReflectedRegistrations( + private static void AppendOrderedImplementationRegistrations( StringBuilder builder, ImplementationRegistrationSpec registration, int registrationIndex) { + var orderedRegistrations = + new List<(string HandlerInterfaceLogName, OrderedRegistrationKind Kind, int Index)>( + registration.DirectRegistrations.Length + + registration.ReflectedImplementationRegistrations.Length + + registration.PreciseReflectedRegistrations.Length); + + for (var directIndex = 0; directIndex < registration.DirectRegistrations.Length; directIndex++) + { + orderedRegistrations.Add(( + registration.DirectRegistrations[directIndex].HandlerInterfaceLogName, + OrderedRegistrationKind.Direct, + directIndex)); + } + + for (var reflectedIndex = 0; + reflectedIndex < registration.ReflectedImplementationRegistrations.Length; + reflectedIndex++) + { + orderedRegistrations.Add(( + registration.ReflectedImplementationRegistrations[reflectedIndex].HandlerInterfaceLogName, + OrderedRegistrationKind.ReflectedImplementation, + reflectedIndex)); + } + + for (var preciseIndex = 0; + preciseIndex < registration.PreciseReflectedRegistrations.Length; + preciseIndex++) + { + orderedRegistrations.Add(( + registration.PreciseReflectedRegistrations[preciseIndex].HandlerInterfaceLogName, + OrderedRegistrationKind.PreciseReflected, + preciseIndex)); + } + + orderedRegistrations.Sort(static (left, right) => + StringComparer.Ordinal.Compare(left.HandlerInterfaceLogName, right.HandlerInterfaceLogName)); + var implementationVariableName = $"implementationType{registrationIndex}"; + var knownServiceTypesVariableName = $"knownServiceTypes{registrationIndex}"; if (string.IsNullOrWhiteSpace(registration.ReflectionTypeMetadataName)) { builder.Append(" var "); @@ -658,11 +638,10 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator } else { - var implementationReflectionTypeMetadataName = registration.ReflectionTypeMetadataName!; builder.Append(" var "); builder.Append(implementationVariableName); builder.Append(" = registryAssembly.GetType(\""); - builder.Append(EscapeStringLiteral(implementationReflectionTypeMetadataName)); + builder.Append(EscapeStringLiteral(registration.ReflectionTypeMetadataName!)); builder.AppendLine("\", throwOnError: false, ignoreCase: false);"); } @@ -671,21 +650,98 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator builder.AppendLine(" is not null)"); builder.AppendLine(" {"); - for (var registrationOffset = 0; - registrationOffset < registration.PreciseReflectedRegistrations.Length; - registrationOffset++) + if (registration.RequiresRuntimeInterfaceDiscovery) { - var reflectedRegistration = registration.PreciseReflectedRegistrations[registrationOffset]; - var registrationVariablePrefix = $"serviceType{registrationIndex}_{registrationOffset}"; - AppendPreciseReflectedTypeResolution( - builder, - reflectedRegistration.ServiceTypeArguments, - registrationVariablePrefix, - implementationVariableName, - reflectedRegistration.OpenHandlerTypeDisplayName, - registration.ImplementationLogName, - reflectedRegistration.HandlerInterfaceLogName, - 3); + builder.Append(" var "); + builder.Append(knownServiceTypesVariableName); + builder.AppendLine(" = new global::System.Collections.Generic.HashSet();"); + } + + foreach (var orderedRegistration in orderedRegistrations) + { + switch (orderedRegistration.Kind) + { + case OrderedRegistrationKind.Direct: + var directRegistration = registration.DirectRegistrations[orderedRegistration.Index]; + if (registration.RequiresRuntimeInterfaceDiscovery) + { + builder.Append(" "); + builder.Append(knownServiceTypesVariableName); + builder.Append(".Add(typeof("); + builder.Append(directRegistration.HandlerInterfaceDisplayName); + builder.AppendLine("));"); + } + + builder.AppendLine( + " global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient("); + builder.AppendLine(" services,"); + builder.Append(" typeof("); + builder.Append(directRegistration.HandlerInterfaceDisplayName); + builder.AppendLine("),"); + builder.Append(" "); + builder.Append(implementationVariableName); + builder.AppendLine(");"); + builder.Append(" logger.Debug(\"Registered CQRS handler "); + builder.Append(EscapeStringLiteral(registration.ImplementationLogName)); + builder.Append(" as "); + builder.Append(EscapeStringLiteral(directRegistration.HandlerInterfaceLogName)); + builder.AppendLine(".\");"); + break; + case OrderedRegistrationKind.ReflectedImplementation: + var reflectedRegistration = + registration.ReflectedImplementationRegistrations[orderedRegistration.Index]; + if (registration.RequiresRuntimeInterfaceDiscovery) + { + builder.Append(" "); + builder.Append(knownServiceTypesVariableName); + builder.Append(".Add(typeof("); + builder.Append(reflectedRegistration.HandlerInterfaceDisplayName); + builder.AppendLine("));"); + } + + builder.AppendLine( + " global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient("); + builder.AppendLine(" services,"); + builder.Append(" typeof("); + builder.Append(reflectedRegistration.HandlerInterfaceDisplayName); + builder.AppendLine("),"); + builder.Append(" "); + builder.Append(implementationVariableName); + builder.AppendLine(");"); + builder.Append(" logger.Debug(\"Registered CQRS handler "); + builder.Append(EscapeStringLiteral(registration.ImplementationLogName)); + builder.Append(" as "); + builder.Append(EscapeStringLiteral(reflectedRegistration.HandlerInterfaceLogName)); + builder.AppendLine(".\");"); + break; + case OrderedRegistrationKind.PreciseReflected: + var preciseRegistration = registration.PreciseReflectedRegistrations[orderedRegistration.Index]; + var registrationVariablePrefix = $"serviceType{registrationIndex}_{orderedRegistration.Index}"; + AppendPreciseReflectedTypeResolution( + builder, + preciseRegistration.ServiceTypeArguments, + registrationVariablePrefix, + implementationVariableName, + preciseRegistration.OpenHandlerTypeDisplayName, + registration.ImplementationLogName, + preciseRegistration.HandlerInterfaceLogName, + knownServiceTypesVariableName, + registration.RequiresRuntimeInterfaceDiscovery, + 3); + break; + default: + throw new InvalidOperationException( + $"Unsupported ordered CQRS registration kind {orderedRegistration.Kind}."); + } + } + + if (registration.RequiresRuntimeInterfaceDiscovery) + { + builder.Append(" RegisterRemainingReflectedHandlerInterfaces(services, logger, "); + builder.Append(implementationVariableName); + builder.Append(", "); + builder.Append(knownServiceTypesVariableName); + builder.AppendLine(");"); } builder.AppendLine(" }"); @@ -699,6 +755,8 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator string openHandlerTypeDisplayName, string implementationLogName, string handlerInterfaceLogName, + string knownServiceTypesVariableName, + bool trackKnownServiceTypes, int indentLevel) { var indent = new string(' ', indentLevel * 4); @@ -763,6 +821,15 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator builder.Append(" "); builder.Append(implementationVariableName); builder.AppendLine(");"); + if (trackKnownServiceTypes) + { + builder.Append(indent); + builder.Append(knownServiceTypesVariableName); + builder.Append(".Add("); + builder.Append(registrationVariablePrefix); + builder.AppendLine(");"); + } + builder.Append(indent); builder.Append("logger.Debug(\"Registered CQRS handler "); builder.Append(EscapeStringLiteral(implementationLogName)); @@ -839,15 +906,11 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator private static void AppendReflectionHelpers(StringBuilder builder) { - // Emit the runtime helper methods only when at least one handler requires metadata-name lookup. + // Emit the runtime helper methods only when at least one handler still needs implementation-scoped + // interface discovery after all direct / precise registrations have been emitted. builder.AppendLine( - " private static void RegisterReflectedHandler(global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, global::GFramework.Core.Abstractions.Logging.ILogger logger, global::System.Reflection.Assembly registryAssembly, string implementationTypeMetadataName)"); + " private static void RegisterRemainingReflectedHandlerInterfaces(global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, global::GFramework.Core.Abstractions.Logging.ILogger logger, global::System.Type implementationType, global::System.Collections.Generic.ISet knownServiceTypes)"); builder.AppendLine(" {"); - builder.AppendLine( - " var implementationType = registryAssembly.GetType(implementationTypeMetadataName, throwOnError: false, ignoreCase: false);"); - builder.AppendLine(" if (implementationType is null)"); - builder.AppendLine(" return;"); - builder.AppendLine(); builder.AppendLine(" var handlerInterfaces = implementationType.GetInterfaces();"); builder.AppendLine(" global::System.Array.Sort(handlerInterfaces, CompareTypes);"); builder.AppendLine(); @@ -856,6 +919,9 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator builder.AppendLine(" if (!IsSupportedHandlerInterface(handlerInterface))"); builder.AppendLine(" continue;"); builder.AppendLine(); + builder.AppendLine(" if (knownServiceTypes.Contains(handlerInterface))"); + builder.AppendLine(" continue;"); + builder.AppendLine(); builder.AppendLine( " global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient("); builder.AppendLine(" services,"); @@ -863,6 +929,7 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator builder.AppendLine(" implementationType);"); builder.AppendLine( " logger.Debug($\"Registered CQRS handler {GetRuntimeTypeDisplayName(implementationType)} as {GetRuntimeTypeDisplayName(handlerInterface)}.\");"); + builder.AppendLine(" knownServiceTypes.Add(handlerInterface);"); builder.AppendLine(" }"); builder.AppendLine(" }"); builder.AppendLine(); @@ -969,6 +1036,13 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator string HandlerInterfaceDisplayName, string HandlerInterfaceLogName); + private enum OrderedRegistrationKind + { + Direct, + ReflectedImplementation, + PreciseReflected + } + private sealed record RuntimeTypeReferenceSpec( string? TypeDisplayName, string? ReflectionTypeMetadataName, @@ -1015,7 +1089,8 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator ImmutableArray DirectRegistrations, ImmutableArray ReflectedImplementationRegistrations, ImmutableArray PreciseReflectedRegistrations, - string? ReflectionTypeMetadataName); + string? ReflectionTypeMetadataName, + bool RequiresRuntimeInterfaceDiscovery); private readonly struct HandlerCandidateAnalysis : IEquatable { @@ -1025,7 +1100,8 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator ImmutableArray registrations, ImmutableArray reflectedImplementationRegistrations, ImmutableArray preciseReflectedRegistrations, - string? reflectionTypeMetadataName) + string? reflectionTypeMetadataName, + bool requiresRuntimeInterfaceDiscovery) { ImplementationTypeDisplayName = implementationTypeDisplayName; ImplementationLogName = implementationLogName; @@ -1033,6 +1109,7 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator ReflectedImplementationRegistrations = reflectedImplementationRegistrations; PreciseReflectedRegistrations = preciseReflectedRegistrations; ReflectionTypeMetadataName = reflectionTypeMetadataName; + RequiresRuntimeInterfaceDiscovery = requiresRuntimeInterfaceDiscovery; } public string ImplementationTypeDisplayName { get; } @@ -1047,6 +1124,8 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator public string? ReflectionTypeMetadataName { get; } + public bool RequiresRuntimeInterfaceDiscovery { get; } + public bool Equals(HandlerCandidateAnalysis other) { if (!string.Equals(ImplementationTypeDisplayName, other.ImplementationTypeDisplayName, @@ -1054,6 +1133,7 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator !string.Equals(ImplementationLogName, other.ImplementationLogName, StringComparison.Ordinal) || !string.Equals(ReflectionTypeMetadataName, other.ReflectionTypeMetadataName, StringComparison.Ordinal) || + RequiresRuntimeInterfaceDiscovery != other.RequiresRuntimeInterfaceDiscovery || Registrations.Length != other.Registrations.Length || ReflectedImplementationRegistrations.Length != other.ReflectedImplementationRegistrations.Length || PreciseReflectedRegistrations.Length != other.PreciseReflectedRegistrations.Length) @@ -1098,6 +1178,7 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator (ReflectionTypeMetadataName is null ? 0 : StringComparer.Ordinal.GetHashCode(ReflectionTypeMetadataName)); + hashCode = (hashCode * 397) ^ RequiresRuntimeInterfaceDiscovery.GetHashCode(); foreach (var registration in Registrations) { hashCode = (hashCode * 397) ^ registration.GetHashCode();