using System.Linq; using System.Text; using GFramework.SourceGenerators.Common.diagnostics; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; namespace GFramework.SourceGenerators.rule; [Generator] public sealed class ContextAwareGenerator : IIncrementalGenerator { private const string AttributeMetadataName = "GFramework.SourceGenerators.Attributes.rule.ContextAwareAttribute"; public void Initialize(IncrementalGeneratorInitializationContext context) { // 1️⃣ 查找候选类 var candidates = context.SyntaxProvider .CreateSyntaxProvider( predicate: static (node, _) => node is ClassDeclarationSyntax, transform: static (ctx, _) => GetCandidate(ctx) ) .Where(static s => s is not null); // 2️⃣ 注册生成输出 context.RegisterSourceOutput(candidates, static (spc, symbol) => { if (symbol != null) GenerateOutput(spc, symbol); }); } #region 候选类查找 private static INamedTypeSymbol? GetCandidate(GeneratorSyntaxContext context) { if (context.SemanticModel.GetDeclaredSymbol(context.Node) is not INamedTypeSymbol symbol) return null; // 仅筛选带有 ContextAwareAttribute 的类 var hasAttr = symbol.GetAttributes() .Any(attr => attr.AttributeClass?.ToDisplayString() == AttributeMetadataName); return hasAttr ? symbol : null; } #endregion #region 输出生成 + 诊断 private static void GenerateOutput(SourceProductionContext context, INamedTypeSymbol symbol) { var syntax = symbol.DeclaringSyntaxReferences.FirstOrDefault()?.GetSyntax() as ClassDeclarationSyntax; if (syntax == null) return; // 1️⃣ 必须是 partial if (!syntax.Modifiers.Any(SyntaxKind.PartialKeyword)) { context.ReportDiagnostic(Diagnostic.Create( CommonDiagnostics.ClassMustBePartial, syntax.Identifier.GetLocation(), symbol.Name )); return; } // 2️⃣ 必须实现 IContextAware(直接或间接) if (!symbol.AllInterfaces.Any(i => i.ToDisplayString() == "GFramework.Core.rule.IContextAware")) { context.ReportDiagnostic(Diagnostic.Create( ContextAwareDiagnostic.ClassMustImplementIContextAware, syntax.Identifier.GetLocation(), symbol.Name )); return; } // 3️⃣ 生成源码 var ns = symbol.ContainingNamespace.IsGlobalNamespace ? null : symbol.ContainingNamespace.ToDisplayString(); var source = GenerateSource(ns, symbol); context .AddSource( $"{symbol.Name}.ContextAware.g.cs", source); } #endregion #region 源码生成 private static string GenerateSource(string? ns, INamedTypeSymbol symbol) { var sb = new StringBuilder(); sb.AppendLine("// "); sb.AppendLine("#nullable enable"); if (ns is not null) { sb.AppendLine($"namespace {ns};"); sb.AppendLine(); } sb.AppendLine($"partial class {symbol.Name}"); sb.AppendLine("{"); sb.AppendLine( " protected GFramework.Core.architecture.IArchitectureContext Context { get; private set; } = null!;"); sb.AppendLine(); sb.AppendLine(" void GFramework.Core.rule.IContextAware.SetContext("); sb.AppendLine(" GFramework.Core.architecture.IArchitectureContext context)"); sb.AppendLine(" {"); sb.AppendLine(" Context = context;"); sb.AppendLine(" }"); sb.AppendLine("}"); return sb.ToString().TrimEnd(); } #endregion }