feat(cqrs): 添加 CQRS 处理器注册生成器

- 实现 CqrsHandlerRegistryGenerator 源代码生成器
- 支持 IRequestHandler、INotificationHandler 和 IStreamRequestHandler 接口的处理器注册
- 生成程序集级别的 CQRS 处理器注册器以减少运行时反射开销
- 添加对请求、通知和流处理器的稳定顺序注册支持
- 实现对私有嵌套处理器的检测和回退机制
- 提供字符串字面量转义功能以避免生成代码中的语法错误
- 添加完整的单元测试验证生成器的功能和边界条件
This commit is contained in:
GeWuYou 2026-04-15 11:12:36 +08:00
parent fd64423741
commit 7a6f966601
2 changed files with 203 additions and 96 deletions

View File

@ -1,3 +1,4 @@
using System.Reflection;
using GFramework.SourceGenerators.Cqrs;
using GFramework.SourceGenerators.Tests.Core;
@ -192,4 +193,23 @@ public class CqrsHandlerRegistryGeneratorTests
await test.RunAsync();
}
/// <summary>
/// 验证日志字符串转义会覆盖换行、反斜杠和双引号,避免生成代码中的字符串字面量被意外截断。
/// </summary>
[Test]
public void Escape_String_Literal_Handles_Control_Characters()
{
var method = typeof(CqrsHandlerRegistryGenerator).GetMethod(
"EscapeStringLiteral",
BindingFlags.NonPublic | BindingFlags.Static);
Assert.That(method, Is.Not.Null);
const string input = "line1\r\nline2\\\"";
const string expected = "line1\\r\\nline2\\\\\\\"";
var escaped = method!.Invoke(null, [input]) as string;
Assert.That(escaped, Is.EqualTo(expected));
}
}

View File

