feat(cqrs-generator): 支持泛型与数组类型重建并优化隐藏处理器绑定

This commit is contained in:
gewuyou 2026-04-16 13:02:01 +08:00 committed by GeWuYou
parent 76bb9671d5
commit f25353db8c
2 changed files with 642 additions and 128 deletions

View File

@ -29,115 +29,26 @@ public class CqrsHandlerRegistryGeneratorTests
var registryAssembly = typeof(global::GFramework.Generated.Cqrs.__GFrameworkGeneratedCqrsHandlerRegistry).Assembly;
RegisterReflectedHandler(services, logger, registryAssembly, "TestApp.Container+HiddenHandler");
var implementationType0 = registryAssembly.GetType("TestApp.Container+HiddenHandler", throwOnError: false, ignoreCase: false);
if (implementationType0 is not null)
{
var serviceType0_0Argument0 = registryAssembly.GetType("TestApp.Container+HiddenRequest", throwOnError: false, ignoreCase: false);
if (serviceType0_0Argument0 is not null)
{
var serviceType0_0 = typeof(global::GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler<,>).MakeGenericType(serviceType0_0Argument0, typeof(string));
global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient(
services,
serviceType0_0,
implementationType0);
logger.Debug("Registered CQRS handler TestApp.Container.HiddenHandler as GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler<TestApp.Container.HiddenRequest, string>.");
}
}
global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient(
services,
typeof(global::GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler<global::TestApp.VisibleRequest, string>),
typeof(global::TestApp.VisibleHandler));
logger.Debug("Registered CQRS handler TestApp.VisibleHandler as GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler<TestApp.VisibleRequest, string>.");
}
private static void RegisterReflectedHandler(global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, global::GFramework.Core.Abstractions.Logging.ILogger logger, global::System.Reflection.Assembly registryAssembly, string implementationTypeMetadataName)
{
var implementationType = registryAssembly.GetType(implementationTypeMetadataName, throwOnError: false, ignoreCase: false);
if (implementationType is null)
return;
var handlerInterfaces = implementationType.GetInterfaces();
global::System.Array.Sort(handlerInterfaces, CompareTypes);
foreach (var handlerInterface in handlerInterfaces)
{
if (!IsSupportedHandlerInterface(handlerInterface))
continue;
global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient(
services,
handlerInterface,
implementationType);
logger.Debug($"Registered CQRS handler {GetRuntimeTypeDisplayName(implementationType)} as {GetRuntimeTypeDisplayName(handlerInterface)}.");
}
}
private static int CompareTypes(global::System.Type left, global::System.Type right)
{
return global::System.StringComparer.Ordinal.Compare(GetRuntimeTypeDisplayName(left), GetRuntimeTypeDisplayName(right));
}
private static bool IsSupportedHandlerInterface(global::System.Type interfaceType)
{
if (!interfaceType.IsGenericType)
return false;
var definitionFullName = interfaceType.GetGenericTypeDefinition().FullName;
return global::System.StringComparer.Ordinal.Equals(definitionFullName, "GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler`2")
|| global::System.StringComparer.Ordinal.Equals(definitionFullName, "GFramework.Cqrs.Abstractions.Cqrs.INotificationHandler`1")
|| global::System.StringComparer.Ordinal.Equals(definitionFullName, "GFramework.Cqrs.Abstractions.Cqrs.IStreamRequestHandler`2");
}
private static string GetRuntimeTypeDisplayName(global::System.Type type)
{
if (type == typeof(string))
return "string";
if (type == typeof(int))
return "int";
if (type == typeof(long))
return "long";
if (type == typeof(short))
return "short";
if (type == typeof(byte))
return "byte";
if (type == typeof(bool))
return "bool";
if (type == typeof(object))
return "object";
if (type == typeof(void))
return "void";
if (type == typeof(uint))
return "uint";
if (type == typeof(ulong))
return "ulong";
if (type == typeof(ushort))
return "ushort";
if (type == typeof(sbyte))
return "sbyte";
if (type == typeof(float))
return "float";
if (type == typeof(double))
return "double";
if (type == typeof(decimal))
return "decimal";
if (type == typeof(char))
return "char";
if (type.IsArray)
return GetRuntimeTypeDisplayName(type.GetElementType()!) + "[]";
if (!type.IsGenericType)
return (type.FullName ?? type.Name).Replace('+', '.');
var genericTypeName = type.GetGenericTypeDefinition().FullName ?? type.Name;
var arityIndex = genericTypeName.IndexOf('`');
if (arityIndex >= 0)
genericTypeName = genericTypeName[..arityIndex];
genericTypeName = genericTypeName.Replace('+', '.');
var arguments = type.GetGenericArguments();
var builder = new global::System.Text.StringBuilder();
builder.Append(genericTypeName);
builder.Append('<');
for (var index = 0; index < arguments.Length; index++)
{
if (index > 0)
builder.Append(", ");
builder.Append(GetRuntimeTypeDisplayName(arguments[index]));
}
builder.Append('>');
return builder.ToString();
}
}
""";
@ -175,6 +86,84 @@ public class CqrsHandlerRegistryGeneratorTests
""";
private const string HiddenArrayResponseFallbackExpected = """
// <auto-generated />
#nullable enable
[assembly: global::GFramework.Cqrs.CqrsHandlerRegistryAttribute(typeof(global::GFramework.Generated.Cqrs.__GFrameworkGeneratedCqrsHandlerRegistry))]
namespace GFramework.Generated.Cqrs;
internal sealed class __GFrameworkGeneratedCqrsHandlerRegistry : global::GFramework.Cqrs.ICqrsHandlerRegistry
{
public void Register(global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, global::GFramework.Core.Abstractions.Logging.ILogger logger)
{
if (services is null)
throw new global::System.ArgumentNullException(nameof(services));
if (logger is null)
throw new global::System.ArgumentNullException(nameof(logger));
var registryAssembly = typeof(global::GFramework.Generated.Cqrs.__GFrameworkGeneratedCqrsHandlerRegistry).Assembly;
var implementationType0 = registryAssembly.GetType("TestApp.Container+HiddenHandler", throwOnError: false, ignoreCase: false);
if (implementationType0 is not null)
{
var serviceType0_0Argument0 = registryAssembly.GetType("TestApp.Container+HiddenRequest", throwOnError: false, ignoreCase: false);
var serviceType0_0Argument1Element = registryAssembly.GetType("TestApp.Container+HiddenResponse", throwOnError: false, ignoreCase: false);
if (serviceType0_0Argument0 is not null && serviceType0_0Argument1Element is not null)
{
var serviceType0_0 = typeof(global::GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler<,>).MakeGenericType(serviceType0_0Argument0, serviceType0_0Argument1Element.MakeArrayType());
global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient(
services,
serviceType0_0,
implementationType0);
logger.Debug("Registered CQRS handler TestApp.Container.HiddenHandler as GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler<TestApp.Container.HiddenRequest, TestApp.Container.HiddenResponse[]>.");
}
}
}
}
""";
private const string HiddenGenericEnvelopeResponseExpected = """
// <auto-generated />
#nullable enable
[assembly: global::GFramework.Cqrs.CqrsHandlerRegistryAttribute(typeof(global::GFramework.Generated.Cqrs.__GFrameworkGeneratedCqrsHandlerRegistry))]
namespace GFramework.Generated.Cqrs;
internal sealed class __GFrameworkGeneratedCqrsHandlerRegistry : global::GFramework.Cqrs.ICqrsHandlerRegistry
{
public void Register(global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, global::GFramework.Core.Abstractions.Logging.ILogger logger)
{
if (services is null)
throw new global::System.ArgumentNullException(nameof(services));
if (logger is null)
throw new global::System.ArgumentNullException(nameof(logger));
var registryAssembly = typeof(global::GFramework.Generated.Cqrs.__GFrameworkGeneratedCqrsHandlerRegistry).Assembly;
var implementationType0 = registryAssembly.GetType("TestApp.Container+HiddenHandler", throwOnError: false, ignoreCase: false);
if (implementationType0 is not null)
{
var serviceType0_0Argument0 = registryAssembly.GetType("TestApp.Container+HiddenRequest", throwOnError: false, ignoreCase: false);
var serviceType0_0Argument1GenericDefinition = registryAssembly.GetType("TestApp.Container+HiddenEnvelope`1", throwOnError: false, ignoreCase: false);
if (serviceType0_0Argument0 is not null && serviceType0_0Argument1GenericDefinition is not null)
{
var serviceType0_0 = typeof(global::GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler<,>).MakeGenericType(serviceType0_0Argument0, serviceType0_0Argument1GenericDefinition.MakeGenericType(typeof(string)));
global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient(
services,
serviceType0_0,
implementationType0);
logger.Debug("Registered CQRS handler TestApp.Container.HiddenHandler as GFramework.Cqrs.Abstractions.Cqrs.IRequestHandler<TestApp.Container.HiddenRequest, TestApp.Container.HiddenEnvelope<string>>.");
}
}
}
}
""";
/// <summary>
/// 验证生成器会为当前程序集中的 request、notification 和 stream 处理器生成稳定顺序的注册器。
/// </summary>
@ -444,6 +433,152 @@ public class CqrsHandlerRegistryGeneratorTests
("CqrsHandlerRegistry.g.cs", HiddenImplementationDirectInterfaceRegistrationExpected));
}
/// <summary>
/// 验证精确重建路径会递归覆盖隐藏元素类型数组,
/// 使这类 handler interface 也能直接生成 closed service type而不再退回 <c>GetInterfaces()</c>。
/// </summary>
[Test]
public async Task Generates_Precise_Service_Type_For_Hidden_Array_Type_Arguments()
{
const string source = """
using System;
namespace Microsoft.Extensions.DependencyInjection
{
public interface IServiceCollection { }
public static class ServiceCollectionServiceExtensions
{
public static void AddTransient(IServiceCollection services, Type serviceType, Type implementationType) { }
}
}
namespace GFramework.Core.Abstractions.Logging
{
public interface ILogger
{
void Debug(string msg);
}
}
namespace GFramework.Cqrs.Abstractions.Cqrs
{
public interface IRequest<TResponse> { }
public interface INotification { }
public interface IStreamRequest<TResponse> { }
public interface IRequestHandler<in TRequest, TResponse> where TRequest : IRequest<TResponse> { }
public interface INotificationHandler<in TNotification> where TNotification : INotification { }
public interface IStreamRequestHandler<in TRequest, out TResponse> where TRequest : IStreamRequest<TResponse> { }
}
namespace GFramework.Cqrs
{
public interface ICqrsHandlerRegistry
{
void Register(Microsoft.Extensions.DependencyInjection.IServiceCollection services, GFramework.Core.Abstractions.Logging.ILogger logger);
}
[AttributeUsage(AttributeTargets.Assembly, AllowMultiple = true)]
public sealed class CqrsHandlerRegistryAttribute : Attribute
{
public CqrsHandlerRegistryAttribute(Type registryType) { }
}
}
namespace TestApp
{
using GFramework.Cqrs.Abstractions.Cqrs;
public sealed class Container
{
private sealed record HiddenResponse();
private sealed record HiddenRequest() : IRequest<HiddenResponse[]>;
private sealed class HiddenHandler : IRequestHandler<HiddenRequest, HiddenResponse[]> { }
}
}
""";
await GeneratorTest<CqrsHandlerRegistryGenerator>.RunAsync(
source,
("CqrsHandlerRegistry.g.cs", HiddenArrayResponseFallbackExpected));
}
/// <summary>
/// 验证精确重建路径会递归覆盖隐藏泛型定义,
/// 使“隐藏泛型定义 + 可见/常量型实参”的闭包类型也能直接生成 closed service type。
/// </summary>
[Test]
public async Task Generates_Precise_Service_Type_For_Hidden_Generic_Type_Definitions()
{
const string source = """
using System;
namespace Microsoft.Extensions.DependencyInjection
{
public interface IServiceCollection { }
public static class ServiceCollectionServiceExtensions
{
public static void AddTransient(IServiceCollection services, Type serviceType, Type implementationType) { }
}
}
namespace GFramework.Core.Abstractions.Logging
{
public interface ILogger
{
void Debug(string msg);
}
}
namespace GFramework.Cqrs.Abstractions.Cqrs
{
public interface IRequest<TResponse> { }
public interface INotification { }
public interface IStreamRequest<TResponse> { }
public interface IRequestHandler<in TRequest, TResponse> where TRequest : IRequest<TResponse> { }
public interface INotificationHandler<in TNotification> where TNotification : INotification { }
public interface IStreamRequestHandler<in TRequest, out TResponse> where TRequest : IStreamRequest<TResponse> { }
}
namespace GFramework.Cqrs
{
public interface ICqrsHandlerRegistry
{
void Register(Microsoft.Extensions.DependencyInjection.IServiceCollection services, GFramework.Core.Abstractions.Logging.ILogger logger);
}
[AttributeUsage(AttributeTargets.Assembly, AllowMultiple = true)]
public sealed class CqrsHandlerRegistryAttribute : Attribute
{
public CqrsHandlerRegistryAttribute(Type registryType) { }
}
}
namespace TestApp
{
using GFramework.Cqrs.Abstractions.Cqrs;
public sealed class Container
{
private sealed class HiddenEnvelope<T> { }
private sealed record HiddenRequest() : IRequest<HiddenEnvelope<string>>;
private sealed class HiddenHandler : IRequestHandler<HiddenRequest, HiddenEnvelope<string>> { }
}
}
""";
await GeneratorTest<CqrsHandlerRegistryGenerator>.RunAsync(
source,
("CqrsHandlerRegistry.g.cs", HiddenGenericEnvelopeResponseExpected));
}
/// <summary>
/// 验证即使 runtime 仍暴露旧版无参 fallback marker生成器也会优先在生成注册器内部处理隐藏 handler
/// 不再输出 fallback marker。

