using System.Text;
using GFramework.SourceGenerators.Common.constants;
using GFramework.SourceGenerators.Common.diagnostics;
using GFramework.SourceGenerators.Common.generator;
using GFramework.SourceGenerators.diagnostics;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
namespace GFramework.SourceGenerators.rule;
///
/// 上下文感知生成器,用于为标记了ContextAware特性的类自动生成IContextAware接口实现
///
[Generator]
public sealed class ContextAwareGenerator : MetadataAttributeClassGeneratorBase
{
///
/// 获取特性的元数据名称
///
protected override string AttributeMetadataName =>
$"{PathContests.SourceGeneratorsAbstractionsPath}.rule.ContextAwareAttribute";
///
/// 获取特性的短名称(不包含后缀)
///
protected override string AttributeShortNameWithoutSuffix => "ContextAware";
///
/// 验证符号是否符合生成条件
///
/// 源生产上下文
/// 编译对象
/// 类声明语法节点
/// 命名类型符号
/// 特性数据
/// 验证是否通过
protected override bool ValidateSymbol(
SourceProductionContext context,
Compilation compilation,
ClassDeclarationSyntax syntax,
INamedTypeSymbol symbol,
AttributeData attr)
{
// 1. 必须是 partial
if (!syntax.Modifiers.Any(SyntaxKind.PartialKeyword))
{
context.ReportDiagnostic(Diagnostic.Create(
CommonDiagnostics.ClassMustBePartial,
syntax.Identifier.GetLocation(),
symbol.Name));
return false;
}
// 2. 必须是 class
if (symbol.TypeKind != TypeKind.Class)
{
context.ReportDiagnostic(Diagnostic.Create(
ContextAwareDiagnostic.ContextAwareOnlyForClass,
syntax.Identifier.GetLocation(),
symbol.Name));
return false;
}
return true;
}
///
/// 生成源代码
///
/// 源生产上下文
/// 编译对象
/// 命名类型符号
/// 特性数据
/// 生成的源代码字符串
protected override string Generate(
SourceProductionContext context,
Compilation compilation,
INamedTypeSymbol symbol,
AttributeData attr)
{
var ns = symbol.ContainingNamespace.IsGlobalNamespace
? null
: symbol.ContainingNamespace.ToDisplayString();
var iContextAware = compilation.GetTypeByMetadataName(
$"{PathContests.CoreAbstractionsNamespace}.rule.IContextAware")!;
var sb = new StringBuilder();
sb.AppendLine("// ");
sb.AppendLine("#nullable enable");
sb.AppendLine();
if (ns is not null)
{
sb.AppendLine($"namespace {ns};");
sb.AppendLine();
}
var interfaceName = iContextAware.ToDisplayString(
SymbolDisplayFormat.FullyQualifiedFormat);
sb.AppendLine($"partial class {symbol.Name} : {interfaceName}");
sb.AppendLine("{");
GenerateContextProperty(sb);
GenerateInterfaceImplementations(sb, iContextAware);
sb.AppendLine("}");
return sb.ToString().TrimEnd();
}
///
/// 获取生成文件的提示名称
///
/// 命名类型符号
/// 生成文件的提示名称
protected override string GetHintName(INamedTypeSymbol symbol)
{
return $"{symbol.Name}.ContextAware.g.cs";
}
// =========================
// Context 属性(无 global::,与测试一致)
// =========================
///
/// 生成Context属性
///
/// 字符串构建器
private static void GenerateContextProperty(StringBuilder sb)
{
sb.AppendLine(" private global::GFramework.Core.Abstractions.architecture.IArchitectureContext? _context;");
sb.AppendLine();
sb.AppendLine(" /// ");
sb.AppendLine(" /// 自动获取的架构上下文(懒加载,默认使用第一个架构)");
sb.AppendLine(" /// ");
sb.AppendLine(" protected global::GFramework.Core.Abstractions.architecture.IArchitectureContext Context");
sb.AppendLine(" {");
sb.AppendLine(" get");
sb.AppendLine(" {");
sb.AppendLine(" if (_context == null)");
sb.AppendLine(" {");
sb.AppendLine(
" _context = global::GFramework.Core.architecture.GameContext.GetFirstArchitectureContext();");
sb.AppendLine(" }");
sb.AppendLine();
sb.AppendLine(" return _context;");
sb.AppendLine(" }");
sb.AppendLine(" }");
sb.AppendLine();
}
// =========================
// 显式接口实现(使用 global::)
// =========================
///
/// 生成接口实现
///
/// 字符串构建器
/// 接口符号
private static void GenerateInterfaceImplementations(
StringBuilder sb,
INamedTypeSymbol interfaceSymbol)
{
var interfaceName = interfaceSymbol.ToDisplayString(
SymbolDisplayFormat.FullyQualifiedFormat);
foreach (var method in interfaceSymbol.GetMembers().OfType())
{
if (method.MethodKind != MethodKind.Ordinary)
continue;
GenerateMethod(sb, interfaceName, method);
sb.AppendLine();
}
}
///
/// 生成方法实现
///
/// 字符串构建器
/// 接口名称
/// 方法符号
private static void GenerateMethod(
StringBuilder sb,
string interfaceName,
IMethodSymbol method)
{
var returnType = method.ReturnType.ToDisplayString(
SymbolDisplayFormat.FullyQualifiedFormat);
var parameters = string.Join(", ",
method.Parameters.Select(p =>
$"{p.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)} {p.Name}"));
sb.AppendLine(
$" {returnType} {interfaceName}.{method.Name}({parameters})");
sb.AppendLine(" {");
GenerateMethodBody(sb, method);
sb.AppendLine(" }");
}
///
/// 生成方法体
///
/// 字符串构建器
/// 方法符号
private static void GenerateMethodBody(
StringBuilder sb,
IMethodSymbol method)
{
switch (method.Name)
{
case "SetContext":
sb.AppendLine(" _context = context;");
break;
case "GetContext":
sb.AppendLine(" return Context;");
break;
default:
if (!method.ReturnsVoid)
sb.AppendLine(
$" throw new System.NotImplementedException(\"Method '{method.Name}' is not supported.\");");
break;
}
}
}