diff --git a/GFramework.SourceGenerators.Tests/Cqrs/CqrsHandlerRegistryGeneratorTests.cs b/GFramework.SourceGenerators.Tests/Cqrs/CqrsHandlerRegistryGeneratorTests.cs
index dcdb5e5f..679fd0fa 100644
--- a/GFramework.SourceGenerators.Tests/Cqrs/CqrsHandlerRegistryGeneratorTests.cs
+++ b/GFramework.SourceGenerators.Tests/Cqrs/CqrsHandlerRegistryGeneratorTests.cs
@@ -142,6 +142,39 @@ public class CqrsHandlerRegistryGeneratorTests
""";
+ private const string HiddenImplementationDirectInterfaceRegistrationExpected = """
+ //
+ #nullable enable
+
+ [assembly: global::GFramework.Cqrs.CqrsHandlerRegistryAttribute(typeof(global::GFramework.Generated.Cqrs.__GFrameworkGeneratedCqrsHandlerRegistry))]
+
+ namespace GFramework.Generated.Cqrs;
+
+ internal sealed class __GFrameworkGeneratedCqrsHandlerRegistry : global::GFramework.Cqrs.ICqrsHandlerRegistry
+ {
+ public void Register(global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, global::GFramework.Core.Abstractions.Logging.ILogger logger)
+ {
+ if (services is null)
+ throw new global::System.ArgumentNullException(nameof(services));
+ if (logger is null)
+ throw new global::System.ArgumentNullException(nameof(logger));
+
+ var registryAssembly = typeof(global::GFramework.Generated.Cqrs.__GFrameworkGeneratedCqrsHandlerRegistry).Assembly;
+
+ var implementationType0 = registryAssembly.GetType("TestApp.Container+HiddenHandler", throwOnError: false, ignoreCase: false);
+ if (implementationType0 is not null)
+ {
+ global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient(
+ services,
+ typeof(global::GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler),
+ implementationType0);
+ logger.Debug("Registered CQRS handler TestApp.Container.HiddenHandler as GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler.");
+ }
+ }
+ }
+
+ """;
+
///
/// 验证生成器会为当前程序集中的 request、notification 和 stream 处理器生成稳定顺序的注册器。
///
@@ -339,6 +372,78 @@ public class CqrsHandlerRegistryGeneratorTests
("CqrsHandlerRegistry.g.cs", HiddenNestedHandlerSelfRegistrationExpected));
}
+ ///
+ /// 验证当隐藏实现类型的 handler 接口仍可被生成代码直接引用时,
+ /// 生成器只会定向反射实现类型,而不会再生成基于 GetInterfaces() 的接口发现辅助逻辑。
+ ///
+ [Test]
+ public async Task
+ Generates_Direct_Interface_Registrations_For_Hidden_Implementation_When_Handler_Interface_Is_Public()
+ {
+ const string source = """
+ using System;
+
+ namespace Microsoft.Extensions.DependencyInjection
+ {
+ public interface IServiceCollection { }
+
+ public static class ServiceCollectionServiceExtensions
+ {
+ public static void AddTransient(IServiceCollection services, Type serviceType, Type implementationType) { }
+ }
+ }
+
+ namespace GFramework.Core.Abstractions.Logging
+ {
+ public interface ILogger
+ {
+ void Debug(string msg);
+ }
+ }
+
+ namespace GFramework.Cqrs.Abstractions.Cqrs
+ {
+ public interface IRequest { }
+ public interface INotification { }
+ public interface IStreamRequest { }
+
+ public interface IRequestHandler where TRequest : IRequest { }
+ public interface INotificationHandler where TNotification : INotification { }
+ public interface IStreamRequestHandler where TRequest : IStreamRequest { }
+ }
+
+ namespace GFramework.Cqrs
+ {
+ public interface ICqrsHandlerRegistry
+ {
+ void Register(Microsoft.Extensions.DependencyInjection.IServiceCollection services, GFramework.Core.Abstractions.Logging.ILogger logger);
+ }
+
+ [AttributeUsage(AttributeTargets.Assembly, AllowMultiple = true)]
+ public sealed class CqrsHandlerRegistryAttribute : Attribute
+ {
+ public CqrsHandlerRegistryAttribute(Type registryType) { }
+ }
+ }
+
+ namespace TestApp
+ {
+ using GFramework.Cqrs.Abstractions.Cqrs;
+
+ public sealed record VisibleRequest() : IRequest;
+
+ public sealed class Container
+ {
+ private sealed class HiddenHandler : IRequestHandler { }
+ }
+ }
+ """;
+
+ await GeneratorTest.RunAsync(
+ source,
+ ("CqrsHandlerRegistry.g.cs", HiddenImplementationDirectInterfaceRegistrationExpected));
+ }
+
///
/// 验证即使 runtime 仍暴露旧版无参 fallback marker,生成器也会优先在生成注册器内部处理隐藏 handler,
/// 不再输出 fallback marker。
diff --git a/GFramework.SourceGenerators/Cqrs/CqrsHandlerRegistryGenerator.cs b/GFramework.SourceGenerators/Cqrs/CqrsHandlerRegistryGenerator.cs
index 83559781..27b01f9f 100644
--- a/GFramework.SourceGenerators/Cqrs/CqrsHandlerRegistryGenerator.cs
+++ b/GFramework.SourceGenerators/Cqrs/CqrsHandlerRegistryGenerator.cs
@@ -86,22 +86,35 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
var implementationTypeDisplayName = type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat);
var implementationLogName = GetLogDisplayName(type);
- if (!CanReferenceFromGeneratedRegistry(type) ||
- handlerInterfaces.Any(interfaceType => !CanReferenceFromGeneratedRegistry(interfaceType)))
- {
- // Non-public handlers and handlers closed over non-public message types cannot appear in typeof(...)
- // expressions inside generated code. Preserve generator hit rate by resolving just that implementation
- // type back from the current assembly instead of asking the runtime registrar to rescan the assembly.
- return new HandlerCandidateAnalysis(
- implementationTypeDisplayName,
- implementationLogName,
- ImmutableArray.Empty,
- GetReflectionTypeMetadataName(type));
- }
-
+ var canReferenceImplementation = CanReferenceFromGeneratedRegistry(type);
var registrations = ImmutableArray.CreateBuilder(handlerInterfaces.Length);
+ var reflectedImplementationRegistrations =
+ ImmutableArray.CreateBuilder(handlerInterfaces.Length);
foreach (var handlerInterface in handlerInterfaces)
{
+ var canReferenceHandlerInterface = CanReferenceFromGeneratedRegistry(handlerInterface);
+ if (!canReferenceImplementation || !canReferenceHandlerInterface)
+ {
+ if (!canReferenceImplementation && canReferenceHandlerInterface)
+ {
+ reflectedImplementationRegistrations.Add(new ReflectedImplementationRegistrationSpec(
+ handlerInterface.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat),
+ GetLogDisplayName(handlerInterface)));
+ continue;
+ }
+
+ // Non-public handlers closed over non-public message types still cannot be expressed purely as
+ // typeof(...) registrations. Preserve generator hit rate by resolving only the affected
+ // implementation back from the current assembly instead of asking the runtime registrar to rescan
+ // the whole assembly.
+ return new HandlerCandidateAnalysis(
+ implementationTypeDisplayName,
+ implementationLogName,
+ ImmutableArray.Empty,
+ ImmutableArray.Empty,
+ GetReflectionTypeMetadataName(type));
+ }
+
registrations.Add(new HandlerRegistrationSpec(
handlerInterface.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat),
implementationTypeDisplayName,
@@ -112,8 +125,9 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
return new HandlerCandidateAnalysis(
implementationTypeDisplayName,
implementationLogName,
- registrations.MoveToImmutable(),
- null);
+ registrations.ToImmutable(),
+ reflectedImplementationRegistrations.ToImmutable(),
+ canReferenceImplementation ? null : GetReflectionTypeMetadataName(type));
}
private static void Execute(SourceProductionContext context, GenerationEnvironment generationEnvironment,
@@ -155,6 +169,7 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
candidate.ImplementationTypeDisplayName,
candidate.ImplementationLogName,
candidate.Registrations,
+ candidate.ReflectedImplementationRegistrations,
candidate.ReflectionTypeMetadataName));
}
@@ -302,8 +317,11 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
private static string GenerateSource(
IReadOnlyList registrations)
{
- var hasReflectionRegistrations = registrations.Any(static registration =>
- !string.IsNullOrWhiteSpace(registration.ReflectionTypeMetadataName));
+ var hasReflectedImplementationRegistrations = registrations.Any(static registration =>
+ !registration.ReflectedImplementationRegistrations.IsDefaultOrEmpty);
+ var hasFullReflectionRegistrations = registrations.Any(static registration =>
+ !string.IsNullOrWhiteSpace(registration.ReflectionTypeMetadataName) &&
+ registration.ReflectedImplementationRegistrations.IsDefaultOrEmpty);
var builder = new StringBuilder();
builder.AppendLine("// ");
builder.AppendLine("#nullable enable");
@@ -336,7 +354,7 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
builder.AppendLine(" throw new global::System.ArgumentNullException(nameof(services));");
builder.AppendLine(" if (logger is null)");
builder.AppendLine(" throw new global::System.ArgumentNullException(nameof(logger));");
- if (hasReflectionRegistrations)
+ if (hasReflectedImplementationRegistrations || hasFullReflectionRegistrations)
{
builder.AppendLine();
builder.Append(" var registryAssembly = typeof(global::");
@@ -349,8 +367,15 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
if (registrations.Count > 0)
builder.AppendLine();
- foreach (var registration in registrations)
+ for (var registrationIndex = 0; registrationIndex < registrations.Count; registrationIndex++)
{
+ var registration = registrations[registrationIndex];
+ if (!registration.ReflectedImplementationRegistrations.IsDefaultOrEmpty)
+ {
+ AppendReflectedImplementationRegistrations(builder, registration, registrationIndex);
+ continue;
+ }
+
if (!string.IsNullOrWhiteSpace(registration.ReflectionTypeMetadataName))
{
AppendReflectionRegistration(builder, registration.ReflectionTypeMetadataName!);
@@ -378,7 +403,7 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
builder.AppendLine(" }");
- if (hasReflectionRegistrations)
+ if (hasFullReflectionRegistrations)
{
builder.AppendLine();
AppendReflectionHelpers(builder);
@@ -395,6 +420,43 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
builder.AppendLine("\");");
}
+ private static void AppendReflectedImplementationRegistrations(
+ StringBuilder builder,
+ ImplementationRegistrationSpec registration,
+ int registrationIndex)
+ {
+ var implementationVariableName = $"implementationType{registrationIndex}";
+ builder.Append(" var ");
+ builder.Append(implementationVariableName);
+ builder.Append(" = registryAssembly.GetType(\"");
+ builder.Append(EscapeStringLiteral(registration.ReflectionTypeMetadataName!));
+ builder.AppendLine("\", throwOnError: false, ignoreCase: false);");
+ builder.Append(" if (");
+ builder.Append(implementationVariableName);
+ builder.AppendLine(" is not null)");
+ builder.AppendLine(" {");
+
+ foreach (var reflectedRegistration in registration.ReflectedImplementationRegistrations)
+ {
+ builder.AppendLine(
+ " global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient(");
+ builder.AppendLine(" services,");
+ builder.Append(" typeof(");
+ builder.Append(reflectedRegistration.HandlerInterfaceDisplayName);
+ builder.AppendLine("),");
+ builder.Append(" ");
+ builder.Append(implementationVariableName);
+ builder.AppendLine(");");
+ builder.Append(" logger.Debug(\"Registered CQRS handler ");
+ builder.Append(EscapeStringLiteral(registration.ImplementationLogName));
+ builder.Append(" as ");
+ builder.Append(EscapeStringLiteral(reflectedRegistration.HandlerInterfaceLogName));
+ builder.AppendLine(".\");");
+ }
+
+ builder.AppendLine(" }");
+ }
+
private static void AppendReflectionHelpers(StringBuilder builder)
{
// Emit the runtime helper methods only when at least one handler requires metadata-name lookup.
@@ -523,10 +585,15 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
string HandlerInterfaceLogName,
string ImplementationLogName);
+ private readonly record struct ReflectedImplementationRegistrationSpec(
+ string HandlerInterfaceDisplayName,
+ string HandlerInterfaceLogName);
+
private readonly record struct ImplementationRegistrationSpec(
string ImplementationTypeDisplayName,
string ImplementationLogName,
ImmutableArray DirectRegistrations,
+ ImmutableArray ReflectedImplementationRegistrations,
string? ReflectionTypeMetadataName);
private readonly struct HandlerCandidateAnalysis : IEquatable
@@ -535,11 +602,13 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
string implementationTypeDisplayName,
string implementationLogName,
ImmutableArray registrations,
+ ImmutableArray reflectedImplementationRegistrations,
string? reflectionTypeMetadataName)
{
ImplementationTypeDisplayName = implementationTypeDisplayName;
ImplementationLogName = implementationLogName;
Registrations = registrations;
+ ReflectedImplementationRegistrations = reflectedImplementationRegistrations;
ReflectionTypeMetadataName = reflectionTypeMetadataName;
}
@@ -549,6 +618,8 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
public ImmutableArray Registrations { get; }
+ public ImmutableArray ReflectedImplementationRegistrations { get; }
+
public string? ReflectionTypeMetadataName { get; }
public bool Equals(HandlerCandidateAnalysis other)
@@ -558,7 +629,8 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
!string.Equals(ImplementationLogName, other.ImplementationLogName, StringComparison.Ordinal) ||
!string.Equals(ReflectionTypeMetadataName, other.ReflectionTypeMetadataName,
StringComparison.Ordinal) ||
- Registrations.Length != other.Registrations.Length)
+ Registrations.Length != other.Registrations.Length ||
+ ReflectedImplementationRegistrations.Length != other.ReflectedImplementationRegistrations.Length)
{
return false;
}
@@ -569,6 +641,13 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
return false;
}
+ for (var index = 0; index < ReflectedImplementationRegistrations.Length; index++)
+ {
+ if (!ReflectedImplementationRegistrations[index].Equals(
+ other.ReflectedImplementationRegistrations[index]))
+ return false;
+ }
+
return true;
}
@@ -592,6 +671,11 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
hashCode = (hashCode * 397) ^ registration.GetHashCode();
}
+ foreach (var reflectedImplementationRegistration in ReflectedImplementationRegistrations)
+ {
+ hashCode = (hashCode * 397) ^ reflectedImplementationRegistration.GetHashCode();
+ }
+
return hashCode;
}
}