From 7a6f966601d235abbd20fb2b8470409b6636ddec Mon Sep 17 00:00:00 2001 From: GeWuYou <95328647+GeWuYou@users.noreply.github.com> Date: Wed, 15 Apr 2026 11:12:36 +0800 Subject: [PATCH] =?UTF-8?q?feat(cqrs):=20=E6=B7=BB=E5=8A=A0=20CQRS=20?= =?UTF-8?q?=E5=A4=84=E7=90=86=E5=99=A8=E6=B3=A8=E5=86=8C=E7=94=9F=E6=88=90?= =?UTF-8?q?=E5=99=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 实现 CqrsHandlerRegistryGenerator 源代码生成器 - 支持 IRequestHandler、INotificationHandler 和 IStreamRequestHandler 接口的处理器注册 - 生成程序集级别的 CQRS 处理器注册器以减少运行时反射开销 - 添加对请求、通知和流处理器的稳定顺序注册支持 - 实现对私有嵌套处理器的检测和回退机制 - 提供字符串字面量转义功能以避免生成代码中的语法错误 - 添加完整的单元测试验证生成器的功能和边界条件 --- .../Cqrs/CqrsHandlerRegistryGeneratorTests.cs | 20 ++ .../Cqrs/CqrsHandlerRegistryGenerator.cs | 279 ++++++++++++------ 2 files changed, 203 insertions(+), 96 deletions(-) diff --git a/GFramework.SourceGenerators.Tests/Cqrs/CqrsHandlerRegistryGeneratorTests.cs b/GFramework.SourceGenerators.Tests/Cqrs/CqrsHandlerRegistryGeneratorTests.cs index b04d8165..e2b7ffb1 100644 --- a/GFramework.SourceGenerators.Tests/Cqrs/CqrsHandlerRegistryGeneratorTests.cs +++ b/GFramework.SourceGenerators.Tests/Cqrs/CqrsHandlerRegistryGeneratorTests.cs @@ -1,3 +1,4 @@ +using System.Reflection; using GFramework.SourceGenerators.Cqrs; using GFramework.SourceGenerators.Tests.Core; @@ -192,4 +193,23 @@ public class CqrsHandlerRegistryGeneratorTests await test.RunAsync(); } + + /// + /// 验证日志字符串转义会覆盖换行、反斜杠和双引号,避免生成代码中的字符串字面量被意外截断。 + /// + [Test] + public void Escape_String_Literal_Handles_Control_Characters() + { + var method = typeof(CqrsHandlerRegistryGenerator).GetMethod( + "EscapeStringLiteral", + BindingFlags.NonPublic | BindingFlags.Static); + + Assert.That(method, Is.Not.Null); + + const string input = "line1\r\nline2\\\""; + const string expected = "line1\\r\\nline2\\\\\\\""; + var escaped = method!.Invoke(null, [input]) as string; + + Assert.That(escaped, Is.EqualTo(expected)); + } } diff --git a/GFramework.SourceGenerators/Cqrs/CqrsHandlerRegistryGenerator.cs b/GFramework.SourceGenerators/Cqrs/CqrsHandlerRegistryGenerator.cs index 59db7c77..39c3aadb 100644 --- a/GFramework.SourceGenerators/Cqrs/CqrsHandlerRegistryGenerator.cs +++ b/GFramework.SourceGenerators/Cqrs/CqrsHandlerRegistryGenerator.cs @@ -24,38 +24,93 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator /// public void Initialize(IncrementalGeneratorInitializationContext context) { + var generationEnabled = context.CompilationProvider + .Select(static (compilation, _) => HasRequiredTypes(compilation)); + + // Restrict semantic analysis to type declarations that can actually contribute implemented interfaces. + var handlerCandidates = context.SyntaxProvider.CreateSyntaxProvider( + static (node, _) => IsHandlerCandidate(node), + static (syntaxContext, _) => TransformHandlerCandidate(syntaxContext)) + .Where(static candidate => candidate is not null) + .Collect(); + context.RegisterSourceOutput( - context.CompilationProvider, - static (productionContext, compilation) => Execute(productionContext, compilation)); + generationEnabled.Combine(handlerCandidates), + static (productionContext, pair) => Execute(productionContext, pair.Left, pair.Right)); } - private static void Execute(SourceProductionContext context, Compilation compilation) + private static bool HasRequiredTypes(Compilation compilation) { - var requestHandlerType = compilation.GetTypeByMetadataName(IRequestHandlerMetadataName); - var notificationHandlerType = compilation.GetTypeByMetadataName(INotificationHandlerMetadataName); - var streamHandlerType = compilation.GetTypeByMetadataName(IStreamRequestHandlerMetadataName); - var registryInterfaceType = compilation.GetTypeByMetadataName(ICqrsHandlerRegistryMetadataName); - var registryAttributeType = compilation.GetTypeByMetadataName(CqrsHandlerRegistryAttributeMetadataName); - var loggerType = compilation.GetTypeByMetadataName(ILoggerMetadataName); - var serviceCollectionType = compilation.GetTypeByMetadataName(IServiceCollectionMetadataName); + return compilation.GetTypeByMetadataName(IRequestHandlerMetadataName) is not null && + compilation.GetTypeByMetadataName(INotificationHandlerMetadataName) is not null && + compilation.GetTypeByMetadataName(IStreamRequestHandlerMetadataName) is not null && + compilation.GetTypeByMetadataName(ICqrsHandlerRegistryMetadataName) is not null && + compilation.GetTypeByMetadataName(CqrsHandlerRegistryAttributeMetadataName) is not null && + compilation.GetTypeByMetadataName(ILoggerMetadataName) is not null && + compilation.GetTypeByMetadataName(IServiceCollectionMetadataName) is not null; + } - if (requestHandlerType is null || - notificationHandlerType is null || - streamHandlerType is null || - registryInterfaceType is null || - registryAttributeType is null || - loggerType is null || - serviceCollectionType is null) + private static bool IsHandlerCandidate(SyntaxNode node) + { + return node is TypeDeclarationSyntax { - return; + BaseList.Types.Count: > 0 + }; + } + + private static HandlerCandidateAnalysis? TransformHandlerCandidate(GeneratorSyntaxContext context) + { + if (context.Node is not TypeDeclarationSyntax typeDeclaration) + return null; + + if (context.SemanticModel.GetDeclaredSymbol(typeDeclaration) is not INamedTypeSymbol type) + return null; + + if (!IsConcreteHandlerType(type)) + return null; + + var handlerInterfaces = type.AllInterfaces + .Where(IsSupportedHandlerInterface) + .OrderBy(GetTypeSortKey, StringComparer.Ordinal) + .ToImmutableArray(); + + if (handlerInterfaces.IsDefaultOrEmpty) + return null; + + var implementationTypeDisplayName = type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + if (!CanReferenceFromGeneratedRegistry(type) || + handlerInterfaces.Any(interfaceType => !CanReferenceFromGeneratedRegistry(interfaceType))) + { + return new HandlerCandidateAnalysis( + implementationTypeDisplayName, + ImmutableArray.Empty, + true); } - var registrations = CollectRegistrations( - compilation.Assembly.GlobalNamespace, - requestHandlerType, - notificationHandlerType, - streamHandlerType, - out var hasUnsupportedConcreteHandler); + var implementationLogName = GetLogDisplayName(type); + var registrations = ImmutableArray.CreateBuilder(handlerInterfaces.Length); + foreach (var handlerInterface in handlerInterfaces) + { + registrations.Add(new HandlerRegistrationSpec( + handlerInterface.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat), + implementationTypeDisplayName, + GetLogDisplayName(handlerInterface), + implementationLogName)); + } + + return new HandlerCandidateAnalysis( + implementationTypeDisplayName, + registrations.MoveToImmutable(), + false); + } + + private static void Execute(SourceProductionContext context, bool generationEnabled, + ImmutableArray candidates) + { + if (!generationEnabled) + return; + + var registrations = CollectRegistrations(candidates, out var hasUnsupportedConcreteHandler); // If the assembly contains handlers that generated code cannot legally reference // (for example private nested handlers), keep the runtime on the reflection path @@ -67,50 +122,33 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator } private static List CollectRegistrations( - INamespaceSymbol rootNamespace, - INamedTypeSymbol requestHandlerType, - INamedTypeSymbol notificationHandlerType, - INamedTypeSymbol streamHandlerType, + ImmutableArray candidates, out bool hasUnsupportedConcreteHandler) { var registrations = new List(); hasUnsupportedConcreteHandler = false; - foreach (var type in EnumerateTypes(rootNamespace)) + // Partial declarations surface the same symbol through multiple syntax nodes. + // Collapse them by implementation type so generated registrations stay stable and duplicate-free. + var uniqueCandidates = new Dictionary(StringComparer.Ordinal); + + foreach (var candidate in candidates) { - if (!IsConcreteHandlerType(type)) + if (candidate is null) continue; - var handlerInterfaces = type.AllInterfaces - .Where(interfaceType => IsSupportedHandlerInterface( - interfaceType, - requestHandlerType, - notificationHandlerType, - streamHandlerType)) - .OrderBy(GetTypeSortKey, StringComparer.Ordinal) - .ToList(); - - if (handlerInterfaces.Count == 0) - continue; - - if (!CanReferenceFromGeneratedRegistry(type) || - handlerInterfaces.Any(interfaceType => !CanReferenceFromGeneratedRegistry(interfaceType))) + if (candidate.Value.HasUnsupportedConcreteHandler) { hasUnsupportedConcreteHandler = true; return []; } - var implementationTypeDisplayName = type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); - var implementationLogName = GetLogDisplayName(type); + uniqueCandidates[candidate.Value.ImplementationTypeDisplayName] = candidate.Value; + } - foreach (var handlerInterface in handlerInterfaces) - { - registrations.Add(new HandlerRegistrationSpec( - handlerInterface.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat), - implementationTypeDisplayName, - GetLogDisplayName(handlerInterface), - implementationLogName)); - } + foreach (var candidate in uniqueCandidates.Values) + { + registrations.AddRange(candidate.Registrations); } registrations.Sort(static (left, right) => @@ -127,38 +165,6 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator return registrations; } - private static IEnumerable EnumerateTypes(INamespaceSymbol namespaceSymbol) - { - foreach (var member in namespaceSymbol.GetMembers()) - { - switch (member) - { - case INamespaceSymbol childNamespace: - foreach (var type in EnumerateTypes(childNamespace)) - yield return type; - - break; - - case INamedTypeSymbol namedType: - foreach (var type in EnumerateTypes(namedType)) - yield return type; - - break; - } - } - } - - private static IEnumerable EnumerateTypes(INamedTypeSymbol typeSymbol) - { - yield return typeSymbol; - - foreach (var nestedType in typeSymbol.GetTypeMembers()) - { - foreach (var descendant in EnumerateTypes(nestedType)) - yield return descendant; - } - } - private static bool IsConcreteHandlerType(INamedTypeSymbol type) { return type.TypeKind is TypeKind.Class or TypeKind.Struct && @@ -177,19 +183,15 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator return false; } - private static bool IsSupportedHandlerInterface( - INamedTypeSymbol interfaceType, - INamedTypeSymbol requestHandlerType, - INamedTypeSymbol notificationHandlerType, - INamedTypeSymbol streamHandlerType) + private static bool IsSupportedHandlerInterface(INamedTypeSymbol interfaceType) { if (!interfaceType.IsGenericType) return false; - var definition = interfaceType.OriginalDefinition; - return SymbolEqualityComparer.Default.Equals(definition, requestHandlerType) || - SymbolEqualityComparer.Default.Equals(definition, notificationHandlerType) || - SymbolEqualityComparer.Default.Equals(definition, streamHandlerType); + var definitionMetadataName = GetFullyQualifiedMetadataName(interfaceType.OriginalDefinition); + return string.Equals(definitionMetadataName, IRequestHandlerMetadataName, StringComparison.Ordinal) || + string.Equals(definitionMetadataName, INotificationHandlerMetadataName, StringComparison.Ordinal) || + string.Equals(definitionMetadataName, IStreamRequestHandlerMetadataName, StringComparison.Ordinal); } private static bool CanReferenceFromGeneratedRegistry(ITypeSymbol type) @@ -229,6 +231,31 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator or Accessibility.ProtectedOrInternal; } + private static string GetFullyQualifiedMetadataName(INamedTypeSymbol type) + { + var nestedTypes = new Stack(); + for (var current = type; current is not null; current = current.ContainingType) + { + nestedTypes.Push(current.MetadataName); + } + + var builder = new StringBuilder(); + if (!type.ContainingNamespace.IsGlobalNamespace) + { + builder.Append(type.ContainingNamespace.ToDisplayString()); + builder.Append('.'); + } + + while (nestedTypes.Count > 0) + { + builder.Append(nestedTypes.Pop()); + if (nestedTypes.Count > 0) + builder.Append('.'); + } + + return builder.ToString(); + } + private static string GetTypeSortKey(ITypeSymbol type) { return type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); @@ -300,7 +327,9 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator private static string EscapeStringLiteral(string value) { return value.Replace("\\", "\\\\") - .Replace("\"", "\\\""); + .Replace("\"", "\\\"") + .Replace("\n", "\\n") + .Replace("\r", "\\r"); } private readonly record struct HandlerRegistrationSpec( @@ -308,4 +337,62 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator string ImplementationTypeDisplayName, string HandlerInterfaceLogName, string ImplementationLogName); + + private readonly struct HandlerCandidateAnalysis : IEquatable + { + public HandlerCandidateAnalysis( + string implementationTypeDisplayName, + ImmutableArray registrations, + bool hasUnsupportedConcreteHandler) + { + ImplementationTypeDisplayName = implementationTypeDisplayName; + Registrations = registrations; + HasUnsupportedConcreteHandler = hasUnsupportedConcreteHandler; + } + + public string ImplementationTypeDisplayName { get; } + + public ImmutableArray Registrations { get; } + + public bool HasUnsupportedConcreteHandler { get; } + + public bool Equals(HandlerCandidateAnalysis other) + { + if (!string.Equals(ImplementationTypeDisplayName, other.ImplementationTypeDisplayName, + StringComparison.Ordinal) || + HasUnsupportedConcreteHandler != other.HasUnsupportedConcreteHandler || + Registrations.Length != other.Registrations.Length) + { + return false; + } + + for (var index = 0; index < Registrations.Length; index++) + { + if (!Registrations[index].Equals(other.Registrations[index])) + return false; + } + + return true; + } + + public override bool Equals(object? obj) + { + return obj is HandlerCandidateAnalysis other && Equals(other); + } + + public override int GetHashCode() + { + unchecked + { + var hashCode = StringComparer.Ordinal.GetHashCode(ImplementationTypeDisplayName); + hashCode = (hashCode * 397) ^ HasUnsupportedConcreteHandler.GetHashCode(); + foreach (var registration in Registrations) + { + hashCode = (hashCode * 397) ^ registration.GetHashCode(); + } + + return hashCode; + } + } + } }