Merge pull request #232 from GeWuYou/refactor/cqrs-architecture-decoupling-todo-10

Refactor/cqrs architecture decoupling todo 10
This commit is contained in:
gewuyou 2026-04-16 19:14:24 +08:00 committed by GitHub
commit e96623b7f5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 878 additions and 181 deletions

View File

@ -7,11 +7,14 @@ using GFramework.Cqrs.Abstractions.Cqrs;
namespace GFramework.Cqrs.Tests.Cqrs;
/// <summary>
/// 验证 CQRS dispatcher 会缓存热路径中的服务类型构造结果
/// 验证 CQRS dispatcher 会缓存热路径中的服务类型与调用委托
/// </summary>
[TestFixture]
internal sealed class CqrsDispatcherCacheTests
{
private MicrosoftDiContainer? _container;
private ArchitectureContext? _context;
/// <summary>
/// 初始化测试上下文。
/// </summary>
@ -29,6 +32,7 @@ internal sealed class CqrsDispatcherCacheTests
_container.Freeze();
_context = new ArchitectureContext(_container);
ClearDispatcherCaches();
}
/// <summary>
@ -41,11 +45,8 @@ internal sealed class CqrsDispatcherCacheTests
_container = null;
}
private MicrosoftDiContainer? _container;
private ArchitectureContext? _context;
/// <summary>
/// 验证相同消息类型重复分发时,不会重复扩张服务类型缓存。
/// 验证相同消息类型重复分发时,不会重复扩张服务类型与调用委托缓存。
/// </summary>
[Test]
public async Task Dispatcher_Should_Cache_Service_Types_After_First_Dispatch()
@ -53,8 +54,8 @@ internal sealed class CqrsDispatcherCacheTests
var notificationServiceTypes = GetCacheField("NotificationHandlerServiceTypes");
var requestServiceTypes = GetCacheField("RequestServiceTypes");
var streamServiceTypes = GetCacheField("StreamHandlerServiceTypes");
var requestInvokers = GetCacheField("RequestInvokers");
var requestPipelineInvokers = GetCacheField("RequestPipelineInvokers");
var requestInvokers = GetGenericCacheField("RequestInvokerCache`1", typeof(int), "Invokers");
var requestPipelineInvokers = GetGenericCacheField("RequestPipelineInvokerCache`1", typeof(int), "Invokers");
var notificationInvokers = GetCacheField("NotificationInvokers");
var streamInvokers = GetCacheField("StreamInvokers");
@ -104,14 +105,42 @@ internal sealed class CqrsDispatcherCacheTests
});
}
/// <summary>
/// 验证 request 调用委托会按响应类型分别缓存,避免不同响应类型共用 object 结果桥接。
/// </summary>
[Test]
public async Task Dispatcher_Should_Cache_Request_Invokers_Per_Response_Type()
{
var intRequestInvokers = GetGenericCacheField("RequestInvokerCache`1", typeof(int), "Invokers");
var stringRequestInvokers = GetGenericCacheField("RequestInvokerCache`1", typeof(string), "Invokers");
var intBefore = intRequestInvokers.Count;
var stringBefore = stringRequestInvokers.Count;
await _context!.SendRequestAsync(new DispatcherCacheRequest());
await _context.SendRequestAsync(new DispatcherStringCacheRequest());
var intAfterFirstDispatch = intRequestInvokers.Count;
var stringAfterFirstDispatch = stringRequestInvokers.Count;
await _context.SendRequestAsync(new DispatcherCacheRequest());
await _context.SendRequestAsync(new DispatcherStringCacheRequest());
Assert.Multiple(() =>
{
Assert.That(intAfterFirstDispatch, Is.EqualTo(intBefore + 1));
Assert.That(stringAfterFirstDispatch, Is.EqualTo(stringBefore + 1));
Assert.That(intRequestInvokers.Count, Is.EqualTo(intAfterFirstDispatch));
Assert.That(stringRequestInvokers.Count, Is.EqualTo(stringAfterFirstDispatch));
});
}
/// <summary>
/// 通过反射读取 dispatcher 的静态缓存字典。
/// </summary>
private static IDictionary GetCacheField(string fieldName)
{
var dispatcherType = typeof(CqrsReflectionFallbackAttribute).Assembly
.GetType("GFramework.Cqrs.Internal.CqrsDispatcher", throwOnError: true)!;
var dispatcherType = GetDispatcherType();
var field = dispatcherType.GetField(
fieldName,
BindingFlags.NonPublic | BindingFlags.Static);
@ -123,6 +152,57 @@ internal sealed class CqrsDispatcherCacheTests
$"Dispatcher cache field {fieldName} does not implement IDictionary.");
}
/// <summary>
/// 清空本测试依赖的 dispatcher 静态缓存,避免跨用例共享进程级状态导致断言漂移。
/// </summary>
private static void ClearDispatcherCaches()
{
GetCacheField("NotificationHandlerServiceTypes").Clear();
GetCacheField("RequestServiceTypes").Clear();
GetCacheField("StreamHandlerServiceTypes").Clear();
GetCacheField("NotificationInvokers").Clear();
GetCacheField("StreamInvokers").Clear();
GetGenericCacheField("RequestInvokerCache`1", typeof(int), "Invokers").Clear();
GetGenericCacheField("RequestInvokerCache`1", typeof(string), "Invokers").Clear();
GetGenericCacheField("RequestPipelineInvokerCache`1", typeof(int), "Invokers").Clear();
GetGenericCacheField("RequestPipelineInvokerCache`1", typeof(string), "Invokers").Clear();
}
/// <summary>
/// 通过反射读取 dispatcher 嵌套泛型缓存类型上的静态缓存字典。
/// </summary>
private static IDictionary GetGenericCacheField(string nestedTypeName, Type genericTypeArgument, string fieldName)
{
var nestedGenericType = GetDispatcherType().GetNestedType(
nestedTypeName,
BindingFlags.NonPublic);
Assert.That(nestedGenericType, Is.Not.Null, $"Missing dispatcher nested cache type {nestedTypeName}.");
var closedNestedType = nestedGenericType!.MakeGenericType(genericTypeArgument);
var field = closedNestedType.GetField(
fieldName,
BindingFlags.NonPublic | BindingFlags.Static);
Assert.That(
field,
Is.Not.Null,
$"Missing dispatcher nested cache field {nestedTypeName}.{fieldName} for {genericTypeArgument.FullName}.");
return field!.GetValue(null) as IDictionary
?? throw new InvalidOperationException(
$"Dispatcher nested cache field {nestedTypeName}.{fieldName} does not implement IDictionary.");
}
/// <summary>
/// 获取 CQRS dispatcher 运行时类型。
/// </summary>
private static Type GetDispatcherType()
{
return typeof(CqrsReflectionFallbackAttribute).Assembly
.GetType("GFramework.Cqrs.Internal.CqrsDispatcher", throwOnError: true)!;
}
/// <summary>
/// 消费整个异步流,确保建流路径被真实执行。
/// </summary>
@ -154,6 +234,11 @@ internal sealed record DispatcherCacheStreamRequest : IStreamRequest<int>;
/// </summary>
internal sealed record DispatcherPipelineCacheRequest : IRequest<int>;
/// <summary>
/// 用于验证按响应类型分层 request invoker 缓存的测试请求。
/// </summary>
internal sealed record DispatcherStringCacheRequest : IRequest<string>;
/// <summary>
/// 处理 <see cref="DispatcherCacheRequest" />。
/// </summary>
@ -213,6 +298,20 @@ internal sealed class DispatcherPipelineCacheRequestHandler : IRequestHandler<Di
}
}
/// <summary>
/// 处理 <see cref="DispatcherStringCacheRequest" />。
/// </summary>
internal sealed class DispatcherStringCacheRequestHandler : IRequestHandler<DispatcherStringCacheRequest, string>
{
/// <summary>
/// 返回固定字符串,供按响应类型缓存测试验证 string 路径。
/// </summary>
public ValueTask<string> Handle(DispatcherStringCacheRequest request, CancellationToken cancellationToken)
{
return ValueTask.FromResult("dispatcher-cache");
}
}
/// <summary>
/// 为 <see cref="DispatcherPipelineCacheRequest" /> 提供最小 pipeline 行为,
/// 用于命中 dispatcher 的 pipeline invoker 缓存分支。

