diff --git a/GFramework.SourceGenerators.Tests/Cqrs/CqrsHandlerRegistryGeneratorTests.cs b/GFramework.SourceGenerators.Tests/Cqrs/CqrsHandlerRegistryGeneratorTests.cs index ef70ec4a..3026d279 100644 --- a/GFramework.SourceGenerators.Tests/Cqrs/CqrsHandlerRegistryGeneratorTests.cs +++ b/GFramework.SourceGenerators.Tests/Cqrs/CqrsHandlerRegistryGeneratorTests.cs @@ -164,6 +164,94 @@ public class CqrsHandlerRegistryGeneratorTests """; + private const string MixedDirectAndPreciseRegistrationsExpected = """ + // + #nullable enable + + [assembly: global::GFramework.Cqrs.CqrsHandlerRegistryAttribute(typeof(global::GFramework.Generated.Cqrs.__GFrameworkGeneratedCqrsHandlerRegistry))] + + namespace GFramework.Generated.Cqrs; + + internal sealed class __GFrameworkGeneratedCqrsHandlerRegistry : global::GFramework.Cqrs.ICqrsHandlerRegistry + { + public void Register(global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, global::GFramework.Core.Abstractions.Logging.ILogger logger) + { + if (services is null) + throw new global::System.ArgumentNullException(nameof(services)); + if (logger is null) + throw new global::System.ArgumentNullException(nameof(logger)); + + var registryAssembly = typeof(global::GFramework.Generated.Cqrs.__GFrameworkGeneratedCqrsHandlerRegistry).Assembly; + + var implementationType0 = typeof(global::TestApp.Container.MixedHandler); + if (implementationType0 is not null) + { + var serviceType0_0Argument0 = registryAssembly.GetType("TestApp.Container+HiddenRequest", throwOnError: false, ignoreCase: false); + var serviceType0_0Argument1Element = registryAssembly.GetType("TestApp.Container+HiddenResponse", throwOnError: false, ignoreCase: false); + if (serviceType0_0Argument0 is not null && serviceType0_0Argument1Element is not null) + { + var serviceType0_0 = typeof(global::GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler<,>).MakeGenericType(serviceType0_0Argument0, serviceType0_0Argument1Element.MakeArrayType()); + global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient( + services, + serviceType0_0, + implementationType0); + logger.Debug("Registered CQRS handler TestApp.Container.MixedHandler as GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler."); + } + global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient( + services, + typeof(global::GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler), + implementationType0); + logger.Debug("Registered CQRS handler TestApp.Container.MixedHandler as GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler."); + } + } + } + + """; + + private const string MixedReflectedImplementationAndPreciseRegistrationsExpected = """ + // + #nullable enable + + [assembly: global::GFramework.Cqrs.CqrsHandlerRegistryAttribute(typeof(global::GFramework.Generated.Cqrs.__GFrameworkGeneratedCqrsHandlerRegistry))] + + namespace GFramework.Generated.Cqrs; + + internal sealed class __GFrameworkGeneratedCqrsHandlerRegistry : global::GFramework.Cqrs.ICqrsHandlerRegistry + { + public void Register(global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, global::GFramework.Core.Abstractions.Logging.ILogger logger) + { + if (services is null) + throw new global::System.ArgumentNullException(nameof(services)); + if (logger is null) + throw new global::System.ArgumentNullException(nameof(logger)); + + var registryAssembly = typeof(global::GFramework.Generated.Cqrs.__GFrameworkGeneratedCqrsHandlerRegistry).Assembly; + + var implementationType0 = registryAssembly.GetType("TestApp.Container+HiddenMixedHandler", throwOnError: false, ignoreCase: false); + if (implementationType0 is not null) + { + var serviceType0_0Argument0 = registryAssembly.GetType("TestApp.Container+HiddenRequest", throwOnError: false, ignoreCase: false); + var serviceType0_0Argument1Element = registryAssembly.GetType("TestApp.Container+HiddenResponse", throwOnError: false, ignoreCase: false); + if (serviceType0_0Argument0 is not null && serviceType0_0Argument1Element is not null) + { + var serviceType0_0 = typeof(global::GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler<,>).MakeGenericType(serviceType0_0Argument0, serviceType0_0Argument1Element.MakeArrayType()); + global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient( + services, + serviceType0_0, + implementationType0); + logger.Debug("Registered CQRS handler TestApp.Container.HiddenMixedHandler as GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler."); + } + global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient( + services, + typeof(global::GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler), + implementationType0); + logger.Debug("Registered CQRS handler TestApp.Container.HiddenMixedHandler as GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler."); + } + } + } + + """; + /// /// 验证生成器会为当前程序集中的 request、notification 和 stream 处理器生成稳定顺序的注册器。 /// @@ -579,6 +667,164 @@ public class CqrsHandlerRegistryGeneratorTests ("CqrsHandlerRegistry.g.cs", HiddenGenericEnvelopeResponseExpected)); } + /// + /// 验证同一个 implementation 同时包含可直接注册接口与需精确重建接口时, + /// 生成器会保留两类注册,并继续按 handler interface 名称稳定排序。 + /// + [Test] + public async Task Generates_Mixed_Direct_And_Precise_Registrations_For_Same_Implementation() + { + const string source = """ + using System; + + namespace Microsoft.Extensions.DependencyInjection + { + public interface IServiceCollection { } + + public static class ServiceCollectionServiceExtensions + { + public static void AddTransient(IServiceCollection services, Type serviceType, Type implementationType) { } + } + } + + namespace GFramework.Core.Abstractions.Logging + { + public interface ILogger + { + void Debug(string msg); + } + } + + namespace GFramework.Cqrs.Abstractions.Cqrs + { + public interface IRequest { } + public interface INotification { } + public interface IStreamRequest { } + + public interface IRequestHandler where TRequest : IRequest { } + public interface INotificationHandler where TNotification : INotification { } + public interface IStreamRequestHandler where TRequest : IStreamRequest { } + } + + namespace GFramework.Cqrs + { + public interface ICqrsHandlerRegistry + { + void Register(Microsoft.Extensions.DependencyInjection.IServiceCollection services, GFramework.Core.Abstractions.Logging.ILogger logger); + } + + [AttributeUsage(AttributeTargets.Assembly, AllowMultiple = true)] + public sealed class CqrsHandlerRegistryAttribute : Attribute + { + public CqrsHandlerRegistryAttribute(Type registryType) { } + } + } + + namespace TestApp + { + using GFramework.Cqrs.Abstractions.Cqrs; + + public sealed record VisibleRequest() : IRequest; + + public sealed class Container + { + private sealed record HiddenResponse(); + + private sealed record HiddenRequest() : IRequest; + + public sealed class MixedHandler : + IRequestHandler, + IRequestHandler + { + } + } + } + """; + + await GeneratorTest.RunAsync( + source, + ("CqrsHandlerRegistry.g.cs", MixedDirectAndPreciseRegistrationsExpected)); + } + + /// + /// 验证隐藏 implementation 同时包含可见 handler interface 与需精确重建接口时, + /// 生成器会保留两类注册,而不会让可见接口被整实现回退吞掉。 + /// + [Test] + public async Task Generates_Mixed_Reflected_Implementation_And_Precise_Registrations_For_Same_Implementation() + { + const string source = """ + using System; + + namespace Microsoft.Extensions.DependencyInjection + { + public interface IServiceCollection { } + + public static class ServiceCollectionServiceExtensions + { + public static void AddTransient(IServiceCollection services, Type serviceType, Type implementationType) { } + } + } + + namespace GFramework.Core.Abstractions.Logging + { + public interface ILogger + { + void Debug(string msg); + } + } + + namespace GFramework.Cqrs.Abstractions.Cqrs + { + public interface IRequest { } + public interface INotification { } + public interface IStreamRequest { } + + public interface IRequestHandler where TRequest : IRequest { } + public interface INotificationHandler where TNotification : INotification { } + public interface IStreamRequestHandler where TRequest : IStreamRequest { } + } + + namespace GFramework.Cqrs + { + public interface ICqrsHandlerRegistry + { + void Register(Microsoft.Extensions.DependencyInjection.IServiceCollection services, GFramework.Core.Abstractions.Logging.ILogger logger); + } + + [AttributeUsage(AttributeTargets.Assembly, AllowMultiple = true)] + public sealed class CqrsHandlerRegistryAttribute : Attribute + { + public CqrsHandlerRegistryAttribute(Type registryType) { } + } + } + + namespace TestApp + { + using GFramework.Cqrs.Abstractions.Cqrs; + + public sealed record VisibleRequest() : IRequest; + + public sealed class Container + { + private sealed record HiddenResponse(); + + private sealed record HiddenRequest() : IRequest; + + private sealed class HiddenMixedHandler : + IRequestHandler, + IRequestHandler + { + } + } + } + """; + + await GeneratorTest.RunAsync( + source, + ("CqrsHandlerRegistry.g.cs", MixedReflectedImplementationAndPreciseRegistrationsExpected)); + } + /// /// 验证即使 runtime 仍暴露旧版无参 fallback marker,生成器也会优先在生成注册器内部处理隐藏 handler, /// 不再输出 fallback marker。 diff --git a/GFramework.SourceGenerators/Cqrs/CqrsHandlerRegistryGenerator.cs b/GFramework.SourceGenerators/Cqrs/CqrsHandlerRegistryGenerator.cs index 36a72be5..f3bb5e96 100644 --- a/GFramework.SourceGenerators/Cqrs/CqrsHandlerRegistryGenerator.cs +++ b/GFramework.SourceGenerators/Cqrs/CqrsHandlerRegistryGenerator.cs @@ -549,40 +549,22 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator for (var registrationIndex = 0; registrationIndex < registrations.Count; registrationIndex++) { var registration = registrations[registrationIndex]; - if (!registration.ReflectedImplementationRegistrations.IsDefaultOrEmpty) + if (!registration.ReflectedImplementationRegistrations.IsDefaultOrEmpty || + !registration.PreciseReflectedRegistrations.IsDefaultOrEmpty) { - AppendReflectedImplementationRegistrations(builder, registration, registrationIndex); - continue; + AppendOrderedImplementationRegistrations(builder, registration, registrationIndex); + } + else if (!registration.DirectRegistrations.IsDefaultOrEmpty) + { + AppendDirectRegistrations(builder, registration); } - if (!registration.PreciseReflectedRegistrations.IsDefaultOrEmpty) - { - AppendPreciseReflectedRegistrations(builder, registration, registrationIndex); - continue; - } - - if (!string.IsNullOrWhiteSpace(registration.ReflectionTypeMetadataName)) + if (!string.IsNullOrWhiteSpace(registration.ReflectionTypeMetadataName) && + registration.ReflectedImplementationRegistrations.IsDefaultOrEmpty && + registration.PreciseReflectedRegistrations.IsDefaultOrEmpty && + registration.DirectRegistrations.IsDefaultOrEmpty) { AppendReflectionRegistration(builder, registration.ReflectionTypeMetadataName!); - continue; - } - - foreach (var directRegistration in registration.DirectRegistrations) - { - builder.AppendLine( - " global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient("); - builder.AppendLine(" services,"); - builder.Append(" typeof("); - builder.Append(directRegistration.HandlerInterfaceDisplayName); - builder.AppendLine("),"); - builder.Append(" typeof("); - builder.Append(directRegistration.ImplementationTypeDisplayName); - builder.AppendLine("));"); - builder.Append(" logger.Debug(\"Registered CQRS handler "); - builder.Append(EscapeStringLiteral(directRegistration.ImplementationLogName)); - builder.Append(" as "); - builder.Append(EscapeStringLiteral(directRegistration.HandlerInterfaceLogName)); - builder.AppendLine(".\");"); } } @@ -605,48 +587,71 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator builder.AppendLine("\");"); } - private static void AppendReflectedImplementationRegistrations( + private static void AppendDirectRegistrations( StringBuilder builder, - ImplementationRegistrationSpec registration, - int registrationIndex) + ImplementationRegistrationSpec registration) { - var implementationVariableName = $"implementationType{registrationIndex}"; - builder.Append(" var "); - builder.Append(implementationVariableName); - builder.Append(" = registryAssembly.GetType(\""); - builder.Append(EscapeStringLiteral(registration.ReflectionTypeMetadataName!)); - builder.AppendLine("\", throwOnError: false, ignoreCase: false);"); - builder.Append(" if ("); - builder.Append(implementationVariableName); - builder.AppendLine(" is not null)"); - builder.AppendLine(" {"); - - foreach (var reflectedRegistration in registration.ReflectedImplementationRegistrations) + foreach (var directRegistration in registration.DirectRegistrations) { builder.AppendLine( - " global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient("); - builder.AppendLine(" services,"); - builder.Append(" typeof("); - builder.Append(reflectedRegistration.HandlerInterfaceDisplayName); + " global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient("); + builder.AppendLine(" services,"); + builder.Append(" typeof("); + builder.Append(directRegistration.HandlerInterfaceDisplayName); builder.AppendLine("),"); - builder.Append(" "); - builder.Append(implementationVariableName); - builder.AppendLine(");"); - builder.Append(" logger.Debug(\"Registered CQRS handler "); - builder.Append(EscapeStringLiteral(registration.ImplementationLogName)); + builder.Append(" typeof("); + builder.Append(directRegistration.ImplementationTypeDisplayName); + builder.AppendLine("));"); + builder.Append(" logger.Debug(\"Registered CQRS handler "); + builder.Append(EscapeStringLiteral(directRegistration.ImplementationLogName)); builder.Append(" as "); - builder.Append(EscapeStringLiteral(reflectedRegistration.HandlerInterfaceLogName)); + builder.Append(EscapeStringLiteral(directRegistration.HandlerInterfaceLogName)); builder.AppendLine(".\");"); } - - builder.AppendLine(" }"); } - private static void AppendPreciseReflectedRegistrations( + private static void AppendOrderedImplementationRegistrations( StringBuilder builder, ImplementationRegistrationSpec registration, int registrationIndex) { + var orderedRegistrations = + new List<(string HandlerInterfaceLogName, OrderedRegistrationKind Kind, int Index)>( + registration.DirectRegistrations.Length + + registration.ReflectedImplementationRegistrations.Length + + registration.PreciseReflectedRegistrations.Length); + + for (var directIndex = 0; directIndex < registration.DirectRegistrations.Length; directIndex++) + { + orderedRegistrations.Add(( + registration.DirectRegistrations[directIndex].HandlerInterfaceLogName, + OrderedRegistrationKind.Direct, + directIndex)); + } + + for (var reflectedIndex = 0; + reflectedIndex < registration.ReflectedImplementationRegistrations.Length; + reflectedIndex++) + { + orderedRegistrations.Add(( + registration.ReflectedImplementationRegistrations[reflectedIndex].HandlerInterfaceLogName, + OrderedRegistrationKind.ReflectedImplementation, + reflectedIndex)); + } + + for (var preciseIndex = 0; + preciseIndex < registration.PreciseReflectedRegistrations.Length; + preciseIndex++) + { + orderedRegistrations.Add(( + registration.PreciseReflectedRegistrations[preciseIndex].HandlerInterfaceLogName, + OrderedRegistrationKind.PreciseReflected, + preciseIndex)); + } + + orderedRegistrations.Sort(static (left, right) => + StringComparer.Ordinal.Compare(left.HandlerInterfaceLogName, right.HandlerInterfaceLogName)); + var implementationVariableName = $"implementationType{registrationIndex}"; if (string.IsNullOrWhiteSpace(registration.ReflectionTypeMetadataName)) { @@ -658,11 +663,10 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator } else { - var implementationReflectionTypeMetadataName = registration.ReflectionTypeMetadataName!; builder.Append(" var "); builder.Append(implementationVariableName); builder.Append(" = registryAssembly.GetType(\""); - builder.Append(EscapeStringLiteral(implementationReflectionTypeMetadataName)); + builder.Append(EscapeStringLiteral(registration.ReflectionTypeMetadataName!)); builder.AppendLine("\", throwOnError: false, ignoreCase: false);"); } @@ -671,21 +675,62 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator builder.AppendLine(" is not null)"); builder.AppendLine(" {"); - for (var registrationOffset = 0; - registrationOffset < registration.PreciseReflectedRegistrations.Length; - registrationOffset++) + foreach (var orderedRegistration in orderedRegistrations) { - var reflectedRegistration = registration.PreciseReflectedRegistrations[registrationOffset]; - var registrationVariablePrefix = $"serviceType{registrationIndex}_{registrationOffset}"; - AppendPreciseReflectedTypeResolution( - builder, - reflectedRegistration.ServiceTypeArguments, - registrationVariablePrefix, - implementationVariableName, - reflectedRegistration.OpenHandlerTypeDisplayName, - registration.ImplementationLogName, - reflectedRegistration.HandlerInterfaceLogName, - 3); + switch (orderedRegistration.Kind) + { + case OrderedRegistrationKind.Direct: + var directRegistration = registration.DirectRegistrations[orderedRegistration.Index]; + builder.AppendLine( + " global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient("); + builder.AppendLine(" services,"); + builder.Append(" typeof("); + builder.Append(directRegistration.HandlerInterfaceDisplayName); + builder.AppendLine("),"); + builder.Append(" "); + builder.Append(implementationVariableName); + builder.AppendLine(");"); + builder.Append(" logger.Debug(\"Registered CQRS handler "); + builder.Append(EscapeStringLiteral(registration.ImplementationLogName)); + builder.Append(" as "); + builder.Append(EscapeStringLiteral(directRegistration.HandlerInterfaceLogName)); + builder.AppendLine(".\");"); + break; + case OrderedRegistrationKind.ReflectedImplementation: + var reflectedRegistration = + registration.ReflectedImplementationRegistrations[orderedRegistration.Index]; + builder.AppendLine( + " global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient("); + builder.AppendLine(" services,"); + builder.Append(" typeof("); + builder.Append(reflectedRegistration.HandlerInterfaceDisplayName); + builder.AppendLine("),"); + builder.Append(" "); + builder.Append(implementationVariableName); + builder.AppendLine(");"); + builder.Append(" logger.Debug(\"Registered CQRS handler "); + builder.Append(EscapeStringLiteral(registration.ImplementationLogName)); + builder.Append(" as "); + builder.Append(EscapeStringLiteral(reflectedRegistration.HandlerInterfaceLogName)); + builder.AppendLine(".\");"); + break; + case OrderedRegistrationKind.PreciseReflected: + var preciseRegistration = registration.PreciseReflectedRegistrations[orderedRegistration.Index]; + var registrationVariablePrefix = $"serviceType{registrationIndex}_{orderedRegistration.Index}"; + AppendPreciseReflectedTypeResolution( + builder, + preciseRegistration.ServiceTypeArguments, + registrationVariablePrefix, + implementationVariableName, + preciseRegistration.OpenHandlerTypeDisplayName, + registration.ImplementationLogName, + preciseRegistration.HandlerInterfaceLogName, + 3); + break; + default: + throw new InvalidOperationException( + $"Unsupported ordered CQRS registration kind {orderedRegistration.Kind}."); + } } builder.AppendLine(" }"); @@ -969,6 +1014,13 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator string HandlerInterfaceDisplayName, string HandlerInterfaceLogName); + private enum OrderedRegistrationKind + { + Direct, + ReflectedImplementation, + PreciseReflected + } + private sealed record RuntimeTypeReferenceSpec( string? TypeDisplayName, string? ReflectionTypeMetadataName,