diff --git a/mindformers/parallel_core/training_graph/transformer/utils.py b/mindformers/parallel_core/training_graph/transformer/utils.py index 46243be006db13f10f1481465371bb62b4b3a303..61bc60b411962678c4080498f9ec31c5793aae1f 100644 --- a/mindformers/parallel_core/training_graph/transformer/utils.py +++ b/mindformers/parallel_core/training_graph/transformer/utils.py @@ -159,7 +159,6 @@ class LayerSetting: self.pp_interleave_num = pp_interleave_num if use_pp_interleave else 1 self.offset = np.array(offset, np.int32) self._check_inputs() - self.offset = np.broadcast_to(self.offset, (self.pp_interleave_num, self.pp)) self.is_zbv = ms.get_auto_parallel_context("pipeline_scheduler") == "zero_bubble_v" avg_layer = self.num_layers // (self.pp * self.pp_interleave_num) @@ -571,17 +570,39 @@ class LayerSetting: def _check_inputs(self): """Check the inputs of offset.""" + total_stages = self.pp * self.pp_interleave_num + if total_stages <= 1: + self.offset = np.zeros((1, 1), dtype=np.int32) + return + base_layers = self.num_layers // total_stages + remainder_layers = self.num_layers % total_stages + if self.offset.ndim >= 1 and self.offset.shape[-1] != self.pp: raise ValueError(f"offset.shape[-1] should equal to `pp` ({self.pp}), " f"but got ({self.offset.shape[-1]}). `offset`: {self.offset}") if self.offset.ndim >= 2 and self.offset.shape[-2] != self.pp_interleave_num: raise ValueError(f"offset.shape[-2] should equal to `pp_interleave_num` ({self.pp_interleave_num}), " f"but got ({self.offset.shape[-2]}). `offset`: {self.offset}") - if self.offset.sum() != self.num_layers % (self.pp * self.pp_interleave_num): - r = self.num_layers % (self.pp * self.pp_interleave_num) + if self.offset.sum() != remainder_layers: raise ValueError(f"The sum of `offset` ({self.offset.sum()}) should equal to remainder of `num_layers` " f"({self.num_layers}) % (pp ({self.pp}) * pp_interleave_num ({self.pp_interleave_num})) " - f"= {r}") + f"= {remainder_layers}") + + # Broadcast offset + self.offset = np.broadcast_to(self.offset, (self.pp_interleave_num, self.pp)) + actual_layers = self.offset.flatten() + base_layers + + # Ensure head and tail stages have non-negative layer counts (0 is allowed) + if actual_layers[0] < 0: + raise ValueError(f"Head stage has negative layers. Offset must be ≥ {-base_layers}.") + if actual_layers[-1] < 0: + raise ValueError(f"Tail stage has negative layers. Offset must be ≥ {-base_layers}.") + + # Ensure all middle stages have at least 1 layer + if total_stages > 2: + middle_layers = actual_layers[1:-1] + if np.any(middle_layers < 1): + raise ValueError(f"Some middle stage has fewer than 1 layer. Offset must be ≥ {1 - base_layers}.") @staticmethod def set_pattern_recompute(layer, p_list, add_prim_attr=False, set_on=True, info=''):