From a750bc12b22435143089e13e1b9ebb6ad3b055b8 Mon Sep 17 00:00:00 2001 From: pu-zhe Date: Tue, 22 Jul 2025 09:28:37 +0800 Subject: [PATCH 01/11] Cosyvoice2 is supported on 300I --- .../CosyVoice2/300I/diff_CosyVoice_300I.patch | 754 ++++++++++++++ .../audio/CosyVoice/CosyVoice2/300I/infer.py | 108 ++ .../CosyVoice2/300I/modeling_qwen2.py | 919 ++++++++++++++++++ .../diff_CosyVoice_800I.patch} | 0 .../CosyVoice/CosyVoice2/{ => 800I}/infer.py | 0 .../CosyVoice2/{ => 800I}/modeling_qwen2.py | 0 .../audio/CosyVoice/CosyVoice2/README.md | 28 +- .../CosyVoice/CosyVoice2/requirements.txt | 1 + 8 files changed, 1799 insertions(+), 11 deletions(-) create mode 100644 ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/300I/diff_CosyVoice_300I.patch create mode 100644 ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/300I/infer.py create mode 100644 ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/300I/modeling_qwen2.py rename ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/{diff_CosyVoice.patch => 800I/diff_CosyVoice_800I.patch} (100%) mode change 100755 => 100644 rename ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/{ => 800I}/infer.py (100%) mode change 100755 => 100644 rename ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/{ => 800I}/modeling_qwen2.py (100%) mode change 100755 => 100644 diff --git a/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/300I/diff_CosyVoice_300I.patch b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/300I/diff_CosyVoice_300I.patch new file mode 100644 index 0000000000..c6bf0eefef --- /dev/null +++ b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/300I/diff_CosyVoice_300I.patch @@ -0,0 +1,754 @@ +diff --git a/cosyvoice/cli/cosyvoice.py b/cosyvoice/cli/cosyvoice.py +index e2d62e2..ac45938 100644 +--- a/cosyvoice/cli/cosyvoice.py ++++ b/cosyvoice/cli/cosyvoice.py +@@ -13,11 +13,14 @@ + # limitations under the License. + import os + import time ++import platform ++import datetime + from typing import Generator + from tqdm import tqdm + from hyperpyyaml import load_hyperpyyaml + from modelscope import snapshot_download + import torch ++from ais_bench.infer.interface import InferSession + from cosyvoice.cli.frontend import CosyVoiceFrontEnd + from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model + from cosyvoice.utils.file_utils import logging +@@ -68,9 +71,12 @@ class CosyVoice: + model_input = self.frontend.frontend_sft(i, spk_id) + start_time = time.time() + logging.info('synthesis text {}'.format(i)) +- for model_output in self.model.tts(**model_input, stream=stream, speed=speed): ++ for i, model_output in enumerate(self.model.tts(**model_input, stream=stream, speed=speed)): + speech_len = model_output['tts_speech'].shape[1] / self.sample_rate +- logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len)) ++ if i == 0: ++ logging.info('yield speech len {}, rtf {}, TTFT {}'.format(speech_len, (time.time() - start_time) / speech_len, time.time() - start_time)) ++ else: ++ logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len)) + yield model_output + start_time = time.time() + +@@ -82,9 +88,12 @@ class CosyVoice: + model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k, self.sample_rate) + start_time = time.time() + logging.info('synthesis text {}'.format(i)) +- for model_output in self.model.tts(**model_input, stream=stream, speed=speed): ++ for i, model_output in enumerate(self.model.tts(**model_input, stream=stream, speed=speed)): + speech_len = model_output['tts_speech'].shape[1] / self.sample_rate +- logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len)) ++ if i == 0: ++ logging.info('yield speech len {}, rtf {}, TTFT {}'.format(speech_len, (time.time() - start_time) / speech_len, time.time() - start_time)) ++ else: ++ logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len)) + yield model_output + start_time = time.time() + +@@ -93,9 +102,12 @@ class CosyVoice: + model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k, self.sample_rate) + start_time = time.time() + logging.info('synthesis text {}'.format(i)) +- for model_output in self.model.tts(**model_input, stream=stream, speed=speed): ++ for i, model_output in enumerate(self.model.tts(**model_input, stream=stream, speed=speed)): + speech_len = model_output['tts_speech'].shape[1] / self.sample_rate +- logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len)) ++ if i == 0: ++ logging.info('yield speech len {}, rtf {}, TTFT {}'.format(speech_len, (time.time() - start_time) / speech_len, time.time() - start_time)) ++ else: ++ logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len)) + yield model_output + start_time = time.time() + +@@ -108,25 +120,31 @@ class CosyVoice: + model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text) + start_time = time.time() + logging.info('synthesis text {}'.format(i)) +- for model_output in self.model.tts(**model_input, stream=stream, speed=speed): ++ for i, model_output in enumerate(self.model.tts(**model_input, stream=stream, speed=speed)): + speech_len = model_output['tts_speech'].shape[1] / self.sample_rate +- logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len)) ++ if i == 0: ++ logging.info('yield speech len {}, rtf {}, TTFT {}'.format(speech_len, (time.time() - start_time) / speech_len, time.time() - start_time)) ++ else: ++ logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len)) + yield model_output + start_time = time.time() + + def inference_vc(self, source_speech_16k, prompt_speech_16k, stream=False, speed=1.0): + model_input = self.frontend.frontend_vc(source_speech_16k, prompt_speech_16k, self.sample_rate) + start_time = time.time() +- for model_output in self.model.vc(**model_input, stream=stream, speed=speed): ++ for i, model_output in enumerate(self.model.vc(**model_input, stream=stream, speed=speed)): + speech_len = model_output['tts_speech'].shape[1] / self.sample_rate +- logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len)) ++ if i == 0: ++ logging.info('yield speech len {}, rtf {}, TTFT {}'.format(speech_len, (time.time() - start_time) / speech_len, time.time() - start_time)) ++ else: ++ logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len)) + yield model_output + start_time = time.time() + + + class CosyVoice2(CosyVoice): + +- def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False): ++ def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, load_om=False): + self.instruct = True if '-Instruct' in model_dir else False + self.model_dir = model_dir + self.fp16 = fp16 +@@ -155,6 +173,16 @@ class CosyVoice2(CosyVoice): + self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'), + '{}/flow.decoder.estimator.fp32.onnx'.format(model_dir), + self.fp16) ++ if load_om: ++ arch = platform.machine() ++ system = platform.system().lower() ++ flow_om = InferSession(0, '{}/flow_{}_{}.om'.format(model_dir, system ,arch)) ++ flow_om_static = InferSession(0, '{}/flow_static.om'.format(model_dir)) ++ speech_om = InferSession(0, '{}/speech_{}_{}.om'.format(model_dir, system ,arch)) ++ self.frontend.speech_om = speech_om ++ self.frontend.flow_om = flow_om ++ self.model.flow.decoder.flow_om_static = flow_om_static ++ self.model.flow.decoder.flow_om = flow_om + del configs + + def inference_instruct(self, *args, **kwargs): +@@ -171,3 +199,19 @@ class CosyVoice2(CosyVoice): + logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len)) + yield model_output + start_time = time.time() ++ ++ def inference_sft_streaming_input(self, tts_text, char_idx, spk_id, user_id, input_end, stream=False, speed=1.0, text_frontend=True): ++ for i in [tts_text]: ++ model_input = self.frontend.frontend_sft(i, spk_id) ++ model_input["user_id"] = user_id ++ model_input["input_end"] = input_end ++ model_input['char_idx'] = char_idx ++ ++ start_time = time.time() ++ # print('synthesis text {}'.format(i)) ++ for model_output in self.model.tts_streaming_input(**model_input, stream=stream, speed=speed): ++ speech_len = model_output['tts_speech'].shape[1] / self.sample_rate ++ print("finish 1 chunk inference ", datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')) ++ logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len)) ++ yield model_output ++ start_time = time.time() +diff --git a/cosyvoice/cli/frontend.py b/cosyvoice/cli/frontend.py +index 6e10f00..25ad767 100644 +--- a/cosyvoice/cli/frontend.py ++++ b/cosyvoice/cli/frontend.py +@@ -71,6 +71,8 @@ class CosyVoiceFrontEnd: + self.zh_tn_model = ZhNormalizer(remove_erhua=False, full_to_half=False, overwrite_cache=True) + self.en_tn_model = EnNormalizer() + self.inflect_parser = inflect.engine() ++ self.speech_om = None ++ self.flow_om = None + + def _extract_text_token(self, text): + if isinstance(text, Generator): +@@ -92,11 +94,16 @@ class CosyVoiceFrontEnd: + def _extract_speech_token(self, speech): + assert speech.shape[1] / 16000 <= 30, 'do not support extract speech token for audio longer than 30s' + feat = whisper.log_mel_spectrogram(speech, n_mels=128) +- speech_token = self.speech_tokenizer_session.run(None, +- {self.speech_tokenizer_session.get_inputs()[0].name: +- feat.detach().cpu().numpy(), +- self.speech_tokenizer_session.get_inputs()[1].name: +- np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist() ++ if torch.npu.is_available() and self.speech_om: ++ feed = [feat.detach().cpu().numpy(), np.array([feat.shape[2]], dtype=np.int32)] ++ speech_token = self.speech_om.infer(feed, mode='dymshape', custom_sizes=[100000000])[0].flatten().tolist() ++ self.flow_om.set_context() ++ else: ++ speech_token = self.speech_tokenizer_session.run(None, ++ {self.speech_tokenizer_session.get_inputs()[0].name: ++ feat.detach().cpu().numpy(), ++ self.speech_tokenizer_session.get_inputs()[1].name: ++ np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist() + speech_token = torch.tensor([speech_token], dtype=torch.int32).to(self.device) + speech_token_len = torch.tensor([speech_token.shape[1]], dtype=torch.int32).to(self.device) + return speech_token, speech_token_len +diff --git a/cosyvoice/cli/model.py b/cosyvoice/cli/model.py +index 9ebf8cb..3db0b4f 100644 +--- a/cosyvoice/cli/model.py ++++ b/cosyvoice/cli/model.py +@@ -99,7 +99,7 @@ class CosyVoiceModel: + self.flow.decoder.estimator = self.flow.decoder.estimator_engine.create_execution_context() + + def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid): +- with self.llm_context: ++ with self.llm_context(): + if isinstance(text, Generator): + assert isinstance(self, CosyVoice2Model), 'streaming input text is only implemented for CosyVoice2!' + for i in self.llm.inference_bistream(text=text, +@@ -307,13 +307,25 @@ class CosyVoice2Model(CosyVoiceModel): + self.speech_window = np.hamming(2 * self.source_cache_len) + # rtf and decoding related + self.stream_scale_factor = 1 +- self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext() ++ if torch.cuda.is_available(): ++ stream = torch.cuda.Stream(device=self.device) ++ self.llm_context = lambda: torch.cuda.stream(stream) ++ else: ++ self.llm_context = lambda: contextlib.nullcontext() + self.lock = threading.Lock() + # dict used to store session related variable + self.tts_speech_token_dict = {} + self.llm_end_dict = {} + self.hift_cache_dict = {} + ++ # add for support streaming input ++ self.first_chunk_size = 20 ++ self.token_offset_dict = {} ++ self.prompt_text_dict = {} ++ self.prompt_speech_token_dict = {} ++ self.speech_feat_dict = {} ++ self.embedding_dict = {} ++ + def load_jit(self, flow_encoder_model): + flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device) + self.flow.encoder = flow_encoder +@@ -409,3 +421,83 @@ class CosyVoice2Model(CosyVoiceModel): + self.tts_speech_token_dict.pop(this_uuid) + self.llm_end_dict.pop(this_uuid) + torch.cuda.empty_cache() ++ ++ def tts_streaming_input(self, text, char_idx, flow_embedding, llm_embedding=torch.zeros(0, 192), ++ prompt_text=torch.zeros(1, 0, dtype=torch.int32), ++ llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), ++ flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), ++ prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs): ++ this_uuid = kwargs.get("user_id", "AscendDefaultUser") ++ if this_uuid not in self.tts_speech_token_dict: ++ self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False ++ self.hift_cache_dict[this_uuid] = None ++ self.token_offset_dict[this_uuid] = 0 ++ ++ self.prompt_text_dict[this_uuid] = prompt_text ++ self.prompt_speech_token_dict[this_uuid] = flow_prompt_speech_token ++ self.speech_feat_dict[this_uuid] = prompt_speech_feat ++ self.embedding_dict[this_uuid] = flow_embedding ++ else: ++ prompt_text = self.prompt_text_dict[this_uuid] ++ llm_prompt_speech_token = self.prompt_speech_token_dict[this_uuid] ++ flow_prompt_speech_token = self.prompt_speech_token_dict[this_uuid] ++ flow_embedding = self.embedding_dict[this_uuid] ++ llm_embedding = self.embedding_dict[this_uuid] ++ prompt_speech_feat = self.speech_feat_dict[this_uuid] ++ ++ for i in self.llm.inference_bistream_streaming_input(text=text, ++ char_idx=torch.tensor([char_idx]).to(self.device), ++ prompt_text=prompt_text.to(self.device), ++ prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device), ++ prompt_speech_token=llm_prompt_speech_token.to(self.device), ++ prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device), ++ embedding=llm_embedding.to(self.device), ++ uuid=this_uuid, input_end=kwargs['input_end']): ++ self.tts_speech_token_dict[this_uuid].append(i) ++ ++ assert stream is True, "output must be streaming" ++ ++ while True: ++ is_first_chunk_ready = (self.token_offset_dict[this_uuid] == 0 and len(self.tts_speech_token_dict[this_uuid]) >= self.first_chunk_size + self.flow.pre_lookahead_len) ++ is_next_chunk_ready = (self.token_offset_dict[this_uuid] > 0 and len(self.tts_speech_token_dict[this_uuid]) - self.token_offset_dict[this_uuid] >= self.token_hop_len + self.flow.pre_lookahead_len) ++ if is_first_chunk_ready or is_next_chunk_ready: ++ if self.token_offset_dict[this_uuid] == 0: ++ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:self.first_chunk_size + self.flow.pre_lookahead_len]).unsqueeze(dim=0) ++ else: ++ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:self.token_offset_dict[this_uuid] + self.token_hop_len + self.flow.pre_lookahead_len]).unsqueeze(dim=0) # 0-53, 0-103, 0-153... ++ this_tts_speech = self.token2wav(token=this_tts_speech_token, ++ prompt_token=flow_prompt_speech_token, ++ prompt_feat=prompt_speech_feat, ++ embedding=flow_embedding, ++ uuid=this_uuid, ++ token_offset=self.token_offset_dict[this_uuid], ++ finalize=False) ++ if self.token_offset_dict[this_uuid] == 0: ++ self.token_offset_dict[this_uuid] += self.first_chunk_size ++ else: ++ self.token_offset_dict[this_uuid] += self.token_hop_len ++ yield {'tts_speech': this_tts_speech.cpu()} ++ # 是否需要退出循环(token 不够下一次推理) ++ if len(self.tts_speech_token_dict[this_uuid]) - self.token_offset_dict[this_uuid] < self.token_hop_len + self.flow.pre_lookahead_len: ++ break ++ ++ if kwargs['input_end'] is True: ++ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0) ++ this_tts_speech = self.token2wav(token=this_tts_speech_token, ++ prompt_token=flow_prompt_speech_token, ++ prompt_feat=prompt_speech_feat, ++ embedding=flow_embedding, ++ uuid=this_uuid, ++ token_offset=self.token_offset_dict[this_uuid], ++ finalize=True) ++ yield {'tts_speech': this_tts_speech.cpu()} ++ ++ self.tts_speech_token_dict.pop(this_uuid) ++ self.llm_end_dict.pop(this_uuid) ++ self.hift_cache_dict.pop(this_uuid) ++ ++ self.token_offset_dict.pop(this_uuid) ++ self.prompt_text_dict.pop(this_uuid) ++ self.prompt_speech_token_dict.pop(this_uuid) ++ self.speech_feat_dict.pop(this_uuid) ++ self.embedding_dict.pop(this_uuid) +diff --git a/cosyvoice/flow/flow_matching.py b/cosyvoice/flow/flow_matching.py +index 6a60f6d..fbe7545 100644 +--- a/cosyvoice/flow/flow_matching.py ++++ b/cosyvoice/flow/flow_matching.py +@@ -14,6 +14,7 @@ + import threading + import torch + import torch.nn.functional as F ++import numpy as np + from matcha.models.components.flow_matching import BASECFM + + +@@ -32,6 +33,8 @@ class ConditionalCFM(BASECFM): + # Just change the architecture of the estimator here + self.estimator = estimator + self.lock = threading.Lock() ++ self.flow_om = None ++ self.flow_om_static = None + + @torch.inference_mode() + def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, flow_cache=torch.zeros(1, 80, 0, 2)): +@@ -105,12 +108,26 @@ class ConditionalCFM(BASECFM): + t_in[:] = t.unsqueeze(0) + spks_in[0] = spks + cond_in[0] = cond +- dphi_dt = self.forward_estimator( +- x_in, mask_in, +- mu_in, t_in, +- spks_in, +- cond_in +- ) ++ # 动态分档推理, 在流式输出中,每次输出的token数目固定,可以采取动态分档模型执行推理 ++ if torch.npu.is_available() and self.flow_om_static and x.size(2)%100==0 and x.size(2)<800: ++ feed_list = [x_in, mask_in, mu_in, t_in, spks_in, cond_in] ++ feed = [i.cpu().detach().numpy().astype(np.float32) for i in feed_list] ++ dphi_dt = self.flow_om_static.infer(feed, mode="dymdims") ++ self.flow_om.set_context() ++ dphi_dt = torch.from_numpy(dphi_dt[0]).npu() ++ # 输出的token数目不固定场景采用动态模型推理 ++ elif torch.npu.is_available() and self.flow_om: ++ feed_list = [x_in, mask_in, mu_in, t_in, spks_in, cond_in] ++ feed = [i.cpu().detach().numpy().astype(np.float32) for i in feed_list] ++ dphi_dt = self.flow_om.infer(feed, mode="dymshape", custom_sizes=10000000) ++ dphi_dt = torch.from_numpy(dphi_dt[0]).npu() ++ else: ++ dphi_dt = self.forward_estimator( ++ x_in, mask_in, ++ mu_in, t_in, ++ spks_in, ++ cond_in ++ ) + dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0) + dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt) + x = x + dt * dphi_dt +diff --git a/cosyvoice/hifigan/generator.py b/cosyvoice/hifigan/generator.py +index c47bf05..7dd9fb0 100644 +--- a/cosyvoice/hifigan/generator.py ++++ b/cosyvoice/hifigan/generator.py +@@ -20,9 +20,11 @@ from scipy.signal import get_window + import torch + import torch.nn as nn + import torch.nn.functional as F ++import torch_npu + from torch.nn import Conv1d + from torch.nn import ConvTranspose1d + from torch.nn.utils import remove_weight_norm ++from torch.nn.utils.parametrize import remove_parametrizations + from torch.nn.utils.parametrizations import weight_norm + from torch.distributions.uniform import Uniform + +@@ -99,8 +101,8 @@ class ResBlock(torch.nn.Module): + + def remove_weight_norm(self): + for idx in range(len(self.convs1)): +- remove_weight_norm(self.convs1[idx]) +- remove_weight_norm(self.convs2[idx]) ++ remove_parametrizations(self.convs1[idx], "weight") ++ remove_parametrizations(self.convs2[idx], "weight") + + + class SineGen(torch.nn.Module): +@@ -319,22 +321,19 @@ class HiFTGenerator(nn.Module): + def remove_weight_norm(self): + print('Removing weight norm...') + for l in self.ups: +- remove_weight_norm(l) ++ remove_parametrizations(l, 'weight') + for l in self.resblocks: + l.remove_weight_norm() +- remove_weight_norm(self.conv_pre) +- remove_weight_norm(self.conv_post) +- self.m_source.remove_weight_norm() +- for l in self.source_downs: +- remove_weight_norm(l) ++ remove_parametrizations(self.conv_pre, 'weight') ++ remove_parametrizations(self.conv_post, 'weight') + for l in self.source_resblocks: + l.remove_weight_norm() + + def _stft(self, x): + spec = torch.stft( +- x, +- self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(x.device), +- return_complex=True) ++ x.cpu(), ++ self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.cpu(), ++ return_complex=True).npu() + spec = torch.view_as_real(spec) # [B, F, TT, 2] + return spec[..., 0], spec[..., 1] + +@@ -342,13 +341,11 @@ class HiFTGenerator(nn.Module): + magnitude = torch.clip(magnitude, max=1e2) + real = magnitude * torch.cos(phase) + img = magnitude * torch.sin(phase) +- inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"], +- self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device)) ++ inverse_transform = torch.istft(torch.complex(real, img).cpu(), self.istft_params["n_fft"], self.istft_params["hop_len"], ++ self.istft_params["n_fft"], window=self.stft_window.cpu()).npu() + return inverse_transform + +- def decode(self, x: torch.Tensor, s: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor: +- s_stft_real, s_stft_imag = self._stft(s.squeeze(1)) +- s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1) ++ def decode(self, x: torch.Tensor, s_stft: torch.Tensor, index: torch.int) -> torch.Tensor: + + x = self.conv_pre(x) + for i in range(self.num_upsamples): +@@ -356,7 +353,7 @@ class HiFTGenerator(nn.Module): + x = self.ups[i](x) + + if i == self.num_upsamples - 1: +- x = self.reflection_pad(x) ++ x = torch.cat((x, x[:,:,-2:-1]), -1) + + # fusion + si = self.source_downs[i](s_stft) +@@ -373,12 +370,10 @@ class HiFTGenerator(nn.Module): + + x = F.leaky_relu(x) + x = self.conv_post(x) +- magnitude = torch.exp(x[:, :self.istft_params["n_fft"] // 2 + 1, :]) +- phase = torch.sin(x[:, self.istft_params["n_fft"] // 2 + 1:, :]) # actually, sin is redundancy ++ magnitude = torch.exp(x[:, :index, :]) ++ phase = torch.sin(x[:, index:, :]) # actually, sin is redundancy + +- x = self._istft(magnitude, phase) +- x = torch.clamp(x, -self.audio_limit, self.audio_limit) +- return x ++ return magnitude, phase + + def forward( + self, +@@ -407,5 +402,12 @@ class HiFTGenerator(nn.Module): + # use cache_source to avoid glitch + if cache_source.shape[2] != 0: + s[:, :, :cache_source.shape[2]] = cache_source +- generated_speech = self.decode(x=speech_feat, s=s) ++ # torchair编译,对decode函数做部分适配 ++ s_stft_real, s_stft_imag = self._stft(s.squeeze(1)) ++ s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1) ++ # 字典取值操作无法被dynamo编译,把decode内部的index拿到外面计算 ++ index = self.istft_params["n_fft"] // 2 + 1 ++ magnitude, phase = self.decode(x=speech_feat, s_stft=s_stft, index=index) ++ x = self._istft(magnitude, phase) ++ generated_speech = torch.clamp(x, -self.audio_limit, self.audio_limit) + return generated_speech, s +diff --git a/cosyvoice/llm/llm.py b/cosyvoice/llm/llm.py +index bbd3305..42d1ccf 100644 +--- a/cosyvoice/llm/llm.py ++++ b/cosyvoice/llm/llm.py +@@ -11,6 +11,7 @@ + # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + # See the License for the specific language governing permissions and + # limitations under the License. ++import math + from typing import Dict, Optional, Callable, List, Generator + import torch + from torch import nn +@@ -229,16 +230,17 @@ class Qwen2Encoder(torch.nn.Module): + super().__init__() + self.model = Qwen2ForCausalLM.from_pretrained(pretrain_path) + +- def forward_one_step(self, xs, masks, cache=None): +- input_masks = masks[:, -1, :] +- outs = self.model( +- inputs_embeds=xs, +- attention_mask=input_masks, +- output_hidden_states=True, +- return_dict=True, +- use_cache=True, +- past_key_values=cache, +- ) ++ def forward_one_step(self, xs, masks, prompt_length, cache=None): ++ with torch.no_grad(): ++ outs = self.model( ++ inputs_embeds=xs, ++ attention_mask=masks, ++ prompt_length=prompt_length, ++ output_hidden_states=True, ++ return_dict=True, ++ use_cache=True, ++ past_key_values=cache, ++ ) + xs = outs.hidden_states[-1] + new_cache = outs.past_key_values + return xs, new_cache +@@ -283,6 +285,15 @@ class Qwen2LM(TransformerLM): + self.sampling = sampling + self.mix_ratio = mix_ratio + ++ # 5. added for support streaming input ++ self.prompt_speech_token_emb_dict = {} ++ self.lm_input_dict = {} ++ self.out_tokens_dict = {} ++ self.cache_dict = {} ++ self.text_cache_dict = {} ++ self.next_fill_index = {} ++ self.prompt_length = {} ++ + @torch.inference_mode() + def inference( + self, +@@ -318,9 +329,17 @@ class Qwen2LM(TransformerLM): + # 5. step by step decode + out_tokens = [] + cache = None ++ input_length = lm_input.shape[1] + for i in range(max_len): ++ prompt_length = input_length + i ++ if i == 0: ++ seqlen_align = (lm_input.shape[1] + 15) // 16 * 16 if lm_input.shape[1] > 32 else 32 ++ masks = torch.tril(torch.ones((1, seqlen_align, seqlen_align), device=lm_input.device)).to(torch.bool).logical_not() ++ else: ++ masks = None + y_pred, cache = self.llm.forward_one_step(lm_input, +- masks=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool), ++ masks=masks, ++ prompt_length=prompt_length, + cache=cache) + logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1) + top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item() +@@ -331,7 +350,7 @@ class Qwen2LM(TransformerLM): + # in stream mode, yield token one by one + yield top_ids + out_tokens.append(top_ids) +- lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1) ++ lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1).detach().clone() + + @torch.inference_mode() + def inference_bistream( +@@ -392,8 +411,10 @@ class Qwen2LM(TransformerLM): + continue + while True: + seq_len = lm_input.shape[1] if cache is None else lm_input.shape[1] + cache[0][0].size(2) ++ seqlen_align = (seq_len + 15) // 16 * 16 if seq_len > 32 else 32 ++ masks = torch.tril(torch.ones((1, seqlen_align, seqlen_align), device=lm_input.device)).to(torch.bool) + y_pred, cache = self.llm.forward_one_step(lm_input, +- masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool), ++ masks=masks, + cache=cache) + logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1) + if next_fill_index != -1 and len(out_tokens) == next_fill_index: +@@ -418,8 +439,10 @@ class Qwen2LM(TransformerLM): + logging.info('no more text token, decode until met eos') + while True: + seq_len = lm_input.shape[1] if cache is None else lm_input.shape[1] + cache[0][0].size(2) ++ seqlen_align = (seq_len + 15) // 16 * 16 if seq_len > 32 else 32 ++ masks = torch.tril(torch.ones((1, seqlen_align, seqlen_align), device=lm_input.device)).to(torch.bool) + y_pred, cache = self.llm.forward_one_step(lm_input, +- masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool), ++ masks=masks, + cache=cache) + logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1) + top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=False).item() +@@ -432,3 +455,142 @@ class Qwen2LM(TransformerLM): + # in stream mode, yield token one by one + yield top_ids + lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1) ++ ++ @torch.inference_mode() ++ def inference_bistream_streaming_input( ++ self, ++ text: torch.Tensor, ++ char_idx: torch.Tensor, ++ prompt_text: torch.Tensor, ++ prompt_text_len: torch.Tensor, ++ prompt_speech_token: torch.Tensor, ++ prompt_speech_token_len: torch.Tensor, ++ embedding: torch.Tensor, ++ uuid: str, ++ input_end: bool, ++ sampling: int = 25, ++ max_token_text_ratio: float = 20, ++ min_token_text_ratio: float = 2, ++ ) -> Generator[torch.Tensor, None, None]: ++ ++ def build_causal_mask(query_len, key_len, devices): ++ assert key_len >= query_len ++ key_len = (key_len + 15) // 16 * 16 if key_len > 32 else 32 ++ causal_mask = torch.triu(torch.ones((key_len, key_len), device=devices), diagonal=(key_len - query_len) + 1).to(torch.bool) ++ return causal_mask.unsqueeze(0) ++ ++ device = prompt_text.device ++ ++ if uuid not in self.cache_dict: ++ sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1) ++ if prompt_speech_token_len != 0: ++ self.prompt_speech_token_emb_dict[uuid] = self.speech_embedding(prompt_speech_token) ++ else: ++ self.prompt_speech_token_emb_dict[uuid] = torch.zeros(1, 0, self.llm_input_size, dtype=prompt_text.dtype).to(device) ++ ++ self.lm_input_dict[uuid] = torch.concat([sos_eos_emb], dim=1) # [1,1,896] ++ ++ self.out_tokens_dict[uuid] = [] ++ self.cache_dict[uuid] = None ++ ++ self.text_cache_dict[uuid] = self.llm.model.model.embed_tokens(prompt_text) # [1, prompt_text, 896] ++ self.next_fill_index[uuid] = -1 ++ self.prompt_length[uuid] = 0 ++ ++ text_emb = self.llm.model.model.embed_tokens(text) ++ ++ for i in range(text_emb.size(1)): ++ self.text_cache_dict[uuid] = torch.concat([self.text_cache_dict[uuid], text_emb[:, i].unsqueeze(1)], dim=1) ++ index = 0 ++ while self.prompt_speech_token_emb_dict[uuid].size(1) != 0: ++ if self.text_cache_dict[uuid].size(1) >= self.mix_ratio[0]: ++ lm_input_text, lm_input_speech = self.text_cache_dict[uuid][:, :self.mix_ratio[0]], self.prompt_speech_token_emb_dict[uuid][:, :self.mix_ratio[1]] ++ index += 1 ++ logging.info('append {} text token {} speech token'.format(lm_input_text.size(1), lm_input_speech.size(1))) ++ self.lm_input_dict[uuid] = torch.concat([self.lm_input_dict[uuid], lm_input_text, lm_input_speech], dim=1) ++ self.text_cache_dict[uuid], self.prompt_speech_token_emb_dict[uuid] = self.text_cache_dict[uuid][:, self.mix_ratio[0]:], self.prompt_speech_token_emb_dict[uuid][:, self.mix_ratio[1]:] ++ else: ++ break ++ ++ if self.prompt_speech_token_emb_dict[uuid].size(1) == 0: # 文本token数量多于音频token,混合完以后,剩余文本token,开始解码 ++ # 若上一次解码的 token 是 fill_token,说明 LLM 想要更多 text token ++ # 或者首次预测时,还没开始解码,out_tokens_dict 为空 ++ if ((len(self.out_tokens_dict[uuid]) != 0 and self.out_tokens_dict[uuid][-1] == self.speech_token_size + 2) ++ or (len(self.out_tokens_dict[uuid]) == 0 and self.lm_input_dict[uuid].size(1) == 1)): ++ # token数量够了 ++ if self.text_cache_dict[uuid].size(1) >= self.mix_ratio[0]: ++ lm_input_text = self.text_cache_dict[uuid][:, :self.mix_ratio[0]] # 抽出5个token ++ if len(self.out_tokens_dict[uuid]) != 0 and self.out_tokens_dict[uuid][-1] == self.speech_token_size + 2: # 预测出filling token,前面cache已经缓存,当前直接输入即可 ++ self.lm_input_dict[uuid] = lm_input_text ++ else: # sft刚开始预测,需要和sos token拼接在一起 ++ self.lm_input_dict[uuid] = torch.concat([self.lm_input_dict[uuid], lm_input_text], dim=1) ++ self.text_cache_dict[uuid] = self.text_cache_dict[uuid][:, self.mix_ratio[0]:] ++ else: ++ continue ++ ++ while True: ++ self.prompt_length[uuid] += self.lm_input_dict[uuid].shape[1] ++ seq_len = self.prompt_length[uuid] ++ if self.lm_input_dict[uuid].shape[1] > 1: ++ masks = build_causal_mask(self.lm_input_dict[uuid].shape[1], seq_len, ++ self.lm_input_dict[uuid].device) ++ else: ++ masks = None ++ y_pred, self.cache_dict[uuid] = self.llm.forward_one_step(self.lm_input_dict[uuid], ++ masks=masks, ++ prompt_length=seq_len, ++ cache=self.cache_dict[uuid]) ++ logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1) ++ # 判断是否生成 filling_token: ++ if self.next_fill_index[uuid] != -1 and len(self.out_tokens_dict[uuid]) == self.next_fill_index[uuid]: ++ top_ids = self.speech_token_size + 2 # 该预测filling token了 ++ self.next_fill_index[uuid] += (self.mix_ratio[1] + 1) # 找到下一个filling token的位置 ++ else: ++ top_ids = self.sampling_ids(logp.squeeze(dim=0), self.out_tokens_dict[uuid], sampling, ignore_eos=True).item() ++ # 特殊 token 处理, fill_token → 中断预测、等待新文本 token。 ++ if top_ids == self.speech_token_size + 2: ++ self.next_fill_index[uuid] = len(self.out_tokens_dict[uuid]) + self.mix_ratio[1] + 1 # -1 > 30 ++ self.out_tokens_dict[uuid].append(top_ids) ++ if top_ids >= self.speech_token_size: ++ if top_ids == self.speech_token_size + 2: # 预测到了filling token, break掉迎接新的文本token ++ break ++ else: ++ raise ValueError('should not get token {}'.format(top_ids)) ++ yield top_ids ++ self.lm_input_dict[uuid] = self.speech_embedding.weight[top_ids].reshape(1, 1, -1).detach().clone() ++ ++ if input_end: ++ # 3. final decode 文本全部送完,进行最后的解码。 ++ task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1) ++ self.lm_input_dict[uuid] = torch.concat([self.lm_input_dict[uuid], self.text_cache_dict[uuid], task_id_emb, self.prompt_speech_token_emb_dict[uuid]], dim=1) ++ logging.info('no more text token, decode until met eos') ++ while True: ++ self.prompt_length[uuid] += self.lm_input_dict[uuid].shape[1] ++ seq_len = self.prompt_length[uuid] ++ if self.lm_input_dict[uuid].shape[1] > 1: ++ masks = build_causal_mask(self.lm_input_dict[uuid].shape[1], seq_len, self.lm_input_dict[uuid].device) ++ else: ++ masks = None ++ y_pred, self.cache_dict[uuid] = self.llm.forward_one_step(self.lm_input_dict[uuid], ++ masks=masks, ++ prompt_length=seq_len, ++ cache=self.cache_dict[uuid]) ++ logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1) ++ top_ids = self.sampling_ids(logp.squeeze(dim=0), self.out_tokens_dict[uuid], sampling, ignore_eos=False).item() ++ self.out_tokens_dict[uuid].append(top_ids) ++ if top_ids >= self.speech_token_size: ++ if top_ids == self.speech_token_size: ++ break ++ else: ++ raise ValueError('should not get token {}'.format(top_ids)) ++ # in stream mode, yield token one by one ++ yield top_ids ++ self.lm_input_dict[uuid] = self.speech_embedding.weight[top_ids].reshape(1, 1, -1).detach().clone() ++ ++ # this user is done ++ self.prompt_speech_token_emb_dict.pop(uuid) ++ self.lm_input_dict.pop(uuid) ++ self.out_tokens_dict.pop(uuid) ++ self.cache_dict.pop(uuid) ++ self.text_cache_dict.pop(uuid) ++ self.next_fill_index.pop(uuid) +\ No newline at end of file +diff --git a/cosyvoice/utils/common.py b/cosyvoice/utils/common.py +index 3e61a8c..d316b92 100644 +--- a/cosyvoice/utils/common.py ++++ b/cosyvoice/utils/common.py +@@ -107,12 +107,33 @@ def init_weights(m, mean=0.0, std=0.01): + + # Repetition Aware Sampling in VALL-E 2 + def ras_sampling(weighted_scores, decoded_tokens, sampling, top_p=0.8, top_k=25, win_size=10, tau_r=0.1): +- top_ids = nucleus_sampling(weighted_scores, top_p=top_p, top_k=top_k) ++ top_ids = dst_sampling(weighted_scores, top_p=top_p, top_k=top_k) + rep_num = (torch.tensor(decoded_tokens[-win_size:]).to(weighted_scores.device) == top_ids).sum().item() + if rep_num >= win_size * tau_r: + top_ids = random_sampling(weighted_scores, decoded_tokens, sampling) + return top_ids + ++def dst_sampling(weighted_scores, top_p=0.8, top_k=25): ++ ++ sorted_value, sorted_idx = weighted_scores.softmax(dim=0).sort(descending=True, stable=True) ++ ++ cum_sum = torch.cumsum(sorted_value, dim=0) ++ n = sorted_value.size(0) ++ device = cum_sum.device ++ pre_cum_sum = torch.cat([torch.zeros(1, device=device), cum_sum[:-1]]) ++ ++ indices = torch.arange(n ,device=device) ++ condition = (pre_cum_sum < top_p) & (indices < top_k) ++ ++ max_i_tensor = torch.where(condition, indices, torch.tensor(-1, device=device)) ++ n_selected = max_i_tensor.max() + 1 ++ ++ selected_prob = sorted_value[:n_selected] ++ selected_indices = sorted_idx[:n_selected] ++ ++ top_ids = selected_indices[selected_prob.multinomial(1, replacement=True)] ++ ++ return top_ids + + def nucleus_sampling(weighted_scores, top_p=0.8, top_k=25): + prob, indices = [], [] diff --git a/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/300I/infer.py b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/300I/infer.py new file mode 100644 index 0000000000..72d584396f --- /dev/null +++ b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/300I/infer.py @@ -0,0 +1,108 @@ +# Copyright (c) 2025 Huawei Technologies Co., Ltd +# [Software Name] is licensed under Mulan PSL v2. +# You can use this software according to the terms and conditions of the Mulan PSL v2. +# You may obtain a copy of Mulan PSL v2 at: +# http://license.coscl.org.cn/MulanPSL2 +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, +# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, +# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +# See the Mulan PSL v2 for more details. + +import argparse +from tqdm import tqdm +import torch +import torchaudio +import torch_npu +from torch_npu.contrib import transfer_to_npu +import torchair as tng +from torchair.configs.compiler_config import CompilerConfig +from cosyvoice.cli.cosyvoice import CosyVoice2 +from cosyvoice.utils.file_utils import load_wav + + +def no_stream_input_inference(args, cosyvoice, prompt_txt): + with torch.no_grad(): + print('warm up start') + for _ in range(args.warm_up_times): + for _ in enumerate(cosyvoice.inference_sft(prompt_txt[0], '中文女', stream=args.stream_out)): + pass + print('warm up end') + infer_res = [torch.tensor([]) for _ in range(args.infer_count)] + for i_step in range(args.infer_count): + for _, j in enumerate(cosyvoice.inference_sft(prompt_txt[0], '中文女', stream=args.stream_out)): + infer_res[i_step] = torch.cat((infer_res[i_step], j['tts_speech']), dim=1) + print(f"save out wav file to sft_out_{i_step+1}.wav") + torchaudio.save(f"sft_out_{i_step+1}.wav", infer_res[i_step], cosyvoice.sample_rate) + + +def stream_input_inference(args, cosyvoice, prompt_txt): + + def inference_step(step, mode): + times = args.warm_up_times if mode == "warmup" else args.infer_count + print(f"第{step + 1}/{times}轮 {mode}:↓↓↓") + print(f"curr prompt text:{prompt_txt[step % len(prompt_txt)]}") + for char_idx, char in enumerate(prompt_txt[step % len(prompt_txt)]): + if char_idx == len(prompt_txt[step % len(prompt_txt)]) - 1: + for _, j in enumerate(cosyvoice.inference_sft_streaming_input(char, char_idx, "中文女", user_id="AscendUser", input_end=True, stream=args.stream_out)): + if mode == "warmup": + pass + else: + infer_res[step] = torch.cat((infer_res[step], j['tts_speech']), dim=1) + else: + for _, j in enumerate(cosyvoice.inference_sft_streaming_input(char, char_idx, "中文女", user_id="AscendUser", input_end=False, stream=args.stream_out)): + if mode == "warmup": + pass + else: + infer_res[step] = torch.cat((infer_res[step], j['tts_speech']), dim=1) + + infer_res = [torch.tensor([]) for _ in range(args.infer_count)] + + with torch.no_grad(): + print("warm up start") + for w_step in range(args.warm_up_times): + inference_step(w_step, mode="warmup") + print("warm up end") + + print("inference start") + for i_step in range(args.infer_count): + inference_step(i_step, mode="inference") + print(f"save out wav file to stream_input_out_{i_step+1}.wav") + torchaudio.save(f"stream_input_out_{i_step+1}.wav", infer_res[i_step], cosyvoice.sample_rate) + print("inference end") + + +if __name__ == '__main__': + torch_npu.npu.set_compile_mode(jit_compile=False) + + parser = argparse.ArgumentParser(description="CosyVoice2 infer") + parser.add_argument("--model_path", type=str, help="model path") + parser.add_argument('--warm_up_times', default=2, type=int, help='warm up times') + parser.add_argument('--infer_count', default=20, type=int, help='infer loop count') + parser.add_argument('--stream_in', action="store_true", help='stream input infer') + parser.add_argument('--stream_out', action="store_true", help='stream output infer') + args = parser.parse_args() + + cosyvoice = CosyVoice2(args.model_path, load_om=True, fp16=True) + cosyvoice.model.llm.eval() + cosyvoice.model.llm.llm.model.model.half() + + # 对hift模型结构进行torchair图模式适配 + cosyvoice.model.hift.remove_weight_norm() + config = CompilerConfig() + config.experimental_config.frozen_parameter = True + config.experimental_config.tiling_schedule_optimize = True + npu_backend = tng.get_npu_backend(compiler_config=config) + cosyvoice.model.hift.decode = torch.compile(cosyvoice.model.hift.decode, dynamic=True, fullgraph=True, backend=npu_backend) + + # 输入数据加载 + prompt_txt = [ + '收到好友从远方寄来的生日礼物,那份意外的惊喜和深深的祝福,让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', + '全球每年有超过一百三十五万人,因吸烟而死亡' + ] + + # 普通输入(非流式输入) + if not args.stream_in: + no_stream_input_inference(args, cosyvoice, prompt_txt) + # 流式输入 + else: + stream_input_inference(args, cosyvoice, prompt_txt) diff --git a/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/300I/modeling_qwen2.py b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/300I/modeling_qwen2.py new file mode 100644 index 0000000000..94ad8d5a94 --- /dev/null +++ b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/300I/modeling_qwen2.py @@ -0,0 +1,919 @@ +# Copyright (c) 2025 Huawei Technologies Co., Ltd +# [Software Name] is licensed under Mulan PSL v2. +# You can use this software according to the terms and conditions of the Mulan PSL v2. +# You may obtain a copy of Mulan PSL v2 at: +# http://license.coscl.org.cn/MulanPSL2 +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, +# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, +# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +# See the Mulan PSL v2 for more details. +""" PyTorch Qwen2 model.""" + +import math +import warnings +from typing import List, Optional, Tuple, Union +import torch +import torch_npu +import torchair as tng +from torchair.configs.compiler_config import CompilerConfig +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_utils import PreTrainedModel +from ...utils import ( + add_start_docstrings, + logging +) +from .configuration_qwen2 import Qwen2Config + + +logger = logging.get_logger(__name__) + + +QWEN2_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "Qwen/Qwen2-7B-beta", +] + + +# Ascend优化:Add/Norm昇腾自定义融合算子 +class Qwen2RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Qwen2RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, + hidden_states, + residual: Optional[torch.Tensor] = None): + if residual is None: + return torch_npu.npu_rms_norm(hidden_states, self.weight, self.variance_epsilon)[0], hidden_states + else: + y, _, x = torch_npu.npu_add_rms_norm(residual, hidden_states, self.weight, self.variance_epsilon) + return y, x + + +# Ascend优化:提前计算位置编码,无需在每层layer中重复计算 +class Qwen2RotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() + ) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + def forward(self, x=None, seq_len=None): + if x is None and seq_len is None: + return self.cos_cached, self.sin_cached + + return ( + self.cos_cached.to(dtype=x.dtype), + self.sin_cached.to(dtype=x.dtype), + ) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., :x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin): + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class Qwen2MLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + +class Qwen2Attention(nn.Module): + """ + Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer + and "Generating Long Sequences with Sparse Transformers". + """ + + def __init__(self, config: Qwen2Config, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " + "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + self.attention_dropout = config.attention_dropout + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + + self.rotary_emb = Qwen2RotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + + +# Ascend优化:PFA/IFA自定义算子替换,kv cache固定shape并在指定位置更新 +class Qwen2SdpaAttention(Qwen2Attention): + """ + Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `Qwen2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # 优化Attention部分逻辑,替换torch_npu算子 + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + updated_kv_positions: Optional[torch.LongTensor] = None, + actual_seq_len: Optional[list] = None, + rotary_emb_cos: Optional[torch.Tensor] = None, + rotary_emb_sin: Optional[torch.Tensor] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + logger.warning_once( + "Qwen2Model is using Qwen2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) + + # 利用已经提前计算好的位置编码数据对q,k值进行更新 + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, + rotary_emb_cos.to(value_states.dtype), + rotary_emb_sin.to(value_states.dtype)) + + if use_cache and past_key_value is not None: + # 把计算好的kv值更新到kv cahce中 + tmp_ids = updated_kv_positions.reshape(-1) + torch_npu.scatter_update_(past_key_value.key_cache[self.layer_idx], tmp_ids, key_states, 1) + torch_npu.scatter_update_(past_key_value.value_cache[self.layer_idx], tmp_ids, value_states, 1) + # 流式输入场景,decode阶段 + if q_len == 1: + key_states = past_key_value[self.layer_idx][0] + value_states = past_key_value[self.layer_idx][1] + # 流式输入场景,首次prefill阶段 + elif q_len == actual_seq_len[0]: + key_states = key_states + value_states = value_states + # 流式输入场景,后续prefill阶段 + elif q_len < actual_seq_len[0]: + key_states = past_key_value.key_cache[self.layer_idx][:, :actual_seq_len[0]] + value_states = past_key_value.value_cache[self.layer_idx][:, :actual_seq_len[0]] + else: + raise ValueError(f"Unexpected q_len: {q_len}, actual_seq_len[0]: {actual_seq_len[0]}") + + if q_len > 1: + # prefill阶段利用PFA自定义算子执行计算,因为bs为1,mask固定为下三角全为0上三角全为负无穷的倒三角mask矩阵 + kv_len = key_states.shape[1] + # 310P 仅支持qSeqlen = kvSeqlen + q_pad_len = (kv_len + 15) // 16 * 16 - q_len if kv_len > 32 else 32 + if q_pad_len > 0: + q_pad = torch.zeros((bsz, q_pad_len, self.num_heads, self.head_dim), dtype=query_states.dtype, device=query_states.device) + query_states = torch.cat([query_states, q_pad], dim=1).contiguous() + kv_pad_len = (kv_len + 15) // 16 * 16 - kv_len if kv_len > 32 else 32 + if kv_pad_len > 0: + kv_pad = torch.zeros((bsz, kv_pad_len, self.num_key_value_heads, self.head_dim), dtype=key_states.dtype, device=key_states.device) + key_states = torch.cat([key_states, kv_pad], dim=1) + value_states = torch.cat([value_states, kv_pad], dim=1) + attn_output = torch_npu.npu_prompt_flash_attention(query_states, + key_states.contiguous(), + value_states.contiguous(), + num_heads=self.num_heads, + input_layout="BSND", + scale_value=1 / math.sqrt(self.head_dim), + pre_tokens=65535, next_tokens=0, + atten_mask=attention_mask, + sparse_mode=1, + num_key_value_heads=self.num_key_value_heads) + attn_output = attn_output[:, :q_len, :, :] + else: + # decode阶段利用IFA自定义算子执行计算,qkv的sequence都为1,该算子采用tiling下沉,视为静态算子,支持整图下发 + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + attn_output = torch_npu.npu_incre_flash_attention(query_states, + key_states.contiguous(), + value_states.contiguous(), + num_heads=self.num_heads, + input_layout="BNSD", + scale_value=1 / math.sqrt(self.head_dim), + atten_mask=None, + actual_seq_lengths=actual_seq_len, + num_key_value_heads=self.num_key_value_heads) + attn_output = attn_output.transpose(1, 2) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +QWEN2_ATTENTION_CLASSES = { + "sdpa": Qwen2SdpaAttention, +} + + +# Ascend优化:每层layer的前后rms替换为昇腾自定义算子 +class Qwen2DecoderLayer(nn.Module): + def __init__(self, config: Qwen2Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + if config.use_sliding_window and config._attn_implementation != "flash_attention_2": + logger.warning_once( + f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " + "unexpected results may be encountered." + ) + self.self_attn = QWEN2_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) + + self.mlp = Qwen2MLP(config) + self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + past_residual: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + updated_kv_positions: Optional[torch.LongTensor] = None, + actual_seq_len: Optional[list] = None, + rotary_emb_cos: Optional[torch.Tensor] = None, + rotary_emb_sin: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. " + "Please make sure use `attention_mask` instead.`" + ) + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + + # rms计算替换为昇腾自定义融合算子 + hidden_states, residual = self.input_layernorm(hidden_states, past_residual) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + updated_kv_positions=updated_kv_positions, + actual_seq_len=actual_seq_len, + rotary_emb_cos=rotary_emb_cos, + rotary_emb_sin=rotary_emb_sin, + use_cache=use_cache, + ) + + # rms计算替换为昇腾自定义融合算子 + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + hidden_states = self.mlp(hidden_states) + + outputs = (residual, hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +@add_start_docstrings( + "The bare Qwen2 Model outputting raw hidden-states without any specific head on top.", +) +class Qwen2PreTrainedModel(PreTrainedModel): + config_class = Qwen2Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Qwen2DecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +# Ascend优化:forward函数利用torchair编译为图模式,利用cache接口避免重复编译 +@add_start_docstrings( + "The bare Qwen2 Model outputting raw hidden-states without any specific head on top.", +) +class Qwen2Model(Qwen2PreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`] + + Args: + config: Qwen2Config + """ + + def __init__(self, config: Qwen2Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.max_position_embeddings = config.max_position_embeddings + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.rope_theta = config.rope_theta + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [Qwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self._attn_implementation = config._attn_implementation + self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Qwen2RotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + # torchair编译参数,编译Qwen2Model的forward部分 + config = CompilerConfig() + config.experimental_config.frozen_parameter = True + # tiling下沉,主要针对IFA算子,使其算子tiling操作在AICPU上执行 + config.experimental_config.tiling_schedule_optimize = True + + # torchair的cache编译,保证模型编译cache文件,避免重复推理 + self.cached_decode = tng.inference.cache_compile(self.decode, config=config) + self.cached_first_prefill = tng.inference.cache_compile(self.first_prefill, config=config) # 用于首次prefill,无kv_cache + self.cached_next_prefill = tng.inference.cache_compile(self.next_prefill, config=config) # 用于后续prefill,有kv_cache + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def _prepare_decoder_rotary_cos_sin(self, position_ids): + cos, sin = self.rotary_emb() + f_position_ids = position_ids.flatten() + cos = torch.index_select(cos, 0, f_position_ids) + sin = torch.index_select(sin, 0, f_position_ids) + cos = cos.reshape(position_ids.size(0), position_ids.size(1), -1).unsqueeze(2) + sin = sin.reshape(position_ids.size(0), position_ids.size(1), -1).unsqueeze(2) + return cos, sin + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + updated_kv_positions: Optional[torch.LongTensor] = None, + actual_seq_len: Optional[list] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + lm_head: Optional[object] = None + ) -> Union[Tuple, BaseModelOutputWithPast]: + # prefill_1, prefill_2和decode需要编译为3个不同的模型 + seq_len = inputs_embeds.size(1) + + # 流式输入场景,首次prefill + if seq_len > 1 and seq_len == actual_seq_len[0]: + return self.cached_first_prefill( + input_ids, + attention_mask, + position_ids, + past_key_values, + updated_kv_positions, + actual_seq_len, + inputs_embeds, + use_cache, + output_attentions, + output_hidden_states, + return_dict, + lm_head + ) + + # 流式输入场景,后续prefill + elif 1 < seq_len < actual_seq_len[0]: + return self.cached_next_prefill( + input_ids, + attention_mask, + position_ids, + past_key_values, + updated_kv_positions, + actual_seq_len, + inputs_embeds, + use_cache, + output_attentions, + output_hidden_states, + return_dict, + lm_head + ) + + # 流式输入场景,decode + else: + return self.cached_decode( + input_ids, + attention_mask, + position_ids, + past_key_values, + updated_kv_positions, + actual_seq_len, + inputs_embeds, + use_cache, + output_attentions, + output_hidden_states, + return_dict, + lm_head + ) + + def decode( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + updated_kv_positions: Optional[torch.LongTensor] = None, + actual_seq_len: Optional[list] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + lm_head: Optional[object] = None + ): + return self._forward( + input_ids, + attention_mask, + position_ids, + past_key_values, + updated_kv_positions, + actual_seq_len, + inputs_embeds, + use_cache, + output_attentions, + output_hidden_states, + return_dict, + lm_head + ) + + def first_prefill( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + updated_kv_positions: Optional[torch.LongTensor] = None, + actual_seq_len: Optional[list] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + lm_head: Optional[object] = None + ): + return self._forward( + input_ids, + attention_mask, + position_ids, + past_key_values, + updated_kv_positions, + actual_seq_len, + inputs_embeds, + use_cache, + output_attentions, + output_hidden_states, + return_dict, + lm_head + ) + + def next_prefill( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + updated_kv_positions: Optional[torch.LongTensor] = None, + actual_seq_len: Optional[list] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + lm_head: Optional[object] = None + ): + return self._forward( + input_ids, + attention_mask, + position_ids, + past_key_values, + updated_kv_positions, + actual_seq_len, + inputs_embeds, + use_cache, + output_attentions, + output_hidden_states, + return_dict, + lm_head + ) + + + def _forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + updated_kv_positions: Optional[torch.LongTensor] = None, + actual_seq_len: Optional[list] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + lm_head: Optional[object] = None + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + + # prefill阶段初始化kv cache,decode阶段对kv cache进行更新 + # 固定kv cache为最大shape,避免内存的重复申请和拷贝,也保证了模型的静态shape,可整图下发推理 + if use_cache: + use_legacy_cache = not isinstance(past_key_values, Cache) + if use_legacy_cache: + kv_shape = ( + batch_size, self.config.max_position_embeddings, + self.config.num_key_value_heads, + self.config.hidden_size // self.config.num_attention_heads) # (1, 32768, 2, 64) + past_key_values = () + for _ in range(self.config.num_hidden_layers): + k_cache = torch.zeros(kv_shape, dtype=inputs_embeds.dtype, device=inputs_embeds.device) + v_cache = torch.zeros(kv_shape, dtype=inputs_embeds.dtype, device=inputs_embeds.device) + past_key_values += ((k_cache, v_cache),) + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + + past_key_values_length = self.max_position_embeddings if actual_seq_len[0] > inputs_embeds.shape[1] else 0 + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) # tensor([0, 1, 2, 3, 4, 5], device='npu:0') + else: + position_ids = position_ids.view(-1, seq_length).long() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + hidden_states = inputs_embeds + + # 此处统一计算位置编码,在每个layer中取对应位置的值 + rotary_emb_cos, rotary_emb_sin = self._prepare_decoder_rotary_cos_sin(position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + residual = None + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + # 执行layer层推理 + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + past_residual=residual, + position_ids=position_ids, + past_key_value=past_key_values, + updated_kv_positions=updated_kv_positions, + actual_seq_len=actual_seq_len, + rotary_emb_cos=rotary_emb_cos, + rotary_emb_sin=rotary_emb_sin, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + residual = layer_outputs[0] + hidden_states = layer_outputs[1] + + if use_cache: + next_decoder_cache = layer_outputs[3 if output_attentions else 2] + + if output_attentions: + all_self_attns += (layer_outputs[2],) + + # norm计算,此处替换为昇腾融合算子 + hidden_states, _ = self.norm(hidden_states, residual) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = None + if use_cache: + next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + + out = BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + hidden_states = out[0] + # 由于logits最后也只取[:,-1,:],相当于只取最新seq位置上的数据,l + # 所以在全量的最后线性层计算可以只对最新的seq位置做计算,降低计算量 + bs, seq, hidden = hidden_states.size() + if seq > 1: + gather_index = torch.ones(bs, dtype=torch.int64, device=hidden_states.device) * (seq - 1) + gather_index = gather_index.unsqueeze(dim=1).unsqueeze(dim=2).repeat(1, 1, hidden) + hidden_states = torch.gather(hidden_states, 1, gather_index) + logits = lm_head(hidden_states) + logits = logits.float() + return out, logits + + +class Qwen2ForCausalLM(Qwen2PreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = Qwen2Model(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + prompt_length: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None + ) -> Union[Tuple, CausalLMOutputWithPast]: + """ + 对CosyVoice2模型中使用的Qwen模型进行昇腾适配优化,具体优化点有: + 1. 固定KV CACHE大小,避免重复申请内存和拷贝 + 2. 替换部分算子为昇腾自定义算子 + 3. 首层计算位置编码避免重复计算 + 4. 在decode阶段,固定输入shape大小,保证整图下发 + + 模型有以下输入: + 1. attention_mask + 2. inputs_embeds:CosyVoice会把inputs_ids处理embeding后输入模型 + 3. past_key_values:kv cache,在每次推理后会进行更新 + 4. position_ids:位置id,在每次推理后会进行更新 + 5. prompt_length:实际输入长度,在prefill阶段为首token长度,后续每次推理长度加1 + """ + + # 每次推理前对输入数据进行昇腾适配处理,处理为昇腾自定义算子所需类型参数 + updated_kv_positions, past_key_values, position_ids, actual_seq_len = self.prepare_data(inputs_embeds, past_key_values, prompt_length) + + model_inputs = { + "inputs_embeds": inputs_embeds, + "past_key_values": past_key_values, + "position_ids": position_ids, + "actual_seq_len": actual_seq_len, + "attention_mask": attention_mask, + } + + # prefill阶段由于输出token长度不固定,为动态shape推理。decode阶段把输入固定为静态,保证整图静态推理。 + if inputs_embeds.shape[1] == 1: + self._mark_model_inputs_static(model_inputs) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # 主要推理阶段,利用torchair编译为整图推理 + outputs, logits = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + updated_kv_positions=updated_kv_positions, + actual_seq_len=actual_seq_len, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + lm_head=self.lm_head + ) + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + # Ascend优化:把数据输入处理为Ascend优化所需要的格式和类型 + def prepare_data(self, inputs_embeds, past_key_values, prompt_length): + bsz = inputs_embeds.shape[0] + seq_length = inputs_embeds.shape[1] + if past_key_values: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + + # 流式输入场景,首次prefill + if seq_length > 1 and prompt_length == inputs_embeds.shape[1]: + updated_kv_positions = torch.zeros(bsz, dtype=torch.long, device=inputs_embeds.device) + position_ids = None + # 流式输入场景,后续prefill + elif seq_length > 1 and prompt_length > inputs_embeds.shape[1]: + # updated_kv_positions,kv_cache需要更新的起始位置 + updated_kv_positions = torch.ones(bsz, dtype=torch.long, device=inputs_embeds.device) * (prompt_length - inputs_embeds.shape[1]) + tmp_head = prompt_length - inputs_embeds.shape[1] + tmp_tail = prompt_length + position_ids = torch.arange(tmp_head, tmp_tail, dtype=torch.long, device=inputs_embeds.device) + # 流式输入场景,decode + else: + updated_kv_positions = torch.ones(bsz, dtype=torch.long, device=inputs_embeds.device) * (prompt_length - 1) + position_ids = torch.tensor([prompt_length - 1], device=inputs_embeds.device) + + # ifa Computational optimization inputs + actual_seq_len = ([prompt_length]) + + return updated_kv_positions, past_key_values, position_ids, actual_seq_len + + # Ascend优化:固定input shape,使能静态推理,模型整图下发 + def _mark_model_inputs_static(self, model_inputs): + for key, value in model_inputs.items(): + if key == "past_key_values" and value is not None: + for i in range(self.config.num_hidden_layers): + torch._dynamo.mark_static(value[i][0]) + torch._dynamo.mark_static(value[i][1]) + elif isinstance(value, torch.Tensor): + torch._dynamo.mark_static(value) + diff --git a/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/diff_CosyVoice.patch b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/800I/diff_CosyVoice_800I.patch old mode 100755 new mode 100644 similarity index 100% rename from ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/diff_CosyVoice.patch rename to ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/800I/diff_CosyVoice_800I.patch diff --git a/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/infer.py b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/800I/infer.py old mode 100755 new mode 100644 similarity index 100% rename from ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/infer.py rename to ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/800I/infer.py diff --git a/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/modeling_qwen2.py b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/800I/modeling_qwen2.py old mode 100755 new mode 100644 similarity index 100% rename from ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/modeling_qwen2.py rename to ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/800I/modeling_qwen2.py diff --git a/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/README.md b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/README.md index 661d8377a3..8353bc43a4 100755 --- a/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/README.md +++ b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/README.md @@ -55,28 +55,34 @@ cd ModelZoo-PyTorch/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2 cd CosyVoice git reset --hard fd45708 git submodule update --init --recursive - git apply ../diff_CosyVoice.patch + git apply ../${platform}/diff_CosyVoice_${platform}.patch # 获取Transformer源码 + cd .. git clone https://github.com/huggingface/transformers.git cd transformers git checkout v4.37.0 cd .. # 将modeling_qwen模型文件替换到transformers仓内 - mv ../modeling_qwen2.py ./transformers/src/transformers/models/qwen2 + mv ../${platform}/modeling_qwen2.py ./transformers/src/transformers/models/qwen2 ``` 文件目录结构大致如下: ```text 📁 CosyVoice/ ├── 📁 CosyVoice2/ - | |── 📄 diff_CosyVoice.patch + |── 📁 300I + | |── 📄 diff_CosyVoice_300I.patch + | |── 📄 infer.py # 推理脚本 + | |── 📄 modeling_qwen2.py + |── 📁 800I + | |── 📄 diff_CosyVoice_800I.patch + | |── 📄 infer.py # 推理脚本 | |── 📄 modeling_qwen2.py - | |── 📁 CosyVoice - | |── 📁 cosyVoice源码文件 # cosyVoice的源码文件,此处不一一列举 - │ ├── 📁 CosyVoice-0.5B/ # 权重文件 - │ ├── 📁 transformers/ # transformers文件,里面有修改过的modeling_qwen2.py文件 - │ ├── 📄 infer.py # 推理脚本 - │ └── 📄 modify_onnx.py # 模型转换脚本 + | 📁 CosyVoice + | |── 📁 cosyVoice源码文件 # cosyVoice的源码文件,此处不一一列举 + │ ├── 📁 CosyVoice-0.5B/ # 权重文件 + │ ├── 📁 transformers/ # transformers库,里面有修改过的modeling_qwen2.py文件 + │ └── 📄 modify_onnx.py # 模型转换脚本 ``` 2. 安装依赖 @@ -103,7 +109,7 @@ cd ModelZoo-PyTorch/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2 3. 安装msit工具 - 参考[msit](https://gitee.com/ascend/msit)安装工具中的benchmark和surgen组件。(未安装会提示 ais_bench 导入失败报错) + 参考[msit](https://gitee.com/ascend/msit)安装工具中的benchmark和surgeon组件。(未安装会提示 ais_bench 导入失败报错) 4. 获取权重数据 @@ -163,7 +169,7 @@ cd ModelZoo-PyTorch/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2 ### 2 开始推理验证 - 1. 首先移动infer.py文件到CosyVoice目录下 + 1. 首先移动对应推理平台路径下的infer.py文件到CosyVoice目录下 2. 设置环境变量,执行推理命令 diff --git a/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/requirements.txt b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/requirements.txt index c3cdf647ef..7afcddafb2 100755 --- a/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/requirements.txt +++ b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/requirements.txt @@ -20,6 +20,7 @@ onnxruntime==1.16.0 openai-whisper==20231117 protobuf==4.25 pydantic==2.7.0 +pyworld==0.3.4 rich==13.7.1 soundfile==0.12.1 tensorboard==2.14.0 -- Gitee From 33229b3cf2b03c9c3d1d347d2d1281aa034f1a04 Mon Sep 17 00:00:00 2001 From: pu-zhe Date: Tue, 22 Jul 2025 09:38:35 +0800 Subject: [PATCH 02/11] modify Readme --- .../audio/CosyVoice/CosyVoice2/README.md | 28 ++++++++++--------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/README.md b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/README.md index 8353bc43a4..8261e80042 100755 --- a/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/README.md +++ b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/README.md @@ -70,19 +70,20 @@ cd ModelZoo-PyTorch/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2 ```text 📁 CosyVoice/ ├── 📁 CosyVoice2/ - |── 📁 300I - | |── 📄 diff_CosyVoice_300I.patch - | |── 📄 infer.py # 推理脚本 - | |── 📄 modeling_qwen2.py - |── 📁 800I - | |── 📄 diff_CosyVoice_800I.patch - | |── 📄 infer.py # 推理脚本 - | |── 📄 modeling_qwen2.py - | 📁 CosyVoice - | |── 📁 cosyVoice源码文件 # cosyVoice的源码文件,此处不一一列举 - │ ├── 📁 CosyVoice-0.5B/ # 权重文件 - │ ├── 📁 transformers/ # transformers库,里面有修改过的modeling_qwen2.py文件 - │ └── 📄 modify_onnx.py # 模型转换脚本 + | |── 📁 300I + | |── 📄 diff_CosyVoice_300I.patch + | |── 📄 infer.py # 推理脚本 + | |── 📄 modeling_qwen2.py + | |── 📁 800I + | |── 📄 diff_CosyVoice_800I.patch + | |── 📄 infer.py # 推理脚本 + | |── 📄 modeling_qwen2.py + | |── 📁 CosyVoice + | |── 📁 cosyVoice源码文件 # cosyVoice的源码文件,此处不一一列举 + │ ├── 📁 CosyVoice-0.5B/ # 权重文件 + │ ├── 📁 transformers/ # transformers库,里面修改modeling_qwen2.py文件 + │── 📄 requirements.txt # 依赖库 + └── 📄 modify_onnx.py # 模型转换脚本 ``` 2. 安装依赖 @@ -198,4 +199,5 @@ cd ModelZoo-PyTorch/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2 | 模型 |芯片|rtf(实时率)| |-----------|------|------| | cosyvoice |800I A2|0.28s| + | cosyvoice |300I DUO|0.90s| -- Gitee From 62c0ddc2d8c7b60b8acfba3b0c7ff96f614ed749 Mon Sep 17 00:00:00 2001 From: pu-zhe Date: Tue, 22 Jul 2025 10:30:56 +0800 Subject: [PATCH 03/11] fix decode bug --- .../CosyVoice2/300I/modeling_qwen2.py | 21 ++++++++++--------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/300I/modeling_qwen2.py b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/300I/modeling_qwen2.py index 94ad8d5a94..9ba6e01e21 100644 --- a/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/300I/modeling_qwen2.py +++ b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/300I/modeling_qwen2.py @@ -263,17 +263,18 @@ class Qwen2SdpaAttention(Qwen2Attention): attn_output = attn_output[:, :q_len, :, :] else: # decode阶段利用IFA自定义算子执行计算,qkv的sequence都为1,该算子采用tiling下沉,视为静态算子,支持整图下发 - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) + query_states = query_states.transpose(1, 2).contiguous() + key_states = key_states.transpose(1, 2).contiguous() + value_states = value_states.transpose(1, 2).contiguous() attn_output = torch_npu.npu_incre_flash_attention(query_states, - key_states.contiguous(), - value_states.contiguous(), - num_heads=self.num_heads, - input_layout="BNSD", - scale_value=1 / math.sqrt(self.head_dim), - atten_mask=None, - actual_seq_lengths=actual_seq_len, - num_key_value_heads=self.num_key_value_heads) + key_states, + value_states, + num_heads=self.num_heads, + input_layout="BNSD", + scale_value=1 / math.sqrt(self.head_dim), + atten_mask=None, + actual_seq_lengths=actual_seq_len, + num_key_value_heads=self.num_key_value_heads) attn_output = attn_output.transpose(1, 2) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) -- Gitee From 99f89e1cbdd8042f51c6701957c8a505ce47035d Mon Sep 17 00:00:00 2001 From: pu-zhe Date: Tue, 22 Jul 2025 17:02:44 +0800 Subject: [PATCH 04/11] fix bug --- .../audio/CosyVoice/CosyVoice2/300I/modeling_qwen2.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/300I/modeling_qwen2.py b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/300I/modeling_qwen2.py index 9ba6e01e21..185df71213 100644 --- a/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/300I/modeling_qwen2.py +++ b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/300I/modeling_qwen2.py @@ -241,11 +241,11 @@ class Qwen2SdpaAttention(Qwen2Attention): # prefill阶段利用PFA自定义算子执行计算,因为bs为1,mask固定为下三角全为0上三角全为负无穷的倒三角mask矩阵 kv_len = key_states.shape[1] # 310P 仅支持qSeqlen = kvSeqlen - q_pad_len = (kv_len + 15) // 16 * 16 - q_len if kv_len > 32 else 32 + q_pad_len = (kv_len + 15) // 16 * 16 - q_len if kv_len > 32 else 32 - q_len if q_pad_len > 0: q_pad = torch.zeros((bsz, q_pad_len, self.num_heads, self.head_dim), dtype=query_states.dtype, device=query_states.device) query_states = torch.cat([query_states, q_pad], dim=1).contiguous() - kv_pad_len = (kv_len + 15) // 16 * 16 - kv_len if kv_len > 32 else 32 + kv_pad_len = (kv_len + 15) // 16 * 16 - kv_len if kv_len > 32 else 32 - kv_len if kv_pad_len > 0: kv_pad = torch.zeros((bsz, kv_pad_len, self.num_key_value_heads, self.head_dim), dtype=key_states.dtype, device=key_states.device) key_states = torch.cat([key_states, kv_pad], dim=1) -- Gitee From 57d095d548677aca7d19d5061ee621a60f6d2eaa Mon Sep 17 00:00:00 2001 From: pu-zhe Date: Tue, 22 Jul 2025 20:07:15 +0800 Subject: [PATCH 05/11] optimize decode performance --- .../CosyVoice2/300I/modeling_qwen2.py | 35 ++++++++++--------- .../audio/CosyVoice/CosyVoice2/README.md | 2 ++ 2 files changed, 20 insertions(+), 17 deletions(-) diff --git a/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/300I/modeling_qwen2.py b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/300I/modeling_qwen2.py index 185df71213..fa5aa51f72 100644 --- a/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/300I/modeling_qwen2.py +++ b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/300I/modeling_qwen2.py @@ -220,8 +220,10 @@ class Qwen2SdpaAttention(Qwen2Attention): if use_cache and past_key_value is not None: # 把计算好的kv值更新到kv cahce中 tmp_ids = updated_kv_positions.reshape(-1) - torch_npu.scatter_update_(past_key_value.key_cache[self.layer_idx], tmp_ids, key_states, 1) - torch_npu.scatter_update_(past_key_value.value_cache[self.layer_idx], tmp_ids, value_states, 1) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + torch_npu.scatter_update_(past_key_value.key_cache[self.layer_idx], tmp_ids, key_states, 2) + torch_npu.scatter_update_(past_key_value.value_cache[self.layer_idx], tmp_ids, value_states, 2) # 流式输入场景,decode阶段 if q_len == 1: key_states = past_key_value[self.layer_idx][0] @@ -232,40 +234,38 @@ class Qwen2SdpaAttention(Qwen2Attention): value_states = value_states # 流式输入场景,后续prefill阶段 elif q_len < actual_seq_len[0]: - key_states = past_key_value.key_cache[self.layer_idx][:, :actual_seq_len[0]] - value_states = past_key_value.value_cache[self.layer_idx][:, :actual_seq_len[0]] + key_states = past_key_value.key_cache[self.layer_idx][:, :, :actual_seq_len[0]] + value_states = past_key_value.value_cache[self.layer_idx][:, :, :actual_seq_len[0]] else: raise ValueError(f"Unexpected q_len: {q_len}, actual_seq_len[0]: {actual_seq_len[0]}") + query_states = query_states.transpose(1, 2).contiguous() if q_len > 1: # prefill阶段利用PFA自定义算子执行计算,因为bs为1,mask固定为下三角全为0上三角全为负无穷的倒三角mask矩阵 - kv_len = key_states.shape[1] + kv_len = key_states.shape[2] # 310P 仅支持qSeqlen = kvSeqlen q_pad_len = (kv_len + 15) // 16 * 16 - q_len if kv_len > 32 else 32 - q_len if q_pad_len > 0: - q_pad = torch.zeros((bsz, q_pad_len, self.num_heads, self.head_dim), dtype=query_states.dtype, device=query_states.device) - query_states = torch.cat([query_states, q_pad], dim=1).contiguous() + q_pad = torch.zeros((bsz, self.num_heads, q_pad_len, self.head_dim), dtype=query_states.dtype, device=query_states.device) + query_states = torch.cat([query_states, q_pad], dim=2).contiguous() kv_pad_len = (kv_len + 15) // 16 * 16 - kv_len if kv_len > 32 else 32 - kv_len if kv_pad_len > 0: - kv_pad = torch.zeros((bsz, kv_pad_len, self.num_key_value_heads, self.head_dim), dtype=key_states.dtype, device=key_states.device) - key_states = torch.cat([key_states, kv_pad], dim=1) - value_states = torch.cat([value_states, kv_pad], dim=1) + kv_pad = torch.zeros((bsz, self.num_key_value_heads, kv_pad_len, self.head_dim), dtype=key_states.dtype, device=key_states.device) + key_states = torch.cat([key_states, kv_pad], dim=2) + value_states = torch.cat([value_states, kv_pad], dim=2) attn_output = torch_npu.npu_prompt_flash_attention(query_states, key_states.contiguous(), value_states.contiguous(), num_heads=self.num_heads, - input_layout="BSND", + input_layout="BNSD", scale_value=1 / math.sqrt(self.head_dim), pre_tokens=65535, next_tokens=0, atten_mask=attention_mask, sparse_mode=1, num_key_value_heads=self.num_key_value_heads) - attn_output = attn_output[:, :q_len, :, :] + attn_output = attn_output[:, :, :q_len, :].transpose(1, 2) else: # decode阶段利用IFA自定义算子执行计算,qkv的sequence都为1,该算子采用tiling下沉,视为静态算子,支持整图下发 - query_states = query_states.transpose(1, 2).contiguous() - key_states = key_states.transpose(1, 2).contiguous() - value_states = value_states.transpose(1, 2).contiguous() attn_output = torch_npu.npu_incre_flash_attention(query_states, key_states, value_states, @@ -662,9 +662,10 @@ class Qwen2Model(Qwen2PreTrainedModel): use_legacy_cache = not isinstance(past_key_values, Cache) if use_legacy_cache: kv_shape = ( - batch_size, self.config.max_position_embeddings, + batch_size, self.config.num_key_value_heads, - self.config.hidden_size // self.config.num_attention_heads) # (1, 32768, 2, 64) + self.config.max_position_embeddings, + self.config.hidden_size // self.config.num_attention_heads) # (1, 2, 32768, 64) past_key_values = () for _ in range(self.config.num_hidden_layers): k_cache = torch.zeros(kv_shape, dtype=inputs_embeds.dtype, device=inputs_embeds.device) diff --git a/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/README.md b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/README.md index 8261e80042..72639a4df5 100755 --- a/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/README.md +++ b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/README.md @@ -194,6 +194,8 @@ cd ModelZoo-PyTorch/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2 * 非流式输入:将推理结果保存在`sft_i.wav`中,并打屏性能数据:实时率(rtf),指的是平均1s时长的音频需要多少时间处理。 * 流式输入:将推理结果保存在`stream_input_out_i.wav`文件中,并打屏性能数据:实时率(rtf) + 3. 更换推理模式后需将已生成的.torchair_cache路径删除,避免场景不同但图复用导致的前向出错。 + ### 3 性能数据 | 模型 |芯片|rtf(实时率)| -- Gitee From f337101f80f6713b8221ea9780ae57b015bd0a18 Mon Sep 17 00:00:00 2001 From: pu-zhe Date: Thu, 24 Jul 2025 15:32:44 +0800 Subject: [PATCH 06/11] fix bugs --- .../CosyVoice2/300I/diff_CosyVoice_300I.patch | 82 ++++++++++--------- .../CosyVoice2/300I/modeling_qwen2.py | 59 +++++++------ 2 files changed, 74 insertions(+), 67 deletions(-) diff --git a/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/300I/diff_CosyVoice_300I.patch b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/300I/diff_CosyVoice_300I.patch index c6bf0eefef..6e519c3f6e 100644 --- a/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/300I/diff_CosyVoice_300I.patch +++ b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/300I/diff_CosyVoice_300I.patch @@ -172,10 +172,18 @@ index 6e10f00..25ad767 100644 speech_token_len = torch.tensor([speech_token.shape[1]], dtype=torch.int32).to(self.device) return speech_token, speech_token_len diff --git a/cosyvoice/cli/model.py b/cosyvoice/cli/model.py -index 9ebf8cb..3db0b4f 100644 +index 9ebf8cb..7b8aac4 100644 --- a/cosyvoice/cli/model.py +++ b/cosyvoice/cli/model.py -@@ -99,7 +99,7 @@ class CosyVoiceModel: +@@ -14,6 +14,7 @@ + import os + from typing import Generator + import torch ++import torch_npu + import numpy as np + import threading + import time +@@ -99,7 +100,7 @@ class CosyVoiceModel: self.flow.decoder.estimator = self.flow.decoder.estimator_engine.create_execution_context() def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid): @@ -184,7 +192,27 @@ index 9ebf8cb..3db0b4f 100644 if isinstance(text, Generator): assert isinstance(self, CosyVoice2Model), 'streaming input text is only implemented for CosyVoice2!' for i in self.llm.inference_bistream(text=text, -@@ -307,13 +307,25 @@ class CosyVoice2Model(CosyVoiceModel): +@@ -278,6 +279,19 @@ class CosyVoiceModel: + self.hift_cache_dict.pop(this_uuid) + torch.cuda.empty_cache() + ++ def _weight_format_cast(self, model: torch.nn.Module): ++ def _cast_to_internal_format(module: torch.nn.Module, class_name): ++ if issubclass(class_name, torch.nn.Linear): ++ if module.weight.data.is_cpu: ++ raise RuntimeError("FRACTAL_NZ is only supported on NPU tensor.") ++ module.weight.data = torch_npu.npu_format_cast(module.weight.data, 29) ++ current_class = model.__class__ ++ _cast_to_internal_format(model, current_class) ++ if not model.children: ++ return ++ for sub_module in model.children(): ++ if isinstance(sub_module, torch.nn.Module): ++ self._weight_format_cast(sub_module) + + class CosyVoice2Model(CosyVoiceModel): + +@@ -307,13 +321,25 @@ class CosyVoice2Model(CosyVoiceModel): self.speech_window = np.hamming(2 * self.source_cache_len) # rtf and decoding related self.stream_scale_factor = 1 @@ -211,7 +239,7 @@ index 9ebf8cb..3db0b4f 100644 def load_jit(self, flow_encoder_model): flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device) self.flow.encoder = flow_encoder -@@ -409,3 +421,83 @@ class CosyVoice2Model(CosyVoiceModel): +@@ -409,3 +435,83 @@ class CosyVoice2Model(CosyVoiceModel): self.tts_speech_token_dict.pop(this_uuid) self.llm_end_dict.pop(this_uuid) torch.cuda.empty_cache() @@ -462,7 +490,7 @@ index c47bf05..7dd9fb0 100644 + generated_speech = torch.clamp(x, -self.audio_limit, self.audio_limit) return generated_speech, s diff --git a/cosyvoice/llm/llm.py b/cosyvoice/llm/llm.py -index bbd3305..42d1ccf 100644 +index bbd3305..1380dad 100644 --- a/cosyvoice/llm/llm.py +++ b/cosyvoice/llm/llm.py @@ -11,6 +11,7 @@ @@ -517,7 +545,7 @@ index bbd3305..42d1ccf 100644 @torch.inference_mode() def inference( self, -@@ -318,9 +329,17 @@ class Qwen2LM(TransformerLM): +@@ -318,9 +329,16 @@ class Qwen2LM(TransformerLM): # 5. step by step decode out_tokens = [] cache = None @@ -525,8 +553,7 @@ index bbd3305..42d1ccf 100644 for i in range(max_len): + prompt_length = input_length + i + if i == 0: -+ seqlen_align = (lm_input.shape[1] + 15) // 16 * 16 if lm_input.shape[1] > 32 else 32 -+ masks = torch.tril(torch.ones((1, seqlen_align, seqlen_align), device=lm_input.device)).to(torch.bool).logical_not() ++ masks = torch.triu(torch.ones((1, 1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device), diagonal=1).to(lm_input.dtype) * -10000.0 + else: + masks = None y_pred, cache = self.llm.forward_one_step(lm_input, @@ -536,7 +563,7 @@ index bbd3305..42d1ccf 100644 cache=cache) logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1) top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item() -@@ -331,7 +350,7 @@ class Qwen2LM(TransformerLM): +@@ -331,7 +349,7 @@ class Qwen2LM(TransformerLM): # in stream mode, yield token one by one yield top_ids out_tokens.append(top_ids) @@ -545,31 +572,7 @@ index bbd3305..42d1ccf 100644 @torch.inference_mode() def inference_bistream( -@@ -392,8 +411,10 @@ class Qwen2LM(TransformerLM): - continue - while True: - seq_len = lm_input.shape[1] if cache is None else lm_input.shape[1] + cache[0][0].size(2) -+ seqlen_align = (seq_len + 15) // 16 * 16 if seq_len > 32 else 32 -+ masks = torch.tril(torch.ones((1, seqlen_align, seqlen_align), device=lm_input.device)).to(torch.bool) - y_pred, cache = self.llm.forward_one_step(lm_input, -- masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool), -+ masks=masks, - cache=cache) - logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1) - if next_fill_index != -1 and len(out_tokens) == next_fill_index: -@@ -418,8 +439,10 @@ class Qwen2LM(TransformerLM): - logging.info('no more text token, decode until met eos') - while True: - seq_len = lm_input.shape[1] if cache is None else lm_input.shape[1] + cache[0][0].size(2) -+ seqlen_align = (seq_len + 15) // 16 * 16 if seq_len > 32 else 32 -+ masks = torch.tril(torch.ones((1, seqlen_align, seqlen_align), device=lm_input.device)).to(torch.bool) - y_pred, cache = self.llm.forward_one_step(lm_input, -- masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool), -+ masks=masks, - cache=cache) - logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1) - top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=False).item() -@@ -432,3 +455,142 @@ class Qwen2LM(TransformerLM): +@@ -432,3 +450,141 @@ class Qwen2LM(TransformerLM): # in stream mode, yield token one by one yield top_ids lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1) @@ -591,11 +594,10 @@ index bbd3305..42d1ccf 100644 + min_token_text_ratio: float = 2, + ) -> Generator[torch.Tensor, None, None]: + -+ def build_causal_mask(query_len, key_len, devices): ++ def build_causal_mask(query_len, key_len, devices, dtype): + assert key_len >= query_len -+ key_len = (key_len + 15) // 16 * 16 if key_len > 32 else 32 -+ causal_mask = torch.triu(torch.ones((key_len, key_len), device=devices), diagonal=(key_len - query_len) + 1).to(torch.bool) -+ return causal_mask.unsqueeze(0) ++ causal_mask = torch.triu(torch.ones((1, 1, query_len, key_len), device=devices), diagonal=(key_len - query_len) + 1).to(dtype) * -10000.0 ++ return causal_mask + + device = prompt_text.device + @@ -651,7 +653,7 @@ index bbd3305..42d1ccf 100644 + seq_len = self.prompt_length[uuid] + if self.lm_input_dict[uuid].shape[1] > 1: + masks = build_causal_mask(self.lm_input_dict[uuid].shape[1], seq_len, -+ self.lm_input_dict[uuid].device) ++ self.lm_input_dict[uuid].device, self.lm_input_dict[uuid].dtype) + else: + masks = None + y_pred, self.cache_dict[uuid] = self.llm.forward_one_step(self.lm_input_dict[uuid], @@ -686,7 +688,7 @@ index bbd3305..42d1ccf 100644 + self.prompt_length[uuid] += self.lm_input_dict[uuid].shape[1] + seq_len = self.prompt_length[uuid] + if self.lm_input_dict[uuid].shape[1] > 1: -+ masks = build_causal_mask(self.lm_input_dict[uuid].shape[1], seq_len, self.lm_input_dict[uuid].device) ++ masks = build_causal_mask(self.lm_input_dict[uuid].shape[1], seq_len, self.lm_input_dict[uuid].device, self.lm_input_dict[uuid].dtype) + else: + masks = None + y_pred, self.cache_dict[uuid] = self.llm.forward_one_step(self.lm_input_dict[uuid], diff --git a/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/300I/modeling_qwen2.py b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/300I/modeling_qwen2.py index fa5aa51f72..36cbb67912 100644 --- a/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/300I/modeling_qwen2.py +++ b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/300I/modeling_qwen2.py @@ -13,6 +13,7 @@ import math import warnings from typing import List, Optional, Tuple, Union import torch +import torch.nn.functional as F import torch_npu import torchair as tng from torchair.configs.compiler_config import CompilerConfig @@ -142,6 +143,7 @@ class Qwen2Attention(nn.Module): self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.hidden_size // self.num_heads + self.scale = 1 / math.sqrt(self.head_dim) self.num_key_value_heads = config.num_key_value_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.max_position_embeddings = config.max_position_embeddings @@ -173,6 +175,16 @@ class Qwen2SdpaAttention(Qwen2Attention): `Qwen2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to SDPA API. """ + def group_mm_torch(self, heads, group_num, A, B): + group_head = heads // group_num + score = None + for i in range(group_num): + group_score = torch.matmul(A[i * group_head: (i + 1) * group_head, :, :], B[i: (i + 1), :, :]) + if score is None: + score = group_score + else: + score = torch.cat((score, group_score), 0) + return score # 优化Attention部分逻辑,替换torch_npu算子 def forward( @@ -217,11 +229,13 @@ class Qwen2SdpaAttention(Qwen2Attention): rotary_emb_cos.to(value_states.dtype), rotary_emb_sin.to(value_states.dtype)) + query_states = query_states.transpose(1, 2) # BNSD + key_states = key_states.transpose(1, 2) # BNSD + value_states = value_states.transpose(1, 2) # BNSD + if use_cache and past_key_value is not None: # 把计算好的kv值更新到kv cahce中 tmp_ids = updated_kv_positions.reshape(-1) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) torch_npu.scatter_update_(past_key_value.key_cache[self.layer_idx], tmp_ids, key_states, 2) torch_npu.scatter_update_(past_key_value.value_cache[self.layer_idx], tmp_ids, value_states, 2) # 流式输入场景,decode阶段 @@ -239,31 +253,22 @@ class Qwen2SdpaAttention(Qwen2Attention): else: raise ValueError(f"Unexpected q_len: {q_len}, actual_seq_len[0]: {actual_seq_len[0]}") - query_states = query_states.transpose(1, 2).contiguous() if q_len > 1: - # prefill阶段利用PFA自定义算子执行计算,因为bs为1,mask固定为下三角全为0上三角全为负无穷的倒三角mask矩阵 - kv_len = key_states.shape[2] - # 310P 仅支持qSeqlen = kvSeqlen - q_pad_len = (kv_len + 15) // 16 * 16 - q_len if kv_len > 32 else 32 - q_len - if q_pad_len > 0: - q_pad = torch.zeros((bsz, self.num_heads, q_pad_len, self.head_dim), dtype=query_states.dtype, device=query_states.device) - query_states = torch.cat([query_states, q_pad], dim=2).contiguous() - kv_pad_len = (kv_len + 15) // 16 * 16 - kv_len if kv_len > 32 else 32 - kv_len - if kv_pad_len > 0: - kv_pad = torch.zeros((bsz, self.num_key_value_heads, kv_pad_len, self.head_dim), dtype=key_states.dtype, device=key_states.device) - key_states = torch.cat([key_states, kv_pad], dim=2) - value_states = torch.cat([value_states, kv_pad], dim=2) - attn_output = torch_npu.npu_prompt_flash_attention(query_states, - key_states.contiguous(), - value_states.contiguous(), - num_heads=self.num_heads, - input_layout="BNSD", - scale_value=1 / math.sqrt(self.head_dim), - pre_tokens=65535, next_tokens=0, - atten_mask=attention_mask, - sparse_mode=1, - num_key_value_heads=self.num_key_value_heads) - attn_output = attn_output[:, :, :q_len, :].transpose(1, 2) + attn_output = None + for idx in range(bsz): + q_slice = query_states[idx] + k_slice = key_states[idx].transpose(1, 2) + v_slice = value_states[idx] + score = self.group_mm_torch(self.num_heads, self.num_key_value_heads, q_slice, k_slice) + score = score * self.scale + score = score + attention_mask[idx] + score = F.softmax(score, dim=-1, dtype=torch.float32).to(q_slice.dtype) + out = self.group_mm_torch(self.num_heads, self.num_key_value_heads, score, v_slice).unsqueeze(0) + if attn_output is None: + attn_output = out + else: + attn_output = torch.cat((attn_output, out), dim=0) + attn_output = attn_output.transpose(1, 2) else: # decode阶段利用IFA自定义算子执行计算,qkv的sequence都为1,该算子采用tiling下沉,视为静态算子,支持整图下发 attn_output = torch_npu.npu_incre_flash_attention(query_states, @@ -271,7 +276,7 @@ class Qwen2SdpaAttention(Qwen2Attention): value_states, num_heads=self.num_heads, input_layout="BNSD", - scale_value=1 / math.sqrt(self.head_dim), + scale_value=self.scale, atten_mask=None, actual_seq_lengths=actual_seq_len, num_key_value_heads=self.num_key_value_heads) -- Gitee From 1739c7d1018f1dd30fa8d6831b5b9d5baaaa15dd Mon Sep 17 00:00:00 2001 From: pu-zhe Date: Fri, 25 Jul 2025 16:14:50 +0800 Subject: [PATCH 07/11] separate commit --- .../CosyVoice2/300I/modeling_qwen2.py | 926 ------------------ 1 file changed, 926 deletions(-) delete mode 100644 ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/300I/modeling_qwen2.py diff --git a/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/300I/modeling_qwen2.py b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/300I/modeling_qwen2.py deleted file mode 100644 index 36cbb67912..0000000000 --- a/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/300I/modeling_qwen2.py +++ /dev/null @@ -1,926 +0,0 @@ -# Copyright (c) 2025 Huawei Technologies Co., Ltd -# [Software Name] is licensed under Mulan PSL v2. -# You can use this software according to the terms and conditions of the Mulan PSL v2. -# You may obtain a copy of Mulan PSL v2 at: -# http://license.coscl.org.cn/MulanPSL2 -# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, -# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, -# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. -# See the Mulan PSL v2 for more details. -""" PyTorch Qwen2 model.""" - -import math -import warnings -from typing import List, Optional, Tuple, Union -import torch -import torch.nn.functional as F -import torch_npu -import torchair as tng -from torchair.configs.compiler_config import CompilerConfig -import torch.utils.checkpoint -from torch import nn -from torch.nn import CrossEntropyLoss -from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache -from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast -from ...modeling_utils import PreTrainedModel -from ...utils import ( - add_start_docstrings, - logging -) -from .configuration_qwen2 import Qwen2Config - - -logger = logging.get_logger(__name__) - - -QWEN2_PRETRAINED_MODEL_ARCHIVE_LIST = [ - "Qwen/Qwen2-7B-beta", -] - - -# Ascend优化:Add/Norm昇腾自定义融合算子 -class Qwen2RMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - """ - Qwen2RMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, - hidden_states, - residual: Optional[torch.Tensor] = None): - if residual is None: - return torch_npu.npu_rms_norm(hidden_states, self.weight, self.variance_epsilon)[0], hidden_states - else: - y, _, x = torch_npu.npu_add_rms_norm(residual, hidden_states, self.weight, self.variance_epsilon) - return y, x - - -# Ascend优化:提前计算位置编码,无需在每层layer中重复计算 -class Qwen2RotaryEmbedding(nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): - super().__init__() - - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - - # Build here to make `torch.jit.trace` work. - self._set_cos_sin_cache( - seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() - ) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) - - freqs = torch.outer(t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) - - def forward(self, x=None, seq_len=None): - if x is None and seq_len is None: - return self.cos_cached, self.sin_cached - - return ( - self.cos_cached.to(dtype=x.dtype), - self.sin_cached.to(dtype=x.dtype), - ) - - -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., :x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2:] - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb(q, k, cos, sin): - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - -class Qwen2MLP(nn.Module): - def __init__(self, config): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) - self.act_fn = ACT2FN[config.hidden_act] - - def forward(self, x): - return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - - -class Qwen2Attention(nn.Module): - """ - Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer - and "Generating Long Sequences with Sparse Transformers". - """ - - def __init__(self, config: Qwen2Config, layer_idx: Optional[int] = None): - super().__init__() - self.config = config - self.layer_idx = layer_idx - if layer_idx is None: - logger.warning_once( - f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " - "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class." - ) - - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.scale = 1 / math.sqrt(self.head_dim) - self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.max_position_embeddings = config.max_position_embeddings - self.rope_theta = config.rope_theta - self.is_causal = True - self.attention_dropout = config.attention_dropout - - if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True) - self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) - self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) - - self.rotary_emb = Qwen2RotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - base=self.rope_theta, - ) - - -# Ascend优化:PFA/IFA自定义算子替换,kv cache固定shape并在指定位置更新 -class Qwen2SdpaAttention(Qwen2Attention): - """ - Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from - `Qwen2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to - SDPA API. - """ - def group_mm_torch(self, heads, group_num, A, B): - group_head = heads // group_num - score = None - for i in range(group_num): - group_score = torch.matmul(A[i * group_head: (i + 1) * group_head, :, :], B[i: (i + 1), :, :]) - if score is None: - score = group_score - else: - score = torch.cat((score, group_score), 0) - return score - - # 优化Attention部分逻辑,替换torch_npu算子 - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - updated_kv_positions: Optional[torch.LongTensor] = None, - actual_seq_len: Optional[list] = None, - rotary_emb_cos: Optional[torch.Tensor] = None, - rotary_emb_sin: Optional[torch.Tensor] = None, - output_attentions: bool = False, - use_cache: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if output_attentions: - logger.warning_once( - "Qwen2Model is using Qwen2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) - - # 利用已经提前计算好的位置编码数据对q,k值进行更新 - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, - rotary_emb_cos.to(value_states.dtype), - rotary_emb_sin.to(value_states.dtype)) - - query_states = query_states.transpose(1, 2) # BNSD - key_states = key_states.transpose(1, 2) # BNSD - value_states = value_states.transpose(1, 2) # BNSD - - if use_cache and past_key_value is not None: - # 把计算好的kv值更新到kv cahce中 - tmp_ids = updated_kv_positions.reshape(-1) - torch_npu.scatter_update_(past_key_value.key_cache[self.layer_idx], tmp_ids, key_states, 2) - torch_npu.scatter_update_(past_key_value.value_cache[self.layer_idx], tmp_ids, value_states, 2) - # 流式输入场景,decode阶段 - if q_len == 1: - key_states = past_key_value[self.layer_idx][0] - value_states = past_key_value[self.layer_idx][1] - # 流式输入场景,首次prefill阶段 - elif q_len == actual_seq_len[0]: - key_states = key_states - value_states = value_states - # 流式输入场景,后续prefill阶段 - elif q_len < actual_seq_len[0]: - key_states = past_key_value.key_cache[self.layer_idx][:, :, :actual_seq_len[0]] - value_states = past_key_value.value_cache[self.layer_idx][:, :, :actual_seq_len[0]] - else: - raise ValueError(f"Unexpected q_len: {q_len}, actual_seq_len[0]: {actual_seq_len[0]}") - - if q_len > 1: - attn_output = None - for idx in range(bsz): - q_slice = query_states[idx] - k_slice = key_states[idx].transpose(1, 2) - v_slice = value_states[idx] - score = self.group_mm_torch(self.num_heads, self.num_key_value_heads, q_slice, k_slice) - score = score * self.scale - score = score + attention_mask[idx] - score = F.softmax(score, dim=-1, dtype=torch.float32).to(q_slice.dtype) - out = self.group_mm_torch(self.num_heads, self.num_key_value_heads, score, v_slice).unsqueeze(0) - if attn_output is None: - attn_output = out - else: - attn_output = torch.cat((attn_output, out), dim=0) - attn_output = attn_output.transpose(1, 2) - else: - # decode阶段利用IFA自定义算子执行计算,qkv的sequence都为1,该算子采用tiling下沉,视为静态算子,支持整图下发 - attn_output = torch_npu.npu_incre_flash_attention(query_states, - key_states, - value_states, - num_heads=self.num_heads, - input_layout="BNSD", - scale_value=self.scale, - atten_mask=None, - actual_seq_lengths=actual_seq_len, - num_key_value_heads=self.num_key_value_heads) - attn_output = attn_output.transpose(1, 2) - - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - - attn_output = self.o_proj(attn_output) - - return attn_output, None, past_key_value - - -QWEN2_ATTENTION_CLASSES = { - "sdpa": Qwen2SdpaAttention, -} - - -# Ascend优化:每层layer的前后rms替换为昇腾自定义算子 -class Qwen2DecoderLayer(nn.Module): - def __init__(self, config: Qwen2Config, layer_idx: int): - super().__init__() - self.hidden_size = config.hidden_size - - if config.use_sliding_window and config._attn_implementation != "flash_attention_2": - logger.warning_once( - f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " - "unexpected results may be encountered." - ) - self.self_attn = QWEN2_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) - - self.mlp = Qwen2MLP(config) - self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - def forward( - self, - hidden_states: torch.Tensor, - past_residual: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - updated_kv_positions: Optional[torch.LongTensor] = None, - actual_seq_len: Optional[list] = None, - rotary_emb_cos: Optional[torch.Tensor] = None, - rotary_emb_sin: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - **kwargs, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. " - "Please make sure use `attention_mask` instead.`" - ) - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): attention mask of size - `(batch, sequence_length)` where padding elements are indicated by 0. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - """ - - # rms计算替换为昇腾自定义融合算子 - hidden_states, residual = self.input_layernorm(hidden_states, past_residual) - - # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - updated_kv_positions=updated_kv_positions, - actual_seq_len=actual_seq_len, - rotary_emb_cos=rotary_emb_cos, - rotary_emb_sin=rotary_emb_sin, - use_cache=use_cache, - ) - - # rms计算替换为昇腾自定义融合算子 - hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) - hidden_states = self.mlp(hidden_states) - - outputs = (residual, hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if use_cache: - outputs += (present_key_value,) - - return outputs - - -@add_start_docstrings( - "The bare Qwen2 Model outputting raw hidden-states without any specific head on top.", -) -class Qwen2PreTrainedModel(PreTrainedModel): - config_class = Qwen2Config - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["Qwen2DecoderLayer"] - _skip_keys_device_placement = "past_key_values" - _supports_flash_attn_2 = True - _supports_sdpa = True - _supports_cache_class = True - - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - - -# Ascend优化:forward函数利用torchair编译为图模式,利用cache接口避免重复编译 -@add_start_docstrings( - "The bare Qwen2 Model outputting raw hidden-states without any specific head on top.", -) -class Qwen2Model(Qwen2PreTrainedModel): - """ - Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`] - - Args: - config: Qwen2Config - """ - - def __init__(self, config: Qwen2Config): - super().__init__(config) - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - self.max_position_embeddings = config.max_position_embeddings - - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.rope_theta = config.rope_theta - - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - self.layers = nn.ModuleList( - [Qwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] - ) - self._attn_implementation = config._attn_implementation - self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.rotary_emb = Qwen2RotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - base=self.rope_theta, - ) - - self.gradient_checkpointing = False - # Initialize weights and apply final processing - self.post_init() - - # torchair编译参数,编译Qwen2Model的forward部分 - config = CompilerConfig() - config.experimental_config.frozen_parameter = True - # tiling下沉,主要针对IFA算子,使其算子tiling操作在AICPU上执行 - config.experimental_config.tiling_schedule_optimize = True - - # torchair的cache编译,保证模型编译cache文件,避免重复推理 - self.cached_decode = tng.inference.cache_compile(self.decode, config=config) - self.cached_first_prefill = tng.inference.cache_compile(self.first_prefill, config=config) # 用于首次prefill,无kv_cache - self.cached_next_prefill = tng.inference.cache_compile(self.next_prefill, config=config) # 用于后续prefill,有kv_cache - - def get_input_embeddings(self): - return self.embed_tokens - - def set_input_embeddings(self, value): - self.embed_tokens = value - - def _prepare_decoder_rotary_cos_sin(self, position_ids): - cos, sin = self.rotary_emb() - f_position_ids = position_ids.flatten() - cos = torch.index_select(cos, 0, f_position_ids) - sin = torch.index_select(sin, 0, f_position_ids) - cos = cos.reshape(position_ids.size(0), position_ids.size(1), -1).unsqueeze(2) - sin = sin.reshape(position_ids.size(0), position_ids.size(1), -1).unsqueeze(2) - return cos, sin - - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - updated_kv_positions: Optional[torch.LongTensor] = None, - actual_seq_len: Optional[list] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - lm_head: Optional[object] = None - ) -> Union[Tuple, BaseModelOutputWithPast]: - # prefill_1, prefill_2和decode需要编译为3个不同的模型 - seq_len = inputs_embeds.size(1) - - # 流式输入场景,首次prefill - if seq_len > 1 and seq_len == actual_seq_len[0]: - return self.cached_first_prefill( - input_ids, - attention_mask, - position_ids, - past_key_values, - updated_kv_positions, - actual_seq_len, - inputs_embeds, - use_cache, - output_attentions, - output_hidden_states, - return_dict, - lm_head - ) - - # 流式输入场景,后续prefill - elif 1 < seq_len < actual_seq_len[0]: - return self.cached_next_prefill( - input_ids, - attention_mask, - position_ids, - past_key_values, - updated_kv_positions, - actual_seq_len, - inputs_embeds, - use_cache, - output_attentions, - output_hidden_states, - return_dict, - lm_head - ) - - # 流式输入场景,decode - else: - return self.cached_decode( - input_ids, - attention_mask, - position_ids, - past_key_values, - updated_kv_positions, - actual_seq_len, - inputs_embeds, - use_cache, - output_attentions, - output_hidden_states, - return_dict, - lm_head - ) - - def decode( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - updated_kv_positions: Optional[torch.LongTensor] = None, - actual_seq_len: Optional[list] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - lm_head: Optional[object] = None - ): - return self._forward( - input_ids, - attention_mask, - position_ids, - past_key_values, - updated_kv_positions, - actual_seq_len, - inputs_embeds, - use_cache, - output_attentions, - output_hidden_states, - return_dict, - lm_head - ) - - def first_prefill( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - updated_kv_positions: Optional[torch.LongTensor] = None, - actual_seq_len: Optional[list] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - lm_head: Optional[object] = None - ): - return self._forward( - input_ids, - attention_mask, - position_ids, - past_key_values, - updated_kv_positions, - actual_seq_len, - inputs_embeds, - use_cache, - output_attentions, - output_hidden_states, - return_dict, - lm_head - ) - - def next_prefill( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - updated_kv_positions: Optional[torch.LongTensor] = None, - actual_seq_len: Optional[list] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - lm_head: Optional[object] = None - ): - return self._forward( - input_ids, - attention_mask, - position_ids, - past_key_values, - updated_kv_positions, - actual_seq_len, - inputs_embeds, - use_cache, - output_attentions, - output_hidden_states, - return_dict, - lm_head - ) - - - def _forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - updated_kv_positions: Optional[torch.LongTensor] = None, - actual_seq_len: Optional[list] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - lm_head: Optional[object] = None - ) -> Union[Tuple, BaseModelOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - - - # prefill阶段初始化kv cache,decode阶段对kv cache进行更新 - # 固定kv cache为最大shape,避免内存的重复申请和拷贝,也保证了模型的静态shape,可整图下发推理 - if use_cache: - use_legacy_cache = not isinstance(past_key_values, Cache) - if use_legacy_cache: - kv_shape = ( - batch_size, - self.config.num_key_value_heads, - self.config.max_position_embeddings, - self.config.hidden_size // self.config.num_attention_heads) # (1, 2, 32768, 64) - past_key_values = () - for _ in range(self.config.num_hidden_layers): - k_cache = torch.zeros(kv_shape, dtype=inputs_embeds.dtype, device=inputs_embeds.device) - v_cache = torch.zeros(kv_shape, dtype=inputs_embeds.dtype, device=inputs_embeds.device) - past_key_values += ((k_cache, v_cache),) - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - - past_key_values_length = self.max_position_embeddings if actual_seq_len[0] > inputs_embeds.shape[1] else 0 - - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device - ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) # tensor([0, 1, 2, 3, 4, 5], device='npu:0') - else: - position_ids = position_ids.view(-1, seq_length).long() - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - - hidden_states = inputs_embeds - - # 此处统一计算位置编码,在每个layer中取对应位置的值 - rotary_emb_cos, rotary_emb_sin = self._prepare_decoder_rotary_cos_sin(position_ids) - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = None - - residual = None - for decoder_layer in self.layers: - if output_hidden_states: - all_hidden_states += (hidden_states,) - - # 执行layer层推理 - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - past_residual=residual, - position_ids=position_ids, - past_key_value=past_key_values, - updated_kv_positions=updated_kv_positions, - actual_seq_len=actual_seq_len, - rotary_emb_cos=rotary_emb_cos, - rotary_emb_sin=rotary_emb_sin, - output_attentions=output_attentions, - use_cache=use_cache, - ) - - residual = layer_outputs[0] - hidden_states = layer_outputs[1] - - if use_cache: - next_decoder_cache = layer_outputs[3 if output_attentions else 2] - - if output_attentions: - all_self_attns += (layer_outputs[2],) - - # norm计算,此处替换为昇腾融合算子 - hidden_states, _ = self.norm(hidden_states, residual) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - next_cache = None - if use_cache: - next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache - - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - - out = BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - - hidden_states = out[0] - # 由于logits最后也只取[:,-1,:],相当于只取最新seq位置上的数据,l - # 所以在全量的最后线性层计算可以只对最新的seq位置做计算,降低计算量 - bs, seq, hidden = hidden_states.size() - if seq > 1: - gather_index = torch.ones(bs, dtype=torch.int64, device=hidden_states.device) * (seq - 1) - gather_index = gather_index.unsqueeze(dim=1).unsqueeze(dim=2).repeat(1, 1, hidden) - hidden_states = torch.gather(hidden_states, 1, gather_index) - logits = lm_head(hidden_states) - logits = logits.float() - return out, logits - - -class Qwen2ForCausalLM(Qwen2PreTrainedModel): - _tied_weights_keys = ["lm_head.weight"] - - def __init__(self, config): - super().__init__(config) - self.model = Qwen2Model(config) - self.vocab_size = config.vocab_size - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def set_decoder(self, decoder): - self.model = decoder - - def get_decoder(self): - return self.model - - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - prompt_length: Optional[torch.LongTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None - ) -> Union[Tuple, CausalLMOutputWithPast]: - """ - 对CosyVoice2模型中使用的Qwen模型进行昇腾适配优化,具体优化点有: - 1. 固定KV CACHE大小,避免重复申请内存和拷贝 - 2. 替换部分算子为昇腾自定义算子 - 3. 首层计算位置编码避免重复计算 - 4. 在decode阶段,固定输入shape大小,保证整图下发 - - 模型有以下输入: - 1. attention_mask - 2. inputs_embeds:CosyVoice会把inputs_ids处理embeding后输入模型 - 3. past_key_values:kv cache,在每次推理后会进行更新 - 4. position_ids:位置id,在每次推理后会进行更新 - 5. prompt_length:实际输入长度,在prefill阶段为首token长度,后续每次推理长度加1 - """ - - # 每次推理前对输入数据进行昇腾适配处理,处理为昇腾自定义算子所需类型参数 - updated_kv_positions, past_key_values, position_ids, actual_seq_len = self.prepare_data(inputs_embeds, past_key_values, prompt_length) - - model_inputs = { - "inputs_embeds": inputs_embeds, - "past_key_values": past_key_values, - "position_ids": position_ids, - "actual_seq_len": actual_seq_len, - "attention_mask": attention_mask, - } - - # prefill阶段由于输出token长度不固定,为动态shape推理。decode阶段把输入固定为静态,保证整图静态推理。 - if inputs_embeds.shape[1] == 1: - self._mark_model_inputs_static(model_inputs) - - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # 主要推理阶段,利用torchair编译为整图推理 - outputs, logits = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - updated_kv_positions=updated_kv_positions, - actual_seq_len=actual_seq_len, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - lm_head=self.lm_head - ) - - loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - # Ascend优化:把数据输入处理为Ascend优化所需要的格式和类型 - def prepare_data(self, inputs_embeds, past_key_values, prompt_length): - bsz = inputs_embeds.shape[0] - seq_length = inputs_embeds.shape[1] - if past_key_values: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - - # 流式输入场景,首次prefill - if seq_length > 1 and prompt_length == inputs_embeds.shape[1]: - updated_kv_positions = torch.zeros(bsz, dtype=torch.long, device=inputs_embeds.device) - position_ids = None - # 流式输入场景,后续prefill - elif seq_length > 1 and prompt_length > inputs_embeds.shape[1]: - # updated_kv_positions,kv_cache需要更新的起始位置 - updated_kv_positions = torch.ones(bsz, dtype=torch.long, device=inputs_embeds.device) * (prompt_length - inputs_embeds.shape[1]) - tmp_head = prompt_length - inputs_embeds.shape[1] - tmp_tail = prompt_length - position_ids = torch.arange(tmp_head, tmp_tail, dtype=torch.long, device=inputs_embeds.device) - # 流式输入场景,decode - else: - updated_kv_positions = torch.ones(bsz, dtype=torch.long, device=inputs_embeds.device) * (prompt_length - 1) - position_ids = torch.tensor([prompt_length - 1], device=inputs_embeds.device) - - # ifa Computational optimization inputs - actual_seq_len = ([prompt_length]) - - return updated_kv_positions, past_key_values, position_ids, actual_seq_len - - # Ascend优化:固定input shape,使能静态推理,模型整图下发 - def _mark_model_inputs_static(self, model_inputs): - for key, value in model_inputs.items(): - if key == "past_key_values" and value is not None: - for i in range(self.config.num_hidden_layers): - torch._dynamo.mark_static(value[i][0]) - torch._dynamo.mark_static(value[i][1]) - elif isinstance(value, torch.Tensor): - torch._dynamo.mark_static(value) - -- Gitee From e175c93fe9ac26ad06a63859bff4711fe1b4a725 Mon Sep 17 00:00:00 2001 From: pu-zhe Date: Sat, 26 Jul 2025 10:25:45 +0800 Subject: [PATCH 08/11] optimize infer script --- .../CosyVoice2/300I/diff_CosyVoice_300I.patch | 85 +++---------------- .../audio/CosyVoice/CosyVoice2/300I/infer.py | 17 +++- .../audio/CosyVoice/CosyVoice2/README.md | 6 +- 3 files changed, 29 insertions(+), 79 deletions(-) diff --git a/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/300I/diff_CosyVoice_300I.patch b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/300I/diff_CosyVoice_300I.patch index 6e519c3f6e..a760ffd311 100644 --- a/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/300I/diff_CosyVoice_300I.patch +++ b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/300I/diff_CosyVoice_300I.patch @@ -1,5 +1,5 @@ diff --git a/cosyvoice/cli/cosyvoice.py b/cosyvoice/cli/cosyvoice.py -index e2d62e2..ac45938 100644 +index e2d62e2..0d4f860 100644 --- a/cosyvoice/cli/cosyvoice.py +++ b/cosyvoice/cli/cosyvoice.py @@ -13,11 +13,14 @@ @@ -22,10 +22,10 @@ index e2d62e2..ac45938 100644 start_time = time.time() logging.info('synthesis text {}'.format(i)) - for model_output in self.model.tts(**model_input, stream=stream, speed=speed): -+ for i, model_output in enumerate(self.model.tts(**model_input, stream=stream, speed=speed)): ++ for idx, model_output in enumerate(self.model.tts(**model_input, stream=stream, speed=speed)): speech_len = model_output['tts_speech'].shape[1] / self.sample_rate - logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len)) -+ if i == 0: ++ if idx == 0: + logging.info('yield speech len {}, rtf {}, TTFT {}'.format(speech_len, (time.time() - start_time) / speech_len, time.time() - start_time)) + else: + logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len)) @@ -37,60 +37,17 @@ index e2d62e2..ac45938 100644 start_time = time.time() logging.info('synthesis text {}'.format(i)) - for model_output in self.model.tts(**model_input, stream=stream, speed=speed): -+ for i, model_output in enumerate(self.model.tts(**model_input, stream=stream, speed=speed)): ++ for idx, model_output in enumerate(self.model.tts(**model_input, stream=stream, speed=speed)): speech_len = model_output['tts_speech'].shape[1] / self.sample_rate - logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len)) -+ if i == 0: ++ if idx == 0: + logging.info('yield speech len {}, rtf {}, TTFT {}'.format(speech_len, (time.time() - start_time) / speech_len, time.time() - start_time)) + else: + logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len)) yield model_output start_time = time.time() -@@ -93,9 +102,12 @@ class CosyVoice: - model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k, self.sample_rate) - start_time = time.time() - logging.info('synthesis text {}'.format(i)) -- for model_output in self.model.tts(**model_input, stream=stream, speed=speed): -+ for i, model_output in enumerate(self.model.tts(**model_input, stream=stream, speed=speed)): - speech_len = model_output['tts_speech'].shape[1] / self.sample_rate -- logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len)) -+ if i == 0: -+ logging.info('yield speech len {}, rtf {}, TTFT {}'.format(speech_len, (time.time() - start_time) / speech_len, time.time() - start_time)) -+ else: -+ logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len)) - yield model_output - start_time = time.time() - -@@ -108,25 +120,31 @@ class CosyVoice: - model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text) - start_time = time.time() - logging.info('synthesis text {}'.format(i)) -- for model_output in self.model.tts(**model_input, stream=stream, speed=speed): -+ for i, model_output in enumerate(self.model.tts(**model_input, stream=stream, speed=speed)): - speech_len = model_output['tts_speech'].shape[1] / self.sample_rate -- logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len)) -+ if i == 0: -+ logging.info('yield speech len {}, rtf {}, TTFT {}'.format(speech_len, (time.time() - start_time) / speech_len, time.time() - start_time)) -+ else: -+ logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len)) - yield model_output - start_time = time.time() - - def inference_vc(self, source_speech_16k, prompt_speech_16k, stream=False, speed=1.0): - model_input = self.frontend.frontend_vc(source_speech_16k, prompt_speech_16k, self.sample_rate) - start_time = time.time() -- for model_output in self.model.vc(**model_input, stream=stream, speed=speed): -+ for i, model_output in enumerate(self.model.vc(**model_input, stream=stream, speed=speed)): - speech_len = model_output['tts_speech'].shape[1] / self.sample_rate -- logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len)) -+ if i == 0: -+ logging.info('yield speech len {}, rtf {}, TTFT {}'.format(speech_len, (time.time() - start_time) / speech_len, time.time() - start_time)) -+ else: -+ logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len)) - yield model_output - start_time = time.time() - +@@ -126,7 +135,7 @@ class CosyVoice: class CosyVoice2(CosyVoice): @@ -99,7 +56,7 @@ index e2d62e2..ac45938 100644 self.instruct = True if '-Instruct' in model_dir else False self.model_dir = model_dir self.fp16 = fp16 -@@ -155,6 +173,16 @@ class CosyVoice2(CosyVoice): +@@ -155,6 +164,16 @@ class CosyVoice2(CosyVoice): self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'), '{}/flow.decoder.estimator.fp32.onnx'.format(model_dir), self.fp16) @@ -116,7 +73,7 @@ index e2d62e2..ac45938 100644 del configs def inference_instruct(self, *args, **kwargs): -@@ -171,3 +199,19 @@ class CosyVoice2(CosyVoice): +@@ -171,3 +190,19 @@ class CosyVoice2(CosyVoice): logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len)) yield model_output start_time = time.time() @@ -172,7 +129,7 @@ index 6e10f00..25ad767 100644 speech_token_len = torch.tensor([speech_token.shape[1]], dtype=torch.int32).to(self.device) return speech_token, speech_token_len diff --git a/cosyvoice/cli/model.py b/cosyvoice/cli/model.py -index 9ebf8cb..7b8aac4 100644 +index 9ebf8cb..407f1ae 100644 --- a/cosyvoice/cli/model.py +++ b/cosyvoice/cli/model.py @@ -14,6 +14,7 @@ @@ -192,27 +149,7 @@ index 9ebf8cb..7b8aac4 100644 if isinstance(text, Generator): assert isinstance(self, CosyVoice2Model), 'streaming input text is only implemented for CosyVoice2!' for i in self.llm.inference_bistream(text=text, -@@ -278,6 +279,19 @@ class CosyVoiceModel: - self.hift_cache_dict.pop(this_uuid) - torch.cuda.empty_cache() - -+ def _weight_format_cast(self, model: torch.nn.Module): -+ def _cast_to_internal_format(module: torch.nn.Module, class_name): -+ if issubclass(class_name, torch.nn.Linear): -+ if module.weight.data.is_cpu: -+ raise RuntimeError("FRACTAL_NZ is only supported on NPU tensor.") -+ module.weight.data = torch_npu.npu_format_cast(module.weight.data, 29) -+ current_class = model.__class__ -+ _cast_to_internal_format(model, current_class) -+ if not model.children: -+ return -+ for sub_module in model.children(): -+ if isinstance(sub_module, torch.nn.Module): -+ self._weight_format_cast(sub_module) - - class CosyVoice2Model(CosyVoiceModel): - -@@ -307,13 +321,25 @@ class CosyVoice2Model(CosyVoiceModel): +@@ -307,13 +308,25 @@ class CosyVoice2Model(CosyVoiceModel): self.speech_window = np.hamming(2 * self.source_cache_len) # rtf and decoding related self.stream_scale_factor = 1 @@ -239,7 +176,7 @@ index 9ebf8cb..7b8aac4 100644 def load_jit(self, flow_encoder_model): flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device) self.flow.encoder = flow_encoder -@@ -409,3 +435,83 @@ class CosyVoice2Model(CosyVoiceModel): +@@ -409,3 +422,83 @@ class CosyVoice2Model(CosyVoiceModel): self.tts_speech_token_dict.pop(this_uuid) self.llm_end_dict.pop(this_uuid) torch.cuda.empty_cache() diff --git a/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/300I/infer.py b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/300I/infer.py index 72d584396f..3f17564040 100644 --- a/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/300I/infer.py +++ b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/300I/infer.py @@ -9,7 +9,7 @@ # See the Mulan PSL v2 for more details. import argparse -from tqdm import tqdm +import time import torch import torchaudio import torch_npu @@ -17,7 +17,6 @@ from torch_npu.contrib import transfer_to_npu import torchair as tng from torchair.configs.compiler_config import CompilerConfig from cosyvoice.cli.cosyvoice import CosyVoice2 -from cosyvoice.utils.file_utils import load_wav def no_stream_input_inference(args, cosyvoice, prompt_txt): @@ -28,11 +27,18 @@ def no_stream_input_inference(args, cosyvoice, prompt_txt): pass print('warm up end') infer_res = [torch.tensor([]) for _ in range(args.infer_count)] + rtf = [] for i_step in range(args.infer_count): + start_time = time.time() for _, j in enumerate(cosyvoice.inference_sft(prompt_txt[0], '中文女', stream=args.stream_out)): infer_res[i_step] = torch.cat((infer_res[i_step], j['tts_speech']), dim=1) + end_time = time.time() + speech_len = infer_res[i_step].shape[1] / cosyvoice.sample_rate + print(f"singe infer RTF: {(end_time - start_time) / speech_len}") + rtf.append((end_time - start_time) / speech_len) print(f"save out wav file to sft_out_{i_step+1}.wav") torchaudio.save(f"sft_out_{i_step+1}.wav", infer_res[i_step], cosyvoice.sample_rate) + print(f"avg RTF: {sum(rtf) / len(rtf)}") def stream_input_inference(args, cosyvoice, prompt_txt): @@ -64,10 +70,17 @@ def stream_input_inference(args, cosyvoice, prompt_txt): print("warm up end") print("inference start") + rtf = [] for i_step in range(args.infer_count): + start_time = time.time() inference_step(i_step, mode="inference") + end_time = time.time() + speech_len = infer_res[i_step].shape[1] / cosyvoice.sample_rate + print(f"avg RTF: {(end_time - start_time) / speech_len}") + rtf.append((end_time - start_time) / speech_len) print(f"save out wav file to stream_input_out_{i_step+1}.wav") torchaudio.save(f"stream_input_out_{i_step+1}.wav", infer_res[i_step], cosyvoice.sample_rate) + print(f"avg RTF: {sum(rtf) / len(rtf)}") print("inference end") diff --git a/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/README.md b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/README.md index 72639a4df5..94e79df866 100755 --- a/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/README.md +++ b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/README.md @@ -194,12 +194,12 @@ cd ModelZoo-PyTorch/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2 * 非流式输入:将推理结果保存在`sft_i.wav`中,并打屏性能数据:实时率(rtf),指的是平均1s时长的音频需要多少时间处理。 * 流式输入:将推理结果保存在`stream_input_out_i.wav`文件中,并打屏性能数据:实时率(rtf) - 3. 更换推理模式后需将已生成的.torchair_cache路径删除,避免场景不同但图复用导致的前向出错。 + 3. 如因为意外操作导致torchair编译失败,需将已生成的.torchair_cache路径删除,避免使用编译错误的图导致的前向出错。 ### 3 性能数据 | 模型 |芯片|rtf(实时率)| |-----------|------|------| - | cosyvoice |800I A2|0.28s| - | cosyvoice |300I DUO|0.90s| + | cosyvoice |800I A2|0.28| + | cosyvoice |300I DUO|0.7s| -- Gitee From 5542e3b0bf8c8367168e798fbc7699adf58a36fd Mon Sep 17 00:00:00 2001 From: pu-zhe Date: Sat, 26 Jul 2025 12:02:43 +0800 Subject: [PATCH 09/11] 310 and 910 share infer.py --- .../audio/CosyVoice/CosyVoice2/800I/infer.py | 106 ------------------ .../audio/CosyVoice/CosyVoice2/README.md | 7 +- .../CosyVoice/CosyVoice2/{300I => }/infer.py | 0 3 files changed, 4 insertions(+), 109 deletions(-) delete mode 100644 ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/800I/infer.py rename ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/{300I => }/infer.py (100%) diff --git a/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/800I/infer.py b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/800I/infer.py deleted file mode 100644 index 972cb18522..0000000000 --- a/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/800I/infer.py +++ /dev/null @@ -1,106 +0,0 @@ -# Copyright (c) 2025 Huawei Technologies Co., Ltd -# [Software Name] is licensed under Mulan PSL v2. -# You can use this software according to the terms and conditions of the Mulan PSL v2. -# You may obtain a copy of Mulan PSL v2 at: -# http://license.coscl.org.cn/MulanPSL2 -# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, -# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, -# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. -# See the Mulan PSL v2 for more details. - -import argparse -from tqdm import tqdm -import torch -import torchaudio -import torch_npu -from torch_npu.contrib import transfer_to_npu -import torchair as tng -from torchair.configs.compiler_config import CompilerConfig -from cosyvoice.cli.cosyvoice import CosyVoice2 -from cosyvoice.utils.file_utils import load_wav - - -def no_stream_input_inference(args, cosyvoice, prompt_txt): - with torch.no_grad(): - print('warm up start') - for _ in range(args.warm_up_times): - for _ in enumerate(cosyvoice.inference_sft(prompt_txt[0], '中文女', stream=args.stream_out)): - pass - print('warm up end') - for _ in range(args.infer_count): - for i, j in enumerate(cosyvoice.inference_sft(prompt_txt[0], '中文女', stream=args.stream_out)): - torchaudio.save('sft_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate) - - -def stream_input_inference(args, cosyvoice, prompt_txt): - - def inference_step(step, mode): - times = args.warm_up_times if mode == "warmup" else args.infer_count - print(f"第{step + 1}/{times}轮 {mode}:↓↓↓") - print(f"curr prompt text:{prompt_txt[step % len(prompt_txt)]}") - for char_idx, char in enumerate(prompt_txt[step % len(prompt_txt)]): - if char_idx == len(prompt_txt[step % len(prompt_txt)]) - 1: - for _, j in enumerate(cosyvoice.inference_sft_streaming_input(char, char_idx, "中文女", user_id="AscendUser", input_end=True, stream=args.stream_out)): - if mode == "warmup": - pass - else: - infer_res[i_step] = torch.cat((infer_res[i_step], j['tts_speech']), dim=1) - else: - for _, j in enumerate(cosyvoice.inference_sft_streaming_input(char, char_idx, "中文女", user_id="AscendUser", input_end=False, stream=args.stream_out)): - if mode == "warmup": - pass - else: - infer_res[i_step] = torch.cat((infer_res[i_step], j['tts_speech']), dim=1) - - infer_res = [torch.tensor([]) for _ in range(args.infer_count)] - - with torch.no_grad(): - print("warm up start") - for w_step in range(args.warm_up_times): - inference_step(w_step, mode="warmup") - print("warm up end") - - print("inference start") - for i_step in range(args.infer_count): - inference_step(i_step, mode="inference") - print("inference end") - - print(f"save out wav file ...") - for i_step in tqdm(range(args.infer_count)): - torchaudio.save(f"stream_input_out_{i_step+1}.wav", infer_res[i_step], 24000) - -if __name__ == '__main__': - torch_npu.npu.set_compile_mode(jit_compile=False) - - parser = argparse.ArgumentParser(description="CosyVoice2 infer") - parser.add_argument("--model_path", type=str, help="model path") - parser.add_argument('--warm_up_times', default=2, type=int, help='warm up times') - parser.add_argument('--infer_count', default=20, type=int, help='infer loop count') - parser.add_argument('--stream_in', action="store_true", help='stream input infer') - parser.add_argument('--stream_out', action="store_true", help='stream output infer') - args = parser.parse_args() - - cosyvoice = CosyVoice2(args.model_path, load_om=True, fp16=True) - cosyvoice.model.llm.eval() - cosyvoice.model.llm.llm.model.model.half() - - # 对hift模型结构进行torchair图模式适配 - cosyvoice.model.hift.remove_weight_norm() - config = CompilerConfig() - config.experimental_config.frozen_parameter = True - config.experimental_config.tiling_schedule_optimize = True - npu_backend = tng.get_npu_backend(compiler_config=config) - cosyvoice.model.hift.decode = torch.compile(cosyvoice.model.hift.decode, dynamic=True, fullgraph=True, backend=npu_backend) - - # 输入数据加载 - prompt_txt = [ - '收到好友从远方寄来的生日礼物,那份意外的惊喜和深深的祝福,让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', - '全球每年有超过一百三十五万人,因吸烟而死亡' - ] - - # 普通输入(非流式输入) - if not args.stream_in: - no_stream_input_inference(args, cosyvoice, prompt_txt) - # 流式输入 - else: - stream_input_inference(args, cosyvoice, prompt_txt) diff --git a/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/README.md b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/README.md index 94e79df866..d398ee4bc1 100755 --- a/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/README.md +++ b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/README.md @@ -56,6 +56,8 @@ cd ModelZoo-PyTorch/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2 git reset --hard fd45708 git submodule update --init --recursive git apply ../${platform}/diff_CosyVoice_${platform}.patch + # 将infer.py复制到CosyVoice中 + cp ../infer.py ./ # 获取Transformer源码 cd .. git clone https://github.com/huggingface/transformers.git @@ -72,17 +74,16 @@ cd ModelZoo-PyTorch/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2 ├── 📁 CosyVoice2/ | |── 📁 300I | |── 📄 diff_CosyVoice_300I.patch - | |── 📄 infer.py # 推理脚本 | |── 📄 modeling_qwen2.py | |── 📁 800I | |── 📄 diff_CosyVoice_800I.patch - | |── 📄 infer.py # 推理脚本 | |── 📄 modeling_qwen2.py | |── 📁 CosyVoice | |── 📁 cosyVoice源码文件 # cosyVoice的源码文件,此处不一一列举 │ ├── 📁 CosyVoice-0.5B/ # 权重文件 │ ├── 📁 transformers/ # transformers库,里面修改modeling_qwen2.py文件 │── 📄 requirements.txt # 依赖库 + |── 📄 infer.py # 推理脚本 └── 📄 modify_onnx.py # 模型转换脚本 ``` @@ -201,5 +202,5 @@ cd ModelZoo-PyTorch/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2 | 模型 |芯片|rtf(实时率)| |-----------|------|------| | cosyvoice |800I A2|0.28| - | cosyvoice |300I DUO|0.7s| + | cosyvoice |300I DUO|0.75| diff --git a/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/300I/infer.py b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/infer.py similarity index 100% rename from ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/300I/infer.py rename to ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/infer.py -- Gitee From 7c56b9afb73b71db2f197ab5dfe8c66aa7cf909c Mon Sep 17 00:00:00 2001 From: pu-zhe Date: Sat, 26 Jul 2025 12:09:44 +0800 Subject: [PATCH 10/11] update readme --- ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/README.md | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/README.md b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/README.md index d398ee4bc1..4860dd966a 100755 --- a/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/README.md +++ b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/README.md @@ -171,10 +171,7 @@ cd ModelZoo-PyTorch/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2 ### 2 开始推理验证 - 1. 首先移动对应推理平台路径下的infer.py文件到CosyVoice目录下 - - - 2. 设置环境变量,执行推理命令 + 1. 设置环境变量,执行推理命令 ``` # 1. 指定使用NPU ID,默认为0 -- Gitee From 1b7d8b58560441083751dc485980caaa2b2ad9ae Mon Sep 17 00:00:00 2001 From: pu-zhe Date: Tue, 29 Jul 2025 08:27:41 +0800 Subject: [PATCH 11/11] increase E2E precision --- ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/README.md b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/README.md index 4860dd966a..b418438adf 100755 --- a/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/README.md +++ b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/README.md @@ -161,7 +161,7 @@ cd ModelZoo-PyTorch/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2 执行ATC命令,将利用npu-smi info命令获取的芯片型号填入${soc_version}中 ``` - atc --framework=5 --soc_version=${soc_version} --model ./${CosyVoice2-0.5B}/speech_token_md.onnx --output ./${CosyVoice2-0.5B}/speech --input_shape="feats:1,128,-1;feats_length:1" + atc --framework=5 --soc_version=${soc_version} --model ./${CosyVoice2-0.5B}/speech_token_md.onnx --output ./${CosyVoice2-0.5B}/speech --input_shape="feats:1,128,-1;feats_length:1" --precision_mode allow_fp32_to_fp16 atc --framework=5 --soc_version=${soc_version} --model ./${CosyVoice2-0.5B}/flow.decoder.estimator.fp32.onnx --output ./${CosyVoice2-0.5B}/flow --input_shape="x:2,80,-1;mask:2,1,-1;mu:2,80,-1;t:2;spks:2,80;cond:2,80,-1" atc --framework=5 --soc_version=${soc_version} --model ./${CosyVoice2-0.5B}/flow.decoder.estimator.fp32.onnx --output ./${CosyVoice2-0.5B}/flow_static --input_shape="x:2,80,-1;mask:2,1,-1;mu:2,80,-1;t:2;spks:2,80;cond:2,80,-1" --dynamic_dims="100,100,100,100;200,200,200,200;300,300,300,300;400,400,400,400;500,500,500,500;600,600,600,600;700,700,700,700" --input_format=ND ``` -- Gitee