From 391e3e98138d5023192857c37eca201a5fee78b3 Mon Sep 17 00:00:00 2001
From: GeWuYou <95328647+GeWuYou@users.noreply.github.com>
Date: Thu, 16 Apr 2026 11:11:29 +0800
Subject: [PATCH 1/4] =?UTF-8?q?feat(cqrs):=20=E6=B7=BB=E5=8A=A0CQRS?=
=?UTF-8?q?=E5=A4=84=E7=90=86=E5=99=A8=E8=87=AA=E5=8A=A8=E6=B3=A8=E5=86=8C?=
=?UTF-8?q?=E5=8A=9F=E8=83=BD?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
- 实现CqrsHandlerRegistrar类,支持扫描并注册CQRS请求/通知/流式处理器
- 添加源码生成注册器优先策略,减少冷启动时的反射开销
- 实现运行时反射扫描回退机制,确保处理器注册的完整性
- 添加CqrsReflectionFallbackAttribute特性,标记需要运行时补充扫描的程序集
- 创建完整的单元测试套件,验证处理器注册顺序与容错行为
- 实现CqrsHandlerRegistryGenerator源码生成器,自动生成处理器注册代码
- 添加详细的日志记录与诊断功能,便于调试注册过程
- 实现类型安全的处理器映射验证与重复注册检测机制
---
.../Cqrs/CqrsHandlerRegistrarTests.cs | 27 +++-
.../CqrsReflectionFallbackAttribute.cs | 22 +++
.../Internal/CqrsHandlerRegistrar.cs | 152 +++++++++++++++---
.../Cqrs/CqrsHandlerRegistryGeneratorTests.cs | 111 +++++++++++++
.../Cqrs/CqrsHandlerRegistryGenerator.cs | 145 +++++++++++++++--
5 files changed, 414 insertions(+), 43 deletions(-)
diff --git a/GFramework.Cqrs.Tests/Cqrs/CqrsHandlerRegistrarTests.cs b/GFramework.Cqrs.Tests/Cqrs/CqrsHandlerRegistrarTests.cs
index 318f6d21..95afa92f 100644
--- a/GFramework.Cqrs.Tests/Cqrs/CqrsHandlerRegistrarTests.cs
+++ b/GFramework.Cqrs.Tests/Cqrs/CqrsHandlerRegistrarTests.cs
@@ -190,11 +190,11 @@ internal sealed class CqrsHandlerRegistrarTests
}
///
- /// 验证当生成注册器显式要求 reflection fallback 时,运行时会补扫剩余 handlers,
- /// 同时避免把已由生成注册器注册的映射重复写入服务集合。
+ /// 验证当生成注册器提供精确 fallback 类型名时,运行时会定向补扫剩余 handlers,
+ /// 而不是重新枚举整个程序集的类型列表。
///
[Test]
- public void RegisterHandlers_Should_Combine_Generated_Registry_With_Reflection_Fallback_Without_Duplicates()
+ public void RegisterHandlers_Should_Use_Targeted_Type_Lookups_For_Reflection_Fallback_Without_Duplicates()
{
var generatedAssembly = new Mock();
generatedAssembly
@@ -205,14 +205,17 @@ internal sealed class CqrsHandlerRegistrarTests
.Returns([new CqrsHandlerRegistryAttribute(typeof(PartialGeneratedNotificationHandlerRegistry))]);
generatedAssembly
.Setup(static assembly => assembly.GetCustomAttributes(typeof(CqrsReflectionFallbackAttribute), false))
- .Returns([new CqrsReflectionFallbackAttribute()]);
- generatedAssembly
- .Setup(static assembly => assembly.GetTypes())
.Returns(
[
- typeof(GeneratedRegistryNotificationHandler),
- ReflectionFallbackNotificationContainer.ReflectionOnlyHandlerType
+ new CqrsReflectionFallbackAttribute(
+ ReflectionFallbackNotificationContainer.ReflectionOnlyHandlerType.FullName!)
]);
+ generatedAssembly
+ .Setup(static assembly => assembly.GetType(
+ ReflectionFallbackNotificationContainer.ReflectionOnlyHandlerType.FullName!,
+ false,
+ false))
+ .Returns(ReflectionFallbackNotificationContainer.ReflectionOnlyHandlerType);
var container = new MicrosoftDiContainer();
CqrsTestRuntime.RegisterHandlers(container, generatedAssembly.Object);
@@ -231,6 +234,14 @@ internal sealed class CqrsHandlerRegistrarTests
typeof(GeneratedRegistryNotificationHandler),
ReflectionFallbackNotificationContainer.ReflectionOnlyHandlerType
]));
+
+ generatedAssembly.Verify(
+ static assembly => assembly.GetType(
+ ReflectionFallbackNotificationContainer.ReflectionOnlyHandlerType.FullName!,
+ false,
+ false),
+ Times.Once);
+ generatedAssembly.Verify(static assembly => assembly.GetTypes(), Times.Never);
}
}
diff --git a/GFramework.Cqrs/CqrsReflectionFallbackAttribute.cs b/GFramework.Cqrs/CqrsReflectionFallbackAttribute.cs
index f18a7344..9d3c21bf 100644
--- a/GFramework.Cqrs/CqrsReflectionFallbackAttribute.cs
+++ b/GFramework.Cqrs/CqrsReflectionFallbackAttribute.cs
@@ -11,4 +11,26 @@ namespace GFramework.Cqrs;
[AttributeUsage(AttributeTargets.Assembly)]
public sealed class CqrsReflectionFallbackAttribute : Attribute
{
+ ///
+ /// 初始化 。
+ ///
+ ///
+ /// 需要运行时补充反射注册的处理器类型全名。
+ /// 当该清单为空时,运行时会回退到整程序集扫描,以兼容旧版 marker 语义。
+ ///
+ public CqrsReflectionFallbackAttribute(params string[] fallbackHandlerTypeNames)
+ {
+ ArgumentNullException.ThrowIfNull(fallbackHandlerTypeNames);
+
+ FallbackHandlerTypeNames = fallbackHandlerTypeNames
+ .Where(static typeName => !string.IsNullOrWhiteSpace(typeName))
+ .Distinct(StringComparer.Ordinal)
+ .OrderBy(static typeName => typeName, StringComparer.Ordinal)
+ .ToArray();
+ }
+
+ ///
+ /// 获取需要运行时补充反射注册的处理器类型全名集合。
+ ///
+ public IReadOnlyList FallbackHandlerTypeNames { get; }
}
diff --git a/GFramework.Cqrs/Internal/CqrsHandlerRegistrar.cs b/GFramework.Cqrs/Internal/CqrsHandlerRegistrar.cs
index 867c6887..968424b7 100644
--- a/GFramework.Cqrs/Internal/CqrsHandlerRegistrar.cs
+++ b/GFramework.Cqrs/Internal/CqrsHandlerRegistrar.cs
@@ -33,10 +33,14 @@ internal static class CqrsHandlerRegistrar
{
var generatedRegistrationResult =
TryRegisterGeneratedHandlers(container.GetServicesUnsafe, assembly, logger);
- if (generatedRegistrationResult == GeneratedRegistrationResult.FullyHandled)
+ if (generatedRegistrationResult is { UsedGeneratedRegistry: true, RequiresReflectionFallback: false })
continue;
- RegisterAssemblyHandlers(container.GetServicesUnsafe, assembly, logger);
+ RegisterAssemblyHandlers(
+ container.GetServicesUnsafe,
+ assembly,
+ logger,
+ generatedRegistrationResult.ReflectionFallbackTypeNames);
}
}
@@ -66,7 +70,7 @@ internal static class CqrsHandlerRegistrar
.ToList();
if (registryTypes.Count == 0)
- return GeneratedRegistrationResult.NoGeneratedRegistry;
+ return GeneratedRegistrationResult.NoGeneratedRegistry();
var registries = new List(registryTypes.Count);
foreach (var registryType in registryTypes)
@@ -75,21 +79,21 @@ internal static class CqrsHandlerRegistrar
{
logger.Warn(
$"Ignoring generated CQRS handler registry {registryType.FullName} in assembly {assemblyName} because it does not implement {typeof(ICqrsHandlerRegistry).FullName}.");
- return GeneratedRegistrationResult.NoGeneratedRegistry;
+ return GeneratedRegistrationResult.NoGeneratedRegistry();
}
if (registryType.IsAbstract)
{
logger.Warn(
$"Ignoring generated CQRS handler registry {registryType.FullName} in assembly {assemblyName} because it is abstract.");
- return GeneratedRegistrationResult.NoGeneratedRegistry;
+ return GeneratedRegistrationResult.NoGeneratedRegistry();
}
if (Activator.CreateInstance(registryType, nonPublic: true) is not ICqrsHandlerRegistry registry)
{
logger.Warn(
$"Ignoring generated CQRS handler registry {registryType.FullName} in assembly {assemblyName} because it could not be instantiated.");
- return GeneratedRegistrationResult.NoGeneratedRegistry;
+ return GeneratedRegistrationResult.NoGeneratedRegistry();
}
registries.Add(registry);
@@ -102,14 +106,24 @@ internal static class CqrsHandlerRegistrar
registry.Register(services, logger);
}
- if (RequiresReflectionFallback(assembly))
+ var reflectionFallbackTypeNames = GetReflectionFallbackTypeNames(assembly);
+ if (reflectionFallbackTypeNames is not null)
{
- logger.Debug(
- $"Generated CQRS registry for assembly {assemblyName} requested reflection fallback for unsupported handlers.");
- return GeneratedRegistrationResult.RequiresReflectionFallback;
+ if (reflectionFallbackTypeNames.Count > 0)
+ {
+ logger.Debug(
+ $"Generated CQRS registry for assembly {assemblyName} requested targeted reflection fallback for {reflectionFallbackTypeNames.Count} unsupported handler type(s).");
+ }
+ else
+ {
+ logger.Debug(
+ $"Generated CQRS registry for assembly {assemblyName} requested full reflection fallback for unsupported handlers.");
+ }
+
+ return GeneratedRegistrationResult.WithReflectionFallback(reflectionFallbackTypeNames);
}
- return GeneratedRegistrationResult.FullyHandled;
+ return GeneratedRegistrationResult.FullyHandled();
}
catch (Exception exception)
{
@@ -117,16 +131,21 @@ internal static class CqrsHandlerRegistrar
$"Generated CQRS handler registry discovery failed for assembly {assemblyName}. Falling back to reflection scan.");
logger.Warn(
$"Failed to use generated CQRS handler registry for assembly {assemblyName}: {exception.Message}");
- return GeneratedRegistrationResult.NoGeneratedRegistry;
+ return GeneratedRegistrationResult.NoGeneratedRegistry();
}
}
///
/// 注册单个程序集里的所有 CQRS 处理器映射。
///
- private static void RegisterAssemblyHandlers(IServiceCollection services, Assembly assembly, ILogger logger)
+ private static void RegisterAssemblyHandlers(
+ IServiceCollection services,
+ Assembly assembly,
+ ILogger logger,
+ IReadOnlyList? reflectionFallbackTypeNames)
{
- foreach (var implementationType in GetLoadableTypes(assembly, logger).Where(IsConcreteHandlerType))
+ foreach (var implementationType in GetCandidateHandlerTypes(assembly, logger, reflectionFallbackTypeNames)
+ .Where(IsConcreteHandlerType))
{
var handlerInterfaces = implementationType
.GetInterfaces()
@@ -155,6 +174,58 @@ internal static class CqrsHandlerRegistrar
}
}
+ ///
+ /// 根据生成器提供的 fallback 清单或整程序集扫描结果,获取本轮要注册的候选处理器类型。
+ ///
+ private static IReadOnlyList GetCandidateHandlerTypes(
+ Assembly assembly,
+ ILogger logger,
+ IReadOnlyList? reflectionFallbackTypeNames)
+ {
+ return reflectionFallbackTypeNames is { Count: > 0 }
+ ? GetNamedFallbackTypes(assembly, reflectionFallbackTypeNames, logger)
+ : GetLoadableTypes(assembly, logger);
+ }
+
+ ///
+ /// 根据生成器记录的类型全名,精确解析仍需运行时补充注册的处理器类型。
+ ///
+ private static IReadOnlyList GetNamedFallbackTypes(
+ Assembly assembly,
+ IReadOnlyList reflectionFallbackTypeNames,
+ ILogger logger)
+ {
+ var assemblyName = GetAssemblySortKey(assembly);
+ var resolvedTypes = new List(reflectionFallbackTypeNames.Count);
+ foreach (var typeName in reflectionFallbackTypeNames
+ .Where(static name => !string.IsNullOrWhiteSpace(name))
+ .Distinct(StringComparer.Ordinal)
+ .OrderBy(static name => name, StringComparer.Ordinal))
+ {
+ try
+ {
+ var type = assembly.GetType(typeName, throwOnError: false, ignoreCase: false);
+ if (type is null)
+ {
+ logger.Warn(
+ $"Generated CQRS reflection fallback type {typeName} could not be resolved in assembly {assemblyName}. Skipping targeted fallback entry.");
+ continue;
+ }
+
+ resolvedTypes.Add(type);
+ }
+ catch (Exception exception)
+ {
+ logger.Warn(
+ $"Generated CQRS reflection fallback type {typeName} failed to load in assembly {assemblyName}: {exception.Message}");
+ }
+ }
+
+ return resolvedTypes
+ .OrderBy(GetTypeSortKey, StringComparer.Ordinal)
+ .ToList();
+ }
+
///
/// 安全获取程序集中的可加载类型,并在部分类型加载失败时保留其余处理器注册能力。
///
@@ -221,11 +292,24 @@ internal static class CqrsHandlerRegistrar
}
///
- /// 判断生成注册器是否要求运行时继续补充反射扫描。
+ /// 获取生成注册器要求运行时继续补充反射扫描的 handler 类型名清单。
///
- private static bool RequiresReflectionFallback(Assembly assembly)
+ private static IReadOnlyList? GetReflectionFallbackTypeNames(Assembly assembly)
{
- return assembly.GetCustomAttributes(typeof(CqrsReflectionFallbackAttribute), inherit: false)?.Length > 0;
+ var fallbackAttributes = assembly
+ .GetCustomAttributes(typeof(CqrsReflectionFallbackAttribute), inherit: false)
+ .OfType()
+ .ToList();
+
+ if (fallbackAttributes.Count == 0)
+ return null;
+
+ return fallbackAttributes
+ .SelectMany(static attribute => attribute.FallbackHandlerTypeNames)
+ .Where(static typeName => !string.IsNullOrWhiteSpace(typeName))
+ .Distinct(StringComparer.Ordinal)
+ .OrderBy(static typeName => typeName, StringComparer.Ordinal)
+ .ToArray();
}
///
@@ -259,10 +343,36 @@ internal static class CqrsHandlerRegistrar
return type.FullName ?? type.Name;
}
- private enum GeneratedRegistrationResult
+ private readonly record struct GeneratedRegistrationResult(
+ bool UsedGeneratedRegistry,
+ bool RequiresReflectionFallback,
+ IReadOnlyList? ReflectionFallbackTypeNames)
{
- NoGeneratedRegistry,
- FullyHandled,
- RequiresReflectionFallback
+ public static GeneratedRegistrationResult NoGeneratedRegistry()
+ {
+ return new GeneratedRegistrationResult(
+ UsedGeneratedRegistry: false,
+ RequiresReflectionFallback: false,
+ ReflectionFallbackTypeNames: null);
+ }
+
+ public static GeneratedRegistrationResult FullyHandled()
+ {
+ return new GeneratedRegistrationResult(
+ UsedGeneratedRegistry: true,
+ RequiresReflectionFallback: false,
+ ReflectionFallbackTypeNames: null);
+ }
+
+ public static GeneratedRegistrationResult WithReflectionFallback(
+ IReadOnlyList reflectionFallbackTypeNames)
+ {
+ ArgumentNullException.ThrowIfNull(reflectionFallbackTypeNames);
+
+ return new GeneratedRegistrationResult(
+ UsedGeneratedRegistry: true,
+ RequiresReflectionFallback: true,
+ ReflectionFallbackTypeNames: reflectionFallbackTypeNames);
+ }
}
}
diff --git a/GFramework.SourceGenerators.Tests/Cqrs/CqrsHandlerRegistryGeneratorTests.cs b/GFramework.SourceGenerators.Tests/Cqrs/CqrsHandlerRegistryGeneratorTests.cs
index 0cd91844..e1ec1546 100644
--- a/GFramework.SourceGenerators.Tests/Cqrs/CqrsHandlerRegistryGeneratorTests.cs
+++ b/GFramework.SourceGenerators.Tests/Cqrs/CqrsHandlerRegistryGeneratorTests.cs
@@ -65,6 +65,7 @@ public class CqrsHandlerRegistryGeneratorTests
[AttributeUsage(AttributeTargets.Assembly)]
public sealed class CqrsReflectionFallbackAttribute : Attribute
{
+ public CqrsReflectionFallbackAttribute(params string[] fallbackHandlerTypeNames) { }
}
}
@@ -180,6 +181,116 @@ public class CqrsHandlerRegistryGeneratorTests
[AttributeUsage(AttributeTargets.Assembly)]
public sealed class CqrsReflectionFallbackAttribute : Attribute
{
+ public CqrsReflectionFallbackAttribute(params string[] fallbackHandlerTypeNames) { }
+ }
+ }
+
+ namespace TestApp
+ {
+ using GFramework.Cqrs.Abstractions.Cqrs;
+
+ public sealed record VisibleRequest() : IRequest;
+
+ public sealed class Container
+ {
+ private sealed record HiddenRequest() : IRequest;
+
+ private sealed class HiddenHandler : IRequestHandler { }
+ }
+
+ public sealed class VisibleHandler : IRequestHandler { }
+ }
+ """;
+
+ const string expected = """
+ //
+ #nullable enable
+
+ [assembly: global::GFramework.Cqrs.CqrsHandlerRegistryAttribute(typeof(global::GFramework.Generated.Cqrs.__GFrameworkGeneratedCqrsHandlerRegistry))]
+ [assembly: global::GFramework.Cqrs.CqrsReflectionFallbackAttribute("TestApp.Container+HiddenHandler")]
+
+ namespace GFramework.Generated.Cqrs;
+
+ internal sealed class __GFrameworkGeneratedCqrsHandlerRegistry : global::GFramework.Cqrs.ICqrsHandlerRegistry
+ {
+ public void Register(global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, global::GFramework.Core.Abstractions.Logging.ILogger logger)
+ {
+ if (services is null)
+ throw new global::System.ArgumentNullException(nameof(services));
+ if (logger is null)
+ throw new global::System.ArgumentNullException(nameof(logger));
+
+ global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient(
+ services,
+ typeof(global::GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler),
+ typeof(global::TestApp.VisibleHandler));
+ logger.Debug("Registered CQRS handler TestApp.VisibleHandler as GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler.");
+ }
+ }
+
+ """;
+
+ await GeneratorTest.RunAsync(
+ source,
+ ("CqrsHandlerRegistry.g.cs", expected));
+ }
+
+ ///
+ /// 验证当 runtime 仅支持旧版无参 fallback marker 时,生成器会退回旧语义,
+ /// 只输出 marker 而不输出精确类型名。
+ ///
+ [Test]
+ public async Task Generates_Legacy_Fallback_Marker_When_Runtime_Does_Not_Support_Type_Name_List()
+ {
+ const string source = """
+ using System;
+
+ namespace Microsoft.Extensions.DependencyInjection
+ {
+ public interface IServiceCollection { }
+
+ public static class ServiceCollectionServiceExtensions
+ {
+ public static void AddTransient(IServiceCollection services, Type serviceType, Type implementationType) { }
+ }
+ }
+
+ namespace GFramework.Core.Abstractions.Logging
+ {
+ public interface ILogger
+ {
+ void Debug(string msg);
+ }
+ }
+
+ namespace GFramework.Cqrs.Abstractions.Cqrs
+ {
+ public interface IRequest { }
+ public interface INotification { }
+ public interface IStreamRequest { }
+
+ public interface IRequestHandler where TRequest : IRequest { }
+ public interface INotificationHandler where TNotification : INotification { }
+ public interface IStreamRequestHandler where TRequest : IStreamRequest { }
+ }
+
+ namespace GFramework.Cqrs
+ {
+ public interface ICqrsHandlerRegistry
+ {
+ void Register(Microsoft.Extensions.DependencyInjection.IServiceCollection services, GFramework.Core.Abstractions.Logging.ILogger logger);
+ }
+
+ [AttributeUsage(AttributeTargets.Assembly, AllowMultiple = true)]
+ public sealed class CqrsHandlerRegistryAttribute : Attribute
+ {
+ public CqrsHandlerRegistryAttribute(Type registryType) { }
+ }
+
+ [AttributeUsage(AttributeTargets.Assembly)]
+ public sealed class CqrsReflectionFallbackAttribute : Attribute
+ {
+ public CqrsReflectionFallbackAttribute() { }
}
}
diff --git a/GFramework.SourceGenerators/Cqrs/CqrsHandlerRegistryGenerator.cs b/GFramework.SourceGenerators/Cqrs/CqrsHandlerRegistryGenerator.cs
index 80561248..1e260e32 100644
--- a/GFramework.SourceGenerators/Cqrs/CqrsHandlerRegistryGenerator.cs
+++ b/GFramework.SourceGenerators/Cqrs/CqrsHandlerRegistryGenerator.cs
@@ -59,7 +59,8 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
return new GenerationEnvironment(
generationEnabled,
- compilation.GetTypeByMetadataName(CqrsReflectionFallbackAttributeMetadataName) is not null);
+ GetReflectionFallbackEmissionMode(
+ compilation.GetTypeByMetadataName(CqrsReflectionFallbackAttributeMetadataName)));
}
private static bool IsHandlerCandidate(SyntaxNode node)
@@ -96,7 +97,8 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
return new HandlerCandidateAnalysis(
implementationTypeDisplayName,
ImmutableArray.Empty,
- true);
+ true,
+ GetReflectionFallbackTypeName(type));
}
var implementationLogName = GetLogDisplayName(type);
@@ -113,7 +115,8 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
return new HandlerCandidateAnalysis(
implementationTypeDisplayName,
registrations.MoveToImmutable(),
- false);
+ false,
+ null);
}
private static void Execute(SourceProductionContext context, GenerationEnvironment generationEnvironment,
@@ -122,27 +125,37 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
if (!generationEnvironment.GenerationEnabled)
return;
- var registrations = CollectRegistrations(candidates, out var hasUnsupportedConcreteHandler);
+ var registrations = CollectRegistrations(
+ candidates,
+ out var hasUnsupportedConcreteHandler,
+ out var reflectionFallbackTypeNames);
if (registrations.Count == 0)
return;
// If the runtime contract does not yet expose the reflection fallback marker,
// keep the previous all-or-nothing behavior so unsupported handlers are not silently dropped.
- if (hasUnsupportedConcreteHandler && !generationEnvironment.SupportsReflectionFallbackMarker)
+ if (hasUnsupportedConcreteHandler &&
+ generationEnvironment.ReflectionFallbackEmissionMode == ReflectionFallbackEmissionMode.Disabled)
return;
context.AddSource(
HintName,
- GenerateSource(registrations, hasUnsupportedConcreteHandler));
+ GenerateSource(
+ registrations,
+ hasUnsupportedConcreteHandler,
+ generationEnvironment.ReflectionFallbackEmissionMode,
+ reflectionFallbackTypeNames));
}
private static List CollectRegistrations(
ImmutableArray candidates,
- out bool hasUnsupportedConcreteHandler)
+ out bool hasUnsupportedConcreteHandler,
+ out IReadOnlyList reflectionFallbackTypeNames)
{
var registrations = new List();
hasUnsupportedConcreteHandler = false;
+ var fallbackTypeNames = new SortedSet(StringComparer.Ordinal);
// Partial declarations surface the same symbol through multiple syntax nodes.
// Collapse them by implementation type so generated registrations stay stable and duplicate-free.
@@ -156,6 +169,13 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
if (candidate.Value.HasUnsupportedConcreteHandler)
{
hasUnsupportedConcreteHandler = true;
+ var reflectionFallbackTypeName = candidate.Value.ReflectionFallbackTypeName;
+ if (reflectionFallbackTypeName is not null &&
+ !string.IsNullOrWhiteSpace(reflectionFallbackTypeName))
+ {
+ fallbackTypeNames.Add(reflectionFallbackTypeName);
+ }
+
continue;
}
@@ -178,9 +198,30 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
: StringComparer.Ordinal.Compare(left.HandlerInterfaceLogName, right.HandlerInterfaceLogName);
});
+ reflectionFallbackTypeNames = fallbackTypeNames.ToArray();
return registrations;
}
+ private static ReflectionFallbackEmissionMode GetReflectionFallbackEmissionMode(INamedTypeSymbol? attributeType)
+ {
+ if (attributeType is null)
+ return ReflectionFallbackEmissionMode.Disabled;
+
+ foreach (var constructor in attributeType.InstanceConstructors)
+ {
+ if (constructor.Parameters.Length != 1)
+ continue;
+
+ if (constructor.Parameters[0].Type is IArrayTypeSymbol arrayType &&
+ arrayType.ElementType.SpecialType == SpecialType.System_String)
+ {
+ return ReflectionFallbackEmissionMode.PreciseTypeNames;
+ }
+ }
+
+ return ReflectionFallbackEmissionMode.MarkerOnly;
+ }
+
private static bool IsConcreteHandlerType(INamedTypeSymbol type)
{
return type.TypeKind is TypeKind.Class or TypeKind.Struct &&
@@ -272,6 +313,34 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
return builder.ToString();
}
+ private static string GetReflectionFallbackTypeName(INamedTypeSymbol type)
+ {
+ var nestedTypes = new Stack();
+ for (var current = type; current is not null; current = current.ContainingType)
+ {
+ nestedTypes.Push(current.MetadataName);
+ }
+
+ var builder = new StringBuilder();
+ if (!type.ContainingNamespace.IsGlobalNamespace)
+ {
+ builder.Append(type.ContainingNamespace.ToDisplayString());
+ builder.Append('.');
+ }
+
+ var isFirstType = true;
+ while (nestedTypes.Count > 0)
+ {
+ if (!isFirstType)
+ builder.Append('+');
+
+ builder.Append(nestedTypes.Pop());
+ isFirstType = false;
+ }
+
+ return builder.ToString();
+ }
+
private static string GetTypeSortKey(ITypeSymbol type)
{
return type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat);
@@ -284,7 +353,9 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
private static string GenerateSource(
IReadOnlyList registrations,
- bool emitReflectionFallbackAttribute)
+ bool emitReflectionFallbackAttribute,
+ ReflectionFallbackEmissionMode reflectionFallbackEmissionMode,
+ IReadOnlyList reflectionFallbackTypeNames)
{
var builder = new StringBuilder();
builder.AppendLine("// ");
@@ -297,11 +368,10 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
builder.Append('.');
builder.Append(GeneratedTypeName);
builder.AppendLine("))]");
- if (emitReflectionFallbackAttribute)
+ if (emitReflectionFallbackAttribute &&
+ reflectionFallbackEmissionMode != ReflectionFallbackEmissionMode.Disabled)
{
- builder.Append("[assembly: global::");
- builder.Append(CqrsRuntimeNamespace);
- builder.AppendLine(".CqrsReflectionFallbackAttribute()]");
+ AppendReflectionFallbackAttribute(builder, reflectionFallbackEmissionMode, reflectionFallbackTypeNames);
}
builder.AppendLine();
@@ -349,6 +419,36 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
return builder.ToString();
}
+ private static void AppendReflectionFallbackAttribute(
+ StringBuilder builder,
+ ReflectionFallbackEmissionMode reflectionFallbackEmissionMode,
+ IReadOnlyList reflectionFallbackTypeNames)
+ {
+ builder.Append("[assembly: global::");
+ builder.Append(CqrsRuntimeNamespace);
+ builder.Append(".CqrsReflectionFallbackAttribute");
+
+ if (reflectionFallbackEmissionMode == ReflectionFallbackEmissionMode.PreciseTypeNames &&
+ reflectionFallbackTypeNames.Count > 0)
+ {
+ builder.Append('(');
+ for (var index = 0; index < reflectionFallbackTypeNames.Count; index++)
+ {
+ if (index > 0)
+ builder.Append(", ");
+
+ builder.Append('"');
+ builder.Append(EscapeStringLiteral(reflectionFallbackTypeNames[index]));
+ builder.Append('"');
+ }
+
+ builder.AppendLine(")]");
+ return;
+ }
+
+ builder.AppendLine("()]");
+ }
+
private static string EscapeStringLiteral(string value)
{
return value.Replace("\\", "\\\\")
@@ -368,11 +468,13 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
public HandlerCandidateAnalysis(
string implementationTypeDisplayName,
ImmutableArray registrations,
- bool hasUnsupportedConcreteHandler)
+ bool hasUnsupportedConcreteHandler,
+ string? reflectionFallbackTypeName)
{
ImplementationTypeDisplayName = implementationTypeDisplayName;
Registrations = registrations;
HasUnsupportedConcreteHandler = hasUnsupportedConcreteHandler;
+ ReflectionFallbackTypeName = reflectionFallbackTypeName;
}
public string ImplementationTypeDisplayName { get; }
@@ -381,11 +483,15 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
public bool HasUnsupportedConcreteHandler { get; }
+ public string? ReflectionFallbackTypeName { get; }
+
public bool Equals(HandlerCandidateAnalysis other)
{
if (!string.Equals(ImplementationTypeDisplayName, other.ImplementationTypeDisplayName,
StringComparison.Ordinal) ||
HasUnsupportedConcreteHandler != other.HasUnsupportedConcreteHandler ||
+ !string.Equals(ReflectionFallbackTypeName, other.ReflectionFallbackTypeName,
+ StringComparison.Ordinal) ||
Registrations.Length != other.Registrations.Length)
{
return false;
@@ -411,6 +517,10 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
{
var hashCode = StringComparer.Ordinal.GetHashCode(ImplementationTypeDisplayName);
hashCode = (hashCode * 397) ^ HasUnsupportedConcreteHandler.GetHashCode();
+ hashCode = (hashCode * 397) ^
+ (ReflectionFallbackTypeName is null
+ ? 0
+ : StringComparer.Ordinal.GetHashCode(ReflectionFallbackTypeName));
foreach (var registration in Registrations)
{
hashCode = (hashCode * 397) ^ registration.GetHashCode();
@@ -423,5 +533,12 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
private readonly record struct GenerationEnvironment(
bool GenerationEnabled,
- bool SupportsReflectionFallbackMarker);
+ ReflectionFallbackEmissionMode ReflectionFallbackEmissionMode);
+
+ private enum ReflectionFallbackEmissionMode
+ {
+ Disabled,
+ MarkerOnly,
+ PreciseTypeNames
+ }
}
From 06f95db5933a82d173c6f9930db5dd20e3746baa Mon Sep 17 00:00:00 2001
From: GeWuYou <95328647+GeWuYou@users.noreply.github.com>
Date: Thu, 16 Apr 2026 11:36:31 +0800
Subject: [PATCH 2/4] =?UTF-8?q?feat(cqrs):=20=E6=B7=BB=E5=8A=A0CQRS?=
=?UTF-8?q?=E8=B0=83=E5=BA=A6=E5=99=A8=E5=AE=9E=E7=8E=B0=E5=92=8C=E6=94=B9?=
=?UTF-8?q?=E8=BF=9B=E5=A4=84=E7=90=86=E5=99=A8=E6=B3=A8=E5=86=8C=E6=9C=BA?=
=?UTF-8?q?=E5=88=B6?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
- 实现GFramework自有CQRS运行时分发器,支持请求/通知/流式请求处理
- 添加进程级缓存机制优化反射调用性能,包括请求、通知、流水线调用委托缓存
- 重构CqrsHandlerRegistrar使用ReflectionFallbackMetadata替代字符串类型名
- 引入CqrsReflectionFallbackAttribute支持运行时补充反射扫描的处理器类型
- 添加完整的CQRS处理器注册单元测试,验证有序执行和容错行为
- 修复MicrosoftDiContainer中异常消息的格式化空白问题
- 实现上下文感知处理器的CQRS分发上下文注入功能
---
GFramework.Core/Ioc/MicrosoftDiContainer.cs | 2 +-
.../Cqrs/CqrsDispatcherCacheTests.cs | 172 ++++++++++++++++++
.../Cqrs/CqrsHandlerRegistrarTests.cs | 49 +++++
.../CqrsReflectionFallbackAttribute.cs | 35 ++++
GFramework.Cqrs/Internal/CqrsDispatcher.cs | 34 +++-
.../Internal/CqrsHandlerRegistrar.cs | 107 ++++++-----
6 files changed, 345 insertions(+), 54 deletions(-)
create mode 100644 GFramework.Cqrs.Tests/Cqrs/CqrsDispatcherCacheTests.cs
diff --git a/GFramework.Core/Ioc/MicrosoftDiContainer.cs b/GFramework.Core/Ioc/MicrosoftDiContainer.cs
index 6152366f..d1a0576d 100644
--- a/GFramework.Core/Ioc/MicrosoftDiContainer.cs
+++ b/GFramework.Core/Ioc/MicrosoftDiContainer.cs
@@ -400,7 +400,7 @@ public class MicrosoftDiContainer(IServiceCollection? serviceCollection = null)
/// 要接入的程序集集合。
/// 为 。
/// 中存在 元素。
- /// 容器已冻结,无法继续注册 CQRS 处理器。
+ /// 容器已冻结,无法继续注册 CQRS 处理器。
public void RegisterCqrsHandlersFromAssemblies(IEnumerable assemblies)
{
ArgumentNullException.ThrowIfNull(assemblies);
diff --git a/GFramework.Cqrs.Tests/Cqrs/CqrsDispatcherCacheTests.cs b/GFramework.Cqrs.Tests/Cqrs/CqrsDispatcherCacheTests.cs
new file mode 100644
index 00000000..6c52a910
--- /dev/null
+++ b/GFramework.Cqrs.Tests/Cqrs/CqrsDispatcherCacheTests.cs
@@ -0,0 +1,172 @@
+using GFramework.Core.Abstractions.Logging;
+using GFramework.Core.Architectures;
+using GFramework.Core.Ioc;
+using GFramework.Core.Logging;
+using GFramework.Cqrs.Abstractions.Cqrs;
+
+namespace GFramework.Cqrs.Tests.Cqrs;
+
+///
+/// 验证 CQRS dispatcher 会缓存热路径中的服务类型构造结果。
+///
+[TestFixture]
+internal sealed class CqrsDispatcherCacheTests
+{
+ private MicrosoftDiContainer? _container;
+ private ArchitectureContext? _context;
+
+ ///
+ /// 初始化测试上下文。
+ ///
+ [SetUp]
+ public void SetUp()
+ {
+ LoggerFactoryResolver.Provider = new ConsoleLoggerFactoryProvider();
+ _container = new MicrosoftDiContainer();
+
+ CqrsTestRuntime.RegisterHandlers(
+ _container,
+ typeof(CqrsDispatcherCacheTests).Assembly,
+ typeof(ArchitectureContext).Assembly);
+
+ _container.Freeze();
+ _context = new ArchitectureContext(_container);
+ }
+
+ ///
+ /// 清理测试上下文引用。
+ ///
+ [TearDown]
+ public void TearDown()
+ {
+ _context = null;
+ _container = null;
+ }
+
+ ///
+ /// 验证相同消息类型重复分发时,不会重复扩张服务类型缓存。
+ ///
+ [Test]
+ public async Task Dispatcher_Should_Cache_Service_Types_After_First_Dispatch()
+ {
+ var notificationServiceTypes = GetCacheField("NotificationHandlerServiceTypes");
+ var requestServiceTypes = GetCacheField("RequestServiceTypes");
+ var streamServiceTypes = GetCacheField("StreamHandlerServiceTypes");
+
+ var notificationBefore = notificationServiceTypes.Count;
+ var requestBefore = requestServiceTypes.Count;
+ var streamBefore = streamServiceTypes.Count;
+
+ await _context!.SendRequestAsync(new DispatcherCacheRequest());
+ await _context.PublishAsync(new DispatcherCacheNotification());
+ await DrainAsync(_context.CreateStream(new DispatcherCacheStreamRequest()));
+
+ var notificationAfterFirstDispatch = notificationServiceTypes.Count;
+ var requestAfterFirstDispatch = requestServiceTypes.Count;
+ var streamAfterFirstDispatch = streamServiceTypes.Count;
+
+ await _context.SendRequestAsync(new DispatcherCacheRequest());
+ await _context.PublishAsync(new DispatcherCacheNotification());
+ await DrainAsync(_context.CreateStream(new DispatcherCacheStreamRequest()));
+
+ Assert.Multiple(() =>
+ {
+ Assert.That(notificationAfterFirstDispatch, Is.EqualTo(notificationBefore + 1));
+ Assert.That(requestAfterFirstDispatch, Is.EqualTo(requestBefore + 1));
+ Assert.That(streamAfterFirstDispatch, Is.EqualTo(streamBefore + 1));
+
+ Assert.That(notificationServiceTypes.Count, Is.EqualTo(notificationAfterFirstDispatch));
+ Assert.That(requestServiceTypes.Count, Is.EqualTo(requestAfterFirstDispatch));
+ Assert.That(streamServiceTypes.Count, Is.EqualTo(streamAfterFirstDispatch));
+ });
+ }
+
+ ///
+ /// 通过反射读取 dispatcher 的静态缓存字典。
+ ///
+ private static IDictionary GetCacheField(string fieldName)
+ {
+ var dispatcherType = typeof(CqrsReflectionFallbackAttribute).Assembly
+ .GetType("GFramework.Cqrs.Internal.CqrsDispatcher", throwOnError: true)!;
+
+ var field = dispatcherType.GetField(
+ fieldName,
+ BindingFlags.NonPublic | BindingFlags.Static);
+
+ Assert.That(field, Is.Not.Null, $"Missing dispatcher cache field {fieldName}.");
+
+ return field!.GetValue(null) as IDictionary
+ ?? throw new InvalidOperationException(
+ $"Dispatcher cache field {fieldName} does not implement IDictionary.");
+ }
+
+ ///
+ /// 消费整个异步流,确保建流路径被真实执行。
+ ///
+ private static async Task DrainAsync(IAsyncEnumerable stream)
+ {
+ await foreach (var _ in stream)
+ {
+ }
+ }
+}
+
+///
+/// 用于验证 request 服务类型缓存的测试请求。
+///
+internal sealed record DispatcherCacheRequest : IRequest;
+
+///
+/// 用于验证 notification 服务类型缓存的测试通知。
+///
+internal sealed record DispatcherCacheNotification : INotification;
+
+///
+/// 用于验证 stream 服务类型缓存的测试请求。
+///
+internal sealed record DispatcherCacheStreamRequest : IStreamRequest;
+
+///
+/// 处理 。
+///
+internal sealed class DispatcherCacheRequestHandler : IRequestHandler
+{
+ ///
+ /// 返回固定结果,供缓存测试验证 dispatcher 请求路径。
+ ///
+ public ValueTask Handle(DispatcherCacheRequest request, CancellationToken cancellationToken)
+ {
+ return ValueTask.FromResult(1);
+ }
+}
+
+///
+/// 处理 。
+///
+internal sealed class DispatcherCacheNotificationHandler : INotificationHandler
+{
+ ///
+ /// 消费通知,不执行额外副作用。
+ ///
+ public ValueTask Handle(DispatcherCacheNotification notification, CancellationToken cancellationToken)
+ {
+ return ValueTask.CompletedTask;
+ }
+}
+
+///
+/// 处理 。
+///
+internal sealed class DispatcherCacheStreamHandler : IStreamRequestHandler
+{
+ ///
+ /// 返回一个最小流,供缓存测试命中 stream 分发路径。
+ ///
+ public async IAsyncEnumerable Handle(
+ DispatcherCacheStreamRequest request,
+ [EnumeratorCancellation] CancellationToken cancellationToken)
+ {
+ yield return 1;
+ await Task.CompletedTask;
+ }
+}
diff --git a/GFramework.Cqrs.Tests/Cqrs/CqrsHandlerRegistrarTests.cs b/GFramework.Cqrs.Tests/Cqrs/CqrsHandlerRegistrarTests.cs
index 95afa92f..b44b0bb1 100644
--- a/GFramework.Cqrs.Tests/Cqrs/CqrsHandlerRegistrarTests.cs
+++ b/GFramework.Cqrs.Tests/Cqrs/CqrsHandlerRegistrarTests.cs
@@ -243,6 +243,55 @@ internal sealed class CqrsHandlerRegistrarTests
Times.Once);
generatedAssembly.Verify(static assembly => assembly.GetTypes(), Times.Never);
}
+
+ ///
+ /// 验证手写 fallback metadata 直接提供 handler 类型时,运行时会复用这些类型,
+ /// 而不会再通过程序集名称查找或整程序集扫描补齐映射。
+ ///
+ [Test]
+ public void RegisterHandlers_Should_Use_Direct_Fallback_Types_Without_GetType_Or_GetTypes()
+ {
+ var generatedAssembly = new Mock();
+ generatedAssembly
+ .SetupGet(static assembly => assembly.FullName)
+ .Returns(ReflectionFallbackNotificationContainer.ReflectionOnlyHandlerType.Assembly.FullName);
+ generatedAssembly
+ .Setup(static assembly => assembly.GetCustomAttributes(typeof(CqrsHandlerRegistryAttribute), false))
+ .Returns([new CqrsHandlerRegistryAttribute(typeof(PartialGeneratedNotificationHandlerRegistry))]);
+ generatedAssembly
+ .Setup(static assembly => assembly.GetCustomAttributes(typeof(CqrsReflectionFallbackAttribute), false))
+ .Returns(
+ [
+ new CqrsReflectionFallbackAttribute(
+ ReflectionFallbackNotificationContainer.ReflectionOnlyHandlerType)
+ ]);
+
+ var container = new MicrosoftDiContainer();
+ CqrsTestRuntime.RegisterHandlers(container, generatedAssembly.Object);
+
+ var registrations = container.GetServicesUnsafe
+ .Where(static descriptor =>
+ descriptor.ServiceType == typeof(INotificationHandler) &&
+ descriptor.ImplementationType is not null)
+ .Select(static descriptor => descriptor.ImplementationType!)
+ .ToList();
+
+ Assert.That(
+ registrations,
+ Is.EqualTo(
+ [
+ typeof(GeneratedRegistryNotificationHandler),
+ ReflectionFallbackNotificationContainer.ReflectionOnlyHandlerType
+ ]));
+
+ generatedAssembly.Verify(
+ static assembly => assembly.GetType(
+ ReflectionFallbackNotificationContainer.ReflectionOnlyHandlerType.FullName!,
+ false,
+ false),
+ Times.Never);
+ generatedAssembly.Verify(static assembly => assembly.GetTypes(), Times.Never);
+ }
}
///
diff --git a/GFramework.Cqrs/CqrsReflectionFallbackAttribute.cs b/GFramework.Cqrs/CqrsReflectionFallbackAttribute.cs
index 9d3c21bf..da557d84 100644
--- a/GFramework.Cqrs/CqrsReflectionFallbackAttribute.cs
+++ b/GFramework.Cqrs/CqrsReflectionFallbackAttribute.cs
@@ -11,6 +11,15 @@ namespace GFramework.Cqrs;
[AttributeUsage(AttributeTargets.Assembly)]
public sealed class CqrsReflectionFallbackAttribute : Attribute
{
+ ///
+ /// 初始化 ,保留旧版“仅标记需要补扫”的语义。
+ ///
+ public CqrsReflectionFallbackAttribute()
+ {
+ FallbackHandlerTypeNames = [];
+ FallbackHandlerTypes = [];
+ }
+
///
/// 初始化 。
///
@@ -27,10 +36,36 @@ public sealed class CqrsReflectionFallbackAttribute : Attribute
.Distinct(StringComparer.Ordinal)
.OrderBy(static typeName => typeName, StringComparer.Ordinal)
.ToArray();
+ FallbackHandlerTypes = [];
+ }
+
+ ///
+ /// 初始化 。
+ ///
+ ///
+ /// 需要运行时补充反射注册的处理器类型。
+ /// 该重载适合手写或第三方程序集显式声明可直接引用的 fallback handlers,
+ /// 避免再通过字符串名称回查程序集元数据。
+ ///
+ public CqrsReflectionFallbackAttribute(params Type[] fallbackHandlerTypes)
+ {
+ ArgumentNullException.ThrowIfNull(fallbackHandlerTypes);
+
+ FallbackHandlerTypeNames = [];
+ FallbackHandlerTypes = fallbackHandlerTypes
+ .Where(static type => type is not null)
+ .Distinct()
+ .OrderBy(static type => type.FullName ?? type.Name, StringComparer.Ordinal)
+ .ToArray();
}
///
/// 获取需要运行时补充反射注册的处理器类型全名集合。
///
public IReadOnlyList FallbackHandlerTypeNames { get; }
+
+ ///
+ /// 获取可直接供运行时补充反射注册的处理器类型集合。
+ ///
+ public IReadOnlyList FallbackHandlerTypes { get; }
}
diff --git a/GFramework.Cqrs/Internal/CqrsDispatcher.cs b/GFramework.Cqrs/Internal/CqrsDispatcher.cs
index 9a125789..91532e17 100644
--- a/GFramework.Cqrs/Internal/CqrsDispatcher.cs
+++ b/GFramework.Cqrs/Internal/CqrsDispatcher.cs
@@ -1,5 +1,3 @@
-using System.Collections.Concurrent;
-using System.Reflection;
using GFramework.Core.Abstractions.Architectures;
using GFramework.Core.Abstractions.Ioc;
using GFramework.Core.Abstractions.Logging;
@@ -30,10 +28,22 @@ internal sealed class CqrsDispatcher(
// 进程级缓存:缓存通知调用委托,复用并发安全字典以支撑多线程发布路径。
private static readonly ConcurrentDictionary NotificationInvokers = new();
+ // 进程级缓存:缓存通知处理器服务类型,避免每次发布都重复 MakeGenericType。
+ private static readonly ConcurrentDictionary NotificationHandlerServiceTypes = new();
+
// 进程级缓存:缓存流式请求调用委托,避免每次创建流时重复解析反射签名。
private static readonly ConcurrentDictionary<(Type RequestType, Type ResponseType), StreamInvoker> StreamInvokers =
new();
+ // 进程级缓存:缓存请求处理器与 pipeline 行为的服务类型,减少热路径中的泛型类型构造。
+ private static readonly ConcurrentDictionary<(Type RequestType, Type ResponseType), RequestServiceTypeSet>
+ RequestServiceTypes = new();
+
+ // 进程级缓存:缓存流式请求处理器服务类型,避免每次建流时重复 MakeGenericType。
+ private static readonly ConcurrentDictionary<(Type RequestType, Type ResponseType), Type>
+ StreamHandlerServiceTypes =
+ new();
+
///
/// 发布通知到所有已注册处理器。
///
@@ -51,7 +61,9 @@ internal sealed class CqrsDispatcher(
ArgumentNullException.ThrowIfNull(notification);
var notificationType = notification.GetType();
- var handlerType = typeof(INotificationHandler<>).MakeGenericType(notificationType);
+ var handlerType = NotificationHandlerServiceTypes.GetOrAdd(
+ notificationType,
+ static type => typeof(INotificationHandler<>).MakeGenericType(type));
var handlers = container.GetAll(handlerType);
if (handlers.Count == 0)
@@ -88,14 +100,18 @@ internal sealed class CqrsDispatcher(
ArgumentNullException.ThrowIfNull(request);
var requestType = request.GetType();
- var handlerType = typeof(IRequestHandler<,>).MakeGenericType(requestType, typeof(TResponse));
+ var serviceTypes = RequestServiceTypes.GetOrAdd(
+ (requestType, typeof(TResponse)),
+ static key => new RequestServiceTypeSet(
+ typeof(IRequestHandler<,>).MakeGenericType(key.RequestType, key.ResponseType),
+ typeof(IPipelineBehavior<,>).MakeGenericType(key.RequestType, key.ResponseType)));
+ var handlerType = serviceTypes.HandlerType;
var handler = container.Get(handlerType)
?? throw new InvalidOperationException(
$"No CQRS request handler registered for {requestType.FullName}.");
PrepareHandler(handler, context);
- var behaviorType = typeof(IPipelineBehavior<,>).MakeGenericType(requestType, typeof(TResponse));
- var behaviors = container.GetAll(behaviorType);
+ var behaviors = container.GetAll(serviceTypes.BehaviorType);
foreach (var behavior in behaviors)
PrepareHandler(behavior, context);
@@ -135,7 +151,9 @@ internal sealed class CqrsDispatcher(
ArgumentNullException.ThrowIfNull(request);
var requestType = request.GetType();
- var handlerType = typeof(IStreamRequestHandler<,>).MakeGenericType(requestType, typeof(TResponse));
+ var handlerType = StreamHandlerServiceTypes.GetOrAdd(
+ (requestType, typeof(TResponse)),
+ static key => typeof(IStreamRequestHandler<,>).MakeGenericType(key.RequestType, key.ResponseType));
var handler = container.Get(handlerType)
?? throw new InvalidOperationException(
$"No CQRS stream handler registered for {requestType.FullName}.");
@@ -293,4 +311,6 @@ internal sealed class CqrsDispatcher(
CancellationToken cancellationToken);
private delegate object StreamInvoker(object handler, object request, CancellationToken cancellationToken);
+
+ private readonly record struct RequestServiceTypeSet(Type HandlerType, Type BehaviorType);
}
diff --git a/GFramework.Cqrs/Internal/CqrsHandlerRegistrar.cs b/GFramework.Cqrs/Internal/CqrsHandlerRegistrar.cs
index 968424b7..3604de83 100644
--- a/GFramework.Cqrs/Internal/CqrsHandlerRegistrar.cs
+++ b/GFramework.Cqrs/Internal/CqrsHandlerRegistrar.cs
@@ -40,7 +40,7 @@ internal static class CqrsHandlerRegistrar
container.GetServicesUnsafe,
assembly,
logger,
- generatedRegistrationResult.ReflectionFallbackTypeNames);
+ generatedRegistrationResult.ReflectionFallbackMetadata);
}
}
@@ -106,13 +106,13 @@ internal static class CqrsHandlerRegistrar
registry.Register(services, logger);
}
- var reflectionFallbackTypeNames = GetReflectionFallbackTypeNames(assembly);
- if (reflectionFallbackTypeNames is not null)
+ var reflectionFallbackMetadata = GetReflectionFallbackMetadata(assembly, logger);
+ if (reflectionFallbackMetadata is not null)
{
- if (reflectionFallbackTypeNames.Count > 0)
+ if (reflectionFallbackMetadata.HasExplicitTypes)
{
logger.Debug(
- $"Generated CQRS registry for assembly {assemblyName} requested targeted reflection fallback for {reflectionFallbackTypeNames.Count} unsupported handler type(s).");
+ $"Generated CQRS registry for assembly {assemblyName} requested targeted reflection fallback for {reflectionFallbackMetadata.Types.Count} unsupported handler type(s).");
}
else
{
@@ -120,7 +120,7 @@ internal static class CqrsHandlerRegistrar
$"Generated CQRS registry for assembly {assemblyName} requested full reflection fallback for unsupported handlers.");
}
- return GeneratedRegistrationResult.WithReflectionFallback(reflectionFallbackTypeNames);
+ return GeneratedRegistrationResult.WithReflectionFallback(reflectionFallbackMetadata);
}
return GeneratedRegistrationResult.FullyHandled();
@@ -142,9 +142,9 @@ internal static class CqrsHandlerRegistrar
IServiceCollection services,
Assembly assembly,
ILogger logger,
- IReadOnlyList? reflectionFallbackTypeNames)
+ ReflectionFallbackMetadata? reflectionFallbackMetadata)
{
- foreach (var implementationType in GetCandidateHandlerTypes(assembly, logger, reflectionFallbackTypeNames)
+ foreach (var implementationType in GetCandidateHandlerTypes(assembly, logger, reflectionFallbackMetadata)
.Where(IsConcreteHandlerType))
{
var handlerInterfaces = implementationType
@@ -180,24 +180,51 @@ internal static class CqrsHandlerRegistrar
private static IReadOnlyList GetCandidateHandlerTypes(
Assembly assembly,
ILogger logger,
- IReadOnlyList? reflectionFallbackTypeNames)
+ ReflectionFallbackMetadata? reflectionFallbackMetadata)
{
- return reflectionFallbackTypeNames is { Count: > 0 }
- ? GetNamedFallbackTypes(assembly, reflectionFallbackTypeNames, logger)
+ return reflectionFallbackMetadata is { HasExplicitTypes: true }
+ ? reflectionFallbackMetadata.Types
: GetLoadableTypes(assembly, logger);
}
///
- /// 根据生成器记录的类型全名,精确解析仍需运行时补充注册的处理器类型。
+ /// 获取生成注册器要求运行时继续补充反射扫描的 handler 元数据。
///
- private static IReadOnlyList GetNamedFallbackTypes(
+ private static ReflectionFallbackMetadata? GetReflectionFallbackMetadata(
Assembly assembly,
- IReadOnlyList reflectionFallbackTypeNames,
ILogger logger)
{
var assemblyName = GetAssemblySortKey(assembly);
- var resolvedTypes = new List(reflectionFallbackTypeNames.Count);
- foreach (var typeName in reflectionFallbackTypeNames
+ var fallbackAttributes = assembly
+ .GetCustomAttributes(typeof(CqrsReflectionFallbackAttribute), inherit: false)
+ .OfType()
+ .ToList();
+
+ if (fallbackAttributes.Count == 0)
+ return null;
+
+ var resolvedTypes = new List();
+ foreach (var fallbackType in fallbackAttributes
+ .SelectMany(static attribute => attribute.FallbackHandlerTypes)
+ .Where(static type => type is not null)
+ .Distinct()
+ .OrderBy(GetTypeSortKey, StringComparer.Ordinal))
+ {
+ if (!string.Equals(
+ GetAssemblySortKey(fallbackType.Assembly),
+ assemblyName,
+ StringComparison.Ordinal))
+ {
+ logger.Warn(
+ $"Generated CQRS reflection fallback type {fallbackType.FullName} was declared on assembly {assemblyName} but belongs to assembly {GetAssemblySortKey(fallbackType.Assembly)}. Skipping mismatched fallback entry.");
+ continue;
+ }
+
+ resolvedTypes.Add(fallbackType);
+ }
+
+ foreach (var typeName in fallbackAttributes
+ .SelectMany(static attribute => attribute.FallbackHandlerTypeNames)
.Where(static name => !string.IsNullOrWhiteSpace(name))
.Distinct(StringComparer.Ordinal)
.OrderBy(static name => name, StringComparer.Ordinal))
@@ -221,9 +248,11 @@ internal static class CqrsHandlerRegistrar
}
}
- return resolvedTypes
- .OrderBy(GetTypeSortKey, StringComparer.Ordinal)
- .ToList();
+ return new ReflectionFallbackMetadata(
+ resolvedTypes
+ .Distinct()
+ .OrderBy(GetTypeSortKey, StringComparer.Ordinal)
+ .ToArray());
}
///
@@ -291,27 +320,6 @@ internal static class CqrsHandlerRegistrar
definition == typeof(IStreamRequestHandler<,>);
}
- ///
- /// 获取生成注册器要求运行时继续补充反射扫描的 handler 类型名清单。
- ///
- private static IReadOnlyList? GetReflectionFallbackTypeNames(Assembly assembly)
- {
- var fallbackAttributes = assembly
- .GetCustomAttributes(typeof(CqrsReflectionFallbackAttribute), inherit: false)
- .OfType()
- .ToList();
-
- if (fallbackAttributes.Count == 0)
- return null;
-
- return fallbackAttributes
- .SelectMany(static attribute => attribute.FallbackHandlerTypeNames)
- .Where(static typeName => !string.IsNullOrWhiteSpace(typeName))
- .Distinct(StringComparer.Ordinal)
- .OrderBy(static typeName => typeName, StringComparer.Ordinal)
- .ToArray();
- }
-
///
/// 判断同一 handler 映射是否已经由生成注册器或先前扫描步骤写入服务集合。
///
@@ -346,14 +354,14 @@ internal static class CqrsHandlerRegistrar
private readonly record struct GeneratedRegistrationResult(
bool UsedGeneratedRegistry,
bool RequiresReflectionFallback,
- IReadOnlyList? ReflectionFallbackTypeNames)
+ ReflectionFallbackMetadata? ReflectionFallbackMetadata)
{
public static GeneratedRegistrationResult NoGeneratedRegistry()
{
return new GeneratedRegistrationResult(
UsedGeneratedRegistry: false,
RequiresReflectionFallback: false,
- ReflectionFallbackTypeNames: null);
+ ReflectionFallbackMetadata: null);
}
public static GeneratedRegistrationResult FullyHandled()
@@ -361,18 +369,25 @@ internal static class CqrsHandlerRegistrar
return new GeneratedRegistrationResult(
UsedGeneratedRegistry: true,
RequiresReflectionFallback: false,
- ReflectionFallbackTypeNames: null);
+ ReflectionFallbackMetadata: null);
}
public static GeneratedRegistrationResult WithReflectionFallback(
- IReadOnlyList reflectionFallbackTypeNames)
+ ReflectionFallbackMetadata reflectionFallbackMetadata)
{
- ArgumentNullException.ThrowIfNull(reflectionFallbackTypeNames);
+ ArgumentNullException.ThrowIfNull(reflectionFallbackMetadata);
return new GeneratedRegistrationResult(
UsedGeneratedRegistry: true,
RequiresReflectionFallback: true,
- ReflectionFallbackTypeNames: reflectionFallbackTypeNames);
+ ReflectionFallbackMetadata: reflectionFallbackMetadata);
}
}
+
+ private sealed class ReflectionFallbackMetadata(IReadOnlyList types)
+ {
+ public IReadOnlyList Types { get; } = types ?? throw new ArgumentNullException(nameof(types));
+
+ public bool HasExplicitTypes => Types.Count > 0;
+ }
}
From b07da252c48ebc66875315eb6c4407ffca0ce77e Mon Sep 17 00:00:00 2001
From: GeWuYou <95328647+GeWuYou@users.noreply.github.com>
Date: Thu, 16 Apr 2026 11:53:07 +0800
Subject: [PATCH 3/4] =?UTF-8?q?refactor(cqrs):=20=E4=BC=98=E5=8C=96?=
=?UTF-8?q?=E5=B9=B6=E5=8F=91=E5=A4=84=E7=90=86=E8=83=BD=E5=8A=9B?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
- 在 CqrsDispatcher 中添加 Concurrent 包引用以支持线程安全操作
- 在全局引用文件中增加 Concurrent 包引用,统一并发编程支持
- 为后续的并发处理逻辑改进奠定基础架构支持
---
GFramework.Cqrs/GlobalUsings.cs | 1 +
1 file changed, 1 insertion(+)
diff --git a/GFramework.Cqrs/GlobalUsings.cs b/GFramework.Cqrs/GlobalUsings.cs
index b60938a5..3085d1e1 100644
--- a/GFramework.Cqrs/GlobalUsings.cs
+++ b/GFramework.Cqrs/GlobalUsings.cs
@@ -6,3 +6,4 @@ global using System.Threading.Tasks;
global using System.Reflection;
global using Microsoft.Extensions.DependencyInjection;
global using System.Diagnostics;
+global using System.Collections.Concurrent;
From 4951fb0254be4220725ce7c6a20970d2c971db89 Mon Sep 17 00:00:00 2001
From: GeWuYou <95328647+GeWuYou@users.noreply.github.com>
Date: Thu, 16 Apr 2026 12:19:44 +0800
Subject: [PATCH 4/4] =?UTF-8?q?feat(cqrs):=20=E6=B7=BB=E5=8A=A0=20CQRS=20?=
=?UTF-8?q?=E5=88=86=E5=8F=91=E5=99=A8=E5=92=8C=E6=9C=8D=E5=8A=A1=E6=B3=A8?=
=?UTF-8?q?=E5=86=8C=E7=94=9F=E6=88=90=E5=99=A8?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
- 实现 CqrsDispatcher 类,支持请求/通知/流式请求的分发处理
- 添加进程级缓存机制,优化热路径中的反射和类型构造性能
- 实现上下文感知处理器的 CQRS 分发上下文注入功能
- 开发 CqrsHandlerRegistryGenerator 源代码生成器,减少运行时反射扫描
- 添加完整的单元测试验证缓存机制和服务类型注册功能
- 支持管道行为链处理和异步流式请求响应模式
---
.../Cqrs/CqrsDispatcherCacheTests.cs | 68 +++-
GFramework.Cqrs/Internal/CqrsDispatcher.cs | 25 +-
.../Cqrs/CqrsHandlerRegistryGeneratorTests.cs | 223 +++++++----
.../Cqrs/CqrsHandlerRegistryGenerator.cs | 347 ++++++++++--------
4 files changed, 429 insertions(+), 234 deletions(-)
diff --git a/GFramework.Cqrs.Tests/Cqrs/CqrsDispatcherCacheTests.cs b/GFramework.Cqrs.Tests/Cqrs/CqrsDispatcherCacheTests.cs
index 6c52a910..5ae794ad 100644
--- a/GFramework.Cqrs.Tests/Cqrs/CqrsDispatcherCacheTests.cs
+++ b/GFramework.Cqrs.Tests/Cqrs/CqrsDispatcherCacheTests.cs
@@ -12,9 +12,6 @@ namespace GFramework.Cqrs.Tests.Cqrs;
[TestFixture]
internal sealed class CqrsDispatcherCacheTests
{
- private MicrosoftDiContainer? _container;
- private ArchitectureContext? _context;
-
///
/// 初始化测试上下文。
///
@@ -23,6 +20,7 @@ internal sealed class CqrsDispatcherCacheTests
{
LoggerFactoryResolver.Provider = new ConsoleLoggerFactoryProvider();
_container = new MicrosoftDiContainer();
+ _container.RegisterCqrsPipelineBehavior();
CqrsTestRuntime.RegisterHandlers(
_container,
@@ -43,6 +41,9 @@ internal sealed class CqrsDispatcherCacheTests
_container = null;
}
+ private MicrosoftDiContainer? _container;
+ private ArchitectureContext? _context;
+
///
/// 验证相同消息类型重复分发时,不会重复扩张服务类型缓存。
///
@@ -52,32 +53,54 @@ internal sealed class CqrsDispatcherCacheTests
var notificationServiceTypes = GetCacheField("NotificationHandlerServiceTypes");
var requestServiceTypes = GetCacheField("RequestServiceTypes");
var streamServiceTypes = GetCacheField("StreamHandlerServiceTypes");
+ var requestInvokers = GetCacheField("RequestInvokers");
+ var requestPipelineInvokers = GetCacheField("RequestPipelineInvokers");
+ var notificationInvokers = GetCacheField("NotificationInvokers");
+ var streamInvokers = GetCacheField("StreamInvokers");
var notificationBefore = notificationServiceTypes.Count;
var requestBefore = requestServiceTypes.Count;
var streamBefore = streamServiceTypes.Count;
+ var requestInvokersBefore = requestInvokers.Count;
+ var requestPipelineInvokersBefore = requestPipelineInvokers.Count;
+ var notificationInvokersBefore = notificationInvokers.Count;
+ var streamInvokersBefore = streamInvokers.Count;
await _context!.SendRequestAsync(new DispatcherCacheRequest());
+ await _context.SendRequestAsync(new DispatcherPipelineCacheRequest());
await _context.PublishAsync(new DispatcherCacheNotification());
await DrainAsync(_context.CreateStream(new DispatcherCacheStreamRequest()));
var notificationAfterFirstDispatch = notificationServiceTypes.Count;
var requestAfterFirstDispatch = requestServiceTypes.Count;
var streamAfterFirstDispatch = streamServiceTypes.Count;
+ var requestInvokersAfterFirstDispatch = requestInvokers.Count;
+ var requestPipelineInvokersAfterFirstDispatch = requestPipelineInvokers.Count;
+ var notificationInvokersAfterFirstDispatch = notificationInvokers.Count;
+ var streamInvokersAfterFirstDispatch = streamInvokers.Count;
await _context.SendRequestAsync(new DispatcherCacheRequest());
+ await _context.SendRequestAsync(new DispatcherPipelineCacheRequest());
await _context.PublishAsync(new DispatcherCacheNotification());
await DrainAsync(_context.CreateStream(new DispatcherCacheStreamRequest()));
Assert.Multiple(() =>
{
Assert.That(notificationAfterFirstDispatch, Is.EqualTo(notificationBefore + 1));
- Assert.That(requestAfterFirstDispatch, Is.EqualTo(requestBefore + 1));
+ Assert.That(requestAfterFirstDispatch, Is.EqualTo(requestBefore + 2));
Assert.That(streamAfterFirstDispatch, Is.EqualTo(streamBefore + 1));
+ Assert.That(requestInvokersAfterFirstDispatch, Is.EqualTo(requestInvokersBefore + 1));
+ Assert.That(requestPipelineInvokersAfterFirstDispatch, Is.EqualTo(requestPipelineInvokersBefore + 1));
+ Assert.That(notificationInvokersAfterFirstDispatch, Is.EqualTo(notificationInvokersBefore + 1));
+ Assert.That(streamInvokersAfterFirstDispatch, Is.EqualTo(streamInvokersBefore + 1));
Assert.That(notificationServiceTypes.Count, Is.EqualTo(notificationAfterFirstDispatch));
Assert.That(requestServiceTypes.Count, Is.EqualTo(requestAfterFirstDispatch));
Assert.That(streamServiceTypes.Count, Is.EqualTo(streamAfterFirstDispatch));
+ Assert.That(requestInvokers.Count, Is.EqualTo(requestInvokersAfterFirstDispatch));
+ Assert.That(requestPipelineInvokers.Count, Is.EqualTo(requestPipelineInvokersAfterFirstDispatch));
+ Assert.That(notificationInvokers.Count, Is.EqualTo(notificationInvokersAfterFirstDispatch));
+ Assert.That(streamInvokers.Count, Is.EqualTo(streamInvokersAfterFirstDispatch));
});
}
@@ -126,6 +149,11 @@ internal sealed record DispatcherCacheNotification : INotification;
///
internal sealed record DispatcherCacheStreamRequest : IStreamRequest;
+///
+/// 用于验证 pipeline invoker 缓存的测试请求。
+///
+internal sealed record DispatcherPipelineCacheRequest : IRequest;
+
///
/// 处理 。
///
@@ -170,3 +198,35 @@ internal sealed class DispatcherCacheStreamHandler : IStreamRequestHandler
+/// 处理 。
+///
+internal sealed class DispatcherPipelineCacheRequestHandler : IRequestHandler
+{
+ ///
+ /// 返回固定结果,供 pipeline 缓存测试使用。
+ ///
+ public ValueTask Handle(DispatcherPipelineCacheRequest request, CancellationToken cancellationToken)
+ {
+ return ValueTask.FromResult(2);
+ }
+}
+
+///
+/// 为 提供最小 pipeline 行为,
+/// 用于命中 dispatcher 的 pipeline invoker 缓存分支。
+///
+internal sealed class DispatcherPipelineCacheBehavior : IPipelineBehavior
+{
+ ///
+ /// 直接转发到下一个处理器。
+ ///
+ public ValueTask Handle(
+ DispatcherPipelineCacheRequest request,
+ MessageHandlerDelegate next,
+ CancellationToken cancellationToken)
+ {
+ return next(request, cancellationToken);
+ }
+}
diff --git a/GFramework.Cqrs/Internal/CqrsDispatcher.cs b/GFramework.Cqrs/Internal/CqrsDispatcher.cs
index 91532e17..002b7edc 100644
--- a/GFramework.Cqrs/Internal/CqrsDispatcher.cs
+++ b/GFramework.Cqrs/Internal/CqrsDispatcher.cs
@@ -44,6 +44,19 @@ internal sealed class CqrsDispatcher(
StreamHandlerServiceTypes =
new();
+ // 静态方法定义缓存:这些反射查找与消息类型无关,只需解析一次即可复用。
+ private static readonly MethodInfo RequestHandlerInvokerMethodDefinition = typeof(CqrsDispatcher)
+ .GetMethod(nameof(InvokeRequestHandlerAsync), BindingFlags.NonPublic | BindingFlags.Static)!;
+
+ private static readonly MethodInfo RequestPipelineInvokerMethodDefinition = typeof(CqrsDispatcher)
+ .GetMethod(nameof(InvokeRequestPipelineAsync), BindingFlags.NonPublic | BindingFlags.Static)!;
+
+ private static readonly MethodInfo NotificationHandlerInvokerMethodDefinition = typeof(CqrsDispatcher)
+ .GetMethod(nameof(InvokeNotificationHandlerAsync), BindingFlags.NonPublic | BindingFlags.Static)!;
+
+ private static readonly MethodInfo StreamHandlerInvokerMethodDefinition = typeof(CqrsDispatcher)
+ .GetMethod(nameof(InvokeStreamHandler), BindingFlags.NonPublic | BindingFlags.Static)!;
+
///
/// 发布通知到所有已注册处理器。
///
@@ -189,8 +202,7 @@ internal sealed class CqrsDispatcher(
///
private static RequestInvoker CreateRequestInvoker(Type requestType, Type responseType)
{
- var method = typeof(CqrsDispatcher)
- .GetMethod(nameof(InvokeRequestHandlerAsync), BindingFlags.NonPublic | BindingFlags.Static)!
+ var method = RequestHandlerInvokerMethodDefinition
.MakeGenericMethod(requestType, responseType);
return (RequestInvoker)Delegate.CreateDelegate(typeof(RequestInvoker), method);
}
@@ -200,8 +212,7 @@ internal sealed class CqrsDispatcher(
///
private static RequestPipelineInvoker CreateRequestPipelineInvoker(Type requestType, Type responseType)
{
- var method = typeof(CqrsDispatcher)
- .GetMethod(nameof(InvokeRequestPipelineAsync), BindingFlags.NonPublic | BindingFlags.Static)!
+ var method = RequestPipelineInvokerMethodDefinition
.MakeGenericMethod(requestType, responseType);
return (RequestPipelineInvoker)Delegate.CreateDelegate(typeof(RequestPipelineInvoker), method);
}
@@ -211,8 +222,7 @@ internal sealed class CqrsDispatcher(
///
private static NotificationInvoker CreateNotificationInvoker(Type notificationType)
{
- var method = typeof(CqrsDispatcher)
- .GetMethod(nameof(InvokeNotificationHandlerAsync), BindingFlags.NonPublic | BindingFlags.Static)!
+ var method = NotificationHandlerInvokerMethodDefinition
.MakeGenericMethod(notificationType);
return (NotificationInvoker)Delegate.CreateDelegate(typeof(NotificationInvoker), method);
}
@@ -222,8 +232,7 @@ internal sealed class CqrsDispatcher(
///
private static StreamInvoker CreateStreamInvoker(Type requestType, Type responseType)
{
- var method = typeof(CqrsDispatcher)
- .GetMethod(nameof(InvokeStreamHandler), BindingFlags.NonPublic | BindingFlags.Static)!
+ var method = StreamHandlerInvokerMethodDefinition
.MakeGenericMethod(requestType, responseType);
return (StreamInvoker)Delegate.CreateDelegate(typeof(StreamInvoker), method);
}
diff --git a/GFramework.SourceGenerators.Tests/Cqrs/CqrsHandlerRegistryGeneratorTests.cs b/GFramework.SourceGenerators.Tests/Cqrs/CqrsHandlerRegistryGeneratorTests.cs
index e1ec1546..dcdb5e5f 100644
--- a/GFramework.SourceGenerators.Tests/Cqrs/CqrsHandlerRegistryGeneratorTests.cs
+++ b/GFramework.SourceGenerators.Tests/Cqrs/CqrsHandlerRegistryGeneratorTests.cs
@@ -10,6 +10,138 @@ namespace GFramework.SourceGenerators.Tests.Cqrs;
[TestFixture]
public class CqrsHandlerRegistryGeneratorTests
{
+ private const string HiddenNestedHandlerSelfRegistrationExpected = """
+ //
+ #nullable enable
+
+ [assembly: global::GFramework.Cqrs.CqrsHandlerRegistryAttribute(typeof(global::GFramework.Generated.Cqrs.__GFrameworkGeneratedCqrsHandlerRegistry))]
+
+ namespace GFramework.Generated.Cqrs;
+
+ internal sealed class __GFrameworkGeneratedCqrsHandlerRegistry : global::GFramework.Cqrs.ICqrsHandlerRegistry
+ {
+ public void Register(global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, global::GFramework.Core.Abstractions.Logging.ILogger logger)
+ {
+ if (services is null)
+ throw new global::System.ArgumentNullException(nameof(services));
+ if (logger is null)
+ throw new global::System.ArgumentNullException(nameof(logger));
+
+ var registryAssembly = typeof(global::GFramework.Generated.Cqrs.__GFrameworkGeneratedCqrsHandlerRegistry).Assembly;
+
+ RegisterReflectedHandler(services, logger, registryAssembly, "TestApp.Container+HiddenHandler");
+ global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient(
+ services,
+ typeof(global::GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler),
+ typeof(global::TestApp.VisibleHandler));
+ logger.Debug("Registered CQRS handler TestApp.VisibleHandler as GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler.");
+ }
+
+ private static void RegisterReflectedHandler(global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, global::GFramework.Core.Abstractions.Logging.ILogger logger, global::System.Reflection.Assembly registryAssembly, string implementationTypeMetadataName)
+ {
+ var implementationType = registryAssembly.GetType(implementationTypeMetadataName, throwOnError: false, ignoreCase: false);
+ if (implementationType is null)
+ return;
+
+ var handlerInterfaces = implementationType.GetInterfaces();
+ global::System.Array.Sort(handlerInterfaces, CompareTypes);
+
+ foreach (var handlerInterface in handlerInterfaces)
+ {
+ if (!IsSupportedHandlerInterface(handlerInterface))
+ continue;
+
+ global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient(
+ services,
+ handlerInterface,
+ implementationType);
+ logger.Debug($"Registered CQRS handler {GetRuntimeTypeDisplayName(implementationType)} as {GetRuntimeTypeDisplayName(handlerInterface)}.");
+ }
+ }
+
+ private static int CompareTypes(global::System.Type left, global::System.Type right)
+ {
+ return global::System.StringComparer.Ordinal.Compare(GetRuntimeTypeDisplayName(left), GetRuntimeTypeDisplayName(right));
+ }
+
+ private static bool IsSupportedHandlerInterface(global::System.Type interfaceType)
+ {
+ if (!interfaceType.IsGenericType)
+ return false;
+
+ var definitionFullName = interfaceType.GetGenericTypeDefinition().FullName;
+ return global::System.StringComparer.Ordinal.Equals(definitionFullName, "GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler`2")
+ || global::System.StringComparer.Ordinal.Equals(definitionFullName, "GFramework.Cqrs.Abstractions.Cqrs.INotificationHandler`1")
+ || global::System.StringComparer.Ordinal.Equals(definitionFullName, "GFramework.Cqrs.Abstractions.Cqrs.IStreamRequestHandler`2");
+ }
+
+ private static string GetRuntimeTypeDisplayName(global::System.Type type)
+ {
+ if (type == typeof(string))
+ return "string";
+ if (type == typeof(int))
+ return "int";
+ if (type == typeof(long))
+ return "long";
+ if (type == typeof(short))
+ return "short";
+ if (type == typeof(byte))
+ return "byte";
+ if (type == typeof(bool))
+ return "bool";
+ if (type == typeof(object))
+ return "object";
+ if (type == typeof(void))
+ return "void";
+ if (type == typeof(uint))
+ return "uint";
+ if (type == typeof(ulong))
+ return "ulong";
+ if (type == typeof(ushort))
+ return "ushort";
+ if (type == typeof(sbyte))
+ return "sbyte";
+ if (type == typeof(float))
+ return "float";
+ if (type == typeof(double))
+ return "double";
+ if (type == typeof(decimal))
+ return "decimal";
+ if (type == typeof(char))
+ return "char";
+
+ if (type.IsArray)
+ return GetRuntimeTypeDisplayName(type.GetElementType()!) + "[]";
+
+ if (!type.IsGenericType)
+ return (type.FullName ?? type.Name).Replace('+', '.');
+
+ var genericTypeName = type.GetGenericTypeDefinition().FullName ?? type.Name;
+ var arityIndex = genericTypeName.IndexOf('`');
+ if (arityIndex >= 0)
+ genericTypeName = genericTypeName[..arityIndex];
+
+ genericTypeName = genericTypeName.Replace('+', '.');
+ var arguments = type.GetGenericArguments();
+ var builder = new global::System.Text.StringBuilder();
+ builder.Append(genericTypeName);
+ builder.Append('<');
+
+ for (var index = 0; index < arguments.Length; index++)
+ {
+ if (index > 0)
+ builder.Append(", ");
+
+ builder.Append(GetRuntimeTypeDisplayName(arguments[index]));
+ }
+
+ builder.Append('>');
+ return builder.ToString();
+ }
+ }
+
+ """;
+
///
/// 验证生成器会为当前程序集中的 request、notification 和 stream 处理器生成稳定顺序的注册器。
///
@@ -126,12 +258,12 @@ public class CqrsHandlerRegistryGeneratorTests
}
///
- /// 验证当程序集包含生成代码无法合法引用的私有嵌套处理器时,生成器仍会为可见 handlers 生成注册器,
- /// 并额外标记运行时补充反射扫描。
+ /// 验证当程序集包含生成代码无法合法引用的私有嵌套处理器时,生成器会在生成注册器内部执行定向反射注册,
+ /// 不再依赖程序集级 fallback marker。
///
[Test]
public async Task
- Generates_Visible_Handlers_And_Requests_Reflection_Fallback_When_Assembly_Contains_Private_Nested_Handler()
+ Generates_Visible_Handlers_And_Self_Registers_Private_Nested_Handler_When_Assembly_Contains_Hidden_Handler()
{
const string source = """
using System;
@@ -202,45 +334,17 @@ public class CqrsHandlerRegistryGeneratorTests
}
""";
- const string expected = """
- //
- #nullable enable
-
- [assembly: global::GFramework.Cqrs.CqrsHandlerRegistryAttribute(typeof(global::GFramework.Generated.Cqrs.__GFrameworkGeneratedCqrsHandlerRegistry))]
- [assembly: global::GFramework.Cqrs.CqrsReflectionFallbackAttribute("TestApp.Container+HiddenHandler")]
-
- namespace GFramework.Generated.Cqrs;
-
- internal sealed class __GFrameworkGeneratedCqrsHandlerRegistry : global::GFramework.Cqrs.ICqrsHandlerRegistry
- {
- public void Register(global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, global::GFramework.Core.Abstractions.Logging.ILogger logger)
- {
- if (services is null)
- throw new global::System.ArgumentNullException(nameof(services));
- if (logger is null)
- throw new global::System.ArgumentNullException(nameof(logger));
-
- global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient(
- services,
- typeof(global::GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler),
- typeof(global::TestApp.VisibleHandler));
- logger.Debug("Registered CQRS handler TestApp.VisibleHandler as GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler.");
- }
- }
-
- """;
-
await GeneratorTest.RunAsync(
source,
- ("CqrsHandlerRegistry.g.cs", expected));
+ ("CqrsHandlerRegistry.g.cs", HiddenNestedHandlerSelfRegistrationExpected));
}
///
- /// 验证当 runtime 仅支持旧版无参 fallback marker 时,生成器会退回旧语义,
- /// 只输出 marker 而不输出精确类型名。
+ /// 验证即使 runtime 仍暴露旧版无参 fallback marker,生成器也会优先在生成注册器内部处理隐藏 handler,
+ /// 不再输出 fallback marker。
///
[Test]
- public async Task Generates_Legacy_Fallback_Marker_When_Runtime_Does_Not_Support_Type_Name_List()
+ public async Task Does_Not_Emit_Legacy_Fallback_Marker_When_Generated_Registry_Can_Self_Register_Hidden_Handler()
{
const string source = """
using System;
@@ -311,45 +415,17 @@ public class CqrsHandlerRegistryGeneratorTests
}
""";
- const string expected = """
- //
- #nullable enable
-
- [assembly: global::GFramework.Cqrs.CqrsHandlerRegistryAttribute(typeof(global::GFramework.Generated.Cqrs.__GFrameworkGeneratedCqrsHandlerRegistry))]
- [assembly: global::GFramework.Cqrs.CqrsReflectionFallbackAttribute()]
-
- namespace GFramework.Generated.Cqrs;
-
- internal sealed class __GFrameworkGeneratedCqrsHandlerRegistry : global::GFramework.Cqrs.ICqrsHandlerRegistry
- {
- public void Register(global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, global::GFramework.Core.Abstractions.Logging.ILogger logger)
- {
- if (services is null)
- throw new global::System.ArgumentNullException(nameof(services));
- if (logger is null)
- throw new global::System.ArgumentNullException(nameof(logger));
-
- global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient(
- services,
- typeof(global::GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler),
- typeof(global::TestApp.VisibleHandler));
- logger.Debug("Registered CQRS handler TestApp.VisibleHandler as GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler.");
- }
- }
-
- """;
-
await GeneratorTest.RunAsync(
source,
- ("CqrsHandlerRegistry.g.cs", expected));
+ ("CqrsHandlerRegistry.g.cs", HiddenNestedHandlerSelfRegistrationExpected));
}
///
- /// 验证当旧版 runtime 合同中不存在 reflection fallback 标记特性时,
- /// 生成器会保留此前的整程序集回退行为,避免丢失不可见 handlers。
+ /// 验证即使 runtime 合同中完全不存在 reflection fallback 标记特性,
+ /// 生成器仍能通过生成注册器内部的定向反射逻辑覆盖隐藏 handler。
///
[Test]
- public async Task Skips_Generation_For_Unsupported_Handler_When_Fallback_Marker_Is_Unavailable()
+ public async Task Generates_Registry_For_Hidden_Handler_When_Fallback_Marker_Is_Unavailable()
{
const string source = """
using System;
@@ -414,16 +490,9 @@ public class CqrsHandlerRegistryGeneratorTests
}
""";
- var test = new CSharpSourceGeneratorTest
- {
- TestState =
- {
- Sources = { source }
- },
- DisabledDiagnostics = { "GF_Common_Trace_001" }
- };
-
- await test.RunAsync();
+ await GeneratorTest.RunAsync(
+ source,
+ ("CqrsHandlerRegistry.g.cs", HiddenNestedHandlerSelfRegistrationExpected));
}
///
diff --git a/GFramework.SourceGenerators/Cqrs/CqrsHandlerRegistryGenerator.cs b/GFramework.SourceGenerators/Cqrs/CqrsHandlerRegistryGenerator.cs
index 1e260e32..83559781 100644
--- a/GFramework.SourceGenerators/Cqrs/CqrsHandlerRegistryGenerator.cs
+++ b/GFramework.SourceGenerators/Cqrs/CqrsHandlerRegistryGenerator.cs
@@ -16,9 +16,6 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
private const string IStreamRequestHandlerMetadataName = $"{CqrsContractsNamespace}.IStreamRequestHandler`2";
private const string ICqrsHandlerRegistryMetadataName = $"{CqrsRuntimeNamespace}.ICqrsHandlerRegistry";
- private const string CqrsReflectionFallbackAttributeMetadataName =
- $"{CqrsRuntimeNamespace}.CqrsReflectionFallbackAttribute";
-
private const string CqrsHandlerRegistryAttributeMetadataName =
$"{CqrsRuntimeNamespace}.CqrsHandlerRegistryAttribute";
@@ -57,10 +54,7 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
compilation.GetTypeByMetadataName(ILoggerMetadataName) is not null &&
compilation.GetTypeByMetadataName(IServiceCollectionMetadataName) is not null;
- return new GenerationEnvironment(
- generationEnabled,
- GetReflectionFallbackEmissionMode(
- compilation.GetTypeByMetadataName(CqrsReflectionFallbackAttributeMetadataName)));
+ return new GenerationEnvironment(generationEnabled);
}
private static bool IsHandlerCandidate(SyntaxNode node)
@@ -91,17 +85,20 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
return null;
var implementationTypeDisplayName = type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat);
+ var implementationLogName = GetLogDisplayName(type);
if (!CanReferenceFromGeneratedRegistry(type) ||
handlerInterfaces.Any(interfaceType => !CanReferenceFromGeneratedRegistry(interfaceType)))
{
+ // Non-public handlers and handlers closed over non-public message types cannot appear in typeof(...)
+ // expressions inside generated code. Preserve generator hit rate by resolving just that implementation
+ // type back from the current assembly instead of asking the runtime registrar to rescan the assembly.
return new HandlerCandidateAnalysis(
implementationTypeDisplayName,
+ implementationLogName,
ImmutableArray.Empty,
- true,
- GetReflectionFallbackTypeName(type));
+ GetReflectionTypeMetadataName(type));
}
- var implementationLogName = GetLogDisplayName(type);
var registrations = ImmutableArray.CreateBuilder(handlerInterfaces.Length);
foreach (var handlerInterface in handlerInterfaces)
{
@@ -114,8 +111,8 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
return new HandlerCandidateAnalysis(
implementationTypeDisplayName,
+ implementationLogName,
registrations.MoveToImmutable(),
- false,
null);
}
@@ -125,40 +122,23 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
if (!generationEnvironment.GenerationEnabled)
return;
- var registrations = CollectRegistrations(
- candidates,
- out var hasUnsupportedConcreteHandler,
- out var reflectionFallbackTypeNames);
+ var registrations = CollectRegistrations(candidates);
if (registrations.Count == 0)
return;
- // If the runtime contract does not yet expose the reflection fallback marker,
- // keep the previous all-or-nothing behavior so unsupported handlers are not silently dropped.
- if (hasUnsupportedConcreteHandler &&
- generationEnvironment.ReflectionFallbackEmissionMode == ReflectionFallbackEmissionMode.Disabled)
- return;
-
context.AddSource(
HintName,
- GenerateSource(
- registrations,
- hasUnsupportedConcreteHandler,
- generationEnvironment.ReflectionFallbackEmissionMode,
- reflectionFallbackTypeNames));
+ GenerateSource(registrations));
}
- private static List CollectRegistrations(
- ImmutableArray candidates,
- out bool hasUnsupportedConcreteHandler,
- out IReadOnlyList reflectionFallbackTypeNames)
+ private static List CollectRegistrations(
+ ImmutableArray candidates)
{
- var registrations = new List();
- hasUnsupportedConcreteHandler = false;
- var fallbackTypeNames = new SortedSet(StringComparer.Ordinal);
+ var registrations = new List();
// Partial declarations surface the same symbol through multiple syntax nodes.
- // Collapse them by implementation type so generated registrations stay stable and duplicate-free.
+ // Collapse them by implementation type so direct and reflected registrations stay stable and duplicate-free.
var uniqueCandidates = new Dictionary(StringComparer.Ordinal);
foreach (var candidate in candidates)
@@ -166,25 +146,16 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
if (candidate is null)
continue;
- if (candidate.Value.HasUnsupportedConcreteHandler)
- {
- hasUnsupportedConcreteHandler = true;
- var reflectionFallbackTypeName = candidate.Value.ReflectionFallbackTypeName;
- if (reflectionFallbackTypeName is not null &&
- !string.IsNullOrWhiteSpace(reflectionFallbackTypeName))
- {
- fallbackTypeNames.Add(reflectionFallbackTypeName);
- }
-
- continue;
- }
-
uniqueCandidates[candidate.Value.ImplementationTypeDisplayName] = candidate.Value;
}
foreach (var candidate in uniqueCandidates.Values)
{
- registrations.AddRange(candidate.Registrations);
+ registrations.Add(new ImplementationRegistrationSpec(
+ candidate.ImplementationTypeDisplayName,
+ candidate.ImplementationLogName,
+ candidate.Registrations,
+ candidate.ReflectionTypeMetadataName));
}
registrations.Sort(static (left, right) =>
@@ -193,35 +164,12 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
left.ImplementationLogName,
right.ImplementationLogName);
- return implementationComparison != 0
- ? implementationComparison
- : StringComparer.Ordinal.Compare(left.HandlerInterfaceLogName, right.HandlerInterfaceLogName);
+ return implementationComparison;
});
- reflectionFallbackTypeNames = fallbackTypeNames.ToArray();
return registrations;
}
- private static ReflectionFallbackEmissionMode GetReflectionFallbackEmissionMode(INamedTypeSymbol? attributeType)
- {
- if (attributeType is null)
- return ReflectionFallbackEmissionMode.Disabled;
-
- foreach (var constructor in attributeType.InstanceConstructors)
- {
- if (constructor.Parameters.Length != 1)
- continue;
-
- if (constructor.Parameters[0].Type is IArrayTypeSymbol arrayType &&
- arrayType.ElementType.SpecialType == SpecialType.System_String)
- {
- return ReflectionFallbackEmissionMode.PreciseTypeNames;
- }
- }
-
- return ReflectionFallbackEmissionMode.MarkerOnly;
- }
-
private static bool IsConcreteHandlerType(INamedTypeSymbol type)
{
return type.TypeKind is TypeKind.Class or TypeKind.Struct &&
@@ -313,7 +261,7 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
return builder.ToString();
}
- private static string GetReflectionFallbackTypeName(INamedTypeSymbol type)
+ private static string GetReflectionTypeMetadataName(INamedTypeSymbol type)
{
var nestedTypes = new Stack();
for (var current = type; current is not null; current = current.ContainingType)
@@ -352,11 +300,10 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
}
private static string GenerateSource(
- IReadOnlyList registrations,
- bool emitReflectionFallbackAttribute,
- ReflectionFallbackEmissionMode reflectionFallbackEmissionMode,
- IReadOnlyList reflectionFallbackTypeNames)
+ IReadOnlyList registrations)
{
+ var hasReflectionRegistrations = registrations.Any(static registration =>
+ !string.IsNullOrWhiteSpace(registration.ReflectionTypeMetadataName));
var builder = new StringBuilder();
builder.AppendLine("// ");
builder.AppendLine("#nullable enable");
@@ -368,11 +315,6 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
builder.Append('.');
builder.Append(GeneratedTypeName);
builder.AppendLine("))]");
- if (emitReflectionFallbackAttribute &&
- reflectionFallbackEmissionMode != ReflectionFallbackEmissionMode.Disabled)
- {
- AppendReflectionFallbackAttribute(builder, reflectionFallbackEmissionMode, reflectionFallbackTypeNames);
- }
builder.AppendLine();
builder.Append("namespace ");
@@ -394,59 +336,177 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
builder.AppendLine(" throw new global::System.ArgumentNullException(nameof(services));");
builder.AppendLine(" if (logger is null)");
builder.AppendLine(" throw new global::System.ArgumentNullException(nameof(logger));");
- builder.AppendLine();
+ if (hasReflectionRegistrations)
+ {
+ builder.AppendLine();
+ builder.Append(" var registryAssembly = typeof(global::");
+ builder.Append(GeneratedNamespace);
+ builder.Append('.');
+ builder.Append(GeneratedTypeName);
+ builder.AppendLine(").Assembly;");
+ }
+
+ if (registrations.Count > 0)
+ builder.AppendLine();
foreach (var registration in registrations)
{
- builder.AppendLine(
- " global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient(");
- builder.AppendLine(" services,");
- builder.Append(" typeof(");
- builder.Append(registration.HandlerInterfaceDisplayName);
- builder.AppendLine("),");
- builder.Append(" typeof(");
- builder.Append(registration.ImplementationTypeDisplayName);
- builder.AppendLine("));");
- builder.Append(" logger.Debug(\"Registered CQRS handler ");
- builder.Append(EscapeStringLiteral(registration.ImplementationLogName));
- builder.Append(" as ");
- builder.Append(EscapeStringLiteral(registration.HandlerInterfaceLogName));
- builder.AppendLine(".\");");
+ if (!string.IsNullOrWhiteSpace(registration.ReflectionTypeMetadataName))
+ {
+ AppendReflectionRegistration(builder, registration.ReflectionTypeMetadataName!);
+ continue;
+ }
+
+ foreach (var directRegistration in registration.DirectRegistrations)
+ {
+ builder.AppendLine(
+ " global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient(");
+ builder.AppendLine(" services,");
+ builder.Append(" typeof(");
+ builder.Append(directRegistration.HandlerInterfaceDisplayName);
+ builder.AppendLine("),");
+ builder.Append(" typeof(");
+ builder.Append(directRegistration.ImplementationTypeDisplayName);
+ builder.AppendLine("));");
+ builder.Append(" logger.Debug(\"Registered CQRS handler ");
+ builder.Append(EscapeStringLiteral(directRegistration.ImplementationLogName));
+ builder.Append(" as ");
+ builder.Append(EscapeStringLiteral(directRegistration.HandlerInterfaceLogName));
+ builder.AppendLine(".\");");
+ }
}
builder.AppendLine(" }");
+
+ if (hasReflectionRegistrations)
+ {
+ builder.AppendLine();
+ AppendReflectionHelpers(builder);
+ }
+
builder.AppendLine("}");
return builder.ToString();
}
- private static void AppendReflectionFallbackAttribute(
- StringBuilder builder,
- ReflectionFallbackEmissionMode reflectionFallbackEmissionMode,
- IReadOnlyList reflectionFallbackTypeNames)
+ private static void AppendReflectionRegistration(StringBuilder builder, string reflectionTypeMetadataName)
{
- builder.Append("[assembly: global::");
- builder.Append(CqrsRuntimeNamespace);
- builder.Append(".CqrsReflectionFallbackAttribute");
+ builder.Append(" RegisterReflectedHandler(services, logger, registryAssembly, \"");
+ builder.Append(EscapeStringLiteral(reflectionTypeMetadataName));
+ builder.AppendLine("\");");
+ }
- if (reflectionFallbackEmissionMode == ReflectionFallbackEmissionMode.PreciseTypeNames &&
- reflectionFallbackTypeNames.Count > 0)
- {
- builder.Append('(');
- for (var index = 0; index < reflectionFallbackTypeNames.Count; index++)
- {
- if (index > 0)
- builder.Append(", ");
-
- builder.Append('"');
- builder.Append(EscapeStringLiteral(reflectionFallbackTypeNames[index]));
- builder.Append('"');
- }
-
- builder.AppendLine(")]");
- return;
- }
-
- builder.AppendLine("()]");
+ private static void AppendReflectionHelpers(StringBuilder builder)
+ {
+ // Emit the runtime helper methods only when at least one handler requires metadata-name lookup.
+ builder.AppendLine(
+ " private static void RegisterReflectedHandler(global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, global::GFramework.Core.Abstractions.Logging.ILogger logger, global::System.Reflection.Assembly registryAssembly, string implementationTypeMetadataName)");
+ builder.AppendLine(" {");
+ builder.AppendLine(
+ " var implementationType = registryAssembly.GetType(implementationTypeMetadataName, throwOnError: false, ignoreCase: false);");
+ builder.AppendLine(" if (implementationType is null)");
+ builder.AppendLine(" return;");
+ builder.AppendLine();
+ builder.AppendLine(" var handlerInterfaces = implementationType.GetInterfaces();");
+ builder.AppendLine(" global::System.Array.Sort(handlerInterfaces, CompareTypes);");
+ builder.AppendLine();
+ builder.AppendLine(" foreach (var handlerInterface in handlerInterfaces)");
+ builder.AppendLine(" {");
+ builder.AppendLine(" if (!IsSupportedHandlerInterface(handlerInterface))");
+ builder.AppendLine(" continue;");
+ builder.AppendLine();
+ builder.AppendLine(
+ " global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient(");
+ builder.AppendLine(" services,");
+ builder.AppendLine(" handlerInterface,");
+ builder.AppendLine(" implementationType);");
+ builder.AppendLine(
+ " logger.Debug($\"Registered CQRS handler {GetRuntimeTypeDisplayName(implementationType)} as {GetRuntimeTypeDisplayName(handlerInterface)}.\");");
+ builder.AppendLine(" }");
+ builder.AppendLine(" }");
+ builder.AppendLine();
+ builder.AppendLine(" private static int CompareTypes(global::System.Type left, global::System.Type right)");
+ builder.AppendLine(" {");
+ builder.AppendLine(
+ " return global::System.StringComparer.Ordinal.Compare(GetRuntimeTypeDisplayName(left), GetRuntimeTypeDisplayName(right));");
+ builder.AppendLine(" }");
+ builder.AppendLine();
+ builder.AppendLine(" private static bool IsSupportedHandlerInterface(global::System.Type interfaceType)");
+ builder.AppendLine(" {");
+ builder.AppendLine(" if (!interfaceType.IsGenericType)");
+ builder.AppendLine(" return false;");
+ builder.AppendLine();
+ builder.AppendLine(" var definitionFullName = interfaceType.GetGenericTypeDefinition().FullName;");
+ builder.AppendLine(
+ $" return global::System.StringComparer.Ordinal.Equals(definitionFullName, \"{IRequestHandlerMetadataName}\")");
+ builder.AppendLine(
+ $" || global::System.StringComparer.Ordinal.Equals(definitionFullName, \"{INotificationHandlerMetadataName}\")");
+ builder.AppendLine(
+ $" || global::System.StringComparer.Ordinal.Equals(definitionFullName, \"{IStreamRequestHandlerMetadataName}\");");
+ builder.AppendLine(" }");
+ builder.AppendLine();
+ builder.AppendLine(" private static string GetRuntimeTypeDisplayName(global::System.Type type)");
+ builder.AppendLine(" {");
+ builder.AppendLine(" if (type == typeof(string))");
+ builder.AppendLine(" return \"string\";");
+ builder.AppendLine(" if (type == typeof(int))");
+ builder.AppendLine(" return \"int\";");
+ builder.AppendLine(" if (type == typeof(long))");
+ builder.AppendLine(" return \"long\";");
+ builder.AppendLine(" if (type == typeof(short))");
+ builder.AppendLine(" return \"short\";");
+ builder.AppendLine(" if (type == typeof(byte))");
+ builder.AppendLine(" return \"byte\";");
+ builder.AppendLine(" if (type == typeof(bool))");
+ builder.AppendLine(" return \"bool\";");
+ builder.AppendLine(" if (type == typeof(object))");
+ builder.AppendLine(" return \"object\";");
+ builder.AppendLine(" if (type == typeof(void))");
+ builder.AppendLine(" return \"void\";");
+ builder.AppendLine(" if (type == typeof(uint))");
+ builder.AppendLine(" return \"uint\";");
+ builder.AppendLine(" if (type == typeof(ulong))");
+ builder.AppendLine(" return \"ulong\";");
+ builder.AppendLine(" if (type == typeof(ushort))");
+ builder.AppendLine(" return \"ushort\";");
+ builder.AppendLine(" if (type == typeof(sbyte))");
+ builder.AppendLine(" return \"sbyte\";");
+ builder.AppendLine(" if (type == typeof(float))");
+ builder.AppendLine(" return \"float\";");
+ builder.AppendLine(" if (type == typeof(double))");
+ builder.AppendLine(" return \"double\";");
+ builder.AppendLine(" if (type == typeof(decimal))");
+ builder.AppendLine(" return \"decimal\";");
+ builder.AppendLine(" if (type == typeof(char))");
+ builder.AppendLine(" return \"char\";");
+ builder.AppendLine();
+ builder.AppendLine(" if (type.IsArray)");
+ builder.AppendLine(" return GetRuntimeTypeDisplayName(type.GetElementType()!) + \"[]\";");
+ builder.AppendLine();
+ builder.AppendLine(" if (!type.IsGenericType)");
+ builder.AppendLine(" return (type.FullName ?? type.Name).Replace('+', '.');");
+ builder.AppendLine();
+ builder.AppendLine(" var genericTypeName = type.GetGenericTypeDefinition().FullName ?? type.Name;");
+ builder.AppendLine(" var arityIndex = genericTypeName.IndexOf('`');");
+ builder.AppendLine(" if (arityIndex >= 0)");
+ builder.AppendLine(" genericTypeName = genericTypeName[..arityIndex];");
+ builder.AppendLine();
+ builder.AppendLine(" genericTypeName = genericTypeName.Replace('+', '.');");
+ builder.AppendLine(" var arguments = type.GetGenericArguments();");
+ builder.AppendLine(" var builder = new global::System.Text.StringBuilder();");
+ builder.AppendLine(" builder.Append(genericTypeName);");
+ builder.AppendLine(" builder.Append('<');");
+ builder.AppendLine();
+ builder.AppendLine(" for (var index = 0; index < arguments.Length; index++)");
+ builder.AppendLine(" {");
+ builder.AppendLine(" if (index > 0)");
+ builder.AppendLine(" builder.Append(\", \");");
+ builder.AppendLine();
+ builder.AppendLine(" builder.Append(GetRuntimeTypeDisplayName(arguments[index]));");
+ builder.AppendLine(" }");
+ builder.AppendLine();
+ builder.AppendLine(" builder.Append('>');");
+ builder.AppendLine(" return builder.ToString();");
+ builder.AppendLine(" }");
}
private static string EscapeStringLiteral(string value)
@@ -463,34 +523,40 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
string HandlerInterfaceLogName,
string ImplementationLogName);
+ private readonly record struct ImplementationRegistrationSpec(
+ string ImplementationTypeDisplayName,
+ string ImplementationLogName,
+ ImmutableArray DirectRegistrations,
+ string? ReflectionTypeMetadataName);
+
private readonly struct HandlerCandidateAnalysis : IEquatable
{
public HandlerCandidateAnalysis(
string implementationTypeDisplayName,
+ string implementationLogName,
ImmutableArray registrations,
- bool hasUnsupportedConcreteHandler,
- string? reflectionFallbackTypeName)
+ string? reflectionTypeMetadataName)
{
ImplementationTypeDisplayName = implementationTypeDisplayName;
+ ImplementationLogName = implementationLogName;
Registrations = registrations;
- HasUnsupportedConcreteHandler = hasUnsupportedConcreteHandler;
- ReflectionFallbackTypeName = reflectionFallbackTypeName;
+ ReflectionTypeMetadataName = reflectionTypeMetadataName;
}
public string ImplementationTypeDisplayName { get; }
+ public string ImplementationLogName { get; }
+
public ImmutableArray Registrations { get; }
- public bool HasUnsupportedConcreteHandler { get; }
-
- public string? ReflectionFallbackTypeName { get; }
+ public string? ReflectionTypeMetadataName { get; }
public bool Equals(HandlerCandidateAnalysis other)
{
if (!string.Equals(ImplementationTypeDisplayName, other.ImplementationTypeDisplayName,
StringComparison.Ordinal) ||
- HasUnsupportedConcreteHandler != other.HasUnsupportedConcreteHandler ||
- !string.Equals(ReflectionFallbackTypeName, other.ReflectionFallbackTypeName,
+ !string.Equals(ImplementationLogName, other.ImplementationLogName, StringComparison.Ordinal) ||
+ !string.Equals(ReflectionTypeMetadataName, other.ReflectionTypeMetadataName,
StringComparison.Ordinal) ||
Registrations.Length != other.Registrations.Length)
{
@@ -516,11 +582,11 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
unchecked
{
var hashCode = StringComparer.Ordinal.GetHashCode(ImplementationTypeDisplayName);
- hashCode = (hashCode * 397) ^ HasUnsupportedConcreteHandler.GetHashCode();
+ hashCode = (hashCode * 397) ^ StringComparer.Ordinal.GetHashCode(ImplementationLogName);
hashCode = (hashCode * 397) ^
- (ReflectionFallbackTypeName is null
+ (ReflectionTypeMetadataName is null
? 0
- : StringComparer.Ordinal.GetHashCode(ReflectionFallbackTypeName));
+ : StringComparer.Ordinal.GetHashCode(ReflectionTypeMetadataName));
foreach (var registration in Registrations)
{
hashCode = (hashCode * 397) ^ registration.GetHashCode();
@@ -531,14 +597,5 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
}
}
- private readonly record struct GenerationEnvironment(
- bool GenerationEnabled,
- ReflectionFallbackEmissionMode ReflectionFallbackEmissionMode);
-
- private enum ReflectionFallbackEmissionMode
- {
- Disabled,
- MarkerOnly,
- PreciseTypeNames
- }
+ private readonly record struct GenerationEnvironment(bool GenerationEnabled);
}