refactor(state): 优化状态机实现并增强上下文安全检查

- 在 ContextAwareStateBase 和 AsyncContextAwareStateBase 中添加架构上下文空值检查
- 引入 HashSet 用于快速检查状态注册状态,提升性能
- 重构注销逻辑,分离准备和完成阶段的处理
- 优化回退功能,跳过已注销的状态并使用 O(1) 复杂度检查
- 统一状态切换中的进入和退出逻辑处理
- 简化状态转换验证流程,提升代码可读性
- 添加辅助方法处理异步状态操作的统一入口
- [release ci]
This commit is contained in:
GeWuYou 2026-02-15 18:29:13 +08:00 committed by gewuyou
parent a5daadea96
commit 703328deb2
3 changed files with 153 additions and 153 deletions

View File

@ -76,7 +76,8 @@ public class AsyncContextAwareStateBase : IAsyncState, IContextAware, IDisposabl
/// <returns>架构上下文实例</returns>
public IArchitectureContext GetContext()
{
return _context!;
return _context ?? throw new InvalidOperationException(
$"Architecture context has not been set. Call {nameof(SetContext)} before accessing the context.");
}
/// <summary>

View File

@ -32,7 +32,8 @@ public class ContextAwareStateBase : IState, IContextAware, IDisposable
/// <returns>架构上下文实例</returns>
public IArchitectureContext GetContext()
{
return _context!;
return _context ?? throw new InvalidOperationException(
$"Architecture context has not been set. Call {nameof(SetContext)} before accessing the context.");
}
/// <summary>

View File

