From 1792fafc85269b929d121687bee4416bc319250b Mon Sep 17 00:00:00 2001
From: GeWuYou <95328647+GeWuYou@users.noreply.github.com>
Date: Thu, 16 Apr 2026 17:24:52 +0800
Subject: [PATCH] =?UTF-8?q?refactor(Cqrs):=20=E9=87=8D=E6=9E=84CQRS?=
=?UTF-8?q?=E5=A4=84=E7=90=86=E5=99=A8=E6=B3=A8=E5=86=8C=E7=94=9F=E6=88=90?=
=?UTF-8?q?=E9=80=BB=E8=BE=91=E4=BB=A5=E6=94=AF=E6=8C=81=E6=B7=B7=E5=90=88?=
=?UTF-8?q?=E6=B3=A8=E5=86=8C=E7=B1=BB=E5=9E=8B?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
- 修改注册条件判断逻辑,支持多种注册类型的组合处理
- 新增有序注册实现方法,统一处理直接、反射和精确反射注册
- 添加注册类型枚举以区分不同的注册方式
- 实现混合注册场景下的稳定排序机制
- 更新反射注册逻辑以支持更复杂的类型解析
- 优化代码结构提升可读性和维护性
- 添加单元测试验证各种混合注册场景的正确性
---
.../Cqrs/CqrsHandlerRegistryGeneratorTests.cs | 246 ++++++++++++++++++
.../Cqrs/CqrsHandlerRegistryGenerator.cs | 198 ++++++++------
2 files changed, 371 insertions(+), 73 deletions(-)
diff --git a/GFramework.SourceGenerators.Tests/Cqrs/CqrsHandlerRegistryGeneratorTests.cs b/GFramework.SourceGenerators.Tests/Cqrs/CqrsHandlerRegistryGeneratorTests.cs
index ef70ec4a..3026d279 100644
--- a/GFramework.SourceGenerators.Tests/Cqrs/CqrsHandlerRegistryGeneratorTests.cs
+++ b/GFramework.SourceGenerators.Tests/Cqrs/CqrsHandlerRegistryGeneratorTests.cs
@@ -164,6 +164,94 @@ public class CqrsHandlerRegistryGeneratorTests
""";
+ private const string MixedDirectAndPreciseRegistrationsExpected = """
+ //
+ #nullable enable
+
+ [assembly: global::GFramework.Cqrs.CqrsHandlerRegistryAttribute(typeof(global::GFramework.Generated.Cqrs.__GFrameworkGeneratedCqrsHandlerRegistry))]
+
+ namespace GFramework.Generated.Cqrs;
+
+ internal sealed class __GFrameworkGeneratedCqrsHandlerRegistry : global::GFramework.Cqrs.ICqrsHandlerRegistry
+ {
+ public void Register(global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, global::GFramework.Core.Abstractions.Logging.ILogger logger)
+ {
+ if (services is null)
+ throw new global::System.ArgumentNullException(nameof(services));
+ if (logger is null)
+ throw new global::System.ArgumentNullException(nameof(logger));
+
+ var registryAssembly = typeof(global::GFramework.Generated.Cqrs.__GFrameworkGeneratedCqrsHandlerRegistry).Assembly;
+
+ var implementationType0 = typeof(global::TestApp.Container.MixedHandler);
+ if (implementationType0 is not null)
+ {
+ var serviceType0_0Argument0 = registryAssembly.GetType("TestApp.Container+HiddenRequest", throwOnError: false, ignoreCase: false);
+ var serviceType0_0Argument1Element = registryAssembly.GetType("TestApp.Container+HiddenResponse", throwOnError: false, ignoreCase: false);
+ if (serviceType0_0Argument0 is not null && serviceType0_0Argument1Element is not null)
+ {
+ var serviceType0_0 = typeof(global::GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler<,>).MakeGenericType(serviceType0_0Argument0, serviceType0_0Argument1Element.MakeArrayType());
+ global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient(
+ services,
+ serviceType0_0,
+ implementationType0);
+ logger.Debug("Registered CQRS handler TestApp.Container.MixedHandler as GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler.");
+ }
+ global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient(
+ services,
+ typeof(global::GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler),
+ implementationType0);
+ logger.Debug("Registered CQRS handler TestApp.Container.MixedHandler as GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler.");
+ }
+ }
+ }
+
+ """;
+
+ private const string MixedReflectedImplementationAndPreciseRegistrationsExpected = """
+ //
+ #nullable enable
+
+ [assembly: global::GFramework.Cqrs.CqrsHandlerRegistryAttribute(typeof(global::GFramework.Generated.Cqrs.__GFrameworkGeneratedCqrsHandlerRegistry))]
+
+ namespace GFramework.Generated.Cqrs;
+
+ internal sealed class __GFrameworkGeneratedCqrsHandlerRegistry : global::GFramework.Cqrs.ICqrsHandlerRegistry
+ {
+ public void Register(global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, global::GFramework.Core.Abstractions.Logging.ILogger logger)
+ {
+ if (services is null)
+ throw new global::System.ArgumentNullException(nameof(services));
+ if (logger is null)
+ throw new global::System.ArgumentNullException(nameof(logger));
+
+ var registryAssembly = typeof(global::GFramework.Generated.Cqrs.__GFrameworkGeneratedCqrsHandlerRegistry).Assembly;
+
+ var implementationType0 = registryAssembly.GetType("TestApp.Container+HiddenMixedHandler", throwOnError: false, ignoreCase: false);
+ if (implementationType0 is not null)
+ {
+ var serviceType0_0Argument0 = registryAssembly.GetType("TestApp.Container+HiddenRequest", throwOnError: false, ignoreCase: false);
+ var serviceType0_0Argument1Element = registryAssembly.GetType("TestApp.Container+HiddenResponse", throwOnError: false, ignoreCase: false);
+ if (serviceType0_0Argument0 is not null && serviceType0_0Argument1Element is not null)
+ {
+ var serviceType0_0 = typeof(global::GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler<,>).MakeGenericType(serviceType0_0Argument0, serviceType0_0Argument1Element.MakeArrayType());
+ global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient(
+ services,
+ serviceType0_0,
+ implementationType0);
+ logger.Debug("Registered CQRS handler TestApp.Container.HiddenMixedHandler as GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler.");
+ }
+ global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient(
+ services,
+ typeof(global::GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler),
+ implementationType0);
+ logger.Debug("Registered CQRS handler TestApp.Container.HiddenMixedHandler as GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler.");
+ }
+ }
+ }
+
+ """;
+
///
/// 验证生成器会为当前程序集中的 request、notification 和 stream 处理器生成稳定顺序的注册器。
///
@@ -579,6 +667,164 @@ public class CqrsHandlerRegistryGeneratorTests
("CqrsHandlerRegistry.g.cs", HiddenGenericEnvelopeResponseExpected));
}
+ ///
+ /// 验证同一个 implementation 同时包含可直接注册接口与需精确重建接口时,
+ /// 生成器会保留两类注册,并继续按 handler interface 名称稳定排序。
+ ///
+ [Test]
+ public async Task Generates_Mixed_Direct_And_Precise_Registrations_For_Same_Implementation()
+ {
+ const string source = """
+ using System;
+
+ namespace Microsoft.Extensions.DependencyInjection
+ {
+ public interface IServiceCollection { }
+
+ public static class ServiceCollectionServiceExtensions
+ {
+ public static void AddTransient(IServiceCollection services, Type serviceType, Type implementationType) { }
+ }
+ }
+
+ namespace GFramework.Core.Abstractions.Logging
+ {
+ public interface ILogger
+ {
+ void Debug(string msg);
+ }
+ }
+
+ namespace GFramework.Cqrs.Abstractions.Cqrs
+ {
+ public interface IRequest { }
+ public interface INotification { }
+ public interface IStreamRequest { }
+
+ public interface IRequestHandler where TRequest : IRequest { }
+ public interface INotificationHandler where TNotification : INotification { }
+ public interface IStreamRequestHandler where TRequest : IStreamRequest { }
+ }
+
+ namespace GFramework.Cqrs
+ {
+ public interface ICqrsHandlerRegistry
+ {
+ void Register(Microsoft.Extensions.DependencyInjection.IServiceCollection services, GFramework.Core.Abstractions.Logging.ILogger logger);
+ }
+
+ [AttributeUsage(AttributeTargets.Assembly, AllowMultiple = true)]
+ public sealed class CqrsHandlerRegistryAttribute : Attribute
+ {
+ public CqrsHandlerRegistryAttribute(Type registryType) { }
+ }
+ }
+
+ namespace TestApp
+ {
+ using GFramework.Cqrs.Abstractions.Cqrs;
+
+ public sealed record VisibleRequest() : IRequest;
+
+ public sealed class Container
+ {
+ private sealed record HiddenResponse();
+
+ private sealed record HiddenRequest() : IRequest;
+
+ public sealed class MixedHandler :
+ IRequestHandler,
+ IRequestHandler
+ {
+ }
+ }
+ }
+ """;
+
+ await GeneratorTest.RunAsync(
+ source,
+ ("CqrsHandlerRegistry.g.cs", MixedDirectAndPreciseRegistrationsExpected));
+ }
+
+ ///
+ /// 验证隐藏 implementation 同时包含可见 handler interface 与需精确重建接口时,
+ /// 生成器会保留两类注册,而不会让可见接口被整实现回退吞掉。
+ ///
+ [Test]
+ public async Task Generates_Mixed_Reflected_Implementation_And_Precise_Registrations_For_Same_Implementation()
+ {
+ const string source = """
+ using System;
+
+ namespace Microsoft.Extensions.DependencyInjection
+ {
+ public interface IServiceCollection { }
+
+ public static class ServiceCollectionServiceExtensions
+ {
+ public static void AddTransient(IServiceCollection services, Type serviceType, Type implementationType) { }
+ }
+ }
+
+ namespace GFramework.Core.Abstractions.Logging
+ {
+ public interface ILogger
+ {
+ void Debug(string msg);
+ }
+ }
+
+ namespace GFramework.Cqrs.Abstractions.Cqrs
+ {
+ public interface IRequest { }
+ public interface INotification { }
+ public interface IStreamRequest { }
+
+ public interface IRequestHandler where TRequest : IRequest { }
+ public interface INotificationHandler where TNotification : INotification { }
+ public interface IStreamRequestHandler where TRequest : IStreamRequest { }
+ }
+
+ namespace GFramework.Cqrs
+ {
+ public interface ICqrsHandlerRegistry
+ {
+ void Register(Microsoft.Extensions.DependencyInjection.IServiceCollection services, GFramework.Core.Abstractions.Logging.ILogger logger);
+ }
+
+ [AttributeUsage(AttributeTargets.Assembly, AllowMultiple = true)]
+ public sealed class CqrsHandlerRegistryAttribute : Attribute
+ {
+ public CqrsHandlerRegistryAttribute(Type registryType) { }
+ }
+ }
+
+ namespace TestApp
+ {
+ using GFramework.Cqrs.Abstractions.Cqrs;
+
+ public sealed record VisibleRequest() : IRequest;
+
+ public sealed class Container
+ {
+ private sealed record HiddenResponse();
+
+ private sealed record HiddenRequest() : IRequest;
+
+ private sealed class HiddenMixedHandler :
+ IRequestHandler,
+ IRequestHandler
+ {
+ }
+ }
+ }
+ """;
+
+ await GeneratorTest.RunAsync(
+ source,
+ ("CqrsHandlerRegistry.g.cs", MixedReflectedImplementationAndPreciseRegistrationsExpected));
+ }
+
///
/// 验证即使 runtime 仍暴露旧版无参 fallback marker,生成器也会优先在生成注册器内部处理隐藏 handler,
/// 不再输出 fallback marker。
diff --git a/GFramework.SourceGenerators/Cqrs/CqrsHandlerRegistryGenerator.cs b/GFramework.SourceGenerators/Cqrs/CqrsHandlerRegistryGenerator.cs
index 36a72be5..f3bb5e96 100644
--- a/GFramework.SourceGenerators/Cqrs/CqrsHandlerRegistryGenerator.cs
+++ b/GFramework.SourceGenerators/Cqrs/CqrsHandlerRegistryGenerator.cs
@@ -549,40 +549,22 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
for (var registrationIndex = 0; registrationIndex < registrations.Count; registrationIndex++)
{
var registration = registrations[registrationIndex];
- if (!registration.ReflectedImplementationRegistrations.IsDefaultOrEmpty)
+ if (!registration.ReflectedImplementationRegistrations.IsDefaultOrEmpty ||
+ !registration.PreciseReflectedRegistrations.IsDefaultOrEmpty)
{
- AppendReflectedImplementationRegistrations(builder, registration, registrationIndex);
- continue;
+ AppendOrderedImplementationRegistrations(builder, registration, registrationIndex);
+ }
+ else if (!registration.DirectRegistrations.IsDefaultOrEmpty)
+ {
+ AppendDirectRegistrations(builder, registration);
}
- if (!registration.PreciseReflectedRegistrations.IsDefaultOrEmpty)
- {
- AppendPreciseReflectedRegistrations(builder, registration, registrationIndex);
- continue;
- }
-
- if (!string.IsNullOrWhiteSpace(registration.ReflectionTypeMetadataName))
+ if (!string.IsNullOrWhiteSpace(registration.ReflectionTypeMetadataName) &&
+ registration.ReflectedImplementationRegistrations.IsDefaultOrEmpty &&
+ registration.PreciseReflectedRegistrations.IsDefaultOrEmpty &&
+ registration.DirectRegistrations.IsDefaultOrEmpty)
{
AppendReflectionRegistration(builder, registration.ReflectionTypeMetadataName!);
- continue;
- }
-
- foreach (var directRegistration in registration.DirectRegistrations)
- {
- builder.AppendLine(
- " global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient(");
- builder.AppendLine(" services,");
- builder.Append(" typeof(");
- builder.Append(directRegistration.HandlerInterfaceDisplayName);
- builder.AppendLine("),");
- builder.Append(" typeof(");
- builder.Append(directRegistration.ImplementationTypeDisplayName);
- builder.AppendLine("));");
- builder.Append(" logger.Debug(\"Registered CQRS handler ");
- builder.Append(EscapeStringLiteral(directRegistration.ImplementationLogName));
- builder.Append(" as ");
- builder.Append(EscapeStringLiteral(directRegistration.HandlerInterfaceLogName));
- builder.AppendLine(".\");");
}
}
@@ -605,48 +587,71 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
builder.AppendLine("\");");
}
- private static void AppendReflectedImplementationRegistrations(
+ private static void AppendDirectRegistrations(
StringBuilder builder,
- ImplementationRegistrationSpec registration,
- int registrationIndex)
+ ImplementationRegistrationSpec registration)
{
- var implementationVariableName = $"implementationType{registrationIndex}";
- builder.Append(" var ");
- builder.Append(implementationVariableName);
- builder.Append(" = registryAssembly.GetType(\"");
- builder.Append(EscapeStringLiteral(registration.ReflectionTypeMetadataName!));
- builder.AppendLine("\", throwOnError: false, ignoreCase: false);");
- builder.Append(" if (");
- builder.Append(implementationVariableName);
- builder.AppendLine(" is not null)");
- builder.AppendLine(" {");
-
- foreach (var reflectedRegistration in registration.ReflectedImplementationRegistrations)
+ foreach (var directRegistration in registration.DirectRegistrations)
{
builder.AppendLine(
- " global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient(");
- builder.AppendLine(" services,");
- builder.Append(" typeof(");
- builder.Append(reflectedRegistration.HandlerInterfaceDisplayName);
+ " global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient(");
+ builder.AppendLine(" services,");
+ builder.Append(" typeof(");
+ builder.Append(directRegistration.HandlerInterfaceDisplayName);
builder.AppendLine("),");
- builder.Append(" ");
- builder.Append(implementationVariableName);
- builder.AppendLine(");");
- builder.Append(" logger.Debug(\"Registered CQRS handler ");
- builder.Append(EscapeStringLiteral(registration.ImplementationLogName));
+ builder.Append(" typeof(");
+ builder.Append(directRegistration.ImplementationTypeDisplayName);
+ builder.AppendLine("));");
+ builder.Append(" logger.Debug(\"Registered CQRS handler ");
+ builder.Append(EscapeStringLiteral(directRegistration.ImplementationLogName));
builder.Append(" as ");
- builder.Append(EscapeStringLiteral(reflectedRegistration.HandlerInterfaceLogName));
+ builder.Append(EscapeStringLiteral(directRegistration.HandlerInterfaceLogName));
builder.AppendLine(".\");");
}
-
- builder.AppendLine(" }");
}
- private static void AppendPreciseReflectedRegistrations(
+ private static void AppendOrderedImplementationRegistrations(
StringBuilder builder,
ImplementationRegistrationSpec registration,
int registrationIndex)
{
+ var orderedRegistrations =
+ new List<(string HandlerInterfaceLogName, OrderedRegistrationKind Kind, int Index)>(
+ registration.DirectRegistrations.Length +
+ registration.ReflectedImplementationRegistrations.Length +
+ registration.PreciseReflectedRegistrations.Length);
+
+ for (var directIndex = 0; directIndex < registration.DirectRegistrations.Length; directIndex++)
+ {
+ orderedRegistrations.Add((
+ registration.DirectRegistrations[directIndex].HandlerInterfaceLogName,
+ OrderedRegistrationKind.Direct,
+ directIndex));
+ }
+
+ for (var reflectedIndex = 0;
+ reflectedIndex < registration.ReflectedImplementationRegistrations.Length;
+ reflectedIndex++)
+ {
+ orderedRegistrations.Add((
+ registration.ReflectedImplementationRegistrations[reflectedIndex].HandlerInterfaceLogName,
+ OrderedRegistrationKind.ReflectedImplementation,
+ reflectedIndex));
+ }
+
+ for (var preciseIndex = 0;
+ preciseIndex < registration.PreciseReflectedRegistrations.Length;
+ preciseIndex++)
+ {
+ orderedRegistrations.Add((
+ registration.PreciseReflectedRegistrations[preciseIndex].HandlerInterfaceLogName,
+ OrderedRegistrationKind.PreciseReflected,
+ preciseIndex));
+ }
+
+ orderedRegistrations.Sort(static (left, right) =>
+ StringComparer.Ordinal.Compare(left.HandlerInterfaceLogName, right.HandlerInterfaceLogName));
+
var implementationVariableName = $"implementationType{registrationIndex}";
if (string.IsNullOrWhiteSpace(registration.ReflectionTypeMetadataName))
{
@@ -658,11 +663,10 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
}
else
{
- var implementationReflectionTypeMetadataName = registration.ReflectionTypeMetadataName!;
builder.Append(" var ");
builder.Append(implementationVariableName);
builder.Append(" = registryAssembly.GetType(\"");
- builder.Append(EscapeStringLiteral(implementationReflectionTypeMetadataName));
+ builder.Append(EscapeStringLiteral(registration.ReflectionTypeMetadataName!));
builder.AppendLine("\", throwOnError: false, ignoreCase: false);");
}
@@ -671,21 +675,62 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
builder.AppendLine(" is not null)");
builder.AppendLine(" {");
- for (var registrationOffset = 0;
- registrationOffset < registration.PreciseReflectedRegistrations.Length;
- registrationOffset++)
+ foreach (var orderedRegistration in orderedRegistrations)
{
- var reflectedRegistration = registration.PreciseReflectedRegistrations[registrationOffset];
- var registrationVariablePrefix = $"serviceType{registrationIndex}_{registrationOffset}";
- AppendPreciseReflectedTypeResolution(
- builder,
- reflectedRegistration.ServiceTypeArguments,
- registrationVariablePrefix,
- implementationVariableName,
- reflectedRegistration.OpenHandlerTypeDisplayName,
- registration.ImplementationLogName,
- reflectedRegistration.HandlerInterfaceLogName,
- 3);
+ switch (orderedRegistration.Kind)
+ {
+ case OrderedRegistrationKind.Direct:
+ var directRegistration = registration.DirectRegistrations[orderedRegistration.Index];
+ builder.AppendLine(
+ " global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient(");
+ builder.AppendLine(" services,");
+ builder.Append(" typeof(");
+ builder.Append(directRegistration.HandlerInterfaceDisplayName);
+ builder.AppendLine("),");
+ builder.Append(" ");
+ builder.Append(implementationVariableName);
+ builder.AppendLine(");");
+ builder.Append(" logger.Debug(\"Registered CQRS handler ");
+ builder.Append(EscapeStringLiteral(registration.ImplementationLogName));
+ builder.Append(" as ");
+ builder.Append(EscapeStringLiteral(directRegistration.HandlerInterfaceLogName));
+ builder.AppendLine(".\");");
+ break;
+ case OrderedRegistrationKind.ReflectedImplementation:
+ var reflectedRegistration =
+ registration.ReflectedImplementationRegistrations[orderedRegistration.Index];
+ builder.AppendLine(
+ " global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient(");
+ builder.AppendLine(" services,");
+ builder.Append(" typeof(");
+ builder.Append(reflectedRegistration.HandlerInterfaceDisplayName);
+ builder.AppendLine("),");
+ builder.Append(" ");
+ builder.Append(implementationVariableName);
+ builder.AppendLine(");");
+ builder.Append(" logger.Debug(\"Registered CQRS handler ");
+ builder.Append(EscapeStringLiteral(registration.ImplementationLogName));
+ builder.Append(" as ");
+ builder.Append(EscapeStringLiteral(reflectedRegistration.HandlerInterfaceLogName));
+ builder.AppendLine(".\");");
+ break;
+ case OrderedRegistrationKind.PreciseReflected:
+ var preciseRegistration = registration.PreciseReflectedRegistrations[orderedRegistration.Index];
+ var registrationVariablePrefix = $"serviceType{registrationIndex}_{orderedRegistration.Index}";
+ AppendPreciseReflectedTypeResolution(
+ builder,
+ preciseRegistration.ServiceTypeArguments,
+ registrationVariablePrefix,
+ implementationVariableName,
+ preciseRegistration.OpenHandlerTypeDisplayName,
+ registration.ImplementationLogName,
+ preciseRegistration.HandlerInterfaceLogName,
+ 3);
+ break;
+ default:
+ throw new InvalidOperationException(
+ $"Unsupported ordered CQRS registration kind {orderedRegistration.Kind}.");
+ }
}
builder.AppendLine(" }");
@@ -969,6 +1014,13 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
string HandlerInterfaceDisplayName,
string HandlerInterfaceLogName);
+ private enum OrderedRegistrationKind
+ {
+ Direct,
+ ReflectedImplementation,
+ PreciseReflected
+ }
+
private sealed record RuntimeTypeReferenceSpec(
string? TypeDisplayName,
string? ReflectionTypeMetadataName,