feat(cqrs): 添加CQRS处理器注册器源代码生成器

- 实现了CqrsHandlerRegistryGenerator源代码生成器
- 为CQRS处理器减少运行时程序集反射扫描开销
- 支持IRequestHandler、INotificationHandler和IStreamRequestHandler接口
- 提供静态类型引用和运行时反射发现的混合注册策略
- 生成服务注册代码并添加调试日志记录功能
- 实现精确的运行时类型引用描述和泛型类型处理
This commit is contained in:
GeWuYou 2026-04-16 18:41:20 +08:00
parent 1792fafc85
commit eca2d67529
4 changed files with 347 additions and 70 deletions

View File

@ -16,6 +16,24 @@ public static class GeneratorTest<TGenerator>
public static async Task RunAsync( public static async Task RunAsync(
string source, string source,
params (string filename, string content)[] generatedSources) params (string filename, string content)[] generatedSources)
{
await RunAsync(
source,
additionalReferences: [],
generatedSources);
}
/// <summary>
/// 运行源代码生成器测试,并为测试编译显式追加元数据引用。
/// </summary>
/// <param name="source">输入的源代码。</param>
/// <param name="additionalReferences">附加元数据引用,用于构造多程序集场景。</param>
/// <param name="generatedSources">期望生成的源文件集合,包含文件名和内容的元组。</param>
/// <returns>异步操作任务。</returns>
public static async Task RunAsync(
string source,
IEnumerable<MetadataReference> additionalReferences,
params (string filename, string content)[] generatedSources)
{ {
var test = new CSharpSourceGeneratorTest<TGenerator, DefaultVerifier> var test = new CSharpSourceGeneratorTest<TGenerator, DefaultVerifier>
{ {
@ -31,6 +49,9 @@ public static class GeneratorTest<TGenerator>
test.TestState.GeneratedSources.Add( test.TestState.GeneratedSources.Add(
(typeof(TGenerator), filename, NormalizeLineEndings(content))); (typeof(TGenerator), filename, NormalizeLineEndings(content)));
foreach (var additionalReference in additionalReferences)
test.TestState.AdditionalReferences.Add(additionalReference);
await test.RunAsync(); await test.RunAsync();
} }

View File

@ -0,0 +1,65 @@
using System.Collections.Immutable;
using System.IO;
using Microsoft.CodeAnalysis.CSharp;
namespace GFramework.SourceGenerators.Tests.Core;
/// <summary>
/// 为多程序集源生成器测试构建内存元数据引用。
/// </summary>
public static class MetadataReferenceTestBuilder
{
/// <summary>
/// 将给定源码编译为内存程序集,并返回可供测试编译消费的元数据引用。
/// </summary>
/// <param name="assemblyName">目标程序集名称。</param>
/// <param name="source">待编译源码。</param>
/// <param name="additionalReferences">附加元数据引用,用于构造依赖链。</param>
/// <returns>编译成功后的内存元数据引用。</returns>
public static MetadataReference CreateFromSource(
string assemblyName,
string source,
params MetadataReference[] additionalReferences)
{
var syntaxTree = CSharpSyntaxTree.ParseText(source);
var references = GetRuntimeMetadataReferences()
.Concat(additionalReferences)
.ToImmutableArray();
var compilation = CSharpCompilation.Create(
assemblyName,
[syntaxTree],
references,
new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary));
using var stream = new MemoryStream();
var emitResult = compilation.Emit(stream);
if (!emitResult.Success)
{
var diagnostics = string.Join(
Environment.NewLine,
emitResult.Diagnostics
.Where(static diagnostic => diagnostic.Severity == DiagnosticSeverity.Error)
.Select(static diagnostic => diagnostic.ToString()));
throw new InvalidOperationException(
$"Failed to build metadata reference '{assemblyName}'.{Environment.NewLine}{diagnostics}");
}
stream.Position = 0;
return MetadataReference.CreateFromImage(stream.ToArray());
}
/// <summary>
/// 获取当前测试运行时可直接复用的基础元数据引用集合。
/// </summary>
/// <returns>当前运行时可信平台程序集对应的元数据引用。</returns>
public static ImmutableArray<MetadataReference> GetRuntimeMetadataReferences()
{
var trustedPlatformAssemblies = ((string?)AppContext.GetData("TRUSTED_PLATFORM_ASSEMBLIES"))?
.Split(Path.PathSeparator, StringSplitOptions.RemoveEmptyEntries)
?? Array.Empty<string>();
return trustedPlatformAssemblies
.Select(static path => (MetadataReference)MetadataReference.CreateFromFile(path))
.ToImmutableArray();
}
}

View File

@ -1,6 +1,7 @@
using System.Reflection; using System.Reflection;
using GFramework.SourceGenerators.Cqrs; using GFramework.SourceGenerators.Cqrs;
using GFramework.SourceGenerators.Tests.Core; using GFramework.SourceGenerators.Tests.Core;
using Microsoft.CodeAnalysis.CSharp;
namespace GFramework.SourceGenerators.Tests.Cqrs; namespace GFramework.SourceGenerators.Tests.Cqrs;
@ -825,6 +826,135 @@ public class CqrsHandlerRegistryGeneratorTests
("CqrsHandlerRegistry.g.cs", MixedReflectedImplementationAndPreciseRegistrationsExpected)); ("CqrsHandlerRegistry.g.cs", MixedReflectedImplementationAndPreciseRegistrationsExpected));
} }
/// <summary>
/// 验证当外部基类暴露的 handler interface 含有生成注册器顶层上下文不可直接引用的 protected 类型时,
/// 生成器会保留已知直注册,并只对剩余未知接口做本地 interface discovery。
/// </summary>
[Test]
public void Generates_Partial_Runtime_Interface_Discovery_For_Inaccessible_External_Protected_Types()
{
const string contractsSource = """
namespace GFramework.Cqrs.Abstractions.Cqrs
{
public interface IRequest<TResponse> { }
public interface INotification { }
public interface IStreamRequest<TResponse> { }
public interface IRequestHandler<in TRequest, TResponse> where TRequest : IRequest<TResponse> { }
public interface INotificationHandler<in TNotification> where TNotification : INotification { }
public interface IStreamRequestHandler<in TRequest, out TResponse> where TRequest : IStreamRequest<TResponse> { }
}
""";
const string dependencySource = """
using GFramework.Cqrs.Abstractions.Cqrs;
namespace Dep;
public sealed record VisibleRequest() : IRequest<string>;
public abstract class VisibilityScope
{
protected internal sealed record ProtectedResponse();
protected internal sealed record ProtectedRequest() : IRequest<ProtectedResponse[]>;
}
public abstract class HandlerBase :
VisibilityScope,
IRequestHandler<VisibleRequest, string>,
IRequestHandler<VisibilityScope.ProtectedRequest, VisibilityScope.ProtectedResponse[]>
{
}
""";
const string source = """
using System;
using Dep;
namespace Microsoft.Extensions.DependencyInjection
{
public interface IServiceCollection { }
public static class ServiceCollectionServiceExtensions
{
public static void AddTransient(IServiceCollection services, Type serviceType, Type implementationType) { }
}
}
namespace GFramework.Core.Abstractions.Logging
{
public interface ILogger
{
void Debug(string msg);
}
}
namespace GFramework.Cqrs
{
public interface ICqrsHandlerRegistry
{
void Register(Microsoft.Extensions.DependencyInjection.IServiceCollection services, GFramework.Core.Abstractions.Logging.ILogger logger);
}
[AttributeUsage(AttributeTargets.Assembly, AllowMultiple = true)]
public sealed class CqrsHandlerRegistryAttribute : Attribute
{
public CqrsHandlerRegistryAttribute(Type registryType) { }
}
}
namespace TestApp
{
public sealed class DerivedHandler : HandlerBase
{
}
}
""";
var contractsReference = MetadataReferenceTestBuilder.CreateFromSource(
"Contracts",
contractsSource);
var dependencyReference = MetadataReferenceTestBuilder.CreateFromSource(
"Dependency",
dependencySource,
contractsReference);
var generatedSource = RunGenerator(
source,
contractsReference,
dependencyReference);
Assert.Multiple(() =>
{
Assert.That(
generatedSource,
Does.Contain("var implementationType0 = typeof(global::TestApp.DerivedHandler);"));
Assert.That(
generatedSource,
Does.Contain(
"var knownServiceTypes0 = new global::System.Collections.Generic.HashSet<global::System.Type>();"));
Assert.That(
generatedSource,
Does.Contain(
"knownServiceTypes0.Add(typeof(global::GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler<global::Dep.VisibleRequest, string>));"));
Assert.That(
generatedSource,
Does.Contain(
"RegisterRemainingReflectedHandlerInterfaces(services, logger, implementationType0, knownServiceTypes0);"));
Assert.That(
generatedSource,
Does.Contain("if (knownServiceTypes.Contains(handlerInterface))"));
Assert.That(
generatedSource,
Does.Contain(
"Registered CQRS handler TestApp.DerivedHandler as GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler<Dep.VisibleRequest, string>."));
Assert.That(
generatedSource,
Does.Not.Contain(
"typeof(global::GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler<global::Dep.VisibilityScope.ProtectedRequest"));
});
}
/// <summary> /// <summary>
/// 验证即使 runtime 仍暴露旧版无参 fallback marker生成器也会优先在生成注册器内部处理隐藏 handler /// 验证即使 runtime 仍暴露旧版无参 fallback marker生成器也会优先在生成注册器内部处理隐藏 handler
/// 不再输出 fallback marker。 /// 不再输出 fallback marker。
@ -999,4 +1129,38 @@ public class CqrsHandlerRegistryGeneratorTests
Assert.That(escaped, Is.EqualTo(expected)); Assert.That(escaped, Is.EqualTo(expected));
} }
/// <summary>
/// 运行 CQRS handler registry generator并返回单个生成文件的源码文本。
/// </summary>
private static string RunGenerator(
string source,
params MetadataReference[] additionalReferences)
{
var syntaxTree = CSharpSyntaxTree.ParseText(source);
var compilation = CSharpCompilation.Create(
"TestProject",
[syntaxTree],
MetadataReferenceTestBuilder.GetRuntimeMetadataReferences().AddRange(additionalReferences),
new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary));
GeneratorDriver driver = CSharpGeneratorDriver.Create(
generators: [new CqrsHandlerRegistryGenerator().AsSourceGenerator()],
parseOptions: (CSharpParseOptions)syntaxTree.Options);
driver = driver.RunGeneratorsAndUpdateCompilation(
compilation,
out var updatedCompilation,
out _);
var compilationErrors = updatedCompilation.GetDiagnostics()
.Where(static diagnostic => diagnostic.Severity == DiagnosticSeverity.Error)
.ToArray();
Assert.That(compilationErrors, Is.Empty, string.Join(Environment.NewLine, compilationErrors));
var runResult = driver.GetRunResult();
Assert.That(runResult.Results, Has.Length.EqualTo(1));
Assert.That(runResult.Results[0].GeneratedSources, Has.Length.EqualTo(1));
return runResult.Results[0].GeneratedSources[0].SourceText.ToString();
}
} }