View File

@ -15,16 +15,6 @@ internal sealed class CqrsDispatcher(
IIocContainer container,
ILogger logger) : ICqrsRuntime
{
// 进程级缓存:按请求/响应类型缓存直接处理器调用委托,避免热路径重复反射。
// 线程安全依赖 ConcurrentDictionary缓存与进程同寿命默认假设请求类型集合有限且稳定。
private static readonly ConcurrentDictionary<(Type RequestType, Type ResponseType), RequestInvoker>
RequestInvokers = new();
// 进程级缓存:缓存带 pipeline 的请求调用委托,减少每次分发时的反射与表达式重建开销。
// 若后续引入动态生成请求类型,需要重新评估该缓存的增长边界。
private static readonly ConcurrentDictionary<(Type RequestType, Type ResponseType), RequestPipelineInvoker>
RequestPipelineInvokers = new();
// 进程级缓存:缓存通知调用委托,复用并发安全字典以支撑多线程发布路径。
private static readonly ConcurrentDictionary<Type, NotificationInvoker> NotificationInvokers = new();
@ -131,20 +121,18 @@ internal sealed class CqrsDispatcher(
if (behaviors.Count == 0)
{
var invoker = RequestInvokers.GetOrAdd(
(requestType, typeof(TResponse)),
static key => CreateRequestInvoker(key.RequestType, key.ResponseType));
var invoker = RequestInvokerCache<TResponse>.Invokers.GetOrAdd(
requestType,
CreateRequestInvoker<TResponse>);
var result = await invoker(handler, request, cancellationToken);
return result is null ? default! : (TResponse)result;
return await invoker(handler, request, cancellationToken);
}
var pipelineInvoker = RequestPipelineInvokers.GetOrAdd(
(requestType, typeof(TResponse)),
static key => CreateRequestPipelineInvoker(key.RequestType, key.ResponseType));
var pipelineInvoker = RequestPipelineInvokerCache<TResponse>.Invokers.GetOrAdd(
requestType,
CreateRequestPipelineInvoker<TResponse>);
var pipelineResult = await pipelineInvoker(handler, behaviors, request, cancellationToken);
return pipelineResult is null ? default! : (TResponse)pipelineResult;
return await pipelineInvoker(handler, behaviors, request, cancellationToken);
}
/// <summary>
@ -200,21 +188,23 @@ internal sealed class CqrsDispatcher(
/// <summary>
/// 生成请求处理器调用委托,避免每次发送都重复反射。
/// </summary>
private static RequestInvoker CreateRequestInvoker(Type requestType, Type responseType)
private static RequestInvoker<TResponse> CreateRequestInvoker<TResponse>(Type requestType)
{
var method = RequestHandlerInvokerMethodDefinition
.MakeGenericMethod(requestType, responseType);
return (RequestInvoker)Delegate.CreateDelegate(typeof(RequestInvoker), method);
.MakeGenericMethod(requestType, typeof(TResponse));
return (RequestInvoker<TResponse>)Delegate.CreateDelegate(typeof(RequestInvoker<TResponse>), method);
}
/// <summary>
/// 生成带管道行为的请求处理委托,避免每次发送都重复反射。
/// </summary>
private static RequestPipelineInvoker CreateRequestPipelineInvoker(Type requestType, Type responseType)
private static RequestPipelineInvoker<TResponse> CreateRequestPipelineInvoker<TResponse>(Type requestType)
{
var method = RequestPipelineInvokerMethodDefinition
.MakeGenericMethod(requestType, responseType);
return (RequestPipelineInvoker)Delegate.CreateDelegate(typeof(RequestPipelineInvoker), method);
.MakeGenericMethod(requestType, typeof(TResponse));
return (RequestPipelineInvoker<TResponse>)Delegate.CreateDelegate(
typeof(RequestPipelineInvoker<TResponse>),
method);
}
/// <summary>
@ -240,7 +230,7 @@ internal sealed class CqrsDispatcher(
/// <summary>
/// 执行已强类型化的请求处理器调用。
/// </summary>
private static async ValueTask<object?> InvokeRequestHandlerAsync<TRequest, TResponse>(
private static ValueTask<TResponse> InvokeRequestHandlerAsync<TRequest, TResponse>(
object handler,
object request,
CancellationToken cancellationToken)
@ -248,14 +238,13 @@ internal sealed class CqrsDispatcher(
{
var typedHandler = (IRequestHandler<TRequest, TResponse>)handler;
var typedRequest = (TRequest)request;
var result = await typedHandler.Handle(typedRequest, cancellationToken);
return result;
return typedHandler.Handle(typedRequest, cancellationToken);
}
/// <summary>
/// 执行包含管道行为链的请求处理。
/// </summary>
private static async ValueTask<object?> InvokeRequestPipelineAsync<TRequest, TResponse>(
private static ValueTask<TResponse> InvokeRequestPipelineAsync<TRequest, TResponse>(
object handler,
IReadOnlyList<object> behaviors,
object request,
@ -275,8 +264,7 @@ internal sealed class CqrsDispatcher(
next = (message, token) => behavior.Handle(message, currentNext, token);
}
var result = await next(typedRequest, cancellationToken);
return result;
return next(typedRequest, cancellationToken);
}
/// <summary>
@ -307,10 +295,12 @@ internal sealed class CqrsDispatcher(
return typedHandler.Handle(typedRequest, cancellationToken);
}
private delegate ValueTask<object?> RequestInvoker(object handler, object request,
private delegate ValueTask<TResponse> RequestInvoker<TResponse>(
object handler,
object request,
CancellationToken cancellationToken);
private delegate ValueTask<object?> RequestPipelineInvoker(
private delegate ValueTask<TResponse> RequestPipelineInvoker<TResponse>(
object handler,
IReadOnlyList<object> behaviors,
object request,
@ -321,5 +311,23 @@ internal sealed class CqrsDispatcher(
private delegate object StreamInvoker(object handler, object request, CancellationToken cancellationToken);
/// <summary>
/// 按响应类型分层缓存 request 处理器调用委托,避免 value-type 响应在 object 桥接中产生装箱。
/// </summary>
/// <typeparam name="TResponse">请求响应类型。</typeparam>
private static class RequestInvokerCache<TResponse>
{
internal static readonly ConcurrentDictionary<Type, RequestInvoker<TResponse>> Invokers = new();
}
/// <summary>
/// 按响应类型分层缓存带 pipeline 的 request 调用委托,避免 pipeline 热路径上的额外装箱。
/// </summary>
/// <typeparam name="TResponse">请求响应类型。</typeparam>
private static class RequestPipelineInvokerCache<TResponse>
{
internal static readonly ConcurrentDictionary<Type, RequestPipelineInvoker<TResponse>> Invokers = new();
}
private readonly record struct RequestServiceTypeSet(Type HandlerType, Type BehaviorType);
}

View File

@ -16,6 +16,24 @@ public static class GeneratorTest<TGenerator>
public static async Task RunAsync(
string source,
params (string filename, string content)[] generatedSources)
{
await RunAsync(
source,
additionalReferences: [],
generatedSources);
}
/// <summary>
/// 运行源代码生成器测试,并为测试编译显式追加元数据引用。
/// </summary>
/// <param name="source">输入的源代码。</param>
/// <param name="additionalReferences">附加元数据引用,用于构造多程序集场景。</param>
/// <param name="generatedSources">期望生成的源文件集合,包含文件名和内容的元组。</param>
/// <returns>异步操作任务。</returns>
public static async Task RunAsync(
string source,
IEnumerable<MetadataReference> additionalReferences,
params (string filename, string content)[] generatedSources)
{
var test = new CSharpSourceGeneratorTest<TGenerator, DefaultVerifier>
{
@ -31,6 +49,9 @@ public static class GeneratorTest<TGenerator>
test.TestState.GeneratedSources.Add(
(typeof(TGenerator), filename, NormalizeLineEndings(content)));
foreach (var additionalReference in additionalReferences)
test.TestState.AdditionalReferences.Add(additionalReference);
await test.RunAsync();
}
@ -46,4 +67,4 @@ public static class GeneratorTest<TGenerator>
.Replace("\r", "\n", StringComparison.Ordinal)
.Replace("\n", Environment.NewLine, StringComparison.Ordinal);
}
}
}

View File

@ -0,0 +1,74 @@
using System.Collections.Immutable;
using System.IO;
namespace GFramework.SourceGenerators.Tests.Core;
/// <summary>
/// 为多程序集源生成器测试构建内存元数据引用。
/// </summary>
public static class MetadataReferenceTestBuilder
{
// Reuse the runtime reference set across generator tests to avoid reparsing TRUSTED_PLATFORM_ASSEMBLIES
// for every in-memory compilation.
private static readonly Lazy<ImmutableArray<MetadataReference>> CachedRuntimeReferences =
new(CreateRuntimeMetadataReferences);
/// <summary>
/// 将给定源码编译为内存程序集,并返回可供测试编译消费的元数据引用。
/// </summary>
/// <param name="assemblyName">目标程序集名称。</param>
/// <param name="source">待编译源码。</param>
/// <param name="additionalReferences">附加元数据引用,用于构造依赖链。</param>
/// <returns>编译成功后的内存元数据引用。</returns>
public static MetadataReference CreateFromSource(
string assemblyName,
string source,
params MetadataReference[] additionalReferences)
{
var syntaxTree = CSharpSyntaxTree.ParseText(source);
var references = CachedRuntimeReferences.Value
.Concat(additionalReferences)
.ToImmutableArray();
var compilation = CSharpCompilation.Create(
assemblyName,
[syntaxTree],
references,
new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary));
using var stream = new MemoryStream();
var emitResult = compilation.Emit(stream);
if (!emitResult.Success)
{
var diagnostics = string.Join(
Environment.NewLine,
emitResult.Diagnostics
.Where(static diagnostic => diagnostic.Severity == DiagnosticSeverity.Error)
.Select(static diagnostic => diagnostic.ToString()));
throw new InvalidOperationException(
$"Failed to build metadata reference '{assemblyName}'.{Environment.NewLine}{diagnostics}");
}
stream.Position = 0;
return MetadataReference.CreateFromImage(stream.ToArray());
}
/// <summary>
/// 获取当前测试运行时可直接复用的基础元数据引用集合。
/// </summary>
/// <returns>当前运行时可信平台程序集对应的元数据引用。</returns>
public static ImmutableArray<MetadataReference> GetRuntimeMetadataReferences()
{
return CachedRuntimeReferences.Value;
}
private static ImmutableArray<MetadataReference> CreateRuntimeMetadataReferences()
{
var trustedPlatformAssemblies = ((string?)AppContext.GetData("TRUSTED_PLATFORM_ASSEMBLIES"))?
.Split(Path.PathSeparator, StringSplitOptions.RemoveEmptyEntries)
?? Array.Empty<string>();
return trustedPlatformAssemblies
.Select(static path => (MetadataReference)MetadataReference.CreateFromFile(path))
.ToImmutableArray();
}
}

View File

@ -164,6 +164,94 @@ public class CqrsHandlerRegistryGeneratorTests
""";
private const string MixedDirectAndPreciseRegistrationsExpected = """
// <auto-generated />
#nullable enable
[assembly: global::GFramework.Cqrs.CqrsHandlerRegistryAttribute(typeof(global::GFramework.Generated.Cqrs.__GFrameworkGeneratedCqrsHandlerRegistry))]
namespace GFramework.Generated.Cqrs;
internal sealed class __GFrameworkGeneratedCqrsHandlerRegistry : global::GFramework.Cqrs.ICqrsHandlerRegistry
{
public void Register(global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, global::GFramework.Core.Abstractions.Logging.ILogger logger)
{
if (services is null)
throw new global::System.ArgumentNullException(nameof(services));
if (logger is null)
throw new global::System.ArgumentNullException(nameof(logger));
var registryAssembly = typeof(global::GFramework.Generated.Cqrs.__GFrameworkGeneratedCqrsHandlerRegistry).Assembly;
var implementationType0 = typeof(global::TestApp.Container.MixedHandler);
if (implementationType0 is not null)
{
var serviceType0_0Argument0 = registryAssembly.GetType("TestApp.Container+HiddenRequest", throwOnError: false, ignoreCase: false);
var serviceType0_0Argument1Element = registryAssembly.GetType("TestApp.Container+HiddenResponse", throwOnError: false, ignoreCase: false);
if (serviceType0_0Argument0 is not null && serviceType0_0Argument1Element is not null)
{
var serviceType0_0 = typeof(global::GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler<,>).MakeGenericType(serviceType0_0Argument0, serviceType0_0Argument1Element.MakeArrayType());
global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient(
services,
serviceType0_0,
implementationType0);
logger.Debug("Registered CQRS handler TestApp.Container.MixedHandler as GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler<TestApp.Container.HiddenRequest, TestApp.Container.HiddenResponse[]>.");
}
global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient(
services,
typeof(global::GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler<global::TestApp.VisibleRequest, string>),
implementationType0);
logger.Debug("Registered CQRS handler TestApp.Container.MixedHandler as GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler<TestApp.VisibleRequest, string>.");
}
}
}
""";
private const string MixedReflectedImplementationAndPreciseRegistrationsExpected = """
// <auto-generated />
#nullable enable
[assembly: global::GFramework.Cqrs.CqrsHandlerRegistryAttribute(typeof(global::GFramework.Generated.Cqrs.__GFrameworkGeneratedCqrsHandlerRegistry))]
namespace GFramework.Generated.Cqrs;
internal sealed class __GFrameworkGeneratedCqrsHandlerRegistry : global::GFramework.Cqrs.ICqrsHandlerRegistry
{
public void Register(global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, global::GFramework.Core.Abstractions.Logging.ILogger logger)
{
if (services is null)
throw new global::System.ArgumentNullException(nameof(services));
if (logger is null)
throw new global::System.ArgumentNullException(nameof(logger));
var registryAssembly = typeof(global::GFramework.Generated.Cqrs.__GFrameworkGeneratedCqrsHandlerRegistry).Assembly;
var implementationType0 = registryAssembly.GetType("TestApp.Container+HiddenMixedHandler", throwOnError: false, ignoreCase: false);
if (implementationType0 is not null)
{
var serviceType0_0Argument0 = registryAssembly.GetType("TestApp.Container+HiddenRequest", throwOnError: false, ignoreCase: false);
var serviceType0_0Argument1Element = registryAssembly.GetType("TestApp.Container+HiddenResponse", throwOnError: false, ignoreCase: false);
if (serviceType0_0Argument0 is not null && serviceType0_0Argument1Element is not null)
{
var serviceType0_0 = typeof(global::GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler<,>).MakeGenericType(serviceType0_0Argument0, serviceType0_0Argument1Element.MakeArrayType());
global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient(
services,
serviceType0_0,
implementationType0);
logger.Debug("Registered CQRS handler TestApp.Container.HiddenMixedHandler as GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler<TestApp.Container.HiddenRequest, TestApp.Container.HiddenResponse[]>.");
}
global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient(
services,
typeof(global::GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler<global::TestApp.VisibleRequest, string>),
implementationType0);
logger.Debug("Registered CQRS handler TestApp.Container.HiddenMixedHandler as GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler<TestApp.VisibleRequest, string>.");
}
}
}
""";
/// <summary>
/// 验证生成器会为当前程序集中的 request、notification 和 stream 处理器生成稳定顺序的注册器。
/// </summary>
@ -579,6 +667,293 @@ public class CqrsHandlerRegistryGeneratorTests
("CqrsHandlerRegistry.g.cs", HiddenGenericEnvelopeResponseExpected));
}
/// <summary>
/// 验证同一个 implementation 同时包含可直接注册接口与需精确重建接口时,
/// 生成器会保留两类注册,并继续按 handler interface 名称稳定排序。
/// </summary>
[Test]
public async Task Generates_Mixed_Direct_And_Precise_Registrations_For_Same_Implementation()
{
const string source = """
using System;
namespace Microsoft.Extensions.DependencyInjection
{
public interface IServiceCollection { }
public static class ServiceCollectionServiceExtensions
{
public static void AddTransient(IServiceCollection services, Type serviceType, Type implementationType) { }
}
}
namespace GFramework.Core.Abstractions.Logging
{
public interface ILogger
{
void Debug(string msg);
}
}
namespace GFramework.Cqrs.Abstractions.Cqrs
{
public interface IRequest<TResponse> { }
public interface INotification { }
public interface IStreamRequest<TResponse> { }
public interface IRequestHandler<in TRequest, TResponse> where TRequest : IRequest<TResponse> { }
public interface INotificationHandler<in TNotification> where TNotification : INotification { }
public interface IStreamRequestHandler<in TRequest, out TResponse> where TRequest : IStreamRequest<TResponse> { }
}
namespace GFramework.Cqrs
{
public interface ICqrsHandlerRegistry
{
void Register(Microsoft.Extensions.DependencyInjection.IServiceCollection services, GFramework.Core.Abstractions.Logging.ILogger logger);
}
[AttributeUsage(AttributeTargets.Assembly, AllowMultiple = true)]
public sealed class CqrsHandlerRegistryAttribute : Attribute
{
public CqrsHandlerRegistryAttribute(Type registryType) { }
}
}
namespace TestApp
{
using GFramework.Cqrs.Abstractions.Cqrs;
public sealed record VisibleRequest() : IRequest<string>;
public sealed class Container
{
private sealed record HiddenResponse();
private sealed record HiddenRequest() : IRequest<HiddenResponse[]>;
public sealed class MixedHandler :
IRequestHandler<HiddenRequest, HiddenResponse[]>,
IRequestHandler<VisibleRequest, string>
{
}
}
}
""";
await GeneratorTest<CqrsHandlerRegistryGenerator>.RunAsync(
source,
("CqrsHandlerRegistry.g.cs", MixedDirectAndPreciseRegistrationsExpected));
}
/// <summary>
/// 验证隐藏 implementation 同时包含可见 handler interface 与需精确重建接口时,
/// 生成器会保留两类注册,而不会让可见接口被整实现回退吞掉。
/// </summary>
[Test]
public async Task Generates_Mixed_Reflected_Implementation_And_Precise_Registrations_For_Same_Implementation()
{
const string source = """
using System;
namespace Microsoft.Extensions.DependencyInjection
{
public interface IServiceCollection { }
public static class ServiceCollectionServiceExtensions
{
public static void AddTransient(IServiceCollection services, Type serviceType, Type implementationType) { }
}
}
namespace GFramework.Core.Abstractions.Logging
{
public interface ILogger
{
void Debug(string msg);
}
}
namespace GFramework.Cqrs.Abstractions.Cqrs
{
public interface IRequest<TResponse> { }
public interface INotification { }
public interface IStreamRequest<TResponse> { }
public interface IRequestHandler<in TRequest, TResponse> where TRequest : IRequest<TResponse> { }
public interface INotificationHandler<in TNotification> where TNotification : INotification { }
public interface IStreamRequestHandler<in TRequest, out TResponse> where TRequest : IStreamRequest<TResponse> { }
}
namespace GFramework.Cqrs
{
public interface ICqrsHandlerRegistry
{
void Register(Microsoft.Extensions.DependencyInjection.IServiceCollection services, GFramework.Core.Abstractions.Logging.ILogger logger);
}
[AttributeUsage(AttributeTargets.Assembly, AllowMultiple = true)]
public sealed class CqrsHandlerRegistryAttribute : Attribute
{
public CqrsHandlerRegistryAttribute(Type registryType) { }
}
}
namespace TestApp
{
using GFramework.Cqrs.Abstractions.Cqrs;
public sealed record VisibleRequest() : IRequest<string>;
public sealed class Container
{
private sealed record HiddenResponse();
private sealed record HiddenRequest() : IRequest<HiddenResponse[]>;
private sealed class HiddenMixedHandler :
IRequestHandler<HiddenRequest, HiddenResponse[]>,
IRequestHandler<VisibleRequest, string>
{
}
}
}
""";
await GeneratorTest<CqrsHandlerRegistryGenerator>.RunAsync(
source,
("CqrsHandlerRegistry.g.cs", MixedReflectedImplementationAndPreciseRegistrationsExpected));
}
/// <summary>
/// 验证当外部基类暴露的 handler interface 含有生成注册器顶层上下文不可直接引用的 protected 类型时,
/// 生成器会保留已知直注册,并只对剩余未知接口做本地 interface discovery。
/// </summary>
[Test]
public void Generates_Partial_Runtime_Interface_Discovery_For_Inaccessible_External_Protected_Types()
{
const string contractsSource = """
namespace GFramework.Cqrs.Abstractions.Cqrs
{
public interface IRequest<TResponse> { }
public interface INotification { }
public interface IStreamRequest<TResponse> { }
public interface IRequestHandler<in TRequest, TResponse> where TRequest : IRequest<TResponse> { }
public interface INotificationHandler<in TNotification> where TNotification : INotification { }
public interface IStreamRequestHandler<in TRequest, out TResponse> where TRequest : IStreamRequest<TResponse> { }
}
""";
const string dependencySource = """
using GFramework.Cqrs.Abstractions.Cqrs;
namespace Dep;
public sealed record VisibleRequest() : IRequest<string>;
public abstract class VisibilityScope
{
protected internal sealed record ProtectedResponse();
protected internal sealed record ProtectedRequest() : IRequest<ProtectedResponse[]>;
}
public abstract class HandlerBase :
VisibilityScope,
IRequestHandler<VisibleRequest, string>,
IRequestHandler<VisibilityScope.ProtectedRequest, VisibilityScope.ProtectedResponse[]>
{
}
""";
const string source = """
using System;
using Dep;
namespace Microsoft.Extensions.DependencyInjection
{
public interface IServiceCollection { }
public static class ServiceCollectionServiceExtensions
{
public static void AddTransient(IServiceCollection services, Type serviceType, Type implementationType) { }
}
}
namespace GFramework.Core.Abstractions.Logging
{
public interface ILogger
{
void Debug(string msg);
}
}
namespace GFramework.Cqrs
{
public interface ICqrsHandlerRegistry
{
void Register(Microsoft.Extensions.DependencyInjection.IServiceCollection services, GFramework.Core.Abstractions.Logging.ILogger logger);
}
[AttributeUsage(AttributeTargets.Assembly, AllowMultiple = true)]
public sealed class CqrsHandlerRegistryAttribute : Attribute
{
public CqrsHandlerRegistryAttribute(Type registryType) { }
}
}
namespace TestApp
{
public sealed class DerivedHandler : HandlerBase
{
}
}
""";
var contractsReference = MetadataReferenceTestBuilder.CreateFromSource(
"Contracts",
contractsSource);
var dependencyReference = MetadataReferenceTestBuilder.CreateFromSource(
"Dependency",
dependencySource,
contractsReference);
var generatedSource = RunGenerator(
source,
contractsReference,
dependencyReference);
Assert.Multiple(() =>
{
Assert.That(
generatedSource,
Does.Contain("var implementationType0 = typeof(global::TestApp.DerivedHandler);"));
Assert.That(
generatedSource,
Does.Contain(
"var knownServiceTypes0 = new global::System.Collections.Generic.HashSet<global::System.Type>();"));
Assert.That(
generatedSource,
Does.Contain(
"knownServiceTypes0.Add(typeof(global::GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler<global::Dep.VisibleRequest, string>));"));
Assert.That(
generatedSource,
Does.Contain(
"RegisterRemainingReflectedHandlerInterfaces(services, logger, implementationType0, knownServiceTypes0);"));
Assert.That(
generatedSource,
Does.Contain("if (knownServiceTypes.Contains(handlerInterface))"));
Assert.That(
generatedSource,
Does.Contain(
"Registered CQRS handler TestApp.DerivedHandler as GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler<Dep.VisibleRequest, string>."));
Assert.That(
generatedSource,
Does.Not.Contain(
"typeof(global::GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler<global::Dep.VisibilityScope.ProtectedRequest"));
});
}
/// <summary>
/// 验证即使 runtime 仍暴露旧版无参 fallback marker生成器也会优先在生成注册器内部处理隐藏 handler
/// 不再输出 fallback marker。
@ -753,4 +1128,42 @@ public class CqrsHandlerRegistryGeneratorTests
Assert.That(escaped, Is.EqualTo(expected));
}
/// <summary>
/// 运行 CQRS handler registry generator并返回单个生成文件的源码文本。
/// </summary>
private static string RunGenerator(
string source,
params MetadataReference[] additionalReferences)
{
var syntaxTree = CSharpSyntaxTree.ParseText(source);
var compilation = CSharpCompilation.Create(
"TestProject",
[syntaxTree],
MetadataReferenceTestBuilder.GetRuntimeMetadataReferences().AddRange(additionalReferences),
new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary));
GeneratorDriver driver = CSharpGeneratorDriver.Create(
generators: [new CqrsHandlerRegistryGenerator().AsSourceGenerator()],
parseOptions: (CSharpParseOptions)syntaxTree.Options);
driver = driver.RunGeneratorsAndUpdateCompilation(
compilation,
out var updatedCompilation,
out _);
var compilationErrors = updatedCompilation.GetDiagnostics()
.Where(static diagnostic => diagnostic.Severity == DiagnosticSeverity.Error)
.ToArray();
Assert.That(
compilationErrors,
Is.Empty,
() =>
$"编译生成的代码时出现错误:{Environment.NewLine}{string.Join(Environment.NewLine, compilationErrors.Select(static diagnostic => diagnostic.ToString()))}");
var runResult = driver.GetRunResult();
Assert.That(runResult.Results, Has.Length.EqualTo(1));
Assert.That(runResult.Results[0].GeneratedSources, Has.Length.EqualTo(1));
return runResult.Results[0].GeneratedSources[0].SourceText.ToString();
}
}

View File

@ -20,4 +20,5 @@ global using Microsoft.CodeAnalysis;
global using Microsoft.CodeAnalysis.Text;
global using Microsoft.CodeAnalysis.CSharp.Testing;
global using Microsoft.CodeAnalysis.Testing;
global using NUnit.Framework;
global using Microsoft.CodeAnalysis.CSharp;
global using NUnit.Framework;

View File

@ -86,15 +86,17 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
var implementationTypeDisplayName = type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat);
var implementationLogName = GetLogDisplayName(type);
var canReferenceImplementation = CanReferenceFromGeneratedRegistry(type);
var canReferenceImplementation = CanReferenceFromGeneratedRegistry(context.SemanticModel.Compilation, type);
var registrations = ImmutableArray.CreateBuilder<HandlerRegistrationSpec>(handlerInterfaces.Length);
var reflectedImplementationRegistrations =
ImmutableArray.CreateBuilder<ReflectedImplementationRegistrationSpec>(handlerInterfaces.Length);
var preciseReflectedRegistrations =
ImmutableArray.CreateBuilder<PreciseReflectedRegistrationSpec>(handlerInterfaces.Length);
var requiresRuntimeInterfaceDiscovery = false;
foreach (var handlerInterface in handlerInterfaces)
{
var canReferenceHandlerInterface = CanReferenceFromGeneratedRegistry(handlerInterface);
var canReferenceHandlerInterface =
CanReferenceFromGeneratedRegistry(context.SemanticModel.Compilation, handlerInterface);
if (canReferenceImplementation && canReferenceHandlerInterface)
{
registrations.Add(new HandlerRegistrationSpec(
@ -122,16 +124,10 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
continue;
}
// Some closed handler interfaces still contain runtime-only type shapes such as arrays closed over
// non-public element types. For those rare cases keep the narrow implementation lookup, but let the
// generated registry discover the exact supported interfaces from the implementation type at runtime.
return new HandlerCandidateAnalysis(
implementationTypeDisplayName,
implementationLogName,
ImmutableArray<HandlerRegistrationSpec>.Empty,
ImmutableArray<ReflectedImplementationRegistrationSpec>.Empty,
ImmutableArray<PreciseReflectedRegistrationSpec>.Empty,
GetReflectionTypeMetadataName(type));
// 某些关闭 handler interface 仍包含只能在实现类型运行时语义里解析的类型形态。
// 对这些边角场景保留“已知接口静态注册 + 剩余接口运行时补洞”的组合路径,
// 避免单个未知接口把同实现上的其它已知注册全部拖回整实现反射发现。
requiresRuntimeInterfaceDiscovery = true;
}
return new HandlerCandidateAnalysis(
@ -140,7 +136,8 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
registrations.ToImmutable(),
reflectedImplementationRegistrations.ToImmutable(),
preciseReflectedRegistrations.ToImmutable(),
canReferenceImplementation ? null : GetReflectionTypeMetadataName(type));
canReferenceImplementation ? null : GetReflectionTypeMetadataName(type),
requiresRuntimeInterfaceDiscovery);
}
private static void Execute(SourceProductionContext context, GenerationEnvironment generationEnvironment,
@ -184,7 +181,8 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
candidate.Registrations,
candidate.ReflectedImplementationRegistrations,
candidate.PreciseReflectedRegistrations,
candidate.ReflectionTypeMetadataName));
candidate.ReflectionTypeMetadataName,
candidate.RequiresRuntimeInterfaceDiscovery));
}
registrations.Sort(static (left, right) =>
@ -295,7 +293,7 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
ITypeSymbol type,
out RuntimeTypeReferenceSpec? runtimeTypeReference)
{
if (CanReferenceFromGeneratedRegistry(type))
if (CanReferenceFromGeneratedRegistry(compilation, type))
{
runtimeTypeReference = RuntimeTypeReferenceSpec.FromDirectReference(
type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat));
@ -369,7 +367,7 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
out RuntimeTypeReferenceSpec? genericTypeDefinitionReference)
{
var genericTypeDefinition = genericNamedType.OriginalDefinition;
if (CanReferenceFromGeneratedRegistry(genericTypeDefinition))
if (CanReferenceFromGeneratedRegistry(compilation, genericTypeDefinition))
{
genericTypeDefinitionReference = RuntimeTypeReferenceSpec.FromDirectReference(
genericTypeDefinition
@ -389,43 +387,34 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
return false;
}
private static bool CanReferenceFromGeneratedRegistry(ITypeSymbol type)
private static bool CanReferenceFromGeneratedRegistry(Compilation compilation, ITypeSymbol type)
{
switch (type)
{
case IArrayTypeSymbol arrayType:
return CanReferenceFromGeneratedRegistry(arrayType.ElementType);
return CanReferenceFromGeneratedRegistry(compilation, arrayType.ElementType);
case INamedTypeSymbol namedType:
if (!IsTypeChainAccessible(namedType))
if (!compilation.IsSymbolAccessibleWithin(namedType, compilation.Assembly, throughType: null))
return false;
return namedType.TypeArguments.All(CanReferenceFromGeneratedRegistry);
foreach (var typeArgument in namedType.TypeArguments)
{
if (!CanReferenceFromGeneratedRegistry(compilation, typeArgument))
return false;
}
return true;
case IPointerTypeSymbol pointerType:
return CanReferenceFromGeneratedRegistry(pointerType.PointedAtType);
return CanReferenceFromGeneratedRegistry(compilation, pointerType.PointedAtType);
case ITypeParameterSymbol:
return false;
default:
// Treat other Roslyn type kinds, such as dynamic or unresolved error types, as referenceable for now.
// If a real-world case proves unsafe, tighten this branch instead of broadening the named-type path above.
return true;
}
}
private static bool IsTypeChainAccessible(INamedTypeSymbol type)
{
for (var current = type; current is not null; current = current.ContainingType)
{
if (!IsSymbolAccessible(current))
return false;
}
return true;
}
private static bool IsSymbolAccessible(ISymbol symbol)
{
return symbol.DeclaredAccessibility is Accessibility.Public or Accessibility.Internal
or Accessibility.ProtectedOrInternal;
}
private static string GetFullyQualifiedMetadataName(INamedTypeSymbol type)
{
var nestedTypes = new Stack<string>();
@ -496,10 +485,8 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
!registration.ReflectedImplementationRegistrations.IsDefaultOrEmpty);
var hasPreciseReflectedRegistrations = registrations.Any(static registration =>
!registration.PreciseReflectedRegistrations.IsDefaultOrEmpty);
var hasFullReflectionRegistrations = registrations.Any(static registration =>
!string.IsNullOrWhiteSpace(registration.ReflectionTypeMetadataName) &&
registration.ReflectedImplementationRegistrations.IsDefaultOrEmpty &&
registration.PreciseReflectedRegistrations.IsDefaultOrEmpty);
var hasRuntimeInterfaceDiscovery = registrations.Any(static registration =>
registration.RequiresRuntimeInterfaceDiscovery);
var builder = new StringBuilder();
builder.AppendLine("// <auto-generated />");
builder.AppendLine("#nullable enable");
@ -533,7 +520,8 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
builder.AppendLine(" if (logger is null)");
builder.AppendLine(" throw new global::System.ArgumentNullException(nameof(logger));");
if (hasReflectedImplementationRegistrations || hasPreciseReflectedRegistrations ||
hasFullReflectionRegistrations)
registrations.Any(static registration =>
!string.IsNullOrWhiteSpace(registration.ReflectionTypeMetadataName)))
{
builder.AppendLine();
builder.Append(" var registryAssembly = typeof(global::");
@ -549,46 +537,21 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
for (var registrationIndex = 0; registrationIndex < registrations.Count; registrationIndex++)
{
var registration = registrations[registrationIndex];
if (!registration.ReflectedImplementationRegistrations.IsDefaultOrEmpty)
if (!registration.ReflectedImplementationRegistrations.IsDefaultOrEmpty ||
!registration.PreciseReflectedRegistrations.IsDefaultOrEmpty ||
registration.RequiresRuntimeInterfaceDiscovery)
{
AppendReflectedImplementationRegistrations(builder, registration, registrationIndex);
continue;
AppendOrderedImplementationRegistrations(builder, registration, registrationIndex);
}
if (!registration.PreciseReflectedRegistrations.IsDefaultOrEmpty)
else if (!registration.DirectRegistrations.IsDefaultOrEmpty)
{
AppendPreciseReflectedRegistrations(builder, registration, registrationIndex);
continue;
}
if (!string.IsNullOrWhiteSpace(registration.ReflectionTypeMetadataName))
{
AppendReflectionRegistration(builder, registration.ReflectionTypeMetadataName!);
continue;
}
foreach (var directRegistration in registration.DirectRegistrations)
{
builder.AppendLine(
" global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient(");
builder.AppendLine(" services,");
builder.Append(" typeof(");
builder.Append(directRegistration.HandlerInterfaceDisplayName);
builder.AppendLine("),");
builder.Append(" typeof(");
builder.Append(directRegistration.ImplementationTypeDisplayName);
builder.AppendLine("));");
builder.Append(" logger.Debug(\"Registered CQRS handler ");
builder.Append(EscapeStringLiteral(directRegistration.ImplementationLogName));
builder.Append(" as ");
builder.Append(EscapeStringLiteral(directRegistration.HandlerInterfaceLogName));
builder.AppendLine(".\");");
AppendDirectRegistrations(builder, registration);
}
}
builder.AppendLine(" }");
if (hasFullReflectionRegistrations)
if (hasRuntimeInterfaceDiscovery)
{
builder.AppendLine();
AppendReflectionHelpers(builder);
@ -598,56 +561,73 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
return builder.ToString();
}
private static void AppendReflectionRegistration(StringBuilder builder, string reflectionTypeMetadataName)
{
builder.Append(" RegisterReflectedHandler(services, logger, registryAssembly, \"");
builder.Append(EscapeStringLiteral(reflectionTypeMetadataName));
builder.AppendLine("\");");
}
private static void AppendReflectedImplementationRegistrations(
private static void AppendDirectRegistrations(
StringBuilder builder,
ImplementationRegistrationSpec registration,
int registrationIndex)
ImplementationRegistrationSpec registration)
{
var implementationVariableName = $"implementationType{registrationIndex}";
builder.Append(" var ");
builder.Append(implementationVariableName);
builder.Append(" = registryAssembly.GetType(\"");
builder.Append(EscapeStringLiteral(registration.ReflectionTypeMetadataName!));
builder.AppendLine("\", throwOnError: false, ignoreCase: false);");
builder.Append(" if (");
builder.Append(implementationVariableName);
builder.AppendLine(" is not null)");
builder.AppendLine(" {");
foreach (var reflectedRegistration in registration.ReflectedImplementationRegistrations)
foreach (var directRegistration in registration.DirectRegistrations)
{
builder.AppendLine(
" global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient(");
builder.AppendLine(" services,");
builder.Append(" typeof(");
builder.Append(reflectedRegistration.HandlerInterfaceDisplayName);
" global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient(");
builder.AppendLine(" services,");
builder.Append(" typeof(");
builder.Append(directRegistration.HandlerInterfaceDisplayName);
builder.AppendLine("),");
builder.Append(" ");
builder.Append(implementationVariableName);
builder.AppendLine(");");
builder.Append(" logger.Debug(\"Registered CQRS handler ");
builder.Append(EscapeStringLiteral(registration.ImplementationLogName));
builder.Append(" typeof(");
builder.Append(directRegistration.ImplementationTypeDisplayName);
builder.AppendLine("));");
builder.Append(" logger.Debug(\"Registered CQRS handler ");
builder.Append(EscapeStringLiteral(directRegistration.ImplementationLogName));
builder.Append(" as ");
builder.Append(EscapeStringLiteral(reflectedRegistration.HandlerInterfaceLogName));
builder.Append(EscapeStringLiteral(directRegistration.HandlerInterfaceLogName));
builder.AppendLine(".\");");
}
builder.AppendLine(" }");
}
private static void AppendPreciseReflectedRegistrations(
private static void AppendOrderedImplementationRegistrations(
StringBuilder builder,
ImplementationRegistrationSpec registration,
int registrationIndex)
{
var orderedRegistrations =
new List<(string HandlerInterfaceLogName, OrderedRegistrationKind Kind, int Index)>(
registration.DirectRegistrations.Length +
registration.ReflectedImplementationRegistrations.Length +
registration.PreciseReflectedRegistrations.Length);
for (var directIndex = 0; directIndex < registration.DirectRegistrations.Length; directIndex++)
{
orderedRegistrations.Add((
registration.DirectRegistrations[directIndex].HandlerInterfaceLogName,
OrderedRegistrationKind.Direct,
directIndex));
}
for (var reflectedIndex = 0;
reflectedIndex < registration.ReflectedImplementationRegistrations.Length;
reflectedIndex++)
{
orderedRegistrations.Add((
registration.ReflectedImplementationRegistrations[reflectedIndex].HandlerInterfaceLogName,
OrderedRegistrationKind.ReflectedImplementation,
reflectedIndex));
}
for (var preciseIndex = 0;
preciseIndex < registration.PreciseReflectedRegistrations.Length;
preciseIndex++)
{
orderedRegistrations.Add((
registration.PreciseReflectedRegistrations[preciseIndex].HandlerInterfaceLogName,
OrderedRegistrationKind.PreciseReflected,
preciseIndex));
}
orderedRegistrations.Sort(static (left, right) =>
StringComparer.Ordinal.Compare(left.HandlerInterfaceLogName, right.HandlerInterfaceLogName));
var implementationVariableName = $"implementationType{registrationIndex}";
var knownServiceTypesVariableName = $"knownServiceTypes{registrationIndex}";
if (string.IsNullOrWhiteSpace(registration.ReflectionTypeMetadataName))
{
builder.Append(" var ");
@ -658,11 +638,10 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
}
else
{
var implementationReflectionTypeMetadataName = registration.ReflectionTypeMetadataName!;
builder.Append(" var ");
builder.Append(implementationVariableName);
builder.Append(" = registryAssembly.GetType(\"");
builder.Append(EscapeStringLiteral(implementationReflectionTypeMetadataName));
builder.Append(EscapeStringLiteral(registration.ReflectionTypeMetadataName!));
builder.AppendLine("\", throwOnError: false, ignoreCase: false);");
}
@ -671,21 +650,98 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
builder.AppendLine(" is not null)");
builder.AppendLine(" {");
for (var registrationOffset = 0;
registrationOffset < registration.PreciseReflectedRegistrations.Length;
registrationOffset++)
if (registration.RequiresRuntimeInterfaceDiscovery)
{
var reflectedRegistration = registration.PreciseReflectedRegistrations[registrationOffset];
var registrationVariablePrefix = $"serviceType{registrationIndex}_{registrationOffset}";
AppendPreciseReflectedTypeResolution(
builder,
reflectedRegistration.ServiceTypeArguments,
registrationVariablePrefix,
implementationVariableName,
reflectedRegistration.OpenHandlerTypeDisplayName,
registration.ImplementationLogName,
reflectedRegistration.HandlerInterfaceLogName,
3);
builder.Append(" var ");
builder.Append(knownServiceTypesVariableName);
builder.AppendLine(" = new global::System.Collections.Generic.HashSet<global::System.Type>();");
}
foreach (var orderedRegistration in orderedRegistrations)
{
switch (orderedRegistration.Kind)
{
case OrderedRegistrationKind.Direct:
var directRegistration = registration.DirectRegistrations[orderedRegistration.Index];
if (registration.RequiresRuntimeInterfaceDiscovery)
{
builder.Append(" ");
builder.Append(knownServiceTypesVariableName);
builder.Append(".Add(typeof(");
builder.Append(directRegistration.HandlerInterfaceDisplayName);
builder.AppendLine("));");
}
builder.AppendLine(
" global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient(");
builder.AppendLine(" services,");
builder.Append(" typeof(");
builder.Append(directRegistration.HandlerInterfaceDisplayName);
builder.AppendLine("),");
builder.Append(" ");
builder.Append(implementationVariableName);
builder.AppendLine(");");
builder.Append(" logger.Debug(\"Registered CQRS handler ");
builder.Append(EscapeStringLiteral(registration.ImplementationLogName));
builder.Append(" as ");
builder.Append(EscapeStringLiteral(directRegistration.HandlerInterfaceLogName));
builder.AppendLine(".\");");
break;
case OrderedRegistrationKind.ReflectedImplementation:
var reflectedRegistration =
registration.ReflectedImplementationRegistrations[orderedRegistration.Index];
if (registration.RequiresRuntimeInterfaceDiscovery)
{
builder.Append(" ");
builder.Append(knownServiceTypesVariableName);
builder.Append(".Add(typeof(");
builder.Append(reflectedRegistration.HandlerInterfaceDisplayName);
builder.AppendLine("));");
}
builder.AppendLine(
" global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient(");
builder.AppendLine(" services,");
builder.Append(" typeof(");
builder.Append(reflectedRegistration.HandlerInterfaceDisplayName);
builder.AppendLine("),");
builder.Append(" ");
builder.Append(implementationVariableName);
builder.AppendLine(");");
builder.Append(" logger.Debug(\"Registered CQRS handler ");
builder.Append(EscapeStringLiteral(registration.ImplementationLogName));
builder.Append(" as ");
builder.Append(EscapeStringLiteral(reflectedRegistration.HandlerInterfaceLogName));
builder.AppendLine(".\");");
break;
case OrderedRegistrationKind.PreciseReflected:
var preciseRegistration = registration.PreciseReflectedRegistrations[orderedRegistration.Index];
var registrationVariablePrefix = $"serviceType{registrationIndex}_{orderedRegistration.Index}";
AppendPreciseReflectedTypeResolution(
builder,
preciseRegistration.ServiceTypeArguments,
registrationVariablePrefix,
implementationVariableName,
preciseRegistration.OpenHandlerTypeDisplayName,
registration.ImplementationLogName,
preciseRegistration.HandlerInterfaceLogName,
knownServiceTypesVariableName,
registration.RequiresRuntimeInterfaceDiscovery,
3);
break;
default:
throw new InvalidOperationException(
$"Unsupported ordered CQRS registration kind {orderedRegistration.Kind}.");
}
}
if (registration.RequiresRuntimeInterfaceDiscovery)
{
builder.Append(" RegisterRemainingReflectedHandlerInterfaces(services, logger, ");
builder.Append(implementationVariableName);
builder.Append(", ");
builder.Append(knownServiceTypesVariableName);
builder.AppendLine(");");
}
builder.AppendLine(" }");
@ -699,6 +755,8 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
string openHandlerTypeDisplayName,
string implementationLogName,
string handlerInterfaceLogName,
string knownServiceTypesVariableName,
bool trackKnownServiceTypes,
int indentLevel)
{
var indent = new string(' ', indentLevel * 4);
@ -763,6 +821,15 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
builder.Append(" ");
builder.Append(implementationVariableName);
builder.AppendLine(");");
if (trackKnownServiceTypes)
{
builder.Append(indent);
builder.Append(knownServiceTypesVariableName);
builder.Append(".Add(");
builder.Append(registrationVariablePrefix);
builder.AppendLine(");");
}
builder.Append(indent);
builder.Append("logger.Debug(\"Registered CQRS handler ");
builder.Append(EscapeStringLiteral(implementationLogName));
@ -839,15 +906,11 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
private static void AppendReflectionHelpers(StringBuilder builder)
{
// Emit the runtime helper methods only when at least one handler requires metadata-name lookup.
// Emit the runtime helper methods only when at least one handler still needs implementation-scoped
// interface discovery after all direct / precise registrations have been emitted.
builder.AppendLine(
" private static void RegisterReflectedHandler(global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, global::GFramework.Core.Abstractions.Logging.ILogger logger, global::System.Reflection.Assembly registryAssembly, string implementationTypeMetadataName)");
" private static void RegisterRemainingReflectedHandlerInterfaces(global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, global::GFramework.Core.Abstractions.Logging.ILogger logger, global::System.Type implementationType, global::System.Collections.Generic.ISet<global::System.Type> knownServiceTypes)");
builder.AppendLine(" {");
builder.AppendLine(
" var implementationType = registryAssembly.GetType(implementationTypeMetadataName, throwOnError: false, ignoreCase: false);");
builder.AppendLine(" if (implementationType is null)");
builder.AppendLine(" return;");
builder.AppendLine();
builder.AppendLine(" var handlerInterfaces = implementationType.GetInterfaces();");
builder.AppendLine(" global::System.Array.Sort(handlerInterfaces, CompareTypes);");
builder.AppendLine();
@ -856,6 +919,9 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
builder.AppendLine(" if (!IsSupportedHandlerInterface(handlerInterface))");
builder.AppendLine(" continue;");
builder.AppendLine();
builder.AppendLine(" if (knownServiceTypes.Contains(handlerInterface))");
builder.AppendLine(" continue;");
builder.AppendLine();
builder.AppendLine(
" global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient(");
builder.AppendLine(" services,");
@ -863,6 +929,7 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
builder.AppendLine(" implementationType);");
builder.AppendLine(
" logger.Debug($\"Registered CQRS handler {GetRuntimeTypeDisplayName(implementationType)} as {GetRuntimeTypeDisplayName(handlerInterface)}.\");");
builder.AppendLine(" knownServiceTypes.Add(handlerInterface);");
builder.AppendLine(" }");
builder.AppendLine(" }");
builder.AppendLine();
@ -969,6 +1036,13 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
string HandlerInterfaceDisplayName,
string HandlerInterfaceLogName);
private enum OrderedRegistrationKind
{
Direct,
ReflectedImplementation,
PreciseReflected
}
private sealed record RuntimeTypeReferenceSpec(
string? TypeDisplayName,
string? ReflectionTypeMetadataName,
@ -1015,7 +1089,8 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
ImmutableArray<HandlerRegistrationSpec> DirectRegistrations,
ImmutableArray<ReflectedImplementationRegistrationSpec> ReflectedImplementationRegistrations,
ImmutableArray<PreciseReflectedRegistrationSpec> PreciseReflectedRegistrations,
string? ReflectionTypeMetadataName);
string? ReflectionTypeMetadataName,
bool RequiresRuntimeInterfaceDiscovery);
private readonly struct HandlerCandidateAnalysis : IEquatable<HandlerCandidateAnalysis>
{
@ -1025,7 +1100,8 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
ImmutableArray<HandlerRegistrationSpec> registrations,
ImmutableArray<ReflectedImplementationRegistrationSpec> reflectedImplementationRegistrations,
ImmutableArray<PreciseReflectedRegistrationSpec> preciseReflectedRegistrations,
string? reflectionTypeMetadataName)
string? reflectionTypeMetadataName,
bool requiresRuntimeInterfaceDiscovery)
{
ImplementationTypeDisplayName = implementationTypeDisplayName;
ImplementationLogName = implementationLogName;
@ -1033,6 +1109,7 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
ReflectedImplementationRegistrations = reflectedImplementationRegistrations;
PreciseReflectedRegistrations = preciseReflectedRegistrations;
ReflectionTypeMetadataName = reflectionTypeMetadataName;
RequiresRuntimeInterfaceDiscovery = requiresRuntimeInterfaceDiscovery;
}
public string ImplementationTypeDisplayName { get; }
@ -1047,6 +1124,8 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
public string? ReflectionTypeMetadataName { get; }
public bool RequiresRuntimeInterfaceDiscovery { get; }
public bool Equals(HandlerCandidateAnalysis other)
{
if (!string.Equals(ImplementationTypeDisplayName, other.ImplementationTypeDisplayName,
@ -1054,6 +1133,7 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
!string.Equals(ImplementationLogName, other.ImplementationLogName, StringComparison.Ordinal) ||
!string.Equals(ReflectionTypeMetadataName, other.ReflectionTypeMetadataName,
StringComparison.Ordinal) ||
RequiresRuntimeInterfaceDiscovery != other.RequiresRuntimeInterfaceDiscovery ||
Registrations.Length != other.Registrations.Length ||
ReflectedImplementationRegistrations.Length != other.ReflectedImplementationRegistrations.Length ||
PreciseReflectedRegistrations.Length != other.PreciseReflectedRegistrations.Length)
@ -1098,6 +1178,7 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
(ReflectionTypeMetadataName is null
? 0
: StringComparer.Ordinal.GetHashCode(ReflectionTypeMetadataName));
hashCode = (hashCode * 397) ^ RequiresRuntimeInterfaceDiscovery.GetHashCode();
foreach (var registration in Registrations)
{
hashCode = (hashCode * 397) ^ registration.GetHashCode();