diff --git a/GFramework.SourceGenerators/Rule/ContextGetGenerator.cs b/GFramework.SourceGenerators/Rule/ContextGetGenerator.cs index d2beb2a..902b0df 100644 --- a/GFramework.SourceGenerators/Rule/ContextGetGenerator.cs +++ b/GFramework.SourceGenerators/Rule/ContextGetGenerator.cs @@ -171,11 +171,22 @@ public sealed class ContextGetGenerator : IIncrementalGenerator var descriptors = ResolveBindingDescriptors(compilation); var getAllAttribute = compilation.GetTypeByMetadataName(GetAllAttributeMetadataName); - if (descriptors.Length == 0 && getAllAttribute is null) return; - var symbols = new ContextSymbols( + var symbols = CreateContextSymbols(compilation); + var workItems = CollectWorkItems( + fieldCandidates, + typeCandidates, + descriptors, + getAllAttribute); + + GenerateSources(context, descriptors, symbols, workItems); + } + + private static ContextSymbols CreateContextSymbols(Compilation compilation) + { + return new ContextSymbols( compilation.GetTypeByMetadataName(ContextAwareAttributeMetadataName), compilation.GetTypeByMetadataName(IContextAwareMetadataName), compilation.GetTypeByMetadataName(ContextAwareBaseMetadataName), @@ -184,83 +195,20 @@ public sealed class ContextGetGenerator : IIncrementalGenerator compilation.GetTypeByMetadataName(IUtilityMetadataName), compilation.GetTypeByMetadataName(IReadOnlyListMetadataName), compilation.GetTypeByMetadataName(GodotNodeMetadataName)); + } - var workItems = CollectWorkItems( - fieldCandidates, - typeCandidates, - descriptors, - getAllAttribute); - + private static void GenerateSources( + SourceProductionContext context, + ImmutableArray descriptors, + ContextSymbols symbols, + Dictionary workItems) + { foreach (var workItem in workItems.Values) { if (!CanGenerateForType(context, workItem, symbols)) continue; - var bindings = new List(); - var explicitFields = new HashSet(SymbolEqualityComparer.Default); - - foreach (var candidate in workItem.FieldCandidates - .OrderBy(static candidate => candidate.Variable.SpanStart) - .ThenBy(static candidate => candidate.FieldSymbol.Name, StringComparer.Ordinal)) - { - var matches = ResolveExplicitBindings(candidate.FieldSymbol, descriptors); - if (matches.Length == 0) - continue; - - explicitFields.Add(candidate.FieldSymbol); - - if (matches.Length > 1) - { - ReportFieldDiagnostic( - context, - ContextGetDiagnostics.MultipleBindingAttributesNotSupported, - candidate); - continue; - } - - if (!TryCreateExplicitBinding( - context, - candidate, - matches[0], - symbols, - out var binding)) - continue; - - bindings.Add(binding); - } - - if (workItem.GetAllDeclaration is not null) - { - foreach (var field in GetAllFields(workItem.TypeSymbol)) - { - if (explicitFields.Contains(field)) - continue; - - if (field.IsStatic) - { - ReportFieldDiagnostic( - context, - ContextGetDiagnostics.StaticFieldNotSupported, - field); - continue; - } - - if (field.IsReadOnly) - { - ReportFieldDiagnostic( - context, - ContextGetDiagnostics.ReadOnlyFieldNotSupported, - field); - continue; - } - - if (!TryCreateInferredBinding(field, symbols, out var binding)) - continue; - - bindings.Add(binding); - } - } - + var bindings = CollectBindings(context, workItem, descriptors, symbols); if (bindings.Count == 0 && workItem.GetAllDeclaration is null) continue; @@ -269,6 +217,106 @@ public sealed class ContextGetGenerator : IIncrementalGenerator } } + private static List CollectBindings( + SourceProductionContext context, + TypeWorkItem workItem, + ImmutableArray descriptors, + ContextSymbols symbols) + { + var bindings = new List(); + var explicitFields = new HashSet(SymbolEqualityComparer.Default); + + AddExplicitBindings(context, workItem, descriptors, symbols, bindings, explicitFields); + AddInferredBindings(context, workItem, symbols, bindings, explicitFields); + + return bindings; + } + + private static void AddExplicitBindings( + SourceProductionContext context, + TypeWorkItem workItem, + ImmutableArray descriptors, + ContextSymbols symbols, + ICollection bindings, + ISet explicitFields) + { + foreach (var candidate in workItem.FieldCandidates + .OrderBy(static candidate => candidate.Variable.SpanStart) + .ThenBy(static candidate => candidate.FieldSymbol.Name, StringComparer.Ordinal)) + { + var matches = ResolveExplicitBindings(candidate.FieldSymbol, descriptors); + if (matches.Length == 0) + continue; + + explicitFields.Add(candidate.FieldSymbol); + + if (matches.Length > 1) + { + ReportFieldDiagnostic( + context, + ContextGetDiagnostics.MultipleBindingAttributesNotSupported, + candidate); + continue; + } + + if (!TryCreateExplicitBinding( + context, + candidate, + matches[0], + symbols, + out var binding)) + continue; + + bindings.Add(binding); + } + } + + private static void AddInferredBindings( + SourceProductionContext context, + TypeWorkItem workItem, + ContextSymbols symbols, + ICollection bindings, + ISet explicitFields) + { + if (workItem.GetAllDeclaration is null) + return; + + foreach (var field in GetAllFields(workItem.TypeSymbol)) + { + if (explicitFields.Contains(field)) + continue; + + if (!CanInferBinding(context, field)) + continue; + + if (!TryCreateInferredBinding(field, symbols, out var binding)) + continue; + + bindings.Add(binding); + } + } + + private static bool CanInferBinding(SourceProductionContext context, IFieldSymbol field) + { + if (field.IsStatic) + { + ReportFieldDiagnostic( + context, + ContextGetDiagnostics.StaticFieldNotSupported, + field); + return false; + } + + if (!field.IsReadOnly) + return true; + + ReportFieldDiagnostic( + context, + ContextGetDiagnostics.ReadOnlyFieldNotSupported, + field); + return false; + } + private static Dictionary CollectWorkItems( ImmutableArray fieldCandidates, ImmutableArray typeCandidates,