View File

@ -86,15 +86,17 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
var implementationTypeDisplayName = type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); var implementationTypeDisplayName = type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat);
var implementationLogName = GetLogDisplayName(type); var implementationLogName = GetLogDisplayName(type);
var canReferenceImplementation = CanReferenceFromGeneratedRegistry(type); var canReferenceImplementation = CanReferenceFromGeneratedRegistry(context.SemanticModel.Compilation, type);
var registrations = ImmutableArray.CreateBuilder<HandlerRegistrationSpec>(handlerInterfaces.Length); var registrations = ImmutableArray.CreateBuilder<HandlerRegistrationSpec>(handlerInterfaces.Length);
var reflectedImplementationRegistrations = var reflectedImplementationRegistrations =
ImmutableArray.CreateBuilder<ReflectedImplementationRegistrationSpec>(handlerInterfaces.Length); ImmutableArray.CreateBuilder<ReflectedImplementationRegistrationSpec>(handlerInterfaces.Length);
var preciseReflectedRegistrations = var preciseReflectedRegistrations =
ImmutableArray.CreateBuilder<PreciseReflectedRegistrationSpec>(handlerInterfaces.Length); ImmutableArray.CreateBuilder<PreciseReflectedRegistrationSpec>(handlerInterfaces.Length);
var requiresRuntimeInterfaceDiscovery = false;
foreach (var handlerInterface in handlerInterfaces) foreach (var handlerInterface in handlerInterfaces)
{ {
var canReferenceHandlerInterface = CanReferenceFromGeneratedRegistry(handlerInterface); var canReferenceHandlerInterface =
CanReferenceFromGeneratedRegistry(context.SemanticModel.Compilation, handlerInterface);
if (canReferenceImplementation && canReferenceHandlerInterface) if (canReferenceImplementation && canReferenceHandlerInterface)
{ {
registrations.Add(new HandlerRegistrationSpec( registrations.Add(new HandlerRegistrationSpec(
@ -122,16 +124,10 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
continue; continue;
} }
// Some closed handler interfaces still contain runtime-only type shapes such as arrays closed over // 某些关闭 handler interface 仍包含只能在实现类型运行时语义里解析的类型形态。
// non-public element types. For those rare cases keep the narrow implementation lookup, but let the // 对这些边角场景保留“已知接口静态注册 + 剩余接口运行时补洞”的组合路径,
// generated registry discover the exact supported interfaces from the implementation type at runtime. // 避免单个未知接口把同实现上的其它已知注册全部拖回整实现反射发现。
return new HandlerCandidateAnalysis( requiresRuntimeInterfaceDiscovery = true;
implementationTypeDisplayName,
implementationLogName,
ImmutableArray<HandlerRegistrationSpec>.Empty,
ImmutableArray<ReflectedImplementationRegistrationSpec>.Empty,
ImmutableArray<PreciseReflectedRegistrationSpec>.Empty,
GetReflectionTypeMetadataName(type));
} }
return new HandlerCandidateAnalysis( return new HandlerCandidateAnalysis(
@ -140,7 +136,8 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
registrations.ToImmutable(), registrations.ToImmutable(),
reflectedImplementationRegistrations.ToImmutable(), reflectedImplementationRegistrations.ToImmutable(),
preciseReflectedRegistrations.ToImmutable(), preciseReflectedRegistrations.ToImmutable(),
canReferenceImplementation ? null : GetReflectionTypeMetadataName(type)); canReferenceImplementation ? null : GetReflectionTypeMetadataName(type),
requiresRuntimeInterfaceDiscovery);
} }
private static void Execute(SourceProductionContext context, GenerationEnvironment generationEnvironment, private static void Execute(SourceProductionContext context, GenerationEnvironment generationEnvironment,
@ -184,7 +181,8 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
candidate.Registrations, candidate.Registrations,
candidate.ReflectedImplementationRegistrations, candidate.ReflectedImplementationRegistrations,
candidate.PreciseReflectedRegistrations, candidate.PreciseReflectedRegistrations,
candidate.ReflectionTypeMetadataName)); candidate.ReflectionTypeMetadataName,
candidate.RequiresRuntimeInterfaceDiscovery));
} }
registrations.Sort(static (left, right) => registrations.Sort(static (left, right) =>
@ -295,7 +293,7 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
ITypeSymbol type, ITypeSymbol type,
out RuntimeTypeReferenceSpec? runtimeTypeReference) out RuntimeTypeReferenceSpec? runtimeTypeReference)
{ {
if (CanReferenceFromGeneratedRegistry(type)) if (CanReferenceFromGeneratedRegistry(compilation, type))
{ {
runtimeTypeReference = RuntimeTypeReferenceSpec.FromDirectReference( runtimeTypeReference = RuntimeTypeReferenceSpec.FromDirectReference(
type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)); type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat));
@ -369,7 +367,7 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
out RuntimeTypeReferenceSpec? genericTypeDefinitionReference) out RuntimeTypeReferenceSpec? genericTypeDefinitionReference)
{ {
var genericTypeDefinition = genericNamedType.OriginalDefinition; var genericTypeDefinition = genericNamedType.OriginalDefinition;
if (CanReferenceFromGeneratedRegistry(genericTypeDefinition)) if (CanReferenceFromGeneratedRegistry(compilation, genericTypeDefinition))
{ {
genericTypeDefinitionReference = RuntimeTypeReferenceSpec.FromDirectReference( genericTypeDefinitionReference = RuntimeTypeReferenceSpec.FromDirectReference(
genericTypeDefinition genericTypeDefinition
@ -389,19 +387,25 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
return false; return false;
} }
private static bool CanReferenceFromGeneratedRegistry(ITypeSymbol type) private static bool CanReferenceFromGeneratedRegistry(Compilation compilation, ITypeSymbol type)
{ {
switch (type) switch (type)
{ {
case IArrayTypeSymbol arrayType: case IArrayTypeSymbol arrayType:
return CanReferenceFromGeneratedRegistry(arrayType.ElementType); return CanReferenceFromGeneratedRegistry(compilation, arrayType.ElementType);
case INamedTypeSymbol namedType: case INamedTypeSymbol namedType:
if (!IsTypeChainAccessible(namedType)) if (!compilation.IsSymbolAccessibleWithin(namedType, compilation.Assembly, throughType: null))
return false; return false;
return namedType.TypeArguments.All(CanReferenceFromGeneratedRegistry); foreach (var typeArgument in namedType.TypeArguments)
{
if (!CanReferenceFromGeneratedRegistry(compilation, typeArgument))
return false;
}
return true;
case IPointerTypeSymbol pointerType: case IPointerTypeSymbol pointerType:
return CanReferenceFromGeneratedRegistry(pointerType.PointedAtType); return CanReferenceFromGeneratedRegistry(compilation, pointerType.PointedAtType);
case ITypeParameterSymbol: case ITypeParameterSymbol:
return false; return false;
default: default:
@ -409,23 +413,6 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
} }
} }
private static bool IsTypeChainAccessible(INamedTypeSymbol type)
{
for (var current = type; current is not null; current = current.ContainingType)
{
if (!IsSymbolAccessible(current))
return false;
}
return true;
}
private static bool IsSymbolAccessible(ISymbol symbol)
{
return symbol.DeclaredAccessibility is Accessibility.Public or Accessibility.Internal
or Accessibility.ProtectedOrInternal;
}
private static string GetFullyQualifiedMetadataName(INamedTypeSymbol type) private static string GetFullyQualifiedMetadataName(INamedTypeSymbol type)
{ {
var nestedTypes = new Stack<string>(); var nestedTypes = new Stack<string>();
@ -496,10 +483,8 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
!registration.ReflectedImplementationRegistrations.IsDefaultOrEmpty); !registration.ReflectedImplementationRegistrations.IsDefaultOrEmpty);
var hasPreciseReflectedRegistrations = registrations.Any(static registration => var hasPreciseReflectedRegistrations = registrations.Any(static registration =>
!registration.PreciseReflectedRegistrations.IsDefaultOrEmpty); !registration.PreciseReflectedRegistrations.IsDefaultOrEmpty);
var hasFullReflectionRegistrations = registrations.Any(static registration => var hasRuntimeInterfaceDiscovery = registrations.Any(static registration =>
!string.IsNullOrWhiteSpace(registration.ReflectionTypeMetadataName) && registration.RequiresRuntimeInterfaceDiscovery);
registration.ReflectedImplementationRegistrations.IsDefaultOrEmpty &&
registration.PreciseReflectedRegistrations.IsDefaultOrEmpty);
var builder = new StringBuilder(); var builder = new StringBuilder();
builder.AppendLine("// <auto-generated />"); builder.AppendLine("// <auto-generated />");
builder.AppendLine("#nullable enable"); builder.AppendLine("#nullable enable");
@ -533,7 +518,8 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
builder.AppendLine(" if (logger is null)"); builder.AppendLine(" if (logger is null)");
builder.AppendLine(" throw new global::System.ArgumentNullException(nameof(logger));"); builder.AppendLine(" throw new global::System.ArgumentNullException(nameof(logger));");
if (hasReflectedImplementationRegistrations || hasPreciseReflectedRegistrations || if (hasReflectedImplementationRegistrations || hasPreciseReflectedRegistrations ||
hasFullReflectionRegistrations) registrations.Any(static registration =>
!string.IsNullOrWhiteSpace(registration.ReflectionTypeMetadataName)))
{ {
builder.AppendLine(); builder.AppendLine();
builder.Append(" var registryAssembly = typeof(global::"); builder.Append(" var registryAssembly = typeof(global::");
@ -550,7 +536,8 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
{ {
var registration = registrations[registrationIndex]; var registration = registrations[registrationIndex];
if (!registration.ReflectedImplementationRegistrations.IsDefaultOrEmpty || if (!registration.ReflectedImplementationRegistrations.IsDefaultOrEmpty ||
!registration.PreciseReflectedRegistrations.IsDefaultOrEmpty) !registration.PreciseReflectedRegistrations.IsDefaultOrEmpty ||
registration.RequiresRuntimeInterfaceDiscovery)
{ {
AppendOrderedImplementationRegistrations(builder, registration, registrationIndex); AppendOrderedImplementationRegistrations(builder, registration, registrationIndex);
} }
@ -558,19 +545,11 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
{ {
AppendDirectRegistrations(builder, registration); AppendDirectRegistrations(builder, registration);
} }
if (!string.IsNullOrWhiteSpace(registration.ReflectionTypeMetadataName) &&
registration.ReflectedImplementationRegistrations.IsDefaultOrEmpty &&
registration.PreciseReflectedRegistrations.IsDefaultOrEmpty &&
registration.DirectRegistrations.IsDefaultOrEmpty)
{
AppendReflectionRegistration(builder, registration.ReflectionTypeMetadataName!);
}
} }
builder.AppendLine(" }"); builder.AppendLine(" }");
if (hasFullReflectionRegistrations) if (hasRuntimeInterfaceDiscovery)
{ {
builder.AppendLine(); builder.AppendLine();
AppendReflectionHelpers(builder); AppendReflectionHelpers(builder);
@ -580,13 +559,6 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
return builder.ToString(); return builder.ToString();
} }
private static void AppendReflectionRegistration(StringBuilder builder, string reflectionTypeMetadataName)
{
builder.Append(" RegisterReflectedHandler(services, logger, registryAssembly, \"");
builder.Append(EscapeStringLiteral(reflectionTypeMetadataName));
builder.AppendLine("\");");
}
private static void AppendDirectRegistrations( private static void AppendDirectRegistrations(
StringBuilder builder, StringBuilder builder,
ImplementationRegistrationSpec registration) ImplementationRegistrationSpec registration)
@ -653,6 +625,7 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
StringComparer.Ordinal.Compare(left.HandlerInterfaceLogName, right.HandlerInterfaceLogName)); StringComparer.Ordinal.Compare(left.HandlerInterfaceLogName, right.HandlerInterfaceLogName));
var implementationVariableName = $"implementationType{registrationIndex}"; var implementationVariableName = $"implementationType{registrationIndex}";
var knownServiceTypesVariableName = $"knownServiceTypes{registrationIndex}";
if (string.IsNullOrWhiteSpace(registration.ReflectionTypeMetadataName)) if (string.IsNullOrWhiteSpace(registration.ReflectionTypeMetadataName))
{ {
builder.Append(" var "); builder.Append(" var ");
@ -675,12 +648,28 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
builder.AppendLine(" is not null)"); builder.AppendLine(" is not null)");
builder.AppendLine(" {"); builder.AppendLine(" {");
if (registration.RequiresRuntimeInterfaceDiscovery)
{
builder.Append(" var ");
builder.Append(knownServiceTypesVariableName);
builder.AppendLine(" = new global::System.Collections.Generic.HashSet<global::System.Type>();");
}
foreach (var orderedRegistration in orderedRegistrations) foreach (var orderedRegistration in orderedRegistrations)
{ {
switch (orderedRegistration.Kind) switch (orderedRegistration.Kind)
{ {
case OrderedRegistrationKind.Direct: case OrderedRegistrationKind.Direct:
var directRegistration = registration.DirectRegistrations[orderedRegistration.Index]; var directRegistration = registration.DirectRegistrations[orderedRegistration.Index];
if (registration.RequiresRuntimeInterfaceDiscovery)
{
builder.Append(" ");
builder.Append(knownServiceTypesVariableName);
builder.Append(".Add(typeof(");
builder.Append(directRegistration.HandlerInterfaceDisplayName);
builder.AppendLine("));");
}
builder.AppendLine( builder.AppendLine(
" global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient("); " global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient(");
builder.AppendLine(" services,"); builder.AppendLine(" services,");
@ -699,6 +688,15 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
case OrderedRegistrationKind.ReflectedImplementation: case OrderedRegistrationKind.ReflectedImplementation:
var reflectedRegistration = var reflectedRegistration =
registration.ReflectedImplementationRegistrations[orderedRegistration.Index]; registration.ReflectedImplementationRegistrations[orderedRegistration.Index];
if (registration.RequiresRuntimeInterfaceDiscovery)
{
builder.Append(" ");
builder.Append(knownServiceTypesVariableName);
builder.Append(".Add(typeof(");
builder.Append(reflectedRegistration.HandlerInterfaceDisplayName);
builder.AppendLine("));");
}
builder.AppendLine( builder.AppendLine(
" global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient("); " global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient(");
builder.AppendLine(" services,"); builder.AppendLine(" services,");
@ -725,6 +723,8 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
preciseRegistration.OpenHandlerTypeDisplayName, preciseRegistration.OpenHandlerTypeDisplayName,
registration.ImplementationLogName, registration.ImplementationLogName,
preciseRegistration.HandlerInterfaceLogName, preciseRegistration.HandlerInterfaceLogName,
knownServiceTypesVariableName,
registration.RequiresRuntimeInterfaceDiscovery,
3); 3);
break; break;
default: default:
@ -733,6 +733,15 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
} }
} }
if (registration.RequiresRuntimeInterfaceDiscovery)
{
builder.Append(" RegisterRemainingReflectedHandlerInterfaces(services, logger, ");
builder.Append(implementationVariableName);
builder.Append(", ");
builder.Append(knownServiceTypesVariableName);
builder.AppendLine(");");
}
builder.AppendLine(" }"); builder.AppendLine(" }");
} }
@ -744,6 +753,8 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
string openHandlerTypeDisplayName, string openHandlerTypeDisplayName,
string implementationLogName, string implementationLogName,
string handlerInterfaceLogName, string handlerInterfaceLogName,
string knownServiceTypesVariableName,
bool trackKnownServiceTypes,
int indentLevel) int indentLevel)
{ {
var indent = new string(' ', indentLevel * 4); var indent = new string(' ', indentLevel * 4);
@ -808,6 +819,15 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
builder.Append(" "); builder.Append(" ");
builder.Append(implementationVariableName); builder.Append(implementationVariableName);
builder.AppendLine(");"); builder.AppendLine(");");
if (trackKnownServiceTypes)
{
builder.Append(indent);
builder.Append(knownServiceTypesVariableName);
builder.Append(".Add(");
builder.Append(registrationVariablePrefix);
builder.AppendLine(");");
}
builder.Append(indent); builder.Append(indent);
builder.Append("logger.Debug(\"Registered CQRS handler "); builder.Append("logger.Debug(\"Registered CQRS handler ");
builder.Append(EscapeStringLiteral(implementationLogName)); builder.Append(EscapeStringLiteral(implementationLogName));
@ -884,15 +904,11 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
private static void AppendReflectionHelpers(StringBuilder builder) private static void AppendReflectionHelpers(StringBuilder builder)
{ {
// Emit the runtime helper methods only when at least one handler requires metadata-name lookup. // Emit the runtime helper methods only when at least one handler still needs implementation-scoped
// interface discovery after all direct / precise registrations have been emitted.
builder.AppendLine( builder.AppendLine(
" private static void RegisterReflectedHandler(global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, global::GFramework.Core.Abstractions.Logging.ILogger logger, global::System.Reflection.Assembly registryAssembly, string implementationTypeMetadataName)"); " private static void RegisterRemainingReflectedHandlerInterfaces(global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, global::GFramework.Core.Abstractions.Logging.ILogger logger, global::System.Type implementationType, global::System.Collections.Generic.ISet<global::System.Type> knownServiceTypes)");
builder.AppendLine(" {"); builder.AppendLine(" {");
builder.AppendLine(
" var implementationType = registryAssembly.GetType(implementationTypeMetadataName, throwOnError: false, ignoreCase: false);");
builder.AppendLine(" if (implementationType is null)");
builder.AppendLine(" return;");
builder.AppendLine();
builder.AppendLine(" var handlerInterfaces = implementationType.GetInterfaces();"); builder.AppendLine(" var handlerInterfaces = implementationType.GetInterfaces();");
builder.AppendLine(" global::System.Array.Sort(handlerInterfaces, CompareTypes);"); builder.AppendLine(" global::System.Array.Sort(handlerInterfaces, CompareTypes);");
builder.AppendLine(); builder.AppendLine();
@ -901,6 +917,9 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
builder.AppendLine(" if (!IsSupportedHandlerInterface(handlerInterface))"); builder.AppendLine(" if (!IsSupportedHandlerInterface(handlerInterface))");
builder.AppendLine(" continue;"); builder.AppendLine(" continue;");
builder.AppendLine(); builder.AppendLine();
builder.AppendLine(" if (knownServiceTypes.Contains(handlerInterface))");
builder.AppendLine(" continue;");
builder.AppendLine();
builder.AppendLine( builder.AppendLine(
" global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient("); " global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient(");
builder.AppendLine(" services,"); builder.AppendLine(" services,");
@ -908,6 +927,7 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
builder.AppendLine(" implementationType);"); builder.AppendLine(" implementationType);");
builder.AppendLine( builder.AppendLine(
" logger.Debug($\"Registered CQRS handler {GetRuntimeTypeDisplayName(implementationType)} as {GetRuntimeTypeDisplayName(handlerInterface)}.\");"); " logger.Debug($\"Registered CQRS handler {GetRuntimeTypeDisplayName(implementationType)} as {GetRuntimeTypeDisplayName(handlerInterface)}.\");");
builder.AppendLine(" knownServiceTypes.Add(handlerInterface);");
builder.AppendLine(" }"); builder.AppendLine(" }");
builder.AppendLine(" }"); builder.AppendLine(" }");
builder.AppendLine(); builder.AppendLine();
@ -1067,7 +1087,8 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
ImmutableArray<HandlerRegistrationSpec> DirectRegistrations, ImmutableArray<HandlerRegistrationSpec> DirectRegistrations,
ImmutableArray<ReflectedImplementationRegistrationSpec> ReflectedImplementationRegistrations, ImmutableArray<ReflectedImplementationRegistrationSpec> ReflectedImplementationRegistrations,
ImmutableArray<PreciseReflectedRegistrationSpec> PreciseReflectedRegistrations, ImmutableArray<PreciseReflectedRegistrationSpec> PreciseReflectedRegistrations,
string? ReflectionTypeMetadataName); string? ReflectionTypeMetadataName,
bool RequiresRuntimeInterfaceDiscovery);
private readonly struct HandlerCandidateAnalysis : IEquatable<HandlerCandidateAnalysis> private readonly struct HandlerCandidateAnalysis : IEquatable<HandlerCandidateAnalysis>
{ {
@ -1077,7 +1098,8 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
ImmutableArray<HandlerRegistrationSpec> registrations, ImmutableArray<HandlerRegistrationSpec> registrations,
ImmutableArray<ReflectedImplementationRegistrationSpec> reflectedImplementationRegistrations, ImmutableArray<ReflectedImplementationRegistrationSpec> reflectedImplementationRegistrations,
ImmutableArray<PreciseReflectedRegistrationSpec> preciseReflectedRegistrations, ImmutableArray<PreciseReflectedRegistrationSpec> preciseReflectedRegistrations,
string? reflectionTypeMetadataName) string? reflectionTypeMetadataName,
bool requiresRuntimeInterfaceDiscovery)
{ {
ImplementationTypeDisplayName = implementationTypeDisplayName; ImplementationTypeDisplayName = implementationTypeDisplayName;
ImplementationLogName = implementationLogName; ImplementationLogName = implementationLogName;
@ -1085,6 +1107,7 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
ReflectedImplementationRegistrations = reflectedImplementationRegistrations; ReflectedImplementationRegistrations = reflectedImplementationRegistrations;
PreciseReflectedRegistrations = preciseReflectedRegistrations; PreciseReflectedRegistrations = preciseReflectedRegistrations;
ReflectionTypeMetadataName = reflectionTypeMetadataName; ReflectionTypeMetadataName = reflectionTypeMetadataName;
RequiresRuntimeInterfaceDiscovery = requiresRuntimeInterfaceDiscovery;
} }
public string ImplementationTypeDisplayName { get; } public string ImplementationTypeDisplayName { get; }
@ -1099,6 +1122,8 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
public string? ReflectionTypeMetadataName { get; } public string? ReflectionTypeMetadataName { get; }
public bool RequiresRuntimeInterfaceDiscovery { get; }
public bool Equals(HandlerCandidateAnalysis other) public bool Equals(HandlerCandidateAnalysis other)
{ {
if (!string.Equals(ImplementationTypeDisplayName, other.ImplementationTypeDisplayName, if (!string.Equals(ImplementationTypeDisplayName, other.ImplementationTypeDisplayName,
@ -1106,6 +1131,7 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
!string.Equals(ImplementationLogName, other.ImplementationLogName, StringComparison.Ordinal) || !string.Equals(ImplementationLogName, other.ImplementationLogName, StringComparison.Ordinal) ||
!string.Equals(ReflectionTypeMetadataName, other.ReflectionTypeMetadataName, !string.Equals(ReflectionTypeMetadataName, other.ReflectionTypeMetadataName,
StringComparison.Ordinal) || StringComparison.Ordinal) ||
RequiresRuntimeInterfaceDiscovery != other.RequiresRuntimeInterfaceDiscovery ||
Registrations.Length != other.Registrations.Length || Registrations.Length != other.Registrations.Length ||
ReflectedImplementationRegistrations.Length != other.ReflectedImplementationRegistrations.Length || ReflectedImplementationRegistrations.Length != other.ReflectedImplementationRegistrations.Length ||
PreciseReflectedRegistrations.Length != other.PreciseReflectedRegistrations.Length) PreciseReflectedRegistrations.Length != other.PreciseReflectedRegistrations.Length)
@ -1150,6 +1176,7 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
(ReflectionTypeMetadataName is null (ReflectionTypeMetadataName is null
? 0 ? 0
: StringComparer.Ordinal.GetHashCode(ReflectionTypeMetadataName)); : StringComparer.Ordinal.GetHashCode(ReflectionTypeMetadataName));
hashCode = (hashCode * 397) ^ RequiresRuntimeInterfaceDiscovery.GetHashCode();
foreach (var registration in Registrations) foreach (var registration in Registrations)
{ {
hashCode = (hashCode * 397) ^ registration.GetHashCode(); hashCode = (hashCode * 397) ^ registration.GetHashCode();