From 3486cc1b1983ed7732a959227d02fbd1165932e6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cmazhixin00=E2=80=9D?= Date: Wed, 24 Sep 2025 15:43:01 +0800 Subject: [PATCH] test --- .../cogview3plus/layers/__init__.py | 3 +- .../cogview3plus/layers/linear.py | 97 +++++++++++++++++++ .../models/attention_processor.py | 2 +- 3 files changed, 100 insertions(+), 2 deletions(-) create mode 100644 MindIE/MultiModal/CogView3-Plus-3B/cogview3plus/layers/linear.py diff --git a/MindIE/MultiModal/CogView3-Plus-3B/cogview3plus/layers/__init__.py b/MindIE/MultiModal/CogView3-Plus-3B/cogview3plus/layers/__init__.py index 4d25f1e889..602ad432a0 100644 --- a/MindIE/MultiModal/CogView3-Plus-3B/cogview3plus/layers/__init__.py +++ b/MindIE/MultiModal/CogView3-Plus-3B/cogview3plus/layers/__init__.py @@ -1,2 +1,3 @@ from .normalization import CogView3PlusAdaLayerNormZeroTextImage, AdaLayerNormContinuous -from .embeddings import CogView3CombinedTimestepSizeEmbeddings, CogView3PlusPatchEmbed \ No newline at end of file +from .embeddings import CogView3CombinedTimestepSizeEmbeddings, CogView3PlusPatchEmbed +from .linear import QKVLinear \ No newline at end of file diff --git a/MindIE/MultiModal/CogView3-Plus-3B/cogview3plus/layers/linear.py b/MindIE/MultiModal/CogView3-Plus-3B/cogview3plus/layers/linear.py new file mode 100644 index 0000000000..c871bb867f --- /dev/null +++ b/MindIE/MultiModal/CogView3-Plus-3B/cogview3plus/layers/linear.py @@ -0,0 +1,97 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +import torch.nn as nn + + +class QKVLinear(nn.Module): + def __init__(self, attention_dim, hidden_size, qkv_bias=True, cross_attention_dim=None, cross_hidden_size=None, + device=None, dtype=None): + super(QKVLinear, self).__init__() + self.attention_dim = attention_dim + self.hidden_size = hidden_size + + self.cross_attention_dim = cross_attention_dim + self.cross_hidden_size = self.hidden_size if cross_hidden_size is None else cross_hidden_size + self.qkv_bias = qkv_bias + + factory_kwargs = {"device": device, "dtype": dtype} + + if cross_attention_dim is None: + self.weight = nn.Parameter(torch.empty([self.attention_dim, 3 * self.hidden_size], **factory_kwargs)) + if self.qkv_bias: + self.bias = nn.Parameter(torch.empty([3 * self.hidden_size], **factory_kwargs)) + else: + self.q_weight = nn.Parameter(torch.empty([self.attention_dim, self.hidden_size], **factory_kwargs)) + self.kv_weight = nn.Parameter( + torch.empty([self.cross_attention_dim, 2 * self.cross_hidden_size], **factory_kwargs)) + + if self.qkv_bias: + self.q_bias = nn.Parameter(torch.empty([self.hidden_size], **factory_kwargs)) + self.kv_bias = nn.Parameter(torch.empty([2 * self.cross_hidden_size], **factory_kwargs)) + + def forward(self, hidden_states, encoder_hidden_states=None): + + if self.cross_attention_dim is None: + if not self.qkv_bias: + qkv = torch.matmul(hidden_states, self.weight) + else: + qkv = torch.addmm( + self.bias, + hidden_states.view(hidden_states.size(0) * hidden_states.size(1), hidden_states.size(2)), + self.weight, + beta=1, + alpha=1 + ) + + batch, seqlen, _ = hidden_states.shape + qkv_shape = (batch, seqlen, 3, -1) + qkv = qkv.view(qkv_shape) + q, k, v = qkv.unbind(2) + + else: + if not self.qkv_bias: + q = torch.matmul(hidden_states, self.q_weight) + kv = torch.matmul(encoder_hidden_states, self.kv_weight) + else: + q = torch.addmm( + self.q_bias, + hidden_states.view(hidden_states.size(0) * hidden_states.size(1), hidden_states.size(2)), + self.q_weight, + beta=1, + alpha=1 + ) + kv = torch.addmm( + self.kv_bias, + encoder_hidden_states.view( + encoder_hidden_states.size(0) * encoder_hidden_states.size(1), + encoder_hidden_states.size(2)), + self.kv_weight, + beta=1, + alpha=1 + ) + + batch, q_seqlen, _ = hidden_states.shape + q = q.view(batch, q_seqlen, -1) + + batch, kv_seqlen, _ = encoder_hidden_states.shape + kv_shape = (batch, kv_seqlen, 2, -1) + + kv = kv.view(kv_shape) + k, v = kv.unbind(2) + + return q, k, v diff --git a/MindIE/MultiModal/CogView3-Plus-3B/cogview3plus/models/attention_processor.py b/MindIE/MultiModal/CogView3-Plus-3B/cogview3plus/models/attention_processor.py index dbef911628..7beecd074c 100644 --- a/MindIE/MultiModal/CogView3-Plus-3B/cogview3plus/models/attention_processor.py +++ b/MindIE/MultiModal/CogView3-Plus-3B/cogview3plus/models/attention_processor.py @@ -23,7 +23,7 @@ import torch_npu from diffusers.utils import logging from diffusers.utils.torch_utils import maybe_allow_in_graph -from mindiesd.layers.linear import QKVLinear +from ..layers import QKVLinear logger = logging.get_logger(__name__) # pylint: disable=invalid-name -- Gitee