View File

@ -90,36 +90,48 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
var registrations = ImmutableArray.CreateBuilder<HandlerRegistrationSpec>(handlerInterfaces.Length);
var reflectedImplementationRegistrations =
ImmutableArray.CreateBuilder<ReflectedImplementationRegistrationSpec>(handlerInterfaces.Length);
var preciseReflectedRegistrations =
ImmutableArray.CreateBuilder<PreciseReflectedRegistrationSpec>(handlerInterfaces.Length);
foreach (var handlerInterface in handlerInterfaces)
{
var canReferenceHandlerInterface = CanReferenceFromGeneratedRegistry(handlerInterface);
if (!canReferenceImplementation || !canReferenceHandlerInterface)
if (canReferenceImplementation && canReferenceHandlerInterface)
{
if (!canReferenceImplementation && canReferenceHandlerInterface)
{
reflectedImplementationRegistrations.Add(new ReflectedImplementationRegistrationSpec(
handlerInterface.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat),
GetLogDisplayName(handlerInterface)));
continue;
}
// Non-public handlers closed over non-public message types still cannot be expressed purely as
// typeof(...) registrations. Preserve generator hit rate by resolving only the affected
// implementation back from the current assembly instead of asking the runtime registrar to rescan
// the whole assembly.
return new HandlerCandidateAnalysis(
registrations.Add(new HandlerRegistrationSpec(
handlerInterface.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat),
implementationTypeDisplayName,
implementationLogName,
ImmutableArray<HandlerRegistrationSpec>.Empty,
ImmutableArray<ReflectedImplementationRegistrationSpec>.Empty,
GetReflectionTypeMetadataName(type));
GetLogDisplayName(handlerInterface),
implementationLogName));
continue;
}
registrations.Add(new HandlerRegistrationSpec(
handlerInterface.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat),
if (!canReferenceImplementation && canReferenceHandlerInterface)
{
reflectedImplementationRegistrations.Add(new ReflectedImplementationRegistrationSpec(
handlerInterface.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat),
GetLogDisplayName(handlerInterface)));
continue;
}
if (TryCreatePreciseReflectedRegistration(
context.SemanticModel.Compilation,
handlerInterface,
out var preciseReflectedRegistration))
{
preciseReflectedRegistrations.Add(preciseReflectedRegistration);
continue;
}
// Some closed handler interfaces still contain runtime-only type shapes such as arrays closed over
// non-public element types. For those rare cases keep the narrow implementation lookup, but let the
// generated registry discover the exact supported interfaces from the implementation type at runtime.
return new HandlerCandidateAnalysis(
implementationTypeDisplayName,
GetLogDisplayName(handlerInterface),
implementationLogName));
implementationLogName,
ImmutableArray<HandlerRegistrationSpec>.Empty,
ImmutableArray<ReflectedImplementationRegistrationSpec>.Empty,
ImmutableArray<PreciseReflectedRegistrationSpec>.Empty,
GetReflectionTypeMetadataName(type));
}
return new HandlerCandidateAnalysis(
@ -127,6 +139,7 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
implementationLogName,
registrations.ToImmutable(),
reflectedImplementationRegistrations.ToImmutable(),
preciseReflectedRegistrations.ToImmutable(),
canReferenceImplementation ? null : GetReflectionTypeMetadataName(type));
}
@ -170,6 +183,7 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
candidate.ImplementationLogName,
candidate.Registrations,
candidate.ReflectedImplementationRegistrations,
candidate.PreciseReflectedRegistrations,
candidate.ReflectionTypeMetadataName));
}
@ -214,6 +228,114 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
string.Equals(definitionMetadataName, IStreamRequestHandlerMetadataName, StringComparison.Ordinal);
}
private static bool TryCreatePreciseReflectedRegistration(
Compilation compilation,
INamedTypeSymbol handlerInterface,
out PreciseReflectedRegistrationSpec registration)
{
var openHandlerTypeDisplayName = handlerInterface.OriginalDefinition
.ConstructUnboundGenericType()
.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat);
var typeArguments = ImmutableArray.CreateBuilder<RuntimeTypeReferenceSpec>(handlerInterface.TypeArguments.Length);
foreach (var typeArgument in handlerInterface.TypeArguments)
{
if (!TryCreateRuntimeTypeReference(compilation, typeArgument, out var runtimeTypeReference))
{
registration = default;
return false;
}
typeArguments.Add(runtimeTypeReference!);
}
registration = new PreciseReflectedRegistrationSpec(
openHandlerTypeDisplayName,
GetLogDisplayName(handlerInterface),
typeArguments.ToImmutable());
return true;
}
private static bool TryCreateRuntimeTypeReference(
Compilation compilation,
ITypeSymbol type,
out RuntimeTypeReferenceSpec? runtimeTypeReference)
{
if (CanReferenceFromGeneratedRegistry(type))
{
runtimeTypeReference = RuntimeTypeReferenceSpec.FromDirectReference(
type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat));
return true;
}
if (type is IArrayTypeSymbol arrayType &&
TryCreateRuntimeTypeReference(compilation, arrayType.ElementType, out var elementTypeReference))
{
runtimeTypeReference = RuntimeTypeReferenceSpec.FromArray(elementTypeReference!, arrayType.Rank);
return true;
}
if (type is INamedTypeSymbol genericNamedType &&
genericNamedType.IsGenericType &&
!genericNamedType.IsUnboundGenericType &&
TryCreateGenericTypeDefinitionReference(compilation, genericNamedType, out var genericTypeDefinitionReference))
{
var genericTypeArguments =
ImmutableArray.CreateBuilder<RuntimeTypeReferenceSpec>(genericNamedType.TypeArguments.Length);
foreach (var typeArgument in genericNamedType.TypeArguments)
{
if (!TryCreateRuntimeTypeReference(compilation, typeArgument, out var genericTypeArgumentReference))
{
runtimeTypeReference = null;
return false;
}
genericTypeArguments.Add(genericTypeArgumentReference!);
}
runtimeTypeReference = RuntimeTypeReferenceSpec.FromConstructedGeneric(
genericTypeDefinitionReference!,
genericTypeArguments.ToImmutable());
return true;
}
if (type is INamedTypeSymbol namedType &&
SymbolEqualityComparer.Default.Equals(namedType.ContainingAssembly, compilation.Assembly))
{
runtimeTypeReference = RuntimeTypeReferenceSpec.FromReflectionLookup(
GetReflectionTypeMetadataName(namedType));
return true;
}
runtimeTypeReference = null;
return false;
}
private static bool TryCreateGenericTypeDefinitionReference(
Compilation compilation,
INamedTypeSymbol genericNamedType,
out RuntimeTypeReferenceSpec? genericTypeDefinitionReference)
{
var genericTypeDefinition = genericNamedType.OriginalDefinition;
if (CanReferenceFromGeneratedRegistry(genericTypeDefinition))
{
genericTypeDefinitionReference = RuntimeTypeReferenceSpec.FromDirectReference(
genericTypeDefinition
.ConstructUnboundGenericType()
.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat));
return true;
}
if (SymbolEqualityComparer.Default.Equals(genericTypeDefinition.ContainingAssembly, compilation.Assembly))
{
genericTypeDefinitionReference = RuntimeTypeReferenceSpec.FromReflectionLookup(
GetReflectionTypeMetadataName(genericTypeDefinition));
return true;
}
genericTypeDefinitionReference = null;
return false;
}
private static bool CanReferenceFromGeneratedRegistry(ITypeSymbol type)
{
switch (type)
@ -319,9 +441,12 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
{
var hasReflectedImplementationRegistrations = registrations.Any(static registration =>
!registration.ReflectedImplementationRegistrations.IsDefaultOrEmpty);
var hasPreciseReflectedRegistrations = registrations.Any(static registration =>
!registration.PreciseReflectedRegistrations.IsDefaultOrEmpty);
var hasFullReflectionRegistrations = registrations.Any(static registration =>
!string.IsNullOrWhiteSpace(registration.ReflectionTypeMetadataName) &&
registration.ReflectedImplementationRegistrations.IsDefaultOrEmpty);
registration.ReflectedImplementationRegistrations.IsDefaultOrEmpty &&
registration.PreciseReflectedRegistrations.IsDefaultOrEmpty);
var builder = new StringBuilder();
builder.AppendLine("// <auto-generated />");
builder.AppendLine("#nullable enable");
@ -354,7 +479,7 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
builder.AppendLine(" throw new global::System.ArgumentNullException(nameof(services));");
builder.AppendLine(" if (logger is null)");
builder.AppendLine(" throw new global::System.ArgumentNullException(nameof(logger));");
if (hasReflectedImplementationRegistrations || hasFullReflectionRegistrations)
if (hasReflectedImplementationRegistrations || hasPreciseReflectedRegistrations || hasFullReflectionRegistrations)
{
builder.AppendLine();
builder.Append(" var registryAssembly = typeof(global::");
@ -376,6 +501,12 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
continue;
}
if (!registration.PreciseReflectedRegistrations.IsDefaultOrEmpty)
{
AppendPreciseReflectedRegistrations(builder, registration, registrationIndex);
continue;
}
if (!string.IsNullOrWhiteSpace(registration.ReflectionTypeMetadataName))
{
AppendReflectionRegistration(builder, registration.ReflectionTypeMetadataName!);
@ -457,6 +588,197 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
builder.AppendLine(" }");
}
private static void AppendPreciseReflectedRegistrations(
StringBuilder builder,
ImplementationRegistrationSpec registration,
int registrationIndex)
{
var implementationVariableName = $"implementationType{registrationIndex}";
if (string.IsNullOrWhiteSpace(registration.ReflectionTypeMetadataName))
{
builder.Append(" var ");
builder.Append(implementationVariableName);
builder.Append(" = typeof(");
builder.Append(registration.ImplementationTypeDisplayName);
builder.AppendLine(");");
}
else
{
var implementationReflectionTypeMetadataName = registration.ReflectionTypeMetadataName!;
builder.Append(" var ");
builder.Append(implementationVariableName);
builder.Append(" = registryAssembly.GetType(\"");
builder.Append(EscapeStringLiteral(implementationReflectionTypeMetadataName));
builder.AppendLine("\", throwOnError: false, ignoreCase: false);");
}
builder.Append(" if (");
builder.Append(implementationVariableName);
builder.AppendLine(" is not null)");
builder.AppendLine(" {");
for (var registrationOffset = 0;
registrationOffset < registration.PreciseReflectedRegistrations.Length;
registrationOffset++)
{
var reflectedRegistration = registration.PreciseReflectedRegistrations[registrationOffset];
var registrationVariablePrefix = $"serviceType{registrationIndex}_{registrationOffset}";
AppendPreciseReflectedTypeResolution(
builder,
reflectedRegistration.ServiceTypeArguments,
registrationVariablePrefix,
implementationVariableName,
reflectedRegistration.OpenHandlerTypeDisplayName,
registration.ImplementationLogName,
reflectedRegistration.HandlerInterfaceLogName,
3);
}
builder.AppendLine(" }");
}
private static void AppendPreciseReflectedTypeResolution(
StringBuilder builder,
ImmutableArray<RuntimeTypeReferenceSpec> serviceTypeArguments,
string registrationVariablePrefix,
string implementationVariableName,
string openHandlerTypeDisplayName,
string implementationLogName,
string handlerInterfaceLogName,
int indentLevel)
{
var indent = new string(' ', indentLevel * 4);
var nestedIndent = new string(' ', (indentLevel + 1) * 4);
var resolvedArgumentNames = new string[serviceTypeArguments.Length];
var reflectedArgumentNames = new List<string>();
for (var argumentIndex = 0; argumentIndex < serviceTypeArguments.Length; argumentIndex++)
{
resolvedArgumentNames[argumentIndex] = AppendRuntimeTypeReferenceResolution(
builder,
serviceTypeArguments[argumentIndex],
$"{registrationVariablePrefix}Argument{argumentIndex}",
reflectedArgumentNames,
indent);
}
if (reflectedArgumentNames.Count > 0)
{
builder.Append(indent);
builder.Append("if (");
for (var index = 0; index < reflectedArgumentNames.Count; index++)
{
if (index > 0)
builder.Append(" && ");
builder.Append(reflectedArgumentNames[index]);
builder.Append(" is not null");
}
builder.AppendLine(")");
builder.Append(indent);
builder.AppendLine("{");
indent = nestedIndent;
}
builder.Append(indent);
builder.Append("var ");
builder.Append(registrationVariablePrefix);
builder.Append(" = typeof(");
builder.Append(openHandlerTypeDisplayName);
builder.Append(").MakeGenericType(");
for (var index = 0; index < resolvedArgumentNames.Length; index++)
{
if (index > 0)
builder.Append(", ");
builder.Append(resolvedArgumentNames[index]);
}
builder.AppendLine(");");
builder.Append(indent);
builder.AppendLine("global::Microsoft.Extensions.DependencyInjection.ServiceCollectionServiceExtensions.AddTransient(");
builder.Append(indent);
builder.AppendLine(" services,");
builder.Append(indent);
builder.Append(" ");
builder.Append(registrationVariablePrefix);
builder.AppendLine(",");
builder.Append(indent);
builder.Append(" ");
builder.Append(implementationVariableName);
builder.AppendLine(");");
builder.Append(indent);
builder.Append("logger.Debug(\"Registered CQRS handler ");
builder.Append(EscapeStringLiteral(implementationLogName));
builder.Append(" as ");
builder.Append(EscapeStringLiteral(handlerInterfaceLogName));
builder.AppendLine(".\");");
if (reflectedArgumentNames.Count > 0)
{
builder.Append(new string(' ', indentLevel * 4));
builder.AppendLine("}");
}
}
private static string AppendRuntimeTypeReferenceResolution(
StringBuilder builder,
RuntimeTypeReferenceSpec runtimeTypeReference,
string variableBaseName,
ICollection<string> reflectedArgumentNames,
string indent)
{
if (!string.IsNullOrWhiteSpace(runtimeTypeReference.TypeDisplayName))
return $"typeof({runtimeTypeReference.TypeDisplayName})";
if (runtimeTypeReference.ArrayElementTypeReference is not null)
{
var elementExpression = AppendRuntimeTypeReferenceResolution(
builder,
runtimeTypeReference.ArrayElementTypeReference,
$"{variableBaseName}Element",
reflectedArgumentNames,
indent);
return runtimeTypeReference.ArrayRank == 1
? $"{elementExpression}.MakeArrayType()"
: $"{elementExpression}.MakeArrayType({runtimeTypeReference.ArrayRank})";
}
if (runtimeTypeReference.GenericTypeDefinitionReference is not null)
{
var genericTypeDefinitionExpression = AppendRuntimeTypeReferenceResolution(
builder,
runtimeTypeReference.GenericTypeDefinitionReference,
$"{variableBaseName}GenericDefinition",
reflectedArgumentNames,
indent);
var genericArgumentExpressions = new string[runtimeTypeReference.GenericTypeArguments.Length];
for (var argumentIndex = 0; argumentIndex < runtimeTypeReference.GenericTypeArguments.Length; argumentIndex++)
{
genericArgumentExpressions[argumentIndex] = AppendRuntimeTypeReferenceResolution(
builder,
runtimeTypeReference.GenericTypeArguments[argumentIndex],
$"{variableBaseName}GenericArgument{argumentIndex}",
reflectedArgumentNames,
indent);
}
return $"{genericTypeDefinitionExpression}.MakeGenericType({string.Join(", ", genericArgumentExpressions)})";
}
var reflectionTypeMetadataName = runtimeTypeReference.ReflectionTypeMetadataName!;
reflectedArgumentNames.Add(variableBaseName);
builder.Append(indent);
builder.Append("var ");
builder.Append(variableBaseName);
builder.Append(" = registryAssembly.GetType(\"");
builder.Append(EscapeStringLiteral(reflectionTypeMetadataName));
builder.AppendLine("\", throwOnError: false, ignoreCase: false);");
return variableBaseName;
}
private static void AppendReflectionHelpers(StringBuilder builder)
{
// Emit the runtime helper methods only when at least one handler requires metadata-name lookup.
@ -589,11 +911,52 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
string HandlerInterfaceDisplayName,
string HandlerInterfaceLogName);
private sealed record RuntimeTypeReferenceSpec(
string? TypeDisplayName,
string? ReflectionTypeMetadataName,
RuntimeTypeReferenceSpec? ArrayElementTypeReference,
int ArrayRank,
RuntimeTypeReferenceSpec? GenericTypeDefinitionReference,
ImmutableArray<RuntimeTypeReferenceSpec> GenericTypeArguments)
{
public static RuntimeTypeReferenceSpec FromDirectReference(string typeDisplayName)
{
return new RuntimeTypeReferenceSpec(typeDisplayName, null, null, 0, null,
ImmutableArray<RuntimeTypeReferenceSpec>.Empty);
}
public static RuntimeTypeReferenceSpec FromReflectionLookup(string reflectionTypeMetadataName)
{
return new RuntimeTypeReferenceSpec(null, reflectionTypeMetadataName, null, 0, null,
ImmutableArray<RuntimeTypeReferenceSpec>.Empty);
}
public static RuntimeTypeReferenceSpec FromArray(RuntimeTypeReferenceSpec elementTypeReference, int arrayRank)
{
return new RuntimeTypeReferenceSpec(null, null, elementTypeReference, arrayRank, null,
ImmutableArray<RuntimeTypeReferenceSpec>.Empty);
}
public static RuntimeTypeReferenceSpec FromConstructedGeneric(
RuntimeTypeReferenceSpec genericTypeDefinitionReference,
ImmutableArray<RuntimeTypeReferenceSpec> genericTypeArguments)
{
return new RuntimeTypeReferenceSpec(null, null, null, 0, genericTypeDefinitionReference,
genericTypeArguments);
}
}
private readonly record struct PreciseReflectedRegistrationSpec(
string OpenHandlerTypeDisplayName,
string HandlerInterfaceLogName,
ImmutableArray<RuntimeTypeReferenceSpec> ServiceTypeArguments);
private readonly record struct ImplementationRegistrationSpec(
string ImplementationTypeDisplayName,
string ImplementationLogName,
ImmutableArray<HandlerRegistrationSpec> DirectRegistrations,
ImmutableArray<ReflectedImplementationRegistrationSpec> ReflectedImplementationRegistrations,
ImmutableArray<PreciseReflectedRegistrationSpec> PreciseReflectedRegistrations,
string? ReflectionTypeMetadataName);
private readonly struct HandlerCandidateAnalysis : IEquatable<HandlerCandidateAnalysis>
@ -603,12 +966,14 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
string implementationLogName,
ImmutableArray<HandlerRegistrationSpec> registrations,
ImmutableArray<ReflectedImplementationRegistrationSpec> reflectedImplementationRegistrations,
ImmutableArray<PreciseReflectedRegistrationSpec> preciseReflectedRegistrations,
string? reflectionTypeMetadataName)
{
ImplementationTypeDisplayName = implementationTypeDisplayName;
ImplementationLogName = implementationLogName;
Registrations = registrations;
ReflectedImplementationRegistrations = reflectedImplementationRegistrations;
PreciseReflectedRegistrations = preciseReflectedRegistrations;
ReflectionTypeMetadataName = reflectionTypeMetadataName;
}
@ -620,6 +985,8 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
public ImmutableArray<ReflectedImplementationRegistrationSpec> ReflectedImplementationRegistrations { get; }
public ImmutableArray<PreciseReflectedRegistrationSpec> PreciseReflectedRegistrations { get; }
public string? ReflectionTypeMetadataName { get; }
public bool Equals(HandlerCandidateAnalysis other)
@ -630,7 +997,8 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
!string.Equals(ReflectionTypeMetadataName, other.ReflectionTypeMetadataName,
StringComparison.Ordinal) ||
Registrations.Length != other.Registrations.Length ||
ReflectedImplementationRegistrations.Length != other.ReflectedImplementationRegistrations.Length)
ReflectedImplementationRegistrations.Length != other.ReflectedImplementationRegistrations.Length ||
PreciseReflectedRegistrations.Length != other.PreciseReflectedRegistrations.Length)
{
return false;
}
@ -648,6 +1016,12 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
return false;
}
for (var index = 0; index < PreciseReflectedRegistrations.Length; index++)
{
if (!PreciseReflectedRegistrations[index].Equals(other.PreciseReflectedRegistrations[index]))
return false;
}
return true;
}
@ -676,6 +1050,11 @@ public sealed class CqrsHandlerRegistryGenerator : IIncrementalGenerator
hashCode = (hashCode * 397) ^ reflectedImplementationRegistration.GetHashCode();
}
foreach (var preciseReflectedRegistration in PreciseReflectedRegistrations)
{
hashCode = (hashCode * 397) ^ preciseReflectedRegistration.GetHashCode();
}
return hashCode;
}
}