From ded0e650003d7c7e9912d4167ec54f0daf5fb1b9 Mon Sep 17 00:00:00 2001 From: wenziyi2025 <1747295943@qq.com> Date: Wed, 10 Dec 2025 21:07:28 +0800 Subject: [PATCH] =?UTF-8?q?Fix:=20=E9=80=82=E9=85=8D=20Ascend=20910B=20?= =?UTF-8?q?=E7=8E=AF=E5=A2=83=EF=BC=8C=E4=BF=AE=E5=A4=8D=20L2Normalize=20?= =?UTF-8?q?=E7=B2=BE=E5=BA=A6=E5=B9=B6=E7=A7=BB=E9=99=A4=E6=97=A0=E5=85=B3?= =?UTF-8?q?=E6=A8=A1=E5=9E=8B=E4=BE=9D=E8=B5=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../applications/model_cards/ColabDesign.md | 59 +++ .../models/esm_if1/module/features.py | 373 +++++++++++++ .../models/megafold/module/structure.py | 265 ++++++++++ .../multimer/module/multimer_structure.py | 252 +++++++++ .../pipeline/models/rasp/module/structure.py | 498 ++++++++++++++++++ 5 files changed, 1447 insertions(+) create mode 100644 MindSPONGE/applications/model_cards/ColabDesign.md create mode 100644 MindSPONGE/src/mindsponge/pipeline/models/esm_if1/module/features.py create mode 100644 MindSPONGE/src/mindsponge/pipeline/models/megafold/module/structure.py create mode 100644 MindSPONGE/src/mindsponge/pipeline/models/multimer/module/multimer_structure.py create mode 100644 MindSPONGE/src/mindsponge/pipeline/models/rasp/module/structure.py diff --git a/MindSPONGE/applications/model_cards/ColabDesign.md b/MindSPONGE/applications/model_cards/ColabDesign.md new file mode 100644 index 000000000..4ab9de3c1 --- /dev/null +++ b/MindSPONGE/applications/model_cards/ColabDesign.md @@ -0,0 +1,59 @@ +# ColabDesign + +## 模型介绍 + +对于一个骨架结构位置坐标已知但氨基酸种类未知的蛋白质,假定它的长度为n,该序列共有20的n次方种可能性,然而自然界中现存的蛋白质样本只占这庞大集合中的一小部分,难以通过遍历的方式筛选到合理的氨基酸序列。因此,蛋白质设计任务即为通过计算的方式,找到可以形成该pdb结构的蛋白质氨基酸序列。 + +ColabDesign是蛋白质设计模型,通过输入蛋白质骨架坐标的pdb文件,基于蛋白质结构预测模型来预测整个蛋白质序列,被称为Hallucination and Inpainting。 + +通常设计具有某种特定功能的蛋白质共需要2个步骤: + +- 识别特定功能的可能活性位点的几何形状与氨基酸种类,如酶的活性位点,蛋白抑制剂等。 +- 设计一个包含这些特定活性位点的氨基酸序列,并折叠成对应三维结构。 + +步骤2为ColabDesign主要解决的问题,用固定位点或者骨架作为输入,产生完整序列。 + +最早对该项目进行探索的方法为[trDesign](https://www.biorxiv.org/content/10.1101/2020.07.23.218917v1.abstract),使用了trDesign和Rosetta结合的方式。之后[Hallucination](https://www.nature.com/articles/s41586-021-04184-w)基于trDesign,借鉴了DeepDream模型,以Hallucination+trDesign的设计方式进一步提升了效果。在融合功能[Motif](https://www.biorxiv.org/content/10.1101/2020.11.29.402743v1.abstract)设计方法出现之后,将trDesign和Hallucination相结合,解决了对预先生成的scaffold数据库的依赖问题。 + +在这之后,ColabDesign以RoseTTAFold为核心进行实验,在AlphaFold2上进行交叉验证,基于这两个模型的Hallucination被称为“RFdesign”和“AFdesign”。RoseTTAFold显式地利用了三维结构坐标,相比trRosetta只利用二维特征信息,它有着更多地信息来定义各类loss,解决更多之前不可解决的问题,大幅提升了实验精度。 + +![ColabDesign](../../docs/modelcards/ColabDesign.png) + +A图为Free hallucination,将序列传入trRosetta或者RoseTTAFold预测3D结构,使用MCMC迭代优化loss函数来产生序列。B图为Constrained hallucination,使用与A图相同的方式,但是loss函数除了结构信息之外还包含了Motif重述和其他特定任务信息。C图为缺失信息恢复任务,通过输入部分序列或者部分结构信息来补齐完整序列或结构。D图为可以通过约束幻觉和相应的损失函数,即本文的方法来解决的设计问题。E图为本文方法概览,本文中的蛋白质设计挑战为多种场景下的缺失信息恢复任务。 + +## 使用限制 + +该Pipeline中的ColabDesign与最初的ColabDesign不同,没有基于RoseTTAFold和AlphaFold 2,而是基于MEGA-Protein实现了Hallucination和fixbb两个功能。 +支持mindspore2.7.1 + +该模型目前只支持推理,即输入蛋白质pdb文件,输出设计后地氨基酸序列。暂未提供模型训练方法与训练数据集。 + +## 如何使用 + +可使用PDB文件作为模型推理输入,样例代码如下所示: +先提前下载 https://gitee.com/mindspore/mindscience/blob/r0.7/MindSPONGE/applications/model_configs/ColabDesign/predict_256.yaml + +```bash +from mindsponge import PipeLine +import os +config_path = os.path.abspath("predict_256.yaml") +pipe = PipeLine(name = "ColabDesign") +pipe.set_device_id(0) +pipe.initialize(config_path=config_path) +pipe.model.from_pretrained() +res = pipe.predict({YOUR_PDB_PATH}) +print(res) +``` + +## 引用 + +```bash +@article{wang2021deep, + title={Deep learning methods for designing proteins scaffolding functional sites}, + author={Wang, Jue and Lisanza, Sidney and Juergens, David and Tischer, Doug and Anishchenko, Ivan and Baek, Minkyung and Watson, Joseph L and Chun, Jung Ho and Milles, Lukas F and Dauparas, Justas and others}, + journal={BioRxiv}, + pages={2021--11}, + year={2021}, + publisher={Cold Spring Harbor Laboratory} +} +``` diff --git a/MindSPONGE/src/mindsponge/pipeline/models/esm_if1/module/features.py b/MindSPONGE/src/mindsponge/pipeline/models/esm_if1/module/features.py new file mode 100644 index 000000000..68f8b63ea --- /dev/null +++ b/MindSPONGE/src/mindsponge/pipeline/models/esm_if1/module/features.py @@ -0,0 +1,373 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Feature extraction""" + +import math +import numpy as np +import mindspore as ms +import mindspore.ops as ops +import mindspore.nn as nn +from mindspore import context +# pylint: disable=relative-beyond-top-level +from .basic_modules import GVP, LayerNorm, Dense +from .util import normalize, norm, nan_to_num, rbf, flatten_graph, ms_transpose, ms_padding_without_val + + +class GVPInputFeaturizer(nn.Cell): + """Input feature extraction for GVP""" + + @staticmethod + def get_node_features(coords, coord_mask, with_coord_mask=True): + """Get node features""" + node_scalar_features = GVPInputFeaturizer._dihedrals(coords) + if with_coord_mask: + coord_mask = ops.ExpandDims()(ops.Cast()(coord_mask, ms.float32), -1) + node_scalar_features = ops.Concat(axis=-1)([node_scalar_features, coord_mask]) + x_ca = coords[:, :, 1] + orientations = GVPInputFeaturizer._orientations(x_ca) + sidechains = GVPInputFeaturizer._sidechains(coords) + node_vector_features = ops.Concat(axis=-2)([orientations, ops.ExpandDims()(sidechains, -2)]) + return node_scalar_features, node_vector_features + + @staticmethod + def _orientations(x): + + forward = normalize(x[:, 1:] - x[:, :-1]) + backward = normalize(x[:, :-1] - x[:, 1:]) + forward = ops.concat((forward, ops.Zeros()((forward.shape[0], 1, forward.shape[2]), ms.float32)), 1) + backward = ops.concat((ops.Zeros()((backward.shape[0], 1, backward.shape[2]), ms.float32), backward), 1) + + output = ops.Concat(axis=-2)([ops.ExpandDims()(forward, -2), ops.ExpandDims()(backward, -2)]) + return output + + @staticmethod + def _sidechains(x): + n, origin, c = x[:, :, 0], x[:, :, 1], x[:, :, 2] + c, n = normalize(c - origin), normalize(n - origin) + bisector = normalize(c + n) + perp = normalize(ms.numpy.cross(c, n)) + vec = -bisector * math.sqrt(1 / 3) - perp * math.sqrt(2 / 3) + return vec + + @staticmethod + def _dihedrals(x, eps=1e-7): + """Dihedron""" + + y = x[:, :, :3].reshape((x.shape[0], (x.shape[1] * x.shape[2]), x.shape[3])) + bsz = x.shape[0] + dx = y[:, 1:] - y[:, :-1] + u = normalize(dx, dim=-1) + u_2 = u[:, :-2] + u_1 = u[:, 1:-1] + u_0 = u[:, 2:] + + # Backbone normals + n_2 = normalize(ms.numpy.cross(u_2, u_1), dim=-1) + n_1 = normalize(ms.numpy.cross(u_1, u_0), dim=-1) + + # Angle between normals + cosd = ops.ReduceSum()(n_2 * n_1, -1) + + min_value = ms.Tensor((-1 + eps), ms.float32) + max_value = ms.Tensor((1 - eps), ms.float32) + cosd = ops.clip_by_value(cosd, clip_value_min=min_value, clip_value_max=max_value) + d = ops.Sign()((u_2 * n_1).sum(-1)) * ops.ACos()(cosd) + + # This scheme will remove phi[0], psi[-1], omega[-1] + d = ms_padding_without_val(d, [1, 2]) + d = ops.Reshape()(d, (bsz, -1, 3)) + # Lift angle representations to the circle + d_features = ops.Concat(axis=-1)([ops.Cos()(d), ops.Sin()(d)]) + return d_features + + @staticmethod + def _positional_embeddings(edge_index, + num_embeddings=None, + num_positional_embeddings=16): + """Positional embeddings""" + + num_embeddings = num_embeddings or num_positional_embeddings or [] + d = edge_index[0] - edge_index[1] + + frequency = ops.Exp()( + ms.numpy.arange(0, num_embeddings, 2, dtype=ms.float32) + * -(np.log(10000.0) / num_embeddings) + ) + angles = ops.ExpandDims()(d, -1) * frequency + e = ops.Concat(-1)((ops.Cos()(angles), ops.Sin()(angles))) + return e + + @staticmethod + def _dist(x, coord_mask, padding_mask, top_k_neighbors): + """ Pairwise euclidean distances """ + bsz, maxlen = x.shape[0], x.shape[1] + coord_mask = ops.Cast()(coord_mask, ms.float32) + coord_mask_2d = ops.ExpandDims()(coord_mask, 1) * ops.ExpandDims()(coord_mask, 2) + residue_mask = ~padding_mask + residue_mask = ops.Cast()(residue_mask, ms.float32) + residue_mask_2d = ops.ExpandDims()(residue_mask, 1) * ops.ExpandDims()(residue_mask, 2) + dx = ops.ExpandDims()(x, 1) - ops.ExpandDims()(x, 2) + d = coord_mask_2d * norm(dx, dim=-1) + + # sorting preference: first those with coords, then among the residues that + # exist but are masked use distance in sequence as tie breaker, and then the + # residues that came from padding are last + seqpos = ms.numpy.arange(maxlen) + seqpos_1 = ops.ExpandDims()(seqpos, 1) + seqpos_0 = ops.ExpandDims()(seqpos, 0) + d_seq = ops.Abs()(seqpos_1 - seqpos_0) + if bsz != 1: + d_seq = ms.numpy.tile(d_seq, (bsz, 1, 1)) + coord_mask_2d = ops.Cast()(coord_mask_2d, ms.bool_) + residue_mask_2d = ops.Cast()(residue_mask_2d, ms.bool_) + verse_coord_mask_2d = ops.Cast()(~coord_mask_2d, ms.float32) + verse_residue_mask_2d = ops.Cast()(~residue_mask_2d, ms.float32) + d_adjust = nan_to_num(d) + (verse_coord_mask_2d) * (1e8 + d_seq * 1e6) + ( + verse_residue_mask_2d) * (1e10) + + if top_k_neighbors == -1: + d_neighbors = d_adjust / 1e4 + e_idx = seqpos.repeat( + *d_neighbors.shape[:-1], 1) + else: + d_adjust = d_adjust / 1e4 + if context.get_context("device_target") == "GPU": + d_neighbors, e_idx = ops.Sort(axis=-1, descending=True)(d_adjust) + else: + d_neighbors, e_idx = ops.TopK(sorted=True)(d_adjust, d_adjust.shape[-1]) + d_neighbors, e_idx = d_neighbors[..., ::-1], e_idx[..., ::-1] + d_neighbors, e_idx = d_neighbors[:, :, 0:int(min(top_k_neighbors, x.shape[1]))], \ + e_idx[:, :, 0:int(min(top_k_neighbors, x.shape[1]))] + d_neighbors = ms.Tensor(d_neighbors, ms.float32)*1e4 + coord_mask_neighbors = (d_neighbors < 5e7) + residue_mask_neighbors = (d_neighbors < 5e9) + output = [d_neighbors, e_idx, coord_mask_neighbors, residue_mask_neighbors] + return output + + +class Normalize(nn.Cell): + """Normalization""" + + def __init__(self, features, epsilon=1e-6): + super(Normalize, self).__init__() + self.gain = ms.Parameter(ops.Ones()(features, ms.float32)) + self.bias = ms.Parameter(ops.Zeros()(features, ms.float32)) + self.epsilon = epsilon + + def construct(self, x, dim=-1): + """Normalization construction""" + + mu = x.mean(dim, keep_dims=True) + sigma = ops.Sqrt()(x.var(dim, keepdims=True) + self.epsilon) + gain = self.gain + bias = self.bias + # Reshape + if dim != -1: + shape = [1] * len(mu.size()) + shape[dim] = self.gain.size()[0] + gain = gain.view(shape) + bias = bias.view(shape) + return gain * (x - mu) / (sigma + self.epsilon) + bias + + +class DihedralFeatures(nn.Cell): + """Dihedral features""" + + def __init__(self, node_embed_dim): + """ Embed dihedral angle features. """ + super(DihedralFeatures, self).__init__() + # 3 dihedral angles; sin and cos of each angle + node_in = 6 + # Normalization and embedding + self.node_embedding = Dense(node_in, node_embed_dim, has_bias=True) + self.norm_nodes = Normalize(node_embed_dim) + + @staticmethod + def _dihedrals(x, eps=1e-7, return_angles=False): + """Dihedron in DihedralFeatures""" + + # First 3 coordinates are N, CA, C + x = x[:, :, :3, :].reshape(x.shape[0], 3 * x.shape[1], 3) + + # Shifted slices of unit vectors + dx = x[:, 1:, :] - x[:, :-1, :] + u = ops.L2Normalize(axis=-1,epsilon=1e-6)(dx) + u_2 = u[:, :-2, :] + u_1 = u[:, 1:-1, :] + u_0 = u[:, 2:, :] + # Backbone normals + n_2 = ops.L2Normalize(axis=-1,epsilon=1e-6)(ms.numpy.cross(u_2, u_1)) + n_1 = ops.L2Normalize(axis=-1,epsilon=1e-6)(ms.numpy.cross(u_1, u_0)) + + # Angle between normals + cosd = (n_2 * n_1).sum(-1) + min_value = ms.Tensor((-1 + eps), ms.float32) + max_value = ms.Tensor((1 - eps), ms.float32) + cosd = ops.clip_by_value(cosd, clip_value_min=min_value, clip_value_max=max_value) + d = ops.Sign()((u_2 * n_1).sum(-1)) * ops.ACos()(cosd) + + # This scheme will remove phi[0], psi[-1], omega[-1] + d = ms_padding_without_val(d, [1, 2]) + d = d.view((d.shape[0], int(d.shape[1] / 3), 3)) + phi, psi, omega = ops.Unstack(axis=-1)(d) + + if return_angles: + return phi, psi, omega + + # Lift angle representations to the circle + d_features = ops.Concat(axis=2)((ops.Cos()(d), ops.Sin()(d))) + return d_features + + def construct(self, x): + """ Featurize coordinates as an attributed graph """ + v = self._dihedrals(x) + v = self.node_embedding(v) + v = self.norm_nodes(v) + return v + + +class GVPGraphEmbedding(GVPInputFeaturizer): + """GVP graph embedding""" + + def __init__(self, args): + super().__init__() + self.top_k_neighbors = args.top_k_neighbors + self.num_positional_embeddings = 16 + self.remove_edges_without_coords = True + node_input_dim = (7, 3) + edge_input_dim = (34, 1) + node_hidden_dim = (args.node_hidden_dim_scalar, + args.node_hidden_dim_vector) + edge_hidden_dim = (args.edge_hidden_dim_scalar, + args.edge_hidden_dim_vector) + self.embed_node = nn.SequentialCell( + [GVP(node_input_dim, node_hidden_dim, activations=(None, None)), + LayerNorm(node_hidden_dim, eps=1e-4)] + ) + self.embed_edge = nn.SequentialCell( + [GVP(edge_input_dim, edge_hidden_dim, activations=(None, None)), + LayerNorm(edge_hidden_dim, eps=1e-4)] + ) + self.embed_confidence = Dense(16, args.node_hidden_dim_scalar) + + def construct(self, coords, coord_mask, padding_mask, confidence): + """GVP graph embedding construction""" + + node_features = self.get_node_features(coords, coord_mask) + + edge_features, edge_index = self.get_edge_features( + coords, coord_mask, padding_mask) + node_embeddings_scalar, node_embeddings_vector = self.embed_node(node_features) + edge_embeddings = self.embed_edge(edge_features) + + rbf_rep = rbf(confidence, 0., 1.) + + node_embeddings = ( + node_embeddings_scalar + self.embed_confidence(rbf_rep), + node_embeddings_vector + ) + + + node_embeddings, edge_embeddings, edge_index = flatten_graph( + node_embeddings, edge_embeddings, edge_index) + return node_embeddings, edge_embeddings, edge_index + + def get_edge_features(self, coords, coord_mask, padding_mask): + """Get edge features""" + + x_ca = coords[:, :, 1] + + # Get distances to the top k neighbors + e_dist, e_idx, e_coord_mask, e_residue_mask = GVPInputFeaturizer._dist( + x_ca, coord_mask, padding_mask, self.top_k_neighbors) + # Flatten the graph to be batch size 1 for torch_geometric package + dest = e_idx + e_idx_b, e_idx_l, k = e_idx.shape[:3] + + src = ms.numpy.arange(e_idx_l).view((1, e_idx_l, 1)) + src = ops.BroadcastTo((e_idx_b, e_idx_l, k))(src) + + + edge_index = ops.Stack(axis=0)([src, dest]) + + edge_index = edge_index.reshape((edge_index.shape[0], edge_index.shape[1], + (edge_index.shape[2] * edge_index.shape[3]))) + + # After flattening, [B, E] + e_dist = e_dist.reshape((e_dist.shape[0], (e_dist.shape[1] * e_dist.shape[2]))) + + e_coord_mask = e_coord_mask.reshape((e_coord_mask.shape[0], (e_coord_mask.shape[1] * e_coord_mask.shape[2]))) + e_coord_mask = ops.ExpandDims()(e_coord_mask, -1) + e_residue_mask = e_residue_mask.reshape((e_residue_mask.shape[0], + (e_residue_mask.shape[1] * e_residue_mask.shape[2]))) + + # Calculate relative positional embeddings and distance RBF + pos_embeddings = GVPInputFeaturizer._positional_embeddings( + edge_index, + num_positional_embeddings=self.num_positional_embeddings, + ) + d_rbf = rbf(e_dist, 0., 20.) + + # Calculate relative orientation + x_src = ops.ExpandDims()(x_ca, 2) + x_src = ops.BroadcastTo((-1, -1, k, -1))(x_src) + x_src = x_src.reshape((x_src.shape[0], (x_src.shape[1] * x_src.shape[2]), x_src.shape[3])) + + a = ops.ExpandDims()(edge_index[1, :, :], -1) + a = ops.BroadcastTo((e_idx_b, e_idx_l * k, 3))(a) + x_dest = ops.GatherD()( + x_ca, + 1, + a + ) + coord_mask_src = ops.ExpandDims()(coord_mask, 2) + coord_mask_src = ops.BroadcastTo((-1, -1, k))(coord_mask_src) + coord_mask_src = coord_mask_src.reshape((coord_mask_src.shape[0], + (coord_mask_src.shape[1] * coord_mask_src.shape[2]))) + + b = ops.BroadcastTo((e_idx_b, e_idx_l * k))(edge_index[1, :, :]) + + coord_mask_dest = ops.GatherD()( + coord_mask, + 1, + b + ) + e_vectors = x_src - x_dest + # For the ones without coordinates, substitute in the average vector + e_coord_mask_fl = ops.Cast()(e_coord_mask, ms.float32) + e_vector_mean_top = ops.ReduceSum(keep_dims=True)(e_vectors * e_coord_mask_fl, axis=1) + e_vector_mean_bottom = ops.ReduceSum(keep_dims=True)(e_coord_mask_fl, axis=1) + e_vector_mean = e_vector_mean_top / e_vector_mean_bottom + e_vectors_factor1 = e_vectors * e_coord_mask_fl + e_vectors_factor2 = e_vector_mean * ~(e_coord_mask) + e_vectors = e_vectors_factor1 + e_vectors_factor2 + # Normalize and remove nans + edge_s = ops.Concat(axis=-1)([d_rbf, pos_embeddings]) + edge_v = ops.ExpandDims()(normalize(e_vectors), -2) + edge_s, edge_v = map(nan_to_num, (edge_s, edge_v)) + # Also add indications of whether the coordinates are present + + edge_s = ops.Concat(axis=-1)([ + edge_s, + ops.ExpandDims()((~coord_mask_src).astype(np.float32), -1), + ops.ExpandDims()((~coord_mask_dest).astype(np.float32), -1)]) + e_residue_mask = ops.Cast()(e_residue_mask, ms.bool_) + fill_value = ms.Tensor(-1, dtype=edge_index.dtype) + edge_index = edge_index.masked_fill(~e_residue_mask, fill_value) + + if self.remove_edges_without_coords: + edge_index = ops.masked_fill(edge_index, ~e_coord_mask.squeeze(-1), fill_value) + + return (edge_s, edge_v), ms_transpose(edge_index, 0, 1) diff --git a/MindSPONGE/src/mindsponge/pipeline/models/megafold/module/structure.py b/MindSPONGE/src/mindsponge/pipeline/models/megafold/module/structure.py new file mode 100644 index 000000000..14b1a759e --- /dev/null +++ b/MindSPONGE/src/mindsponge/pipeline/models/megafold/module/structure.py @@ -0,0 +1,265 @@ +# Copyright 2023 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""structure module""" +import numpy as np +import mindspore.common.dtype as mstype +import mindspore.nn as nn +import mindspore.numpy as mnp +import mindspore.ops as ops +from mindspore import Tensor +from mindspore.ops import functional as F +from mindsponge.cell import InvariantPointAttention +from mindsponge.common import residue_constants +from mindsponge.cell.initializer import lecun_init +from mindsponge.common.utils import torsion_angles_to_frames,\ + frames_and_literature_positions_to_atom14_pos, atom14_to_atom37 +from mindsponge.common.geometry import initial_affine, quaternion_to_tensor, pre_compose, vecs_scale,\ + vecs_to_tensor, vecs_expand_dims, rots_expand_dims + + +class MultiRigidSidechain(nn.Cell): + """Class to make side chain atoms.""" + + def __init__(self, config, single_repr_dim): + super().__init__() + self.config = config + num_channel = self.config.num_channel + self.input_projection = nn.Dense(single_repr_dim, num_channel, + weight_init=lecun_init(single_repr_dim)) + self.input_projection_1 = nn.Dense(single_repr_dim, num_channel, + weight_init=lecun_init(single_repr_dim)) + self.relu = nn.ReLU() + self.resblock1 = nn.Dense(num_channel, num_channel, + weight_init=lecun_init(num_channel, + initializer_name='relu')) + self.resblock2 = nn.Dense(num_channel, num_channel, weight_init='zeros') + self.resblock1_1 = nn.Dense(num_channel, num_channel, + weight_init=lecun_init(num_channel, initializer_name='relu')) + self.resblock2_1 = nn.Dense(num_channel, num_channel, weight_init='zeros') + self.unnormalized_angles = nn.Dense(num_channel, 14, + weight_init=lecun_init(num_channel)) + self.restype_atom14_to_rigid_group = Tensor(residue_constants.restype_atom14_to_rigid_group) + self.restype_atom14_rigid_group_positions = Tensor(residue_constants.restype_atom14_rigid_group_positions) + self.restype_atom14_mask = Tensor(residue_constants.restype_atom14_mask) + self.restype_rigid_group_default_frame = Tensor(residue_constants.restype_rigid_group_default_frame) + self.l2_normalize = ops.L2Normalize(axis=-1, epsilon=1e-6) + + def construct(self, rotation, translation, act, initial_act, aatype): + """Predict side chains using rotation and translation representations. + + Args: + rotation: The rotation matrices. + translation: A translation matrices. + act: updated pair activations from structure module + initial_act: initial act representations (input of structure module) + aatype: Amino acid type representations + + Returns: + angles, positions and new frames + """ + + act1 = self.input_projection(self.relu(act)) + init_act1 = self.input_projection_1(self.relu(initial_act)) + # Sum the activation list (equivalent to concat then Linear). + act = act1 + init_act1 + + # Mapping with some residual blocks. + # resblock1 + old_act = act + act = self.resblock1(self.relu(act)) + act = self.resblock2(self.relu(act)) + act += old_act + # resblock2 + old_act = act + act = self.resblock1_1(self.relu(act)) + act = self.resblock2_1(self.relu(act)) + act += old_act + + # Map activations to torsion angles. Shape: (num_res, 14). + num_res = act.shape[0] + unnormalized_angles = self.unnormalized_angles(self.relu(act)) + + unnormalized_angles = mnp.reshape(unnormalized_angles, [num_res, 7, 2]) + angles = self.l2_normalize(unnormalized_angles) + + backb_to_global = ((rotation[0], rotation[1], rotation[2], + rotation[3], rotation[4], rotation[5], + rotation[6], rotation[7], rotation[8]), + (translation[0], translation[1], translation[2])) + + all_frames_to_global = torsion_angles_to_frames(aatype, backb_to_global, angles, + self.restype_rigid_group_default_frame) + + pred_positions = \ + frames_and_literature_positions_to_atom14_pos(aatype, all_frames_to_global, + self.restype_atom14_to_rigid_group, + self.restype_atom14_rigid_group_positions, + self.restype_atom14_mask) + + atom_pos = pred_positions + frames = all_frames_to_global + res = (angles, unnormalized_angles, atom_pos, frames) + return res + + +class FoldIteration(nn.Cell): + """A single iteration of the main structure module loop.""" + + def __init__(self, config, pair_dim, single_repr_dim): + super().__init__() + self.config = config + self.drop_out = nn.Dropout(p=0.1) + num_channel = self.config.num_channel + self.attention_layer_norm = nn.LayerNorm([num_channel,], epsilon=1e-5) + self.transition_layer_norm = nn.LayerNorm([num_channel,], epsilon=1e-5) + self.transition = nn.Dense(num_channel, config.num_channel, + weight_init=lecun_init(num_channel, initializer_name='relu')) + self.transition_1 = nn.Dense(num_channel, num_channel, + weight_init=lecun_init(num_channel, initializer_name='relu')) + self.transition_2 = nn.Dense(num_channel, num_channel, weight_init='zeros') + self.relu = nn.ReLU() + self.affine_update = nn.Dense(num_channel, 6, weight_init='zeros') + self.attention_module = InvariantPointAttention(self.config.num_head, + self.config.num_scalar_qk, + self.config.num_scalar_v, + self.config.num_point_v, + self.config.num_point_qk, + num_channel, + pair_dim) + self.mu_side_chain = MultiRigidSidechain(self.config.sidechain, single_repr_dim) + + def construct(self, act, static_feat_2d, sequence_mask, quaternion, rotation, translation, + initial_act, aatype): + """construct""" + attn = self.attention_module(act, static_feat_2d, sequence_mask, rotation, translation) + act += attn + act = self.drop_out(act) + act = self.attention_layer_norm(act) + # Transition + input_act = act + act = self.transition(act) + act = self.relu(act) + act = self.transition_1(act) + act = self.relu(act) + act = self.transition_2(act) + act += input_act + act = self.drop_out(act) + act = self.transition_layer_norm(act) + # This block corresponds to Jumper et al. (2021) Alg. 23 "Backbone update" + # Affine update + affine_update = self.affine_update(act) + quaternion, rotation, translation = pre_compose(quaternion, rotation, translation, + affine_update) + translation1 = vecs_scale(translation, 10.0) + rotation1 = rotation + angles_sin_cos, unnormalized_angles_sin_cos, atom_pos, frames = \ + self.mu_side_chain(rotation1, translation1, act, initial_act, aatype) + affine_output = quaternion_to_tensor(quaternion, translation) + quaternion = F.stop_gradient(quaternion) + rotation = F.stop_gradient(rotation) + res = (act, quaternion, translation, rotation, affine_output, angles_sin_cos, \ + unnormalized_angles_sin_cos, atom_pos, frames) + return res + + +class StructureModule(nn.Cell): + """StructureModule as a network head.""" + + def __init__(self, config, single_repr_dim, pair_dim, seq_length): + super(StructureModule, self).__init__() + self.config = config.structure_module + self.seq_length = seq_length + self.fold_iteration = FoldIteration(self.config, pair_dim, single_repr_dim) + self.single_layer_norm = nn.LayerNorm([single_repr_dim,], epsilon=1e-5) + self.initial_projection = nn.Dense(single_repr_dim, self.config.num_channel, + weight_init=lecun_init(single_repr_dim)) + self.pair_layer_norm = nn.LayerNorm([pair_dim,], epsilon=1e-5) + self.num_layer = self.config.num_layer + self.indice0 = Tensor( + np.arange(self.seq_length).reshape((-1, 1, 1)).repeat(37, axis=1).astype("int32")) + self.traj_w = Tensor(np.array([1.] * 4 + [self.config.position_scale] * 3), mstype.float32) + + def construct(self, single, pair, seq_mask, aatype, residx_atom37_to_atom14=None, + atom37_atom_exists=None): + """construct""" + sequence_mask = seq_mask[:, None] + act = self.single_layer_norm(single) + initial_act = act + act = self.initial_projection(act) + quaternion, rotation, translation = initial_affine(self.seq_length) + act_2d = self.pair_layer_norm(pair) + + atom_pos, affine_output_new, angles_sin_cos_new, um_angles_sin_cos_new,\ + sidechain_frames, act_iter = \ + self.iteration_operation(act, act_2d, sequence_mask, quaternion, + rotation, translation, initial_act, aatype) + atom14_pred_positions = vecs_to_tensor(atom_pos)[-1] + sidechain_atom_pos = atom_pos + + atom37_pred_positions = atom14_to_atom37(atom14_pred_positions, + residx_atom37_to_atom14, + atom37_atom_exists, + self.indice0) + + structure_traj = affine_output_new * self.traj_w + final_affines = affine_output_new[-1] + final_atom_positions = atom37_pred_positions + final_atom_mask = atom37_atom_exists + rp_structure_module = act_iter + res = (final_atom_positions, final_atom_mask, rp_structure_module, atom14_pred_positions,\ + final_affines, angles_sin_cos_new, um_angles_sin_cos_new, sidechain_frames, \ + sidechain_atom_pos, structure_traj) + return res + + def iteration_operation(self, act, act_2d, sequence_mask, quaternion, rotation, translation, + initial_act, aatype): + """iteration_operation""" + affine_init = () + angles_sin_cos_init = () + um_angles_sin_cos_init = () + atom_pos_batch = () + frames_batch = () + + for _ in range(self.num_layer): + act, quaternion, translation, rotation, affine_output, angles_sin_cos, \ + unnormalized_angles_sin_cos, atom_pos, frames = \ + self.fold_iteration(act, act_2d, sequence_mask, quaternion, rotation, + translation, initial_act, aatype) + + affine_init = affine_init + (affine_output[None, ...],) + angles_sin_cos_init = angles_sin_cos_init + (angles_sin_cos[None, ...],) + um_angles_sin_cos_init = um_angles_sin_cos_init + \ + (unnormalized_angles_sin_cos[None, ...],) + atom_pos_batch += (mnp.concatenate(vecs_expand_dims(atom_pos, 0), + axis=0)[:, None, ...],) + frames_batch += (mnp.concatenate(rots_expand_dims(frames[0], 0) + + vecs_expand_dims(frames[1], 0), + axis=0)[:, None, ...],) + affine_output_new = mnp.concatenate(affine_init, axis=0) + angles_sin_cos_new = mnp.concatenate(angles_sin_cos_init, axis=0) + um_angles_sin_cos_new = mnp.concatenate(um_angles_sin_cos_init, axis=0) + frames_new = mnp.concatenate(frames_batch, axis=1) + atom_pos_new = mnp.concatenate(atom_pos_batch, axis=1) + res = (atom_pos_new, affine_output_new, angles_sin_cos_new, \ + um_angles_sin_cos_new, frames_new, act) + return res diff --git a/MindSPONGE/src/mindsponge/pipeline/models/multimer/module/multimer_structure.py b/MindSPONGE/src/mindsponge/pipeline/models/multimer/module/multimer_structure.py new file mode 100644 index 000000000..4dc9eb893 --- /dev/null +++ b/MindSPONGE/src/mindsponge/pipeline/models/multimer/module/multimer_structure.py @@ -0,0 +1,252 @@ +# Copyright 2023 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""structure module""" +import numpy as np +import mindspore.common.dtype as mstype +import mindspore.nn as nn +import mindspore.numpy as mnp +import mindspore.ops as ops +from mindspore import Tensor +from mindspore.ops import functional as F +from mindsponge.common import residue_constants +from mindsponge.cell.initializer import lecun_init +from mindsponge.common.utils import torsion_angles_to_frames, frames_and_literature_positions_to_atom14_pos, \ + atom14_to_atom37 +from mindsponge.common.geometry import initial_affine, quaternion_to_tensor, pre_compose, vecs_scale,\ + vecs_to_tensor, vecs_expand_dims, rots_expand_dims +from .multimer_block import MultimerInvariantPointAttention + + +class MultiRigidSidechain(nn.Cell): + """Class to make side chain atoms.""" + + def __init__(self, config, single_repr_dim): + super().__init__() + self.config = config + self.input_projection = nn.Dense(single_repr_dim, self.config.num_channel, + weight_init=lecun_init(single_repr_dim)) + self.input_projection_1 = nn.Dense(single_repr_dim, self.config.num_channel, + weight_init=lecun_init(single_repr_dim)) + self.relu = nn.ReLU() + self.resblock1 = nn.Dense(self.config.num_channel, self.config.num_channel, + weight_init=lecun_init(self.config.num_channel, + initializer_name='relu')) + self.resblock2 = nn.Dense(self.config.num_channel, self.config.num_channel, weight_init='zeros') + self.resblock1_1 = nn.Dense(self.config.num_channel, self.config.num_channel, + weight_init=lecun_init(self.config.num_channel, initializer_name='relu')) + self.resblock2_1 = nn.Dense(self.config.num_channel, self.config.num_channel, weight_init='zeros') + self.unnormalized_angles = nn.Dense(self.config.num_channel, 14, + weight_init=lecun_init(self.config.num_channel)) + self.restype_atom14_to_rigid_group = Tensor(residue_constants.restype_atom14_to_rigid_group) + self.restype_atom14_rigid_group_positions = Tensor(residue_constants.restype_atom14_rigid_group_positions) + self.restype_atom14_mask = Tensor(residue_constants.restype_atom14_mask) + self.restype_rigid_group_default_frame = Tensor(residue_constants.restype_rigid_group_default_frame) + self.l2_normalize = ops.L2Normalize(axis=-1, epsilon=1e-6) + + def construct(self, rotation, translation, act, initial_act, aatype): + """Predict side chains using rotation and translation representations. + + Args: + rotation: The rotation matrices. + translation: A translation matrices. + act: updated pair activations from structure module + initial_act: initial act representations (input of structure module) + aatype: Amino acid type representations + + Returns: + angles, positions and new frames + """ + + act1 = self.input_projection(self.relu(act)) + init_act1 = self.input_projection_1(self.relu(initial_act)) + # Sum the activation list (equivalent to concat then Linear). + act = act1 + init_act1 + + # Mapping with some residual blocks. + # resblock1 + old_act = act + act = self.resblock1(self.relu(act)) + act = self.resblock2(self.relu(act)) + act += old_act + # resblock2 + old_act = act + act = self.resblock1_1(self.relu(act)) + act = self.resblock2_1(self.relu(act)) + act += old_act + + # Map activations to torsion angles. Shape: (num_res, 14). + num_res = act.shape[0] + unnormalized_angles = self.unnormalized_angles(self.relu(act)) + + unnormalized_angles = mnp.reshape(unnormalized_angles, [num_res, 7, 2]) + angles = self.l2_normalize(unnormalized_angles) + + backb_to_global = ((rotation[0], rotation[1], rotation[2], + rotation[3], rotation[4], rotation[5], + rotation[6], rotation[7], rotation[8]), + (translation[0], translation[1], translation[2])) + + all_frames_to_global = torsion_angles_to_frames(aatype, backb_to_global, angles, + self.restype_rigid_group_default_frame) + + pred_positions = frames_and_literature_positions_to_atom14_pos(aatype, all_frames_to_global, + self.restype_atom14_to_rigid_group, + self.restype_atom14_rigid_group_positions, + self.restype_atom14_mask) + + atom_pos = pred_positions + frames = all_frames_to_global + res = (angles, unnormalized_angles, atom_pos, frames) + return res + + +class MultimerFoldIteration(nn.Cell): + """A single iteration of the main structure module loop.""" + + def __init__(self, config, pair_dim, single_repr_dim): + super().__init__() + self.config = config + self.drop_out = nn.Dropout(p=0.1) + self.attention_layer_norm = nn.LayerNorm([self.config.num_channel,], epsilon=1e-5) + self.transition_layer_norm = nn.LayerNorm([self.config.num_channel,], epsilon=1e-5) + self.transition = nn.Dense(self.config.num_channel, config.num_channel, + weight_init=lecun_init(self.config.num_channel, initializer_name='relu')) + self.transition_1 = nn.Dense(self.config.num_channel, self.config.num_channel, + weight_init=lecun_init(self.config.num_channel, initializer_name='relu')) + self.transition_2 = nn.Dense(self.config.num_channel, self.config.num_channel, weight_init='zeros') + self.relu = nn.ReLU() + self.affine_update = nn.Dense(self.config.num_channel, 6, weight_init='zeros') + self.attention_module = MultimerInvariantPointAttention(self.config.num_head, + self.config.num_scalar_qk, + self.config.num_scalar_v, + self.config.num_point_v, + self.config.num_point_qk, + self.config.num_channel, + pair_dim) + self.mu_side_chain = MultiRigidSidechain(self.config.sidechain, single_repr_dim) + + def construct(self, act, static_feat_2d, sequence_mask, quaternion, rotation, translation, initial_act, aatype): + """construct""" + attn = self.attention_module(act, static_feat_2d, sequence_mask, rotation, translation) + act += attn + act = self.drop_out(act) + act = self.attention_layer_norm(act) + # Transition + input_act = act + act = self.transition(act) + act = self.relu(act) + act = self.transition_1(act) + act = self.relu(act) + act = self.transition_2(act) + + act += input_act + act = self.drop_out(act) + act = self.transition_layer_norm(act) + # This block corresponds to + # Jumper et al. (2021) Alg. 23 "Backbone update" + # Affine update + affine_update = self.affine_update(act) + quaternion, rotation, translation = pre_compose(quaternion, rotation, translation, affine_update) + translation1 = vecs_scale(translation, 20.0) + rotation1 = rotation + angles_sin_cos, unnormalized_angles_sin_cos, atom_pos, frames = \ + self.mu_side_chain(rotation1, translation1, act, initial_act, aatype) + affine_output = quaternion_to_tensor(quaternion, translation) + quaternion = F.stop_gradient(quaternion) + rotation = F.stop_gradient(rotation) + res = (act, quaternion, translation, rotation, affine_output, angles_sin_cos, unnormalized_angles_sin_cos, \ + atom_pos, frames) + return res + + +class MultimerStructureModule(nn.Cell): + """StructureModule as a network head.""" + + def __init__(self, config, single_repr_dim, pair_dim): + super(MultimerStructureModule, self).__init__() + self.config = config.model.structure_module + self.seq_length = config.seq_length + self.fold_iteration = MultimerFoldIteration(self.config, pair_dim, single_repr_dim) + self.single_layer_norm = nn.LayerNorm([single_repr_dim,], epsilon=1e-5) + self.initial_projection = nn.Dense(single_repr_dim, self.config.num_channel, + weight_init=lecun_init(single_repr_dim)) + self.pair_layer_norm = nn.LayerNorm([pair_dim,], epsilon=1e-5) + self.num_layer = self.config.num_layer + self.indice0 = Tensor( + np.arange(self.seq_length).reshape((-1, 1, 1)).repeat(37, axis=1).astype("int32")) + self.traj_w = Tensor(np.array([1.] * 4 + [self.config.position_scale] * 3), mstype.float32) + + def construct(self, single, pair, seq_mask, aatype, residx_atom37_to_atom14=None, atom37_atom_exists=None): + """construct""" + sequence_mask = seq_mask[:, None] + act = self.single_layer_norm(single) + initial_act = act + act = self.initial_projection(act) + quaternion, rotation, translation = initial_affine(self.seq_length) + act_2d = self.pair_layer_norm(pair) + + # folder iteration + atom_pos, affine_output_new, angles_sin_cos_new, um_angles_sin_cos_new, sidechain_frames, act_iter = \ + self.iteration_operation(act, act_2d, sequence_mask, quaternion, rotation, translation, initial_act, aatype) + atom14_pred_positions = vecs_to_tensor(atom_pos)[-1] + sidechain_atom_pos = atom_pos + + atom37_pred_positions = atom14_to_atom37(atom14_pred_positions, + residx_atom37_to_atom14, + atom37_atom_exists, + self.indice0) + structure_traj = affine_output_new * self.traj_w + final_affines = affine_output_new[-1] + final_atom_positions = atom37_pred_positions + final_atom_mask = atom37_atom_exists + rp_structure_module = act_iter + res = (final_atom_positions, final_atom_mask, rp_structure_module, atom14_pred_positions, final_affines, \ + angles_sin_cos_new, um_angles_sin_cos_new, sidechain_frames, sidechain_atom_pos, structure_traj) + return res + + def iteration_operation(self, act, act_2d, sequence_mask, quaternion, rotation, translation, initial_act, + aatype): + """iteration_operation""" + affine_init = () + angles_sin_cos_init = () + um_angles_sin_cos_init = () + atom_pos_batch = () + frames_batch = () + + for _ in range(self.num_layer): + act, quaternion, translation, rotation, affine_output, angles_sin_cos, unnormalized_angles_sin_cos, \ + atom_pos, frames = \ + self.fold_iteration(act, act_2d, sequence_mask, quaternion, rotation, translation, initial_act, aatype) + affine_init = affine_init + (affine_output[None, ...],) + angles_sin_cos_init = angles_sin_cos_init + (angles_sin_cos[None, ...],) + um_angles_sin_cos_init = um_angles_sin_cos_init + (unnormalized_angles_sin_cos[None, ...],) + atom_pos_batch += (mnp.concatenate(vecs_expand_dims(atom_pos, 0), axis=0)[:, None, ...],) + frames_batch += (mnp.concatenate(rots_expand_dims(frames[0], 0) + + vecs_expand_dims(frames[1], 0), axis=0)[:, None, ...],) + affine_output_new = mnp.concatenate(affine_init, axis=0) + angles_sin_cos_new = mnp.concatenate(angles_sin_cos_init, axis=0) + um_angles_sin_cos_new = mnp.concatenate(um_angles_sin_cos_init, axis=0) + frames_new = mnp.concatenate(frames_batch, axis=1) + atom_pos_new = mnp.concatenate(atom_pos_batch, axis=1) + res = (atom_pos_new, affine_output_new, angles_sin_cos_new, um_angles_sin_cos_new, frames_new, act) + return res diff --git a/MindSPONGE/src/mindsponge/pipeline/models/rasp/module/structure.py b/MindSPONGE/src/mindsponge/pipeline/models/rasp/module/structure.py new file mode 100644 index 000000000..4eac0e45b --- /dev/null +++ b/MindSPONGE/src/mindsponge/pipeline/models/rasp/module/structure.py @@ -0,0 +1,498 @@ +# Copyright 2023 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""structure module""" +import numpy as np +import mindspore.common.dtype as mstype +import mindspore.nn as nn +import mindspore.numpy as mnp +import mindspore.ops as ops +from mindspore import Tensor, Parameter +from mindspore.ops import functional as F +import mindsponge.common.residue_constants as residue_constants +from mindsponge.cell.initializer import lecun_init +from mindsponge.common.utils import torsion_angles_to_frames, frames_and_literature_positions_to_atom14_pos, \ + atom14_to_atom37, pseudo_beta_fn +from mindsponge.common.geometry import initial_affine, quaternion_to_tensor, pre_compose, vecs_scale, \ + vecs_to_tensor, vecs_expand_dims, rots_expand_dims, apply_to_point, invert_point + + +class InvariantPointContactAttention(nn.Cell): + r""" + Invariant Point attention module. + This module is used to update the sequence representation ,which is the first input--inputs_1d, + adding location information to the sequence representation. + + The attention consists of three parts, namely, q, k, v obtained by the sequence representation, + q'k'v' obtained by the interaction between the sequence representation and the rigid body group, + and b , which is th bias, obtained from the pair representation (the second inputs -- inputs_2d). + + .. math:: + a_{ij} = Softmax(w_l(c_1{q_i}^Tk_j+b{ij}-c_2\sum {\left \| T_i\circ q'_i-T_j\circ k'_j \right \| ^{2 } }) + + where i and j represent the ith and jth amino acids in the sequence, respectively, + and T is the rotation and translation in the input. + + `Jumper et al. (2021) Suppl. Alg. 22 "InvariantPointContactAttention" + `_. + + Args: + num_head (int): The number of the heads. + num_scalar_qk (int): The number of the scalar query/key. + num_scalar_v (int): The number of the scalar value. + num_point_v (int): The number of the point value. + num_point_qk (int): The number of the point query/key. + num_channel (int): The number of the channel. + pair_dim (int): The last dimension length of pair. + + Inputs: + - **inputs_1d** (Tensor) - The first row of msa representation which is the output of evoformer module, + also called the sequence representation, shape :math:`[N_{res}, num\_channel]`. + - **inputs_2d** (Tensor) - The pair representation which is the output of evoformer module, + shape :math:`[N_{res}, N_{res}, pair\_dim]`. + - **mask** (Tensor) - A mask that determines which elements of inputs_1d are involved in the + attention calculation, shape :math:`[N_{res}, 1]` + - **rotation** (tuple) - A rotation term in a rigid body group T(r,t), + A tuple of length 9, The shape of each elements in the tuple is :math:`[N_{res}]`. + - **translation** (tuple) - A translation term in a rigid body group T(r,t), + A tuple of length 3, The shape of each elements in the tuple is :math:`[N_{res}]`. + + Outputs: + Tensor, the update of inputs_1d, shape :math:`[N_{res}, num\_channel]`. + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> import numpy as np + >>> from mindsponge.cell import InvariantPointContactAttention + >>> from mindspore import dtype as mstype + >>> from mindspore import Tensor + >>> import mindspore.context as context + >>> context.set_context(mode=context.GRAPH_MODE) + >>> model = InvariantPointContactAttention(num_head=12, num_scalar_qk=16, num_scalar_v=16, + ... num_point_v=8, num_point_qk=4, + ... num_channel=384, pair_dim=128) + >>> inputs_1d = Tensor(np.ones((256, 384)), mstype.float32) + >>> inputs_2d = Tensor(np.ones((256, 256, 128)), mstype.float32) + >>> mask = Tensor(np.ones((256, 1)), mstype.float32) + >>> rotation = tuple([Tensor(np.ones(256), mstype.float16) for _ in range(9)]) + >>> translation = tuple([Tensor(np.ones(256), mstype.float16) for _ in range(3)]) + >>> attn_out = model(inputs_1d, inputs_2d, mask, rotation, translation) + >>> print(attn_out.shape) + (256, 384) + """ + + def __init__(self, num_head, num_scalar_qk, num_scalar_v, num_point_v, num_point_qk, num_channel, pair_dim): + super(InvariantPointContactAttention, self).__init__() + + self._dist_epsilon = 1e-8 + self.num_head = num_head + self.num_scalar_qk = num_scalar_qk + self.num_scalar_v = num_scalar_v + self.num_point_v = num_point_v + self.num_point_qk = num_point_qk + self.num_channel = num_channel + self.projection_num = self.num_head * self.num_scalar_v + self.num_head * self.num_point_v * 4 + \ + self.num_head * pair_dim + self.q_scalar = nn.Dense(self.num_channel, self.num_head * self.num_scalar_qk, + weight_init=lecun_init(self.num_channel)) + self.kv_scalar = nn.Dense(self.num_channel, self.num_head * (self.num_scalar_qk + self.num_scalar_v), + weight_init=lecun_init(self.num_channel)) + self.q_point_local = nn.Dense(self.num_channel, self.num_head * 3 * self.num_point_qk, + weight_init=lecun_init(self.num_channel) + ) + self.kv_point_local = nn.Dense(self.num_channel, self.num_head * 3 * (self.num_point_qk + self.num_point_v), + weight_init=lecun_init(self.num_channel)) + self.contact_layer = nn.Dense(32, self.num_head) + self.soft_max = nn.Softmax() + self.soft_plus = ops.Softplus() + self.trainable_point_weights = Parameter(Tensor(np.ones((12,)), mstype.float32), name="trainable_point_weights") + self.attention_2d = nn.Dense(pair_dim, self.num_head, weight_init=lecun_init(pair_dim)) + self.output_projection = nn.Dense(self.projection_num, self.num_channel, weight_init='zeros' + ) + self.scalar_weights = Tensor(np.sqrt(1.0 / (3 * 16)).astype(np.float32)) + self.point_weights = Tensor(np.sqrt(1.0 / (3 * 18)).astype(np.float32)) + self.attention_2d_weights = Tensor(np.sqrt(1.0 / 3).astype(np.float32)) + + def construct(self, inputs_1d, inputs_2d, mask, rotation, translation, contact_act=None, contact_info_mask=None): + '''construct''' + num_residues, _ = inputs_1d.shape + + # Improve readability by removing a large number of 'self's. + num_head = self.num_head + num_scalar_qk = self.num_scalar_qk + num_point_qk = self.num_point_qk + num_scalar_v = self.num_scalar_v + num_point_v = self.num_point_v + + # Construct scalar queries of shape: + q_scalar = self.q_scalar(inputs_1d) + q_scalar = mnp.reshape(q_scalar, [num_residues, num_head, num_scalar_qk]) + + # Construct scalar keys/values of shape: + kv_scalar = self.kv_scalar(inputs_1d) + kv_scalar = mnp.reshape(kv_scalar, [num_residues, num_head, num_scalar_v + num_scalar_qk]) + k_scalar, v_scalar = mnp.split(kv_scalar, [num_scalar_qk], axis=-1) + + # Construct query points of shape: + # First construct query points in local frame. + q_point_local = self.q_point_local(inputs_1d) + + q_point_local = mnp.split(q_point_local, 3, axis=-1) + q_point_local = (ops.Squeeze()(q_point_local[0]), ops.Squeeze()(q_point_local[1]), + ops.Squeeze()(q_point_local[2])) + # Project query points into global frame. + q_point_global = apply_to_point(rotation, translation, q_point_local, 1) + + # Reshape query point for later use. + q_point0 = mnp.reshape(q_point_global[0], (num_residues, num_head, num_point_qk)) + q_point1 = mnp.reshape(q_point_global[1], (num_residues, num_head, num_point_qk)) + q_point2 = mnp.reshape(q_point_global[2], (num_residues, num_head, num_point_qk)) + + # Construct key and value points. + # Key points have shape [num_residues, num_head, num_point_qk] + # Value points have shape [num_residues, num_head, num_point_v] + + # Construct key and value points in local frame. + kv_point_local = self.kv_point_local(inputs_1d) + + kv_point_local = mnp.split(kv_point_local, 3, axis=-1) + kv_point_local = (ops.Squeeze()(kv_point_local[0]), ops.Squeeze()(kv_point_local[1]), + ops.Squeeze()(kv_point_local[2])) + # Project key and value points into global frame. + kv_point_global = apply_to_point(rotation, translation, kv_point_local, 1) + + kv_point_global0 = mnp.reshape(kv_point_global[0], (num_residues, num_head, (num_point_qk + num_point_v))) + kv_point_global1 = mnp.reshape(kv_point_global[1], (num_residues, num_head, (num_point_qk + num_point_v))) + kv_point_global2 = mnp.reshape(kv_point_global[2], (num_residues, num_head, (num_point_qk + num_point_v))) + + # Split key and value points. + k_point0, v_point0 = mnp.split(kv_point_global0, [num_point_qk], axis=-1) + k_point1, v_point1 = mnp.split(kv_point_global1, [num_point_qk], axis=-1) + k_point2, v_point2 = mnp.split(kv_point_global2, [num_point_qk], axis=-1) + + trainable_point_weights = self.soft_plus(self.trainable_point_weights) + point_weights = self.point_weights * ops.expand_dims(trainable_point_weights, axis=1) + + v_point = [mnp.swapaxes(v_point0, -2, -3), mnp.swapaxes(v_point1, -2, -3), mnp.swapaxes(v_point2, -2, -3)] + q_point = [mnp.swapaxes(q_point0, -2, -3), mnp.swapaxes(q_point1, -2, -3), mnp.swapaxes(q_point2, -2, -3)] + k_point = [mnp.swapaxes(k_point0, -2, -3), mnp.swapaxes(k_point1, -2, -3), mnp.swapaxes(k_point2, -2, -3)] + + dist2 = ops.Square()(ops.expand_dims(q_point[0], 2) - ops.expand_dims(k_point[0], 1)) + \ + ops.Square()(ops.expand_dims(q_point[1], 2) - ops.expand_dims(k_point[1], 1)) + \ + ops.Square()(ops.expand_dims(q_point[2], 2) - ops.expand_dims(k_point[2], 1)) + + attn_qk_point = -0.5 * mnp.sum(ops.expand_dims(ops.expand_dims(point_weights, 1), 1) * dist2, axis=-1) + + v = mnp.swapaxes(v_scalar, -2, -3) + q = mnp.swapaxes(self.scalar_weights * q_scalar, -2, -3) + k = mnp.swapaxes(k_scalar, -2, -3) + attn_qk_scalar = ops.matmul(q, mnp.swapaxes(k, -2, -1)) + attn_logits = attn_qk_scalar + attn_qk_point + + attention_2d = self.attention_2d(inputs_2d) + attention_2d = mnp.transpose(attention_2d, [2, 0, 1]) + attention_2d = self.attention_2d_weights * attention_2d + + attn_logits += attention_2d + + # modify wch + contact_act = self.contact_layer(contact_act) + contact_act = ops.Transpose()(contact_act, (2, 0, 1)) + contact_act = contact_act * contact_info_mask[None, :, :] + + attn_logits += contact_act + + mask_2d = mask * mnp.swapaxes(mask, -1, -2) + attn_logits -= 50 * (1. - mask_2d) + + attn = self.soft_max(attn_logits) + + result_scalar = ops.matmul(attn, v) + + result_point_global = [mnp.swapaxes(mnp.sum(attn[:, :, :, None] * v_point[0][:, None, :, :], axis=-2), -2, -3), + mnp.swapaxes(mnp.sum(attn[:, :, :, None] * v_point[1][:, None, :, :], axis=-2), -2, -3), + mnp.swapaxes(mnp.sum(attn[:, :, :, None] * v_point[2][:, None, :, :], axis=-2), -2, -3) + ] + + result_point_global = [mnp.reshape(result_point_global[0], [num_residues, num_head * num_point_v]), + mnp.reshape(result_point_global[1], [num_residues, num_head * num_point_v]), + mnp.reshape(result_point_global[2], [num_residues, num_head * num_point_v])] + result_scalar = mnp.swapaxes(result_scalar, -2, -3) + + result_scalar = mnp.reshape(result_scalar, [num_residues, num_head * num_scalar_v]) + + result_point_local = invert_point(result_point_global, rotation, translation, 1) + + output_feature1 = result_scalar + output_feature20 = result_point_local[0] + output_feature21 = result_point_local[1] + output_feature22 = result_point_local[2] + + output_feature3 = mnp.sqrt(self._dist_epsilon + + ops.Square()(result_point_local[0]) + + ops.Square()(result_point_local[1]) + + ops.Square()(result_point_local[2])) + + result_attention_over_2d = ops.matmul(mnp.swapaxes(attn, 0, 1), inputs_2d) + num_out = num_head * result_attention_over_2d.shape[-1] + output_feature4 = mnp.reshape(result_attention_over_2d, [num_residues, num_out]) + + final_act = mnp.concatenate([output_feature1, output_feature20, output_feature21, + output_feature22, output_feature3, output_feature4], axis=-1) + final_result = self.output_projection(final_act) + return final_result + + +class MultiRigidSidechain(nn.Cell): + """Class to make side chain atoms.""" + + def __init__(self, config, single_repr_dim): + super().__init__() + self.config = config + self.input_projection = nn.Dense(single_repr_dim, self.config.num_channel, + weight_init=lecun_init(single_repr_dim)) + self.input_projection_1 = nn.Dense(single_repr_dim, self.config.num_channel, + weight_init=lecun_init(single_repr_dim)) + self.relu = nn.ReLU() + self.resblock1 = nn.Dense(self.config.num_channel, self.config.num_channel, + weight_init=lecun_init(self.config.num_channel, + initializer_name='relu')) + self.resblock2 = nn.Dense(self.config.num_channel, self.config.num_channel, weight_init='zeros') + self.resblock1_1 = nn.Dense(self.config.num_channel, self.config.num_channel, + weight_init=lecun_init(self.config.num_channel, initializer_name='relu')) + self.resblock2_1 = nn.Dense(self.config.num_channel, self.config.num_channel, weight_init='zeros') + self.unnormalized_angles = nn.Dense(self.config.num_channel, 14, + weight_init=lecun_init(self.config.num_channel)) + self.restype_atom14_to_rigid_group = Tensor(residue_constants.restype_atom14_to_rigid_group) + self.restype_atom14_rigid_group_positions = Tensor(residue_constants.restype_atom14_rigid_group_positions) + self.restype_atom14_mask = Tensor(residue_constants.restype_atom14_mask) + self.restype_rigid_group_default_frame = Tensor(residue_constants.restype_rigid_group_default_frame) + self.l2_normalize = ops.L2Normalize(axis=-1, epsilon=1e-6) + + def construct(self, rotation, translation, act, initial_act, aatype): + """Predict side chains using rotation and translation representations. + + Args: + rotation: The rotation matrices. + translation: A translation matrices. + act: updated pair activations from structure module + initial_act: initial act representations (input of structure module) + aatype: Amino acid type representations + + Returns: + angles, positions and new frames + """ + + act1 = self.input_projection(self.relu(act)) + init_act1 = self.input_projection_1(self.relu(initial_act)) + # Sum the activation list (equivalent to concat then Linear). + act = act1 + init_act1 + + # Mapping with some residual blocks. + # resblock1 + old_act = act + act = self.resblock1(self.relu(act)) + act = self.resblock2(self.relu(act)) + act += old_act + # resblock2 + old_act = act + act = self.resblock1_1(self.relu(act)) + act = self.resblock2_1(self.relu(act)) + act += old_act + + # Map activations to torsion angles. Shape: (num_res, 14). + num_res = act.shape[0] + unnormalized_angles = self.unnormalized_angles(self.relu(act)) + + unnormalized_angles = mnp.reshape(unnormalized_angles, [num_res, 7, 2]) + angles = self.l2_normalize(unnormalized_angles) + + backb_to_global = ((rotation[0], rotation[1], rotation[2], + rotation[3], rotation[4], rotation[5], + rotation[6], rotation[7], rotation[8]), + (translation[0], translation[1], translation[2])) + + all_frames_to_global = torsion_angles_to_frames(aatype, backb_to_global, angles, + self.restype_rigid_group_default_frame) + + pred_positions = frames_and_literature_positions_to_atom14_pos(aatype, all_frames_to_global, + self.restype_atom14_to_rigid_group, + self.restype_atom14_rigid_group_positions, + self.restype_atom14_mask) + + atom_pos = pred_positions + frames = all_frames_to_global + res = (angles, unnormalized_angles, atom_pos, frames) + return res + + +class FoldIteration(nn.Cell): + """A single iteration of the main structure module loop.""" + + def __init__(self, config, pair_dim, single_repr_dim): + super().__init__() + self.config = config + self.drop_out = nn.Dropout(p=0.1) + self.attention_layer_norm = nn.LayerNorm([self.config.num_channel,], epsilon=1e-5) + self.transition_layer_norm = nn.LayerNorm([self.config.num_channel,], epsilon=1e-5) + self.transition = nn.Dense(self.config.num_channel, config.num_channel, + weight_init=lecun_init(self.config.num_channel, initializer_name='relu')) + self.transition_1 = nn.Dense(self.config.num_channel, self.config.num_channel, + weight_init=lecun_init(self.config.num_channel, initializer_name='relu')) + self.transition_2 = nn.Dense(self.config.num_channel, self.config.num_channel, weight_init='zeros') + self.relu = nn.ReLU() + self.affine_update = nn.Dense(self.config.num_channel, 6, weight_init='zeros') + self.attention_module = InvariantPointContactAttention(self.config.num_head, + self.config.num_scalar_qk, + self.config.num_scalar_v, + self.config.num_point_v, + self.config.num_point_qk, + self.config.num_channel, + pair_dim) + self.mu_side_chain = MultiRigidSidechain(self.config.sidechain, single_repr_dim) + self.print = ops.Print() + + def construct(self, act, static_feat_2d, sequence_mask, quaternion, rotation, \ + translation, initial_act, aatype, contact_act2, contact_info_mask2): + """construct""" + attn = self.attention_module(act, static_feat_2d, sequence_mask, \ + rotation, translation, contact_act2, contact_info_mask2) + act += attn + act = self.drop_out(act) + act = self.attention_layer_norm(act) + # Transition + input_act = act + act = self.transition(act) + act = self.relu(act) + act = self.transition_1(act) + act = self.relu(act) + act = self.transition_2(act) + + act += input_act + act = self.drop_out(act) + act = self.transition_layer_norm(act) + + # This block corresponds to + # Jumper et al. (2021) Alg. 23 "Backbone update" + # Affine update + affine_update = self.affine_update(act) + quaternion, rotation, translation = pre_compose(quaternion, rotation, translation, affine_update) + translation1 = vecs_scale(translation, 10.0) + rotation1 = rotation + angles_sin_cos, unnormalized_angles_sin_cos, atom_pos, frames = \ + self.mu_side_chain(rotation1, translation1, act, initial_act, aatype) + + affine_output = quaternion_to_tensor(quaternion, translation) + quaternion = F.stop_gradient(quaternion) + rotation = F.stop_gradient(rotation) + res = (act, quaternion, translation, rotation, affine_output, angles_sin_cos, unnormalized_angles_sin_cos, \ + atom_pos, frames) + return res + + +class ContactStructureModule(nn.Cell): + """StructureModule as a network head.""" + + def __init__(self, config, single_repr_dim, pair_dim): + super(ContactStructureModule, self).__init__() + self.config = config.model.structure_module + self.seq_length = config.seq_length + self.fold_iteration = FoldIteration(self.config, pair_dim, single_repr_dim) + self.single_layer_norm = nn.LayerNorm([single_repr_dim,], epsilon=1e-5) + self.initial_projection = nn.Dense(single_repr_dim, self.config.num_channel, + weight_init=lecun_init(single_repr_dim)) + self.pair_layer_norm = nn.LayerNorm([pair_dim,], epsilon=1e-5) + self.num_layer = self.config.num_layer + self.indice0 = Tensor( + np.arange(self.seq_length).reshape((-1, 1, 1)).repeat(37, axis=1).astype("int32")) + self.traj_w = Tensor(np.array([1.] * 4 + [self.config.position_scale] * 3), mstype.float32) + self.use_sumcons = True + + def construct(self, single, pair, seq_mask, aatype, contact_act2, contact_info_mask2, residx_atom37_to_atom14=None, + atom37_atom_exists=None): + """construct""" + sequence_mask = seq_mask[:, None] + act = self.single_layer_norm(single) + initial_act = act + act = self.initial_projection(act) + quaternion, rotation, translation = initial_affine(self.seq_length) + act_2d = self.pair_layer_norm(pair) + # folder iteration + atom_pos, affine_output_new, angles_sin_cos_new, um_angles_sin_cos_new, sidechain_frames, act_iter = \ + self.iteration_operation(act, act_2d, sequence_mask, quaternion, rotation, translation, initial_act, aatype, + contact_act2, contact_info_mask2) + atom14_pred_positions = vecs_to_tensor(atom_pos)[-1] + sidechain_atom_pos = atom_pos + + atom37_pred_positions = atom14_to_atom37(atom14_pred_positions, + residx_atom37_to_atom14, + atom37_atom_exists, + self.indice0) + + structure_traj = affine_output_new * self.traj_w + final_affines = affine_output_new[-1] + final_atom_positions = atom37_pred_positions + final_atom_mask = atom37_atom_exists + rp_structure_module = act_iter + if self.use_sumcons: + pseudo_beta_pred = pseudo_beta_fn(aatype, atom37_pred_positions, None) + coord_diffs = pseudo_beta_pred[None] - pseudo_beta_pred[:, None] + distance = ops.Sqrt()(ops.ReduceSum()(ops.Square()(coord_diffs), -1) + 1e-8) + scale = (8.10 / distance - 1) * contact_info_mask2 * (distance > 8.10) + contact_translation_2 = scale[:, :, None] * coord_diffs / 2 + contact_translation = ops.ReduceSum(keep_dims=True)(contact_translation_2, 1) + atom14_pred_positions = atom14_pred_positions - contact_translation + final_atom_positions = final_atom_positions - contact_translation + res = (final_atom_positions, final_atom_mask, rp_structure_module, atom14_pred_positions, final_affines, \ + angles_sin_cos_new, um_angles_sin_cos_new, sidechain_frames, sidechain_atom_pos, structure_traj) + return res + + def iteration_operation(self, act, act_2d, sequence_mask, quaternion, rotation, translation, initial_act, + aatype, contact_act2, contact_info_mask2): + """iteration_operation""" + affine_init = () + angles_sin_cos_init = () + um_angles_sin_cos_init = () + atom_pos_batch = () + frames_batch = () + + for _ in range(self.num_layer): + act, quaternion, translation, rotation, affine_output, angles_sin_cos, unnormalized_angles_sin_cos, \ + atom_pos, frames = \ + self.fold_iteration(act, act_2d, sequence_mask, quaternion, rotation, translation, initial_act, aatype, + contact_act2, contact_info_mask2) + + affine_init = affine_init + (affine_output[None, ...],) + angles_sin_cos_init = angles_sin_cos_init + (angles_sin_cos[None, ...],) + um_angles_sin_cos_init = um_angles_sin_cos_init + (unnormalized_angles_sin_cos[None, ...],) + atom_pos_batch += (mnp.concatenate(vecs_expand_dims(atom_pos, 0), axis=0)[:, None, ...],) + frames_batch += (mnp.concatenate(rots_expand_dims(frames[0], 0) + + vecs_expand_dims(frames[1], 0), axis=0)[:, None, ...],) + affine_output_new = mnp.concatenate(affine_init, axis=0) + angles_sin_cos_new = mnp.concatenate(angles_sin_cos_init, axis=0) + um_angles_sin_cos_new = mnp.concatenate(um_angles_sin_cos_init, axis=0) + frames_new = mnp.concatenate(frames_batch, axis=1) + atom_pos_new = mnp.concatenate(atom_pos_batch, axis=1) + res = (atom_pos_new, affine_output_new, angles_sin_cos_new, um_angles_sin_cos_new, frames_new, act) + return res -- Gitee