diff --git a/GFramework.Core/state/AsyncContextAwareStateBase.cs b/GFramework.Core/state/AsyncContextAwareStateBase.cs
index 77dfdec..266a610 100644
--- a/GFramework.Core/state/AsyncContextAwareStateBase.cs
+++ b/GFramework.Core/state/AsyncContextAwareStateBase.cs
@@ -76,7 +76,8 @@ public class AsyncContextAwareStateBase : IAsyncState, IContextAware, IDisposabl
/// 架构上下文实例
public IArchitectureContext GetContext()
{
- return _context!;
+ return _context ?? throw new InvalidOperationException(
+ $"Architecture context has not been set. Call {nameof(SetContext)} before accessing the context.");
}
///
diff --git a/GFramework.Core/state/ContextAwareStateBase.cs b/GFramework.Core/state/ContextAwareStateBase.cs
index 4f28d16..9bca0d9 100644
--- a/GFramework.Core/state/ContextAwareStateBase.cs
+++ b/GFramework.Core/state/ContextAwareStateBase.cs
@@ -32,7 +32,8 @@ public class ContextAwareStateBase : IState, IContextAware, IDisposable
/// 架构上下文实例
public IArchitectureContext GetContext()
{
- return _context!;
+ return _context ?? throw new InvalidOperationException(
+ $"Architecture context has not been set. Call {nameof(SetContext)} before accessing the context.");
}
///
diff --git a/GFramework.Core/state/StateMachine.cs b/GFramework.Core/state/StateMachine.cs
index 5b7a372..70f4495 100644
--- a/GFramework.Core/state/StateMachine.cs
+++ b/GFramework.Core/state/StateMachine.cs
@@ -9,6 +9,7 @@ namespace GFramework.Core.state;
public class StateMachine(int maxHistorySize = 10) : IStateMachine
{
private readonly object _lock = new();
+ private readonly HashSet _registeredStates = new(); // 优化:用于快速检查状态是否注册
private readonly Stack _stateHistory = new();
///
@@ -30,6 +31,7 @@ public class StateMachine(int maxHistorySize = 10) : IStateMachine
lock (_lock)
{
States[state.GetType()] = state;
+ _registeredStates.Add(state);
}
return this;
@@ -41,26 +43,17 @@ public class StateMachine(int maxHistorySize = 10) : IStateMachine
/// 要注销的状态类型
public IStateMachine Unregister() where T : IState
{
- lock (_lock)
+ var stateToUnregister = PrepareUnregister(out var isCurrentState);
+ if (stateToUnregister == null) return this;
+
+ // 如果是当前状态,执行同步退出
+ if (isCurrentState)
{
- var type = typeof(T);
- if (!States.TryGetValue(type, out var state)) return this;
-
- // 如果当前状态是要注销的状态,则先执行退出逻辑
- if (Current == state)
- {
- Current.OnExit(null);
- Current = null;
- }
-
- // 从历史记录中移除该状态的所有引用
- var tempStack = new Stack(_stateHistory.Reverse());
- _stateHistory.Clear();
- foreach (var historyState in tempStack.Where(s => s != state)) _stateHistory.Push(historyState);
-
- States.Remove(type);
+ Current!.OnExit(null);
+ Current = null;
}
+ CompleteUnregister(stateToUnregister);
return this;
}
@@ -70,36 +63,17 @@ public class StateMachine(int maxHistorySize = 10) : IStateMachine
/// 要注销的状态类型
public async Task UnregisterAsync() where T : IState
{
- IState? stateToUnregister;
+ var stateToUnregister = PrepareUnregister(out var isCurrentState);
+ if (stateToUnregister == null) return this;
- lock (_lock)
+ // 如果是当前状态,执行异步退出
+ if (isCurrentState)
{
- var type = typeof(T);
- if (!States.TryGetValue(type, out stateToUnregister)) return this;
- }
-
- // 如果当前状态是要注销的状态,则先执行退出逻辑
- if (Current == stateToUnregister)
- {
- if (Current is IAsyncState asyncState)
- await asyncState.OnExitAsync(null);
- else
- Current.OnExit(null);
-
+ await ExecuteExitAsync(Current!, null);
Current = null;
}
- lock (_lock)
- {
- // 从历史记录中移除该状态的所有引用
- var tempStack = new Stack(_stateHistory.Reverse());
- _stateHistory.Clear();
- foreach (var historyState in tempStack.Where(s => s != stateToUnregister))
- _stateHistory.Push(historyState);
-
- States.Remove(typeof(T));
- }
-
+ CompleteUnregister(stateToUnregister);
return this;
}
@@ -128,12 +102,7 @@ public class StateMachine(int maxHistorySize = 10) : IStateMachine
if (Current == null) return true;
- // 如果当前状态是异步状态,使用异步方法
- if (Current is IAsyncState asyncState)
- return await asyncState.CanTransitionToAsync(target);
-
- // 否则使用同步方法
- return Current.CanTransitionTo(target);
+ return await CanTransitionToAsync(Current, target);
}
///
@@ -146,11 +115,9 @@ public class StateMachine(int maxHistorySize = 10) : IStateMachine
{
lock (_lock)
{
- // 检查目标状态是否已注册
if (!States.TryGetValue(typeof(T), out var target))
- throw new InvalidOperationException("State not registered.");
+ throw new InvalidOperationException($"State {typeof(T).Name} not registered.");
- // 验证当前状态是否可以转换到目标状态
if (Current != null && !Current.CanTransitionTo(target))
{
OnTransitionRejected(Current, target);
@@ -174,20 +141,14 @@ public class StateMachine(int maxHistorySize = 10) : IStateMachine
lock (_lock)
{
- // 检查目标状态是否已注册
if (!States.TryGetValue(typeof(T), out target!))
- throw new InvalidOperationException("State not registered.");
+ throw new InvalidOperationException($"State {typeof(T).Name} not registered.");
}
- // 验证当前状态是否可以转换到目标状态(异步)
+ // 验证转换(在锁外执行异步操作)
if (Current != null)
{
- bool canTransition;
- if (Current is IAsyncState asyncState)
- canTransition = await asyncState.CanTransitionToAsync(target);
- else
- canTransition = Current.CanTransitionTo(target);
-
+ var canTransition = await CanTransitionToAsync(Current, target);
if (!canTransition)
{
await OnTransitionRejectedAsync(Current, target);
@@ -258,21 +219,11 @@ public class StateMachine(int maxHistorySize = 10) : IStateMachine
/// 如果成功回退则返回true,否则返回false
public bool GoBack()
{
- lock (_lock)
- {
- if (_stateHistory.Count == 0) return false;
+ var previousState = FindValidPreviousState();
+ if (previousState == null) return false;
- var previousState = _stateHistory.Pop();
-
- // 检查上一个状态是否仍然注册
- if (!States.ContainsValue(previousState))
- // 如果状态已被注销,继续尝试更早的状态
- return GoBack();
-
- // 回退时不添加到历史记录
- ChangeInternalWithoutHistory(previousState);
- return true;
- }
+ ChangeInternalWithoutHistory(previousState);
+ return true;
}
///
@@ -281,28 +232,9 @@ public class StateMachine(int maxHistorySize = 10) : IStateMachine
/// 如果成功回退则返回true,否则返回false
public async Task GoBackAsync()
{
- IState? previousState = null;
+ var previousState = FindValidPreviousState();
+ if (previousState == null) return false;
- // 循环查找有效的历史状态
- while (previousState == null)
- {
- lock (_lock)
- {
- if (_stateHistory.Count == 0)
- return false;
-
- var candidate = _stateHistory.Pop();
-
- // 检查状态是否仍然注册
- if (States.ContainsValue(candidate))
- {
- previousState = candidate;
- }
- // 如果状态已被注销,继续循环尝试更早的状态
- }
- }
-
- // 回退时不添加到历史记录
await ChangeInternalWithoutHistoryAsync(previousState);
return true;
}
@@ -318,6 +250,63 @@ public class StateMachine(int maxHistorySize = 10) : IStateMachine
}
}
+ ///
+ /// 准备注销操作,返回要注销的状态
+ ///
+ private IState? PrepareUnregister(out bool isCurrentState) where T : IState
+ {
+ lock (_lock)
+ {
+ var type = typeof(T);
+ if (!States.TryGetValue(type, out var state))
+ {
+ isCurrentState = false;
+ return null;
+ }
+
+ isCurrentState = Current == state;
+ return state;
+ }
+ }
+
+ ///
+ /// 完成注销操作,清理历史记录和状态字典
+ ///
+ private void CompleteUnregister(IState stateToUnregister)
+ {
+ lock (_lock)
+ {
+ // 从历史记录中移除该状态的所有引用
+ var tempStack = new Stack(_stateHistory.Reverse());
+ _stateHistory.Clear();
+ foreach (var historyState in tempStack.Where(s => s != stateToUnregister))
+ _stateHistory.Push(historyState);
+
+ States.Remove(stateToUnregister.GetType());
+ _registeredStates.Remove(stateToUnregister);
+ }
+ }
+
+ ///
+ /// 查找有效的上一个状态(跳过已注销的状态)
+ ///
+ private IState? FindValidPreviousState()
+ {
+ lock (_lock)
+ {
+ while (_stateHistory.Count > 0)
+ {
+ var candidate = _stateHistory.Pop();
+
+ // 使用 HashSet 快速检查,O(1) 复杂度
+ if (_registeredStates.Contains(candidate))
+ return candidate;
+ }
+
+ return null;
+ }
+ }
+
///
/// 内部状态切换方法(不记录历史),用于回退操作
///
@@ -347,17 +336,9 @@ public class StateMachine(int maxHistorySize = 10) : IStateMachine
var old = Current;
await OnStateChangingAsync(old, next);
- if (old is IAsyncState asyncOld)
- await asyncOld.OnExitAsync(next);
- else
- old?.OnExit(next);
-
+ await ExecuteExitAsync(old, next);
Current = next;
-
- if (Current is IAsyncState asyncCurrent)
- await asyncCurrent.OnEnterAsync(old);
- else
- Current.OnEnter(old);
+ await ExecuteEnterAsync(Current, old);
await OnStateChangedAsync(old, Current);
}
@@ -368,10 +349,8 @@ public class StateMachine(int maxHistorySize = 10) : IStateMachine
/// 下一个状态实例
protected virtual void ChangeInternal(IState next)
{
- // 检查是否为相同状态,避免不必要的切换
if (Current == next) return;
- // 验证当前状态是否允许切换到目标状态
if (Current != null && !Current.CanTransitionTo(next))
{
OnTransitionRejected(Current, next);
@@ -381,20 +360,7 @@ public class StateMachine(int maxHistorySize = 10) : IStateMachine
var old = Current;
OnStateChanging(old, next);
- // 将当前状态添加到历史记录
- if (Current != null)
- {
- _stateHistory.Push(Current);
-
- // 限制历史记录大小
- if (_stateHistory.Count > maxHistorySize)
- {
- // 移除最旧的记录(栈底元素)
- var tempStack = new Stack(_stateHistory.Reverse().Skip(1));
- _stateHistory.Clear();
- foreach (var state in tempStack.Reverse()) _stateHistory.Push(state);
- }
- }
+ AddToHistory(Current);
old?.OnExit(next);
Current = next;
@@ -409,47 +375,79 @@ public class StateMachine(int maxHistorySize = 10) : IStateMachine
/// 下一个状态实例
protected virtual async Task ChangeInternalAsync(IState next)
{
- // 检查是否为相同状态,避免不必要的切换
if (Current == next) return;
var old = Current;
await OnStateChangingAsync(old, next);
- // 将当前状态添加到历史记录
- lock (_lock)
- {
- if (Current != null)
- {
- _stateHistory.Push(Current);
-
- // 限制历史记录大小
- if (_stateHistory.Count > maxHistorySize)
- {
- // 移除最旧的记录(栈底元素)
- var tempStack = new Stack(_stateHistory.Reverse().Skip(1));
- _stateHistory.Clear();
- foreach (var state in tempStack.Reverse()) _stateHistory.Push(state);
- }
- }
- }
-
- // 执行退出逻辑(异步或同步)
- if (old is IAsyncState asyncOld)
- await asyncOld.OnExitAsync(next);
- else
- old?.OnExit(next);
+ AddToHistory(Current);
+ await ExecuteExitAsync(old, next);
Current = next;
-
- // 执行进入逻辑(异步或同步)
- if (Current is IAsyncState asyncCurrent)
- await asyncCurrent.OnEnterAsync(old);
- else
- Current.OnEnter(old);
+ await ExecuteEnterAsync(Current, old);
await OnStateChangedAsync(old, Current);
}
+ ///
+ /// 将状态添加到历史记录
+ ///
+ private void AddToHistory(IState? state)
+ {
+ if (state == null) return;
+
+ lock (_lock)
+ {
+ _stateHistory.Push(state);
+
+ // 限制历史记录大小
+ if (_stateHistory.Count > maxHistorySize)
+ {
+ var tempStack = new Stack(_stateHistory.Reverse().Skip(1));
+ _stateHistory.Clear();
+ foreach (var s in tempStack.Reverse())
+ _stateHistory.Push(s);
+ }
+ }
+ }
+
+ ///
+ /// 执行状态进入逻辑(智能判断同步/异步)
+ ///
+ private static async Task ExecuteEnterAsync(IState? state, IState? from)
+ {
+ if (state == null) return;
+
+ if (state is IAsyncState asyncState)
+ await asyncState.OnEnterAsync(from);
+ else
+ state.OnEnter(from);
+ }
+
+ ///
+ /// 执行状态退出逻辑(智能判断同步/异步)
+ ///
+ private static async Task ExecuteExitAsync(IState? state, IState? to)
+ {
+ if (state == null) return;
+
+ if (state is IAsyncState asyncState)
+ await asyncState.OnExitAsync(to);
+ else
+ state.OnExit(to);
+ }
+
+ ///
+ /// 检查是否可以转换到目标状态(智能判断同步/异步)
+ ///
+ private static async Task CanTransitionToAsync(IState current, IState target)
+ {
+ if (current is IAsyncState asyncState)
+ return await asyncState.CanTransitionToAsync(target);
+
+ return current.CanTransitionTo(target);
+ }
+
///
/// 当状态转换被拒绝时的回调方法
///