diff --git a/GFramework.SourceGenerators.Tests/Cqrs/CqrsHandlerRegistryGeneratorTests.cs b/GFramework.SourceGenerators.Tests/Cqrs/CqrsHandlerRegistryGeneratorTests.cs
index ef70ec4a..3026d279 100644
--- a/GFramework.SourceGenerators.Tests/Cqrs/CqrsHandlerRegistryGeneratorTests.cs
+++ b/GFramework.SourceGenerators.Tests/Cqrs/CqrsHandlerRegistryGeneratorTests.cs
@@ -164,6 +164,94 @@ public class CqrsHandlerRegistryGeneratorTests
""";
+ private const string MixedDirectAndPreciseRegistrationsExpected = """
+ //
+ #nullable enable
+
+ [assembly: global::GFramework.Cqrs.CqrsHandlerRegistryAttribute(typeof(global::GFramework.Generated.Cqrs.__GFrameworkGeneratedCqrsHandlerRegistry))]
+
+ namespace GFramework.Generated.Cqrs;
+
+ internal sealed class __GFrameworkGeneratedCqrsHandlerRegistry : global::GFramework.Cqrs.ICqrsHandlerRegistry
+ {
+ public void Register(global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, global::GFramework.Core.Abstractions.Logging.ILogger logger)
+ {
+ if (services is null)
+ throw new global::System.ArgumentNullException(nameof(services));
+ if (logger is null)
+ throw new global::System.ArgumentNullException(nameof(logger));
+
+ var registryAssembly = typeof(global::GFramework.Generated.Cqrs.__GFrameworkGeneratedCqrsHandlerRegistry).Assembly;
+
+ var implementationType0 = typeof(global::TestApp.Container.MixedHandler);
+ if (implementationType0 is not null)
+ {
+ var serviceType0_0Argument0 = registryAssembly.GetType("TestApp.Container+HiddenRequest", throwOnError: false, ignoreCase: false);
+ var serviceType0_0Argument1Element = registryAssembly.GetType("TestApp.Container+HiddenResponse", throwOnError: false, ignoreCase: false);
+ if (serviceType0_0Argument0 is not null && serviceType0_0Argument1Element is not null)
+ {
+ var serviceType0_0 = typeof(global::GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler<,>).MakeGenericType(serviceType0_0Argument0, serviceType0_0Argument1Element.MakeArrayType());
+ global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient(
+ services,
+ serviceType0_0,
+ implementationType0);
+ logger.Debug("Registered CQRS handler TestApp.Container.MixedHandler as GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler.");
+ }
+ global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient(
+ services,
+ typeof(global::GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler),
+ implementationType0);
+ logger.Debug("Registered CQRS handler TestApp.Container.MixedHandler as GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler.");
+ }
+ }
+ }
+
+ """;
+
+ private const string MixedReflectedImplementationAndPreciseRegistrationsExpected = """
+ //
+ #nullable enable
+
+ [assembly: global::GFramework.Cqrs.CqrsHandlerRegistryAttribute(typeof(global::GFramework.Generated.Cqrs.__GFrameworkGeneratedCqrsHandlerRegistry))]
+
+ namespace GFramework.Generated.Cqrs;
+
+ internal sealed class __GFrameworkGeneratedCqrsHandlerRegistry : global::GFramework.Cqrs.ICqrsHandlerRegistry
+ {
+ public void Register(global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, global::GFramework.Core.Abstractions.Logging.ILogger logger)
+ {
+ if (services is null)
+ throw new global::System.ArgumentNullException(nameof(services));
+ if (logger is null)
+ throw new global::System.ArgumentNullException(nameof(logger));
+
+ var registryAssembly = typeof(global::GFramework.Generated.Cqrs.__GFrameworkGeneratedCqrsHandlerRegistry).Assembly;
+
+ var implementationType0 = registryAssembly.GetType("TestApp.Container+HiddenMixedHandler", throwOnError: false, ignoreCase: false);
+ if (implementationType0 is not null)
+ {
+ var serviceType0_0Argument0 = registryAssembly.GetType("TestApp.Container+HiddenRequest", throwOnError: false, ignoreCase: false);
+ var serviceType0_0Argument1Element = registryAssembly.GetType("TestApp.Container+HiddenResponse", throwOnError: false, ignoreCase: false);
+ if (serviceType0_0Argument0 is not null && serviceType0_0Argument1Element is not null)
+ {
+ var serviceType0_0 = typeof(global::GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler<,>).MakeGenericType(serviceType0_0Argument0, serviceType0_0Argument1Element.MakeArrayType());
+ global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient(
+ services,
+ serviceType0_0,
+ implementationType0);
+ logger.Debug("Registered CQRS handler TestApp.Container.HiddenMixedHandler as GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler.");
+ }
+ global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient(
+ services,
+ typeof(global::GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler),
+ implementationType0);
+ logger.Debug("Registered CQRS handler TestApp.Container.HiddenMixedHandler as GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler.");
+ }
+ }
+ }
+
+ """;
+
///
/// 验证生成器会为当前程序集中的 request、notification 和 stream 处理器生成稳定顺序的注册器。
///
@@ -579,6 +667,164 @@ public class CqrsHandlerRegistryGeneratorTests
("CqrsHandlerRegistry.g.cs", HiddenGenericEnvelopeResponseExpected));
}
+ ///
+ /// 验证同一个 implementation 同时包含可直接注册接口与需精确重建接口时,
+ /// 生成器会保留两类注册,并继续按 handler interface 名称稳定排序。
+ ///
+ [Test]
+ public async Task Generates_Mixed_Direct_And_Precise_Registrations_For_Same_Implementation()
+ {
+ const string source = """
+ using System;
+
+ namespace Microsoft.Extensions.DependencyInjection
+ {
+ public interface IServiceCollection { }
+
+ public static class ServiceCollectionServiceExtensions
+ {
+ public static void AddTransient(IServiceCollection services, Type serviceType, Type implementationType) { }
+ }
+ }
+
+ namespace GFramework.Core.Abstractions.Logging
+ {
+ public interface ILogger
+ {
+ void Debug(string msg);
+ }
+ }
+
+ namespace GFramework.Cqrs.Abstractions.Cqrs
+ {
+ public interface IRequest { }
+ public interface INotification { }
+ public interface IStreamRequest { }
+
+ public interface IRequestHandler where TRequest : IRequest { }
+ public interface INotificationHandler where TNotification : INotification { }
+ public interface IStreamRequestHandler where TRequest : IStreamRequest { }
+ }
+
+ namespace GFramework.Cqrs
+ {
+ public interface ICqrsHandlerRegistry
+ {
+ void Register(Microsoft.Extensions.DependencyInjection.IServiceCollection services, GFramework.Core.Abstractions.Logging.ILogger logger);
+ }
+
+ [AttributeUsage(AttributeTargets.Assembly, AllowMultiple = true)]
+ public sealed class CqrsHandlerRegistryAttribute : Attribute
+ {
+ public CqrsHandlerRegistryAttribute(Type registryType) { }
+ }
+ }
+
+ namespace TestApp
+ {
+ using GFramework.Cqrs.Abstractions.Cqrs;
+
+ public sealed record VisibleRequest() : IRequest;
+
+ public sealed class Container
+ {
+ private sealed record HiddenResponse();
+
+ private sealed record HiddenRequest() : IRequest;
+
+ public sealed class MixedHandler :
+ IRequestHandler,
+ IRequestHandler
+ {
+ }
+ }
+ }
+ """;
+
+ await GeneratorTest.RunAsync(
+ source,
+ ("CqrsHandlerRegistry.g.cs", MixedDirectAndPreciseRegistrationsExpected));
+ }
+
+ ///
+ /// 验证隐藏 implementation 同时包含可见 handler interface 与需精确重建接口时,
+ /// 生成器会保留两类注册,而不会让可见接口被整实现回退吞掉。
+ ///
+ [Test]
+ public async Task Generates_Mixed_Reflected_Implementation_And_Precise_Registrations_For_Same_Implementation()
+ {
+ const string source = """
+ using System;
+
+ namespace Microsoft.Extensions.DependencyInjection
+ {
+ public interface IServiceCollection { }
+
+ public static class ServiceCollectionServiceExtensions
+ {
+ public static void AddTransient(IServiceCollection services, Type serviceType, Type implementationType) { }
+ }
+ }
+
+ namespace GFramework.Core.Abstractions.Logging
+ {
+ public interface ILogger
+ {
+ void Debug(string msg);
+ }
+ }
+
+ namespace GFramework.Cqrs.Abstractions.Cqrs
+ {
+ public interface IRequest { }
+ public interface INotification { }
+ public interface IStreamRequest { }
+
+ public interface IRequestHandler where TRequest : IRequest { }
+ public interface INotificationHandler where TNotification : INotification { }
+ public interface IStreamRequestHandler where TRequest : IStreamRequest { }
+ }
+
+ namespace GFramework.Cqrs
+ {
+ public interface ICqrsHandlerRegistry
+ {
+ void Register(Microsoft.Extensions.DependencyInjection.IServiceCollection services, GFramework.Core.Abstractions.Logging.ILogger logger);
+ }
+
+ [AttributeUsage(AttributeTargets.Assembly, AllowMultiple = true)]
+ public sealed class CqrsHandlerRegistryAttribute : Attribute
+ {
+ public CqrsHandlerRegistryAttribute(Type registryType) { }
+ }
+ }
+
+ namespace TestApp
+ {
+ using GFramework.Cqrs.Abstractions.Cqrs;
+
+ public sealed record VisibleRequest() : IRequest;
+
+ public sealed class Container
+ {
+ private sealed record HiddenResponse();
+
+ private sealed record HiddenRequest() : IRequest;
+
+ private sealed class HiddenMixedHandler :
+ IRequestHandler,
+ IRequestHandler
+ {
+ }
+ }
+ }
+ """;
+
+ await GeneratorTest.RunAsync(
+ source,
+ ("CqrsHandlerRegistry.g.cs", MixedReflectedImplementationAndPreciseRegistrationsExpected));
+ }
+
///
/// 验证即使 runtime 仍暴露旧版无参 fallback marker,生成器也会优先在生成注册器内部处理隐藏 handler,
/// 不再输出 fallback marker。
diff --git a/GFramework.SourceGenerators/Cqrs/CqrsHandlerRegistryGenerator.cs b/GFramework.SourceGenerators/Cqrs/CqrsHandlerRegistryGenerator.cs
index 36a72be5..f3bb5e96 100644
--- a/GFramework.SourceGenerators/Cqrs/CqrsHandlerRegistryGenerator.cs
+++ b/GFramework.SourceGenerators/Cqrs/CqrsHandlerRegistryGenerator.cs
@@ -549,40 +549,22 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
for (var registrationIndex = 0; registrationIndex < registrations.Count; registrationIndex++)
{
var registration = registrations[registrationIndex];
- if (!registration.ReflectedImplementationRegistrations.IsDefaultOrEmpty)
+ if (!registration.ReflectedImplementationRegistrations.IsDefaultOrEmpty ||
+ !registration.PreciseReflectedRegistrations.IsDefaultOrEmpty)
{
- AppendReflectedImplementationRegistrations(builder, registration, registrationIndex);
- continue;
+ AppendOrderedImplementationRegistrations(builder, registration, registrationIndex);
+ }
+ else if (!registration.DirectRegistrations.IsDefaultOrEmpty)
+ {
+ AppendDirectRegistrations(builder, registration);
}
- if (!registration.PreciseReflectedRegistrations.IsDefaultOrEmpty)
- {
- AppendPreciseReflectedRegistrations(builder, registration, registrationIndex);
- continue;
- }
-
- if (!string.IsNullOrWhiteSpace(registration.ReflectionTypeMetadataName))
+ if (!string.IsNullOrWhiteSpace(registration.ReflectionTypeMetadataName) &&
+ registration.ReflectedImplementationRegistrations.IsDefaultOrEmpty &&
+ registration.PreciseReflectedRegistrations.IsDefaultOrEmpty &&
+ registration.DirectRegistrations.IsDefaultOrEmpty)
{
AppendReflectionRegistration(builder, registration.ReflectionTypeMetadataName!);
- continue;
- }
-
- foreach (var directRegistration in registration.DirectRegistrations)
- {
- builder.AppendLine(
- " global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient(");
- builder.AppendLine(" services,");
- builder.Append(" typeof(");
- builder.Append(directRegistration.HandlerInterfaceDisplayName);
- builder.AppendLine("),");
- builder.Append(" typeof(");
- builder.Append(directRegistration.ImplementationTypeDisplayName);
- builder.AppendLine("));");
- builder.Append(" logger.Debug(\"Registered CQRS handler ");
- builder.Append(EscapeStringLiteral(directRegistration.ImplementationLogName));
- builder.Append(" as ");
- builder.Append(EscapeStringLiteral(directRegistration.HandlerInterfaceLogName));
- builder.AppendLine(".\");");
}
}
@@ -605,48 +587,71 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
builder.AppendLine("\");");
}
- private static void AppendReflectedImplementationRegistrations(
+ private static void AppendDirectRegistrations(
StringBuilder builder,
- ImplementationRegistrationSpec registration,
- int registrationIndex)
+ ImplementationRegistrationSpec registration)
{
- var implementationVariableName = $"implementationType{registrationIndex}";
- builder.Append(" var ");
- builder.Append(implementationVariableName);
- builder.Append(" = registryAssembly.GetType(\"");
- builder.Append(EscapeStringLiteral(registration.ReflectionTypeMetadataName!));
- builder.AppendLine("\", throwOnError: false, ignoreCase: false);");
- builder.Append(" if (");
- builder.Append(implementationVariableName);
- builder.AppendLine(" is not null)");
- builder.AppendLine(" {");
-
- foreach (var reflectedRegistration in registration.ReflectedImplementationRegistrations)
+ foreach (var directRegistration in registration.DirectRegistrations)
{
builder.AppendLine(
- " global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient(");
- builder.AppendLine(" services,");
- builder.Append(" typeof(");
- builder.Append(reflectedRegistration.HandlerInterfaceDisplayName);
+ " global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient(");
+ builder.AppendLine(" services,");
+ builder.Append(" typeof(");
+ builder.Append(directRegistration.HandlerInterfaceDisplayName);
builder.AppendLine("),");
- builder.Append(" ");
- builder.Append(implementationVariableName);
- builder.AppendLine(");");
- builder.Append(" logger.Debug(\"Registered CQRS handler ");
- builder.Append(EscapeStringLiteral(registration.ImplementationLogName));
+ builder.Append(" typeof(");
+ builder.Append(directRegistration.ImplementationTypeDisplayName);
+ builder.AppendLine("));");
+ builder.Append(" logger.Debug(\"Registered CQRS handler ");
+ builder.Append(EscapeStringLiteral(directRegistration.ImplementationLogName));
builder.Append(" as ");
- builder.Append(EscapeStringLiteral(reflectedRegistration.HandlerInterfaceLogName));
+ builder.Append(EscapeStringLiteral(directRegistration.HandlerInterfaceLogName));
builder.AppendLine(".\");");
}
-
- builder.AppendLine(" }");
}
- private static void AppendPreciseReflectedRegistrations(
+ private static void AppendOrderedImplementationRegistrations(
StringBuilder builder,
ImplementationRegistrationSpec registration,
int registrationIndex)
{
+ var orderedRegistrations =
+ new List<(string HandlerInterfaceLogName, OrderedRegistrationKind Kind, int Index)>(
+ registration.DirectRegistrations.Length +
+ registration.ReflectedImplementationRegistrations.Length +
+ registration.PreciseReflectedRegistrations.Length);
+
+ for (var directIndex = 0; directIndex < registration.DirectRegistrations.Length; directIndex++)
+ {
+ orderedRegistrations.Add((
+ registration.DirectRegistrations[directIndex].HandlerInterfaceLogName,
+ OrderedRegistrationKind.Direct,
+ directIndex));
+ }
+
+ for (var reflectedIndex = 0;
+ reflectedIndex < registration.ReflectedImplementationRegistrations.Length;
+ reflectedIndex++)
+ {
+ orderedRegistrations.Add((
+ registration.ReflectedImplementationRegistrations[reflectedIndex].HandlerInterfaceLogName,
+ OrderedRegistrationKind.ReflectedImplementation,
+ reflectedIndex));
+ }
+
+ for (var preciseIndex = 0;
+ preciseIndex < registration.PreciseReflectedRegistrations.Length;
+ preciseIndex++)
+ {
+ orderedRegistrations.Add((
+ registration.PreciseReflectedRegistrations[preciseIndex].HandlerInterfaceLogName,
+ OrderedRegistrationKind.PreciseReflected,
+ preciseIndex));
+ }
+
+ orderedRegistrations.Sort(static (left, right) =>
+ StringComparer.Ordinal.Compare(left.HandlerInterfaceLogName, right.HandlerInterfaceLogName));
+
var implementationVariableName = $"implementationType{registrationIndex}";
if (string.IsNullOrWhiteSpace(registration.ReflectionTypeMetadataName))
{
@@ -658,11 +663,10 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
}
else
{
- var implementationReflectionTypeMetadataName = registration.ReflectionTypeMetadataName!;
builder.Append(" var ");
builder.Append(implementationVariableName);
builder.Append(" = registryAssembly.GetType(\"");
- builder.Append(EscapeStringLiteral(implementationReflectionTypeMetadataName));
+ builder.Append(EscapeStringLiteral(registration.ReflectionTypeMetadataName!));
builder.AppendLine("\", throwOnError: false, ignoreCase: false);");
}
@@ -671,21 +675,62 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
builder.AppendLine(" is not null)");
builder.AppendLine(" {");
- for (var registrationOffset = 0;
- registrationOffset < registration.PreciseReflectedRegistrations.Length;
- registrationOffset++)
+ foreach (var orderedRegistration in orderedRegistrations)
{
- var reflectedRegistration = registration.PreciseReflectedRegistrations[registrationOffset];
- var registrationVariablePrefix = $"serviceType{registrationIndex}_{registrationOffset}";
- AppendPreciseReflectedTypeResolution(
- builder,
- reflectedRegistration.ServiceTypeArguments,
- registrationVariablePrefix,
- implementationVariableName,
- reflectedRegistration.OpenHandlerTypeDisplayName,
- registration.ImplementationLogName,
- reflectedRegistration.HandlerInterfaceLogName,
- 3);
+ switch (orderedRegistration.Kind)
+ {
+ case OrderedRegistrationKind.Direct:
+ var directRegistration = registration.DirectRegistrations[orderedRegistration.Index];
+ builder.AppendLine(
+ " global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient(");
+ builder.AppendLine(" services,");
+ builder.Append(" typeof(");
+ builder.Append(directRegistration.HandlerInterfaceDisplayName);
+ builder.AppendLine("),");
+ builder.Append(" ");
+ builder.Append(implementationVariableName);
+ builder.AppendLine(");");
+ builder.Append(" logger.Debug(\"Registered CQRS handler ");
+ builder.Append(EscapeStringLiteral(registration.ImplementationLogName));
+ builder.Append(" as ");
+ builder.Append(EscapeStringLiteral(directRegistration.HandlerInterfaceLogName));
+ builder.AppendLine(".\");");
+ break;
+ case OrderedRegistrationKind.ReflectedImplementation:
+ var reflectedRegistration =
+ registration.ReflectedImplementationRegistrations[orderedRegistration.Index];
+ builder.AppendLine(
+ " global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient(");
+ builder.AppendLine(" services,");
+ builder.Append(" typeof(");
+ builder.Append(reflectedRegistration.HandlerInterfaceDisplayName);
+ builder.AppendLine("),");
+ builder.Append(" ");
+ builder.Append(implementationVariableName);
+ builder.AppendLine(");");
+ builder.Append(" logger.Debug(\"Registered CQRS handler ");
+ builder.Append(EscapeStringLiteral(registration.ImplementationLogName));
+ builder.Append(" as ");
+ builder.Append(EscapeStringLiteral(reflectedRegistration.HandlerInterfaceLogName));
+ builder.AppendLine(".\");");
+ break;
+ case OrderedRegistrationKind.PreciseReflected:
+ var preciseRegistration = registration.PreciseReflectedRegistrations[orderedRegistration.Index];
+ var registrationVariablePrefix = $"serviceType{registrationIndex}_{orderedRegistration.Index}";
+ AppendPreciseReflectedTypeResolution(
+ builder,
+ preciseRegistration.ServiceTypeArguments,
+ registrationVariablePrefix,
+ implementationVariableName,
+ preciseRegistration.OpenHandlerTypeDisplayName,
+ registration.ImplementationLogName,
+ preciseRegistration.HandlerInterfaceLogName,
+ 3);
+ break;
+ default:
+ throw new InvalidOperationException(
+ $"Unsupported ordered CQRS registration kind {orderedRegistration.Kind}.");
+ }
}
builder.AppendLine(" }");
@@ -969,6 +1014,13 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
string HandlerInterfaceDisplayName,
string HandlerInterfaceLogName);
+ private enum OrderedRegistrationKind
+ {
+ Direct,
+ ReflectedImplementation,
+ PreciseReflected
+ }
+
private sealed record RuntimeTypeReferenceSpec(
string? TypeDisplayName,
string? ReflectionTypeMetadataName,