diff --git a/GFramework.SourceGenerators.Tests/Core/GeneratorTest.cs b/GFramework.SourceGenerators.Tests/Core/GeneratorTest.cs index a622d387..bed493e7 100644 --- a/GFramework.SourceGenerators.Tests/Core/GeneratorTest.cs +++ b/GFramework.SourceGenerators.Tests/Core/GeneratorTest.cs @@ -16,6 +16,24 @@ public static class GeneratorTest public static async Task RunAsync( string source, params (string filename, string content)[] generatedSources) + { + await RunAsync( + source, + additionalReferences: [], + generatedSources); + } + + /// + /// 运行源代码生成器测试,并为测试编译显式追加元数据引用。 + /// + /// 输入的源代码。 + /// 附加元数据引用,用于构造多程序集场景。 + /// 期望生成的源文件集合,包含文件名和内容的元组。 + /// 异步操作任务。 + public static async Task RunAsync( + string source, + IEnumerable additionalReferences, + params (string filename, string content)[] generatedSources) { var test = new CSharpSourceGeneratorTest { @@ -31,6 +49,9 @@ public static class GeneratorTest test.TestState.GeneratedSources.Add( (typeof(TGenerator), filename, NormalizeLineEndings(content))); + foreach (var additionalReference in additionalReferences) + test.TestState.AdditionalReferences.Add(additionalReference); + await test.RunAsync(); } @@ -46,4 +67,4 @@ public static class GeneratorTest .Replace("\r", "\n", StringComparison.Ordinal) .Replace("\n", Environment.NewLine, StringComparison.Ordinal); } -} \ No newline at end of file +} diff --git a/GFramework.SourceGenerators.Tests/Core/MetadataReferenceTestBuilder.cs b/GFramework.SourceGenerators.Tests/Core/MetadataReferenceTestBuilder.cs new file mode 100644 index 00000000..173a88d9 --- /dev/null +++ b/GFramework.SourceGenerators.Tests/Core/MetadataReferenceTestBuilder.cs @@ -0,0 +1,65 @@ +using System.Collections.Immutable; +using System.IO; +using Microsoft.CodeAnalysis.CSharp; + +namespace GFramework.SourceGenerators.Tests.Core; + +/// +/// 为多程序集源生成器测试构建内存元数据引用。 +/// +public static class MetadataReferenceTestBuilder +{ + /// + /// 将给定源码编译为内存程序集,并返回可供测试编译消费的元数据引用。 + /// + /// 目标程序集名称。 + /// 待编译源码。 + /// 附加元数据引用,用于构造依赖链。 + /// 编译成功后的内存元数据引用。 + public static MetadataReference CreateFromSource( + string assemblyName, + string source, + params MetadataReference[] additionalReferences) + { + var syntaxTree = CSharpSyntaxTree.ParseText(source); + var references = GetRuntimeMetadataReferences() + .Concat(additionalReferences) + .ToImmutableArray(); + var compilation = CSharpCompilation.Create( + assemblyName, + [syntaxTree], + references, + new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary)); + + using var stream = new MemoryStream(); + var emitResult = compilation.Emit(stream); + if (!emitResult.Success) + { + var diagnostics = string.Join( + Environment.NewLine, + emitResult.Diagnostics + .Where(static diagnostic => diagnostic.Severity == DiagnosticSeverity.Error) + .Select(static diagnostic => diagnostic.ToString())); + throw new InvalidOperationException( + $"Failed to build metadata reference '{assemblyName}'.{Environment.NewLine}{diagnostics}"); + } + + stream.Position = 0; + return MetadataReference.CreateFromImage(stream.ToArray()); + } + + /// + /// 获取当前测试运行时可直接复用的基础元数据引用集合。 + /// + /// 当前运行时可信平台程序集对应的元数据引用。 + public static ImmutableArray GetRuntimeMetadataReferences() + { + var trustedPlatformAssemblies = ((string?)AppContext.GetData("TRUSTED_PLATFORM_ASSEMBLIES"))? + .Split(Path.PathSeparator, StringSplitOptions.RemoveEmptyEntries) + ?? Array.Empty(); + + return trustedPlatformAssemblies + .Select(static path => (MetadataReference)MetadataReference.CreateFromFile(path)) + .ToImmutableArray(); + } +} diff --git a/GFramework.SourceGenerators.Tests/Cqrs/CqrsHandlerRegistryGeneratorTests.cs b/GFramework.SourceGenerators.Tests/Cqrs/CqrsHandlerRegistryGeneratorTests.cs index 3026d279..dc8c9abb 100644 --- a/GFramework.SourceGenerators.Tests/Cqrs/CqrsHandlerRegistryGeneratorTests.cs +++ b/GFramework.SourceGenerators.Tests/Cqrs/CqrsHandlerRegistryGeneratorTests.cs @@ -1,6 +1,7 @@ using System.Reflection; using GFramework.SourceGenerators.Cqrs; using GFramework.SourceGenerators.Tests.Core; +using Microsoft.CodeAnalysis.CSharp; namespace GFramework.SourceGenerators.Tests.Cqrs; @@ -825,6 +826,135 @@ public class CqrsHandlerRegistryGeneratorTests ("CqrsHandlerRegistry.g.cs", MixedReflectedImplementationAndPreciseRegistrationsExpected)); } + /// + /// 验证当外部基类暴露的 handler interface 含有生成注册器顶层上下文不可直接引用的 protected 类型时, + /// 生成器会保留已知直注册,并只对剩余未知接口做本地 interface discovery。 + /// + [Test] + public void Generates_Partial_Runtime_Interface_Discovery_For_Inaccessible_External_Protected_Types() + { + const string contractsSource = """ + namespace GFramework.Cqrs.Abstractions.Cqrs + { + public interface IRequest { } + public interface INotification { } + public interface IStreamRequest { } + + public interface IRequestHandler where TRequest : IRequest { } + public interface INotificationHandler where TNotification : INotification { } + public interface IStreamRequestHandler where TRequest : IStreamRequest { } + } + """; + + const string dependencySource = """ + using GFramework.Cqrs.Abstractions.Cqrs; + + namespace Dep; + + public sealed record VisibleRequest() : IRequest; + + public abstract class VisibilityScope + { + protected internal sealed record ProtectedResponse(); + + protected internal sealed record ProtectedRequest() : IRequest; + } + + public abstract class HandlerBase : + VisibilityScope, + IRequestHandler, + IRequestHandler + { + } + """; + + const string source = """ + using System; + using Dep; + + namespace Microsoft.Extensions.DependencyInjection + { + public interface IServiceCollection { } + + public static class ServiceCollectionServiceExtensions + { + public static void AddTransient(IServiceCollection services, Type serviceType, Type implementationType) { } + } + } + + namespace GFramework.Core.Abstractions.Logging + { + public interface ILogger + { + void Debug(string msg); + } + } + + namespace GFramework.Cqrs + { + public interface ICqrsHandlerRegistry + { + void Register(Microsoft.Extensions.DependencyInjection.IServiceCollection services, GFramework.Core.Abstractions.Logging.ILogger logger); + } + + [AttributeUsage(AttributeTargets.Assembly, AllowMultiple = true)] + public sealed class CqrsHandlerRegistryAttribute : Attribute + { + public CqrsHandlerRegistryAttribute(Type registryType) { } + } + } + + namespace TestApp + { + public sealed class DerivedHandler : HandlerBase + { + } + } + """; + + var contractsReference = MetadataReferenceTestBuilder.CreateFromSource( + "Contracts", + contractsSource); + var dependencyReference = MetadataReferenceTestBuilder.CreateFromSource( + "Dependency", + dependencySource, + contractsReference); + var generatedSource = RunGenerator( + source, + contractsReference, + dependencyReference); + + Assert.Multiple(() => + { + Assert.That( + generatedSource, + Does.Contain("var implementationType0 = typeof(global::TestApp.DerivedHandler);")); + Assert.That( + generatedSource, + Does.Contain( + "var knownServiceTypes0 = new global::System.Collections.Generic.HashSet();")); + Assert.That( + generatedSource, + Does.Contain( + "knownServiceTypes0.Add(typeof(global::GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler));")); + Assert.That( + generatedSource, + Does.Contain( + "RegisterRemainingReflectedHandlerInterfaces(services, logger, implementationType0, knownServiceTypes0);")); + Assert.That( + generatedSource, + Does.Contain("if (knownServiceTypes.Contains(handlerInterface))")); + Assert.That( + generatedSource, + Does.Contain( + "Registered CQRS handler TestApp.DerivedHandler as GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler.")); + Assert.That( + generatedSource, + Does.Not.Contain( + "typeof(global::GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler /// 验证即使 runtime 仍暴露旧版无参 fallback marker,生成器也会优先在生成注册器内部处理隐藏 handler, /// 不再输出 fallback marker。 @@ -999,4 +1129,38 @@ public class CqrsHandlerRegistryGeneratorTests Assert.That(escaped, Is.EqualTo(expected)); } + + /// + /// 运行 CQRS handler registry generator,并返回单个生成文件的源码文本。 + /// + private static string RunGenerator( + string source, + params MetadataReference[] additionalReferences) + { + var syntaxTree = CSharpSyntaxTree.ParseText(source); + var compilation = CSharpCompilation.Create( + "TestProject", + [syntaxTree], + MetadataReferenceTestBuilder.GetRuntimeMetadataReferences().AddRange(additionalReferences), + new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary)); + + GeneratorDriver driver = CSharpGeneratorDriver.Create( + generators: [new CqrsHandlerRegistryGenerator().AsSourceGenerator()], + parseOptions: (CSharpParseOptions)syntaxTree.Options); + driver = driver.RunGeneratorsAndUpdateCompilation( + compilation, + out var updatedCompilation, + out _); + + var compilationErrors = updatedCompilation.GetDiagnostics() + .Where(static diagnostic => diagnostic.Severity == DiagnosticSeverity.Error) + .ToArray(); + Assert.That(compilationErrors, Is.Empty, string.Join(Environment.NewLine, compilationErrors)); + + var runResult = driver.GetRunResult(); + Assert.That(runResult.Results, Has.Length.EqualTo(1)); + Assert.That(runResult.Results[0].GeneratedSources, Has.Length.EqualTo(1)); + + return runResult.Results[0].GeneratedSources[0].SourceText.ToString(); + } } diff --git a/GFramework.SourceGenerators/Cqrs/CqrsHandlerRegistryGenerator.cs b/GFramework.SourceGenerators/Cqrs/CqrsHandlerRegistryGenerator.cs index f3bb5e96..49065f0c 100644 --- a/GFramework.SourceGenerators/Cqrs/CqrsHandlerRegistryGenerator.cs +++ b/GFramework.SourceGenerators/Cqrs/CqrsHandlerRegistryGenerator.cs @@ -86,15 +86,17 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator var implementationTypeDisplayName = type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); var implementationLogName = GetLogDisplayName(type); - var canReferenceImplementation = CanReferenceFromGeneratedRegistry(type); + var canReferenceImplementation = CanReferenceFromGeneratedRegistry(context.SemanticModel.Compilation, type); var registrations = ImmutableArray.CreateBuilder(handlerInterfaces.Length); var reflectedImplementationRegistrations = ImmutableArray.CreateBuilder(handlerInterfaces.Length); var preciseReflectedRegistrations = ImmutableArray.CreateBuilder(handlerInterfaces.Length); + var requiresRuntimeInterfaceDiscovery = false; foreach (var handlerInterface in handlerInterfaces) { - var canReferenceHandlerInterface = CanReferenceFromGeneratedRegistry(handlerInterface); + var canReferenceHandlerInterface = + CanReferenceFromGeneratedRegistry(context.SemanticModel.Compilation, handlerInterface); if (canReferenceImplementation && canReferenceHandlerInterface) { registrations.Add(new HandlerRegistrationSpec( @@ -122,16 +124,10 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator continue; } - // Some closed handler interfaces still contain runtime-only type shapes such as arrays closed over - // non-public element types. For those rare cases keep the narrow implementation lookup, but let the - // generated registry discover the exact supported interfaces from the implementation type at runtime. - return new HandlerCandidateAnalysis( - implementationTypeDisplayName, - implementationLogName, - ImmutableArray.Empty, - ImmutableArray.Empty, - ImmutableArray.Empty, - GetReflectionTypeMetadataName(type)); + // 某些关闭 handler interface 仍包含只能在实现类型运行时语义里解析的类型形态。 + // 对这些边角场景保留“已知接口静态注册 + 剩余接口运行时补洞”的组合路径, + // 避免单个未知接口把同实现上的其它已知注册全部拖回整实现反射发现。 + requiresRuntimeInterfaceDiscovery = true; } return new HandlerCandidateAnalysis( @@ -140,7 +136,8 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator registrations.ToImmutable(), reflectedImplementationRegistrations.ToImmutable(), preciseReflectedRegistrations.ToImmutable(), - canReferenceImplementation ? null : GetReflectionTypeMetadataName(type)); + canReferenceImplementation ? null : GetReflectionTypeMetadataName(type), + requiresRuntimeInterfaceDiscovery); } private static void Execute(SourceProductionContext context, GenerationEnvironment generationEnvironment, @@ -184,7 +181,8 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator candidate.Registrations, candidate.ReflectedImplementationRegistrations, candidate.PreciseReflectedRegistrations, - candidate.ReflectionTypeMetadataName)); + candidate.ReflectionTypeMetadataName, + candidate.RequiresRuntimeInterfaceDiscovery)); } registrations.Sort(static (left, right) => @@ -295,7 +293,7 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator ITypeSymbol type, out RuntimeTypeReferenceSpec? runtimeTypeReference) { - if (CanReferenceFromGeneratedRegistry(type)) + if (CanReferenceFromGeneratedRegistry(compilation, type)) { runtimeTypeReference = RuntimeTypeReferenceSpec.FromDirectReference( type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)); @@ -369,7 +367,7 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator out RuntimeTypeReferenceSpec? genericTypeDefinitionReference) { var genericTypeDefinition = genericNamedType.OriginalDefinition; - if (CanReferenceFromGeneratedRegistry(genericTypeDefinition)) + if (CanReferenceFromGeneratedRegistry(compilation, genericTypeDefinition)) { genericTypeDefinitionReference = RuntimeTypeReferenceSpec.FromDirectReference( genericTypeDefinition @@ -389,19 +387,25 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator return false; } - private static bool CanReferenceFromGeneratedRegistry(ITypeSymbol type) + private static bool CanReferenceFromGeneratedRegistry(Compilation compilation, ITypeSymbol type) { switch (type) { case IArrayTypeSymbol arrayType: - return CanReferenceFromGeneratedRegistry(arrayType.ElementType); + return CanReferenceFromGeneratedRegistry(compilation, arrayType.ElementType); case INamedTypeSymbol namedType: - if (!IsTypeChainAccessible(namedType)) + if (!compilation.IsSymbolAccessibleWithin(namedType, compilation.Assembly, throughType: null)) return false; - return namedType.TypeArguments.All(CanReferenceFromGeneratedRegistry); + foreach (var typeArgument in namedType.TypeArguments) + { + if (!CanReferenceFromGeneratedRegistry(compilation, typeArgument)) + return false; + } + + return true; case IPointerTypeSymbol pointerType: - return CanReferenceFromGeneratedRegistry(pointerType.PointedAtType); + return CanReferenceFromGeneratedRegistry(compilation, pointerType.PointedAtType); case ITypeParameterSymbol: return false; default: @@ -409,23 +413,6 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator } } - private static bool IsTypeChainAccessible(INamedTypeSymbol type) - { - for (var current = type; current is not null; current = current.ContainingType) - { - if (!IsSymbolAccessible(current)) - return false; - } - - return true; - } - - private static bool IsSymbolAccessible(ISymbol symbol) - { - return symbol.DeclaredAccessibility is Accessibility.Public or Accessibility.Internal - or Accessibility.ProtectedOrInternal; - } - private static string GetFullyQualifiedMetadataName(INamedTypeSymbol type) { var nestedTypes = new Stack(); @@ -496,10 +483,8 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator !registration.ReflectedImplementationRegistrations.IsDefaultOrEmpty); var hasPreciseReflectedRegistrations = registrations.Any(static registration => !registration.PreciseReflectedRegistrations.IsDefaultOrEmpty); - var hasFullReflectionRegistrations = registrations.Any(static registration => - !string.IsNullOrWhiteSpace(registration.ReflectionTypeMetadataName) && - registration.ReflectedImplementationRegistrations.IsDefaultOrEmpty && - registration.PreciseReflectedRegistrations.IsDefaultOrEmpty); + var hasRuntimeInterfaceDiscovery = registrations.Any(static registration => + registration.RequiresRuntimeInterfaceDiscovery); var builder = new StringBuilder(); builder.AppendLine("// "); builder.AppendLine("#nullable enable"); @@ -533,7 +518,8 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator builder.AppendLine(" if (logger is null)"); builder.AppendLine(" throw new global::System.ArgumentNullException(nameof(logger));"); if (hasReflectedImplementationRegistrations || hasPreciseReflectedRegistrations || - hasFullReflectionRegistrations) + registrations.Any(static registration => + !string.IsNullOrWhiteSpace(registration.ReflectionTypeMetadataName))) { builder.AppendLine(); builder.Append(" var registryAssembly = typeof(global::"); @@ -550,7 +536,8 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator { var registration = registrations[registrationIndex]; if (!registration.ReflectedImplementationRegistrations.IsDefaultOrEmpty || - !registration.PreciseReflectedRegistrations.IsDefaultOrEmpty) + !registration.PreciseReflectedRegistrations.IsDefaultOrEmpty || + registration.RequiresRuntimeInterfaceDiscovery) { AppendOrderedImplementationRegistrations(builder, registration, registrationIndex); } @@ -558,19 +545,11 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator { AppendDirectRegistrations(builder, registration); } - - if (!string.IsNullOrWhiteSpace(registration.ReflectionTypeMetadataName) && - registration.ReflectedImplementationRegistrations.IsDefaultOrEmpty && - registration.PreciseReflectedRegistrations.IsDefaultOrEmpty && - registration.DirectRegistrations.IsDefaultOrEmpty) - { - AppendReflectionRegistration(builder, registration.ReflectionTypeMetadataName!); - } } builder.AppendLine(" }"); - if (hasFullReflectionRegistrations) + if (hasRuntimeInterfaceDiscovery) { builder.AppendLine(); AppendReflectionHelpers(builder); @@ -580,13 +559,6 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator return builder.ToString(); } - private static void AppendReflectionRegistration(StringBuilder builder, string reflectionTypeMetadataName) - { - builder.Append(" RegisterReflectedHandler(services, logger, registryAssembly, \""); - builder.Append(EscapeStringLiteral(reflectionTypeMetadataName)); - builder.AppendLine("\");"); - } - private static void AppendDirectRegistrations( StringBuilder builder, ImplementationRegistrationSpec registration) @@ -653,6 +625,7 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator StringComparer.Ordinal.Compare(left.HandlerInterfaceLogName, right.HandlerInterfaceLogName)); var implementationVariableName = $"implementationType{registrationIndex}"; + var knownServiceTypesVariableName = $"knownServiceTypes{registrationIndex}"; if (string.IsNullOrWhiteSpace(registration.ReflectionTypeMetadataName)) { builder.Append(" var "); @@ -675,12 +648,28 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator builder.AppendLine(" is not null)"); builder.AppendLine(" {"); + if (registration.RequiresRuntimeInterfaceDiscovery) + { + builder.Append(" var "); + builder.Append(knownServiceTypesVariableName); + builder.AppendLine(" = new global::System.Collections.Generic.HashSet();"); + } + foreach (var orderedRegistration in orderedRegistrations) { switch (orderedRegistration.Kind) { case OrderedRegistrationKind.Direct: var directRegistration = registration.DirectRegistrations[orderedRegistration.Index]; + if (registration.RequiresRuntimeInterfaceDiscovery) + { + builder.Append(" "); + builder.Append(knownServiceTypesVariableName); + builder.Append(".Add(typeof("); + builder.Append(directRegistration.HandlerInterfaceDisplayName); + builder.AppendLine("));"); + } + builder.AppendLine( " global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient("); builder.AppendLine(" services,"); @@ -699,6 +688,15 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator case OrderedRegistrationKind.ReflectedImplementation: var reflectedRegistration = registration.ReflectedImplementationRegistrations[orderedRegistration.Index]; + if (registration.RequiresRuntimeInterfaceDiscovery) + { + builder.Append(" "); + builder.Append(knownServiceTypesVariableName); + builder.Append(".Add(typeof("); + builder.Append(reflectedRegistration.HandlerInterfaceDisplayName); + builder.AppendLine("));"); + } + builder.AppendLine( " global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient("); builder.AppendLine(" services,"); @@ -725,6 +723,8 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator preciseRegistration.OpenHandlerTypeDisplayName, registration.ImplementationLogName, preciseRegistration.HandlerInterfaceLogName, + knownServiceTypesVariableName, + registration.RequiresRuntimeInterfaceDiscovery, 3); break; default: @@ -733,6 +733,15 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator } } + if (registration.RequiresRuntimeInterfaceDiscovery) + { + builder.Append(" RegisterRemainingReflectedHandlerInterfaces(services, logger, "); + builder.Append(implementationVariableName); + builder.Append(", "); + builder.Append(knownServiceTypesVariableName); + builder.AppendLine(");"); + } + builder.AppendLine(" }"); } @@ -744,6 +753,8 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator string openHandlerTypeDisplayName, string implementationLogName, string handlerInterfaceLogName, + string knownServiceTypesVariableName, + bool trackKnownServiceTypes, int indentLevel) { var indent = new string(' ', indentLevel * 4); @@ -808,6 +819,15 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator builder.Append(" "); builder.Append(implementationVariableName); builder.AppendLine(");"); + if (trackKnownServiceTypes) + { + builder.Append(indent); + builder.Append(knownServiceTypesVariableName); + builder.Append(".Add("); + builder.Append(registrationVariablePrefix); + builder.AppendLine(");"); + } + builder.Append(indent); builder.Append("logger.Debug(\"Registered CQRS handler "); builder.Append(EscapeStringLiteral(implementationLogName)); @@ -884,15 +904,11 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator private static void AppendReflectionHelpers(StringBuilder builder) { - // Emit the runtime helper methods only when at least one handler requires metadata-name lookup. + // Emit the runtime helper methods only when at least one handler still needs implementation-scoped + // interface discovery after all direct / precise registrations have been emitted. builder.AppendLine( - " private static void RegisterReflectedHandler(global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, global::GFramework.Core.Abstractions.Logging.ILogger logger, global::System.Reflection.Assembly registryAssembly, string implementationTypeMetadataName)"); + " private static void RegisterRemainingReflectedHandlerInterfaces(global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, global::GFramework.Core.Abstractions.Logging.ILogger logger, global::System.Type implementationType, global::System.Collections.Generic.ISet knownServiceTypes)"); builder.AppendLine(" {"); - builder.AppendLine( - " var implementationType = registryAssembly.GetType(implementationTypeMetadataName, throwOnError: false, ignoreCase: false);"); - builder.AppendLine(" if (implementationType is null)"); - builder.AppendLine(" return;"); - builder.AppendLine(); builder.AppendLine(" var handlerInterfaces = implementationType.GetInterfaces();"); builder.AppendLine(" global::System.Array.Sort(handlerInterfaces, CompareTypes);"); builder.AppendLine(); @@ -901,6 +917,9 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator builder.AppendLine(" if (!IsSupportedHandlerInterface(handlerInterface))"); builder.AppendLine(" continue;"); builder.AppendLine(); + builder.AppendLine(" if (knownServiceTypes.Contains(handlerInterface))"); + builder.AppendLine(" continue;"); + builder.AppendLine(); builder.AppendLine( " global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient("); builder.AppendLine(" services,"); @@ -908,6 +927,7 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator builder.AppendLine(" implementationType);"); builder.AppendLine( " logger.Debug($\"Registered CQRS handler {GetRuntimeTypeDisplayName(implementationType)} as {GetRuntimeTypeDisplayName(handlerInterface)}.\");"); + builder.AppendLine(" knownServiceTypes.Add(handlerInterface);"); builder.AppendLine(" }"); builder.AppendLine(" }"); builder.AppendLine(); @@ -1067,7 +1087,8 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator ImmutableArray DirectRegistrations, ImmutableArray ReflectedImplementationRegistrations, ImmutableArray PreciseReflectedRegistrations, - string? ReflectionTypeMetadataName); + string? ReflectionTypeMetadataName, + bool RequiresRuntimeInterfaceDiscovery); private readonly struct HandlerCandidateAnalysis : IEquatable { @@ -1077,7 +1098,8 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator ImmutableArray registrations, ImmutableArray reflectedImplementationRegistrations, ImmutableArray preciseReflectedRegistrations, - string? reflectionTypeMetadataName) + string? reflectionTypeMetadataName, + bool requiresRuntimeInterfaceDiscovery) { ImplementationTypeDisplayName = implementationTypeDisplayName; ImplementationLogName = implementationLogName; @@ -1085,6 +1107,7 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator ReflectedImplementationRegistrations = reflectedImplementationRegistrations; PreciseReflectedRegistrations = preciseReflectedRegistrations; ReflectionTypeMetadataName = reflectionTypeMetadataName; + RequiresRuntimeInterfaceDiscovery = requiresRuntimeInterfaceDiscovery; } public string ImplementationTypeDisplayName { get; } @@ -1099,6 +1122,8 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator public string? ReflectionTypeMetadataName { get; } + public bool RequiresRuntimeInterfaceDiscovery { get; } + public bool Equals(HandlerCandidateAnalysis other) { if (!string.Equals(ImplementationTypeDisplayName, other.ImplementationTypeDisplayName, @@ -1106,6 +1131,7 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator !string.Equals(ImplementationLogName, other.ImplementationLogName, StringComparison.Ordinal) || !string.Equals(ReflectionTypeMetadataName, other.ReflectionTypeMetadataName, StringComparison.Ordinal) || + RequiresRuntimeInterfaceDiscovery != other.RequiresRuntimeInterfaceDiscovery || Registrations.Length != other.Registrations.Length || ReflectedImplementationRegistrations.Length != other.ReflectedImplementationRegistrations.Length || PreciseReflectedRegistrations.Length != other.PreciseReflectedRegistrations.Length) @@ -1150,6 +1176,7 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator (ReflectionTypeMetadataName is null ? 0 : StringComparer.Ordinal.GetHashCode(ReflectionTypeMetadataName)); + hashCode = (hashCode * 397) ^ RequiresRuntimeInterfaceDiscovery.GetHashCode(); foreach (var registration in Registrations) { hashCode = (hashCode * 397) ^ registration.GetHashCode();