diff --git a/mindspeed/core/memory/adaptive_recomputing/adaptive_recompute.py b/mindspeed/core/memory/adaptive_recomputing/adaptive_recompute.py index 2232bcccb385a06dbd99fdda97505dc88db7396e..4d84ec3816c8e912a2f7960c1a3d651629ddbb42 100644 --- a/mindspeed/core/memory/adaptive_recomputing/adaptive_recompute.py +++ b/mindspeed/core/memory/adaptive_recomputing/adaptive_recompute.py @@ -299,8 +299,10 @@ class AdaptiveRecomputePolicy: cur_pp_noop_layers = [] pp_size = all_args.pipeline_model_parallel_size or 1 layers_per_pp = all_args.num_layers // pp_size + vpp_layer = all_args.num_layers_per_virtual_pipeline_stage or layers_per_pp + vpp_layers = vpp_layer * pp_size for i in noop_layers: - pp_id = i // layers_per_pp + pp_id = (i % vpp_layers) // vpp_layer if pp_id == cur_pp_rank: cur_pp_noop_layers.append(i) return cur_pp_noop_layers diff --git a/mindspeed/core/memory/adaptive_recomputing/prefetch.py b/mindspeed/core/memory/adaptive_recomputing/prefetch.py index 1e791587bebee4a1ce6fefd7ddbb03cfa311cd51..8abd0b7e234c19aad9316d6b5fcc6e32d04b78a7 100644 --- a/mindspeed/core/memory/adaptive_recomputing/prefetch.py +++ b/mindspeed/core/memory/adaptive_recomputing/prefetch.py @@ -34,7 +34,6 @@ class SwapTensor: self.prefetch_data_ptr = tensor.data_ptr() self.storage_data_ptr = tensor.storage().data_ptr() self.layer_id = None - self.last_tensor = False self.first_tensor = False self.is_slice_tensor = tensor.storage().size() != tensor.numel() self.stream = None @@ -131,6 +130,8 @@ class SwapPrefetch: return True if ori_tensor.storage().size() == 0: return True + if ori_tensor.storage().size() != ori_tensor.numel(): + return True return False @@ -230,6 +231,7 @@ class SwapPrefetch: self.first_layer_id = self.swap_tensors[0].layer_id elif self.prefetch_list and self.swap_tensors[0].layer_id <= self.prefetch_list[-1][-1][-1].layer_id: self.first_layer_id = self.swap_tensors[0].layer_id + first_resize_tensor = False for swap_tensor in self.swap_tensors: if self.swap_tensors[0].layer_id > self.first_layer_id and self.prefetch_list: swap_tensor.layer_index = len(self.prefetch_list[-1]) @@ -237,10 +239,12 @@ class SwapPrefetch: and swap_tensor.stat == "d2h": if not self.update_slice_tensor_stat(swap_tensor): continue + if not first_resize_tensor: + swap_tensor.first_tensor = True + first_resize_tensor = True # During synchronization, let the first tensor wait for d2h swap_tensor.wait_d2h_finished(swap_tensor.stream, swap_tensor.first_tensor) - self.swap_tensors[-1].last_tensor = True - self.swap_tensors[0].first_tensor = True + if self.swap_tensors[-1].stat == 'host': if self.swap_tensors[0].layer_id > self.first_layer_id and self.prefetch_list: self.prefetch_list[-1].append(self.swap_tensors) diff --git a/tests_extend/unit_tests/features/swap_attention/test_swap_attention.py b/tests_extend/unit_tests/features/swap_attention/test_swap_attention.py index 1c8e850879b4a43a824826ca97e3c1b6fd3aaa69..90ebb3bcbb031b265c62086192d0ece0b5856de3 100644 --- a/tests_extend/unit_tests/features/swap_attention/test_swap_attention.py +++ b/tests_extend/unit_tests/features/swap_attention/test_swap_attention.py @@ -67,8 +67,10 @@ class AdaptiveRecomputePolicy: cur_pp_rank = self.pp_rank pp_size = all_args.pipeline_model_parallel_size or 1 layers_per_pp = all_args.num_layers // pp_size + vpp_layer = all_args.num_layers_per_virtual_pipeline_stage or layers_per_pp + vpp_layers = vpp_layer * pp_size for i in noop_layers: - pp_id = i // layers_per_pp + pp_id = (i % vpp_layers) // vpp_layer if pp_id == cur_pp_rank: cur_pp_noop_layers.append(i) return cur_pp_noop_layers @@ -263,3 +265,26 @@ class TestSwapAttention(DistributedTest): [['0', '1'], ['', '']], [['0', '1'], ['', '']], [6, 7]) + + def test_swap_attention_cal_prefetch_list_enable_vpp_enable_multiple_noop_layers_with_inter_layer(self): + args = Config() + args.num_layers = 16 + args.pipeline_model_parallel_size = 4 + args.virtual_pipeline_model_parallel_size = 2 + args.num_layers_per_virtual_pipeline_stage = 2 + args.noop_layers = {0, 7} + args.enable_recompute_layers_per_pp_rank = True + arp = AdaptiveRecomputePolicy(args) + arp.pp_rank = 0 + self.check_result(arp, + [['', '1'], ['0', '1']], + [['', '1'], ['0', '1']], + [['', '1'], ['0', '1']], + [0]) + + arp.pp_rank = 3 + self.check_result(arp, + [['0', ''], ['0', '1']], + [['0', ''], ['0', '1']], + [['0', ''], ['0', '1']], + [7]) \ No newline at end of file