From 711e4007f1f4c13d8295b0edde22e4c0131a24a4 Mon Sep 17 00:00:00 2001 From: zhoufan2956 Date: Tue, 24 Dec 2024 20:48:13 +0800 Subject: [PATCH 01/32] add stable_audio --- .../foundation/stable_audio/README.md | 186 ++ .../stable_audio/inference_stableaudio.py | 129 ++ .../stable_audio/prompts/prompts.txt | 3 + .../stable_audio/stableaudio/__init__.py | 6 + .../stableaudio/layers/__init__.py | 6 + .../stableaudio/layers/activations.py | 165 ++ .../stableaudio/layers/attention.py | 434 +++++ .../stableaudio/layers/attention_processor.py | 904 ++++++++++ .../stableaudio/layers/embeddings.py | 1556 +++++++++++++++++ .../stableaudio/layers/normalization.py | 451 +++++ .../stableaudio/models/__init__.py | 2 + .../models/modeling_stable_audio.py | 158 ++ .../models/stable_audio_transformer.py | 457 +++++ .../stableaudio/pipeline/__init__.py | 1 + .../pipeline/pipeline_stable_audio.py | 745 ++++++++ .../stableaudio/scheduler/__init__.py | 1 + .../scheduling_cosine_dpmsolver_multistep.py | 572 ++++++ .../stable_audio/stableaudio/vae/__init__.py | 1 + .../stableaudio/vae/autoencoder_oobleck.py | 464 +++++ ...ly_200k_bs64_crop_640_640_coco_dsconv.yaml | 42 - 20 files changed, 6241 insertions(+), 42 deletions(-) create mode 100644 MindIE/MindIE-Torch/built-in/foundation/stable_audio/README.md create mode 100644 MindIE/MindIE-Torch/built-in/foundation/stable_audio/inference_stableaudio.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/stable_audio/prompts/prompts.txt create mode 100644 MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/__init__.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/__init__.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/activations.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/attention.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/attention_processor.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/embeddings.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/normalization.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/models/__init__.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/models/modeling_stable_audio.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/models/stable_audio_transformer.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/pipeline/__init__.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/pipeline/pipeline_stable_audio.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/scheduler/__init__.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/scheduler/scheduling_cosine_dpmsolver_multistep.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/vae/__init__.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/vae/autoencoder_oobleck.py delete mode 100644 PyTorch/dev/cv/image_classification/SlowFast_ID0646_for_PyTorch/detectron2/projects/Panoptic-DeepLab/configs/COCO-PanopticSegmentation/panoptic_deeplab_R_52_os16_mg124_poly_200k_bs64_crop_640_640_coco_dsconv.yaml diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/README.md b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/README.md new file mode 100644 index 0000000000..809973a563 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/README.md @@ -0,0 +1,186 @@ +# stable-audio-open-1.0模型-diffusers方式推理指导 + +- [概述](#ZH-CN_TOPIC_0000001172161501) + + +- [推理环境准备](#ZH-CN_TOPIC_0000001126281702) + +- [快速上手](#ZH-CN_TOPIC_0000001126281700) + + - [获取源码](#section4622531142816) + - [模型推理](#section741711594517) + +- [模型推理性能&精度](#ZH-CN_TOPIC_0000001172201573) + +# 概述 + + [此处获得](https://huggingface.co/stabilityai/stable-audio-open-1.0) + +- 参考实现: + ```bash + # StableAudioOpen1.0 + https://huggingface.co/stabilityai/stable-audio-open-1.0 + ``` + +- 设备支持: +Atlas 800I A2推理设备:支持的卡数为1 +Atlas 300I Duo推理卡:支持的卡数为1 + +# 推理环境准备 + +- 该模型需要以下插件与驱动 + + **表 1** 版本配套表 + + | 配套 | 版本 | 环境准备指导 | + | ----- | ----- |-----| + | Python | 3.10.2 | - | + | torch | 2.1.0 | - | + +该模型性能受CPU规格影响,建议使用64核CPU(arm)以复现性能 + +# 快速上手 +## 获取源码 +1. 安装依赖。 + ```bash + pip3 install -r requirements.txt + apt-get update + apt-get install libsndfile1 + ``` + +2. 安装mindie包 + + ```bash + # 安装mindie + source /usr/local/Ascend/ascend-toolkit/set_env.sh + chmod +x ./Ascend-mindie_xxx.run + ./Ascend-mindie_xxx.run --install + source /usr/local/Ascend/mindie/set_env.sh + ``` + +3. 代码修改 + +- 执行命令: + ```bash + python3 diffusers_aie_patch.py + python3 brownian_interval_patch.py + ``` + +4. MindieTorch配套Torch_NPU使用 + + MindieTorch采用dlopen的方式动态加载Torch_NPU,需要手动编译libtorch_npu_bridge.so,并将其放在libtorch_aie.so同一路径下,或者将其路径设置到LD_LIBRARY_PATH环境变量中,具体参考: + ```bash + https://www.hiascend.com/document/detail/zh/mindie/10RC2/mindietorch/Torchdev/mindie_torch0017.html + ``` + +## 模型推理 + +1. 模型转换。 + + 1. 提前下载权重,放到代码同级目录下。 + + ```bash + # 需要使用 git-lfs (https://git-lfs.com) + git lfs install + + # 下载stable-audio-open-1.0权重 + git clone https://huggingface.co/stabilityai/stable-audio-open-1.0 + ``` + + 2. 导出pt模型并进行编译。 + + (1) 设置模型权重的路径 + ```bash + # stable-audio-open-1.0 (执行时下载权重) + model_base="stabilityai/stable-audio-open-1.0" + + # stable-audio-open-1.0 (使用上一步下载的权重) + model_base="./stable-audio-open-1.0" + ``` + + (2) 执行命令查看芯片名称($\{chip\_name\})。 + + ``` + npu-smi info + ``` + + (3) 执行export命令 + + ```bash + python3 export_ts.py --model ${model_base} --output_dir ./models --soc Ascend${chip_name} --device 0 + ``` + + 参数说明: + - --model:模型权重路径 + - --output_dir: 存放导出模型的路径 + - --soc:处理器型号。 + - --device:推理设备ID + + 注意:trace+compile耗时较长且占用较多的CPU资源,请勿在执行export命令时运行其他占用CPU内存的任务,避免程序意外退出。 + +2. 开始推理验证。 + + 1. 开启cpu高性能模式 + ```bash + echo performance |tee /sys/devices/system/cpu/cpu*/cpufreq/scaling_governor + sysctl -w vm.swappiness=0 + sysctl -w kernel.numa_balancing=0 + ``` + + 2. 安装绑核工具 + ```bash + apt-get update + apt-get install numactl + ``` + 查询卡的NUMA node + ```shell + lspci -vs bus-id + ``` + bus-id可通过npu-smi info获得,查询到NUMA node,在推理命令前加上对应的数字 + + 可通过lscpu获得NUMA node对应的CPU核数 + ```shell + NUMA node0: 0-23 + NUMA node1: 24-47 + NUMA node2: 48-71 + NUMA node3: 72-95 + ``` + 当前查到NUMA node是0,对应0-23,推荐绑定其中单核以获得更好的性能。 + + 3. 执行推理脚本。 + ```bash + numactl -C 0-23 python3 stable_audio_open_aie_pipeline.py \ + --model ${model_base} \ + --output_dir ./models \ + --prompt_file ./prompts.txt \ + --num_inference_steps 100 \ + --audio_end_in_s 10 10 47 \ + --num_waveforms_per_prompt 1 \ + --guidance_scale 7 \ + --save_dir ./results \ + --device 0 + ``` + + 参数说明: + - --model:模型权重路径。 + - --output_dir:存放导出模型的目录。 + - --prompt_file:提示词文件。 + - --num_inference_steps: 语音生成迭代次数。 + - --audio_end_in_s:生成语音的时长,如不输入则默认生成10s。 + - --num_waveforms_per_prompt:一个提示词生成的语音数量。 + - --guidance_scale:音频生成质量与准确度系数。 + - --save_dir:生成语音的存放目录。 + - --device:推理设备ID。 + + 执行完成后在`./results`目录下生成推理语音,语音生成顺序与文本中prompt顺序保持一致,并在终端显示推理时间。 + + + +# 模型推理性能&精度 +性能参考下列数据。 + +### Stable-Audio-Open-1.0 + +| 硬件形态 | 迭代次数 | 平均耗时| +| :------: |:----:|:----:| +| Atlas 800I A2 (32G) | 100 | 5.895s | \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/inference_stableaudio.py b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/inference_stableaudio.py new file mode 100644 index 0000000000..fc4f242512 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/inference_stableaudio.py @@ -0,0 +1,129 @@ +import torch +import torch_npu +import time +import json +import os +import argparse +import soundfile as sf +from safetensors.torch import load_file + +from stableaudio.vae.autoencoder_oobleck import AutoencoderOobleck +from stableaudio.pipeline import StableAudioPipeline +from transformers import T5TokenizerFast +from transformers import T5EncoderModel +from stableaudio.models.modeling_stable_audio import StableAudioProjectionModel +from stableaudio.models.stable_audio_transformer import StableAudioDiTModel +from stableaudio.schedulers.scheduling_cosine_dpmsolver_multistep import CosineDPMSolverMultistepScheduler + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--prompt_file", + type=str, + default="./prompts/prompts.txt", + help="The prompts file to guide audio generation.", + ) + parser.add_argument( + "--negative_prompt", + type=str, + default="", + help="The prompt or prompts to guide what to not include in audio generation.", + ) + parser.add_argument( + "--num_inference_steps", + type=int, + default=100, + help="The number of denoising steps. More denoising steps usually lead to a higher quality audio at the expense of slower inference.", + ) + parser.add_argument( + "--model", + type=str, + default="./stable-audio-open-1.0", + help="The path of stable-audio-open-1.0.", + ) + parser.add_argument( + "--audio_end_in_s", + nargs='+', + default=[10], + help="Audio end index in seconds.", + ) + parser.add_argument( + "--device", + type=int, + default=0, + help="NPU device id.", + ) + parser.add_argument( + "--save_dir", + type=str, + default="./results", + help="Path to save result audio files.", + ) + return parser.parse_args() + +def main(): + args = parse_arguments() + save_dir = args.save_dir + if not os.path.exists(save_dir): + os.makedirs(save_dir) + + torch_npu.npu.set_device(args.device) + torch.manual_seed(1) + latents = torch.randn(1, 64, 1024, dtype=torch.float16,device="cpu") + with open(args.stable_audio_open_dir + "/vae/config.json", "r", encoding="utf-8") as reader: + data = reader.read() + json_data = json.loads(data) + init_dict = {key: json_data[key] for key in json_data} + vae = AutoencoderOobleck(**init_dict) + vae.load_state_dict(load_file(args.stable_audio_open_dir + "/vae/diffusion_pytorch_model.safetensors"), strict=False) + + tokenizer = T5TokenizerFast.from_pretrained(args.stable_audio_open_dir + "/tokenizer") + text_encoder = T5EncoderModel.from_pretrained(args.stable_audio_open_dir + "/text_encoder") + projection_model = StableAudioProjectionModel.from_pretrained(args.stable_audio_open_dir + "/projection_model") + audio_dit = StableAudioDiTModel.from_pretrained(args.stable_audio_open_dir + "/transformer") + scheduler = CosineDPMSolverMultistepScheduler.from_pretrained(args.stable_audio_open_dir + "/scheduler") + + npu_stream = torch_npu.npu.Stream() + vae = vae.to("npu").to(torch.float16).eval() + text_encoder = text_encoder.to("npu").to(torch.float16).eval() + projection_model = projection_model.to("npu").to(torch.float16).eval() + audio_dit = audio_dit.to("npu").to(torch.float16).eval() + + pipe = StableAudioPipeline(vae=vae, tokenizer=tokenizer, text_encoder=text_encoder, + projection_model=projection_model, transformer=audio_dit, scheduler=scheduler) + pipe.to("npu") + + total_time = 0 + prompts_num = 0 + average_time = 0 + skip = 2 + with os.fdopen(os.open(args.prompt_file, os.O_RDONLY), "r") as f: + for i, prompt in enumerate(f): + with torch.no_grad(): + npu_stream.synchronize() + audio_end_in_s = float(args.audio_end_in_s[i]) if (len(args.audio_end_in_s) > i) else 10.0 + begin = time.time() + audio = pipe( + prompt=prompt, + negative_prompt=args.negative_prompt, + num_inference_steps=args.num_inference_steps, + latents=latents.to("npu"), + audio_end_in_s=audio_end_in_s, + ).audios + npu_stream.synchronize() + end = time.time() + if i > skip - 1: + total_time += end - begin + prompts_num = i+1 + output = audio[0].T.float().cpu().numpy() + sf.write(args.save_dir + "/audio_by_prompt" + str(prompts_num) + ".wav", output, pipe.vae.sampling_rate) + if prompts_num > skip: + average_time = total_time / (prompts_num-skip) + else: + raise ValueError("Infer average time skip first two prompts, ensure that prompts.txt \ + contains more than three prompts") + print(f"Infer average time: {average_time:.3f}s\n") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/prompts/prompts.txt b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/prompts/prompts.txt new file mode 100644 index 0000000000..e1c7734ef9 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/prompts/prompts.txt @@ -0,0 +1,3 @@ +Berlin techno, rave, drum machine, kick, ARP synthesizer, dark, moody, hypnotic, evolving, 135BPM. LOOP. +Uplifting acoustic loop. 120 BPM. +Disco, Driving Drum Machine, Synthesizer, Bass, Piano, Guitars, Instrumental, Clubby, Euphoric, Chicago, New York, 115 BPM. \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/__init__.py b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/__init__.py new file mode 100644 index 0000000000..bb0fa4ea0b --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/__init__.py @@ -0,0 +1,6 @@ +__version__ = "0.30.0" +from .pipeline import StableAudioPipeline +from .models import StableAudioDiTModel +from .models import StableAudioProjectionModel +from .vae.autoencoder_oobleck import AutoencoderOobleck +from .scheduler.scheduling_cosine_dpmsolver_multistep import CosineDPMSolverMultistepScheduler \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/__init__.py b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/__init__.py new file mode 100644 index 0000000000..febbf254b3 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/__init__.py @@ -0,0 +1,6 @@ +from .attention import FeedForward +from .attention_processor import ( + Attention, + AttentionProcessor, + StableAudioAttnProcessor2_0, +) \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/activations.py b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/activations.py new file mode 100644 index 0000000000..fb24a36bae --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/activations.py @@ -0,0 +1,165 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn.functional as F +from torch import nn + +from ..utils import deprecate +from ..utils.import_utils import is_torch_npu_available + + +if is_torch_npu_available(): + import torch_npu + +ACTIVATION_FUNCTIONS = { + "swish": nn.SiLU(), + "silu": nn.SiLU(), + "mish": nn.Mish(), + "gelu": nn.GELU(), + "relu": nn.ReLU(), +} + + +def get_activation(act_fn: str) -> nn.Module: + """Helper function to get activation function from string. + + Args: + act_fn (str): Name of activation function. + + Returns: + nn.Module: Activation function. + """ + + act_fn = act_fn.lower() + if act_fn in ACTIVATION_FUNCTIONS: + return ACTIVATION_FUNCTIONS[act_fn] + else: + raise ValueError(f"Unsupported activation function: {act_fn}") + + +class FP32SiLU(nn.Module): + r""" + SiLU activation function with input upcasted to torch.float32. + """ + + def __init__(self): + super().__init__() + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + return F.silu(inputs.float(), inplace=False).to(inputs.dtype) + + +class GELU(nn.Module): + r""" + GELU activation function with tanh approximation support with `approximate="tanh"`. + + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out, bias=bias) + self.approximate = approximate + + def gelu(self, gate: torch.Tensor) -> torch.Tensor: + if gate.device.type != "mps": + return F.gelu(gate, approximate=self.approximate) + # mps: gelu is not implemented for float16 + return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype) + + def forward(self, hidden_states): + hidden_states = self.proj(hidden_states) + hidden_states = self.gelu(hidden_states) + return hidden_states + + +class GEGLU(nn.Module): + r""" + A [variant](https://arxiv.org/abs/2002.05202) of the gated linear unit activation function. + + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + def __init__(self, dim_in: int, dim_out: int, bias: bool = True): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias) + + def gelu(self, gate: torch.Tensor) -> torch.Tensor: + if gate.device.type != "mps": + return F.gelu(gate) + # mps: gelu is not implemented for float16 + return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) + + def forward(self, hidden_states, *args, **kwargs): + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + hidden_states = self.proj(hidden_states) + if is_torch_npu_available(): + # using torch_npu.npu_geglu can run faster and save memory on NPU. + return torch_npu.npu_geglu(hidden_states, dim=-1, approximate=1)[0] + else: + hidden_states, gate = hidden_states.chunk(2, dim=-1) + return hidden_states * self.gelu(gate) + + +class SwiGLU(nn.Module): + r""" + A [variant](https://arxiv.org/abs/2002.05202) of the gated linear unit activation function. It's similar to `GEGLU` + but uses SiLU / Swish instead of GeLU. + + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + def __init__(self, dim_in: int, dim_out: int, bias: bool = True): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias) + self.activation = nn.SiLU() + + def forward(self, hidden_states): + hidden_states = self.proj(hidden_states) + hidden_states, gate = hidden_states.chunk(2, dim=-1) + return hidden_states * self.activation(gate) + + +class ApproximateGELU(nn.Module): + r""" + The approximate form of the Gaussian Error Linear Unit (GELU). For more details, see section 2 of this + [paper](https://arxiv.org/abs/1606.08415). + + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + def __init__(self, dim_in: int, dim_out: int, bias: bool = True): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out, bias=bias) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x) + return x * torch.sigmoid(1.702 * x) diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/attention.py b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/attention.py new file mode 100644 index 0000000000..d8f4b1ceac --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/attention.py @@ -0,0 +1,434 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Dict, List, Optional, Tuple + +import torch +import torch.nn.functional as F +from torch import nn + +from diffusers.utils import deprecate, logging +from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, SwiGLU +from .attention_processor import Attention, JointAttnProcessor2_0 +from .embeddings import SinusoidalPositionalEmbedding +from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero + + +logger = logging.get_logger(__name__) + + +def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int): + # "feed_forward_chunk_size" can be used to save memory + if hidden_states.shape[chunk_dim] % chunk_size != 0: + raise ValueError( + f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." + ) + + num_chunks = hidden_states.shape[chunk_dim] // chunk_size + ff_output = torch.cat( + [ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)], + dim=chunk_dim, + ) + return ff_output + + +class BasicTransformerBlock(nn.Module): + r""" + A basic Transformer block. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm (: + obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (: + obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used. + double_self_attention (`bool`, *optional*): + Whether to use two self-attention layers. In this case no cross attention layers are used. + upcast_attention (`bool`, *optional*): + Whether to upcast the attention computation to float32. This is useful for mixed precision training. + norm_elementwise_affine (`bool`, *optional*, defaults to `True`): + Whether to use learnable elementwise affine parameters for normalization. + norm_type (`str`, *optional*, defaults to `"layer_norm"`): + The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`. + final_dropout (`bool` *optional*, defaults to False): + Whether to apply a final dropout after the last feed-forward layer. + attention_type (`str`, *optional*, defaults to `"default"`): + The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`. + positional_embeddings (`str`, *optional*, defaults to `None`): + The type of positional embeddings to apply to. + num_positional_embeddings (`int`, *optional*, defaults to `None`): + The maximum number of positional embeddings to apply. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen' + norm_eps: float = 1e-5, + final_dropout: bool = False, + attention_type: str = "default", + positional_embeddings: Optional[str] = None, + num_positional_embeddings: Optional[int] = None, + ada_norm_continous_conditioning_embedding_dim: Optional[int] = None, + ada_norm_bias: Optional[int] = None, + ff_inner_dim: Optional[int] = None, + ff_bias: bool = True, + attention_out_bias: bool = True, + ): + super().__init__() + self.dim = dim + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + self.dropout = dropout + self.cross_attention_dim = cross_attention_dim + self.activation_fn = activation_fn + self.attention_bias = attention_bias + self.double_self_attention = double_self_attention + self.norm_elementwise_affine = norm_elementwise_affine + self.positional_embeddings = positional_embeddings + self.num_positional_embeddings = num_positional_embeddings + self.only_cross_attention = only_cross_attention + + # We keep these boolean flags for backward-compatibility. + self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" + self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" + self.use_ada_layer_norm_single = norm_type == "ada_norm_single" + self.use_layer_norm = norm_type == "layer_norm" + self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous" + + if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: + raise ValueError( + f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" + f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." + ) + + self.norm_type = norm_type + self.num_embeds_ada_norm = num_embeds_ada_norm + + if positional_embeddings and (num_positional_embeddings is None): + raise ValueError( + "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined." + ) + + if positional_embeddings == "sinusoidal": + self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings) + else: + self.pos_embed = None + + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + if norm_type == "ada_norm": + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) + elif norm_type == "ada_norm_zero": + self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) + elif norm_type == "ada_norm_continuous": + self.norm1 = AdaLayerNormContinuous( + dim, + ada_norm_continous_conditioning_embedding_dim, + norm_elementwise_affine, + norm_eps, + ada_norm_bias, + "rms_norm", + ) + else: + self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + out_bias=attention_out_bias, + ) + + # 2. Cross-Attn + if cross_attention_dim is not None or double_self_attention: + # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. + # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during + # the second cross attention block. + if norm_type == "ada_norm": + self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) + elif norm_type == "ada_norm_continuous": + self.norm2 = AdaLayerNormContinuous( + dim, + ada_norm_continous_conditioning_embedding_dim, + norm_elementwise_affine, + norm_eps, + ada_norm_bias, + "rms_norm", + ) + else: + self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) + + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim if not double_self_attention else None, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + out_bias=attention_out_bias, + ) # is self-attn if encoder_hidden_states is none + else: + if norm_type == "ada_norm_single": # For Latte + self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) + else: + self.norm2 = None + self.attn2 = None + + # 3. Feed-forward + if norm_type == "ada_norm_continuous": + self.norm3 = AdaLayerNormContinuous( + dim, + ada_norm_continous_conditioning_embedding_dim, + norm_elementwise_affine, + norm_eps, + ada_norm_bias, + "layer_norm", + ) + + elif norm_type in ["ada_norm_zero", "ada_norm", "layer_norm"]: + self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) + elif norm_type == "layer_norm_i2vgen": + self.norm3 = None + + self.ff = FeedForward( + dim, + dropout=dropout, + activation_fn=activation_fn, + final_dropout=final_dropout, + inner_dim=ff_inner_dim, + bias=ff_bias, + ) + + # 4. Fuser + if attention_type == "gated" or attention_type == "gated-text-image": + self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim) + + # 5. Scale-shift for PixArt-Alpha. + if norm_type == "ada_norm_single": + self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5) + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): + # Sets chunk feed-forward + self._chunk_size = chunk_size + self._chunk_dim = dim + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + timestep: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[torch.LongTensor] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + ) -> torch.Tensor: + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + # Notice that normalization is always applied before the real computation in the following blocks. + # 0. Self-Attention + batch_size = hidden_states.shape[0] + + if self.norm_type == "ada_norm": + norm_hidden_states = self.norm1(hidden_states, timestep) + elif self.norm_type == "ada_norm_zero": + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]: + norm_hidden_states = self.norm1(hidden_states) + elif self.norm_type == "ada_norm_continuous": + norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"]) + elif self.norm_type == "ada_norm_single": + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) + ).chunk(6, dim=1) + norm_hidden_states = self.norm1(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa + else: + raise ValueError("Incorrect norm used") + + if self.pos_embed is not None: + norm_hidden_states = self.pos_embed(norm_hidden_states) + + # 1. Prepare GLIGEN inputs + cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} + gligen_kwargs = cross_attention_kwargs.pop("gligen", None) + + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + + if self.norm_type == "ada_norm_zero": + attn_output = gate_msa.unsqueeze(1) * attn_output + elif self.norm_type == "ada_norm_single": + attn_output = gate_msa * attn_output + + hidden_states = attn_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + # 1.2 GLIGEN Control + if gligen_kwargs is not None: + hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"]) + + # 3. Cross-Attention + if self.attn2 is not None: + if self.norm_type == "ada_norm": + norm_hidden_states = self.norm2(hidden_states, timestep) + elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]: + norm_hidden_states = self.norm2(hidden_states) + elif self.norm_type == "ada_norm_single": + # For PixArt norm2 isn't applied here: + # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103 + norm_hidden_states = hidden_states + elif self.norm_type == "ada_norm_continuous": + norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"]) + else: + raise ValueError("Incorrect norm") + + if self.pos_embed is not None and self.norm_type != "ada_norm_single": + norm_hidden_states = self.pos_embed(norm_hidden_states) + + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + hidden_states = attn_output + hidden_states + + # 4. Feed-forward + # i2vgen doesn't have this norm 🤷‍♂️ + if self.norm_type == "ada_norm_continuous": + norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"]) + elif not self.norm_type == "ada_norm_single": + norm_hidden_states = self.norm3(hidden_states) + + if self.norm_type == "ada_norm_zero": + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + if self.norm_type == "ada_norm_single": + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp + + if self._chunk_size is not None: + # "feed_forward_chunk_size" can be used to save memory + ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) + else: + ff_output = self.ff(norm_hidden_states) + + if self.norm_type == "ada_norm_zero": + ff_output = gate_mlp.unsqueeze(1) * ff_output + elif self.norm_type == "ada_norm_single": + ff_output = gate_mlp * ff_output + + hidden_states = ff_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + return hidden_states + + +class FeedForward(nn.Module): + r""" + A feed-forward layer. + + Parameters: + dim (`int`): The number of channels in the input. + dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. + mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + mult: int = 4, + dropout: float = 0.0, + activation_fn: str = "geglu", + final_dropout: bool = False, + inner_dim=None, + bias: bool = True, + ): + super().__init__() + if inner_dim is None: + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + + if activation_fn == "gelu": + act_fn = GELU(dim, inner_dim, bias=bias) + if activation_fn == "gelu-approximate": + act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias) + elif activation_fn == "geglu": + act_fn = GEGLU(dim, inner_dim, bias=bias) + elif activation_fn == "geglu-approximate": + act_fn = ApproximateGELU(dim, inner_dim, bias=bias) + elif activation_fn == "swiglu": + act_fn = SwiGLU(dim, inner_dim, bias=bias) + + self.net = nn.ModuleList([]) + # project in + self.net.append(act_fn) + # project dropout + self.net.append(nn.Dropout(dropout)) + # project out + self.net.append(nn.Linear(inner_dim, dim_out, bias=bias)) + # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout + if final_dropout: + self.net.append(nn.Dropout(dropout)) + + def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + for module in self.net: + hidden_states = module(hidden_states) + return hidden_states diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/attention_processor.py b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/attention_processor.py new file mode 100644 index 0000000000..55264cbfd2 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/attention_processor.py @@ -0,0 +1,904 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect +import math +from typing import Callable, List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from torch import nn + +from diffusers.utils import deprecate, logging + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class Attention(nn.Module): + r""" + A cross attention layer. + + Parameters: + query_dim (`int`): + The number of channels in the query. + cross_attention_dim (`int`, *optional*): + The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. + heads (`int`, *optional*, defaults to 8): + The number of heads to use for multi-head attention. + kv_heads (`int`, *optional*, defaults to `None`): + The number of key and value heads to use for multi-head attention. Defaults to `heads`. If + `kv_heads=heads`, the model will use Multi Head Attention (MHA), if `kv_heads=1` the model will use Multi + Query Attention (MQA) otherwise GQA is used. + dim_head (`int`, *optional*, defaults to 64): + The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): + The dropout probability to use. + bias (`bool`, *optional*, defaults to False): + Set to `True` for the query, key, and value linear layers to contain a bias parameter. + upcast_attention (`bool`, *optional*, defaults to False): + Set to `True` to upcast the attention computation to `float32`. + upcast_softmax (`bool`, *optional*, defaults to False): + Set to `True` to upcast the softmax computation to `float32`. + cross_attention_norm (`str`, *optional*, defaults to `None`): + The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`. + cross_attention_norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups to use for the group norm in the cross attention. + added_kv_proj_dim (`int`, *optional*, defaults to `None`): + The number of channels to use for the added key and value projections. If `None`, no projection is used. + norm_num_groups (`int`, *optional*, defaults to `None`): + The number of groups to use for the group norm in the attention. + spatial_norm_dim (`int`, *optional*, defaults to `None`): + The number of channels to use for the spatial normalization. + out_bias (`bool`, *optional*, defaults to `True`): + Set to `True` to use a bias in the output linear layer. + scale_qk (`bool`, *optional*, defaults to `True`): + Set to `True` to scale the query and key by `1 / sqrt(dim_head)`. + only_cross_attention (`bool`, *optional*, defaults to `False`): + Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if + `added_kv_proj_dim` is not `None`. + eps (`float`, *optional*, defaults to 1e-5): + An additional value added to the denominator in group normalization that is used for numerical stability. + rescale_output_factor (`float`, *optional*, defaults to 1.0): + A factor to rescale the output by dividing it with this value. + residual_connection (`bool`, *optional*, defaults to `False`): + Set to `True` to add the residual connection to the output. + _from_deprecated_attn_block (`bool`, *optional*, defaults to `False`): + Set to `True` if the attention block is loaded from a deprecated state dict. + processor (`AttnProcessor`, *optional*, defaults to `None`): + The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and + `AttnProcessor` otherwise. + """ + + def __init__( + self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + kv_heads: Optional[int] = None, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + upcast_attention: bool = False, + upcast_softmax: bool = False, + cross_attention_norm: Optional[str] = None, + cross_attention_norm_num_groups: int = 32, + qk_norm: Optional[str] = None, + added_kv_proj_dim: Optional[int] = None, + added_proj_bias: Optional[bool] = True, + norm_num_groups: Optional[int] = None, + spatial_norm_dim: Optional[int] = None, + out_bias: bool = True, + scale_qk: bool = True, + only_cross_attention: bool = False, + eps: float = 1e-5, + rescale_output_factor: float = 1.0, + residual_connection: bool = False, + _from_deprecated_attn_block: bool = False, + processor: Optional["AttnProcessor"] = None, + out_dim: int = None, + context_pre_only=None, + pre_only=False, + ): + super().__init__() + + # To prevent circular import. + from .normalization import FP32LayerNorm, RMSNorm + + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads + self.query_dim = query_dim + self.use_bias = bias + self.is_cross_attention = cross_attention_dim is not None + self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim + self.upcast_attention = upcast_attention + self.upcast_softmax = upcast_softmax + self.rescale_output_factor = rescale_output_factor + self.residual_connection = residual_connection + self.dropout = dropout + self.fused_projections = False + self.out_dim = out_dim if out_dim is not None else query_dim + self.context_pre_only = context_pre_only + self.pre_only = pre_only + + # we make use of this private variable to know whether this class is loaded + # with an deprecated state dict so that we can convert it on the fly + self._from_deprecated_attn_block = _from_deprecated_attn_block + + self.scale_qk = scale_qk + self.scale = dim_head**-0.5 if self.scale_qk else 1.0 + + self.heads = out_dim // dim_head if out_dim is not None else heads + # for slice_size > 0 the attention score computation + # is split across the batch axis to save memory + # You can set slice_size with `set_attention_slice` + self.sliceable_head_dim = heads + + self.added_kv_proj_dim = added_kv_proj_dim + self.only_cross_attention = only_cross_attention + + if self.added_kv_proj_dim is None and self.only_cross_attention: + raise ValueError( + "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`." + ) + + if norm_num_groups is not None: + self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True) + else: + self.group_norm = None + + if spatial_norm_dim is not None: + self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim) + else: + self.spatial_norm = None + + if qk_norm is None: + self.norm_q = None + self.norm_k = None + elif qk_norm == "layer_norm": + self.norm_q = nn.LayerNorm(dim_head, eps=eps) + self.norm_k = nn.LayerNorm(dim_head, eps=eps) + elif qk_norm == "fp32_layer_norm": + self.norm_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) + self.norm_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) + elif qk_norm == "layer_norm_across_heads": + # Lumina applys qk norm across all heads + self.norm_q = nn.LayerNorm(dim_head * heads, eps=eps) + self.norm_k = nn.LayerNorm(dim_head * kv_heads, eps=eps) + elif qk_norm == "rms_norm": + self.norm_q = RMSNorm(dim_head, eps=eps) + self.norm_k = RMSNorm(dim_head, eps=eps) + else: + raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None or 'layer_norm'") + + if cross_attention_norm is None: + self.norm_cross = None + elif cross_attention_norm == "layer_norm": + self.norm_cross = nn.LayerNorm(self.cross_attention_dim) + elif cross_attention_norm == "group_norm": + if self.added_kv_proj_dim is not None: + # The given `encoder_hidden_states` are initially of shape + # (batch_size, seq_len, added_kv_proj_dim) before being projected + # to (batch_size, seq_len, cross_attention_dim). The norm is applied + # before the projection, so we need to use `added_kv_proj_dim` as + # the number of channels for the group norm. + norm_cross_num_channels = added_kv_proj_dim + else: + norm_cross_num_channels = self.cross_attention_dim + + self.norm_cross = nn.GroupNorm( + num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True + ) + else: + raise ValueError( + f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'" + ) + + self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias) + + if not self.only_cross_attention: + # only relevant for the `AddedKVProcessor` classes + self.to_k = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias) + self.to_v = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias) + else: + self.to_k = None + self.to_v = None + + self.added_proj_bias = added_proj_bias + if self.added_kv_proj_dim is not None: + self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias) + self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias) + if self.context_pre_only is not None: + self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + + if not self.pre_only: + self.to_out = nn.ModuleList([]) + self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) + self.to_out.append(nn.Dropout(dropout)) + + if self.context_pre_only is not None and not self.context_pre_only: + self.to_add_out = nn.Linear(self.inner_dim, self.out_dim, bias=out_bias) + + if qk_norm is not None and added_kv_proj_dim is not None: + if qk_norm == "fp32_layer_norm": + self.norm_added_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) + self.norm_added_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) + elif qk_norm == "rms_norm": + self.norm_added_q = RMSNorm(dim_head, eps=eps) + self.norm_added_k = RMSNorm(dim_head, eps=eps) + else: + self.norm_added_q = None + self.norm_added_k = None + + # set attention processor + # We use the AttnProcessor2_0 by default when torch 2.x is used which uses + # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention + # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 + if processor is None: + processor = ( + AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() + ) + self.set_processor(processor) + + def set_use_npu_flash_attention(self, use_npu_flash_attention: bool) -> None: + r""" + Set whether to use npu flash attention from `torch_npu` or not. + + """ + if use_npu_flash_attention: + processor = AttnProcessorNPU() + else: + # set attention processor + # We use the AttnProcessor2_0 by default when torch 2.x is used which uses + # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention + # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 + processor = ( + AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() + ) + self.set_processor(processor) + + def set_use_memory_efficient_attention_xformers( + self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None + ) -> None: + r""" + Set whether to use memory efficient attention from `xformers` or not. + + Args: + use_memory_efficient_attention_xformers (`bool`): + Whether to use memory efficient attention from `xformers` or not. + attention_op (`Callable`, *optional*): + The attention operation to use. Defaults to `None` which uses the default attention operation from + `xformers`. + """ + is_custom_diffusion = hasattr(self, "processor") and isinstance( + self.processor, + (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor, CustomDiffusionAttnProcessor2_0), + ) + is_added_kv_processor = hasattr(self, "processor") and isinstance( + self.processor, + ( + AttnAddedKVProcessor, + AttnAddedKVProcessor2_0, + SlicedAttnAddedKVProcessor, + XFormersAttnAddedKVProcessor, + ), + ) + + if use_memory_efficient_attention_xformers: + if is_added_kv_processor and is_custom_diffusion: + raise NotImplementedError( + f"Memory efficient attention is currently not supported for custom diffusion for attention processor type {self.processor}" + ) + if not is_xformers_available(): + raise ModuleNotFoundError( + ( + "Refer to https://github.com/facebookresearch/xformers for more information on how to install" + " xformers" + ), + name="xformers", + ) + elif not torch.cuda.is_available(): + raise ValueError( + "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is" + " only available for GPU " + ) + else: + try: + # Make sure we can run the memory efficient attention + _ = xformers.ops.memory_efficient_attention( + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + ) + except Exception as e: + raise e + + if is_custom_diffusion: + processor = CustomDiffusionXFormersAttnProcessor( + train_kv=self.processor.train_kv, + train_q_out=self.processor.train_q_out, + hidden_size=self.processor.hidden_size, + cross_attention_dim=self.processor.cross_attention_dim, + attention_op=attention_op, + ) + processor.load_state_dict(self.processor.state_dict()) + if hasattr(self.processor, "to_k_custom_diffusion"): + processor.to(self.processor.to_k_custom_diffusion.weight.device) + elif is_added_kv_processor: + # TODO(Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP + # which uses this type of cross attention ONLY because the attention mask of format + # [0, ..., -10.000, ..., 0, ...,] is not supported + # throw warning + logger.info( + "Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation." + ) + processor = XFormersAttnAddedKVProcessor(attention_op=attention_op) + else: + processor = XFormersAttnProcessor(attention_op=attention_op) + else: + if is_custom_diffusion: + attn_processor_class = ( + CustomDiffusionAttnProcessor2_0 + if hasattr(F, "scaled_dot_product_attention") + else CustomDiffusionAttnProcessor + ) + processor = attn_processor_class( + train_kv=self.processor.train_kv, + train_q_out=self.processor.train_q_out, + hidden_size=self.processor.hidden_size, + cross_attention_dim=self.processor.cross_attention_dim, + ) + processor.load_state_dict(self.processor.state_dict()) + if hasattr(self.processor, "to_k_custom_diffusion"): + processor.to(self.processor.to_k_custom_diffusion.weight.device) + else: + # set attention processor + # We use the AttnProcessor2_0 by default when torch 2.x is used which uses + # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention + # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 + processor = ( + AttnProcessor2_0() + if hasattr(F, "scaled_dot_product_attention") and self.scale_qk + else AttnProcessor() + ) + + self.set_processor(processor) + + def set_attention_slice(self, slice_size: int) -> None: + r""" + Set the slice size for attention computation. + + Args: + slice_size (`int`): + The slice size for attention computation. + """ + if slice_size is not None and slice_size > self.sliceable_head_dim: + raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.") + + if slice_size is not None and self.added_kv_proj_dim is not None: + processor = SlicedAttnAddedKVProcessor(slice_size) + elif slice_size is not None: + processor = SlicedAttnProcessor(slice_size) + elif self.added_kv_proj_dim is not None: + processor = AttnAddedKVProcessor() + else: + # set attention processor + # We use the AttnProcessor2_0 by default when torch 2.x is used which uses + # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention + # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 + processor = ( + AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() + ) + + self.set_processor(processor) + + def set_processor(self, processor: "AttnProcessor") -> None: + r""" + Set the attention processor to use. + + Args: + processor (`AttnProcessor`): + The attention processor to use. + """ + # if current processor is in `self._modules` and if passed `processor` is not, we need to + # pop `processor` from `self._modules` + if ( + hasattr(self, "processor") + and isinstance(self.processor, torch.nn.Module) + and not isinstance(processor, torch.nn.Module) + ): + logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}") + self._modules.pop("processor") + + self.processor = processor + + def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProcessor": + r""" + Get the attention processor in use. + + Args: + return_deprecated_lora (`bool`, *optional*, defaults to `False`): + Set to `True` to return the deprecated LoRA attention processor. + + Returns: + "AttentionProcessor": The attention processor in use. + """ + if not return_deprecated_lora: + return self.processor + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + **cross_attention_kwargs, + ) -> torch.Tensor: + r""" + The forward method of the `Attention` class. + + Args: + hidden_states (`torch.Tensor`): + The hidden states of the query. + encoder_hidden_states (`torch.Tensor`, *optional*): + The hidden states of the encoder. + attention_mask (`torch.Tensor`, *optional*): + The attention mask to use. If `None`, no mask is applied. + **cross_attention_kwargs: + Additional keyword arguments to pass along to the cross attention. + + Returns: + `torch.Tensor`: The output of the attention layer. + """ + # The `Attention` class can call different attention processors / attention functions + # here we simply pass along all tensors to the selected processor class + # For standard processors that are defined here, `**cross_attention_kwargs` is empty + + attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) + quiet_attn_parameters = {"ip_adapter_masks"} + unused_kwargs = [ + k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters + ] + if len(unused_kwargs) > 0: + logger.warning( + f"cross_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored." + ) + cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters} + + return self.processor( + self, + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + + def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor: + r""" + Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`. `heads` + is the number of heads initialized while constructing the `Attention` class. + + Args: + tensor (`torch.Tensor`): The tensor to reshape. + + Returns: + `torch.Tensor`: The reshaped tensor. + """ + head_size = self.heads + batch_size, seq_len, dim = tensor.shape + tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) + return tensor + + def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor: + r""" + Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size, seq_len, heads, dim // heads]` `heads` is + the number of heads initialized while constructing the `Attention` class. + + Args: + tensor (`torch.Tensor`): The tensor to reshape. + out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor. If `3`, the tensor is + reshaped to `[batch_size * heads, seq_len, dim // heads]`. + + Returns: + `torch.Tensor`: The reshaped tensor. + """ + head_size = self.heads + if tensor.ndim == 3: + batch_size, seq_len, dim = tensor.shape + extra_dim = 1 + else: + batch_size, extra_dim, seq_len, dim = tensor.shape + tensor = tensor.reshape(batch_size, seq_len * extra_dim, head_size, dim // head_size) + tensor = tensor.permute(0, 2, 1, 3) + + if out_dim == 3: + tensor = tensor.reshape(batch_size * head_size, seq_len * extra_dim, dim // head_size) + + return tensor + + def get_attention_scores( + self, query: torch.Tensor, key: torch.Tensor, attention_mask: Optional[torch.Tensor] = None + ) -> torch.Tensor: + r""" + Compute the attention scores. + + Args: + query (`torch.Tensor`): The query tensor. + key (`torch.Tensor`): The key tensor. + attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied. + + Returns: + `torch.Tensor`: The attention probabilities/scores. + """ + dtype = query.dtype + if self.upcast_attention: + query = query.float() + key = key.float() + + if attention_mask is None: + baddbmm_input = torch.empty( + query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device + ) + beta = 0 + else: + baddbmm_input = attention_mask + beta = 1 + + attention_scores = torch.baddbmm( + baddbmm_input, + query, + key.transpose(-1, -2), + beta=beta, + alpha=self.scale, + ) + del baddbmm_input + + if self.upcast_softmax: + attention_scores = attention_scores.float() + + attention_probs = attention_scores.softmax(dim=-1) + del attention_scores + + attention_probs = attention_probs.to(dtype) + + return attention_probs + + def prepare_attention_mask( + self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3 + ) -> torch.Tensor: + r""" + Prepare the attention mask for the attention computation. + + Args: + attention_mask (`torch.Tensor`): + The attention mask to prepare. + target_length (`int`): + The target length of the attention mask. This is the length of the attention mask after padding. + batch_size (`int`): + The batch size, which is used to repeat the attention mask. + out_dim (`int`, *optional*, defaults to `3`): + The output dimension of the attention mask. Can be either `3` or `4`. + + Returns: + `torch.Tensor`: The prepared attention mask. + """ + head_size = self.heads + if attention_mask is None: + return attention_mask + + current_length: int = attention_mask.shape[-1] + if current_length != target_length: + if attention_mask.device.type == "mps": + # HACK: MPS: Does not support padding by greater than dimension of input tensor. + # Instead, we can manually construct the padding tensor. + padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length) + padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device) + attention_mask = torch.cat([attention_mask, padding], dim=2) + else: + # TODO: for pipelines such as stable-diffusion, padding cross-attn mask: + # we want to instead pad by (0, remaining_length), where remaining_length is: + # remaining_length: int = target_length - current_length + # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + + if out_dim == 3: + if attention_mask.shape[0] < batch_size * head_size: + attention_mask = attention_mask.repeat_interleave(head_size, dim=0) + elif out_dim == 4: + attention_mask = attention_mask.unsqueeze(1) + attention_mask = attention_mask.repeat_interleave(head_size, dim=1) + + return attention_mask + + def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor: + r""" + Normalize the encoder hidden states. Requires `self.norm_cross` to be specified when constructing the + `Attention` class. + + Args: + encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder. + + Returns: + `torch.Tensor`: The normalized encoder hidden states. + """ + assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states" + + if isinstance(self.norm_cross, nn.LayerNorm): + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + elif isinstance(self.norm_cross, nn.GroupNorm): + # Group norm norms along the channels dimension and expects + # input to be in the shape of (N, C, *). In this case, we want + # to norm along the hidden dimension, so we need to move + # (batch_size, sequence_length, hidden_size) -> + # (batch_size, hidden_size, sequence_length) + encoder_hidden_states = encoder_hidden_states.transpose(1, 2) + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.transpose(1, 2) + else: + assert False + + return encoder_hidden_states + + @torch.no_grad() + def fuse_projections(self, fuse=True): + device = self.to_q.weight.data.device + dtype = self.to_q.weight.data.dtype + + if not self.is_cross_attention: + # fetch weight matrices. + concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data]) + in_features = concatenated_weights.shape[1] + out_features = concatenated_weights.shape[0] + + # create a new single projection layer and copy over the weights. + self.to_qkv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype) + self.to_qkv.weight.copy_(concatenated_weights) + if self.use_bias: + concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data]) + self.to_qkv.bias.copy_(concatenated_bias) + + else: + concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data]) + in_features = concatenated_weights.shape[1] + out_features = concatenated_weights.shape[0] + + self.to_kv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype) + self.to_kv.weight.copy_(concatenated_weights) + if self.use_bias: + concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data]) + self.to_kv.bias.copy_(concatenated_bias) + + # handle added projections for SD3 and others. + if hasattr(self, "add_q_proj") and hasattr(self, "add_k_proj") and hasattr(self, "add_v_proj"): + concatenated_weights = torch.cat( + [self.add_q_proj.weight.data, self.add_k_proj.weight.data, self.add_v_proj.weight.data] + ) + in_features = concatenated_weights.shape[1] + out_features = concatenated_weights.shape[0] + + self.to_added_qkv = nn.Linear( + in_features, out_features, bias=self.added_proj_bias, device=device, dtype=dtype + ) + self.to_added_qkv.weight.copy_(concatenated_weights) + if self.added_proj_bias: + concatenated_bias = torch.cat( + [self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data] + ) + self.to_added_qkv.bias.copy_(concatenated_bias) + + self.fused_projections = fuse + + +class AttnProcessor: + r""" + Default processor for performing attention-related computations. + """ + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + *args, + **kwargs, + ) -> torch.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + +class StableAudioAttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is + used in the Stable Audio model. It applies rotary embedding on query and key vector, and allows MHA, GQA or MQA. + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "StableAudioAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + def apply_partial_rotary_emb( + self, + x: torch.Tensor, + freqs_cis: Tuple[torch.Tensor], + ) -> torch.Tensor: + from .embeddings import apply_rotary_emb + + rot_dim = freqs_cis[0].shape[-1] + x_to_rotate, x_unrotated = x[..., :rot_dim], x[..., rot_dim:] + + x_rotated = apply_rotary_emb(x_to_rotate, freqs_cis, use_real=True, use_real_unbind_dim=-2) + + out = torch.cat((x_rotated, x_unrotated), dim=-1) + return out + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + from .embeddings import apply_rotary_emb + + residual = hidden_states + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + head_dim = query.shape[-1] // attn.heads + kv_heads = key.shape[-1] // head_dim + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2) + + if kv_heads != attn.heads: + # if GQA or MQA, repeat the key/value heads to reach the number of query heads. + heads_per_kv_head = attn.heads // kv_heads + key = torch.repeat_interleave(key, heads_per_kv_head, dim=1) + value = torch.repeat_interleave(value, heads_per_kv_head, dim=1) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Apply RoPE if needed + if rotary_emb is not None: + query_dtype = query.dtype + key_dtype = key.dtype + query = query.to(torch.float32) + key = key.to(torch.float32) + + rot_dim = rotary_emb[0].shape[-1] + query_to_rotate, query_unrotated = query[..., :rot_dim], query[..., rot_dim:] + query_rotated = apply_rotary_emb(query_to_rotate, rotary_emb, use_real=True, use_real_unbind_dim=-2) + + query = torch.cat((query_rotated, query_unrotated), dim=-1) + + if not attn.is_cross_attention: + key_to_rotate, key_unrotated = key[..., :rot_dim], key[..., rot_dim:] + key_rotated = apply_rotary_emb(key_to_rotate, rotary_emb, use_real=True, use_real_unbind_dim=-2) + + key = torch.cat((key_rotated, key_unrotated), dim=-1) + + query = query.to(query_dtype) + key = key.to(key_dtype) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +AttentionProcessor =AttnProcessor \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/embeddings.py b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/embeddings.py new file mode 100644 index 0000000000..1258964385 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/embeddings.py @@ -0,0 +1,1556 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + +from ..utils import deprecate +from .activations import FP32SiLU, get_activation +from .attention_processor import Attention + + +def get_timestep_embedding( + timesteps: torch.Tensor, + embedding_dim: int, + flip_sin_to_cos: bool = False, + downscale_freq_shift: float = 1, + scale: float = 1, + max_period: int = 10000, +): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. + + Args + timesteps (torch.Tensor): + a 1-D Tensor of N indices, one per batch element. These may be fractional. + embedding_dim (int): + the dimension of the output. + flip_sin_to_cos (bool): + Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False) + downscale_freq_shift (float): + Controls the delta between frequencies between dimensions + scale (float): + Scaling factor applied to the embeddings. + max_period (int): + Controls the maximum frequency of the embeddings + Returns + torch.Tensor: an [N x dim] Tensor of positional embeddings. + """ + assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" + + half_dim = embedding_dim // 2 + exponent = -math.log(max_period) * torch.arange( + start=0, end=half_dim, dtype=torch.float32, device=timesteps.device + ) + exponent = exponent / (half_dim - downscale_freq_shift) + + emb = torch.exp(exponent) + emb = timesteps[:, None].float() * emb[None, :] + + # scale embeddings + emb = scale * emb + + # concat sine and cosine embeddings + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) + + # flip sine and cosine embeddings + if flip_sin_to_cos: + emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) + + # zero pad + if embedding_dim % 2 == 1: + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +def get_3d_sincos_pos_embed( + embed_dim: int, + spatial_size: Union[int, Tuple[int, int]], + temporal_size: int, + spatial_interpolation_scale: float = 1.0, + temporal_interpolation_scale: float = 1.0, +) -> np.ndarray: + r""" + Args: + embed_dim (`int`): + spatial_size (`int` or `Tuple[int, int]`): + temporal_size (`int`): + spatial_interpolation_scale (`float`, defaults to 1.0): + temporal_interpolation_scale (`float`, defaults to 1.0): + """ + if embed_dim % 4 != 0: + raise ValueError("`embed_dim` must be divisible by 4") + if isinstance(spatial_size, int): + spatial_size = (spatial_size, spatial_size) + + embed_dim_spatial = 3 * embed_dim // 4 + embed_dim_temporal = embed_dim // 4 + + # 1. Spatial + grid_h = np.arange(spatial_size[1], dtype=np.float32) / spatial_interpolation_scale + grid_w = np.arange(spatial_size[0], dtype=np.float32) / spatial_interpolation_scale + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, spatial_size[1], spatial_size[0]]) + pos_embed_spatial = get_2d_sincos_pos_embed_from_grid(embed_dim_spatial, grid) + + # 2. Temporal + grid_t = np.arange(temporal_size, dtype=np.float32) / temporal_interpolation_scale + pos_embed_temporal = get_1d_sincos_pos_embed_from_grid(embed_dim_temporal, grid_t) + + # 3. Concat + pos_embed_spatial = pos_embed_spatial[np.newaxis, :, :] + pos_embed_spatial = np.repeat(pos_embed_spatial, temporal_size, axis=0) # [T, H*W, D // 4 * 3] + + pos_embed_temporal = pos_embed_temporal[:, np.newaxis, :] + pos_embed_temporal = np.repeat(pos_embed_temporal, spatial_size[0] * spatial_size[1], axis=1) # [T, H*W, D // 4] + + pos_embed = np.concatenate([pos_embed_temporal, pos_embed_spatial], axis=-1) # [T, H*W, D] + return pos_embed + + +def get_2d_sincos_pos_embed( + embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16 +): + """ + grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or + [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + if isinstance(grid_size, int): + grid_size = (grid_size, grid_size) + + grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size) / interpolation_scale + grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size) / interpolation_scale + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size[1], grid_size[0]]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token and extra_tokens > 0: + pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + if embed_dim % 2 != 0: + raise ValueError("embed_dim must be divisible by 2") + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) + """ + if embed_dim % 2 != 0: + raise ValueError("embed_dim must be divisible by 2") + + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +class PatchEmbed(nn.Module): + """2D Image to Patch Embedding with support for SD3 cropping.""" + + def __init__( + self, + height=224, + width=224, + patch_size=16, + in_channels=3, + embed_dim=768, + layer_norm=False, + flatten=True, + bias=True, + interpolation_scale=1, + pos_embed_type="sincos", + pos_embed_max_size=None, # For SD3 cropping + ): + super().__init__() + + num_patches = (height // patch_size) * (width // patch_size) + self.flatten = flatten + self.layer_norm = layer_norm + self.pos_embed_max_size = pos_embed_max_size + + self.proj = nn.Conv2d( + in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias + ) + if layer_norm: + self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6) + else: + self.norm = None + + self.patch_size = patch_size + self.height, self.width = height // patch_size, width // patch_size + self.base_size = height // patch_size + self.interpolation_scale = interpolation_scale + + # Calculate positional embeddings based on max size or default + if pos_embed_max_size: + grid_size = pos_embed_max_size + else: + grid_size = int(num_patches**0.5) + + if pos_embed_type is None: + self.pos_embed = None + elif pos_embed_type == "sincos": + pos_embed = get_2d_sincos_pos_embed( + embed_dim, grid_size, base_size=self.base_size, interpolation_scale=self.interpolation_scale + ) + persistent = True if pos_embed_max_size else False + self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=persistent) + else: + raise ValueError(f"Unsupported pos_embed_type: {pos_embed_type}") + + def cropped_pos_embed(self, height, width): + """Crops positional embeddings for SD3 compatibility.""" + if self.pos_embed_max_size is None: + raise ValueError("`pos_embed_max_size` must be set for cropping.") + + height = height // self.patch_size + width = width // self.patch_size + if height > self.pos_embed_max_size: + raise ValueError( + f"Height ({height}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}." + ) + if width > self.pos_embed_max_size: + raise ValueError( + f"Width ({width}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}." + ) + + top = (self.pos_embed_max_size - height) // 2 + left = (self.pos_embed_max_size - width) // 2 + spatial_pos_embed = self.pos_embed.reshape(1, self.pos_embed_max_size, self.pos_embed_max_size, -1) + spatial_pos_embed = spatial_pos_embed[:, top : top + height, left : left + width, :] + spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1]) + return spatial_pos_embed + + def forward(self, latent): + if self.pos_embed_max_size is not None: + height, width = latent.shape[-2:] + else: + height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size + + latent = self.proj(latent) + if self.flatten: + latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC + if self.layer_norm: + latent = self.norm(latent) + if self.pos_embed is None: + return latent.to(latent.dtype) + # Interpolate or crop positional embeddings as needed + if self.pos_embed_max_size: + pos_embed = self.cropped_pos_embed(height, width) + else: + if self.height != height or self.width != width: + pos_embed = get_2d_sincos_pos_embed( + embed_dim=self.pos_embed.shape[-1], + grid_size=(height, width), + base_size=self.base_size, + interpolation_scale=self.interpolation_scale, + ) + pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).to(latent.device) + else: + pos_embed = self.pos_embed + + return (latent + pos_embed).to(latent.dtype) + + +class LuminaPatchEmbed(nn.Module): + """2D Image to Patch Embedding with support for Lumina-T2X""" + + def __init__(self, patch_size=2, in_channels=4, embed_dim=768, bias=True): + super().__init__() + self.patch_size = patch_size + self.proj = nn.Linear( + in_features=patch_size * patch_size * in_channels, + out_features=embed_dim, + bias=bias, + ) + + def forward(self, x, freqs_cis): + """ + Patchifies and embeds the input tensor(s). + + Args: + x (List[torch.Tensor] | torch.Tensor): The input tensor(s) to be patchified and embedded. + + Returns: + Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], torch.Tensor]: A tuple containing the patchified + and embedded tensor(s), the mask indicating the valid patches, the original image size(s), and the + frequency tensor(s). + """ + freqs_cis = freqs_cis.to(x[0].device) + patch_height = patch_width = self.patch_size + batch_size, channel, height, width = x.size() + height_tokens, width_tokens = height // patch_height, width // patch_width + + x = x.view(batch_size, channel, height_tokens, patch_height, width_tokens, patch_width).permute( + 0, 2, 4, 1, 3, 5 + ) + x = x.flatten(3) + x = self.proj(x) + x = x.flatten(1, 2) + + mask = torch.ones(x.shape[0], x.shape[1], dtype=torch.int32, device=x.device) + + return ( + x, + mask, + [(height, width)] * batch_size, + freqs_cis[:height_tokens, :width_tokens].flatten(0, 1).unsqueeze(0), + ) + + +class CogVideoXPatchEmbed(nn.Module): + def __init__( + self, + patch_size: int = 2, + in_channels: int = 16, + embed_dim: int = 1920, + text_embed_dim: int = 4096, + bias: bool = True, + ) -> None: + super().__init__() + self.patch_size = patch_size + + self.proj = nn.Conv2d( + in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias + ) + self.text_proj = nn.Linear(text_embed_dim, embed_dim) + + def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor): + r""" + Args: + text_embeds (`torch.Tensor`): + Input text embeddings. Expected shape: (batch_size, seq_length, embedding_dim). + image_embeds (`torch.Tensor`): + Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width). + """ + text_embeds = self.text_proj(text_embeds) + + batch, num_frames, channels, height, width = image_embeds.shape + image_embeds = image_embeds.reshape(-1, channels, height, width) + image_embeds = self.proj(image_embeds) + image_embeds = image_embeds.view(batch, num_frames, *image_embeds.shape[1:]) + image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels] + image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels] + + embeds = torch.cat( + [text_embeds, image_embeds], dim=1 + ).contiguous() # [batch, seq_length + num_frames x height x width, channels] + return embeds + + +def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True): + """ + RoPE for image tokens with 2d structure. + + Args: + embed_dim: (`int`): + The embedding dimension size + crops_coords (`Tuple[int]`) + The top-left and bottom-right coordinates of the crop. + grid_size (`Tuple[int]`): + The grid size of the positional embedding. + use_real (`bool`): + If True, return real part and imaginary part separately. Otherwise, return complex numbers. + + Returns: + `torch.Tensor`: positional embedding with shape `( grid_size * grid_size, embed_dim/2)`. + """ + start, stop = crops_coords + grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32) + grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) # [2, W, H] + + grid = grid.reshape([2, 1, *grid.shape[1:]]) + pos_embed = get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real) + return pos_embed + + +def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False): + assert embed_dim % 4 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_rotary_pos_embed( + embed_dim // 2, grid[0].reshape(-1), use_real=use_real + ) # (H*W, D/2) if use_real else (H*W, D/4) + emb_w = get_1d_rotary_pos_embed( + embed_dim // 2, grid[1].reshape(-1), use_real=use_real + ) # (H*W, D/2) if use_real else (H*W, D/4) + + if use_real: + cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D) + sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D) + return cos, sin + else: + emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D/2) + return emb + + +def get_2d_rotary_pos_embed_lumina(embed_dim, len_h, len_w, linear_factor=1.0, ntk_factor=1.0): + assert embed_dim % 4 == 0 + + emb_h = get_1d_rotary_pos_embed( + embed_dim // 2, len_h, linear_factor=linear_factor, ntk_factor=ntk_factor + ) # (H, D/4) + emb_w = get_1d_rotary_pos_embed( + embed_dim // 2, len_w, linear_factor=linear_factor, ntk_factor=ntk_factor + ) # (W, D/4) + emb_h = emb_h.view(len_h, 1, embed_dim // 4, 1).repeat(1, len_w, 1, 1) # (H, W, D/4, 1) + emb_w = emb_w.view(1, len_w, embed_dim // 4, 1).repeat(len_h, 1, 1, 1) # (H, W, D/4, 1) + + emb = torch.cat([emb_h, emb_w], dim=-1).flatten(2) # (H, W, D/2) + return emb + + +def get_1d_rotary_pos_embed( + dim: int, + pos: Union[np.ndarray, int], + theta: float = 10000.0, + use_real=False, + linear_factor=1.0, + ntk_factor=1.0, + repeat_interleave_real=True, +): + """ + Precompute the frequency tensor for complex exponentials (cis) with given dimensions. + + This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end + index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64 + data type. + + Args: + dim (`int`): Dimension of the frequency tensor. + pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar + theta (`float`, *optional*, defaults to 10000.0): + Scaling factor for frequency computation. Defaults to 10000.0. + use_real (`bool`, *optional*): + If True, return real part and imaginary part separately. Otherwise, return complex numbers. + linear_factor (`float`, *optional*, defaults to 1.0): + Scaling factor for the context extrapolation. Defaults to 1.0. + ntk_factor (`float`, *optional*, defaults to 1.0): + Scaling factor for the NTK-Aware RoPE. Defaults to 1.0. + repeat_interleave_real (`bool`, *optional*, defaults to `True`): + If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`. + Otherwise, they are concateanted with themselves. + Returns: + `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2] + """ + assert dim % 2 == 0 + + if isinstance(pos, int): + pos = np.arange(pos) + theta = theta * ntk_factor + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) / linear_factor # [D/2] + t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S] + freqs = torch.outer(t, freqs).float() # type: ignore # [S, D/2] + if use_real and repeat_interleave_real: + freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D] + freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D] + return freqs_cos, freqs_sin + elif use_real: + freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1) # [S, D] + freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1) # [S, D] + return freqs_cos, freqs_sin + else: + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2] + return freqs_cis + + +def apply_rotary_emb( + x: torch.Tensor, + freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], + use_real: bool = True, + use_real_unbind_dim: int = -1, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings + to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are + reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting + tensors contain rotary embeddings and are returned as real tensors. + + Args: + x (`torch.Tensor`): + Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply + freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],) + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. + """ + if use_real: + cos, sin = freqs_cis # [S, D] + cos = cos[None, None] + sin = sin[None, None] + cos, sin = cos.to(x.device), sin.to(x.device) + + if use_real_unbind_dim == -1: + # Use for example in Lumina + x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2] + x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) + elif use_real_unbind_dim == -2: + # Use for example in Stable Audio + x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2] + x_rotated = torch.cat([-x_imag, x_real], dim=-1) + else: + raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.") + + out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) + + return out + else: + x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) + freqs_cis = freqs_cis.unsqueeze(2) + x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3) + + return x_out.type_as(x) + + +class TimestepEmbedding(nn.Module): + def __init__( + self, + in_channels: int, + time_embed_dim: int, + act_fn: str = "silu", + out_dim: int = None, + post_act_fn: Optional[str] = None, + cond_proj_dim=None, + sample_proj_bias=True, + ): + super().__init__() + + self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias) + + if cond_proj_dim is not None: + self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False) + else: + self.cond_proj = None + + self.act = get_activation(act_fn) + + if out_dim is not None: + time_embed_dim_out = out_dim + else: + time_embed_dim_out = time_embed_dim + self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias) + + if post_act_fn is None: + self.post_act = None + else: + self.post_act = get_activation(post_act_fn) + + def forward(self, sample, condition=None): + if condition is not None: + sample = sample + self.cond_proj(condition) + sample = self.linear_1(sample) + + if self.act is not None: + sample = self.act(sample) + + sample = self.linear_2(sample) + + if self.post_act is not None: + sample = self.post_act(sample) + return sample + + +class Timesteps(nn.Module): + def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1): + super().__init__() + self.num_channels = num_channels + self.flip_sin_to_cos = flip_sin_to_cos + self.downscale_freq_shift = downscale_freq_shift + self.scale = scale + + def forward(self, timesteps): + t_emb = get_timestep_embedding( + timesteps, + self.num_channels, + flip_sin_to_cos=self.flip_sin_to_cos, + downscale_freq_shift=self.downscale_freq_shift, + scale=self.scale, + ) + return t_emb + + +class GaussianFourierProjection(nn.Module): + """Gaussian Fourier embeddings for noise levels.""" + + def __init__( + self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False + ): + super().__init__() + self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) + self.log = log + self.flip_sin_to_cos = flip_sin_to_cos + + if set_W_to_weight: + # to delete later + del self.weight + self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) + self.weight = self.W + del self.W + + def forward(self, x): + if self.log: + x = torch.log(x) + + x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi + + if self.flip_sin_to_cos: + out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1) + else: + out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) + return out + + +class SinusoidalPositionalEmbedding(nn.Module): + """Apply positional information to a sequence of embeddings. + + Takes in a sequence of embeddings with shape (batch_size, seq_length, embed_dim) and adds positional embeddings to + them + + Args: + embed_dim: (int): Dimension of the positional embedding. + max_seq_length: Maximum sequence length to apply positional embeddings + + """ + + def __init__(self, embed_dim: int, max_seq_length: int = 32): + super().__init__() + position = torch.arange(max_seq_length).unsqueeze(1) + div_term = torch.exp(torch.arange(0, embed_dim, 2) * (-math.log(10000.0) / embed_dim)) + pe = torch.zeros(1, max_seq_length, embed_dim) + pe[0, :, 0::2] = torch.sin(position * div_term) + pe[0, :, 1::2] = torch.cos(position * div_term) + self.register_buffer("pe", pe) + + def forward(self, x): + _, seq_length, _ = x.shape + x = x + self.pe[:, :seq_length] + return x + + +class ImagePositionalEmbeddings(nn.Module): + """ + Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the + height and width of the latent space. + + For more details, see figure 10 of the dall-e paper: https://arxiv.org/abs/2102.12092 + + For VQ-diffusion: + + Output vector embeddings are used as input for the transformer. + + Note that the vector embeddings for the transformer are different than the vector embeddings from the VQVAE. + + Args: + num_embed (`int`): + Number of embeddings for the latent pixels embeddings. + height (`int`): + Height of the latent image i.e. the number of height embeddings. + width (`int`): + Width of the latent image i.e. the number of width embeddings. + embed_dim (`int`): + Dimension of the produced vector embeddings. Used for the latent pixel, height, and width embeddings. + """ + + def __init__( + self, + num_embed: int, + height: int, + width: int, + embed_dim: int, + ): + super().__init__() + + self.height = height + self.width = width + self.num_embed = num_embed + self.embed_dim = embed_dim + + self.emb = nn.Embedding(self.num_embed, embed_dim) + self.height_emb = nn.Embedding(self.height, embed_dim) + self.width_emb = nn.Embedding(self.width, embed_dim) + + def forward(self, index): + emb = self.emb(index) + + height_emb = self.height_emb(torch.arange(self.height, device=index.device).view(1, self.height)) + + # 1 x H x D -> 1 x H x 1 x D + height_emb = height_emb.unsqueeze(2) + + width_emb = self.width_emb(torch.arange(self.width, device=index.device).view(1, self.width)) + + # 1 x W x D -> 1 x 1 x W x D + width_emb = width_emb.unsqueeze(1) + + pos_emb = height_emb + width_emb + + # 1 x H x W x D -> 1 x L xD + pos_emb = pos_emb.view(1, self.height * self.width, -1) + + emb = emb + pos_emb[:, : emb.shape[1], :] + + return emb + + +class LabelEmbedding(nn.Module): + """ + Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. + + Args: + num_classes (`int`): The number of classes. + hidden_size (`int`): The size of the vector embeddings. + dropout_prob (`float`): The probability of dropping a label. + """ + + def __init__(self, num_classes, hidden_size, dropout_prob): + super().__init__() + use_cfg_embedding = dropout_prob > 0 + self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size) + self.num_classes = num_classes + self.dropout_prob = dropout_prob + + def token_drop(self, labels, force_drop_ids=None): + """ + Drops labels to enable classifier-free guidance. + """ + if force_drop_ids is None: + drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob + else: + drop_ids = torch.tensor(force_drop_ids == 1) + labels = torch.where(drop_ids, self.num_classes, labels) + return labels + + def forward(self, labels: torch.LongTensor, force_drop_ids=None): + use_dropout = self.dropout_prob > 0 + if (self.training and use_dropout) or (force_drop_ids is not None): + labels = self.token_drop(labels, force_drop_ids) + embeddings = self.embedding_table(labels) + return embeddings + + +class TextImageProjection(nn.Module): + def __init__( + self, + text_embed_dim: int = 1024, + image_embed_dim: int = 768, + cross_attention_dim: int = 768, + num_image_text_embeds: int = 10, + ): + super().__init__() + + self.num_image_text_embeds = num_image_text_embeds + self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim) + self.text_proj = nn.Linear(text_embed_dim, cross_attention_dim) + + def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor): + batch_size = text_embeds.shape[0] + + # image + image_text_embeds = self.image_embeds(image_embeds) + image_text_embeds = image_text_embeds.reshape(batch_size, self.num_image_text_embeds, -1) + + # text + text_embeds = self.text_proj(text_embeds) + + return torch.cat([image_text_embeds, text_embeds], dim=1) + + +class ImageProjection(nn.Module): + def __init__( + self, + image_embed_dim: int = 768, + cross_attention_dim: int = 768, + num_image_text_embeds: int = 32, + ): + super().__init__() + + self.num_image_text_embeds = num_image_text_embeds + self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim) + self.norm = nn.LayerNorm(cross_attention_dim) + + def forward(self, image_embeds: torch.Tensor): + batch_size = image_embeds.shape[0] + + # image + image_embeds = self.image_embeds(image_embeds) + image_embeds = image_embeds.reshape(batch_size, self.num_image_text_embeds, -1) + image_embeds = self.norm(image_embeds) + return image_embeds + + +class IPAdapterFullImageProjection(nn.Module): + def __init__(self, image_embed_dim=1024, cross_attention_dim=1024): + super().__init__() + from .attention import FeedForward + + self.ff = FeedForward(image_embed_dim, cross_attention_dim, mult=1, activation_fn="gelu") + self.norm = nn.LayerNorm(cross_attention_dim) + + def forward(self, image_embeds: torch.Tensor): + return self.norm(self.ff(image_embeds)) + + +class IPAdapterFaceIDImageProjection(nn.Module): + def __init__(self, image_embed_dim=1024, cross_attention_dim=1024, mult=1, num_tokens=1): + super().__init__() + from .attention import FeedForward + + self.num_tokens = num_tokens + self.cross_attention_dim = cross_attention_dim + self.ff = FeedForward(image_embed_dim, cross_attention_dim * num_tokens, mult=mult, activation_fn="gelu") + self.norm = nn.LayerNorm(cross_attention_dim) + + def forward(self, image_embeds: torch.Tensor): + x = self.ff(image_embeds) + x = x.reshape(-1, self.num_tokens, self.cross_attention_dim) + return self.norm(x) + + +class CombinedTimestepLabelEmbeddings(nn.Module): + def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1): + super().__init__() + + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1) + self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + self.class_embedder = LabelEmbedding(num_classes, embedding_dim, class_dropout_prob) + + def forward(self, timestep, class_labels, hidden_dtype=None): + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D) + + class_labels = self.class_embedder(class_labels) # (N, D) + + conditioning = timesteps_emb + class_labels # (N, D) + + return conditioning + + +class CombinedTimestepTextProjEmbeddings(nn.Module): + def __init__(self, embedding_dim, pooled_projection_dim): + super().__init__() + + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu") + + def forward(self, timestep, pooled_projection): + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D) + + pooled_projections = self.text_embedder(pooled_projection) + + conditioning = timesteps_emb + pooled_projections + + return conditioning + + +class CombinedTimestepGuidanceTextProjEmbeddings(nn.Module): + def __init__(self, embedding_dim, pooled_projection_dim): + super().__init__() + + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu") + + def forward(self, timestep, guidance, pooled_projection): + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D) + + guidance_proj = self.time_proj(guidance) + guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=pooled_projection.dtype)) # (N, D) + + time_guidance_emb = timesteps_emb + guidance_emb + + pooled_projections = self.text_embedder(pooled_projection) + conditioning = time_guidance_emb + pooled_projections + + return conditioning + + +class HunyuanDiTAttentionPool(nn.Module): + # Copied from https://github.com/Tencent/HunyuanDiT/blob/cb709308d92e6c7e8d59d0dff41b74d35088db6a/hydit/modules/poolers.py#L6 + + def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): + super().__init__() + self.positional_embedding = nn.Parameter(torch.randn(spacial_dim + 1, embed_dim) / embed_dim**0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x): + x = x.permute(1, 0, 2) # NLC -> LNC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC + x = x + self.positional_embedding[:, None, :].to(x.dtype) # (L+1)NC + x, _ = F.multi_head_attention_forward( + query=x[:1], + key=x, + value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0, + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False, + ) + return x.squeeze(0) + + +class HunyuanCombinedTimestepTextSizeStyleEmbedding(nn.Module): + def __init__( + self, + embedding_dim, + pooled_projection_dim=1024, + seq_len=256, + cross_attention_dim=2048, + use_style_cond_and_image_meta_size=True, + ): + super().__init__() + + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + + self.size_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + + self.pooler = HunyuanDiTAttentionPool( + seq_len, cross_attention_dim, num_heads=8, output_dim=pooled_projection_dim + ) + + # Here we use a default learned embedder layer for future extension. + self.use_style_cond_and_image_meta_size = use_style_cond_and_image_meta_size + if use_style_cond_and_image_meta_size: + self.style_embedder = nn.Embedding(1, embedding_dim) + extra_in_dim = 256 * 6 + embedding_dim + pooled_projection_dim + else: + extra_in_dim = pooled_projection_dim + + self.extra_embedder = PixArtAlphaTextProjection( + in_features=extra_in_dim, + hidden_size=embedding_dim * 4, + out_features=embedding_dim, + act_fn="silu_fp32", + ) + + def forward(self, timestep, encoder_hidden_states, image_meta_size, style, hidden_dtype=None): + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, 256) + + # extra condition1: text + pooled_projections = self.pooler(encoder_hidden_states) # (N, 1024) + + if self.use_style_cond_and_image_meta_size: + # extra condition2: image meta size embedding + image_meta_size = self.size_proj(image_meta_size.view(-1)) + image_meta_size = image_meta_size.to(dtype=hidden_dtype) + image_meta_size = image_meta_size.view(-1, 6 * 256) # (N, 1536) + + # extra condition3: style embedding + style_embedding = self.style_embedder(style) # (N, embedding_dim) + + # Concatenate all extra vectors + extra_cond = torch.cat([pooled_projections, image_meta_size, style_embedding], dim=1) + else: + extra_cond = torch.cat([pooled_projections], dim=1) + + conditioning = timesteps_emb + self.extra_embedder(extra_cond) # [B, D] + + return conditioning + + +class LuminaCombinedTimestepCaptionEmbedding(nn.Module): + def __init__(self, hidden_size=4096, cross_attention_dim=2048, frequency_embedding_size=256): + super().__init__() + self.time_proj = Timesteps( + num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0.0 + ) + + self.timestep_embedder = TimestepEmbedding(in_channels=frequency_embedding_size, time_embed_dim=hidden_size) + + self.caption_embedder = nn.Sequential( + nn.LayerNorm(cross_attention_dim), + nn.Linear( + cross_attention_dim, + hidden_size, + bias=True, + ), + ) + + def forward(self, timestep, caption_feat, caption_mask): + # timestep embedding: + time_freq = self.time_proj(timestep) + time_embed = self.timestep_embedder(time_freq.to(dtype=self.timestep_embedder.linear_1.weight.dtype)) + + # caption condition embedding: + caption_mask_float = caption_mask.float().unsqueeze(-1) + caption_feats_pool = (caption_feat * caption_mask_float).sum(dim=1) / caption_mask_float.sum(dim=1) + caption_feats_pool = caption_feats_pool.to(caption_feat) + caption_embed = self.caption_embedder(caption_feats_pool) + + conditioning = time_embed + caption_embed + + return conditioning + + +class TextTimeEmbedding(nn.Module): + def __init__(self, encoder_dim: int, time_embed_dim: int, num_heads: int = 64): + super().__init__() + self.norm1 = nn.LayerNorm(encoder_dim) + self.pool = AttentionPooling(num_heads, encoder_dim) + self.proj = nn.Linear(encoder_dim, time_embed_dim) + self.norm2 = nn.LayerNorm(time_embed_dim) + + def forward(self, hidden_states): + hidden_states = self.norm1(hidden_states) + hidden_states = self.pool(hidden_states) + hidden_states = self.proj(hidden_states) + hidden_states = self.norm2(hidden_states) + return hidden_states + + +class TextImageTimeEmbedding(nn.Module): + def __init__(self, text_embed_dim: int = 768, image_embed_dim: int = 768, time_embed_dim: int = 1536): + super().__init__() + self.text_proj = nn.Linear(text_embed_dim, time_embed_dim) + self.text_norm = nn.LayerNorm(time_embed_dim) + self.image_proj = nn.Linear(image_embed_dim, time_embed_dim) + + def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor): + # text + time_text_embeds = self.text_proj(text_embeds) + time_text_embeds = self.text_norm(time_text_embeds) + + # image + time_image_embeds = self.image_proj(image_embeds) + + return time_image_embeds + time_text_embeds + + +class ImageTimeEmbedding(nn.Module): + def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536): + super().__init__() + self.image_proj = nn.Linear(image_embed_dim, time_embed_dim) + self.image_norm = nn.LayerNorm(time_embed_dim) + + def forward(self, image_embeds: torch.Tensor): + # image + time_image_embeds = self.image_proj(image_embeds) + time_image_embeds = self.image_norm(time_image_embeds) + return time_image_embeds + + +class ImageHintTimeEmbedding(nn.Module): + def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536): + super().__init__() + self.image_proj = nn.Linear(image_embed_dim, time_embed_dim) + self.image_norm = nn.LayerNorm(time_embed_dim) + self.input_hint_block = nn.Sequential( + nn.Conv2d(3, 16, 3, padding=1), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1), + nn.SiLU(), + nn.Conv2d(16, 32, 3, padding=1, stride=2), + nn.SiLU(), + nn.Conv2d(32, 32, 3, padding=1), + nn.SiLU(), + nn.Conv2d(32, 96, 3, padding=1, stride=2), + nn.SiLU(), + nn.Conv2d(96, 96, 3, padding=1), + nn.SiLU(), + nn.Conv2d(96, 256, 3, padding=1, stride=2), + nn.SiLU(), + nn.Conv2d(256, 4, 3, padding=1), + ) + + def forward(self, image_embeds: torch.Tensor, hint: torch.Tensor): + # image + time_image_embeds = self.image_proj(image_embeds) + time_image_embeds = self.image_norm(time_image_embeds) + hint = self.input_hint_block(hint) + return time_image_embeds, hint + + +class AttentionPooling(nn.Module): + # Copied from https://github.com/deep-floyd/IF/blob/2f91391f27dd3c468bf174be5805b4cc92980c0b/deepfloyd_if/model/nn.py#L54 + + def __init__(self, num_heads, embed_dim, dtype=None): + super().__init__() + self.dtype = dtype + self.positional_embedding = nn.Parameter(torch.randn(1, embed_dim) / embed_dim**0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype) + self.q_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype) + self.v_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype) + self.num_heads = num_heads + self.dim_per_head = embed_dim // self.num_heads + + def forward(self, x): + bs, length, width = x.size() + + def shape(x): + # (bs, length, width) --> (bs, length, n_heads, dim_per_head) + x = x.view(bs, -1, self.num_heads, self.dim_per_head) + # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) + x = x.transpose(1, 2) + # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) + x = x.reshape(bs * self.num_heads, -1, self.dim_per_head) + # (bs*n_heads, length, dim_per_head) --> (bs*n_heads, dim_per_head, length) + x = x.transpose(1, 2) + return x + + class_token = x.mean(dim=1, keepdim=True) + self.positional_embedding.to(x.dtype) + x = torch.cat([class_token, x], dim=1) # (bs, length+1, width) + + # (bs*n_heads, class_token_length, dim_per_head) + q = shape(self.q_proj(class_token)) + # (bs*n_heads, length+class_token_length, dim_per_head) + k = shape(self.k_proj(x)) + v = shape(self.v_proj(x)) + + # (bs*n_heads, class_token_length, length+class_token_length): + scale = 1 / math.sqrt(math.sqrt(self.dim_per_head)) + weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + + # (bs*n_heads, dim_per_head, class_token_length) + a = torch.einsum("bts,bcs->bct", weight, v) + + # (bs, length+1, width) + a = a.reshape(bs, -1, 1).transpose(1, 2) + + return a[:, 0, :] # cls_token + + +def get_fourier_embeds_from_boundingbox(embed_dim, box): + """ + Args: + embed_dim: int + box: a 3-D tensor [B x N x 4] representing the bounding boxes for GLIGEN pipeline + Returns: + [B x N x embed_dim] tensor of positional embeddings + """ + + batch_size, num_boxes = box.shape[:2] + + emb = 100 ** (torch.arange(embed_dim) / embed_dim) + emb = emb[None, None, None].to(device=box.device, dtype=box.dtype) + emb = emb * box.unsqueeze(-1) + + emb = torch.stack((emb.sin(), emb.cos()), dim=-1) + emb = emb.permute(0, 1, 3, 4, 2).reshape(batch_size, num_boxes, embed_dim * 2 * 4) + + return emb + + +class GLIGENTextBoundingboxProjection(nn.Module): + def __init__(self, positive_len, out_dim, feature_type="text-only", fourier_freqs=8): + super().__init__() + self.positive_len = positive_len + self.out_dim = out_dim + + self.fourier_embedder_dim = fourier_freqs + self.position_dim = fourier_freqs * 2 * 4 # 2: sin/cos, 4: xyxy + + if isinstance(out_dim, tuple): + out_dim = out_dim[0] + + if feature_type == "text-only": + self.linears = nn.Sequential( + nn.Linear(self.positive_len + self.position_dim, 512), + nn.SiLU(), + nn.Linear(512, 512), + nn.SiLU(), + nn.Linear(512, out_dim), + ) + self.null_positive_feature = torch.nn.Parameter(torch.zeros([self.positive_len])) + + elif feature_type == "text-image": + self.linears_text = nn.Sequential( + nn.Linear(self.positive_len + self.position_dim, 512), + nn.SiLU(), + nn.Linear(512, 512), + nn.SiLU(), + nn.Linear(512, out_dim), + ) + self.linears_image = nn.Sequential( + nn.Linear(self.positive_len + self.position_dim, 512), + nn.SiLU(), + nn.Linear(512, 512), + nn.SiLU(), + nn.Linear(512, out_dim), + ) + self.null_text_feature = torch.nn.Parameter(torch.zeros([self.positive_len])) + self.null_image_feature = torch.nn.Parameter(torch.zeros([self.positive_len])) + + self.null_position_feature = torch.nn.Parameter(torch.zeros([self.position_dim])) + + def forward( + self, + boxes, + masks, + positive_embeddings=None, + phrases_masks=None, + image_masks=None, + phrases_embeddings=None, + image_embeddings=None, + ): + masks = masks.unsqueeze(-1) + + # embedding position (it may includes padding as placeholder) + xyxy_embedding = get_fourier_embeds_from_boundingbox(self.fourier_embedder_dim, boxes) # B*N*4 -> B*N*C + + # learnable null embedding + xyxy_null = self.null_position_feature.view(1, 1, -1) + + # replace padding with learnable null embedding + xyxy_embedding = xyxy_embedding * masks + (1 - masks) * xyxy_null + + # positionet with text only information + if positive_embeddings is not None: + # learnable null embedding + positive_null = self.null_positive_feature.view(1, 1, -1) + + # replace padding with learnable null embedding + positive_embeddings = positive_embeddings * masks + (1 - masks) * positive_null + + objs = self.linears(torch.cat([positive_embeddings, xyxy_embedding], dim=-1)) + + # positionet with text and image information + else: + phrases_masks = phrases_masks.unsqueeze(-1) + image_masks = image_masks.unsqueeze(-1) + + # learnable null embedding + text_null = self.null_text_feature.view(1, 1, -1) + image_null = self.null_image_feature.view(1, 1, -1) + + # replace padding with learnable null embedding + phrases_embeddings = phrases_embeddings * phrases_masks + (1 - phrases_masks) * text_null + image_embeddings = image_embeddings * image_masks + (1 - image_masks) * image_null + + objs_text = self.linears_text(torch.cat([phrases_embeddings, xyxy_embedding], dim=-1)) + objs_image = self.linears_image(torch.cat([image_embeddings, xyxy_embedding], dim=-1)) + objs = torch.cat([objs_text, objs_image], dim=1) + + return objs + + +class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module): + """ + For PixArt-Alpha. + + Reference: + https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29 + """ + + def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False): + super().__init__() + + self.outdim = size_emb_dim + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + + self.use_additional_conditions = use_additional_conditions + if use_additional_conditions: + self.additional_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.resolution_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim) + self.aspect_ratio_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim) + + def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype): + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D) + + if self.use_additional_conditions: + resolution_emb = self.additional_condition_proj(resolution.flatten()).to(hidden_dtype) + resolution_emb = self.resolution_embedder(resolution_emb).reshape(batch_size, -1) + aspect_ratio_emb = self.additional_condition_proj(aspect_ratio.flatten()).to(hidden_dtype) + aspect_ratio_emb = self.aspect_ratio_embedder(aspect_ratio_emb).reshape(batch_size, -1) + conditioning = timesteps_emb + torch.cat([resolution_emb, aspect_ratio_emb], dim=1) + else: + conditioning = timesteps_emb + + return conditioning + + +class PixArtAlphaTextProjection(nn.Module): + """ + Projects caption embeddings. Also handles dropout for classifier-free guidance. + + Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py + """ + + def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh"): + super().__init__() + if out_features is None: + out_features = hidden_size + self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True) + if act_fn == "gelu_tanh": + self.act_1 = nn.GELU(approximate="tanh") + elif act_fn == "silu": + self.act_1 = nn.SiLU() + elif act_fn == "silu_fp32": + self.act_1 = FP32SiLU() + else: + raise ValueError(f"Unknown activation function: {act_fn}") + self.linear_2 = nn.Linear(in_features=hidden_size, out_features=out_features, bias=True) + + def forward(self, caption): + hidden_states = self.linear_1(caption) + hidden_states = self.act_1(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + +class IPAdapterPlusImageProjectionBlock(nn.Module): + def __init__( + self, + embed_dims: int = 768, + dim_head: int = 64, + heads: int = 16, + ffn_ratio: float = 4, + ) -> None: + super().__init__() + from .attention import FeedForward + + self.ln0 = nn.LayerNorm(embed_dims) + self.ln1 = nn.LayerNorm(embed_dims) + self.attn = Attention( + query_dim=embed_dims, + dim_head=dim_head, + heads=heads, + out_bias=False, + ) + self.ff = nn.Sequential( + nn.LayerNorm(embed_dims), + FeedForward(embed_dims, embed_dims, activation_fn="gelu", mult=ffn_ratio, bias=False), + ) + + def forward(self, x, latents, residual): + encoder_hidden_states = self.ln0(x) + latents = self.ln1(latents) + encoder_hidden_states = torch.cat([encoder_hidden_states, latents], dim=-2) + latents = self.attn(latents, encoder_hidden_states) + residual + latents = self.ff(latents) + latents + return latents + + +class IPAdapterPlusImageProjection(nn.Module): + """Resampler of IP-Adapter Plus. + + Args: + embed_dims (int): The feature dimension. Defaults to 768. output_dims (int): The number of output channels, + that is the same + number of the channels in the `unet.config.cross_attention_dim`. Defaults to 1024. + hidden_dims (int): + The number of hidden channels. Defaults to 1280. depth (int): The number of blocks. Defaults + to 8. dim_head (int): The number of head channels. Defaults to 64. heads (int): Parallel attention heads. + Defaults to 16. num_queries (int): + The number of queries. Defaults to 8. ffn_ratio (float): The expansion ratio + of feedforward network hidden + layer channels. Defaults to 4. + """ + + def __init__( + self, + embed_dims: int = 768, + output_dims: int = 1024, + hidden_dims: int = 1280, + depth: int = 4, + dim_head: int = 64, + heads: int = 16, + num_queries: int = 8, + ffn_ratio: float = 4, + ) -> None: + super().__init__() + self.latents = nn.Parameter(torch.randn(1, num_queries, hidden_dims) / hidden_dims**0.5) + + self.proj_in = nn.Linear(embed_dims, hidden_dims) + + self.proj_out = nn.Linear(hidden_dims, output_dims) + self.norm_out = nn.LayerNorm(output_dims) + + self.layers = nn.ModuleList( + [IPAdapterPlusImageProjectionBlock(hidden_dims, dim_head, heads, ffn_ratio) for _ in range(depth)] + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass. + + Args: + x (torch.Tensor): Input Tensor. + Returns: + torch.Tensor: Output Tensor. + """ + latents = self.latents.repeat(x.size(0), 1, 1) + + x = self.proj_in(x) + + for block in self.layers: + residual = latents + latents = block(x, latents, residual) + + latents = self.proj_out(latents) + return self.norm_out(latents) + + +class IPAdapterFaceIDPlusImageProjection(nn.Module): + """FacePerceiverResampler of IP-Adapter Plus. + + Args: + embed_dims (int): The feature dimension. Defaults to 768. output_dims (int): The number of output channels, + that is the same + number of the channels in the `unet.config.cross_attention_dim`. Defaults to 1024. + hidden_dims (int): + The number of hidden channels. Defaults to 1280. depth (int): The number of blocks. Defaults + to 8. dim_head (int): The number of head channels. Defaults to 64. heads (int): Parallel attention heads. + Defaults to 16. num_tokens (int): Number of tokens num_queries (int): The number of queries. Defaults to 8. + ffn_ratio (float): The expansion ratio of feedforward network hidden + layer channels. Defaults to 4. + ffproj_ratio (float): The expansion ratio of feedforward network hidden + layer channels (for ID embeddings). Defaults to 4. + """ + + def __init__( + self, + embed_dims: int = 768, + output_dims: int = 768, + hidden_dims: int = 1280, + id_embeddings_dim: int = 512, + depth: int = 4, + dim_head: int = 64, + heads: int = 16, + num_tokens: int = 4, + num_queries: int = 8, + ffn_ratio: float = 4, + ffproj_ratio: int = 2, + ) -> None: + super().__init__() + from .attention import FeedForward + + self.num_tokens = num_tokens + self.embed_dim = embed_dims + self.clip_embeds = None + self.shortcut = False + self.shortcut_scale = 1.0 + + self.proj = FeedForward(id_embeddings_dim, embed_dims * num_tokens, activation_fn="gelu", mult=ffproj_ratio) + self.norm = nn.LayerNorm(embed_dims) + + self.proj_in = nn.Linear(hidden_dims, embed_dims) + + self.proj_out = nn.Linear(embed_dims, output_dims) + self.norm_out = nn.LayerNorm(output_dims) + + self.layers = nn.ModuleList( + [IPAdapterPlusImageProjectionBlock(embed_dims, dim_head, heads, ffn_ratio) for _ in range(depth)] + ) + + def forward(self, id_embeds: torch.Tensor) -> torch.Tensor: + """Forward pass. + + Args: + id_embeds (torch.Tensor): Input Tensor (ID embeds). + Returns: + torch.Tensor: Output Tensor. + """ + id_embeds = id_embeds.to(self.clip_embeds.dtype) + id_embeds = self.proj(id_embeds) + id_embeds = id_embeds.reshape(-1, self.num_tokens, self.embed_dim) + id_embeds = self.norm(id_embeds) + latents = id_embeds + + clip_embeds = self.proj_in(self.clip_embeds) + x = clip_embeds.reshape(-1, clip_embeds.shape[2], clip_embeds.shape[3]) + + for block in self.layers: + residual = latents + latents = block(x, latents, residual) + + latents = self.proj_out(latents) + out = self.norm_out(latents) + if self.shortcut: + out = id_embeds + self.shortcut_scale * out + return out + + +class MultiIPAdapterImageProjection(nn.Module): + def __init__(self, IPAdapterImageProjectionLayers: Union[List[nn.Module], Tuple[nn.Module]]): + super().__init__() + self.image_projection_layers = nn.ModuleList(IPAdapterImageProjectionLayers) + + def forward(self, image_embeds: List[torch.Tensor]): + projected_image_embeds = [] + + # currently, we accept `image_embeds` as + # 1. a tensor (deprecated) with shape [batch_size, embed_dim] or [batch_size, sequence_length, embed_dim] + # 2. list of `n` tensors where `n` is number of ip-adapters, each tensor can hae shape [batch_size, num_images, embed_dim] or [batch_size, num_images, sequence_length, embed_dim] + if not isinstance(image_embeds, list): + deprecation_message = ( + "You have passed a tensor as `image_embeds`.This is deprecated and will be removed in a future release." + " Please make sure to update your script to pass `image_embeds` as a list of tensors to suppress this warning." + ) + deprecate("image_embeds not a list", "1.0.0", deprecation_message, standard_warn=False) + image_embeds = [image_embeds.unsqueeze(1)] + + if len(image_embeds) != len(self.image_projection_layers): + raise ValueError( + f"image_embeds must have the same length as image_projection_layers, got {len(image_embeds)} and {len(self.image_projection_layers)}" + ) + + for image_embed, image_projection_layer in zip(image_embeds, self.image_projection_layers): + batch_size, num_images = image_embed.shape[0], image_embed.shape[1] + image_embed = image_embed.reshape((batch_size * num_images,) + image_embed.shape[2:]) + image_embed = image_projection_layer(image_embed) + image_embed = image_embed.reshape((batch_size, num_images) + image_embed.shape[1:]) + + projected_image_embeds.append(image_embed) + + return projected_image_embeds diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/normalization.py b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/normalization.py new file mode 100644 index 0000000000..6a23843a9e --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/normalization.py @@ -0,0 +1,451 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numbers +from typing import Dict, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from diffusers.utils import is_torch_version +from .activations import get_activation +from .embeddings import ( + CombinedTimestepLabelEmbeddings, + PixArtAlphaCombinedTimestepSizeEmbeddings, +) + + +class AdaLayerNorm(nn.Module): + r""" + Norm layer modified to incorporate timestep embeddings. + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + num_embeddings (`int`, *optional*): The size of the embeddings dictionary. + output_dim (`int`, *optional*): + norm_elementwise_affine (`bool`, defaults to `False): + norm_eps (`bool`, defaults to `False`): + chunk_dim (`int`, defaults to `0`): + """ + + def __init__( + self, + embedding_dim: int, + num_embeddings: Optional[int] = None, + output_dim: Optional[int] = None, + norm_elementwise_affine: bool = False, + norm_eps: float = 1e-5, + chunk_dim: int = 0, + ): + super().__init__() + + self.chunk_dim = chunk_dim + output_dim = output_dim or embedding_dim * 2 + + if num_embeddings is not None: + self.emb = nn.Embedding(num_embeddings, embedding_dim) + else: + self.emb = None + + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, output_dim) + self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine) + + def forward( + self, x: torch.Tensor, timestep: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None + ) -> torch.Tensor: + if self.emb is not None: + temb = self.emb(timestep) + + temb = self.linear(self.silu(temb)) + + if self.chunk_dim == 1: + # This is a bit weird why we have the order of "shift, scale" here and "scale, shift" in the + # other if-branch. This branch is specific to CogVideoX for now. + shift, scale = temb.chunk(2, dim=1) + shift = shift[:, None, :] + scale = scale[:, None, :] + else: + scale, shift = temb.chunk(2, dim=0) + + x = self.norm(x) * (1 + scale) + shift + return x + + +class FP32LayerNorm(nn.LayerNorm): + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + origin_dtype = inputs.dtype + return F.layer_norm( + inputs.float(), + self.normalized_shape, + self.weight.float() if self.weight is not None else None, + self.bias.float() if self.bias is not None else None, + self.eps, + ).to(origin_dtype) + + +class AdaLayerNormZero(nn.Module): + r""" + Norm layer adaptive layer norm zero (adaLN-Zero). + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + num_embeddings (`int`): The size of the embeddings dictionary. + """ + + def __init__(self, embedding_dim: int, num_embeddings: Optional[int] = None, norm_type="layer_norm", bias=True): + super().__init__() + if num_embeddings is not None: + self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim) + else: + self.emb = None + + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=bias) + if norm_type == "layer_norm": + self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6) + elif norm_type == "fp32_layer_norm": + self.norm = FP32LayerNorm(embedding_dim, elementwise_affine=False, bias=False) + else: + raise ValueError( + f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'." + ) + + def forward( + self, + x: torch.Tensor, + timestep: Optional[torch.Tensor] = None, + class_labels: Optional[torch.LongTensor] = None, + hidden_dtype: Optional[torch.dtype] = None, + emb: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + if self.emb is not None: + emb = self.emb(timestep, class_labels, hidden_dtype=hidden_dtype) + emb = self.linear(self.silu(emb)) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1) + x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None] + return x, gate_msa, shift_mlp, scale_mlp, gate_mlp + + +class AdaLayerNormZeroSingle(nn.Module): + r""" + Norm layer adaptive layer norm zero (adaLN-Zero). + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + num_embeddings (`int`): The size of the embeddings dictionary. + """ + + def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True): + super().__init__() + + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, 3 * embedding_dim, bias=bias) + if norm_type == "layer_norm": + self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6) + else: + raise ValueError( + f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'." + ) + + def forward( + self, + x: torch.Tensor, + emb: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + emb = self.linear(self.silu(emb)) + shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=1) + x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None] + return x, gate_msa + + +class LuminaRMSNormZero(nn.Module): + """ + Norm layer adaptive RMS normalization zero. + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + """ + + def __init__(self, embedding_dim: int, norm_eps: float, norm_elementwise_affine: bool): + super().__init__() + self.silu = nn.SiLU() + self.linear = nn.Linear( + min(embedding_dim, 1024), + 4 * embedding_dim, + bias=True, + ) + self.norm = RMSNorm(embedding_dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine) + + def forward( + self, + x: torch.Tensor, + emb: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + # emb = self.emb(timestep, encoder_hidden_states, encoder_mask) + emb = self.linear(self.silu(emb)) + scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1) + x = self.norm(x) * (1 + scale_msa[:, None]) + + return x, gate_msa, scale_mlp, gate_mlp + + +class AdaLayerNormSingle(nn.Module): + r""" + Norm layer adaptive layer norm single (adaLN-single). + + As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3). + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + use_additional_conditions (`bool`): To use additional conditions for normalization or not. + """ + + def __init__(self, embedding_dim: int, use_additional_conditions: bool = False): + super().__init__() + + self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings( + embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions + ) + + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True) + + def forward( + self, + timestep: torch.Tensor, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + batch_size: Optional[int] = None, + hidden_dtype: Optional[torch.dtype] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + # No modulation happening here. + embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype) + return self.linear(self.silu(embedded_timestep)), embedded_timestep + + +class AdaGroupNorm(nn.Module): + r""" + GroupNorm layer modified to incorporate timestep embeddings. + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + num_embeddings (`int`): The size of the embeddings dictionary. + num_groups (`int`): The number of groups to separate the channels into. + act_fn (`str`, *optional*, defaults to `None`): The activation function to use. + eps (`float`, *optional*, defaults to `1e-5`): The epsilon value to use for numerical stability. + """ + + def __init__( + self, embedding_dim: int, out_dim: int, num_groups: int, act_fn: Optional[str] = None, eps: float = 1e-5 + ): + super().__init__() + self.num_groups = num_groups + self.eps = eps + + if act_fn is None: + self.act = None + else: + self.act = get_activation(act_fn) + + self.linear = nn.Linear(embedding_dim, out_dim * 2) + + def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: + if self.act: + emb = self.act(emb) + emb = self.linear(emb) + emb = emb[:, :, None, None] + scale, shift = emb.chunk(2, dim=1) + + x = F.group_norm(x, self.num_groups, eps=self.eps) + x = x * (1 + scale) + shift + return x + + +class AdaLayerNormContinuous(nn.Module): + def __init__( + self, + embedding_dim: int, + conditioning_embedding_dim: int, + # NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters + # because the output is immediately scaled and shifted by the projected conditioning embeddings. + # Note that AdaLayerNorm does not let the norm layer have scale and shift parameters. + # However, this is how it was implemented in the original code, and it's rather likely you should + # set `elementwise_affine` to False. + elementwise_affine=True, + eps=1e-5, + bias=True, + norm_type="layer_norm", + ): + super().__init__() + self.silu = nn.SiLU() + self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias) + if norm_type == "layer_norm": + self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias) + elif norm_type == "rms_norm": + self.norm = RMSNorm(embedding_dim, eps, elementwise_affine) + else: + raise ValueError(f"unknown norm_type {norm_type}") + + def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor: + # convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT) + emb = self.linear(self.silu(conditioning_embedding).to(x.dtype)) + scale, shift = torch.chunk(emb, 2, dim=1) + x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :] + return x + + +class LuminaLayerNormContinuous(nn.Module): + def __init__( + self, + embedding_dim: int, + conditioning_embedding_dim: int, + # NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters + # because the output is immediately scaled and shifted by the projected conditioning embeddings. + # Note that AdaLayerNorm does not let the norm layer have scale and shift parameters. + # However, this is how it was implemented in the original code, and it's rather likely you should + # set `elementwise_affine` to False. + elementwise_affine=True, + eps=1e-5, + bias=True, + norm_type="layer_norm", + out_dim: Optional[int] = None, + ): + super().__init__() + # AdaLN + self.silu = nn.SiLU() + self.linear_1 = nn.Linear(conditioning_embedding_dim, embedding_dim, bias=bias) + if norm_type == "layer_norm": + self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias) + else: + raise ValueError(f"unknown norm_type {norm_type}") + # linear_2 + if out_dim is not None: + self.linear_2 = nn.Linear( + embedding_dim, + out_dim, + bias=bias, + ) + + def forward( + self, + x: torch.Tensor, + conditioning_embedding: torch.Tensor, + ) -> torch.Tensor: + # convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT) + emb = self.linear_1(self.silu(conditioning_embedding).to(x.dtype)) + scale = emb + x = self.norm(x) * (1 + scale)[:, None, :] + + if self.linear_2 is not None: + x = self.linear_2(x) + + return x + + +class CogVideoXLayerNormZero(nn.Module): + def __init__( + self, + conditioning_dim: int, + embedding_dim: int, + elementwise_affine: bool = True, + eps: float = 1e-5, + bias: bool = True, + ) -> None: + super().__init__() + + self.silu = nn.SiLU() + self.linear = nn.Linear(conditioning_dim, 6 * embedding_dim, bias=bias) + self.norm = nn.LayerNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine) + + def forward( + self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + shift, scale, gate, enc_shift, enc_scale, enc_gate = self.linear(self.silu(temb)).chunk(6, dim=1) + hidden_states = self.norm(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :] + encoder_hidden_states = self.norm(encoder_hidden_states) * (1 + enc_scale)[:, None, :] + enc_shift[:, None, :] + return hidden_states, encoder_hidden_states, gate[:, None, :], enc_gate[:, None, :] + + +if is_torch_version(">=", "2.1.0"): + LayerNorm = nn.LayerNorm +else: + # Has optional bias parameter compared to torch layer norm + # TODO: replace with torch layernorm once min required torch version >= 2.1 + class LayerNorm(nn.Module): + def __init__(self, dim, eps: float = 1e-5, elementwise_affine: bool = True, bias: bool = True): + super().__init__() + + self.eps = eps + + if isinstance(dim, numbers.Integral): + dim = (dim,) + + self.dim = torch.Size(dim) + + if elementwise_affine: + self.weight = nn.Parameter(torch.ones(dim)) + self.bias = nn.Parameter(torch.zeros(dim)) if bias else None + else: + self.weight = None + self.bias = None + + def forward(self, input): + return F.layer_norm(input, self.dim, self.weight, self.bias, self.eps) + + +class RMSNorm(nn.Module): + def __init__(self, dim, eps: float, elementwise_affine: bool = True): + super().__init__() + + self.eps = eps + + if isinstance(dim, numbers.Integral): + dim = (dim,) + + self.dim = torch.Size(dim) + + if elementwise_affine: + self.weight = nn.Parameter(torch.ones(dim)) + else: + self.weight = None + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.eps) + + if self.weight is not None: + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + hidden_states = hidden_states * self.weight + else: + hidden_states = hidden_states.to(input_dtype) + + return hidden_states + + +class GlobalResponseNorm(nn.Module): + # Taken from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105 + def __init__(self, dim): + super().__init__() + self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim)) + self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim)) + + def forward(self, x): + gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True) + nx = gx / (gx.mean(dim=-1, keepdim=True) + 1e-6) + return self.gamma * (x * nx) + self.beta + x diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/models/__init__.py b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/models/__init__.py new file mode 100644 index 0000000000..6f6a082442 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/models/__init__.py @@ -0,0 +1,2 @@ +from .modeling_stable_audio import StableAudioProjectionModel +from .stable_audio_transformer import StableAudioDiTModel \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/models/modeling_stable_audio.py b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/models/modeling_stable_audio.py new file mode 100644 index 0000000000..b8f8a705de --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/models/modeling_stable_audio.py @@ -0,0 +1,158 @@ +# Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from math import pi +from typing import Optional + +import torch +import torch.nn as nn +import torch.utils.checkpoint + +from ...configuration_utils import ConfigMixin, register_to_config +from ...models.modeling_utils import ModelMixin +from ...utils import BaseOutput, logging + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class StableAudioPositionalEmbedding(nn.Module): + """Used for continuous time""" + + def __init__(self, dim: int): + super().__init__() + assert (dim % 2) == 0 + half_dim = dim // 2 + self.weights = nn.Parameter(torch.randn(half_dim)) + + def forward(self, times: torch.Tensor) -> torch.Tensor: + times = times[..., None] + freqs = times * self.weights[None] * 2 * pi + fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1) + fouriered = torch.cat((times, fouriered), dim=-1) + return fouriered + + +@dataclass +class StableAudioProjectionModelOutput(BaseOutput): + """ + Args: + Class for StableAudio projection layer's outputs. + text_hidden_states (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states obtained by linearly projecting the hidden-states for the text encoder. + seconds_start_hidden_states (`torch.Tensor` of shape `(batch_size, 1, hidden_size)`, *optional*): + Sequence of hidden-states obtained by linearly projecting the audio start hidden states. + seconds_end_hidden_states (`torch.Tensor` of shape `(batch_size, 1, hidden_size)`, *optional*): + Sequence of hidden-states obtained by linearly projecting the audio end hidden states. + """ + + text_hidden_states: Optional[torch.Tensor] = None + seconds_start_hidden_states: Optional[torch.Tensor] = None + seconds_end_hidden_states: Optional[torch.Tensor] = None + + +class StableAudioNumberConditioner(nn.Module): + """ + A simple linear projection model to map numbers to a latent space. + + Args: + number_embedding_dim (`int`): + Dimensionality of the number embeddings. + min_value (`int`): + The minimum value of the seconds number conditioning modules. + max_value (`int`): + The maximum value of the seconds number conditioning modules + internal_dim (`int`): + Dimensionality of the intermediate number hidden states. + """ + + def __init__( + self, + number_embedding_dim, + min_value, + max_value, + internal_dim: Optional[int] = 256, + ): + super().__init__() + self.time_positional_embedding = nn.Sequential( + StableAudioPositionalEmbedding(internal_dim), + nn.Linear(in_features=internal_dim + 1, out_features=number_embedding_dim), + ) + + self.number_embedding_dim = number_embedding_dim + self.min_value = min_value + self.max_value = max_value + + def forward( + self, + floats: torch.Tensor, + ): + floats = floats.clamp(self.min_value, self.max_value) + + normalized_floats = (floats - self.min_value) / (self.max_value - self.min_value) + + # Cast floats to same type as embedder + embedder_dtype = next(self.time_positional_embedding.parameters()).dtype + normalized_floats = normalized_floats.to(embedder_dtype) + + embedding = self.time_positional_embedding(normalized_floats) + float_embeds = embedding.view(-1, 1, self.number_embedding_dim) + + return float_embeds + + +class StableAudioProjectionModel(ModelMixin, ConfigMixin): + """ + A simple linear projection model to map the conditioning values to a shared latent space. + + Args: + text_encoder_dim (`int`): + Dimensionality of the text embeddings from the text encoder (T5). + conditioning_dim (`int`): + Dimensionality of the output conditioning tensors. + min_value (`int`): + The minimum value of the seconds number conditioning modules. + max_value (`int`): + The maximum value of the seconds number conditioning modules + """ + + @register_to_config + def __init__(self, text_encoder_dim, conditioning_dim, min_value, max_value): + super().__init__() + self.text_projection = ( + nn.Identity() if conditioning_dim == text_encoder_dim else nn.Linear(text_encoder_dim, conditioning_dim) + ) + self.start_number_conditioner = StableAudioNumberConditioner(conditioning_dim, min_value, max_value) + self.end_number_conditioner = StableAudioNumberConditioner(conditioning_dim, min_value, max_value) + + def forward( + self, + text_hidden_states: Optional[torch.Tensor] = None, + start_seconds: Optional[torch.Tensor] = None, + end_seconds: Optional[torch.Tensor] = None, + ): + text_hidden_states = ( + text_hidden_states if text_hidden_states is None else self.text_projection(text_hidden_states) + ) + seconds_start_hidden_states = ( + start_seconds if start_seconds is None else self.start_number_conditioner(start_seconds) + ) + seconds_end_hidden_states = end_seconds if end_seconds is None else self.end_number_conditioner(end_seconds) + + return StableAudioProjectionModelOutput( + text_hidden_states=text_hidden_states, + seconds_start_hidden_states=seconds_start_hidden_states, + seconds_end_hidden_states=seconds_end_hidden_states, + ) diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/models/stable_audio_transformer.py b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/models/stable_audio_transformer.py new file mode 100644 index 0000000000..97151fa9f9 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/models/stable_audio_transformer.py @@ -0,0 +1,457 @@ +# Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Any, Dict, Optional, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.utils.checkpoint + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from ..layers.attention import FeedForward +from ..layers.attention_processor import ( + Attention, + AttentionProcessor, + StableAudioAttnProcessor2_0, +) +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.transformers.transformer_2d import Transformer2DModelOutput +from diffusers.utils import is_torch_version, logging + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class StableAudioGaussianFourierProjection(nn.Module): + """Gaussian Fourier embeddings for noise levels.""" + + # Copied from diffusers.models.embeddings.GaussianFourierProjection.__init__ + def __init__( + self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False + ): + super().__init__() + self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) + self.log = log + self.flip_sin_to_cos = flip_sin_to_cos + + if set_W_to_weight: + # to delete later + del self.weight + self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) + self.weight = self.W + del self.W + + def forward(self, x): + if self.log: + x = torch.log(x) + + x_proj = 2 * np.pi * x[:, None] @ self.weight[None, :] + + if self.flip_sin_to_cos: + out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1) + else: + out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) + return out + + + +class StableAudioDiTBlock(nn.Module): + r""" + Transformer block used in Stable Audio model (https://github.com/Stability-AI/stable-audio-tools). Allow skip + connection and QKNorm + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for the query states. + num_key_value_attention_heads (`int`): The number of heads to use for the key and value states. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + upcast_attention (`bool`, *optional*): + Whether to upcast the attention computation to float32. This is useful for mixed precision training. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + num_key_value_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + upcast_attention: bool = False, + norm_eps: float = 1e-5, + ff_inner_dim: Optional[int] = None, + ): + super().__init__() + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + self.norm1 = nn.LayerNorm(dim, elementwise_affine=True, eps=norm_eps) + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=False, + upcast_attention=upcast_attention, + out_bias=False, + processor=StableAudioAttnProcessor2_0(), + ) + + # 2. Cross-Attn + self.norm2 = nn.LayerNorm(dim, norm_eps, True) + + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + kv_heads=num_key_value_attention_heads, + dropout=dropout, + bias=False, + upcast_attention=upcast_attention, + out_bias=False, + processor=StableAudioAttnProcessor2_0(), + ) # is self-attn if encoder_hidden_states is none + + # 3. Feed-forward + self.norm3 = nn.LayerNorm(dim, norm_eps, True) + self.ff = FeedForward( + dim, + dropout=dropout, + activation_fn="swiglu", + final_dropout=False, + inner_dim=ff_inner_dim, + bias=True, + ) + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): + # Sets chunk feed-forward + self._chunk_size = chunk_size + self._chunk_dim = dim + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + rotary_embedding: Optional[torch.FloatTensor] = None, + ) -> torch.Tensor: + # Notice that normalization is always applied before the real computation in the following blocks. + # 0. Self-Attention + norm_hidden_states = self.norm1(hidden_states) + + attn_output = self.attn1( + norm_hidden_states, + attention_mask=attention_mask, + rotary_emb=rotary_embedding, + ) + + hidden_states = attn_output + hidden_states + + # 2. Cross-Attention + norm_hidden_states = self.norm2(hidden_states) + + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + ) + hidden_states = attn_output + hidden_states + + # 3. Feed-forward + norm_hidden_states = self.norm3(hidden_states) + ff_output = self.ff(norm_hidden_states) + + hidden_states = ff_output + hidden_states + + return hidden_states + + +class StableAudioDiTModel(ModelMixin, ConfigMixin): + """ + The Diffusion Transformer model introduced in Stable Audio. + + Reference: https://github.com/Stability-AI/stable-audio-tools + + Parameters: + sample_size ( `int`, *optional*, defaults to 1024): The size of the input sample. + in_channels (`int`, *optional*, defaults to 64): The number of channels in the input. + num_layers (`int`, *optional*, defaults to 24): The number of layers of Transformer blocks to use. + attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head. + num_attention_heads (`int`, *optional*, defaults to 24): The number of heads to use for the query states. + num_key_value_attention_heads (`int`, *optional*, defaults to 12): + The number of heads to use for the key and value states. + out_channels (`int`, defaults to 64): Number of output channels. + cross_attention_dim ( `int`, *optional*, defaults to 768): Dimension of the cross-attention projection. + time_proj_dim ( `int`, *optional*, defaults to 256): Dimension of the timestep inner projection. + global_states_input_dim ( `int`, *optional*, defaults to 1536): + Input dimension of the global hidden states projection. + cross_attention_input_dim ( `int`, *optional*, defaults to 768): + Input dimension of the cross-attention projection + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: int = 1024, + in_channels: int = 64, + num_layers: int = 24, + attention_head_dim: int = 64, + num_attention_heads: int = 24, + num_key_value_attention_heads: int = 12, + out_channels: int = 64, + cross_attention_dim: int = 768, + time_proj_dim: int = 256, + global_states_input_dim: int = 1536, + cross_attention_input_dim: int = 768, + ): + super().__init__() + self.sample_size = sample_size + self.out_channels = out_channels + self.inner_dim = num_attention_heads * attention_head_dim + + self.time_proj = StableAudioGaussianFourierProjection( + embedding_size=time_proj_dim // 2, + flip_sin_to_cos=True, + log=False, + set_W_to_weight=False, + ) + + self.timestep_proj = nn.Sequential( + nn.Linear(time_proj_dim, self.inner_dim, bias=True), + nn.SiLU(), + nn.Linear(self.inner_dim, self.inner_dim, bias=True), + ) + + self.global_proj = nn.Sequential( + nn.Linear(global_states_input_dim, self.inner_dim, bias=False), + nn.SiLU(), + nn.Linear(self.inner_dim, self.inner_dim, bias=False), + ) + + self.cross_attention_proj = nn.Sequential( + nn.Linear(cross_attention_input_dim, cross_attention_dim, bias=False), + nn.SiLU(), + nn.Linear(cross_attention_dim, cross_attention_dim, bias=False), + ) + + self.preprocess_conv = nn.Conv1d(in_channels, in_channels, 1, bias=False) + self.proj_in = nn.Linear(in_channels, self.inner_dim, bias=False) + + self.transformer_blocks = nn.ModuleList( + [ + StableAudioDiTBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + num_key_value_attention_heads=num_key_value_attention_heads, + attention_head_dim=attention_head_dim, + cross_attention_dim=cross_attention_dim, + ) + for i in range(num_layers) + ] + ) + + self.proj_out = nn.Linear(self.inner_dim, self.out_channels, bias=False) + self.postprocess_conv = nn.Conv1d(self.out_channels, self.out_channels, 1, bias=False) + + self.gradient_checkpointing = False + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + # Copied from diffusers.models.transformers.hunyuan_transformer_2d.HunyuanDiT2DModel.set_default_attn_processor with Hunyuan->StableAudio + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + self.set_attn_processor(StableAudioAttnProcessor2_0()) + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + def forward( + self, + hidden_states: torch.FloatTensor, + timestep: torch.LongTensor = None, + encoder_hidden_states: torch.FloatTensor = None, + global_hidden_states: torch.FloatTensor = None, + rotary_embedding: torch.FloatTensor = None, + return_dict: bool = True, + attention_mask: Optional[torch.LongTensor] = None, + encoder_attention_mask: Optional[torch.LongTensor] = None, + ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: + """ + The [`StableAudioDiTModel`] forward method. + + Args: + hidden_states (`torch.FloatTensor` of shape `(batch size, in_channels, sequence_len)`): + Input `hidden_states`. + timestep ( `torch.LongTensor`): + Used to indicate denoising step. + encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, encoder_sequence_len, cross_attention_input_dim)`): + Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. + global_hidden_states (`torch.FloatTensor` of shape `(batch size, global_sequence_len, global_states_input_dim)`): + Global embeddings that will be prepended to the hidden states. + rotary_embedding (`torch.Tensor`): + The rotary embeddings to apply on query and key tensors during attention calculation. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain + tuple. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_len)`, *optional*): + Mask to avoid performing attention on padding token indices, formed by concatenating the attention + masks + for the two text encoders together. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + encoder_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_len)`, *optional*): + Mask to avoid performing attention on padding token cross-attention indices, formed by concatenating + the attention masks + for the two text encoders together. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + cross_attention_hidden_states = self.cross_attention_proj(encoder_hidden_states) + global_hidden_states = self.global_proj(global_hidden_states) + time_hidden_states = self.timestep_proj(self.time_proj(timestep.to(self.dtype))) + + global_hidden_states = global_hidden_states + time_hidden_states.unsqueeze(1) + + hidden_states = self.preprocess_conv(hidden_states) + hidden_states + # (batch_size, dim, sequence_length) -> (batch_size, sequence_length, dim) + hidden_states = hidden_states.transpose(1, 2) + + hidden_states = self.proj_in(hidden_states) + + # prepend global states to hidden states + hidden_states = torch.cat([global_hidden_states, hidden_states], dim=-2) + if attention_mask is not None: + prepend_mask = torch.ones((hidden_states.shape[0], 1), device=hidden_states.device, dtype=torch.bool) + attention_mask = torch.cat([prepend_mask, attention_mask], dim=-1) + + for block in self.transformer_blocks: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + attention_mask, + cross_attention_hidden_states, + encoder_attention_mask, + rotary_embedding, + **ckpt_kwargs, + ) + + else: + hidden_states = block( + hidden_states=hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=cross_attention_hidden_states, + encoder_attention_mask=encoder_attention_mask, + rotary_embedding=rotary_embedding, + ) + + hidden_states = self.proj_out(hidden_states) + + # (batch_size, sequence_length, dim) -> (batch_size, dim, sequence_length) + # remove prepend length that has been added by global hidden states + hidden_states = hidden_states.transpose(1, 2)[:, :, 1:] + hidden_states = self.postprocess_conv(hidden_states) + hidden_states + + if not return_dict: + return (hidden_states,) + + return Transformer2DModelOutput(sample=hidden_states) diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/pipeline/__init__.py b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/pipeline/__init__.py new file mode 100644 index 0000000000..b1ddb485f1 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/pipeline/__init__.py @@ -0,0 +1 @@ +from .pipeline_stable_audio import StableAudioPipeline \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/pipeline/pipeline_stable_audio.py b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/pipeline/pipeline_stable_audio.py new file mode 100644 index 0000000000..4fe082d889 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/pipeline/pipeline_stable_audio.py @@ -0,0 +1,745 @@ +# Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable, List, Optional, Union + +import torch +from transformers import ( + T5EncoderModel, + T5Tokenizer, + T5TokenizerFast, +) + +from ...models import AutoencoderOobleck, StableAudioDiTModel +from ...models.embeddings import get_1d_rotary_pos_embed +from ...schedulers import EDMDPMSolverMultistepScheduler +from ...utils import ( + logging, + replace_example_docstring, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline +from .modeling_stable_audio import StableAudioProjectionModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import scipy + >>> import torch + >>> import soundfile as sf + >>> from diffusers import StableAudioPipeline + + >>> repo_id = "stabilityai/stable-audio-open-1.0" + >>> pipe = StableAudioPipeline.from_pretrained(repo_id, torch_dtype=torch.float16) + >>> pipe = pipe.to("cuda") + + >>> # define the prompts + >>> prompt = "The sound of a hammer hitting a wooden surface." + >>> negative_prompt = "Low quality." + + >>> # set the seed for generator + >>> generator = torch.Generator("cuda").manual_seed(0) + + >>> # run the generation + >>> audio = pipe( + ... prompt, + ... negative_prompt=negative_prompt, + ... num_inference_steps=200, + ... audio_end_in_s=10.0, + ... num_waveforms_per_prompt=3, + ... generator=generator, + ... ).audios + + >>> output = audio[0].T.float().cpu().numpy() + >>> sf.write("hammer.wav", output, pipe.vae.sampling_rate) + ``` +""" + + +class StableAudioPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-audio generation using StableAudio. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + vae ([`AutoencoderOobleck`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.T5EncoderModel`]): + Frozen text-encoder. StableAudio uses the encoder of + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the + [google-t5/t5-base](https://huggingface.co/google-t5/t5-base) variant. + projection_model ([`StableAudioProjectionModel`]): + A trained model used to linearly project the hidden-states from the text encoder model and the start and + end seconds. The projected hidden-states from the encoder and the conditional seconds are concatenated to + give the input to the transformer model. + tokenizer ([`~transformers.T5Tokenizer`]): + Tokenizer to tokenize text for the frozen text-encoder. + transformer ([`StableAudioDiTModel`]): + A `StableAudioDiTModel` to denoise the encoded audio latents. + scheduler ([`EDMDPMSolverMultistepScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded audio latents. + """ + + model_cpu_offload_seq = "text_encoder->projection_model->transformer->vae" + + def __init__( + self, + vae: AutoencoderOobleck, + text_encoder: T5EncoderModel, + projection_model: StableAudioProjectionModel, + tokenizer: Union[T5Tokenizer, T5TokenizerFast], + transformer: StableAudioDiTModel, + scheduler: EDMDPMSolverMultistepScheduler, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + projection_model=projection_model, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + ) + self.rotary_embed_dim = self.transformer.config.attention_head_dim // 2 + + # Copied from diffusers.pipelines.pipeline_utils.StableDiffusionMixin.enable_vae_slicing + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.pipeline_utils.StableDiffusionMixin.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def encode_prompt( + self, + prompt, + device, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + negative_attention_mask: Optional[torch.LongTensor] = None, + ): + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # 1. Tokenize text + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + attention_mask = text_inputs.attention_mask + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + f"The following part of your input was truncated because {self.text_encoder.config.model_type} can " + f"only handle sequences up to {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + text_input_ids = text_input_ids.to(device) + attention_mask = attention_mask.to(device) + + # 2. Text encoder forward + self.text_encoder.eval() + prompt_embeds = self.text_encoder( + text_input_ids, + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + if do_classifier_free_guidance and negative_prompt is not None: + uncond_tokens: List[str] + if type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # 1. Tokenize text + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + uncond_input_ids = uncond_input.input_ids.to(device) + negative_attention_mask = uncond_input.attention_mask.to(device) + + # 2. Text encoder forward + self.text_encoder.eval() + negative_prompt_embeds = self.text_encoder( + uncond_input_ids, + attention_mask=negative_attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if negative_attention_mask is not None: + # set the masked tokens to the null embed + negative_prompt_embeds = torch.where( + negative_attention_mask.to(torch.bool).unsqueeze(2), negative_prompt_embeds, 0.0 + ) + + # 3. Project prompt_embeds and negative_prompt_embeds + if do_classifier_free_guidance and negative_prompt_embeds is not None: + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the negative and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + if attention_mask is not None and negative_attention_mask is None: + negative_attention_mask = torch.ones_like(attention_mask) + elif attention_mask is None and negative_attention_mask is not None: + attention_mask = torch.ones_like(negative_attention_mask) + + if attention_mask is not None: + attention_mask = torch.cat([negative_attention_mask, attention_mask]) + + prompt_embeds = self.projection_model( + text_hidden_states=prompt_embeds, + ).text_hidden_states + if attention_mask is not None: + prompt_embeds = prompt_embeds * attention_mask.unsqueeze(-1).to(prompt_embeds.dtype) + prompt_embeds = prompt_embeds * attention_mask.unsqueeze(-1).to(prompt_embeds.dtype) + + return prompt_embeds + + def encode_duration( + self, + audio_start_in_s, + audio_end_in_s, + device, + do_classifier_free_guidance, + batch_size, + ): + audio_start_in_s = audio_start_in_s if isinstance(audio_start_in_s, list) else [audio_start_in_s] + audio_end_in_s = audio_end_in_s if isinstance(audio_end_in_s, list) else [audio_end_in_s] + + if len(audio_start_in_s) == 1: + audio_start_in_s = audio_start_in_s * batch_size + if len(audio_end_in_s) == 1: + audio_end_in_s = audio_end_in_s * batch_size + + # Cast the inputs to floats + audio_start_in_s = [float(x) for x in audio_start_in_s] + audio_start_in_s = torch.tensor(audio_start_in_s).to(device) + + audio_end_in_s = [float(x) for x in audio_end_in_s] + audio_end_in_s = torch.tensor(audio_end_in_s).to(device) + + projection_output = self.projection_model( + start_seconds=audio_start_in_s, + end_seconds=audio_end_in_s, + ) + seconds_start_hidden_states = projection_output.seconds_start_hidden_states + seconds_end_hidden_states = projection_output.seconds_end_hidden_states + + # For classifier free guidance, we need to do two forward passes. + # Here we repeat the audio hidden states to avoid doing two forward passes + if do_classifier_free_guidance: + seconds_start_hidden_states = torch.cat([seconds_start_hidden_states, seconds_start_hidden_states], dim=0) + seconds_end_hidden_states = torch.cat([seconds_end_hidden_states, seconds_end_hidden_states], dim=0) + + return seconds_start_hidden_states, seconds_end_hidden_states + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + audio_start_in_s, + audio_end_in_s, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + attention_mask=None, + negative_attention_mask=None, + initial_audio_waveforms=None, + initial_audio_sampling_rate=None, + ): + if audio_end_in_s < audio_start_in_s: + raise ValueError( + f"`audio_end_in_s={audio_end_in_s}' must be higher than 'audio_start_in_s={audio_start_in_s}` but " + ) + + if ( + audio_start_in_s < self.projection_model.config.min_value + or audio_start_in_s > self.projection_model.config.max_value + ): + raise ValueError( + f"`audio_start_in_s` must be greater than or equal to {self.projection_model.config.min_value}, and lower than or equal to {self.projection_model.config.max_value} but " + f"is {audio_start_in_s}." + ) + + if ( + audio_end_in_s < self.projection_model.config.min_value + or audio_end_in_s > self.projection_model.config.max_value + ): + raise ValueError( + f"`audio_end_in_s` must be greater than or equal to {self.projection_model.config.min_value}, and lower than or equal to {self.projection_model.config.max_value} but " + f"is {audio_end_in_s}." + ) + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and (prompt_embeds is None): + raise ValueError( + "Provide either `prompt`, or `prompt_embeds`. Cannot leave" + "`prompt` undefined without specifying `prompt_embeds`." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if attention_mask is not None and attention_mask.shape != prompt_embeds.shape[:2]: + raise ValueError( + "`attention_mask should have the same batch size and sequence length as `prompt_embeds`, but got:" + f"`attention_mask: {attention_mask.shape} != `prompt_embeds` {prompt_embeds.shape}" + ) + + if initial_audio_sampling_rate is None and initial_audio_waveforms is not None: + raise ValueError( + "`initial_audio_waveforms' is provided but the sampling rate is not. Make sure to pass `initial_audio_sampling_rate`." + ) + + if initial_audio_sampling_rate is not None and initial_audio_sampling_rate != self.vae.sampling_rate: + raise ValueError( + f"`initial_audio_sampling_rate` must be {self.vae.hop_length}' but is `{initial_audio_sampling_rate}`." + "Make sure to resample the `initial_audio_waveforms` and to correct the sampling rate. " + ) + + def prepare_latents( + self, + batch_size, + num_channels_vae, + sample_size, + dtype, + device, + generator, + latents=None, + initial_audio_waveforms=None, + num_waveforms_per_prompt=None, + audio_channels=None, + ): + shape = (batch_size, num_channels_vae, sample_size) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + + # encode the initial audio for use by the model + if initial_audio_waveforms is not None: + # check dimension + if initial_audio_waveforms.ndim == 2: + initial_audio_waveforms = initial_audio_waveforms.unsqueeze(1) + elif initial_audio_waveforms.ndim != 3: + raise ValueError( + f"`initial_audio_waveforms` must be of shape `(batch_size, num_channels, audio_length)` or `(batch_size, audio_length)` but has `{initial_audio_waveforms.ndim}` dimensions" + ) + + audio_vae_length = self.transformer.config.sample_size * self.vae.hop_length + audio_shape = (batch_size // num_waveforms_per_prompt, audio_channels, audio_vae_length) + + # check num_channels + if initial_audio_waveforms.shape[1] == 1 and audio_channels == 2: + initial_audio_waveforms = initial_audio_waveforms.repeat(1, 2, 1) + elif initial_audio_waveforms.shape[1] == 2 and audio_channels == 1: + initial_audio_waveforms = initial_audio_waveforms.mean(1, keepdim=True) + + if initial_audio_waveforms.shape[:2] != audio_shape[:2]: + raise ValueError( + f"`initial_audio_waveforms` must be of shape `(batch_size, num_channels, audio_length)` or `(batch_size, audio_length)` but is of shape `{initial_audio_waveforms.shape}`" + ) + + # crop or pad + audio_length = initial_audio_waveforms.shape[-1] + if audio_length < audio_vae_length: + logger.warning( + f"The provided input waveform is shorter ({audio_length}) than the required audio length ({audio_vae_length}) of the model and will thus be padded." + ) + elif audio_length > audio_vae_length: + logger.warning( + f"The provided input waveform is longer ({audio_length}) than the required audio length ({audio_vae_length}) of the model and will thus be cropped." + ) + + audio = initial_audio_waveforms.new_zeros(audio_shape) + audio[:, :, : min(audio_length, audio_vae_length)] = initial_audio_waveforms[:, :, :audio_vae_length] + + encoded_audio = self.vae.encode(audio).latent_dist.sample(generator) + encoded_audio = encoded_audio.repeat((num_waveforms_per_prompt, 1, 1)) + latents = encoded_audio + latents + return latents + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + audio_end_in_s: Optional[float] = None, + audio_start_in_s: Optional[float] = 0.0, + num_inference_steps: int = 100, + guidance_scale: float = 7.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_waveforms_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + initial_audio_waveforms: Optional[torch.Tensor] = None, + initial_audio_sampling_rate: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + negative_attention_mask: Optional[torch.LongTensor] = None, + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, + callback_steps: Optional[int] = 1, + output_type: Optional[str] = "pt", + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide audio generation. If not defined, you need to pass `prompt_embeds`. + audio_end_in_s (`float`, *optional*, defaults to 47.55): + Audio end index in seconds. + audio_start_in_s (`float`, *optional*, defaults to 0): + Audio start index in seconds. + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality audio at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.0): + A higher guidance scale value encourages the model to generate audio that is closely linked to the text + `prompt` at the expense of lower sound quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in audio generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_waveforms_per_prompt (`int`, *optional*, defaults to 1): + The number of waveforms to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for audio + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + initial_audio_waveforms (`torch.Tensor`, *optional*): + Optional initial audio waveforms to use as the initial audio waveform for generation. Must be of shape + `(batch_size, num_channels, audio_length)` or `(batch_size, audio_length)`, where `batch_size` + corresponds to the number of prompts passed to the model. + initial_audio_sampling_rate (`int`, *optional*): + Sampling rate of the `initial_audio_waveforms`, if they are provided. Must be the same as the model. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-computed text embeddings from the text encoder model. Can be used to easily tweak text inputs, + *e.g.* prompt weighting. If not provided, text embeddings will be computed from `prompt` input + argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-computed negative text embeddings from the text encoder model. Can be used to easily tweak text + inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be computed from + `negative_prompt` input argument. + attention_mask (`torch.LongTensor`, *optional*): + Pre-computed attention mask to be applied to the `prompt_embeds`. If not provided, attention mask will + be computed from `prompt` input argument. + negative_attention_mask (`torch.LongTensor`, *optional*): + Pre-computed attention mask to be applied to the `negative_text_audio_duration_embeds`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + output_type (`str`, *optional*, defaults to `"pt"`): + The output format of the generated audio. Choose between `"np"` to return a NumPy `np.ndarray` or + `"pt"` to return a PyTorch `torch.Tensor` object. Set to `"latent"` to return the latent diffusion + model (LDM) output. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated audio. + """ + # 0. Convert audio input length from seconds to latent length + downsample_ratio = self.vae.hop_length + + max_audio_length_in_s = self.transformer.config.sample_size * downsample_ratio / self.vae.config.sampling_rate + if audio_end_in_s is None: + audio_end_in_s = max_audio_length_in_s + + if audio_end_in_s - audio_start_in_s > max_audio_length_in_s: + raise ValueError( + f"The total audio length requested ({audio_end_in_s-audio_start_in_s}s) is longer than the model maximum possible length ({max_audio_length_in_s}). Make sure that 'audio_end_in_s-audio_start_in_s<={max_audio_length_in_s}'." + ) + + waveform_start = int(audio_start_in_s * self.vae.config.sampling_rate) + waveform_end = int(audio_end_in_s * self.vae.config.sampling_rate) + waveform_length = int(self.transformer.config.sample_size) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + audio_start_in_s, + audio_end_in_s, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + attention_mask, + negative_attention_mask, + initial_audio_waveforms, + initial_audio_sampling_rate, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds = self.encode_prompt( + prompt, + device, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + attention_mask, + negative_attention_mask, + ) + + # Encode duration + seconds_start_hidden_states, seconds_end_hidden_states = self.encode_duration( + audio_start_in_s, + audio_end_in_s, + device, + do_classifier_free_guidance and (negative_prompt is not None or negative_prompt_embeds is not None), + batch_size, + ) + + # Create text_audio_duration_embeds and audio_duration_embeds + text_audio_duration_embeds = torch.cat( + [prompt_embeds, seconds_start_hidden_states, seconds_end_hidden_states], dim=1 + ) + + audio_duration_embeds = torch.cat([seconds_start_hidden_states, seconds_end_hidden_states], dim=2) + + # In case of classifier free guidance without negative prompt, we need to create unconditional embeddings and + # to concatenate it to the embeddings + if do_classifier_free_guidance and negative_prompt_embeds is None and negative_prompt is None: + negative_text_audio_duration_embeds = torch.zeros_like( + text_audio_duration_embeds, device=text_audio_duration_embeds.device + ) + text_audio_duration_embeds = torch.cat( + [negative_text_audio_duration_embeds, text_audio_duration_embeds], dim=0 + ) + audio_duration_embeds = torch.cat([audio_duration_embeds, audio_duration_embeds], dim=0) + + bs_embed, seq_len, hidden_size = text_audio_duration_embeds.shape + # duplicate audio_duration_embeds and text_audio_duration_embeds for each generation per prompt, using mps friendly method + text_audio_duration_embeds = text_audio_duration_embeds.repeat(1, num_waveforms_per_prompt, 1) + text_audio_duration_embeds = text_audio_duration_embeds.view( + bs_embed * num_waveforms_per_prompt, seq_len, hidden_size + ) + + audio_duration_embeds = audio_duration_embeds.repeat(1, num_waveforms_per_prompt, 1) + audio_duration_embeds = audio_duration_embeds.view( + bs_embed * num_waveforms_per_prompt, -1, audio_duration_embeds.shape[-1] + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_vae = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_waveforms_per_prompt, + num_channels_vae, + waveform_length, + text_audio_duration_embeds.dtype, + device, + generator, + latents, + initial_audio_waveforms, + num_waveforms_per_prompt, + audio_channels=self.vae.config.audio_channels, + ) + + # 6. Prepare extra step kwargs + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Prepare rotary positional embedding + rotary_embedding = get_1d_rotary_pos_embed( + self.rotary_embed_dim, + latents.shape[2] + audio_duration_embeds.shape[1], + use_real=True, + repeat_interleave_real=False, + ) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.transformer( + latent_model_input, + t.unsqueeze(0), + encoder_hidden_states=text_audio_duration_embeds, + global_hidden_states=audio_duration_embeds, + rotary_embedding=rotary_embedding, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + # 9. Post-processing + if not output_type == "latent": + audio = self.vae.decode(latents).sample + else: + return AudioPipelineOutput(audios=latents) + + audio = audio[:, :, waveform_start:waveform_end] + + if output_type == "np": + audio = audio.cpu().float().numpy() + + self.maybe_free_model_hooks() + + if not return_dict: + return (audio,) + + return AudioPipelineOutput(audios=audio) diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/scheduler/__init__.py b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/scheduler/__init__.py new file mode 100644 index 0000000000..5bad3d9b15 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/scheduler/__init__.py @@ -0,0 +1 @@ +from .scheduling_cosine_dpmsolver_multistep import CosineDPMSolverMultistepScheduler \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/scheduler/scheduling_cosine_dpmsolver_multistep.py b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/scheduler/scheduling_cosine_dpmsolver_multistep.py new file mode 100644 index 0000000000..69d0458d12 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/scheduler/scheduling_cosine_dpmsolver_multistep.py @@ -0,0 +1,572 @@ +# Copyright 2024 TSAIL Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver and https://github.com/NVlabs/edm + +import math +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.scheduling_dpmsolver_sde import BrownianTreeNoiseSampler +from diffusers.scheduling_utils import SchedulerMixin, SchedulerOutput + + +class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): + """ + Implements a variant of `DPMSolverMultistepScheduler` with cosine schedule, proposed by Nichol and Dhariwal (2021). + This scheduler was used in Stable Audio Open [1]. + + [1] Evans, Parker, et al. "Stable Audio Open" https://arxiv.org/abs/2407.14358 + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + sigma_min (`float`, *optional*, defaults to 0.3): + Minimum noise magnitude in the sigma schedule. This was set to 0.3 in Stable Audio Open [1]. + sigma_max (`float`, *optional*, defaults to 500): + Maximum noise magnitude in the sigma schedule. This was set to 500 in Stable Audio Open [1]. + sigma_data (`float`, *optional*, defaults to 1.0): + The standard deviation of the data distribution. This is set to 1.0 in Stable Audio Open [1]. + sigma_schedule (`str`, *optional*, defaults to `exponential`): + Sigma schedule to compute the `sigmas`. By default, we the schedule introduced in the EDM paper + (https://arxiv.org/abs/2206.00364). Other acceptable value is "exponential". The exponential schedule was + incorporated in this model: https://huggingface.co/stabilityai/cosxl. + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + solver_order (`int`, defaults to 2): + The DPMSolver order which can be `1` or `2`. It is recommended to use `solver_order=2`. + prediction_type (`str`, defaults to `v_prediction`, *optional*): + Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), + `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen + Video](https://imagen.research.google/video/paper.pdf) paper). + solver_type (`str`, defaults to `midpoint`): + Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the + sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers. + lower_order_final (`bool`, defaults to `True`): + Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can + stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10. + euler_at_final (`bool`, defaults to `False`): + Whether to use Euler's method in the final step. It is a trade-off between numerical stability and detail + richness. This can stabilize the sampling of the SDE variant of DPMSolver for small number of inference + steps, but sometimes may result in blurring. + final_sigmas_type (`str`, defaults to `"zero"`): + The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final + sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. + """ + + _compatibles = [] + order = 1 + + @register_to_config + def __init__( + self, + sigma_min: float = 0.3, + sigma_max: float = 500, + sigma_data: float = 1.0, + sigma_schedule: str = "exponential", + num_train_timesteps: int = 1000, + solver_order: int = 2, + prediction_type: str = "v_prediction", + rho: float = 7.0, + solver_type: str = "midpoint", + lower_order_final: bool = True, + euler_at_final: bool = False, + final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min" + ): + if solver_type not in ["midpoint", "heun"]: + if solver_type in ["logrho", "bh1", "bh2"]: + self.register_to_config(solver_type="midpoint") + else: + raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}") + + ramp = torch.linspace(0, 1, num_train_timesteps) + if sigma_schedule == "karras": + sigmas = self._compute_karras_sigmas(ramp) + elif sigma_schedule == "exponential": + sigmas = self._compute_exponential_sigmas(ramp) + + self.timesteps = self.precondition_noise(sigmas) + + self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) + + # setable values + self.num_inference_steps = None + self.model_outputs = [None] * solver_order + self.lower_order_nums = 0 + self._step_index = None + self._begin_index = None + self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication + + @property + def init_noise_sigma(self): + # standard deviation of the initial noise distribution + return (self.config.sigma_max**2 + 1) ** 0.5 + + @property + def step_index(self): + """ + The index counter for current timestep. It will increase 1 after each scheduler step. + """ + return self._step_index + + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.precondition_inputs + def precondition_inputs(self, sample, sigma): + c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5) + scaled_sample = sample * c_in + return scaled_sample + + def precondition_noise(self, sigma): + if not isinstance(sigma, torch.Tensor): + sigma = torch.tensor([sigma]) + + return sigma.atan() / math.pi * 2 + + # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.precondition_outputs + def precondition_outputs(self, sample, model_output, sigma): + sigma_data = self.config.sigma_data + c_skip = sigma_data**2 / (sigma**2 + sigma_data**2) + + if self.config.prediction_type == "epsilon": + c_out = sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5 + elif self.config.prediction_type == "v_prediction": + c_out = -sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5 + else: + raise ValueError(f"Prediction type {self.config.prediction_type} is not supported.") + + denoised = c_skip * sample + c_out * model_output + + return denoised + + # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.scale_model_input + def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm. + + Args: + sample (`torch.Tensor`): + The input sample. + timestep (`int`, *optional*): + The current timestep in the diffusion chain. + + Returns: + `torch.Tensor`: + A scaled input sample. + """ + if self.step_index is None: + self._init_step_index(timestep) + + sigma = self.sigmas[self.step_index] + sample = self.precondition_inputs(sample, sigma) + + self.is_scale_input_called = True + return sample + + def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + + self.num_inference_steps = num_inference_steps + + ramp = torch.linspace(0, 1, self.num_inference_steps) + if self.config.sigma_schedule == "karras": + sigmas = self._compute_karras_sigmas(ramp) + elif self.config.sigma_schedule == "exponential": + sigmas = self._compute_exponential_sigmas(ramp) + + sigmas = sigmas.to(dtype=torch.float32, device=device) + self.timesteps = self.precondition_noise(sigmas) + + if self.config.final_sigmas_type == "sigma_min": + sigma_last = self.config.sigma_min + elif self.config.final_sigmas_type == "zero": + sigma_last = 0 + else: + raise ValueError( + f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" + ) + + self.sigmas = torch.cat([sigmas, torch.tensor([sigma_last], dtype=torch.float32, device=device)]) + + self.model_outputs = [ + None, + ] * self.config.solver_order + self.lower_order_nums = 0 + + # add an index counter for schedulers that allow duplicated timesteps + self._step_index = None + self._begin_index = None + self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication + + # if a noise sampler is used, reinitialise it + self.noise_sampler = None + + # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_karras_sigmas + def _compute_karras_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor: + """Constructs the noise schedule of Karras et al. (2022).""" + sigma_min = sigma_min or self.config.sigma_min + sigma_max = sigma_max or self.config.sigma_max + + rho = self.config.rho + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return sigmas + + # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_exponential_sigmas + def _compute_exponential_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor: + """Implementation closely follows k-diffusion. + + https://github.com/crowsonkb/k-diffusion/blob/6ab5146d4a5ef63901326489f31f1d8e7dd36b48/k_diffusion/sampling.py#L26 + """ + sigma_min = sigma_min or self.config.sigma_min + sigma_max = sigma_max or self.config.sigma_max + sigmas = torch.linspace(math.log(sigma_min), math.log(sigma_max), len(ramp)).exp().flip(0) + return sigmas + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t + def _sigma_to_t(self, sigma, log_sigmas): + # get log sigma + log_sigma = np.log(np.maximum(sigma, 1e-10)) + + # get distribution + dists = log_sigma - log_sigmas[:, np.newaxis] + + # get sigmas range + low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) + high_idx = low_idx + 1 + + low = log_sigmas[low_idx] + high = log_sigmas[high_idx] + + # interpolate sigmas + w = (low - log_sigma) / (low - high) + w = np.clip(w, 0, 1) + + # transform interpolation to time range + t = (1 - w) * low_idx + w * high_idx + t = t.reshape(sigma.shape) + return t + + def _sigma_to_alpha_sigma_t(self, sigma): + alpha_t = torch.tensor(1) # Inputs are pre-scaled before going into unet, so alpha_t = 1 + sigma_t = sigma + + return alpha_t, sigma_t + + def convert_model_output( + self, + model_output: torch.Tensor, + sample: torch.Tensor = None, + ) -> torch.Tensor: + """ + Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is + designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an + integral of the data prediction model. + + + + The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise + prediction and data prediction models. + + + + Args: + model_output (`torch.Tensor`): + The direct output from the learned diffusion model. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + + Returns: + `torch.Tensor`: + The converted model output. + """ + sigma = self.sigmas[self.step_index] + x0_pred = self.precondition_outputs(sample, model_output, sigma) + + return x0_pred + + def dpm_solver_first_order_update( + self, + model_output: torch.Tensor, + sample: torch.Tensor = None, + noise: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + One step for the first-order DPMSolver (equivalent to DDIM). + + Args: + model_output (`torch.Tensor`): + The direct output from the learned diffusion model. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + + Returns: + `torch.Tensor`: + The sample tensor at the previous timestep. + """ + sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s) + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s = torch.log(alpha_s) - torch.log(sigma_s) + + h = lambda_t - lambda_s + assert noise is not None + x_t = ( + (sigma_t / sigma_s * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise + ) + + return x_t + + def multistep_dpm_solver_second_order_update( + self, + model_output_list: List[torch.Tensor], + sample: torch.Tensor = None, + noise: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + One step for the second-order multistep DPMSolver. + + Args: + model_output_list (`List[torch.Tensor]`): + The direct outputs from learned diffusion model at current and latter timesteps. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + + Returns: + `torch.Tensor`: + The sample tensor at the previous timestep. + """ + sigma_t, sigma_s0, sigma_s1 = ( + self.sigmas[self.step_index + 1], + self.sigmas[self.step_index], + self.sigmas[self.step_index - 1], + ) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) + + m0, m1 = model_output_list[-1], model_output_list[-2] + + h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1 + r0 = h_0 / h + D0, D1 = m0, (1.0 / r0) * (m0 - m1) + + # sde-dpmsolver++ + assert noise is not None + if self.config.solver_type == "midpoint": + x_t = ( + (sigma_t / sigma_s0 * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + + 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1 + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise + ) + elif self.config.solver_type == "heun": + x_t = ( + (sigma_t / sigma_s0 * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + + (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1 + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise + ) + + return x_t + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + index_candidates = (schedule_timesteps == timestep).nonzero() + + if len(index_candidates) == 0: + step_index = len(self.timesteps) - 1 + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + elif len(index_candidates) > 1: + step_index = index_candidates[1].item() + else: + step_index = index_candidates[0].item() + + return step_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index + def _init_step_index(self, timestep): + """ + Initialize the step_index counter for the scheduler. + """ + + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + def step( + self, + model_output: torch.Tensor, + timestep: Union[int, torch.Tensor], + sample: torch.Tensor, + generator=None, + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with + the multistep DPMSolver. + + Args: + model_output (`torch.Tensor`): + The direct output from learned diffusion model. + timestep (`int`): + The current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + generator (`torch.Generator`, *optional*): + A random number generator. + return_dict (`bool`): + Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. + + Returns: + [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if self.step_index is None: + self._init_step_index(timestep) + + # Improve numerical stability for small number of steps + lower_order_final = (self.step_index == len(self.timesteps) - 1) and ( + self.config.euler_at_final + or (self.config.lower_order_final and len(self.timesteps) < 15) + or self.config.final_sigmas_type == "zero" + ) + lower_order_second = ( + (self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15 + ) + + model_output = self.convert_model_output(model_output, sample=sample) + for i in range(self.config.solver_order - 1): + self.model_outputs[i] = self.model_outputs[i + 1] + self.model_outputs[-1] = model_output + + if self.noise_sampler is None: + seed = None + if generator is not None: + seed = ( + [g.initial_seed() for g in generator] if isinstance(generator, list) else generator.initial_seed() + ) + self.noise_sampler = BrownianTreeNoiseSampler( + model_output, sigma_min=self.config.sigma_min, sigma_max=self.config.sigma_max, seed=seed + ) + noise = self.noise_sampler(self.sigmas[self.step_index], self.sigmas[self.step_index + 1]).to( + model_output.device + ) + + if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final: + prev_sample = self.dpm_solver_first_order_update(model_output, sample=sample, noise=noise) + elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second: + prev_sample = self.multistep_dpm_solver_second_order_update(self.model_outputs, sample=sample, noise=noise) + + if self.lower_order_nums < self.config.solver_order: + self.lower_order_nums += 1 + + # upon completion increase step index by one + self._step_index += 1 + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.Tensor, + ) -> torch.Tensor: + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) + if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): + # mps does not support float64 + schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) + timesteps = timesteps.to(original_samples.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) + + # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index + if self.begin_index is None: + step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timesteps.shape[0] + else: + # add noise is called before first denoising step to create initial latent(img2img) + step_indices = [self.begin_index] * timesteps.shape[0] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + noisy_samples = original_samples + noise * sigma + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/vae/__init__.py b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/vae/__init__.py new file mode 100644 index 0000000000..24a5f3bc5f --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/vae/__init__.py @@ -0,0 +1 @@ +from .autoencoder_oobleck import AutoencoderOobleck \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/vae/autoencoder_oobleck.py b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/vae/autoencoder_oobleck.py new file mode 100644 index 0000000000..e8e372a709 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/vae/autoencoder_oobleck.py @@ -0,0 +1,464 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +from torch.nn.utils import weight_norm + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import BaseOutput +from ...utils.accelerate_utils import apply_forward_hook +from ...utils.torch_utils import randn_tensor +from ..modeling_utils import ModelMixin + + +class Snake1d(nn.Module): + """ + A 1-dimensional Snake activation function module. + """ + + def __init__(self, hidden_dim, logscale=True): + super().__init__() + self.alpha = nn.Parameter(torch.zeros(1, hidden_dim, 1)) + self.beta = nn.Parameter(torch.zeros(1, hidden_dim, 1)) + + self.alpha.requires_grad = True + self.beta.requires_grad = True + self.logscale = logscale + + def forward(self, hidden_states): + shape = hidden_states.shape + + alpha = self.alpha if not self.logscale else torch.exp(self.alpha) + beta = self.beta if not self.logscale else torch.exp(self.beta) + + hidden_states = hidden_states.reshape(shape[0], shape[1], -1) + hidden_states = hidden_states + (beta + 1e-9).reciprocal() * torch.sin(alpha * hidden_states).pow(2) + hidden_states = hidden_states.reshape(shape) + return hidden_states + + +class OobleckResidualUnit(nn.Module): + """ + A residual unit composed of Snake1d and weight-normalized Conv1d layers with dilations. + """ + + def __init__(self, dimension: int = 16, dilation: int = 1): + super().__init__() + pad = ((7 - 1) * dilation) // 2 + + self.snake1 = Snake1d(dimension) + self.conv1 = weight_norm(nn.Conv1d(dimension, dimension, kernel_size=7, dilation=dilation, padding=pad)) + self.snake2 = Snake1d(dimension) + self.conv2 = weight_norm(nn.Conv1d(dimension, dimension, kernel_size=1)) + + def forward(self, hidden_state): + """ + Forward pass through the residual unit. + + Args: + hidden_state (`torch.Tensor` of shape `(batch_size, channels, time_steps)`): + Input tensor . + + Returns: + output_tensor (`torch.Tensor` of shape `(batch_size, channels, time_steps)`) + Input tensor after passing through the residual unit. + """ + output_tensor = hidden_state + output_tensor = self.conv1(self.snake1(output_tensor)) + output_tensor = self.conv2(self.snake2(output_tensor)) + + padding = (hidden_state.shape[-1] - output_tensor.shape[-1]) // 2 + if padding > 0: + hidden_state = hidden_state[..., padding:-padding] + output_tensor = hidden_state + output_tensor + return output_tensor + + +class OobleckEncoderBlock(nn.Module): + """Encoder block used in Oobleck encoder.""" + + def __init__(self, input_dim, output_dim, stride: int = 1): + super().__init__() + + self.res_unit1 = OobleckResidualUnit(input_dim, dilation=1) + self.res_unit2 = OobleckResidualUnit(input_dim, dilation=3) + self.res_unit3 = OobleckResidualUnit(input_dim, dilation=9) + self.snake1 = Snake1d(input_dim) + self.conv1 = weight_norm( + nn.Conv1d(input_dim, output_dim, kernel_size=2 * stride, stride=stride, padding=math.ceil(stride / 2)) + ) + + def forward(self, hidden_state): + hidden_state = self.res_unit1(hidden_state) + hidden_state = self.res_unit2(hidden_state) + hidden_state = self.snake1(self.res_unit3(hidden_state)) + hidden_state = self.conv1(hidden_state) + + return hidden_state + + +class OobleckDecoderBlock(nn.Module): + """Decoder block used in Oobleck decoder.""" + + def __init__(self, input_dim, output_dim, stride: int = 1): + super().__init__() + + self.snake1 = Snake1d(input_dim) + self.conv_t1 = weight_norm( + nn.ConvTranspose1d( + input_dim, + output_dim, + kernel_size=2 * stride, + stride=stride, + padding=math.ceil(stride / 2), + ) + ) + self.res_unit1 = OobleckResidualUnit(output_dim, dilation=1) + self.res_unit2 = OobleckResidualUnit(output_dim, dilation=3) + self.res_unit3 = OobleckResidualUnit(output_dim, dilation=9) + + def forward(self, hidden_state): + hidden_state = self.snake1(hidden_state) + hidden_state = self.conv_t1(hidden_state) + hidden_state = self.res_unit1(hidden_state) + hidden_state = self.res_unit2(hidden_state) + hidden_state = self.res_unit3(hidden_state) + + return hidden_state + + +class OobleckDiagonalGaussianDistribution(object): + def __init__(self, parameters: torch.Tensor, deterministic: bool = False): + self.parameters = parameters + self.mean, self.scale = parameters.chunk(2, dim=1) + self.std = nn.functional.softplus(self.scale) + 1e-4 + self.var = self.std * self.std + self.logvar = torch.log(self.var) + self.deterministic = deterministic + + def sample(self, generator: Optional[torch.Generator] = None) -> torch.Tensor: + # make sure sample is on the same device as the parameters and has same dtype + sample = randn_tensor( + self.mean.shape, + generator=generator, + device=self.parameters.device, + dtype=self.parameters.dtype, + ) + x = self.mean + self.std * sample + return x + + def kl(self, other: "OobleckDiagonalGaussianDistribution" = None) -> torch.Tensor: + if self.deterministic: + return torch.Tensor([0.0]) + else: + if other is None: + return (self.mean * self.mean + self.var - self.logvar - 1.0).sum(1).mean() + else: + normalized_diff = torch.pow(self.mean - other.mean, 2) / other.var + var_ratio = self.var / other.var + logvar_diff = self.logvar - other.logvar + + kl = normalized_diff + var_ratio + logvar_diff - 1 + + kl = kl.sum(1).mean() + return kl + + def mode(self) -> torch.Tensor: + return self.mean + + +@dataclass +class AutoencoderOobleckOutput(BaseOutput): + """ + Output of AutoencoderOobleck encoding method. + + Args: + latent_dist (`OobleckDiagonalGaussianDistribution`): + Encoded outputs of `Encoder` represented as the mean and standard deviation of + `OobleckDiagonalGaussianDistribution`. `OobleckDiagonalGaussianDistribution` allows for sampling latents + from the distribution. + """ + + latent_dist: "OobleckDiagonalGaussianDistribution" # noqa: F821 + + +@dataclass +class OobleckDecoderOutput(BaseOutput): + r""" + Output of decoding method. + + Args: + sample (`torch.Tensor` of shape `(batch_size, audio_channels, sequence_length)`): + The decoded output sample from the last layer of the model. + """ + + sample: torch.Tensor + + +class OobleckEncoder(nn.Module): + """Oobleck Encoder""" + + def __init__(self, encoder_hidden_size, audio_channels, downsampling_ratios, channel_multiples): + super().__init__() + + strides = downsampling_ratios + channel_multiples = [1] + channel_multiples + + # Create first convolution + self.conv1 = weight_norm(nn.Conv1d(audio_channels, encoder_hidden_size, kernel_size=7, padding=3)) + + self.block = [] + # Create EncoderBlocks that double channels as they downsample by `stride` + for stride_index, stride in enumerate(strides): + self.block += [ + OobleckEncoderBlock( + input_dim=encoder_hidden_size * channel_multiples[stride_index], + output_dim=encoder_hidden_size * channel_multiples[stride_index + 1], + stride=stride, + ) + ] + + self.block = nn.ModuleList(self.block) + d_model = encoder_hidden_size * channel_multiples[-1] + self.snake1 = Snake1d(d_model) + self.conv2 = weight_norm(nn.Conv1d(d_model, encoder_hidden_size, kernel_size=3, padding=1)) + + def forward(self, hidden_state): + hidden_state = self.conv1(hidden_state) + + for module in self.block: + hidden_state = module(hidden_state) + + hidden_state = self.snake1(hidden_state) + hidden_state = self.conv2(hidden_state) + + return hidden_state + + +class OobleckDecoder(nn.Module): + """Oobleck Decoder""" + + def __init__(self, channels, input_channels, audio_channels, upsampling_ratios, channel_multiples): + super().__init__() + + strides = upsampling_ratios + channel_multiples = [1] + channel_multiples + + # Add first conv layer + self.conv1 = weight_norm(nn.Conv1d(input_channels, channels * channel_multiples[-1], kernel_size=7, padding=3)) + + # Add upsampling + MRF blocks + block = [] + for stride_index, stride in enumerate(strides): + block += [ + OobleckDecoderBlock( + input_dim=channels * channel_multiples[len(strides) - stride_index], + output_dim=channels * channel_multiples[len(strides) - stride_index - 1], + stride=stride, + ) + ] + + self.block = nn.ModuleList(block) + output_dim = channels + self.snake1 = Snake1d(output_dim) + self.conv2 = weight_norm(nn.Conv1d(channels, audio_channels, kernel_size=7, padding=3, bias=False)) + + def forward(self, hidden_state): + hidden_state = self.conv1(hidden_state) + + for layer in self.block: + hidden_state = layer(hidden_state) + + hidden_state = self.snake1(hidden_state) + hidden_state = self.conv2(hidden_state) + + return hidden_state + + +class AutoencoderOobleck(ModelMixin, ConfigMixin): + r""" + An autoencoder for encoding waveforms into latents and decoding latent representations into waveforms. First + introduced in Stable Audio. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + encoder_hidden_size (`int`, *optional*, defaults to 128): + Intermediate representation dimension for the encoder. + downsampling_ratios (`List[int]`, *optional*, defaults to `[2, 4, 4, 8, 8]`): + Ratios for downsampling in the encoder. These are used in reverse order for upsampling in the decoder. + channel_multiples (`List[int]`, *optional*, defaults to `[1, 2, 4, 8, 16]`): + Multiples used to determine the hidden sizes of the hidden layers. + decoder_channels (`int`, *optional*, defaults to 128): + Intermediate representation dimension for the decoder. + decoder_input_channels (`int`, *optional*, defaults to 64): + Input dimension for the decoder. Corresponds to the latent dimension. + audio_channels (`int`, *optional*, defaults to 2): + Number of channels in the audio data. Either 1 for mono or 2 for stereo. + sampling_rate (`int`, *optional*, defaults to 44100): + The sampling rate at which the audio waveform should be digitalized expressed in hertz (Hz). + """ + + _supports_gradient_checkpointing = False + + @register_to_config + def __init__( + self, + encoder_hidden_size=128, + downsampling_ratios=[2, 4, 4, 8, 8], + channel_multiples=[1, 2, 4, 8, 16], + decoder_channels=128, + decoder_input_channels=64, + audio_channels=2, + sampling_rate=44100, + ): + super().__init__() + + self.encoder_hidden_size = encoder_hidden_size + self.downsampling_ratios = downsampling_ratios + self.decoder_channels = decoder_channels + self.upsampling_ratios = downsampling_ratios[::-1] + self.hop_length = int(np.prod(downsampling_ratios)) + self.sampling_rate = sampling_rate + + self.encoder = OobleckEncoder( + encoder_hidden_size=encoder_hidden_size, + audio_channels=audio_channels, + downsampling_ratios=downsampling_ratios, + channel_multiples=channel_multiples, + ) + + self.decoder = OobleckDecoder( + channels=decoder_channels, + input_channels=decoder_input_channels, + audio_channels=audio_channels, + upsampling_ratios=self.upsampling_ratios, + channel_multiples=channel_multiples, + ) + + self.use_slicing = False + + def enable_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.use_slicing = True + + def disable_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_slicing = False + + @apply_forward_hook + def encode( + self, x: torch.Tensor, return_dict: bool = True + ) -> Union[AutoencoderOobleckOutput, Tuple[OobleckDiagonalGaussianDistribution]]: + """ + Encode a batch of images into latents. + + Args: + x (`torch.Tensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + The latent representations of the encoded images. If `return_dict` is True, a + [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. + """ + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)] + h = torch.cat(encoded_slices) + else: + h = self.encoder(x) + + posterior = OobleckDiagonalGaussianDistribution(h) + + if not return_dict: + return (posterior,) + + return AutoencoderOobleckOutput(latent_dist=posterior) + + def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[OobleckDecoderOutput, torch.Tensor]: + dec = self.decoder(z) + + if not return_dict: + return (dec,) + + return OobleckDecoderOutput(sample=dec) + + @apply_forward_hook + def decode( + self, z: torch.FloatTensor, return_dict: bool = True, generator=None + ) -> Union[OobleckDecoderOutput, torch.FloatTensor]: + """ + Decode a batch of images. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.vae.OobleckDecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.OobleckDecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.OobleckDecoderOutput`] is returned, otherwise a plain `tuple` + is returned. + + """ + if self.use_slicing and z.shape[0] > 1: + decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = self._decode(z).sample + + if not return_dict: + return (decoded,) + + return OobleckDecoderOutput(sample=decoded) + + def forward( + self, + sample: torch.Tensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: Optional[torch.Generator] = None, + ) -> Union[OobleckDecoderOutput, torch.Tensor]: + r""" + Args: + sample (`torch.Tensor`): Input sample. + sample_posterior (`bool`, *optional*, defaults to `False`): + Whether to sample from the posterior. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`OobleckDecoderOutput`] instead of a plain tuple. + """ + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z).sample + + if not return_dict: + return (dec,) + + return OobleckDecoderOutput(sample=dec) diff --git a/PyTorch/dev/cv/image_classification/SlowFast_ID0646_for_PyTorch/detectron2/projects/Panoptic-DeepLab/configs/COCO-PanopticSegmentation/panoptic_deeplab_R_52_os16_mg124_poly_200k_bs64_crop_640_640_coco_dsconv.yaml b/PyTorch/dev/cv/image_classification/SlowFast_ID0646_for_PyTorch/detectron2/projects/Panoptic-DeepLab/configs/COCO-PanopticSegmentation/panoptic_deeplab_R_52_os16_mg124_poly_200k_bs64_crop_640_640_coco_dsconv.yaml deleted file mode 100644 index 6944c6fdf3..0000000000 --- a/PyTorch/dev/cv/image_classification/SlowFast_ID0646_for_PyTorch/detectron2/projects/Panoptic-DeepLab/configs/COCO-PanopticSegmentation/panoptic_deeplab_R_52_os16_mg124_poly_200k_bs64_crop_640_640_coco_dsconv.yaml +++ /dev/null @@ -1,42 +0,0 @@ -_BASE_: ../Cityscapes-PanopticSegmentation/Base-PanopticDeepLab-OS16.yaml -MODEL: - WEIGHTS: "detectron2://DeepLab/R-52.pkl" - PIXEL_MEAN: [123.675, 116.280, 103.530] - PIXEL_STD: [58.395, 57.120, 57.375] - BACKBONE: - NAME: "build_resnet_deeplab_backbone" - RESNETS: - DEPTH: 50 - NORM: "SyncBN" - RES5_MULTI_GRID: [1, 2, 4] - STEM_TYPE: "deeplab" - STEM_OUT_CHANNELS: 128 - STRIDE_IN_1X1: False - SEM_SEG_HEAD: - NUM_CLASSES: 133 - LOSS_TOP_K: 1.0 - USE_DEPTHWISE_SEPARABLE_CONV: True - PANOPTIC_DEEPLAB: - STUFF_AREA: 4096 - NMS_KERNEL: 41 - SIZE_DIVISIBILITY: 640 - USE_DEPTHWISE_SEPARABLE_CONV: True -DATASETS: - TRAIN: ("coco_2017_train_panoptic",) - TEST: ("coco_2017_val_panoptic",) -SOLVER: - BASE_LR: 0.0005 - MAX_ITER: 200000 - IMS_PER_BATCH: 64 -INPUT: - FORMAT: "RGB" - GAUSSIAN_SIGMA: 8 - MIN_SIZE_TRAIN: !!python/object/apply:eval ["[int(x * 0.1 * 640) for x in range(5, 16)]"] - MIN_SIZE_TRAIN_SAMPLING: "choice" - MIN_SIZE_TEST: 640 - MAX_SIZE_TRAIN: 960 - MAX_SIZE_TEST: 640 - CROP: - ENABLED: True - TYPE: "absolute" - SIZE: (640, 640) -- Gitee From a90c4587d7a5b708b726c738687e79621df10b66 Mon Sep 17 00:00:00 2001 From: zhoufan2956 Date: Tue, 24 Dec 2024 21:01:02 +0800 Subject: [PATCH 02/32] add stable_audio --- .../stableaudio/layers/activations.py | 4 ++-- .../pipeline/pipeline_stable_audio.py | 19 ++++++++++--------- .../{scheduler => schedulers}/__init__.py | 0 .../scheduling_cosine_dpmsolver_multistep.py | 0 .../stableaudio/vae/autoencoder_oobleck.py | 10 +++++----- 5 files changed, 17 insertions(+), 16 deletions(-) rename MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/{scheduler => schedulers}/__init__.py (100%) rename MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/{scheduler => schedulers}/scheduling_cosine_dpmsolver_multistep.py (100%) diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/activations.py b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/activations.py index fb24a36bae..7cd6938b22 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/activations.py +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/activations.py @@ -17,8 +17,8 @@ import torch import torch.nn.functional as F from torch import nn -from ..utils import deprecate -from ..utils.import_utils import is_torch_npu_available +from diffusers.utils import deprecate +from diffusers.utils.import_utils import is_torch_npu_available if is_torch_npu_available(): diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/pipeline/pipeline_stable_audio.py b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/pipeline/pipeline_stable_audio.py index 4fe082d889..0a4698ec09 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/pipeline/pipeline_stable_audio.py +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/pipeline/pipeline_stable_audio.py @@ -22,16 +22,17 @@ from transformers import ( T5TokenizerFast, ) -from ...models import AutoencoderOobleck, StableAudioDiTModel -from ...models.embeddings import get_1d_rotary_pos_embed -from ...schedulers import EDMDPMSolverMultistepScheduler -from ...utils import ( +from ..models import StableAudioDiTModel +from ..models.modeling_stable_audio import StableAudioProjectionModel +from ..layers.embeddings import get_1d_rotary_pos_embed +from ..schedulers.scheduling_cosine_dpmsolver_multistep import CosineDPMSolverMultistepScheduler +from ..vae.autoencoder_oobleck import AutoencoderOobleck +from diffusers.utils import ( logging, replace_example_docstring, ) -from ...utils.torch_utils import randn_tensor -from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline -from .modeling_stable_audio import StableAudioProjectionModel +from diffusers.utils.torch_utils import randn_tensor +from diffusers.pipelines.pipeline_utils import AudioPipelineOutput, DiffusionPipeline logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -93,7 +94,7 @@ class StableAudioPipeline(DiffusionPipeline): Tokenizer to tokenize text for the frozen text-encoder. transformer ([`StableAudioDiTModel`]): A `StableAudioDiTModel` to denoise the encoded audio latents. - scheduler ([`EDMDPMSolverMultistepScheduler`]): + scheduler ([`CosineDPMSolverMultistepScheduler`]): A scheduler to be used in combination with `transformer` to denoise the encoded audio latents. """ @@ -106,7 +107,7 @@ class StableAudioPipeline(DiffusionPipeline): projection_model: StableAudioProjectionModel, tokenizer: Union[T5Tokenizer, T5TokenizerFast], transformer: StableAudioDiTModel, - scheduler: EDMDPMSolverMultistepScheduler, + scheduler: CosineDPMSolverMultistepScheduler, ): super().__init__() diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/scheduler/__init__.py b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/schedulers/__init__.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/scheduler/__init__.py rename to MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/schedulers/__init__.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/scheduler/scheduling_cosine_dpmsolver_multistep.py b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/schedulers/scheduling_cosine_dpmsolver_multistep.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/scheduler/scheduling_cosine_dpmsolver_multistep.py rename to MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/schedulers/scheduling_cosine_dpmsolver_multistep.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/vae/autoencoder_oobleck.py b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/vae/autoencoder_oobleck.py index e8e372a709..8650d40b32 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/vae/autoencoder_oobleck.py +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/vae/autoencoder_oobleck.py @@ -20,11 +20,11 @@ import torch import torch.nn as nn from torch.nn.utils import weight_norm -from ...configuration_utils import ConfigMixin, register_to_config -from ...utils import BaseOutput -from ...utils.accelerate_utils import apply_forward_hook -from ...utils.torch_utils import randn_tensor -from ..modeling_utils import ModelMixin +from diffsuers.configuration_utils import ConfigMixin, register_to_config +from diffsuers.utils import BaseOutput +from diffsuers.utils.accelerate_utils import apply_forward_hook +from diffsuers.utils.torch_utils import randn_tensor +from diffsuers.models.modeling_utils import ModelMixin class Snake1d(nn.Module): -- Gitee From bec6f6fe2abe81aada6554ea6eb70744b9d537ff Mon Sep 17 00:00:00 2001 From: zhoufan2956 Date: Tue, 24 Dec 2024 21:02:16 +0800 Subject: [PATCH 03/32] add stable_audio --- .../stableaudio/layers/embeddings.py | 881 +----------------- .../models/modeling_stable_audio.py | 6 +- 2 files changed, 4 insertions(+), 883 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/embeddings.py b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/embeddings.py index 1258964385..de34071086 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/embeddings.py +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/embeddings.py @@ -19,112 +19,11 @@ import torch import torch.nn.functional as F from torch import nn -from ..utils import deprecate +from diffusers.utils import deprecate from .activations import FP32SiLU, get_activation from .attention_processor import Attention -def get_timestep_embedding( - timesteps: torch.Tensor, - embedding_dim: int, - flip_sin_to_cos: bool = False, - downscale_freq_shift: float = 1, - scale: float = 1, - max_period: int = 10000, -): - """ - This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. - - Args - timesteps (torch.Tensor): - a 1-D Tensor of N indices, one per batch element. These may be fractional. - embedding_dim (int): - the dimension of the output. - flip_sin_to_cos (bool): - Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False) - downscale_freq_shift (float): - Controls the delta between frequencies between dimensions - scale (float): - Scaling factor applied to the embeddings. - max_period (int): - Controls the maximum frequency of the embeddings - Returns - torch.Tensor: an [N x dim] Tensor of positional embeddings. - """ - assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" - - half_dim = embedding_dim // 2 - exponent = -math.log(max_period) * torch.arange( - start=0, end=half_dim, dtype=torch.float32, device=timesteps.device - ) - exponent = exponent / (half_dim - downscale_freq_shift) - - emb = torch.exp(exponent) - emb = timesteps[:, None].float() * emb[None, :] - - # scale embeddings - emb = scale * emb - - # concat sine and cosine embeddings - emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) - - # flip sine and cosine embeddings - if flip_sin_to_cos: - emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) - - # zero pad - if embedding_dim % 2 == 1: - emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) - return emb - - -def get_3d_sincos_pos_embed( - embed_dim: int, - spatial_size: Union[int, Tuple[int, int]], - temporal_size: int, - spatial_interpolation_scale: float = 1.0, - temporal_interpolation_scale: float = 1.0, -) -> np.ndarray: - r""" - Args: - embed_dim (`int`): - spatial_size (`int` or `Tuple[int, int]`): - temporal_size (`int`): - spatial_interpolation_scale (`float`, defaults to 1.0): - temporal_interpolation_scale (`float`, defaults to 1.0): - """ - if embed_dim % 4 != 0: - raise ValueError("`embed_dim` must be divisible by 4") - if isinstance(spatial_size, int): - spatial_size = (spatial_size, spatial_size) - - embed_dim_spatial = 3 * embed_dim // 4 - embed_dim_temporal = embed_dim // 4 - - # 1. Spatial - grid_h = np.arange(spatial_size[1], dtype=np.float32) / spatial_interpolation_scale - grid_w = np.arange(spatial_size[0], dtype=np.float32) / spatial_interpolation_scale - grid = np.meshgrid(grid_w, grid_h) # here w goes first - grid = np.stack(grid, axis=0) - - grid = grid.reshape([2, 1, spatial_size[1], spatial_size[0]]) - pos_embed_spatial = get_2d_sincos_pos_embed_from_grid(embed_dim_spatial, grid) - - # 2. Temporal - grid_t = np.arange(temporal_size, dtype=np.float32) / temporal_interpolation_scale - pos_embed_temporal = get_1d_sincos_pos_embed_from_grid(embed_dim_temporal, grid_t) - - # 3. Concat - pos_embed_spatial = pos_embed_spatial[np.newaxis, :, :] - pos_embed_spatial = np.repeat(pos_embed_spatial, temporal_size, axis=0) # [T, H*W, D // 4 * 3] - - pos_embed_temporal = pos_embed_temporal[:, np.newaxis, :] - pos_embed_temporal = np.repeat(pos_embed_temporal, spatial_size[0] * spatial_size[1], axis=1) # [T, H*W, D // 4] - - pos_embed = np.concatenate([pos_embed_temporal, pos_embed_spatial], axis=-1) # [T, H*W, D] - return pos_embed - - def get_2d_sincos_pos_embed( embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16 ): @@ -288,156 +187,6 @@ class PatchEmbed(nn.Module): return (latent + pos_embed).to(latent.dtype) -class LuminaPatchEmbed(nn.Module): - """2D Image to Patch Embedding with support for Lumina-T2X""" - - def __init__(self, patch_size=2, in_channels=4, embed_dim=768, bias=True): - super().__init__() - self.patch_size = patch_size - self.proj = nn.Linear( - in_features=patch_size * patch_size * in_channels, - out_features=embed_dim, - bias=bias, - ) - - def forward(self, x, freqs_cis): - """ - Patchifies and embeds the input tensor(s). - - Args: - x (List[torch.Tensor] | torch.Tensor): The input tensor(s) to be patchified and embedded. - - Returns: - Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], torch.Tensor]: A tuple containing the patchified - and embedded tensor(s), the mask indicating the valid patches, the original image size(s), and the - frequency tensor(s). - """ - freqs_cis = freqs_cis.to(x[0].device) - patch_height = patch_width = self.patch_size - batch_size, channel, height, width = x.size() - height_tokens, width_tokens = height // patch_height, width // patch_width - - x = x.view(batch_size, channel, height_tokens, patch_height, width_tokens, patch_width).permute( - 0, 2, 4, 1, 3, 5 - ) - x = x.flatten(3) - x = self.proj(x) - x = x.flatten(1, 2) - - mask = torch.ones(x.shape[0], x.shape[1], dtype=torch.int32, device=x.device) - - return ( - x, - mask, - [(height, width)] * batch_size, - freqs_cis[:height_tokens, :width_tokens].flatten(0, 1).unsqueeze(0), - ) - - -class CogVideoXPatchEmbed(nn.Module): - def __init__( - self, - patch_size: int = 2, - in_channels: int = 16, - embed_dim: int = 1920, - text_embed_dim: int = 4096, - bias: bool = True, - ) -> None: - super().__init__() - self.patch_size = patch_size - - self.proj = nn.Conv2d( - in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias - ) - self.text_proj = nn.Linear(text_embed_dim, embed_dim) - - def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor): - r""" - Args: - text_embeds (`torch.Tensor`): - Input text embeddings. Expected shape: (batch_size, seq_length, embedding_dim). - image_embeds (`torch.Tensor`): - Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width). - """ - text_embeds = self.text_proj(text_embeds) - - batch, num_frames, channels, height, width = image_embeds.shape - image_embeds = image_embeds.reshape(-1, channels, height, width) - image_embeds = self.proj(image_embeds) - image_embeds = image_embeds.view(batch, num_frames, *image_embeds.shape[1:]) - image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels] - image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels] - - embeds = torch.cat( - [text_embeds, image_embeds], dim=1 - ).contiguous() # [batch, seq_length + num_frames x height x width, channels] - return embeds - - -def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True): - """ - RoPE for image tokens with 2d structure. - - Args: - embed_dim: (`int`): - The embedding dimension size - crops_coords (`Tuple[int]`) - The top-left and bottom-right coordinates of the crop. - grid_size (`Tuple[int]`): - The grid size of the positional embedding. - use_real (`bool`): - If True, return real part and imaginary part separately. Otherwise, return complex numbers. - - Returns: - `torch.Tensor`: positional embedding with shape `( grid_size * grid_size, embed_dim/2)`. - """ - start, stop = crops_coords - grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32) - grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32) - grid = np.meshgrid(grid_w, grid_h) # here w goes first - grid = np.stack(grid, axis=0) # [2, W, H] - - grid = grid.reshape([2, 1, *grid.shape[1:]]) - pos_embed = get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real) - return pos_embed - - -def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False): - assert embed_dim % 4 == 0 - - # use half of dimensions to encode grid_h - emb_h = get_1d_rotary_pos_embed( - embed_dim // 2, grid[0].reshape(-1), use_real=use_real - ) # (H*W, D/2) if use_real else (H*W, D/4) - emb_w = get_1d_rotary_pos_embed( - embed_dim // 2, grid[1].reshape(-1), use_real=use_real - ) # (H*W, D/2) if use_real else (H*W, D/4) - - if use_real: - cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D) - sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D) - return cos, sin - else: - emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D/2) - return emb - - -def get_2d_rotary_pos_embed_lumina(embed_dim, len_h, len_w, linear_factor=1.0, ntk_factor=1.0): - assert embed_dim % 4 == 0 - - emb_h = get_1d_rotary_pos_embed( - embed_dim // 2, len_h, linear_factor=linear_factor, ntk_factor=ntk_factor - ) # (H, D/4) - emb_w = get_1d_rotary_pos_embed( - embed_dim // 2, len_w, linear_factor=linear_factor, ntk_factor=ntk_factor - ) # (W, D/4) - emb_h = emb_h.view(len_h, 1, embed_dim // 4, 1).repeat(1, len_w, 1, 1) # (H, W, D/4, 1) - emb_w = emb_w.view(1, len_w, embed_dim // 4, 1).repeat(len_h, 1, 1, 1) # (H, W, D/4, 1) - - emb = torch.cat([emb_h, emb_w], dim=-1).flatten(2) # (H, W, D/2) - return emb - - def get_1d_rotary_pos_embed( dim: int, pos: Union[np.ndarray, int], @@ -540,104 +289,6 @@ def apply_rotary_emb( return x_out.type_as(x) -class TimestepEmbedding(nn.Module): - def __init__( - self, - in_channels: int, - time_embed_dim: int, - act_fn: str = "silu", - out_dim: int = None, - post_act_fn: Optional[str] = None, - cond_proj_dim=None, - sample_proj_bias=True, - ): - super().__init__() - - self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias) - - if cond_proj_dim is not None: - self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False) - else: - self.cond_proj = None - - self.act = get_activation(act_fn) - - if out_dim is not None: - time_embed_dim_out = out_dim - else: - time_embed_dim_out = time_embed_dim - self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias) - - if post_act_fn is None: - self.post_act = None - else: - self.post_act = get_activation(post_act_fn) - - def forward(self, sample, condition=None): - if condition is not None: - sample = sample + self.cond_proj(condition) - sample = self.linear_1(sample) - - if self.act is not None: - sample = self.act(sample) - - sample = self.linear_2(sample) - - if self.post_act is not None: - sample = self.post_act(sample) - return sample - - -class Timesteps(nn.Module): - def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1): - super().__init__() - self.num_channels = num_channels - self.flip_sin_to_cos = flip_sin_to_cos - self.downscale_freq_shift = downscale_freq_shift - self.scale = scale - - def forward(self, timesteps): - t_emb = get_timestep_embedding( - timesteps, - self.num_channels, - flip_sin_to_cos=self.flip_sin_to_cos, - downscale_freq_shift=self.downscale_freq_shift, - scale=self.scale, - ) - return t_emb - - -class GaussianFourierProjection(nn.Module): - """Gaussian Fourier embeddings for noise levels.""" - - def __init__( - self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False - ): - super().__init__() - self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) - self.log = log - self.flip_sin_to_cos = flip_sin_to_cos - - if set_W_to_weight: - # to delete later - del self.weight - self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) - self.weight = self.W - del self.W - - def forward(self, x): - if self.log: - x = torch.log(x) - - x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi - - if self.flip_sin_to_cos: - out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1) - else: - out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) - return out - - class SinusoidalPositionalEmbedding(nn.Module): """Apply positional information to a sequence of embeddings. @@ -730,69 +381,6 @@ class ImagePositionalEmbeddings(nn.Module): return emb -class LabelEmbedding(nn.Module): - """ - Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. - - Args: - num_classes (`int`): The number of classes. - hidden_size (`int`): The size of the vector embeddings. - dropout_prob (`float`): The probability of dropping a label. - """ - - def __init__(self, num_classes, hidden_size, dropout_prob): - super().__init__() - use_cfg_embedding = dropout_prob > 0 - self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size) - self.num_classes = num_classes - self.dropout_prob = dropout_prob - - def token_drop(self, labels, force_drop_ids=None): - """ - Drops labels to enable classifier-free guidance. - """ - if force_drop_ids is None: - drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob - else: - drop_ids = torch.tensor(force_drop_ids == 1) - labels = torch.where(drop_ids, self.num_classes, labels) - return labels - - def forward(self, labels: torch.LongTensor, force_drop_ids=None): - use_dropout = self.dropout_prob > 0 - if (self.training and use_dropout) or (force_drop_ids is not None): - labels = self.token_drop(labels, force_drop_ids) - embeddings = self.embedding_table(labels) - return embeddings - - -class TextImageProjection(nn.Module): - def __init__( - self, - text_embed_dim: int = 1024, - image_embed_dim: int = 768, - cross_attention_dim: int = 768, - num_image_text_embeds: int = 10, - ): - super().__init__() - - self.num_image_text_embeds = num_image_text_embeds - self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim) - self.text_proj = nn.Linear(text_embed_dim, cross_attention_dim) - - def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor): - batch_size = text_embeds.shape[0] - - # image - image_text_embeds = self.image_embeds(image_embeds) - image_text_embeds = image_text_embeds.reshape(batch_size, self.num_image_text_embeds, -1) - - # text - text_embeds = self.text_proj(text_embeds) - - return torch.cat([image_text_embeds, text_embeds], dim=1) - - class ImageProjection(nn.Module): def __init__( self, @@ -816,473 +404,6 @@ class ImageProjection(nn.Module): return image_embeds -class IPAdapterFullImageProjection(nn.Module): - def __init__(self, image_embed_dim=1024, cross_attention_dim=1024): - super().__init__() - from .attention import FeedForward - - self.ff = FeedForward(image_embed_dim, cross_attention_dim, mult=1, activation_fn="gelu") - self.norm = nn.LayerNorm(cross_attention_dim) - - def forward(self, image_embeds: torch.Tensor): - return self.norm(self.ff(image_embeds)) - - -class IPAdapterFaceIDImageProjection(nn.Module): - def __init__(self, image_embed_dim=1024, cross_attention_dim=1024, mult=1, num_tokens=1): - super().__init__() - from .attention import FeedForward - - self.num_tokens = num_tokens - self.cross_attention_dim = cross_attention_dim - self.ff = FeedForward(image_embed_dim, cross_attention_dim * num_tokens, mult=mult, activation_fn="gelu") - self.norm = nn.LayerNorm(cross_attention_dim) - - def forward(self, image_embeds: torch.Tensor): - x = self.ff(image_embeds) - x = x.reshape(-1, self.num_tokens, self.cross_attention_dim) - return self.norm(x) - - -class CombinedTimestepLabelEmbeddings(nn.Module): - def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1): - super().__init__() - - self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1) - self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) - self.class_embedder = LabelEmbedding(num_classes, embedding_dim, class_dropout_prob) - - def forward(self, timestep, class_labels, hidden_dtype=None): - timesteps_proj = self.time_proj(timestep) - timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D) - - class_labels = self.class_embedder(class_labels) # (N, D) - - conditioning = timesteps_emb + class_labels # (N, D) - - return conditioning - - -class CombinedTimestepTextProjEmbeddings(nn.Module): - def __init__(self, embedding_dim, pooled_projection_dim): - super().__init__() - - self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) - self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) - self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu") - - def forward(self, timestep, pooled_projection): - timesteps_proj = self.time_proj(timestep) - timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D) - - pooled_projections = self.text_embedder(pooled_projection) - - conditioning = timesteps_emb + pooled_projections - - return conditioning - - -class CombinedTimestepGuidanceTextProjEmbeddings(nn.Module): - def __init__(self, embedding_dim, pooled_projection_dim): - super().__init__() - - self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) - self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) - self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) - self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu") - - def forward(self, timestep, guidance, pooled_projection): - timesteps_proj = self.time_proj(timestep) - timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D) - - guidance_proj = self.time_proj(guidance) - guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=pooled_projection.dtype)) # (N, D) - - time_guidance_emb = timesteps_emb + guidance_emb - - pooled_projections = self.text_embedder(pooled_projection) - conditioning = time_guidance_emb + pooled_projections - - return conditioning - - -class HunyuanDiTAttentionPool(nn.Module): - # Copied from https://github.com/Tencent/HunyuanDiT/blob/cb709308d92e6c7e8d59d0dff41b74d35088db6a/hydit/modules/poolers.py#L6 - - def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): - super().__init__() - self.positional_embedding = nn.Parameter(torch.randn(spacial_dim + 1, embed_dim) / embed_dim**0.5) - self.k_proj = nn.Linear(embed_dim, embed_dim) - self.q_proj = nn.Linear(embed_dim, embed_dim) - self.v_proj = nn.Linear(embed_dim, embed_dim) - self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) - self.num_heads = num_heads - - def forward(self, x): - x = x.permute(1, 0, 2) # NLC -> LNC - x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC - x = x + self.positional_embedding[:, None, :].to(x.dtype) # (L+1)NC - x, _ = F.multi_head_attention_forward( - query=x[:1], - key=x, - value=x, - embed_dim_to_check=x.shape[-1], - num_heads=self.num_heads, - q_proj_weight=self.q_proj.weight, - k_proj_weight=self.k_proj.weight, - v_proj_weight=self.v_proj.weight, - in_proj_weight=None, - in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), - bias_k=None, - bias_v=None, - add_zero_attn=False, - dropout_p=0, - out_proj_weight=self.c_proj.weight, - out_proj_bias=self.c_proj.bias, - use_separate_proj_weight=True, - training=self.training, - need_weights=False, - ) - return x.squeeze(0) - - -class HunyuanCombinedTimestepTextSizeStyleEmbedding(nn.Module): - def __init__( - self, - embedding_dim, - pooled_projection_dim=1024, - seq_len=256, - cross_attention_dim=2048, - use_style_cond_and_image_meta_size=True, - ): - super().__init__() - - self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) - self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) - - self.size_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) - - self.pooler = HunyuanDiTAttentionPool( - seq_len, cross_attention_dim, num_heads=8, output_dim=pooled_projection_dim - ) - - # Here we use a default learned embedder layer for future extension. - self.use_style_cond_and_image_meta_size = use_style_cond_and_image_meta_size - if use_style_cond_and_image_meta_size: - self.style_embedder = nn.Embedding(1, embedding_dim) - extra_in_dim = 256 * 6 + embedding_dim + pooled_projection_dim - else: - extra_in_dim = pooled_projection_dim - - self.extra_embedder = PixArtAlphaTextProjection( - in_features=extra_in_dim, - hidden_size=embedding_dim * 4, - out_features=embedding_dim, - act_fn="silu_fp32", - ) - - def forward(self, timestep, encoder_hidden_states, image_meta_size, style, hidden_dtype=None): - timesteps_proj = self.time_proj(timestep) - timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, 256) - - # extra condition1: text - pooled_projections = self.pooler(encoder_hidden_states) # (N, 1024) - - if self.use_style_cond_and_image_meta_size: - # extra condition2: image meta size embedding - image_meta_size = self.size_proj(image_meta_size.view(-1)) - image_meta_size = image_meta_size.to(dtype=hidden_dtype) - image_meta_size = image_meta_size.view(-1, 6 * 256) # (N, 1536) - - # extra condition3: style embedding - style_embedding = self.style_embedder(style) # (N, embedding_dim) - - # Concatenate all extra vectors - extra_cond = torch.cat([pooled_projections, image_meta_size, style_embedding], dim=1) - else: - extra_cond = torch.cat([pooled_projections], dim=1) - - conditioning = timesteps_emb + self.extra_embedder(extra_cond) # [B, D] - - return conditioning - - -class LuminaCombinedTimestepCaptionEmbedding(nn.Module): - def __init__(self, hidden_size=4096, cross_attention_dim=2048, frequency_embedding_size=256): - super().__init__() - self.time_proj = Timesteps( - num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0.0 - ) - - self.timestep_embedder = TimestepEmbedding(in_channels=frequency_embedding_size, time_embed_dim=hidden_size) - - self.caption_embedder = nn.Sequential( - nn.LayerNorm(cross_attention_dim), - nn.Linear( - cross_attention_dim, - hidden_size, - bias=True, - ), - ) - - def forward(self, timestep, caption_feat, caption_mask): - # timestep embedding: - time_freq = self.time_proj(timestep) - time_embed = self.timestep_embedder(time_freq.to(dtype=self.timestep_embedder.linear_1.weight.dtype)) - - # caption condition embedding: - caption_mask_float = caption_mask.float().unsqueeze(-1) - caption_feats_pool = (caption_feat * caption_mask_float).sum(dim=1) / caption_mask_float.sum(dim=1) - caption_feats_pool = caption_feats_pool.to(caption_feat) - caption_embed = self.caption_embedder(caption_feats_pool) - - conditioning = time_embed + caption_embed - - return conditioning - - -class TextTimeEmbedding(nn.Module): - def __init__(self, encoder_dim: int, time_embed_dim: int, num_heads: int = 64): - super().__init__() - self.norm1 = nn.LayerNorm(encoder_dim) - self.pool = AttentionPooling(num_heads, encoder_dim) - self.proj = nn.Linear(encoder_dim, time_embed_dim) - self.norm2 = nn.LayerNorm(time_embed_dim) - - def forward(self, hidden_states): - hidden_states = self.norm1(hidden_states) - hidden_states = self.pool(hidden_states) - hidden_states = self.proj(hidden_states) - hidden_states = self.norm2(hidden_states) - return hidden_states - - -class TextImageTimeEmbedding(nn.Module): - def __init__(self, text_embed_dim: int = 768, image_embed_dim: int = 768, time_embed_dim: int = 1536): - super().__init__() - self.text_proj = nn.Linear(text_embed_dim, time_embed_dim) - self.text_norm = nn.LayerNorm(time_embed_dim) - self.image_proj = nn.Linear(image_embed_dim, time_embed_dim) - - def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor): - # text - time_text_embeds = self.text_proj(text_embeds) - time_text_embeds = self.text_norm(time_text_embeds) - - # image - time_image_embeds = self.image_proj(image_embeds) - - return time_image_embeds + time_text_embeds - - -class ImageTimeEmbedding(nn.Module): - def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536): - super().__init__() - self.image_proj = nn.Linear(image_embed_dim, time_embed_dim) - self.image_norm = nn.LayerNorm(time_embed_dim) - - def forward(self, image_embeds: torch.Tensor): - # image - time_image_embeds = self.image_proj(image_embeds) - time_image_embeds = self.image_norm(time_image_embeds) - return time_image_embeds - - -class ImageHintTimeEmbedding(nn.Module): - def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536): - super().__init__() - self.image_proj = nn.Linear(image_embed_dim, time_embed_dim) - self.image_norm = nn.LayerNorm(time_embed_dim) - self.input_hint_block = nn.Sequential( - nn.Conv2d(3, 16, 3, padding=1), - nn.SiLU(), - nn.Conv2d(16, 16, 3, padding=1), - nn.SiLU(), - nn.Conv2d(16, 32, 3, padding=1, stride=2), - nn.SiLU(), - nn.Conv2d(32, 32, 3, padding=1), - nn.SiLU(), - nn.Conv2d(32, 96, 3, padding=1, stride=2), - nn.SiLU(), - nn.Conv2d(96, 96, 3, padding=1), - nn.SiLU(), - nn.Conv2d(96, 256, 3, padding=1, stride=2), - nn.SiLU(), - nn.Conv2d(256, 4, 3, padding=1), - ) - - def forward(self, image_embeds: torch.Tensor, hint: torch.Tensor): - # image - time_image_embeds = self.image_proj(image_embeds) - time_image_embeds = self.image_norm(time_image_embeds) - hint = self.input_hint_block(hint) - return time_image_embeds, hint - - -class AttentionPooling(nn.Module): - # Copied from https://github.com/deep-floyd/IF/blob/2f91391f27dd3c468bf174be5805b4cc92980c0b/deepfloyd_if/model/nn.py#L54 - - def __init__(self, num_heads, embed_dim, dtype=None): - super().__init__() - self.dtype = dtype - self.positional_embedding = nn.Parameter(torch.randn(1, embed_dim) / embed_dim**0.5) - self.k_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype) - self.q_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype) - self.v_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype) - self.num_heads = num_heads - self.dim_per_head = embed_dim // self.num_heads - - def forward(self, x): - bs, length, width = x.size() - - def shape(x): - # (bs, length, width) --> (bs, length, n_heads, dim_per_head) - x = x.view(bs, -1, self.num_heads, self.dim_per_head) - # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) - x = x.transpose(1, 2) - # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) - x = x.reshape(bs * self.num_heads, -1, self.dim_per_head) - # (bs*n_heads, length, dim_per_head) --> (bs*n_heads, dim_per_head, length) - x = x.transpose(1, 2) - return x - - class_token = x.mean(dim=1, keepdim=True) + self.positional_embedding.to(x.dtype) - x = torch.cat([class_token, x], dim=1) # (bs, length+1, width) - - # (bs*n_heads, class_token_length, dim_per_head) - q = shape(self.q_proj(class_token)) - # (bs*n_heads, length+class_token_length, dim_per_head) - k = shape(self.k_proj(x)) - v = shape(self.v_proj(x)) - - # (bs*n_heads, class_token_length, length+class_token_length): - scale = 1 / math.sqrt(math.sqrt(self.dim_per_head)) - weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards - weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) - - # (bs*n_heads, dim_per_head, class_token_length) - a = torch.einsum("bts,bcs->bct", weight, v) - - # (bs, length+1, width) - a = a.reshape(bs, -1, 1).transpose(1, 2) - - return a[:, 0, :] # cls_token - - -def get_fourier_embeds_from_boundingbox(embed_dim, box): - """ - Args: - embed_dim: int - box: a 3-D tensor [B x N x 4] representing the bounding boxes for GLIGEN pipeline - Returns: - [B x N x embed_dim] tensor of positional embeddings - """ - - batch_size, num_boxes = box.shape[:2] - - emb = 100 ** (torch.arange(embed_dim) / embed_dim) - emb = emb[None, None, None].to(device=box.device, dtype=box.dtype) - emb = emb * box.unsqueeze(-1) - - emb = torch.stack((emb.sin(), emb.cos()), dim=-1) - emb = emb.permute(0, 1, 3, 4, 2).reshape(batch_size, num_boxes, embed_dim * 2 * 4) - - return emb - - -class GLIGENTextBoundingboxProjection(nn.Module): - def __init__(self, positive_len, out_dim, feature_type="text-only", fourier_freqs=8): - super().__init__() - self.positive_len = positive_len - self.out_dim = out_dim - - self.fourier_embedder_dim = fourier_freqs - self.position_dim = fourier_freqs * 2 * 4 # 2: sin/cos, 4: xyxy - - if isinstance(out_dim, tuple): - out_dim = out_dim[0] - - if feature_type == "text-only": - self.linears = nn.Sequential( - nn.Linear(self.positive_len + self.position_dim, 512), - nn.SiLU(), - nn.Linear(512, 512), - nn.SiLU(), - nn.Linear(512, out_dim), - ) - self.null_positive_feature = torch.nn.Parameter(torch.zeros([self.positive_len])) - - elif feature_type == "text-image": - self.linears_text = nn.Sequential( - nn.Linear(self.positive_len + self.position_dim, 512), - nn.SiLU(), - nn.Linear(512, 512), - nn.SiLU(), - nn.Linear(512, out_dim), - ) - self.linears_image = nn.Sequential( - nn.Linear(self.positive_len + self.position_dim, 512), - nn.SiLU(), - nn.Linear(512, 512), - nn.SiLU(), - nn.Linear(512, out_dim), - ) - self.null_text_feature = torch.nn.Parameter(torch.zeros([self.positive_len])) - self.null_image_feature = torch.nn.Parameter(torch.zeros([self.positive_len])) - - self.null_position_feature = torch.nn.Parameter(torch.zeros([self.position_dim])) - - def forward( - self, - boxes, - masks, - positive_embeddings=None, - phrases_masks=None, - image_masks=None, - phrases_embeddings=None, - image_embeddings=None, - ): - masks = masks.unsqueeze(-1) - - # embedding position (it may includes padding as placeholder) - xyxy_embedding = get_fourier_embeds_from_boundingbox(self.fourier_embedder_dim, boxes) # B*N*4 -> B*N*C - - # learnable null embedding - xyxy_null = self.null_position_feature.view(1, 1, -1) - - # replace padding with learnable null embedding - xyxy_embedding = xyxy_embedding * masks + (1 - masks) * xyxy_null - - # positionet with text only information - if positive_embeddings is not None: - # learnable null embedding - positive_null = self.null_positive_feature.view(1, 1, -1) - - # replace padding with learnable null embedding - positive_embeddings = positive_embeddings * masks + (1 - masks) * positive_null - - objs = self.linears(torch.cat([positive_embeddings, xyxy_embedding], dim=-1)) - - # positionet with text and image information - else: - phrases_masks = phrases_masks.unsqueeze(-1) - image_masks = image_masks.unsqueeze(-1) - - # learnable null embedding - text_null = self.null_text_feature.view(1, 1, -1) - image_null = self.null_image_feature.view(1, 1, -1) - - # replace padding with learnable null embedding - phrases_embeddings = phrases_embeddings * phrases_masks + (1 - phrases_masks) * text_null - image_embeddings = image_embeddings * image_masks + (1 - image_masks) * image_null - - objs_text = self.linears_text(torch.cat([phrases_embeddings, xyxy_embedding], dim=-1)) - objs_image = self.linears_image(torch.cat([image_embeddings, xyxy_embedding], dim=-1)) - objs = torch.cat([objs_text, objs_image], dim=1) - - return objs - - class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module): """ For PixArt-Alpha. diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/models/modeling_stable_audio.py b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/models/modeling_stable_audio.py index b8f8a705de..6e31de55c8 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/models/modeling_stable_audio.py +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/models/modeling_stable_audio.py @@ -20,9 +20,9 @@ import torch import torch.nn as nn import torch.utils.checkpoint -from ...configuration_utils import ConfigMixin, register_to_config -from ...models.modeling_utils import ModelMixin -from ...utils import BaseOutput, logging +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.modeling_utils import ModelMixin +from diffusers.utils import BaseOutput, logging logger = logging.get_logger(__name__) # pylint: disable=invalid-name -- Gitee From e62037a02c6bf1919c23360e4164ae10417f3358 Mon Sep 17 00:00:00 2001 From: zhoufan2956 Date: Tue, 24 Dec 2024 21:03:37 +0800 Subject: [PATCH 04/32] add stable_audio --- .../foundation/stable_audio/stableaudio/layers/attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/attention.py b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/attention.py index d8f4b1ceac..46c4ea495b 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/attention.py +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/attention.py @@ -19,7 +19,7 @@ from torch import nn from diffusers.utils import deprecate, logging from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, SwiGLU -from .attention_processor import Attention, JointAttnProcessor2_0 +from .attention_processor import Attention from .embeddings import SinusoidalPositionalEmbedding from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero -- Gitee From 8f57f533f7359591bf45825da2bf76b15872498a Mon Sep 17 00:00:00 2001 From: zhoufan2956 Date: Tue, 24 Dec 2024 21:08:31 +0800 Subject: [PATCH 05/32] add stable_audio --- .../stableaudio/layers/embeddings.py | 384 ++++++++---------- 1 file changed, 177 insertions(+), 207 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/embeddings.py b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/embeddings.py index de34071086..a4d66b7d1d 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/embeddings.py +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/embeddings.py @@ -24,6 +24,61 @@ from .activations import FP32SiLU, get_activation from .attention_processor import Attention + +def get_timestep_embedding( + timesteps: torch.Tensor, + embedding_dim: int, + flip_sin_to_cos: bool = False, + downscale_freq_shift: float = 1, + scale: float = 1, + max_period: int = 10000, +): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. + + Args + timesteps (torch.Tensor): + a 1-D Tensor of N indices, one per batch element. These may be fractional. + embedding_dim (int): + the dimension of the output. + flip_sin_to_cos (bool): + Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False) + downscale_freq_shift (float): + Controls the delta between frequencies between dimensions + scale (float): + Scaling factor applied to the embeddings. + max_period (int): + Controls the maximum frequency of the embeddings + Returns + torch.Tensor: an [N x dim] Tensor of positional embeddings. + """ + assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" + + half_dim = embedding_dim // 2 + exponent = -math.log(max_period) * torch.arange( + start=0, end=half_dim, dtype=torch.float32, device=timesteps.device + ) + exponent = exponent / (half_dim - downscale_freq_shift) + + emb = torch.exp(exponent) + emb = timesteps[:, None].float() * emb[None, :] + + # scale embeddings + emb = scale * emb + + # concat sine and cosine embeddings + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) + + # flip sine and cosine embeddings + if flip_sin_to_cos: + emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) + + # zero pad + if embedding_dim % 2 == 1: + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + def get_2d_sincos_pos_embed( embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16 ): @@ -404,6 +459,128 @@ class ImageProjection(nn.Module): return image_embeds +class TimestepEmbedding(nn.Module): + def __init__( + self, + in_channels: int, + time_embed_dim: int, + act_fn: str = "silu", + out_dim: int = None, + post_act_fn: Optional[str] = None, + cond_proj_dim=None, + sample_proj_bias=True, + ): + super().__init__() + + self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias) + + if cond_proj_dim is not None: + self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False) + else: + self.cond_proj = None + + self.act = get_activation(act_fn) + + if out_dim is not None: + time_embed_dim_out = out_dim + else: + time_embed_dim_out = time_embed_dim + self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias) + + if post_act_fn is None: + self.post_act = None + else: + self.post_act = get_activation(post_act_fn) + + def forward(self, sample, condition=None): + if condition is not None: + sample = sample + self.cond_proj(condition) + sample = self.linear_1(sample) + + if self.act is not None: + sample = self.act(sample) + + sample = self.linear_2(sample) + + if self.post_act is not None: + sample = self.post_act(sample) + return sample + + +class Timesteps(nn.Module): + def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1): + super().__init__() + self.num_channels = num_channels + self.flip_sin_to_cos = flip_sin_to_cos + self.downscale_freq_shift = downscale_freq_shift + self.scale = scale + + def forward(self, timesteps): + t_emb = get_timestep_embedding( + timesteps, + self.num_channels, + flip_sin_to_cos=self.flip_sin_to_cos, + downscale_freq_shift=self.downscale_freq_shift, + scale=self.scale, + ) + return t_emb + + +class LabelEmbedding(nn.Module): + """ + Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. + + Args: + num_classes (`int`): The number of classes. + hidden_size (`int`): The size of the vector embeddings. + dropout_prob (`float`): The probability of dropping a label. + """ + + def __init__(self, num_classes, hidden_size, dropout_prob): + super().__init__() + use_cfg_embedding = dropout_prob > 0 + self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size) + self.num_classes = num_classes + self.dropout_prob = dropout_prob + + def token_drop(self, labels, force_drop_ids=None): + """ + Drops labels to enable classifier-free guidance. + """ + if force_drop_ids is None: + drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob + else: + drop_ids = torch.tensor(force_drop_ids == 1) + labels = torch.where(drop_ids, self.num_classes, labels) + return labels + + def forward(self, labels: torch.LongTensor, force_drop_ids=None): + use_dropout = self.dropout_prob > 0 + if (self.training and use_dropout) or (force_drop_ids is not None): + labels = self.token_drop(labels, force_drop_ids) + embeddings = self.embedding_table(labels) + return embeddings + + +class CombinedTimestepLabelEmbeddings(nn.Module): + def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1): + super().__init__() + + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1) + self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + self.class_embedder = LabelEmbedding(num_classes, embedding_dim, class_dropout_prob) + + def forward(self, timestep, class_labels, hidden_dtype=None): + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D) + + class_labels = self.class_embedder(class_labels) # (N, D) + + conditioning = timesteps_emb + class_labels # (N, D) + + return conditioning + + class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module): """ For PixArt-Alpha. @@ -468,210 +645,3 @@ class PixArtAlphaTextProjection(nn.Module): hidden_states = self.act_1(hidden_states) hidden_states = self.linear_2(hidden_states) return hidden_states - - -class IPAdapterPlusImageProjectionBlock(nn.Module): - def __init__( - self, - embed_dims: int = 768, - dim_head: int = 64, - heads: int = 16, - ffn_ratio: float = 4, - ) -> None: - super().__init__() - from .attention import FeedForward - - self.ln0 = nn.LayerNorm(embed_dims) - self.ln1 = nn.LayerNorm(embed_dims) - self.attn = Attention( - query_dim=embed_dims, - dim_head=dim_head, - heads=heads, - out_bias=False, - ) - self.ff = nn.Sequential( - nn.LayerNorm(embed_dims), - FeedForward(embed_dims, embed_dims, activation_fn="gelu", mult=ffn_ratio, bias=False), - ) - - def forward(self, x, latents, residual): - encoder_hidden_states = self.ln0(x) - latents = self.ln1(latents) - encoder_hidden_states = torch.cat([encoder_hidden_states, latents], dim=-2) - latents = self.attn(latents, encoder_hidden_states) + residual - latents = self.ff(latents) + latents - return latents - - -class IPAdapterPlusImageProjection(nn.Module): - """Resampler of IP-Adapter Plus. - - Args: - embed_dims (int): The feature dimension. Defaults to 768. output_dims (int): The number of output channels, - that is the same - number of the channels in the `unet.config.cross_attention_dim`. Defaults to 1024. - hidden_dims (int): - The number of hidden channels. Defaults to 1280. depth (int): The number of blocks. Defaults - to 8. dim_head (int): The number of head channels. Defaults to 64. heads (int): Parallel attention heads. - Defaults to 16. num_queries (int): - The number of queries. Defaults to 8. ffn_ratio (float): The expansion ratio - of feedforward network hidden - layer channels. Defaults to 4. - """ - - def __init__( - self, - embed_dims: int = 768, - output_dims: int = 1024, - hidden_dims: int = 1280, - depth: int = 4, - dim_head: int = 64, - heads: int = 16, - num_queries: int = 8, - ffn_ratio: float = 4, - ) -> None: - super().__init__() - self.latents = nn.Parameter(torch.randn(1, num_queries, hidden_dims) / hidden_dims**0.5) - - self.proj_in = nn.Linear(embed_dims, hidden_dims) - - self.proj_out = nn.Linear(hidden_dims, output_dims) - self.norm_out = nn.LayerNorm(output_dims) - - self.layers = nn.ModuleList( - [IPAdapterPlusImageProjectionBlock(hidden_dims, dim_head, heads, ffn_ratio) for _ in range(depth)] - ) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """Forward pass. - - Args: - x (torch.Tensor): Input Tensor. - Returns: - torch.Tensor: Output Tensor. - """ - latents = self.latents.repeat(x.size(0), 1, 1) - - x = self.proj_in(x) - - for block in self.layers: - residual = latents - latents = block(x, latents, residual) - - latents = self.proj_out(latents) - return self.norm_out(latents) - - -class IPAdapterFaceIDPlusImageProjection(nn.Module): - """FacePerceiverResampler of IP-Adapter Plus. - - Args: - embed_dims (int): The feature dimension. Defaults to 768. output_dims (int): The number of output channels, - that is the same - number of the channels in the `unet.config.cross_attention_dim`. Defaults to 1024. - hidden_dims (int): - The number of hidden channels. Defaults to 1280. depth (int): The number of blocks. Defaults - to 8. dim_head (int): The number of head channels. Defaults to 64. heads (int): Parallel attention heads. - Defaults to 16. num_tokens (int): Number of tokens num_queries (int): The number of queries. Defaults to 8. - ffn_ratio (float): The expansion ratio of feedforward network hidden - layer channels. Defaults to 4. - ffproj_ratio (float): The expansion ratio of feedforward network hidden - layer channels (for ID embeddings). Defaults to 4. - """ - - def __init__( - self, - embed_dims: int = 768, - output_dims: int = 768, - hidden_dims: int = 1280, - id_embeddings_dim: int = 512, - depth: int = 4, - dim_head: int = 64, - heads: int = 16, - num_tokens: int = 4, - num_queries: int = 8, - ffn_ratio: float = 4, - ffproj_ratio: int = 2, - ) -> None: - super().__init__() - from .attention import FeedForward - - self.num_tokens = num_tokens - self.embed_dim = embed_dims - self.clip_embeds = None - self.shortcut = False - self.shortcut_scale = 1.0 - - self.proj = FeedForward(id_embeddings_dim, embed_dims * num_tokens, activation_fn="gelu", mult=ffproj_ratio) - self.norm = nn.LayerNorm(embed_dims) - - self.proj_in = nn.Linear(hidden_dims, embed_dims) - - self.proj_out = nn.Linear(embed_dims, output_dims) - self.norm_out = nn.LayerNorm(output_dims) - - self.layers = nn.ModuleList( - [IPAdapterPlusImageProjectionBlock(embed_dims, dim_head, heads, ffn_ratio) for _ in range(depth)] - ) - - def forward(self, id_embeds: torch.Tensor) -> torch.Tensor: - """Forward pass. - - Args: - id_embeds (torch.Tensor): Input Tensor (ID embeds). - Returns: - torch.Tensor: Output Tensor. - """ - id_embeds = id_embeds.to(self.clip_embeds.dtype) - id_embeds = self.proj(id_embeds) - id_embeds = id_embeds.reshape(-1, self.num_tokens, self.embed_dim) - id_embeds = self.norm(id_embeds) - latents = id_embeds - - clip_embeds = self.proj_in(self.clip_embeds) - x = clip_embeds.reshape(-1, clip_embeds.shape[2], clip_embeds.shape[3]) - - for block in self.layers: - residual = latents - latents = block(x, latents, residual) - - latents = self.proj_out(latents) - out = self.norm_out(latents) - if self.shortcut: - out = id_embeds + self.shortcut_scale * out - return out - - -class MultiIPAdapterImageProjection(nn.Module): - def __init__(self, IPAdapterImageProjectionLayers: Union[List[nn.Module], Tuple[nn.Module]]): - super().__init__() - self.image_projection_layers = nn.ModuleList(IPAdapterImageProjectionLayers) - - def forward(self, image_embeds: List[torch.Tensor]): - projected_image_embeds = [] - - # currently, we accept `image_embeds` as - # 1. a tensor (deprecated) with shape [batch_size, embed_dim] or [batch_size, sequence_length, embed_dim] - # 2. list of `n` tensors where `n` is number of ip-adapters, each tensor can hae shape [batch_size, num_images, embed_dim] or [batch_size, num_images, sequence_length, embed_dim] - if not isinstance(image_embeds, list): - deprecation_message = ( - "You have passed a tensor as `image_embeds`.This is deprecated and will be removed in a future release." - " Please make sure to update your script to pass `image_embeds` as a list of tensors to suppress this warning." - ) - deprecate("image_embeds not a list", "1.0.0", deprecation_message, standard_warn=False) - image_embeds = [image_embeds.unsqueeze(1)] - - if len(image_embeds) != len(self.image_projection_layers): - raise ValueError( - f"image_embeds must have the same length as image_projection_layers, got {len(image_embeds)} and {len(self.image_projection_layers)}" - ) - - for image_embed, image_projection_layer in zip(image_embeds, self.image_projection_layers): - batch_size, num_images = image_embed.shape[0], image_embed.shape[1] - image_embed = image_embed.reshape((batch_size * num_images,) + image_embed.shape[2:]) - image_embed = image_projection_layer(image_embed) - image_embed = image_embed.reshape((batch_size, num_images) + image_embed.shape[1:]) - - projected_image_embeds.append(image_embed) - - return projected_image_embeds -- Gitee From 5bfd7b94d69ad567f1d9f42c05615101c8951d35 Mon Sep 17 00:00:00 2001 From: zhoufan2956 Date: Tue, 24 Dec 2024 21:09:40 +0800 Subject: [PATCH 06/32] add stable_audio --- .../schedulers/scheduling_cosine_dpmsolver_multistep.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/schedulers/scheduling_cosine_dpmsolver_multistep.py b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/schedulers/scheduling_cosine_dpmsolver_multistep.py index 69d0458d12..000189436e 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/schedulers/scheduling_cosine_dpmsolver_multistep.py +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/schedulers/scheduling_cosine_dpmsolver_multistep.py @@ -21,8 +21,8 @@ import numpy as np import torch from diffusers.configuration_utils import ConfigMixin, register_to_config -from diffusers.scheduling_dpmsolver_sde import BrownianTreeNoiseSampler -from diffusers.scheduling_utils import SchedulerMixin, SchedulerOutput +from diffusers.schedulers.scheduling_dpmsolver_sde import BrownianTreeNoiseSampler +from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): -- Gitee From b1a26486e0f1969bb2aaeb04540e249c2982f856 Mon Sep 17 00:00:00 2001 From: zhoufan2956 Date: Tue, 24 Dec 2024 21:10:44 +0800 Subject: [PATCH 07/32] add stable_audio --- .../stableaudio/vae/autoencoder_oobleck.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/vae/autoencoder_oobleck.py b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/vae/autoencoder_oobleck.py index 8650d40b32..1facb94849 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/vae/autoencoder_oobleck.py +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/vae/autoencoder_oobleck.py @@ -20,11 +20,11 @@ import torch import torch.nn as nn from torch.nn.utils import weight_norm -from diffsuers.configuration_utils import ConfigMixin, register_to_config -from diffsuers.utils import BaseOutput -from diffsuers.utils.accelerate_utils import apply_forward_hook -from diffsuers.utils.torch_utils import randn_tensor -from diffsuers.models.modeling_utils import ModelMixin +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.utils import BaseOutput +from diffusers.utils.accelerate_utils import apply_forward_hook +from diffusers.utils.torch_utils import randn_tensor +from diffusers.models.modeling_utils import ModelMixin class Snake1d(nn.Module): -- Gitee From 17d5889bb7d14e9a6f8d13bb0f96cecaf16ff4f7 Mon Sep 17 00:00:00 2001 From: zhoufan2956 Date: Tue, 24 Dec 2024 21:13:10 +0800 Subject: [PATCH 08/32] add stable_audio --- .../built-in/foundation/stable_audio/stableaudio/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/__init__.py b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/__init__.py index bb0fa4ea0b..0267fdcab4 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/__init__.py +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/__init__.py @@ -3,4 +3,4 @@ from .pipeline import StableAudioPipeline from .models import StableAudioDiTModel from .models import StableAudioProjectionModel from .vae.autoencoder_oobleck import AutoencoderOobleck -from .scheduler.scheduling_cosine_dpmsolver_multistep import CosineDPMSolverMultistepScheduler \ No newline at end of file +from .schedulers.scheduling_cosine_dpmsolver_multistep import CosineDPMSolverMultistepScheduler \ No newline at end of file -- Gitee From 12231e02c937ebbfa1ffc0310290c00137f20594 Mon Sep 17 00:00:00 2001 From: zhoufan2956 Date: Tue, 24 Dec 2024 21:14:17 +0800 Subject: [PATCH 09/32] add stable_audio --- .../stable_audio/inference_stableaudio.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/inference_stableaudio.py b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/inference_stableaudio.py index fc4f242512..e645c8aa48 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/inference_stableaudio.py +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/inference_stableaudio.py @@ -71,18 +71,18 @@ def main(): torch_npu.npu.set_device(args.device) torch.manual_seed(1) latents = torch.randn(1, 64, 1024, dtype=torch.float16,device="cpu") - with open(args.stable_audio_open_dir + "/vae/config.json", "r", encoding="utf-8") as reader: + with open(args.model + "/vae/config.json", "r", encoding="utf-8") as reader: data = reader.read() json_data = json.loads(data) init_dict = {key: json_data[key] for key in json_data} vae = AutoencoderOobleck(**init_dict) - vae.load_state_dict(load_file(args.stable_audio_open_dir + "/vae/diffusion_pytorch_model.safetensors"), strict=False) + vae.load_state_dict(load_file(args.model + "/vae/diffusion_pytorch_model.safetensors"), strict=False) - tokenizer = T5TokenizerFast.from_pretrained(args.stable_audio_open_dir + "/tokenizer") - text_encoder = T5EncoderModel.from_pretrained(args.stable_audio_open_dir + "/text_encoder") - projection_model = StableAudioProjectionModel.from_pretrained(args.stable_audio_open_dir + "/projection_model") - audio_dit = StableAudioDiTModel.from_pretrained(args.stable_audio_open_dir + "/transformer") - scheduler = CosineDPMSolverMultistepScheduler.from_pretrained(args.stable_audio_open_dir + "/scheduler") + tokenizer = T5TokenizerFast.from_pretrained(args.model + "/tokenizer") + text_encoder = T5EncoderModel.from_pretrained(args.model + "/text_encoder") + projection_model = StableAudioProjectionModel.from_pretrained(args.model + "/projection_model") + audio_dit = StableAudioDiTModel.from_pretrained(args.model + "/transformer") + scheduler = CosineDPMSolverMultistepScheduler.from_pretrained(args.model + "/scheduler") npu_stream = torch_npu.npu.Stream() vae = vae.to("npu").to(torch.float16).eval() -- Gitee From 0e27af2c1a66c063a5271bb10d4977fc43ce64fd Mon Sep 17 00:00:00 2001 From: zhoufan2956 Date: Tue, 24 Dec 2024 21:27:54 +0800 Subject: [PATCH 10/32] add stable_audio --- .../stableaudio/layers/attention_processor.py | 21 ++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/attention_processor.py b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/attention_processor.py index 55264cbfd2..a373a783db 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/attention_processor.py +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/attention_processor.py @@ -16,6 +16,7 @@ import math from typing import Callable, List, Optional, Tuple, Union import torch +import torch_npu import torch.nn.functional as F from torch import nn @@ -876,11 +877,21 @@ class StableAudioAttnProcessor2_0: query = query.to(query_dtype) key = key.to(key_dtype) - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) + if query.device is not torch.device("cpu"): + if attention_mask is not None: + attention_mask=~attention_mask + attention_mask=attention_mask.to(torch.bool) + hidden_states=torch_npu.npu_prompt_flash_attention(query,key,value, + atten_mask=attention_mask, + input_layout='BNSD', + scale_vaule=head_dim**-0.5, + pre_tokens=65535, + next_tokens=65535, + num_heads=attn.heads) + else: + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) -- Gitee From f1b831b8a0a617df98bd3696415b43c46a463e38 Mon Sep 17 00:00:00 2001 From: zhoufan2956 Date: Tue, 24 Dec 2024 21:30:16 +0800 Subject: [PATCH 11/32] add stable_audio --- .../stable_audio/stableaudio/layers/attention_processor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/attention_processor.py b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/attention_processor.py index a373a783db..f573dba08f 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/attention_processor.py +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/attention_processor.py @@ -884,7 +884,7 @@ class StableAudioAttnProcessor2_0: hidden_states=torch_npu.npu_prompt_flash_attention(query,key,value, atten_mask=attention_mask, input_layout='BNSD', - scale_vaule=head_dim**-0.5, + scale_value=head_dim**-0.5, pre_tokens=65535, next_tokens=65535, num_heads=attn.heads) -- Gitee From 5d34e2bcc4f2f10838ecc7bb9bfe564827ed3a86 Mon Sep 17 00:00:00 2001 From: zhoufan2956 Date: Tue, 24 Dec 2024 21:32:27 +0800 Subject: [PATCH 12/32] add stable_audio --- .../stableaudio/layers/attention_processor.py | 32 +++++++++---------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/attention_processor.py b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/attention_processor.py index f573dba08f..971c49bf0f 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/attention_processor.py +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/attention_processor.py @@ -855,27 +855,27 @@ class StableAudioAttnProcessor2_0: if attn.norm_k is not None: key = attn.norm_k(key) - # Apply RoPE if needed - if rotary_emb is not None: - query_dtype = query.dtype - key_dtype = key.dtype - query = query.to(torch.float32) - key = key.to(torch.float32) + # # Apply RoPE if needed + # if rotary_emb is not None: + # query_dtype = query.dtype + # key_dtype = key.dtype + # query = query.to(torch.float32) + # key = key.to(torch.float32) - rot_dim = rotary_emb[0].shape[-1] - query_to_rotate, query_unrotated = query[..., :rot_dim], query[..., rot_dim:] - query_rotated = apply_rotary_emb(query_to_rotate, rotary_emb, use_real=True, use_real_unbind_dim=-2) + # rot_dim = rotary_emb[0].shape[-1] + # query_to_rotate, query_unrotated = query[..., :rot_dim], query[..., rot_dim:] + # query_rotated = apply_rotary_emb(query_to_rotate, rotary_emb, use_real=True, use_real_unbind_dim=-2) - query = torch.cat((query_rotated, query_unrotated), dim=-1) + # query = torch.cat((query_rotated, query_unrotated), dim=-1) - if not attn.is_cross_attention: - key_to_rotate, key_unrotated = key[..., :rot_dim], key[..., rot_dim:] - key_rotated = apply_rotary_emb(key_to_rotate, rotary_emb, use_real=True, use_real_unbind_dim=-2) + # if not attn.is_cross_attention: + # key_to_rotate, key_unrotated = key[..., :rot_dim], key[..., rot_dim:] + # key_rotated = apply_rotary_emb(key_to_rotate, rotary_emb, use_real=True, use_real_unbind_dim=-2) - key = torch.cat((key_rotated, key_unrotated), dim=-1) + # key = torch.cat((key_rotated, key_unrotated), dim=-1) - query = query.to(query_dtype) - key = key.to(key_dtype) + # query = query.to(query_dtype) + # key = key.to(key_dtype) if query.device is not torch.device("cpu"): if attention_mask is not None: -- Gitee From e805f1dc5a81f0499dceedfd32d50ee8adb3ae26 Mon Sep 17 00:00:00 2001 From: zhoufan2956 Date: Tue, 24 Dec 2024 21:36:49 +0800 Subject: [PATCH 13/32] add stable_audio --- .../stableaudio/models/stable_audio_transformer.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/models/stable_audio_transformer.py b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/models/stable_audio_transformer.py index 97151fa9f9..4ab52a3838 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/models/stable_audio_transformer.py +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/models/stable_audio_transformer.py @@ -67,7 +67,6 @@ class StableAudioGaussianFourierProjection(nn.Module): return out - class StableAudioDiTBlock(nn.Module): r""" Transformer block used in Stable Audio model (https://github.com/Stability-AI/stable-audio-tools). Allow skip @@ -230,7 +229,7 @@ class StableAudioDiTModel(ModelMixin, ConfigMixin): self.sample_size = sample_size self.out_channels = out_channels self.inner_dim = num_attention_heads * attention_head_dim - + self. init_dtype = self.dtype self.time_proj = StableAudioGaussianFourierProjection( embedding_size=time_proj_dim // 2, flip_sin_to_cos=True, @@ -396,7 +395,7 @@ class StableAudioDiTModel(ModelMixin, ConfigMixin): """ cross_attention_hidden_states = self.cross_attention_proj(encoder_hidden_states) global_hidden_states = self.global_proj(global_hidden_states) - time_hidden_states = self.timestep_proj(self.time_proj(timestep.to(self.dtype))) + time_hidden_states = self.timestep_proj(self.time_proj(timestep.to(self.init_dtype))) global_hidden_states = global_hidden_states + time_hidden_states.unsqueeze(1) -- Gitee From c8ebca9d78fbf215af1e01281f46c826ef0f5df0 Mon Sep 17 00:00:00 2001 From: zhoufan2956 Date: Tue, 24 Dec 2024 21:41:29 +0800 Subject: [PATCH 14/32] add stable_audio --- .../layers/attention_processor copy.py | 915 ++++++++++++++++++ .../stableaudio/layers/attention_processor.py | 47 +- 2 files changed, 938 insertions(+), 24 deletions(-) create mode 100644 MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/attention_processor copy.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/attention_processor copy.py b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/attention_processor copy.py new file mode 100644 index 0000000000..f573dba08f --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/attention_processor copy.py @@ -0,0 +1,915 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect +import math +from typing import Callable, List, Optional, Tuple, Union + +import torch +import torch_npu +import torch.nn.functional as F +from torch import nn + +from diffusers.utils import deprecate, logging + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class Attention(nn.Module): + r""" + A cross attention layer. + + Parameters: + query_dim (`int`): + The number of channels in the query. + cross_attention_dim (`int`, *optional*): + The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. + heads (`int`, *optional*, defaults to 8): + The number of heads to use for multi-head attention. + kv_heads (`int`, *optional*, defaults to `None`): + The number of key and value heads to use for multi-head attention. Defaults to `heads`. If + `kv_heads=heads`, the model will use Multi Head Attention (MHA), if `kv_heads=1` the model will use Multi + Query Attention (MQA) otherwise GQA is used. + dim_head (`int`, *optional*, defaults to 64): + The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): + The dropout probability to use. + bias (`bool`, *optional*, defaults to False): + Set to `True` for the query, key, and value linear layers to contain a bias parameter. + upcast_attention (`bool`, *optional*, defaults to False): + Set to `True` to upcast the attention computation to `float32`. + upcast_softmax (`bool`, *optional*, defaults to False): + Set to `True` to upcast the softmax computation to `float32`. + cross_attention_norm (`str`, *optional*, defaults to `None`): + The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`. + cross_attention_norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups to use for the group norm in the cross attention. + added_kv_proj_dim (`int`, *optional*, defaults to `None`): + The number of channels to use for the added key and value projections. If `None`, no projection is used. + norm_num_groups (`int`, *optional*, defaults to `None`): + The number of groups to use for the group norm in the attention. + spatial_norm_dim (`int`, *optional*, defaults to `None`): + The number of channels to use for the spatial normalization. + out_bias (`bool`, *optional*, defaults to `True`): + Set to `True` to use a bias in the output linear layer. + scale_qk (`bool`, *optional*, defaults to `True`): + Set to `True` to scale the query and key by `1 / sqrt(dim_head)`. + only_cross_attention (`bool`, *optional*, defaults to `False`): + Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if + `added_kv_proj_dim` is not `None`. + eps (`float`, *optional*, defaults to 1e-5): + An additional value added to the denominator in group normalization that is used for numerical stability. + rescale_output_factor (`float`, *optional*, defaults to 1.0): + A factor to rescale the output by dividing it with this value. + residual_connection (`bool`, *optional*, defaults to `False`): + Set to `True` to add the residual connection to the output. + _from_deprecated_attn_block (`bool`, *optional*, defaults to `False`): + Set to `True` if the attention block is loaded from a deprecated state dict. + processor (`AttnProcessor`, *optional*, defaults to `None`): + The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and + `AttnProcessor` otherwise. + """ + + def __init__( + self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + kv_heads: Optional[int] = None, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + upcast_attention: bool = False, + upcast_softmax: bool = False, + cross_attention_norm: Optional[str] = None, + cross_attention_norm_num_groups: int = 32, + qk_norm: Optional[str] = None, + added_kv_proj_dim: Optional[int] = None, + added_proj_bias: Optional[bool] = True, + norm_num_groups: Optional[int] = None, + spatial_norm_dim: Optional[int] = None, + out_bias: bool = True, + scale_qk: bool = True, + only_cross_attention: bool = False, + eps: float = 1e-5, + rescale_output_factor: float = 1.0, + residual_connection: bool = False, + _from_deprecated_attn_block: bool = False, + processor: Optional["AttnProcessor"] = None, + out_dim: int = None, + context_pre_only=None, + pre_only=False, + ): + super().__init__() + + # To prevent circular import. + from .normalization import FP32LayerNorm, RMSNorm + + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads + self.query_dim = query_dim + self.use_bias = bias + self.is_cross_attention = cross_attention_dim is not None + self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim + self.upcast_attention = upcast_attention + self.upcast_softmax = upcast_softmax + self.rescale_output_factor = rescale_output_factor + self.residual_connection = residual_connection + self.dropout = dropout + self.fused_projections = False + self.out_dim = out_dim if out_dim is not None else query_dim + self.context_pre_only = context_pre_only + self.pre_only = pre_only + + # we make use of this private variable to know whether this class is loaded + # with an deprecated state dict so that we can convert it on the fly + self._from_deprecated_attn_block = _from_deprecated_attn_block + + self.scale_qk = scale_qk + self.scale = dim_head**-0.5 if self.scale_qk else 1.0 + + self.heads = out_dim // dim_head if out_dim is not None else heads + # for slice_size > 0 the attention score computation + # is split across the batch axis to save memory + # You can set slice_size with `set_attention_slice` + self.sliceable_head_dim = heads + + self.added_kv_proj_dim = added_kv_proj_dim + self.only_cross_attention = only_cross_attention + + if self.added_kv_proj_dim is None and self.only_cross_attention: + raise ValueError( + "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`." + ) + + if norm_num_groups is not None: + self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True) + else: + self.group_norm = None + + if spatial_norm_dim is not None: + self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim) + else: + self.spatial_norm = None + + if qk_norm is None: + self.norm_q = None + self.norm_k = None + elif qk_norm == "layer_norm": + self.norm_q = nn.LayerNorm(dim_head, eps=eps) + self.norm_k = nn.LayerNorm(dim_head, eps=eps) + elif qk_norm == "fp32_layer_norm": + self.norm_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) + self.norm_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) + elif qk_norm == "layer_norm_across_heads": + # Lumina applys qk norm across all heads + self.norm_q = nn.LayerNorm(dim_head * heads, eps=eps) + self.norm_k = nn.LayerNorm(dim_head * kv_heads, eps=eps) + elif qk_norm == "rms_norm": + self.norm_q = RMSNorm(dim_head, eps=eps) + self.norm_k = RMSNorm(dim_head, eps=eps) + else: + raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None or 'layer_norm'") + + if cross_attention_norm is None: + self.norm_cross = None + elif cross_attention_norm == "layer_norm": + self.norm_cross = nn.LayerNorm(self.cross_attention_dim) + elif cross_attention_norm == "group_norm": + if self.added_kv_proj_dim is not None: + # The given `encoder_hidden_states` are initially of shape + # (batch_size, seq_len, added_kv_proj_dim) before being projected + # to (batch_size, seq_len, cross_attention_dim). The norm is applied + # before the projection, so we need to use `added_kv_proj_dim` as + # the number of channels for the group norm. + norm_cross_num_channels = added_kv_proj_dim + else: + norm_cross_num_channels = self.cross_attention_dim + + self.norm_cross = nn.GroupNorm( + num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True + ) + else: + raise ValueError( + f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'" + ) + + self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias) + + if not self.only_cross_attention: + # only relevant for the `AddedKVProcessor` classes + self.to_k = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias) + self.to_v = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias) + else: + self.to_k = None + self.to_v = None + + self.added_proj_bias = added_proj_bias + if self.added_kv_proj_dim is not None: + self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias) + self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias) + if self.context_pre_only is not None: + self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + + if not self.pre_only: + self.to_out = nn.ModuleList([]) + self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) + self.to_out.append(nn.Dropout(dropout)) + + if self.context_pre_only is not None and not self.context_pre_only: + self.to_add_out = nn.Linear(self.inner_dim, self.out_dim, bias=out_bias) + + if qk_norm is not None and added_kv_proj_dim is not None: + if qk_norm == "fp32_layer_norm": + self.norm_added_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) + self.norm_added_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) + elif qk_norm == "rms_norm": + self.norm_added_q = RMSNorm(dim_head, eps=eps) + self.norm_added_k = RMSNorm(dim_head, eps=eps) + else: + self.norm_added_q = None + self.norm_added_k = None + + # set attention processor + # We use the AttnProcessor2_0 by default when torch 2.x is used which uses + # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention + # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 + if processor is None: + processor = ( + AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() + ) + self.set_processor(processor) + + def set_use_npu_flash_attention(self, use_npu_flash_attention: bool) -> None: + r""" + Set whether to use npu flash attention from `torch_npu` or not. + + """ + if use_npu_flash_attention: + processor = AttnProcessorNPU() + else: + # set attention processor + # We use the AttnProcessor2_0 by default when torch 2.x is used which uses + # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention + # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 + processor = ( + AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() + ) + self.set_processor(processor) + + def set_use_memory_efficient_attention_xformers( + self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None + ) -> None: + r""" + Set whether to use memory efficient attention from `xformers` or not. + + Args: + use_memory_efficient_attention_xformers (`bool`): + Whether to use memory efficient attention from `xformers` or not. + attention_op (`Callable`, *optional*): + The attention operation to use. Defaults to `None` which uses the default attention operation from + `xformers`. + """ + is_custom_diffusion = hasattr(self, "processor") and isinstance( + self.processor, + (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor, CustomDiffusionAttnProcessor2_0), + ) + is_added_kv_processor = hasattr(self, "processor") and isinstance( + self.processor, + ( + AttnAddedKVProcessor, + AttnAddedKVProcessor2_0, + SlicedAttnAddedKVProcessor, + XFormersAttnAddedKVProcessor, + ), + ) + + if use_memory_efficient_attention_xformers: + if is_added_kv_processor and is_custom_diffusion: + raise NotImplementedError( + f"Memory efficient attention is currently not supported for custom diffusion for attention processor type {self.processor}" + ) + if not is_xformers_available(): + raise ModuleNotFoundError( + ( + "Refer to https://github.com/facebookresearch/xformers for more information on how to install" + " xformers" + ), + name="xformers", + ) + elif not torch.cuda.is_available(): + raise ValueError( + "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is" + " only available for GPU " + ) + else: + try: + # Make sure we can run the memory efficient attention + _ = xformers.ops.memory_efficient_attention( + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + ) + except Exception as e: + raise e + + if is_custom_diffusion: + processor = CustomDiffusionXFormersAttnProcessor( + train_kv=self.processor.train_kv, + train_q_out=self.processor.train_q_out, + hidden_size=self.processor.hidden_size, + cross_attention_dim=self.processor.cross_attention_dim, + attention_op=attention_op, + ) + processor.load_state_dict(self.processor.state_dict()) + if hasattr(self.processor, "to_k_custom_diffusion"): + processor.to(self.processor.to_k_custom_diffusion.weight.device) + elif is_added_kv_processor: + # TODO(Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP + # which uses this type of cross attention ONLY because the attention mask of format + # [0, ..., -10.000, ..., 0, ...,] is not supported + # throw warning + logger.info( + "Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation." + ) + processor = XFormersAttnAddedKVProcessor(attention_op=attention_op) + else: + processor = XFormersAttnProcessor(attention_op=attention_op) + else: + if is_custom_diffusion: + attn_processor_class = ( + CustomDiffusionAttnProcessor2_0 + if hasattr(F, "scaled_dot_product_attention") + else CustomDiffusionAttnProcessor + ) + processor = attn_processor_class( + train_kv=self.processor.train_kv, + train_q_out=self.processor.train_q_out, + hidden_size=self.processor.hidden_size, + cross_attention_dim=self.processor.cross_attention_dim, + ) + processor.load_state_dict(self.processor.state_dict()) + if hasattr(self.processor, "to_k_custom_diffusion"): + processor.to(self.processor.to_k_custom_diffusion.weight.device) + else: + # set attention processor + # We use the AttnProcessor2_0 by default when torch 2.x is used which uses + # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention + # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 + processor = ( + AttnProcessor2_0() + if hasattr(F, "scaled_dot_product_attention") and self.scale_qk + else AttnProcessor() + ) + + self.set_processor(processor) + + def set_attention_slice(self, slice_size: int) -> None: + r""" + Set the slice size for attention computation. + + Args: + slice_size (`int`): + The slice size for attention computation. + """ + if slice_size is not None and slice_size > self.sliceable_head_dim: + raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.") + + if slice_size is not None and self.added_kv_proj_dim is not None: + processor = SlicedAttnAddedKVProcessor(slice_size) + elif slice_size is not None: + processor = SlicedAttnProcessor(slice_size) + elif self.added_kv_proj_dim is not None: + processor = AttnAddedKVProcessor() + else: + # set attention processor + # We use the AttnProcessor2_0 by default when torch 2.x is used which uses + # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention + # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 + processor = ( + AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() + ) + + self.set_processor(processor) + + def set_processor(self, processor: "AttnProcessor") -> None: + r""" + Set the attention processor to use. + + Args: + processor (`AttnProcessor`): + The attention processor to use. + """ + # if current processor is in `self._modules` and if passed `processor` is not, we need to + # pop `processor` from `self._modules` + if ( + hasattr(self, "processor") + and isinstance(self.processor, torch.nn.Module) + and not isinstance(processor, torch.nn.Module) + ): + logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}") + self._modules.pop("processor") + + self.processor = processor + + def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProcessor": + r""" + Get the attention processor in use. + + Args: + return_deprecated_lora (`bool`, *optional*, defaults to `False`): + Set to `True` to return the deprecated LoRA attention processor. + + Returns: + "AttentionProcessor": The attention processor in use. + """ + if not return_deprecated_lora: + return self.processor + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + **cross_attention_kwargs, + ) -> torch.Tensor: + r""" + The forward method of the `Attention` class. + + Args: + hidden_states (`torch.Tensor`): + The hidden states of the query. + encoder_hidden_states (`torch.Tensor`, *optional*): + The hidden states of the encoder. + attention_mask (`torch.Tensor`, *optional*): + The attention mask to use. If `None`, no mask is applied. + **cross_attention_kwargs: + Additional keyword arguments to pass along to the cross attention. + + Returns: + `torch.Tensor`: The output of the attention layer. + """ + # The `Attention` class can call different attention processors / attention functions + # here we simply pass along all tensors to the selected processor class + # For standard processors that are defined here, `**cross_attention_kwargs` is empty + + attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) + quiet_attn_parameters = {"ip_adapter_masks"} + unused_kwargs = [ + k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters + ] + if len(unused_kwargs) > 0: + logger.warning( + f"cross_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored." + ) + cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters} + + return self.processor( + self, + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + + def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor: + r""" + Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`. `heads` + is the number of heads initialized while constructing the `Attention` class. + + Args: + tensor (`torch.Tensor`): The tensor to reshape. + + Returns: + `torch.Tensor`: The reshaped tensor. + """ + head_size = self.heads + batch_size, seq_len, dim = tensor.shape + tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) + return tensor + + def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor: + r""" + Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size, seq_len, heads, dim // heads]` `heads` is + the number of heads initialized while constructing the `Attention` class. + + Args: + tensor (`torch.Tensor`): The tensor to reshape. + out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor. If `3`, the tensor is + reshaped to `[batch_size * heads, seq_len, dim // heads]`. + + Returns: + `torch.Tensor`: The reshaped tensor. + """ + head_size = self.heads + if tensor.ndim == 3: + batch_size, seq_len, dim = tensor.shape + extra_dim = 1 + else: + batch_size, extra_dim, seq_len, dim = tensor.shape + tensor = tensor.reshape(batch_size, seq_len * extra_dim, head_size, dim // head_size) + tensor = tensor.permute(0, 2, 1, 3) + + if out_dim == 3: + tensor = tensor.reshape(batch_size * head_size, seq_len * extra_dim, dim // head_size) + + return tensor + + def get_attention_scores( + self, query: torch.Tensor, key: torch.Tensor, attention_mask: Optional[torch.Tensor] = None + ) -> torch.Tensor: + r""" + Compute the attention scores. + + Args: + query (`torch.Tensor`): The query tensor. + key (`torch.Tensor`): The key tensor. + attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied. + + Returns: + `torch.Tensor`: The attention probabilities/scores. + """ + dtype = query.dtype + if self.upcast_attention: + query = query.float() + key = key.float() + + if attention_mask is None: + baddbmm_input = torch.empty( + query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device + ) + beta = 0 + else: + baddbmm_input = attention_mask + beta = 1 + + attention_scores = torch.baddbmm( + baddbmm_input, + query, + key.transpose(-1, -2), + beta=beta, + alpha=self.scale, + ) + del baddbmm_input + + if self.upcast_softmax: + attention_scores = attention_scores.float() + + attention_probs = attention_scores.softmax(dim=-1) + del attention_scores + + attention_probs = attention_probs.to(dtype) + + return attention_probs + + def prepare_attention_mask( + self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3 + ) -> torch.Tensor: + r""" + Prepare the attention mask for the attention computation. + + Args: + attention_mask (`torch.Tensor`): + The attention mask to prepare. + target_length (`int`): + The target length of the attention mask. This is the length of the attention mask after padding. + batch_size (`int`): + The batch size, which is used to repeat the attention mask. + out_dim (`int`, *optional*, defaults to `3`): + The output dimension of the attention mask. Can be either `3` or `4`. + + Returns: + `torch.Tensor`: The prepared attention mask. + """ + head_size = self.heads + if attention_mask is None: + return attention_mask + + current_length: int = attention_mask.shape[-1] + if current_length != target_length: + if attention_mask.device.type == "mps": + # HACK: MPS: Does not support padding by greater than dimension of input tensor. + # Instead, we can manually construct the padding tensor. + padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length) + padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device) + attention_mask = torch.cat([attention_mask, padding], dim=2) + else: + # TODO: for pipelines such as stable-diffusion, padding cross-attn mask: + # we want to instead pad by (0, remaining_length), where remaining_length is: + # remaining_length: int = target_length - current_length + # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + + if out_dim == 3: + if attention_mask.shape[0] < batch_size * head_size: + attention_mask = attention_mask.repeat_interleave(head_size, dim=0) + elif out_dim == 4: + attention_mask = attention_mask.unsqueeze(1) + attention_mask = attention_mask.repeat_interleave(head_size, dim=1) + + return attention_mask + + def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor: + r""" + Normalize the encoder hidden states. Requires `self.norm_cross` to be specified when constructing the + `Attention` class. + + Args: + encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder. + + Returns: + `torch.Tensor`: The normalized encoder hidden states. + """ + assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states" + + if isinstance(self.norm_cross, nn.LayerNorm): + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + elif isinstance(self.norm_cross, nn.GroupNorm): + # Group norm norms along the channels dimension and expects + # input to be in the shape of (N, C, *). In this case, we want + # to norm along the hidden dimension, so we need to move + # (batch_size, sequence_length, hidden_size) -> + # (batch_size, hidden_size, sequence_length) + encoder_hidden_states = encoder_hidden_states.transpose(1, 2) + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.transpose(1, 2) + else: + assert False + + return encoder_hidden_states + + @torch.no_grad() + def fuse_projections(self, fuse=True): + device = self.to_q.weight.data.device + dtype = self.to_q.weight.data.dtype + + if not self.is_cross_attention: + # fetch weight matrices. + concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data]) + in_features = concatenated_weights.shape[1] + out_features = concatenated_weights.shape[0] + + # create a new single projection layer and copy over the weights. + self.to_qkv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype) + self.to_qkv.weight.copy_(concatenated_weights) + if self.use_bias: + concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data]) + self.to_qkv.bias.copy_(concatenated_bias) + + else: + concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data]) + in_features = concatenated_weights.shape[1] + out_features = concatenated_weights.shape[0] + + self.to_kv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype) + self.to_kv.weight.copy_(concatenated_weights) + if self.use_bias: + concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data]) + self.to_kv.bias.copy_(concatenated_bias) + + # handle added projections for SD3 and others. + if hasattr(self, "add_q_proj") and hasattr(self, "add_k_proj") and hasattr(self, "add_v_proj"): + concatenated_weights = torch.cat( + [self.add_q_proj.weight.data, self.add_k_proj.weight.data, self.add_v_proj.weight.data] + ) + in_features = concatenated_weights.shape[1] + out_features = concatenated_weights.shape[0] + + self.to_added_qkv = nn.Linear( + in_features, out_features, bias=self.added_proj_bias, device=device, dtype=dtype + ) + self.to_added_qkv.weight.copy_(concatenated_weights) + if self.added_proj_bias: + concatenated_bias = torch.cat( + [self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data] + ) + self.to_added_qkv.bias.copy_(concatenated_bias) + + self.fused_projections = fuse + + +class AttnProcessor: + r""" + Default processor for performing attention-related computations. + """ + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + *args, + **kwargs, + ) -> torch.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + +class StableAudioAttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is + used in the Stable Audio model. It applies rotary embedding on query and key vector, and allows MHA, GQA or MQA. + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "StableAudioAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + def apply_partial_rotary_emb( + self, + x: torch.Tensor, + freqs_cis: Tuple[torch.Tensor], + ) -> torch.Tensor: + from .embeddings import apply_rotary_emb + + rot_dim = freqs_cis[0].shape[-1] + x_to_rotate, x_unrotated = x[..., :rot_dim], x[..., rot_dim:] + + x_rotated = apply_rotary_emb(x_to_rotate, freqs_cis, use_real=True, use_real_unbind_dim=-2) + + out = torch.cat((x_rotated, x_unrotated), dim=-1) + return out + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + from .embeddings import apply_rotary_emb + + residual = hidden_states + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + head_dim = query.shape[-1] // attn.heads + kv_heads = key.shape[-1] // head_dim + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2) + + if kv_heads != attn.heads: + # if GQA or MQA, repeat the key/value heads to reach the number of query heads. + heads_per_kv_head = attn.heads // kv_heads + key = torch.repeat_interleave(key, heads_per_kv_head, dim=1) + value = torch.repeat_interleave(value, heads_per_kv_head, dim=1) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Apply RoPE if needed + if rotary_emb is not None: + query_dtype = query.dtype + key_dtype = key.dtype + query = query.to(torch.float32) + key = key.to(torch.float32) + + rot_dim = rotary_emb[0].shape[-1] + query_to_rotate, query_unrotated = query[..., :rot_dim], query[..., rot_dim:] + query_rotated = apply_rotary_emb(query_to_rotate, rotary_emb, use_real=True, use_real_unbind_dim=-2) + + query = torch.cat((query_rotated, query_unrotated), dim=-1) + + if not attn.is_cross_attention: + key_to_rotate, key_unrotated = key[..., :rot_dim], key[..., rot_dim:] + key_rotated = apply_rotary_emb(key_to_rotate, rotary_emb, use_real=True, use_real_unbind_dim=-2) + + key = torch.cat((key_rotated, key_unrotated), dim=-1) + + query = query.to(query_dtype) + key = key.to(key_dtype) + + if query.device is not torch.device("cpu"): + if attention_mask is not None: + attention_mask=~attention_mask + attention_mask=attention_mask.to(torch.bool) + hidden_states=torch_npu.npu_prompt_flash_attention(query,key,value, + atten_mask=attention_mask, + input_layout='BNSD', + scale_value=head_dim**-0.5, + pre_tokens=65535, + next_tokens=65535, + num_heads=attn.heads) + else: + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +AttentionProcessor =AttnProcessor \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/attention_processor.py b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/attention_processor.py index 971c49bf0f..e4e2104bd9 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/attention_processor.py +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/attention_processor.py @@ -839,43 +839,42 @@ class StableAudioAttnProcessor2_0: head_dim = query.shape[-1] // attn.heads kv_heads = key.shape[-1] // head_dim - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - key = key.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2) + query = query.view(batch_size, -1, attn.heads, head_dim) + key = key.view(batch_size, -1, kv_heads, head_dim) + value = value.view(batch_size, -1, kv_heads, head_dim) if kv_heads != attn.heads: # if GQA or MQA, repeat the key/value heads to reach the number of query heads. heads_per_kv_head = attn.heads // kv_heads - key = torch.repeat_interleave(key, heads_per_kv_head, dim=1) - value = torch.repeat_interleave(value, heads_per_kv_head, dim=1) + key = torch.repeat_interleave(key, heads_per_kv_head, dim=2) + value = torch.repeat_interleave(value, heads_per_kv_head, dim=2) if attn.norm_q is not None: query = attn.norm_q(query) if attn.norm_k is not None: key = attn.norm_k(key) - # # Apply RoPE if needed - # if rotary_emb is not None: - # query_dtype = query.dtype - # key_dtype = key.dtype - # query = query.to(torch.float32) - # key = key.to(torch.float32) + # Apply RoPE if needed + if rotary_emb is not None: + query_dtype = query.dtype + key_dtype = key.dtype + query = query.to(torch.float32) + key = key.to(torch.float32) - # rot_dim = rotary_emb[0].shape[-1] - # query_to_rotate, query_unrotated = query[..., :rot_dim], query[..., rot_dim:] - # query_rotated = apply_rotary_emb(query_to_rotate, rotary_emb, use_real=True, use_real_unbind_dim=-2) + rot_dim = rotary_emb[0].shape[-1] + query_to_rotate, query_unrotated = query[..., :rot_dim], query[..., rot_dim:] + query_rotated = apply_rotary_emb(query_to_rotate, rotary_emb, use_real=True, use_real_unbind_dim=-2) - # query = torch.cat((query_rotated, query_unrotated), dim=-1) + query = torch.cat((query_rotated, query_unrotated), dim=-1) - # if not attn.is_cross_attention: - # key_to_rotate, key_unrotated = key[..., :rot_dim], key[..., rot_dim:] - # key_rotated = apply_rotary_emb(key_to_rotate, rotary_emb, use_real=True, use_real_unbind_dim=-2) + if not attn.is_cross_attention: + key_to_rotate, key_unrotated = key[..., :rot_dim], key[..., rot_dim:] + key_rotated = apply_rotary_emb(key_to_rotate, rotary_emb, use_real=True, use_real_unbind_dim=-2) - # key = torch.cat((key_rotated, key_unrotated), dim=-1) + key = torch.cat((key_rotated, key_unrotated), dim=-1) - # query = query.to(query_dtype) - # key = key.to(key_dtype) + query = query.to(query_dtype) + key = key.to(key_dtype) if query.device is not torch.device("cpu"): if attention_mask is not None: @@ -883,7 +882,7 @@ class StableAudioAttnProcessor2_0: attention_mask=attention_mask.to(torch.bool) hidden_states=torch_npu.npu_prompt_flash_attention(query,key,value, atten_mask=attention_mask, - input_layout='BNSD', + input_layout='BSND', scale_value=head_dim**-0.5, pre_tokens=65535, next_tokens=65535, @@ -893,7 +892,7 @@ class StableAudioAttnProcessor2_0: query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) # linear proj -- Gitee From 797556b7eb8f576b6ac2be33b84680a672fe40eb Mon Sep 17 00:00:00 2001 From: zhoufan2956 Date: Tue, 24 Dec 2024 21:44:02 +0800 Subject: [PATCH 15/32] add stable_audio --- .../foundation/stable_audio/stableaudio/layers/embeddings.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/embeddings.py b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/embeddings.py index a4d66b7d1d..19aacf6ba2 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/embeddings.py +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/embeddings.py @@ -318,8 +318,8 @@ def apply_rotary_emb( """ if use_real: cos, sin = freqs_cis # [S, D] - cos = cos[None, None] - sin = sin[None, None] + cos = cos[None, :, None, :] + sin = sin[None, :, None, :] cos, sin = cos.to(x.device), sin.to(x.device) if use_real_unbind_dim == -1: -- Gitee From 59c069b0a6b9b2cb1518ab986ee894aad03eef04 Mon Sep 17 00:00:00 2001 From: zhoufan2956 Date: Wed, 25 Dec 2024 10:02:26 +0800 Subject: [PATCH 16/32] add stable_audio --- .../built-in/foundation/stable_audio/requirements.txt | 6 ++++++ 1 file changed, 6 insertions(+) create mode 100644 MindIE/MindIE-Torch/built-in/foundation/stable_audio/requirements.txt diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/requirements.txt b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/requirements.txt new file mode 100644 index 0000000000..6353806d3f --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/requirements.txt @@ -0,0 +1,6 @@ +torch==2.1.0 +torchsde==0.2.6 +diffusers==0.30.0 +transformers==4.40.0 +soundfile==0.12.1 +torch_npu==2.1.0.post6 \ No newline at end of file -- Gitee From 4a28e11a23265f3959954b173a461b0416fd3c57 Mon Sep 17 00:00:00 2001 From: zhoufan2956 Date: Wed, 25 Dec 2024 19:43:32 +0800 Subject: [PATCH 17/32] add stable_audio --- .../layers/attention_processor copy.py | 1 - .../scheduling_cosine_dpmsolver_multistep.py | 2 +- .../schedulers/scheduling_dpmsolver_sde.py | 70 ++ .../schedulers/scheduling_utils.py | 786 ++++++++++++++++++ 4 files changed, 857 insertions(+), 2 deletions(-) create mode 100644 MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/schedulers/scheduling_dpmsolver_sde.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/schedulers/scheduling_utils.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/attention_processor copy.py b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/attention_processor copy.py index f573dba08f..4cdb8bd044 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/attention_processor copy.py +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/attention_processor copy.py @@ -840,7 +840,6 @@ class StableAudioAttnProcessor2_0: kv_heads = key.shape[-1] // head_dim query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - key = key.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2) diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/schedulers/scheduling_cosine_dpmsolver_multistep.py b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/schedulers/scheduling_cosine_dpmsolver_multistep.py index 000189436e..7a33eae2b9 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/schedulers/scheduling_cosine_dpmsolver_multistep.py +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/schedulers/scheduling_cosine_dpmsolver_multistep.py @@ -21,8 +21,8 @@ import numpy as np import torch from diffusers.configuration_utils import ConfigMixin, register_to_config -from diffusers.schedulers.scheduling_dpmsolver_sde import BrownianTreeNoiseSampler from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput +from .scheduling_dpmsolver_sde import BrownianTreeNoiseSampler class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/schedulers/scheduling_dpmsolver_sde.py b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/schedulers/scheduling_dpmsolver_sde.py new file mode 100644 index 0000000000..2193d4175d --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/schedulers/scheduling_dpmsolver_sde.py @@ -0,0 +1,70 @@ +# Copyright 2024 Katherine Crowson, The HuggingFace Team and hlky. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch + +from .scheduling_utils import BrownianTree + + +class BatchedBrownianTree: + """A wrapper around torchsde.BrownianTree that enables batches of entropy.""" + + def __init__(self, x, t0, t1, seed=None, **kwargs): + t0, t1, self.sign = self.sort(t0, t1) + w0 = kwargs.get("w0", torch.zeros_like(x)) + if seed is None: + seed = torch.randint(0, 2**63 - 1, []).item() + self.batched = True + try: + assert len(seed) == x.shape[0] + w0 = w0[0] + except TypeError: + seed = [seed] + self.batched = False + self.trees = [BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed] + + @staticmethod + def sort(a, b): + return (a, b, 1) if a < b else (b, a, -1) + + def __call__(self, t0, t1): + t0, t1, sign = self.sort(t0, t1) + w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign) + return w if self.batched else w[0] + + +class BrownianTreeNoiseSampler: + """A noise sampler backed by a torchsde.BrownianTree. + + Args: + x (Tensor): The tensor whose shape, device and dtype to use to generate + random samples. + sigma_min (float): The low end of the valid interval. + sigma_max (float): The high end of the valid interval. + seed (int or List[int]): The random seed. If a list of seeds is + supplied instead of a single integer, then the noise sampler will use one BrownianTree per batch item, each + with its own seed. + transform (callable): A function that maps sigma to the sampler's + internal timestep. + """ + + def __init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x): + self.transform = transform + t0, t1 = self.transform(torch.as_tensor(sigma_min)), self.transform(torch.as_tensor(sigma_max)) + self.tree = BatchedBrownianTree(x, t0, t1, seed) + + def __call__(self, sigma, sigma_next): + t0, t1 = self.transform(torch.as_tensor(sigma)), self.transform(torch.as_tensor(sigma_next)) + return self.tree(t0, t1) / (t1 - t0).abs().sqrt() \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/schedulers/scheduling_utils.py b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/schedulers/scheduling_utils.py new file mode 100644 index 0000000000..932be322c7 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/schedulers/scheduling_utils.py @@ -0,0 +1,786 @@ + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import trampoline +import warnings + +import numpy as np +import torch + +from torchsde._brownian import brownian_base +from torchsde.settings import LEVY_AREA_APPROXIMATIONS +from torchsde.types import Scalar, Optional, Tuple, Union, Tensor + +_rsqrt3 = 1 / math.sqrt(3) +_r12 = 1 / 12 + + +def _randn(size, dtype, device, seed): + generator = torch.Generator(device).manual_seed(int(seed)) + return torch.randn(size, dtype=dtype, device=device, generator=generator) + + +def _is_scalar(x): + return isinstance(x, int) or isinstance(x, float) or (isinstance(x, torch.Tensor) and x.numel() == 1) + + +def _assert_floating_tensor(name, tensor): + if not torch.is_tensor(tensor): + raise ValueError(f"{name}={tensor} should be a Tensor.") + if not tensor.is_floating_point(): + raise ValueError(f"{name}={tensor} should be floating point.") + + +def _check_tensor_info(*tensors, size, dtype, device): + """Check if sizes, dtypes, and devices of input tensors all match prescribed values.""" + tensors = list(filter(torch.is_tensor, tensors)) + + if dtype is None and len(tensors) == 0: + dtype = torch.get_default_dtype() + if device is None and len(tensors) == 0: + device = torch.device("cpu") + + sizes = [] if size is None else [size] + sizes += [t.shape for t in tensors] + + dtypes = [] if dtype is None else [dtype] + dtypes += [t.dtype for t in tensors] + + devices = [] if device is None else [device] + devices += [t.device for t in tensors] + + if len(sizes) == 0: + raise ValueError("Must either specify `size` or pass in `W` or `H` to implicitly define the size.") + + if not all(i == sizes[0] for i in sizes): + raise ValueError("Multiple sizes found. Make sure `size` and `W` or `H` are consistent.") + if not all(i == dtypes[0] for i in dtypes): + raise ValueError("Multiple dtypes found. Make sure `dtype` and `W` or `H` are consistent.") + if not all(i == devices[0] for i in devices): + raise ValueError("Multiple devices found. Make sure `device` and `W` or `H` are consistent.") + + # Make sure size is a tuple (not a torch.Size) for neat repr-printing purposes. + return tuple(sizes[0]), dtypes[0], devices[0] + + +def _davie_foster_approximation(W, H, h, levy_area_approximation, get_noise): + if levy_area_approximation in (LEVY_AREA_APPROXIMATIONS.none, LEVY_AREA_APPROXIMATIONS.space_time): + return None + elif W.ndimension() in (0, 1): + # If we have zero or one dimensions then treat the scalar / single dimension we have as batch, so that the + # Brownian motion is one dimensional and the Levy area is zero. + return torch.zeros_like(W) + else: + # Davie's approximation to the Levy area from space-time Levy area + A = H.unsqueeze(-1) * W.unsqueeze(-2) - W.unsqueeze(-1) * H.unsqueeze(-2) + noise = get_noise() + noise = noise - noise.transpose(-1, -2) # noise is skew symmetric of variance 2 + if levy_area_approximation == LEVY_AREA_APPROXIMATIONS.foster: + # Foster's additional correction to Davie's approximation + tenth_h = 0.1 * h + H_squared = H ** 2 + std = (tenth_h * (tenth_h + H_squared.unsqueeze(-1) + H_squared.unsqueeze(-2))).sqrt() + else: # davie approximation + std = math.sqrt(_r12 * h ** 2) + a_tilde = std * noise + A += a_tilde + return A + + +def _H_to_U(W: torch.Tensor, H: torch.Tensor, h: float) -> torch.Tensor: + return h * (.5 * W + H) + + +class _EmptyDict: + def __setitem__(self, key, value): + pass + + def __getitem__(self, item): + raise KeyError + + +class _LRUDict(dict): + def __init__(self, max_size): + super().__init__() + self._max_size = max_size + self._keys = [] + + def __setitem__(self, key, value): + if key in self: + self._keys.remove(key) + elif len(self) >= self._max_size: + del self[self._keys.pop(0)] + super().__setitem__(key, value) + self._keys.append(key) + + +class _Interval: + # Intervals correspond to some subinterval of the overall interval [t0, t1]. + # They are arranged as a binary tree: each node corresponds to an interval. If a node has children, they are left + # and right subintervals, which partition the parent interval. + + __slots__ = ( + # These are the things that every interval has + '_start', + '_end', + '_parent', + '_is_left', + '_top', + # These are the things that intervals which are parents also have + '_midway', + '_spawn_key', + '_depth', + '_W_seed', + '_H_seed', + '_left_a_seed', + '_right_a_seed', + '_left_child', + '_right_child') + + def __init__(self, start, end, parent, is_left, top): + self._start = top._round(start) # the left hand edge of the interval + self._end = top._round(end) # the right hand edge of the interval + self._parent = parent # our parent interval + self._is_left = is_left # are we the left or right child of our parent + self._top = top # the top-level BrownianInterval, where we cache certain state + self._midway = None # The point at which we split between left and right subintervals + + ######################################## + # Calculate increments and levy area # + ######################################## + # + # This is a little bit convoluted, so here's an explanation. + # + # The entry point is _increment_and_levy_area, below. This immediately calls _increment_and_space_time_levy_area, + # applies the space-time to full Levy area correction, and then returns. + # + # _increment_and_space_time_levy_area in turn calls a central LRU cache, as (later on) we'll need the increment and + # space-time Levy area of the parent interval to compute our own increment and space-time Levy area, and it's likely + # that our parent exists in the cache, as if we're being queried then our parent was probably queried recently as + # well. + # (The top-level BrownianInterval overrides _increment_and_space_time_levy_area to return its own increment and + # space-time Levy area, effectively holding them permanently in the cache.) + # + # If the request isn't found in the LRU cache then it computes it from its parent. + # Now it turns out that the size of our increment and space-time Levy area is really most naturally thought of as a + # property of our parent: it depends on our parent's increment, space-time Levy area, and whether we are the left or + # right interval within our parent. So _increment_and_space_time_levy_area in turn checks if we are on the + # left or right of our parent and does most of the computation using the parent's attributes. + + def _increment_and_levy_area(self): + W, H = trampoline.trampoline(self._increment_and_space_time_levy_area()) + A = _davie_foster_approximation(W, H, self._end - self._start, self._top._levy_area_approximation, + self._randn_levy) + return W, H, A + + def _increment_and_space_time_levy_area(self): + try: + return self._top._increment_and_space_time_levy_area_cache[self] + except KeyError: + parent = self._parent + + W, H = yield parent._increment_and_space_time_levy_area() + h_reciprocal = 1 / (parent._end - parent._start) + left_diff = parent._midway - parent._start + right_diff = parent._end - parent._midway + + if self._top._have_H: + left_diff_squared = left_diff ** 2 + right_diff_squared = right_diff ** 2 + left_diff_cubed = left_diff * left_diff_squared + right_diff_cubed = right_diff * right_diff_squared + + v = 0.5 * math.sqrt(left_diff * right_diff / (left_diff_cubed + right_diff_cubed)) + + a = v * left_diff_squared * h_reciprocal + b = v * right_diff_squared * h_reciprocal + c = v * _rsqrt3 + + X1 = parent._randn(parent._W_seed) + X2 = parent._randn(parent._H_seed) + + third_coeff = 2 * (a * left_diff + b * right_diff) * h_reciprocal + + if self._is_left: + first_coeff = left_diff * h_reciprocal + second_coeff = 6 * first_coeff * right_diff * h_reciprocal + out_W = first_coeff * W + second_coeff * H + third_coeff * X1 + out_H = first_coeff ** 2 * H - a * X1 + c * right_diff * X2 + else: + first_coeff = right_diff * h_reciprocal + second_coeff = 6 * first_coeff * left_diff * h_reciprocal + out_W = first_coeff * W - second_coeff * H - third_coeff * X1 + out_H = first_coeff ** 2 * H - b * X1 - c * left_diff * X2 + else: + # Don't compute space-time Levy area unless we need to + + mean = left_diff * h_reciprocal * W + var = left_diff * right_diff * h_reciprocal + noise = parent._randn(parent._W_seed) + left_W = mean + math.sqrt(var) * noise + + if self._is_left: + out_W = left_W + else: + out_W = W - left_W + out_H = None + + self._top._increment_and_space_time_levy_area_cache[self] = (out_W, out_H) + return out_W, out_H + + def _randn(self, seed): + # We generate random noise deterministically wrt some seed; this seed is determined by the generator. + # This means that if we drop out of the cache, then we'll create the same random noise next time, as we still + # have the generator. + size = self._top._size + return _randn(size, self._top._dtype, self._top._device, seed) + + def _a_seed(self): + return self._parent._left_a_seed if self._is_left else self._parent._right_a_seed + + def _randn_levy(self): + size = (*self._top._size, *self._top._size[-1:]) + return _randn(size, self._top._dtype, self._top._device, self._a_seed()) + + ######################################## + # Locate an interval in the hierarchy # + ######################################## + # + # The other important piece of this construction is a way to locate any given interval within the binary tree + # hierarchy. (This is typically the slightly slower part, actually, so if you want to speed things up then this is + # the bit to target.) + # + # loc finds the interval [ta, tb] - and creates it in the appropriate place (as a child of some larger interval) if + # it doesn't already exist. As in principle we may request an interval that covers multiple existing intervals, then + # in fact the interval [ta, tb] is returned as an ordered list of existing subintervals. + # + # It calls _loc, which operates recursively. See _loc for more details on how the search works. + + def _loc(self, ta, tb): + out = [] + ta = self._top._round(ta) + tb = self._top._round(tb) + trampoline.trampoline(self._loc_inner(ta, tb, out)) + return out + + def _loc_inner(self, ta, tb, out): + # Expect to have ta < tb + + # First, we (this interval) only have jurisdiction over [self._start, self._end]. So if we're asked for + # something outside of that then we pass the buck up to our parent, who is strictly larger. + if ta < self._start or tb > self._end: + raise trampoline.TailCall(self._parent._loc_inner(ta, tb, out)) + + # If it's us that's being asked for, then we add ourselves on to out and return. + if ta == self._start and tb == self._end: + out.append(self) + return + + # If we've got this far then we know that it's an interval that's within our jurisdiction, and that it's not us. + # So next we check if it's up to us to figure out, or up to our children. + if self._midway is None: + # It's up to us. Create subintervals (_split) if appropriate. + if ta == self._start: + self._split(tb) + raise trampoline.TailCall(self._left_child._loc_inner(ta, tb, out)) + # implies ta > self._start + self._split(ta) + # Query our (newly created) right_child: if tb == self._end then our right child will be the result, and it + # will tell us so. But if tb < self._end then our right_child will need to make another split of its own. + raise trampoline.TailCall(self._right_child._loc_inner(ta, tb, out)) + + # If we're here then we have children: self._midway is not None + if tb <= self._midway: + # Strictly our left_child's problem + raise trampoline.TailCall(self._left_child._loc_inner(ta, tb, out)) + if ta >= self._midway: + # Strictly our right_child's problem + raise trampoline.TailCall(self._right_child._loc_inner(ta, tb, out)) + # It's a problem for both of our children: the requested interval overlaps our midpoint. Call the left_child + # first (to append to out in the correct order), then call our right child. + # (Implies ta < self._midway < tb) + yield self._left_child._loc_inner(ta, self._midway, out) + raise trampoline.TailCall(self._right_child._loc_inner(self._midway, tb, out)) + + def _set_spawn_key_and_depth(self): + self._spawn_key = 2 * self._parent._spawn_key + (0 if self._is_left else 1) + self._depth = self._parent._depth + 1 + + def _split(self, midway): + if self._top._halfway_tree: + self._split_exact(0.5 * (self._end + self._start)) + # self._midway is now the rounded halfway point. + if midway > self._midway: + self._right_child._split(midway) + elif midway < self._midway: + self._left_child._split(midway) + else: + self._split_exact(midway) + + def _split_exact(self, midway): # Create two children + self._midway = self._top._round(midway) + # Use splittable PRNGs to generate noise. + self._set_spawn_key_and_depth() + generator = np.random.SeedSequence(entropy=self._top._entropy, + spawn_key=(self._spawn_key, self._depth), + pool_size=self._top._pool_size) + self._W_seed, self._H_seed, self._left_a_seed, self._right_a_seed = generator.generate_state(4) + + self._left_child = _Interval(start=self._start, + end=midway, + parent=self, + is_left=True, + top=self._top) + self._right_child = _Interval(start=midway, + end=self._end, + parent=self, + is_left=False, + top=self._top) + + +class BrownianInterval(brownian_base.BaseBrownian, _Interval): + """Brownian interval with fixed entropy. + + Computes increments (and optionally Levy area). + + To use: + >>> bm = BrownianInterval(t0=0.0, t1=1.0, size=(4, 1), device='cuda') + >>> bm(0., 0.5) + tensor([[ 0.0733], + [-0.5692], + [ 0.1872], + [-0.3889]], device='cuda:0') + """ + + __slots__ = ( + # Inputs + '_size', + '_dtype', + '_device', + '_entropy', + '_levy_area_approximation', + '_dt', + '_tol', + '_pool_size', + '_cache_size', + '_halfway_tree', + # Quantisation + '_round', + # Caching, searching and computing values + '_increment_and_space_time_levy_area_cache', + '_last_interval', + '_have_H', + '_have_A', + '_w_h', + '_top_a_seed', + # Dependency tree creation + '_average_dt', + '_tree_dt', + '_num_evaluations' + ) + + def __init__(self, + t0: Optional[Scalar] = 0., + t1: Optional[Scalar] = 1., + size: Optional[Tuple[int, ...]] = None, + dtype: Optional[torch.dtype] = None, + device: Optional[Union[str, torch.device]] = None, + entropy: Optional[int] = None, + dt: Optional[Scalar] = None, + tol: Scalar = 0., + pool_size: int = 8, + cache_size: Optional[int] = 45, + halfway_tree: bool = False, + levy_area_approximation: str = LEVY_AREA_APPROXIMATIONS.none, + W: Optional[Tensor] = None, + H: Optional[Tensor] = None): + """Initialize the Brownian interval. + + Args: + t0 (float or Tensor): Initial time. + t1 (float or Tensor): Terminal time. + size (tuple of int): The shape of each Brownian sample. + If zero dimensional represents a scalar Brownian motion. + If one dimensional represents a batch of scalar Brownian motions. + If >two dimensional the last dimension represents the size of a + a multidimensional Brownian motion, and all previous dimensions + represent batch dimensions. + dtype (torch.dtype): The dtype of each Brownian sample. + Defaults to the PyTorch default. + device (str or torch.device): The device of each Brownian sample. + Defaults to the CPU. + entropy (int): Global seed, defaults to `None` for random entropy. + levy_area_approximation (str): Whether to also approximate Levy + area. Defaults to 'none'. Valid options are 'none', + 'space-time', 'davie' or 'foster', corresponding to different + approximation types. + This is needed for some higher-order SDE solvers. + dt (float or Tensor): The expected average step size of the SDE + solver. Set it if you know it (e.g. when using a fixed-step + solver); else it will be estimated from the first few queries. + This is used to set up the data structure such that it is + efficient to query at these intervals. + tol (float or Tensor): What tolerance to resolve the Brownian motion + to. Must be non-negative. Defaults to zero, i.e. floating point + resolution. Usually worth setting in conjunction with + `halfway_tree`, below. + pool_size (int): Size of the pooled entropy. If you care about + statistical randomness then increasing this will help (but will + slow things down). + cache_size (int): How big a cache of recent calculations to use. + (As new calculations depend on old calculations, this speeds + things up dramatically, rather than recomputing things.) + Set this to `None` to use an infinite cache, which will be fast + but memory inefficient. + halfway_tree (bool): Whether the dependency tree (the internal data + structure) should be the dyadic tree. Defaults to `False`. + Normally, the sample path is determined by both `entropy`, + _and_ the locations and order of the query points. Setting this + to `True` will make it deterministic with respect to just + `entropy`; however this is much slower. + W (Tensor): The increment of the Brownian motion over the interval + [t0, t1]. Will be generated randomly if not provided. + H (Tensor): The space-time Levy area of the Brownian motion over the + interval [t0, t1]. Will be generated randomly if not provided. + """ + + ##################################### + # Check and normalise inputs # + ##################################### + + if not _is_scalar(t0): + raise ValueError('Initial time t0 should be a float or 0-d torch.Tensor.') + if not _is_scalar(t1): + raise ValueError('Terminal time t1 should be a float or 0-d torch.Tensor.') + if dt is not None and not _is_scalar(dt): + raise ValueError('Expected average time step dt should be a float or 0-d torch.Tensor.') + + if t0 > t1: + raise ValueError(f'Initial time {t0} should be less than terminal time {t1}.') + t0 = float(t0) + t1 = float(t1) + if dt is not None: + dt = float(dt) + + if halfway_tree: + if tol <= 0.: + raise ValueError("`tol` should be positive.") + if dt is not None: + raise ValueError("`dt` is not used and should be set to `None` if `halfway_tree` is True.") + else: + if tol < 0.: + raise ValueError("`tol` should be non-negative.") + + size, dtype, device = _check_tensor_info(W, H, size=size, dtype=dtype, device=device) + + # Let numpy dictate randomness, so we have fewer seeds to set for reproducibility. + if entropy is None: + entropy = np.random.randint(0, 2 ** 31 - 1) + + if levy_area_approximation not in LEVY_AREA_APPROXIMATIONS: + raise ValueError(f"`levy_area_approximation` must be one of {LEVY_AREA_APPROXIMATIONS}, but got " + f"'{levy_area_approximation}'.") + + ##################################### + # Record inputs # + ##################################### + + self._size = size + self._dtype = dtype + self._device = device + self._entropy = entropy + self._levy_area_approximation = levy_area_approximation + self._dt = dt + self._tol = tol + self._pool_size = pool_size + self._cache_size = cache_size + self._halfway_tree = halfway_tree + + ##################################### + # A miscellany of other things # + ##################################### + + # We keep a cache of recent queries, and their results. This is very important for speed, so that we don't + # recurse all the way up to the top every time we have a query. + if cache_size is None: + self._increment_and_space_time_levy_area_cache = {} + elif cache_size == 0: + self._increment_and_space_time_levy_area_cache = _EmptyDict() + else: + self._increment_and_space_time_levy_area_cache = _LRUDict(max_size=cache_size) + + # We keep track of the most recently queried interval, and start searching for the next interval from that + # element of the binary tree. This is because subsequent queries are likely to be near the most recent query. + self._last_interval = self + + # Precompute these as we don't want to spend lots of time checking strings in hot loops. + self._have_H = self._levy_area_approximation in (LEVY_AREA_APPROXIMATIONS.space_time, + LEVY_AREA_APPROXIMATIONS.davie, + LEVY_AREA_APPROXIMATIONS.foster) + self._have_A = self._levy_area_approximation in (LEVY_AREA_APPROXIMATIONS.davie, + LEVY_AREA_APPROXIMATIONS.foster) + + # If we like we can quantise what level we want to compute the Brownian motion to. + if tol == 0.: + self._round = lambda x: x + else: + ndigits = -int(math.log10(tol)) + self._round = lambda x: round(x, ndigits) + + # Initalise as _Interval. + # (Must come after _round but before _w_h) + super(BrownianInterval, self).__init__(start=t0, + end=t1, + parent=None, + is_left=None, + top=self) + + # Set the global increment and space-time Levy area + generator = np.random.SeedSequence(entropy=entropy, pool_size=pool_size) + initial_W_seed, initial_H_seed, top_a_seed = generator.generate_state(3) + if W is None: + W = self._randn(initial_W_seed) * math.sqrt(t1 - t0) + else: + _assert_floating_tensor('W', W) + if H is None: + H = self._randn(initial_H_seed) * math.sqrt((t1 - t0) / 12) + else: + _assert_floating_tensor('H', H) + self._w_h = (W, H) + self._top_a_seed = top_a_seed + + if not self._halfway_tree: + # We create a binary tree dependency between the points. If we don't do this then the forward pass is still + # efficient at O(N), but we end up with a dependency chain stretching along the interval [t0, t1], making + # the backward pass O(N^2). By setting up a dependency tree of depth relative to `dt` and `cache_size` we + # can instead make both directions O(N log N). + self._average_dt = 0 + self._tree_dt = t1 - t0 + self._num_evaluations = -100 # start off with a warmup period to get a decent estimate of the average + if dt is not None: + # Create the dependency tree based on the supplied hint `dt`. + self._create_dependency_tree(dt) + # If dt is None, then create the dependency tree based on observed statistics of query points. (In __call__) + + # Effectively permanently store our increment and space-time Levy area in the cache. + def _increment_and_space_time_levy_area(self): + return self._w_h + yield # make it a generator + + def _a_seed(self): + return self._top_a_seed + + def _set_spawn_key_and_depth(self): + self._spawn_key = 0 + self._depth = 0 + + def __call__(self, ta, tb=None, return_U=False, return_A=False): + if tb is None: + warnings.warn(f"{self.__class__.__name__} is optimised for interval-based queries, not point evaluation.") + ta, tb = self._start, ta + tb_name = 'ta' + else: + tb_name = 'tb' + ta = float(ta) + tb = float(tb) + if ta < self._start: + warnings.warn(f"Should have ta>=t0 but got ta={ta} and t0={self._start}.") + ta = self._start + if tb < self._start: + warnings.warn(f"Should have {tb_name}>=t0 but got {tb_name}={tb} and t0={self._start}.") + tb = self._start + if ta > self._end: + warnings.warn(f"Should have ta<=t1 but got ta={ta} and t1={self._end}.") + ta = self._end + if tb > self._end: + warnings.warn(f"Should have {tb_name}<=t1 but got {tb_name}={tb} and t1={self._end}.") + tb = self._end + if ta > tb: + raise RuntimeError(f"Query times ta={ta:.3f} and tb={tb:.3f} must respect ta <= tb.") + + if ta == tb: + W = torch.zeros(self._size, dtype=self._dtype, device=self._device) + H = None + A = None + if self._have_H: + H = torch.zeros(self._size, dtype=self._dtype, device=self._device) + if self._have_A: + size = (*self._size, *self._size[-1:]) # not self._size[-1] as that may not exist + A = torch.zeros(size, dtype=self._dtype, device=self._device) + else: + if self._dt is None and not self._halfway_tree: + self._num_evaluations += 1 + # We start off with "negative" num evaluations, to give us a small warm-up period at the start. + if self._num_evaluations > 0: + # Compute average step size so far + dt = tb - ta + self._average_dt = (dt + self._average_dt * (self._num_evaluations - 1)) / self._num_evaluations + if self._average_dt < 0.5 * self._tree_dt: + # If 'dt' wasn't specified, then check the average interval length against the size of the + # bottom of the dependency tree. If we're below halfway then refine the tree by splitting all + # the bottom pieces into two. + self._create_dependency_tree(dt) + + # Find the intervals that correspond to the query. We start our search at the last interval we accessed in + # the binary tree, as it's likely that the next query will come nearby. + intervals = self._last_interval._loc(ta, tb) + # Ideally we'd keep track of intervals[0] on the backward pass. Practically speaking len(intervals) tends to + # be 1 or 2 almost always so this isn't a huge deal. + self._last_interval = intervals[-1] + + W, H, A = intervals[0]._increment_and_levy_area() + if len(intervals) > 1: + # If we have multiple intervals then add up their increments and Levy areas. + + for interval in intervals[1:]: + Wi, Hi, Ai = interval._increment_and_levy_area() + if self._have_H: + # Aggregate H: + # Given s < u < t, then + # H_{s,t} = (term1 + term2) / (t - s) + # where + # term1 = (t - u) * (H_{u, t} + W_{s, u} / 2) + # term2 = (u - s) * (H_{s, u} - W_{u, t} / 2) + term1 = (interval._end - interval._start) * (Hi + 0.5 * W) + term2 = (interval._start - ta) * (H - 0.5 * Wi) + H = (term1 + term2) / (interval._end - ta) + if self._have_A and len(self._size) not in (0, 1): + # If len(self._size) in (0, 1) then we treat our scalar / single dimension as a batch + # dimension, so we have zero Levy area. (And these unsqueezes will result in a tensor of shape + # (batch, batch) which is wrong.) + + # Let B_{x, y} = \int_x^y W^1_{s,u} dW^2_u. + # Then + # B_{s, t} = \int_s^t W^1_{s,u} dW^2_u + # = \int_s^v W^1_{s,u} dW^2_u + \int_v^t W^1_{s,v} dW^2_u + \int_v^t W^1_{v,u} dW^2_u + # = B_{s, v} + W^1_{s, v} W^2_{v, t} + B_{v, t} + # + # A is now the antisymmetric part of B, which gives the formula below. + A = A + Ai + 0.5 * (W.unsqueeze(-1) * Wi.unsqueeze(-2) - Wi.unsqueeze(-1) * W.unsqueeze(-2)) + W = W + Wi + + U = None + if self._have_H: + U = _H_to_U(W, H, tb - ta) + + if return_U: + if return_A: + return W, U, A + else: + return W, U + else: + if return_A: + return W, A + else: + return W + + def _create_dependency_tree(self, dt): + # For safety we take a min with 100: if people take very large cache sizes then this would then break the + # logarithmic into linear, which causes RecursionErrors. + if self._cache_size is None: # cache_size=None corresponds to infinite cache. + cache_size = 100 + else: + cache_size = min(self._cache_size, 100) + + self._tree_dt = min(self._tree_dt, dt) + # Rationale: We are prepared to hold `cache_size` many things in memory, so when making steps of size `dt` + # then we can afford to have the intervals at the bottom of our binary tree be of size `dt * cache_size`. + # For safety we then make this a bit smaller by multiplying by 0.8. + piece_length = self._tree_dt * cache_size * 0.8 + + def _set_points(interval): + start = interval._start + end = interval._end + if end - start > piece_length: + midway = (end + start) / 2 + interval._loc(start, midway) + _set_points(interval._left_child) + _set_points(interval._right_child) + + _set_points(self) + + def __repr__(self): + if self._dt is None: + dt = None + else: + dt = f"{self._dt:.3f}" + return (f"{self.__class__.__name__}(" + f"t0={self._start:.3f}, " + f"t1={self._end:.3f}, " + f"size={self._size}, " + f"dtype={self._dtype}, " + f"device={repr(self._device)}, " + f"entropy={self._entropy}, " + f"dt={dt}, " + f"tol={self._tol}, " + f"pool_size={self._pool_size}, " + f"cache_size={self._cache_size}, " + f"levy_area_approximation={repr(self._levy_area_approximation)}" + f")") + + def display_binary_tree(self): + stack = [(self, 0)] + out = [] + while len(stack): + elem, depth = stack.pop() + out.append(" " * depth + f"({elem._start}, {elem._end})") + if elem._midway is not None: + stack.append((elem._right_child, depth + 1)) + stack.append((elem._left_child, depth + 1)) + print("\n".join(out)) + + @property + def shape(self): + return self._size + + @property + def dtype(self): + return self._dtype + + @property + def device(self): + return self._device + + @property + def entropy(self): + return self._entropy + + @property + def levy_area_approximation(self): + return self._levy_area_approximation + + @property + def dt(self): + return self._dt + + @property + def tol(self): + return self._tol + + @property + def pool_size(self): + return self._pool_size + + @property + def cache_size(self): + return self._cache_size + + @property + def halfway_tree(self): + return self._halfway_tree + + def size(self): + return self._size -- Gitee From c0545c0e0b7754301fe97c630176f208752bf2df Mon Sep 17 00:00:00 2001 From: zhoufan2956 Date: Wed, 25 Dec 2024 20:01:31 +0800 Subject: [PATCH 18/32] add stable_audio --- .../schedulers/scheduling_utils.py | 102 ++++++++++++++++++ 1 file changed, 102 insertions(+) diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/schedulers/scheduling_utils.py b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/schedulers/scheduling_utils.py index 932be322c7..9bc31717cf 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/schedulers/scheduling_utils.py +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/schedulers/scheduling_utils.py @@ -784,3 +784,105 @@ class BrownianInterval(brownian_base.BaseBrownian, _Interval): def size(self): return self._size + + +class BrownianTree(brownian_base.BaseBrownian): + """Brownian tree with fixed entropy. + + Useful when the map from entropy -> Brownian motion shouldn't depend on the + locations and order of the query points. (As the usual BrownianInterval + does - note that BrownianTree is slower as a result though.) + + To use: + >>> bm = BrownianTree(t0=0.0, w0=torch.zeros(4, 1)) + >>> bm(0., 0.5) + tensor([[ 0.0733], + [-0.5692], + [ 0.1872], + [-0.3889]], device='cuda:0') + """ + + def __init__(self, t0: Scalar, + w0: Tensor, + t1: Optional[Scalar] = None, + w1: Optional[Tensor] = None, + entropy: Optional[int] = None, + tol: float = 1e-6, + pool_size: int = 24, + cache_depth: int = 9, + safety: Optional[float] = None): + """Initialize the Brownian tree. + + The random value generation process exploits the parallel random number paradigm and uses + `numpy.random.SeedSequence`. The default generator is PCG64 (used by `default_rng`). + + Arguments: + t0: Initial time. + w0: Initial state. + t1: Terminal time. + w1: Terminal state. + entropy: Global seed, defaults to `None` for random entropy. + tol: Error tolerance before the binary search is terminated; the search depth ~ log2(tol). + pool_size: Size of the pooled entropy. This parameter affects the query speed significantly. + cache_depth: Unused; deprecated. + safety: Unused; deprecated. + """ + + if t1 is None: + t1 = t0 + 1 + if w1 is None: + W = None + else: + W = w1 - w0 + self._w0 = w0 + self._interval = BrownianInterval(t0=t0, + t1=t1, + size=w0.shape, + dtype=w0.dtype, + device=w0.device, + entropy=entropy, + tol=tol, + pool_size=pool_size, + halfway_tree=True, + W=W) + super(BrownianTree, self).__init__() + + def __call__(self, t, tb=None, return_U=False, return_A=False): + # Deliberately called t rather than ta, for backward compatibility + out = self._interval(t, tb, return_U=return_U, return_A=return_A) + if tb is None and not return_U and not return_A: + out = out + self._w0 + return out + + def __repr__(self): + return f"{self.__class__.__name__}(interval={self._interval})" + + @property + def dtype(self): + return self._interval.dtype + + @property + def device(self): + return self._interval.device + + @property + def shape(self): + return self._interval.shape + + @property + def levy_area_approximation(self): + return self._interval.levy_area_approximation + + +def brownian_interval_like(y: Tensor, + t0: Optional[Scalar] = 0., + t1: Optional[Scalar] = 1., + size: Optional[Tuple[int, ...]] = None, + dtype: Optional[torch.dtype] = None, + device: Optional[Union[str, torch.device]] = None, + **kwargs): + """Returns a BrownianInterval object with the same size, device, and dtype as a given tensor.""" + size = y.shape if size is None else size + dtype = y.dtype if dtype is None else dtype + device = y.device if device is None else device + return brownian_interval.BrownianInterval(t0=t0, t1=t1, size=size, dtype=dtype, device=device, **kwargs) \ No newline at end of file -- Gitee From 82b79ef6d00296be9f51e932a43594467abd5d10 Mon Sep 17 00:00:00 2001 From: zhoufan2956 Date: Wed, 25 Dec 2024 20:07:02 +0800 Subject: [PATCH 19/32] add stable_audio --- .../stableaudio/layers/attention_processor.py | 14 +++++++------- ...n_processor copy.py => attention_processor0.py} | 14 +++++++------- 2 files changed, 14 insertions(+), 14 deletions(-) rename MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/{attention_processor copy.py => attention_processor0.py} (99%) diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/attention_processor.py b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/attention_processor.py index e4e2104bd9..4cdb8bd044 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/attention_processor.py +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/attention_processor.py @@ -839,15 +839,15 @@ class StableAudioAttnProcessor2_0: head_dim = query.shape[-1] // attn.heads kv_heads = key.shape[-1] // head_dim - query = query.view(batch_size, -1, attn.heads, head_dim) - key = key.view(batch_size, -1, kv_heads, head_dim) - value = value.view(batch_size, -1, kv_heads, head_dim) + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2) if kv_heads != attn.heads: # if GQA or MQA, repeat the key/value heads to reach the number of query heads. heads_per_kv_head = attn.heads // kv_heads - key = torch.repeat_interleave(key, heads_per_kv_head, dim=2) - value = torch.repeat_interleave(value, heads_per_kv_head, dim=2) + key = torch.repeat_interleave(key, heads_per_kv_head, dim=1) + value = torch.repeat_interleave(value, heads_per_kv_head, dim=1) if attn.norm_q is not None: query = attn.norm_q(query) @@ -882,7 +882,7 @@ class StableAudioAttnProcessor2_0: attention_mask=attention_mask.to(torch.bool) hidden_states=torch_npu.npu_prompt_flash_attention(query,key,value, atten_mask=attention_mask, - input_layout='BSND', + input_layout='BNSD', scale_value=head_dim**-0.5, pre_tokens=65535, next_tokens=65535, @@ -892,7 +892,7 @@ class StableAudioAttnProcessor2_0: query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) - hidden_states = hidden_states.reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) # linear proj diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/attention_processor copy.py b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/attention_processor0.py similarity index 99% rename from MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/attention_processor copy.py rename to MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/attention_processor0.py index 4cdb8bd044..e4e2104bd9 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/attention_processor copy.py +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/attention_processor0.py @@ -839,15 +839,15 @@ class StableAudioAttnProcessor2_0: head_dim = query.shape[-1] // attn.heads kv_heads = key.shape[-1] // head_dim - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - key = key.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2) + query = query.view(batch_size, -1, attn.heads, head_dim) + key = key.view(batch_size, -1, kv_heads, head_dim) + value = value.view(batch_size, -1, kv_heads, head_dim) if kv_heads != attn.heads: # if GQA or MQA, repeat the key/value heads to reach the number of query heads. heads_per_kv_head = attn.heads // kv_heads - key = torch.repeat_interleave(key, heads_per_kv_head, dim=1) - value = torch.repeat_interleave(value, heads_per_kv_head, dim=1) + key = torch.repeat_interleave(key, heads_per_kv_head, dim=2) + value = torch.repeat_interleave(value, heads_per_kv_head, dim=2) if attn.norm_q is not None: query = attn.norm_q(query) @@ -882,7 +882,7 @@ class StableAudioAttnProcessor2_0: attention_mask=attention_mask.to(torch.bool) hidden_states=torch_npu.npu_prompt_flash_attention(query,key,value, atten_mask=attention_mask, - input_layout='BNSD', + input_layout='BSND', scale_value=head_dim**-0.5, pre_tokens=65535, next_tokens=65535, @@ -892,7 +892,7 @@ class StableAudioAttnProcessor2_0: query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) # linear proj -- Gitee From 2da1200e2129818c7b5817c3b0815d452b8de426 Mon Sep 17 00:00:00 2001 From: zhoufan2956 Date: Wed, 25 Dec 2024 20:08:25 +0800 Subject: [PATCH 20/32] add stable_audio --- .../stableaudio/layers/attention_processor.py | 14 +- .../layers/attention_processor0.py | 914 ------------------ 2 files changed, 7 insertions(+), 921 deletions(-) delete mode 100644 MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/attention_processor0.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/attention_processor.py b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/attention_processor.py index 4cdb8bd044..e4e2104bd9 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/attention_processor.py +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/attention_processor.py @@ -839,15 +839,15 @@ class StableAudioAttnProcessor2_0: head_dim = query.shape[-1] // attn.heads kv_heads = key.shape[-1] // head_dim - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - key = key.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2) + query = query.view(batch_size, -1, attn.heads, head_dim) + key = key.view(batch_size, -1, kv_heads, head_dim) + value = value.view(batch_size, -1, kv_heads, head_dim) if kv_heads != attn.heads: # if GQA or MQA, repeat the key/value heads to reach the number of query heads. heads_per_kv_head = attn.heads // kv_heads - key = torch.repeat_interleave(key, heads_per_kv_head, dim=1) - value = torch.repeat_interleave(value, heads_per_kv_head, dim=1) + key = torch.repeat_interleave(key, heads_per_kv_head, dim=2) + value = torch.repeat_interleave(value, heads_per_kv_head, dim=2) if attn.norm_q is not None: query = attn.norm_q(query) @@ -882,7 +882,7 @@ class StableAudioAttnProcessor2_0: attention_mask=attention_mask.to(torch.bool) hidden_states=torch_npu.npu_prompt_flash_attention(query,key,value, atten_mask=attention_mask, - input_layout='BNSD', + input_layout='BSND', scale_value=head_dim**-0.5, pre_tokens=65535, next_tokens=65535, @@ -892,7 +892,7 @@ class StableAudioAttnProcessor2_0: query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) # linear proj diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/attention_processor0.py b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/attention_processor0.py deleted file mode 100644 index e4e2104bd9..0000000000 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/attention_processor0.py +++ /dev/null @@ -1,914 +0,0 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import inspect -import math -from typing import Callable, List, Optional, Tuple, Union - -import torch -import torch_npu -import torch.nn.functional as F -from torch import nn - -from diffusers.utils import deprecate, logging - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -class Attention(nn.Module): - r""" - A cross attention layer. - - Parameters: - query_dim (`int`): - The number of channels in the query. - cross_attention_dim (`int`, *optional*): - The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. - heads (`int`, *optional*, defaults to 8): - The number of heads to use for multi-head attention. - kv_heads (`int`, *optional*, defaults to `None`): - The number of key and value heads to use for multi-head attention. Defaults to `heads`. If - `kv_heads=heads`, the model will use Multi Head Attention (MHA), if `kv_heads=1` the model will use Multi - Query Attention (MQA) otherwise GQA is used. - dim_head (`int`, *optional*, defaults to 64): - The number of channels in each head. - dropout (`float`, *optional*, defaults to 0.0): - The dropout probability to use. - bias (`bool`, *optional*, defaults to False): - Set to `True` for the query, key, and value linear layers to contain a bias parameter. - upcast_attention (`bool`, *optional*, defaults to False): - Set to `True` to upcast the attention computation to `float32`. - upcast_softmax (`bool`, *optional*, defaults to False): - Set to `True` to upcast the softmax computation to `float32`. - cross_attention_norm (`str`, *optional*, defaults to `None`): - The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`. - cross_attention_norm_num_groups (`int`, *optional*, defaults to 32): - The number of groups to use for the group norm in the cross attention. - added_kv_proj_dim (`int`, *optional*, defaults to `None`): - The number of channels to use for the added key and value projections. If `None`, no projection is used. - norm_num_groups (`int`, *optional*, defaults to `None`): - The number of groups to use for the group norm in the attention. - spatial_norm_dim (`int`, *optional*, defaults to `None`): - The number of channels to use for the spatial normalization. - out_bias (`bool`, *optional*, defaults to `True`): - Set to `True` to use a bias in the output linear layer. - scale_qk (`bool`, *optional*, defaults to `True`): - Set to `True` to scale the query and key by `1 / sqrt(dim_head)`. - only_cross_attention (`bool`, *optional*, defaults to `False`): - Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if - `added_kv_proj_dim` is not `None`. - eps (`float`, *optional*, defaults to 1e-5): - An additional value added to the denominator in group normalization that is used for numerical stability. - rescale_output_factor (`float`, *optional*, defaults to 1.0): - A factor to rescale the output by dividing it with this value. - residual_connection (`bool`, *optional*, defaults to `False`): - Set to `True` to add the residual connection to the output. - _from_deprecated_attn_block (`bool`, *optional*, defaults to `False`): - Set to `True` if the attention block is loaded from a deprecated state dict. - processor (`AttnProcessor`, *optional*, defaults to `None`): - The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and - `AttnProcessor` otherwise. - """ - - def __init__( - self, - query_dim: int, - cross_attention_dim: Optional[int] = None, - heads: int = 8, - kv_heads: Optional[int] = None, - dim_head: int = 64, - dropout: float = 0.0, - bias: bool = False, - upcast_attention: bool = False, - upcast_softmax: bool = False, - cross_attention_norm: Optional[str] = None, - cross_attention_norm_num_groups: int = 32, - qk_norm: Optional[str] = None, - added_kv_proj_dim: Optional[int] = None, - added_proj_bias: Optional[bool] = True, - norm_num_groups: Optional[int] = None, - spatial_norm_dim: Optional[int] = None, - out_bias: bool = True, - scale_qk: bool = True, - only_cross_attention: bool = False, - eps: float = 1e-5, - rescale_output_factor: float = 1.0, - residual_connection: bool = False, - _from_deprecated_attn_block: bool = False, - processor: Optional["AttnProcessor"] = None, - out_dim: int = None, - context_pre_only=None, - pre_only=False, - ): - super().__init__() - - # To prevent circular import. - from .normalization import FP32LayerNorm, RMSNorm - - self.inner_dim = out_dim if out_dim is not None else dim_head * heads - self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads - self.query_dim = query_dim - self.use_bias = bias - self.is_cross_attention = cross_attention_dim is not None - self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim - self.upcast_attention = upcast_attention - self.upcast_softmax = upcast_softmax - self.rescale_output_factor = rescale_output_factor - self.residual_connection = residual_connection - self.dropout = dropout - self.fused_projections = False - self.out_dim = out_dim if out_dim is not None else query_dim - self.context_pre_only = context_pre_only - self.pre_only = pre_only - - # we make use of this private variable to know whether this class is loaded - # with an deprecated state dict so that we can convert it on the fly - self._from_deprecated_attn_block = _from_deprecated_attn_block - - self.scale_qk = scale_qk - self.scale = dim_head**-0.5 if self.scale_qk else 1.0 - - self.heads = out_dim // dim_head if out_dim is not None else heads - # for slice_size > 0 the attention score computation - # is split across the batch axis to save memory - # You can set slice_size with `set_attention_slice` - self.sliceable_head_dim = heads - - self.added_kv_proj_dim = added_kv_proj_dim - self.only_cross_attention = only_cross_attention - - if self.added_kv_proj_dim is None and self.only_cross_attention: - raise ValueError( - "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`." - ) - - if norm_num_groups is not None: - self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True) - else: - self.group_norm = None - - if spatial_norm_dim is not None: - self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim) - else: - self.spatial_norm = None - - if qk_norm is None: - self.norm_q = None - self.norm_k = None - elif qk_norm == "layer_norm": - self.norm_q = nn.LayerNorm(dim_head, eps=eps) - self.norm_k = nn.LayerNorm(dim_head, eps=eps) - elif qk_norm == "fp32_layer_norm": - self.norm_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) - self.norm_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) - elif qk_norm == "layer_norm_across_heads": - # Lumina applys qk norm across all heads - self.norm_q = nn.LayerNorm(dim_head * heads, eps=eps) - self.norm_k = nn.LayerNorm(dim_head * kv_heads, eps=eps) - elif qk_norm == "rms_norm": - self.norm_q = RMSNorm(dim_head, eps=eps) - self.norm_k = RMSNorm(dim_head, eps=eps) - else: - raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None or 'layer_norm'") - - if cross_attention_norm is None: - self.norm_cross = None - elif cross_attention_norm == "layer_norm": - self.norm_cross = nn.LayerNorm(self.cross_attention_dim) - elif cross_attention_norm == "group_norm": - if self.added_kv_proj_dim is not None: - # The given `encoder_hidden_states` are initially of shape - # (batch_size, seq_len, added_kv_proj_dim) before being projected - # to (batch_size, seq_len, cross_attention_dim). The norm is applied - # before the projection, so we need to use `added_kv_proj_dim` as - # the number of channels for the group norm. - norm_cross_num_channels = added_kv_proj_dim - else: - norm_cross_num_channels = self.cross_attention_dim - - self.norm_cross = nn.GroupNorm( - num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True - ) - else: - raise ValueError( - f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'" - ) - - self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias) - - if not self.only_cross_attention: - # only relevant for the `AddedKVProcessor` classes - self.to_k = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias) - self.to_v = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias) - else: - self.to_k = None - self.to_v = None - - self.added_proj_bias = added_proj_bias - if self.added_kv_proj_dim is not None: - self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias) - self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias) - if self.context_pre_only is not None: - self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) - - if not self.pre_only: - self.to_out = nn.ModuleList([]) - self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) - self.to_out.append(nn.Dropout(dropout)) - - if self.context_pre_only is not None and not self.context_pre_only: - self.to_add_out = nn.Linear(self.inner_dim, self.out_dim, bias=out_bias) - - if qk_norm is not None and added_kv_proj_dim is not None: - if qk_norm == "fp32_layer_norm": - self.norm_added_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) - self.norm_added_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) - elif qk_norm == "rms_norm": - self.norm_added_q = RMSNorm(dim_head, eps=eps) - self.norm_added_k = RMSNorm(dim_head, eps=eps) - else: - self.norm_added_q = None - self.norm_added_k = None - - # set attention processor - # We use the AttnProcessor2_0 by default when torch 2.x is used which uses - # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention - # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 - if processor is None: - processor = ( - AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() - ) - self.set_processor(processor) - - def set_use_npu_flash_attention(self, use_npu_flash_attention: bool) -> None: - r""" - Set whether to use npu flash attention from `torch_npu` or not. - - """ - if use_npu_flash_attention: - processor = AttnProcessorNPU() - else: - # set attention processor - # We use the AttnProcessor2_0 by default when torch 2.x is used which uses - # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention - # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 - processor = ( - AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() - ) - self.set_processor(processor) - - def set_use_memory_efficient_attention_xformers( - self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None - ) -> None: - r""" - Set whether to use memory efficient attention from `xformers` or not. - - Args: - use_memory_efficient_attention_xformers (`bool`): - Whether to use memory efficient attention from `xformers` or not. - attention_op (`Callable`, *optional*): - The attention operation to use. Defaults to `None` which uses the default attention operation from - `xformers`. - """ - is_custom_diffusion = hasattr(self, "processor") and isinstance( - self.processor, - (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor, CustomDiffusionAttnProcessor2_0), - ) - is_added_kv_processor = hasattr(self, "processor") and isinstance( - self.processor, - ( - AttnAddedKVProcessor, - AttnAddedKVProcessor2_0, - SlicedAttnAddedKVProcessor, - XFormersAttnAddedKVProcessor, - ), - ) - - if use_memory_efficient_attention_xformers: - if is_added_kv_processor and is_custom_diffusion: - raise NotImplementedError( - f"Memory efficient attention is currently not supported for custom diffusion for attention processor type {self.processor}" - ) - if not is_xformers_available(): - raise ModuleNotFoundError( - ( - "Refer to https://github.com/facebookresearch/xformers for more information on how to install" - " xformers" - ), - name="xformers", - ) - elif not torch.cuda.is_available(): - raise ValueError( - "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is" - " only available for GPU " - ) - else: - try: - # Make sure we can run the memory efficient attention - _ = xformers.ops.memory_efficient_attention( - torch.randn((1, 2, 40), device="cuda"), - torch.randn((1, 2, 40), device="cuda"), - torch.randn((1, 2, 40), device="cuda"), - ) - except Exception as e: - raise e - - if is_custom_diffusion: - processor = CustomDiffusionXFormersAttnProcessor( - train_kv=self.processor.train_kv, - train_q_out=self.processor.train_q_out, - hidden_size=self.processor.hidden_size, - cross_attention_dim=self.processor.cross_attention_dim, - attention_op=attention_op, - ) - processor.load_state_dict(self.processor.state_dict()) - if hasattr(self.processor, "to_k_custom_diffusion"): - processor.to(self.processor.to_k_custom_diffusion.weight.device) - elif is_added_kv_processor: - # TODO(Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP - # which uses this type of cross attention ONLY because the attention mask of format - # [0, ..., -10.000, ..., 0, ...,] is not supported - # throw warning - logger.info( - "Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation." - ) - processor = XFormersAttnAddedKVProcessor(attention_op=attention_op) - else: - processor = XFormersAttnProcessor(attention_op=attention_op) - else: - if is_custom_diffusion: - attn_processor_class = ( - CustomDiffusionAttnProcessor2_0 - if hasattr(F, "scaled_dot_product_attention") - else CustomDiffusionAttnProcessor - ) - processor = attn_processor_class( - train_kv=self.processor.train_kv, - train_q_out=self.processor.train_q_out, - hidden_size=self.processor.hidden_size, - cross_attention_dim=self.processor.cross_attention_dim, - ) - processor.load_state_dict(self.processor.state_dict()) - if hasattr(self.processor, "to_k_custom_diffusion"): - processor.to(self.processor.to_k_custom_diffusion.weight.device) - else: - # set attention processor - # We use the AttnProcessor2_0 by default when torch 2.x is used which uses - # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention - # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 - processor = ( - AttnProcessor2_0() - if hasattr(F, "scaled_dot_product_attention") and self.scale_qk - else AttnProcessor() - ) - - self.set_processor(processor) - - def set_attention_slice(self, slice_size: int) -> None: - r""" - Set the slice size for attention computation. - - Args: - slice_size (`int`): - The slice size for attention computation. - """ - if slice_size is not None and slice_size > self.sliceable_head_dim: - raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.") - - if slice_size is not None and self.added_kv_proj_dim is not None: - processor = SlicedAttnAddedKVProcessor(slice_size) - elif slice_size is not None: - processor = SlicedAttnProcessor(slice_size) - elif self.added_kv_proj_dim is not None: - processor = AttnAddedKVProcessor() - else: - # set attention processor - # We use the AttnProcessor2_0 by default when torch 2.x is used which uses - # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention - # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 - processor = ( - AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() - ) - - self.set_processor(processor) - - def set_processor(self, processor: "AttnProcessor") -> None: - r""" - Set the attention processor to use. - - Args: - processor (`AttnProcessor`): - The attention processor to use. - """ - # if current processor is in `self._modules` and if passed `processor` is not, we need to - # pop `processor` from `self._modules` - if ( - hasattr(self, "processor") - and isinstance(self.processor, torch.nn.Module) - and not isinstance(processor, torch.nn.Module) - ): - logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}") - self._modules.pop("processor") - - self.processor = processor - - def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProcessor": - r""" - Get the attention processor in use. - - Args: - return_deprecated_lora (`bool`, *optional*, defaults to `False`): - Set to `True` to return the deprecated LoRA attention processor. - - Returns: - "AttentionProcessor": The attention processor in use. - """ - if not return_deprecated_lora: - return self.processor - - def forward( - self, - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - **cross_attention_kwargs, - ) -> torch.Tensor: - r""" - The forward method of the `Attention` class. - - Args: - hidden_states (`torch.Tensor`): - The hidden states of the query. - encoder_hidden_states (`torch.Tensor`, *optional*): - The hidden states of the encoder. - attention_mask (`torch.Tensor`, *optional*): - The attention mask to use. If `None`, no mask is applied. - **cross_attention_kwargs: - Additional keyword arguments to pass along to the cross attention. - - Returns: - `torch.Tensor`: The output of the attention layer. - """ - # The `Attention` class can call different attention processors / attention functions - # here we simply pass along all tensors to the selected processor class - # For standard processors that are defined here, `**cross_attention_kwargs` is empty - - attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) - quiet_attn_parameters = {"ip_adapter_masks"} - unused_kwargs = [ - k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters - ] - if len(unused_kwargs) > 0: - logger.warning( - f"cross_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored." - ) - cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters} - - return self.processor( - self, - hidden_states, - encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, - **cross_attention_kwargs, - ) - - def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor: - r""" - Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`. `heads` - is the number of heads initialized while constructing the `Attention` class. - - Args: - tensor (`torch.Tensor`): The tensor to reshape. - - Returns: - `torch.Tensor`: The reshaped tensor. - """ - head_size = self.heads - batch_size, seq_len, dim = tensor.shape - tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) - tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) - return tensor - - def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor: - r""" - Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size, seq_len, heads, dim // heads]` `heads` is - the number of heads initialized while constructing the `Attention` class. - - Args: - tensor (`torch.Tensor`): The tensor to reshape. - out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor. If `3`, the tensor is - reshaped to `[batch_size * heads, seq_len, dim // heads]`. - - Returns: - `torch.Tensor`: The reshaped tensor. - """ - head_size = self.heads - if tensor.ndim == 3: - batch_size, seq_len, dim = tensor.shape - extra_dim = 1 - else: - batch_size, extra_dim, seq_len, dim = tensor.shape - tensor = tensor.reshape(batch_size, seq_len * extra_dim, head_size, dim // head_size) - tensor = tensor.permute(0, 2, 1, 3) - - if out_dim == 3: - tensor = tensor.reshape(batch_size * head_size, seq_len * extra_dim, dim // head_size) - - return tensor - - def get_attention_scores( - self, query: torch.Tensor, key: torch.Tensor, attention_mask: Optional[torch.Tensor] = None - ) -> torch.Tensor: - r""" - Compute the attention scores. - - Args: - query (`torch.Tensor`): The query tensor. - key (`torch.Tensor`): The key tensor. - attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied. - - Returns: - `torch.Tensor`: The attention probabilities/scores. - """ - dtype = query.dtype - if self.upcast_attention: - query = query.float() - key = key.float() - - if attention_mask is None: - baddbmm_input = torch.empty( - query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device - ) - beta = 0 - else: - baddbmm_input = attention_mask - beta = 1 - - attention_scores = torch.baddbmm( - baddbmm_input, - query, - key.transpose(-1, -2), - beta=beta, - alpha=self.scale, - ) - del baddbmm_input - - if self.upcast_softmax: - attention_scores = attention_scores.float() - - attention_probs = attention_scores.softmax(dim=-1) - del attention_scores - - attention_probs = attention_probs.to(dtype) - - return attention_probs - - def prepare_attention_mask( - self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3 - ) -> torch.Tensor: - r""" - Prepare the attention mask for the attention computation. - - Args: - attention_mask (`torch.Tensor`): - The attention mask to prepare. - target_length (`int`): - The target length of the attention mask. This is the length of the attention mask after padding. - batch_size (`int`): - The batch size, which is used to repeat the attention mask. - out_dim (`int`, *optional*, defaults to `3`): - The output dimension of the attention mask. Can be either `3` or `4`. - - Returns: - `torch.Tensor`: The prepared attention mask. - """ - head_size = self.heads - if attention_mask is None: - return attention_mask - - current_length: int = attention_mask.shape[-1] - if current_length != target_length: - if attention_mask.device.type == "mps": - # HACK: MPS: Does not support padding by greater than dimension of input tensor. - # Instead, we can manually construct the padding tensor. - padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length) - padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device) - attention_mask = torch.cat([attention_mask, padding], dim=2) - else: - # TODO: for pipelines such as stable-diffusion, padding cross-attn mask: - # we want to instead pad by (0, remaining_length), where remaining_length is: - # remaining_length: int = target_length - current_length - # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding - attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) - - if out_dim == 3: - if attention_mask.shape[0] < batch_size * head_size: - attention_mask = attention_mask.repeat_interleave(head_size, dim=0) - elif out_dim == 4: - attention_mask = attention_mask.unsqueeze(1) - attention_mask = attention_mask.repeat_interleave(head_size, dim=1) - - return attention_mask - - def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor: - r""" - Normalize the encoder hidden states. Requires `self.norm_cross` to be specified when constructing the - `Attention` class. - - Args: - encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder. - - Returns: - `torch.Tensor`: The normalized encoder hidden states. - """ - assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states" - - if isinstance(self.norm_cross, nn.LayerNorm): - encoder_hidden_states = self.norm_cross(encoder_hidden_states) - elif isinstance(self.norm_cross, nn.GroupNorm): - # Group norm norms along the channels dimension and expects - # input to be in the shape of (N, C, *). In this case, we want - # to norm along the hidden dimension, so we need to move - # (batch_size, sequence_length, hidden_size) -> - # (batch_size, hidden_size, sequence_length) - encoder_hidden_states = encoder_hidden_states.transpose(1, 2) - encoder_hidden_states = self.norm_cross(encoder_hidden_states) - encoder_hidden_states = encoder_hidden_states.transpose(1, 2) - else: - assert False - - return encoder_hidden_states - - @torch.no_grad() - def fuse_projections(self, fuse=True): - device = self.to_q.weight.data.device - dtype = self.to_q.weight.data.dtype - - if not self.is_cross_attention: - # fetch weight matrices. - concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data]) - in_features = concatenated_weights.shape[1] - out_features = concatenated_weights.shape[0] - - # create a new single projection layer and copy over the weights. - self.to_qkv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype) - self.to_qkv.weight.copy_(concatenated_weights) - if self.use_bias: - concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data]) - self.to_qkv.bias.copy_(concatenated_bias) - - else: - concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data]) - in_features = concatenated_weights.shape[1] - out_features = concatenated_weights.shape[0] - - self.to_kv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype) - self.to_kv.weight.copy_(concatenated_weights) - if self.use_bias: - concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data]) - self.to_kv.bias.copy_(concatenated_bias) - - # handle added projections for SD3 and others. - if hasattr(self, "add_q_proj") and hasattr(self, "add_k_proj") and hasattr(self, "add_v_proj"): - concatenated_weights = torch.cat( - [self.add_q_proj.weight.data, self.add_k_proj.weight.data, self.add_v_proj.weight.data] - ) - in_features = concatenated_weights.shape[1] - out_features = concatenated_weights.shape[0] - - self.to_added_qkv = nn.Linear( - in_features, out_features, bias=self.added_proj_bias, device=device, dtype=dtype - ) - self.to_added_qkv.weight.copy_(concatenated_weights) - if self.added_proj_bias: - concatenated_bias = torch.cat( - [self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data] - ) - self.to_added_qkv.bias.copy_(concatenated_bias) - - self.fused_projections = fuse - - -class AttnProcessor: - r""" - Default processor for performing attention-related computations. - """ - - def __call__( - self, - attn: Attention, - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - temb: Optional[torch.Tensor] = None, - *args, - **kwargs, - ) -> torch.Tensor: - if len(args) > 0 or kwargs.get("scale", None) is not None: - deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." - deprecate("scale", "1.0.0", deprecation_message) - - residual = hidden_states - - if attn.spatial_norm is not None: - hidden_states = attn.spatial_norm(hidden_states, temb) - - input_ndim = hidden_states.ndim - - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - - if attn.group_norm is not None: - hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - - query = attn.to_q(hidden_states) - - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - elif attn.norm_cross: - encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - - query = attn.head_to_batch_dim(query) - key = attn.head_to_batch_dim(key) - value = attn.head_to_batch_dim(value) - - attention_probs = attn.get_attention_scores(query, key, attention_mask) - hidden_states = torch.bmm(attention_probs, value) - hidden_states = attn.batch_to_head_dim(hidden_states) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) - - if attn.residual_connection: - hidden_states = hidden_states + residual - - hidden_states = hidden_states / attn.rescale_output_factor - - return hidden_states - -class StableAudioAttnProcessor2_0: - r""" - Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is - used in the Stable Audio model. It applies rotary embedding on query and key vector, and allows MHA, GQA or MQA. - """ - - def __init__(self): - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError( - "StableAudioAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." - ) - - def apply_partial_rotary_emb( - self, - x: torch.Tensor, - freqs_cis: Tuple[torch.Tensor], - ) -> torch.Tensor: - from .embeddings import apply_rotary_emb - - rot_dim = freqs_cis[0].shape[-1] - x_to_rotate, x_unrotated = x[..., :rot_dim], x[..., rot_dim:] - - x_rotated = apply_rotary_emb(x_to_rotate, freqs_cis, use_real=True, use_real_unbind_dim=-2) - - out = torch.cat((x_rotated, x_unrotated), dim=-1) - return out - - def __call__( - self, - attn: Attention, - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - rotary_emb: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - from .embeddings import apply_rotary_emb - - residual = hidden_states - - input_ndim = hidden_states.ndim - - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) - - if attention_mask is not None: - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - # scaled_dot_product_attention expects attention_mask shape to be - # (batch, heads, source_length, target_length) - attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) - - query = attn.to_q(hidden_states) - - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - elif attn.norm_cross: - encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - - head_dim = query.shape[-1] // attn.heads - kv_heads = key.shape[-1] // head_dim - - query = query.view(batch_size, -1, attn.heads, head_dim) - key = key.view(batch_size, -1, kv_heads, head_dim) - value = value.view(batch_size, -1, kv_heads, head_dim) - - if kv_heads != attn.heads: - # if GQA or MQA, repeat the key/value heads to reach the number of query heads. - heads_per_kv_head = attn.heads // kv_heads - key = torch.repeat_interleave(key, heads_per_kv_head, dim=2) - value = torch.repeat_interleave(value, heads_per_kv_head, dim=2) - - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) - - # Apply RoPE if needed - if rotary_emb is not None: - query_dtype = query.dtype - key_dtype = key.dtype - query = query.to(torch.float32) - key = key.to(torch.float32) - - rot_dim = rotary_emb[0].shape[-1] - query_to_rotate, query_unrotated = query[..., :rot_dim], query[..., rot_dim:] - query_rotated = apply_rotary_emb(query_to_rotate, rotary_emb, use_real=True, use_real_unbind_dim=-2) - - query = torch.cat((query_rotated, query_unrotated), dim=-1) - - if not attn.is_cross_attention: - key_to_rotate, key_unrotated = key[..., :rot_dim], key[..., rot_dim:] - key_rotated = apply_rotary_emb(key_to_rotate, rotary_emb, use_real=True, use_real_unbind_dim=-2) - - key = torch.cat((key_rotated, key_unrotated), dim=-1) - - query = query.to(query_dtype) - key = key.to(key_dtype) - - if query.device is not torch.device("cpu"): - if attention_mask is not None: - attention_mask=~attention_mask - attention_mask=attention_mask.to(torch.bool) - hidden_states=torch_npu.npu_prompt_flash_attention(query,key,value, - atten_mask=attention_mask, - input_layout='BSND', - scale_value=head_dim**-0.5, - pre_tokens=65535, - next_tokens=65535, - num_heads=attn.heads) - else: - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) - - hidden_states = hidden_states.reshape(batch_size, -1, attn.heads * head_dim) - hidden_states = hidden_states.to(query.dtype) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) - - if attn.residual_connection: - hidden_states = hidden_states + residual - - hidden_states = hidden_states / attn.rescale_output_factor - - return hidden_states - - -AttentionProcessor =AttnProcessor \ No newline at end of file -- Gitee From e91cc73a84a45b13528a5a81edb99f398c944e85 Mon Sep 17 00:00:00 2001 From: zhoufan2956 Date: Wed, 25 Dec 2024 20:33:04 +0800 Subject: [PATCH 21/32] add stable_audio --- .../foundation/stable_audio/readme0.md | 117 ++++++++++++++++++ 1 file changed, 117 insertions(+) create mode 100644 MindIE/MindIE-Torch/built-in/foundation/stable_audio/readme0.md diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/readme0.md b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/readme0.md new file mode 100644 index 0000000000..43bd08b74a --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/readme0.md @@ -0,0 +1,117 @@ +## 一、准备运行环境 + + **表 1** 版本配套表 + + | 配套 | 版本 | 环境准备指导 | + | ----- | ----- |-----| + | Python | 3.10.2 | - | + | torch | 2.1.0 | - | + +### 1.1 获取CANN&MindIE安装包&环境准备 +- [800I A2](https://www.hiascend.com/developer/download/community/result?module=pt+ie+cann&product=4&model=32) +- [Duo卡](https://www.hiascend.com/developer/download/community/result?module=pt+ie+cann&product=2&model=17) +- [环境准备指导](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC2alpha002/softwareinst/instg/instg_0001.html) + +### 1.2 CANN安装 +```shell +# 增加软件包可执行权限,{version}表示软件版本号,{arch}表示CPU架构,{soc}表示昇腾AI处理器的版本。 +chmod +x ./Ascend-cann-toolkit_{version}_linux-{arch}.run +chmod +x ./Ascend-cann-kernels-{soc}_{version}_linux.run +# 校验软件包安装文件的一致性和完整性 +./Ascend-cann-toolkit_{version}_linux-{arch}.run --check +./Ascend-cann-kernels-{soc}_{version}_linux.run --check +# 安装 +./Ascend-cann-toolkit_{version}_linux-{arch}.run --install +./Ascend-cann-kernels-{soc}_{version}_linux.run --install + +# 设置环境变量 +source /usr/local/Ascend/ascend-toolkit/set_env.sh +``` + +### 1.3 环境依赖安装 +```bash +pip3 install -r requirements.txt +apt-get update +apt-get install libsndfile1 +``` + + +### 1.4 MindIE安装 +```shell +# 增加软件包可执行权限,{version}表示软件版本号,{arch}表示CPU架构。 +chmod +x ./Ascend-mindie_${version}_linux-${arch}.run +./Ascend-mindie_${version}_linux-${arch}.run --check + +# 方式一:默认路径安装 +./Ascend-mindie_${version}_linux-${arch}.run --install +# 设置环境变量 +cd /usr/local/Ascend/mindie && source set_env.sh + +# 方式二:指定路径安装 +./Ascend-mindie_${version}_linux-${arch}.run --install-path=${AieInstallPath} +# 设置环境变量 +cd ${AieInstallPath}/mindie && source set_env.sh +``` + +### 1.5 Torch_npu安装 +安装pytorch框架 版本2.1.0 +[安装包下载](https://download.pytorch.org/whl/cpu/torch/) + +使用pip安装 +```shell +# {version}表示软件版本号,{arch}表示CPU架构。 +pip install torch-${version}-cp310-cp310-linux_${arch}.whl +``` +下载 pytorch_v{pytorchversion}_py{pythonversion}.tar.gz +```shell +tar -xzvf pytorch_v{pytorchversion}_py{pythonversion}.tar.gz +# 解压后,会有whl包 +pip install torch_npu-{pytorchversion}.xxxx.{arch}.whl +``` +## 二、下载本仓库 + +### 2.1 下载到本地 +```shell + git clone https://gitee.com/ascend/ModelZoo-PyTorch.git +``` + +## 三、Stable-Audio-Open-1.0 使用 + +### 3.1 权重及配置文件说明 +stable-audio-open-1.0权重链接: +```shell +https://huggingface.co/stabilityai/stable-audio-open-1.0/tree/main +``` + +### 3.2 单卡功能测试 +设置权重路径 +```shell +model_base = './stable-audio-open-1.0' +``` +执行命令: +```shell +python3 inference_stableaudio.py \ + --model ${model_base} \ + --prompt_file ./prompts/prompts.txt \ + --num_inference_steps 100 \ + --audio_end_in_s 10 10 47 \ + --save_dir ./results \ + --device 0 +``` +参数说明: +- --model:模型权重路径。 +- --prompt_file:提示词文件。 +- --num_inference_steps: 语音生成迭代次数。 +- --audio_end_in_s:生成语音的时长,如不输入则默认生成10s。 +- --save_dir:生成语音的存放目录。 +- --device:推理设备ID。 + +执行完成后在`./results`目录下生成推理语音,语音生成顺序与文本中prompt顺序保持一致,并在终端显示推理时间。 + +### 3.2 模型推理性能 + +性能参考下列数据。 + +| 硬件形态 | 迭代次数 | 平均耗时| +| :------: |:----:|:----:| +| Atlas 800I A2 (32G) | 100 | 10.251 | \ No newline at end of file -- Gitee From fb96f1a924144c7600801957afb52344c5e79323 Mon Sep 17 00:00:00 2001 From: zhoufan2956 Date: Wed, 25 Dec 2024 20:35:32 +0800 Subject: [PATCH 22/32] add stable_audio --- .../foundation/stable_audio/README.md | 279 +++++++----------- .../foundation/stable_audio/readme0.md | 117 -------- 2 files changed, 105 insertions(+), 291 deletions(-) delete mode 100644 MindIE/MindIE-Torch/built-in/foundation/stable_audio/readme0.md diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/README.md b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/README.md index 809973a563..5b7cb35643 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/README.md +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/README.md @@ -1,34 +1,4 @@ -# stable-audio-open-1.0模型-diffusers方式推理指导 - -- [概述](#ZH-CN_TOPIC_0000001172161501) - - -- [推理环境准备](#ZH-CN_TOPIC_0000001126281702) - -- [快速上手](#ZH-CN_TOPIC_0000001126281700) - - - [获取源码](#section4622531142816) - - [模型推理](#section741711594517) - -- [模型推理性能&精度](#ZH-CN_TOPIC_0000001172201573) - -# 概述 - - [此处获得](https://huggingface.co/stabilityai/stable-audio-open-1.0) - -- 参考实现: - ```bash - # StableAudioOpen1.0 - https://huggingface.co/stabilityai/stable-audio-open-1.0 - ``` - -- 设备支持: -Atlas 800I A2推理设备:支持的卡数为1 -Atlas 300I Duo推理卡:支持的卡数为1 - -# 推理环境准备 - -- 该模型需要以下插件与驱动 +## 一、准备运行环境 **表 1** 版本配套表 @@ -37,150 +7,111 @@ Atlas 300I Duo推理卡:支持的卡数为1 | Python | 3.10.2 | - | | torch | 2.1.0 | - | -该模型性能受CPU规格影响,建议使用64核CPU(arm)以复现性能 - -# 快速上手 -## 获取源码 -1. 安装依赖。 - ```bash - pip3 install -r requirements.txt - apt-get update - apt-get install libsndfile1 - ``` - -2. 安装mindie包 - - ```bash - # 安装mindie - source /usr/local/Ascend/ascend-toolkit/set_env.sh - chmod +x ./Ascend-mindie_xxx.run - ./Ascend-mindie_xxx.run --install - source /usr/local/Ascend/mindie/set_env.sh - ``` - -3. 代码修改 - -- 执行命令: - ```bash - python3 diffusers_aie_patch.py - python3 brownian_interval_patch.py - ``` - -4. MindieTorch配套Torch_NPU使用 - - MindieTorch采用dlopen的方式动态加载Torch_NPU,需要手动编译libtorch_npu_bridge.so,并将其放在libtorch_aie.so同一路径下,或者将其路径设置到LD_LIBRARY_PATH环境变量中,具体参考: - ```bash - https://www.hiascend.com/document/detail/zh/mindie/10RC2/mindietorch/Torchdev/mindie_torch0017.html - ``` - -## 模型推理 - -1. 模型转换。 - - 1. 提前下载权重,放到代码同级目录下。 - - ```bash - # 需要使用 git-lfs (https://git-lfs.com) - git lfs install - - # 下载stable-audio-open-1.0权重 - git clone https://huggingface.co/stabilityai/stable-audio-open-1.0 - ``` - - 2. 导出pt模型并进行编译。 - - (1) 设置模型权重的路径 - ```bash - # stable-audio-open-1.0 (执行时下载权重) - model_base="stabilityai/stable-audio-open-1.0" - - # stable-audio-open-1.0 (使用上一步下载的权重) - model_base="./stable-audio-open-1.0" - ``` - - (2) 执行命令查看芯片名称($\{chip\_name\})。 - - ``` - npu-smi info - ``` - - (3) 执行export命令 - - ```bash - python3 export_ts.py --model ${model_base} --output_dir ./models --soc Ascend${chip_name} --device 0 - ``` - - 参数说明: - - --model:模型权重路径 - - --output_dir: 存放导出模型的路径 - - --soc:处理器型号。 - - --device:推理设备ID - - 注意:trace+compile耗时较长且占用较多的CPU资源,请勿在执行export命令时运行其他占用CPU内存的任务,避免程序意外退出。 - -2. 开始推理验证。 - - 1. 开启cpu高性能模式 - ```bash - echo performance |tee /sys/devices/system/cpu/cpu*/cpufreq/scaling_governor - sysctl -w vm.swappiness=0 - sysctl -w kernel.numa_balancing=0 - ``` - - 2. 安装绑核工具 - ```bash - apt-get update - apt-get install numactl - ``` - 查询卡的NUMA node - ```shell - lspci -vs bus-id - ``` - bus-id可通过npu-smi info获得,查询到NUMA node,在推理命令前加上对应的数字 - - 可通过lscpu获得NUMA node对应的CPU核数 - ```shell - NUMA node0: 0-23 - NUMA node1: 24-47 - NUMA node2: 48-71 - NUMA node3: 72-95 - ``` - 当前查到NUMA node是0,对应0-23,推荐绑定其中单核以获得更好的性能。 - - 3. 执行推理脚本。 - ```bash - numactl -C 0-23 python3 stable_audio_open_aie_pipeline.py \ - --model ${model_base} \ - --output_dir ./models \ - --prompt_file ./prompts.txt \ - --num_inference_steps 100 \ - --audio_end_in_s 10 10 47 \ - --num_waveforms_per_prompt 1 \ - --guidance_scale 7 \ - --save_dir ./results \ - --device 0 - ``` - - 参数说明: - - --model:模型权重路径。 - - --output_dir:存放导出模型的目录。 - - --prompt_file:提示词文件。 - - --num_inference_steps: 语音生成迭代次数。 - - --audio_end_in_s:生成语音的时长,如不输入则默认生成10s。 - - --num_waveforms_per_prompt:一个提示词生成的语音数量。 - - --guidance_scale:音频生成质量与准确度系数。 - - --save_dir:生成语音的存放目录。 - - --device:推理设备ID。 - - 执行完成后在`./results`目录下生成推理语音,语音生成顺序与文本中prompt顺序保持一致,并在终端显示推理时间。 - - - -# 模型推理性能&精度 -性能参考下列数据。 +### 1.1 获取CANN&MindIE安装包&环境准备 +- [800I A2](https://www.hiascend.com/developer/download/community/result?module=pt+ie+cann&product=4&model=32) +- [Duo卡](https://www.hiascend.com/developer/download/community/result?module=pt+ie+cann&product=2&model=17) +- [环境准备指导](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC2alpha002/softwareinst/instg/instg_0001.html) + +### 1.2 CANN安装 +```shell +# 增加软件包可执行权限,{version}表示软件版本号,{arch}表示CPU架构,{soc}表示昇腾AI处理器的版本。 +chmod +x ./Ascend-cann-toolkit_{version}_linux-{arch}.run +chmod +x ./Ascend-cann-kernels-{soc}_{version}_linux.run +# 校验软件包安装文件的一致性和完整性 +./Ascend-cann-toolkit_{version}_linux-{arch}.run --check +./Ascend-cann-kernels-{soc}_{version}_linux.run --check +# 安装 +./Ascend-cann-toolkit_{version}_linux-{arch}.run --install +./Ascend-cann-kernels-{soc}_{version}_linux.run --install + +# 设置环境变量 +source /usr/local/Ascend/ascend-toolkit/set_env.sh +``` + +### 1.3 环境依赖安装 +```bash +pip3 install -r requirements.txt +apt-get update +apt-get install libsndfile1 +``` + + +### 1.4 MindIE安装 +```shell +# 增加软件包可执行权限,{version}表示软件版本号,{arch}表示CPU架构。 +chmod +x ./Ascend-mindie_${version}_linux-${arch}.run +./Ascend-mindie_${version}_linux-${arch}.run --check + +# 方式一:默认路径安装 +./Ascend-mindie_${version}_linux-${arch}.run --install +# 设置环境变量 +cd /usr/local/Ascend/mindie && source set_env.sh + +# 方式二:指定路径安装 +./Ascend-mindie_${version}_linux-${arch}.run --install-path=${AieInstallPath} +# 设置环境变量 +cd ${AieInstallPath}/mindie && source set_env.sh +``` + +### 1.5 Torch_npu安装 +安装pytorch框架 版本2.1.0 +[安装包下载](https://download.pytorch.org/whl/cpu/torch/) + +使用pip安装 +```shell +# {version}表示软件版本号,{arch}表示CPU架构。 +pip install torch-${version}-cp310-cp310-linux_${arch}.whl +``` +下载 pytorch_v{pytorchversion}_py{pythonversion}.tar.gz +```shell +tar -xzvf pytorch_v{pytorchversion}_py{pythonversion}.tar.gz +# 解压后,会有whl包 +pip install torch_npu-{pytorchversion}.xxxx.{arch}.whl +``` +## 二、下载本仓库 + +### 2.1 下载到本地 +```shell + git clone https://gitee.com/ascend/ModelZoo-PyTorch.git +``` + +## 三、Stable-Audio-Open-1.0 使用 + +### 3.1 权重及配置文件说明 +stable-audio-open-1.0权重链接: +```shell +https://huggingface.co/stabilityai/stable-audio-open-1.0/tree/main +``` + +### 3.2 单卡功能测试 +设置权重路径 +```shell +model_base = './stable-audio-open-1.0' +``` +执行命令: +```shell +python3 inference_stableaudio.py \ + --model ${model_base} \ + --prompt_file ./prompts/prompts.txt \ + --num_inference_steps 100 \ + --audio_end_in_s 10 10 47 \ + --save_dir ./results \ + --device 0 +``` +参数说明: +- --model:模型权重路径。 +- --prompt_file:提示词文件。 +- --num_inference_steps: 语音生成迭代次数。 +- --audio_end_in_s:生成语音的时长,如不输入则默认生成10s。 +- --save_dir:生成语音的存放目录。 +- --device:推理设备ID。 + +执行完成后在`./results`目录下生成推理语音,语音生成顺序与文本中prompt顺序保持一致,并在终端显示推理时间。 + +### 3.2 模型推理性能 -### Stable-Audio-Open-1.0 +性能参考下列数据。 | 硬件形态 | 迭代次数 | 平均耗时| | :------: |:----:|:----:| -| Atlas 800I A2 (32G) | 100 | 5.895s | \ No newline at end of file +| Atlas 800I A2 (32G) | 100 | 10.201s | \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/readme0.md b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/readme0.md deleted file mode 100644 index 43bd08b74a..0000000000 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/readme0.md +++ /dev/null @@ -1,117 +0,0 @@ -## 一、准备运行环境 - - **表 1** 版本配套表 - - | 配套 | 版本 | 环境准备指导 | - | ----- | ----- |-----| - | Python | 3.10.2 | - | - | torch | 2.1.0 | - | - -### 1.1 获取CANN&MindIE安装包&环境准备 -- [800I A2](https://www.hiascend.com/developer/download/community/result?module=pt+ie+cann&product=4&model=32) -- [Duo卡](https://www.hiascend.com/developer/download/community/result?module=pt+ie+cann&product=2&model=17) -- [环境准备指导](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC2alpha002/softwareinst/instg/instg_0001.html) - -### 1.2 CANN安装 -```shell -# 增加软件包可执行权限,{version}表示软件版本号,{arch}表示CPU架构,{soc}表示昇腾AI处理器的版本。 -chmod +x ./Ascend-cann-toolkit_{version}_linux-{arch}.run -chmod +x ./Ascend-cann-kernels-{soc}_{version}_linux.run -# 校验软件包安装文件的一致性和完整性 -./Ascend-cann-toolkit_{version}_linux-{arch}.run --check -./Ascend-cann-kernels-{soc}_{version}_linux.run --check -# 安装 -./Ascend-cann-toolkit_{version}_linux-{arch}.run --install -./Ascend-cann-kernels-{soc}_{version}_linux.run --install - -# 设置环境变量 -source /usr/local/Ascend/ascend-toolkit/set_env.sh -``` - -### 1.3 环境依赖安装 -```bash -pip3 install -r requirements.txt -apt-get update -apt-get install libsndfile1 -``` - - -### 1.4 MindIE安装 -```shell -# 增加软件包可执行权限,{version}表示软件版本号,{arch}表示CPU架构。 -chmod +x ./Ascend-mindie_${version}_linux-${arch}.run -./Ascend-mindie_${version}_linux-${arch}.run --check - -# 方式一:默认路径安装 -./Ascend-mindie_${version}_linux-${arch}.run --install -# 设置环境变量 -cd /usr/local/Ascend/mindie && source set_env.sh - -# 方式二:指定路径安装 -./Ascend-mindie_${version}_linux-${arch}.run --install-path=${AieInstallPath} -# 设置环境变量 -cd ${AieInstallPath}/mindie && source set_env.sh -``` - -### 1.5 Torch_npu安装 -安装pytorch框架 版本2.1.0 -[安装包下载](https://download.pytorch.org/whl/cpu/torch/) - -使用pip安装 -```shell -# {version}表示软件版本号,{arch}表示CPU架构。 -pip install torch-${version}-cp310-cp310-linux_${arch}.whl -``` -下载 pytorch_v{pytorchversion}_py{pythonversion}.tar.gz -```shell -tar -xzvf pytorch_v{pytorchversion}_py{pythonversion}.tar.gz -# 解压后,会有whl包 -pip install torch_npu-{pytorchversion}.xxxx.{arch}.whl -``` -## 二、下载本仓库 - -### 2.1 下载到本地 -```shell - git clone https://gitee.com/ascend/ModelZoo-PyTorch.git -``` - -## 三、Stable-Audio-Open-1.0 使用 - -### 3.1 权重及配置文件说明 -stable-audio-open-1.0权重链接: -```shell -https://huggingface.co/stabilityai/stable-audio-open-1.0/tree/main -``` - -### 3.2 单卡功能测试 -设置权重路径 -```shell -model_base = './stable-audio-open-1.0' -``` -执行命令: -```shell -python3 inference_stableaudio.py \ - --model ${model_base} \ - --prompt_file ./prompts/prompts.txt \ - --num_inference_steps 100 \ - --audio_end_in_s 10 10 47 \ - --save_dir ./results \ - --device 0 -``` -参数说明: -- --model:模型权重路径。 -- --prompt_file:提示词文件。 -- --num_inference_steps: 语音生成迭代次数。 -- --audio_end_in_s:生成语音的时长,如不输入则默认生成10s。 -- --save_dir:生成语音的存放目录。 -- --device:推理设备ID。 - -执行完成后在`./results`目录下生成推理语音,语音生成顺序与文本中prompt顺序保持一致,并在终端显示推理时间。 - -### 3.2 模型推理性能 - -性能参考下列数据。 - -| 硬件形态 | 迭代次数 | 平均耗时| -| :------: |:----:|:----:| -| Atlas 800I A2 (32G) | 100 | 10.251 | \ No newline at end of file -- Gitee From 486805ad3ee606c5850a920a94b5313b3696e338 Mon Sep 17 00:00:00 2001 From: zhoufan2956 Date: Wed, 25 Dec 2024 20:38:13 +0800 Subject: [PATCH 23/32] add stable_audio --- MindIE/MindIE-Torch/built-in/foundation/stable_audio/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/README.md b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/README.md index 5b7cb35643..6f1dadd075 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/README.md +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/README.md @@ -72,7 +72,7 @@ pip install torch_npu-{pytorchversion}.xxxx.{arch}.whl ### 2.1 下载到本地 ```shell - git clone https://gitee.com/ascend/ModelZoo-PyTorch.git +git clone https://gitee.com/ascend/ModelZoo-PyTorch.git ``` ## 三、Stable-Audio-Open-1.0 使用 -- Gitee From c6842a9230cb3f687ff52ee7176f2dddaece611c Mon Sep 17 00:00:00 2001 From: zhoufan2956 Date: Thu, 26 Dec 2024 09:43:50 +0800 Subject: [PATCH 24/32] add stable_audio --- .../built-in/foundation/stable_audio/README.md | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/README.md b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/README.md index 6f1dadd075..a10d9288e4 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/README.md +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/README.md @@ -28,15 +28,8 @@ chmod +x ./Ascend-cann-kernels-{soc}_{version}_linux.run source /usr/local/Ascend/ascend-toolkit/set_env.sh ``` -### 1.3 环境依赖安装 -```bash -pip3 install -r requirements.txt -apt-get update -apt-get install libsndfile1 -``` - -### 1.4 MindIE安装 +### 1.3 MindIE安装 ```shell # 增加软件包可执行权限,{version}表示软件版本号,{arch}表示CPU架构。 chmod +x ./Ascend-mindie_${version}_linux-${arch}.run @@ -53,7 +46,7 @@ cd /usr/local/Ascend/mindie && source set_env.sh cd ${AieInstallPath}/mindie && source set_env.sh ``` -### 1.5 Torch_npu安装 +### 1.4 Torch_npu安装 安装pytorch框架 版本2.1.0 [安装包下载](https://download.pytorch.org/whl/cpu/torch/) @@ -75,6 +68,13 @@ pip install torch_npu-{pytorchversion}.xxxx.{arch}.whl git clone https://gitee.com/ascend/ModelZoo-PyTorch.git ``` +### 2.2 环境依赖安装 +```bash +pip3 install -r requirements.txt +apt-get update +apt-get install libsndfile1 +``` + ## 三、Stable-Audio-Open-1.0 使用 ### 3.1 权重及配置文件说明 -- Gitee From c813bf912ccc2ec1f56dd2cc6bee78c2e2e6d226 Mon Sep 17 00:00:00 2001 From: zhoufan2956 Date: Fri, 27 Dec 2024 21:06:39 +0800 Subject: [PATCH 25/32] add stable_audio --- MindIE/MindIE-Torch/built-in/foundation/stable_audio/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/README.md b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/README.md index a10d9288e4..c96386f720 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/README.md +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/README.md @@ -86,7 +86,7 @@ https://huggingface.co/stabilityai/stable-audio-open-1.0/tree/main ### 3.2 单卡功能测试 设置权重路径 ```shell -model_base = './stable-audio-open-1.0' +model_base='./stable-audio-open-1.0' ``` 执行命令: ```shell -- Gitee From 5cb951a9ce8770d6112175bd6790934b4b5bdb53 Mon Sep 17 00:00:00 2001 From: zhoufan2956 Date: Thu, 2 Jan 2025 15:08:58 +0800 Subject: [PATCH 26/32] add stable_audio --- .../stableaudio/layers/attention_processor.py | 16 ++++------------ .../pipeline/pipeline_stable_audio.py | 3 +++ 2 files changed, 7 insertions(+), 12 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/attention_processor.py b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/attention_processor.py index e4e2104bd9..d412839105 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/attention_processor.py +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/attention_processor.py @@ -856,26 +856,18 @@ class StableAudioAttnProcessor2_0: # Apply RoPE if needed if rotary_emb is not None: - query_dtype = query.dtype - key_dtype = key.dtype - query = query.to(torch.float32) - key = key.to(torch.float32) - rot_dim = rotary_emb[0].shape[-1] + cos = rotary_emb[0][None, :, None, :] + sin = rotary_emb[1][None, :, None, :] query_to_rotate, query_unrotated = query[..., :rot_dim], query[..., rot_dim:] - query_rotated = apply_rotary_emb(query_to_rotate, rotary_emb, use_real=True, use_real_unbind_dim=-2) - + query_rotated = torch_npu.npu_rotary_mul(query_to_rotate, cos, sin) query = torch.cat((query_rotated, query_unrotated), dim=-1) if not attn.is_cross_attention: key_to_rotate, key_unrotated = key[..., :rot_dim], key[..., rot_dim:] - key_rotated = apply_rotary_emb(key_to_rotate, rotary_emb, use_real=True, use_real_unbind_dim=-2) - + key_rotated = torch_npu.npu_rotary_mul(key_to_rotate, cos, sin) key = torch.cat((key_rotated, key_unrotated), dim=-1) - query = query.to(query_dtype) - key = key.to(key_dtype) - if query.device is not torch.device("cpu"): if attention_mask is not None: attention_mask=~attention_mask diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/pipeline/pipeline_stable_audio.py b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/pipeline/pipeline_stable_audio.py index 0a4698ec09..68a580449b 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/pipeline/pipeline_stable_audio.py +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/pipeline/pipeline_stable_audio.py @@ -693,6 +693,9 @@ class StableAudioPipeline(DiffusionPipeline): use_real=True, repeat_interleave_real=False, ) + cos = rotary_embedding[0].to(latents.device).to(torch.float16) + sin = rotary_embedding[1].to(latents.device).to(torch.float16) + rotary_embedding = (cos, sin) # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order -- Gitee From b7f780730b9744cb8437951f2cb1b427739920b5 Mon Sep 17 00:00:00 2001 From: zhoufan2956 Date: Thu, 2 Jan 2025 15:25:15 +0800 Subject: [PATCH 27/32] add stable_audio --- MindIE/MindIE-Torch/built-in/foundation/stable_audio/README.md | 2 +- .../stable_audio/stableaudio/layers/attention_processor.py | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/README.md b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/README.md index c96386f720..21b83e2e89 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/README.md +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/README.md @@ -114,4 +114,4 @@ python3 inference_stableaudio.py \ | 硬件形态 | 迭代次数 | 平均耗时| | :------: |:----:|:----:| -| Atlas 800I A2 (32G) | 100 | 10.201s | \ No newline at end of file +| Atlas 800I A2 (32G) | 100 | 7.376s | \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/attention_processor.py b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/attention_processor.py index d412839105..977cdae7e3 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/attention_processor.py +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/attention_processor.py @@ -857,8 +857,6 @@ class StableAudioAttnProcessor2_0: # Apply RoPE if needed if rotary_emb is not None: rot_dim = rotary_emb[0].shape[-1] - cos = rotary_emb[0][None, :, None, :] - sin = rotary_emb[1][None, :, None, :] query_to_rotate, query_unrotated = query[..., :rot_dim], query[..., rot_dim:] query_rotated = torch_npu.npu_rotary_mul(query_to_rotate, cos, sin) query = torch.cat((query_rotated, query_unrotated), dim=-1) -- Gitee From c5fe917d4d61182e2420bba7977422f425d318f0 Mon Sep 17 00:00:00 2001 From: zhoufan2956 Date: Thu, 2 Jan 2025 15:27:12 +0800 Subject: [PATCH 28/32] add stable_audio --- .../stable_audio/stableaudio/layers/attention_processor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/attention_processor.py b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/attention_processor.py index 977cdae7e3..ac3d427ba3 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/attention_processor.py +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/attention_processor.py @@ -858,12 +858,12 @@ class StableAudioAttnProcessor2_0: if rotary_emb is not None: rot_dim = rotary_emb[0].shape[-1] query_to_rotate, query_unrotated = query[..., :rot_dim], query[..., rot_dim:] - query_rotated = torch_npu.npu_rotary_mul(query_to_rotate, cos, sin) + query_rotated = torch_npu.npu_rotary_mul(query_to_rotate, rotary_emb[0], rotary_emb[1]) query = torch.cat((query_rotated, query_unrotated), dim=-1) if not attn.is_cross_attention: key_to_rotate, key_unrotated = key[..., :rot_dim], key[..., rot_dim:] - key_rotated = torch_npu.npu_rotary_mul(key_to_rotate, cos, sin) + key_rotated = torch_npu.npu_rotary_mul(key_to_rotate, rotary_emb[0], rotary_emb[1]) key = torch.cat((key_rotated, key_unrotated), dim=-1) if query.device is not torch.device("cpu"): -- Gitee From ed16a4fc72575cf1a26a736afa31c02a5efa4d43 Mon Sep 17 00:00:00 2001 From: zhoufan2956 Date: Thu, 2 Jan 2025 15:29:13 +0800 Subject: [PATCH 29/32] add stable_audio --- .../stable_audio/stableaudio/layers/attention_processor.py | 1 + 1 file changed, 1 insertion(+) diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/attention_processor.py b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/attention_processor.py index ac3d427ba3..1303250abf 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/attention_processor.py +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/attention_processor.py @@ -856,6 +856,7 @@ class StableAudioAttnProcessor2_0: # Apply RoPE if needed if rotary_emb is not None: + print("rotary_emb[0]",rotary_emb[0].shape) rot_dim = rotary_emb[0].shape[-1] query_to_rotate, query_unrotated = query[..., :rot_dim], query[..., rot_dim:] query_rotated = torch_npu.npu_rotary_mul(query_to_rotate, rotary_emb[0], rotary_emb[1]) -- Gitee From 9dc0bf263d2f1b4934b1dfa7735955145fb7d7e0 Mon Sep 17 00:00:00 2001 From: zhoufan2956 Date: Thu, 2 Jan 2025 15:30:22 +0800 Subject: [PATCH 30/32] add stable_audio --- .../stableaudio/pipeline/pipeline_stable_audio.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/pipeline/pipeline_stable_audio.py b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/pipeline/pipeline_stable_audio.py index 68a580449b..14c09d34f5 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/pipeline/pipeline_stable_audio.py +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/pipeline/pipeline_stable_audio.py @@ -693,9 +693,11 @@ class StableAudioPipeline(DiffusionPipeline): use_real=True, repeat_interleave_real=False, ) - cos = rotary_embedding[0].to(latents.device).to(torch.float16) - sin = rotary_embedding[1].to(latents.device).to(torch.float16) + cos = rotary_embedding[0][None, :, None, :].to(latents.device).to(torch.float16) + sin = rotary_embedding[1][None, :, None, :].to(latents.device).to(torch.float16) rotary_embedding = (cos, sin) + print("rotary_emb[0]",rotary_emb[0].shape) + # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order -- Gitee From 33044c002d36e81d729e325147ea53490b327b45 Mon Sep 17 00:00:00 2001 From: zhoufan2956 Date: Thu, 2 Jan 2025 15:31:43 +0800 Subject: [PATCH 31/32] add stable_audio --- .../stable_audio/stableaudio/pipeline/pipeline_stable_audio.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/pipeline/pipeline_stable_audio.py b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/pipeline/pipeline_stable_audio.py index 14c09d34f5..05f842387c 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/pipeline/pipeline_stable_audio.py +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/pipeline/pipeline_stable_audio.py @@ -696,7 +696,7 @@ class StableAudioPipeline(DiffusionPipeline): cos = rotary_embedding[0][None, :, None, :].to(latents.device).to(torch.float16) sin = rotary_embedding[1][None, :, None, :].to(latents.device).to(torch.float16) rotary_embedding = (cos, sin) - print("rotary_emb[0]",rotary_emb[0].shape) + print("rotary_emb[0]",rotary_embedding[0].shape) # 8. Denoising loop -- Gitee From b4cf1ee33cc668364bf9115d790a07583fa56e29 Mon Sep 17 00:00:00 2001 From: zhoufan2956 Date: Thu, 2 Jan 2025 15:32:41 +0800 Subject: [PATCH 32/32] add stable_audio --- .../stable_audio/stableaudio/layers/attention_processor.py | 1 - .../stable_audio/stableaudio/pipeline/pipeline_stable_audio.py | 2 -- 2 files changed, 3 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/attention_processor.py b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/attention_processor.py index 1303250abf..ac3d427ba3 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/attention_processor.py +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/layers/attention_processor.py @@ -856,7 +856,6 @@ class StableAudioAttnProcessor2_0: # Apply RoPE if needed if rotary_emb is not None: - print("rotary_emb[0]",rotary_emb[0].shape) rot_dim = rotary_emb[0].shape[-1] query_to_rotate, query_unrotated = query[..., :rot_dim], query[..., rot_dim:] query_rotated = torch_npu.npu_rotary_mul(query_to_rotate, rotary_emb[0], rotary_emb[1]) diff --git a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/pipeline/pipeline_stable_audio.py b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/pipeline/pipeline_stable_audio.py index 05f842387c..d1823c905e 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/pipeline/pipeline_stable_audio.py +++ b/MindIE/MindIE-Torch/built-in/foundation/stable_audio/stableaudio/pipeline/pipeline_stable_audio.py @@ -696,8 +696,6 @@ class StableAudioPipeline(DiffusionPipeline): cos = rotary_embedding[0][None, :, None, :].to(latents.device).to(torch.float16) sin = rotary_embedding[1][None, :, None, :].to(latents.device).to(torch.float16) rotary_embedding = (cos, sin) - print("rotary_emb[0]",rotary_embedding[0].shape) - # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order -- Gitee