diff --git a/mindformers/models/qwen3/modeling_qwen3_infer.py b/mindformers/models/qwen3/modeling_qwen3_infer.py index 90310dad065f34e76586c260a915ef475539c5ea..c53baeb735c23373e97da960287eb2a96fdba1d7 100644 --- a/mindformers/models/qwen3/modeling_qwen3_infer.py +++ b/mindformers/models/qwen3/modeling_qwen3_infer.py @@ -28,6 +28,7 @@ from mindformers.parallel_core.inference.base_models.gpt.gpt_model import GPTMod from mindformers.parallel_core.inference.base_models.gpt.gpt_layer_specs import get_gpt_layer_local_spec from mindformers.parallel_core.process_group_config import ModelCommProcessGroups, default_model_comm_pgs from mindformers.parallel_core.inference.model_utils import InferModelMixin +from mindformers.parallel_core.inference.quantization.utils import get_quant_config from .configuration_qwen3 import Qwen3Config @@ -62,6 +63,7 @@ class InferenceQwen3ForCausalLM(Qwen3PreTrainedModel, InferModelMixin): else: self.model_comm_pgs = default_model_comm_pgs + self.quant_config = get_quant_config(self.config, self.weight_mapping) self.pad_token_id = self.config.pad_token_id self.vocab_size = config.vocab_size self.max_position_embeddings = config.max_position_embeddings @@ -85,7 +87,8 @@ class InferenceQwen3ForCausalLM(Qwen3PreTrainedModel, InferModelMixin): share_embeddings_and_output_weights=self.config.tie_word_embeddings, pre_process=config.pre_process, post_process=config.post_process, - model_comm_pgs=self.model_comm_pgs) + model_comm_pgs=self.model_comm_pgs, + quant_config=self.quant_config) @jit def construct(self, input_ids, hidden_states=None, positions=None, batch_valid_length=None,