From 7f857cbdabfa8cfceda6f85c14349ac71fd3215b Mon Sep 17 00:00:00 2001 From: hangangqiang Date: Wed, 11 Jun 2025 09:05:33 +0800 Subject: [PATCH] compatible with helper --- mindspore_gs/ptq/ptq/quant.py | 12 +++++----- .../daily_test/llama2-13b/daily_test_llama.sh | 24 +++++++++++++++---- 2 files changed, 26 insertions(+), 10 deletions(-) diff --git a/mindspore_gs/ptq/ptq/quant.py b/mindspore_gs/ptq/ptq/quant.py index 6070cfc4..3abca14d 100644 --- a/mindspore_gs/ptq/ptq/quant.py +++ b/mindspore_gs/ptq/ptq/quant.py @@ -165,7 +165,9 @@ class PTQ(CompAlgo): self.decoder_layer_types.append(LLamaDecodeLayer) self.decoder_layer_types.append(ParallelTransformerLayer) - def generate(network, input_ids): + def generate(network, input_ids, helper=None): + if isinstance(helper, NetworkHelper): + return helper.generate(network, input_ids, do_sample=False, max_new_tokens=1) return network.generate(input_ids, do_sample=False, max_new_tokens=1) self._generate_func = generate @@ -279,8 +281,6 @@ class PTQ(CompAlgo): self._config.update_comm_info() self._get_decoder_layers(network) if self._config.mode == PTQMode.DEPLOY: - os.environ.pop('FORCE_EAGER', None) - os.environ.pop('MS_JIT', None) logger.info("unset environ FORCE_EAGER and MS_JIT because of PTQMode.DEPLOY mode") for i in tqdm.tqdm(range(len(self.decoder_layers)), desc="Running PTQ Deploy..."): layer_name, layer = self.decoder_layers[i] @@ -301,7 +301,7 @@ class PTQ(CompAlgo): logger.info("Analysis network structure.") start_time = time.time() logger.info(f"Catching inputs for first decoder layer with {datasets.get_dataset_size()} datasets samples.") - catcher, network = self._get_first_layer_input(network, datasets) + catcher, network = self._get_first_layer_input(network, datasets, network_helper) all_args = catcher.args all_kwargs = catcher.kwargs logger.info(f"_get_first_layer_input time cost {time.time() - start_time}") @@ -352,7 +352,7 @@ class PTQ(CompAlgo): logger.info(f"{i}th layer offload network time cost {time.time() - start_time}") return network - def _get_first_layer_input(self, network: Cell, ds=None): + def _get_first_layer_input(self, network: Cell, ds=None, helper=None): """get first layer input""" catcher = InputCatcher() catcher.patch(self.decoder_layers[0][1]) @@ -364,7 +364,7 @@ class PTQ(CompAlgo): logger.info(f"Calibrating: dataset count: {data_count}/{total_count}") input_ids = ds_item['input_ids'].asnumpy() try: - self._generate_func(network, input_ids) + self._generate_func(network, input_ids, helper) except GeneratorExit: if hasattr(network, "block_mgr") and network.block_mgr: network.block_mgr.clear_cache() diff --git a/tests/daily_test/llama2-13b/daily_test_llama.sh b/tests/daily_test/llama2-13b/daily_test_llama.sh index f69a6ae6..a3f85222 100644 --- a/tests/daily_test/llama2-13b/daily_test_llama.sh +++ b/tests/daily_test/llama2-13b/daily_test_llama.sh @@ -108,7 +108,6 @@ sed_mode() eval() { unset FORCE_EAGER - unset MS_JIT echo "enter test workspace." cd ws || exit 1 echo "${1}, save yaml to ${2}_eval_log/" @@ -124,10 +123,28 @@ eval() cd .. } +eval_pynative() +{ + export FORCE_EAGER=true + export MS_INTERNAL_DISABLE_CUSTOM_KERNEL_LIST=PageAttention + echo "enter test workspace." + cd ws || exit 1 + echo "${1}, save yaml to ${2}_pynative_eval_log/" + mkdir -p "${2}_pynative_eval_log" + cp "${3}" "${2}_pynative_eval_log/" + echo "msrun --worker_num=2 --local_worker_num=2 --master_port=${port} --log_dir=${2}_pynative_eval_log --join=True --cluster_time_out=300 python daily_eval.py -c ${3} -s ${dataset} -n 2000 > pynative_eval_${2}_log 2>&1 &" > "${2}_pynative_eval_log/cmd.sh" + msrun --worker_num=2 --local_worker_num=2 --master_port=${port} --log_dir="${2}_pynative_eval_log" --join=True --cluster_time_out=300 python daily_eval.py -c "${3}" -s ${dataset} -n 2000 > "pynative_eval_${2}_log" 2>&1 & + sleep ${sleep_time} + pid=$(ps -u | grep msrun | grep "daily_eval.py" | grep -v grep | awk -F ' ' '{print$2}') + echo "waiting pid ${pid}" + tail --pid ${pid} -f "${2}_pynative_eval_log/worker_0.log" + sleep ${sleep_time} + cd .. +} + quant() { export FORCE_EAGER=true - export MS_JIT=0 echo "enter test workspace." cd ws || exit 1 echo "${1}, save yaml to ${2}_quant_log/" @@ -146,7 +163,6 @@ quant() quant_awq() { export FORCE_EAGER=true - export MS_JIT=0 echo "enter test workspace." cd ws || exit 1 echo "${1}, save yaml to ${2}_quant_log/" @@ -165,7 +181,6 @@ quant_awq() quant_gptq() { export FORCE_EAGER=true - export MS_JIT=0 echo "enter test workspace." cd ws || exit 1 echo "${1}, save yaml to ${2}_quant_log/" @@ -295,6 +310,7 @@ echo_result() } echo_result "fp16 llama2-13b" "${BASEPATH}/ws/fp16_eval_log/worker_0.log" +echo_result "fp16 pynative llama2-13b" "${BASEPATH}/ws/fp16_pynative_eval_log/worker_0.log" echo_result "fp16->a8w8 llama2-13b" "${BASEPATH}/ws/fp16-a8w8_eval_log/worker_0.log" echo_result "fp16->a16w8 llama2-13b" "${BASEPATH}/ws/fp16-a16w8_eval_log/worker_0.log" echo_result "fp16->a8w8c8 llama2-13b" "${BASEPATH}/ws/fp16-a8w8c8_eval_log/worker_0.log" -- Gitee