diff --git a/GFramework.SourceGenerators.Tests/Cqrs/CqrsHandlerRegistryGeneratorTests.cs b/GFramework.SourceGenerators.Tests/Cqrs/CqrsHandlerRegistryGeneratorTests.cs index dcdb5e5f..679fd0fa 100644 --- a/GFramework.SourceGenerators.Tests/Cqrs/CqrsHandlerRegistryGeneratorTests.cs +++ b/GFramework.SourceGenerators.Tests/Cqrs/CqrsHandlerRegistryGeneratorTests.cs @@ -142,6 +142,39 @@ public class CqrsHandlerRegistryGeneratorTests """; + private const string HiddenImplementationDirectInterfaceRegistrationExpected = """ + // + #nullable enable + + [assembly: global::GFramework.Cqrs.CqrsHandlerRegistryAttribute(typeof(global::GFramework.Generated.Cqrs.__GFrameworkGeneratedCqrsHandlerRegistry))] + + namespace GFramework.Generated.Cqrs; + + internal sealed class __GFrameworkGeneratedCqrsHandlerRegistry : global::GFramework.Cqrs.ICqrsHandlerRegistry + { + public void Register(global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, global::GFramework.Core.Abstractions.Logging.ILogger logger) + { + if (services is null) + throw new global::System.ArgumentNullException(nameof(services)); + if (logger is null) + throw new global::System.ArgumentNullException(nameof(logger)); + + var registryAssembly = typeof(global::GFramework.Generated.Cqrs.__GFrameworkGeneratedCqrsHandlerRegistry).Assembly; + + var implementationType0 = registryAssembly.GetType("TestApp.Container+HiddenHandler", throwOnError: false, ignoreCase: false); + if (implementationType0 is not null) + { + global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient( + services, + typeof(global::GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler), + implementationType0); + logger.Debug("Registered CQRS handler TestApp.Container.HiddenHandler as GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler."); + } + } + } + + """; + /// /// 验证生成器会为当前程序集中的 request、notification 和 stream 处理器生成稳定顺序的注册器。 /// @@ -339,6 +372,78 @@ public class CqrsHandlerRegistryGeneratorTests ("CqrsHandlerRegistry.g.cs", HiddenNestedHandlerSelfRegistrationExpected)); } + /// + /// 验证当隐藏实现类型的 handler 接口仍可被生成代码直接引用时, + /// 生成器只会定向反射实现类型,而不会再生成基于 GetInterfaces() 的接口发现辅助逻辑。 + /// + [Test] + public async Task + Generates_Direct_Interface_Registrations_For_Hidden_Implementation_When_Handler_Interface_Is_Public() + { + const string source = """ + using System; + + namespace Microsoft.Extensions.DependencyInjection + { + public interface IServiceCollection { } + + public static class ServiceCollectionServiceExtensions + { + public static void AddTransient(IServiceCollection services, Type serviceType, Type implementationType) { } + } + } + + namespace GFramework.Core.Abstractions.Logging + { + public interface ILogger + { + void Debug(string msg); + } + } + + namespace GFramework.Cqrs.Abstractions.Cqrs + { + public interface IRequest { } + public interface INotification { } + public interface IStreamRequest { } + + public interface IRequestHandler where TRequest : IRequest { } + public interface INotificationHandler where TNotification : INotification { } + public interface IStreamRequestHandler where TRequest : IStreamRequest { } + } + + namespace GFramework.Cqrs + { + public interface ICqrsHandlerRegistry + { + void Register(Microsoft.Extensions.DependencyInjection.IServiceCollection services, GFramework.Core.Abstractions.Logging.ILogger logger); + } + + [AttributeUsage(AttributeTargets.Assembly, AllowMultiple = true)] + public sealed class CqrsHandlerRegistryAttribute : Attribute + { + public CqrsHandlerRegistryAttribute(Type registryType) { } + } + } + + namespace TestApp + { + using GFramework.Cqrs.Abstractions.Cqrs; + + public sealed record VisibleRequest() : IRequest; + + public sealed class Container + { + private sealed class HiddenHandler : IRequestHandler { } + } + } + """; + + await GeneratorTest.RunAsync( + source, + ("CqrsHandlerRegistry.g.cs", HiddenImplementationDirectInterfaceRegistrationExpected)); + } + /// /// 验证即使 runtime 仍暴露旧版无参 fallback marker,生成器也会优先在生成注册器内部处理隐藏 handler, /// 不再输出 fallback marker。 diff --git a/GFramework.SourceGenerators/Cqrs/CqrsHandlerRegistryGenerator.cs b/GFramework.SourceGenerators/Cqrs/CqrsHandlerRegistryGenerator.cs index 83559781..27b01f9f 100644 --- a/GFramework.SourceGenerators/Cqrs/CqrsHandlerRegistryGenerator.cs +++ b/GFramework.SourceGenerators/Cqrs/CqrsHandlerRegistryGenerator.cs @@ -86,22 +86,35 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator var implementationTypeDisplayName = type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); var implementationLogName = GetLogDisplayName(type); - if (!CanReferenceFromGeneratedRegistry(type) || - handlerInterfaces.Any(interfaceType => !CanReferenceFromGeneratedRegistry(interfaceType))) - { - // Non-public handlers and handlers closed over non-public message types cannot appear in typeof(...) - // expressions inside generated code. Preserve generator hit rate by resolving just that implementation - // type back from the current assembly instead of asking the runtime registrar to rescan the assembly. - return new HandlerCandidateAnalysis( - implementationTypeDisplayName, - implementationLogName, - ImmutableArray.Empty, - GetReflectionTypeMetadataName(type)); - } - + var canReferenceImplementation = CanReferenceFromGeneratedRegistry(type); var registrations = ImmutableArray.CreateBuilder(handlerInterfaces.Length); + var reflectedImplementationRegistrations = + ImmutableArray.CreateBuilder(handlerInterfaces.Length); foreach (var handlerInterface in handlerInterfaces) { + var canReferenceHandlerInterface = CanReferenceFromGeneratedRegistry(handlerInterface); + if (!canReferenceImplementation || !canReferenceHandlerInterface) + { + if (!canReferenceImplementation && canReferenceHandlerInterface) + { + reflectedImplementationRegistrations.Add(new ReflectedImplementationRegistrationSpec( + handlerInterface.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat), + GetLogDisplayName(handlerInterface))); + continue; + } + + // Non-public handlers closed over non-public message types still cannot be expressed purely as + // typeof(...) registrations. Preserve generator hit rate by resolving only the affected + // implementation back from the current assembly instead of asking the runtime registrar to rescan + // the whole assembly. + return new HandlerCandidateAnalysis( + implementationTypeDisplayName, + implementationLogName, + ImmutableArray.Empty, + ImmutableArray.Empty, + GetReflectionTypeMetadataName(type)); + } + registrations.Add(new HandlerRegistrationSpec( handlerInterface.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat), implementationTypeDisplayName, @@ -112,8 +125,9 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator return new HandlerCandidateAnalysis( implementationTypeDisplayName, implementationLogName, - registrations.MoveToImmutable(), - null); + registrations.ToImmutable(), + reflectedImplementationRegistrations.ToImmutable(), + canReferenceImplementation ? null : GetReflectionTypeMetadataName(type)); } private static void Execute(SourceProductionContext context, GenerationEnvironment generationEnvironment, @@ -155,6 +169,7 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator candidate.ImplementationTypeDisplayName, candidate.ImplementationLogName, candidate.Registrations, + candidate.ReflectedImplementationRegistrations, candidate.ReflectionTypeMetadataName)); } @@ -302,8 +317,11 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator private static string GenerateSource( IReadOnlyList registrations) { - var hasReflectionRegistrations = registrations.Any(static registration => - !string.IsNullOrWhiteSpace(registration.ReflectionTypeMetadataName)); + var hasReflectedImplementationRegistrations = registrations.Any(static registration => + !registration.ReflectedImplementationRegistrations.IsDefaultOrEmpty); + var hasFullReflectionRegistrations = registrations.Any(static registration => + !string.IsNullOrWhiteSpace(registration.ReflectionTypeMetadataName) && + registration.ReflectedImplementationRegistrations.IsDefaultOrEmpty); var builder = new StringBuilder(); builder.AppendLine("// "); builder.AppendLine("#nullable enable"); @@ -336,7 +354,7 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator builder.AppendLine(" throw new global::System.ArgumentNullException(nameof(services));"); builder.AppendLine(" if (logger is null)"); builder.AppendLine(" throw new global::System.ArgumentNullException(nameof(logger));"); - if (hasReflectionRegistrations) + if (hasReflectedImplementationRegistrations || hasFullReflectionRegistrations) { builder.AppendLine(); builder.Append(" var registryAssembly = typeof(global::"); @@ -349,8 +367,15 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator if (registrations.Count > 0) builder.AppendLine(); - foreach (var registration in registrations) + for (var registrationIndex = 0; registrationIndex < registrations.Count; registrationIndex++) { + var registration = registrations[registrationIndex]; + if (!registration.ReflectedImplementationRegistrations.IsDefaultOrEmpty) + { + AppendReflectedImplementationRegistrations(builder, registration, registrationIndex); + continue; + } + if (!string.IsNullOrWhiteSpace(registration.ReflectionTypeMetadataName)) { AppendReflectionRegistration(builder, registration.ReflectionTypeMetadataName!); @@ -378,7 +403,7 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator builder.AppendLine(" }"); - if (hasReflectionRegistrations) + if (hasFullReflectionRegistrations) { builder.AppendLine(); AppendReflectionHelpers(builder); @@ -395,6 +420,43 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator builder.AppendLine("\");"); } + private static void AppendReflectedImplementationRegistrations( + StringBuilder builder, + ImplementationRegistrationSpec registration, + int registrationIndex) + { + var implementationVariableName = $"implementationType{registrationIndex}"; + builder.Append(" var "); + builder.Append(implementationVariableName); + builder.Append(" = registryAssembly.GetType(\""); + builder.Append(EscapeStringLiteral(registration.ReflectionTypeMetadataName!)); + builder.AppendLine("\", throwOnError: false, ignoreCase: false);"); + builder.Append(" if ("); + builder.Append(implementationVariableName); + builder.AppendLine(" is not null)"); + builder.AppendLine(" {"); + + foreach (var reflectedRegistration in registration.ReflectedImplementationRegistrations) + { + builder.AppendLine( + " global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient("); + builder.AppendLine(" services,"); + builder.Append(" typeof("); + builder.Append(reflectedRegistration.HandlerInterfaceDisplayName); + builder.AppendLine("),"); + builder.Append(" "); + builder.Append(implementationVariableName); + builder.AppendLine(");"); + builder.Append(" logger.Debug(\"Registered CQRS handler "); + builder.Append(EscapeStringLiteral(registration.ImplementationLogName)); + builder.Append(" as "); + builder.Append(EscapeStringLiteral(reflectedRegistration.HandlerInterfaceLogName)); + builder.AppendLine(".\");"); + } + + builder.AppendLine(" }"); + } + private static void AppendReflectionHelpers(StringBuilder builder) { // Emit the runtime helper methods only when at least one handler requires metadata-name lookup. @@ -523,10 +585,15 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator string HandlerInterfaceLogName, string ImplementationLogName); + private readonly record struct ReflectedImplementationRegistrationSpec( + string HandlerInterfaceDisplayName, + string HandlerInterfaceLogName); + private readonly record struct ImplementationRegistrationSpec( string ImplementationTypeDisplayName, string ImplementationLogName, ImmutableArray DirectRegistrations, + ImmutableArray ReflectedImplementationRegistrations, string? ReflectionTypeMetadataName); private readonly struct HandlerCandidateAnalysis : IEquatable @@ -535,11 +602,13 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator string implementationTypeDisplayName, string implementationLogName, ImmutableArray registrations, + ImmutableArray reflectedImplementationRegistrations, string? reflectionTypeMetadataName) { ImplementationTypeDisplayName = implementationTypeDisplayName; ImplementationLogName = implementationLogName; Registrations = registrations; + ReflectedImplementationRegistrations = reflectedImplementationRegistrations; ReflectionTypeMetadataName = reflectionTypeMetadataName; } @@ -549,6 +618,8 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator public ImmutableArray Registrations { get; } + public ImmutableArray ReflectedImplementationRegistrations { get; } + public string? ReflectionTypeMetadataName { get; } public bool Equals(HandlerCandidateAnalysis other) @@ -558,7 +629,8 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator !string.Equals(ImplementationLogName, other.ImplementationLogName, StringComparison.Ordinal) || !string.Equals(ReflectionTypeMetadataName, other.ReflectionTypeMetadataName, StringComparison.Ordinal) || - Registrations.Length != other.Registrations.Length) + Registrations.Length != other.Registrations.Length || + ReflectedImplementationRegistrations.Length != other.ReflectedImplementationRegistrations.Length) { return false; } @@ -569,6 +641,13 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator return false; } + for (var index = 0; index < ReflectedImplementationRegistrations.Length; index++) + { + if (!ReflectedImplementationRegistrations[index].Equals( + other.ReflectedImplementationRegistrations[index])) + return false; + } + return true; } @@ -592,6 +671,11 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator hashCode = (hashCode * 397) ^ registration.GetHashCode(); } + foreach (var reflectedImplementationRegistration in ReflectedImplementationRegistrations) + { + hashCode = (hashCode * 397) ^ reflectedImplementationRegistration.GetHashCode(); + } + return hashCode; } }