using System;
using System.Linq;
using GFramework.SourceGenerators.Common.diagnostics;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
namespace GFramework.SourceGenerators.Common.generator;
///
/// 属性类生成器基类,用于处理带有特定属性的类并生成相应的源代码
///
public abstract class AttributeClassGeneratorBase : IIncrementalGenerator
{
///
/// 获取属性的元数据名称
///
protected abstract Type AttributeType { get; }
///
/// Attribute 的短名称(不含 Attribute 后缀)
/// 仅用于 Syntax 层宽松匹配
///
protected abstract string AttributeShortNameWithoutSuffix { get; }
///
/// 初始化增量生成器
///
/// 增量生成器初始化上下文
public void Initialize(IncrementalGeneratorInitializationContext context)
{
var targets = context.SyntaxProvider.CreateSyntaxProvider(
(node, _) =>
node is ClassDeclarationSyntax cls &&
cls.AttributeLists
.SelectMany(a => a.Attributes)
.Any(a => a.Name.ToString()
.Contains(AttributeShortNameWithoutSuffix)),
static (ctx, t) =>
{
var cls = (ClassDeclarationSyntax)ctx.Node;
var symbol = ctx.SemanticModel.GetDeclaredSymbol(cls, t);
return (ClassDecl: cls, Symbol: symbol);
}
)
.Where(x => x.Symbol is not null);
context.RegisterSourceOutput(targets, (spc, pair) =>
{
try
{
Execute(spc, pair.ClassDecl, pair.Symbol!);
}
catch (Exception ex)
{
EmitError(spc, pair.Symbol, ex);
}
});
}
///
/// 执行源代码生成的主要逻辑
///
/// 源生产上下文
/// 类声明语法节点
/// 命名类型符号
private void Execute(
SourceProductionContext context,
ClassDeclarationSyntax classDecl,
INamedTypeSymbol symbol)
{
var attr = GetAttribute(symbol);
if (attr == null) return;
// partial 校验
if (!classDecl.Modifiers.Any(SyntaxKind.PartialKeyword))
{
ReportClassMustBePartial(context, classDecl, symbol);
return;
}
// 子类校验
if (!ValidateSymbol(context, classDecl, symbol, attr))
return;
var source = Generate(symbol, attr);
context.AddSource(GetHintName(symbol), source);
}
#region 可覆写点
///
/// 验证符号的有效性
///
/// 源生产上下文
/// 类声明语法节点
/// 命名类型符号
/// 属性数据
/// 验证是否通过
protected virtual bool ValidateSymbol(
SourceProductionContext context,
ClassDeclarationSyntax syntax,
INamedTypeSymbol symbol,
AttributeData attr)
{
return true;
}
///
/// 生成源代码
///
/// 命名类型符号
/// 属性数据
/// 生成的源代码字符串
protected abstract string Generate(
INamedTypeSymbol symbol,
AttributeData attr);
///
/// 获取生成文件的提示名称
///
/// 命名类型符号
/// 生成文件的提示名称
protected virtual string GetHintName(INamedTypeSymbol symbol)
{
return $"{symbol.Name}.g.cs";
}
#endregion
#region Attribute / Diagnostic
///
/// 获取指定符号的属性数据
///
/// 命名类型符号
/// 属性数据,如果未找到则返回null
protected virtual AttributeData? GetAttribute(INamedTypeSymbol symbol)
{
return symbol.GetAttributes().FirstOrDefault(a =>
string.Equals(a.AttributeClass?.ToDisplayString(), AttributeType.FullName, StringComparison.Ordinal));
}
///
/// 报告类必须是partial的诊断信息
///
/// 源生产上下文
/// 类声明语法节点
/// 命名类型符号
protected virtual void ReportClassMustBePartial(
SourceProductionContext context,
ClassDeclarationSyntax syntax,
INamedTypeSymbol symbol)
{
context.ReportDiagnostic(Diagnostic.Create(
CommonDiagnostics.ClassMustBePartial,
syntax.Identifier.GetLocation(),
symbol.Name));
}
///
/// 发出错误信息
///
/// 源生产上下文
/// 命名类型符号
/// 异常对象
protected virtual void EmitError(
SourceProductionContext context,
INamedTypeSymbol? symbol,
Exception ex)
{
var name = symbol?.Name ?? "Unknown";
var text =
$"// source generator error: {ex.Message}\n// {ex.StackTrace}";
context.AddSource($"{name}.Error.g.cs", text);
}
#endregion
}