diff --git a/GFramework.SourceGenerators.Tests/Cqrs/CqrsHandlerRegistryGeneratorTests.cs b/GFramework.SourceGenerators.Tests/Cqrs/CqrsHandlerRegistryGeneratorTests.cs
index b04d8165..e2b7ffb1 100644
--- a/GFramework.SourceGenerators.Tests/Cqrs/CqrsHandlerRegistryGeneratorTests.cs
+++ b/GFramework.SourceGenerators.Tests/Cqrs/CqrsHandlerRegistryGeneratorTests.cs
@@ -1,3 +1,4 @@
+using System.Reflection;
using GFramework.SourceGenerators.Cqrs;
using GFramework.SourceGenerators.Tests.Core;
@@ -192,4 +193,23 @@ public class CqrsHandlerRegistryGeneratorTests
await test.RunAsync();
}
+
+ ///
+ /// 验证日志字符串转义会覆盖换行、反斜杠和双引号,避免生成代码中的字符串字面量被意外截断。
+ ///
+ [Test]
+ public void Escape_String_Literal_Handles_Control_Characters()
+ {
+ var method = typeof(CqrsHandlerRegistryGenerator).GetMethod(
+ "EscapeStringLiteral",
+ BindingFlags.NonPublic | BindingFlags.Static);
+
+ Assert.That(method, Is.Not.Null);
+
+ const string input = "line1\r\nline2\\\"";
+ const string expected = "line1\\r\\nline2\\\\\\\"";
+ var escaped = method!.Invoke(null, [input]) as string;
+
+ Assert.That(escaped, Is.EqualTo(expected));
+ }
}
diff --git a/GFramework.SourceGenerators/Cqrs/CqrsHandlerRegistryGenerator.cs b/GFramework.SourceGenerators/Cqrs/CqrsHandlerRegistryGenerator.cs
index 59db7c77..39c3aadb 100644
--- a/GFramework.SourceGenerators/Cqrs/CqrsHandlerRegistryGenerator.cs
+++ b/GFramework.SourceGenerators/Cqrs/CqrsHandlerRegistryGenerator.cs
@@ -24,38 +24,93 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
///
public void Initialize(IncrementalGeneratorInitializationContext context)
{
+ var generationEnabled = context.CompilationProvider
+ .Select(static (compilation, _) => HasRequiredTypes(compilation));
+
+ // Restrict semantic analysis to type declarations that can actually contribute implemented interfaces.
+ var handlerCandidates = context.SyntaxProvider.CreateSyntaxProvider(
+ static (node, _) => IsHandlerCandidate(node),
+ static (syntaxContext, _) => TransformHandlerCandidate(syntaxContext))
+ .Where(static candidate => candidate is not null)
+ .Collect();
+
context.RegisterSourceOutput(
- context.CompilationProvider,
- static (productionContext, compilation) => Execute(productionContext, compilation));
+ generationEnabled.Combine(handlerCandidates),
+ static (productionContext, pair) => Execute(productionContext, pair.Left, pair.Right));
}
- private static void Execute(SourceProductionContext context, Compilation compilation)
+ private static bool HasRequiredTypes(Compilation compilation)
{
- var requestHandlerType = compilation.GetTypeByMetadataName(IRequestHandlerMetadataName);
- var notificationHandlerType = compilation.GetTypeByMetadataName(INotificationHandlerMetadataName);
- var streamHandlerType = compilation.GetTypeByMetadataName(IStreamRequestHandlerMetadataName);
- var registryInterfaceType = compilation.GetTypeByMetadataName(ICqrsHandlerRegistryMetadataName);
- var registryAttributeType = compilation.GetTypeByMetadataName(CqrsHandlerRegistryAttributeMetadataName);
- var loggerType = compilation.GetTypeByMetadataName(ILoggerMetadataName);
- var serviceCollectionType = compilation.GetTypeByMetadataName(IServiceCollectionMetadataName);
+ return compilation.GetTypeByMetadataName(IRequestHandlerMetadataName) is not null &&
+ compilation.GetTypeByMetadataName(INotificationHandlerMetadataName) is not null &&
+ compilation.GetTypeByMetadataName(IStreamRequestHandlerMetadataName) is not null &&
+ compilation.GetTypeByMetadataName(ICqrsHandlerRegistryMetadataName) is not null &&
+ compilation.GetTypeByMetadataName(CqrsHandlerRegistryAttributeMetadataName) is not null &&
+ compilation.GetTypeByMetadataName(ILoggerMetadataName) is not null &&
+ compilation.GetTypeByMetadataName(IServiceCollectionMetadataName) is not null;
+ }
- if (requestHandlerType is null ||
- notificationHandlerType is null ||
- streamHandlerType is null ||
- registryInterfaceType is null ||
- registryAttributeType is null ||
- loggerType is null ||
- serviceCollectionType is null)
+ private static bool IsHandlerCandidate(SyntaxNode node)
+ {
+ return node is TypeDeclarationSyntax
{
- return;
+ BaseList.Types.Count: > 0
+ };
+ }
+
+ private static HandlerCandidateAnalysis? TransformHandlerCandidate(GeneratorSyntaxContext context)
+ {
+ if (context.Node is not TypeDeclarationSyntax typeDeclaration)
+ return null;
+
+ if (context.SemanticModel.GetDeclaredSymbol(typeDeclaration) is not INamedTypeSymbol type)
+ return null;
+
+ if (!IsConcreteHandlerType(type))
+ return null;
+
+ var handlerInterfaces = type.AllInterfaces
+ .Where(IsSupportedHandlerInterface)
+ .OrderBy(GetTypeSortKey, StringComparer.Ordinal)
+ .ToImmutableArray();
+
+ if (handlerInterfaces.IsDefaultOrEmpty)
+ return null;
+
+ var implementationTypeDisplayName = type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat);
+ if (!CanReferenceFromGeneratedRegistry(type) ||
+ handlerInterfaces.Any(interfaceType => !CanReferenceFromGeneratedRegistry(interfaceType)))
+ {
+ return new HandlerCandidateAnalysis(
+ implementationTypeDisplayName,
+ ImmutableArray.Empty,
+ true);
}
- var registrations = CollectRegistrations(
- compilation.Assembly.GlobalNamespace,
- requestHandlerType,
- notificationHandlerType,
- streamHandlerType,
- out var hasUnsupportedConcreteHandler);
+ var implementationLogName = GetLogDisplayName(type);
+ var registrations = ImmutableArray.CreateBuilder(handlerInterfaces.Length);
+ foreach (var handlerInterface in handlerInterfaces)
+ {
+ registrations.Add(new HandlerRegistrationSpec(
+ handlerInterface.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat),
+ implementationTypeDisplayName,
+ GetLogDisplayName(handlerInterface),
+ implementationLogName));
+ }
+
+ return new HandlerCandidateAnalysis(
+ implementationTypeDisplayName,
+ registrations.MoveToImmutable(),
+ false);
+ }
+
+ private static void Execute(SourceProductionContext context, bool generationEnabled,
+ ImmutableArray candidates)
+ {
+ if (!generationEnabled)
+ return;
+
+ var registrations = CollectRegistrations(candidates, out var hasUnsupportedConcreteHandler);
// If the assembly contains handlers that generated code cannot legally reference
// (for example private nested handlers), keep the runtime on the reflection path
@@ -67,50 +122,33 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
}
private static List CollectRegistrations(
- INamespaceSymbol rootNamespace,
- INamedTypeSymbol requestHandlerType,
- INamedTypeSymbol notificationHandlerType,
- INamedTypeSymbol streamHandlerType,
+ ImmutableArray candidates,
out bool hasUnsupportedConcreteHandler)
{
var registrations = new List();
hasUnsupportedConcreteHandler = false;
- foreach (var type in EnumerateTypes(rootNamespace))
+ // Partial declarations surface the same symbol through multiple syntax nodes.
+ // Collapse them by implementation type so generated registrations stay stable and duplicate-free.
+ var uniqueCandidates = new Dictionary(StringComparer.Ordinal);
+
+ foreach (var candidate in candidates)
{
- if (!IsConcreteHandlerType(type))
+ if (candidate is null)
continue;
- var handlerInterfaces = type.AllInterfaces
- .Where(interfaceType => IsSupportedHandlerInterface(
- interfaceType,
- requestHandlerType,
- notificationHandlerType,
- streamHandlerType))
- .OrderBy(GetTypeSortKey, StringComparer.Ordinal)
- .ToList();
-
- if (handlerInterfaces.Count == 0)
- continue;
-
- if (!CanReferenceFromGeneratedRegistry(type) ||
- handlerInterfaces.Any(interfaceType => !CanReferenceFromGeneratedRegistry(interfaceType)))
+ if (candidate.Value.HasUnsupportedConcreteHandler)
{
hasUnsupportedConcreteHandler = true;
return [];
}
- var implementationTypeDisplayName = type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat);
- var implementationLogName = GetLogDisplayName(type);
+ uniqueCandidates[candidate.Value.ImplementationTypeDisplayName] = candidate.Value;
+ }
- foreach (var handlerInterface in handlerInterfaces)
- {
- registrations.Add(new HandlerRegistrationSpec(
- handlerInterface.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat),
- implementationTypeDisplayName,
- GetLogDisplayName(handlerInterface),
- implementationLogName));
- }
+ foreach (var candidate in uniqueCandidates.Values)
+ {
+ registrations.AddRange(candidate.Registrations);
}
registrations.Sort(static (left, right) =>
@@ -127,38 +165,6 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
return registrations;
}
- private static IEnumerable EnumerateTypes(INamespaceSymbol namespaceSymbol)
- {
- foreach (var member in namespaceSymbol.GetMembers())
- {
- switch (member)
- {
- case INamespaceSymbol childNamespace:
- foreach (var type in EnumerateTypes(childNamespace))
- yield return type;
-
- break;
-
- case INamedTypeSymbol namedType:
- foreach (var type in EnumerateTypes(namedType))
- yield return type;
-
- break;
- }
- }
- }
-
- private static IEnumerable EnumerateTypes(INamedTypeSymbol typeSymbol)
- {
- yield return typeSymbol;
-
- foreach (var nestedType in typeSymbol.GetTypeMembers())
- {
- foreach (var descendant in EnumerateTypes(nestedType))
- yield return descendant;
- }
- }
-
private static bool IsConcreteHandlerType(INamedTypeSymbol type)
{
return type.TypeKind is TypeKind.Class or TypeKind.Struct &&
@@ -177,19 +183,15 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
return false;
}
- private static bool IsSupportedHandlerInterface(
- INamedTypeSymbol interfaceType,
- INamedTypeSymbol requestHandlerType,
- INamedTypeSymbol notificationHandlerType,
- INamedTypeSymbol streamHandlerType)
+ private static bool IsSupportedHandlerInterface(INamedTypeSymbol interfaceType)
{
if (!interfaceType.IsGenericType)
return false;
- var definition = interfaceType.OriginalDefinition;
- return SymbolEqualityComparer.Default.Equals(definition, requestHandlerType) ||
- SymbolEqualityComparer.Default.Equals(definition, notificationHandlerType) ||
- SymbolEqualityComparer.Default.Equals(definition, streamHandlerType);
+ var definitionMetadataName = GetFullyQualifiedMetadataName(interfaceType.OriginalDefinition);
+ return string.Equals(definitionMetadataName, IRequestHandlerMetadataName, StringComparison.Ordinal) ||
+ string.Equals(definitionMetadataName, INotificationHandlerMetadataName, StringComparison.Ordinal) ||
+ string.Equals(definitionMetadataName, IStreamRequestHandlerMetadataName, StringComparison.Ordinal);
}
private static bool CanReferenceFromGeneratedRegistry(ITypeSymbol type)
@@ -229,6 +231,31 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
or Accessibility.ProtectedOrInternal;
}
+ private static string GetFullyQualifiedMetadataName(INamedTypeSymbol type)
+ {
+ var nestedTypes = new Stack();
+ for (var current = type; current is not null; current = current.ContainingType)
+ {
+ nestedTypes.Push(current.MetadataName);
+ }
+
+ var builder = new StringBuilder();
+ if (!type.ContainingNamespace.IsGlobalNamespace)
+ {
+ builder.Append(type.ContainingNamespace.ToDisplayString());
+ builder.Append('.');
+ }
+
+ while (nestedTypes.Count > 0)
+ {
+ builder.Append(nestedTypes.Pop());
+ if (nestedTypes.Count > 0)
+ builder.Append('.');
+ }
+
+ return builder.ToString();
+ }
+
private static string GetTypeSortKey(ITypeSymbol type)
{
return type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat);
@@ -300,7 +327,9 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
private static string EscapeStringLiteral(string value)
{
return value.Replace("\\", "\\\\")
- .Replace("\"", "\\\"");
+ .Replace("\"", "\\\"")
+ .Replace("\n", "\\n")
+ .Replace("\r", "\\r");
}
private readonly record struct HandlerRegistrationSpec(
@@ -308,4 +337,62 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
string ImplementationTypeDisplayName,
string HandlerInterfaceLogName,
string ImplementationLogName);
+
+ private readonly struct HandlerCandidateAnalysis : IEquatable
+ {
+ public HandlerCandidateAnalysis(
+ string implementationTypeDisplayName,
+ ImmutableArray registrations,
+ bool hasUnsupportedConcreteHandler)
+ {
+ ImplementationTypeDisplayName = implementationTypeDisplayName;
+ Registrations = registrations;
+ HasUnsupportedConcreteHandler = hasUnsupportedConcreteHandler;
+ }
+
+ public string ImplementationTypeDisplayName { get; }
+
+ public ImmutableArray Registrations { get; }
+
+ public bool HasUnsupportedConcreteHandler { get; }
+
+ public bool Equals(HandlerCandidateAnalysis other)
+ {
+ if (!string.Equals(ImplementationTypeDisplayName, other.ImplementationTypeDisplayName,
+ StringComparison.Ordinal) ||
+ HasUnsupportedConcreteHandler != other.HasUnsupportedConcreteHandler ||
+ Registrations.Length != other.Registrations.Length)
+ {
+ return false;
+ }
+
+ for (var index = 0; index < Registrations.Length; index++)
+ {
+ if (!Registrations[index].Equals(other.Registrations[index]))
+ return false;
+ }
+
+ return true;
+ }
+
+ public override bool Equals(object? obj)
+ {
+ return obj is HandlerCandidateAnalysis other && Equals(other);
+ }
+
+ public override int GetHashCode()
+ {
+ unchecked
+ {
+ var hashCode = StringComparer.Ordinal.GetHashCode(ImplementationTypeDisplayName);
+ hashCode = (hashCode * 397) ^ HasUnsupportedConcreteHandler.GetHashCode();
+ foreach (var registration in Registrations)
+ {
+ hashCode = (hashCode * 397) ^ registration.GetHashCode();
+ }
+
+ return hashCode;
+ }
+ }
+ }
}