From 0f3bc143bca8713bbde34105c8c6e10887aa3362 Mon Sep 17 00:00:00 2001 From: zyw_hw Date: Tue, 23 Dec 2025 15:07:37 +0800 Subject: [PATCH] fix case error --- .../test_training_state_monitor.py | 22 ++----------------- 1 file changed, 2 insertions(+), 20 deletions(-) diff --git a/tests/st/test_ut/test_core/test_callback/test_training_state_monitor.py b/tests/st/test_ut/test_core/test_callback/test_training_state_monitor.py index e463361dc..58a8af923 100644 --- a/tests/st/test_ut/test_core/test_callback/test_training_state_monitor.py +++ b/tests/st/test_ut/test_core/test_callback/test_training_state_monitor.py @@ -741,27 +741,9 @@ class TestTrainingStateMonitorPrintStableRank: mock_tensorboard, mock_group_size): """Test _print_stable_rank with 3D tensor in MoE 'all' mode""" - # Create a simple wrapper to avoid numpy array comparison issue in tuple - class _ArrayLike: - """self-define array to compare values""" - def __init__(self, arr): - self._arr = np.array(arr) - - def __eq__(self, other): - return False if np.isscalar(other) else np.array_equal(self._arr, other) - - def __iter__(self): - return iter(self._arr) - - def __getitem__(self, i): - return self._arr[i] - - def __array__(self): - return self._arr - mock_get_stable_rank.return_value = ( - _ArrayLike([2.5, 2.6, 2.7]), - _ArrayLike([3.0, 3.1, 3.2]) + np.array([2.5, 2.6, 2.7]), + np.array([3.0, 3.1, 3.2]) ) config = { -- Gitee