@ -9,6 +9,7 @@ namespace GFramework.Core.state;
public class StateMachine(int maxHistorySize = 10) : IStateMachine
{
private readonly object _lock = new();
private readonly HashSet<IState> _registeredStates = new(); // 优化:用于快速检查状态是否注册
private readonly Stack<IState> _stateHistory = new();
/// <summary>
@ -30,6 +31,7 @@ public class StateMachine(int maxHistorySize = 10) : IStateMachine
lock (_lock)
{
States[state.GetType()] = state;
_registeredStates.Add(state);
}
return this;
@ -41,26 +43,17 @@ public class StateMachine(int maxHistorySize = 10) : IStateMachine
/// <typeparam name="T">要注销的状态类型</typeparam>
public IStateMachine Unregister<T>() where T : IState
{
lock (_lock)
var stateToUnregister = PrepareUnregister<T>(out var isCurrentState);
if (stateToUnregister == null) return this;
// 如果是当前状态,执行同步退出
if (isCurrentState)
{
var type = typeof(T);
if (!States.TryGetValue(type, out var state)) return this;
// 如果当前状态是要注销的状态,则先执行退出逻辑
if (Current == state)
{
Current.OnExit(null);
Current = null;
}
// 从历史记录中移除该状态的所有引用
var tempStack = new Stack<IState>(_stateHistory.Reverse());
_stateHistory.Clear();
foreach (var historyState in tempStack.Where(s => s != state)) _stateHistory.Push(historyState);
States.Remove(type);
Current!.OnExit(null);
Current = null;
}
CompleteUnregister(stateToUnregister);
return this;
}
@ -70,36 +63,17 @@ public class StateMachine(int maxHistorySize = 10) : IStateMachine
/// <typeparam name="T">要注销的状态类型</typeparam>
public async Task<IStateMachine> UnregisterAsync<T>() where T : IState
{
IState? stateToUnregister;
var stateToUnregister = PrepareUnregister<T>(out var isCurrentState);
if (stateToUnregister == null) return this;
lock (_lock)
// 如果是当前状态,执行异步退出
if (isCurrentState)
{
var type = typeof(T);
if (!States.TryGetValue(type, out stateToUnregister)) return this;
}
// 如果当前状态是要注销的状态,则先执行退出逻辑
if (Current == stateToUnregister)
{
if (Current is IAsyncState asyncState)
await asyncState.OnExitAsync(null);
else
Current.OnExit(null);
await ExecuteExitAsync(Current!, null);
Current = null;
}
lock (_lock)
{
// 从历史记录中移除该状态的所有引用
var tempStack = new Stack<IState>(_stateHistory.Reverse());
_stateHistory.Clear();
foreach (var historyState in tempStack.Where(s => s != stateToUnregister))
_stateHistory.Push(historyState);
States.Remove(typeof(T));
}
CompleteUnregister(stateToUnregister);
return this;
}
@ -128,12 +102,7 @@ public class StateMachine(int maxHistorySize = 10) : IStateMachine
if (Current == null) return true;
// 如果当前状态是异步状态,使用异步方法
if (Current is IAsyncState asyncState)
return await asyncState.CanTransitionToAsync(target);
// 否则使用同步方法
return Current.CanTransitionTo(target);
return await CanTransitionToAsync(Current, target);
}
/// <summary>
@ -146,11 +115,9 @@ public class StateMachine(int maxHistorySize = 10) : IStateMachine
{
lock (_lock)
{
// 检查目标状态是否已注册
if (!States.TryGetValue(typeof(T), out var target))
throw new InvalidOperationException("State not registered.");
throw new InvalidOperationException($"State {typeof(T).Name} not registered.");
// 验证当前状态是否可以转换到目标状态
if (Current != null && !Current.CanTransitionTo(target))
{
OnTransitionRejected(Current, target);
@ -174,20 +141,14 @@ public class StateMachine(int maxHistorySize = 10) : IStateMachine
lock (_lock)
{
// 检查目标状态是否已注册
if (!States.TryGetValue(typeof(T), out target!))
throw new InvalidOperationException("State not registered.");
throw new InvalidOperationException($"State {typeof(T).Name} not registered.");
}
// 验证当前状态是否可以转换到目标状态(异步)
// 验证转换(在锁外执行异步操作
if (Current != null)
{
bool canTransition;
if (Current is IAsyncState asyncState)
canTransition = await asyncState.CanTransitionToAsync(target);
else
canTransition = Current.CanTransitionTo(target);
var canTransition = await CanTransitionToAsync(Current, target);
if (!canTransition)
{
await OnTransitionRejectedAsync(Current, target);
@ -258,21 +219,11 @@ public class StateMachine(int maxHistorySize = 10) : IStateMachine
/// <returns>如果成功回退则返回true否则返回false</returns>
public bool GoBack()
{
lock (_lock)
{
if (_stateHistory.Count == 0) return false;
var previousState = FindValidPreviousState();
if (previousState == null) return false;
var previousState = _stateHistory.Pop();
// 检查上一个状态是否仍然注册
if (!States.ContainsValue(previousState))
// 如果状态已被注销,继续尝试更早的状态
return GoBack();
// 回退时不添加到历史记录
ChangeInternalWithoutHistory(previousState);
return true;
}
ChangeInternalWithoutHistory(previousState);
return true;
}
/// <summary>
@ -281,28 +232,9 @@ public class StateMachine(int maxHistorySize = 10) : IStateMachine
/// <returns>如果成功回退则返回true否则返回false</returns>
public async Task<bool> GoBackAsync()
{
IState? previousState = null;
var previousState = FindValidPreviousState();
if (previousState == null) return false;
// 循环查找有效的历史状态
while (previousState == null)
{
lock (_lock)
{
if (_stateHistory.Count == 0)
return false;
var candidate = _stateHistory.Pop();
// 检查状态是否仍然注册
if (States.ContainsValue(candidate))
{
previousState = candidate;
}
// 如果状态已被注销,继续循环尝试更早的状态
}
}
// 回退时不添加到历史记录
await ChangeInternalWithoutHistoryAsync(previousState);
return true;
}
@ -318,6 +250,63 @@ public class StateMachine(int maxHistorySize = 10) : IStateMachine
}
}
/// <summary>
/// 准备注销操作,返回要注销的状态
/// </summary>
private IState? PrepareUnregister<T>(out bool isCurrentState) where T : IState
{
lock (_lock)
{
var type = typeof(T);
if (!States.TryGetValue(type, out var state))
{
isCurrentState = false;
return null;
}
isCurrentState = Current == state;
return state;
}
}
/// <summary>
/// 完成注销操作,清理历史记录和状态字典
/// </summary>
private void CompleteUnregister(IState stateToUnregister)
{
lock (_lock)
{
// 从历史记录中移除该状态的所有引用
var tempStack = new Stack<IState>(_stateHistory.Reverse());
_stateHistory.Clear();
foreach (var historyState in tempStack.Where(s => s != stateToUnregister))
_stateHistory.Push(historyState);
States.Remove(stateToUnregister.GetType());
_registeredStates.Remove(stateToUnregister);
}
}
/// <summary>
/// 查找有效的上一个状态(跳过已注销的状态)
/// </summary>
private IState? FindValidPreviousState()
{
lock (_lock)
{
while (_stateHistory.Count > 0)
{
var candidate = _stateHistory.Pop();
// 使用 HashSet 快速检查O(1) 复杂度
if (_registeredStates.Contains(candidate))
return candidate;
}
return null;
}
}
/// <summary>
/// 内部状态切换方法(不记录历史),用于回退操作
/// </summary>
@ -347,17 +336,9 @@ public class StateMachine(int maxHistorySize = 10) : IStateMachine
var old = Current;
await OnStateChangingAsync(old, next);
if (old is IAsyncState asyncOld)
await asyncOld.OnExitAsync(next);
else
old?.OnExit(next);
await ExecuteExitAsync(old, next);
Current = next;
if (Current is IAsyncState asyncCurrent)
await asyncCurrent.OnEnterAsync(old);
else
Current.OnEnter(old);
await ExecuteEnterAsync(Current, old);
await OnStateChangedAsync(old, Current);
}
@ -368,10 +349,8 @@ public class StateMachine(int maxHistorySize = 10) : IStateMachine
/// <param name="next">下一个状态实例</param>
protected virtual void ChangeInternal(IState next)
{
// 检查是否为相同状态,避免不必要的切换
if (Current == next) return;
// 验证当前状态是否允许切换到目标状态
if (Current != null && !Current.CanTransitionTo(next))
{
OnTransitionRejected(Current, next);
@ -381,20 +360,7 @@ public class StateMachine(int maxHistorySize = 10) : IStateMachine
var old = Current;
OnStateChanging(old, next);
// 将当前状态添加到历史记录
if (Current != null)
{
_stateHistory.Push(Current);
// 限制历史记录大小
if (_stateHistory.Count > maxHistorySize)
{
// 移除最旧的记录(栈底元素)
var tempStack = new Stack<IState>(_stateHistory.Reverse().Skip(1));
_stateHistory.Clear();
foreach (var state in tempStack.Reverse()) _stateHistory.Push(state);
}
}
AddToHistory(Current);
old?.OnExit(next);
Current = next;
@ -409,47 +375,79 @@ public class StateMachine(int maxHistorySize = 10) : IStateMachine
/// <param name="next">下一个状态实例</param>
protected virtual async Task ChangeInternalAsync(IState next)
{
// 检查是否为相同状态,避免不必要的切换
if (Current == next) return;
var old = Current;
await OnStateChangingAsync(old, next);
// 将当前状态添加到历史记录
lock (_lock)
{
if (Current != null)
{
_stateHistory.Push(Current);
// 限制历史记录大小
if (_stateHistory.Count > maxHistorySize)
{
// 移除最旧的记录(栈底元素)
var tempStack = new Stack<IState>(_stateHistory.Reverse().Skip(1));
_stateHistory.Clear();
foreach (var state in tempStack.Reverse()) _stateHistory.Push(state);
}
}
}
// 执行退出逻辑(异步或同步)
if (old is IAsyncState asyncOld)
await asyncOld.OnExitAsync(next);
else
old?.OnExit(next);
AddToHistory(Current);
await ExecuteExitAsync(old, next);
Current = next;
// 执行进入逻辑(异步或同步)
if (Current is IAsyncState asyncCurrent)
await asyncCurrent.OnEnterAsync(old);
else
Current.OnEnter(old);
await ExecuteEnterAsync(Current, old);
await OnStateChangedAsync(old, Current);
}
/// <summary>
/// 将状态添加到历史记录
/// </summary>
private void AddToHistory(IState? state)
{
if (state == null) return;
lock (_lock)
{
_stateHistory.Push(state);
// 限制历史记录大小
if (_stateHistory.Count > maxHistorySize)
{
var tempStack = new Stack<IState>(_stateHistory.Reverse().Skip(1));
_stateHistory.Clear();
foreach (var s in tempStack.Reverse())
_stateHistory.Push(s);
}
}
}
/// <summary>
/// 执行状态进入逻辑(智能判断同步/异步)
/// </summary>
private static async Task ExecuteEnterAsync(IState? state, IState? from)
{
if (state == null) return;
if (state is IAsyncState asyncState)
await asyncState.OnEnterAsync(from);
else
state.OnEnter(from);
}
/// <summary>
/// 执行状态退出逻辑(智能判断同步/异步)
/// </summary>
private static async Task ExecuteExitAsync(IState? state, IState? to)
{
if (state == null) return;
if (state is IAsyncState asyncState)
await asyncState.OnExitAsync(to);
else
state.OnExit(to);
}
/// <summary>
/// 检查是否可以转换到目标状态(智能判断同步/异步)
/// </summary>
private static async Task<bool> CanTransitionToAsync(IState current, IState target)
{
if (current is IAsyncState asyncState)
return await asyncState.CanTransitionToAsync(target);
return current.CanTransitionTo(target);
}
/// <summary>
/// 当状态转换被拒绝时的回调方法
/// </summary>