@ -24,38 +24,93 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
/// <inheritdoc />
public void Initialize(IncrementalGeneratorInitializationContext context)
{
var generationEnabled = context.CompilationProvider
.Select(static (compilation, _) => HasRequiredTypes(compilation));
// Restrict semantic analysis to type declarations that can actually contribute implemented interfaces.
var handlerCandidates = context.SyntaxProvider.CreateSyntaxProvider(
static (node, _) => IsHandlerCandidate(node),
static (syntaxContext, _) => TransformHandlerCandidate(syntaxContext))
.Where(static candidate => candidate is not null)
.Collect();
context.RegisterSourceOutput(
context.CompilationProvider,
static (productionContext, compilation) => Execute(productionContext, compilation));
generationEnabled.Combine(handlerCandidates),
static (productionContext, pair) => Execute(productionContext, pair.Left, pair.Right));
}
private static void Execute(SourceProductionContext context, Compilation compilation)
private static bool HasRequiredTypes(Compilation compilation)
{
var requestHandlerType = compilation.GetTypeByMetadataName(IRequestHandlerMetadataName);
var notificationHandlerType = compilation.GetTypeByMetadataName(INotificationHandlerMetadataName);
var streamHandlerType = compilation.GetTypeByMetadataName(IStreamRequestHandlerMetadataName);
var registryInterfaceType = compilation.GetTypeByMetadataName(ICqrsHandlerRegistryMetadataName);
var registryAttributeType = compilation.GetTypeByMetadataName(CqrsHandlerRegistryAttributeMetadataName);
var loggerType = compilation.GetTypeByMetadataName(ILoggerMetadataName);
var serviceCollectionType = compilation.GetTypeByMetadataName(IServiceCollectionMetadataName);
return compilation.GetTypeByMetadataName(IRequestHandlerMetadataName) is not null &&
compilation.GetTypeByMetadataName(INotificationHandlerMetadataName) is not null &&
compilation.GetTypeByMetadataName(IStreamRequestHandlerMetadataName) is not null &&
compilation.GetTypeByMetadataName(ICqrsHandlerRegistryMetadataName) is not null &&
compilation.GetTypeByMetadataName(CqrsHandlerRegistryAttributeMetadataName) is not null &&
compilation.GetTypeByMetadataName(ILoggerMetadataName) is not null &&
compilation.GetTypeByMetadataName(IServiceCollectionMetadataName) is not null;
}
if (requestHandlerType is null ||
notificationHandlerType is null ||
streamHandlerType is null ||
registryInterfaceType is null ||
registryAttributeType is null ||
loggerType is null ||
serviceCollectionType is null)
private static bool IsHandlerCandidate(SyntaxNode node)
{
return node is TypeDeclarationSyntax
{
BaseList.Types.Count: > 0
};
}
private static HandlerCandidateAnalysis? TransformHandlerCandidate(GeneratorSyntaxContext context)
{
if (context.Node is not TypeDeclarationSyntax typeDeclaration)
return null;
if (context.SemanticModel.GetDeclaredSymbol(typeDeclaration) is not INamedTypeSymbol type)
return null;
if (!IsConcreteHandlerType(type))
return null;
var handlerInterfaces = type.AllInterfaces
.Where(IsSupportedHandlerInterface)
.OrderBy(GetTypeSortKey, StringComparer.Ordinal)
.ToImmutableArray();
if (handlerInterfaces.IsDefaultOrEmpty)
return null;
var implementationTypeDisplayName = type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat);
if (!CanReferenceFromGeneratedRegistry(type) ||
handlerInterfaces.Any(interfaceType => !CanReferenceFromGeneratedRegistry(interfaceType)))
{
return new HandlerCandidateAnalysis(
implementationTypeDisplayName,
ImmutableArray<HandlerRegistrationSpec>.Empty,
true);
}
var implementationLogName = GetLogDisplayName(type);
var registrations = ImmutableArray.CreateBuilder<HandlerRegistrationSpec>(handlerInterfaces.Length);
foreach (var handlerInterface in handlerInterfaces)
{
registrations.Add(new HandlerRegistrationSpec(
handlerInterface.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat),
implementationTypeDisplayName,
GetLogDisplayName(handlerInterface),
implementationLogName));
}
return new HandlerCandidateAnalysis(
implementationTypeDisplayName,
registrations.MoveToImmutable(),
false);
}
private static void Execute(SourceProductionContext context, bool generationEnabled,
ImmutableArray<HandlerCandidateAnalysis?> candidates)
{
if (!generationEnabled)
return;
}
var registrations = CollectRegistrations(
compilation.Assembly.GlobalNamespace,
requestHandlerType,
notificationHandlerType,
streamHandlerType,
out var hasUnsupportedConcreteHandler);
var registrations = CollectRegistrations(candidates, out var hasUnsupportedConcreteHandler);
// If the assembly contains handlers that generated code cannot legally reference
// (for example private nested handlers), keep the runtime on the reflection path
@ -67,50 +122,33 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
}
private static List<HandlerRegistrationSpec> CollectRegistrations(
INamespaceSymbol rootNamespace,
INamedTypeSymbol requestHandlerType,
INamedTypeSymbol notificationHandlerType,
INamedTypeSymbol streamHandlerType,
ImmutableArray<HandlerCandidateAnalysis?> candidates,
out bool hasUnsupportedConcreteHandler)
{
var registrations = new List<HandlerRegistrationSpec>();
hasUnsupportedConcreteHandler = false;
foreach (var type in EnumerateTypes(rootNamespace))
// Partial declarations surface the same symbol through multiple syntax nodes.
// Collapse them by implementation type so generated registrations stay stable and duplicate-free.
var uniqueCandidates = new Dictionary<string, HandlerCandidateAnalysis>(StringComparer.Ordinal);
foreach (var candidate in candidates)
{
if (!IsConcreteHandlerType(type))
if (candidate is null)
continue;
var handlerInterfaces = type.AllInterfaces
.Where(interfaceType => IsSupportedHandlerInterface(
interfaceType,
requestHandlerType,
notificationHandlerType,
streamHandlerType))
.OrderBy(GetTypeSortKey, StringComparer.Ordinal)
.ToList();
if (handlerInterfaces.Count == 0)
continue;
if (!CanReferenceFromGeneratedRegistry(type) ||
handlerInterfaces.Any(interfaceType => !CanReferenceFromGeneratedRegistry(interfaceType)))
if (candidate.Value.HasUnsupportedConcreteHandler)
{
hasUnsupportedConcreteHandler = true;
return [];
}
var implementationTypeDisplayName = type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat);
var implementationLogName = GetLogDisplayName(type);
foreach (var handlerInterface in handlerInterfaces)
{
registrations.Add(new HandlerRegistrationSpec(
handlerInterface.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat),
implementationTypeDisplayName,
GetLogDisplayName(handlerInterface),
implementationLogName));
uniqueCandidates[candidate.Value.ImplementationTypeDisplayName] = candidate.Value;
}
foreach (var candidate in uniqueCandidates.Values)
{
registrations.AddRange(candidate.Registrations);
}
registrations.Sort(static (left, right) =>
@ -127,38 +165,6 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
return registrations;
}
private static IEnumerable<INamedTypeSymbol> EnumerateTypes(INamespaceSymbol namespaceSymbol)
{
foreach (var member in namespaceSymbol.GetMembers())
{
switch (member)
{
case INamespaceSymbol childNamespace:
foreach (var type in EnumerateTypes(childNamespace))
yield return type;
break;
case INamedTypeSymbol namedType:
foreach (var type in EnumerateTypes(namedType))
yield return type;
break;
}
}
}
private static IEnumerable<INamedTypeSymbol> EnumerateTypes(INamedTypeSymbol typeSymbol)
{
yield return typeSymbol;
foreach (var nestedType in typeSymbol.GetTypeMembers())
{
foreach (var descendant in EnumerateTypes(nestedType))
yield return descendant;
}
}
private static bool IsConcreteHandlerType(INamedTypeSymbol type)
{
return type.TypeKind is TypeKind.Class or TypeKind.Struct &&
@ -177,19 +183,15 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
return false;
}
private static bool IsSupportedHandlerInterface(
INamedTypeSymbol interfaceType,
INamedTypeSymbol requestHandlerType,
INamedTypeSymbol notificationHandlerType,
INamedTypeSymbol streamHandlerType)
private static bool IsSupportedHandlerInterface(INamedTypeSymbol interfaceType)
{
if (!interfaceType.IsGenericType)
return false;
var definition = interfaceType.OriginalDefinition;
return SymbolEqualityComparer.Default.Equals(definition, requestHandlerType) ||
SymbolEqualityComparer.Default.Equals(definition, notificationHandlerType) ||
SymbolEqualityComparer.Default.Equals(definition, streamHandlerType);
var definitionMetadataName = GetFullyQualifiedMetadataName(interfaceType.OriginalDefinition);
return string.Equals(definitionMetadataName, IRequestHandlerMetadataName, StringComparison.Ordinal) ||
string.Equals(definitionMetadataName, INotificationHandlerMetadataName, StringComparison.Ordinal) ||
string.Equals(definitionMetadataName, IStreamRequestHandlerMetadataName, StringComparison.Ordinal);
}
private static bool CanReferenceFromGeneratedRegistry(ITypeSymbol type)
@ -229,6 +231,31 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
or Accessibility.ProtectedOrInternal;
}
private static string GetFullyQualifiedMetadataName(INamedTypeSymbol type)
{
var nestedTypes = new Stack<string>();
for (var current = type; current is not null; current = current.ContainingType)
{
nestedTypes.Push(current.MetadataName);
}
var builder = new StringBuilder();
if (!type.ContainingNamespace.IsGlobalNamespace)
{
builder.Append(type.ContainingNamespace.ToDisplayString());
builder.Append('.');
}
while (nestedTypes.Count > 0)
{
builder.Append(nestedTypes.Pop());
if (nestedTypes.Count > 0)
builder.Append('.');
}
return builder.ToString();
}
private static string GetTypeSortKey(ITypeSymbol type)
{
return type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat);
@ -300,7 +327,9 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
private static string EscapeStringLiteral(string value)
{
return value.Replace("\\", "\\\\")
.Replace("\"", "\\\"");
.Replace("\"", "\\\"")
.Replace("\n", "\\n")
.Replace("\r", "\\r");
}
private readonly record struct HandlerRegistrationSpec(
@ -308,4 +337,62 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
string ImplementationTypeDisplayName,
string HandlerInterfaceLogName,
string ImplementationLogName);
private readonly struct HandlerCandidateAnalysis : IEquatable<HandlerCandidateAnalysis>
{
public HandlerCandidateAnalysis(
string implementationTypeDisplayName,
ImmutableArray<HandlerRegistrationSpec> registrations,
bool hasUnsupportedConcreteHandler)
{
ImplementationTypeDisplayName = implementationTypeDisplayName;
Registrations = registrations;
HasUnsupportedConcreteHandler = hasUnsupportedConcreteHandler;
}
public string ImplementationTypeDisplayName { get; }
public ImmutableArray<HandlerRegistrationSpec> Registrations { get; }
public bool HasUnsupportedConcreteHandler { get; }
public bool Equals(HandlerCandidateAnalysis other)
{
if (!string.Equals(ImplementationTypeDisplayName, other.ImplementationTypeDisplayName,
StringComparison.Ordinal) ||
HasUnsupportedConcreteHandler != other.HasUnsupportedConcreteHandler ||
Registrations.Length != other.Registrations.Length)
{
return false;
}
for (var index = 0; index < Registrations.Length; index++)
{
if (!Registrations[index].Equals(other.Registrations[index]))
return false;
}
return true;
}
public override bool Equals(object? obj)
{
return obj is HandlerCandidateAnalysis other && Equals(other);
}
public override int GetHashCode()
{
unchecked
{
var hashCode = StringComparer.Ordinal.GetHashCode(ImplementationTypeDisplayName);
hashCode = (hashCode * 397) ^ HasUnsupportedConcreteHandler.GetHashCode();
foreach (var registration in Registrations)
{
hashCode = (hashCode * 397) ^ registration.GetHashCode();
}
return hashCode;
}
}
}
}