diff --git a/MindIE/MultiModal/StableAnimator/README.md b/MindIE/MultiModal/StableAnimator/README.md new file mode 100644 index 0000000000000000000000000000000000000000..f296cfd0d670653062800eb8989599e551e6866c --- /dev/null +++ b/MindIE/MultiModal/StableAnimator/README.md @@ -0,0 +1,125 @@ +# MindSample-infer-Models + + +### 介绍 +StableAnimator(https://github.com/Francis-Rings/StableAnimator) npu迁移适配 + + +### 环境准备 +- 硬件设备 + - Atlas 800I A2(32GB/64GB显存) + - Atlas 800T A2 + + +### 1. 安装基础环境 +- 驱动和固件 +- PYTHON +- CANN 8.1.RC1 +- Pytorch & PTA +- NNAL atb包 +- mindie + +### 2 安装依赖 +``` +pip3 install -r requirements.txt +``` + +#### 拉取源码 +``` +git clone https://github.com/Francis-Rings/StableAnimator + +cp -rf npu_adaptive/animation StableAnimator/ +cp -rf npu_adaptive/inference_basic.py StableAnimator/ +``` + + +### 3. 权重和数据集准备 + +#### 权重及数据集下载 +``` +cd StableAnimator + +git lfs clone https://huggingface.co/FrancisRing/StableAnimator checkpoints + +mv checkpoints/inference.zip ./ +unzip inference.zip +``` + + + +### 4. 推理指导 + + + +#### 两卡推理 +``` +cd StableAnimator + +export PYTORCH_NPU_ALLOC_CONF='expandable_segments:True' +export TASK_QUEUE_ENABLE=2 +export CPU_AFFINITY_CONF=1 +export TOKENIZERS_PARALLELISM=false + +torchrun --nproc_per_node=2 inference_basic.py \ + --pretrained_model_name_or_path="./checkpoints/SVD/stable-video-diffusion-img2vid-xt" \ + --output_dir="./basic_infer" \ + --validation_control_folder="./inference/case-1/poses" \ + --validation_image="./inference/case-1/reference.png" \ + --width=576 \ + --height=1024 \ + --guidance_scale=3.0 \ + --num_inference_steps=25 \ + --posenet_model_name_or_path="./checkpoints/Animation/pose_net.pth" \ + --face_encoder_model_name_or_path="./checkpoints/Animation/face_encoder.pth" \ + --unet_model_name_or_path="./checkpoints/Animation/unet.pth" \ + --tile_size=16 \ + --overlap=4 \ + --noise_aug_strength=0.02 \ + --frames_overlap=4 \ + --decode_chunk_size=4 \ + --gradient_checkpointing + +``` + +#### 八卡推理 +修改nproc_per_node可指定推理卡数,只支持设置2的倍数张卡 +``` +cd StableAnimator + +export PYTORCH_NPU_ALLOC_CONF='expandable_segments:True' +export TASK_QUEUE_ENABLE=2 +export CPU_AFFINITY_CONF=1 +export TOKENIZERS_PARALLELISM=false + +torchrun --nproc_per_node=8 inference_basic.py \ + --pretrained_model_name_or_path="./checkpoints/SVD/stable-video-diffusion-img2vid-xt" \ + --output_dir="./basic_infer" \ + --validation_control_folder="./inference/case-2/poses" \ + --validation_image="./inference/case-2/reference.png" \ + --width=576 \ + --height=1024 \ + --guidance_scale=3.0 \ + --num_inference_steps=25 \ + --posenet_model_name_or_path="./checkpoints/Animation/pose_net.pth" \ + --face_encoder_model_name_or_path="./checkpoints/Animation/face_encoder.pth" \ + --unet_model_name_or_path="./checkpoints/Animation/unet.pth" \ + --tile_size=16 \ + --overlap=4 \ + --noise_aug_strength=0.02 \ + --frames_overlap=4 \ + --decode_chunk_size=4 \ + --gradient_checkpointing +``` + +并行推理会受帧数限制,在帧数小的场景下有限制最多使用的卡数 + +出现提示"use 8 cards now, Current case max support 4 cards !!!!!" 时,表示该case最多支持4卡推理,此时用八卡推理性能和4卡推理性能一致。 + + + +#### 参与贡献 + +1. Fork 本仓库 +2. 新建 Feat_xxx 分支 +3. 提交代码 +4. 新建 Pull Request diff --git a/MindIE/MultiModal/StableAnimator/npu_adaptive/animation/modules/attention_processor.py b/MindIE/MultiModal/StableAnimator/npu_adaptive/animation/modules/attention_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..b9c2b6a8e2ef2657c2bd86f0f947c825c19ede75 --- /dev/null +++ b/MindIE/MultiModal/StableAnimator/npu_adaptive/animation/modules/attention_processor.py @@ -0,0 +1,340 @@ +from time import process_time_ns + +import torch +import torch.nn as nn +import torch.nn.functional as F +from diffusers.models.lora import LoRALinearLayer +from diffusers.utils.import_utils import is_xformers_available + +import math +import torch_npu + +from mindiesd.layers.flash_attn.attention_forward import attention_forward +import os + + +class AnimationAttnProcessor(nn.Module): + def __init__( + self, + hidden_size=None, + cross_attention_dim=None, + rank=4, + network_alpha=None, + lora_scale=1.0,): + super().__init__() + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + self.algo = int(os.getenv('ALGO', 0)) + # self.rank = rank + # self.lora_scale = lora_scale + # + # self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + # self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) + # self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) + # self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + ): + + # hidden_states = hidden_states.to(dtype=torch.float16) + + residual = hidden_states + + # print("-----------------------------") + # print("This is AnimationAttnProcessor") + # print(hidden_states.dtype) + # print("-----------------------------") + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if attention_mask is not None: + _, query_tokens, _ = hidden_states.shape + attention_mask = attention_mask.expand(-1, query_tokens, -1) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + # query = attn.to_q(hidden_states) + self.lora_scale * self.to_q_lora(hidden_states) + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + # key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states) + # value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim) + key = key.view(batch_size, -1, attn.heads, head_dim) + value = value.view(batch_size, -1, attn.heads, head_dim) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + if query.dtype in (torch.float16, torch.bfloat16): + if self.algo == 0: + hidden_states = attention_forward(query, key, value, + opt_mode="manual", op_type="fused_attn_score", layout="BNSD") + else: + hidden_states = torch_npu.npu_fusion_attention( + query.transpose(1, 2), + key.transpose(1, 2), + value.transpose(1, 2), + attn.heads, + input_layout="BNSD", + pse=None, + atten_mask=attention_mask, + scale=1.0 / math.sqrt(query.shape[-1]), + pre_tockens=65536, + next_tockens=65536, + keep_prob=1.0, + sync=False, + inner_precise=0, + )[0].transpose(1, 2) + else: + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2), attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ).transpose(1, 2) + + hidden_states = hidden_states.reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + # hidden_states = attn.to_out[0](hidden_states) + self.lora_scale * self.to_out_lora(hidden_states) + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class AnimationIDAttnProcessor(nn.Module): + def __init__( + self, + hidden_size, + cross_attention_dim=None, + rank=4, + network_alpha=None, + lora_scale=1.0, + scale=1.0, + num_tokens=4): + super().__init__() + + # self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + # self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) + # self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) + # self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + self.scale = scale + + self.id_to_k = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + self.id_to_v = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + + self.lora_scale = lora_scale + self.num_tokens = num_tokens + self.algo = int(os.getenv('ALGO', 0)) + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + scale=1.0, + ): + + # hidden_states = hidden_states.to(encoder_hidden_states.dtype) + + residual = hidden_states + + # print("-----------------------------") + # print("This is AnimationIDAttnProcessor") + # print(hidden_states.dtype) + # print("-----------------------------") + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + # query = attn.to_q(hidden_states) + self.lora_scale * self.to_q_lora(hidden_states) + query = attn.to_q(hidden_states) + + + # print(attn.heads) # 5 + # print(batch_size) # 21 + # print(encoder_hidden_states.size()) # [21, 5, 1024] + # print(self.num_tokens) # 4 + + encoder_hidden_states = encoder_hidden_states.to(hidden_states.dtype) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + else: + end_pos = encoder_hidden_states.shape[1] - self.num_tokens + encoder_hidden_states, ip_hidden_states = ( + encoder_hidden_states[:, :end_pos, :], + encoder_hidden_states[:, end_pos:, :], + ) + if attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + # key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states) + # value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + # print(query.size()) # [21, 4096, 320] + # print(key.size()) # [21, 1, 320] + # print(value.size()) # [21, 1, 320] + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim) + key = key.view(batch_size, -1, attn.heads, head_dim) + value = value.view(batch_size, -1, attn.heads, head_dim) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + if query.dtype in (torch.float16, torch.bfloat16): + if self.algo == 0: + hidden_states = attention_forward(query, key, value, + opt_mode="manual", op_type="fused_attn_score", layout="BNSD") + else: + hidden_states = torch_npu.npu_fusion_attention( + query.transpose(1, 2), + key.transpose(1, 2), + value.transpose(1, 2), + attn.heads, + input_layout="BNSD", + pse=None, + atten_mask=attention_mask, + scale=1.0 / math.sqrt(query.shape[-1]), + pre_tockens=65536, + next_tockens=65536, + keep_prob=1.0, + sync=False, + inner_precise=0, + )[0].transpose(1, 2) + else: + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2), attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ).transpose(1, 2) + + hidden_states = hidden_states.reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # print("==========================This is AnimationIDAttnProcessor==========================") + # print(hidden_states.size()) # [21, 4096, 320] + + ip_key = self.id_to_k(ip_hidden_states) + ip_value = self.id_to_v(ip_hidden_states) + + ip_key = ip_key.to(query.dtype) + ip_value = ip_value.to(query.dtype) + + ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim) + ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + if query.dtype in (torch.float16, torch.bfloat16): + #推理场景 + if self.algo == 0: + ip_hidden_states = attention_forward(query, ip_key, ip_value, + opt_mode="manual", op_type="fused_attn_score", layout="BNSD") + else: + ip_hidden_states = torch_npu.npu_fusion_attention( + query.transpose(1, 2), + ip_key.transpose(1, 2), + ip_value.transpose(1, 2), + attn.heads, + input_layout="BNSD", + pse=None, + atten_mask=attention_mask, + scale=1.0 / math.sqrt(query.shape[-1]), + pre_tockens=65536, + next_tockens=65536, + keep_prob=1.0, + sync=False, + inner_precise=0, + )[0].transpose(1, 2) + else: + # TODO: add support for attn.scale when we move to Torch 2.1 + ip_hidden_states = F.scaled_dot_product_attention( + query.transpose(1, 2), ip_key.transpose(1, 2), ip_value.transpose(1, 2), attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ).transpose(1, 2) + + ip_hidden_states = ip_hidden_states.reshape(batch_size, -1, attn.heads * head_dim) + ip_hidden_states = ip_hidden_states.to(query.dtype) + + # print(ip_hidden_states.size()) # [105, 4096, 64] + # ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + # ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states) + # print(ip_hidden_states.size()) # [21, 4096, 320] + hidden_states = hidden_states + self.scale * ip_hidden_states + # linear proj + # hidden_states = attn.to_out[0](hidden_states) + self.lora_scale * self.to_out_lora(hidden_states) + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states diff --git a/MindIE/MultiModal/StableAnimator/npu_adaptive/animation/modules/attention_processor_normalized.py b/MindIE/MultiModal/StableAnimator/npu_adaptive/animation/modules/attention_processor_normalized.py new file mode 100644 index 0000000000000000000000000000000000000000..2021d780838f9a07c4dd8fbe6d2de46531407fec --- /dev/null +++ b/MindIE/MultiModal/StableAnimator/npu_adaptive/animation/modules/attention_processor_normalized.py @@ -0,0 +1,193 @@ +from time import process_time_ns + +import torch +import torch.nn as nn +import torch.nn.functional as F +from diffusers.models.lora import LoRALinearLayer +from diffusers.utils.import_utils import is_xformers_available + +import math +import torch_npu +from mindiesd.layers.flash_attn.attention_forward import attention_forward +import os + +class AnimationIDAttnNormalizedProcessor(nn.Module): + def __init__( + self, + hidden_size, + cross_attention_dim=None, + rank=4, + network_alpha=None, + lora_scale=1.0, + scale=1.0, + num_tokens=4): + super().__init__() + + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + self.scale = scale + + self.id_to_k = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + self.id_to_v = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + + self.lora_scale = lora_scale + self.num_tokens = num_tokens + self.algo = int(os.getenv('ALGO', 0)) + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + scale=1.0, + ): + + # hidden_states = hidden_states.to(encoder_hidden_states.dtype) + + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + # print(attn.heads) # 5 + # print(batch_size) # 21 + # print(encoder_hidden_states.size()) # [21, 5, 1024] + # print(self.num_tokens) # 4 + + encoder_hidden_states = encoder_hidden_states.to(hidden_states.dtype) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + else: + end_pos = encoder_hidden_states.shape[1] - self.num_tokens + encoder_hidden_states, ip_hidden_states = ( + encoder_hidden_states[:, :end_pos, :], + encoder_hidden_states[:, end_pos:, :], + ) + if attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + # print(query.size()) # [21, 4096, 320] + # print(key.size()) # [21, 1, 320] + # print(value.size()) # [21, 1, 320] + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim) + key = key.view(batch_size, -1, attn.heads, head_dim) + value = value.view(batch_size, -1, attn.heads, head_dim) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + if query.dtype in (torch.float16, torch.bfloat16): + #推理场景 + if self.algo == 0: + hidden_states = attention_forward(query, key, value, + opt_mode="manual", op_type="fused_attn_score", layout="BNSD") + else: + hidden_states = torch_npu.npu_fusion_attention( + query.transpose(1, 2), + key.transpose(1, 2), + value.transpose(1, 2), + attn.heads, + input_layout="BNSD", + pse=None, + atten_mask=attention_mask, + scale=1.0 / math.sqrt(query.shape[-1]), + pre_tockens=65536, + next_tockens=65536, + keep_prob=1.0, + sync=False, + inner_precise=0, + )[0].transpose(1, 2) + else: + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2), attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ).transpose(1, 2) + + hidden_states = hidden_states.reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + # print("==========================This is AnimationIDAttnProcessor==========================") + # print(hidden_states.size()) # [21, 4096, 320] + + ip_key = self.id_to_k(ip_hidden_states) + ip_value = self.id_to_v(ip_hidden_states) + + ip_key = ip_key.to(query.dtype) + ip_value = ip_value.to(query.dtype) + + + ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim) + ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + if query.dtype in (torch.float16, torch.bfloat16): + #推理场景 + if self.algo == 0: + ip_hidden_states = attention_forward(query, ip_key, ip_value, + opt_mode="manual", op_type="fused_attn_score", layout="BNSD") + else: + ip_hidden_states = torch_npu.npu_fusion_attention( + query.transpose(1, 2), + ip_key.transpose(1, 2), + ip_value.transpose(1, 2), + attn.heads, + input_layout="BNSD", + pse=None, + atten_mask=attention_mask, + scale=1.0 / math.sqrt(query.shape[-1]), + pre_tockens=65536, + next_tockens=65536, + keep_prob=1.0, + sync=False, + inner_precise=0, + )[0].transpose(1, 2) + else: + # TODO: add support for attn.scale when we move to Torch 2.1 + ip_hidden_states = F.scaled_dot_product_attention( + query.transpose(1, 2), ip_key.transpose(1, 2), ip_value.transpose(1, 2), attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ).transpose(1, 2) + + ip_hidden_states = ip_hidden_states.reshape(batch_size, -1, attn.heads * head_dim) + ip_hidden_states = ip_hidden_states.to(query.dtype) + + + mean_latents, std_latents = torch.mean(hidden_states, dim=(1, 2), keepdim=True), torch.std(hidden_states, dim=(1, 2), keepdim=True) + mean_ip, std_ip = torch.mean(ip_hidden_states, dim=(1, 2), keepdim=True), torch.std(ip_hidden_states, dim=(1, 2), keepdim=True) + ip_hidden_states = (ip_hidden_states - mean_ip) * (std_latents / (std_ip + 1e-5)) + mean_latents + hidden_states = hidden_states + self.scale * ip_hidden_states + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states diff --git a/MindIE/MultiModal/StableAnimator/npu_adaptive/animation/modules/parallel_mgr.py b/MindIE/MultiModal/StableAnimator/npu_adaptive/animation/modules/parallel_mgr.py new file mode 100644 index 0000000000000000000000000000000000000000..52f7b428801249788ee5bf160a20be76905d0f6a --- /dev/null +++ b/MindIE/MultiModal/StableAnimator/npu_adaptive/animation/modules/parallel_mgr.py @@ -0,0 +1,36 @@ +import torch +import torch.distributed as dist +import os +SEQ = {} +_SP = None + + +PARALLEL = False + + +def use_paralle(): + global PARALLEL + return PARALLEL + + +def init_parallel_env(): + rank = int(os.getenv('RANK', 0)) + world_size = int(os.getenv('WORLD_SIZE', 1)) + torch.npu.set_device(rank) + dist.init_process_group( + backend='hccl', init_method='env://', + world_size=world_size, rank=rank + ) + global PARALLEL + PARALLEL = True + + +def all_gather_aggregate(tensor, world_size): + # 步骤1: 使用 all_gather 收集所有卡的数据到 gather_list + gather_list = [torch.zeros_like(tensor) for _ in range(world_size)] + dist.all_gather(gather_list, tensor) # 每个卡都会收到所有数据 + + # 步骤2: 拼接所有数据(假设数据按第0维分割) + aggregated = torch.cat(gather_list, dim=0) + + return aggregated diff --git a/MindIE/MultiModal/StableAnimator/npu_adaptive/animation/pipelines/inference_pipeline_animation.py b/MindIE/MultiModal/StableAnimator/npu_adaptive/animation/pipelines/inference_pipeline_animation.py new file mode 100644 index 0000000000000000000000000000000000000000..18255c04b1ef59b407cce59c16b5582f4349b4d1 --- /dev/null +++ b/MindIE/MultiModal/StableAnimator/npu_adaptive/animation/pipelines/inference_pipeline_animation.py @@ -0,0 +1,835 @@ +import inspect +from dataclasses import dataclass +from typing import Callable, Dict, List, Optional, Union + +import PIL.Image +import einops +import numpy as np +import torch +from diffusers.image_processor import VaeImageProcessor, PipelineImageInput +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps +from diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion \ + import _resize_with_antialiasing, _append_dims +from diffusers.utils import BaseOutput, logging +from diffusers.utils.torch_utils import is_compiled_module, randn_tensor + +from animation.modules.attention_processor import AnimationAttnProcessor, AnimationIDAttnProcessor +from einops import rearrange +from animation.modules.parallel_mgr import use_paralle, all_gather_aggregate +import os + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def _append_dims(x, target_dims): + """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" + dims_to_append = target_dims - x.ndim + if dims_to_append < 0: + raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less") + return x[(...,) + (None,) * dims_to_append] + + +def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: str = "np"): + batch_size, channels, num_frames, height, width = video.shape + outputs = [] + for batch_idx in range(batch_size): + batch_vid = video[batch_idx].permute(1, 0, 2, 3) + batch_output = processor.postprocess(batch_vid, output_type) + outputs.append(batch_output) + + return outputs + + +@dataclass +class InferenceAnimationPipelineOutput(BaseOutput): + r""" + Output class for mimicmotion pipeline. + + Args: + frames (`[List[List[PIL.Image.Image]]`, `np.ndarray`, `torch.Tensor`]): + List of denoised PIL images of length `batch_size` or numpy array or torch tensor of shape `(batch_size, + num_frames, height, width, num_channels)`. + """ + + frames: Union[List[List[PIL.Image.Image]], np.ndarray, torch.Tensor] + + +class InferenceAnimationPipeline(DiffusionPipeline): + r""" + Pipeline to generate video from an input image using Stable Video Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + vae ([`AutoencoderKLTemporalDecoder`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + image_encoder ([`~transformers.CLIPVisionModelWithProjection`]): + Frozen CLIP image-encoder ([laion/CLIP-ViT-H-14-laion2B-s32B-b79K] + (https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K)). + unet ([`UNetSpatioTemporalConditionModel`]): + A `UNetSpatioTemporalConditionModel` to denoise the encoded image latents. + scheduler ([`EulerDiscreteScheduler`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images. + pose_net ([`PoseNet`]): + A `` to inject pose signals into unet. + """ + + model_cpu_offload_seq = "image_encoder->unet->vae" + _callback_tensor_inputs = ["latents"] + + def __init__( + self, + vae, + image_encoder, + unet, + scheduler, + feature_extractor, + pose_net, + face_encoder, + ): + super().__init__() + + self.register_modules( + vae=vae, + image_encoder=image_encoder, + unet=unet, + scheduler=scheduler, + feature_extractor=feature_extractor, + pose_net=pose_net, + face_encoder=face_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + self.num_tokens = 4 + + # self.app = FaceAnalysis(name="buffalo_l", providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) + # self.app.prepare(ctx_id=0, det_size=(640, 640)) + # self.lora_rank = 128 + # self.set_ip_adapter() + + def get_prepare_faceid(self, face_image): + faceid_image = np.array(face_image) + faces = self.app.get(faceid_image) + if faces == []: + faceid_embeds = torch.zeros_like(torch.empty((1, 512))) + else: + faceid_embeds = torch.from_numpy(faces[0].normed_embedding).unsqueeze(0) + return faceid_embeds + + def set_ip_adapter(self): + unet = self.unet + attn_procs = {} + for name in unet.attn_processors.keys(): + cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim + if name.startswith("mid_block"): + hidden_size = unet.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(unet.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = unet.config.block_out_channels[block_id] + if cross_attention_dim is None: + attn_procs[name] = AnimationAttnProcessor( + hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=self.lora_rank, + ).to(self.device, dtype=self.torch_dtype) + else: + attn_procs[name] = AnimationIDAttnProcessor( + hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0, rank=self.lora_rank, + num_tokens=self.num_tokens, + ).to(self.device, dtype=self.torch_dtype) + + unet.set_attn_processor(attn_procs) + + def _encode_image( + self, + image: PipelineImageInput, + device: Union[str, torch.device], + num_videos_per_prompt: int, + do_classifier_free_guidance: bool): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.image_processor.pil_to_numpy(image) + image = self.image_processor.numpy_to_pt(image) + + # We normalize the image before resizing to match with the original implementation. + # Then we unnormalize it after resizing. + image = image * 2.0 - 1.0 + image = _resize_with_antialiasing(image, (224, 224)) + image = (image + 1.0) / 2.0 + + # Normalize the image with for CLIP input + image = self.feature_extractor( + images=image, + do_normalize=True, + do_center_crop=False, + do_resize=False, + do_rescale=False, + return_tensors="pt", + ).pixel_values + + image = image.to(device=device, dtype=dtype) + image_embeddings = self.image_encoder(image).image_embeds + image_embeddings = image_embeddings.unsqueeze(1) + + # duplicate image embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = image_embeddings.shape + image_embeddings = image_embeddings.repeat(1, num_videos_per_prompt, 1) + image_embeddings = image_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + negative_image_embeddings = torch.zeros_like(image_embeddings) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + image_embeddings = torch.cat([negative_image_embeddings, image_embeddings]) + + return image_embeddings + + def _encode_vae_image( + self, + image: torch.Tensor, + device: Union[str, torch.device], + num_videos_per_prompt: int, + do_classifier_free_guidance: bool, + ): + image = image.to(device=device, dtype=self.vae.dtype) + image_latents = self.vae.encode(image).latent_dist.mode() + + if do_classifier_free_guidance: + negative_image_latents = torch.zeros_like(image_latents) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + image_latents = torch.cat([negative_image_latents, image_latents]) + + # duplicate image_latents for each generation per prompt, using mps friendly method + image_latents = image_latents.repeat(num_videos_per_prompt, 1, 1, 1) + + return image_latents + + def _get_add_time_ids( + self, + fps: int, + motion_bucket_id: int, + noise_aug_strength: float, + dtype: torch.dtype, + batch_size: int, + num_videos_per_prompt: int, + do_classifier_free_guidance: bool, + ): + add_time_ids = [fps, motion_bucket_id, noise_aug_strength] + + passed_add_embed_dim = self.unet.config.addition_time_embed_dim * len(add_time_ids) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, " \ + f"but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. " \ + f"Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + add_time_ids = add_time_ids.repeat(batch_size * num_videos_per_prompt, 1) + + if do_classifier_free_guidance: + add_time_ids = torch.cat([add_time_ids, add_time_ids]) + + return add_time_ids + + def decode_latents( + self, + latents: torch.Tensor, + num_frames: int, + decode_chunk_size: int = 8): + # [batch, frames, channels, height, width] -> [batch*frames, channels, height, width] + latents = latents.flatten(0, 1) + + latents = 1 / self.vae.config.scaling_factor * latents + + forward_vae_fn = self.vae._orig_mod.forward if is_compiled_module(self.vae) else self.vae.forward + accepts_num_frames = "num_frames" in set(inspect.signature(forward_vae_fn).parameters.keys()) + + # decode decode_chunk_size frames at a time to avoid OOM + frames = [] + if use_paralle(): + local_rank = int(os.environ["LOCAL_RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + + # 核心分块参数计算 + total_latent_frames = latents.shape[0] # 潜在空间总帧数 + num_total_chunks = (total_latent_frames + decode_chunk_size - 1) // decode_chunk_size # 总块数(向上取整) + chunks_per_gpu = (num_total_chunks + world_size - 1) // world_size # 每GPU分配块数(向上取整) + + # -------------------------- 当前GPU任务范围确定 -------------------------- + # 边界保护:防止越界(理论上无需但保留防御性编程) + start_idx = local_rank * chunks_per_gpu + end_idx = min(start_idx + chunks_per_gpu, num_total_chunks) + start_idx = max(0, start_idx) # 理论上不会小于0,显式保护 + end_idx = min(end_idx, num_total_chunks) # 确保不超过总块数 + + # -------------------------- 当前GPU块解码 -------------------------- + frames = [] + for i in range(start_idx, end_idx): + # 计算当前块的实际起止位置 + chunk_start = i * decode_chunk_size + chunk_end = min((i + 1) * decode_chunk_size, total_latent_frames) + num_frames_in_chunk = chunk_end - chunk_start # 当前块实际帧数 + + # 动态参数构造(支持num_frames参数的解码器) + decode_kwargs = {} + if accepts_num_frames: + decode_kwargs["num_frames"] = num_frames_in_chunk + + # 混合精度解码当前块 + with torch.cuda.amp.autocast(enabled=True): + chunk_frames = self.vae.decode(latents[chunk_start:chunk_end], **decode_kwargs).sample + frames.append(chunk_frames) + + # -------------------------- 边缘情况处理:无任务分配 -------------------------- + if not frames: + # 生成虚拟帧(形状与正常解码结果一致) + dummy_decode_kwargs = {"num_frames": decode_chunk_size} if accepts_num_frames else {} + with torch.cuda.amp.autocast(enabled=True): + dummy_frame = self.vae.decode(latents[:decode_chunk_size], **dummy_decode_kwargs).sample + # 填充虚拟帧至当前GPU分配的块数(避免后续拼接错误) + dummy_frame = torch.zeros_like(dummy_frame) + frames = [dummy_frame] * chunks_per_gpu + + # -------------------------- 填充未占满的块(当前GPU任务不足) -------------------------- + # 计算需要填充的帧数:预期总帧数(chunks_per_gpu * decode_chunk_size) - 已解码总帧数 + expected_frames = chunks_per_gpu * decode_chunk_size + actual_frames = sum(f.shape[0] for f in frames) + pad_num = expected_frames - actual_frames + + if pad_num > 0: + # 生成零填充帧(形状与单帧一致) + pad_frame = torch.zeros_like(frames[0][:1, ...]) # 保持维度一致性 + frames.extend([pad_frame] * pad_num) + + # -------------------------- 结果聚合与裁剪 -------------------------- + # 拼接当前GPU的所有块 + frames = torch.cat(frames, dim=0) + # 分布式聚合所有GPU的结果 + gather_list = all_gather_aggregate(frames, world_size) + # 裁剪至实际总帧数(去除多GPU聚合的冗余数据) + frames = gather_list.view(-1, *frames.shape[1:])[:total_latent_frames, ...] + else: + # 原单卡解码逻辑 + for i in range(0, latents.shape[0], decode_chunk_size): + num_frames_in = latents[i: i + decode_chunk_size].shape[0] + decode_kwargs = {} + if accepts_num_frames: + # we only pass num_frames_in if it's expected + decode_kwargs["num_frames"] = num_frames_in + + frame = self.vae.decode(latents[i: i + decode_chunk_size], **decode_kwargs).sample + frames.append(frame) + frames = torch.cat(frames, dim=0) + + # [batch*frames, channels, height, width] -> [batch, channels, frames, height, width] + frames = frames.reshape(-1, num_frames, *frames.shape[1:]).permute(0, 2, 1, 3, 4) + + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + frames = frames.float() + return frames + + def check_inputs(self, image, height, width): + if ( + not isinstance(image, torch.Tensor) + and not isinstance(image, PIL.Image.Image) + and not isinstance(image, list) + ): + raise ValueError( + "`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is" + f" {type(image)}" + ) + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + def prepare_latents( + self, + batch_size: int, + num_frames: int, + num_channels_latents: int, + height: int, + width: int, + dtype: torch.dtype, + device: Union[str, torch.device], + generator: torch.Generator, + latents: Optional[torch.Tensor] = None, + ): + shape = ( + batch_size, + num_frames, + num_channels_latents // 2, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + if isinstance(self.guidance_scale, (int, float)): + return self.guidance_scale > 1 + return self.guidance_scale.max() > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + @torch.no_grad() + def __call__( + self, + image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor], + image_pose: Union[torch.FloatTensor], + height: int = 576, + width: int = 1024, + num_frames: Optional[int] = None, + tile_size: Optional[int] = 16, + tile_overlap: Optional[int] = 4, + num_inference_steps: int = 25, + min_guidance_scale: float = 1.0, + max_guidance_scale: float = 3.0, + fps: int = 7, + motion_bucket_id: int = 127, + noise_aug_strength: float = 0.02, + image_only_indicator: bool = False, + decode_chunk_size: Optional[int] = None, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + validation_image_id_ante_embedding=None, + output_type: Optional[str] = "pil", + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + return_dict: bool = True, + ): + r""" + The call function to the pipeline for generation. + + Args: + image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`): + Image or images to guide image generation. If you provide a tensor, it needs to be compatible with + [`CLIPImageProcessor`](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/ + feature_extractor/preprocessor_config.json). + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_frames (`int`, *optional*): + The number of video frames to generate. Defaults to 14 for `stable-video-diffusion-img2vid` + and to 25 for `stable-video-diffusion-img2vid-xt` + num_inference_steps (`int`, *optional*, defaults to 25): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter is modulated by `strength`. + min_guidance_scale (`float`, *optional*, defaults to 1.0): + The minimum guidance scale. Used for the classifier free guidance with first frame. + max_guidance_scale (`float`, *optional*, defaults to 3.0): + The maximum guidance scale. Used for the classifier free guidance with last frame. + fps (`int`, *optional*, defaults to 7): + Frames per second.The rate at which the generated images shall be exported to a video after generation. + Note that Stable Diffusion Video's UNet was micro-conditioned on fps-1 during training. + motion_bucket_id (`int`, *optional*, defaults to 127): + The motion bucket ID. Used as conditioning for the generation. + The higher the number the more motion will be in the video. + noise_aug_strength (`float`, *optional*, defaults to 0.02): + The amount of noise added to the init image, + the higher it is the less the video will look like the init image. Increase it for more motion. + image_only_indicator (`bool`, *optional*, defaults to False): + Whether to treat the inputs as batch of images instead of videos. + decode_chunk_size (`int`, *optional*): + The number of frames to decode at a time.The higher the chunk size, the higher the temporal consistency + between frames, but also the higher the memory consumption. + By default, the decoder will decode all frames at once for maximal quality. + Reduce `decode_chunk_size` to reduce memory usage. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + device: + On which device the pipeline runs on. + + Returns: + [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, + [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list of list with the generated frames. + + Examples: + + ```py + from diffusers import StableVideoDiffusionPipeline + from diffusers.utils import load_image, export_to_video + + pipe = StableVideoDiffusionPipeline.from_pretrained( + "stabilityai/stable-video-diffusion-img2vid-xt", torch_dtype=torch.float16, variant="fp16") + pipe.to("cuda") + + image = load_image( + "https://lh3.googleusercontent.com/y-iFOHfLTwkuQSUegpwDdgKmOjRSTvPxat63dQLB25xkTs4lhIbRUFeNBWZzYf370g=s1200") + image = image.resize((1024, 576)) + + frames = pipe(image, num_frames=25, decode_chunk_size=8).frames[0] + export_to_video(frames, "generated.mp4", fps=7) + ``` + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + num_frames = num_frames if num_frames is not None else self.unet.config.num_frames + decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else num_frames + + # 1. Check inputs. Raise error if not correct + self.check_inputs(image, height, width) + + # 2. Define call parameters + if isinstance(image, PIL.Image.Image): + batch_size = 1 + elif isinstance(image, list): + batch_size = len(image) + else: + batch_size = image.shape[0] + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = max_guidance_scale >= 1.0 + self._guidance_scale = max_guidance_scale + + # 3. Encode input image + image_embeddings = self._encode_image(image, device, num_videos_per_prompt, do_classifier_free_guidance) + # self.image_encoder.cpu() + + # NOTE: Stable Diffusion Video was conditioned on fps - 1, which + # is why it is reduced here. + fps = fps - 1 + + # 4. Encode input image using VAE + + # print(image_embeddings.size()) # [2, 1, 1024] + validation_image_id_ante_embedding = torch.from_numpy(validation_image_id_ante_embedding).unsqueeze(0) + validation_image_id_ante_embedding = validation_image_id_ante_embedding.to(device=device, dtype=image_embeddings.dtype) + + faceid_latents = self.face_encoder(validation_image_id_ante_embedding, image_embeddings[1:]) + # print(faceid_latents.size()) # [1, 4, 1024] + uncond_image_embeddings = image_embeddings[:1] + uncond_faceid_latents = torch.zeros_like(faceid_latents) + uncond_image_embeddings = torch.cat([uncond_image_embeddings, uncond_faceid_latents], dim=1) + cond_image_embeddings = image_embeddings[1:] + cond_image_embeddings = torch.cat([cond_image_embeddings, faceid_latents], dim=1) + image_embeddings = torch.cat([uncond_image_embeddings, cond_image_embeddings]) + + image = self.image_processor.preprocess(image, height=height, width=width).to(device) + noise = randn_tensor(image.shape, generator=generator, device=device, dtype=image.dtype) + image = image + noise_aug_strength * noise + + needs_upcasting = (self.vae.dtype == torch.float16 or self.vae.dtype == torch.bfloat16) and self.vae.config.force_upcast + if needs_upcasting: + self_vae_dtype = self.vae.dtype + self.vae.to(dtype=torch.float32) + + image_latents = self._encode_vae_image( + image, + device=device, + num_videos_per_prompt=num_videos_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + ) + image_latents = image_latents.to(image_embeddings.dtype) + + if needs_upcasting: + self.vae.to(dtype=self_vae_dtype) + # self.vae.cpu() + + # Repeat the image latents for each frame so we can concatenate them with the noise + # image_latents [batch, channels, height, width] ->[batch, num_frames, channels, height, width] + image_latents = image_latents.unsqueeze(1).repeat(1, num_frames, 1, 1, 1) + + # 5. Get Added Time IDs + added_time_ids = self._get_add_time_ids( + fps, + motion_bucket_id, + noise_aug_strength, + image_embeddings.dtype, + batch_size, + num_videos_per_prompt, + self.do_classifier_free_guidance, + ) + added_time_ids = added_time_ids.to(device) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, None) + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + tile_size, + num_channels_latents, + height, + width, + image_embeddings.dtype, + device, + generator, + latents, + ) + latents = latents.repeat(1, num_frames // tile_size + 1, 1, 1, 1)[:, :num_frames] + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, 0.0) + + # 7. Prepare guidance scale + guidance_scale = torch.linspace(min_guidance_scale, max_guidance_scale, num_frames).unsqueeze(0) + guidance_scale = guidance_scale.to(device, latents.dtype) + guidance_scale = guidance_scale.repeat(batch_size * num_videos_per_prompt, 1) + guidance_scale = _append_dims(guidance_scale, latents.ndim) + + self._guidance_scale = guidance_scale + + # 8. Denoising loop + self._num_timesteps = len(timesteps) + indices = [[0, *range(i + 1, min(i + tile_size, num_frames))] for i in + range(0, num_frames - tile_size + 1, tile_size - tile_overlap)] + if indices[-1][-1] < num_frames - 1: + indices.append([0, *range(num_frames - tile_size + 1, num_frames)]) + + pose_pil_image_list = [] + for pose in image_pose: + pose = torch.from_numpy(np.array(pose)).float() + pose = pose / 127.5 - 1 + pose_pil_image_list.append(pose) + pose_pil_image_list = torch.stack(pose_pil_image_list, dim=0) + pose_pil_image_list = rearrange(pose_pil_image_list, "f h w c -> f c h w") + + + # print(indices) # [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]] + # print(pose_pil_image_list.size()) # [16, 3, 512, 512] + + self.pose_net.to(device) + self.unet.to(device) + + with torch.cuda.device(device): + torch.cuda.empty_cache() + + with self.progress_bar(total=len(timesteps) * len(indices)) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # Concatenate image_latents over channels dimension + latent_model_input = torch.cat([latent_model_input, image_latents], dim=2) + + # predict the noise residual + noise_pred = torch.zeros_like(image_latents) + noise_pred_cnt = image_latents.new_zeros((num_frames,)) + weight = (torch.arange(tile_size, device=device) + 0.5) * 2. / tile_size + weight = torch.minimum(weight, 2 - weight) + + if use_paralle(): + local_rank = int(os.environ["LOCAL_RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + + # 模型实例数,每个实例使用两张卡 + num_model_instances = world_size // 2 + if num_model_instances > len(indices): + print( + f"use {world_size} cards now, Current case max support {len(indices)*2} cards !!!!!") + + # 每个实例处理idx_partnum份数据 + idx_partnum = len(indices) // num_model_instances + + remainder = len(indices) % num_model_instances + + model_instance_idx = local_rank // 2 + # 前remainder个实例额外处理一份数据 + if model_instance_idx < remainder: + start = model_instance_idx * idx_partnum + model_instance_idx + end = start + idx_partnum + 1 + else: + start = model_instance_idx * idx_partnum + remainder + end = start + idx_partnum + if remainder != 0: + progress_bar.update() + if model_instance_idx == world_size // 2 - 1: + spilt_indices = indices[start:] + else: + spilt_indices = indices[start:end] + + if local_rank % 2 == 0: + for idx in spilt_indices: + # classification-free inference + _noise_pred = self.unet( + latent_model_input[:1, idx], + t, + encoder_hidden_states=image_embeddings[:1], + added_time_ids=added_time_ids[:1], + pose_latents=None, + image_only_indicator=image_only_indicator, + return_dict=False, + )[0] + noise_pred[:1, idx] += _noise_pred * weight[:, None, None, None] + for _ in range(num_model_instances): + progress_bar.update() + + noise_pred_allgather = all_gather_aggregate(noise_pred[:1, ...], world_size) + noise_pred[0, ...] = noise_pred_allgather[0::2, ...].sum(dim=0) + noise_pred[1, ...] = noise_pred_allgather[1::2, ...].sum(dim=0) + + elif local_rank % 2 == 1: + for idx in spilt_indices: + # normal inference + pose_latents = self.pose_net(pose_pil_image_list[idx].to( + device=device, dtype=latent_model_input.dtype)) + _noise_pred = self.unet( + latent_model_input[1:, idx], + t, + encoder_hidden_states=image_embeddings[1:], + added_time_ids=added_time_ids[1:], + pose_latents=pose_latents, + image_only_indicator=image_only_indicator, + return_dict=False, + )[0] + noise_pred[1:, idx] += _noise_pred * weight[:, None, None, None] + for _ in range(num_model_instances): + progress_bar.update() + + noise_pred_allgather = all_gather_aggregate(noise_pred[1:, ...], world_size) + noise_pred[0, ...] = noise_pred_allgather[0::2, ...].sum(dim=0) + noise_pred[1, ...] = noise_pred_allgather[1::2, ...].sum(dim=0) + else: + for idx in indices: + # classification-free inference + pose_latents = self.pose_net(pose_pil_image_list[idx].to( + device=device, dtype=latent_model_input.dtype)) + _noise_pred = self.unet( + latent_model_input[:1, idx], + t, + encoder_hidden_states=image_embeddings[:1], + added_time_ids=added_time_ids[:1], + pose_latents=None, + image_only_indicator=image_only_indicator, + return_dict=False, + )[0] + noise_pred[:1, idx] += _noise_pred * weight[:, None, None, None] + + # normal inference + _noise_pred = self.unet( + latent_model_input[1:, idx], + t, + encoder_hidden_states=image_embeddings[1:], + added_time_ids=added_time_ids[1:], + pose_latents=pose_latents, + image_only_indicator=image_only_indicator, + return_dict=False, + )[0] + noise_pred[1:, idx] += _noise_pred * weight[:, None, None, None] + progress_bar.update() + + for idx in indices: + noise_pred_cnt[idx] += weight + noise_pred.div_(noise_pred_cnt[:, None, None, None]) + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + + # self.pose_net.cpu() + # self.unet.cpu() + # self.face_encoder.cpu() + + if not output_type == "latent": + self.vae.decoder.to(device) + frames = self.decode_latents(latents, num_frames, decode_chunk_size) + # print(frames.size()) # [1, 3, 16, 512, 512] + # print(latents.size()) # [1, 16, 4, 64, 64] + frames = tensor2vid(frames, self.image_processor, output_type=output_type) + # print(frames[0].size()) # [16, 3, 512, 512] + else: + frames = latents + + self.maybe_free_model_hooks() + + if not return_dict: + return frames + + return InferenceAnimationPipelineOutput(frames=frames) diff --git a/MindIE/MultiModal/StableAnimator/npu_adaptive/inference_basic.py b/MindIE/MultiModal/StableAnimator/npu_adaptive/inference_basic.py new file mode 100644 index 0000000000000000000000000000000000000000..8514daead43b475f33458fdf40e11f014d6402f7 --- /dev/null +++ b/MindIE/MultiModal/StableAnimator/npu_adaptive/inference_basic.py @@ -0,0 +1,469 @@ +import argparse +import os +import cv2 +import numpy as np +from PIL import Image +from diffusers.models.attention_processor import AttnProcessorNPU +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection +import torch +import torch_npu +from torch_npu.contrib import transfer_to_npu +torch_npu.npu.set_compile_mode(jit_compile=False) +torch.npu.config.allow_internal_format=False +from animation.modules.parallel_mgr import init_parallel_env +from diffusers import AutoencoderKLTemporalDecoder, EulerDiscreteScheduler + +from animation.modules.attention_processor import AnimationAttnProcessor +from animation.modules.attention_processor_normalized import AnimationIDAttnNormalizedProcessor +from animation.modules.face_model import FaceModel +from animation.modules.id_encoder import FusionFaceId +from animation.modules.pose_net import PoseNet +from animation.modules.unet import UNetSpatioTemporalConditionModel +from animation.pipelines.inference_pipeline_animation import InferenceAnimationPipeline +import random + +def seed_everything(seed): + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed % (2**32)) + random.seed(seed) + + +def load_images_from_folder(folder, width, height): + images = [] + files = os.listdir(folder) + png_files = [f for f in files if f.endswith('.png')] + png_files.sort(key=lambda x: int(x.split('_')[1].split('.')[0])) + for filename in png_files: + img = Image.open(os.path.join(folder, filename)).convert('RGB') + img = img.resize((width, height)) + images.append(img) + + return images + +def save_frames_as_png(frames, output_path): + pil_frames = [Image.fromarray(frame) if isinstance(frame, np.ndarray) else frame for frame in frames] + num_frames = len(pil_frames) + for i in range(num_frames): + pil_frame = pil_frames[i] + save_path = os.path.join(output_path, f'frame_{i}.png') + pil_frame.save(save_path) + +def save_frames_as_mp4(frames, output_mp4_path, fps): + print("Starting saving the frames as mp4") + height, width, _ = frames[0].shape + fourcc = cv2.VideoWriter_fourcc(*'mp4v') # 'H264' for better quality + out = cv2.VideoWriter(output_mp4_path, fourcc, fps, (width, height)) + for frame in frames: + frame_bgr = frame if frame.shape[2] == 3 else cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) + out.write(frame_bgr) + out.release() + + +def export_to_gif(frames, output_gif_path, fps): + """ + Export a list of frames to a GIF. + + Args: + - frames (list): List of frames (as numpy arrays or PIL Image objects). + - output_gif_path (str): Path to save the output GIF. + - duration_ms (int): Duration of each frame in milliseconds. + + """ + # Convert numpy arrays to PIL Images if needed + pil_frames = [Image.fromarray(frame) if isinstance( + frame, np.ndarray) else frame for frame in frames] + + pil_frames[0].save(output_gif_path.replace('.mp4', '.gif'), + format='GIF', + append_images=pil_frames[1:], + save_all=True, + duration=125, + loop=0) + +def parse_args(): + parser = argparse.ArgumentParser( + description="Script to train Stable Diffusion XL for InstructPix2Pix." + ) + + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True + ) + + parser.add_argument( + "--validation_image", + type=str, + default=None, + help=( + "A set of paths to the controlnext conditioning image be evaluated every `--validation_steps`" + " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a" + " a single `--validation_prompt` to be used with all `--validation_image`s, or a single" + " `--validation_image` that will be used with all `--validation_prompt`s." + ), + ) + parser.add_argument( + "--validation_control_folder", + type=str, + default=None, + help=( + "the validation control image" + ), + ) + + parser.add_argument( + "--output_dir", + type=str, + default=None, + required=True + ) + + parser.add_argument( + "--height", + type=int, + default=768, + required=False + ) + + parser.add_argument( + "--width", + type=int, + default=512, + required=False + ) + + parser.add_argument( + "--guidance_scale", + type=float, + default=2.0, + required=False + ) + + parser.add_argument( + "--num_inference_steps", + type=int, + default=25, + required=False + ) + + parser.add_argument( + "--posenet_model_name_or_path", + type=str, + default=None, + help="Path to pretrained posenet model", + ) + parser.add_argument( + "--face_encoder_model_name_or_path", + type=str, + default=None, + help="Path to pretrained face encoder model", + ) + parser.add_argument( + "--unet_model_name_or_path", + type=str, + default=None, + help="Path to pretrained unet model", + ) + + parser.add_argument( + "--tile_size", + type=int, + default=16, + required=False + ) + + parser.add_argument( + "--overlap", + type=int, + default=4, + required=False + ) + + parser.add_argument( + "--noise_aug_strength", + type=float, + default=0.0, # or set to 0.02 + required=False + ) + parser.add_argument( + "--frames_overlap", + type=int, + default=4, + required=False + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--decode_chunk_size", + type=int, + default=None, + required=False + ) + + args = parser.parse_args() + return args + +import torch.nn as nn +class LayerNormNPU(nn.LayerNorm): + + def forward(self, x): + # return super().forward(x.float()).type_as(x) + return torch_npu.npu_layer_norm_eval( + x, normalized_shape=self.normalized_shape, weight=self.weight, bias=self.bias, eps=self.eps, + ) + +def replace_layer_norm(model): + for name, module in model.named_children(): + # 递归处理子模块 + replace_layer_norm(module) + + # 检查是否是LayerNorm并替换 + if isinstance(module, nn.LayerNorm): + # 动态获取参数(兼容elementwise_affine=False的情况) + kwargs = { + "normalized_shape": module.normalized_shape, + "eps": module.eps, + "elementwise_affine": module.elementwise_affine # 传递是否可学习参数的配置 + } + + # 创建新的LayerNormNPU实例 + new_layer = LayerNormNPU(**kwargs) + if hasattr(module, 'weight'): + new_layer.weight = module.weight + if hasattr(module, 'bias'): + new_layer.bias = module.bias + + setattr(model, name, new_layer) + return model + + + +if __name__ == "__main__": + args = parse_args() + init_parallel_env() + # torch.set_default_dtype(torch.float16) + seed = 23123134 + # seed = 42 + # seed = 123 + seed_everything(seed) + generator = torch.Generator(device='cpu').manual_seed(seed) + + feature_extractor = CLIPImageProcessor.from_pretrained(args.pretrained_model_name_or_path, subfolder="feature_extractor", revision=args.revision) + noise_scheduler = EulerDiscreteScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + image_encoder = CLIPVisionModelWithProjection.from_pretrained( + args.pretrained_model_name_or_path, subfolder="image_encoder", revision=args.revision + ) + vae = AutoencoderKLTemporalDecoder.from_pretrained( + args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision) + unet = UNetSpatioTemporalConditionModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="unet", + low_cpu_mem_usage=True, + ) + pose_net = PoseNet(noise_latent_channels=unet.config.block_out_channels[0]) + face_encoder = FusionFaceId( + cross_attention_dim=1024, + id_embeddings_dim=512, + # clip_embeddings_dim=image_encoder.config.hidden_size, + clip_embeddings_dim=1024, + num_tokens=4, ) + face_model = FaceModel() + + lora_rank = 128 + attn_procs = {} + unet_svd = unet.state_dict() + + for name in unet.attn_processors.keys(): + if "transformer_blocks" in name and "temporal_transformer_blocks" not in name: + cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim + if name.startswith("mid_block"): + hidden_size = unet.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(unet.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = unet.config.block_out_channels[block_id] + if cross_attention_dim is None: + # print(f"This is AnimationAttnProcessor: {name}") + attn_procs[name] = AnimationAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=lora_rank) + else: + # print(f"This is AnimationIDAttnProcessor: {name}") + layer_name = name.split(".processor")[0] + weights = { + "id_to_k.weight": unet_svd[layer_name + ".to_k.weight"], + "id_to_v.weight": unet_svd[layer_name + ".to_v.weight"], + } + attn_procs[name] = AnimationIDAttnNormalizedProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=lora_rank) + attn_procs[name].load_state_dict(weights, strict=False) + elif "temporal_transformer_blocks" in name: + cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim + if name.startswith("mid_block"): + hidden_size = unet.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(unet.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = unet.config.block_out_channels[block_id] + if cross_attention_dim is None: + attn_procs[name] = AttnProcessorNPU() + else: + attn_procs[name] = AttnProcessorNPU() + unet.set_attn_processor(attn_procs) + + replace_layer_norm(unet) + + # resume the previous checkpoint + if args.posenet_model_name_or_path is not None and args.face_encoder_model_name_or_path is not None and args.unet_model_name_or_path is not None: + print("Loading existing posenet weights, face_encoder weights and unet weights.") + if args.posenet_model_name_or_path.endswith(".pth"): + pose_net_state_dict = torch.load(args.posenet_model_name_or_path, map_location="cpu") + pose_net.load_state_dict(pose_net_state_dict, strict=True) + else: + print("posenet weights loading fail") + print(1/0) + if args.face_encoder_model_name_or_path.endswith(".pth"): + face_encoder_state_dict = torch.load(args.face_encoder_model_name_or_path, map_location="cpu") + face_encoder.load_state_dict(face_encoder_state_dict, strict=True) + else: + print("face_encoder weights loading fail") + print(1/0) + if args.unet_model_name_or_path.endswith(".pth"): + unet_state_dict = torch.load(args.unet_model_name_or_path, map_location="cpu") + unet.load_state_dict(unet_state_dict, strict=True) + else: + print("unet weights loading fail") + print(1/0) + + torch.cuda.empty_cache() + vae.requires_grad_(False) + image_encoder.requires_grad_(False) + unet.requires_grad_(False) + pose_net.requires_grad_(False) + face_encoder.requires_grad_(False) + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + + # weight_dtype = torch.float16 + # weight_dtype = torch.float32 + weight_dtype = torch.bfloat16 + + pipeline = InferenceAnimationPipeline( + vae=vae, + image_encoder=image_encoder, + unet=unet, + scheduler=noise_scheduler, + feature_extractor=feature_extractor, + pose_net=pose_net, + face_encoder=face_encoder, + ).to(device='cuda', dtype=weight_dtype) + + os.makedirs(args.output_dir, exist_ok=True) + + validation_image_path = args.validation_image + validation_image = Image.open(args.validation_image).convert('RGB') + validation_control_images = load_images_from_folder(args.validation_control_folder, width=args.width, height=args.height) + + num_frames = len(validation_control_images) + face_model.face_helper.clean_all() + validation_face = cv2.imread(validation_image_path) + validation_image_bgr = cv2.cvtColor(validation_face, cv2.COLOR_RGB2BGR) + validation_image_face_info = face_model.app.get(validation_image_bgr) + if len(validation_image_face_info) > 0: + validation_image_face_info = sorted(validation_image_face_info, key=lambda x: (x['bbox'][2] - x['bbox'][0]) * (x['bbox'][3] - x['bbox'][1]))[-1] + validation_image_id_ante_embedding = validation_image_face_info['embedding'] + else: + validation_image_id_ante_embedding = None + + if validation_image_id_ante_embedding is None: + face_model.face_helper.read_image(validation_image_bgr) + face_model.face_helper.get_face_landmarks_5(only_center_face=True) + face_model.face_helper.align_warp_face() + + if len(face_model.face_helper.cropped_faces) == 0: + validation_image_id_ante_embedding = np.zeros((512,)) + else: + validation_image_align_face = face_model.face_helper.cropped_faces[0] + print('fail to detect face using insightface, extract embedding on align face') + validation_image_id_ante_embedding = face_model.handler_ante.get_feat(validation_image_align_face) + + # generator = torch.Generator(device=accelerator.device).manual_seed(23123134) + + decode_chunk_size = args.decode_chunk_size + + print("warm up ") + video_frames = pipeline( + image=validation_image, + image_pose=validation_control_images, + height=args.height, + width=args.width, + num_frames=num_frames, + tile_size=args.tile_size, + tile_overlap=args.frames_overlap, + decode_chunk_size=decode_chunk_size, + motion_bucket_id=127., + fps=7, + min_guidance_scale=args.guidance_scale, + max_guidance_scale=args.guidance_scale, + noise_aug_strength=args.noise_aug_strength, + num_inference_steps=3, + generator=generator, + output_type="pil", + validation_image_id_ante_embedding=validation_image_id_ante_embedding, + ).frames[0] + + import time + torch.npu.synchronize() + a=time.time() + video_frames = pipeline( + image=validation_image, + image_pose=validation_control_images, + height=args.height, + width=args.width, + num_frames=num_frames, + tile_size=args.tile_size, + tile_overlap=args.frames_overlap, + decode_chunk_size=decode_chunk_size, + motion_bucket_id=127., + fps=7, + min_guidance_scale=args.guidance_scale, + max_guidance_scale=args.guidance_scale, + noise_aug_strength=args.noise_aug_strength, + num_inference_steps=args.num_inference_steps, + generator=generator, + output_type="pil", + validation_image_id_ante_embedding=validation_image_id_ante_embedding, + ).frames[0] + torch.npu.synchronize() + b=time.time() + print("pipline time :",b-a) + + + out_file = os.path.join( + args.output_dir, + f"animation_video.mp4", + ) + for i in range(num_frames): + img = video_frames[i] + video_frames[i] = np.array(img) + + png_out_file = os.path.join(args.output_dir, "animated_images") + os.makedirs(png_out_file, exist_ok=True) + export_to_gif(video_frames, out_file, 8) + save_frames_as_png(video_frames, png_out_file) + + diff --git a/MindIE/MultiModal/StableAnimator/requirements.txt b/MindIE/MultiModal/StableAnimator/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..67faee5a18cf5869b8aa9848e6af93bfa5dd62b3 --- /dev/null +++ b/MindIE/MultiModal/StableAnimator/requirements.txt @@ -0,0 +1,20 @@ +diffusers==0.33.1 +transformers==4.35.2 +accelerate==0.25.0 +timm==0.4.12 +decord +einops +scipy +pandas +coloredlogs +flatbuffers +numpy==1.26.4 +packaging +protobuf +sympy +imageio-ffmpeg +insightface +facexlib +opencv-python-headless +gradio +onnxruntime-gpu