using System.Linq; using System.Text; using GFramework.SourceGenerators.Common.diagnostics; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; namespace GFramework.SourceGenerators.rule; [Generator] public sealed class ContextAwareGenerator : IIncrementalGenerator { private const string AttributeMetadataName = "GFramework.SourceGenerators.Attributes.rule.ContextAwareAttribute"; public void Initialize(IncrementalGeneratorInitializationContext context) { // 1. 找到所有 class 声明 var classDeclarations = context.SyntaxProvider .CreateSyntaxProvider( predicate: static (node, _) => node is ClassDeclarationSyntax, transform: static (ctx, _) => GetCandidate(ctx) ) .Where(static s => s is not null); // 2. 生成代码 context.RegisterSourceOutput( classDeclarations, static (spc, source) => Generate(spc, source!) ); } private static INamedTypeSymbol? GetCandidate(GeneratorSyntaxContext context) { var classDecl = (ClassDeclarationSyntax)context.Node; if (classDecl.AttributeLists.Count == 0) return null; if (context.SemanticModel.GetDeclaredSymbol(classDecl) is not { } symbol) return null; return Enumerable.Any(symbol.GetAttributes(), attr => attr.AttributeClass?.ToDisplayString() == AttributeMetadataName) ? symbol : null; } private static void Generate( SourceProductionContext context, INamedTypeSymbol symbol) { var syntax = symbol.DeclaringSyntaxReferences .FirstOrDefault()? .GetSyntax() as ClassDeclarationSyntax; if (syntax is null || !syntax.Modifiers.Any(SyntaxKind.PartialKeyword)) { context.ReportDiagnostic( Diagnostic.Create( CommonDiagnostics.ClassMustBePartial, syntax?.Identifier.GetLocation(), symbol.Name ) ); return; } var ns = symbol.ContainingNamespace.IsGlobalNamespace ? null : symbol.ContainingNamespace.ToDisplayString(); var source = GenerateSource(ns, symbol); context.AddSource( $"{symbol.Name}.ContextAware.g.cs", source ); } private static string GenerateSource(string? ns, INamedTypeSymbol symbol) { var sb = new StringBuilder(); sb.AppendLine("// "); sb.AppendLine("#nullable enable"); if (ns is not null) { sb.AppendLine($"namespace {ns};"); sb.AppendLine(); } sb.AppendLine($"partial class {symbol.Name} : GFramework.Core.rule.IContextAware"); sb.AppendLine("{"); sb.AppendLine( " protected GFramework.Core.architecture.IArchitectureContext Context { get; private set; } = null!;"); sb.AppendLine(""" void GFramework.Core.rule.IContextAware.SetContext( GFramework.Core.architecture.IArchitectureContext context) { Context = context; } """); sb.AppendLine("}"); return sb.ToString(); } }